Repository: geldata/gel Branch: master Commit: 85191063b4db Files: 1427 Total size: 21.0 MB Directory structure: gitextract_9xzfwvrg/ ├── .editorconfig ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── config.yml │ │ └── feature_request.md │ ├── Makefile │ ├── aws-aurora/ │ │ ├── .gitignore │ │ └── main.tf │ ├── aws-rds/ │ │ ├── .gitignore │ │ ├── .terraform.lock.hcl │ │ ├── main.tf │ │ ├── outputs.tf │ │ └── variables.tf │ ├── do-database/ │ │ ├── .gitignore │ │ ├── .terraform.lock.hcl │ │ ├── main.tf │ │ └── outputs.tf │ ├── gcp-cloud-sql/ │ │ ├── .gitignore │ │ ├── .terraform.lock.hcl │ │ └── main.tf │ ├── heroku-postgres/ │ │ ├── .gitignore │ │ └── main.tf │ ├── scripts/ │ │ ├── docs/ │ │ │ └── preview-deploy.js │ │ └── patches/ │ │ ├── compute-ipu-versions.py │ │ ├── compute-versions.py │ │ ├── create-databases.py │ │ └── test-downgrade.py │ ├── workflows/ │ │ ├── .gitattributes │ │ ├── build.dryrun.yml │ │ ├── build.ls-nightly.yml │ │ ├── build.nightly.yml │ │ ├── build.release.yml │ │ ├── build.testing.yml │ │ ├── docs-preview-deploy.yml │ │ ├── docs.yml │ │ ├── pull-request-meta.yml │ │ ├── tests.ha.yml │ │ ├── tests.inplace.yml │ │ ├── tests.inplace7x.yml │ │ ├── tests.managed-pg.yml │ │ ├── tests.patches.yml │ │ ├── tests.pg-versions.yml │ │ ├── tests.pool.yml │ │ ├── tests.reflection.yml │ │ └── tests.yml │ └── workflows.src/ │ ├── build.dryrun.tpl.yml │ ├── build.inc.yml │ ├── build.ls-nightly.tpl.yml │ ├── build.ls.targets.yml │ ├── build.nightly.tpl.yml │ ├── build.release.tpl.yml │ ├── build.targets.yml │ ├── build.testing.tpl.yml │ ├── render.py │ ├── tests.ha.targets.yml │ ├── tests.ha.tpl.yml │ ├── tests.inc.yml │ ├── tests.inplace.targets.yml │ ├── tests.inplace.tpl.yml │ ├── tests.inplace7x.targets.yml │ ├── tests.inplace7x.tpl.yml │ ├── tests.managed-pg.targets.yml │ ├── tests.managed-pg.tpl.yml │ ├── tests.patches.targets.yml │ ├── tests.patches.tpl.yml │ ├── tests.pg-versions.targets.yml │ ├── tests.pg-versions.tpl.yml │ ├── tests.pool.targets.yml │ ├── tests.pool.tpl.yml │ ├── tests.reflection.targets.yml │ ├── tests.reflection.tpl.yml │ ├── tests.targets.yml │ └── tests.tpl.yml ├── .gitignore ├── .gitmodules ├── .mailmap ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.rst ├── Cargo.toml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── NOTICE ├── README.md ├── build_backend.py ├── dev-notes/ │ ├── concurrent-indexes.py │ ├── inplace-upgrades.md │ ├── newtype-checklist.md │ └── release-process.md ├── docs/ │ ├── .gitignore │ ├── Makefile │ ├── cloud/ │ │ ├── cli.rst │ │ ├── deploy/ │ │ │ ├── fly.rst │ │ │ ├── index.rst │ │ │ ├── netlify.rst │ │ │ ├── railway.rst │ │ │ ├── render.rst │ │ │ └── vercel.rst │ │ ├── http_gql.rst │ │ ├── index.rst │ │ ├── migrate_from.rst │ │ └── web.rst │ ├── conf.py │ ├── index.rst │ ├── intro/ │ │ ├── branches.rst │ │ ├── cli.rst │ │ ├── clients.rst │ │ ├── edgeql.rst │ │ ├── guides/ │ │ │ ├── ai/ │ │ │ │ ├── edgeql.rst │ │ │ │ ├── index.rst │ │ │ │ └── python.rst │ │ │ ├── drizzle/ │ │ │ │ ├── index.rst │ │ │ │ └── nextjs.rst │ │ │ └── index.rst │ │ ├── index.rst │ │ ├── install_table.rst │ │ ├── installation.rst │ │ ├── instances.rst │ │ ├── localdev.rst │ │ ├── migrations.rst │ │ ├── projects.rst │ │ ├── quickstart/ │ │ │ ├── ai/ │ │ │ │ ├── fastapi.rst │ │ │ │ └── index.rst │ │ │ ├── connecting/ │ │ │ │ ├── fastapi.rst │ │ │ │ ├── index.rst │ │ │ │ └── nextjs.rst │ │ │ ├── index.rst │ │ │ ├── inheritance/ │ │ │ │ ├── fastapi.rst │ │ │ │ ├── index.rst │ │ │ │ └── nextjs.rst │ │ │ ├── modeling/ │ │ │ │ ├── fastapi.rst │ │ │ │ ├── index.rst │ │ │ │ └── nextjs.rst │ │ │ ├── overview/ │ │ │ │ ├── fastapi.rst │ │ │ │ ├── index.rst │ │ │ │ └── nextjs.rst │ │ │ ├── setup/ │ │ │ │ ├── fastapi.rst │ │ │ │ ├── index.rst │ │ │ │ └── nextjs.rst │ │ │ └── working/ │ │ │ ├── fastapi.rst │ │ │ ├── index.rst │ │ │ └── nextjs.rst │ │ ├── schema.rst │ │ └── tutorials/ │ │ ├── ai_fastapi_searchbot.rst │ │ ├── gel_drizzle_booknotes.rst │ │ └── index.rst │ ├── redirects │ ├── redirects.js │ ├── reference/ │ │ ├── ai/ │ │ │ ├── extai.rst │ │ │ ├── extvectorstore.rst │ │ │ ├── http.rst │ │ │ ├── index.rst │ │ │ ├── javascript.rst │ │ │ ├── python.rst │ │ │ └── vectorstore_python.rst │ │ ├── auth/ │ │ │ ├── built_in_ui.rst │ │ │ ├── email_password.rst │ │ │ ├── http.rst │ │ │ ├── index.rst │ │ │ ├── magic_link.rst │ │ │ ├── oauth.rst │ │ │ ├── webauthn.rst │ │ │ └── webhooks.rst │ │ ├── datamodel/ │ │ │ ├── access_policies.rst │ │ │ ├── aliases.rst │ │ │ ├── annotations.rst │ │ │ ├── branches.rst │ │ │ ├── comparison.rst │ │ │ ├── computeds.rst │ │ │ ├── constraints.rst │ │ │ ├── extensions.rst │ │ │ ├── functions.rst │ │ │ ├── future.rst │ │ │ ├── globals.rst │ │ │ ├── index.rst │ │ │ ├── indexes.rst │ │ │ ├── inheritance.rst │ │ │ ├── introspection/ │ │ │ │ ├── casts.rst │ │ │ │ ├── colltypes.rst │ │ │ │ ├── constraints.rst │ │ │ │ ├── functions.rst │ │ │ │ ├── index.rst │ │ │ │ ├── indexes.rst │ │ │ │ ├── mutation_rewrites.rst │ │ │ │ ├── objects.rst │ │ │ │ ├── operators.rst │ │ │ │ ├── scalars.rst │ │ │ │ └── triggers.rst │ │ │ ├── linkprops.rst │ │ │ ├── links.rst │ │ │ ├── migrations.rst │ │ │ ├── modules.rst │ │ │ ├── mutation_rewrites.rst │ │ │ ├── objects.rst │ │ │ ├── permissions.rst │ │ │ ├── primitives.rst │ │ │ ├── properties.rst │ │ │ └── triggers.rst │ │ ├── edgeql/ │ │ │ ├── analyze.rst │ │ │ ├── delete.rst │ │ │ ├── for.rst │ │ │ ├── group.rst │ │ │ ├── index.rst │ │ │ ├── insert.rst │ │ │ ├── literals.rst │ │ │ ├── parameters.rst │ │ │ ├── path_resolution.rst │ │ │ ├── paths.rst │ │ │ ├── select.rst │ │ │ ├── sets.rst │ │ │ ├── transactions.rst │ │ │ ├── types.rst │ │ │ ├── update.rst │ │ │ └── with.rst │ │ ├── index.rst │ │ ├── reference/ │ │ │ ├── edgeql/ │ │ │ │ ├── analyze.rst │ │ │ │ ├── cardinality.rst │ │ │ │ ├── casts.csv │ │ │ │ ├── casts.rst │ │ │ │ ├── delete.rst │ │ │ │ ├── describe.rst │ │ │ │ ├── eval.rst │ │ │ │ ├── for.rst │ │ │ │ ├── functions.rst │ │ │ │ ├── group.rst │ │ │ │ ├── index.rst │ │ │ │ ├── insert.rst │ │ │ │ ├── lexical.rst │ │ │ │ ├── paths.rst │ │ │ │ ├── select.rst │ │ │ │ ├── sess_reset_alias.rst │ │ │ │ ├── sess_set_alias.rst │ │ │ │ ├── shapes.rst │ │ │ │ ├── tx_commit.rst │ │ │ │ ├── tx_rollback.rst │ │ │ │ ├── tx_sp_declare.rst │ │ │ │ ├── tx_sp_release.rst │ │ │ │ ├── tx_sp_rollback.rst │ │ │ │ ├── tx_start.rst │ │ │ │ ├── update.rst │ │ │ │ ├── volatility.rst │ │ │ │ └── with.rst │ │ │ └── index.rst │ │ ├── running/ │ │ │ ├── admin/ │ │ │ │ ├── configure.rst │ │ │ │ ├── index.rst │ │ │ │ ├── roles.rst │ │ │ │ ├── statistics_update.rst │ │ │ │ └── vacuum.rst │ │ │ ├── backend_ha.rst │ │ │ ├── configuration.rst │ │ │ ├── deployment/ │ │ │ │ ├── aws_aurora_ecs.rst │ │ │ │ ├── azure_flexibleserver.rst │ │ │ │ ├── bare_metal.rst │ │ │ │ ├── digitalocean.rst │ │ │ │ ├── docker.rst │ │ │ │ ├── fly_io.rst │ │ │ │ ├── gcp.rst │ │ │ │ ├── index.rst │ │ │ │ └── note_cloud_reset_password.rst │ │ │ ├── http.rst │ │ │ ├── index.rst │ │ │ └── local.rst │ │ ├── stdlib/ │ │ │ ├── abstract.rst │ │ │ ├── array.rst │ │ │ ├── bool.rst │ │ │ ├── bytes.rst │ │ │ ├── cfg.rst │ │ │ ├── constraint_table.rst │ │ │ ├── constraints.rst │ │ │ ├── datetime.rst │ │ │ ├── deprecated.rst │ │ │ ├── enum.rst │ │ │ ├── fts.rst │ │ │ ├── generic.rst │ │ │ ├── index.rst │ │ │ ├── json.rst │ │ │ ├── math.rst │ │ │ ├── math_funcops_table.rst │ │ │ ├── net.rst │ │ │ ├── numbers.rst │ │ │ ├── objects.rst │ │ │ ├── pg_trgm.rst │ │ │ ├── pg_unaccent.rst │ │ │ ├── pgcrypto.rst │ │ │ ├── pgvector.rst │ │ │ ├── postgis.rst │ │ │ ├── range.rst │ │ │ ├── sequence.rst │ │ │ ├── set.rst │ │ │ ├── string.rst │ │ │ ├── sys.rst │ │ │ ├── tuple.rst │ │ │ ├── type.rst │ │ │ └── uuid.rst │ │ └── using/ │ │ ├── cli/ │ │ │ ├── gel.rst │ │ │ ├── gel_analyze.rst │ │ │ ├── gel_branch/ │ │ │ │ ├── gel_branch_create.rst │ │ │ │ ├── gel_branch_drop.rst │ │ │ │ ├── gel_branch_list.rst │ │ │ │ ├── gel_branch_merge.rst │ │ │ │ ├── gel_branch_rebase.rst │ │ │ │ ├── gel_branch_rename.rst │ │ │ │ ├── gel_branch_switch.rst │ │ │ │ ├── gel_branch_wipe.rst │ │ │ │ └── index.rst │ │ │ ├── gel_cli_upgrade.rst │ │ │ ├── gel_cloud/ │ │ │ │ ├── gel_cloud_login.rst │ │ │ │ ├── gel_cloud_logout.rst │ │ │ │ ├── gel_cloud_secretkey/ │ │ │ │ │ ├── edgedb_cloud_secretkey_create.rst │ │ │ │ │ ├── edgedb_cloud_secretkey_list.rst │ │ │ │ │ ├── edgedb_cloud_secretkey_revoke.rst │ │ │ │ │ └── index.rst │ │ │ │ └── index.rst │ │ │ ├── gel_configure.rst │ │ │ ├── gel_connopts.rst │ │ │ ├── gel_database/ │ │ │ │ ├── gel_database_create.rst │ │ │ │ ├── gel_database_drop.rst │ │ │ │ ├── gel_database_wipe.rst │ │ │ │ └── index.rst │ │ │ ├── gel_describe/ │ │ │ │ ├── gel_describe_object.rst │ │ │ │ ├── gel_describe_schema.rst │ │ │ │ └── index.rst │ │ │ ├── gel_dump.rst │ │ │ ├── gel_extension/ │ │ │ │ ├── index.rst │ │ │ │ ├── install.rst │ │ │ │ ├── list-available.rst │ │ │ │ ├── list.rst │ │ │ │ └── uninstall.rst │ │ │ ├── gel_info.rst │ │ │ ├── gel_init.rst │ │ │ ├── gel_instance/ │ │ │ │ ├── gel_instance_create.rst │ │ │ │ ├── gel_instance_credentials.rst │ │ │ │ ├── gel_instance_destroy.rst │ │ │ │ ├── gel_instance_link.rst │ │ │ │ ├── gel_instance_list.rst │ │ │ │ ├── gel_instance_logs.rst │ │ │ │ ├── gel_instance_reset_password.rst │ │ │ │ ├── gel_instance_restart.rst │ │ │ │ ├── gel_instance_revert.rst │ │ │ │ ├── gel_instance_start.rst │ │ │ │ ├── gel_instance_status.rst │ │ │ │ ├── gel_instance_stop.rst │ │ │ │ ├── gel_instance_unlink.rst │ │ │ │ ├── gel_instance_upgrade.rst │ │ │ │ └── index.rst │ │ │ ├── gel_list.rst │ │ │ ├── gel_migrate.rst │ │ │ ├── gel_migration/ │ │ │ │ ├── gel_migration_apply.rst │ │ │ │ ├── gel_migration_create.rst │ │ │ │ ├── gel_migration_edit.rst │ │ │ │ ├── gel_migration_extract.rst │ │ │ │ ├── gel_migration_log.rst │ │ │ │ ├── gel_migration_status.rst │ │ │ │ ├── gel_migration_upgrade_check.rst │ │ │ │ └── index.rst │ │ │ ├── gel_project/ │ │ │ │ ├── gel_project_info.rst │ │ │ │ ├── gel_project_init.rst │ │ │ │ ├── gel_project_unlink.rst │ │ │ │ ├── gel_project_upgrade.rst │ │ │ │ └── index.rst │ │ │ ├── gel_query.rst │ │ │ ├── gel_restore.rst │ │ │ ├── gel_server/ │ │ │ │ ├── gel_server_info.rst │ │ │ │ ├── gel_server_install.rst │ │ │ │ ├── gel_server_list_versions.rst │ │ │ │ ├── gel_server_uninstall.rst │ │ │ │ └── index.rst │ │ │ ├── gel_ui.rst │ │ │ ├── gel_watch.rst │ │ │ ├── index.rst │ │ │ └── network.rst │ │ ├── clients.rst │ │ ├── connection.rst │ │ ├── datetime.rst │ │ ├── graphql/ │ │ │ ├── cheatsheet.rst │ │ │ ├── graphql.rst │ │ │ ├── index.rst │ │ │ ├── introspection.rst │ │ │ └── mutations.rst │ │ ├── http.rst │ │ ├── index.rst │ │ ├── js/ │ │ │ ├── client.rst │ │ │ ├── datatypes.rst │ │ │ ├── generation.rst │ │ │ ├── index.rst │ │ │ ├── interfaces.rst │ │ │ ├── queries.rst │ │ │ └── querybuilder.rst │ │ ├── projects.rst │ │ ├── python/ │ │ │ ├── api/ │ │ │ │ ├── advanced.rst │ │ │ │ ├── codegen.rst │ │ │ │ └── types.rst │ │ │ ├── client.rst │ │ │ └── index.rst │ │ └── sql_adapter.rst │ └── resources/ │ ├── changelog/ │ │ ├── 1_0_a2.rst │ │ ├── 1_0_a3.rst │ │ ├── 1_0_a4.rst │ │ ├── 1_0_a5.rst │ │ ├── 1_0_a6.rst │ │ ├── 1_0_a7.rst │ │ ├── 1_0_b1.rst │ │ ├── 1_0_b2.rst │ │ ├── 1_0_b3.rst │ │ ├── 1_0_rc1.rst │ │ ├── 1_0_rc2.rst │ │ ├── 1_0_rc3.rst │ │ ├── 1_0_rc4.rst │ │ ├── 1_0_rc5.rst │ │ ├── 1_x.rst │ │ ├── 2_x.rst │ │ ├── 3_x.rst │ │ ├── 4_x.rst │ │ ├── 5_x.rst │ │ ├── 6_x.rst │ │ ├── 7_x.rst │ │ ├── deprecation.rst │ │ └── index.rst │ ├── cheatsheets/ │ │ ├── admin.rst │ │ ├── aliases.rst │ │ ├── annotations.rst │ │ ├── boolean.rst │ │ ├── cli.rst │ │ ├── delete.rst │ │ ├── functions.rst │ │ ├── index.rst │ │ ├── insert.rst │ │ ├── objects.rst │ │ ├── repl.rst │ │ ├── select.rst │ │ └── update.rst │ ├── guides/ │ │ ├── contributing/ │ │ │ ├── code.rst │ │ │ ├── documentation.rst │ │ │ └── index.rst │ │ ├── datamigrations/ │ │ │ ├── index.rst │ │ │ └── postgres.rst │ │ ├── index.rst │ │ ├── migrations/ │ │ │ ├── guide.rst │ │ │ ├── index.rst │ │ │ └── tips.rst │ │ └── tutorials/ │ │ ├── chatgpt_bot.rst │ │ ├── cloudflare_workers.rst │ │ ├── graphql_apis_with_strawberry.rst │ │ ├── index.rst │ │ ├── jupyter_notebook.rst │ │ ├── nextjs_app_router.rst │ │ ├── nextjs_pages_router.rst │ │ ├── rest_apis_with_fastapi.rst │ │ ├── rest_apis_with_flask.rst │ │ └── trpc.rst │ ├── index.rst │ ├── protocol/ │ │ ├── dataformats.rst │ │ ├── dump_format.rst │ │ ├── errors.rst │ │ ├── index.rst │ │ ├── messages.rst │ │ └── typedesc.rst │ └── upgrading.rst ├── edb/ │ ├── .gitignore │ ├── README.md │ ├── __init__.py │ ├── _edgeql_parser.pyi │ ├── api/ │ │ ├── errors.txt │ │ └── types.txt │ ├── buildmeta.py │ ├── cli/ │ │ ├── .gitignore │ │ ├── __init__.py │ │ └── __main__.py │ ├── common/ │ │ ├── __init__.py │ │ ├── _typing_inspect.py │ │ ├── adapter.py │ │ ├── assert_data_shape.py │ │ ├── ast/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── codegen.py │ │ │ ├── transformer.py │ │ │ └── visitor.py │ │ ├── asyncutil.py │ │ ├── asyncwatcher.py │ │ ├── binwrapper.py │ │ ├── checked.py │ │ ├── colorsys.py │ │ ├── compiler.py │ │ ├── debug.py │ │ ├── devmode.py │ │ ├── english.py │ │ ├── enum.py │ │ ├── exceptions.py │ │ ├── levenshtein.py │ │ ├── log.py │ │ ├── lru.py │ │ ├── markup/ │ │ │ ├── __init__.py │ │ │ ├── elements/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── code.py │ │ │ │ ├── doc.py │ │ │ │ └── lang.py │ │ │ ├── format.py │ │ │ ├── renderers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── styles.py │ │ │ │ └── terminal.py │ │ │ └── serializer/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── code.py │ │ │ └── logging.py │ │ ├── ordered.py │ │ ├── ordered.pyi │ │ ├── parametric.py │ │ ├── parsing.py │ │ ├── prometheus.py │ │ ├── retryloop.py │ │ ├── secretkey.py │ │ ├── signalctl.py │ │ ├── span.py │ │ ├── struct.py │ │ ├── supervisor.py │ │ ├── term.py │ │ ├── token_bucket.py │ │ ├── topological.py │ │ ├── traceback.py │ │ ├── turbo_uuid.pyi │ │ ├── typeutils.py │ │ ├── typing_inspect.py │ │ ├── uuidgen.py │ │ ├── value_dispatch.py │ │ ├── verutils.py │ │ ├── view_patterns.py │ │ ├── windowedsum.py │ │ └── xdedent.py │ ├── edgeql/ │ │ ├── __init__.py │ │ ├── ast.py │ │ ├── codegen.py │ │ ├── compiler/ │ │ │ ├── __init__.py │ │ │ ├── astutils.py │ │ │ ├── casts.py │ │ │ ├── clauses.py │ │ │ ├── config.py │ │ │ ├── config_desc.py │ │ │ ├── conflicts.py │ │ │ ├── context.py │ │ │ ├── dispatch.py │ │ │ ├── eta_expand.py │ │ │ ├── expr.py │ │ │ ├── func.py │ │ │ ├── group.py │ │ │ ├── inference/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cardinality.py │ │ │ │ ├── context.py │ │ │ │ ├── multiplicity.py │ │ │ │ ├── utils.py │ │ │ │ └── volatility.py │ │ │ ├── normalization.py │ │ │ ├── options.py │ │ │ ├── pathctx.py │ │ │ ├── policies.py │ │ │ ├── polyres.py │ │ │ ├── schemactx.py │ │ │ ├── setgen.py │ │ │ ├── stmt.py │ │ │ ├── stmtctx.py │ │ │ ├── triggers.py │ │ │ ├── tuple_args.py │ │ │ ├── typegen.py │ │ │ └── viewgen.py │ │ ├── declarative.py │ │ ├── desugar_group.py │ │ ├── parser/ │ │ │ ├── __init__.py │ │ │ └── grammar/ │ │ │ ├── .gitignore │ │ │ ├── __init__.py │ │ │ ├── commondl.py │ │ │ ├── config.py │ │ │ ├── ddl.py │ │ │ ├── expressions.py │ │ │ ├── keywords.py │ │ │ ├── precedence.py │ │ │ ├── sdl.py │ │ │ ├── session.py │ │ │ ├── start.py │ │ │ ├── statements.py │ │ │ └── tokens.py │ │ ├── qltypes.py │ │ ├── quote.py │ │ ├── tokenizer.py │ │ ├── tracer.py │ │ └── utils.py │ ├── edgeql-parser/ │ │ ├── Cargo.toml │ │ ├── edgeql-parser-derive/ │ │ │ ├── Cargo.toml │ │ │ └── src/ │ │ │ └── lib.rs │ │ ├── edgeql-parser-python/ │ │ │ ├── Cargo.toml │ │ │ ├── src/ │ │ │ │ ├── errors.rs │ │ │ │ ├── hash.rs │ │ │ │ ├── keywords.rs │ │ │ │ ├── lib.rs │ │ │ │ ├── normalize.rs │ │ │ │ ├── parser.rs │ │ │ │ ├── position.rs │ │ │ │ ├── pynormalize.rs │ │ │ │ ├── tokenizer.rs │ │ │ │ └── unpack.rs │ │ │ └── tests/ │ │ │ └── normalize.rs │ │ ├── src/ │ │ │ ├── ast.rs │ │ │ ├── expr.rs │ │ │ ├── hash.rs │ │ │ ├── helpers/ │ │ │ │ ├── bytes.rs │ │ │ │ ├── mod.rs │ │ │ │ └── strings.rs │ │ │ ├── keywords.rs │ │ │ ├── lib.rs │ │ │ ├── parser/ │ │ │ │ ├── cst.rs │ │ │ │ ├── custom_errors.rs │ │ │ │ ├── mod.rs │ │ │ │ └── spec.rs │ │ │ ├── position.rs │ │ │ ├── preparser.rs │ │ │ ├── schema_file.rs │ │ │ ├── tokenizer.rs │ │ │ └── validation.rs │ │ └── tests/ │ │ ├── expr.rs │ │ ├── preparser.rs │ │ └── tokenizer.rs │ ├── errors/ │ │ ├── __init__.py │ │ └── base.py │ ├── graphql/ │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── _patch_core.py │ │ ├── codegen.py │ │ ├── compiler.py │ │ ├── errors.py │ │ ├── explore.py │ │ ├── extension.pyx │ │ ├── tokenizer.py │ │ ├── translator.py │ │ └── types.py │ ├── graphql-rewrite/ │ │ ├── Cargo.toml │ │ ├── _graphql_rewrite.pyi │ │ ├── src/ │ │ │ ├── lib.rs │ │ │ ├── py_entry.rs │ │ │ ├── py_exception.rs │ │ │ ├── py_token.rs │ │ │ ├── rewrite.rs │ │ │ └── token_vec.rs │ │ └── tests/ │ │ └── rewrite.rs │ ├── ir/ │ │ ├── __init__.py │ │ ├── ast.py │ │ ├── astexpr.py │ │ ├── pathid.py │ │ ├── scopetree.py │ │ ├── staeval.py │ │ ├── statypes.py │ │ ├── typeutils.py │ │ └── utils.py │ ├── language_server/ │ │ ├── __init__.py │ │ ├── completion.py │ │ ├── definition.py │ │ ├── main.py │ │ ├── parsing.py │ │ ├── project.py │ │ ├── schema.py │ │ ├── server.py │ │ └── utils.py │ ├── lib/ │ │ ├── __init__.py │ │ ├── _testmode.edgeql │ │ ├── cal.edgeql │ │ ├── cfg.edgeql │ │ ├── enc.edgeql │ │ ├── ext/ │ │ │ ├── ai.edgeql │ │ │ ├── auth.edgeql │ │ │ ├── edgeqlhttp.edgeql │ │ │ ├── graphql.edgeql │ │ │ ├── notebook.edgeql │ │ │ ├── pg_trgm.edgeql │ │ │ ├── pg_unaccent.edgeql │ │ │ ├── pgcrypto.edgeql │ │ │ └── pgvector.edgeql │ │ ├── fts.edgeql │ │ ├── math.edgeql │ │ ├── net.edgeql │ │ ├── pg.edgeql │ │ ├── schema.edgeql │ │ ├── std/ │ │ │ ├── 00-prelude.edgeql │ │ │ ├── 10-scalars.edgeql │ │ │ ├── 15-attrs.edgeql │ │ │ ├── 17-abstractops.edgeql │ │ │ ├── 20-genericfuncs.edgeql │ │ │ ├── 25-booloperators.edgeql │ │ │ ├── 25-enumoperators.edgeql │ │ │ ├── 25-numoperators.edgeql │ │ │ ├── 25-setoperators.edgeql │ │ │ ├── 26-bitwisefuncs.edgeql │ │ │ ├── 30-arrayfuncs.edgeql │ │ │ ├── 30-bytesfuncs.edgeql │ │ │ ├── 30-datetimefuncs.edgeql │ │ │ ├── 30-jsonfuncs.edgeql │ │ │ ├── 30-regexpfuncs.edgeql │ │ │ ├── 30-sequencefuncs.edgeql │ │ │ ├── 30-strfuncs.edgeql │ │ │ ├── 30-uuidfuncs.edgeql │ │ │ ├── 31-rangefuncs.edgeql │ │ │ ├── 50-constraints.edgeql │ │ │ ├── 60-baseobject.edgeql │ │ │ └── 70-converters.edgeql │ │ └── sys.edgeql │ ├── load_ext/ │ │ └── main.py │ ├── pgsql/ │ │ ├── __init__.py │ │ ├── ast.py │ │ ├── codegen.py │ │ ├── common.py │ │ ├── compiler/ │ │ │ ├── ARCHITECTURE.md │ │ │ ├── __init__.py │ │ │ ├── aliases.py │ │ │ ├── astutils.py │ │ │ ├── clauses.py │ │ │ ├── config.py │ │ │ ├── context.py │ │ │ ├── dispatch.py │ │ │ ├── dml.py │ │ │ ├── enums.py │ │ │ ├── expr.py │ │ │ ├── group.py │ │ │ ├── output.py │ │ │ ├── pathctx.py │ │ │ ├── relctx.py │ │ │ ├── relgen.py │ │ │ ├── shapecomp.py │ │ │ └── stmt.py │ │ ├── dbops/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── catalogs.py │ │ │ ├── composites.py │ │ │ ├── config.py │ │ │ ├── constraints.py │ │ │ ├── databases.py │ │ │ ├── ddl.py │ │ │ ├── domains.py │ │ │ ├── enums.py │ │ │ ├── extensions.py │ │ │ ├── functions.py │ │ │ ├── indexes.py │ │ │ ├── operators.py │ │ │ ├── ranges.py │ │ │ ├── roles.py │ │ │ ├── schemas.py │ │ │ ├── sequences.py │ │ │ ├── tables.py │ │ │ ├── triggers.py │ │ │ ├── types.py │ │ │ └── views.py │ │ ├── debug.py │ │ ├── delta.py │ │ ├── delta_ext_ai.py │ │ ├── deltadbops.py │ │ ├── deltafts.py │ │ ├── inheritance.py │ │ ├── keywords.py │ │ ├── metaschema.py │ │ ├── params.py │ │ ├── parser/ │ │ │ ├── .gitignore │ │ │ ├── __init__.py │ │ │ ├── ast_builder.py │ │ │ ├── exceptions.py │ │ │ ├── parser.pxd │ │ │ └── parser.pyx │ │ ├── patches.py │ │ ├── patches_6x.py │ │ ├── resolver/ │ │ │ ├── __init__.py │ │ │ ├── command.py │ │ │ ├── context.py │ │ │ ├── dispatch.py │ │ │ ├── expr.py │ │ │ ├── range_functions.py │ │ │ ├── range_var.py │ │ │ ├── relation.py │ │ │ ├── sql_introspection.py │ │ │ └── static.py │ │ ├── schemamech.py │ │ ├── trampoline.py │ │ └── types.py │ ├── protocol/ │ │ ├── .gitignore │ │ ├── README │ │ ├── __init__.py │ │ ├── enums.py │ │ ├── messages.py │ │ ├── protocol.pxd │ │ ├── protocol.pyi │ │ ├── protocol.pyx │ │ └── render_utils.py │ ├── schema/ │ │ ├── __init__.py │ │ ├── _types.py │ │ ├── abc.py │ │ ├── annos.py │ │ ├── casts.py │ │ ├── constraints.py │ │ ├── database.py │ │ ├── ddl.py │ │ ├── defines.py │ │ ├── delta.py │ │ ├── expr.py │ │ ├── expraliases.py │ │ ├── extensions.py │ │ ├── functions.py │ │ ├── futures.py │ │ ├── globals.py │ │ ├── indexes.py │ │ ├── inheriting.py │ │ ├── links.py │ │ ├── migrations.py │ │ ├── modules.py │ │ ├── name.py │ │ ├── objects.py │ │ ├── objtypes.py │ │ ├── operators.py │ │ ├── ordering.py │ │ ├── permissions.py │ │ ├── pointers.py │ │ ├── policies.py │ │ ├── properties.py │ │ ├── pseudo.py │ │ ├── referencing.py │ │ ├── reflection/ │ │ │ ├── __init__.py │ │ │ ├── reader.py │ │ │ ├── structure.py │ │ │ └── writer.py │ │ ├── rewrites.py │ │ ├── roles.py │ │ ├── scalars.py │ │ ├── schema.py │ │ ├── sources.py │ │ ├── std.py │ │ ├── triggers.py │ │ ├── types.py │ │ ├── unknown_pointers.py │ │ ├── utils.py │ │ └── version.py │ ├── server/ │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── _rust_native/ │ │ │ ├── Cargo.toml │ │ │ └── src/ │ │ │ └── lib.rs │ │ ├── args.py │ │ ├── auth.py │ │ ├── bootstrap.py │ │ ├── cache/ │ │ │ ├── __init__.py │ │ │ ├── stmt_cache.pxd │ │ │ └── stmt_cache.pyx │ │ ├── compiler/ │ │ │ ├── __init__.py │ │ │ ├── compiler.py │ │ │ ├── config.py │ │ │ ├── dbstate.py │ │ │ ├── ddl.py │ │ │ ├── enums.py │ │ │ ├── errormech.py │ │ │ ├── explain/ │ │ │ │ ├── __init__.py │ │ │ │ ├── casefold.py │ │ │ │ ├── coarse_grained.py │ │ │ │ ├── fine_grained.py │ │ │ │ ├── ir_analyze.py │ │ │ │ ├── pg_tree.py │ │ │ │ └── to_json.py │ │ │ ├── rpc.pxd │ │ │ ├── rpc.pyi │ │ │ ├── rpc.pyx │ │ │ ├── sertypes.py │ │ │ ├── sql.py │ │ │ └── status.py │ │ ├── compiler_pool/ │ │ │ ├── __init__.py │ │ │ ├── amsg.py │ │ │ ├── multitenant_worker.py │ │ │ ├── pool.py │ │ │ ├── queue.py │ │ │ ├── server.py │ │ │ ├── state.py │ │ │ ├── worker.py │ │ │ └── worker_proc.py │ │ ├── config/ │ │ │ ├── __init__.py │ │ │ ├── ops.py │ │ │ ├── spec.py │ │ │ └── types.py │ │ ├── connpool/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── pool.py │ │ │ ├── pool2.py │ │ │ └── rolavg.py │ │ ├── consul.py │ │ ├── daemon/ │ │ │ ├── __init__.py │ │ │ ├── daemon.py │ │ │ ├── exceptions.py │ │ │ ├── lib.py │ │ │ └── pidfile.py │ │ ├── dbview/ │ │ │ ├── __init__.py │ │ │ ├── dbview.pxd │ │ │ ├── dbview.pyi │ │ │ └── dbview.pyx │ │ ├── defines.py │ │ ├── ha/ │ │ │ ├── __init__.py │ │ │ ├── adaptive.py │ │ │ ├── base.py │ │ │ └── stolon.py │ │ ├── http.py │ │ ├── inplace_upgrade.py │ │ ├── instdata.py │ │ ├── logsetup.py │ │ ├── main.py │ │ ├── metrics.py │ │ ├── multitenant.py │ │ ├── net_worker.py │ │ ├── pgcluster.py │ │ ├── pgcon/ │ │ │ ├── __init__.py │ │ │ ├── connect.py │ │ │ ├── cpythonx.pxd │ │ │ ├── errors.py │ │ │ ├── pgcon.pxd │ │ │ ├── pgcon.pyi │ │ │ ├── pgcon.pyx │ │ │ ├── pgcon_sql.pxd │ │ │ ├── pgcon_sql.pyx │ │ │ └── rust_transport.py │ │ ├── pgconnparams.py │ │ ├── protocol/ │ │ │ ├── __init__.py │ │ │ ├── ai_ext.py │ │ │ ├── args_ser.pxd │ │ │ ├── args_ser.pyx │ │ │ ├── auth/ │ │ │ │ ├── __init__.py │ │ │ │ └── scram.py │ │ │ ├── auth_ext/ │ │ │ │ ├── __init__.py │ │ │ │ ├── _static/ │ │ │ │ │ ├── interactions.js │ │ │ │ │ ├── styles.css │ │ │ │ │ ├── utils.js │ │ │ │ │ ├── webauthn-authenticate.js │ │ │ │ │ └── webauthn-register.js │ │ │ │ ├── apple.py │ │ │ │ ├── azure.py │ │ │ │ ├── base.py │ │ │ │ ├── config.py │ │ │ │ ├── data.py │ │ │ │ ├── discord.py │ │ │ │ ├── email.py │ │ │ │ ├── email_password.py │ │ │ │ ├── errors.py │ │ │ │ ├── github.py │ │ │ │ ├── google.py │ │ │ │ ├── http.py │ │ │ │ ├── jwt.py │ │ │ │ ├── local.py │ │ │ │ ├── magic_link.py │ │ │ │ ├── oauth.py │ │ │ │ ├── otc.py │ │ │ │ ├── pkce.py │ │ │ │ ├── slack.py │ │ │ │ ├── ui/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── components.py │ │ │ │ │ └── util.py │ │ │ │ ├── util.py │ │ │ │ ├── webauthn.py │ │ │ │ └── webhook.py │ │ │ ├── auth_helpers.pxd │ │ │ ├── auth_helpers.pyx │ │ │ ├── binary.pxd │ │ │ ├── binary.pyx │ │ │ ├── consts.pxi │ │ │ ├── cpythonx.pxd │ │ │ ├── edgeql_ext.pyx │ │ │ ├── execute.pxd │ │ │ ├── execute.pyi │ │ │ ├── execute.pyx │ │ │ ├── frontend.pxd │ │ │ ├── frontend.pyx │ │ │ ├── metrics.py │ │ │ ├── notebook_ext.pxd │ │ │ ├── notebook_ext.pyx │ │ │ ├── pg_ext.pxd │ │ │ ├── pg_ext.pyx │ │ │ ├── protocol.pxd │ │ │ ├── protocol.pyi │ │ │ ├── protocol.pyx │ │ │ ├── request_scheduler.py │ │ │ ├── server_info.py │ │ │ ├── system_api.py │ │ │ └── ui_ext.pyx │ │ ├── rust_async_channel.py │ │ ├── server.py │ │ ├── service_manager.py │ │ ├── smtp.py │ │ └── tenant.py │ ├── testbase/ │ │ ├── __init__.py │ │ ├── asyncutils.py │ │ ├── cluster.py │ │ ├── connection.py │ │ ├── experimental_interpreter.py │ │ ├── http.py │ │ ├── lang.py │ │ ├── proc.py │ │ ├── protocol/ │ │ │ ├── __init__.py │ │ │ └── test.py │ │ ├── serutils.py │ │ └── server.py │ └── tools/ │ ├── __init__.py │ ├── __main__.py │ ├── ast_inheritance_graph.py │ ├── cli.py │ ├── config.py │ ├── dflags.py │ ├── docs/ │ │ ├── __init__.py │ │ ├── cli.py │ │ ├── edb.py │ │ ├── eql.py │ │ ├── go.py │ │ ├── graphql.py │ │ ├── js.py │ │ ├── sdl.py │ │ └── shared.py │ ├── edb.py │ ├── experimental_interpreter/ │ │ ├── back_to_ql.py │ │ ├── basis/ │ │ │ ├── 80-interpreter-internal.edgeql │ │ │ ├── built_ins.py │ │ │ ├── builtin_bin_ops.py │ │ │ ├── errors.py │ │ │ ├── reserved_ops.py │ │ │ ├── server_funcs.py │ │ │ └── std_funcs.py │ │ ├── data/ │ │ │ ├── casts.py │ │ │ ├── data_ops.py │ │ │ ├── deduplication_insert.py │ │ │ ├── expr_ops.py │ │ │ ├── expr_to_str.py │ │ │ ├── module_ops.py │ │ │ ├── path_factor.py │ │ │ ├── query_ops.py │ │ │ ├── type_ops.py │ │ │ └── val_to_json.py │ │ ├── db_interface.py │ │ ├── edb_entry.py │ │ ├── elab_schema.py │ │ ├── elaboration.py │ │ ├── errors.py │ │ ├── evaluation.py │ │ ├── evaluation_tools/ │ │ │ └── storage_coercion.py │ │ ├── helper_funcs.py │ │ ├── interpreter_logging.py │ │ ├── logs.py │ │ ├── new_interpreter.py │ │ ├── post_processing_tools/ │ │ │ ├── insert_select_optimization.py │ │ │ └── post_processing.py │ │ ├── schema/ │ │ │ ├── ddl_processing.py │ │ │ ├── function_elaboration.py │ │ │ ├── library_discovery.py │ │ │ └── subtyping_resolution.py │ │ ├── sqlite/ │ │ │ └── sqlite_adapter.py │ │ └── type_checking_tools/ │ │ ├── cast_checking.py │ │ ├── dml_checking.py │ │ ├── function_checking.py │ │ ├── inheritance_populate.py │ │ ├── module_check_tools.py │ │ ├── name_resolution.py │ │ ├── schema_checking.py │ │ └── typechecking.py │ ├── fake_ai_server.py │ ├── gen_cast_table.py │ ├── gen_errors.py │ ├── gen_meta_grammars.py │ ├── gen_rust_ast.py │ ├── gen_sql_introspection.py │ ├── gen_test_dumps.py │ ├── gen_types.py │ ├── inittestdb.py │ ├── ls.py │ ├── ls_forbidden_functions.py │ ├── mypy/ │ │ ├── __init__.py │ │ └── plugin.py │ ├── parser_demo.py │ ├── profiling/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── cli.py │ │ ├── profiler.py │ │ ├── svg_helpers.js │ │ └── tracing_singledispatch.py │ ├── pygments/ │ │ ├── __init__.py │ │ ├── edgeql/ │ │ │ ├── __init__.py │ │ │ └── meta.py │ │ └── graphql/ │ │ └── __init__.py │ ├── railroad_diagram.py │ ├── redo_metaschema.py │ ├── rm_data_dir.py │ ├── test/ │ │ ├── __init__.py │ │ ├── cpython_state.py │ │ ├── decorators.py │ │ ├── loader.py │ │ ├── mproc_fixes.py │ │ ├── results.py │ │ ├── runner.py │ │ └── styles.py │ ├── test_extension.py │ ├── toy_eval_model.py │ └── wipe.py ├── edb_stat_statements/ │ ├── .gitignore │ ├── Makefile │ ├── edb_stat_statements--1.0.sql │ ├── edb_stat_statements.c │ ├── edb_stat_statements.control │ ├── expected/ │ │ ├── cleanup.out │ │ ├── cursors.out │ │ ├── dml.out.17 │ │ ├── dml.out.18 │ │ ├── entry_timestamp.out │ │ ├── extended.out │ │ ├── level_tracking.out.17 │ │ ├── level_tracking.out.18 │ │ ├── oldextversions.out │ │ ├── parallel.out.17 │ │ ├── parallel.out.18 │ │ ├── planning.out │ │ ├── privileges.out │ │ ├── select.out │ │ ├── user_activity.out │ │ ├── utility.out.16 │ │ ├── utility.out.17 │ │ ├── wal.out.17 │ │ └── wal.out.18 │ ├── sql/ │ │ ├── cleanup.sql │ │ ├── cursors.sql │ │ ├── dml.sql │ │ ├── entry_timestamp.sql │ │ ├── extended.sql │ │ ├── level_tracking.sql │ │ ├── oldextversions.sql │ │ ├── parallel.sql │ │ ├── planning.sql │ │ ├── privileges.sql │ │ ├── select.sql │ │ ├── user_activity.sql │ │ ├── utility.sql │ │ └── wal.sql │ └── t/ │ └── 010_restart.pl ├── pyproject.toml ├── rust/ │ ├── conn_pool/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── algo.rs │ │ ├── bin/ │ │ │ └── optimizer.rs │ │ ├── block.rs │ │ ├── conn.rs │ │ ├── drain.rs │ │ ├── lib.rs │ │ ├── metrics.rs │ │ ├── pool.rs │ │ ├── python.rs │ │ ├── test/ │ │ │ ├── mod.rs │ │ │ └── spec.rs │ │ └── waitqueue.rs │ ├── gel-http/ │ │ ├── Cargo.toml │ │ └── src/ │ │ ├── cache.rs │ │ ├── lib.rs │ │ └── python.rs │ ├── pgrust/ │ │ ├── Cargo.toml │ │ └── src/ │ │ ├── errors/ │ │ │ ├── edgedb.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ └── python/ │ │ └── mod.rs │ └── pyo3_util/ │ ├── Cargo.toml │ └── src/ │ ├── channel.rs │ ├── lib.rs │ └── logging.rs ├── rust-toolchain.toml ├── setup.py └── tests/ ├── __init__.py ├── certs/ │ ├── .gitignore │ ├── ca.cert.pem │ ├── ca.conf │ ├── ca.crl.pem │ ├── ca.key.pem │ ├── client.cert.pem │ ├── client.key.pem │ ├── client.key.protected.pem │ ├── client_ca.cert.pem │ ├── client_ca.key.pem │ ├── gen.py │ ├── gen.sh │ ├── server.cert.pem │ └── server.key.pem ├── common/ │ ├── __init__.py │ ├── test_ast.py │ ├── test_asyncutil.py │ ├── test_checked.py │ ├── test_debug.py │ ├── test_lru.py │ ├── test_markup.py │ ├── test_parametric.py │ ├── test_prometheus.py │ ├── test_signalctl.py │ ├── test_struct.py │ ├── test_supervisor.py │ ├── test_term.py │ ├── test_token_bucket.py │ ├── test_value_dispatch.py │ ├── test_windowedsum.py │ └── test_xdedent.py ├── dumps/ │ ├── dump01/ │ │ ├── 1_4.dump │ │ ├── 2_0.dump │ │ ├── 3_0.dump │ │ ├── 4_0.dump │ │ └── 6_0.dump │ ├── dump02/ │ │ ├── 1_4.dump │ │ ├── 2_0.dump │ │ ├── 3_0.dump │ │ ├── 4_0.dump │ │ └── 6_0.dump │ ├── dump03/ │ │ ├── 1_4.dump │ │ ├── 2_0.dump │ │ ├── 3_0.dump │ │ ├── 4_0.dump │ │ └── 6_0.dump │ ├── dumpv2/ │ │ ├── 2_0.dump │ │ ├── 3_0.dump │ │ ├── 4_0.dump │ │ └── 6_0.dump │ ├── dumpv3/ │ │ ├── .gitignore │ │ ├── 3_0.dump │ │ ├── 4_0.dump │ │ └── 6_0.dump │ ├── dumpv4/ │ │ ├── .gitignore │ │ ├── 3_0.dump │ │ ├── 4_0.dump │ │ └── 6_0.dump │ ├── dumpv5/ │ │ ├── .gitignore │ │ └── 6_0.dump │ ├── dumpv6/ │ │ ├── .gitignore │ │ └── 6_0.dump │ └── dumpv7/ │ └── .gitignore ├── edgeql/ │ ├── __init__.py │ └── test_quote.py ├── extension-testing/ │ ├── .gitignore │ ├── ext_test/ │ │ ├── MANIFEST.toml │ │ ├── Makefile │ │ ├── get_sum.edgeql │ │ └── sql/ │ │ ├── Makefile │ │ ├── get_sum--0.0.1.sql │ │ ├── get_sum.c │ │ └── get_sum.control │ └── exts.mk ├── inplace-testing/ │ ├── prep-upgrades.py │ ├── test-old.sh │ ├── test.sh │ └── upgrade.patch ├── patch-testing/ │ ├── test.sh │ └── upgrade.patch ├── schemas/ │ ├── advtypes.esdl │ ├── cards.esdl │ ├── cards_ir_inference.esdl │ ├── cards_setup.edgeql │ ├── casts.esdl │ ├── casts_setup.edgeql │ ├── constraints.esdl │ ├── constraints_migration/ │ │ ├── schema.esdl │ │ └── updated_schema.esdl │ ├── dump01_default.esdl │ ├── dump01_setup.edgeql │ ├── dump01_test.esdl │ ├── dump02_default.esdl │ ├── dump02_setup.edgeql │ ├── dump03_default.esdl │ ├── dump03_setup.edgeql │ ├── dump_v2_default.esdl │ ├── dump_v2_setup.edgeql │ ├── dump_v3_default.esdl │ ├── dump_v3_setup.edgeql │ ├── dump_v4_default.esdl │ ├── dump_v4_setup.edgeql │ ├── dump_v5_default.esdl │ ├── dump_v5_setup.edgeql │ ├── dump_v6_default.esdl │ ├── dump_v6_setup.edgeql │ ├── dump_v7_default.esdl │ ├── dump_v7_setup.edgeql │ ├── enums.esdl │ ├── explain.esdl │ ├── explain_bug5758.esdl │ ├── explain_bug5791.esdl │ ├── explain_setup.edgeql │ ├── ext_ai.esdl │ ├── fts.esdl │ ├── fts_setup.edgeql │ ├── graphql.esdl │ ├── graphql_other.esdl │ ├── graphql_schema.esdl │ ├── graphql_schema_other.esdl │ ├── graphql_schema_other_deep.esdl │ ├── graphql_setup.edgeql │ ├── insert.esdl │ ├── interpreter_disambiguation.esdl │ ├── interpreter_disambiguation_setup.edgeql │ ├── inventory.esdl │ ├── inventory_setup.edgeql │ ├── issues.esdl │ ├── issues_coalesce_setup.edgeql │ ├── issues_filter_setup.edgeql │ ├── issues_setup.edgeql │ ├── json.esdl │ ├── json_setup.edgeql │ ├── link_tgt_del.esdl │ ├── link_tgt_del_migrated.esdl │ ├── links_1.esdl │ ├── links_1_migrated.esdl │ ├── movies.esdl │ ├── movies_setup.edgeql │ ├── pg_dump01_default.esdl │ ├── pg_dump01_setup.edgeql │ ├── pg_dump02_default.esdl │ ├── pg_dump02_setup.edgeql │ ├── pg_trgm.esdl │ ├── pg_trgm_setup.edgeql │ ├── pg_unaccent.esdl │ ├── pgvector.esdl │ ├── pgvector_setup.edgeql │ ├── smoke_test_interp.esdl │ ├── smoke_test_interp_setup.edgeql │ ├── tree.esdl │ ├── tree_setup.edgeql │ ├── updates.edgeql │ ├── updates.esdl │ ├── volatility.esdl │ └── volatility_setup.edgeql ├── test_api_errors.py ├── test_backend_connect.py ├── test_backend_ha.py ├── test_constraints.py ├── test_database.py ├── test_docs.py ├── test_docs_sphinx_ext.py ├── test_dump01.py ├── test_dump02.py ├── test_dump03.py ├── test_dump_basic.py ├── test_dump_v2.py ├── test_dump_v3.py ├── test_dump_v4.py ├── test_dump_v5.py ├── test_dump_v6.py ├── test_dump_v7.py ├── test_edgeql_advtypes.py ├── test_edgeql_calls.py ├── test_edgeql_casts.py ├── test_edgeql_coalesce.py ├── test_edgeql_data_migration.py ├── test_edgeql_datatypes.py ├── test_edgeql_ddl.py ├── test_edgeql_delete.py ├── test_edgeql_enums.py ├── test_edgeql_explain.py ├── test_edgeql_expr_aliases.py ├── test_edgeql_expressions.py ├── test_edgeql_ext_pg_trgm.py ├── test_edgeql_ext_pg_unaccent.py ├── test_edgeql_ext_pgcrypto.py ├── test_edgeql_extensions.py ├── test_edgeql_filter.py ├── test_edgeql_for.py ├── test_edgeql_fts.py ├── test_edgeql_fts_schema.py ├── test_edgeql_functions.py ├── test_edgeql_functions_inline.py ├── test_edgeql_globals.py ├── test_edgeql_group.py ├── test_edgeql_insert.py ├── test_edgeql_internal_group.py ├── test_edgeql_introspection.py ├── test_edgeql_ir_card_inference.py ├── test_edgeql_ir_mult_inference.py ├── test_edgeql_ir_pathid.py ├── test_edgeql_ir_scopetree.py ├── test_edgeql_ir_type_inference.py ├── test_edgeql_ir_volatility_inference.py ├── test_edgeql_json.py ├── test_edgeql_linkatoms.py ├── test_edgeql_linkprops.py ├── test_edgeql_net_schema.py ├── test_edgeql_permissions.py ├── test_edgeql_policies.py ├── test_edgeql_rewrites.py ├── test_edgeql_scope.py ├── test_edgeql_select.py ├── test_edgeql_select_interpreter.py ├── test_edgeql_sql_codegen.py ├── test_edgeql_syntax.py ├── test_edgeql_sys.py ├── test_edgeql_tree.py ├── test_edgeql_triggers.py ├── test_edgeql_tutorial.py ├── test_edgeql_update.py ├── test_edgeql_userddl.py ├── test_edgeql_vector.py ├── test_edgeql_volatility.py ├── test_eval_model.py ├── test_eval_model_group.py ├── test_eval_model_new_interpreter.py ├── test_ext_ai.py ├── test_http.py ├── test_http_auth.py ├── test_http_edgeql.py ├── test_http_ext_auth.py ├── test_http_graphql_mutation.py ├── test_http_graphql_query.py ├── test_http_graphql_schema.py ├── test_http_notebook.py ├── test_http_std_net.py ├── test_indexes.py ├── test_interpreter_disambiguation.py ├── test_language_server.py ├── test_link_target_delete.py ├── test_pg_dump.py ├── test_pgext.py ├── test_profiling.py ├── test_protocol.py ├── test_schema.py ├── test_schema_syntax.py ├── test_server_auth.py ├── test_server_compiler.py ├── test_server_concurrency.py ├── test_server_config.py ├── test_server_ops.py ├── test_server_param_conversions.py ├── test_server_permissions.py ├── test_server_pool.py ├── test_server_proto.py ├── test_server_request_scheduler.py ├── test_server_unit.py ├── test_session.py ├── test_sourcecode.py ├── test_sql_dml.py ├── test_sql_parse.py ├── test_sql_query.py └── test_tracer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .editorconfig ================================================ root = true [*] trim_trailing_whitespace = true insert_final_newline = true [*.{py,pyx,pxd,pxi,h}] indent_size = 4 indent_style = space [*.yml] indent_size = 2 indent_style = space [edb_stat_statements/*.{c,h,l,y,pl,pm}] indent_style = tab indent_size = tab tab_width = 4 ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve --- - Gel Version: - Gel CLI Version: - OS Version: Steps to Reproduce: 1. 2. Schema: ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: false contact_links: - name: Problem with Gel UI url: https://github.com/geldata/gel-ui/issues about: If you've found a bug or have a feature request for Gel UI, please open an issue in the gel-ui repo. - name: Long question or idea url: https://github.com/geldata/gel/discussions about: Ask long-form questions and discuss ideas. - name: Quick questions or chat url: https://www.geldata.com/p/discord about: Ask quick questions or simply chat on the Gel Discord server. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest a feature for Gel --- ================================================ FILE: .github/Makefile ================================================ .PHONY: all ROOT = $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) all: \ workflows/build.nightly.yml \ workflows/build.release.yml \ workflows/build.testing.yml \ workflows/build.dryrun.yml \ workflows/build.ls-nightly.yml \ workflows/tests.yml \ workflows/tests.pool.yml \ workflows/tests.managed-pg.yml \ workflows/tests.ha.yml \ workflows/tests.pg-versions.yml \ workflows/tests.patches.yml \ workflows/tests.inplace.yml \ workflows/tests.inplace7x.yml \ workflows/tests.reflection.yml \ workflows/build.%.yml: workflows.src/build.%.tpl.yml workflows.src/build.%.targets.yml workflows.src/build.inc.yml $(ROOT)/workflows.src/render.py --workflow=build build.$* build.$*.targets.yml workflows/tests.yml: workflows.src/tests.tpl.yml workflows.src/tests.targets.yml workflows.src/tests.inc.yml $(ROOT)/workflows.src/render.py --workflow=test tests tests.targets.yml workflows/tests.%.yml: workflows.src/tests.%.tpl.yml workflows.src/tests.%.targets.yml workflows.src/tests.inc.yml $(ROOT)/workflows.src/render.py --workflow=test tests.$* tests.$*.targets.yml ================================================ FILE: .github/aws-aurora/.gitignore ================================================ /.terraform/ /terraform.tfstate* /.terraform.lock.hcl /.terraform.tfstate.lock.info ================================================ FILE: .github/aws-aurora/main.tf ================================================ variable "vpc_id" { description = "VPC ID" } variable "sg_id" { description = "security group ID" } variable "password" { description = "password, provide through your ENV variables" } module "aurora" { source = "terraform-aws-modules/rds-aurora/aws" version = "~> 5.0" name = "aws-aurora-instance" engine = "aurora-postgresql" engine_version = "13.4" instance_type = "db.r6g.large" vpc_id = var.vpc_id db_subnet_group_name = "default" replica_count = 1 create_security_group = false vpc_security_group_ids = [var.sg_id] storage_encrypted = true apply_immediately = true username = "edbtest" password = var.password create_random_password = false enabled_cloudwatch_logs_exports = ["postgresql"] publicly_accessible = true skip_final_snapshot = true tags = { Environment = "dev" Terraform = "true" } } output "rds_cluster_endpoint" { description = "The cluster endpoint" value = module.aurora.rds_cluster_endpoint } ================================================ FILE: .github/aws-rds/.gitignore ================================================ /.terraform/ /terraform.tfstate* ================================================ FILE: .github/aws-rds/.terraform.lock.hcl ================================================ # This file is maintained automatically by "terraform init". # Manual edits may be lost in future updates. provider "registry.terraform.io/hashicorp/aws" { version = "3.31.0" hashes = [ "h1:Wou3ZnO10ZvN+n1iwyuaxn3zyGMFj9KYL+9IFb0gGkw=", "zh:07f5b2f4cfaa25e26a4062ac675e3e5aaf65bb21b94b8fd7f30d576398e7410f", "zh:08a2154ad29ae130ea9e46948b7b332ec4b45321b4852b45ba60adcfd049f8d6", "zh:35ed643c2b999021ad56b49f7d9d3a77c98d152477fe54b5c8a68f696bb1a0b7", "zh:3a8dc51b4be1c04130fd76cda4280019020b276336d307e7074ad52f35d4fdda", "zh:3c910c4f25e3ffd6d84f051c32161f03d1843753cd545e769757d7b42d654003", "zh:5d23f316f89937cbda36207271bbe150f633298f96d4644fd02063fc6bf0c28f", "zh:61fedb2915c5188c6550677a10acb955f32834bbe99ba0cafb2a118be282827b", "zh:65076a6899c0781ce95064d47d587ad07f80becd1510e4c475e4554131caec09", "zh:acca833c2d9985e46298323222285b370ea7cf5299b131dbdfc7c3e66fa32401", "zh:c212cf8ba7fdf64e75accf7e745f76d2349b00553ebd928cc6cafbfda99d97b7", "zh:cd3f5e89ac5f5cf3f8fed3aca4cc50261d537b60a3490feaddf9ba2f06e5e7aa", ] } ================================================ FILE: .github/aws-rds/main.tf ================================================ resource "aws_db_instance" "default" { allocated_storage = 10 engine = "postgres" engine_version = "13.4" instance_class = "db.m6g.large" name = "edbtest" username = "edbtest" password = var.password parameter_group_name = "default.postgres13" skip_final_snapshot = true auto_minor_version_upgrade = false publicly_accessible = true vpc_security_group_ids = [var.sg_id] } ================================================ FILE: .github/aws-rds/outputs.tf ================================================ output "db_instance_id" { value = aws_db_instance.default.id } output "db_instance_address" { value = aws_db_instance.default.address } ================================================ FILE: .github/aws-rds/variables.tf ================================================ variable "sg_id" { description = "security group ID" } variable "password" { description = "password, provide through your ENV variables" } ================================================ FILE: .github/do-database/.gitignore ================================================ /.terraform/ /terraform.tfstate* ================================================ FILE: .github/do-database/.terraform.lock.hcl ================================================ # This file is maintained automatically by "terraform init". # Manual edits may be lost in future updates. provider "registry.terraform.io/digitalocean/digitalocean" { version = "2.6.0" constraints = "2.6.0" hashes = [ "h1:P1C7e6RlhLpi6KuE/sMruDdM5zZisJwMuKGbnxg8tAw=", "zh:088c2a4eb9579947d50d8bcd722e75f2f1839acae302c8d43133b1da9926dae3", "zh:323ba833d011371ca6d953752b133c0acad6462176cd2f804077a5f9d892cd2e", "zh:3fbc64f1fabe57b6df49511c0d8753f1bbf776d5824ba060a51961d2a4265097", "zh:4c90a933e23288ee2db2228e4e30055882d91bed831c2191cbecd849b27e44cb", "zh:62f1cf4c82e5fcaf1a17e39cb96638f006b303758813a6c5ecb08bc93cd93364", "zh:68ad1354e9f925477dc41e658e84a4996ba662920bbc61a2680235b94811169b", "zh:9119b573c59429c2dfacb7d95b39c4e021783b8281ecd68f1621ad4a17c112cd", "zh:9c15e3660f2399c25ee3ad53bd54927a6529d1393a54f1e1c2a523e0369dea46", "zh:bc88f68bf6a6b5e803734f06731e31d61a5977ed1a638bfe102a54094c4d4030", "zh:c2b013a5d7e60b31211b0f8c0dd898840b8f1aa7225318da05def33b5edb9388", "zh:e46e21f6ffa7aac11ade8ab4b87a28ac405ef40a35793cef1f1fd6db6d8e5a0a", "zh:e879643369e03abc192fbcf7ab06611bb8f36d37ceb5641ba05d58869f10ab7c", "zh:ee9b56400e545ce1805842b795179a004313b8a947bd8f3490f5c5a0cb7703e5", "zh:fb44861ae0b58b594aa4e565e0ed06bce939753b14a20b4abd3e8276e839e7a7", ] } ================================================ FILE: .github/do-database/main.tf ================================================ terraform { required_providers { digitalocean = { source = "digitalocean/digitalocean" version = "2.6.0" } } } variable "do_token" {} provider "digitalocean" { token = var.do_token } resource "digitalocean_database_cluster" "default" { name = "edbtest" engine = "pg" version = "13" size = "db-s-4vcpu-8gb" region = "nyc1" node_count = 1 } ================================================ FILE: .github/do-database/outputs.tf ================================================ output "db_instance_address" { value = digitalocean_database_cluster.default.host } output "db_instance_port" { value = digitalocean_database_cluster.default.port } output "db_instance_user" { value = digitalocean_database_cluster.default.user } output "db_instance_password" { value = digitalocean_database_cluster.default.password sensitive = true } output "db_instance_database" { value = digitalocean_database_cluster.default.database } ================================================ FILE: .github/gcp-cloud-sql/.gitignore ================================================ /.terraform/ /terraform.tfstate* ================================================ FILE: .github/gcp-cloud-sql/.terraform.lock.hcl ================================================ # This file is maintained automatically by "terraform init". # Manual edits may be lost in future updates. provider "registry.terraform.io/hashicorp/google" { version = "3.62.0" hashes = [ "h1:FgfQz6EhKglcoU7vu1srYqEQFXy1Dti9MoZCxW8HL/w=", "zh:26e44482924c9d22624054dcebf23c89b102aee6b5c66675747cf2f7274cf703", "zh:518ebd73eb8f286f60a0c74970cd4e06883962c4af57f2899bc790d89e04038f", "zh:814036d49d5034cf26fd2239fc57075b42982e1f76ab703fa1cd7609802d979f", "zh:822dce72d1a77e1418b0e9187b4fe6f3e47b38ea5e51b81e5912074a8be3a7b7", "zh:981fc6780e1e9c756390727b94ebd822490f7504a05a26c818922da5635ff9b8", "zh:9a1a7e76ac6c37922261bdb148052fcdcbaf1f521ade68e26b430c106f1974b1", "zh:cb67b6abed58b6d1b789a72690154fcf35707f65c3fca1936bf72c0c819a03dd", "zh:cb87e8425b0eb97d80627243a37a67f0f81640499416ad32f1b786cc9d78c6f4", "zh:d3754c3f05dc9bbd4933b45676144c2dd456de775bff0252c058e0cff94b8f21", "zh:e2d8b0a78d698e92035e339782b299108d6021768ea4d97d150106c524f84ca1", ] } ================================================ FILE: .github/gcp-cloud-sql/main.tf ================================================ variable "password" {} provider "google" { region = "us-east1" } resource "google_sql_database_instance" "default" { database_version = "POSTGRES_13" deletion_protection = false settings { tier = "db-custom-1-3840" ip_configuration { authorized_networks { value = "0.0.0.0/0" } } } } resource "google_sql_user" "users" { instance = google_sql_database_instance.default.name name = "postgres" password = var.password deletion_policy = "ABANDON" } output "db_instance_address" { value = google_sql_database_instance.default.public_ip_address } ================================================ FILE: .github/heroku-postgres/.gitignore ================================================ /.terraform/ /terraform.tfstate* /.terraform.lock.hcl /.terraform.tfstate.lock.info ================================================ FILE: .github/heroku-postgres/main.tf ================================================ terraform { required_providers { heroku = { source = "heroku/heroku" version = "~> 4.0" } } } resource "heroku_addon" "database" { app = "edgedb-heroku-ci" plan = "heroku-postgresql:mini" config = { version = "14" } } output "heroku_postgres_dsn" { value = heroku_addon.database.config_var_values.DATABASE_URL sensitive = true } ================================================ FILE: .github/scripts/docs/preview-deploy.js ================================================ const DOCS_SITE_REPO = { org: "edgedb", repo: "edgedb.com", ref: "master", }; module.exports = async ({ github, context }) => { const { VERCEL_TOKEN, VERCEL_TEAM_ID } = process.env; if (!VERCEL_TOKEN || !VERCEL_TEAM_ID) { throw new Error( `cannot run docs preview deploy workflow, ` + `VERCEL_TOKEN or VERCEL_TEAM_ID secrets are missing` ); } const prBranch = context.payload.pull_request.head.ref; const commitSHA = context.payload.pull_request.head.sha; const shortCommitSHA = commitSHA.slice(0, 8); const existingComments = ( await github.rest.issues.listComments({ owner: context.repo.owner, repo: context.repo.repo, issue_number: context.issue.number, }) ).data; const commentHeader = `### Docs preview deploy\n`; let commentMessage = commentHeader; let updateComment = existingComments.find( (c) => c.performed_via_github_app?.slug === "github-actions" && c.body?.startsWith(commentHeader) ); let deploymentError = null; let deployment; try { deployment = await vercelFetch("https://api.vercel.com/v13/deployments", { name: "edgedb-docs", gitSource: { type: "github", ...DOCS_SITE_REPO, }, projectSettings: { buildCommand: `EDGEDB_REPO_BRANCH=${prBranch} EDGEDB_REPO_SHA=${commitSHA} yarn vercel-build`, }, }); commentMessage += `\n🔄 Deploying docs preview for commit ${shortCommitSHA}:\n\n`; } catch (e) { deploymentError = e; commentMessage += `\n❌ Failed to deploy docs preview for commit ${shortCommitSHA}:\n\n\`\`\`\n${e.message}\n\`\`\``; } commentMessage += `\n\n(Last updated: ${formatDatetime(new Date())})`; if (updateComment) { await github.rest.issues.updateComment({ owner: context.repo.owner, repo: context.repo.repo, comment_id: updateComment.id, body: commentMessage, }); } else { updateComment = ( await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: context.issue.number, body: commentMessage, }) ).data; } if (deploymentError) { throw new Error(`Docs preview deployment failed: ${e.message}`); } let i = 0; while (i < 40) { await sleep(15_000); i++; const status = ( await vercelFetch( `https://api.vercel.com/v13/deployments/${deployment.id}` ) ).status; const latestComment = await github.rest.issues.getComment({ owner: context.repo.owner, repo: context.repo.repo, comment_id: updateComment.id, }); if (!latestComment.data.body.includes(shortCommitSHA)) { console.log("Skipping further updates, new deployment has started"); return; } if (status === "READY" || status === "ERROR" || status === "CANCELED") { await github.rest.issues.updateComment({ owner: context.repo.owner, repo: context.repo.repo, comment_id: updateComment.id, body: `${commentHeader}${ status === "READY" ? `\n✅ Successfully deployed docs preview for commit ${shortCommitSHA}:` : `\n❌ Docs preview deployment ${ status === "CANCELED" ? "was canceled" : "failed" } for commit ${shortCommitSHA}:` }\n\n\n\n(Last updated: ${formatDatetime( new Date() )})`, }); if (status !== "READY") { throw new Error( `Docs preview deployment failed with status ${status}: https://${deployment.url}` ); } return; } } await github.rest.issues.updateComment({ owner: context.repo.owner, repo: context.repo.repo, comment_id: updateComment.id, body: `${commentHeader} ❌ Timed out waiting for deployment status to succeed or fail for commit ${shortCommitSHA}:\n\n\n\n(Last updated: ${formatDatetime(new Date())})`, }); throw new Error("Timed out waiting for deployment status to succeed or fail"); }; async function vercelFetch(url, body) { const { VERCEL_TOKEN, VERCEL_TEAM_ID } = process.env; const _url = new URL(url); url = `${_url.origin}${_url.pathname}?${new URLSearchParams({ teamId: VERCEL_TEAM_ID, })}`; let res; try { res = await fetch(url, { body: body ? JSON.stringify(body) : undefined, headers: { Authorization: `Bearer ${VERCEL_TOKEN}`, "Content-Type": body ? "application/json" : undefined, }, method: body ? "post" : "get", }); } catch (e) { throw new Error(`vercel api request failed: ${e}`); } if (res.ok) { return await res.json(); } else { let body; try { body = await res.text(); } catch (e) { // ignore } throw new Error( `vercel api request failed: ${res.status} ${res.statusText}, ${body}` ); } } function formatDatetime(date) { return date.toLocaleString("en-US", { year: "numeric", month: "short", day: "numeric", hour: "numeric", minute: "numeric", second: "numeric", hourCycle: "h24", timeZoneName: "short", }); } function sleep(milliseconds) { return new Promise((resolve) => setTimeout(resolve, milliseconds)); } ================================================ FILE: .github/scripts/patches/compute-ipu-versions.py ================================================ # Compute prior minor versions to test upgrading from import json import os import pathlib import re import sys from urllib import request sys.path.append(str(pathlib.Path(__file__).parent.parent.parent.parent)) import edb.buildmeta base = 'https://packages.geldata.com' u = f'{base}/archive/.jsonindexes/x86_64-unknown-linux-gnu.json' data = json.loads(request.urlopen(u).read()) u = f'{base}/archive/.jsonindexes/x86_64-unknown-linux-gnu.testing.json' data_testing = json.loads(request.urlopen(u).read()) version = edb.buildmeta.EDGEDB_MAJOR_VERSION - 1 versions = [] prerelease_versions = [] for obj in data['packages'] + data_testing['packages']: if ( obj['basename'] == 'gel-server' and obj['version_details']['major'] == version and ( not obj['version_details']['prerelease'] or obj['version_details']['prerelease'][0]['phase'] in ('beta', 'rc') ) ): l = ( versions if not obj['version_details']['prerelease'] else prerelease_versions ) l.append(( obj['version'], obj['basename'], base + obj['installrefs'][0]['ref'], )) prerelease_versions.sort(key=lambda x: x[0]) if not versions: # Some 7.x prerelease versions are busted due to having taken # extension patches that we don't intend to bundle with 8.x. # Only look at the last. versions = prerelease_versions[-1:] versions.sort(key=lambda x: x[0]) if len(versions) > 3: # We want to try 6.0 and 6.2 versions = [versions[0], versions[2], versions[-1]] elif len(versions) > 1: versions = [versions[0], versions[-1]] matrix = { "include": [ {"edgedb-version": v, "edgedb-url": url, "edgedb-basename": base} for v, base, url in versions ] } print("matrix:", matrix) if output := os.getenv('GITHUB_OUTPUT'): with open(output, 'a') as f: print(f'matrix={json.dumps(matrix)}', file=f) ================================================ FILE: .github/scripts/patches/compute-versions.py ================================================ # Compute prior minor versions to test upgrading from import json import os import re from urllib import request base = 'https://packages.edgedb.com' u = f'{base}/archive/.jsonindexes/x86_64-unknown-linux-gnu.json' data = json.loads(request.urlopen(u).read()) u = f'{base}/archive/.jsonindexes/x86_64-unknown-linux-gnu.testing.json' data_testing = json.loads(request.urlopen(u).read()) branch = os.getenv('GITHUB_BASE_REF') or os.getenv('GITHUB_REF_NAME') print("BRANCH", branch) version = int(re.findall(r'\d+', branch)[0]) versions = [] for obj in data['packages'] + data_testing['packages']: if ( obj['basename'] in {'gel-server', 'edgedb-server'} and obj['version_details']['major'] == version and ( not obj['version_details']['prerelease'] or obj['version_details']['prerelease'][0]['phase'] in ('beta', 'rc') ) ): versions.append(( obj['version'], obj['basename'], base + obj['installrefs'][0]['ref'], )) matrix = { "include": [ {"edgedb-version": v, "edgedb-url": url, "edgedb-basename": base, "make-dbs": mk} for v, base, url in versions for mk in [True, False] ] } print("matrix:", matrix) if output := os.getenv('GITHUB_OUTPUT'): with open(output, 'a') as f: print(f'matrix={json.dumps(matrix)}', file=f) ================================================ FILE: .github/scripts/patches/create-databases.py ================================================ # Create databases on the older edgedb version import edgedb import subprocess import sys cmd = [ sys.argv[1], '-D' 'test-dir', '--testmode', '--security', 'insecure_dev_mode', '--port', '10000', ] proc = subprocess.Popen(cmd) try: db = edgedb.create_client( host='localhost', port=10000, tls_security='insecure' ) for name in [ 'json', 'functions', 'expressions', 'casts', 'policies', 'vector', 'scope', 'httpextauth', ]: db.execute(f'create database {name};') # For the scope database, let's actually migrate to it. This # will test that the migrations can still work after the upgrade. db2 = edgedb.create_client( host='localhost', port=10000, tls_security='insecure', database='scope' ) with open("tests/schemas/cards.esdl") as f: body = f.read() db2.execute(f''' START MIGRATION TO {{ module default {{ {body} }} }}; POPULATE MIGRATION; COMMIT MIGRATION; ''') # Put something in the query cache db2.query(r''' SELECT User { name, id } ORDER BY User.name; ''') db2.close() # Compile a query from the CLI. # (At one point, having a cached query with proto version 1 caused # trouble...) cli_base = [ 'gel', 'query', '-H', 'localhost', '-P', '10000', '-b', 'json', '--tls-security', 'insecure', ] subprocess.run( [*cli_base, 'select 1+1'], check=True, ) # For the httpextauth database, create the proper extensions, so # that patching of the auth extension in place can get tested. db2 = edgedb.create_client( host='localhost', port=10000, tls_security='insecure', database='httpextauth' ) db2.execute(f''' create extension pgcrypto; create extension auth; ''') db2.close() finally: proc.terminate() proc.wait() ================================================ FILE: .github/scripts/patches/test-downgrade.py ================================================ # Test downgrading a database after an upgrade import edgedb import os import subprocess import json version = os.getenv('EDGEDB_VERSION') cmd = [ f'edgedb-server-{version}/bin/edgedb-server', '-D' 'test-dir', '--testmode', '--security', 'insecure_dev_mode', '--port', '10000', ] proc = subprocess.Popen(cmd) db = edgedb.create_client( host='localhost', port=10000, tls_security='insecure', database='policies', ) try: # Test that a basic query works res = json.loads(db.query_json(''' select Issue { name, number, watchers: {name} } filter .number = "1" ''')) expected = [{ "name": "Release EdgeDB", "number": "1", "watchers": [{"name": "Yury"}], }] assert res == expected, res finally: proc.terminate() proc.wait() ================================================ FILE: .github/workflows/.gitattributes ================================================ *.yml linguist-generated=true ================================================ FILE: .github/workflows/build.dryrun.yml ================================================ name: Package Build Dry Run on: workflow_dispatch: inputs: gelpkg_ref: description: "gel-pkg git ref used to build the packages" default: "master" metapkg_ref: description: "metapkg git ref used to build the packages" default: "master" jobs: prep: runs-on: ubuntu-latest outputs: if_debian_buster_x86_64: ${{ steps.scm.outputs.if_debian_buster_x86_64 }} if_debian_buster_aarch64: ${{ steps.scm.outputs.if_debian_buster_aarch64 }} if_debian_bullseye_x86_64: ${{ steps.scm.outputs.if_debian_bullseye_x86_64 }} if_debian_bullseye_aarch64: ${{ steps.scm.outputs.if_debian_bullseye_aarch64 }} if_debian_bookworm_x86_64: ${{ steps.scm.outputs.if_debian_bookworm_x86_64 }} if_debian_bookworm_aarch64: ${{ steps.scm.outputs.if_debian_bookworm_aarch64 }} if_ubuntu_focal_x86_64: ${{ steps.scm.outputs.if_ubuntu_focal_x86_64 }} if_ubuntu_focal_aarch64: ${{ steps.scm.outputs.if_ubuntu_focal_aarch64 }} if_ubuntu_jammy_x86_64: ${{ steps.scm.outputs.if_ubuntu_jammy_x86_64 }} if_ubuntu_jammy_aarch64: ${{ steps.scm.outputs.if_ubuntu_jammy_aarch64 }} if_ubuntu_noble_x86_64: ${{ steps.scm.outputs.if_ubuntu_noble_x86_64 }} if_ubuntu_noble_aarch64: ${{ steps.scm.outputs.if_ubuntu_noble_aarch64 }} if_centos_8_x86_64: ${{ steps.scm.outputs.if_centos_8_x86_64 }} if_centos_8_aarch64: ${{ steps.scm.outputs.if_centos_8_aarch64 }} if_rockylinux_9_x86_64: ${{ steps.scm.outputs.if_rockylinux_9_x86_64 }} if_rockylinux_9_aarch64: ${{ steps.scm.outputs.if_rockylinux_9_aarch64 }} if_linux_x86_64: ${{ steps.scm.outputs.if_linux_x86_64 }} if_linux_aarch64: ${{ steps.scm.outputs.if_linux_aarch64 }} if_linuxmusl_x86_64: ${{ steps.scm.outputs.if_linuxmusl_x86_64 }} if_linuxmusl_aarch64: ${{ steps.scm.outputs.if_linuxmusl_aarch64 }} if_macos_x86_64: ${{ steps.scm.outputs.if_macos_x86_64 }} if_macos_aarch64: ${{ steps.scm.outputs.if_macos_aarch64 }} steps: - uses: actions/checkout@v4 - name: Determine SCM revision id: scm shell: bash run: | rev=$(git rev-parse HEAD) jq_filter='.packages[] | select(.basename == "gel-server") | select(.architecture == $ARCH) | .version_details.metadata.scm_revision | . as $rev | select(($rev != null) and ($REV | startswith($rev)))' key="debian-buster-x86_64" val=true idx_file=buster.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="debian-buster-aarch64" val=true idx_file=buster.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="debian-bullseye-x86_64" val=true idx_file=bullseye.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="debian-bullseye-aarch64" val=true idx_file=bullseye.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="debian-bookworm-x86_64" val=true idx_file=bookworm.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="debian-bookworm-aarch64" val=true idx_file=bookworm.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-focal-x86_64" val=true idx_file=focal.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-focal-aarch64" val=true idx_file=focal.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-jammy-x86_64" val=true idx_file=jammy.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-jammy-aarch64" val=true idx_file=jammy.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-noble-x86_64" val=true idx_file=noble.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-noble-aarch64" val=true idx_file=noble.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="centos-8-x86_64" val=true idx_file=el8.nightly.json url=https://packages.edgedb.com/rpm/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="centos-8-aarch64" val=true idx_file=el8.nightly.json url=https://packages.edgedb.com/rpm/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="rockylinux-9-x86_64" val=true idx_file=el9.nightly.json url=https://packages.edgedb.com/rpm/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="rockylinux-9-aarch64" val=true idx_file=el9.nightly.json url=https://packages.edgedb.com/rpm/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linux-x86_64" val=true idx_file=x86_64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linux-aarch64" val=true idx_file=aarch64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linuxmusl-x86_64" val=true idx_file=x86_64-unknown-linux-musl.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linuxmusl-aarch64" val=true idx_file=aarch64-unknown-linux-musl.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="macos-x86_64" val=true idx_file=x86_64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="macos-aarch64" val=true idx_file=aarch64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT build-debian-buster-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_debian_buster_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-buster:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster build-debian-buster-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_debian_buster_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-buster:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster build-debian-bullseye-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_debian_bullseye_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bullseye:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye build-debian-bullseye-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_debian_bullseye_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bullseye:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye build-debian-bookworm-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_debian_bookworm_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bookworm:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm build-debian-bookworm-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_debian_bookworm_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bookworm:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm build-ubuntu-focal-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_ubuntu_focal_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-focal:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal build-ubuntu-focal-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_ubuntu_focal_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-focal:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal build-ubuntu-jammy-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_ubuntu_jammy_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-jammy:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy build-ubuntu-jammy-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_ubuntu_jammy_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-jammy:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy build-ubuntu-noble-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_ubuntu_noble_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-noble:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble build-ubuntu-noble-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_ubuntu_noble_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-noble:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble build-centos-8-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_centos_8_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-centos-8:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 build-centos-8-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_centos_8_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-centos-8:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 build-rockylinux-9-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_rockylinux_9_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-rockylinux-9:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 build-rockylinux-9-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_rockylinux_9_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-rockylinux-9:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 build-linux-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_linux_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linux-x86_64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 build-linux-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_linux_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linux-aarch64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 build-linuxmusl-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_linuxmusl_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linuxmusl-x86_64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true PKG_PLATFORM_LIBC: "musl" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 build-linuxmusl-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_linuxmusl_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linuxmusl-aarch64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true PKG_PLATFORM_LIBC: "musl" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 build-macos-x86_64: runs-on: ['macos-13'] needs: prep if: needs.prep.outputs.if_macos_x86_64 == 'true' steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: true with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: true - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: true with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: true with: node-version: '20' - name: Install dependencies if: true run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: true run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_ARCH: "x86_64" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled BUILD_GENERIC: true CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 build-macos-aarch64: runs-on: ['macos-14'] needs: prep if: needs.prep.outputs.if_macos_aarch64 == 'true' steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: true with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: true - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: true with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: true with: node-version: '20' - name: Install dependencies if: true run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: true run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_ARCH: "aarch64" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled BUILD_GENERIC: true CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 test-debian-buster-x86_64: needs: [build-debian-buster-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-buster:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-buster-aarch64: needs: [build-debian-buster-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-buster:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bullseye-x86_64: needs: [build-debian-bullseye-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bullseye:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bullseye-aarch64: needs: [build-debian-bullseye-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bullseye:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bookworm-x86_64: needs: [build-debian-bookworm-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bookworm:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bookworm-aarch64: needs: [build-debian-bookworm-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bookworm:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-focal-x86_64: needs: [build-ubuntu-focal-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-focal:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-focal-aarch64: needs: [build-ubuntu-focal-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-focal:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-jammy-x86_64: needs: [build-ubuntu-jammy-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-jammy:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-jammy-aarch64: needs: [build-ubuntu-jammy-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-jammy:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-noble-x86_64: needs: [build-ubuntu-noble-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-noble:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-noble-aarch64: needs: [build-ubuntu-noble-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-noble:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-centos-8-x86_64: needs: [build-centos-8-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-centos-8:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-centos-8-aarch64: needs: [build-centos-8-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-centos-8:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-rockylinux-9-x86_64: needs: [build-rockylinux-9-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-rockylinux-9:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-rockylinux-9-aarch64: needs: [build-rockylinux-9-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-rockylinux-9:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linux-x86_64: needs: [build-linux-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linux-aarch64: needs: [build-linux-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linux-aarch64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linuxmusl-x86_64: needs: [build-linuxmusl-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linuxmusl-x86_64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "musl" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linuxmusl-aarch64: needs: [build-linuxmusl-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linuxmusl-aarch64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "musl" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-macos-x86_64: needs: [build-macos-x86_64] runs-on: ['macos-13'] steps: - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 - name: Test env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " test_dump*.py test_backend_*.py test_database.py test_server_*.py test_edgeql_ddl.py test_session.py " run: | # Bump shmmax and shmall to avoid test failures. sudo sysctl -w kern.sysv.shmmax=12582912 sudo sysctl -w kern.sysv.shmall=12582912 edgedb-pkg/integration/macos/test.sh test-macos-aarch64: needs: [build-macos-aarch64] runs-on: ['macos-14'] steps: - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 - name: Test env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " run: | edgedb-pkg/integration/macos/test.sh workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - prep - build-debian-buster-x86_64 - test-debian-buster-x86_64 - build-debian-buster-aarch64 - test-debian-buster-aarch64 - build-debian-bullseye-x86_64 - test-debian-bullseye-x86_64 - build-debian-bullseye-aarch64 - test-debian-bullseye-aarch64 - build-debian-bookworm-x86_64 - test-debian-bookworm-x86_64 - build-debian-bookworm-aarch64 - test-debian-bookworm-aarch64 - build-ubuntu-focal-x86_64 - test-ubuntu-focal-x86_64 - build-ubuntu-focal-aarch64 - test-ubuntu-focal-aarch64 - build-ubuntu-jammy-x86_64 - test-ubuntu-jammy-x86_64 - build-ubuntu-jammy-aarch64 - test-ubuntu-jammy-aarch64 - build-ubuntu-noble-x86_64 - test-ubuntu-noble-x86_64 - build-ubuntu-noble-aarch64 - test-ubuntu-noble-aarch64 - build-centos-8-x86_64 - test-centos-8-x86_64 - build-centos-8-aarch64 - test-centos-8-aarch64 - build-rockylinux-9-x86_64 - test-rockylinux-9-x86_64 - build-rockylinux-9-aarch64 - test-rockylinux-9-aarch64 - build-linux-x86_64 - test-linux-x86_64 - build-linux-aarch64 - test-linux-aarch64 - build-linuxmusl-x86_64 - test-linuxmusl-x86_64 - build-linuxmusl-aarch64 - test-linuxmusl-aarch64 - build-macos-x86_64 - test-macos-x86_64 - build-macos-aarch64 - test-macos-aarch64 runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/build.ls-nightly.yml ================================================ name: 'ls: Build and Publish Nightly Packages' on: schedule: - cron: "0 1 * * *" workflow_dispatch: inputs: gelpkg_ref: description: "gel-pkg git ref used to build the packages" default: "master" metapkg_ref: description: "metapkg git ref used to build the packages" default: "master" push: branches: - nightly jobs: prep: runs-on: ubuntu-latest outputs: if_linux_x86_64: ${{ steps.scm.outputs.if_linux_x86_64 }} if_linux_aarch64: ${{ steps.scm.outputs.if_linux_aarch64 }} if_linuxmusl_x86_64: ${{ steps.scm.outputs.if_linuxmusl_x86_64 }} if_linuxmusl_aarch64: ${{ steps.scm.outputs.if_linuxmusl_aarch64 }} if_macos_x86_64: ${{ steps.scm.outputs.if_macos_x86_64 }} if_macos_aarch64: ${{ steps.scm.outputs.if_macos_aarch64 }} steps: - uses: actions/checkout@v4 - name: Determine SCM revision id: scm shell: bash run: | rev=$(git rev-parse HEAD) jq_filter='.packages[] | select(.basename == "gel-ls") | select(.architecture == $ARCH) | .version_details.metadata.scm_revision | . as $rev | select(($rev != null) and ($REV | startswith($rev)))' key="linux-x86_64" val=true idx_file=x86_64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linux-aarch64" val=true idx_file=aarch64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linuxmusl-x86_64" val=true idx_file=x86_64-unknown-linux-musl.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linuxmusl-aarch64" val=true idx_file=aarch64-unknown-linux-musl.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="macos-x86_64" val=true idx_file=x86_64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="macos-aarch64" val=true idx_file=aarch64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT build-linux-x86_64: runs-on: ['self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_linux_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linux-x86_64:latest env: PACKAGE: "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 build-linux-aarch64: runs-on: ['self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_linux_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linux-aarch64:latest env: PACKAGE: "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 build-linuxmusl-x86_64: runs-on: ['self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_linuxmusl_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linuxmusl-x86_64:latest env: PACKAGE: "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true PKG_PLATFORM_LIBC: "musl" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 build-linuxmusl-aarch64: runs-on: ['self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_linuxmusl_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linuxmusl-aarch64:latest env: PACKAGE: "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true PKG_PLATFORM_LIBC: "musl" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 build-macos-x86_64: runs-on: ['macos-13'] needs: prep if: needs.prep.outputs.if_macos_x86_64 == 'true' steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: true with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: true - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: true with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: true with: node-version: '20' - name: Install dependencies if: true run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: true run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_ARCH: "x86_64" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled BUILD_GENERIC: true CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 build-macos-aarch64: runs-on: ['macos-14'] needs: prep if: needs.prep.outputs.if_macos_aarch64 == 'true' steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: true with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: true - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: true with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: true with: node-version: '20' - name: Install dependencies if: true run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: true run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_ARCH: "aarch64" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled BUILD_GENERIC: true CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 publish-linux-x86_64: needs: [build-linux-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linux-x86_64: needs: [publish-linux-x86_64] runs-on: ['self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linux-x86_64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linux-x86_64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linux-aarch64: needs: [build-linux-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linux-aarch64: needs: [publish-linux-aarch64] runs-on: ['self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linux-aarch64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linux-aarch64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linuxmusl-x86_64: needs: [build-linuxmusl-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "musl" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linuxmusl-x86_64: needs: [publish-linuxmusl-x86_64] runs-on: ['self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linuxmusl-x86_64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linuxmusl-x86_64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linuxmusl-aarch64: needs: [build-linuxmusl-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "musl" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linuxmusl-aarch64: needs: [publish-linuxmusl-aarch64] runs-on: ['self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linuxmusl-aarch64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linuxmusl-aarch64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-macos-x86_64: needs: [build-macos-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: macos-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" publish-macos-aarch64: needs: [build-macos-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: macos-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - prep - build-linux-x86_64 - publish-linux-x86_64 - check-published-linux-x86_64 - build-linux-aarch64 - publish-linux-aarch64 - check-published-linux-aarch64 - build-linuxmusl-x86_64 - publish-linuxmusl-x86_64 - check-published-linuxmusl-x86_64 - build-linuxmusl-aarch64 - publish-linuxmusl-aarch64 - check-published-linuxmusl-aarch64 - build-macos-x86_64 - publish-macos-x86_64 - build-macos-aarch64 - publish-macos-aarch64 runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/build.nightly.yml ================================================ name: Build Test and Publish Nightly Packages on: schedule: - cron: "0 1 * * *" workflow_dispatch: inputs: gelpkg_ref: description: "gel-pkg git ref used to build the packages" default: "master" metapkg_ref: description: "metapkg git ref used to build the packages" default: "master" push: branches: - nightly jobs: prep: runs-on: ubuntu-latest outputs: if_debian_buster_x86_64: ${{ steps.scm.outputs.if_debian_buster_x86_64 }} if_debian_buster_aarch64: ${{ steps.scm.outputs.if_debian_buster_aarch64 }} if_debian_bullseye_x86_64: ${{ steps.scm.outputs.if_debian_bullseye_x86_64 }} if_debian_bullseye_aarch64: ${{ steps.scm.outputs.if_debian_bullseye_aarch64 }} if_debian_bookworm_x86_64: ${{ steps.scm.outputs.if_debian_bookworm_x86_64 }} if_debian_bookworm_aarch64: ${{ steps.scm.outputs.if_debian_bookworm_aarch64 }} if_ubuntu_focal_x86_64: ${{ steps.scm.outputs.if_ubuntu_focal_x86_64 }} if_ubuntu_focal_aarch64: ${{ steps.scm.outputs.if_ubuntu_focal_aarch64 }} if_ubuntu_jammy_x86_64: ${{ steps.scm.outputs.if_ubuntu_jammy_x86_64 }} if_ubuntu_jammy_aarch64: ${{ steps.scm.outputs.if_ubuntu_jammy_aarch64 }} if_ubuntu_noble_x86_64: ${{ steps.scm.outputs.if_ubuntu_noble_x86_64 }} if_ubuntu_noble_aarch64: ${{ steps.scm.outputs.if_ubuntu_noble_aarch64 }} if_centos_8_x86_64: ${{ steps.scm.outputs.if_centos_8_x86_64 }} if_centos_8_aarch64: ${{ steps.scm.outputs.if_centos_8_aarch64 }} if_rockylinux_9_x86_64: ${{ steps.scm.outputs.if_rockylinux_9_x86_64 }} if_rockylinux_9_aarch64: ${{ steps.scm.outputs.if_rockylinux_9_aarch64 }} if_linux_x86_64: ${{ steps.scm.outputs.if_linux_x86_64 }} if_linux_aarch64: ${{ steps.scm.outputs.if_linux_aarch64 }} if_linuxmusl_x86_64: ${{ steps.scm.outputs.if_linuxmusl_x86_64 }} if_linuxmusl_aarch64: ${{ steps.scm.outputs.if_linuxmusl_aarch64 }} if_macos_x86_64: ${{ steps.scm.outputs.if_macos_x86_64 }} if_macos_aarch64: ${{ steps.scm.outputs.if_macos_aarch64 }} steps: - uses: actions/checkout@v4 - name: Determine SCM revision id: scm shell: bash run: | rev=$(git rev-parse HEAD) jq_filter='.packages[] | select(.basename == "gel-server") | select(.architecture == $ARCH) | .version_details.metadata.scm_revision | . as $rev | select(($rev != null) and ($REV | startswith($rev)))' key="debian-buster-x86_64" val=true idx_file=buster.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="debian-buster-aarch64" val=true idx_file=buster.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="debian-bullseye-x86_64" val=true idx_file=bullseye.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="debian-bullseye-aarch64" val=true idx_file=bullseye.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="debian-bookworm-x86_64" val=true idx_file=bookworm.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="debian-bookworm-aarch64" val=true idx_file=bookworm.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-focal-x86_64" val=true idx_file=focal.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-focal-aarch64" val=true idx_file=focal.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-jammy-x86_64" val=true idx_file=jammy.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-jammy-aarch64" val=true idx_file=jammy.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-noble-x86_64" val=true idx_file=noble.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="ubuntu-noble-aarch64" val=true idx_file=noble.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="centos-8-x86_64" val=true idx_file=el8.nightly.json url=https://packages.edgedb.com/rpm/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="centos-8-aarch64" val=true idx_file=el8.nightly.json url=https://packages.edgedb.com/rpm/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="rockylinux-9-x86_64" val=true idx_file=el9.nightly.json url=https://packages.edgedb.com/rpm/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="rockylinux-9-aarch64" val=true idx_file=el9.nightly.json url=https://packages.edgedb.com/rpm/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linux-x86_64" val=true idx_file=x86_64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linux-aarch64" val=true idx_file=aarch64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linuxmusl-x86_64" val=true idx_file=x86_64-unknown-linux-musl.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="linuxmusl-aarch64" val=true idx_file=aarch64-unknown-linux-musl.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="macos-x86_64" val=true idx_file=x86_64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "x86_64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT key="macos-aarch64" val=true idx_file=aarch64-unknown-linux-gnu.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "aarch64" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT build-debian-buster-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_debian_buster_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-buster:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster build-debian-buster-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_debian_buster_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-buster:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster build-debian-bullseye-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_debian_bullseye_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bullseye:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye build-debian-bullseye-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_debian_bullseye_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bullseye:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye build-debian-bookworm-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_debian_bookworm_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bookworm:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm build-debian-bookworm-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_debian_bookworm_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bookworm:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm build-ubuntu-focal-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_ubuntu_focal_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-focal:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal build-ubuntu-focal-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_ubuntu_focal_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-focal:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal build-ubuntu-jammy-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_ubuntu_jammy_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-jammy:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy build-ubuntu-jammy-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_ubuntu_jammy_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-jammy:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy build-ubuntu-noble-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_ubuntu_noble_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-noble:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble build-ubuntu-noble-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_ubuntu_noble_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-noble:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble build-centos-8-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_centos_8_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-centos-8:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 build-centos-8-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_centos_8_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-centos-8:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 build-rockylinux-9-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_rockylinux_9_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-rockylinux-9:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 build-rockylinux-9-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_rockylinux_9_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-rockylinux-9:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 build-linux-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_linux_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linux-x86_64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 build-linux-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_linux_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linux-aarch64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 build-linuxmusl-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep if: needs.prep.outputs.if_linuxmusl_x86_64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linuxmusl-x86_64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true PKG_PLATFORM_LIBC: "musl" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 build-linuxmusl-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep if: needs.prep.outputs.if_linuxmusl_aarch64 == 'true' steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linuxmusl-aarch64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" EXTRA_OPTIMIZATIONS: "true" BUILD_GENERIC: true PKG_PLATFORM_LIBC: "musl" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 build-macos-x86_64: runs-on: ['macos-13'] needs: prep if: needs.prep.outputs.if_macos_x86_64 == 'true' steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: true with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: true - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: true with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: true with: node-version: '20' - name: Install dependencies if: true run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: true run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_ARCH: "x86_64" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled BUILD_GENERIC: true CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 build-macos-aarch64: runs-on: ['macos-14'] needs: prep if: needs.prep.outputs.if_macos_aarch64 == 'true' steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: true with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: true - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: true with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: true with: node-version: '20' - name: Install dependencies if: true run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: true run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_ARCH: "aarch64" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled BUILD_GENERIC: true CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 test-debian-buster-x86_64: needs: [build-debian-buster-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-buster:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-buster-aarch64: needs: [build-debian-buster-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-buster:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bullseye-x86_64: needs: [build-debian-bullseye-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bullseye:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bullseye-aarch64: needs: [build-debian-bullseye-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bullseye:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bookworm-x86_64: needs: [build-debian-bookworm-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bookworm:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bookworm-aarch64: needs: [build-debian-bookworm-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bookworm:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-focal-x86_64: needs: [build-ubuntu-focal-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-focal:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-focal-aarch64: needs: [build-ubuntu-focal-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-focal:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-jammy-x86_64: needs: [build-ubuntu-jammy-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-jammy:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-jammy-aarch64: needs: [build-ubuntu-jammy-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-jammy:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-noble-x86_64: needs: [build-ubuntu-noble-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-noble:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-noble-aarch64: needs: [build-ubuntu-noble-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-noble:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-centos-8-x86_64: needs: [build-centos-8-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-centos-8:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-centos-8-aarch64: needs: [build-centos-8-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-centos-8:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-rockylinux-9-x86_64: needs: [build-rockylinux-9-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-rockylinux-9:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-rockylinux-9-aarch64: needs: [build-rockylinux-9-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-rockylinux-9:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linux-x86_64: needs: [build-linux-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linux-aarch64: needs: [build-linux-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linux-aarch64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linuxmusl-x86_64: needs: [build-linuxmusl-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linuxmusl-x86_64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "musl" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linuxmusl-aarch64: needs: [build-linuxmusl-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linuxmusl-aarch64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "musl" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-macos-x86_64: needs: [build-macos-x86_64] runs-on: ['macos-13'] steps: - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 - name: Test env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " test_dump*.py test_backend_*.py test_database.py test_server_*.py test_edgeql_ddl.py test_session.py " run: | # Bump shmmax and shmall to avoid test failures. sudo sysctl -w kern.sysv.shmmax=12582912 sudo sysctl -w kern.sysv.shmall=12582912 edgedb-pkg/integration/macos/test.sh test-macos-aarch64: needs: [build-macos-aarch64] runs-on: ['macos-14'] steps: - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 - name: Test env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " run: | edgedb-pkg/integration/macos/test.sh publish-debian-buster-x86_64: needs: [test-debian-buster-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-buster-x86_64: needs: [publish-debian-buster-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-buster - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-buster:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-buster-aarch64: needs: [test-debian-buster-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-buster-aarch64: needs: [publish-debian-buster-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-buster - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-buster:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bullseye-x86_64: needs: [test-debian-bullseye-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bullseye-x86_64: needs: [publish-debian-bullseye-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bullseye - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bullseye:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bullseye-aarch64: needs: [test-debian-bullseye-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bullseye-aarch64: needs: [publish-debian-bullseye-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bullseye - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bullseye:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bookworm-x86_64: needs: [test-debian-bookworm-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bookworm-x86_64: needs: [publish-debian-bookworm-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bookworm - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bookworm:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bookworm-aarch64: needs: [test-debian-bookworm-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bookworm-aarch64: needs: [publish-debian-bookworm-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bookworm - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bookworm:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-focal-x86_64: needs: [test-ubuntu-focal-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-focal-x86_64: needs: [publish-ubuntu-focal-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-focal - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-focal:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-focal-aarch64: needs: [test-ubuntu-focal-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-focal-aarch64: needs: [publish-ubuntu-focal-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-focal - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-focal:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-jammy-x86_64: needs: [test-ubuntu-jammy-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-jammy-x86_64: needs: [publish-ubuntu-jammy-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-jammy - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-jammy:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-jammy-aarch64: needs: [test-ubuntu-jammy-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-jammy-aarch64: needs: [publish-ubuntu-jammy-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-jammy - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-jammy:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-noble-x86_64: needs: [test-ubuntu-noble-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-noble-x86_64: needs: [publish-ubuntu-noble-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-noble - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-noble:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-noble-aarch64: needs: [test-ubuntu-noble-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-noble-aarch64: needs: [publish-ubuntu-noble-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-noble - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-noble:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-centos-8-x86_64: needs: [test-centos-8-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-centos-8-x86_64: needs: [publish-centos-8-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: centos-8 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-centos-8:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-centos-8-aarch64: needs: [test-centos-8-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-centos-8-aarch64: needs: [publish-centos-8-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: centos-8 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-centos-8:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-rockylinux-9-x86_64: needs: [test-rockylinux-9-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-rockylinux-9-x86_64: needs: [publish-rockylinux-9-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: rockylinux-9 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-rockylinux-9:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-rockylinux-9-aarch64: needs: [test-rockylinux-9-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-rockylinux-9-aarch64: needs: [publish-rockylinux-9-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: rockylinux-9 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-rockylinux-9:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linux-x86_64: needs: [test-linux-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linux-x86_64: needs: [publish-linux-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linux-x86_64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linux-x86_64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linux-aarch64: needs: [test-linux-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linux-aarch64: needs: [publish-linux-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linux-aarch64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linux-aarch64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linuxmusl-x86_64: needs: [test-linuxmusl-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "musl" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linuxmusl-x86_64: needs: [publish-linuxmusl-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linuxmusl-x86_64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linuxmusl-x86_64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linuxmusl-aarch64: needs: [test-linuxmusl-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "musl" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linuxmusl-aarch64: needs: [publish-linuxmusl-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linuxmusl-aarch64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linuxmusl-aarch64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "nightly" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-macos-x86_64: needs: [test-macos-x86_64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: macos-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" publish-macos-aarch64: needs: [test-macos-aarch64] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: macos-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "nightly" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" publish-docker: needs: - check-published-debian-bookworm-x86_64 - check-published-debian-bookworm-aarch64 runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: repository: geldata/gel-docker ref: master path: dockerfile - name: Login to Docker Hub uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Login to GitHub Container Registry uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: registry: ghcr.io username: "edgedb-ci" password: ${{ secrets.GITHUB_CI_BOT_TOKEN }} - env: VERSION_SLOT: "${{ needs.check-published-debian-bookworm-x86_64.outputs.version-slot }}" VERSION_CORE: "${{ needs.check-published-debian-bookworm-x86_64.outputs.version-core }}" CATALOG_VERSION: "${{ needs.check-published-debian-bookworm-x86_64.outputs.catalog-version }}" PKG_SUBDIST: "nightly" id: tags run: | set -e url='https://registry.hub.docker.com/v2/repositories/geldata/gel/tags?page_size=100' repo_tags=$( while [ -n "$url" ]; do resp=$(curl -L -s "$url") url=$(echo "$resp" | jq -r ".next") if [ "$url" = "null" ] || [ -z "$url" ]; then break fi echo "$resp" | jq -r '."results"[]["name"]' done | grep "^[[:digit:]]\+.*" | grep -v "alpha\|beta\|rc" || : ) tags=() if [ "$PKG_SUBDIST" = "nightly" ]; then tags+=( "nightly" "nightly_${VERSION_SLOT}_cv${CATALOG_VERSION}" ) else tags+=( "$VERSION_CORE" ) top=$(printf "%s\n%s\n" "$VERSION_CORE" "$repo_tags" \ | grep "^${VERSION_SLOT}[\.-]" \ | sort --version-sort --reverse | head -n 1) if [ "$top" == "$VERSION_CORE" ]; then tags+=( "$VERSION_SLOT" ) fi if [ -z "$PKG_SUBDIST" ]; then top=$(printf "%s\n%s\n" "$VERSION_CORE" "$repo_tags" \ | sort --version-sort --reverse | head -n 1) if [ "$top" == "$VERSION_CORE" ]; then tags+=( "latest" ) fi fi fi fq_tags=() images=("geldata/gel" "ghcr.io/geldata/gel") for image in "${images[@]}"; do fq_tags+=("${tags[@]/#/${image}:}") done IFS=, echo "tags=${fq_tags[*]}" >> $GITHUB_OUTPUT - name: Set up QEMU uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0 - name: Set up Docker Buildx uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - name: Build and Publish Docker Image uses: docker/build-push-action@471d1dc4e07e5cdedd4c2171150001c434f0b7a4 # v6.10.0 with: push: true provenance: mode=max tags: "${{ steps.tags.outputs.tags }}" context: dockerfile build-args: | version=${{ needs.check-published-debian-bookworm-x86_64.outputs.version-slot }} exact_version=${{ needs.check-published-debian-bookworm-x86_64.outputs.version-core }} subdist=nightly platforms: linux/amd64,linux/arm64 workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - prep - build-debian-buster-x86_64 - test-debian-buster-x86_64 - publish-debian-buster-x86_64 - check-published-debian-buster-x86_64 - build-debian-buster-aarch64 - test-debian-buster-aarch64 - publish-debian-buster-aarch64 - check-published-debian-buster-aarch64 - build-debian-bullseye-x86_64 - test-debian-bullseye-x86_64 - publish-debian-bullseye-x86_64 - check-published-debian-bullseye-x86_64 - build-debian-bullseye-aarch64 - test-debian-bullseye-aarch64 - publish-debian-bullseye-aarch64 - check-published-debian-bullseye-aarch64 - build-debian-bookworm-x86_64 - test-debian-bookworm-x86_64 - publish-debian-bookworm-x86_64 - check-published-debian-bookworm-x86_64 - build-debian-bookworm-aarch64 - test-debian-bookworm-aarch64 - publish-debian-bookworm-aarch64 - check-published-debian-bookworm-aarch64 - build-ubuntu-focal-x86_64 - test-ubuntu-focal-x86_64 - publish-ubuntu-focal-x86_64 - check-published-ubuntu-focal-x86_64 - build-ubuntu-focal-aarch64 - test-ubuntu-focal-aarch64 - publish-ubuntu-focal-aarch64 - check-published-ubuntu-focal-aarch64 - build-ubuntu-jammy-x86_64 - test-ubuntu-jammy-x86_64 - publish-ubuntu-jammy-x86_64 - check-published-ubuntu-jammy-x86_64 - build-ubuntu-jammy-aarch64 - test-ubuntu-jammy-aarch64 - publish-ubuntu-jammy-aarch64 - check-published-ubuntu-jammy-aarch64 - build-ubuntu-noble-x86_64 - test-ubuntu-noble-x86_64 - publish-ubuntu-noble-x86_64 - check-published-ubuntu-noble-x86_64 - build-ubuntu-noble-aarch64 - test-ubuntu-noble-aarch64 - publish-ubuntu-noble-aarch64 - check-published-ubuntu-noble-aarch64 - build-centos-8-x86_64 - test-centos-8-x86_64 - publish-centos-8-x86_64 - check-published-centos-8-x86_64 - build-centos-8-aarch64 - test-centos-8-aarch64 - publish-centos-8-aarch64 - check-published-centos-8-aarch64 - build-rockylinux-9-x86_64 - test-rockylinux-9-x86_64 - publish-rockylinux-9-x86_64 - check-published-rockylinux-9-x86_64 - build-rockylinux-9-aarch64 - test-rockylinux-9-aarch64 - publish-rockylinux-9-aarch64 - check-published-rockylinux-9-aarch64 - build-linux-x86_64 - test-linux-x86_64 - publish-linux-x86_64 - check-published-linux-x86_64 - build-linux-aarch64 - test-linux-aarch64 - publish-linux-aarch64 - check-published-linux-aarch64 - build-linuxmusl-x86_64 - test-linuxmusl-x86_64 - publish-linuxmusl-x86_64 - check-published-linuxmusl-x86_64 - build-linuxmusl-aarch64 - test-linuxmusl-aarch64 - publish-linuxmusl-aarch64 - check-published-linuxmusl-aarch64 - build-macos-x86_64 - test-macos-x86_64 - publish-macos-x86_64 - build-macos-aarch64 - test-macos-aarch64 - publish-macos-aarch64 - publish-docker runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/build.release.yml ================================================ name: Build Test and Publish a Release on: workflow_dispatch: inputs: gelpkg_ref: description: "gel-pkg git ref used to build the packages" default: "master" metapkg_ref: description: "metapkg git ref used to build the packages" default: "master" jobs: prep: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 build-debian-buster-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-buster:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster build-debian-buster-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-buster:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster build-debian-bullseye-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bullseye:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye build-debian-bullseye-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bullseye:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye build-debian-bookworm-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bookworm:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm build-debian-bookworm-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bookworm:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm build-ubuntu-focal-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-focal:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal build-ubuntu-focal-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-focal:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal build-ubuntu-jammy-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-jammy:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy build-ubuntu-jammy-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-jammy:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy build-ubuntu-noble-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-noble:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble build-ubuntu-noble-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-noble:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble build-centos-8-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-centos-8:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 build-centos-8-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-centos-8:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 build-rockylinux-9-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-rockylinux-9:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 build-rockylinux-9-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-rockylinux-9:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 build-linux-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linux-x86_64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" BUILD_GENERIC: true METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 build-linux-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linux-aarch64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" BUILD_GENERIC: true METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 build-linuxmusl-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linuxmusl-x86_64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" BUILD_GENERIC: true PKG_PLATFORM_LIBC: "musl" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 build-linuxmusl-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linuxmusl-aarch64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" BUILD_GENERIC: true PKG_PLATFORM_LIBC: "musl" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 build-macos-x86_64: runs-on: ['macos-13'] needs: prep steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: true with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: true - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: true with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: true with: node-version: '20' - name: Install dependencies if: true run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: true run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" BUILD_IS_RELEASE: "true" PKG_REVISION: "" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_ARCH: "x86_64" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled BUILD_GENERIC: true CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 build-macos-aarch64: runs-on: ['macos-14'] needs: prep steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: true with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: true - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: true with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: true with: node-version: '20' - name: Install dependencies if: true run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: true run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" BUILD_IS_RELEASE: "true" PKG_REVISION: "" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_ARCH: "aarch64" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled BUILD_GENERIC: true CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 test-debian-buster-x86_64: needs: [build-debian-buster-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-buster:latest env: PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-buster-aarch64: needs: [build-debian-buster-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-buster:latest env: PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bullseye-x86_64: needs: [build-debian-bullseye-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bullseye:latest env: PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bullseye-aarch64: needs: [build-debian-bullseye-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bullseye:latest env: PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bookworm-x86_64: needs: [build-debian-bookworm-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bookworm:latest env: PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bookworm-aarch64: needs: [build-debian-bookworm-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bookworm:latest env: PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-focal-x86_64: needs: [build-ubuntu-focal-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-focal:latest env: PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-focal-aarch64: needs: [build-ubuntu-focal-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-focal:latest env: PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-jammy-x86_64: needs: [build-ubuntu-jammy-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-jammy:latest env: PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-jammy-aarch64: needs: [build-ubuntu-jammy-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-jammy:latest env: PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-noble-x86_64: needs: [build-ubuntu-noble-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-noble:latest env: PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-noble-aarch64: needs: [build-ubuntu-noble-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-noble:latest env: PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-centos-8-x86_64: needs: [build-centos-8-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-centos-8:latest env: PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-centos-8-aarch64: needs: [build-centos-8-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-centos-8:latest env: PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-rockylinux-9-x86_64: needs: [build-rockylinux-9-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-rockylinux-9:latest env: PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-rockylinux-9-aarch64: needs: [build-rockylinux-9-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-rockylinux-9:latest env: PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linux-x86_64: needs: [build-linux-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linux-x86_64:latest env: PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linux-aarch64: needs: [build-linux-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linux-aarch64:latest env: PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linuxmusl-x86_64: needs: [build-linuxmusl-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linuxmusl-x86_64:latest env: PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "musl" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linuxmusl-aarch64: needs: [build-linuxmusl-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linuxmusl-aarch64:latest env: PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "musl" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-macos-x86_64: needs: [build-macos-x86_64] runs-on: ['macos-13'] steps: - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 - name: Test env: PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " test_dump*.py test_backend_*.py test_database.py test_server_*.py test_edgeql_ddl.py test_session.py " run: | # Bump shmmax and shmall to avoid test failures. sudo sysctl -w kern.sysv.shmmax=12582912 sudo sysctl -w kern.sysv.shmall=12582912 edgedb-pkg/integration/macos/test.sh test-macos-aarch64: needs: [build-macos-aarch64] runs-on: ['macos-14'] steps: - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 - name: Test env: PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " run: | edgedb-pkg/integration/macos/test.sh collect: needs: - test-debian-buster-x86_64 - test-debian-buster-aarch64 - test-debian-bullseye-x86_64 - test-debian-bullseye-aarch64 - test-debian-bookworm-x86_64 - test-debian-bookworm-aarch64 - test-ubuntu-focal-x86_64 - test-ubuntu-focal-aarch64 - test-ubuntu-jammy-x86_64 - test-ubuntu-jammy-aarch64 - test-ubuntu-noble-x86_64 - test-ubuntu-noble-aarch64 - test-centos-8-x86_64 - test-centos-8-aarch64 - test-rockylinux-9-x86_64 - test-rockylinux-9-aarch64 - test-linux-x86_64 - test-linux-aarch64 - test-linuxmusl-x86_64 - test-linuxmusl-aarch64 - test-macos-x86_64 - test-macos-aarch64 runs-on: ubuntu-latest steps: - run: echo 'All build+tests passed, ready to publish now!' publish-debian-buster-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-buster-x86_64: needs: [publish-debian-buster-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-buster - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-buster:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-buster-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-buster-aarch64: needs: [publish-debian-buster-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-buster - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-buster:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bullseye-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bullseye-x86_64: needs: [publish-debian-bullseye-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bullseye - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bullseye:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bullseye-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bullseye-aarch64: needs: [publish-debian-bullseye-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bullseye - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bullseye:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bookworm-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bookworm-x86_64: needs: [publish-debian-bookworm-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bookworm - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bookworm:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bookworm-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bookworm-aarch64: needs: [publish-debian-bookworm-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bookworm - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bookworm:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-focal-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-focal-x86_64: needs: [publish-ubuntu-focal-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-focal - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-focal:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-focal-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-focal-aarch64: needs: [publish-ubuntu-focal-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-focal - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-focal:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-jammy-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-jammy-x86_64: needs: [publish-ubuntu-jammy-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-jammy - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-jammy:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-jammy-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-jammy-aarch64: needs: [publish-ubuntu-jammy-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-jammy - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-jammy:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-noble-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-noble-x86_64: needs: [publish-ubuntu-noble-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-noble - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-noble:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-noble-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-noble-aarch64: needs: [publish-ubuntu-noble-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-noble - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-noble:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-centos-8-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-centos-8-x86_64: needs: [publish-centos-8-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: centos-8 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-centos-8:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-centos-8-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-centos-8-aarch64: needs: [publish-centos-8-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: centos-8 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-centos-8:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-rockylinux-9-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-rockylinux-9-x86_64: needs: [publish-rockylinux-9-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: rockylinux-9 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-rockylinux-9:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-rockylinux-9-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-rockylinux-9-aarch64: needs: [publish-rockylinux-9-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: rockylinux-9 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-rockylinux-9:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linux-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linux-x86_64: needs: [publish-linux-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linux-x86_64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linux-x86_64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linux-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linux-aarch64: needs: [publish-linux-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linux-aarch64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linux-aarch64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linuxmusl-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "musl" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linuxmusl-x86_64: needs: [publish-linuxmusl-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linuxmusl-x86_64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linuxmusl-x86_64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linuxmusl-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "musl" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linuxmusl-aarch64: needs: [publish-linuxmusl-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linuxmusl-aarch64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linuxmusl-aarch64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-macos-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: macos-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" publish-macos-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: macos-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" publish-docker: needs: - check-published-debian-bookworm-x86_64 - check-published-debian-bookworm-aarch64 runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: repository: geldata/gel-docker ref: master path: dockerfile - name: Login to Docker Hub uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Login to GitHub Container Registry uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: registry: ghcr.io username: "edgedb-ci" password: ${{ secrets.GITHUB_CI_BOT_TOKEN }} - env: VERSION_SLOT: "${{ needs.check-published-debian-bookworm-x86_64.outputs.version-slot }}" VERSION_CORE: "${{ needs.check-published-debian-bookworm-x86_64.outputs.version-core }}" CATALOG_VERSION: "${{ needs.check-published-debian-bookworm-x86_64.outputs.catalog-version }}" PKG_SUBDIST: "" id: tags run: | set -e url='https://registry.hub.docker.com/v2/repositories/geldata/gel/tags?page_size=100' repo_tags=$( while [ -n "$url" ]; do resp=$(curl -L -s "$url") url=$(echo "$resp" | jq -r ".next") if [ "$url" = "null" ] || [ -z "$url" ]; then break fi echo "$resp" | jq -r '."results"[]["name"]' done | grep "^[[:digit:]]\+.*" | grep -v "alpha\|beta\|rc" || : ) tags=() if [ "$PKG_SUBDIST" = "nightly" ]; then tags+=( "nightly" "nightly_${VERSION_SLOT}_cv${CATALOG_VERSION}" ) else tags+=( "$VERSION_CORE" ) top=$(printf "%s\n%s\n" "$VERSION_CORE" "$repo_tags" \ | grep "^${VERSION_SLOT}[\.-]" \ | sort --version-sort --reverse | head -n 1) if [ "$top" == "$VERSION_CORE" ]; then tags+=( "$VERSION_SLOT" ) fi if [ -z "$PKG_SUBDIST" ]; then top=$(printf "%s\n%s\n" "$VERSION_CORE" "$repo_tags" \ | sort --version-sort --reverse | head -n 1) if [ "$top" == "$VERSION_CORE" ]; then tags+=( "latest" ) fi fi fi fq_tags=() images=("geldata/gel" "ghcr.io/geldata/gel") for image in "${images[@]}"; do fq_tags+=("${tags[@]/#/${image}:}") done IFS=, echo "tags=${fq_tags[*]}" >> $GITHUB_OUTPUT - name: Set up QEMU uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0 - name: Set up Docker Buildx uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - name: Build and Publish Docker Image uses: docker/build-push-action@471d1dc4e07e5cdedd4c2171150001c434f0b7a4 # v6.10.0 with: push: true provenance: mode=max tags: "${{ steps.tags.outputs.tags }}" context: dockerfile build-args: | version=${{ needs.check-published-debian-bookworm-x86_64.outputs.version-slot }} exact_version=${{ needs.check-published-debian-bookworm-x86_64.outputs.version-core }} platforms: linux/amd64,linux/arm64 workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - prep - collect - build-debian-buster-x86_64 - test-debian-buster-x86_64 - publish-debian-buster-x86_64 - check-published-debian-buster-x86_64 - build-debian-buster-aarch64 - test-debian-buster-aarch64 - publish-debian-buster-aarch64 - check-published-debian-buster-aarch64 - build-debian-bullseye-x86_64 - test-debian-bullseye-x86_64 - publish-debian-bullseye-x86_64 - check-published-debian-bullseye-x86_64 - build-debian-bullseye-aarch64 - test-debian-bullseye-aarch64 - publish-debian-bullseye-aarch64 - check-published-debian-bullseye-aarch64 - build-debian-bookworm-x86_64 - test-debian-bookworm-x86_64 - publish-debian-bookworm-x86_64 - check-published-debian-bookworm-x86_64 - build-debian-bookworm-aarch64 - test-debian-bookworm-aarch64 - publish-debian-bookworm-aarch64 - check-published-debian-bookworm-aarch64 - build-ubuntu-focal-x86_64 - test-ubuntu-focal-x86_64 - publish-ubuntu-focal-x86_64 - check-published-ubuntu-focal-x86_64 - build-ubuntu-focal-aarch64 - test-ubuntu-focal-aarch64 - publish-ubuntu-focal-aarch64 - check-published-ubuntu-focal-aarch64 - build-ubuntu-jammy-x86_64 - test-ubuntu-jammy-x86_64 - publish-ubuntu-jammy-x86_64 - check-published-ubuntu-jammy-x86_64 - build-ubuntu-jammy-aarch64 - test-ubuntu-jammy-aarch64 - publish-ubuntu-jammy-aarch64 - check-published-ubuntu-jammy-aarch64 - build-ubuntu-noble-x86_64 - test-ubuntu-noble-x86_64 - publish-ubuntu-noble-x86_64 - check-published-ubuntu-noble-x86_64 - build-ubuntu-noble-aarch64 - test-ubuntu-noble-aarch64 - publish-ubuntu-noble-aarch64 - check-published-ubuntu-noble-aarch64 - build-centos-8-x86_64 - test-centos-8-x86_64 - publish-centos-8-x86_64 - check-published-centos-8-x86_64 - build-centos-8-aarch64 - test-centos-8-aarch64 - publish-centos-8-aarch64 - check-published-centos-8-aarch64 - build-rockylinux-9-x86_64 - test-rockylinux-9-x86_64 - publish-rockylinux-9-x86_64 - check-published-rockylinux-9-x86_64 - build-rockylinux-9-aarch64 - test-rockylinux-9-aarch64 - publish-rockylinux-9-aarch64 - check-published-rockylinux-9-aarch64 - build-linux-x86_64 - test-linux-x86_64 - publish-linux-x86_64 - check-published-linux-x86_64 - build-linux-aarch64 - test-linux-aarch64 - publish-linux-aarch64 - check-published-linux-aarch64 - build-linuxmusl-x86_64 - test-linuxmusl-x86_64 - publish-linuxmusl-x86_64 - check-published-linuxmusl-x86_64 - build-linuxmusl-aarch64 - test-linuxmusl-aarch64 - publish-linuxmusl-aarch64 - check-published-linuxmusl-aarch64 - build-macos-x86_64 - test-macos-x86_64 - publish-macos-x86_64 - build-macos-aarch64 - test-macos-aarch64 - publish-macos-aarch64 - publish-docker runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/build.testing.yml ================================================ name: Build Test and Publish a Testing Release on: workflow_dispatch: inputs: gelpkg_ref: description: "gel-pkg git ref used to build the packages" default: "master" metapkg_ref: description: "metapkg git ref used to build the packages" default: "master" jobs: prep: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 build-debian-buster-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-buster:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster build-debian-buster-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-buster:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster build-debian-bullseye-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bullseye:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye build-debian-bullseye-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bullseye:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye build-debian-bookworm-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bookworm:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm build-debian-bookworm-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-debian-bookworm:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm build-ubuntu-focal-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-focal:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal build-ubuntu-focal-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-focal:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal build-ubuntu-jammy-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-jammy:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy build-ubuntu-jammy-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-jammy:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy build-ubuntu-noble-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-noble:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble build-ubuntu-noble-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-ubuntu-noble:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble build-centos-8-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-centos-8:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 build-centos-8-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-centos-8:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 build-rockylinux-9-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-rockylinux-9:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 build-rockylinux-9-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-rockylinux-9:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 build-linux-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linux-x86_64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" BUILD_GENERIC: true METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 build-linux-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linux-aarch64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" BUILD_GENERIC: true METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 build-linuxmusl-x86_64: runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linuxmusl-x86_64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" BUILD_GENERIC: true PKG_PLATFORM_LIBC: "musl" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 build-linuxmusl-aarch64: runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] needs: prep steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-linuxmusl-aarch64:latest env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" EXTRA_OPTIMIZATIONS: "true" BUILD_IS_RELEASE: "true" BUILD_GENERIC: true PKG_PLATFORM_LIBC: "musl" METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 build-macos-x86_64: runs-on: ['macos-13'] needs: prep steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: true with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: true - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: true with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: true with: node-version: '20' - name: Install dependencies if: true run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: true run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" BUILD_IS_RELEASE: "true" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_ARCH: "x86_64" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled BUILD_GENERIC: true CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 build-macos-aarch64: runs-on: ['macos-14'] needs: prep steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: true with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: true - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: true with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: true with: node-version: '20' - name: Install dependencies if: true run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: true run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "edgedbpkg.edgedb:Gel" SRC_REF: "${{ github.sha }}" BUILD_IS_RELEASE: "true" PKG_REVISION: "" PKG_SUBDIST: "testing" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_ARCH: "aarch64" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled BUILD_GENERIC: true CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 test-debian-buster-x86_64: needs: [build-debian-buster-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-buster:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-buster-aarch64: needs: [build-debian-buster-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-buster:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bullseye-x86_64: needs: [build-debian-bullseye-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bullseye:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bullseye-aarch64: needs: [build-debian-bullseye-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bullseye:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bookworm-x86_64: needs: [build-debian-bookworm-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bookworm:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-debian-bookworm-aarch64: needs: [build-debian-bookworm-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-debian-bookworm:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-focal-x86_64: needs: [build-ubuntu-focal-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-focal:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-focal-aarch64: needs: [build-ubuntu-focal-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-focal:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-jammy-x86_64: needs: [build-ubuntu-jammy-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-jammy:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-jammy-aarch64: needs: [build-ubuntu-jammy-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-jammy:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-noble-x86_64: needs: [build-ubuntu-noble-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-noble:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-ubuntu-noble-aarch64: needs: [build-ubuntu-noble-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-ubuntu-noble:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-centos-8-x86_64: needs: [build-centos-8-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-centos-8:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-centos-8-aarch64: needs: [build-centos-8-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-centos-8:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-rockylinux-9-x86_64: needs: [build-rockylinux-9-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-rockylinux-9:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-rockylinux-9-aarch64: needs: [build-rockylinux-9-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-rockylinux-9:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linux-x86_64: needs: [build-linux-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linux-x86_64:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linux-aarch64: needs: [build-linux-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linux-aarch64:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linuxmusl-x86_64: needs: [build-linuxmusl-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linuxmusl-x86_64:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "musl" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-linuxmusl-aarch64: needs: [build-linuxmusl-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-linuxmusl-aarch64:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "musl" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: 0 test-macos-x86_64: needs: [build-macos-x86_64] runs-on: ['macos-13'] steps: - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 - name: Test env: PKG_SUBDIST: "testing" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " test_dump*.py test_backend_*.py test_database.py test_server_*.py test_edgeql_ddl.py test_session.py " run: | # Bump shmmax and shmall to avoid test failures. sudo sysctl -w kern.sysv.shmmax=12582912 sudo sysctl -w kern.sysv.shmall=12582912 edgedb-pkg/integration/macos/test.sh test-macos-aarch64: needs: [build-macos-aarch64] runs-on: ['macos-14'] steps: - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 - name: Test env: PKG_SUBDIST: "testing" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PKG_TEST_SELECT: "" PKG_TEST_EXCLUDE: "" PKG_TEST_FILES: " " run: | edgedb-pkg/integration/macos/test.sh collect: needs: - test-debian-buster-x86_64 - test-debian-buster-aarch64 - test-debian-bullseye-x86_64 - test-debian-bullseye-aarch64 - test-debian-bookworm-x86_64 - test-debian-bookworm-aarch64 - test-ubuntu-focal-x86_64 - test-ubuntu-focal-aarch64 - test-ubuntu-jammy-x86_64 - test-ubuntu-jammy-aarch64 - test-ubuntu-noble-x86_64 - test-ubuntu-noble-aarch64 - test-centos-8-x86_64 - test-centos-8-aarch64 - test-rockylinux-9-x86_64 - test-rockylinux-9-aarch64 - test-linux-x86_64 - test-linux-aarch64 - test-linuxmusl-x86_64 - test-linuxmusl-aarch64 - test-macos-x86_64 - test-macos-aarch64 runs-on: ubuntu-latest steps: - run: echo 'All build+tests passed, ready to publish now!' publish-debian-buster-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-buster-x86_64: needs: [publish-debian-buster-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-x86_64 path: artifacts/debian-buster - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-buster - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-buster:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-buster-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-buster-aarch64: needs: [publish-debian-buster-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-buster-aarch64 path: artifacts/debian-buster - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-buster - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-buster:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "buster" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bullseye-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bullseye-x86_64: needs: [publish-debian-bullseye-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-x86_64 path: artifacts/debian-bullseye - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bullseye - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bullseye:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bullseye-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bullseye-aarch64: needs: [publish-debian-bullseye-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bullseye-aarch64 path: artifacts/debian-bullseye - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bullseye - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bullseye:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bullseye" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bookworm-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bookworm-x86_64: needs: [publish-debian-bookworm-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-x86_64 path: artifacts/debian-bookworm - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bookworm - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bookworm:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-debian-bookworm-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-debian-bookworm-aarch64: needs: [publish-debian-bookworm-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-debian-bookworm-aarch64 path: artifacts/debian-bookworm - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: debian-bookworm - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-debian-bookworm:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "debian" PKG_PLATFORM_VERSION: "bookworm" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-focal-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-focal-x86_64: needs: [publish-ubuntu-focal-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-x86_64 path: artifacts/ubuntu-focal - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-focal - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-focal:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-focal-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-focal-aarch64: needs: [publish-ubuntu-focal-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-focal-aarch64 path: artifacts/ubuntu-focal - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-focal - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-focal:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "focal" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-jammy-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-jammy-x86_64: needs: [publish-ubuntu-jammy-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-x86_64 path: artifacts/ubuntu-jammy - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-jammy - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-jammy:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-jammy-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-jammy-aarch64: needs: [publish-ubuntu-jammy-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-jammy-aarch64 path: artifacts/ubuntu-jammy - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-jammy - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-jammy:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "jammy" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-noble-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-noble-x86_64: needs: [publish-ubuntu-noble-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-x86_64 path: artifacts/ubuntu-noble - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-noble - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-noble:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-ubuntu-noble-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-ubuntu-noble-aarch64: needs: [publish-ubuntu-noble-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-ubuntu-noble-aarch64 path: artifacts/ubuntu-noble - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: ubuntu-noble - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-ubuntu-noble:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "ubuntu" PKG_PLATFORM_VERSION: "noble" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-centos-8-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-centos-8-x86_64: needs: [publish-centos-8-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-x86_64 path: artifacts/centos-8 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: centos-8 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-centos-8:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-centos-8-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-centos-8-aarch64: needs: [publish-centos-8-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-centos-8-aarch64 path: artifacts/centos-8 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: centos-8 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-centos-8:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "centos" PKG_PLATFORM_VERSION: "8" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-rockylinux-9-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-rockylinux-9-x86_64: needs: [publish-rockylinux-9-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-x86_64 path: artifacts/rockylinux-9 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: rockylinux-9 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-rockylinux-9:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-rockylinux-9-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-rockylinux-9-aarch64: needs: [publish-rockylinux-9-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-rockylinux-9-aarch64 path: artifacts/rockylinux-9 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: rockylinux-9 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-rockylinux-9:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "rockylinux" PKG_PLATFORM_VERSION: "9" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linux-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linux-x86_64: needs: [publish-linux-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-x86_64 path: artifacts/linux-x86_64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linux-x86_64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linux-x86_64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linux-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linux-aarch64: needs: [publish-linux-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linux-aarch64 path: artifacts/linux-aarch64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linux-aarch64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linux-aarch64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linuxmusl-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_PLATFORM_LIBC: "musl" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linuxmusl-x86_64: needs: [publish-linuxmusl-x86_64] runs-on: ['package-builder', 'self-hosted', 'linux', 'x64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-x86_64 path: artifacts/linuxmusl-x86_64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linuxmusl-x86_64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linuxmusl-x86_64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "x86_64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-linuxmusl-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_PLATFORM_LIBC: "musl" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published-linuxmusl-aarch64: needs: [publish-linuxmusl-aarch64] runs-on: ['package-builder', 'self-hosted', 'linux', 'arm64'] steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-linuxmusl-aarch64 path: artifacts/linuxmusl-aarch64 - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: linuxmusl-aarch64 - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-linuxmusl-aarch64:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" PKG_SUBDIST: "testing" PACKAGE_SERVER: sftp://uploader@package-upload.edgedb.net:22/ PKG_PLATFORM: "linux" PKG_PLATFORM_VERSION: "aarch64" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} publish-macos-x86_64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-x86_64 path: artifacts/macos-x86_64 - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: macos-x86_64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "x86_64" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" publish-macos-aarch64: needs: [collect] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-macos-aarch64 path: artifacts/macos-aarch64 - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: macos-aarch64 - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: PKG_SUBDIST: "testing" PKG_PLATFORM: "macos" PKG_PLATFORM_VERSION: "aarch64" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" publish-docker: needs: - check-published-debian-bookworm-x86_64 - check-published-debian-bookworm-aarch64 runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: repository: geldata/gel-docker ref: master path: dockerfile - name: Login to Docker Hub uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Login to GitHub Container Registry uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: registry: ghcr.io username: "edgedb-ci" password: ${{ secrets.GITHUB_CI_BOT_TOKEN }} - env: VERSION_SLOT: "${{ needs.check-published-debian-bookworm-x86_64.outputs.version-slot }}" VERSION_CORE: "${{ needs.check-published-debian-bookworm-x86_64.outputs.version-core }}" CATALOG_VERSION: "${{ needs.check-published-debian-bookworm-x86_64.outputs.catalog-version }}" PKG_SUBDIST: "testing" id: tags run: | set -e url='https://registry.hub.docker.com/v2/repositories/geldata/gel/tags?page_size=100' repo_tags=$( while [ -n "$url" ]; do resp=$(curl -L -s "$url") url=$(echo "$resp" | jq -r ".next") if [ "$url" = "null" ] || [ -z "$url" ]; then break fi echo "$resp" | jq -r '."results"[]["name"]' done | grep "^[[:digit:]]\+.*" | grep -v "alpha\|beta\|rc" || : ) tags=() if [ "$PKG_SUBDIST" = "nightly" ]; then tags+=( "nightly" "nightly_${VERSION_SLOT}_cv${CATALOG_VERSION}" ) else tags+=( "$VERSION_CORE" ) top=$(printf "%s\n%s\n" "$VERSION_CORE" "$repo_tags" \ | grep "^${VERSION_SLOT}[\.-]" \ | sort --version-sort --reverse | head -n 1) if [ "$top" == "$VERSION_CORE" ]; then tags+=( "$VERSION_SLOT" ) fi if [ -z "$PKG_SUBDIST" ]; then top=$(printf "%s\n%s\n" "$VERSION_CORE" "$repo_tags" \ | sort --version-sort --reverse | head -n 1) if [ "$top" == "$VERSION_CORE" ]; then tags+=( "latest" ) fi fi fi fq_tags=() images=("geldata/gel" "ghcr.io/geldata/gel") for image in "${images[@]}"; do fq_tags+=("${tags[@]/#/${image}:}") done IFS=, echo "tags=${fq_tags[*]}" >> $GITHUB_OUTPUT - name: Set up QEMU uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0 - name: Set up Docker Buildx uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - name: Build and Publish Docker Image uses: docker/build-push-action@471d1dc4e07e5cdedd4c2171150001c434f0b7a4 # v6.10.0 with: push: true provenance: mode=max tags: "${{ steps.tags.outputs.tags }}" context: dockerfile build-args: | version=${{ needs.check-published-debian-bookworm-x86_64.outputs.version-slot }} exact_version=${{ needs.check-published-debian-bookworm-x86_64.outputs.version-core }} subdist=testing platforms: linux/amd64,linux/arm64 workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - prep - collect - build-debian-buster-x86_64 - test-debian-buster-x86_64 - publish-debian-buster-x86_64 - check-published-debian-buster-x86_64 - build-debian-buster-aarch64 - test-debian-buster-aarch64 - publish-debian-buster-aarch64 - check-published-debian-buster-aarch64 - build-debian-bullseye-x86_64 - test-debian-bullseye-x86_64 - publish-debian-bullseye-x86_64 - check-published-debian-bullseye-x86_64 - build-debian-bullseye-aarch64 - test-debian-bullseye-aarch64 - publish-debian-bullseye-aarch64 - check-published-debian-bullseye-aarch64 - build-debian-bookworm-x86_64 - test-debian-bookworm-x86_64 - publish-debian-bookworm-x86_64 - check-published-debian-bookworm-x86_64 - build-debian-bookworm-aarch64 - test-debian-bookworm-aarch64 - publish-debian-bookworm-aarch64 - check-published-debian-bookworm-aarch64 - build-ubuntu-focal-x86_64 - test-ubuntu-focal-x86_64 - publish-ubuntu-focal-x86_64 - check-published-ubuntu-focal-x86_64 - build-ubuntu-focal-aarch64 - test-ubuntu-focal-aarch64 - publish-ubuntu-focal-aarch64 - check-published-ubuntu-focal-aarch64 - build-ubuntu-jammy-x86_64 - test-ubuntu-jammy-x86_64 - publish-ubuntu-jammy-x86_64 - check-published-ubuntu-jammy-x86_64 - build-ubuntu-jammy-aarch64 - test-ubuntu-jammy-aarch64 - publish-ubuntu-jammy-aarch64 - check-published-ubuntu-jammy-aarch64 - build-ubuntu-noble-x86_64 - test-ubuntu-noble-x86_64 - publish-ubuntu-noble-x86_64 - check-published-ubuntu-noble-x86_64 - build-ubuntu-noble-aarch64 - test-ubuntu-noble-aarch64 - publish-ubuntu-noble-aarch64 - check-published-ubuntu-noble-aarch64 - build-centos-8-x86_64 - test-centos-8-x86_64 - publish-centos-8-x86_64 - check-published-centos-8-x86_64 - build-centos-8-aarch64 - test-centos-8-aarch64 - publish-centos-8-aarch64 - check-published-centos-8-aarch64 - build-rockylinux-9-x86_64 - test-rockylinux-9-x86_64 - publish-rockylinux-9-x86_64 - check-published-rockylinux-9-x86_64 - build-rockylinux-9-aarch64 - test-rockylinux-9-aarch64 - publish-rockylinux-9-aarch64 - check-published-rockylinux-9-aarch64 - build-linux-x86_64 - test-linux-x86_64 - publish-linux-x86_64 - check-published-linux-x86_64 - build-linux-aarch64 - test-linux-aarch64 - publish-linux-aarch64 - check-published-linux-aarch64 - build-linuxmusl-x86_64 - test-linuxmusl-x86_64 - publish-linuxmusl-x86_64 - check-published-linuxmusl-x86_64 - build-linuxmusl-aarch64 - test-linuxmusl-aarch64 - publish-linuxmusl-aarch64 - check-published-linuxmusl-aarch64 - build-macos-x86_64 - test-macos-x86_64 - publish-macos-x86_64 - build-macos-aarch64 - test-macos-aarch64 - publish-macos-aarch64 - publish-docker runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/docs-preview-deploy.yml ================================================ name: Docs Preview Deploy on: pull_request: paths: - "docs/**" jobs: deploy: runs-on: ubuntu-latest permissions: write-all steps: - uses: actions/checkout@v4 - uses: actions/github-script@v7 env: VERCEL_TOKEN: ${{ secrets.VERCEL_TOKEN }} VERCEL_TEAM_ID: ${{ secrets.VERCEL_TEAM_ID }} with: script: | const script = require('./.github/scripts/docs/preview-deploy.js'); await script({github, context}); ================================================ FILE: .github/workflows/docs.yml ================================================ name: Deploy Documentation Changes on: push: branches: - master - release/** paths: - "docs/**" workflow_dispatch: jobs: deploy: runs-on: ubuntu-latest steps: - name: Trigger vercel deploy hook run: curl \ --fail-with-body \ --request POST \ ${{ secrets.VERCEL_DOC_DEPLOY_URL_HOOK }} ================================================ FILE: .github/workflows/pull-request-meta.yml ================================================ name: Pull Request Meta on: pull_request: types: [opened, edited, synchronize] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number }} cancel-in-progress: true jobs: test-pr: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Verify that postgres/ was not changed unintentionally env: PR_TITLE: ${{ github.event.pull_request.title }} shell: bash run: | required_prefix="Update bundled PostgreSQL" if [[ "$PR_TITLE" == $required_prefix* ]]; then exit 0 fi if git diff --quiet \ ${{ github.event.pull_request.base.sha }} \ ${{ github.event.pull_request.head.sha }} -- postgres/ then echo 'all ok' else echo "postgres/ submodule has been changed,"\ "but PR title does not indicate that" echo "(it should start with '$required_prefix')" exit 1 fi ================================================ FILE: .github/workflows/tests.ha.yml ================================================ name: High Availability Tests on: workflow_dispatch: inputs: {} workflow_run: workflows: ["Tests"] types: - completed jobs: build: runs-on: ubuntu-latest if: github.event.workflow_run.conclusion == 'success' || github.event_name == 'workflow_dispatch' steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Cached requirements.txt uses: actions/cache@v4 id: requirements-cache with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Compute requirements.txt if: steps.requirements-cache.outputs.cache-hit != 'true' run: | python -m pip install pip-tools pip-compile --no-strip-extras --all-build-deps \ --extra test,language-server \ --output-file requirements.txt pyproject.toml - name: Install Python dependencies run: | python -c "import sys; print(sys.prefix)" python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Our HA tests currently only work on Postgres 14 (see #6332), # so check it out before we compute our build cache keys. - name: Switch back to Postgres 14 shell: bash run: | set -e cd postgres # Fetch postgres 14, since the clone was shallow git fetch origin REL_14_8 --depth=1 # For whatever reason the tag doesn't get fetched, so find it # at FETCH_HEAD git checkout FETCH_HEAD - name: Compute cache keys env: GIST_TOKEN: ${{ secrets.CI_BOT_GIST_TOKEN }} run: | mkdir -p shared-artifacts if [ "$(uname)" = "Darwin" ]; then find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt else find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt fi python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV - name: Upload shared artifacts uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: shared-artifacts path: shared-artifacts retention-days: 1 # Restore binary cache - name: Handle cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-v4- - name: Handle cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Handle cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Handle cached libpg_query build uses: actions/cache@v4 id: libpg-query-cache with: path: edb/pgsql/parser/libpg_query/libpg_query.a key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} # Install system dependencies for building - name: Install system deps if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: steps.rust-cache.outputs.cache-hit != 'true' uses: dsherret/rust-toolchain-file@v1 # Build Rust extensions - name: Handle Rust extensions build cache uses: actions/cache@v4 if: steps.rust-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-build-v1- - name: Build Rust extensions env: CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p build/rust_extensions rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ python setup.py -v build_rust rsync -av ${BUILD_LIB}/ build/rust_extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/rust_extensions/edb/ ./edb/ # Build libpg_query - name: Build libpg_query if: | steps.libpg-query-cache.outputs.cache-hit != 'true' && steps.ext-cache.outputs.cache-hit != 'true' run: | python setup.py build_libpg_query # Build extensions - name: Handle Cython extensions build cache uses: actions/cache@v4 if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} BUILD_EXT_MODE: py-only run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/extensions rsync -av ./build/extensions/ ${BUILD_LIB}/ BUILD_EXT_MODE=py-only python setup.py -v build_ext rsync -av ${BUILD_LIB}/ ./build/extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/extensions/edb/ ./edb/ # Build parsers - name: Handle compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} restore-keys: | edb-parsers-v3- - name: Build parsers env: CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/lib rsync -av ./build/lib/ ${BUILD_LIB}/ python setup.py -v build_parsers rsync -av ${BUILD_LIB}/ ./build/lib/ rm -rf ${BUILD_LIB} fi rsync -av ./build/lib/edb/ ./edb/ # Build PostgreSQL - name: Build PostgreSQL env: CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" == "true" ]]; then cp build/postgres/install/stamp build/postgres/ else python setup.py build_postgres cp build/postgres/stamp build/postgres/install/ fi # Build Stolon - name: Set up Go if: steps.stolon-cache.outputs.cache-hit != 'true' uses: actions/setup-go@v2 with: go-version: 1.16 - uses: actions/checkout@v4 if: steps.stolon-cache.outputs.cache-hit != 'true' with: repository: edgedb/stolon path: build/stolon ref: ${{ env.STOLON_GIT_REV }} fetch-depth: 0 submodules: false - name: Build Stolon if: steps.stolon-cache.outputs.cache-hit != 'true' run: | mkdir -p build/stolon/bin/ curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul chmod +x build/stolon/bin/consul cd build/stolon && make # Install edgedb-server and populate egg-info - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Refresh the bootstrap cache - name: Handle bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} restore-keys: | edb-bootstrap-v2- - name: Bootstrap EdgeDB Server if: steps.bootstrap-cache.outputs.cache-hit != 'true' run: | edb server --bootstrap-only ha-test: needs: build runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test - name: Test env: SHARD: ${{ matrix.shard }} EDGEDB_TEST_HA: 1 EDGEDB_TEST_CONSUL_PATH: build/stolon/bin/consul EDGEDB_TEST_STOLON_CTL: build/stolon/bin/stolonctl EDGEDB_TEST_STOLON_SENTINEL: build/stolon/bin/stolon-sentinel EDGEDB_TEST_STOLON_KEEPER: build/stolon/bin/stolon-keeper run: | edb test -j1 -v -k test_ha_ workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - ha-test runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/tests.inplace.yml ================================================ name: Tests of in-place upgrades and patching on: schedule: - cron: "0 3 * * *" workflow_dispatch: inputs: {} push: branches: - "A-inplace*" jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Cached requirements.txt uses: actions/cache@v4 id: requirements-cache with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Compute requirements.txt if: steps.requirements-cache.outputs.cache-hit != 'true' run: | python -m pip install pip-tools pip-compile --no-strip-extras --all-build-deps \ --extra test,language-server \ --output-file requirements.txt pyproject.toml - name: Install Python dependencies run: | python -c "import sys; print(sys.prefix)" python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 - name: Compute cache keys run: | mkdir -p shared-artifacts if [ "$(uname)" = "Darwin" ]; then find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt else find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt fi python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV - name: Upload shared artifacts uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: shared-artifacts path: shared-artifacts retention-days: 1 # Restore binary cache - name: Handle cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-v4- - name: Handle cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Handle cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Handle cached libpg_query build uses: actions/cache@v4 id: libpg-query-cache with: path: edb/pgsql/parser/libpg_query/libpg_query.a key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} # Install system dependencies for building - name: Install system deps if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: steps.rust-cache.outputs.cache-hit != 'true' uses: dsherret/rust-toolchain-file@v1 # Build Rust extensions - name: Handle Rust extensions build cache uses: actions/cache@v4 if: steps.rust-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-build-v1- - name: Build Rust extensions env: CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p build/rust_extensions rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ python setup.py -v build_rust rsync -av ${BUILD_LIB}/ build/rust_extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/rust_extensions/edb/ ./edb/ # Build libpg_query - name: Build libpg_query if: | steps.libpg-query-cache.outputs.cache-hit != 'true' && steps.ext-cache.outputs.cache-hit != 'true' run: | python setup.py build_libpg_query # Build extensions - name: Handle Cython extensions build cache uses: actions/cache@v4 if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} BUILD_EXT_MODE: py-only run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/extensions rsync -av ./build/extensions/ ${BUILD_LIB}/ BUILD_EXT_MODE=py-only python setup.py -v build_ext rsync -av ${BUILD_LIB}/ ./build/extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/extensions/edb/ ./edb/ # Build parsers - name: Handle compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} restore-keys: | edb-parsers-v3- - name: Build parsers env: CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/lib rsync -av ./build/lib/ ${BUILD_LIB}/ python setup.py -v build_parsers rsync -av ${BUILD_LIB}/ ./build/lib/ rm -rf ${BUILD_LIB} fi rsync -av ./build/lib/edb/ ./edb/ # Build PostgreSQL - name: Build PostgreSQL env: CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" == "true" ]]; then cp build/postgres/install/stamp build/postgres/ else python setup.py build_postgres cp build/postgres/stamp build/postgres/install/ fi # Build Stolon - name: Set up Go if: steps.stolon-cache.outputs.cache-hit != 'true' uses: actions/setup-go@v2 with: go-version: 1.16 - uses: actions/checkout@v4 if: steps.stolon-cache.outputs.cache-hit != 'true' with: repository: edgedb/stolon path: build/stolon ref: ${{ env.STOLON_GIT_REV }} fetch-depth: 0 submodules: false - name: Build Stolon if: steps.stolon-cache.outputs.cache-hit != 'true' run: | mkdir -p build/stolon/bin/ curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul chmod +x build/stolon/bin/consul cd build/stolon && make # Install edgedb-server and populate egg-info - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Refresh the bootstrap cache - name: Handle bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} restore-keys: | edb-bootstrap-v2- - name: Bootstrap EdgeDB Server if: steps.bootstrap-cache.outputs.cache-hit != 'true' run: | edb server --bootstrap-only test-inplace: runs-on: ubuntu-latest needs: build strategy: fail-fast: false matrix: include: - flags: tests: - flags: --rollback-and-test tests: # Do the reapply test on a smaller selection of tests, since # it is slower. - flags: --rollback-and-reapply tests: -k test_link_on_target_delete -k test_edgeql_select -k test_dump steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test # TODO: Would it be better to split this up into multiple jobs? - name: Test performing in-place upgrades run: | ./tests/inplace-testing/test.sh ${{ matrix.flags }} vt ${{ matrix.tests }} test-patches: runs-on: ubuntu-latest needs: build steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] - name: Test performing in-place upgrades run: | ./tests/patch-testing/test.sh test-dir -k test_link_on_target_delete -k test_edgeql_select -k test_edgeql_scope -k test_dump compute-versions: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - uses: actions/checkout@v4 - id: set-matrix name: Compute versions to run on run: python3 .github/scripts/patches/compute-ipu-versions.py test: runs-on: ubuntu-latest needs: [build, compute-versions] strategy: fail-fast: false matrix: ${{fromJSON(needs.compute-versions.outputs.matrix)}} steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test - name: Download an earlier database version run: | wget -q "${{ matrix.edgedb-url }}" tar xzf ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }}.tar.gz - name: Make sure a CLI named "edgedb" exists (sigh) run: | ln -s gel $(dirname $(which gel))/edgedb - name: Test inplace upgrades from previous major version run: | ./tests/inplace-testing/test-old.sh vt ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }} workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - test-inplace - test-patches runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/tests.inplace7x.yml ================================================ name: Tests of in-place upgrades to 7.x on: schedule: - cron: "0 3 * * *" workflow_dispatch: inputs: {} push: branches: - "A-inplace*" jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false ref: release/7.x - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true ref: release/7.x - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Cached requirements.txt uses: actions/cache@v4 id: requirements-cache with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Compute requirements.txt if: steps.requirements-cache.outputs.cache-hit != 'true' run: | python -m pip install pip-tools pip-compile --no-strip-extras --all-build-deps \ --extra test,language-server \ --output-file requirements.txt pyproject.toml - name: Install Python dependencies run: | python -c "import sys; print(sys.prefix)" python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 - name: Compute cache keys run: | mkdir -p shared-artifacts if [ "$(uname)" = "Darwin" ]; then find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt else find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt fi python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV - name: Upload shared artifacts uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: shared-artifacts path: shared-artifacts retention-days: 1 # Restore binary cache - name: Handle cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-v4- - name: Handle cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Handle cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Handle cached libpg_query build uses: actions/cache@v4 id: libpg-query-cache with: path: edb/pgsql/parser/libpg_query/libpg_query.a key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} # Install system dependencies for building - name: Install system deps if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: steps.rust-cache.outputs.cache-hit != 'true' uses: dsherret/rust-toolchain-file@v1 # Build Rust extensions - name: Handle Rust extensions build cache uses: actions/cache@v4 if: steps.rust-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-build-v1- - name: Build Rust extensions env: CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p build/rust_extensions rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ python setup.py -v build_rust rsync -av ${BUILD_LIB}/ build/rust_extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/rust_extensions/edb/ ./edb/ # Build libpg_query - name: Build libpg_query if: | steps.libpg-query-cache.outputs.cache-hit != 'true' && steps.ext-cache.outputs.cache-hit != 'true' run: | python setup.py build_libpg_query # Build extensions - name: Handle Cython extensions build cache uses: actions/cache@v4 if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} BUILD_EXT_MODE: py-only run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/extensions rsync -av ./build/extensions/ ${BUILD_LIB}/ BUILD_EXT_MODE=py-only python setup.py -v build_ext rsync -av ${BUILD_LIB}/ ./build/extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/extensions/edb/ ./edb/ # Build parsers - name: Handle compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} restore-keys: | edb-parsers-v3- - name: Build parsers env: CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/lib rsync -av ./build/lib/ ${BUILD_LIB}/ python setup.py -v build_parsers rsync -av ${BUILD_LIB}/ ./build/lib/ rm -rf ${BUILD_LIB} fi rsync -av ./build/lib/edb/ ./edb/ # Build PostgreSQL - name: Build PostgreSQL env: CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" == "true" ]]; then cp build/postgres/install/stamp build/postgres/ else python setup.py build_postgres cp build/postgres/stamp build/postgres/install/ fi # Build Stolon - name: Set up Go if: steps.stolon-cache.outputs.cache-hit != 'true' uses: actions/setup-go@v2 with: go-version: 1.16 - uses: actions/checkout@v4 if: steps.stolon-cache.outputs.cache-hit != 'true' with: repository: edgedb/stolon path: build/stolon ref: ${{ env.STOLON_GIT_REV }} fetch-depth: 0 submodules: false - name: Build Stolon if: steps.stolon-cache.outputs.cache-hit != 'true' run: | mkdir -p build/stolon/bin/ curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul chmod +x build/stolon/bin/consul cd build/stolon && make # Install edgedb-server and populate egg-info - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Refresh the bootstrap cache - name: Handle bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} restore-keys: | edb-bootstrap-v2- - name: Bootstrap EdgeDB Server if: steps.bootstrap-cache.outputs.cache-hit != 'true' run: | edb server --bootstrap-only test-inplace: runs-on: ubuntu-latest needs: build strategy: fail-fast: false matrix: include: - flags: tests: - flags: --rollback-and-test tests: # Do the reapply test on a smaller selection of tests, since # it is slower. - flags: --rollback-and-reapply tests: -k test_link_on_target_delete -k test_edgeql_select -k test_dump steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false ref: release/7.x - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true ref: release/7.x - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test # TODO: Would it be better to split this up into multiple jobs? - name: Test performing in-place upgrades run: | ./tests/inplace-testing/test.sh ${{ matrix.flags }} vt ${{ matrix.tests }} compute-versions: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false ref: release/7.x - id: set-matrix name: Compute versions to run on run: python3 .github/scripts/patches/compute-ipu-versions.py test: runs-on: ubuntu-latest needs: [build, compute-versions] strategy: fail-fast: false matrix: ${{fromJSON(needs.compute-versions.outputs.matrix)}} steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false ref: release/7.x - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true ref: release/7.x - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test - name: Download an earlier database version run: | wget -q "${{ matrix.edgedb-url }}" tar xzf ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }}.tar.gz - name: Make sure a CLI named "edgedb" exists (sigh) run: | ln -s gel $(dirname $(which gel))/edgedb - name: Test inplace upgrades from previous major version run: | ./tests/inplace-testing/test-old.sh vt ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }} workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - test-inplace runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/tests.managed-pg.yml ================================================ name: Tests on Managed PostgreSQL on: schedule: - cron: "0 3 * * 6" workflow_dispatch: inputs: {} push: branches: - cloud-test jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Cached requirements.txt uses: actions/cache@v4 id: requirements-cache with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Compute requirements.txt if: steps.requirements-cache.outputs.cache-hit != 'true' run: | python -m pip install pip-tools pip-compile --no-strip-extras --all-build-deps \ --extra test,language-server \ --output-file requirements.txt pyproject.toml - name: Install Python dependencies run: | python -c "import sys; print(sys.prefix)" python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 - name: Compute cache keys run: | mkdir -p shared-artifacts if [ "$(uname)" = "Darwin" ]; then find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt else find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt fi python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV - name: Upload shared artifacts uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: shared-artifacts path: shared-artifacts retention-days: 1 # Restore binary cache - name: Handle cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-v4- - name: Handle cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Handle cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Handle cached libpg_query build uses: actions/cache@v4 id: libpg-query-cache with: path: edb/pgsql/parser/libpg_query/libpg_query.a key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} # Install system dependencies for building - name: Install system deps if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: steps.rust-cache.outputs.cache-hit != 'true' uses: dsherret/rust-toolchain-file@v1 # Build Rust extensions - name: Handle Rust extensions build cache uses: actions/cache@v4 if: steps.rust-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-build-v1- - name: Build Rust extensions env: CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p build/rust_extensions rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ python setup.py -v build_rust rsync -av ${BUILD_LIB}/ build/rust_extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/rust_extensions/edb/ ./edb/ # Build libpg_query - name: Build libpg_query if: | steps.libpg-query-cache.outputs.cache-hit != 'true' && steps.ext-cache.outputs.cache-hit != 'true' run: | python setup.py build_libpg_query # Build extensions - name: Handle Cython extensions build cache uses: actions/cache@v4 if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} BUILD_EXT_MODE: py-only run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/extensions rsync -av ./build/extensions/ ${BUILD_LIB}/ BUILD_EXT_MODE=py-only python setup.py -v build_ext rsync -av ${BUILD_LIB}/ ./build/extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/extensions/edb/ ./edb/ # Build parsers - name: Handle compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} restore-keys: | edb-parsers-v3- - name: Build parsers env: CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/lib rsync -av ./build/lib/ ${BUILD_LIB}/ python setup.py -v build_parsers rsync -av ${BUILD_LIB}/ ./build/lib/ rm -rf ${BUILD_LIB} fi rsync -av ./build/lib/edb/ ./edb/ # Build PostgreSQL - name: Build PostgreSQL env: CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" == "true" ]]; then cp build/postgres/install/stamp build/postgres/ else python setup.py build_postgres cp build/postgres/stamp build/postgres/install/ fi # Build Stolon - name: Set up Go if: steps.stolon-cache.outputs.cache-hit != 'true' uses: actions/setup-go@v2 with: go-version: 1.16 - uses: actions/checkout@v4 if: steps.stolon-cache.outputs.cache-hit != 'true' with: repository: edgedb/stolon path: build/stolon ref: ${{ env.STOLON_GIT_REV }} fetch-depth: 0 submodules: false - name: Build Stolon if: steps.stolon-cache.outputs.cache-hit != 'true' run: | mkdir -p build/stolon/bin/ curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul chmod +x build/stolon/bin/consul cd build/stolon && make # Install edgedb-server and populate egg-info - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Refresh the bootstrap cache - name: Handle bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} restore-keys: | edb-bootstrap-v2- - name: Bootstrap EdgeDB Server if: steps.bootstrap-cache.outputs.cache-hit != 'true' run: | edb server --bootstrap-only setup-aws-rds: runs-on: ubuntu-latest outputs: pghost: ${{ steps.pghost.outputs.stdout }} defaults: run: working-directory: .github/aws-rds steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v1 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} aws-region: us-east-2 - name: Setup AWS RDS env: TF_VAR_sg_id: ${{ secrets.AWS_SECURITY_GROUP }} TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} run: | terraform apply -auto-approve - name: Store Terraform state if: ${{ always() }} uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: aws-rds-tfstate path: .github/aws-rds/terraform.tfstate retention-days: 1 - name: Get RDS host id: pghost run: | terraform output -raw db_instance_address test-aws-rds: runs-on: ubuntu-latest needs: [setup-aws-rds, build] steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test - name: Test env: EDGEDB_TEST_BACKEND_DSN: postgres://edbtest:${{ secrets.AWS_RDS_PASSWORD }}@${{ needs.setup-aws-rds.outputs.pghost }}/postgres run: | edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode edb test -j2 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN teardown-aws-rds: runs-on: ubuntu-latest needs: test-aws-rds if: ${{ always() }} defaults: run: working-directory: .github/aws-rds steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v1 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} aws-region: us-east-2 - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: aws-rds-tfstate path: .github/aws-rds - name: Destroy AWS RDS run: terraform destroy -auto-approve env: TF_VAR_sg_id: ${{ secrets.AWS_SECURITY_GROUP }} TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} - name: Overwrite Terraform state uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: aws-rds-tfstate path: .github/aws-rds/terraform.tfstate retention-days: 1 setup-do-database: runs-on: ubuntu-latest defaults: run: working-directory: .github/do-database steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init - name: Setup DigitalOcean Database env: TF_VAR_do_token: ${{ secrets.DIGITALOCEAN_TOKEN }} run: | terraform apply -auto-approve - name: Store Terraform state if: ${{ always() }} uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: do-database-tfstate path: .github/do-database/terraform.tfstate retention-days: 1 test-do-database: runs-on: ubuntu-latest needs: [setup-do-database, build] steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] - name: Setup Terraform uses: hashicorp/setup-terraform@v1 - name: Initialize Terraform working-directory: .github/do-database run: terraform init - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: do-database-tfstate path: .github/do-database - name: Get Database host id: pghost working-directory: .github/do-database run: | terraform output -raw db_instance_address - name: Get Database port id: pgport working-directory: .github/do-database run: | terraform output -raw db_instance_port - name: Get Database user id: pguser working-directory: .github/do-database run: | terraform output -raw db_instance_user - name: Get Database password id: pgpass working-directory: .github/do-database run: | terraform output -raw db_instance_password - name: Get Database dbname id: pgdatabase working-directory: .github/do-database run: | terraform output -raw db_instance_database # Run the test - name: Test env: EDGEDB_TEST_BACKEND_DSN: postgres://${{ steps.pguser.outputs.stdout }}:${{ steps.pgpass.outputs.stdout }}@${{ steps.pghost.outputs.stdout }}:${{ steps.pgport.outputs.stdout }}/${{ steps.pgdatabase.outputs.stdout }} run: | edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode edb test -j2 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN teardown-do-database: runs-on: ubuntu-latest needs: test-do-database if: ${{ always() }} defaults: run: working-directory: .github/do-database steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: do-database-tfstate path: .github/do-database - name: Destroy DigitalOcean Database run: terraform destroy -auto-approve env: TF_VAR_do_token: ${{ secrets.DIGITALOCEAN_TOKEN }} - name: Overwrite Terraform state uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: do-database-tfstate path: .github/do-database/terraform.tfstate retention-days: 1 setup-gcp-cloud-sql: runs-on: ubuntu-latest outputs: pghost: ${{ steps.pghost.outputs.stdout }} defaults: run: working-directory: .github/gcp-cloud-sql steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init - name: Configure GCP Credentials uses: google-github-actions/setup-gcloud@main with: service_account_key: ${{ secrets.GCP_SA_KEY }} export_default_credentials: true - name: Setup GCP Cloud SQL env: TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} run: | terraform apply -auto-approve - name: Store Terraform state if: ${{ always() }} uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: gcp-cloud-sql-tfstate path: .github/gcp-cloud-sql/terraform.tfstate retention-days: 1 - name: Get Cloud SQL host id: pghost run: | terraform output -raw db_instance_address test-gcp-cloud-sql: runs-on: ubuntu-latest needs: [setup-gcp-cloud-sql, build] steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test - name: Test env: EDGEDB_TEST_BACKEND_DSN: postgres://postgres:${{ secrets.AWS_RDS_PASSWORD }}@${{ needs.setup-gcp-cloud-sql.outputs.pghost }}/postgres run: | edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode edb test -j2 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN teardown-gcp-cloud-sql: runs-on: ubuntu-latest needs: test-gcp-cloud-sql if: ${{ always() }} defaults: run: working-directory: .github/gcp-cloud-sql steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init - name: Configure GCP Credentials uses: google-github-actions/setup-gcloud@main with: service_account_key: ${{ secrets.GCP_SA_KEY }} export_default_credentials: true - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: gcp-cloud-sql-tfstate path: .github/gcp-cloud-sql - name: Destroy GCP Cloud SQL run: terraform destroy -auto-approve env: TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} - name: Overwrite Terraform state uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: gcp-cloud-sql-tfstate path: .github/gcp-cloud-sql/terraform.tfstate retention-days: 1 setup-aws-aurora: runs-on: ubuntu-latest outputs: pghost: ${{ steps.pghost.outputs.stdout }} defaults: run: working-directory: .github/aws-aurora steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v1 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} aws-region: us-east-2 - name: Setup AWS RDS Aurora env: TF_VAR_sg_id: ${{ secrets.AWS_SECURITY_GROUP }} TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} TF_VAR_vpc_id: ${{ secrets.AWS_VPC_ID }} run: | terraform apply -auto-approve - name: Store Terraform state if: ${{ always() }} uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: aws-aurora-tfstate path: .github/aws-aurora/terraform.tfstate retention-days: 1 - name: Get RDS Aurora host id: pghost run: | terraform output -raw rds_cluster_endpoint test-aws-aurora: runs-on: ubuntu-latest needs: [setup-aws-aurora, build] steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test - name: Test env: EDGEDB_TEST_BACKEND_DSN: postgres://edbtest:${{ secrets.AWS_RDS_PASSWORD }}@${{ needs.setup-aws-aurora.outputs.pghost }}/postgres run: | edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode edb test -j1 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN teardown-aws-aurora: runs-on: ubuntu-latest needs: test-aws-aurora if: ${{ always() }} defaults: run: working-directory: .github/aws-aurora steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v1 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} aws-region: us-east-2 - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: aws-aurora-tfstate path: .github/aws-aurora - name: Destroy AWS RDS Aurora run: terraform destroy -auto-approve env: TF_VAR_sg_id: ${{ secrets.AWS_SECURITY_GROUP }} TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} TF_VAR_vpc_id: ${{ secrets.AWS_VPC_ID }} - name: Overwrite Terraform state uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: aws-aurora-tfstate path: .github/aws-aurora/terraform.tfstate retention-days: 1 setup-heroku-postgres: runs-on: ubuntu-latest defaults: run: working-directory: .github/heroku-postgres steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init - name: Setup Heroku Postgres env: HEROKU_API_KEY: ${{ secrets.HEROKU_API_KEY }} HEROKU_EMAIL: ${{ secrets.HEROKU_EMAIL }} run: | terraform apply -auto-approve - name: Store Terraform state if: ${{ always() }} uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: heroku-postgres-tfstate path: .github/heroku-postgres/terraform.tfstate retention-days: 1 test-heroku-postgres: runs-on: ubuntu-latest needs: [setup-heroku-postgres, build] steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] - name: Setup Terraform uses: hashicorp/setup-terraform@v1 - name: Initialize Terraform working-directory: .github/heroku-postgres run: terraform init - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: heroku-postgres-tfstate path: .github/heroku-postgres - name: Get Heroku Postgres DSN id: pgdsn working-directory: .github/heroku-postgres run: | terraform output -raw heroku_postgres_dsn # Run the test - name: Test env: EDGEDB_TEST_BACKEND_VENDOR: heroku-postgres EDGEDB_TEST_BACKEND_DSN: ${{ steps.pgdsn.outputs.stdout }} run: | edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode edb test -j1 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN teardown-heroku-postgres: runs-on: ubuntu-latest needs: test-heroku-postgres if: ${{ always() }} defaults: run: working-directory: .github/heroku-postgres steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v1 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} aws-region: us-east-2 - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: heroku-postgres-tfstate path: .github/heroku-postgres - name: Destroy Heroku Postgres run: terraform destroy -auto-approve env: HEROKU_API_KEY: ${{ secrets.HEROKU_API_KEY }} HEROKU_EMAIL: ${{ secrets.HEROKU_EMAIL }} - name: Overwrite Terraform state uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: heroku-postgres-tfstate path: .github/heroku-postgres/terraform.tfstate retention-days: 1 workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - setup-aws-rds - test-aws-rds - teardown-aws-rds - setup-do-database - test-do-database - teardown-do-database - setup-gcp-cloud-sql - test-gcp-cloud-sql - teardown-gcp-cloud-sql - setup-aws-aurora - test-aws-aurora - teardown-aws-aurora - setup-heroku-postgres - test-heroku-postgres - teardown-heroku-postgres runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/tests.patches.yml ================================================ name: Tests of patching old EdgeDB Versions on: workflow_dispatch: inputs: {} pull_request: branches: - release/* push: branches: - patch-test* - release/* jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Cached requirements.txt uses: actions/cache@v4 id: requirements-cache with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Compute requirements.txt if: steps.requirements-cache.outputs.cache-hit != 'true' run: | python -m pip install pip-tools pip-compile --no-strip-extras --all-build-deps \ --extra test,language-server \ --output-file requirements.txt pyproject.toml - name: Install Python dependencies run: | python -c "import sys; print(sys.prefix)" python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 - name: Compute cache keys run: | mkdir -p shared-artifacts if [ "$(uname)" = "Darwin" ]; then find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt else find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt fi python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV - name: Upload shared artifacts uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: shared-artifacts path: shared-artifacts retention-days: 1 # Restore binary cache - name: Handle cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-v4- - name: Handle cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Handle cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Handle cached libpg_query build uses: actions/cache@v4 id: libpg-query-cache with: path: edb/pgsql/parser/libpg_query/libpg_query.a key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} # Install system dependencies for building - name: Install system deps if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: steps.rust-cache.outputs.cache-hit != 'true' uses: dsherret/rust-toolchain-file@v1 # Build Rust extensions - name: Handle Rust extensions build cache uses: actions/cache@v4 if: steps.rust-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-build-v1- - name: Build Rust extensions env: CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p build/rust_extensions rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ python setup.py -v build_rust rsync -av ${BUILD_LIB}/ build/rust_extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/rust_extensions/edb/ ./edb/ # Build libpg_query - name: Build libpg_query if: | steps.libpg-query-cache.outputs.cache-hit != 'true' && steps.ext-cache.outputs.cache-hit != 'true' run: | python setup.py build_libpg_query # Build extensions - name: Handle Cython extensions build cache uses: actions/cache@v4 if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} BUILD_EXT_MODE: py-only run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/extensions rsync -av ./build/extensions/ ${BUILD_LIB}/ BUILD_EXT_MODE=py-only python setup.py -v build_ext rsync -av ${BUILD_LIB}/ ./build/extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/extensions/edb/ ./edb/ # Build parsers - name: Handle compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} restore-keys: | edb-parsers-v3- - name: Build parsers env: CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/lib rsync -av ./build/lib/ ${BUILD_LIB}/ python setup.py -v build_parsers rsync -av ${BUILD_LIB}/ ./build/lib/ rm -rf ${BUILD_LIB} fi rsync -av ./build/lib/edb/ ./edb/ # Build PostgreSQL - name: Build PostgreSQL env: CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" == "true" ]]; then cp build/postgres/install/stamp build/postgres/ else python setup.py build_postgres cp build/postgres/stamp build/postgres/install/ fi # Build Stolon - name: Set up Go if: steps.stolon-cache.outputs.cache-hit != 'true' uses: actions/setup-go@v2 with: go-version: 1.16 - uses: actions/checkout@v4 if: steps.stolon-cache.outputs.cache-hit != 'true' with: repository: edgedb/stolon path: build/stolon ref: ${{ env.STOLON_GIT_REV }} fetch-depth: 0 submodules: false - name: Build Stolon if: steps.stolon-cache.outputs.cache-hit != 'true' run: | mkdir -p build/stolon/bin/ curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul chmod +x build/stolon/bin/consul cd build/stolon && make # Install edgedb-server and populate egg-info - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Refresh the bootstrap cache - name: Handle bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} restore-keys: | edb-bootstrap-v2- - name: Bootstrap EdgeDB Server if: steps.bootstrap-cache.outputs.cache-hit != 'true' run: | edb server --bootstrap-only compute-versions: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - uses: actions/checkout@v4 - id: set-matrix name: Compute versions to run on run: python3 .github/scripts/patches/compute-versions.py test: runs-on: ubuntu-latest needs: [build, compute-versions] strategy: fail-fast: false matrix: ${{fromJSON(needs.compute-versions.outputs.matrix)}} steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test - name: Download an earlier database version and set up a instance run: | wget -q "${{ matrix.edgedb-url }}" tar xzf ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }}.tar.gz ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }}/bin/edgedb-server -D test-dir --bootstrap-only --testmode - name: Create databases on the older version if: ${{ matrix.make-dbs }} run: python3 .github/scripts/patches/create-databases.py ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }}/bin/edgedb-server - name: Run tests with instance created on an older version run: | # Run the server explicitly first to do the upgrade, since edb test # has timeouts. edb server --bootstrap-only --data-dir test-dir # Should we run *all* the tests? edb test -j2 -v --data-dir test-dir tests/test_edgeql_json.py tests/test_edgeql_casts.py tests/test_edgeql_functions.py tests/test_edgeql_expressions.py tests/test_edgeql_policies.py tests/test_edgeql_vector.py tests/test_edgeql_scope.py tests/test_http_ext_auth.py - name: Test downgrading a database after an upgrade if: ${{ !contains(matrix.edgedb-version, '-rc') && !contains(matrix.edgedb-version, '-beta') }} env: EDGEDB_VERSION: ${{ matrix.edgedb-version }} run: python3 .github/scripts/patches/test-downgrade.py workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - compute-versions - test runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/tests.pg-versions.yml ================================================ name: Tests on PostgreSQL Versions on: schedule: - cron: "0 3 * * *" workflow_dispatch: inputs: {} push: branches: - pg-test jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Cached requirements.txt uses: actions/cache@v4 id: requirements-cache with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Compute requirements.txt if: steps.requirements-cache.outputs.cache-hit != 'true' run: | python -m pip install pip-tools pip-compile --no-strip-extras --all-build-deps \ --extra test,language-server \ --output-file requirements.txt pyproject.toml - name: Install Python dependencies run: | python -c "import sys; print(sys.prefix)" python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 - name: Compute cache keys run: | mkdir -p shared-artifacts if [ "$(uname)" = "Darwin" ]; then find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt else find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt fi python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV - name: Upload shared artifacts uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: shared-artifacts path: shared-artifacts retention-days: 1 # Restore binary cache - name: Handle cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-v4- - name: Handle cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Handle cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Handle cached libpg_query build uses: actions/cache@v4 id: libpg-query-cache with: path: edb/pgsql/parser/libpg_query/libpg_query.a key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} # Install system dependencies for building - name: Install system deps if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: steps.rust-cache.outputs.cache-hit != 'true' uses: dsherret/rust-toolchain-file@v1 # Build Rust extensions - name: Handle Rust extensions build cache uses: actions/cache@v4 if: steps.rust-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-build-v1- - name: Build Rust extensions env: CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p build/rust_extensions rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ python setup.py -v build_rust rsync -av ${BUILD_LIB}/ build/rust_extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/rust_extensions/edb/ ./edb/ # Build libpg_query - name: Build libpg_query if: | steps.libpg-query-cache.outputs.cache-hit != 'true' && steps.ext-cache.outputs.cache-hit != 'true' run: | python setup.py build_libpg_query # Build extensions - name: Handle Cython extensions build cache uses: actions/cache@v4 if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} BUILD_EXT_MODE: py-only run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/extensions rsync -av ./build/extensions/ ${BUILD_LIB}/ BUILD_EXT_MODE=py-only python setup.py -v build_ext rsync -av ${BUILD_LIB}/ ./build/extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/extensions/edb/ ./edb/ # Build parsers - name: Handle compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} restore-keys: | edb-parsers-v3- - name: Build parsers env: CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/lib rsync -av ./build/lib/ ${BUILD_LIB}/ python setup.py -v build_parsers rsync -av ${BUILD_LIB}/ ./build/lib/ rm -rf ${BUILD_LIB} fi rsync -av ./build/lib/edb/ ./edb/ # Build PostgreSQL - name: Build PostgreSQL env: CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" == "true" ]]; then cp build/postgres/install/stamp build/postgres/ else python setup.py build_postgres cp build/postgres/stamp build/postgres/install/ fi # Build Stolon - name: Set up Go if: steps.stolon-cache.outputs.cache-hit != 'true' uses: actions/setup-go@v2 with: go-version: 1.16 - uses: actions/checkout@v4 if: steps.stolon-cache.outputs.cache-hit != 'true' with: repository: edgedb/stolon path: build/stolon ref: ${{ env.STOLON_GIT_REV }} fetch-depth: 0 submodules: false - name: Build Stolon if: steps.stolon-cache.outputs.cache-hit != 'true' run: | mkdir -p build/stolon/bin/ curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul chmod +x build/stolon/bin/consul cd build/stolon && make # Install edgedb-server and populate egg-info - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Refresh the bootstrap cache - name: Handle bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} restore-keys: | edb-bootstrap-v2- - name: Bootstrap EdgeDB Server if: steps.bootstrap-cache.outputs.cache-hit != 'true' run: | edb server --bootstrap-only test: runs-on: ubuntu-latest needs: build strategy: fail-fast: false matrix: postgres-version: [ 17 ] single-mode: - '' # These are very broken. Disabling them for now until we # decide whether to fix them or give up. # - 'NOCREATEDB NOCREATEROLE' # - 'CREATEDB NOCREATEROLE' multi-tenant-mode: [ '' ] include: - postgres-version: 14 single-mode: '' multi-tenant-mode: '' - postgres-version: 15 single-mode: '' multi-tenant-mode: '' - postgres-version: 16 single-mode: '' multi-tenant-mode: '' - postgres-version: 17 single-mode: '' multi-tenant-mode: 'remote-compiler' - postgres-version: 17 single-mode: '' multi-tenant-mode: 'multi-tenant' services: postgres: image: pgvector/pgvector:0.7.4-pg${{ matrix.postgres-version }} env: POSTGRES_PASSWORD: postgres options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 --name postgres ports: - 5432:5432 steps: - name: Trust pgvector extension uses: docker://docker with: args: docker exec postgres sed -i $a\trusted=true /usr/share/postgresql/${{ matrix.postgres-version }}/extension/vector.control - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test - name: Setup single mode role and database if: ${{ matrix.single-mode }} shell: python run: | import asyncio import subprocess from edb.server.pgcluster import get_pg_bin_dir async def main(): psql = await get_pg_bin_dir() / "psql" dsn = "postgres://postgres:postgres@localhost/postgres" script = """\ CREATE ROLE singles; ALTER ROLE singles WITH LOGIN PASSWORD 'test' NOSUPERUSER ${{ matrix.single-mode }}; CREATE DATABASE singles OWNER singles; REVOKE ALL ON DATABASE singles FROM PUBLIC; GRANT CONNECT ON DATABASE singles TO singles; GRANT ALL ON DATABASE singles TO singles; """ subprocess.run( [str(psql), dsn], check=True, text=True, input=script, ) asyncio.run(main()) - name: Test env: EDGEDB_TEST_POSTGRES_VERSION: ${{ matrix.postgres-version }} run: | if [[ "${{ matrix.single-mode }}" ]]; then export EDGEDB_TEST_BACKEND_DSN=postgres://singles:test@localhost/singles else export EDGEDB_TEST_BACKEND_DSN=postgres://postgres:postgres@localhost/postgres fi if [[ "${{ matrix.multi-tenant-mode }}" == "remote-compiler" ]]; then export EDGEDB_TEST_REMOTE_COMPILER=localhost:5660 export _EDGEDB_SERVER_COMPILER_POOL_SECRET=secret __EDGEDB_DEVMODE=1 edgedb-server compiler --pool-size 2 & fi edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode if [[ "${{ matrix.multi-tenant-mode }}" == "multi-tenant" ]]; then export EDGEDB_SERVER_MULTITENANT_CONFIG_FILE=/tmp/edb.mt.json echo "{\"localhost\":{\"instance-name\":\"localtest\",\"backend-dsn\":\"$EDGEDB_TEST_BACKEND_DSN\",\"admin\":true,\"max-backend-connections\":10}}" > /tmp/edb.mt.json fi if [[ "${{ matrix.single-mode }}" == *"NOCREATEDB"* ]]; then edb test -j1 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN else edb test -j2 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN fi workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - test runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/tests.pool.yml ================================================ name: Pool Simulation Test on: push: branches: - master - pool-test paths: - 'edb/server/connpool/**' - 'edb/server/conn_pool/**' - 'tests/test_server_pool.py' - '.github/workflows/tests-pool.yml' pull_request: branches: - master paths: - 'edb/server/connpool/**' - 'edb/server/conn_pool/**' - 'tests/test_server_pool.py' - '.github/workflows/tests-pool.yml' jobs: test: runs-on: ubuntu-latest concurrency: pool-test steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Cached requirements.txt uses: actions/cache@v4 id: requirements-cache with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Compute requirements.txt if: steps.requirements-cache.outputs.cache-hit != 'true' run: | python -m pip install pip-tools pip-compile --no-strip-extras --all-build-deps \ --extra test,language-server \ --output-file requirements.txt pyproject.toml - name: Install Python dependencies run: | python -c "import sys; print(sys.prefix)" python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 - name: Compute cache keys run: | mkdir -p shared-artifacts if [ "$(uname)" = "Darwin" ]; then find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt else find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt fi python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV - name: Upload shared artifacts uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: shared-artifacts path: shared-artifacts retention-days: 1 # Restore binary cache - name: Handle cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-v4- - name: Handle cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Handle cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Handle cached libpg_query build uses: actions/cache@v4 id: libpg-query-cache with: path: edb/pgsql/parser/libpg_query/libpg_query.a key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} # Install system dependencies for building - name: Install system deps if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: steps.rust-cache.outputs.cache-hit != 'true' uses: dsherret/rust-toolchain-file@v1 # Build Rust extensions - name: Handle Rust extensions build cache uses: actions/cache@v4 if: steps.rust-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-build-v1- - name: Build Rust extensions env: CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p build/rust_extensions rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ python setup.py -v build_rust rsync -av ${BUILD_LIB}/ build/rust_extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/rust_extensions/edb/ ./edb/ # Build libpg_query - name: Build libpg_query if: | steps.libpg-query-cache.outputs.cache-hit != 'true' && steps.ext-cache.outputs.cache-hit != 'true' run: | python setup.py build_libpg_query # Build extensions - name: Handle Cython extensions build cache uses: actions/cache@v4 if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} BUILD_EXT_MODE: py-only run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/extensions rsync -av ./build/extensions/ ${BUILD_LIB}/ BUILD_EXT_MODE=py-only python setup.py -v build_ext rsync -av ${BUILD_LIB}/ ./build/extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/extensions/edb/ ./edb/ # Build parsers - name: Handle compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} restore-keys: | edb-parsers-v3- - name: Build parsers env: CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/lib rsync -av ./build/lib/ ${BUILD_LIB}/ python setup.py -v build_parsers rsync -av ${BUILD_LIB}/ ./build/lib/ rm -rf ${BUILD_LIB} fi rsync -av ./build/lib/edb/ ./edb/ # Build PostgreSQL - name: Build PostgreSQL env: CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" == "true" ]]; then cp build/postgres/install/stamp build/postgres/ else python setup.py build_postgres cp build/postgres/stamp build/postgres/install/ fi # Build Stolon - name: Set up Go if: steps.stolon-cache.outputs.cache-hit != 'true' uses: actions/setup-go@v2 with: go-version: 1.16 - uses: actions/checkout@v4 if: steps.stolon-cache.outputs.cache-hit != 'true' with: repository: edgedb/stolon path: build/stolon ref: ${{ env.STOLON_GIT_REV }} fetch-depth: 0 submodules: false - name: Build Stolon if: steps.stolon-cache.outputs.cache-hit != 'true' run: | mkdir -p build/stolon/bin/ curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul chmod +x build/stolon/bin/consul cd build/stolon && make # Install edgedb-server and populate egg-info - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Refresh the bootstrap cache - name: Handle bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} restore-keys: | edb-bootstrap-v2- - name: Bootstrap EdgeDB Server if: steps.bootstrap-cache.outputs.cache-hit != 'true' run: | edb server --bootstrap-only - uses: actions/checkout@v4 if: startsWith(github.ref, 'refs/heads') with: repository: edgedb/edgedb-pool-simulation path: pool-simulation token: ${{ secrets.GITHUB_CI_BOT_TOKEN }} - name: Run the pool simulation test env: PYTHONPATH: . SIMULATION_CI: yes TIME_SCALE: 10 run: | mkdir -p pool-simulation/reports python tests/test_server_pool.py - uses: EndBug/add-and-commit@v7.0.0 if: ${{ always() }} continue-on-error: true with: branch: main cwd: pool-simulation author_name: github-actions author_email: 41898282+github-actions[bot]@users.noreply.github.com ================================================ FILE: .github/workflows/tests.reflection.yml ================================================ name: Tests with reflection validation on: schedule: - cron: "0 3 * * *" workflow_dispatch: inputs: {} push: branches: - "REFL-*" jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Cached requirements.txt uses: actions/cache@v4 id: requirements-cache with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Compute requirements.txt if: steps.requirements-cache.outputs.cache-hit != 'true' run: | python -m pip install pip-tools pip-compile --no-strip-extras --all-build-deps \ --extra test,language-server \ --output-file requirements.txt pyproject.toml - name: Install Python dependencies run: | python -c "import sys; print(sys.prefix)" python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 - name: Compute cache keys env: GIST_TOKEN: ${{ secrets.CI_BOT_GIST_TOKEN }} run: | mkdir -p shared-artifacts if [ "$(uname)" = "Darwin" ]; then find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt else find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt fi python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV - name: Upload shared artifacts uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: shared-artifacts path: shared-artifacts retention-days: 1 # Restore binary cache - name: Handle cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-v4- - name: Handle cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Handle cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Handle cached libpg_query build uses: actions/cache@v4 id: libpg-query-cache with: path: edb/pgsql/parser/libpg_query/libpg_query.a key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} # Install system dependencies for building - name: Install system deps if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: steps.rust-cache.outputs.cache-hit != 'true' uses: dsherret/rust-toolchain-file@v1 # Build Rust extensions - name: Handle Rust extensions build cache uses: actions/cache@v4 if: steps.rust-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-build-v1- - name: Build Rust extensions env: CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p build/rust_extensions rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ python setup.py -v build_rust rsync -av ${BUILD_LIB}/ build/rust_extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/rust_extensions/edb/ ./edb/ # Build libpg_query - name: Build libpg_query if: | steps.libpg-query-cache.outputs.cache-hit != 'true' && steps.ext-cache.outputs.cache-hit != 'true' run: | python setup.py build_libpg_query # Build extensions - name: Handle Cython extensions build cache uses: actions/cache@v4 if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} BUILD_EXT_MODE: py-only run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/extensions rsync -av ./build/extensions/ ${BUILD_LIB}/ BUILD_EXT_MODE=py-only python setup.py -v build_ext rsync -av ${BUILD_LIB}/ ./build/extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/extensions/edb/ ./edb/ # Build parsers - name: Handle compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} restore-keys: | edb-parsers-v3- - name: Build parsers env: CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/lib rsync -av ./build/lib/ ${BUILD_LIB}/ python setup.py -v build_parsers rsync -av ${BUILD_LIB}/ ./build/lib/ rm -rf ${BUILD_LIB} fi rsync -av ./build/lib/edb/ ./edb/ # Build PostgreSQL - name: Build PostgreSQL env: CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" == "true" ]]; then cp build/postgres/install/stamp build/postgres/ else python setup.py build_postgres cp build/postgres/stamp build/postgres/install/ fi # Build Stolon - name: Set up Go if: steps.stolon-cache.outputs.cache-hit != 'true' uses: actions/setup-go@v2 with: go-version: 1.16 - uses: actions/checkout@v4 if: steps.stolon-cache.outputs.cache-hit != 'true' with: repository: edgedb/stolon path: build/stolon ref: ${{ env.STOLON_GIT_REV }} fetch-depth: 0 submodules: false - name: Build Stolon if: steps.stolon-cache.outputs.cache-hit != 'true' run: | mkdir -p build/stolon/bin/ curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul chmod +x build/stolon/bin/consul cd build/stolon && make # Install edgedb-server and populate egg-info - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Refresh the bootstrap cache - name: Handle bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} restore-keys: | edb-bootstrap-v2- - name: Bootstrap EdgeDB Server if: steps.bootstrap-cache.outputs.cache-hit != 'true' run: | edb server --bootstrap-only test: needs: build runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test - name: Test env: EDGEDB_DEBUG_DELTA_VALIDATE_REFLECTION: 1 run: | edb test -j2 -v workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - test runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows/tests.yml ================================================ name: Tests on: push: branches: - master - ci - "release/*" pull_request: branches: - '**' schedule: - cron: "0 */3 * * *" jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Cached requirements.txt uses: actions/cache@v4 id: requirements-cache with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Compute requirements.txt if: steps.requirements-cache.outputs.cache-hit != 'true' run: | python -m pip install pip-tools pip-compile --no-strip-extras --all-build-deps \ --extra test,language-server \ --output-file requirements.txt pyproject.toml - name: Install Python dependencies run: | python -c "import sys; print(sys.prefix)" python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 - name: Compute cache keys and download the running times log env: GIST_TOKEN: ${{ secrets.CI_BOT_GIST_TOKEN }} run: | mkdir -p shared-artifacts if [ "$(uname)" = "Darwin" ]; then find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt else find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt fi python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV curl \ -H "Accept: application/vnd.github.v3+json" \ -u edgedb-ci:$GIST_TOKEN \ https://api.github.com/gists/8b722a65397f7c4c0df72f5394efa04c \ | jq '.files."time_stats.csv".raw_url' \ | xargs curl > shared-artifacts/time_stats.csv - name: Upload shared artifacts uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: shared-artifacts path: shared-artifacts retention-days: 1 # Restore binary cache - name: Handle cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-v4- - name: Handle cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Handle cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Handle cached libpg_query build uses: actions/cache@v4 id: libpg-query-cache with: path: edb/pgsql/parser/libpg_query/libpg_query.a key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} # Install system dependencies for building - name: Install system deps if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: steps.rust-cache.outputs.cache-hit != 'true' uses: dsherret/rust-toolchain-file@v1 # Build Rust extensions - name: Handle Rust extensions build cache uses: actions/cache@v4 if: steps.rust-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-build-v1- - name: Build Rust extensions env: CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p build/rust_extensions rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ python setup.py -v build_rust rsync -av ${BUILD_LIB}/ build/rust_extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/rust_extensions/edb/ ./edb/ # Build libpg_query - name: Build libpg_query if: | steps.libpg-query-cache.outputs.cache-hit != 'true' && steps.ext-cache.outputs.cache-hit != 'true' run: | python setup.py build_libpg_query # Build extensions - name: Handle Cython extensions build cache uses: actions/cache@v4 if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} BUILD_EXT_MODE: py-only run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/extensions rsync -av ./build/extensions/ ${BUILD_LIB}/ BUILD_EXT_MODE=py-only python setup.py -v build_ext rsync -av ${BUILD_LIB}/ ./build/extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/extensions/edb/ ./edb/ # Build parsers - name: Handle compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} restore-keys: | edb-parsers-v3- - name: Build parsers env: CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/lib rsync -av ./build/lib/ ${BUILD_LIB}/ python setup.py -v build_parsers rsync -av ${BUILD_LIB}/ ./build/lib/ rm -rf ${BUILD_LIB} fi rsync -av ./build/lib/edb/ ./edb/ # Build PostgreSQL - name: Build PostgreSQL env: CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" == "true" ]]; then cp build/postgres/install/stamp build/postgres/ else python setup.py build_postgres cp build/postgres/stamp build/postgres/install/ fi # Build Stolon - name: Set up Go if: steps.stolon-cache.outputs.cache-hit != 'true' uses: actions/setup-go@v2 with: go-version: 1.16 - uses: actions/checkout@v4 if: steps.stolon-cache.outputs.cache-hit != 'true' with: repository: edgedb/stolon path: build/stolon ref: ${{ env.STOLON_GIT_REV }} fetch-depth: 0 submodules: false - name: Build Stolon if: steps.stolon-cache.outputs.cache-hit != 'true' run: | mkdir -p build/stolon/bin/ curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul chmod +x build/stolon/bin/consul cd build/stolon && make # Install edgedb-server and populate egg-info - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Refresh the bootstrap cache - name: Handle bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} restore-keys: | edb-bootstrap-v2- - name: Bootstrap EdgeDB Server if: steps.bootstrap-cache.outputs.cache-hit != 'true' run: | edb server --bootstrap-only cargo-test: needs: build runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 - name: Download cache key uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Generate environment variables run: | echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV - name: Handle Rust extensions build cache uses: actions/cache@v4 id: rust-cache with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 - name: Cargo test env: CARGO_TARGET_DIR: ${{ env.BUILD_TEMP }}/rust/extensions CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home run: cargo test --all-features python-test: needs: build runs-on: ubuntu-latest strategy: fail-fast: false matrix: shard: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ] steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Run the test - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 - name: Test env: SHARD: ${{ matrix.shard }} EDGEDB_TEST_REPEATS: 1 run: | mkdir -p results/ cp shared-artifacts/time_stats.csv results/running_times_${SHARD}.csv edb test --jobs 2 --verbose --shard ${SHARD}/16 \ --running-times-log=results/running_times_${SHARD}.csv \ --result-log=results/result_${SHARD}.json - name: Upload test results uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 if: ${{ always() }} with: name: python-test-results-${{ matrix.shard }} path: results retention-days: 1 python-test-list: needs: build runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # List tests and upload - name: Generate complete list of tests for verification env: SHARD: ${{ matrix.shard }} EDGEDB_TEST_REPEATS: 1 run: | edb test --list > shared-artifacts/all_tests.txt - name: Upload list of tests uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: test-list path: shared-artifacts retention-days: 1 test-conclusion: needs: [cargo-test, python-test, python-test-list] runs-on: ubuntu-latest if: ${{ always() }} steps: - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.12.2' - name: Install Python deps run: | python -m pip install requests click - uses: actions/checkout@v4 with: submodules: false - name: Download python-test results uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: pattern: python-test-results-* merge-multiple: true path: results # Render results and exit if they were unsuccessful - name: Render results run: | python edb/tools/test/results.py 'results/result_*.json' - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Download test list uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: test-list path: shared-artifacts - name: Merge stats and verify tests completion shell: python env: GIST_TOKEN: ${{ secrets.CI_BOT_GIST_TOKEN }} GIT_REF: ${{ github.ref }} run: | import csv import glob import io import os import requests orig = {} new = {} all_tests = set() with open("shared-artifacts/time_stats.csv") as f: for name, t, c in csv.reader(f): assert name not in orig, "duplicate test name in original stats!" orig[name] = (t, int(c)) with open("shared-artifacts/all_tests.txt") as f: for line in f: assert line not in all_tests, "duplicate test name in this run!" all_tests.add(line.strip()) for new_file in glob.glob("results/running_times_*.csv"): with open(new_file) as f: for name, t, c in csv.reader(f): if int(c) > orig.get(name, (0, 0))[1]: if name.startswith("setup::"): new[name] = (t, c) else: assert name not in new, f"duplicate test! {name}" new[name] = (t, c) all_tests.remove(name) assert not all_tests, "Tests not run! \n" + "\n".join(all_tests) if os.environ["GIT_REF"] == "refs/heads/master": buf = io.StringIO() writer = csv.writer(buf) orig.update(new) for k, v in sorted(orig.items()): writer.writerow((k,) + v) resp = requests.patch( "https://api.github.com/gists/8b722a65397f7c4c0df72f5394efa04c", headers={"Accept": "application/vnd.github.v3+json"}, auth=("edgedb-ci", os.environ["GIST_TOKEN"]), json={"files": {"time_stats.csv": {"content": buf.getvalue()}}}, ) resp.raise_for_status() workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - test-conclusion runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows.src/build.dryrun.tpl.yml ================================================ <% from "build.inc.yml" import workflow, workflow_dispatch -%> name: Package Build Dry Run on: <<- workflow_dispatch() >> jobs: <<- workflow(package, targets, [], subdist="nightly") ->> ================================================ FILE: .github/workflows.src/build.inc.yml ================================================ <% macro workflow_dispatch() %> workflow_dispatch: inputs: gelpkg_ref: description: "gel-pkg git ref used to build the packages" default: "master" metapkg_ref: description: "metapkg git ref used to build the packages" default: "master" <%- endmacro %> <% macro workflow(package, targets, publications, subdist="", publish_all=False) %> prep: runs-on: ubuntu-latest <% if subdist == "nightly" %> outputs: <% for tgt in targets.linux + targets.macos %> if_<< tgt.name.replace('-', '_') >>: ${{ steps.scm.outputs.if_<< tgt.name.replace('-', '_') >> }} <% endfor %> <% endif %> steps: - uses: actions/checkout@v4 <% if subdist == "nightly" %> - name: Determine SCM revision id: scm shell: bash run: | rev=$(git rev-parse HEAD) jq_filter='.packages[] | select(.basename == "<< package.basename >>") | select(.architecture == $ARCH) | .version_details.metadata.scm_revision | . as $rev | select(($rev != null) and ($REV | startswith($rev)))' <% for tgt in targets.linux + targets.macos %> key="<< tgt.name >>" val=true <% if tgt.family == "debian" %> idx_file=<< tgt.platform_version >>.nightly.json url=https://packages.edgedb.com/apt/.jsonindexes/$idx_file <% elif tgt.family == "redhat" %> idx_file=el<< tgt.platform_version >>.nightly.json url=https://packages.edgedb.com/rpm/.jsonindexes/$idx_file <% elif tgt.family == "generic" %> idx_file=<< tgt.platform_version >>-unknown-linux-<< "{}".format(tgt.platform_libc) if tgt.platform_libc else "gnu" >>.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file <% elif tgt.platform == "macos" %> idx_file=<< tgt.platform_version >>-apple-darwin.nightly.json url=https://packages.edgedb.com/archive/.jsonindexes/$idx_file <% endif %> tmp_file="/tmp/$idx_file" if [ ! -e "$tmp_file" ]; then curl --fail -o $tmp_file -s $url || true fi if [ -e "$tmp_file" ]; then out=$(< "$tmp_file" jq -r --arg REV "$rev" --arg ARCH "<< tgt.arch >>" "$jq_filter") if [ -n "$out" ]; then echo "Skip rebuilding existing ${key}" val=false fi fi echo if_${key//-/_}="$val" >> $GITHUB_OUTPUT <% endfor %> <% endif %> <%- for tgt in targets.linux %> <%- set plat_id = tgt.platform + ("{}".format(tgt.platform_libc) if tgt.platform_libc else "") + ("-{}".format(tgt.platform_version) if tgt.platform_version else "") %> build-<< tgt.name >>: runs-on: << tgt.runs_on if tgt.runs_on else "ubuntu-latest" >> needs: prep <% if subdist == "nightly" %> if: needs.prep.outputs.if_<< tgt.name.replace('-', '_') >> == 'true' <% endif %> steps: - name: Build uses: docker://ghcr.io/geldata/gelpkg-build-<< plat_id >>:latest env: PACKAGE: "<< package.name >>" SRC_REF: "${{ github.sha }}" PKG_REVISION: "" <%- if subdist != "" %> PKG_SUBDIST: "<< subdist >>" <%- endif %> PKG_PLATFORM: "<< tgt.platform >>" PKG_PLATFORM_VERSION: "<< tgt.platform_version >>" EXTRA_OPTIMIZATIONS: "true" <%- if subdist != "nightly" %> BUILD_IS_RELEASE: "true" <%- endif %> <%- if tgt.family == "generic" %> BUILD_GENERIC: true <%- endif %> <%- if tgt.platform_libc %> PKG_PLATFORM_LIBC: "<< tgt.platform_libc >>" <%- endif %> METAPKG_GIT_CACHE: disabled GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-<< tgt.name >> path: artifacts/<< plat_id >> <%- endfor %> <%- for tgt in targets.macos %> <%- set plat_id = tgt.platform + ("{}".format(tgt.platform_libc) if tgt.platform_libc else "") + ("-{}".format(tgt.platform_version) if tgt.platform_version else "") %> build-<< tgt.name >>: runs-on: << tgt.runs_on if tgt.runs_on else "macos-latest" >> needs: prep <% if subdist == "nightly" %> if: needs.prep.outputs.if_<< tgt.name.replace('-', '_') >> == 'true' <% endif %> steps: - name: Update Homebrew before installing Rust toolchain run: | # Homebrew renamed `rustup-init` to `rustup`: # https://github.com/Homebrew/homebrew-core/pull/177840 # But the GitHub Action runner is not updated with this change yet. # This caused the later `brew update` in step `Build` to relink Rust # toolchain executables, overwriting the custom toolchain installed by # `dsherret/rust-toolchain-file`. So let's just run `brew update` early. brew update - uses: actions/checkout@v4 if: << 'false' if tgt.runs_on and 'self-hosted' in tgt.runs_on else 'true' >> with: sparse-checkout: | rust-toolchain.toml sparse-checkout-cone-mode: false - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 if: << 'false' if tgt.runs_on and 'self-hosted' in tgt.runs_on else 'true' >> - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Set up Python uses: actions/setup-python@v5 if: << 'false' if tgt.runs_on and 'self-hosted' in tgt.runs_on else 'true' >> with: python-version: "3.12" - name: Set up NodeJS uses: actions/setup-node@v4 if: << 'false' if tgt.runs_on and 'self-hosted' in tgt.runs_on else 'true' >> with: node-version: '20' - name: Install dependencies if: << 'false' if tgt.runs_on and 'self-hosted' in tgt.runs_on else 'true' >> run: | env HOMEBREW_NO_AUTO_UPDATE=1 brew install libmagic - name: Install an alias # This is probably not strictly needed, but sentencepiece build script reports # errors without it. if: << 'false' if tgt.runs_on and 'self-hosted' in tgt.runs_on else 'true' >> run: | printf '#!/bin/sh\n\nexec sysctl -n hw.logicalcpu' > /usr/local/bin/nproc chmod +x /usr/local/bin/nproc - name: Build env: PACKAGE: "<< package.name >>" SRC_REF: "${{ github.sha }}" <%- if subdist != "nightly" %> BUILD_IS_RELEASE: "true" <%- endif %> PKG_REVISION: "" <%- if subdist != "" %> PKG_SUBDIST: "<< subdist >>" <%- endif %> PKG_PLATFORM: "<< tgt.platform >>" PKG_PLATFORM_VERSION: "<< tgt.platform_version >>" PKG_PLATFORM_ARCH: "<< tgt.arch if tgt.arch else '' >>" EXTRA_OPTIMIZATIONS: "true" METAPKG_GIT_CACHE: disabled <%- if tgt.family == "generic" %> BUILD_GENERIC: true <%- endif %> CMAKE_POLICY_VERSION_MINIMUM: '3.5' GEL_PKG_REF: ${{ inputs.gelpkg_ref }} METAPKG_REF: ${{ inputs.metapkg_ref }} run: | edgedb-pkg/integration/macos/build.sh - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: builds-<< tgt.name >> path: artifacts/<< plat_id >> <%- endfor %> <%- if package.name != "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" %> <%- for tgt in targets.linux %> <%- set plat_id = tgt.platform + ("{}".format(tgt.platform_libc) if tgt.platform_libc else "") + ("-{}".format(tgt.platform_version) if tgt.platform_version else "") %> test-<< tgt.name >>: needs: [build-<< tgt.name >>] runs-on: << tgt.runs_on if tgt.runs_on else "ubuntu-latest" >> steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-<< tgt.name >> path: artifacts/<< plat_id >> - name: Test uses: docker://ghcr.io/geldata/gelpkg-test-<< plat_id >>:latest env: <%- if subdist != "" %> PKG_SUBDIST: "<< subdist >>" <%- endif %> PKG_PLATFORM: "<< tgt.platform >>" PKG_PLATFORM_VERSION: "<< tgt.platform_version >>" PKG_PLATFORM_LIBC: "<< tgt.platform_libc >>" PKG_TEST_SELECT: "<< package.test.select >>" PKG_TEST_EXCLUDE: "<< package.test.exclude >>" PKG_TEST_FILES: "<< package.test.files >> << tgt.test.files >>" # edb test with -j higher than 1 seems to result in workflow # jobs getting killed arbitrarily by Github. PKG_TEST_JOBS: << 0 if tgt.runs_on and 'self-hosted' in tgt.runs_on else 1 >> <%- endfor %> <%- for tgt in targets.macos %> <%- set plat_id = tgt.platform + ("{}".format(tgt.platform_libc) if tgt.platform_libc else "") + ("-{}".format(tgt.platform_version) if tgt.platform_version else "") %> test-<< tgt.name >>: needs: [build-<< tgt.name >>] runs-on: << tgt.runs_on if tgt.runs_on else "macos-latest" >> steps: - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-<< tgt.name >> path: artifacts/<< plat_id >> - name: Test env: <%- if subdist != "" %> PKG_SUBDIST: "<< subdist >>" <%- endif %> PKG_PLATFORM: "<< tgt.platform >>" PKG_PLATFORM_VERSION: "<< tgt.platform_version >>" PKG_TEST_SELECT: "<< package.test.select >>" PKG_TEST_EXCLUDE: "<< package.test.exclude >>" PKG_TEST_FILES: "<< package.test.files >> << tgt.test.files >>" run: | <%- if tgt.platform_version == "x86_64" %> # Bump shmmax and shmall to avoid test failures. sudo sysctl -w kern.sysv.shmmax=12582912 sudo sysctl -w kern.sysv.shmall=12582912 <%- endif %> edgedb-pkg/integration/macos/test.sh <%- endfor %> <%- endif %> <%- if publish_all %> collect: needs: <%- for tgt in targets.linux + targets.macos %> - test-<< tgt.name >> <%- endfor %> runs-on: ubuntu-latest steps: - run: echo 'All build+tests passed, ready to publish now!' <%- endif %> <%- for tgt in targets.linux %> <%- set plat_id = tgt.platform + ("{}".format(tgt.platform_libc) if tgt.platform_libc else "") + ("-{}".format(tgt.platform_version) if tgt.platform_version else "") %> <%- for publish in publications %> publish<< publish.suffix>>-<< tgt.name >>: needs: [<% if publish_all %>collect<% elif package.name != "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" %>test-<< tgt.name >><% else %>build-<< tgt.name >><% endif %>] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-<< tgt.name >> path: artifacts/<< plat_id >> - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: <%- if subdist != "" %> PKG_SUBDIST: "<< subdist >>" <%- endif %> <%- if publish.server != "" %> PACKAGE_SERVER: << publish.server >> <%- endif %> PKG_PLATFORM: "<< tgt.platform >>" PKG_PLATFORM_VERSION: "<< tgt.platform_version >>" PKG_PLATFORM_LIBC: "<< tgt.platform_libc >>" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" check-published<>-<< tgt.name >>: needs: [publish<< publish.suffix >>-<< tgt.name >>] runs-on: << tgt.runs_on if tgt.runs_on else "ubuntu-latest" >> steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-<< tgt.name >> path: artifacts/<< plat_id >> - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: << plat_id >> - name: Test Published uses: docker://ghcr.io/geldata/gelpkg-testpublished-<< plat_id >>:latest env: PKG_NAME: "${{ steps.describe.outputs.name }}" <%- if subdist != "" %> PKG_SUBDIST: "<< subdist >>" <%- endif %> <%- if publish.server != "" %> PACKAGE_SERVER: << publish.server >> <%- endif %> PKG_PLATFORM: "<< tgt.platform >>" PKG_PLATFORM_VERSION: "<< tgt.platform_version >>" PKG_INSTALL_REF: "${{ steps.describe.outputs.install-ref }}" PKG_VERSION_SLOT: "${{ steps.describe.outputs.version-slot }}" outputs: version-slot: ${{ steps.describe.outputs.version-slot }} version-core: ${{ steps.describe.outputs.version-core }} catalog-version: ${{ steps.describe.outputs.catalog-version }} <%- endfor %> <%- endfor %> <%- if publications %> <%- for tgt in targets.macos %> <%- set plat_id = tgt.platform + ("{}".format(tgt.platform_libc) if tgt.platform_libc else "") + ("-{}".format(tgt.platform_version) if tgt.platform_version else "") %> publish-<< tgt.name >>: needs: [<% if publish_all %>collect<% elif package.name != "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" %>test-<< tgt.name >><% else %>build-<< tgt.name >><% endif %>] runs-on: ubuntu-latest steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: builds-<< tgt.name >> path: artifacts/<< plat_id >> - uses: actions/checkout@v4 with: repository: edgedb/edgedb-pkg ref: master path: edgedb-pkg - name: Describe id: describe uses: edgedb/edgedb-pkg/integration/actions/describe-artifact@master with: target: << plat_id >> - name: Publish uses: docker://ghcr.io/geldata/gelpkg-upload-linux-x86_64:latest env: <%- if subdist != "" %> PKG_SUBDIST: "<< subdist >>" <%- endif %> PKG_PLATFORM: "<< tgt.platform >>" PKG_PLATFORM_VERSION: "<< tgt.platform_version >>" PACKAGE_UPLOAD_SSH_KEY: "${{ secrets.PACKAGE_UPLOAD_SSH_KEY }}" <%- endfor %> <%- endif %> <%- set docker_tgts = targets.linux | selectattr("docker_arch") | list %> <%- if docker_tgts and publications %> <%- set pub_outputs = "needs.check-published-" + (docker_tgts|first)["name"] + ".outputs" %> publish-docker: needs: <%- for tgt in docker_tgts %> - check-published-<< tgt.name >> <%- endfor %> runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: repository: geldata/gel-docker ref: master path: dockerfile - name: Login to Docker Hub uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Login to GitHub Container Registry uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: registry: ghcr.io username: "edgedb-ci" password: ${{ secrets.GITHUB_CI_BOT_TOKEN }} - env: VERSION_SLOT: "${{ << pub_outputs >>.version-slot }}" VERSION_CORE: "${{ << pub_outputs >>.version-core }}" CATALOG_VERSION: "${{ << pub_outputs >>.catalog-version }}" PKG_SUBDIST: "<< subdist >>" id: tags run: | set -e url='https://registry.hub.docker.com/v2/repositories/geldata/gel/tags?page_size=100' repo_tags=$( while [ -n "$url" ]; do resp=$(curl -L -s "$url") url=$(echo "$resp" | jq -r ".next") if [ "$url" = "null" ] || [ -z "$url" ]; then break fi echo "$resp" | jq -r '."results"[]["name"]' done | grep "^[[:digit:]]\+.*" | grep -v "alpha\|beta\|rc" || : ) tags=() if [ "$PKG_SUBDIST" = "nightly" ]; then tags+=( "nightly" "nightly_${VERSION_SLOT}_cv${CATALOG_VERSION}" ) else tags+=( "$VERSION_CORE" ) top=$(printf "%s\n%s\n" "$VERSION_CORE" "$repo_tags" \ | grep "^${VERSION_SLOT}[\.-]" \ | sort --version-sort --reverse | head -n 1) if [ "$top" == "$VERSION_CORE" ]; then tags+=( "$VERSION_SLOT" ) fi if [ -z "$PKG_SUBDIST" ]; then top=$(printf "%s\n%s\n" "$VERSION_CORE" "$repo_tags" \ | sort --version-sort --reverse | head -n 1) if [ "$top" == "$VERSION_CORE" ]; then tags+=( "latest" ) fi fi fi fq_tags=() images=("geldata/gel" "ghcr.io/geldata/gel") for image in "${images[@]}"; do fq_tags+=("${tags[@]/#/${image}:}") done IFS=, echo "tags=${fq_tags[*]}" >> $GITHUB_OUTPUT - name: Set up QEMU uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0 - name: Set up Docker Buildx uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - name: Build and Publish Docker Image uses: docker/build-push-action@471d1dc4e07e5cdedd4c2171150001c434f0b7a4 # v6.10.0 with: push: true provenance: mode=max tags: "${{ steps.tags.outputs.tags }}" context: dockerfile build-args: | version=${{ << pub_outputs >>.version-slot }} exact_version=${{ << pub_outputs >>.version-core }} <%- if subdist != "" %> subdist=<< subdist >> <%- endif %> platforms: << docker_tgts|map(attribute="docker_arch")|join(",") >> <%- endif %> workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - prep <%- if publish_all %> - collect <%- else %> <%- endif %> <%- for tgt in targets.linux %> - build-<< tgt.name >> <%- if package.name != "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" %> - test-<< tgt.name >> <%- endif %> <%- for publish in publications %> - publish<< publish.suffix>>-<< tgt.name >> - check-published<< publish.suffix>>-<< tgt.name >> <%- endfor %> <%- endfor %> <%- for tgt in targets.macos %> - build-<< tgt.name >> <%- if package.name != "edgedbpkg.edgedb_ls:EdgeDBLanguageServer" %> - test-<< tgt.name >> <%- endif %> <%- for publish in publications %> - publish<< publish.suffix>>-<< tgt.name >> <%- endfor %> <%- endfor %> <%- if docker_tgts and publications %> - publish-docker <%- endif %> runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' <%- endmacro %> ================================================ FILE: .github/workflows.src/build.ls-nightly.tpl.yml ================================================ <% from "build.inc.yml" import workflow, workflow_dispatch -%> name: 'ls: Build and Publish Nightly Packages' on: schedule: - cron: "0 1 * * *" <<- workflow_dispatch() >> push: branches: - nightly jobs: <<- workflow(package, targets, publications, subdist="nightly") ->> ================================================ FILE: .github/workflows.src/build.ls.targets.yml ================================================ publications: - name: prod suffix: "" server: sftp://uploader@package-upload.edgedb.net:22/ package: name: edgedbpkg.edgedb_ls:EdgeDBLanguageServer basename: gel-ls tests: files: "test_language_server.py" targets: linux: - name: linux-x86_64 arch: x86_64 platform: linux platform_version: x86_64 family: generic runs_on: [self-hosted, linux, x64] - name: linux-aarch64 arch: aarch64 platform: linux platform_version: aarch64 family: generic runs_on: [self-hosted, linux, arm64] - name: linuxmusl-x86_64 arch: x86_64 platform: linux platform_version: x86_64 platform_libc: musl family: generic runs_on: [self-hosted, linux, x64] - name: linuxmusl-aarch64 arch: aarch64 platform: linux platform_version: aarch64 platform_libc: musl family: generic runs_on: [self-hosted, linux, arm64] macos: - name: macos-x86_64 arch: x86_64 platform: macos platform_version: x86_64 family: generic runs_on: [macos-13] - name: macos-aarch64 arch: aarch64 platform: macos platform_version: aarch64 family: generic runs_on: [macos-14] ================================================ FILE: .github/workflows.src/build.nightly.tpl.yml ================================================ <% from "build.inc.yml" import workflow, workflow_dispatch -%> name: Build Test and Publish Nightly Packages on: schedule: - cron: "0 1 * * *" <<- workflow_dispatch() >> push: branches: - nightly jobs: <<- workflow(package, targets, publications, subdist="nightly") ->> ================================================ FILE: .github/workflows.src/build.release.tpl.yml ================================================ <% from "build.inc.yml" import workflow, workflow_dispatch -%> name: Build Test and Publish a Release on: <<- workflow_dispatch() >> jobs: <<- workflow(package, targets, publications, subdist="", publish_all=True) ->> ================================================ FILE: .github/workflows.src/build.targets.yml ================================================ publications: - name: prod suffix: "" server: sftp://uploader@package-upload.edgedb.net:22/ package: name: "edgedbpkg.edgedb:Gel" basename: gel-server targets: linux: - name: debian-buster-x86_64 arch: x86_64 platform: debian platform_version: buster family: debian runs_on: [package-builder, self-hosted, linux, x64] - name: debian-buster-aarch64 arch: aarch64 platform: debian platform_version: buster family: debian runs_on: [package-builder, self-hosted, linux, arm64] - name: debian-bullseye-x86_64 arch: x86_64 platform: debian platform_version: bullseye family: debian runs_on: [package-builder, self-hosted, linux, x64] - name: debian-bullseye-aarch64 arch: aarch64 platform: debian platform_version: bullseye family: debian runs_on: [package-builder, self-hosted, linux, arm64] - name: debian-bookworm-x86_64 arch: x86_64 platform: debian platform_version: bookworm family: debian runs_on: [package-builder, self-hosted, linux, x64] docker_arch: linux/amd64 - name: debian-bookworm-aarch64 arch: aarch64 platform: debian platform_version: bookworm family: debian runs_on: [package-builder, self-hosted, linux, arm64] docker_arch: linux/arm64 - name: ubuntu-focal-x86_64 arch: x86_64 platform: ubuntu platform_version: focal family: debian runs_on: [package-builder, self-hosted, linux, x64] - name: ubuntu-focal-aarch64 arch: aarch64 platform: ubuntu platform_version: focal family: debian runs_on: [package-builder, self-hosted, linux, arm64] - name: ubuntu-jammy-x86_64 arch: x86_64 platform: ubuntu platform_version: jammy family: debian runs_on: [package-builder, self-hosted, linux, x64] - name: ubuntu-jammy-aarch64 arch: aarch64 platform: ubuntu platform_version: jammy family: debian runs_on: [package-builder, self-hosted, linux, arm64] - name: ubuntu-noble-x86_64 arch: x86_64 platform: ubuntu platform_version: noble family: debian runs_on: [package-builder, self-hosted, linux, x64] - name: ubuntu-noble-aarch64 arch: aarch64 platform: ubuntu platform_version: noble family: debian runs_on: [package-builder, self-hosted, linux, arm64] - name: centos-8-x86_64 arch: x86_64 platform: centos platform_version: 8 family: redhat runs_on: [package-builder, self-hosted, linux, x64] - name: centos-8-aarch64 arch: aarch64 platform: centos platform_version: 8 family: redhat runs_on: [package-builder, self-hosted, linux, arm64] - name: rockylinux-9-x86_64 arch: x86_64 platform: rockylinux platform_version: 9 family: redhat runs_on: [package-builder, self-hosted, linux, x64] - name: rockylinux-9-aarch64 arch: aarch64 platform: rockylinux platform_version: 9 family: redhat runs_on: [package-builder, self-hosted, linux, arm64] - name: linux-x86_64 arch: x86_64 platform: linux platform_version: x86_64 family: generic runs_on: [package-builder, self-hosted, linux, x64] - name: linux-aarch64 arch: aarch64 platform: linux platform_version: aarch64 family: generic runs_on: [package-builder, self-hosted, linux, arm64] - name: linuxmusl-x86_64 arch: x86_64 platform: linux platform_version: x86_64 platform_libc: musl family: generic runs_on: [package-builder, self-hosted, linux, x64] - name: linuxmusl-aarch64 arch: aarch64 platform: linux platform_version: aarch64 platform_libc: musl family: generic runs_on: [package-builder, self-hosted, linux, arm64] macos: - name: macos-x86_64 arch: x86_64 platform: macos platform_version: x86_64 family: generic runs_on: [macos-13] # Run fewer tests on x86_64, since the test runner is very slow. test: files: > test_dump*.py test_backend_*.py test_database.py test_server_*.py test_edgeql_ddl.py test_session.py - name: macos-aarch64 arch: aarch64 platform: macos platform_version: aarch64 family: generic runs_on: [macos-14] ================================================ FILE: .github/workflows.src/build.testing.tpl.yml ================================================ <% from "build.inc.yml" import workflow, workflow_dispatch -%> name: Build Test and Publish a Testing Release on: <<- workflow_dispatch() >> jobs: <<- workflow(package, targets, publications, subdist="testing", publish_all=True) ->> ================================================ FILE: .github/workflows.src/render.py ================================================ #!/usr/bin/env python3 import argparse import pathlib import sys import jinja2 import yaml env = jinja2.Environment( variable_start_string='<<', variable_end_string='>>', block_start_string='<%', block_end_string='%>', loader=jinja2.FileSystemLoader(pathlib.Path(__file__).parent), ) def die(msg): print(msg, file=sys.stderr) sys.exit(1) def _expand_test_spec(target): if "test" not in target: target["test"] = { "include": "", "exclude": "", "files": "" } for key in {"include", "exclude", "files"}: if key not in target["test"]: target["test"][key] = "" def _render(tpl_path, data): with open(tpl_path) as f: tpl = env.from_string(f.read()) return tpl.render(**data) def main(): parser = argparse.ArgumentParser() parser.add_argument('--workflow', choices=["build", "test"], required=True) parser.add_argument('template') parser.add_argument('datafile') args = parser.parse_args() tplfile = f'{args.template}.tpl.yml' path = pathlib.Path(__file__).parent / tplfile if not path.exists(): die(f'template does not exist: {tplfile}') datapath = pathlib.Path(__file__).parent / args.datafile if datapath.exists(): with open(datapath) as f: data = yaml.load(f, Loader=yaml.SafeLoader) else: data = {} if args.workflow == "build": package = data.get("package") if not package or not isinstance(package, dict): die(f"invalid package: specification in {datapath}") if not package.get("name"): die(f"missing package.name in {datapath}") _expand_test_spec(package) targets = data.get("targets") if not targets or not isinstance(targets, dict): die(f"invalid targets: specification in {datapath}") for target_list in targets.values(): for target in target_list: _expand_test_spec(target) output = _render(path, data) target = ( pathlib.Path(__file__).parent.parent / 'workflows' / f'{args.template}.yml' ) with open(target, 'w') as f: print(output, file=f) if __name__ == '__main__': main() ================================================ FILE: .github/workflows.src/tests.ha.targets.yml ================================================ data: ================================================ FILE: .github/workflows.src/tests.ha.tpl.yml ================================================ <% from "tests.inc.yml" import build, calc_cache_key, restore_cache -%> name: High Availability Tests on: workflow_dispatch: inputs: {} workflow_run: workflows: ["Tests"] types: - completed jobs: build: runs-on: ubuntu-latest if: github.event.workflow_run.conclusion == 'success' || github.event_name == 'workflow_dispatch' steps: <%- call build() -%> # Our HA tests currently only work on Postgres 14 (see #6332), # so check it out before we compute our build cache keys. - name: Switch back to Postgres 14 shell: bash run: | set -e cd postgres # Fetch postgres 14, since the clone was shallow git fetch origin REL_14_8 --depth=1 # For whatever reason the tag doesn't get fetched, so find it # at FETCH_HEAD git checkout FETCH_HEAD - name: Compute cache keys env: GIST_TOKEN: ${{ secrets.CI_BOT_GIST_TOKEN }} run: | << calc_cache_key()|indent >> <%- endcall %> ha-test: needs: build runs-on: ubuntu-latest steps: <<- restore_cache() >> # Run the test - name: Test env: SHARD: ${{ matrix.shard }} EDGEDB_TEST_HA: 1 EDGEDB_TEST_CONSUL_PATH: build/stolon/bin/consul EDGEDB_TEST_STOLON_CTL: build/stolon/bin/stolonctl EDGEDB_TEST_STOLON_SENTINEL: build/stolon/bin/stolon-sentinel EDGEDB_TEST_STOLON_KEEPER: build/stolon/bin/stolon-keeper run: | edb test -j1 -v -k test_ha_ workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - ha-test runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows.src/tests.inc.yml ================================================ <% macro init(ref='') -%> - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false <%- if ref != "" %> ref: << ref >> <%- endif %> - uses: actions/checkout@v4 with: fetch-depth: 50 submodules: true <%- if ref != "" %> ref: << ref >> <%- endif %> - name: Set up Python uses: actions/setup-python@v5 id: setup-python with: python-version: '3.12.2' cache: 'pip' cache-dependency-path: | pyproject.toml # The below is technically a lie as we are technically not # inside a virtual env, but there is really no reason to bother # actually creating and activating one as below works just fine. - name: Export $VIRTUAL_ENV run: | venv="$(python -c 'import sys; sys.stdout.write(sys.prefix)')" echo "VIRTUAL_ENV=${venv}" >> $GITHUB_ENV - name: Set up uv cache uses: actions/cache@v4 with: path: ~/.cache/uv key: uv-cache-${{ runner.os }}-py-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} <%- endmacro %> <% macro build(ref="") %> << init(ref) >> - name: Cached requirements.txt uses: actions/cache@v4 id: requirements-cache with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Compute requirements.txt if: steps.requirements-cache.outputs.cache-hit != 'true' run: | python -m pip install pip-tools pip-compile --no-strip-extras --all-build-deps \ --extra test,language-server \ --output-file requirements.txt pyproject.toml - name: Install Python dependencies run: | python -c "import sys; print(sys.prefix)" python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 << caller() >> - name: Upload shared artifacts uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: shared-artifacts path: shared-artifacts retention-days: 1 # Restore binary cache - name: Handle cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-v4- - name: Handle cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Handle cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Handle cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Handle cached libpg_query build uses: actions/cache@v4 id: libpg-query-cache with: path: edb/pgsql/parser/libpg_query/libpg_query.a key: edb-libpg_query-v1-${{ env.LIBPG_QUERY_GIT_REV }} # Install system dependencies for building - name: Install system deps if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' run: | sudo apt-get update sudo apt-get install -y uuid-dev libreadline-dev bison flex libprotobuf-c-dev - name: Install Rust toolchain if: steps.rust-cache.outputs.cache-hit != 'true' uses: dsherret/rust-toolchain-file@v1 # Build Rust extensions - name: Handle Rust extensions build cache uses: actions/cache@v4 if: steps.rust-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} restore-keys: | edb-rust-build-v1- - name: Build Rust extensions env: CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home CACHE_HIT: ${{ steps.rust-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p build/rust_extensions rsync -av ./build/rust_extensions/ ${BUILD_LIB}/ python setup.py -v build_rust rsync -av ${BUILD_LIB}/ build/rust_extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/rust_extensions/edb/ ./edb/ # Build libpg_query - name: Build libpg_query if: | steps.libpg-query-cache.outputs.cache-hit != 'true' && steps.ext-cache.outputs.cache-hit != 'true' run: | python setup.py build_libpg_query # Build extensions - name: Handle Cython extensions build cache uses: actions/cache@v4 if: steps.ext-cache.outputs.cache-hit != 'true' with: path: ${{ env.BUILD_TEMP }}/edb key: edb-ext-build-v4-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Build Cython extensions env: CACHE_HIT: ${{ steps.ext-cache.outputs.cache-hit }} BUILD_EXT_MODE: py-only run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/extensions rsync -av ./build/extensions/ ${BUILD_LIB}/ BUILD_EXT_MODE=py-only python setup.py -v build_ext rsync -av ${BUILD_LIB}/ ./build/extensions/ rm -rf ${BUILD_LIB} fi rsync -av ./build/extensions/edb/ ./edb/ # Build parsers - name: Handle compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} restore-keys: | edb-parsers-v3- - name: Build parsers env: CACHE_HIT: ${{ steps.parsers-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" != "true" ]]; then rm -rf ${BUILD_LIB} mkdir -p ./build/lib rsync -av ./build/lib/ ${BUILD_LIB}/ python setup.py -v build_parsers rsync -av ${BUILD_LIB}/ ./build/lib/ rm -rf ${BUILD_LIB} fi rsync -av ./build/lib/edb/ ./edb/ # Build PostgreSQL - name: Build PostgreSQL env: CACHE_HIT: ${{ steps.postgres-cache.outputs.cache-hit }} run: | if [[ "$CACHE_HIT" == "true" ]]; then cp build/postgres/install/stamp build/postgres/ else python setup.py build_postgres cp build/postgres/stamp build/postgres/install/ fi # Build Stolon - name: Set up Go if: steps.stolon-cache.outputs.cache-hit != 'true' uses: actions/setup-go@v2 with: go-version: 1.16 - uses: actions/checkout@v4 if: steps.stolon-cache.outputs.cache-hit != 'true' with: repository: edgedb/stolon path: build/stolon ref: ${{ env.STOLON_GIT_REV }} fetch-depth: 0 submodules: false - name: Build Stolon if: steps.stolon-cache.outputs.cache-hit != 'true' run: | mkdir -p build/stolon/bin/ curl -fsSL https://releases.hashicorp.com/consul/1.10.1/consul_1.10.1_linux_amd64.zip | zcat > build/stolon/bin/consul chmod +x build/stolon/bin/consul cd build/stolon && make # Install edgedb-server and populate egg-info - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] # Refresh the bootstrap cache - name: Handle bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} restore-keys: | edb-bootstrap-v2- - name: Bootstrap EdgeDB Server if: steps.bootstrap-cache.outputs.cache-hit != 'true' run: | edb server --bootstrap-only <%- endmacro %> <% macro calc_cache_key() -%> mkdir -p shared-artifacts if [ "$(uname)" = "Darwin" ]; then find /usr/lib -type f -name 'lib*' -exec stat -f '%N %z' {} + | sort | shasum -a 256 | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt else find /usr/lib -type f -name 'lib*' -printf '%P %s\n' | sort | sha256sum | cut -d ' ' -f1 > shared-artifacts/lib_cache_key.txt fi python setup.py -q ci_helper --type rust >shared-artifacts/rust_cache_key.txt python setup.py -q ci_helper --type ext >shared-artifacts/ext_cache_key.txt python setup.py -q ci_helper --type parsers >shared-artifacts/parsers_cache_key.txt python setup.py -q ci_helper --type postgres >shared-artifacts/postgres_git_rev.txt python setup.py -q ci_helper --type libpg_query >shared-artifacts/libpg_query_git_rev.txt echo 'f8cd94309eaccbfba5dea7835b88c78377608a37' >shared-artifacts/stolon_git_rev.txt python setup.py -q ci_helper --type bootstrap >shared-artifacts/bootstrap_cache_key.txt echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo LIBPG_QUERY_GIT_REV=$(cat shared-artifacts/libpg_query_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV <%- endmacro %> <% macro install_python_requirements() %> - name: Download requirements.txt uses: actions/cache@v4 with: path: requirements.txt key: edb-requirements-${{ hashFiles('pyproject.toml') }} - name: Install Python dependencies run: | python -m pip install uv~=0.1.0 && uv pip install -U -r requirements.txt # 80.9.0 breaks our sphinx, and it keeps sneaking in uv pip install setuptools==80.8.0 <%- endmacro %> <% macro restore_cache(ref="") %> << init(ref) >> << install_python_requirements() >> # Restore the artifacts and environment variables - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Set environment variables run: | echo POSTGRES_GIT_REV=$(cat shared-artifacts/postgres_git_rev.txt) >> $GITHUB_ENV echo STOLON_GIT_REV=$(cat shared-artifacts/stolon_git_rev.txt) >> $GITHUB_ENV echo BUILD_LIB=$(python setup.py -q ci_helper --type build_lib) >> $GITHUB_ENV echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV # Restore build cache - name: Restore cached Rust extensions uses: actions/cache@v4 id: rust-cache with: path: build/rust_extensions key: edb-rust-v4-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Restore cached Cython extensions uses: actions/cache@v4 id: ext-cache with: path: build/extensions key: edb-ext-v6-${{ hashFiles('shared-artifacts/ext_cache_key.txt') }} - name: Restore compiled parsers cache uses: actions/cache@v4 id: parsers-cache with: path: build/lib key: edb-parsers-v3-${{ hashFiles('shared-artifacts/parsers_cache_key.txt') }} - name: Restore cached PostgreSQL build uses: actions/cache@v4 id: postgres-cache with: path: build/postgres/install key: edb-postgres-v3-${{ env.POSTGRES_GIT_REV }}-${{ hashFiles('shared-artifacts/lib_cache_key.txt') }} - name: Restore cached Stolon build uses: actions/cache@v4 id: stolon-cache with: path: build/stolon/bin key: edb-stolon-v2-${{ env.STOLON_GIT_REV }} - name: Restore bootstrap cache uses: actions/cache@v4 id: bootstrap-cache with: path: build/cache key: edb-bootstrap-v2-${{ hashFiles('shared-artifacts/bootstrap_cache_key.txt') }} - name: Stop if we cannot retrieve the cache if: | steps.rust-cache.outputs.cache-hit != 'true' || steps.ext-cache.outputs.cache-hit != 'true' || steps.parsers-cache.outputs.cache-hit != 'true' || steps.postgres-cache.outputs.cache-hit != 'true' || steps.stolon-cache.outputs.cache-hit != 'true' || steps.bootstrap-cache.outputs.cache-hit != 'true' run: | echo ::error::Cannot retrieve build cache. exit 1 - name: Validate cached binaries run: | # Validate Stolon ./build/stolon/bin/stolon-sentinel --version || exit 1 ./build/stolon/bin/stolon-keeper --version || exit 1 ./build/stolon/bin/stolon-proxy --version || exit 1 # Validate PostgreSQL ./build/postgres/install/bin/postgres --version || exit 1 ./build/postgres/install/bin/pg_config --version || exit 1 - name: Restore cache into the source tree run: | rsync -av ./build/rust_extensions/edb/ ./edb/ rsync -av ./build/extensions/edb/ ./edb/ rsync -av ./build/lib/edb/ ./edb/ cp build/postgres/install/stamp build/postgres/ - name: Install edgedb-server env: BUILD_EXT_MODE: skip run: | # --no-build-isolation because we have explicitly installed all deps # and don't want them to be reinstalled in an "isolated env". pip install --no-build-isolation --no-deps -e .[test,docs] <%- endmacro %> <% macro setup_terraform() -%> - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false - name: Setup Terraform uses: hashicorp/setup-terraform@633666f66e0061ca3b725c73b2ec20cd13a8fdd1 # v2.0.3 - name: Initialize Terraform run: terraform init <%- endmacro %> ================================================ FILE: .github/workflows.src/tests.inplace.targets.yml ================================================ data: ================================================ FILE: .github/workflows.src/tests.inplace.tpl.yml ================================================ <% from "tests.inc.yml" import build, calc_cache_key, restore_cache -%> name: Tests of in-place upgrades and patching on: schedule: - cron: "0 3 * * *" workflow_dispatch: inputs: {} push: branches: - "A-inplace*" jobs: build: runs-on: ubuntu-latest steps: <%- call build() -%> - name: Compute cache keys run: | << calc_cache_key()|indent >> <%- endcall %> test-inplace: runs-on: ubuntu-latest needs: build strategy: fail-fast: false matrix: include: - flags: tests: - flags: --rollback-and-test tests: # Do the reapply test on a smaller selection of tests, since # it is slower. - flags: --rollback-and-reapply tests: -k test_link_on_target_delete -k test_edgeql_select -k test_dump steps: <<- restore_cache() >> # Run the test # TODO: Would it be better to split this up into multiple jobs? - name: Test performing in-place upgrades run: | ./tests/inplace-testing/test.sh ${{ matrix.flags }} vt ${{ matrix.tests }} test-patches: runs-on: ubuntu-latest needs: build steps: <<- restore_cache() >> - name: Test performing in-place upgrades run: | ./tests/patch-testing/test.sh test-dir -k test_link_on_target_delete -k test_edgeql_select -k test_edgeql_scope -k test_dump compute-versions: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - uses: actions/checkout@v4 - id: set-matrix name: Compute versions to run on run: python3 .github/scripts/patches/compute-ipu-versions.py test: runs-on: ubuntu-latest needs: [build, compute-versions] strategy: fail-fast: false matrix: ${{fromJSON(needs.compute-versions.outputs.matrix)}} steps: <<- restore_cache() >> # Run the test - name: Download an earlier database version run: | wget -q "${{ matrix.edgedb-url }}" tar xzf ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }}.tar.gz - name: Make sure a CLI named "edgedb" exists (sigh) run: | ln -s gel $(dirname $(which gel))/edgedb - name: Test inplace upgrades from previous major version run: | ./tests/inplace-testing/test-old.sh vt ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }} workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - test-inplace - test-patches runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows.src/tests.inplace7x.targets.yml ================================================ data: ================================================ FILE: .github/workflows.src/tests.inplace7x.tpl.yml ================================================ <% from "tests.inc.yml" import build, calc_cache_key, restore_cache -%> name: Tests of in-place upgrades to 7.x on: schedule: - cron: "0 3 * * *" workflow_dispatch: inputs: {} push: branches: - "A-inplace*" jobs: build: runs-on: ubuntu-latest steps: <%- call build("release/7.x") -%> - name: Compute cache keys run: | << calc_cache_key()|indent >> <%- endcall %> test-inplace: runs-on: ubuntu-latest needs: build strategy: fail-fast: false matrix: include: - flags: tests: - flags: --rollback-and-test tests: # Do the reapply test on a smaller selection of tests, since # it is slower. - flags: --rollback-and-reapply tests: -k test_link_on_target_delete -k test_edgeql_select -k test_dump steps: <<- restore_cache("release/7.x") >> # Run the test # TODO: Would it be better to split this up into multiple jobs? - name: Test performing in-place upgrades run: | ./tests/inplace-testing/test.sh ${{ matrix.flags }} vt ${{ matrix.tests }} compute-versions: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - uses: actions/checkout@v4 with: fetch-depth: 0 submodules: false ref: release/7.x - id: set-matrix name: Compute versions to run on run: python3 .github/scripts/patches/compute-ipu-versions.py test: runs-on: ubuntu-latest needs: [build, compute-versions] strategy: fail-fast: false matrix: ${{fromJSON(needs.compute-versions.outputs.matrix)}} steps: <<- restore_cache("release/7.x") >> # Run the test - name: Download an earlier database version run: | wget -q "${{ matrix.edgedb-url }}" tar xzf ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }}.tar.gz - name: Make sure a CLI named "edgedb" exists (sigh) run: | ln -s gel $(dirname $(which gel))/edgedb - name: Test inplace upgrades from previous major version run: | ./tests/inplace-testing/test-old.sh vt ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }} workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - test-inplace runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows.src/tests.managed-pg.targets.yml ================================================ data: ================================================ FILE: .github/workflows.src/tests.managed-pg.tpl.yml ================================================ <% from "tests.inc.yml" import build, calc_cache_key, restore_cache, setup_terraform -%> <% macro setup_aws_creds() -%> - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v1 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} aws-region: us-east-2 <%- endmacro -%> <% macro setup_gcp_creds() -%> - name: Configure GCP Credentials uses: google-github-actions/setup-gcloud@main with: service_account_key: ${{ secrets.GCP_SA_KEY }} export_default_credentials: true <%- endmacro -%> name: Tests on Managed PostgreSQL on: schedule: - cron: "0 3 * * 6" workflow_dispatch: inputs: {} push: branches: - cloud-test jobs: build: runs-on: ubuntu-latest steps: <%- call build() -%> - name: Compute cache keys run: | << calc_cache_key()|indent >> <%- endcall %> setup-aws-rds: runs-on: ubuntu-latest outputs: pghost: ${{ steps.pghost.outputs.stdout }} defaults: run: working-directory: .github/aws-rds steps: << setup_terraform()|indent(2) >> << setup_aws_creds()|indent(2) >> - name: Setup AWS RDS env: TF_VAR_sg_id: ${{ secrets.AWS_SECURITY_GROUP }} TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} run: | terraform apply -auto-approve - name: Store Terraform state if: ${{ always() }} uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: aws-rds-tfstate path: .github/aws-rds/terraform.tfstate retention-days: 1 - name: Get RDS host id: pghost run: | terraform output -raw db_instance_address test-aws-rds: runs-on: ubuntu-latest needs: [setup-aws-rds, build] steps: <<- restore_cache() >> # Run the test - name: Test env: EDGEDB_TEST_BACKEND_DSN: postgres://edbtest:${{ secrets.AWS_RDS_PASSWORD }}@${{ needs.setup-aws-rds.outputs.pghost }}/postgres run: | edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode edb test -j2 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN teardown-aws-rds: runs-on: ubuntu-latest needs: test-aws-rds if: ${{ always() }} defaults: run: working-directory: .github/aws-rds steps: << setup_terraform()|indent(2) >> << setup_aws_creds()|indent(2) >> - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: aws-rds-tfstate path: .github/aws-rds - name: Destroy AWS RDS run: terraform destroy -auto-approve env: TF_VAR_sg_id: ${{ secrets.AWS_SECURITY_GROUP }} TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} - name: Overwrite Terraform state uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: aws-rds-tfstate path: .github/aws-rds/terraform.tfstate retention-days: 1 setup-do-database: runs-on: ubuntu-latest defaults: run: working-directory: .github/do-database steps: << setup_terraform()|indent(2) >> - name: Setup DigitalOcean Database env: TF_VAR_do_token: ${{ secrets.DIGITALOCEAN_TOKEN }} run: | terraform apply -auto-approve - name: Store Terraform state if: ${{ always() }} uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: do-database-tfstate path: .github/do-database/terraform.tfstate retention-days: 1 test-do-database: runs-on: ubuntu-latest needs: [setup-do-database, build] steps: <<- restore_cache() >> - name: Setup Terraform uses: hashicorp/setup-terraform@v1 - name: Initialize Terraform working-directory: .github/do-database run: terraform init - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: do-database-tfstate path: .github/do-database - name: Get Database host id: pghost working-directory: .github/do-database run: | terraform output -raw db_instance_address - name: Get Database port id: pgport working-directory: .github/do-database run: | terraform output -raw db_instance_port - name: Get Database user id: pguser working-directory: .github/do-database run: | terraform output -raw db_instance_user - name: Get Database password id: pgpass working-directory: .github/do-database run: | terraform output -raw db_instance_password - name: Get Database dbname id: pgdatabase working-directory: .github/do-database run: | terraform output -raw db_instance_database # Run the test - name: Test env: EDGEDB_TEST_BACKEND_DSN: postgres://${{ steps.pguser.outputs.stdout }}:${{ steps.pgpass.outputs.stdout }}@${{ steps.pghost.outputs.stdout }}:${{ steps.pgport.outputs.stdout }}/${{ steps.pgdatabase.outputs.stdout }} run: | edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode edb test -j2 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN teardown-do-database: runs-on: ubuntu-latest needs: test-do-database if: ${{ always() }} defaults: run: working-directory: .github/do-database steps: << setup_terraform()|indent(2) >> - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: do-database-tfstate path: .github/do-database - name: Destroy DigitalOcean Database run: terraform destroy -auto-approve env: TF_VAR_do_token: ${{ secrets.DIGITALOCEAN_TOKEN }} - name: Overwrite Terraform state uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: do-database-tfstate path: .github/do-database/terraform.tfstate retention-days: 1 setup-gcp-cloud-sql: runs-on: ubuntu-latest outputs: pghost: ${{ steps.pghost.outputs.stdout }} defaults: run: working-directory: .github/gcp-cloud-sql steps: << setup_terraform()|indent(2) >> << setup_gcp_creds()|indent(2) >> - name: Setup GCP Cloud SQL env: TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} run: | terraform apply -auto-approve - name: Store Terraform state if: ${{ always() }} uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: gcp-cloud-sql-tfstate path: .github/gcp-cloud-sql/terraform.tfstate retention-days: 1 - name: Get Cloud SQL host id: pghost run: | terraform output -raw db_instance_address test-gcp-cloud-sql: runs-on: ubuntu-latest needs: [setup-gcp-cloud-sql, build] steps: <<- restore_cache() >> # Run the test - name: Test env: EDGEDB_TEST_BACKEND_DSN: postgres://postgres:${{ secrets.AWS_RDS_PASSWORD }}@${{ needs.setup-gcp-cloud-sql.outputs.pghost }}/postgres run: | edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode edb test -j2 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN teardown-gcp-cloud-sql: runs-on: ubuntu-latest needs: test-gcp-cloud-sql if: ${{ always() }} defaults: run: working-directory: .github/gcp-cloud-sql steps: << setup_terraform()|indent(2) >> << setup_gcp_creds()|indent(2) >> - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: gcp-cloud-sql-tfstate path: .github/gcp-cloud-sql - name: Destroy GCP Cloud SQL run: terraform destroy -auto-approve env: TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} - name: Overwrite Terraform state uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: gcp-cloud-sql-tfstate path: .github/gcp-cloud-sql/terraform.tfstate retention-days: 1 setup-aws-aurora: runs-on: ubuntu-latest outputs: pghost: ${{ steps.pghost.outputs.stdout }} defaults: run: working-directory: .github/aws-aurora steps: << setup_terraform()|indent(2) >> << setup_aws_creds()|indent(2) >> - name: Setup AWS RDS Aurora env: TF_VAR_sg_id: ${{ secrets.AWS_SECURITY_GROUP }} TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} TF_VAR_vpc_id: ${{ secrets.AWS_VPC_ID }} run: | terraform apply -auto-approve - name: Store Terraform state if: ${{ always() }} uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: aws-aurora-tfstate path: .github/aws-aurora/terraform.tfstate retention-days: 1 - name: Get RDS Aurora host id: pghost run: | terraform output -raw rds_cluster_endpoint test-aws-aurora: runs-on: ubuntu-latest needs: [setup-aws-aurora, build] steps: <<- restore_cache() >> # Run the test - name: Test env: EDGEDB_TEST_BACKEND_DSN: postgres://edbtest:${{ secrets.AWS_RDS_PASSWORD }}@${{ needs.setup-aws-aurora.outputs.pghost }}/postgres run: | edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode edb test -j1 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN teardown-aws-aurora: runs-on: ubuntu-latest needs: test-aws-aurora if: ${{ always() }} defaults: run: working-directory: .github/aws-aurora steps: << setup_terraform()|indent(2) >> << setup_aws_creds()|indent(2) >> - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: aws-aurora-tfstate path: .github/aws-aurora - name: Destroy AWS RDS Aurora run: terraform destroy -auto-approve env: TF_VAR_sg_id: ${{ secrets.AWS_SECURITY_GROUP }} TF_VAR_password: ${{ secrets.AWS_RDS_PASSWORD }} TF_VAR_vpc_id: ${{ secrets.AWS_VPC_ID }} - name: Overwrite Terraform state uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: aws-aurora-tfstate path: .github/aws-aurora/terraform.tfstate retention-days: 1 setup-heroku-postgres: runs-on: ubuntu-latest defaults: run: working-directory: .github/heroku-postgres steps: << setup_terraform()|indent(2) >> - name: Setup Heroku Postgres env: HEROKU_API_KEY: ${{ secrets.HEROKU_API_KEY }} HEROKU_EMAIL: ${{ secrets.HEROKU_EMAIL }} run: | terraform apply -auto-approve - name: Store Terraform state if: ${{ always() }} uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: heroku-postgres-tfstate path: .github/heroku-postgres/terraform.tfstate retention-days: 1 test-heroku-postgres: runs-on: ubuntu-latest needs: [setup-heroku-postgres, build] steps: <<- restore_cache() >> - name: Setup Terraform uses: hashicorp/setup-terraform@v1 - name: Initialize Terraform working-directory: .github/heroku-postgres run: terraform init - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: heroku-postgres-tfstate path: .github/heroku-postgres - name: Get Heroku Postgres DSN id: pgdsn working-directory: .github/heroku-postgres run: | terraform output -raw heroku_postgres_dsn # Run the test - name: Test env: EDGEDB_TEST_BACKEND_VENDOR: heroku-postgres EDGEDB_TEST_BACKEND_DSN: ${{ steps.pgdsn.outputs.stdout }} run: | edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode edb test -j1 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN teardown-heroku-postgres: runs-on: ubuntu-latest needs: test-heroku-postgres if: ${{ always() }} defaults: run: working-directory: .github/heroku-postgres steps: << setup_terraform()|indent(2) >> << setup_aws_creds()|indent(2) >> - name: Restore Terraform state uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: heroku-postgres-tfstate path: .github/heroku-postgres - name: Destroy Heroku Postgres run: terraform destroy -auto-approve env: HEROKU_API_KEY: ${{ secrets.HEROKU_API_KEY }} HEROKU_EMAIL: ${{ secrets.HEROKU_EMAIL }} - name: Overwrite Terraform state uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: heroku-postgres-tfstate path: .github/heroku-postgres/terraform.tfstate retention-days: 1 workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - setup-aws-rds - test-aws-rds - teardown-aws-rds - setup-do-database - test-do-database - teardown-do-database - setup-gcp-cloud-sql - test-gcp-cloud-sql - teardown-gcp-cloud-sql - setup-aws-aurora - test-aws-aurora - teardown-aws-aurora - setup-heroku-postgres - test-heroku-postgres - teardown-heroku-postgres runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows.src/tests.patches.targets.yml ================================================ data: ================================================ FILE: .github/workflows.src/tests.patches.tpl.yml ================================================ <% from "tests.inc.yml" import build, calc_cache_key, restore_cache -%> name: Tests of patching old EdgeDB Versions on: workflow_dispatch: inputs: {} pull_request: branches: - release/* push: branches: - patch-test* - release/* jobs: build: runs-on: ubuntu-latest steps: <%- call build() -%> - name: Compute cache keys run: | << calc_cache_key()|indent >> <%- endcall %> compute-versions: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - uses: actions/checkout@v4 - id: set-matrix name: Compute versions to run on run: python3 .github/scripts/patches/compute-versions.py test: runs-on: ubuntu-latest needs: [build, compute-versions] strategy: fail-fast: false matrix: ${{fromJSON(needs.compute-versions.outputs.matrix)}} steps: <<- restore_cache() >> # Run the test - name: Download an earlier database version and set up a instance run: | wget -q "${{ matrix.edgedb-url }}" tar xzf ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }}.tar.gz ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }}/bin/edgedb-server -D test-dir --bootstrap-only --testmode - name: Create databases on the older version if: ${{ matrix.make-dbs }} run: python3 .github/scripts/patches/create-databases.py ${{ matrix.edgedb-basename }}-${{ matrix.edgedb-version }}/bin/edgedb-server - name: Run tests with instance created on an older version run: | # Run the server explicitly first to do the upgrade, since edb test # has timeouts. edb server --bootstrap-only --data-dir test-dir # Should we run *all* the tests? edb test -j2 -v --data-dir test-dir tests/test_edgeql_json.py tests/test_edgeql_casts.py tests/test_edgeql_functions.py tests/test_edgeql_expressions.py tests/test_edgeql_policies.py tests/test_edgeql_vector.py tests/test_edgeql_scope.py tests/test_http_ext_auth.py - name: Test downgrading a database after an upgrade if: ${{ !contains(matrix.edgedb-version, '-rc') && !contains(matrix.edgedb-version, '-beta') }} env: EDGEDB_VERSION: ${{ matrix.edgedb-version }} run: python3 .github/scripts/patches/test-downgrade.py workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - compute-versions - test runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows.src/tests.pg-versions.targets.yml ================================================ data: ================================================ FILE: .github/workflows.src/tests.pg-versions.tpl.yml ================================================ <% from "tests.inc.yml" import build, calc_cache_key, restore_cache, setup_terraform -%> name: Tests on PostgreSQL Versions on: schedule: - cron: "0 3 * * *" workflow_dispatch: inputs: {} push: branches: - pg-test jobs: build: runs-on: ubuntu-latest steps: <%- call build() -%> - name: Compute cache keys run: | << calc_cache_key()|indent >> <%- endcall %> test: runs-on: ubuntu-latest needs: build strategy: fail-fast: false matrix: postgres-version: [ 17 ] single-mode: - '' # These are very broken. Disabling them for now until we # decide whether to fix them or give up. # - 'NOCREATEDB NOCREATEROLE' # - 'CREATEDB NOCREATEROLE' multi-tenant-mode: [ '' ] include: - postgres-version: 14 single-mode: '' multi-tenant-mode: '' - postgres-version: 15 single-mode: '' multi-tenant-mode: '' - postgres-version: 16 single-mode: '' multi-tenant-mode: '' - postgres-version: 17 single-mode: '' multi-tenant-mode: 'remote-compiler' - postgres-version: 17 single-mode: '' multi-tenant-mode: 'multi-tenant' services: postgres: image: pgvector/pgvector:0.7.4-pg${{ matrix.postgres-version }} env: POSTGRES_PASSWORD: postgres options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 --name postgres ports: - 5432:5432 steps: - name: Trust pgvector extension uses: docker://docker with: args: docker exec postgres sed -i $a\trusted=true /usr/share/postgresql/${{ matrix.postgres-version }}/extension/vector.control <<- restore_cache() >> # Run the test - name: Setup single mode role and database if: ${{ matrix.single-mode }} shell: python run: | import asyncio import subprocess from edb.server.pgcluster import get_pg_bin_dir async def main(): psql = await get_pg_bin_dir() / "psql" dsn = "postgres://postgres:postgres@localhost/postgres" script = """\ CREATE ROLE singles; ALTER ROLE singles WITH LOGIN PASSWORD 'test' NOSUPERUSER ${{ matrix.single-mode }}; CREATE DATABASE singles OWNER singles; REVOKE ALL ON DATABASE singles FROM PUBLIC; GRANT CONNECT ON DATABASE singles TO singles; GRANT ALL ON DATABASE singles TO singles; """ subprocess.run( [str(psql), dsn], check=True, text=True, input=script, ) asyncio.run(main()) - name: Test env: EDGEDB_TEST_POSTGRES_VERSION: ${{ matrix.postgres-version }} run: | if [[ "${{ matrix.single-mode }}" ]]; then export EDGEDB_TEST_BACKEND_DSN=postgres://singles:test@localhost/singles else export EDGEDB_TEST_BACKEND_DSN=postgres://postgres:postgres@localhost/postgres fi if [[ "${{ matrix.multi-tenant-mode }}" == "remote-compiler" ]]; then export EDGEDB_TEST_REMOTE_COMPILER=localhost:5660 export _EDGEDB_SERVER_COMPILER_POOL_SECRET=secret __EDGEDB_DEVMODE=1 edgedb-server compiler --pool-size 2 & fi edb server --bootstrap-only --backend-dsn=$EDGEDB_TEST_BACKEND_DSN --testmode if [[ "${{ matrix.multi-tenant-mode }}" == "multi-tenant" ]]; then export EDGEDB_SERVER_MULTITENANT_CONFIG_FILE=/tmp/edb.mt.json echo "{\"localhost\":{\"instance-name\":\"localtest\",\"backend-dsn\":\"$EDGEDB_TEST_BACKEND_DSN\",\"admin\":true,\"max-backend-connections\":10}}" > /tmp/edb.mt.json fi if [[ "${{ matrix.single-mode }}" == *"NOCREATEDB"* ]]; then edb test -j1 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN else edb test -j2 -v --backend-dsn=$EDGEDB_TEST_BACKEND_DSN fi workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - test runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows.src/tests.pool.targets.yml ================================================ data: ================================================ FILE: .github/workflows.src/tests.pool.tpl.yml ================================================ <% from "tests.inc.yml" import build, calc_cache_key -%> name: Pool Simulation Test on: push: branches: - master - pool-test paths: - 'edb/server/connpool/**' - 'edb/server/conn_pool/**' - 'tests/test_server_pool.py' - '.github/workflows/tests-pool.yml' pull_request: branches: - master paths: - 'edb/server/connpool/**' - 'edb/server/conn_pool/**' - 'tests/test_server_pool.py' - '.github/workflows/tests-pool.yml' jobs: test: runs-on: ubuntu-latest concurrency: pool-test steps: <%- call build() -%> - name: Compute cache keys run: | << calc_cache_key()|indent >> <%- endcall %> - uses: actions/checkout@v4 if: startsWith(github.ref, 'refs/heads') with: repository: edgedb/edgedb-pool-simulation path: pool-simulation token: ${{ secrets.GITHUB_CI_BOT_TOKEN }} - name: Run the pool simulation test env: PYTHONPATH: . SIMULATION_CI: yes TIME_SCALE: 10 run: | mkdir -p pool-simulation/reports python tests/test_server_pool.py - uses: EndBug/add-and-commit@v7.0.0 if: ${{ always() }} continue-on-error: true with: branch: main cwd: pool-simulation author_name: github-actions author_email: 41898282+github-actions[bot]@users.noreply.github.com ================================================ FILE: .github/workflows.src/tests.reflection.targets.yml ================================================ data: ================================================ FILE: .github/workflows.src/tests.reflection.tpl.yml ================================================ <% from "tests.inc.yml" import build, calc_cache_key, restore_cache -%> name: Tests with reflection validation on: schedule: - cron: "0 3 * * *" workflow_dispatch: inputs: {} push: branches: - "REFL-*" jobs: build: runs-on: ubuntu-latest steps: <%- call build() -%> - name: Compute cache keys env: GIST_TOKEN: ${{ secrets.CI_BOT_GIST_TOKEN }} run: | << calc_cache_key()|indent >> <%- endcall %> test: needs: build runs-on: ubuntu-latest steps: <<- restore_cache() >> # Run the test - name: Test env: EDGEDB_DEBUG_DELTA_VALIDATE_REFLECTION: 1 run: | edb test -j2 -v workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - build - test runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .github/workflows.src/tests.targets.yml ================================================ data: ================================================ FILE: .github/workflows.src/tests.tpl.yml ================================================ <% from "tests.inc.yml" import init, build, calc_cache_key, install_python_requirements, restore_cache -%> name: Tests on: push: branches: - master - ci - "release/*" pull_request: branches: - '**' schedule: - cron: "0 */3 * * *" jobs: build: runs-on: ubuntu-latest steps: <%- call build() -%> - name: Compute cache keys and download the running times log env: GIST_TOKEN: ${{ secrets.CI_BOT_GIST_TOKEN }} run: | << calc_cache_key()|indent >> curl \ -H "Accept: application/vnd.github.v3+json" \ -u edgedb-ci:$GIST_TOKEN \ https://api.github.com/gists/8b722a65397f7c4c0df72f5394efa04c \ | jq '.files."time_stats.csv".raw_url' \ | xargs curl > shared-artifacts/time_stats.csv <%- endcall %> cargo-test: needs: build runs-on: ubuntu-latest steps: << init() >> << install_python_requirements() >> - name: Download cache key uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Generate environment variables run: | echo BUILD_TEMP=$(python setup.py -q ci_helper --type build_temp) >> $GITHUB_ENV - name: Handle Rust extensions build cache uses: actions/cache@v4 id: rust-cache with: path: ${{ env.BUILD_TEMP }}/rust/extensions key: edb-rust-build-v1-${{ hashFiles('shared-artifacts/rust_cache_key.txt') }} - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 - name: Cargo test env: CARGO_TARGET_DIR: ${{ env.BUILD_TEMP }}/rust/extensions CARGO_HOME: ${{ env.BUILD_TEMP }}/rust/extensions/cargo_home run: cargo test --all-features python-test: needs: build runs-on: ubuntu-latest strategy: fail-fast: false matrix: shard: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ] steps: <<- restore_cache() >> # Run the test - name: Install Rust toolchain uses: dsherret/rust-toolchain-file@v1 - name: Test env: SHARD: ${{ matrix.shard }} EDGEDB_TEST_REPEATS: 1 run: | mkdir -p results/ cp shared-artifacts/time_stats.csv results/running_times_${SHARD}.csv edb test --jobs 2 --verbose --shard ${SHARD}/16 \ --running-times-log=results/running_times_${SHARD}.csv \ --result-log=results/result_${SHARD}.json - name: Upload test results uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 if: ${{ always() }} with: name: python-test-results-${{ matrix.shard }} path: results retention-days: 1 python-test-list: needs: build runs-on: ubuntu-latest steps: <<- restore_cache() >> # List tests and upload - name: Generate complete list of tests for verification env: SHARD: ${{ matrix.shard }} EDGEDB_TEST_REPEATS: 1 run: | edb test --list > shared-artifacts/all_tests.txt - name: Upload list of tests uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: name: test-list path: shared-artifacts retention-days: 1 test-conclusion: needs: [cargo-test, python-test, python-test-list] runs-on: ubuntu-latest if: ${{ always() }} steps: - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.12.2' - name: Install Python deps run: | python -m pip install requests click - uses: actions/checkout@v4 with: submodules: false - name: Download python-test results uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: pattern: python-test-results-* merge-multiple: true path: results # Render results and exit if they were unsuccessful - name: Render results run: | python edb/tools/test/results.py 'results/result_*.json' - name: Download shared artifacts uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: shared-artifacts path: shared-artifacts - name: Download test list uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: test-list path: shared-artifacts - name: Merge stats and verify tests completion shell: python env: GIST_TOKEN: ${{ secrets.CI_BOT_GIST_TOKEN }} GIT_REF: ${{ github.ref }} run: | import csv import glob import io import os import requests orig = {} new = {} all_tests = set() with open("shared-artifacts/time_stats.csv") as f: for name, t, c in csv.reader(f): assert name not in orig, "duplicate test name in original stats!" orig[name] = (t, int(c)) with open("shared-artifacts/all_tests.txt") as f: for line in f: assert line not in all_tests, "duplicate test name in this run!" all_tests.add(line.strip()) for new_file in glob.glob("results/running_times_*.csv"): with open(new_file) as f: for name, t, c in csv.reader(f): if int(c) > orig.get(name, (0, 0))[1]: if name.startswith("setup::"): new[name] = (t, c) else: assert name not in new, f"duplicate test! {name}" new[name] = (t, c) all_tests.remove(name) assert not all_tests, "Tests not run! \n" + "\n".join(all_tests) if os.environ["GIT_REF"] == "refs/heads/master": buf = io.StringIO() writer = csv.writer(buf) orig.update(new) for k, v in sorted(orig.items()): writer.writerow((k,) + v) resp = requests.patch( "https://api.github.com/gists/8b722a65397f7c4c0df72f5394efa04c", headers={"Accept": "application/vnd.github.v3+json"}, auth=("edgedb-ci", os.environ["GIST_TOKEN"]), json={"files": {"time_stats.csv": {"content": buf.getvalue()}}}, ) resp.raise_for_status() workflow-notifications: if: failure() && github.event_name != 'pull_request' name: Notify in Slack on failures needs: - test-conclusion runs-on: ubuntu-latest permissions: actions: 'read' steps: - name: Slack Workflow Notification uses: Gamesight/slack-workflow-status@26a36836c887f260477432e4314ec3490a84f309 with: repo_token: ${{secrets.GITHUB_TOKEN}} slack_webhook_url: ${{secrets.ACTIONS_SLACK_WEBHOOK_URL}} name: 'Workflow notifications' icon_emoji: ':hammer:' include_jobs: 'on-failure' ================================================ FILE: .gitignore ================================================ *._* *.pyc *.pyo *.o *.so *.dylib .vscode/ .zed/ .helix/ *~ .#* .*.swp .DS_Store \#*# /test*.py /.local /perf.data* /build /target /tmp __pycache__/ .d8_history /.venv /.eggs /*.egg /*.egg-info /dist /.cache docs/_build /AUTHORS /ChangeLog /tests/dumps/**/*_dev*.dump /edb/_buildmeta.py /.coverage* /htmlcov *.code-workspace /.pytest_cache /.mypy_cache /.vagga /.dmypy.json /compile_commands.json /pyrightconfig.json ================================================ FILE: .gitmodules ================================================ [submodule "postgres"] path = postgres url = https://github.com/geldata/postgres.git ignore = untracked [submodule "edb/server/pgproto"] path = edb/server/pgproto url = https://github.com/MagicStack/py-pgproto.git [submodule "edb/pgsql/parser/libpg_query"] path = edb/pgsql/parser/libpg_query url = https://github.com/geldata/libpg_query.git ================================================ FILE: .mailmap ================================================ Elvis Pranskevichus Yury Selivanov Yury Selivanov Yuri Selivanov Victor Petrovykh Victor Petrovykh Vicor Petrovykh ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at coc@edgedb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ================================================ FILE: CONTRIBUTING.rst ================================================ How to contribute to Gel ======================== Thank you for contributing to Gel! We love our open source community and want to foster a healthy contributor ecosystem. To make sure the project can continue to improve quickly, we have a few guidelines designed to make it easier for your contributions to make it into the project. These are guidelines rather than hard rules. If you want to submit a pull request that strays from these, it might be a good idea to start a discussion about it first. Otherwise, it's possible your pull request might not be merged. All contributions ----------------- - **Avoid making pull requests that do not have an associated Github Issue.** This could be an already existing issue or one you create yourself when you discover the problem. This will allow the team to help you scope your solution, warn you of potential gotchas, or give you a heads-up on solutions that are likely not feasible. It's a good idea to mention in the issue that you'd like to contribute code to resolve the issue. **If you're fixing something trivial like a typo,** an associated issue isn't necessary. - **Write good commit messages.** The subject of your commit message — that's the first line — should tell us *what* you did. The body of your message — that's the rest of it — should tell us *why* you did it (unless that's self-evident). Contributing code -------------------------- - **Pull requests without thorough testing are not likely to be merged.** If you're not sure if yours is well-tested enough, go ahead and submit. We can help guide you to the finish line. Contributing documentation -------------------------- - **Avoid changes that don't fix an obvious mistake or add clarity.** This is subjective, but try to look at your changes with a critical eye. Do they fix errors in the original like misspellings or typos? Do they make existing prose more clear or accessible while maintaining accuracy? If you answered "yes" to either of those questions, this might be a great addition to our docs! If not, consider starting a discussion instead to see if your changes might be the exception to this guideline before submitting. - **Keep commits and pull requests small.** We get it. It's more convenient to throw all your changes into a single pull request or even into a single commit. The problem is that, if some of the changes are good and others don't quite work, having everything in one bucket makes it harder to filter out the great changes from those that need more work. - **Make spelling and grammar fixes in a separate pull request from any content changes.** These changes are quick to check and important to anyone reading the docs. We want to make sure they hit the live documentation as quickly as possible without being bogged down by other changes that require more intensive review. Please see Gel's guide for `building documentation `_ from source. Documentation style ~~~~~~~~~~~~~~~~~~~ - **Lines should be no longer than 79 characters.** - **Remove trailing whitespace or whitespace on empty lines.** - **Surround references to parameter named with asterisks.** You may be tempted to surround parameter names with double backticks (````param````). We avoid that in favor of ``*param*``, in order to distinguish between parameter references and inline code (which *should* be surrounded by double backticks). - **Gel is singular.** Choose "Gel is" over "Gel are" and "Gel does" over "Gel do." - **Use American English spellings.** Choose "color" over "colour" and "organize" over "organise." - **Use the Oxford comma.** When delineating a series, place a comma between each item in the series, even the one with the conjunction. Use "eggs, bacon, and juice" rather than "eggs, bacon and juice." - **Write in the simplest prose that is still accurate and expresses everything you need to convey.** You may be tempted to write documentation that sounds like a computer science textbook. Sometimes that's necessary, but in most cases, it isn't. Prioritize accuracy first and accessibility a close second. - **Be careful using words that have a special meaning in the context of Gel.** In casual speech or writing, you might talk about a "set" of something in a generic sense. Using the word this way in Gel documentation might easily be interpreted as a reference to Gel's `sets `. Avoid this kind of casual usage of key terms. ================================================ FILE: Cargo.toml ================================================ [workspace] members = [ "edb/edgeql-parser", "edb/edgeql-parser/edgeql-parser-derive", "edb/edgeql-parser/edgeql-parser-python", "edb/graphql-rewrite", "edb/server/_rust_native", "rust/conn_pool", "rust/gel-http", "rust/pgrust", "rust/pyo3_util", ] resolver = "2" [workspace.dependencies] pyo3 = { version = "0.26", features = ["extension-module", "serde", "macros"] } tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "time", "sync", "net", "io-util"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.20", features = ["registry", "env-filter"] } gel-auth = { version = "=0.1.6" } gel-stream = { version = "=0.4.3" } gel-protocol = { version = "=0.8.5" } gel-jwt = { version = "=0.1.4" } gel-db-protocol = { version = "=0.1.2" } gel-pg-protocol = { version = "=0.1.1" } gel-pg-captive = { version = "=0.1.1" } gel-dsn = { version = "=0.2.14" } conn_pool = { path = "rust/conn_pool" } pgrust = { path = "rust/pgrust" } gel-http = { path = "rust/gel-http" } pyo3_util = { path = "rust/pyo3_util" } [profile.release] debug = true lto = true [workspace.lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(never)'] } [patch.crates-io] openssl-probe = { git = "https://github.com/edgedb/openssl-probe/", rev = "e5ed593600d1f8128629565d349682f54b3a8b57" } ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ recursive-include edb *.edgeql *.esdl *.py *.txt recursive-include tests *.edgeql *.esdl *.py include LICENSE README.md logo.svg recursive-include edb/edgeql-parser * recursive-include edb/edgeql-parser/edgeql-parser-python * recursive-include edb/server/protocol/auth_ext/_static * ================================================ FILE: Makefile ================================================ .PHONY: build docs cython postgres postgres-ext pygments build-reqs .DEFAULT_GOAL := build SPHINXOPTS:="-W -n" BUILD_REQS_SCRIPT='print("\x00".join(__import__("build").ProjectBuilder(".").build_system_requires))' build-reqs: python -m pip install --no-build-isolation build python -c $(BUILD_REQS_SCRIPT) | xargs -0 python -m pip install --no-build-isolation cython: build-reqs find edb -name '*.pyx' | xargs touch BUILD_EXT_MODE=py-only python setup.py build_ext --inplace # Just rebuild actually changed cython. This *should* work, since # that is how build systems are supposed to be, but it sometimes # fails in annoying ways. cython-fast: build-reqs BUILD_EXT_MODE=py-only python setup.py build_ext --inplace rust: build-reqs BUILD_EXT_MODE=rust-only python setup.py build_ext --inplace cli: build-reqs python setup.py build_cli docs: build-reqs find docs -name '*.rst' | xargs touch $(MAKE) -C docs html SPHINXOPTS=$(SPHINXOPTS) BUILDDIR="../build" postgres: build-reqs python setup.py build_postgres parsers: python setup.py build_parsers --inplace libpg-query: python setup.py build_libpg_query ui: build-reqs python setup.py build_ui pygments: build-reqs out=$$(edb gen-meta-grammars edgeql) && \ echo "$$out" > edb/tools/pygments/edgeql/meta.py casts: build-reqs out=$$(edb gen-cast-table) && \ echo "$$out" > docs/reference/edgeql/casts.csv build: build-reqs find edb -name '*.pyx' | xargs touch pip install --upgrade --editable .[docs,test,language-server] clean: git clean -Xfd -e "!/*.code-workspace" -e "!/*.vscode" ================================================ FILE: NOTICE ================================================ EdgeDB Copyright 2008-present EdgeDB Inc. This product includes software developed by EdgeDB Inc (https://www.edgedb.com/). ================================================ FILE: README.md ================================================

Gel

Stars license discord

Learn: build an app with Gel   •   Website   •   Docs   •   Blog   •   Discord   •   Twitter



What is Gel?

Gel is a new kind of database
that takes the best parts of
relational databases, graph
databases, and ORMs. We call it
a graph-relational database.



🧩 Types, not tables 🧩


Schema is the foundation of your application. It should be something you can read, write, and understand. Forget foreign keys; tabular data modeling is a relic of an older age, and it [isn't compatible](https://en.wikipedia.org/wiki/Object%E2%80%93relational_impedance_mismatch) with modern languages. Instead, Gel thinks about schema the same way you do: as **object types** containing **properties** connected by **links**. ```esdl type Person { required name: str; } type Movie { required title: str; multi actors: Person; } ``` This example is intentionally simple, but Gel supports everything you'd expect from your database: a strict type system, indexes, constraints, computed properties, stored procedures...the list goes on. Plus it gives you some shiny new features too: link properties, schema mixins, and best-in-class JSON support. Read the [schema docs](https://docs.geldata.com/reference/datamodel) for details.

🌳 Objects, not rows 🌳


Gel's super-powered query language EdgeQL is designed as a ground-up redesign of SQL. EdgeQL queries produce rich, structured objects, not flat lists of rows. Deeply fetching related objects is painless...bye, bye, JOINs. ```esdl select Movie { title, actors: { name } } filter .title = "The Matrix" ``` EdgeQL queries are also _composable_; you can use one EdgeQL query as an expression inside another. This property makes things like _subqueries_ and _nested mutations_ a breeze. ```esdl insert Movie { title := "The Matrix Resurrections", actors := ( select Person filter .name in { 'Keanu Reeves', 'Carrie-Anne Moss', 'Laurence Fishburne' } ) } ``` There's a lot more to EdgeQL: a comprehensive standard library, computed properties, polymorphic queries, `with` blocks, transactions, and much more. Read the [EdgeQL docs](https://docs.geldata.com/reference/edgeql) for the full picture.

🦋 More than a mapper 🦋


While Gel solves the same problems as ORM libraries, it's so much more. It's a full-fledged database with a [powerful and elegant query language](https://docs.geldata.com/reference/edgeql), a [migrations system](https://docs.geldata.com/learn/migrations), a [suite of client libraries](https://docs.geldata.com/reference/clients) in different languages, a [command line tool](https://docs.geldata.com/learn/cli), and a managed [cloud service](https://geldata.com/cloud). The goal is to rethink every aspect of how developers model, migrate, manage, and query their database. Here's a taste-test of Gel's next-level developer experience: you can install our CLI, spin up an instance, and open an interactive EdgeQL shell with just three commands. ``` $ curl --proto '=https' --tlsv1.2 -sSf https://geldata.com/sh | sh $ edgedb project init $ edgedb edgedb> select "Hello world!" ``` Windows users: use this Powershell command to install the CLI. ``` PS> iwr https://geldata.com/ps1 -useb | iex ```
## Get started To start learning about Gel, check out the following resources: - **[The quickstart](https://docs.geldata.com/learn/quickstart/overview/nextjs)**. If you're just starting out, the 10-minute quickstart guide is the fastest way to get up and running. - **[Gel Cloud 🌤️](https://www.geldata.com/cloud)**. The best most effortless way to host your Gel database in the cloud. - **The docs.** Jump straight into the docs for [schema modeling](https://docs.geldata.com/reference/datamodel) or [EdgeQL](https://docs.geldata.com/reference/edgeql)!
## Contributing PRs are always welcome! To get started, follow [this guide](https://docs.geldata.com/resources/guides/contributing) to build Gel from source on your local machine. [File an issue 👉](https://github.com/geldata/gel/issues/new/choose)
[Start a Discussion 👉](https://github.com/geldata/gel/discussions/new)
[Join the discord 👉](https://discord.gg/gel)
## License The code in this repository is developed and distributed under the Apache 2.0 license. See [LICENSE](LICENSE) for details. ================================================ FILE: build_backend.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # This is a straight proxy to setuptools.build_meta backend that exists # solely because someone thought that in-tree build dependencies should # require this. from setuptools.build_meta import * # noqa ================================================ FILE: dev-notes/concurrent-indexes.py ================================================ #!/usr/bin/env python3 import gel def create_concurrent_indexes(db, msg_callback=print): '''Actually create all "create concurrently" indexes The protocol here is to find all the indexes that need created, and create them with `administer concurrent_index_build()`. It's possible that the database will shut down after an index creation but before the metadata is updated, in which case we might rerun the command later, which is harmless. If we stick with this ADMINISTER-based schemed, I figure this code would live in the CLI. ''' indexes = db.query(''' select schema::Index { id, expr, subject_name := .{index.id}") ''') def main(): with gel.create_client() as db: create_concurrent_indexes(db) if __name__ == '__main__': main() ================================================ FILE: dev-notes/inplace-upgrades.md ================================================ The inplace upgrade system adds three new flags to edgedb-server. They may (though probably usually won't) be specified together. If any of them is specified, the server will exit after performing the in-place upgrade operations instead of starting up. * ``-inplace-upgrade-prepare `` -- "prepare" an inplace upgrade, using schema information provided in ````. (More about this later.) This will create the new standard library (in a namespace), populate the schema tables with user schemas, and prepare (but not execute) any irreversible scripts for updating the standard library trampolines and fixing up user-defined functions. This operation should not do anything that cannot be backed out. It may be run while an older version of the server is still live. If this is interrupted, crashes, or fails, it *will leave a partially prepared database*. To deal with this, see the next command. The file should be in the format produced by ``tests/inplace-testing/prep-upgrades.py``: a JSON object where the keys are branch names and the values are the results of executing ``administer prepare_upgrade()``. * ``-inplace-upgrade-rollback`` -- Rolls back a prepared upgrade. This works by deleting everything in the newly created schemas. It can rollback partially prepared upgrades. It may be run while an older version of the server is still live. * ``-inplace-upgrade-finalize`` -- Finalizes a prepared upgrade by fully flipping the database to the new version. This flips standard library trampolines, patches user-defined functions, and deletes the old standard library. The old version must not be running. (Though there is not a clear way to enforce this.) Finalize does a dry run of each branch's upgrade inside a transaction before making any changes. If this fails, the upgrade may be broken (due to a bug or an incompatibility), and it may still be rolled back. If finalize fails *after* the dry run, once it has started actually finalizing branches, then it *may not* be rolled back. Because all of the upgrades were tested in (reverted) transactions, this *should* only happen in the case of interruption or postgres crash, and it should be safe to retry the finalize. If finalize emitted a message of the form "Finished pivoting branch ''", then the upgrade may not be rolled back; the only way out is through. Rollback will refuse to operate in this case. ----- Suggested procedure: 0.5. Make a backup 1. ``edgedb query 'configure instance set force_database_error := $${"type": "AvailabilityError", "message": "DDL is disabled due to in-place upgrade.", "_scopes": ["ddl"]}$$;'``. This will disable all DDL commands to the database, while leaving it running for both read and write queries. 2. ``tests/inplace-testing/prep-upgrades.py > "upgrade.json"``. This will dump the information needed for upgrade. 3. ``edgedb-server --backend-dsn="$DSN" --inplace-upgrade-prepare upgrade.json``. This will prepare the upgrade. 4. Stop the old edgedb server. 4.5. Make a backup 5. ``edgedb-server --backend-dsn="$DSN" --inplace-upgrade-finalize``. This will finalize the upgrade. 6. Start the new server. 7. ``edgedb query 'configure instance reset force_database_error'`` If there is a failure in step 3 or step 5 *before* a branch has finished pivoting, then it can be rolled back with ``edgedb-server --backend-dsn="$DSN" --inplace-upgrade-rollback``. If there is a failure after a branch has been pivoted, then there is nothing to do but retry it. (And restore from a backup if that doesn't work. That would be a bug, and one that has slipped past at least one line of defence.) ---- Testing notes: Currently, we can only inplace upgrade beween full major versions, since we use the major version number to distinguish between the namespaced stdlibs. For testing inplace upgrades, we have a test that applies a patch that bumps the major version number and catalog. TODO: Maybe we should use the catalog number instead, which will make it easier to test between different nightlies. ================================================ FILE: dev-notes/newtype-checklist.md ================================================ This is a checklist of steps needed to add a new type to EdgeDB, along with links to examples of PRs doing the tasks. Core database range PRs: * https://github.com/edgedb/edgedb/pull/3983 * https://github.com/edgedb/edgedb/pull/4020 Core database cal::duration PRs: * https://github.com/edgedb/edgedb/pull/3948 - [ ] JSON handling - [ ] Implement JSON casts if the default Postgres behavior won't work - [ ] Update output.serialize_expr_to_json(), if the default won't work * range: https://github.com/edgedb/edgedb/pull/4008 - [ ] If any new functions or constructors have an implementation that is not just purely a call to a strict function, make sure to test with inputs that are NULL at runtime! Probably the easiest way to generate NULL-at-runtime values is `$0` and then passing in `{}`. * range test example and bugfix: https://github.com/edgedb/edgedb/pull/4207/ - [ ] For compound types, add a schema class in edb/schema/types.py and - [ ] Add mapping to pgsql types in edb/pgsql/types.py - [ ] Add implementations of any relevant functions/operations to `edb/lib`. - [ ] For compound types, add a type descriptor and code for encoding it in edb/server/compiler/sertypes.py. * range: https://github.com/edgedb/edgedb/pull/4016 - [ ] For new scalar types, add it to edb/api/types.txt and edb/graphql/types.py. Run `edb gen-types`. - [ ] Update all of the first-party language drivers (or get their owners to) - [ ] Python (Fantix/Elvis/Sully) * cal::date_duration: https://github.com/edgedb/edgedb-python/pull/335 * range: https://github.com/edgedb/edgedb-python/pull/332/ - [ ] Go (Frederick) * cal::date_duration: https://github.com/edgedb/edgedb-go/pull/232 - [ ] Javascript (James/Colin) * cal::date_duration: https://github.com/edgedb/edgedb-js/pull/373/ * range: https://github.com/edgedb/edgedb-js/pull/377 - [ ] Rust/CLI (Paul) * This requires updating both the Rust bindings to support the new type and the CLI to properly print it * cal::date_duration: https://github.com/edgedb/edgedb-rust/pull/146, https://github.com/edgedb/edgedb-cli/pull/759 * range: https://github.com/edgedb/edgedb-rust/pull/145, https://github.com/edgedb/edgedb-cli/pull/755 - [ ] Add a field of the new type to the `dump` test for the new version - [ ] Write tests. ================================================ FILE: dev-notes/release-process.md ================================================ # Instructions for releasing a new version Deprecates release instruction from [RFC 2](https://github.com/edgedb/rfcs/blob/master/text/0002-edgedb-release-process.rst). EdgeDB packages are published on https://packages.edgedb.com. They are build in GitHub Actions pipelines, using https://github.com/edgedb/edgedb-pkg. Releases are built from a release branch associated with a major version (i.e. "release/4.x"). At feature freeze, we create this branch. From that moment on, all additional commits will have to be cherry-picked to this branch. Before the major version, we publish "testing releases": - "alpha" (i.e. `v4.0a1`, `v4.0a2`), - "beta" (i.e. `v4.0b1`, `v4.0b2`), - "release candidates" (`v4.0rc1`) that we might promote into the final release. ## Internal Communication Announce on team slack when you are beginning to prepare a release, when a release build has been kicked off, and when the release has succeeded. Update the thread with any problems and attempted resolutions. Communicate in the other direction as well: make sure the release manager knows of any pending work that you want in a release. "b1", "rc1", and ".0" releases are big deals. Make sure to get signoff before releasing. ## edgedb-ui On release branches, `edgedb-ui` should be pinned to the associated branch. This can be done in `setup.py` with the variable `EDGEDBGUI_COMMIT`. For example, on branch `release/4.x`, it is pinned to `edgedb-ui`'s branch `4.x`. This means any release off `release/4.x` will contain latest commits from `edgedb-ui`'s branch `4.x`. ## Preparing commits for a release For each major release `N`, we have two GitHub labels: `to-backport-N.x` and `backported-N.x`. PRs that need to be backported should be labelled with `to-backport-N.x` for each of the target versions. Once a PR is backported, `to-backport-N.x` should be removed and `backported-N.x` added. Tracking both states makes it easy to tell what needs to be backported and what has been backported. (Historical note: previously we had simply a `backport-N.x` label. This made it easy to ensure that everything that got labelled with `backport` actually got backported, but there was not an at-a-glance way to see if something *had* been backported. Even looking at the issue didn't always tell you, since sometimes we labelled things as `backport` and then thought better of it.) ### Technical helpers The `gh` command line makes a bunch of these operations simple. To enumerate all pending backports for a branch: ```bash gh pr list --state all -l to-backport-5.x ``` To adjust labels to mark a PR as backported: ```bash gh pr edit --remove-label to-backport-N.x --add-label backported-N.x ``` A helper shell script to cherry-pick a commit using its PR number: ```bash # this won't work if a PR is not squashed into a single commit function cp-pr { git cherry-pick $(gh pr view $1 --json mergeCommit --jq .mergeCommit.oid) } ``` ### What to backport? Sometimes, people will forget to label the PR to be back-ported, so a good practice is to list all commits since the last release: ``` git show releases/4.x # to see the last commit that has been cherry-picked git log master # find the hash of that commit on master git log hash_of_that_commit..master > ../to-backport.txt ``` Now, one can go through the list and see if the commits are worth back-porting. A few pointers: - Don't backport new features, unless it is high-priority for some reason. - Don't backport docs, since the website is built from master. - Don't backport refactors, since they might introduce bugs and there is no point in improving the codebase of a branch we are not developing on anymore. Disregard this rule early on after the fork of the release branch, since porting refactors will decrease chances merge conflicts of other commits later on. - Don't backport "build" commits (updating of build deps, refactoring of the release pipeline), since that might trigger problems in the release process. - If a PR changes: - any of the schema objects (i.e. adding a field to `s_types.Type`) or - a std library object (i.e. changing implementation of `std::round`), - metaschema (i.e. changing a pg function `edgedb.range_to_jsonb`), ... a "patch" needs to be added into `pgsql/patches.py`. This is needed, because minor releases don't require a "dump and restore", so we must apply these changes to existing user databases. Patches must be tested using this GHA workflow: https://github.com/edgedb/edgedb/actions/workflows/tests-patches.yml ## Release pipeline When you have your commits ready, tag the commit and push: ``` # git tag --sign v4.5 # git push origin releases/4.x --follow-tags ``` Then open GitHub Actions page and run one of these pipelines: - https://github.com/edgedb/edgedb/actions/workflows/testing.yml - https://github.com/edgedb/edgedb/actions/workflows/release.yml This will kick-off an GHA workflow that should take ~3 hours. It will build, test and publish for each of the supported platforms. It will not publish any packages if any of the tests fail. Sometimes, tests will be flakey and just need to be re-run. You can do that with a button top-right. ## Changelog Each major release has a changelog page in the docs (i.e. `docs/changelog/4_x.rst`). It should contain explanations of the new features, which are usually composed by our dev-rel team. Each minor release is just a subsection in the page, as a list of back-ported PRs. Any PRs that fix internal stuff (like our test framework) or are not user facing should not be included in the changelog. Don't forget to include commits released from `edgedb-ui`. These changes need to land on master branch and are not needed on the release branch, so best course of action if to open a PR to master after kicking off the release pipeline. After that PR is merged, the website needs to be deployed, for changelog to land on the website (ping dev rel team). A helper function to generate changelog is: ```python # I keep this in ../compose-changelog.py import json import requests import re import sys BASE_URL = 'https://api.github.com/repos/edgedb/edgedb/compare' def main(): if len(sys.argv) < 2: print('pass a sha1 hash as a first argument') sys.exit(1) from_hash = sys.argv[1] if len(sys.argv) > 2: to_hash = sys.argv[2] r = requests.get(f'{BASE_URL}/{from_hash}...{to_hash}') data = json.loads(r.text) for commit in data['commits']: message = commit['commit']['message'] first_line = message.partition('\n\n')[0] if commit.get('author'): username = '@{}'.format(commit['author']['login']) else: username = commit['commit']['author']['name'] sha = commit["sha"][:8] m = re.search(r'\#(?P\d+)\b', message) if m: issue_num = m.group('num') else: issue_num = None first_line = re.sub(r'\(\#(?P\d+)\)', '', first_line) print(f'* {first_line}') # print(f' (by {username} in {sha}', end='') if issue_num: print(f' (:eql:gh:`#{issue_num}`)') print() if __name__ == '__main__': main() ``` ```bash python ../compose-changelog.py v4.5 v4.6 >> docs/changelog/4_x.rst ``` ## After the release The release pipelines will make the new version available at https://packages.edgedb.com. This is enough for it to be installable using the CLI, but other methods of installation need to be kicked of manually: - our cloud team needs to deploy separate _cloud wizardly groups_, - docker image needs to be published to https://hub.docker.com, - Digital Ocean image needs to be published by Frederick, ================================================ FILE: docs/.gitignore ================================================ /_build ================================================ FILE: docs/Makefile ================================================ # Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) endif # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " applehelp to make an Apple Help Book" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " xml to make Docutils-native XML files" @echo " pseudoxml to make pseudoxml-XML files for display purposes" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" @echo " coverage to run coverage check of the documentation (if enabled)" clean: rm -rf $(BUILDDIR)/* html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/EdgeDB.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/EdgeDB.qhc" applehelp: $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp @echo @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." @echo "N.B. You won't be able to view it unless you put it in" \ "~/Library/Documentation/Help or install it in your application" \ "bundle." devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/EdgeDB" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/EdgeDB" @echo "# devhelp" epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." latexpdfja: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through platex and dvipdfmx..." $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." coverage: $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage @echo "Testing of coverage in the sources finished, look at the " \ "results in $(BUILDDIR)/coverage/python.txt." xml: $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml @echo @echo "Build finished. The XML files are in $(BUILDDIR)/xml." pseudoxml: $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml @echo @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." ================================================ FILE: docs/cloud/cli.rst ================================================ .. _ref_guide_cloud_cli: === CLI === :edb-alt-title: Using Gel Cloud via the CLI To use |Gel| Cloud via the CLI, first log in using :ref:`ref_cli_gel_cloud_login`. .. note:: This is the way you'll log in interactively on your development machine, but when interacting with Gel Cloud via a script or in CI, you'll instead set the :gelenv:`SECRET_KEY` environment variable to your secret key. Generate a secret key in the Gel Cloud UI or by running :ref:`ref_cli_gel_cloud_secretkey_create`. The :gelcmd:`cloud login` and :gelcmd:`cloud logout` commands are not intended for use in this context. Once your login is successful, you will be able to create an instance using either :ref:`ref_cli_gel_instance_create` or :ref:`ref_cli_gel_project_init`, depending on whether you also want to create a local project linked to your instance. * :ref:`ref_cli_gel_instance_create` with an instance name of ``/``. .. code-block:: bash $ gel instance create / * :ref:`ref_cli_gel_project_init` with the ``--server-instance`` option. Set the server instance name to ``/``. .. code-block:: bash $ gel project init \ --server-instance / Alternatively, you can run :gelcmd:`project init` *without* the ``--server-instance`` option and enter an instance name in the ``/`` format when prompted interactively. .. note:: Please be aware of the following restrictions on |Gel| Cloud instance names: * can contain only Latin alpha-numeric characters or ``-`` * cannot start with a dash (``-``) or contain double dashes (``--``) * maximum instance name length is 61 characters minus the length of your organization name (i.e., length of organization name + length of instance name must be fewer than 62 characters) To use :gelcmd:`instance create`: .. code-block:: bash $ gel instance create / To use :gelcmd:`project init`: .. code-block:: bash $ gel project init \ --server-instance / Alternatively, you can run :gelcmd:`project init` *without* the ``--server-instance`` option and enter an instance name in the ``/`` format when prompted interactively. ================================================ FILE: docs/cloud/deploy/fly.rst ================================================ .. _ref_guide_cloud_deploy_fly: ====== Fly.io ====== :edb-alt-title: Deploying applications built on Gel Cloud to Fly.io 1. Install the `Fly.io CLI `_ 2. Log in to Fly.io with ``flyctl auth login`` 3. Run ``flyctl launch`` to create a new app on Fly.io and configure it. It will ask you to select a region and a name for your app. When done it will create a ``fly.toml`` file and a ``Dockerfile`` in your project directory. 4. Set :gelenv:`INSTANCE` and :gelenv:`SECRET_KEY` as secrets in your Fly.io app. For **runtime secrets**, you can do this by running the following commands: .. code-block:: bash $ flyctl secrets set GEL_INSTANCE $ flyctl secrets set GEL_SECRET_KEY `Read more about Fly.io runtime secrets `_. For **build secrets**, you can do this by modifying the ``Dockerfile`` to mount the secrets as environment variables. .. code-block:: dockerfile-diff :caption: Dockerfile # Build application - RUN pnpm run build + RUN --mount=type=secret,id=GEL_INSTANCE \ + --mount=type=secret,id=GEL_SECRET_KEY \ + GEL_INSTANCE="$(cat /run/secrets/GEL_INSTANCE)" \ + GEL_SECRET_KEY="$(cat /run/secrets/GEL_SECRET_KEY)" \ + pnpm run build `Read more about Fly.io build secrets `_. 5. Deploy your app to Fly.io .. code-block:: bash $ flyctl deploy If your app requires build secrets, you can pass them as arguments to the ``deploy`` command: .. code-block:: bash $ flyctl deploy --build-secret GEL_INSTANCE="" \ --build-secret GEL_SECRET_KEY="" ================================================ FILE: docs/cloud/deploy/index.rst ================================================ .. _ref_guide_cloud_deploy: ============= Deploy an app ============= :edb-alt-title: Deploying applications built on Gel Cloud For your production deployment, generate a dedicated secret key for your instance with :ref:`ref_cli_gel_cloud_secretkey_create` or via the web UI's "Secret Keys" pane in your instance dashboard. Create two environment variables accessible to your production application: * :gelenv:`SECRET_KEY`- contains the secret key you generated * :gelenv:`INSTANCE`- the name of your Gel Cloud instance (``/``) If you use one of these platforms, try the platform's guide for platform-specific instructions: .. toctree:: :maxdepth: 1 vercel netlify fly railway render ================================================ FILE: docs/cloud/deploy/netlify.rst ================================================ .. _ref_guide_cloud_deploy_netlify: ======= Netlify ======= :edb-alt-title: Deploying applications built on Gel Cloud to Netlify .. note:: This guide assumes the Git deployment method on Netlify, but you may also deploy your site using other methods. Just make sure the Gel Cloud environment variables are set, and your app should have connectivity to your instance. 1. Push project to GitHub or some other Git remote repository 2. Create and make note of a secret key for your Gel Cloud instance 3. On your Netlify Team Overview view under Sites, click Import from Git 4. Import your project's repository 5. Configure the build settings appropriately for your app 6. Click the Add environment variable button 7. Use the New variable button to add two variables: - :gelenv:`INSTANCE` containing your Gel Cloud instance name (in ``/`` format) - :gelenv:`SECRET_KEY` containing the secret key you created and noted previously. 8. Click Deploy .. image:: images/cloud-netlify-config.png :width: 100% :alt: A screenshot of the Netlify deployment configuration view highlighting the environment variables section where a user will need to set the necessary variables for Gel Cloud instance connection. ================================================ FILE: docs/cloud/deploy/railway.rst ================================================ .. _ref_guide_cloud_deploy_railway: ======= Railway ======= :edb-alt-title: Deploying applications built on Gel Cloud to Railway 1. Push project to GitHub or some other Git remote repository 2. Create and make note of a secret key for your Gel Cloud instance 3. From Railway's dashboard, click the "New Project" button 4. Select the repository you want to deploy 5. Click the "Add variables" button to add the following environment variables: - :gelenv:`INSTANCE` containing your Gel Cloud instance name (in ``/`` format) - :gelenv:`SECRET_KEY` containing the secret key you created and noted previously. 6. Click "Deploy" .. image:: images/cloud-railway-config.png :width: 100% :alt: A screenshot of the Railway deployment configuration view highlighting the environment variables section where a user will need to set the necessary variables for Gel Cloud instance connection. ================================================ FILE: docs/cloud/deploy/render.rst ================================================ .. _ref_guide_cloud_deploy_render: ====== Render ====== :edb-alt-title: Deploying applications built on Gel Cloud to Render 1. Push project to GitHub or some other Git remote repository 2. Create and make note of a secret key for your Gel Cloud instance 3. From Render's dashboard, click "New > Web Service" 4. Import your project's repository 5. In the setup page, scroll down to the "Environment Variables" section and add the following environment variables: - :gelenv:`INSTANCE` containing your Gel Cloud instance name (in ``/`` format) - :gelenv:`SECRET_KEY` containing the secret key you created and noted previously. 6. Click Deploy .. image:: images/cloud-render-config.png :width: 100% :alt: A screenshot of the Render deployment configuration view highlighting the environment variables section where a user will need to set the necessary variables for Gel Cloud instance connection. ================================================ FILE: docs/cloud/deploy/vercel.rst ================================================ .. _ref_guide_cloud_deploy_vercel: ====== Vercel ====== :edb-alt-title: Deploying applications built on Gel Cloud to Vercel 1. Push project to GitHub or some other Git remote repository 2. Create and make note of a secret key for your Gel Cloud instance 3. From Vercel's Overview tab, click Add New > Project 4. Import your project's repository 5. In "Configure Project," expand "Environment Variables" to add two variables: - :gelenv:`INSTANCE` containing your Gel Cloud instance name (in ``/`` format) - :gelenv:`SECRET_KEY` containing the secret key you created and noted previously. 6. Click Deploy .. image:: images/cloud-vercel-config.png :width: 100% :alt: A screenshot of the Vercel deployment configuration view highlighting the environment variables section where a user will need to set the necessary variables for |Gel| Cloud instance connection. ================================================ FILE: docs/cloud/http_gql.rst ================================================ .. _ref_guide_cloud_http_gql: =================== HTTP & GraphQL APIs =================== :edb-alt-title: Querying Gel Cloud over HTTP and GraphQL Using |Gel| Cloud via HTTP and GraphQL works the same as :ref:`using any other |Gel| instance `. The two differences are in **how to discover your instance's URL** and **authentication**. Enabling ======== |Gel| Cloud can expose an HTTP endpoint for EdgeQL queries. Since HTTP is a stateless protocol, no :ref:`DDL ` or :ref:`transaction commands `, can be executed using this endpoint. Only one query per request can be executed. In order to set up HTTP access to the database add the following to the schema: .. code-block:: sdl using extension edgeql_http; Then create a new migration and apply it using :ref:`ref_cli_gel_migration_create` and :ref:`ref_cli_gel_migrate`, respectively. Your instance can now receive EdgeQL queries over HTTP at ``https://:/branch//edgeql``. Instance URL ============ To determine the URL of a |Gel| Cloud instance, find the host by running :gelcmd:`instance credentials -I /`. Use the ``host`` and ``port`` from that table in the URL format above this note. Change the protocol to ``https`` since Gel Cloud instances are secured with TLS. Your instance can now receive EdgeQL queries over HTTP at ``https://:/branch//edgeql``. Authentication ============== To authenticate to your |Gel| Cloud instance, first create a secret key using the Gel Cloud UI or :ref:`ref_cli_gel_cloud_secretkey_create`. Use the secret key as your token with the bearer authentication method. Here is an example showing how you might send the query ``select Person {*};`` using cURL: .. lint-off .. code-block:: bash $ curl -G https://:/branch/main/edgeql \ -H "Authorization: Bearer \ --data-urlencode "query=select Person {*};" .. lint-on Usage ===== Usage of the HTTP and GraphQL APIs is identical on a |Gel| Cloud instance. Reference the HTTP and GraphQL documentation for more information. HTTP ---- - :ref:`Overview ` - :ref:`ref_edgeql_protocol` - :ref:`ref_edgeql_http_health_checks` GraphQL ------- - :ref:`Overview ` - :ref:`ref_graphql_overview` - :ref:`ref_graphql_mutations` - :ref:`ref_graphql_introspection` - :ref:`ref_cheatsheet_graphql` ================================================ FILE: docs/cloud/index.rst ================================================ .. _ref_guide_cloud: ===== Cloud ===== :edb-alt-title: Using Gel Cloud .. toctree:: :maxdepth: 2 :hidden: cli web http_gql deploy/index deploy/vercel deploy/netlify deploy/fly deploy/render deploy/railway migrate_from |Gel| Cloud is a fully managed, effortless cloud database service, engineered to let you deploy your database instantly and connect from anywhere with near-zero configuration. Connecting your app =================== Try a guide for connecting your app running on your platform of choice: .. TODO: render these with icons * :ref:`Vercel ` * :ref:`Netlify ` * :ref:`Fly.io ` * :ref:`Render ` * :ref:`Railway ` To connect your apps running on other platforms, generate a dedicated secret key for your instance with :gelcmd:`cloud secretkey create` or via the web UI's “Secret Keys” pane in your instance dashboard. Create two environment variables accessible to your production application: * :gelenv:`SECRET_KEY` - contains the secret key you generated * :gelenv:`INSTANCE` - the name of your |Gel| Cloud instance (``/``) Two ways to use Gel Cloud ========================= 1. CLI ^^^^^^ Log in to |Gel| Cloud via the CLI: .. code-block:: bash $ gel cloud login This will open a browser window and allow you to log in via GitHub. Now, create your |Gel| Cloud instance the same way you would create a local instance: .. code-block:: bash $ gel instance create / or .. code-block:: bash $ gel project init \ --server-instance / 2. GUI ^^^^^^ Create your instance at `cloud.geldata.com `_ by clicking on “Create new instance” in the “Instances” tab. ..
Complete the following form to configure your instance. You can access your instance via the CLI using the name ``/`` or via the GUI. Useful Gel Cloud commands ========================= Get REPL ^^^^^^^^ .. code-block:: bash $ gel \ -I / Run migrations ^^^^^^^^^^^^^^ .. code-block:: bash $ gel migrate \ -I / Update your instance ^^^^^^^^^^^^^^^^^^^^ .. code-block:: bash $ gel instance upgrade \ --to-version \ -I / Manual full backup ^^^^^^^^^^^^^^^^^^ .. code-block:: bash $ gel dump \ --all --format dir \ -I / \ Full restore ^^^^^^^^^^^^ .. code-block:: bash $ gel restore \ --all \ -I / \ .. note:: Restoring works only to an empty database. Questions? Problems? Bugs? ========================== Thank you for helping us make the best way to host your |Gel| instances even better! * Please join us on `our Discord `_ to ask questions. * If you're experiencing a service interruption, check `our status page `_ for information on what may be causing it. * Report any bugs you find by `submitting a support ticket `_. Note: when using |Gel| Cloud through the CLI, setting the ``RUST_LOG`` environment variable to ``info``, ``debug``, or ``trace`` may provide additional debugging information which will be useful to include with your ticket. ================================================ FILE: docs/cloud/migrate_from.rst ================================================ .. _ref_migrate_from: ======================================= Migrating from Gel Cloud to Self-Hosted ======================================= :edb-alt-title: Migrating from Gel Cloud to Self-Hosted Gel |Gel| Cloud is sunsetting at the end of January 2026. To ensure your applications continue to run smoothly, you should migrate your data to a self-hosted |Gel| instance. This guide outlines the process of migrating your production data. We strongly recommend performing a "dry run" with your staging or development environment first to familiarize yourself with the workflow. Phase 1: Preparation ==================== 1. Spin up your self-hosted deployment -------------------------------------- While you can host |Gel| on any infrastructure that supports Docker or binary installations, we recommend using a managed cloud provider (such as Fly.io, AWS, or GCP with managed Postgres) for production reliability. See our :ref:`self-hosted deployment guides ` for step-by-step instructions on deploying to various platforms. Ensure your new instance is: * Running the same version of |Gel| as your Cloud instance (or newer). * Configured with a persistent volume for data or connected to a managed Postgres instance. 2. Retrieve connection parameters --------------------------------- Once your new instance is live, you need its DSN (Data Source Name) or individual connection parameters. Each :ref:`deployment guide ` outlines the best way to retrieve the various connection parameters from your specific setup. You will typically need: * **Host**: The domain or IP of your new instance. * **Port**: Default is ``5656``. * **User**: Default is |admin|. * **Password**: The password you set during initialization. * **TLS CA**: The certificate used to secure the connection (unless using a public CA or ``--trust-tls-cert``). The DSN format is: .. code-block:: text gel://:@:/ All components except the scheme are optional. See :ref:`ref_reference_connection_dsn` for more details. Phase 2: The Migration (Cutover) ================================ To ensure data consistency, you must prevent new writes to your database during the transfer. 1. Enable maintenance mode -------------------------- Before touching the data, put your application into maintenance mode. This ensures that no new records are created in |Gel| Cloud after you've started the dump. * **Web Apps**: Point your load balancer to a static "Maintenance" page. * **Background Jobs**: Stop all workers, cron jobs, or queues that interact with the database. 2. Perform the migration ------------------------ You can use the |Gel| CLI to move data directly from your Cloud instance to your new self-hosted instance. .. code-block:: bash # 1. Dump from Gel Cloud to directory $ gel dump --instance / \ --all --format=dir \ production_dump # 2. Restore to self-hosted from dump $ gel restore --dsn --all production_dump .. note:: If your self-hosted instance uses a self-signed TLS certificate, you may need to add ``--tls-security insecure`` to the restore command, or first retrieve the TLS certificate and set it via :gelenv:`TLS_CA`. Phase 3: Verification and Go-Live ================================= 1. Update application environment variables ------------------------------------------- Update your application's configuration to point to the new instance. Replace the Gel Cloud specific connection environment variables :gelenv:`INSTANCE` and :gelenv:`SECRET_KEY` variables with the new connection details: * :gelenv:`DSN`: :geluri:`user:password@host:port/branch` * :gelenv:`TLS_CA`: The TLS certificate content (if your instance uses a self-signed certificate) * :gelenv:`CLIENT_TLS_SECURITY`: Set to ``insecure`` if you need to skip TLS verification (not recommended for production) 2. Sanity check --------------- Before turning off maintenance mode: * Run a few :gelcmd:`query` commands against the new instance to verify data integrity. * Check that your schema migrated correctly: :gelcmd:`migrate --status`. * Launch a local instance of your app connected to the new production DB to ensure connection logic is sound. 3. Disable maintenance mode --------------------------- Once verified, restart your application servers and background workers. Monitor your logs closely for any connection or permission errors. Post-Migration Note =================== Once you are 100% certain your data is safe and your app is stable on the new host, you can de-provision your |Gel| Cloud instance. Remember that all |Gel| Cloud data will be deleted after the January 2026 deadline. ================================================ FILE: docs/cloud/web.rst ================================================ .. _ref_guide_cloud_web: ======= Web GUI ======= :edb-alt-title: Using Gel Cloud via the web GUI If you'd prefer, you can also manage your account via `the Gel Cloud web-based GUI `_. The first time you access the web UI, you will be prompted to log in. Once you log in with your account, you'll be on the "Instances" tab of the front page which shows your instance list. The other two tabs allow you to manage your organization settings and billing. Instances --------- If this is your first time accessing Gel Cloud, this list will be empty. To create an instance, click "Create new instance." This will pop up a modal allowing you to name your instance and specify the version of Gel and the region for the instance. Once the instance has been created, you'll see the instance dashboard which allows you to monitor your instance, navigate to the management page for its branches, and create secret keys. You'll also see instructions in the bottom-right for linking your |Gel| CLI to your Gel Cloud account. You do this by running the CLI command :gelcmd:`cloud login`. This will make all of your Gel Cloud instances accessible via the CLI. You can manage them just as you would other remote Gel instances. If you want to manage a branch of your database, click through on the instance's name from the top right of the instance dashboard. If you just created a database, the branch management view will be mostly empty except for a button offering to create a sample branch. Once you have a schema created and some data in a database, this view will offer you similar tools to those in our local UI. You'll be able to access a REPL, edit complex queries or build them graphically, inspect your schema, and browse your data. Org Settings ------------ This tab allows you to add GitHub organizations for which you are an admin. If you don't see your organization's name here, you may need to update your `org settings`_ in GitHub to allow |Gel| Cloud to read your list of organizations, and then refresh the org list. .. lint-off .. _org setings: https://docs.github.com/en/organizations/managing-oauth-access-to-your-organizations-data/approving-oauth-apps-for-your-organization .. lint-on Billing ------- On this page you can manage your account type and payment methods, and set your email for receiving billing info. Optionally, you can also save your payment info using `Link `_, `Stripe's `_ fast-checkout solution. ================================================ FILE: docs/conf.py ================================================ # -*- coding: utf-8 -*- # # EdgeDB documentation build configuration file, created by # sphinx-quickstart on Wed Aug 3 17:58:14 2016. # # This file is execfile()d with the current directory set to its # containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # sys.path.insert(0, os.path.abspath('.')) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.todo', 'sphinx.ext.viewcode', 'edb.tools.docs', 'sphinxcontrib.asyncio', 'sphinx.ext.intersphinx', 'sphinx_code_tabs', ] intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} # Add any paths that contain templates here, relative to this directory. templates_path = [] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = ['.rst', '.md'] source_suffix = '.rst' # The encoding of source files. # source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' # General information about the project. project = u'EdgeDB' copyright = u'2016, magicstack' author = u'magicstack' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. version = '0.5.0' # The full version, including alpha/beta/rc tags. release = '0.5.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: # today = '' # Else, today_fmt is used as the format for a strftime call. # today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = [] # The reST default role (used for this markup: `text`) to use for all # documents. # default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. # add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). # add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. # show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. # keep_warnings = False suppress_warnings = ['image.not_readable'] # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False primary_domain = None # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = 'alabaster' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. # html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". # html_title = None # A shorter title for the navigation bar. Default is the same as html_title. # html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. # html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. # html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = [] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. # html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. # html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. # html_use_smartypants = True # Custom sidebar templates, maps document names to template names. html_sidebars = { '**': [ 'globaltoc.html', 'searchbox.html', ] } # Additional templates that should be rendered to pages, maps page names to # template names. # html_additional_pages = {} # If false, no module index is generated. # html_domain_indices = True # If false, no index is generated. # html_use_index = True # If true, the index is split into individual pages for each letter. # html_split_index = False # If true, links to the reST sources are added to the pages. # html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. # html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. # html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. # html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). # html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' # html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value # html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. # html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. htmlhelp_basename = 'EdgeDBdoc' # -- Options for LaTeX output --------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # 'preamble': '', # Latex figure (float) alignment # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, 'EdgeDB.tex', u'EdgeDB Documentation', u'magicstack', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of # the title page. # latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. # latex_use_parts = False # If true, show page references after internal links. # latex_show_pagerefs = False # If true, show URL addresses after external links. # latex_show_urls = False # Documents to append as an appendix to all manuals. # latex_appendices = [] # If false, no module index is generated. # latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ (master_doc, 'edgedb', u'EdgeDB Documentation', [author], 1) ] # If true, show URL addresses after external links. # man_show_urls = False # -- Options for Texinfo output ------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'EdgeDB', u'EdgeDB Documentation', author, 'EdgeDB', 'One line description of project.', 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. # texinfo_appendices = [] # If false, no module index is generated. # texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. # texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. # texinfo_no_detailmenu = False # config for srclink srclink_project = 'https://github.com/edgedb/edgedb' srclink_src_path = 'doc/' srclink_branch = 'doc' ================================================ FILE: docs/index.rst ================================================ .. _index_toplevel: ================= Gel documentation ================= Welcome to the |Gel| |version| documentation. .. !!! DO NOT CHANGE :maxdepth: egdedb.com/docs depends on it !!! .. toctree:: :maxdepth: 5 :includehidden: intro/index reference/index resources/index cloud/index ================================================ FILE: docs/intro/branches.rst ================================================ .. _ref_intro_branches: ======== Branches ======== |Gel's| branches make it easy to prototype app features that impact your database schema, even in cases where those features are never released. You can create a branch in your Gel database that corresponds to a feature branch in your VCS. When you're done, either :ref:`merge ` that branch into your main branch or :ref:`drop ` it leaving your original schema intact. .. note:: The procedure we will describe should be adaptable to any VCS offering branching and rebasing, but in order to make the examples concrete and easy-to-follow, we'll be demonstrating how Gel branches interact with Git branches. You may adapt these examples to your VCS of choice. 1. Create a new feature branch ------------------------------ Create a feature branch in your VCS and switch to it. Then, create and switch to a corresponding branch in Gel using the CLI. .. code-block:: bash $ gel branch create feature Creating branch 'feature'... OK: CREATE BRANCH $ gel branch switch feature Switching from 'main' to 'feature' .. note:: You can alternatively create and switch in one shot using :gelcmd:`branch switch -c feature`. 2. Build your feature --------------------- Write your code and make any schema changes your feature requires. 3. Pull any changes on "main" ----------------------------- .. note:: This step is optional. If you know your |main| code branch is current and all migrations in that code branch have already been applied to your |main| database branch, feel free to skip it. We need to make sure that merging our feature branch onto |main| is a simple fast-forward. The next two steps take care of that. Switch back to your |main| code branch. Run ``git pull`` to pull down any new changes. If any of these are schema changes, use :gelcmd:`branch switch main` to switch back to your |main| database branch and apply the new schema with :gelcmd:`migrate`. Once this is done, you can switch back to your feature branches in your VCS and |Gel|. 4. Rebase your feature branch on "main" --------------------------------------- .. note:: If you skipped the previous step, you can skip this one too. This is only necessary if you had to pull down new changes on |main|. For your code branch, first make sure you're on ``feature`` and then run the rebase: .. code-block:: bash $ git rebase main Now, do the same for your database, also from ``feature``: .. code-block:: bash $ gel branch rebase main 5. Merge ``feature`` onto "main" -------------------------------- Switch back to both |main| branches and merge ``feature``. .. code-block:: bash $ git switch main Switched to branch 'main' $ git merge feature .. code-block:: bash $ gel branch switch main Switching from 'feature' to 'main' $ gel branch merge feature Now, your feature and its schema have been successfully merged! 🎉 Further reading ^^^^^^^^^^^^^^^ - :ref:`Branches CLI ` Further information can be found in the `branches RFC `_, which describes the design of the migration system. ================================================ FILE: docs/intro/cli.rst ================================================ .. _ref_intro_cli: .. _ref_admin_install: ======= The CLI ======= The |gelcmd| command line tool is an integral part of the developer workflow of building with Gel. Below are instructions for installing it. Installation ------------ To get started with Gel, the first step is install the |gelcmd| CLI. **Linux or macOS** .. code-block:: bash $ curl --proto '=https' --tlsv1.2 -sSf https://www.geldata.com/sh | sh **Windows Powershell** .. note:: Gel on Windows requires WSL 2 because the Gel server runs on Linux. .. code-block:: powershell PS> iwr https://www.geldata.com/ps1 -useb | iex Follow the prompts on screen to complete the installation. The script will download the |gelcmd| command built for your OS and add a path to it to your shell environment. Then test the installation: .. code-block:: bash $ gel --version Gel CLI x.x+abcdefg .. note:: If you encounter a ``command not found`` error, you may need to open a fresh shell window. See ``help`` commands --------------------- The entire CLI is self-documenting. Once it's installed, run :gelcmd:`--help` to see a breakdown of all the commands and options. .. code-block:: bash $ gel --help Usage: gel [OPTIONS] [COMMAND] Commands: Options: Connection Options (gel --help-connect to see full list): Cloud Connection Options: The majority of CLI commands perform some action against a *particular* Gel instance. As such, there are a standard set of flags that are used to specify *which instance* should be the target of the command, plus additional information like TLS certificates. The following command documents these flags. .. code-block:: bash $ gel --help-connect Connection Options (full list): -I, --instance Instance name (use `gel instance list` to list local, remote and Cloud instances available to you) --dsn DSN for Gel to connect to (overrides all other options except password) --credentials-file Path to JSON file to read credentials from -H, --host Gel instance host -P, --port Port to connect to Gel --unix-path A path to a Unix socket for Gel connection When the supplied path is a directory, the actual path will be computed using the `--port` and `--admin` parameters. ... If you ever want to see documentation for a particular command ( :gelcmd:`migration create`) or group of commands (:gelcmd:`instance`), just append the ``--help`` flag. .. code-block:: bash $ gel instance --help Manage local Gel instances Usage: gel instance Commands: create Initialize a new Gel instance list Show all instances status Show status of an instance start Start an instance stop Stop an instance ... Upgrade the CLI --------------- To upgrade to the latest version: .. code-block:: bash $ gel cli upgrade ================================================ FILE: docs/intro/clients.rst ================================================ .. _ref_intro_clients: ================ Client Libraries ================ |Gel| implements libraries for popular languages that make it easier to work with Gel. These libraries provide a common set of functionality. - *Instantiating clients.* Most libraries implement a ``Client`` class that internally manages a pool of physical connections to your Gel instance. - *Resolving connections.* All client libraries implement a standard protocol for determining how to connect to your database. In most cases, this will involve checking for special environment variables like :gelenv:`DSN` or, in the case of Gel Cloud instances, :gelenv:`INSTANCE` and :gelenv:`SECRET_KEY`. (More on this in :ref:`the Connection section below `.) - *Executing queries.* A ``Client`` will provide some methods for executing queries against your database. Under the hood, this query is executed using Gel's efficient binary protocol. .. note:: For some use cases, you may not need a client library. Gel allows you to execute :ref:`queries over HTTP `. This is slower than the binary protocol and lacks support for transactions and rich data types, but may be suitable if a client library isn't available for your language of choice. Available libraries =================== To execute queries from your application code, use one of :ref:`Gel's client libraries `. Usage ===== To follow along with the guide below, first create a new directory and initialize a project. .. code-block:: bash $ mydir myproject $ cd myproject $ gel project init Configure the environment as needed for your preferred language. .. tabs:: .. code-tab:: bash :caption: Node.js $ npm init -y $ tsc --init # (TypeScript only) $ touch index.ts .. code-tab:: bash :caption: Deno $ touch index.ts .. code-tab:: bash :caption: Python $ python -m venv venv $ source venv/bin/activate $ touch main.py .. code-tab:: bash :caption: Rust $ cargo init .. code-tab:: bash :caption: Go $ go mod init example/quickstart $ touch hello.go .. code-tab:: bash :caption: .NET $ dotnet new console -o . -f net6.0 Install the Gel client library. .. tabs:: .. code-tab:: bash :caption: Node.js $ npm install gel # npm $ yarn add gel # yarn .. code-tab:: txt :caption: Deno n/a .. code-tab:: bash :caption: Python $ pip install gel .. code-tab:: toml :caption: Rust # Cargo.toml [dependencies] gel-tokio = "0.5.0" # Additional dependency tokio = { version = "1.28.1", features = ["macros", "rt-multi-thread"] } .. code-tab:: bash :caption: Go $ go get github.com/geldata/gel-go .. code-tab:: bash :caption: .NET $ dotnet add package Gel.Net.Driver Copy and paste the following simple script. This script initializes a ``Client`` instance. Clients manage an internal pool of connections to your database and provide a set of methods for executing queries. .. note:: Note that we aren't passing connection information (say, a connection URL) when creating a client. The client libraries can detect that they are inside a project directory and connect to the project-linked instance automatically. For details on configuring connections, refer to the :ref:`Connection ` section below. .. lint-off .. tabs:: .. code-tab:: typescript :caption: Node.js import {createClient} from 'gel'; const client = createClient(); client.querySingle(`select random()`).then((result) => { console.log(result); }); .. code-tab:: python from gel import create_client client = create_client() result = client.query_single("select random()") print(result) .. code-tab:: rust // src/main.rs #[tokio::main] async fn main() { let conn = gel_tokio::create_client() .await .expect("Client initiation"); let val = conn .query_required_single::("select random()", &()) .await .expect("Returning value"); println!("Result: {}", val); } .. code-tab:: go // hello.go package main import ( "context" "fmt" "log" "github.com/geldata/gel-go" ) func main() { ctx := context.Background() client, err := gel.CreateClient(ctx, gel.Options{}) if err != nil { log.Fatal(err) } defer client.Close() var result float64 err = client. QuerySingle(ctx, "select random();", &result) if err != nil { log.Fatal(err) } fmt.Println(result) } .. code-tab:: csharp :caption: .NET using Gel; var client = new GelClient(); var result = await client.QuerySingleAsync("select random();"); Console.WriteLine(result); .. code-tab:: elixir :caption: Elixir # lib/gel_quickstart.ex defmodule GelQuickstart do def run do {:ok, client} = Gel.start_link() result = Gel.query_single!(client, "select random()") IO.inspect(result) end end .. lint-on Finally, execute the file. .. tabs:: .. code-tab:: bash :caption: Node.js $ npx tsx index.ts .. code-tab:: bash :caption: Deno $ deno run --allow-all --unstable index.deno.ts .. code-tab:: bash :caption: Python $ python index.py .. code-tab:: bash :caption: Rust $ cargo run .. code-tab:: bash :caption: Go $ go run . .. code-tab:: bash :caption: .NET $ dotnet run .. code-tab:: bash :caption: Elixir $ mix run -e GelQuickstart.run You should see a random number get printed to the console. This number was generated inside your Gel instance using EdgeQL's built-in :eql:func:`random` function. .. _ref_intro_clients_connection: Connection ========== All client libraries implement a standard protocol for determining how to connect to your database. Using projects -------------- In development, we recommend :ref:`initializing a project ` in the root of your codebase. .. code-block:: bash $ gel project init Once the project is initialized, any code that uses an official client library will automatically connect to the project-linked instance—no need for environment variables or hard-coded credentials. Follow the :ref:`Using projects ` guide to get started. Using environment variables --------------------------- .. _ref_intro_clients_connection_cloud: For Gel Cloud ^^^^^^^^^^^^^ In production, connection information can be securely passed to the client library via environment variables. For Gel Cloud instances, the recommended variables to set are :gelenv:`INSTANCE` and :gelenv:`SECRET_KEY`. Set :gelenv:`INSTANCE` to ``/`` where ```` is the name you set when you created the Gel Cloud instance. If you have not yet created a secret key, you can do so in the Gel Cloud UI or by running :ref:`ref_cli_gel_cloud_secretkey_create` via the CLI. For self-hosted instances ^^^^^^^^^^^^^^^^^^^^^^^^^ Most commonly for self-hosted remote instances, you set a value for the :gelenv:`DSN` environment variable. .. note:: If environment variables like :gelenv:`DSN` are defined inside a project directory, the environment variables will take precedence. A DSN is also known as a "connection string" and takes the following form: :geluri:`:@:`. Each element of the DSN is optional; in fact |geluri| is a technically a valid DSN. Any unspecified element will default to the following values. .. list-table:: * - ```` - ``localhost`` * - ```` - ``5656`` * - ```` - |admin| * - ```` - ``null`` A typical DSN may look like this: :geluri:`admin:PASSWORD@db.domain.com:8080`. DSNs can also contain the following query parameters. .. list-table:: * - ``branch`` - The database branch to connect to within the given instance. Defaults to |main|. * - ``tls_security`` - The TLS security mode. Accepts the following values. - ``"strict"`` (**default**) — verify certificates and hostnames - ``"no_host_verification"`` — verify certificates only - ``"insecure"`` — trust self-signed certificates * - ``tls_ca_file`` - A filesystem path pointing to a CA root certificate. This is usually only necessary when attempting to connect via TLS to a remote instance with a self-signed certificate. These parameters can be added to any DSN using web-standard query string notation: :geluri:`user:pass@example.com:8080?branch=my_branch&tls_security=insecure`. For a more comprehensive guide to DSNs, see the :ref:`DSN Specification `. Using multiple environment variables ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ If needed for your deployment pipeline, each element of the DSN can be specified independently. - :gelenv:`HOST` - :gelenv:`PORT` - :gelenv:`USER` - :gelenv:`PASSWORD` - :gelenv:`BRANCH` - :gelenv:`TLS_CA_FILE` - :gelenv:`CLIENT_TLS_SECURITY` .. note:: If a value for :gelenv:`DSN` is defined, it will override these variables! Other mechanisms ---------------- :gelenv:`CREDENTIALS_FILE` A path to a ``.json`` file containing connection information. In some scenarios (including local Docker development) its useful to represent connection information with files. .. code-block:: json { "host": "localhost", "port": 10700, "user": "testuser", "password": "testpassword", "branch": "main", "tls_cert_data": "-----BEGIN CERTIFICATE-----\nabcdef..." } :gelenv:`INSTANCE` (local/Gel Cloud only) The name of an instance. Useful only for local or Gel Cloud instances. .. note:: For more on Gel Cloud instances, see the :ref:`Gel Cloud instance connection section ` above. Reference --------- These are the most common ways to connect to an instance, however Gel supports several other options for advanced use cases. For a complete reference on connection configuration, see :ref:`Reference > Connection Parameters `. ================================================ FILE: docs/intro/edgeql.rst ================================================ .. _ref_intro_edgeql: EdgeQL ====== EdgeQL is the query language of Gel. It's intended as a spiritual successor to SQL that solves some of its biggest design limitations. This page is intended as a rapid-fire overview so you can hit the ground running with |Gel|. Refer to the linked pages for more in-depth documentation. .. note:: The examples below also demonstrate how to express the query with the :ref:`TypeScript client's ` query builder, which lets you express arbitrary EdgeQL queries in a code-first, typesafe way. Scalar literals ^^^^^^^^^^^^^^^ |Gel| has a rich primitive type system consisting of the following data types. .. list-table:: * - Strings - ``str`` * - Booleans - ``bool`` * - Numbers - ``int16`` ``int32`` ``int64`` ``float32`` ``float64`` ``bigint`` ``decimal`` * - UUID - ``uuid`` * - JSON - ``json`` * - Dates and times - ``datetime`` ``cal::local_datetime`` ``cal::local_date`` ``cal::local_time`` * - Durations - ``duration`` ``cal::relative_duration`` ``cal::date_duration`` * - Binary data - ``bytes`` * - Auto-incrementing counters - ``sequence`` * - Enums - ``enum`` Basic literals can be declared using familiar syntax. .. tabs:: .. code-tab:: edgeql-repl db> select "I ❤️ EdgeQL"; # str {'U ❤️ EdgeQL'} db> select false; # bool {false} db> select 42; # int64 {42} db> select 3.14; # float64 {3.14} db> select 12345678n; # bigint {12345678n} db> select 15.0e+100n; # decimal {15.0e+100n} db> select b'bina\\x01ry'; # bytes {b'bina\\x01ry'} .. code-tab:: typescript e.str("I ❤️ EdgeQL") // string e.bool(false) // boolean e.int64(42) // number e.float64(3.14) // number e.bigint(BigInt(12345678)) // bigint e.decimal("1234.4567") // n/a (not supported by JS clients) e.bytes(Buffer.from("bina\\x01ry")) // Buffer Other type literals are declared by *casting* an appropriately structured string. .. tabs:: .. code-tab:: edgeql-repl db> select 'a5ea6360-75bd-4c20-b69c-8f317b0d2857'; {a5ea6360-75bd-4c20-b69c-8f317b0d2857} db> select '1999-03-31T15:17:00Z'; {'1999-03-31T15:17:00Z'} db> select '5 hours 4 minutes 3 seconds'; {'5:04:03'} db> select '2 years 18 days'; {'P2Y18D'} .. code-tab:: typescript e.uuid("a5ea6360-75bd-4c20-b69c-8f317b0d2857") // string e.datetime("1999-03-31T15:17:00Z") // Date e.duration("5 hours 4 minutes 3 seconds") // gel.Duration (custom class) e.cal.relative_duration("2 years 18 days") // gel.RelativeDuration (custom class) Primitive data can be composed into arrays and tuples, which can themselves be nested. .. tabs:: .. code-tab:: edgeql-repl db> select ['hello', 'world']; {['hello', 'world']} db> select ('Apple', 7, true); {('Apple', 7, true)} # unnamed tuple db> select (fruit := 'Apple', quantity := 3.14, fresh := true); {(fruit := 'Apple', quantity := 3.14, fresh := true)} # named tuple db> select ["this", "is", "an", "array"]; {"[\"this\", \"is\", \"an\", \"array\"]"} .. code-tab:: typescript e.array(["hello", "world"]); // string[] e.tuple(["Apple", 7, true]); // [string, number, boolean] e.tuple({fruit: "Apple", quantity: 3.14, fresh: true}); // {fruit: string; quantity: number; fresh: boolean} e.json(["this", "is", "an", "array"]); // unknown |Gel| also supports a special ``json`` type for representing unstructured data. Primitive data structures can be converted to JSON using a type cast (````). Alternatively, a properly JSON-encoded string can be converted to ``json`` with the built-in ``to_json`` function. Indexing a ``json`` value returns another ``json`` value. .. code-tabs:: .. code-tab:: edgeql-repl gel> select 5; {"5"} gel> select [1,2,3]; {"[1, 2, 3]"} gel> select to_json('[{ "name": "Peter Parker" }]'); {"[{\"name\": \"Peter Parker\"}]"} gel> select to_json('[{ "name": "Peter Parker" }]')[0]['name']; {"\"Peter Parker\""} .. code-tab:: typescript /* The result of an query returning `json` is represented with `unknown` in TypeScript. */ e.json(5); // => unknown e.json([1, 2, 3]); // => unknown e.to_json('[{ "name": "Peter Parker" }]'); // => unknown e.to_json('[{ "name": "Peter Parker" }]')[0]["name"]; // => unknown Refer to :ref:`Docs > EdgeQL > Literals ` for complete docs. Functions and operators ^^^^^^^^^^^^^^^^^^^^^^^ |Gel| provides a rich standard library of functions to operate and manipulate various data types. .. tabs:: .. code-tab:: edgeql-repl db> select str_upper('oh hi mark'); {'OH HI MARK'} db> select len('oh hi mark'); {10} db> select uuid_generate_v1mc(); {c68e3836-0d59-11ed-9379-fb98e50038bb} db> select contains(['a', 'b', 'c'], 'd'); {false} .. code-tab:: typescript e.str_upper("oh hi mark"); // string e.len("oh hi mark"); // number e.uuid_generate_v1mc(); // string e.contains(["a", "b", "c"], "d"); // boolean Similarly, it provides a comprehensive set of built-in operators. .. tabs:: .. code-tab:: edgeql-repl db> select not true; {false} db> select exists 'hi'; {true} db> select 2 + 2; {4} db> select 'Hello' ++ ' world!'; {'Hello world!'} db> select '😄' if true else '😢'; {'😄'} db> select '5 minutes' + '2 hours'; {'2:05:00'} .. code-tab:: typescript e.op("not", e.bool(true)); // booolean e.op("exists", e.set("hi")); // boolean e.op("exists", e.cast(e.str, e.set())); // boolean e.op(e.int64(2), "+", e.int64(2)); // number e.op(e.str("Hello "), "++", e.str("World!")); // string e.op(e.str("😄"), "if", e.bool(true), "else", e.str("😢")); // string e.op(e.duration("5 minutes"), "+", e.duration("2 hours")) See :ref:`Docs > Standard Library ` for reference documentation on all built-in types, including the functions and operators that apply to them. Insert an object ^^^^^^^^^^^^^^^^ Objects are created using ``insert``. The ``insert`` statement relies on developer-friendly syntax like curly braces and the ``:=`` operator. .. tabs:: .. code-tab:: edgeql insert Movie { title := 'Doctor Strange 2', release_year := 2022 }; .. code-tab:: typescript const query = e.insert(e.Movie, { title: 'Doctor Strange 2', release_year: 2022 }); const result = await query.run(client); // {id: string} // by default INSERT only returns // the id of the new object See :ref:`Docs > EdgeQL > Insert `. Nested inserts ^^^^^^^^^^^^^^ One of EdgeQL's greatest features is that it's easy to compose. Nested inserts are easily achieved with subqueries. .. tabs:: .. code-tab:: edgeql insert Movie { title := 'Doctor Strange 2', release_year := 2022, director := (insert Person { name := 'Sam Raimi' }) }; .. code-tab:: typescript const query = e.insert(e.Movie, { title: 'Doctor Strange 2', release_year: 2022, director: e.insert(e.Person, { name: 'Sam Raimi' }) }); const result = await query.run(client); // {id: string} // by default INSERT only returns // the id of the new object Select objects ^^^^^^^^^^^^^^ Use a *shape* to define which properties to ``select`` from the given object type. .. tabs:: .. code-tab:: edgeql select Movie { id, title }; .. code-tab:: typescript const query = e.select(e.Movie, () => ({ id: true, title: true })); const result = await query.run(client); // {id: string; title: string; }[] // To select all properties of an object, use the // spread operator with the special "*"" property: const query = e.select(e.Movie, () => ({ ...e.Movie['*'] })); Fetch linked objects with a nested shape. .. tabs:: .. code-tab:: edgeql select Movie { id, title, actors: { name } }; .. code-tab:: typescript const query = e.select(e.Movie, () => ({ id: true, title: true, actors: { name: true, } })); const result = await query.run(client); // {id: string; title: string, actors: {name: string}[]}[] See :ref:`Docs > EdgeQL > Select > Shapes `. Filtering, ordering, and pagination ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The ``select`` statement can be augmented with ``filter``, ``order by``, ``offset``, and ``limit`` clauses (in that order). .. tabs:: .. code-tab:: edgeql select Movie { id, title } filter .release_year > 2017 order by .title offset 10 limit 10; .. code-tab:: typescript const query = e.select(e.Movie, (movie) => ({ id: true, title: true, filter: e.op(movie.release_year, ">", 1999), order_by: movie.title, offset: 10, limit: 10, })); const result = await query.run(client); // {id: string; title: number}[] Note that you reference properties of the object to include in your ``select`` by prepending the property name with a period: ``.release_year``. This is known as *leading dot notation*. Every new set of curly braces introduces a new scope. You can add ``filter``, ``limit``, and ``offset`` clauses to nested shapes. .. tabs:: .. code-tab:: edgeql select Movie { title, actors: { name } filter .name ilike 'chris%' } filter .title ilike '%avengers%'; .. code-tab:: typescript e.select(e.Movie, movie => ({ title: true, characters: c => ({ name: true, filter: e.op(c.name, "ilike", "chris%"), }), filter: e.op(movie.title, "ilike", "%avengers%"), })); // => { characters: { name: string; }[]; title: string; }[] const result = await query.run(client); // {id: string; title: number}[] See :ref:`Filtering `, :ref:`Ordering `, and :ref:`Pagination `. Query composition ^^^^^^^^^^^^^^^^^ We've seen how to ``insert`` and ``select``. How do we do both in one query? Answer: query composition. EdgeQL's syntax is designed to be *composable*, like any good programming language. .. tabs:: .. code-tab:: edgeql select ( insert Movie { title := 'The Marvels' } ) { id, title }; .. code-tab:: typescript const newMovie = e.insert(e.Movie, { title: "The Marvels" }); const query = e.select(newMovie, () => ({ id: true, title: true })); const result = await query.run(client); // {id: string; title: string} We can clean up this query by pulling out the ``insert`` statement into a ``with`` block. A ``with`` block is useful for composing complex multi-step queries, like a script. .. tabs:: .. code-tab:: edgeql with new_movie := (insert Movie { title := 'The Marvels' }) select new_movie { id, title }; .. code-tab:: typescript /* Same as above. In the query builder, explicit ``with`` blocks aren't necessary! Just assign your EdgeQL subqueries to variables and compose them as you like. The query builder automatically convert your top-level query to an EdgeQL expression with proper ``with`` blocks. */ Computed properties ^^^^^^^^^^^^^^^^^^^ Selection shapes can contain computed properties. .. tabs:: .. code-tab:: edgeql select Movie { title, title_upper := str_upper(.title), cast_size := count(.actors) }; .. code-tab:: typescript e.select(e.Movie, movie => ({ title: true, title_upper: e.str_upper(movie.title), cast_size: e.count(movie.actors) })) // {title: string; title_upper: string; cast_size: number}[] A common use for computed properties is to query a link in reverse; this is known as a *backlink* and it has special syntax. .. tabs:: .. code-tab:: edgeql select Person { name, acted_in := . ({ name: true, acted_in: e.select(person[" ({ title: true, })), })); // {name: string; acted_in: {title: string}[];}[] See :ref:`Docs > EdgeQL > Select > Computed fields ` and :ref:`Docs > EdgeQL > Select > Backlinks `. Update objects ^^^^^^^^^^^^^^ The ``update`` statement accepts a ``filter`` clause up-front, followed by a ``set`` shape indicating how the matching objects should be updated. .. tabs:: .. code-tab:: edgeql update Movie filter .title = "Doctor Strange 2" set { title := "Doctor Strange in the Multiverse of Madness" }; .. code-tab:: typescript const query = e.update(e.Movie, (movie) => ({ filter: e.op(movie.title, '=', 'Doctor Strange 2'), set: { title: 'Doctor Strange in the Multiverse of Madness', }, })); const result = await query.run(client); // {id: string} When updating links, the set of linked objects can be added to with ``+=``, subtracted from with ``-=``, or overwritten with ``:=``. .. tabs:: .. code-tab:: edgeql update Movie filter .title = "Doctor Strange 2" set { actors += (select Person filter .name = "Rachel McAdams") }; .. code-tab:: typescript e.update(e.Movie, (movie) => ({ filter: e.op(movie.title, '=', 'Doctor Strange 2'), set: { actors: { "+=": e.select(e.Person, person => ({ filter: e.op(person.name, "=", "Rachel McAdams") })) } }, })); See :ref:`Docs > EdgeQL > Update `. Delete objects ^^^^^^^^^^^^^^ The ``delete`` statement can contain ``filter``, ``order by``, ``offset``, and ``limit`` clauses. .. tabs:: .. code-tab:: edgeql delete Movie filter .title ilike "the avengers%" limit 3; .. code-tab:: typescript const query = e.delete(e.Movie, (movie) => ({ filter: e.op(movie.title, 'ilike', "the avengers%"), })); const result = await query.run(client); // {id: string}[] See :ref:`Docs > EdgeQL > Delete `. Query parameters ^^^^^^^^^^^^^^^^ You can reference query parameters in your queries with ``$`` notation. Since EdgeQL is a strongly typed language, all query parameters must be prepending with a *type cast* to indicate the expected type. .. note:: Scalars like ``str``, ``int64``, and ``json`` are supported. Tuples, arrays, and object types are not. .. tabs:: .. code-tab:: edgeql insert Movie { title := $title, release_year := $release_year }; .. code-tab:: typescript const query = e.params({ title: e.str, release_year: e.int64 }, ($) => { return e.insert(e.Movie, { title: $.title, release_year: $.release_year, })) }; const result = await query.run(client, { title: 'Thor: Love and Thunder', release_year: 2022, }); // {id: string} All client libraries provide a dedicated API for specifying parameters when executing a query. .. tabs:: .. code-tab:: javascript import {createClient} from "gel"; const client = createClient(); const result = await client.query(`select $param`, { param: "Play it, Sam." }); // => "Play it, Sam." .. code-tab:: python import gel client = gel.create_async_client() async def main(): result = await client.query("select $param", param="Play it, Sam") # => "Play it, Sam" .. code-tab:: go package main import ( "context" "log" "github.com/geldata/gel-go" ) func main() { ctx := context.Background() client, err := gel.CreateClient(ctx, gel.Options{}) if err != nil { log.Fatal(err) } defer client.Close() var ( param string = "Play it, Sam." result string ) query := "select $0" err = client.Query(ctx, query, &result, param) // ... } .. code-tab:: rust // [dependencies] // gel-tokio = "0.5.0" // tokio = { version = "1.28.1", features = ["macros", "rt-multi-thread"] } #[tokio::main] async fn main() { let conn = gel_tokio::create_client() .await .expect("Client initiation"); let param = "Play it, Sam."; let val = conn .query_required_single::("select $0", &(param,)) .await .expect("Returning value"); println!("{val}"); } See :ref:`Docs > EdgeQL > Parameters `. Subqueries ^^^^^^^^^^ Unlike SQL, EdgeQL is *composable*; queries can be naturally nested. This is useful, for instance, when performing nested mutations. .. tabs:: .. code-tab:: edgeql with dr_strange := (select Movie filter .title = "Doctor Strange"), benedicts := (select Person filter .name in { 'Benedict Cumberbatch', 'Benedict Wong' }) update dr_strange set { actors += benedicts }; .. code-tab:: typescript // select Doctor Strange const drStrange = e.select(e.Movie, movie => ({ filter: e.op(movie.title, '=', "Doctor Strange") })); // select actors const actors = e.select(e.Person, person => ({ filter: e.op(person.name, 'in', e.set( 'Benedict Cumberbatch', 'Benedict Wong' )) })); // add actors to cast of drStrange const query = e.update(drStrange, ()=>({ actors: { "+=": actors } })); We can also use subqueries to fetch properties of an object we just inserted. .. tabs:: .. code-tab:: edgeql with new_movie := (insert Movie { title := "Avengers: The Kang Dynasty", release_year := 2025 }) select new_movie { title, release_year }; .. code-tab:: typescript // "with" blocks are added automatically // in the generated query! const newMovie = e.insert(e.Movie, { title: "Avengers: The Kang Dynasty", release_year: 2025 }); const query = e.select(newMovie, ()=>({ title: true, release_year: true, })); const result = await query.run(client); // {title: string; release_year: number;} See :ref:`Docs > EdgeQL > Select > Subqueries `. Polymorphic queries ^^^^^^^^^^^^^^^^^^^ Consider the following schema. .. code-block:: sdl abstract type Content { required title: str; } type Movie extending Content { release_year: int64; } type TVShow extending Content { num_seasons: int64; } We can ``select`` the abstract type ``Content`` to simultaneously fetch all objects that extend it, and use the ``[is ]`` syntax to select properties from known subtypes. .. tabs:: .. code-tab:: edgeql select Content { title, [is TVShow].num_seasons, [is Movie].release_year }; .. code-tab:: typescript const query = e.select(e.Content, (content) => ({ title: true, ...e.is(e.Movie, {release_year: true}), ...e.is(e.TVShow, {num_seasons: true}), })); /* { title: string; release_year: number | null; num_seasons: number | null; }[] */ See :ref:`Docs > EdgeQL > Select > Polymorphic queries `. Grouping objects ^^^^^^^^^^^^^^^^ Unlike SQL, EdgeQL provides a top-level ``group`` statement to compute groupings of objects. .. tabs:: .. code-tab:: edgeql group Movie { title, actors: { name }} by .release_year; .. code-tab:: typescript e.group(e.Movie, (movie) => { const release_year = movie.release_year; return { title: true, by: {release_year}, }; }); /* { grouping: string[]; key: { release_year: number | null }; elements: { title: string; }[]; }[] */ See :ref:`Docs > EdgeQL > Group `. ================================================ FILE: docs/intro/guides/ai/edgeql.rst ================================================ .. _ref_ai_guide_edgeql: ================ Gel AI in EdgeQL ================ :edb-alt-title: How to set up Gel AI in EdgeQL |Gel| AI brings vector search capabilities and retrieval-augmented generation directly into the database. Enable and configure the extension ================================== .. edb:split-section:: AI is a |Gel| extension. To enable it, we will need to add the extension to the app’s schema: .. code-block:: sdl using extension ai; .. edb:split-section:: |Gel| AI uses external APIs in order to get vectors and LLM completions. For it to work, we need to configure an API provider and specify their API key. Let's open EdgeQL REPL and run the following query: .. code-block:: edgeql configure current database insert ext::ai::OpenAIProviderConfig { secret := 'sk-....', }; Now our |Gel| application can take advantage of OpenAI's API to implement AI capabilities. .. note:: |Gel| AI comes with its own :ref:`UI ` that can be used to configure providers, set up prompts and test them in a sandbox. .. note:: Most API providers require you to set up and account and charge money for model use. Add vectors and perform similarity search ========================================= .. edb:split-section:: Before we start introducing AI capabilities, let's set up our database with a schema and populate it with some data (we're going to be helping Komi-san keep track of her friends). .. code-block:: sdl module default { type Friend { required name: str { constraint exclusive; }; summary: str; # A brief description of personality and role relationship_to_komi: str; # Relationship with Komi defining_trait: str; # Primary character trait or quirk } } .. edb:split-section:: Here's a shell command you can paste and run that will populate the database with some sample data. .. code-block:: bash :class: collapsible $ cat << 'EOF' > populate_db.edgeql insert Friend { name := 'Tadano Hitohito', summary := 'An extremely average high school boy with a remarkable ability to read the atmosphere and understand others\' feelings, especially Komi\'s.', relationship_to_komi := 'First friend and love interest', defining_trait := 'Perceptiveness', }; insert Friend { name := 'Osana Najimi', summary := 'An extremely outgoing person who claims to have been everyone\'s childhood friend. Gender: Najimi.', relationship_to_komi := 'Second friend and social catalyst', defining_trait := 'Universal childhood friend', }; insert Friend { name := 'Yamai Ren', summary := 'An intense and sometimes obsessive classmate who is completely infatuated with Komi.', relationship_to_komi := 'Self-proclaimed guardian and admirer', defining_trait := 'Obsessive devotion', }; insert Friend { name := 'Katai Makoto', summary := 'A intimidating-looking but shy student who shares many communication problems with Komi.', relationship_to_komi := 'Fellow communication-challenged friend', defining_trait := 'Scary appearance but gentle nature', }; insert Friend { name := 'Nakanaka Omoharu', summary := 'A self-proclaimed wielder of dark powers who acts like an anime character and is actually just a regular gaming enthusiast.', relationship_to_komi := 'Gaming buddy and chuunibyou friend', defining_trait := 'Chuunibyou tendencies', }; EOF $ gel query -f populate_db.edgeql .. edb:split-section:: In order to get |Gel| to produce embedding vectors, we need to create a special ``deferred index`` on the type we would like to perform similarity search on. More specifically, we need to specify an EdgeQL expression that produces a string that we're going to create an embedding vector for. This is how we would set up an index if we wanted to perform similarity search on ``Friend.summary``: .. code-block:: sdl-diff module default { type Friend { required name: str { constraint exclusive; }; summary: str; # A brief description of personality and role relationship_to_komi: str; # Relationship with Komi defining_trait: str; # Primary character trait or quirk + deferred index ext::ai::index(embedding_model := 'text-embedding-3-small') + on (.summary); } } .. edb:split-section:: But actually, in our case it would be better if we could similarity search across all properties at the same time. We can define the index on a more complex expression - like a concatenation of string properties - like this: .. code-block:: sdl-diff module default { type Friend { required name: str { constraint exclusive; }; summary: str; # A brief description of personality and role relationship_to_komi: str; # Relationship with Komi defining_trait: str; # Primary character trait or quirk deferred index ext::ai::index(embedding_model := 'text-embedding-3-small') - on (.summary); + on ( + .name ++ ' ' ++ .summary ++ ' ' + ++ .relationship_to_komi ++ ' ' + ++ .defining_trait + ); } } .. edb:split-section:: Once we're done with schema modification, we need to apply them by going through a migration: .. code-block:: bash $ gel migration create $ gel migrate .. edb:split-section:: That's it! |Gel| will make necessary API requests in the background and create an index that will enable us to perform efficient similarity search like this: .. code-block:: edgeql select ext::ai::search(Friend, query_vector); .. edb:split-section:: Note that this function accepts an embedding vector as the second argument, not a text string. This means that in order to similarity search for a string, we need to create a vector embedding for it using the same model as we used to create the index. |Gel| offers an HTTP endpoint ``/ai/embeddings`` that can handle it for us. All we need to do is to pass the vector it produces into the search query: .. note:: Note that we're passing our login and password in order to autheticate the request. We can find those using the CLI: :gelcmd:`instance credentials --json`. Learn about all the other ways you can authenticate a request :ref:`here `. .. code-block:: bash $ curl --user user:password \ --json '{"input": "Who helps Komi make friends?", "model": "text-embedding-3-small"}' \ http://localhost:/branch/main/ai/embeddings \ | jq -r '.data[0].embedding' \ # extract the embedding out of the JSON | tr -d '\n' \ # remove newlines | sed 's/^\[//;s/\]$//' \ # remove square brackets | awk '{print "select ext::ai::search(Friend, >[" $0 "]);"}' \ # assemble the query | gel query --file - # pass the query into Gel CLI Use the built-in RAG ==================== One more feature |Gel| AI offers is built-in retrieval-augmented generation, also known as RAG. .. edb:split-section:: |Gel| comes preconfigured to be able to process our text query, perform similarity search across the index we just created, pass the results to an LLM and return a response. We can access the built-in RAG using the ``/ai/rag`` HTTP endpoint: .. code-block:: bash $ curl --user user:password --json '{ "query": "Who helps Komi make friends?", "model": "gpt-4-turbo-preview", "context": {"query":"select Friend"} }' http://localhost:/branch/main/ai/rag .. edb:split-section:: We can also stream the response like this: .. code-block:: bash-diff $ curl --user user:password --json '{ "query": "Who helps Komi make friends?", "model": "gpt-4-turbo-preview", "context": {"query":"select Friend"}, + "stream": true, }' http://localhost:/branch/main/ai/rag Keep going! =========== You are now sufficiently equipped to use |Gel| AI in your applications. If you'd like to build something on your own, make sure to check out the :ref:`Reference manual ` in order to learn the details about using different APIs and models, configuring prompts or using the UI. Make sure to also check out the |Gel| AI bindings in :ref:`Python ` and :ref:`JavaScript ` if those languages are relevant to you. And if you would like more guidance for how |Gel| AI can be fit into an application, take a look at the :ref:`FastAPI Gel AI Tutorial `, where we're building a search bot using features you learned about above. ================================================ FILE: docs/intro/guides/ai/index.rst ================================================ .. edb:env-switcher:: ========= Adding AI ========= .. toctree:: :maxdepth: 1 :hidden: edgeql python ================================================ FILE: docs/intro/guides/ai/python.rst ================================================ .. _ref_ai_guide_python: ================ Gel AI in Python ================ :edb-alt-title: How to set up Gel AI in Python .. edb:split-section:: |Gel| AI brings vector search capabilities and retrieval-augmented generation directly into the database. It's integrated into the |Gel| Python binding via the ``gel.ai`` module. .. code-block:: bash $ pip install 'gel[ai]' Enable and configure the extension ================================== .. edb:split-section:: AI is an |Gel| extension. To enable it, we will need to add the extension to the app’s schema: .. code-block:: sdl using extension ai; .. edb:split-section:: |Gel| AI uses external APIs in order to get vectors and LLM completions. For it to work, we need to configure an API provider and specify their API key. Let's open EdgeQL REPL and run the following query: .. code-block:: edgeql configure current database insert ext::ai::OpenAIProviderConfig { secret := 'sk-....', }; Now our |Gel| application can take advantage of OpenAI's API to implement AI capabilities. .. note:: |Gel| AI comes with its own :ref:`UI ` that can be used to configure providers, set up prompts and test them in a sandbox. .. note:: Most API providers require you to set up and account and charge money for model use. Add vectors =========== .. edb:split-section:: Before we start introducing AI capabilities, let's set up our database with a schema and populate it with some data (we're going to be helping Komi-san keep track of her friends). .. code-block:: sdl module default { type Friend { required name: str { constraint exclusive; }; summary: str; # A brief description of personality and role relationship_to_komi: str; # Relationship with Komi defining_trait: str; # Primary character trait or quirk } } .. edb:split-section:: Here's a shell command you can paste and run that will populate the database with some sample data. .. code-block:: bash :class: collapsible $ cat << 'EOF' > populate_db.edgeql insert Friend { name := 'Tadano Hitohito', summary := 'An extremely average high school boy with a remarkable ability to read the atmosphere and understand others\' feelings, especially Komi\'s.', relationship_to_komi := 'First friend and love interest', defining_trait := 'Perceptiveness', }; insert Friend { name := 'Osana Najimi', summary := 'An extremely outgoing person who claims to have been everyone\'s childhood friend. Gender: Najimi.', relationship_to_komi := 'Second friend and social catalyst', defining_trait := 'Universal childhood friend', }; insert Friend { name := 'Yamai Ren', summary := 'An intense and sometimes obsessive classmate who is completely infatuated with Komi.', relationship_to_komi := 'Self-proclaimed guardian and admirer', defining_trait := 'Obsessive devotion', }; insert Friend { name := 'Katai Makoto', summary := 'A intimidating-looking but shy student who shares many communication problems with Komi.', relationship_to_komi := 'Fellow communication-challenged friend', defining_trait := 'Scary appearance but gentle nature', }; insert Friend { name := 'Nakanaka Omoharu', summary := 'A self-proclaimed wielder of dark powers who acts like an anime character and is actually just a regular gaming enthusiast.', relationship_to_komi := 'Gaming buddy and chuunibyou friend', defining_trait := 'Chuunibyou tendencies', }; EOF $ gel query -f populate_db.edgeql .. edb:split-section:: In order to get |Gel| to produce embedding vectors, we need to create a special ``deferred index`` on the type we would like to perform similarity search on. More specifically, we need to specify an EdgeQL expression that produces a string that we're going to create an embedding vector for. This is how we would set up an index if we wanted to perform similarity search on ``Friend.summary``: .. code-block:: sdl-diff module default { type Friend { required name: str { constraint exclusive; }; summary: str; # A brief description of personality and role relationship_to_komi: str; # Relationship with Komi defining_trait: str; # Primary character trait or quirk + deferred index ext::ai::index(embedding_model := 'text-embedding-3-small') + on (.summary); } } .. edb:split-section:: But actually, in our case it would be better if we could similarity search across all properties at the same time. We can define the index on a more complex expression - like a concatenation of string properties - like this: .. code-block:: sdl-diff module default { type Friend { required name: str { constraint exclusive; }; summary: str; # A brief description of personality and role relationship_to_komi: str; # Relationship with Komi defining_trait: str; # Primary character trait or quirk deferred index ext::ai::index(embedding_model := 'text-embedding-3-small') - on (.summary); + on ( + .name ++ ' ' ++ .summary ++ ' ' + ++ .relationship_to_komi ++ ' ' + ++ .defining_trait + ); } } .. edb:split-section:: Once we're done with schema modification, we need to apply them by going through a migration: .. code-block:: bash $ gel migration create $ gel migrate That's it! |Gel| will make necessary API requests in the background and create an index that will enable us to perform efficient similarity search. Perform similarity search in Python =================================== .. edb:split-section:: In order to run queries against the index we just created, we need to create a |Gel| client and pass it to a |Gel| AI instance. .. code-block:: python import gel import gel.ai gel_client = gel.create_client() gel_ai = gel.ai.create_rag_client(client) text = "Who helps Komi make friends?" vector = gel_ai.generate_embeddings( text, "text-embedding-3-small", ) gel_client.query( "select ext::ai::search(Friend, >$embedding_vector", embedding_vector=vector, ) .. edb:split-section:: We are going to execute a query that calls a single function: ``ext::ai::search(, )``. That function accepts an embedding vector as the second argument, not a text string. This means that in order to similarity search for a string, we need to create a vector embedding for it using the same model as we used to create the index. The |Gel| AI binding in Python comes with a ``generate_embeddings`` function that does exactly that: .. code-block:: python-diff import gel import gel.ai gel_client = gel.create_client() gel_ai = gel.ai.create_rag_client(client) + text = "Who helps Komi make friends?" + vector = gel_ai.generate_embeddings( + text, + "text-embedding-3-small", + ) .. edb:split-section:: Now we can plug that vector directly into our query to get similarity search results: .. code-block:: python-diff import gel import gel.ai gel_client = gel.create_client() gel_ai = gel.ai.create_rag_client(client) text = "Who helps Komi make friends?" vector = gel_ai.generate_embeddings( text, "text-embedding-3-small", ) + gel_client.query( + "select ext::ai::search(Friend, >$embedding_vector", + embedding_vector=vector, + ) Use the built-in RAG ==================== One more feature |Gel| AI offers is built-in retrieval-augmented generation, also known as RAG. .. edb:split-section:: |Gel| comes preconfigured to be able to process our text query, perform similarity search across the index we just created, pass the results to an LLM and return a response. In order to access the built-in RAG, we need to start by selecting an LLM and passing its name to the |Gel| AI instance constructor: .. code-block:: python-diff import gel import gel.ai gel_client = gel.create_client() gel_ai = gel.ai.create_rag_client( client, + model="gpt-4-turbo-preview" ) .. edb:split-section:: Now we can access the RAG using the ``query_rag`` function like this: .. code-block:: python-diff import gel import gel.ai gel_client = gel.create_client() gel_ai = gel.ai.create_rag_client( client, model="gpt-4-turbo-preview" ) + gel_ai.query_rag( + "Who helps Komi make friends?", + context="Friend", + ) .. edb:split-section:: We can also stream the response like this: .. code-block:: python-diff import gel import gel.ai gel_client = gel.create_client() gel_ai = gel.ai.create_rag_client( client, model="gpt-4-turbo-preview" ) - gel_ai.query_rag( + gel_ai.stream_rag( "Who helps Komi make friends?", context="Friend", ) Keep going! =========== You are now sufficiently equipped to use |Gel| AI in your applications. If you'd like to build something on your own, make sure to check out the :ref:`Reference manual ` for the AI extension in order to learn the details about using different APIs and models, configuring prompts or using the UI. Make sure to take a look at the :ref:`Python binding reference `, too. And if you would like more guidance for how |Gel| AI can be fit into an application, take a look at the :ref:`FastAPI Gel AI Tutorial `, where we're building a search bot using features you learned about above. ================================================ FILE: docs/intro/guides/drizzle/index.rst ================================================ .. edb:env-switcher:: ================== Adding Drizzle ORM ================== .. toctree:: :maxdepth: 1 :hidden: nextjs ================================================ FILE: docs/intro/guides/drizzle/nextjs.rst ================================================ .. _ref_guide_gel_drizzle: ====================== Drizzle ORM in Next.js ====================== |Gel| integrates seamlessly with Drizzle ORM, providing a type-safe and intuitive way to interact with your database in TypeScript applications. Enable Drizzle in your Gel project ================================== .. edb:split-section:: To integrate Drizzle with your Gel project, you'll need to install the necessary dependencies: .. code-block:: bash $ npm install drizzle-orm $ npm install -D drizzle-kit .. edb:split-section:: Next, create a Drizzle configuration file in your project root to tell Drizzle how to work with your Gel database: .. code-block:: typescript :caption: drizzle.config.ts import { defineConfig } from 'drizzle-kit'; export default defineConfig({ dialect: 'gel', }); Sync your Gel schema with Drizzle ================================= .. edb:split-section:: Before using Drizzle with your Gel database, you'll need to let Drizzle introspect your schema. This step generates TypeScript files that Drizzle can use to interact with your database. .. code-block:: bash $ npx drizzle-kit pull .. edb:split-section:: This command will create a schema file based on your Gel database. The file will typically look something like this: .. code-block:: typescript :caption: drizzle/schema.ts :class: collapsible import { gelTable, uniqueIndex, uuid, smallint, text, timestamp, relations } from "drizzle-orm/gel-core" import { sql } from "drizzle-orm" export const books = gelTable("Book", { id: uuid().default(sql`uuid_generate_v4()`).primaryKey().notNull(), title: text().notNull(), author: text(), year: smallint(), genre: text(), read_date: timestamp(), }, (table) => [ uniqueIndex("books_pkey").using("btree", table.id.asc().nullsLast().op("uuid_ops")), ]); export const notes = gelTable("Note", { id: uuid().default(sql`uuid_generate_v4()`).primaryKey().notNull(), text: text().notNull(), created_at: timestamp().default(sql`datetime_current()`), book_id: uuid().notNull(), }, (table) => [ uniqueIndex("notes_pkey").using("btree", table.id.asc().nullsLast().op("uuid_ops")), ]); Keep Drizzle in sync with Gel ============================= .. edb:split-section:: To keep your Drizzle schema in sync with your Gel schema, add a hook to your ``gel.toml`` file. This hook will automatically run ``drizzle-kit pull`` after each migration: .. code-block:: toml :caption: gel.toml [hooks] after_migration_apply = [ "npx drizzle-kit pull" ] With this hook in place, your Drizzle schema will automatically update whenever you apply Gel migrations. Create a database client ======================== .. edb:split-section:: Now, let's create a database client that you can use throughout your application: .. code-block:: typescript :caption: src/db/index.ts import { drizzle } from 'drizzle-orm/gel'; import { createClient } from 'gel-js'; import * as schema from '@/drizzle/schema'; import * as relations from '@/drizzle/relations'; // Import our schema import * as schema from './schema'; // Initialize Gel client const gelClient = createClient(); // Create Drizzle instance export const db = drizzle({ client: gelClient, schema: { ...schema, ...relations }, }); // Helper types for use in our application export type Book = typeof schema.book.$inferSelect; export type NewBook = typeof schema.book.$inferInsert; export type Note = typeof schema.note.$inferSelect; export type NewNote = typeof schema.note.$inferInsert; Perform database operations with Drizzle ======================================== For more detailed information on querying and other operations, refer to the `Drizzle documentation `_. Below are some examples of common database operations you can perform with Drizzle. .. edb:split-section:: Drizzle provides a clean, type-safe API for database operations. Here are some examples of common operations: **Selecting data:** .. code-block:: typescript // Get all books with their notes const allBooks = await db.query.book.findMany({ with: { notes: true, }, }); // Get a specific book const book = await db.query.book.findFirst({ where: eq(books.id, id), with: { notes: true }, }); .. edb:split-section:: **Inserting data:** .. code-block:: typescript // Insert a new book const newBook = await db.insert(book).values({ title: 'The Great Gatsby', author: 'F. Scott Fitzgerald', year: 1925, genre: 'Novel', }).returning(); // Insert a note for a book const newNote = await db.insert(note).values({ text: 'A classic novel about the American Dream', book_id: newBook.bookId, }).returning(); **Bulk inserting data:** .. code-block:: typescript // Insert multiple books at once const newBooks = await db.insert(book).values([ { title: '1984', author: 'George Orwell', year: 1949, genre: 'Dystopian', }, { title: 'To Kill a Mockingbird', author: 'Harper Lee', year: 1960, genre: 'Fiction', }, { title: 'Pride and Prejudice', author: 'Jane Austen', year: 1813, genre: 'Romance', }, ]).returning(); .. edb:split-section:: **Updating data:** .. code-block:: typescript // Update a book const updatedBook = await db.update(book) .set({ title: 'Updated Title', author: 'Updated Author', }) .where(eq(books.id, bookId)) .returning(); .. edb:split-section:: **Deleting data:** .. code-block:: typescript // Delete a note await db.delete(notes).where(eq(notes.id, noteId)); Using Drizzle with Next.js ========================== .. edb:split-section:: In a Next.js application, you can use your Drizzle client in API routes and server components. Here's an example of an API route that gets all books: .. code-block:: typescript :caption: src/app/api/books/route.ts import { NextResponse } from 'next/server'; import { db } from '@/db'; export async function GET() { try { const allBooks = await db.query.book.findMany({ with: { notes: true }, }); return NextResponse.json(allBooks); } catch (error) { console.error('Error fetching books:', error); return NextResponse.json( { error: 'Failed to fetch books' }, { status: 500 } ); } } .. edb:split-section:: And here's an example of using Drizzle in a server component: .. code-block:: typescript :caption: src/app/books/page.tsx import { db } from '@/db'; import BookCard from '@/components/BookCard'; export default async function BooksPage() { const books = await db.query.book.findMany({ with: { notes: true }, }); return (
{books.map((book) => ( ))}
); } Keep going! =========== You are now ready to use Gel with Drizzle in your applications. This integration gives you the best of both worlds: Gel's powerful features and Drizzle's type-safe, intuitive API. For a complete example of using Gel with Drizzle in a Next.js application, check out our `Book Notes app example `_. You can also find a detailed tutorial on building a Book Notes app with Gel, Drizzle, and Next.js in our :ref:`documentation `. ================================================ FILE: docs/intro/guides/index.rst ================================================ ====== Guides ====== .. toctree:: :maxdepth: 1 ai/index drizzle/index ================================================ FILE: docs/intro/index.rst ================================================ .. _ref_intro: ============== Welcome to Gel ============== .. toctree:: :maxdepth: 3 :hidden: installation quickstart/index tutorials/index guides/index localdev cli instances projects schema migrations branches edgeql clients |Gel| is a next-generation `graph-relational database `_ designed as a spiritual successor to the relational database. It inherits the strengths of SQL databases: type safety, performance, reliability, and transactionality. But instead of modeling data in a relational (tabular) way, Gel represents data with *object types* containing *properties* and *links* to other objects. It leverages this object-oriented model to provide a superpowered query language that solves some of SQL's biggest usability problems. How to read the docs ^^^^^^^^^^^^^^^^^^^^ |Gel| is a complex system, but we've structured the documentation so you can learn it in "phases". You only need to learn as much as you need to start building your application. - **Get Started** — Start with the :ref:`quickstart `. It walks through Gel's core workflows: how to install Gel, create an instance, write a simple schema, execute a migration, write some simple queries, and use the client libraries. The rest of the section goes deeper on each of these subjects. - **Schema** — A set of pages that break down the concepts of syntax of Gel's schema definition language (SDL). This starts with a rundown of Gel's primitive type system (:ref:`Primitives `), followed by a description of (:ref:`Object Types `) and the things they can contain: links, properties, indexes, access policies, and more. - **EdgeQL** — A set of pages that break down Gel's query language, EdgeQL. It starts with a rundown of how to declare :ref:`literal values `, then introduces some key EdgeQL concepts like sets, paths, and type casts. With the basics established, it proceeds to break down all of EdgeQL's top-level statements: ``select``, ``insert``, and so on. - **Guides** — Contains collections of guides on topics that are peripheral to Gel itself: how to deploy to various cloud providers, how to integrate with various frameworks, and how to introspect the schema to build code-generation tools on top of Gel. - **Standard Library** — This section contains an encyclopedic breakdown of Gel's built-in types and the functions/operators that can be used with them. We didn't want to \ clutter the **EdgeQL** section with all the nitty-gritty on each of these. If you're looking for a particular function (say, a ``replace``), go to the Standard Library page for the relevant type (in this case, :ref:`String `), and peruse the table for what you're looking for (:eql:func:`str_replace`). - **Client Libraries** The documentation for Gel's set of official client libraries for JavaScript/TypeScript, Python, Go, and Rust. All client libraries implement Gel's binary protocol and provide a standard interface for executing queries. If you're using another language, you can execute queries :ref:`over HTTP `. This section also includes documentation for Gel's :ref:`GraphQL ` endpoint. - **CLI** Complete reference for the |gelcmd| command-line tool. The CLI is self-documenting—add the ``--help`` flag after any command to print the relevant documentation—so you shouldn't need to reference this section often. - **Reference** The *Reference* section contains a complete breakdown of Gel's *syntax* (for both EdgeQL and SDL), *internals* (like the binary protocol and dump file format), and *configuration settings*. Usually you'll only need to reference these once you're an advanced user. - **Changelog** Detailed changelogs for each successive version of Gel, including any breaking changes, new features, bigfixes, and links to Tooling ^^^^^^^ To actually build apps with Gel, you'll need to know more than SDL and EdgeQL. - **CLI** — The most commonly used CLI functionality is covered in the :ref:`Quickstart `. For additional details, we have dedicated guides for :ref:`Migrations ` and :ref:`Projects `. A full CLI reference is available under :ref:`CLI `. - **Client Libraries** — To actually execute queries, you'll use one of our client libraries for JavaScript, Go, or Python; find your preferred library under :ref:`Client Libraries `. If you're using another language, you can still use Gel! You can execute :ref:`queries via HTTP `. - **Deployment** — To publish a Gel-backed application, you'll need to deploy Gel. Refer to :ref:`Guides > Deployment ` for step-by-step deployment guides for all major cloud hosting platforms, as well as instructions for self-hosting with Docker. .. .. eql:react-element:: DocsNavTable |Gel| features: .. class:: ticklist - strict, strongly typed schema; - powerful and clean query language; - ability to easily work with complex hierarchical data; - built-in support for schema migrations. |Gel| is not a graph database: the data is stored and queried using relational database techniques. Unlike most graph databases, Gel maintains a strict schema. |Gel| is not a document database, but inserting and querying hierarchical document-like data is trivial. |Gel| is not a traditional object database, despite the classification, it is not an implementation of OOP persistence. ================================================ FILE: docs/intro/install_table.rst ================================================ .. tabs:: .. code-tab:: bash :caption: bash $ curl https://www.geldata.com/sh --proto "=https" -sSf1 | sh .. code-tab:: powershell :caption: Powershell PS> irm https://www.geldata.com/ps1 | iex .. code-tab:: bash :caption: Homebrew $ brew install geldata/tap/gel-cli .. code-tab:: bash :caption: Nixpkgs $ nix-shell -p gel .. code-tab:: bash :caption: JavaScript $ npx gel --version .. code-tab:: bash :caption: Python $ uvx gel --version ================================================ FILE: docs/intro/installation.rst ================================================ .. _ref_cli_gel_install: ============ Installation ============ We provide a :ref:`CLI for managing and interacting with local and remote databases `. If you're using JavaScript or Python, our client libraries will handle downloading and running the CLI for you using tools like ``npx`` and ``uvx``. For everyone else, or if you wish to install the CLI globally, you can install using our bash installer or your operating system's package manager. .. include:: ./install_table.rst ================================================ FILE: docs/intro/instances.rst ================================================ .. _ref_intro_instances: ========= Instances ========= Let's get to the good stuff. You can spin up a Gel instance with a single command. .. code-block:: bash $ gel instance create my_instance This creates a new instance named ``my_instance`` that runs the latest stable version of Gel. (Gel itself will be automatically installed if it isn't already.) Alternatively you can specify a specific version with ``--version``. .. code-block:: bash $ gel instance create my_instance --version 6.1 $ gel instance create my_instance --version nightly We can execute a query against our new instance with :gelcmd:`query`. Specify which instance to connect to by passing an instance name into the ``-I`` flag. .. code-block:: bash $ gel query "select 3.14" -I my_instance 3.14 Managing instances ^^^^^^^^^^^^^^^^^^ Instances can be stopped, started, restarted, and destroyed. .. code-block:: bash $ gel instance stop -I my_instance $ gel instance start -I my_instance $ gel instance restart -I my_instance $ gel instance destroy -I my_instance Listing instances ^^^^^^^^^^^^^^^^^ To list all instances on your machine: .. code-block:: bash $ gel instance list ┌────────┬──────────────────┬──────────┬────────────────┬──────────┐ │ Kind │ Name │ Port │ Version │ Status │ ├────────┼──────────────────┼──────────┼────────────────┼──────────┤ │ local │ my_instance │ 10700 │ x.x+cc4f3b5 │ active │ │ local │ my_instance_2 │ 10701 │ x.x+cc4f3b5 │ active │ │ local │ my_instance_3 │ 10702 │ x.x+cc4f3b5 │ active │ └────────┴──────────────────┴──────────┴────────────────┴──────────┘ Further reference ^^^^^^^^^^^^^^^^^ For complete documentation on managing instances with the CLI (upgrading, viewing logs, etc.), refer to the :ref:`gel instance ` reference or view the help text in your shell: .. code-block:: bash $ gel instance --help ================================================ FILE: docs/intro/localdev.rst ================================================ ================= Local Development ================= One of Gel's most powerful features is its seamless support for local development. The Gel CLI makes it incredibly easy to spin up a local instance, manage it, access GUI, and iterate on your schema quickly and safely. This guide outlines the flexible options available for your local development workflow. If you're using JavaScript or Python, our client libraries will automatically handle the installation for you using tools like ``npx`` and ``uvx``. For other environments or to install the CLI globally, you can use one of the following methods: .. include:: ./install_table.rst Initialize your local instance ============================== It's easy to get started with a local Gel instance. Navigate to the root of your project repository and run: .. code-block:: bash $ gel init Creates a database tied to the current directory and to the :ref:`gel.toml ` file in it. This simplifies connection configuration and installation for you. Alias for :ref:`gel project init `. To conserve resources, Gel automatically puts inactive local development instances to sleep. This means you can have multiple instances running without them draining your system's resources when not in use. Iterate on your schema ====================== Gel simplifies the process of evolving your data model. You can apply changes from your Gel schema files directly to your running local instance without needing to create a separate migration file for every minor adjustment. There are two primary ways to apply schema changes during development: 1. **Automatic updates with** :gelcmd:`watch --migrate` For a hands-off approach, you can use the watch command. This starts a process that monitors your Gel schema files for changes and automatically migrates your local instance as soon as you save them: .. code-block:: bash $ gel watch --migrate This is ideal for rapid iteration when you want to see your schema changes reflected immediately. 2. **Manual updates with** :gelcmd:`migrate --dev-mode` If you prefer more explicit control, or don't want a background process running, you can apply schema changes manually: .. code-block:: bash $ gel migrate --dev-mode This command performs the same action as the watch --migrate mode—applying the current state of your schema files to the local instance—but only when you explicitly run it. Finalizing changes ================== Once you're satisfied with the schema changes you've made iteratively, you'll want to create a migration file which will be committed to version control and shared with others. This new migration file will encapsulate all the modifications made since your last migration. 1. **Create the migration file** .. code-block:: bash $ gel migration create This command inspects the differences between your last migration file and the current state of your database schema, then generates a new migration file reflecting these changes. 1. **Align your local instance** After creating the migration, run the following command to ensure your local instance's migration history is aligned with this new migration. You can do this by running: .. code-block:: bash $ gel migrate --dev-mode This command effectively "fast-forwards" your local instance. From its perspective, it will appear as though all the iterative changes were applied as part of this single, new migration. This keeps your local development environment consistent with the migration history you'll use in other environments (like staging or production). Undoing destructive changes =========================== Mistakes happen! You might accidentally make a destructive schema change. Fortunately, Gel has your back. Every time you migrate your schema (either via :gelcmd:`watch --migrate` or :gelcmd:`migrate --dev-mode`), a backup of your local instance is automatically taken. This behavior can be disabled by setting env variable ``GEL_AUTO_BACKUP_MODE`` to ``disabled``. If you need to roll back to a previous state: 1. **Stop any active migration processes**: Ensure :gelcmd:`watch --migrate` is not running. 2. **Find the backup ID**: Look through your shell's scrollback history. You'll find messages indicating backups were made, along with their IDs. Identify the ID of the backup created before the destructive change. You can also use the :gelcmd:`instance listbackups` command to list all backups for this instance. 3. **Restore the instance** .. code-block:: bash $ gel instance restore -I Replace with the actual ID and with the name of your instance (e.g., my_project). This will restore both your data and schema to the state at that backup point. Once restored, you can make the intended schema changes and then restart :gelcmd:`watch --migrate` or use :gelcmd:`migrate --dev-mode` as preferred. Keeping code in sync ==================== Many Gel language bindings offer code generation capabilities (e.g., query builders, typed query functions). This generated code needs to stay synchronized with your schema. Gel provides a system of hooks and watchers that you can configure in your |gel.toml| file to automate this. These hooks can trigger codegen scripts when: * The schema changes (using the "schema.change.after" hook). * Specific files are edited (using watch scripts). Here's an example |gel.toml| configuration for a TypeScript project. It runs a query builder generator and a queries generator at the appropriate times: .. code-block:: toml [instance] server-version = "6.7" [hooks] "schema.change.after" = "npx @gel/generate edgeql-js && npx @gel/generate queries" [watch] "src/queries/**/*.edgeql" = "npx @gel/generate queries" Explanation: * ``[hooks] / "schema.change.after"``: When any schema change is successfully applied, we run the query builder generator (to reflect schema structure changes) and the queries generator (to update based on new or modified types). * ``[watch] / "src/queries/**/*.edgeql"``: If any ``.edgeql`` files within the ``src/queries/`` directory (or its subdirectories) are modified, the command ``npx @gel/generate queries`` is executed. This ensures that your typed query functions are always up-to-date with your EdgeQL query definitions. By configuring these hooks and watchers, you can maintain a smooth workflow where your generated code automatically adapts to changes in your schema and query files. ================================================ FILE: docs/intro/migrations.rst ================================================ .. _ref_intro_migrations: ========== Migrations ========== .. index:: fill_expr, cast_expr |Gel's| baked-in migration system lets you painlessly evolve your schema throughout the development process. If you want to work along with this guide, start a new project with :ref:`ref_cli_gel_project_init`. This will create a new instance and some empty schema files to get you started. 1. Start the ``watch`` command ------------------------------ The easiest way to work with your schema in development is by running :gelcmd:`watch --migrate`. This long-running task will monitor your schema files and automatically apply schema changes in your database as you work. .. code-block:: bash $ gel watch --migrate Hint: --migrate will apply any changes from your schema files to the database. When ready to commit your changes, use: 1) `gel migration create` to write those changes to a migration file, 2) `gel migrate --dev-mode` to replace all synced changes with the migration. Monitoring /home/instancename for changes in: --migrate: gel migration apply --dev-mode If you get output similar to the output above, you're ready to get started! 2. Write an initial schema -------------------------- By convention, your Gel schema is defined inside one or more |.gel| files that live in a directory called ``dbschema`` in the root directory of your codebase. .. code-block:: . ├── dbschema │ └── default.gel # schema file (written by you) └── gel.toml The schema itself is written using Gel's schema definition language. Edit your :dotgel:`dbschema/default` and add the following schema inside your ``module default`` block: .. code-block:: sdl type User { required name: str; } type Post { required title: str; required author: User; } It's common to keep your entire schema in a single file, and many users use this :dotgel:`default` that is created when you start a project. However it's also possible to split their schemas across a number of |.gel| files. Once you save your initial schema, assuming it is valid, the ``watch`` command will pick it up and apply it to your database. 3. Edit your schema files ------------------------- As your application evolves, directly edit your schema files to reflect your desired data model. Try updating your :dotgel:`dbschema/default` to add a ``Comment`` type: .. code-block:: sdl-diff type User { required name: str; } type Post { required title: str; required author: User; } + type Comment { + required content: str; + } When you save your changes, ``watch`` will immediately begin applying your new schema to the database. .. note:: If your schema cannot be applied, the ``watch`` command will generate an error. If you're using one of our client bindings as you update your schema with ``watch``, you will see the error there the next time you execute a query using that client binding. If things aren't working the way you expect after making a schema change, take a look at the ``watch`` console to find out why. Once you have the schema the way you want it, and you're ready to lock it in and commit it to version control, it's time to generate a migration. 4. Generate a migration ----------------------- To generate a migration that reflects all your changes, run :gelcmd:`migration create`. .. code-block:: bash $ gel migration create The CLI reads your schema file and sends it to the active Gel instance. The instance compares the file's contents to its current schema state and determines a migration plan. **The migration plan is generated by the database itself.** This plan is then presented to you interactively; each detected schema change will be individually presented to you for approval. For each prompt, you have a variety of commands at your disposal. Type ``y`` to approve, ``n`` to reject, ``q`` to cancel the migration, or ``?`` for a breakdown of some more advanced options. .. code-block:: bash $ gel migration create did you create object type 'default::Comment'? [y,n,l,c,b,s,q,?] > y did you create object type 'default::User'? [y,n,l,c,b,s,q,?] > y did you create object type 'default::Post'? [y,n,l,c,b,s,q,?] > y Created dbschema/migrations/00001.edgeql, id: .. _ref_intro_migrations_wo_iteration: Migration without iteration --------------------------- If you want to change the schema, but you already know exactly what you want to change and don't need to iterate on your schema — you want to lock in the migration right away — :gelcmd:`watch` might not be the tool you reach for. Instead, you might use this method: 1. Edit your schema files 2. Create your migration with :gelcmd:`migration create` 3. Apply your migration with :gelcmd:`migrate` Since you're not using ``watch``, the schema changes are not applied when you save your schema files. As a result, we need to tack an extra step on the end of the process of applying the migration. That's handled by :gelcmd:`migrate`. .. code-block:: bash $ gel migrate Applied m1virjowa... (00002.edgeql) Once your migration is applied, you'll see the schema changes reflected in your database. Data migrations --------------- Depending on how the schema was changed, data in your database may prevent |Gel| from applying your schema changes. Imagine we added a required ``body`` property to our ``Post`` type: .. code-block:: sdl-diff type User { required name: str; } type Post { required title: str; + required body: str; required author: User; } type Comment { required content: str; } If we hadn't added any ``Post`` objects to our database before this, everything would have worked fine, but it's likely that, in testing out our schema, we *did* add a ``Post`` object. It does not have a ``body`` property, but now we've told the database this property is required on all ``Post`` objects. The database can't apply this change because existing data would break it. We have a couple of options here. We could delete all the offending objects. .. code-block:: edgeql-repl db> delete Post; { default::Post {id: a4a0a40c-d9f5-11ed-8912-1397f7af9fdf}, default::Post {id: cc051bea-d9f5-11ed-a26d-2b64b6b273a4} } Now, if we save the schema again, :gelcmd:`watch` will be able to apply it. If we have data in here we don't want to lose though, that's not a good option. In that case, we might drop back to creating and applying the migration outside of :gelcmd:`watch`. To start, run :gelcmd:`migration create`. The interactive plan generator will ask you for an EdgeQL expression to map the contents of your database to the new schema. .. code-block:: bash $ gel migration create did you create property 'body' of object type 'default::Post'? [y,n,l,c,b,s,q,?] > y Please specify an expression to populate existing objects in order to make property 'body' of object type 'default::Post' required: fill_expr> Because the ``body`` property does not currently exist, the database contains ``Post`` objects without it. The expression you provide will be used to *assign a body* to any ``Post`` object that doesn't have one. We'll just provide a simple default: ``'No content'``. .. code-block:: fill_expr> 'No content' Created dbschema/migrations/00002.edgeql, id: m1pjiibv4sa4cao7txpgsbuw2erctmacyrj4qmn45ggapsaztmvxfa Nice! It accepted our answer and created a new migration file ``00002.edgeql``. Let's see what the newly created ``00002.edgeql`` file contains. .. code-block:: edgeql CREATE MIGRATION m1pjiibv4sa4cao7txpgsbuw2erctmacyrj4qmn45ggapsaztmvxfa ONTO m1nlvzbm7buwktkp4vu4shylq6zp2shruokbbssyeidqmmmfqz77yq { ALTER TYPE default::Post { CREATE REQUIRED PROPERTY body: std::str { SET REQUIRED USING ('No content'); }; }; }; We have a ``CREATE MIGRATION`` block containing an ``ALTER TYPE`` statement to create ``Post.body`` as a ``required`` property. We can see that our fill expression (``'No content'``) is included directly in the migration file. Note that we could have provide an *arbitrary EdgeQL expression*! The following EdgeQL features are often useful: .. list-table:: * - ``assert_exists`` - This is an "escape hatch" function that tells Gel to assume the input has *at least* one element. .. code-block:: fill_expr> assert_exists(.body) If you provide a ``fill_expr`` like the one above, you must separately ensure that all posts have a ``body`` before executing the migration; otherwise it will fail. * - ``assert_single`` - This tells Gel to assume the input has *at most* one element. This will throw an error if the argument is a set containing more than one element. This is useful is you are changing a property from ``multi`` to ``single``. .. code-block:: fill_expr> assert_single(.sheep) * - type casts - Useful when converting a property to a different type. .. code-block:: cast_expr> .xp Further reading ^^^^^^^^^^^^^^^ - :ref:`Guide to schema migrations ` - :ref:`Migration tips ` Further information can be found in the :ref:`CLI reference ` or the `Beta 1 blog post `_, which describes the design of the migration system. ================================================ FILE: docs/intro/projects.rst ================================================ .. _ref_intro_projects: ======== Projects ======== It can be inconvenient to pass the ``-I`` flag every time you wish to run a CLI command. .. code-block:: bash $ gel migration create -I my_instance That's one of the reasons we introduced the concept of an *Gel project*. A project is a directory on your file system that is associated ("linked") with a Gel instance. .. note:: Projects are intended to make *local development* easier! They only exist on your local machine and are managed with the CLI. When deploying Gel for production, you will typically pass connection information to the client library using environment variables. When you're inside a project, all CLI commands will be applied against the *linked instance* by default (no CLI flags required). .. code-block:: bash $ gel migration create The same is true for all Gel client libraries (discussed in more depth in the :ref:`Clients ` section). If the following file lives inside a Gel project directory, ``createClient`` will discover the project and connect to its linked instance with no additional configuration. .. code-block:: typescript // clientTest.js import {createClient} from 'gel'; const client = createClient(); await client.query("select 5"); Initializing ^^^^^^^^^^^^ To initialize a project, create a new directory and run :gelcmd:`project init` inside it. You'll see something like this: .. code-block:: bash $ gel project init No `gel.toml` found in this repo or above. Do you want to initialize a new project? [Y/n] > Y Specify the name of Gel instance to use with this project [default: my_instance]: > my_instance Checking Gel versions... Specify the version of Gel to use with this project [default: x.x]: > # (left blank for default) ... Successfully installed x.x+cc4f3b5 Initializing Gel instance... Applying migrations... Everything is up to date. Revision initial Project initialized. To connect to my_instance, run `gel` This command does a couple important things. 1. It spins up a new Gel instance called ``my_instance``. 2. If no |gel.toml| file exists, it will create one. This is a configuration file that marks a given directory as a Gel project. Learn more about it in the :ref:`gel.toml reference `. .. code-block:: toml [instance] server-version = "6.0" 3. If no ``dbschema`` directory exists, it will be created, along with an empty :dotgel:`default` file which will contain your schema. If a ``dbschema`` directory exists and contains a subdirectory called ``migrations``, those migrations will be applied against the new instance. Every project maps one-to-one to a particular Gel instance. From inside a project directory, you can run :gelcmd:`project info` to see information about the current project. .. code-block:: bash $ gel project info ┌───────────────┬──────────────────────────────────────────┐ │ Instance name │ my_instance │ │ Project root │ /path/to/project │ └───────────────┴──────────────────────────────────────────┘ Connection ^^^^^^^^^^ As long as you are inside the project directory, all CLI commands will be executed against the project-linked instance. For instance, you can simply run |gelcmd| to open a REPL. .. code-block:: bash $ gel Gel x.x+cc4f3b5 (repl x.x+da2788e) Type \help for help, \quit to quit. my_instance:main> select "Hello world!"; By contrast, if you leave the project directory, the CLI will no longer know which instance to connect to. You can solve this by specifing an instance name with the ``-I`` flag. .. code-block:: bash $ cd ~ $ gel gel error: no `gel.toml` found and no connection options are specified Hint: Run `gel project init` or use any of `-H`, `-P`, `-I` arguments to specify connection parameters. See `--help` for details $ gel -I my_instance Gel x.x+cc4f3b5 (repl x.x+da2788e) Type \help for help, \quit to quit. my_instance:main> Similarly, client libraries will auto-connect to the project's linked instance without additional configuration. Using remote instances ^^^^^^^^^^^^^^^^^^^^^^ You may want to initialize a project that points to a remote Gel instance. This is totally a valid case and Gel fully supports it! Before running :gelcmd:`project init`, you just need to create an alias for the remote instance using :gelcmd:`instance link`, like so: .. lint-off .. code-block:: bash $ gel instance link Specify server host [default: localhost]: > 192.168.4.2 Specify server port [default: 5656]: > 10818 Specify database user [default: admin]: > admin Specify branch [default: main]: > main Unknown server certificate: SHA1:c38a7a90429b033dfaf7a81e08112a9d58d97286. Trust? [y/N] > y Password for 'admin': Specify a new instance name for the remote server [default: abcd]: > staging_db Successfully linked to remote instance. To connect run: gel -I staging_db .. lint-on After receiving the necessary connection information, this command links the remote instance to a local alias ``"staging_db"``. You can use this as instance name in CLI commands. .. code-block:: $ gel -I staging_db gel> To initialize a project that uses the remote instance, provide this alias when prompted for an instance name during the :gelcmd:`project init` workflow. Unlinking ^^^^^^^^^ An instance can be unlinked from a project. This leaves the instance running but effectively "uninitializes" the project. The |gel.toml| and ``dbschema`` are left untouched. .. code-block:: bash $ gel project unlink If you wish to delete the instance as well, use the ``-D`` flag. .. code-block:: bash $ gel project unlink -D Upgrading ^^^^^^^^^ A standalone instance (not linked to a project) can be upgraded with the :gelcmd:`instance upgrade` command. .. code-block:: bash $ gel project upgrade --to-latest $ gel project upgrade --to-nightly $ gel project upgrade --to-version x.x See info ^^^^^^^^ You can see the location of a project and the name of its linked instance. .. code-block:: bash $ gel project info ┌───────────────┬──────────────────────────────────────────┐ │ Instance name │ my_app │ │ Project root │ /path/to/my_app │ └───────────────┴──────────────────────────────────────────┘ ================================================ FILE: docs/intro/quickstart/ai/fastapi.rst ================================================ .. _ref_quickstart_ai: ====================== Using the built-in RAG ====================== .. edb:split-section:: In this section we'll learn about |Gel's| built-in vector search and retrieval-augmented generation capabilities. We'll be continuing from where we left off in the :ref:`main quickstart `. Feel free to browse the complete flascards app code in this `repo `_. In this tutorial we'll focus on creating a ``/fetch_similar`` endpoint for looking up flashcards similar to a text search query, as well as a ``/fetch_rag`` endpoint that's going to enable us to talk to an LLM about the content of our flashcard deck. We're going to start with the same schema we left off with in the primary quickstart. .. code-block:: sdl :caption: dbschema/default.gel module default { abstract type Timestamped { required created_at: datetime { default := datetime_of_statement(); }; required updated_at: datetime { default := datetime_of_statement(); }; } type Deck extending Timestamped { required name: str; description: str; multi cards: Card { constraint exclusive; on target delete allow; }; }; type Card extending Timestamped { required order: int64; required front: str; required back: str; } } .. edb:split-section:: AI-related features in |Gel| come packaged in the extension called ``ai``. Let's enable it by adding the following line on top of the :dotgel:`dbschema/default` and running a migration. This does a few things. First, it enables us to use features from the extension by prefixing them with ``ext::ai::``. .. code-block:: sdl-diff :caption: dbschema/default.gel + using extension ai; module default { abstract type Timestamped { required created_at: datetime { default := datetime_of_statement(); }; required updated_at: datetime { default := datetime_of_statement(); }; } type Deck extending Timestamped { required name: str; description: str; multi cards: Card { constraint exclusive; on target delete allow; }; }; type Card extending Timestamped { required order: int64; required front: str; required back: str; } } .. edb:split-section:: This enabled us to use features in the ``ext::ai::`` namespace. Here's a notable one: ``ProviderConfig``, which we can use to configure our API keys. |Gel| supports a variety of external APIs for creating embedding vectors for text and fetching LLM completions. Let's configure an API key for OpenAI by running the following query in the REPL: .. note:: Once the extension is active, we can also access the dedicated AI tab in the UI. There we can manage provider configurations and try out different RAG configuraton in the Playground. .. code-block:: edgeql-repl db> configure current database insert ext::ai::OpenAIProviderConfig { secret := 'sk-....', }; .. edb:split-section:: Once last thing before we move on. Let's add some sample data to give the embedding model something to work with. You can copy and run this command in the terminal, or come up with your own sample data. .. code-block:: edgeql :class: collapsible $ cat << 'EOF' | gel query --file - with deck := ( insert Deck { name := 'Smelly Cheeses', description := 'To impress everyone with stinky cheese trivia.' } ) for card_data in {( 1, 'Époisses de Bourgogne', 'Known as the "king of cheeses", this French cheese is so pungent it\'s banned on public transport in France. Washed in brandy, it becomes increasingly funky as it ages. Orange-red rind, creamy interior.' ), ( 2, 'Vieux-Boulogne', 'Officially the smelliest cheese in the world according to scientific studies. This northern French cheese has a reddish-orange rind from being washed in beer. Smooth, creamy texture with a powerful aroma.' ), ( 3, 'Durian Cheese', 'This Malaysian creation combines durian fruit with cheese, creating what some consider the ultimate "challenging" dairy product. Combines the pungency of blue cheese with durian\'s notorious aroma.' ), ( 4, 'Limburger', 'German cheese famous for its intense smell, often compared to foot odor due to the same bacteria. Despite its reputation, has a surprisingly mild taste with notes of mushroom and grass.' ), ( 5, 'Roquefort', 'The "king of blue cheeses", aged in limestone caves in southern France. Contains Penicillium roqueforti mold. Strong, tangy, and salty with a crumbly texture. Legend says it was discovered when a shepherd left his lunch in a cave.' ), ( 6, 'What makes washed-rind cheeses so smelly?', 'The process of washing cheese rinds in brine, alcohol, or other solutions promotes the growth of Brevibacterium linens, the same bacteria responsible for human body odor. This bacteria contributes to both the orange color and distinctive aroma.' ), ( 7, 'Stinking Bishop', 'Named after the Stinking Bishop pear (not a religious figure). This English cheese is washed in perry made from these pears. Known for its powerful aroma and sticky, pink-orange rind. Gained fame after being featured in Wallace & Gromit.' )} union ( insert Card { deck := deck, order := card_data.0, front := card_data.1, back := card_data.2 } ); EOF .. edb:split-section:: Now we can finally start producing embedding vectors. Since |Gel| is fully aware of when your data gets inserted, updated and deleted, it's perfectly equipped to handle all the tedious work of keeping those vectors up to date. All that's left for us is to create a special ``deferred index`` on the data we would like to perform similarity search on. .. code-block:: sdl-diff :caption: dbschema/default.gel using extension ai; module default { abstract type Timestamped { required created_at: datetime { default := datetime_of_statement(); }; required updated_at: datetime { default := datetime_of_statement(); }; } type Deck extending Timestamped { required name: str; description: str; multi cards: Card { constraint exclusive; on target delete allow; }; }; type Card extending Timestamped { required order: int64; required front: str; required back: str; + deferred index ext::ai::index(embedding_model := 'text-embedding-3-small') + on (.front ++ ' ' ++ .back); } } .. edb:split-section:: It's time to start running queries. Let's begin by creating the ``/fetch_similar`` endpoint we mentioned earlier. It's job is going to be to find 3 flashcards that are the most similar to the provided text query. We can use this endpoint to implement a "recommended flashcards" on the frontend. The AI extension contains a function called ``ext::ai::search(Type, embedding_vector)`` that we can use to do our fetch. Note that the second argument is an embedding vector, not a text query. To transform our text query into a vector, we will use the ``generate_embeddings`` function from the ``ai`` module of |Gel|'s Python binding. Gathered together, here are the modifications we need to do to the ``main.py`` function: .. code-block:: python-diff :caption: main.py import gel + import gel.ai from fastapi import FastAPI client = gel.create_async_client() app = FastAPI() + @app.get("/fetch_similar") + async def fetch_similar_cards(query: str): + rag = await gel.ai.create_async_rag_client(client, model="gpt-4-turbo-preview") + embedding_vector = await rag.generate_embeddings( + query, model="text-embedding-3-small" + ) + similar_cards = await client.query( + "select ext::ai::search(Card, >$embedding_vector)", + embedding_vector=embedding_vector, + ) + return similar_cards .. edb:split-section:: Let's test the endpoint to see that everything works the way we expect. .. code-block:: bash $ curl -X 'GET' \ 'http://localhost:8000/fetch_similar?query=the%20stinkiest%20cheese' \ -H 'accept: application/json' .. edb:split-section:: Finally, let's create the second endpoint we mentioned, called ``/fetch_rag``. We'll be able to use this one to, for example, ask an LLM to quiz us on the contents of our deck. The RAG feature is represented in the Python binding with the ``query_rag`` method of the ``GelRAG`` class. To use it, we're going to instantiate the class and call the method... And that's it! .. code-block:: python-diff :caption: main.py import gel import gel.ai from fastapi import FastAPI client = gel.create_async_client() app = FastAPI() @app.get("/fetch_similar") async def fetch_similar_cards(query: str): rag = await gel.ai.create_async_rag_client(client, model="gpt-4-turbo-preview") embedding_vector = await rag.generate_embeddings( query, model="text-embedding-3-small" ) similar_cards = await client.query( "select ext::ai::search(Card, >$embedding_vector)", embedding_vector=embedding_vector, ) return similar_cards + @app.get("/fetch_rag") + async def fetch_rag_response(query: str): + rag = await gel.ai.create_async_rag_client(client, model="gpt-4-turbo-preview") + response = await rag.query_rag( + message=query, + context=gel.ai.QueryContext(query="select Card"), + ) + return response .. edb:split-section:: Let's test the endpoint to see if it works: .. code-block:: bash $ curl -X 'GET' \ 'http://localhost:8000/fetch_rag?query=what%20cheese%20smells%20like%20feet' \ -H 'accept: application/json' Congratulations! We've now implemented AI features in our flashcards app. Of course, there's more to learn when it comes to using the AI extension. Make sure to check out the :ref:`Reference manual `, or build an LLM-powered search bot from the ground up with the :ref:`FastAPI Gel AI tutorial `. ================================================ FILE: docs/intro/quickstart/ai/index.rst ================================================ .. edb:env-switcher:: ========= Adding AI ========= .. toctree:: :maxdepth: 1 fastapi ================================================ FILE: docs/intro/quickstart/connecting/fastapi.rst ================================================ .. _ref_quickstart_fastapi_connecting: ========================== Connecting to the database ========================== .. edb:split-section:: Before diving into the application, let's take a quick look at how to connect to the database from your code. We will intialize a client and use it to make a simple, static query to the database, and log the result to the console. .. note:: Notice that the ``create_async_client`` function isn't being passed any connection details. With |Gel|, you do not need to come up with your own scheme for how to build the correct database connection credentials and worry about leaking them into your code. You simply use |Gel| "projects" for local development, and set the appropriate environment variables in your deployment environments, and the ``create_async_client`` function knows what to do! .. edb:split-point:: .. code-block:: python :caption: ./test.py import gel import asyncio async def main(): client = gel.create_async_client() result = await client.query_single("select 'Hello from Gel!';") print(result) asyncio.run(main()) .. code-block:: sh $ python test.py Hello from Gel! .. edb:split-section:: In Python, we write EdgeQL queries directly as strings. This gives us the full power and expressiveness of EdgeQL while maintaining type safety through Gel's strict schema. Let's try inserting a few ``Deck`` objects into the database and then selecting them back. .. edb:split-point:: .. code-block:: python-diff :caption: ./test.py import gel import asyncio async def main(): client = gel.create_async_client() - result = await client.query_single("select 'Hello from Gel!';") - print(result) + await client.query(""" + insert Deck { name := "I am one" } + """) + + await client.query(""" + insert Deck { name := "I am two" } + """) + + decks = await client.query(""" + select Deck { + id, + name + } + """) + + for deck in decks: + print(f"ID: {deck.id}, Name: {deck.name}") + + await client.query("delete Deck") asyncio.run(main()) .. code-block:: sh $ python test.py Hello from Gel! ID: f4cd3e6c-ea75-11ef-83ec-037350ea8a6e, Name: I am one ID: f4cf27ae-ea75-11ef-83ec-3f7b2fceab24, Name: I am two ================================================ FILE: docs/intro/quickstart/connecting/index.rst ================================================ .. edb:env-switcher:: ========================== Connecting to the database ========================== .. toctree:: :maxdepth: 3 :hidden: nextjs fastapi ================================================ FILE: docs/intro/quickstart/connecting/nextjs.rst ================================================ .. _ref_quickstart_connecting: ========================== Connecting to the database ========================== .. edb:split-section:: Before diving into the application, let's take a quick look at how to connect to the database from your code. We will intialize a client and use it to make a simple, static query to the database, and log the result to the console. .. note:: Notice that the ``createClient`` function isn't being passed any connection details. With |Gel|, you do not need to come up with your own scheme for how to build the correct database connection credentials and worry about leaking them into your code. You simply use |Gel| "projects" for local development, and set the appropriate environment variables in your deployment environments, and the ``createClient`` function knows what to do! .. edb:split-point:: .. code-block:: typescript :caption: ./test.ts import { createClient } from "gel"; const client = createClient(); async function main() { console.log(await client.query("select 'Hello from Gel!';")); } main().then( () => process.exit(0), (err) => { console.error(err); process.exit(1); } ); .. code-block:: sh $ npx tsx test.ts [ 'Hello from Gel!' ] .. edb:split-section:: With TypeScript, there are three ways to run a query: use a string EdgeQL query, use the ``queries`` generator to turn a string of EdgeQL into a TypeScript function, or use the query builder API to build queries dynamically in a type-safe manner. In this tutorial, you will use the TypeScript query builder API. This query builder must be regenerated any time the schema changes, so a hook has been added to the ``gel.toml`` file to generate the query builder any time the schema is updated. Moving beyond this simple query, use the query builder API to insert a few ``Deck`` objects into the database, and then select them back. .. edb:split-point:: .. code-block:: typescript-diff :caption: ./test.ts import { createClient } from "gel"; + import e from "@/dbschema/edgeql-js"; const client = createClient(); async function main() { console.log(await client.query("select 'Hello from Gel!';")); + await e.insert(e.Deck, { name: "I am one" }).run(client); + + await e.insert(e.Deck, { name: "I am two" }).run(client); + + const decks = await e + .select(e.Deck, () => ({ + id: true, + name: true, + })) + .run(client); + + console.table(decks); + + await e.delete(e.Deck).run(client); } main().then( () => process.exit(0), (err) => { console.error(err); process.exit(1); } ); .. code-block:: sh $ npx tsx test.ts [ 'Hello from Gel!' ] ┌─────────┬────────────────────────────────────────┬────────────┐ │ (index) │ id │ name │ ├─────────┼────────────────────────────────────────┼────────────┤ │ 0 │ 'f4cd3e6c-ea75-11ef-83ec-037350ea8a6e' │ 'I am one' │ │ 1 │ 'f4cf27ae-ea75-11ef-83ec-3f7b2fceab24' │ 'I am two' │ └─────────┴────────────────────────────────────────┴────────────┘ Now that you know how to connect to the database, you will see that we have provided an initialized ``Client`` object in the ``/lib/gel.ts`` module. Throughout the rest of the tutorial, you will import this ``Client`` object and use it to make queries. ================================================ FILE: docs/intro/quickstart/index.rst ================================================ ========== Quickstart ========== .. toctree:: :maxdepth: 1 :hidden: overview/index setup/index modeling/index connecting/index working/index inheritance/index ai/index ================================================ FILE: docs/intro/quickstart/inheritance/fastapi.rst ================================================ .. _ref_quickstart_fastapi_inheritance: ======================== Adding shared properties ======================== .. edb:split-section:: One common pattern in applications is to add shared properties to the schema that are used by multiple objects. For example, you might want to add a ``created_at`` and ``updated_at`` property to every object in your schema. You can do this by adding an abstract type and using it as a mixin for your other object types. .. code-block:: sdl-diff :caption: dbschema/default.gel module default { + abstract type Timestamped { + required created_at: datetime { + default := datetime_of_statement(); + }; + required updated_at: datetime { + default := datetime_of_statement(); + }; + } + - type Deck { + type Deck extending Timestamped { required name: str; description: str; cards := ( select . y did you alter object type 'default::Card'? [y,n,l,c,b,s,q,?] > y did you alter object type 'default::Deck'? [y,n,l,c,b,s,q,?] > y Created /home/strinh/projects/flashcards/dbschema/migrations/00004-m1d2m5n.edgeql, id: m1d2m5n5ajkalyijrxdliioyginonqbtfzihvwdfdmfwodunszstya $ gel migrate Applying m1d2m5n5ajkalyijrxdliioyginonqbtfzihvwdfdmfwodunszstya (00004-m1d2m5n.edgeql) ... parsed ... applied .. edb:split-section:: Update the ``get_decks`` query to sort the decks by ``updated_at`` in descending order. .. code-block:: python-diff :caption: main.py @app.get("/decks", response_model=List[Deck]) async def get_decks(): decks = await client.query(""" select Deck { id, name, description, cards := ( select .cards { id, front, back } order by .order ) } + order by .updated_at desc """) return decks ================================================ FILE: docs/intro/quickstart/inheritance/index.rst ================================================ .. edb:env-switcher:: ======================== Adding shared properties ======================== .. toctree:: :maxdepth: 3 :hidden: nextjs fastapi ================================================ FILE: docs/intro/quickstart/inheritance/nextjs.rst ================================================ .. _ref_quickstart_inheritance: ======================== Adding shared properties ======================== .. edb:split-section:: One common pattern in applications is to add shared properties to the schema that are used by multiple objects. For example, you might want to add a ``created_at`` and ``updated_at`` property to every object in your schema. You can do this by adding an abstract type and using it as a mixin for your other object types. .. code-block:: sdl-diff :caption: dbschema/default.gel module default { + abstract type Timestamped { + required created_at: datetime { + default := datetime_of_statement(); + }; + required updated_at: datetime { + default := datetime_of_statement(); + }; + } + - type Deck { + type Deck extending Timestamped { required name: str; description: str; multi cards: Card { constraint exclusive; on target delete allow; }; }; - type Card { + type Card extending Timestamped { required order: int64; required front: str; required back: str; } } .. edb:split-section:: Since you don't have historical data for when these objects were actually created or modified, the migration will fall back to the default values set in the ``Timestamped`` type. .. code-block:: sh $ npx gel migration create did you create object type 'default::Timestamped'? [y,n,l,c,b,s,q,?] > y did you alter object type 'default::Card'? [y,n,l,c,b,s,q,?] > y did you alter object type 'default::Deck'? [y,n,l,c,b,s,q,?] > y Created /home/strinh/projects/flashcards/dbschema/migrations/00004-m1d2m5n.edgeql, id: m1d2m5n5ajkalyijrxdliioyginonqbtfzihvwdfdmfwodunszstya $ npx gel migrate Applying m1d2m5n5ajkalyijrxdliioyginonqbtfzihvwdfdmfwodunszstya (00004-m1d2m5n.edgeql) ... parsed ... applied Generating query builder... Detected tsconfig.json, generating TypeScript files. To override this, use the --target flag. Run `npx @gel/generate --help` for full options. Introspecting database schema... Generating runtime spec... Generating cast maps... Generating scalars... Generating object types... Generating function types... Generating operators... Generating set impl... Generating globals... Generating index... Writing files to ./dbschema/edgeql-js Generation complete! 🤘 .. edb:split-section:: Update the ``getDecks`` query to sort the decks by ``updated_at`` in descending order. .. code-block:: typescript-diff :caption: app/queries.ts import { client } from "@/lib/gel"; import e from "@/dbschema/edgeql-js"; export async function getDecks() { const decks = await e.select(e.Deck, (deck) => ({ id: true, name: true, description: true, cards: e.select(deck.cards, (card) => ({ id: true, front: true, back: true, order_by: card.order, })), + order_by: { + expression: deck.updated_at, + direction: e.DESC, + }, })).run(client); return decks; } .. edb:split-section:: Now when you look at the data in the UI, you will see the new properties on each of your object types. .. image:: images/timestamped.png ================================================ FILE: docs/intro/quickstart/modeling/fastapi.rst ================================================ .. _ref_quickstart_fastapi_modeling: ================= Modeling the data ================= .. edb:split-section:: The flashcards application has a simple data model, but it's interesting enough to utilize many unique features of the |Gel| schema language. Looking at the mock data in the example JSON file ``./deck-edgeql.json``, you can see this structure in the JSON. There is a ``Card`` class that describes a single flashcard, which contains two required string properties: ``front`` and ``back``. Each ``Deck`` object has zero or more ``Card`` objects in a list. .. code-block:: python from pydantic import BaseModel class CardBase(BaseModel): front: str back: str class Card(CardBase): id: str class DeckBase(BaseModel): name: str description: Optional[str] = None class Deck(DeckBase): id: str cards: List[Card] .. edb:split-section:: Starting with this simple model, add these types to the :dotgel:`dbschema/default` schema file. As you can see, the types closely mirror the JSON mock data. Also of note, the link between ``Card`` and ``Deck`` objects creates a "1-to-n" relationship, where each ``Deck`` object has a link to zero or more ``Card`` objects. When you query the ``Deck.cards`` link, the cards will be unordered, so the ``Card`` type needs an explicit ``order`` property to allow sorting them at query time. By default, when you try to delete an object that is linked to another object, the database will prevent you from doing so. We want to support removing a ``Card``, so we define a deletion policy on the ``cards`` link that allows deleting the target of this link. .. code-block:: sdl-diff :caption: dbschema/default.gel module default { + type Card { + required order: int64; + required front: str; + required back: str; + }; + + type Deck { + required name: str; + description: str; + multi cards: Card { + constraint exclusive; + on target delete allow; + }; + }; }; .. edb:split-section:: Congratulations! This first version of the data model's schema is *stored in a file on disk*. Now you need to signal the database to actually create types for ``Deck`` and ``Card`` in the database. To make |Gel| do that, you need to do two quick steps: 1. **Create a migration**: a "migration" is a file containing a set of low level instructions that define how the database schema should change. It records any additions, modifications, or deletions to your schema in a way that the database can understand. .. note:: When you are changing existing schema, the CLI migration tool might ask questions to ensure that it understands your changes exactly. Since the existing schema was empty, the CLI will skip asking any questions and simply create the migration file. 2. **Apply the migration**: This executes the migration file on the database, instructing |Gel| to implement the recorded changes in the database. Essentially, this step updates the database structure to match your defined schema, ensuring that the ``Deck`` and ``Card`` types are created and ready for use. .. code-block:: sh $ uvx gel migration create Created ./dbschema/migrations/00001-m125ajr.edgeql, id: m125ajrbqp7ov36s7aniefxc376ofxdlketzspy4yddd3hrh4lxmla $ uvx gel migrate Applying m125ajrbqp7ov36s7aniefxc376ofxdlketzspy4yddd3hrh4lxmla (00001-m125ajr.edgeql) ... parsed ... applied .. edb:split-section:: Take a look at the schema you've generated in the built-in database UI. Use this tool to visualize your data model and see the object types and links you've defined. .. edb:split-point:: .. code-block:: sh $ uvx gel ui .. image:: images/schema-ui.png ================================================ FILE: docs/intro/quickstart/modeling/index.rst ================================================ .. edb:env-switcher:: ================= Modeling the data ================= .. toctree:: :maxdepth: 3 :hidden: nextjs fastapi ================================================ FILE: docs/intro/quickstart/modeling/nextjs.rst ================================================ .. _ref_quickstart_modeling: ================= Modeling the data ================= .. edb:split-section:: The flashcards application has a simple data model, but it's interesting enough to utilize many unique features of the |Gel| schema language. Looking at the mock data in the example JSON file ``./deck-edgeql.json``, you can see this structure in the JSON. There is a ``Card`` type that describes a single flashcard, which contains two required string properties: ``front`` and ``back``. Each ``Deck`` object has zero or more ``Card`` objects in an array. .. code-block:: typescript interface Card { front: string; back: string; } interface Deck { name: string; description: string | null; cards: Card[]; } .. edb:split-section:: Starting with this simple model, add these types to the :dotgel:`dbschema/default` schema file. As you can see, the types closely mirror the JSON mock data. Also of note, the link between ``Card`` and ``Deck`` objects creates a "1-to-n" relationship, where each ``Deck`` object has a link to zero or more ``Card`` objects. When you query the ``Deck.cards`` link, the cards will be unordered, so the ``Card`` type needs an explicit ``order`` property to allow sorting them at query time. By default, when you try to delete an object that is linked to another object, the database will prevent you from doing so. We want to support removing a ``Card``, so we define a deletion policy on the ``cards`` link that allows deleting the target of this link. .. code-block:: sdl-diff :caption: dbschema/default.gel module default { + type Card { + required order: int64; + required front: str; + required back: str; + }; + + type Deck { + required name: str; + description: str; + multi cards: Card { + constraint exclusive; + on target delete allow; + }; + }; }; .. edb:split-section:: Congratulations! This first version of the data model's schema is *stored in a file on disk*. Now you need to signal the database to actually create types for ``Deck`` and ``Card`` in the database. To make |Gel| do that, you need to do two quick steps: 1. **Create a migration**: a "migration" is a file containing a set of low level instructions that define how the database schema should change. It records any additions, modifications, or deletions to your schema in a way that the database can understand. .. note:: When you are changing existing schema, the CLI migration tool might ask questions to ensure that it understands your changes exactly. Since the existing schema was empty, the CLI will skip asking any questions and simply create the migration file. 2. **Apply the migration**: This executes the migration file on the database, instructing |Gel| to implement the recorded changes in the database. Essentially, this step updates the database structure to match your defined schema, ensuring that the ``Deck`` and ``Card`` types are created and ready for use. .. note:: Notice that after the migration is applied, the CLI will automatically run the script to generate the query builder. This is a convenience feature that is enabled by the ``schema.update.after`` hook in the ``gel.toml`` file. .. code-block:: sh $ npx gel migration create Created ./dbschema/migrations/00001-m125ajr.edgeql, id: m125ajrbqp7ov36s7aniefxc376ofxdlketzspy4yddd3hrh4lxmla $ npx gel migrate Applying m125ajrbqp7ov36s7aniefxc376ofxdlketzspy4yddd3hrh4lxmla (00001-m125ajr.edgeql) ... parsed ... applied Generating query builder... Detected tsconfig.json, generating TypeScript files. To override this, use the --target flag. Run `npx @gel/generate --help` for full options. Introspecting database schema... Generating runtime spec... Generating cast maps... Generating scalars... Generating object types... Generating function types... Generating operators... Generating set impl... Generating globals... Generating index... Writing files to ./dbschema/edgeql-js Generation complete! 🤘 .. edb:split-section:: Take a look at the schema you've generated in the built-in database UI. Use this tool to visualize your data model and see the object types and links you've defined. .. edb:split-point:: .. code-block:: sh $ npx gel ui .. image:: images/schema-ui.png ================================================ FILE: docs/intro/quickstart/overview/fastapi.rst ================================================ .. _ref_quickstart_fastapi: ========== Quickstart ========== Welcome to the quickstart tutorial! In this tutorial, you will update a FastAPI backend for a Flashcards application to use |Gel| as your data layer. The application will let users build and manage their own study decks, with each flashcard featuring customizable text on both sides - making it perfect for studying, memorization practice, or creating educational games. Don't worry if you're new to |Gel| - you will be up and running with a working FastAPI backend and a local |Gel| database in just about **5 minutes**. From there, you will replace the static mock data with a |Gel| powered data layer in roughly 30-45 minutes. By the end of this tutorial, you will be comfortable with: * Creating and updating a database schema * Running migrations to evolve your data * Writing EdgeQL queries * Building an app backed by |Gel| Features of the flashcards app ------------------------------ * Create, edit, and delete decks * Add/remove cards with front/back content * Clean, type-safe schema with |Gel| Requirements ------------ Before you start, you need: * Basic familiarity with Python and FastAPI * Python 3.8+ on a Unix-like OS (Linux, macOS, or WSL) * A code editor you love Why |Gel| for FastAPI? ---------------------- * **Type Safety**: Catch data errors before runtime * **Rich Modeling**: Use object types and links to model relations * **Modern Tooling**: Python-friendly schemas and migrations * **Performance**: Efficient queries for complex data * **Developer Experience**: An intuitive query language (EdgeQL) Need Help? ---------- If you run into issues while following this tutorial: - Check the `Gel documentation `_ - Visit our `community Discord `_ - File an issue on `GitHub `_ ================================================ FILE: docs/intro/quickstart/overview/index.rst ================================================ .. edb:env-switcher:: ======== Overview ======== .. toctree:: :maxdepth: 3 :hidden: nextjs fastapi ================================================ FILE: docs/intro/quickstart/overview/nextjs.rst ================================================ .. _gel-js-quickstart: .. _ref_quickstart: ========== Quickstart ========== Welcome to the quickstart tutorial! In this tutorial, you will update a simple Next.js application to use |Gel| as your data layer. The application will let users build and manage their own study decks, with each flashcard featuring customizable text on both sides - making it perfect for studying, memorization practice, or creating educational games. Don't worry if you're new to |Gel| - you will be up and running with a working Next.js application and a local |Gel| database in just about **5 minutes**. From there, you will replace the static mock data with a |Gel| powered data layer in roughly 30-45 minutes. By the end of this tutorial, you will be comfortable with: * Creating and updating a database schema * Running migrations to evolve your data * Writing EdgeQL queries in text and via a TypeScript query builder * Building an app backed by |Gel| Features of the flashcards app ------------------------------ * Create, edit, and delete decks * Add/remove cards with front/back content * Simple Next.js + Tailwind UI * Clean, type-safe schema with |Gel| Requirements ------------ Before you start, you need: * Basic familiarity with TypeScript, Next.js, and React * Node.js 20+ on a Unix-like OS (Linux, macOS, or WSL) * A code editor you love Why |Gel| for Next.js? ---------------------- * **Type Safety**: Catch data errors before runtime * **Rich Modeling**: Use object types and links to model relations * **Modern Tooling**: TypeScript-friendly schemas and migrations * **Performance**: Efficient queries for complex data * **Developer Experience**: An intuitive query language (EdgeQL) Need Help? ---------- If you run into issues while following this tutorial: * Check the `Gel documentation `_ * Visit our `community Discord `_ * File an issue on `GitHub `_ ================================================ FILE: docs/intro/quickstart/setup/fastapi.rst ================================================ .. _ref_quickstart_fastapi_setup: =========================== Setting up your environment =========================== .. edb:split-section:: Use git to clone the `FastAPI starter template `_ into a new directory called ``flashcards``. This will create a fully configured FastAPI project and a local |Gel| instance with an empty schema. You will see the database instance being created and the project being initialized. You are now ready to start building the application. .. code-block:: sh $ git clone \ git@github.com:geldata/quickstart-fastapi.git \ flashcards $ cd flashcards $ python -m venv venv $ source venv/bin/activate # or venv\Scripts\activate on Windows $ pip install -r requirements.txt $ uvx gel project init .. edb:split-section:: Explore the empty database by starting our REPL from the project root. .. code-block:: sh $ uvx gel .. edb:split-section:: Try the following queries which will work without any schema defined. .. code-block:: edgeql-repl db> select 42; {42} db> select sum({1, 2, 3}); {6} db> with cards := { ... ( ... front := "What is the highest mountain in the world?", ... back := "Mount Everest", ... ), ... ( ... front := "Which ocean contains the deepest trench on Earth?", ... back := "The Pacific Ocean", ... ), ... } ... select cards order by random() limit 1; { ( front := "What is the highest mountain in the world?", back := "Mount Everest", ) } .. edb:split-section:: Fun! You will create a proper data model for the application in the next step, but for now, take a look around the project we have. Here are the files that integrate |Gel|: - ``gel.toml``: The configuration file for the |Gel| project instance. Notice that we have a ``hooks.migration.apply.after`` hook that will run ``uvx gel-py`` after migrations are applied. This will run the code generator that you will use later to get fully type-safe queries you can run from your FastAPI backend. More details on that to come! - ``dbschema/``: This directory contains the schema for the database, and later supporting files like migrations, and generated code. - :dotgel:`dbschema/default`: The default schema file that you'll use to define your data model. It is empty for now, but you'll add your data model to this file in the next step. .. tabs:: .. code-tab:: toml :caption: gel.toml [instance] server-version = "6.11" [hooks] schema.update.after = "uvx gel-py" .. code-tab:: sdl :caption: dbschema/default.gel module default { } ================================================ FILE: docs/intro/quickstart/setup/index.rst ================================================ .. edb:env-switcher:: =========================== Setting up your environment =========================== .. toctree:: :maxdepth: 3 :hidden: nextjs fastapi ================================================ FILE: docs/intro/quickstart/setup/nextjs.rst ================================================ .. _ref_quickstart_setup: =========================== Setting up your environment =========================== .. edb:split-section:: Use git to clone `the Next.js starter template `_ into a new directory called ``flashcards``. This will create a fully configured Next.js project and a local |Gel| instance with an empty schema. You will see the database instance being created and the project being initialized. You are now ready to start building the application. .. code-block:: sh $ git clone \ git@github.com:geldata/quickstart-nextjs.git \ flashcards $ cd flashcards $ npm install $ npx gel project init .. edb:split-section:: Explore the empty database by starting our REPL from the project root. .. code-block:: sh $ npx gel .. edb:split-section:: Try the following queries which will work without any schema defined. .. code-block:: edgeql-repl db> select 42; {42} db> select sum({1, 2, 3}); {6} db> with cards := { ... ( ... front := "What is the highest mountain in the world?", ... back := "Mount Everest", ... ), ... ( ... front := "Which ocean contains the deepest trench on Earth?", ... back := "The Pacific Ocean", ... ), ... } ... select cards order by random() limit 1; { ( front := "What is the highest mountain in the world?", back := "Mount Everest", ) } .. edb:split-section:: Fun! You will create a proper data model for the application in the next step, but for now, take a look around the project you've just created. Most of the project files will be familiar if you've worked with Next.js before. Here are the files that integrate |Gel|: - ``gel.toml``: The configuration file for the |Gel| project instance. Notice that we have a ``hooks.migration.apply.after`` hook that will run ``npx @gel/generate edgeql-js`` after migrations are applied. This will generate the query builder code that you'll use to interact with the database. More details on that to come! - ``dbschema/``: This directory contains the schema for the database, and later supporting files like migrations, and generated code. - :dotgel:`dbschema/default`: The default schema file that you'll use to define your data model. It is empty for now, but you'll add your data model to this file in the next step. - ``lib/gel.ts``: A utility module that exports the |Gel| client, which you'll use to interact with the database. .. tabs:: .. code-tab:: toml :caption: gel.toml [instance] server-version = "6.11" [hooks] schema.update.after = "npx @gel/generate edgeql-js" .. code-tab:: sdl :caption: dbschema/default.gel module default { } .. code-tab:: typescript :caption: lib/gel.ts import { createClient } from "gel"; export const client = createClient(); ================================================ FILE: docs/intro/quickstart/working/fastapi.rst ================================================ .. _ref_quickstart_fastapi_working: ===================== Working with the data ===================== In this section, you will update the existing FastAPI application to use |Gel| to store and query data, instead of a JSON file. Having a working application with mock data allows you to focus on learning how |Gel| works, without getting bogged down by the details of the application. Bulk importing of data ====================== .. edb:split-section:: First, update the imports and Pydantic models to use UUID instead of string for ID fields, since this is what |Gel| returns. You also need to initialize the |Gel| client and import the asyncio module to work with async functions. .. code-block:: python-diff :caption: main.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional - import json - from pathlib import Path + from uuid import UUID + from gel import create_async_client + import asyncio app = FastAPI(title="Flashcards API") # Pydantic models class CardBase(BaseModel): front: str back: str class Card(CardBase): - id: str + id: UUID class DeckBase(BaseModel): name: str description: Optional[str] = None class DeckCreate(DeckBase): cards: List[CardBase] class Deck(DeckBase): - id: str + id: UUID cards: List[Card] - DATA_DIR = Path(__file__).parent / "data" - DECKS_FILE = DATA_DIR / "decks.json" + client = create_async_client() .. edb:split-section:: Next, update the deck import operation to use |Gel| to create the deck and cards. The operation creates cards first, then creates a deck with links to the cards. Finally, it fetches the newly created deck with all required fields. .. note:: Notice the ``{ ** }`` in the query. This is a shorthand for selecting all fields of the object. It's useful when you want to return the entire object without specifying each field. In our case, we want to return the entire deck object with all the nested fields. .. code-block:: python-diff :caption: main.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional from uuid import UUID from gel import create_async_client import asyncio app = FastAPI(title="Flashcards API") # Pydantic models class CardBase(BaseModel): front: str back: str class Card(CardBase): id: UUID class DeckBase(BaseModel): name: str description: Optional[str] = None class DeckCreate(DeckBase): cards: List[CardBase] class Deck(DeckBase): id: UUID cards: List[Card] client = create_client() - DATA_DIR.mkdir(exist_ok=True) - if not DECKS_FILE.exists(): - DECKS_FILE.write_text("[]") - def read_decks() -> List[Deck]: - content = DECKS_FILE.read_text() - data = json.loads(content) - return [Deck(**deck) for deck in data] - - def write_decks(decks: List[Deck]) -> None: - data = [deck.model_dump() for deck in decks] - DECKS_FILE.write_text(json.dumps(data, indent=2)) @app.post("/decks/import", response_model=Deck) async def import_deck(deck: DeckCreate): - decks = read_decks() - new_deck = Deck( - id=str(uuid.uuid4()), - name=deck.name, - description=deck.description, - cards=[Card(id=str(uuid.uuid4()), **card.model_dump()) - for card in deck.cards] - ) - decks.append(new_deck) - write_decks(decks) - return new_deck + card_ids = [] + for i, card in enumerate(deck.cards): + created_card = await client.query_single(""" + insert Card { + front := $front, + back := $back, + order := $order + } + """, front=card.front, back=card.back, order=i) + card_ids.append(created_card.id) + + new_deck = await client.query_single(""" + select( + insert Deck { + name := $name, + description := $description, + cards := ( + select Card + filter contains(>$card_ids, .id) + ) + } + ) { ** } + """, name=deck.name, description=deck.description, + card_ids=card_ids) + + return new_deck .. edb:split-section:: The above works but isn't atomic - if any single query fails, you could end up with partial data. Let's wrap it in a transaction: .. code-block:: python-diff :caption: main.py @app.post("/decks/import", response_model=Deck) async def import_deck(deck: DeckCreate): + async for tx in client.transaction(): + async with tx: card_ids = [] for i, card in enumerate(deck.cards): - created_card = await client.query_single( + created_card = await tx.query_single( """ insert Card { front := $front, back := $back, order := $order } """, front=card.front, back=card.back, order=i, ) card_ids.append(created_card.id) - new_deck = await client.query_single(""" + new_deck = await tx.query_single(""" select( insert Deck { name := $name, description := $description, cards := ( select Card filter .id IN array_unpack(>$card_ids) ) } ) { ** } """, name=deck.name, description=deck.description, card_ids=card_ids, ) return new_deck .. edb:split-section:: One of the most powerful features of EdgeQL is the ability to compose complex queries in a way that is both readable and efficient. Use this super-power to create a single query that inserts the deck and cards, along with their links, in one efficient query. This new query uses a ``for`` expression to iterate over the set of cards, and sets the ``Deck.cards`` link to the result of inserting each card. This is logically equivalent to the previous approach, but is more efficient since it inserts the deck and cards in a single query. .. code-block:: python-diff :caption: main.py @app.post("/decks/import", response_model=Deck) async def import_deck(deck: DeckCreate): - async for tx in client.transaction(): - async with tx: - card_ids = [] - for i, card in enumerate(deck.cards): - created_card = await tx.query_single( - """ - insert Card { - front := $front, - back := $back, - order := $order - } - """, - front=card.front, - back=card.back, - order=i, - ) - card_ids.append(created_card.id) - - new_deck = await client.query_single(""" - select( - insert Deck { - name := $name, - description := $description, - cards := ( - select Card - filter .id IN array_unpack(>$card_ids) - ) - } - ) { ** } - """, - name=deck.name, - description=deck.description, - card_ids=card_ids, - ) + cards_data = [(c.front, c.back, i) for i, c in enumerate(deck.cards)] + + new_deck = await client.query_single(""" + select( + with cards := >>$cards_data + insert Deck { + name := $name, + description := $description, + cards := ( + for card in array_unpack(cards) + insert Card { + front := card.0, + back := card.1, + order := card.2 + } + ) + } + ) { ** } + """, name=deck.name, description=deck.description, + cards_data=cards_data) return new_deck Updating data ============= .. edb:split-section:: Next, update the deck operations. The update operation needs to handle partial updates of name and description: .. code-block:: python-diff :caption: main.py @app.put("/decks/{deck_id}", response_model=Deck) async def update_deck(deck_id: UUID, deck_update: DeckBase): - decks = read_decks() - deck = next((deck for deck in decks if deck.id == deck_id), None) - if not deck: - raise HTTPException(status_code=404, detail="Deck not found") - - deck.name = deck_update.name - deck.description = deck_update.description - write_decks(decks) - return deck + # Build update sets based on provided fields + sets = [] + params = {"id": deck_id} + + if deck_update.name is not None: + sets.append("name := $name") + params["name"] = deck_update.name + + if deck_update.description is not None: + sets.append("description := $description") + params["description"] = deck_update.description + + if not sets: + return await get_deck(deck_id) + + updated_deck = await client.query(f""" + with updated := ( + update Deck + filter .id = $id + set {{ {', '.join(sets)} }} + ) + select updated {{ ** }} + """, **params) + + if not updated_deck: + raise HTTPException(status_code=404, detail="Deck not found") + + return updated_deck Adding linked data ================== .. edb:split-section:: Now, update the add card operation to use |Gel|. This operation will insert a new ``Card`` object and update the ``Deck.cards`` set to include the new ``Card`` object. Notice that the ``order`` property is set by selecting the maximum ``order`` property of this ``Deck.cards`` set and incrementing it by 1. The syntax for adding an object to a set of links is ``{ "+=": object }``. You can think of this as a shortcut for setting the link set to the current set plus the new object. .. code-block:: python-diff :caption: main.py @app.post("/decks/{deck_id}/cards", response_model=Card) async def add_card(deck_id: UUID, card: CardBase): - decks = read_decks() - deck = next((deck for deck in decks if deck.id == deck_id), None) - if not deck: - raise HTTPException(status_code=404, detail="Deck not found") - - new_card = Card(id=str(uuid.uuid4()), **card.model_dump()) - deck.cards.append(new_card) - write_decks(decks) - return new_card + new_card = await client.query_single( + """ + with + deck := (select Deck filter .id = $id), + order := (max(deck.cards.order) + 1), + new_card := ( + insert Card { + front := $front, + back := $back, + order := order, + } + ), + updated := ( + update deck + set { + cards += new_card + } + ), + select new_card { ** } + """, + id=deck_id, + front=card.front, + back=card.back, + ) + + if not new_card: + raise HTTPException(status_code=404, detail="Deck not found") + + return new_card Deleting linked data ==================== .. edb:split-section:: As the next step, update the card deletion operation to use |Gel| to remove a card from a deck: .. code-block:: python-diff :caption: main.py @app.delete("/cards/{card_id}") async def delete_card(card_id: str): - decks = read_decks() - deck = next((deck for deck in decks if deck.id == deck_id), None) - if not deck: - raise HTTPException(status_code=404, detail="Deck not found") - - deck.cards = [card for card in deck.cards if card.id != card_id] - write_decks(decks) + deleted = await client.query_single(""" + delete Card filter .id = $card_id + """, card_id=card_id) + + if not deleted: + raise HTTPException(status_code=404, detail="Card not found") + return {"message": "Card deleted"} Querying data ============= .. edb:split-section:: Finally, update the query endpoints to fetch data from |Gel|: .. code-block:: python-diff :caption: main.py @app.get("/decks", response_model=List[Deck]) async def get_decks(): - return read_decks() + decks = await client.query(""" + select Deck { + id, + name, + description, + cards := ( + select .cards { + id, + front, + back + } + order by .order + ) + } + """) + return decks @app.get("/decks/{deck_id}", response_model=Deck) async def get_deck(deck_id: UUID): - decks = read_decks() - deck = next((deck for deck in decks if deck.id == deck_id), None) - if not deck: - raise HTTPException(status_code=404, detail=f"Deck with id {deck_id} not found") - return deck + deck = await client.query_single(""" + select Deck { + id, + name, + description, + cards := ( + select .cards { + id, + front, + back + } + order by .order + ) + } + filter .id = $id + """, id=deck_id) + + if not deck: + raise HTTPException( + status_code=404, + detail=f"Deck with id {deck_id} not found" + ) + + return deck .. edb:split-section:: You can now run your FastAPI application with: .. code-block:: sh $ uvicorn main:app --reload .. edb:split-section:: The API documentation will be available at http://localhost:8000/docs. You can use this interface to test your endpoints and import the sample flashcard deck. .. image:: images/flashcards-api.png ================================================ FILE: docs/intro/quickstart/working/index.rst ================================================ .. edb:env-switcher:: ===================== Working with the data ===================== .. toctree:: :maxdepth: 3 :hidden: nextjs fastapi ================================================ FILE: docs/intro/quickstart/working/nextjs.rst ================================================ .. _ref_quickstart_working: ===================== Working with the data ===================== In this section, you will update the existing application to use |Gel| to store and query data, instead of a static JSON file. Having a working application with mock data allows you to focus on learning how |Gel| works, without getting bogged down by the details of the application. Bulk importing of data ====================== .. edb:split-section:: Begin by updating the server action to import a deck with cards. Loop through each card in the deck and insert it, building an array of IDs as you go. This array of IDs will be used to set the ``cards`` link on the ``Deck`` object after all cards have been inserted. The array of card IDs is initially an array of strings. To satisfy the |Gel| type system, which expects the ``id`` property of ``Card`` objects to be a ``uuid`` rather than a ``str``, you need to cast the array of strings to an array of UUIDs. Use the ``e.literal(e.array(e.uuid), cardIds)`` function to perform this casting. The function ``e.contains(cardIdsLiteral, c.id)`` from our standard library checks if a value is present in an array and returns a boolean. When inserting the ``Deck`` object, set the ``cards`` to the result of selecting only the ``Card`` objects whose ``id`` is included in the ``cardIds`` array. .. code-block:: typescript-diff :caption: app/actions.ts "use server"; - import { readFile, writeFile } from "node:fs/promises"; + import { client } from "@/lib/gel"; + import e from "@/dbschema/edgeql-js"; import { revalidatePath } from "next/cache"; - import { RawJSONDeck, Deck } from "@/lib/models"; + import { RawJSONDeck } from "@/lib/models"; export async function importDeck(formData: FormData) { const file = formData.get("file") as File; const rawDeck = JSON.parse(await file.text()) as RawJSONDeck; const deck = { ...rawDeck, - id: crypto.randomUUID(), - cards: rawDeck.cards.map((card) => ({ + cards: rawDeck.cards.map((card, index) => ({ ...card, - id: crypto.randomUUID(), + order: index, })), }; - - const existingDecks = JSON.parse( - await readFile("./decks.json", "utf-8") - ) as Deck[]; - - await writeFile( - "./decks.json", - JSON.stringify([...existingDecks, deck], null, 2) - ); + const cardIds: string[] = []; + for (const card of deck.cards) { + const createdCard = await e + .insert(e.Card, { + front: card.front, + back: card.back, + order: card.order, + }) + .run(client); + + cardIds.push(createdCard.id); + } + + const cardIdsLiteral = e.literal(e.array(e.uuid), cardIds); + + await e.insert(e.Deck, { + name: deck.name, + description: deck.description, + cards: e.select(e.Card, (c) => ({ + filter: e.contains(cardIdsLiteral, c.id), + })), + }).run(client); revalidatePath("/"); } .. edb:split-section:: This works, but you might notice that it is not atomic. For instance, if one of the ``Card`` objects fails to insert, the entire operation will fail and the ``Deck`` will not be inserted, but some data will still linger. To make this operation atomic, update the ``importDeck`` action to use a transaction. .. code-block:: typescript-diff :caption: app/actions.ts "use server"; import { client } from "@/lib/gel"; import e from "@/dbschema/edgeql-js"; import { revalidatePath } from "next/cache"; import { RawJSONDeck } from "@/lib/models"; export async function importDeck(formData: FormData) { const file = formData.get("file") as File; const rawDeck = JSON.parse(await file.text()) as RawJSONDeck; const deck = { ...rawDeck, cards: rawDeck.cards.map((card, index) => ({ ...card, order: index, })), }; + await client.transaction(async (tx) => { const cardIds: string[] = []; for (const card of deck.cards) { const createdCard = await e .insert(e.Card, { front: card.front, back: card.back, order: card.order, }) - .run(client); + .run(tx); cardIds.push(createdCard.id); } const cardIdsLiteral = e.literal(e.array(e.uuid), cardIds); await e.insert(e.Deck, { name: deck.name, description: deck.description, cards: e.select(e.Card, (c) => ({ filter: e.contains(cardIdsLiteral, c.id), })), - }).run(client); + }).run(tx); + }); revalidatePath("/"); } .. edb:split-section:: You might think this is as good as it gets, and many ORMs will create a similar set of queries. However, with the query builder, you can improve this by crafting a single query that inserts the ``Deck`` and ``Card`` objects, along with their links, in one efficient query. The first thing to notice is that the ``e.params`` function is used to define parameters for your query instead of embedding literal values directly. This approach eliminates the need for casting, as was necessary with the ``cardIds`` array. By defining the ``cards`` parameter as an array of tuples, you ensure full type safety with both TypeScript and the database. Another key feature of this query builder expression is the ``e.for(e.array_unpack(params.cards), (card) => {...})`` construct. This expression converts the array of tuples into a set of tuples and generates a set containing an expression for each element. Essentially, you assign the ``Deck.cards`` set of ``Card`` objects to the result of inserting each element from the ``cards`` array. This is similar to what you were doing before by selecting all ``Card`` objects by their ``id``, but is more efficient since you are inserting the ``Deck`` and all ``Card`` objects in one query. .. code-block:: typescript-diff :caption: app/actions.ts "use server"; import { client } from "@/lib/gel"; import e from "@/dbschema/edgeql-js"; import { revalidatePath } from "next/cache"; import { RawJSONDeck } from "@/lib/models"; export async function importDeck(formData: FormData) { const file = formData.get("file") as File; const rawDeck = JSON.parse(await file.text()) as RawJSONDeck; const deck = { ...rawDeck, cards: rawDeck.cards.map((card, index) => ({ ...card, order: index, })), }; - await client.transaction(async (tx) => { - const cardIds: string[] = []; - for (const card of deck.cards) { - const createdCard = await e - .insert(e.Card, { - front: card.front, - back: card.back, - order: card.order, - }) - .run(tx); - - cardIds.push(createdCard.id); - } - - const cardIdsLiteral = e.literal(e.array(e.uuid), cardIds); - - await e.insert(e.Deck, { - name: deck.name, - description: deck.description, - cards: e.select(e.Card, (c) => ({ - filter: e.contains(cardIdsLiteral, c.id), - })), - }).run(tx); - }); + await e + .params( + { + name: e.str, + description: e.optional(e.str), + cards: e.array(e.tuple({ front: e.str, back: e.str, order: e.int64 })), + }, + (params) => + e.insert(e.Deck, { + name: params.name, + description: params.description, + cards: e.for(e.array_unpack(params.cards), (card) => + e.insert(e.Card, { + front: card.front, + back: card.back, + order: card.order, + }) + ), + }) + ) + .run(client, deck); revalidatePath("/"); } Updating data ============= .. edb:split-section:: Next, you will update the Server Actions for each ``Deck`` object: ``updateDeck``, ``addCard``, and ``deleteCard``. Start with ``updateDeck``, which is the most complex because it is dynamic. You can set either the ``title`` or ``description`` fields in an update. Use the dynamic nature of the query builder to generate separate queries based on which fields are present in the form data. This may seem a bit intimidating at first, but the key to making this query dynamic is the ``nameSet`` and ``descriptionSet`` variables. These variables conditionally add the ``name`` or ``description`` fields to the ``set`` parameter of the ``update`` call. .. code-block:: typescript-diff :caption: app/deck/[id]/actions.ts "use server"; import { revalidatePath } from "next/cache"; import { readFile, writeFile } from "node:fs/promises"; + import { client } from "@/lib/gel"; + import e from "@/dbschema/edgeql-js"; import { Deck } from "@/lib/models"; export async function updateDeck(formData: FormData) { const id = formData.get("id"); const name = formData.get("name"); const description = formData.get("description"); if ( typeof id !== "string" || (typeof name !== "string" && typeof description !== "string") ) { return; } - const decks = JSON.parse( - await readFile("./decks.json", "utf-8") - ) as Deck[]; - decks[index].name = name ?? decks[index].name; + const nameSet = typeof name === "string" ? { name } : {}; - decks[index].description = description ?? decks[index].description; + const descriptionSet = + typeof description === "string" ? { description: description || null } : {}; + await e + .update(e.Deck, (d) => ({ + filter_single: e.op(d.id, "=", e.uuid(id)), + set: { + ...nameSet, + ...descriptionSet, + }, + })).run(client); - await writeFile("./decks.json", JSON.stringify(decks, null, 2)); revalidatePath(`/deck/${id}`); } export async function addCard(formData: FormData) { const deckId = formData.get("deckId"); const front = formData.get("front"); const back = formData.get("back"); if ( typeof deckId !== "string" || typeof front !== "string" || typeof back !== "string" ) { return; } const decks = JSON.parse(await readFile("./decks.json", "utf-8")) as Deck[]; const deck = decks.find((deck) => deck.id === deckId); if (!deck) { return; } deck.cards.push({ front, back, id: crypto.randomUUID() }); await writeFile("./decks.json", JSON.stringify(decks, null, 2)); revalidatePath(`/deck/${deckId}`); } export async function deleteCard(formData: FormData) { const cardId = formData.get("cardId"); if (typeof cardId !== "string") { return; } const decks = JSON.parse(await readFile("./decks.json", "utf-8")) as Deck[]; const deck = decks.find((deck) => deck.cards.some((card) => card.id === cardId)); if (!deck) { return; } deck.cards = deck.cards.filter((card) => card.id !== cardId); await writeFile("./decks.json", JSON.stringify(decks, null, 2)); revalidatePath(`/`); } Adding linked data ================== .. edb:split-section:: For the ``addCard`` action, you need to insert a new ``Card`` object and update the ``Deck.cards`` set to include the new ``Card`` object. Notice that the ``order`` property is set by selecting the maximum ``order`` property of this ``Deck.cards`` set and incrementing it by 1. The syntax for adding an object to a set of links is ``{ "+=": object }``. You can think of this as a shortcut for setting the link set to the current set plus the new object. .. code-block:: typescript-diff :caption: app/deck/[id]/actions.ts "use server"; import { revalidatePath } from "next/cache"; import { readFile, writeFile } from "node:fs/promises"; import { client } from "@/lib/gel"; import e from "@/dbschema/edgeql-js"; import { Deck } from "@/lib/models"; export async function updateDeck(formData: FormData) { const id = formData.get("id"); const name = formData.get("name"); const description = formData.get("description"); if ( typeof id !== "string" || (typeof name !== "string" && typeof description !== "string") ) { return; } const nameSet = typeof name === "string" ? { name } : {}; const descriptionSet = typeof description === "string" ? { description: description || null } : {}; await e .update(e.Deck, (d) => ({ filter_single: e.op(d.id, "=", e.uuid(id)), set: { ...nameSet, ...descriptionSet, }, })).run(client); revalidatePath(`/deck/${id}`); } export async function addCard(formData: FormData) { const deckId = formData.get("deckId"); const front = formData.get("front"); const back = formData.get("back"); if ( typeof deckId !== "string" || typeof front !== "string" || typeof back !== "string" ) { return; } - const decks = JSON.parse(await readFile("./decks.json", "utf-8")) as Deck[]; - - const deck = decks.find((deck) => deck.id === deckId); - if (!deck) { - return; - } - - deck.cards.push({ front, back, id: crypto.randomUUID() }); - await writeFile("./decks.json", JSON.stringify(decks, null, 2)); + await e + .params( + { + front: e.str, + back: e.str, + deckId: e.uuid, + }, + (params) => { + const deck = e.assert_exists( + e.select(e.Deck, (d) => ({ + filter_single: e.op(d.id, "=", params.deckId), + })) + ); + + const order = e.cast(e.int64, e.max(deck.cards.order)); + const card = e.insert(e.Card, { + front: params.front, + back: params.back, + order: e.op(order, "+", 1), + }); + return e.update(deck, (d) => ({ + set: { + cards: { + "+=": card + }, + }, + })) + } + ) + .run(client, { + front, + back, + deckId, + }); revalidatePath(`/deck/${deckId}`); } export async function deleteCard(formData: FormData) { const cardId = formData.get("cardId"); if (typeof cardId !== "string") { return; } const decks = JSON.parse(await readFile("./decks.json", "utf-8")) as Deck[]; const deck = decks.find((deck) => deck.cards.some((card) => card.id === cardId)); if (!deck) { return; } deck.cards = deck.cards.filter((card) => card.id !== cardId); await writeFile("./decks.json", JSON.stringify(decks, null, 2)); revalidatePath(`/`); } Deleting linked data ==================== .. edb:split-section:: For the ``deleteCard`` action, delete the ``Card`` object and based on the deletion policy we set up earlier in the schema, the object will be deleted from the database and removed from the ``Deck.cards`` set. .. code-block:: typescript-diff :caption: app/deck/[id]/actions.ts "use server"; import { revalidatePath } from "next/cache"; - import { readFile, writeFile } from "node:fs/promises"; import { client } from "@/lib/gel"; import e from "@/dbschema/edgeql-js"; import { Deck } from "@/lib/models"; export async function updateDeck(formData: FormData) { const id = formData.get("id"); const name = formData.get("name"); const description = formData.get("description"); if ( typeof id !== "string" || (typeof name !== "string" && typeof description !== "string") ) { return; } const nameSet = typeof name === "string" ? { name } : {}; const descriptionSet = typeof description === "string" ? { description: description || null } : {}; await e .update(e.Deck, (d) => ({ filter_single: e.op(d.id, "=", e.uuid(id)), set: { ...nameSet, ...descriptionSet, }, })).run(client); revalidatePath(`/deck/${id}`); } export async function addCard(formData: FormData) { const deckId = formData.get("deckId"); const front = formData.get("front"); const back = formData.get("back"); if ( typeof deckId !== "string" || typeof front !== "string" || typeof back !== "string" ) { return; } await e .params( { front: e.str, back: e.str, deckId: e.uuid, }, (params) => { const deck = e.assert_exists( e.select(e.Deck, (d) => ({ filter_single: e.op(d.id, "=", params.deckId), })) ); const order = e.cast(e.int64, e.max(deck.cards.order)); const card = e.insert(e.Card, { front: params.front, back: params.back, order: e.op(order, "+", 1), }); return e.update(deck, (d) => ({ set: { cards: { "+=": card }, }, })) } ) .run(client, { front, back, deckId, }); revalidatePath(`/deck/${deckId}`); } export async function deleteCard(formData: FormData) { const cardId = formData.get("cardId"); if (typeof cardId !== "string") { return; } - const decks = JSON.parse(await readFile("./decks.json", "utf-8")) as Deck[]; - const deck = decks.find((deck) => deck.cards.some((card) => card.id === cardId)); - if (!deck) { - return; - } - - deck.cards = deck.cards.filter((card) => card.id !== cardId); - await writeFile("./decks.json", JSON.stringify(decks, null, 2)); + await e + .params({ id: e.uuid }, (params) => + e.delete(e.Card, (c) => ({ + filter_single: e.op(c.id, "=", params.id), + })) + ) + .run(client, { id: cardId }); + revalidatePath(`/`); } Querying data ============= .. edb:split-section:: Next, update the two ``queries.ts`` methods: ``getDecks`` and ``getDeck``. .. tabs:: .. code-tab:: typescript-diff :caption: app/queries.ts - import { readFile } from "node:fs/promises"; + import { client } from "@/lib/gel"; + import e from "@/dbschema/edgeql-js"; - - import { Deck } from "@/lib/models"; export async function getDecks() { - const decks = JSON.parse(await readFile("./decks.json", "utf-8")) as Deck[]; + const decks = await e.select(e.Deck, (deck) => ({ + id: true, + name: true, + description: true, + cards: e.select(deck.cards, (card) => ({ + id: true, + front: true, + back: true, + order_by: card.order, + })), + })).run(client); return decks; } .. code-tab:: typescript-diff :caption: app/deck/[id]/queries.ts - import { readFile } from "node:fs/promises"; - import { Deck } from "@/lib/models"; + import { client } from "@/lib/gel"; + import e from "@/dbschema/edgeql-js"; export async function getDeck({ id }: { id: string }) { - const decks = JSON.parse(await readFile("./decks.json", "utf-8")) as Deck[]; - return decks.find((deck) => deck.id === id) ?? null; + return await e + .select(e.Deck, (deck) => ({ + filter_single: e.op(deck.id, "=", e.uuid(id)), + id: true, + name: true, + description: true, + cards: e.select(deck.cards, (card) => ({ + id: true, + front: true, + back: true, + order_by: card.order, + })), + })) + .run(client); } .. edb:split-section:: In a terminal, run the Next.js development server. .. code-block:: sh $ npm run dev .. edb:split-section:: A static JSON file to seed your database with a deck of trivia cards is included in the project. Open your browser and navigate to the app at ``_. Use the "Import JSON" button to import this JSON file into your database. .. image:: images/flashcards-import.png ================================================ FILE: docs/intro/schema.rst ================================================ .. _ref_intro_schema: ====== Schema ====== This page is intended as a rapid-fire overview of Gel's schema definition language (SDL) so you can hit the ground running with Gel. Refer to the linked pages for more in-depth documentation! Scalar types ------------ |Gel| implements a rigorous type system containing the following primitive types. .. list-table:: * - Strings - ``str`` * - Booleans - ``bool`` * - Numbers - ``int16`` ``int32`` ``int64`` ``float32`` ``float64`` ``bigint`` ``decimal`` * - UUID - ``uuid`` * - JSON - ``json`` * - Dates and times - ``datetime`` ``cal::local_datetime`` ``cal::local_date`` ``cal::local_time`` * - Durations - ``duration`` ``cal::relative_duration`` ``cal::date_duration`` * - Binary data - ``bytes`` * - Auto-incrementing counters - ``sequence`` * - Enums - ``enum`` These primitives can be combined into arrays, tuples, and ranges. .. list-table:: * - Arrays - ``array`` * - Tuples (unnamed) - ``tuple`` * - Tuples (named) - ``tuple`` * - Ranges - ``range`` Collectively, *primitive* and *collection* types comprise Gel's *scalar type system*. Object types ------------ Object types are analogous to tables in SQL. They can contain **properties**, which can correspond to any scalar types, and **links**, which can correspond to any object types. Properties ---------- Declare a property by naming it and setting its type. .. code-block:: sdl type Movie { title: str; } The ``property`` keyword can be omitted for non-computed properties. See :ref:`Schema > Object types `. Required vs optional ^^^^^^^^^^^^^^^^^^^^ Properties are optional by default. Use the ``required`` keyword to make them required. .. code-block:: sdl type Movie { required title: str; # required release_year: int64; # optional } See :ref:`Schema > Properties `. Constraints ^^^^^^^^^^^ Add a pair of curly braces after the property to define additional information, including constraints. .. code-block:: sdl type Movie { required title: str { constraint exclusive; constraint min_len_value(8); constraint regexp(r'^[A-Za-z0-9 ]+$'); } } See :ref:`Schema > Constraints `. Computed properties ^^^^^^^^^^^^^^^^^^^ Object types can contain *computed properties* that correspond to EdgeQL expressions. This expression is dynamically computed whenever the property is queried. .. code-block:: sdl type Movie { required title: str; uppercase_title := str_upper(.title); } See :ref:`Schema > Computeds `. Links ----- Object types can have links to other object types. .. code-block:: sdl type Movie { required title: str; director: Person; } type Person { required name: str; } The ``link`` keyword can be omitted for non-computed links since Gel v3. Use the ``required`` and ``multi`` keywords to specify the cardinality of the relation. .. code-block:: sdl type Movie { required title: str; cinematographer: Person; # zero or one required director: Person; # exactly one multi writers: Person; # zero or more required multi actors: Person; # one or more } type Person { required name: str; } To define a one-to-one relation, use an ``exclusive`` constraint. .. code-block:: sdl type Movie { required title: str; required stats: MovieStats { constraint exclusive; }; } type MovieStats { required budget: int64; required box_office: int64; } See :ref:`Schema > Links `. Computed links ^^^^^^^^^^^^^^ Objects can contain "computed links": stored expressions that return a set of objects. Computed links are dynamically computed when they are referenced in queries. The example below defines a backlink. .. code-block:: sdl type Movie { required title: str; multi actors: Person; # returns all movies with same title multi same_title := ( with t := .title select detached Movie filter .title = t ) } Backlinks ^^^^^^^^^ A common use case for computed links is *backlinks*. .. code-block:: sdl type Movie { required title: str; multi actors: Person; } type Person { required name: str; multi acted_in := . Person``. ``[is Movie]`` This is a *type filter* that filters out all objects that aren't ``Movie`` objects. A backlink still works without this filter, but could contain any other number of objects besides ``Movie`` objects. See :ref:`Schema > Computeds > Backlinks `. Constraints ----------- Constraints can also be defined at the *object level*. .. code-block:: sdl type BlogPost { title: str; author: User; constraint exclusive on ((.title, .author)); } Constraints can contain exceptions; these are called *partial constraints*. .. code-block:: sdl type BlogPost { title: str; published: bool; constraint exclusive on (.title) except (not .published); } Indexes ------- Use ``index on`` to define indexes on an object type. .. code-block:: sdl type Movie { required title: str; required release_year: int64; index on (.title); # simple index index on ((.title, .release_year)); # composite index index on (str_trim(str_lower(.title))); # computed index } The ``id`` property, all links, and all properties with ``exclusive`` constraints are automatically indexed. See :ref:`Schema > Indexes `. Schema mixins ------------- Object types can be declared as ``abstract``. Non-abstract types can *extend* abstract types. .. code-block:: sdl abstract type Content { required title: str; } type Movie extending Content { required release_year: int64; } type TVShow extending Content { required num_seasons: int64; } Multiple inheritance is supported. .. code-block:: sdl abstract type HasTitle { required title: str; } abstract type HasReleaseYear { required release_year: int64; } type Movie extending HasTitle, HasReleaseYear { sequel_to: Movie; } See :ref:`Schema > Object types > Inheritance `. Polymorphism ------------ Links can correspond to abstract types. These are known as *polymorphic links*. .. code-block:: sdl abstract type Content { required title: str; } type Movie extending Content { required release_year: int64; } type TVShow extending Content { required num_seasons: int64; } type Franchise { required name: str; multi entries: Content; } See :ref:`Schema > Links > Polymorphism ` and :ref:`EdgeQL > Select > Polymorphic queries `. ================================================ FILE: docs/intro/tutorials/ai_fastapi_searchbot.rst ================================================ .. _ref_guide_fastapi_gelai_searchbot: =============================== Build a Search Bot with FastAPI =============================== :edb-alt-title: Building a search bot with memory using FastAPI and Gel AI In this tutorial we're going to walk you through building a chat bot with search capabilities using Gel and `FastAPI `_. FastAPI is a framework designed to help you build web apps *fast*. Gel is a data layer designed to help you figure out storage in your application - also *fast*. By the end of this tutorial, you will have tried out different aspects of using those two together. We will start by creating an app with FastAPI, adding web search capabilities, and then putting search results through a language model to get a human-friendly answer. After that, we'll use Gel to implement chat history so that the bot remembers previous interactions with the user. We'll finish it off with semantic search-based cross-chat memory. 1. Initialize the project ========================= .. edb:split-section:: We're going to start by installing `uv `_ - a Python package manager that's going to simplify environment management for us. You can follow their `installation instructions `_ or simply run: .. code-block:: bash $ curl -LsSf https://astral.sh/uv/install.sh | sh .. edb:split-section:: Once that is done, we can use uv to create scaffolding for our project following the `documentation `_: .. code-block:: bash $ uv init searchbot \ && cd searchbot .. edb:split-section:: For now, we know we're going to need Gel and FastAPI, so let's add those following uv's instructions on `managing dependencies `_, as well as FastAPI's `installation docs `_. Running ``uv sync`` after that will create our virtual environment in a ``.venv`` directory and ensure it's ready. As the last step, we'll activate the environment and get started. .. note:: Every time you open a new terminal session, you should source the environment before running ``python``, ``gel`` or ``fastapi`` commands. .. code-block:: bash $ uv add "fastapi[standard]" \ && uv add gel \ && uv sync \ && source .venv/bin/activate 2. Get started with FastAPI =========================== .. edb:split-section:: At this stage we need to follow FastAPI's `tutorial `_ to create the foundation of our app. We're going to make a minimal web API with one endpoint that takes in a user query as an input and echoes it as an output. First, let's make a directory called ``app`` in our project root, and put an empty ``__init__.py`` there. .. code-block:: bash $ mkdir app && touch app/__init__.py .. edb:split-section:: Now let's create a file called ``main.py`` inside the ``app`` directory and put the "Hello World" example in it: .. code-block:: python :caption: app/main.py from fastapi import FastAPI app = FastAPI() @app.get("/") async def root(): return {"message": "Hello World"} .. edb:split-section:: To start the server, we'll run: .. code-block:: bash $ fastapi dev app/main.py .. edb:split-section:: Once the server gets up and running, we can make sure it works using FastAPI's built-in UI at _, or manually with ``curl``: .. code-block:: bash $ curl -X 'GET' \ 'http://127.0.0.1:8000/' \ -H 'accept: application/json' {"message":"Hello World"} .. edb:split-section:: Now, to create the search endpoint we mentioned earlier, we need to pass our query as a parameter to it. We'd prefer to have it in the request's body since user messages can be long. In FastAPI land, this is done by creating a Pydantic schema and making it the type of the input parameter. `Pydantic `_ is a data validation library for Python. It has many features, but we don't actually need to know about them for now. All we need to know is that FastAPI uses Pydantic types to automatically figure out schemas for `input `_, as well as `output `_. Let's add the following to our ``main.py``: .. code-block:: python :caption: app/main.py from pydantic import BaseModel class SearchTerms(BaseModel): query: str class SearchResult(BaseModel): response: str | None = None .. edb:split-section:: Now, we can define our endpoint. We'll set the two classes we just created as the new endpoint's argument and return type. .. code-block:: python :caption: app/main.py @app.post("/search") async def search(search_terms: SearchTerms) -> SearchResult: return SearchResult(response=search_terms.query) .. edb:split-section:: Same as before, we can test the endpoint using the UI, or by sending a request with ``curl``: .. code-block:: bash $ curl -X 'POST' \ 'http://127.0.0.1:8000/search' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "query": "string" }' { "response": "string", } 3. Implement web search ======================= Now that we have our web app infrastructure in place, let's add some substance to it by implementing web search capabilities. .. edb:split-section:: There're many powerful feature-rich products for LLM-driven web search. But in this tutorial we're going to use a much more reliable source of real-world information that is comment threads on `Hacker News `_. Their `web API `_ is free of charge and doesn't require an account. Below is a simple function that requests a full-text search for a string query and extracts a nice sampling of comment threads from each of the stories that came up in the result. We are not going to cover this code sample in too much depth. Feel free to grab it save it to ``app/web.py``, or make your own. Notice that we've created another Pydantic type called ``WebSource`` to store our web search results. There's no framework-related reason for that, it's just nicer than passing dictionaries around. .. code-block:: python :caption: app/web.py :class: collapsible import requests from pydantic import BaseModel from datetime import datetime import html class WebSource(BaseModel): """Type that stores search results.""" url: str | None = None title: str | None = None text: str | None = None def extract_comment_thread( comment: dict, max_depth: int = 3, current_depth: int = 0, max_children=3, ) -> list[str]: """ Recursively extract comments from a thread up to max_depth. Returns a list of formatted comment strings. """ if not comment or current_depth > max_depth: return [] results = [] # Get timestamp, author and the body of the comment, # then pad it with spaces so that it's offset appropriately for its depth if comment["text"]: timestamp = datetime.fromisoformat(comment["created_at"].replace("Z", "+00:00")) author = comment["author"] text = html.unescape(comment["text"]) formatted_comment = f"[{timestamp.strftime('%Y-%m-%d %H:%M')}] {author}: {text}" results.append((" " * current_depth) + formatted_comment) # If there're children comments, we are going to extract them too, # and add them to the list. if comment.get("children"): for child in comment["children"][:max_children]: child_comments = extract_comment_thread(child, max_depth, current_depth + 1) results.extend(child_comments) return results def fetch_web_sources(query: str, limit: int = 5) -> list[WebSource]: """ For a given query perform a full-text search for stories on Hacker News. From each of the matched stories extract the comment thread and format it into a single string. For each story return its title, url and comment thread. """ search_url = "http://hn.algolia.com/api/v1/search_by_date?numericFilters=num_comments>0" # Search for stories response = requests.get( search_url, params={ "query": query, "tags": "story", "hitsPerPage": limit, "page": 0, }, ) response.raise_for_status() search_result = response.json() # For each search hit fetch and process the story web_sources = [] for hit in search_result.get("hits", []): item_url = f"https://hn.algolia.com/api/v1/items/{hit['story_id']}" response = requests.get(item_url) response.raise_for_status() item_result = response.json() site_url = f"https://news.ycombinator.com/item?id={hit['story_id']}" title = hit["title"] comments = extract_comment_thread(item_result) text = "\n".join(comments) if len(comments) > 0 else None web_sources.append( WebSource(url=site_url, title=title, text=text) ) return web_sources if __name__ == "__main__": web_sources = fetch_web_sources("edgedb", limit=5) for source in web_sources: print(source.url) print(source.title) print(source.text) .. edb:split-section:: One more note: this snippet comes with an extra dependency called ``requests``, which is a library for making HTTP requests. Let's add it by running: .. code-block:: bash $ uv add requests .. edb:split-section:: Now, we can test our web search on its own by running it like this: .. code-block:: bash $ python3 app/web.py .. edb:split-section:: It's time to reflect the new capabilities in our web app. .. code-block:: python :caption: app/main.py from .web import fetch_web_sources, WebSource async def search_web(query: str) -> list[WebSource]: raw_sources = fetch_web_sources(query, limit=5) return [s for s in raw_sources if s.text is not None] .. edb:split-section:: Now we can update the ``/search`` endpoint as follows: .. code-block:: python-diff :caption: app/main.py class SearchResult(BaseModel): response: str | None = None + sources: list[WebSource] | None = None @app.post("/search") async def search(search_terms: SearchTerms) -> SearchResult: + web_sources = await search_web(search_terms.query) - return SearchResult(response=search_terms.query) + return SearchResult( + response=search_terms.query, sources=web_sources + ) 4. Connect to the LLM ===================== Now that we're capable of scraping text from search results, we can forward those results to the LLM to get a nice-looking summary. .. edb:split-section:: There's a million different LLMs accessible via a web API (`one `_, `two `_, `three `_, `four `_ to name a few), feel free to choose whichever you prefer. In this tutorial we will roll with OpenAI, primarily for how ubiquitous it is. To keep things somewhat provider-agnostic, we're going to get completions via raw HTTP requests. Let's grab API descriptions from OpenAI's `API documentation `_, and set up LLM generation like this: .. code-block:: python :caption: app/main.py import requests from dotenv import load_dotenv _ = load_dotenv() def get_llm_completion(system_prompt: str, messages: list[dict[str, str]]) -> str: api_key = os.getenv("OPENAI_API_KEY") url = "https://api.openai.com/v1/chat/completions" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} response = requests.post( url, headers=headers, json={ "model": "gpt-4o-mini", "messages": [ {"role": "developer", "content": system_prompt}, *messages, ], }, ) response.raise_for_status() result = response.json() return result["choices"][0]["message"]["content"] .. edb:split-section:: Note that this cloud LLM API (and many others) requires a secret key to be set as an environment variable. A common way to manage those is to use the ``python-dotenv`` library in combinations with a ``.env`` file. Feel free to browse `the readme `_, to learn more. Create a file called ``.env`` in the root directory and put your api key in there: .. code-block:: .env :caption: .env OPENAI_API_KEY="sk-..." .. edb:split-section:: Don't forget to add the new dependency to the environment: .. code-block:: bash uv add python-dotenv .. edb:split-section:: And now we can integrate this LLM-related code with the rest of the app. First, let's set up a function that prepares LLM inputs: .. code-block:: python :caption: app/main.py async def generate_answer( query: str, web_sources: list[WebSource], ) -> SearchResult: system_prompt = ( "You are a helpful assistant that answers user's questions" + " by finding relevant information in Hacker News threads." + " When answering the question, describe conversations that people have around the subject," + " provided to you as a context, or say i don't know if they are completely irrelevant." ) prompt = f"User search query: {query}\n\nWeb search results:\n" for i, source in enumerate(web_sources): prompt += f"Result {i} (URL: {source.url}):\n" prompt += f"{source.text}\n\n" messages = [{"role": "user", "content": prompt}] llm_response = get_llm_completion( system_prompt=system_prompt, messages=messages, ) search_result = SearchResult( response=llm_response, sources=web_sources, ) return search_result .. edb:split-section:: Then we can plug that function into the ``/search`` endpoint: .. code-block:: python-diff :caption: app/main.py @app.post("/search") async def search(search_terms: SearchTerms) -> SearchResult: web_sources = await search_web(search_terms.query) + search_result = await generate_answer(search_terms.query, web_sources) + return search_result - return SearchResult( - response=search_terms.query, sources=web_sources - ) .. edb:split-section:: And now we can test the result as usual. .. code-block:: bash $ curl -X 'POST' \ 'http://127.0.0.1:8000/search' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "query": "gel" }' 5. Use Gel to implement chat history ==================================== So far we've built an application that can take in a query, fetch some Hacker News threads for it, sift through them using an LLM, and generate a nice summary. However, right now it's hardly user-friendly since you have to speak in keywords and basically start over every time you want to refine the query. To enable a more organic multi-turn interaction, we need to add chat history and infer the query from the context of the entire conversation. Now's a good time to introduce Gel. .. edb:split-section:: In case you need installation instructions, take a look at the :ref:`Quickstart `. Once Gel CLI is present in your system, initialize the project like this: .. code-block:: bash $ gel project init --non-interactive This command is going to put some project scaffolding inside our app, spin up a local instace of Gel, and then link the two together. From now on, all Gel-related things that happen inside our project directory are going to be automatically run on the correct database instance, no need to worry about connection incantations. Defining the schema ------------------- The database :ref:`schema ` in Gel is defined declaratively. The :gelcmd:`project init` command has created a file called :dotgel:`dbschema/default`, which we're going to use to define our types. .. edb:split-section:: We obviously want to keep track of the messages, so we need to represent those in the schema. By convention established in the LLM space, each message is going to have a role in addition to the message content itself. We can also get Gel to automatically keep track of message's creation time by adding a property callled ``timestamp`` and setting its :ref:`default value ` to the output of the :ref:`datetime_current() ` function. Finally, LLM messages in our search bot have source URLs associated with them. Let's keep track of those too, by adding a :ref:`multi-property `. .. code-block:: sdl :caption: dbschema/default.esdl type Message { role: str; body: str; timestamp: datetime { default := datetime_current(); } multi sources: str; } .. edb:split-section:: Messages are grouped together into a chat, so let's add that entity to our schema too. .. code-block:: sdl :caption: dbschema/default.esdl type Chat { multi messages: Message; } .. edb:split-section:: And chats all belong to a certain user, making up their chat history. One other thing we'd like to keep track of about our users is their username, and it would make sense for us to make sure that it's unique by using an ``excusive`` :ref:`constraint `. .. code-block:: sdl :caption: dbschema/default.esdl type User { name: str { constraint exclusive; } multi chats: Chat; } .. edb:split-section:: We're going to keep our schema super simple. One cool thing about Gel is that it will enable us to easily implement advanced features such as authentication or AI down the road, but we're gonna come back to that later. For now, this is the entire schema we came up with: .. code-block:: sdl :caption: dbschema/default.esdl module default { type Message { role: str; body: str; timestamp: datetime { default := datetime_current(); } multi sources: str; } type Chat { multi messages: Message; } type User { name: str { constraint exclusive; } multi chats: Chat; } } .. edb:split-section:: Let's use the :gelcmd:`migration create` CLI command, followed by :gelcmd:`migrate` in order to migrate to our new schema and proceed to writing some queries. .. code-block:: bash $ gel migration create $ gel migrate .. edb:split-section:: Now that our schema is applied, let's quickly populate the database with some fake data in order to be able to test the queries. We're going to explore writing queries in a bit, but for now you can just run the following command in the shell: .. code-block:: bash :class: collapsible $ mkdir app/sample_data && cat << 'EOF' > app/sample_data/inserts.edgeql # Create users first insert User { name := 'alice', }; insert User { name := 'bob', }; # Insert chat histories for Alice update User filter .name = 'alice' set { chats := { (insert Chat { messages := { (insert Message { role := 'user', body := 'What are the main differences between GPT-3 and GPT-4?', timestamp := '2024-01-07T10:00:00Z', sources := {'arxiv:2303.08774', 'openai.com/research/gpt-4'} }), (insert Message { role := 'assistant', body := 'The key differences include improved reasoning capabilities, better context understanding, and enhanced safety features...', timestamp := '2024-01-07T10:00:05Z', sources := {'openai.com/blog/gpt-4-details', 'arxiv:2303.08774'} }) } }), (insert Chat { messages := { (insert Message { role := 'user', body := 'Can you explain what policy gradient methods are in RL?', timestamp := '2024-01-08T14:30:00Z', sources := {'Sutton-Barto-RL-Book-Ch13', 'arxiv:1904.12901'} }), (insert Message { role := 'assistant', body := 'Policy gradient methods are a class of reinforcement learning algorithms that directly optimize the policy...', timestamp := '2024-01-08T14:30:10Z', sources := {'Sutton-Barto-RL-Book-Ch13', 'spinning-up.openai.com'} }) } }) } }; # Insert chat histories for Bob update User filter .name = 'bob' set { chats := { (insert Chat { messages := { (insert Message { role := 'user', body := 'What are the pros and cons of different sharding strategies?', timestamp := '2024-01-05T16:15:00Z', sources := {'martin-kleppmann-ddia-ch6', 'aws.amazon.com/sharding-patterns'} }), (insert Message { role := 'assistant', body := 'The main sharding strategies include range-based, hash-based, and directory-based sharding...', timestamp := '2024-01-05T16:15:08Z', sources := {'martin-kleppmann-ddia-ch6', 'mongodb.com/docs/sharding'} }), (insert Message { role := 'user', body := 'Could you elaborate on hash-based sharding?', timestamp := '2024-01-05T16:16:00Z', sources := {'mongodb.com/docs/sharding'} }) } }) } }; EOF .. edb:split-section:: This created the ``app/sample_data/inserts.edgeql`` file, which we can now execute using the CLI like this: .. code-block:: bash $ gel query -f app/sample_data/inserts.edgeql {"id": "862de904-de39-11ef-9713-4fab09220c4a"} {"id": "862e400c-de39-11ef-9713-2f81f2b67013"} {"id": "862de904-de39-11ef-9713-4fab09220c4a"} {"id": "862e400c-de39-11ef-9713-2f81f2b67013"} .. edb:split-section:: The :gelcmd:`query` command is one of many ways we can execute a query in Gel. Now that we've done it, there's stuff in the database. Let's verify it by running: .. code-block:: bash $ gel query "select User { name };" {"name": "alice"} {"name": "bob"} Writing queries --------------- With schema in place, it's time to focus on getting the data in and out of the database. In this tutorial we're going to write queries using :ref:`EdgeQL ` and then use :ref:`codegen ` to generate typesafe function that we can plug directly into out Python code. If you are completely unfamiliar with EdgeQL, now is a good time to check out the basics before proceeding. .. edb:split-section:: Let's move on. First, we'll create a directory inside ``app`` called ``queries``. This is where we're going to put all of the EdgeQL-related stuff. We're going to start by writing a query that fetches all of the users. In ``queries`` create a file named ``get_users.edgeql`` and put the following query in there: .. code-block:: edgeql :caption: app/queries/get_users.edgeql select User { name }; .. edb:split-section:: Now run the code generator from the shell: .. code-block:: bash $ gel-py .. edb:split-section:: It's going to automatically locate the ``.edgeql`` file and generate types for it. We can inspect generated code in ``app.queries/get_users_async_edgeql.py``. Once that is done, let's use those types to create the endpoint in ``main.py``: .. code-block:: python :caption: app/main.py from edgedb import create_async_client from .queries.get_users_async_edgeql import get_users as get_users_query, GetUsersResult gel_client = create_async_client() @app.get("/users") async def get_users() -> list[GetUsersResult]: return await get_users_query(gel_client) .. edb:split-section:: Let's verify it that works as expected: .. code-block:: bash $ curl -X 'GET' \ 'http://127.0.0.1:8000/users' \ -H 'accept: application/json' [ { "id": "862de904-de39-11ef-9713-4fab09220c4a", "name": "alice" }, { "id": "862e400c-de39-11ef-9713-2f81f2b67013", "name": "bob" } ] .. edb:split-section:: While we're at it, let's also implement the option to fetch a user by their username. In order to do that, we need to write a new query in a separate file ``app/queries/get_user_by_name.edgeql``: .. code-block:: edgeql :caption: app/queries/get_user_by_name.edgeql select User { name } filter .name = $name; .. edb:split-section:: After that, we will run the code generator again by calling ``gel-py``. In the app, we are going to reuse the same endpoint that fetches the list of all users. From now on, if the user calls it without any arguments (e.g. ``http://127.0.0.1/users``), they are going to receive the list of all users, same as before. But if they pass a username as a query argument like this: ``http://127.0.0.1/users?username=bob``, the system will attempt to fetch a user named ``bob``. In order to achieve this, we're going to need to add a ``Query``-type argument to our endpoint function. You can learn more about how to configure this type of arguments in `FastAPI's docs `_. It's default value is going to be ``None``, which will enable us to implement our conditional logic: .. code-block:: python :caption: app/main.py from fastapi import Query, HTTPException from http import HTTPStatus from .queries.get_user_by_name_async_edgeql import ( get_user_by_name as get_user_by_name_query, GetUserByNameResult, ) @app.get("/users") async def get_users( username: str = Query(None), ) -> list[GetUsersResult] | GetUserByNameResult: """List all users or get a user by their username""" if username: user = await get_user_by_name_query(gel_client, name=username) if not user: raise HTTPException( HTTPStatus.NOT_FOUND, detail={"error": f"Error: user {username} does not exist."}, ) return user else: return await get_users_query(gel_client) .. edb:split-section:: And once again, let's verify that everything works: .. code-block:: bash $ curl -X 'GET' \ 'http://127.0.0.1:8000/users?username=alice' \ -H 'accept: application/json' { "id": "862de904-de39-11ef-9713-4fab09220c4a", "name": "alice" } .. edb:split-section:: Finally, let's also implement the option to add a new user. For this, just as before, we'll create a new file ``app/queries/create_user.edgeql``, add a query to it and run code generation. Note that in this query we've wrapped the ``insert`` in a ``select`` statement. This is a common pattern in EdgeQL, that can be used whenever you would like to get something other than object ID when you just inserted it. .. code-block:: edgeql :caption: app/queries/create_user.edgeql select( insert User { name := $username } ) { name } .. edb:split-section:: In order to integrate this query into our app, we're going to add a new endpoint. Note that this one has the same name ``/users``, but is for the POST HTTP method. .. code-block:: python :caption: app/main.py from gel import ConstraintViolationError from .queries.create_user_async_edgeql import ( create_user as create_user_query, CreateUserResult, ) @app.post("/users", status_code=HTTPStatus.CREATED) async def post_user(username: str = Query()) -> CreateUserResult: try: return await create_user_query(gel_client, username=username) except ConstraintViolationError: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail={"error": f"Username '{username}' already exists."}, ) .. edb:split-section:: Once more, let's verify that the new endpoint works as expected: .. code-block:: bash $ curl -X 'POST' \ 'http://127.0.0.1:8000/users?username=charlie' \ -H 'accept: application/json' \ -d '' { "id": "20372a1a-ded5-11ef-9a08-b329b578c45c", "name": "charlie" } .. edb:split-section:: This wraps things up for our user-related functionality. Of course, we now need to deal with Chats and Messages, too. We're not going to go in depth for those, since the process would be quite similar to what we've just done. Instead, feel free to implement those endpoints yourself as an exercise, or copy the code below if you are in rush. .. code-block:: bash :class: collapsible $ echo 'select Chat { messages: { role, body, sources }, user := .$username;' > app/queries/get_chats.edgeql && echo 'select Chat { messages: { role, body, sources }, user := .$username and .id = $chat_id;' > app/queries/get_chat_by_id.edgeql && echo 'with new_chat := (insert Chat) select ( update User filter .name = $username set { chats := assert_distinct(.chats union new_chat) } ) { new_chat_id := new_chat.id }' > app/queries/create_chat.edgeql && echo 'with user := (select User filter .name = $username), chat := ( select Chat filter .$chat_id ) select Message { role, body, sources, chat := . app/queries/get_messages.edgeql && echo 'with user := (select User filter .name = $username), update Chat filter .id = $chat_id and .$message_role, body := $message_body, sources := array_unpack(>$sources) } )) }' > app/queries/add_message.edgeql .. edb:split-section:: And these are the endpoint definitions, provided in bulk. .. code-block:: python :caption: app/main.py :class: collapsible from .queries.get_chats_async_edgeql import get_chats as get_chats_query, GetChatsResult from .queries.get_chat_by_id_async_edgeql import ( get_chat_by_id as get_chat_by_id_query, GetChatByIdResult, ) from .queries.get_messages_async_edgeql import ( get_messages as get_messages_query, GetMessagesResult, ) from .queries.create_chat_async_edgeql import ( create_chat as create_chat_query, CreateChatResult, ) from .queries.add_message_async_edgeql import ( add_message as add_message_query, ) @app.get("/chats") async def get_chats( username: str = Query(), chat_id: str = Query(None) ) -> list[GetChatsResult] | GetChatByIdResult: """List user's chats or get a chat by username and id""" if chat_id: chat = await get_chat_by_id_query( gel_client, username=username, chat_id=chat_id ) if not chat: raise HTTPException( HTTPStatus.NOT_FOUND, detail={"error": f"Chat {chat_id} for user {username} does not exist."}, ) return chat else: return await get_chats_query(gel_client, username=username) @app.post("/chats", status_code=HTTPStatus.CREATED) async def post_chat(username: str) -> CreateChatResult: return await create_chat_query(gel_client, username=username) @app.get("/messages") async def get_messages( username: str = Query(), chat_id: str = Query() ) -> list[GetMessagesResult]: """Fetch all messages from a chat""" return await get_messages_query(gel_client, username=username, chat_id=chat_id) .. edb:split-section:: For the ``post_messages`` function we're going to do something a little bit different though. Since this is now the primary way for the user to add their queries to the system, it functionally superceeds the ``/search`` endpoint we made before. To this end, this function is where we're going to handle saving messages, retrieving chat history, invoking web search and generating the answer. .. code-block:: python-diff :caption: app/main.py - @app.post("/search") - async def search(search_terms: SearchTerms) -> SearchResult: - web_sources = await search_web(search_terms.query) - search_result = await generate_answer(search_terms.query, web_sources) - return search_result + @app.post("/messages", status_code=HTTPStatus.CREATED) + async def post_messages( + search_terms: SearchTerms, + username: str = Query(), + chat_id: str = Query(), + ) -> SearchResult: + chat_history = await get_messages_query( + gel_client, username=username, chat_id=chat_id + ) + _ = await add_message_query( + gel_client, + username=username, + message_role="user", + message_body=search_terms.query, + sources=[], + chat_id=chat_id, + ) + search_query = search_terms.query + web_sources = await search_web(search_query) + search_result = await generate_answer( + search_terms.query, chat_history, web_sources + ) + _ = await add_message_query( + gel_client, + username=username, + message_role="assistant", + message_body=search_result.response, + sources=search_result.sources, + chat_id=chat_id, + ) + return search_result .. edb:split-section:: Let's not forget to modify the ``generate_answer`` function, so it can also be history-aware. .. code-block:: python-diff :caption: app/main.py async def generate_answer( query: str, + chat_history: list[GetMessagesResult], web_sources: list[WebSource], ) -> SearchResult: system_prompt = ( "You are a helpful assistant that answers user's questions" + " by finding relevant information in HackerNews threads." + " When answering the question, describe conversations that people have around the subject," + " provided to you as a context, or say i don't know if they are completely irrelevant." ) prompt = f"User search query: {query}\n\nWeb search results:\n" for i, source in enumerate(web_sources): prompt += f"Result {i} (URL: {source.url}):\n" prompt += f"{source.text}\n\n" - messages = [{"role": "user", "content": prompt}] + messages = [ + {"role": message.role, "content": message.body} for message in chat_history + ] + messages.append({"role": "user", "content": prompt}) llm_response = get_llm_completion( system_prompt=system_prompt, messages=messages, ) search_result = SearchResult( response=llm_response, sources=web_sources, ) return search_result .. edb:split-section:: Ok, this should be it for setting up the chat history. Let's test it. First, we are going to start a new chat for our user: .. code-block:: bash $ curl -X 'POST' \ 'http://127.0.0.1:8000/chats?username=charlie' \ -H 'accept: application/json' \ -d '' { "id": "20372a1a-ded5-11ef-9a08-b329b578c45c", "new_chat_id": "544ef3f2-ded8-11ef-ba16-f7f254b95e36" } .. edb:split-section:: Next, let's add a couple messages and wait for the bot to respond: .. code-block:: bash $ curl -X 'POST' \ 'http://127.0.0.1:8000/messages?username=charlie&chat_id=544ef3f2-ded8-11ef-ba16-f7f254b95e36' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "query": "best database in existence" }' $ curl -X 'POST' \ 'http://127.0.0.1:8000/messages?username=charlie&chat_id=544ef3f2-ded8-11ef-ba16-f7f254b95e36' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "query": "gel" }' .. edb:split-section:: Finally, let's check that the messages we saw are in fact stored in the chat history: .. code-block:: bash $ curl -X 'GET' \ 'http://127.0.0.1:8000/messages?username=charlie&chat_id=544ef3f2-ded8-11ef-ba16-f7f254b95e36' \ -H 'accept: application/json' In reality this workflow would've been handled by the frontend, providing the user with a nice inteface to interact with. But even without one our chatbot is almost functional by now. Generating a Google search query -------------------------------- Congratulations! We just got done implementing multi-turn conversations for our search bot. However, there's still one crucial piece missing. Right now we're simply forwarding the users message straight to the full-text search. But what happens if their message is a followup that cannot be used as a standalone search query? Ideally what we should do is we should infer the search query from the entire conversation, and use that to perform the search. Let's implement an extra step in which the LLM is going to produce a query for us based on the entire chat history. That way we can be sure we're progressively working on our query rather than rewriting it from scratch every time. .. edb:split-section:: This is what we need to do: every time the user submits a message, we need to fetch the chat history, extract a search query from it using the LLM, and the other steps are going to the the same as before. Let's make the follwing modifications to the ``main.py``: first we need to create a function that prepares LLM inputs for the search query inference. .. code-block:: python :caption: app/main.py async def generate_search_query( query: str, message_history: list[GetMessagesResult] ) -> str: system_prompt = ( "You are a helpful assistant." + " Your job is to extract a keyword search query" + " from a chat between an AI and a human." + " Make sure it's a single most relevant keyword to maximize matching." + " Only provide the query itself as your response." ) formatted_history = "\n---\n".join( [ f"{message.role}: {message.body} (sources: {message.sources})" for message in message_history ] ) prompt = f"Chat history: {formatted_history}\n\nUser message: {query} \n\n" llm_response = get_llm_completion( system_prompt=system_prompt, messages=[{"role": "user", "content": prompt}] ) return llm_response .. edb:split-section:: And now we can use this function in ``post_messages`` in order to get our search query: .. code-block:: python-diff :caption: app/main.py class SearchResult(BaseModel): response: str | None = None + search_query: str | None = None sources: list[WebSource] | None = None @app.post("/messages", status_code=HTTPStatus.CREATED) async def post_messages( search_terms: SearchTerms, username: str = Query(), chat_id: str = Query(), ) -> SearchResult: # 1. Fetch chat history chat_history = await get_messages_query( gel_client, username=username, chat_id=chat_id ) # 2. Add incoming message to Gel _ = await add_message_query( gel_client, username=username, message_role="user", message_body=search_terms.query, sources=[], chat_id=chat_id, ) # 3. Generate a query and perform googling - search_query = search_terms.query + search_query = await generate_search_query(search_terms.query, chat_history) + web_sources = await search_web(search_query) # 5. Generate answer search_result = await generate_answer( search_terms.query, chat_history, web_sources, ) + search_result.search_query = search_query # add search query to the output + # to see what the bot is searching for # 6. Add LLM response to Gel _ = await add_message_query( gel_client, username=username, message_role="assistant", message_body=search_result.response, sources=[s.url for s in search_result.sources], chat_id=chat_id, ) # 7. Send result back to the client return search_result .. edb:split-section:: Done! We've now fully integrated the chat history into out app and enabled natural language conversations. As before, let's quickly test out the improvements before moving on: .. code-block:: bash $ curl -X 'POST' \ 'http://localhost:8000/messages?username=alice&chat_id=d4eed420-e903-11ef-b8a7-8718abdafbe1' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "query": "what are people saying about gel" }' $ curl -X 'POST' \ 'http://localhost:8000/messages?username=alice&chat_id=d4eed420-e903-11ef-b8a7-8718abdafbe1' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "query": "do they like it or not" }' 6. Use Gel's advanced features to create a RAG ============================================== At this point we have a decent search bot that can refine a search query over multiple turns of a conversation. It's time to add the final touch: we can make the bot remember previous similar interactions with the user using retrieval-augmented generation (RAG). To achieve this we need to implement similarity search across message history: we're going to create a vector embedding for every message in the database using a neural network. Every time we generate a Google search query, we're also going to use it to search for similar messages in user's message history, and inject the corresponding chat into the prompt. That way the search bot will be able to quickly "remember" similar interactions with the user and use them to understand what they are looking for. Gel enables us to implement such a system with only minor modifications to the schema. .. edb:split-section:: We begin by enabling the ``ai`` extension by adding the following like on top of the :dotgel:`dbschema/default`: .. code-block:: sdl-diff :caption: dbschema/default.esdl + using extension ai; .. edb:split-section:: ... and do the migration: .. code-block:: bash $ gel migration create $ gel migrate .. edb:split-section:: Next, we need to configure the API key in Gel for whatever embedding provider we're going to be using. As per documentation, let's open up the CLI by typing ``gel`` and run the following command (assuming we're using OpenAI): .. code-block:: edgeql-repl searchbot:main> configure current database insert ext::ai::OpenAIProviderConfig { secret := 'sk-....', }; OK: CONFIGURE DATABASE .. edb:split-section:: In order to get Gel to automatically keep track of creating and updating message embeddings, all we need to do is create a deferred index like this. Don't forget to run a migration one more time! .. code-block:: sdl-diff type Message { role: str; body: str; timestamp: datetime { default := datetime_current(); } multi sources: str; + deferred index ext::ai::index(embedding_model := 'text-embedding-3-small') + on (.body); } .. edb:split-section:: And we're done! Gel is going to cook in the background for a while and generate embedding vectors for our queries. To make sure nothing broke we can follow Gel's AI documentation and take a look at instance logs: .. code-block:: bash $ gel instance logs -I searchbot | grep api.openai.com INFO 50121 searchbot 2025-01-30T14:39:53.364 httpx: HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK" .. edb:split-section:: It's time to create the second half of the similarity search - the search query. The query needs to fetch ``k`` chats in which there're messages that are most similar to our current message. This can be a little difficult to visualize in your head, so here's the query itself: .. code-block:: edgeql :caption: app/queries/search_chats.edgeql with user := (select User filter .name = $username), chats := ( select Chat filter .$current_chat_id ) select chats { distance := min( ext::ai::search( .messages, >$embedding, ).distance, ), messages: { role, body, sources } } order by .distance limit $limit; .. edb:split-section:: .. note:: Before we can integrate this query into our Python app, we also need to add a new dependency for the Python binding: ``httpx-sse``. It's enables streaming outputs, which we're not going to use right now, but we won't be able to create the AI client without it. Let's place in in ``app/queries/search_chats.edgeql``, run the codegen and modify our ``post_messages`` endpoint to keep track of those similar chats. .. code-block:: python-diff :caption: app.main.py + from edgedb.ai import create_async_ai, AsyncEdgeDBAI + from .queries.search_chats_async_edgeql import ( + search_chats as search_chats_query, + ) class SearchResult(BaseModel): response: str | None = None search_query: str | None = None sources: list[WebSource] | None = None + similar_chats: list[str] | None = None @app.post("/messages", status_code=HTTPStatus.CREATED) async def post_messages( search_terms: SearchTerms, username: str = Query(), chat_id: str = Query(), ) -> SearchResult: # 1. Fetch chat history chat_history = await get_messages_query( gel_client, username=username, chat_id=chat_id ) # 2. Add incoming message to Gel _ = await add_message_query( gel_client, username=username, message_role="user", message_body=search_terms.query, sources=[], chat_id=chat_id, ) # 3. Generate a query and perform googling search_query = await generate_search_query(search_terms.query, chat_history) web_sources = await search_web(search_query) + # 4. Fetch similar chats + db_ai: AsyncEdgeDBAI = await create_async_ai(gel_client, model="gpt-4o-mini") + embedding = await db_ai.generate_embeddings( + search_query, model="text-embedding-3-small" + ) + similar_chats = await search_chats_query( + gel_client, + username=username, + current_chat_id=chat_id, + embedding=embedding, + limit=1, + ) # 5. Generate answer search_result = await generate_answer( search_terms.query, chat_history, web_sources, + similar_chats, ) search_result.search_query = search_query # add search query to the output # to see what the bot is searching for # 6. Add LLM response to Gel _ = await add_message_query( gel_client, username=username, message_role="assistant", message_body=search_result.response, sources=[s.url for s in search_result.sources], chat_id=chat_id, ) # 7. Send result back to the client return search_result .. edb:split-section:: Finally, the answer generator needs to get updated one more time, since we need to inject the additional messages into the prompt. .. code-block:: python-diff :caption: app/main.py async def generate_answer( query: str, chat_history: list[GetMessagesResult], web_sources: list[WebSource], + similar_chats: list[list[GetMessagesResult]], ) -> SearchResult: system_prompt = ( "You are a helpful assistant that answers user's questions" + " by finding relevant information in HackerNews threads." + " When answering the question, describe conversations that people have around the subject, provided to you as a context, or say i don't know if they are completely irrelevant." + + " You can reference previous conversation with the user that" + + " are provided to you, if they are relevant, by explicitly referring" + + " to them by saying as we discussed in the past." ) prompt = f"User search query: {query}\n\nWeb search results:\n" for i, source in enumerate(web_sources): prompt += f"Result {i} (URL: {source.url}):\n" prompt += f"{source.text}\n\n" + prompt += "Similar chats with the same user:\n" + formatted_chats = [] + for i, chat in enumerate(similar_chats): + formatted_chat = f"Chat {i}: \n" + for message in chat.messages: + formatted_chat += f"{message.role}: {message.body}\n" + formatted_chats.append(formatted_chat) + prompt += "\n".join(formatted_chats) messages = [ {"role": message.role, "content": message.body} for message in chat_history ] messages.append({"role": "user", "content": prompt}) llm_response = get_llm_completion( system_prompt=system_prompt, messages=messages, ) search_result = SearchResult( response=llm_response, sources=web_sources, + similar_chats=formatted_chats, ) return search_result .. edb:split-section:: And one last time, let's check to make sure everything works: .. code-block:: bash $ curl -X 'POST' \ 'http://localhost:8000/messages?username=alice&chat_id=d4eed420-e903-11ef-b8a7-8718abdafbe1' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "query": "remember that cool db i was talking to you about?" }' Keep going! =========== This tutorial is over, but this app surely could use way more features! Basic functionality like deleting messages, a user interface or real web search, sure. But also authentication or access policies -- Gel will let you set those up in minutes. Thanks! ================================================ FILE: docs/intro/tutorials/gel_drizzle_booknotes.rst ================================================ .. _ref_guide_gel_drizzle_booknotes: ==================================================== Build a Book Notes App with Drizzle ==================================================== :edb-alt-title: Building a book notes app using Gel, Drizzle ORM, and Next.js In this tutorial we're going to walk you through building a Book Notes application that lets you keep track of books you've read along with your personal notes. We'll be using Gel as the database, Drizzle as the ORM layer, and Next.js as our full-stack framework. Gel is a data layer designed to supercharge PostgreSQL with a graph-like object model, access control, Auth, and many other features. It provides a unified schema and tooling experience across multiple languages, making it ideal for projects with diverse tech stacks. With Gel, you get access to EdgeQL, which eliminates n+1 query problems, supports automatic embeddings, and offers a seamless developer experience. Drizzle, on the other hand, is a TypeScript ORM that offers type safety and a great developer experience. By combining Gel with Drizzle, you can leverage Gel's powerful features while using Drizzle as a familiar ORM layer to interact with your database. This approach is perfect for developers who want to start learning Gel or prefer using Drizzle for their projects. experience. Next.js is a React framework for building production-ready web applications with features like server components, built-in routing, and API routes. By the end of this tutorial, you will see how these technologies work together to create a modern, full-stack web application with a great developer experience. .. note:: The complete source code for this tutorial is available in our `Gel Examples repository `_. We will start by creating a Gel schema, setting up Drizzle, and then building a Next.js application with API routes and a simple UI to manage your book collection and notes. 1. Initialize the project ========================= .. edb:split-section:: Let's start with setting up our project. We'll create a new Next.js application, install the necessary dependencies, and initialize a Gel project. Here's a summary of the setup steps we'll follow: 1. Create a Next.js application 2. Install Gel and related packages 3. Initialize a Gel project 4. Update schema and apply migrations 5. Install and set up Drizzle 6. Pull schema into Drizzle 7. Configure hooks in gel.toml .. edb:split-section:: First, let's create a new Next.js application. You can use the ``create-next-app`` command to set up a new project. When prompted, choose TypeScript, ESLint, Tailwind CSS, and the App Router. You can skip the default import alias configuration. This will create a new Next.js application with the necessary configuration files and dependencies. .. note:: Make sure you have Node.js and npm installed on your machine. .. code-block:: bash # Step 1: Create a Next.js application $ npx create-next-app@latest book-notes-app # When prompted, choose: # ✔ Would you like to use TypeScript? Yes # ✔ Would you like to use ESLint? Yes # ✔ Would you like to use Tailwind CSS? Yes # ✔ Would you like to use `src/` directory? Yes # ✔ Would you like to use App Router? Yes # ✔ Would you like to use Turbopack for `next dev`? No # ✔ Would you like to customize the default import alias (@/*)? No $ cd book-notes-app .. edb:split-section:: Next, let's install the Gel library. We'll need ``gel`` for database access. .. code-block:: bash $ npm i gel .. edb:split-section:: Now, we'll initialize a Gel project. This will create the necessary configuration files and set up a local Gel instance. .. code-block:: bash $ npx gel project init 2. Define the Gel schema ======================== Now that we have our project environment set up, let's define our database schema. For our Book Notes app, we'll create two main types: 1. ``Book`` - to store information about books 2. ``Note`` - to store notes associated with each book Let's edit the :dotgel:`dbschema/default` file that was created during initialization. .. edb:split-section:: Our schema defines two types: - ``Book`` with properties like title, author, publication year, genre, and read date. - ``Note`` with text content and a timestamp, linked to a specific book. The relationship is defined such that a book can have multiple notes, and each note belongs to exactly one book. We're using a computed link ``notes`` to allow easy access to a book's notes. .. code-block:: sdl :caption: :dotgel:`dbschema/default` module default { type Book { required title: str; author: str; year: int16; genre: str; read_date: datetime; # Relationship to notes multi notes := . [run ``drizzle-kit pull``]. .. code-block:: typescript :caption: drizzle/schema.ts :class: collapsible import { gelTable, uniqueIndex, uuid, text, timestamptz, smallint, foreignKey } from "drizzle-orm/gel-core" import { sql } from "drizzle-orm" export const book = gelTable("Book", { id: uuid().default(sql`uuid_generate_v4()`).primaryKey().notNull(), author: text(), genre: text(), readDate: timestamptz("read_date"), title: text().notNull(), year: smallint(), }, (table) => [ uniqueIndex("5f1d3546-1943-11f0-be08-df1707d45eaa;schemaconstr").using("btree", table.id.asc().nullsLast().op("uuid_ops")), ]); export const note = gelTable("Note", { id: uuid().default(sql`uuid_generate_v4()`).primaryKey().notNull(), bookId: uuid("book_id").notNull(), createdAt: timestamptz("created_at").default(sql`(clock_timestamp())`), text: text().notNull(), }, (table) => [ uniqueIndex("5f1e4652-1943-11f0-a4a0-f1f912666606;schemaconstr").using("btree", table.id.asc().nullsLast().op("uuid_ops")), foreignKey({ columns: [table.bookId], foreignColumns: [book.id], name: "Note_fk_book" }), ]); .. edb:split-section:: Finally, we need to update the hooks in our ``gel.toml`` file to ensure that our Drizzle schema stays in sync with our Gel schema. Every time we apply a migration, we want to run the Drizzle pull command to update the TypeScript files. .. code-block:: toml-diff :caption: gel.toml + [hooks] + after_migration_apply = [ + "npx drizzle-kit pull" + ] 4. Creating the database client ================================ .. edb:split-section:: Now that we have our schema set up, let's create a database client that we can use throughout our application. This client will connect to our Gel database using Drizzle. .. code-block:: typescript :caption: src/db/index.ts import { drizzle } from 'drizzle-orm/gel'; import { createClient } from 'gel'; import * as schema from '../../drizzle/schema'; import * as relations from '../../drizzle/relations'; // Initialize Gel client const gelClient = createClient(); // Create Drizzle instance export const db = drizzle({ client: gelClient, schema: { ...schema, ...relations, } }); // Helper types for use in our application export type Book = typeof schema.book.$inferSelect; export type NewBook = typeof schema.book.$inferInsert; export interface BookWithNotes extends Book { notes: Note[]; }; export type Note = typeof schema.note.$inferSelect; export type NewNote = typeof schema.note.$inferInsert; 5. Implementing API Routes =========================== Next, let's implement the API routes for our book notes application. With Next.js, we can create API endpoints in the ``app/api`` directory to handle HTTP requests. .. edb:split-section:: We'll start by creating a route for managing all books. This will handle fetching all books and adding new books. The ``GET`` method will return a list of all books, while the ``POST`` method will allow us to add a new book. We'll also include error handling for both methods. In both, we'll use Drizzle ORM to interact with the database. .. code-block:: typescript :caption: app/api/books/route.ts import { NextResponse } from 'next/server'; import { db } from '@/src/db'; import { books } from '@/drizzle/schema'; export async function GET() { try { const allBooks = await db.query.book.findMany({ with: { notes: true, }, }); return NextResponse.json(allBooks); } catch (error) { console.error('Error fetching books:', error); return NextResponse.json( { error: 'Failed to fetch books' }, { status: 500 } ); } } export async function POST(request: Request) { try { const body = await request.json(); const result = await db.insert(book).values({ title: body.title, author: body.author, year: body.year, genre: body.genre, readDate: new Date(body.read_date), }).returning(); return NextResponse.json(result[0], { status: 201 }); } catch (error) { console.error('Error adding book:', error); return NextResponse.json( { error: 'Failed to add book' }, { status: 500 } ); } } .. edb:split-section:: Next, let's create a route for managing a specific book by its ID. This will handle getting book details, updating books, and deleting books. - ``GET`` method will fetch a specific book by its ID. - ``PUT`` method will update the book details based on the request body. - ``DELETE`` method will delete the book and all its associated notes. We'll also include error handling for each method. .. code-block:: typescript :caption: src/app/api/books/[id]/route.ts import { NextResponse } from 'next/server'; import { db } from '@/src/db'; import { book, note } from '@/drizzle/schema'; import { eq } from 'drizzle-orm'; export async function GET( request: Request, { params }: { params: Promise<{ id: string }> } ) { const { id } = await params; try { const requestedBook = await db.query.book.findFirst({ where: eq(books.id, id), with: { note: true, }, }); if (!requestedBook) { return NextResponse.json( { error: 'Book not found' }, { status: 404 } ); } return NextResponse.json(requestedBook); } catch (error) { console.error('Error fetching book:', error); return NextResponse.json( { error: 'Failed to fetch book' }, { status: 500 } ); } } export async function PUT( request: Request, { params }: { params: Promise<{ id: string }> } ) { const { id } = await params; try { const body = await request.json(); const result = await db.update(book) .set({ title: body.title, author: body.author, year: body.year, genre: body.genre, readDate: new Date(body.read_date), }) .where(eq(books.id, id)) .returning(); if (result.length === 0) { return NextResponse.json( { error: 'Book not found' }, { status: 404 } ); } return NextResponse.json(result[0]); } catch (error) { console.error('Error updating book:', error); return NextResponse.json( { error: 'Failed to update book' }, { status: 500 } ); } } export async function DELETE( request: Request, { params }: { params: Promise<{ id: string }> } ) { const { id } = await params; try { // First delete associated notes await db.delete(note).where(eq(note.bookId, id)); // Then delete the book const result = await db.delete(book) .where(eq(book.id, id)) .returning(); if (result.length === 0) { return NextResponse.json( { error: 'Book not found' }, { status: 404 } ); } return NextResponse.json({ success: true }); } catch (error) { console.error('Error deleting book:', error); return NextResponse.json( { error: 'Failed to delete book' }, { status: 500 } ); } } .. edb:split-section:: Now, let's create a route for adding notes to a book. This endpoint will handle the creation of new notes for a specific book. The ``POST`` method will accept a request body with the note text and the book ID. .. code-block:: typescript :caption: src/app/api/books/[id]/notes/route.ts import { NextResponse } from 'next/server'; import { db } from '@/src/db'; import { note } from '@/drizzle/schema'; export async function POST( request: Request, { params }: { params: Promise<{ id: string }> } ) { const { id } = await params; try { const body = await request.json(); const result = await db.insert(note).values({ text: body.text, bookId: id, }).returning(); return NextResponse.json(result[0], { status: 201 }); } catch (error) { console.error('Error adding note:', error); return NextResponse.json( { error: 'Failed to add note' }, { status: 500 } ); } } .. edb:split-section:: Finally, let's create a route for updating and deleting individual notes. This will handle the ``PUT`` and ``DELETE`` methods for a specific note. The ``PUT`` method will update the note text, while the ``DELETE`` method will delete the note. .. code-block:: typescript :caption: src/app/api/notes/[id]/route.ts import { NextResponse } from 'next/server'; import { db } from '@/src/db'; import { note } from '@/drizzle/schema'; import { eq } from 'drizzle-orm'; export async function PUT( request: Request, { params }: { params: Promise<{ id: string }> } ) { const { id } = await params; try { const body = await request.json(); const result = await db.update(note) .set({ text: body.text, }) .where(eq(note.id, id)) .returning(); if (result.length === 0) { return NextResponse.json( { error: 'Note not found' }, { status: 404 } ); } return NextResponse.json(result[0]); } catch (error) { console.error('Error updating note:', error); return NextResponse.json( { error: 'Failed to update note' }, { status: 500 } ); } } export async function DELETE( request: Request, { params }: { params: Promise<{ id: string }> } ) { const { id } = await params; try { const result = await db.delete(note) .where(eq(notes.id, id)) .returning(); if (result.length === 0) { return NextResponse.json( { error: 'Note not found' }, { status: 404 } ); } return NextResponse.json({ success: true }); } catch (error) { console.error('Error deleting note:', error); return NextResponse.json( { error: 'Failed to delete note' }, { status: 500 } ); } } .. edb:split-section:: We can test our API routes using a tool like Postman or cURL. Let's start the development server and test the routes. .. code-block:: bash $ npm run dev .. edb:split-section:: You can now access the API routes at ``http://localhost:3000/api`` (or the port specified in your environment). For example, to access the books route, you can go to ``http://localhost:3000/api/books``. You can use Postman or cURL to test the endpoints. For example, to fetch all books, you can use the following cURL command: .. code-block:: bash $ curl -X GET http://localhost:3000/api/books .. edb:split-section:: To add a new book, you can use the following cURL command: .. code-block:: bash $ curl -X POST http://localhost:3000/api/books \ -H "Content-Type: application/json" \ -d '{"title": "The Great Gatsby", "author": "F. Scott Fitzgerald", "year": 1925, "genre": "Fiction", "read_date": "2023-10-01"}' .. edb:split-section:: Or to create a new note for a book, you can use the following cURL command (replace ```` with the actual book ID): .. code-block:: bash $ curl -X POST http://localhost:3000/api/books//notes \ -H "Content-Type: application/json" \ -d '{"text": "This is a great book!"}' 6. Building the UI ================== Now that we have our API routes in place, we can build a user interface for our book notes application. We'll use Tailwind CSS for styling, which was included when we created our Next.js application. We won't go into extensive UI details, but here's a basic implementation for the home page that lists all books. .. edb:split-section:: We'll start by creating a home page that fetches and displays all books from our API. This page will also include a link to add a new book. We'll use the ``useEffect`` hook to fetch the books when the component mounts. We'll also handle loading states and error handling. The home page will display a list of books with their titles, authors, publication years, genres, and the number of notes associated with each book. .. code-block:: typescript :caption: app/page.tsx :class: collapsible 'use client'; import { useState, useEffect } from 'react'; import Link from 'next/link'; import { BookWithNotes } from '@/db'; export default function Home() { const [books, setBooks] = useState([]); const [loading, setLoading] = useState(true); useEffect(() => { async function fetchBooks() { try { const response = await fetch('/api/books'); if (!response.ok) throw new Error('Failed to fetch books'); const data = await response.json(); setBooks(data); } catch (error) { console.error('Error:', error); } finally { setLoading(false); } } fetchBooks(); }, []); if (loading) { return (

Loading...

); } return (

My Book Notes

Add New Book
{books.length === 0 ? (

No books found. Add your first book!

) : ( books.map((book) => (

{book.title}

{book.author && (

by {book.author}

)} {book.year && (

Published: {book.year}

)} {book.genre && (

Genre: {book.genre}

)}

{book.notes?.length || 0} notes

)) )}
); } .. edb:split-section:: Next, let's create a form component for adding and editing books. This will be used in both the "Add Book" page and the "Edit Book" page. .. code-block:: typescript :caption: src/components/BookForm.tsx :class: collapsible 'use client'; import { useState, FormEvent } from 'react'; import { useRouter } from 'next/navigation'; import { Book } from '../db'; interface BookFormProps { book?: Book; isEditing?: boolean; } export default function BookForm({ book, isEditing = false }: BookFormProps) { const router = useRouter(); const [title, setTitle] = useState(book?.title || ''); const [author, setAuthor] = useState(book?.author || ''); const [year, setYear] = useState(book?.year?.toString() || ''); const [genre, setGenre] = useState(book?.genre || ''); const [readDate, setReadDate] = useState( book?.readDate ? new Date(book.readDate).toISOString().split('T')[0] : '' ); const handleSubmit = async (e: FormEvent) => { e.preventDefault(); const bookData = { title, author, year: year ? parseInt(year) : undefined, genre, read_date: readDate || undefined, }; try { if (isEditing && book) { // Update existing book await fetch(`/api/books/${book.id}`, { method: 'PUT', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(bookData), }); } else { // Create new book await fetch('/api/books', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(bookData), }); } router.push('/'); router.refresh(); } catch (error) { console.error('Error saving book:', error); } }; return (
setTitle(e.target.value)} required className="w-full px-4 py-2 bg-gray-800 text-white border border-gray-700 rounded focus:outline-none focus:ring-2 focus:ring-blue-500 transition duration-150" />
setAuthor(e.target.value)} className="w-full px-4 py-2 bg-gray-800 text-white border border-gray-700 rounded focus:outline-none focus:ring-2 focus:ring-blue-500 transition duration-150" />
setYear(e.target.value)} className="w-full px-4 py-2 bg-gray-800 text-white border border-gray-700 rounded focus:outline-none focus:ring-2 focus:ring-blue-500 transition duration-150" />
setGenre(e.target.value)} className="w-full px-4 py-2 bg-gray-800 text-white border border-gray-700 rounded focus:outline-none focus:ring-2 focus:ring-blue-500 transition duration-150" />
setReadDate(e.target.value)} className="w-full px-4 py-2 bg-gray-800 text-white border border-gray-700 rounded focus:outline-none focus:ring-2 focus:ring-blue-500 transition duration-150" />
); } .. edb:split-section:: Now, let's create the "Add Book" page that uses our form component. .. code-block:: typescript :caption: app/books/add/page.tsx 'use client'; import BookForm from "@/src/components/BookForm"; export default function AddBookPage() { return (

Add New Book

); } .. edb:split-section:: Let's also create a page to view book details and manage notes. .. code-block:: typescript :caption: src/app/books/[id]/page.tsx :class: collapsible 'use client'; import { useState, useEffect, FormEvent, use } from 'react'; import { useRouter } from 'next/navigation'; import Link from 'next/link'; import { BookWithNotes } from '@/src/db'; export default function BookDetailPage({ params }: { params: Promise<{ id: string }> }) { const { id } = use(params); const router = useRouter(); const [book, setBook] = useState(null); const [loading, setLoading] = useState(true); const [noteText, setNoteText] = useState(''); useEffect(() => { async function fetchBook() { try { const response = await fetch(`/api/books/${id}`); if (!response.ok) { if (response.status === 404) { router.push('/'); return; } throw new Error('Failed to fetch book'); } const data = await response.json(); setBook(data); } catch (error) { console.error('Error:', error); } finally { setLoading(false); } } fetchBook(); }, [id, router]); const handleAddNote = async (e: FormEvent) => { e.preventDefault(); if (!noteText.trim()) return; try { const response = await fetch(`/api/books/${id}/notes`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ text: noteText }), }); if (!response.ok) throw new Error('Failed to add note'); const newNote = await response.json(); setBook(prev => prev ? { ...prev, notes: [...prev.notes, newNote] } : null); setNoteText(''); } catch (error) { console.error('Error adding note:', error); } }; const handleDeleteNote = async (noteId: string) => { try { const response = await fetch(`/api/notes/${noteId}`, { method: 'DELETE', }); if (!response.ok) throw new Error('Failed to delete note'); setBook(prev => prev ? { ...prev, notes: prev.notes.filter(note => note.id !== noteId) } : null); } catch (error) { console.error('Error deleting note:', error); } }; const handleDeleteBook = async () => { if (!confirm('Are you sure you want to delete this book and all its notes?')) { return; } try { const response = await fetch(`/api/books/${id}`, { method: 'DELETE', }); if (!response.ok) throw new Error('Failed to delete book'); router.push('/'); } catch (error) { console.error('Error deleting book:', error); } }; if (loading) { return (

Loading...

); } if (!book) { return (

Book not found.

Back to All Books
); } return (
← Back

{book.title}

Edit
{book.author && (

by {book.author}

)}
{book.year &&

Published: {book.year}

} {book.genre &&

Genre: {book.genre}

} {book.readDate && (

Read on: {new Date(book.readDate).toLocaleDateString()}

)}

Notes

setNoteText(e.target.value)} placeholder="Add a new note..." className="flex-grow px-4 py-2 bg-gray-700 text-white placeholder-gray-400 border border-gray-600 rounded-l focus:outline-none focus:ring-2 focus:ring-blue-500 transition duration-150" />
{book.notes.length === 0 ? (

No notes yet. Add your first note above.

) : (
    {book.notes.map((note) => (
  • {note.text}

    {note.createdAt && (

    {new Date(note.createdAt).toLocaleString()}

    )}
  • ))}
)}
); } .. edb:split-section:: For a complete application, you would also need to implement an edit page for books. Here's a simplified example: .. code-block:: typescript :caption: src/app/books/[id]/edit/page.tsx :class: collapsible 'use client'; import { useState, useEffect, use } from 'react'; import { useRouter } from 'next/navigation'; import { Book } from '@/src/db'; import BookForm from '@/src/components/BookForm'; export default function EditBookPage({ params }: { params: Promise<{ id: string }> }) { const router = useRouter(); const { id } = use(params); const [book, setBook] = useState(null); const [loading, setLoading] = useState(true); useEffect(() => { async function fetchBook() { try { const response = await fetch(`/api/books/${id}`); if (!response.ok) { if (response.status === 404) { router.push('/'); return; } throw new Error('Failed to fetch book'); } const data = await response.json(); setBook(data); } catch (error) { console.error('Error:', error); } finally { setLoading(false); } } fetchBook(); }, [id, router]); if (loading) { return (

Loading...

); } if (!book) { return (

Book not found.

); } return (

Edit Book

); } .. edb:split-section:: These UI components provide a basic but functional user interface for our Book Notes application. Tailwind CSS helps us create a clean and responsive design with minimal effort. Since we're focusing on the Gel and Drizzle integration, we won't detail every UI component, but the pattern is consistent throughout the application: - We use React hooks for state management (useState, useEffect) - We call our API endpoints to fetch and modify data - We use Tailwind CSS classes for styling the components - We implement client-side navigation with Next.js's useRouter 7. Testing the application =========================== .. edb:split-section:: Now that we have built our API routes and basic UI, let's test our application. Start the development server: .. code-block:: bash $ npm run dev .. edb:split-section:: Navigate to http://localhost:3000 in your browser, and you should see your Book Notes application. Try performing these operations to ensure everything is working correctly: 1. Adding a new book 2. Viewing book details 3. Adding notes to a book 4. Editing book information 5. Deleting notes 6. Deleting a book (which should also delete its notes) If you encounter any issues, check your browser's developer console and the terminal running your Next.js server for error messages. 8. Next steps ============== Congratulations! You've built a Book Notes application using Gel, Drizzle, and Next.js. This tutorial demonstrated how these technologies can work together to create a full-stack application. Here are some ideas for extending the application: 1. **Add authentication**: Implement user authentication to allow multiple users to have their own book collections. 2. **Advanced filtering**: Add the ability to filter books by genre, author, or reading status. 3. **Book statistics**: Create a dashboard with statistics about your reading habits. 4. **Reading goals**: Implement a feature to set and track reading goals. 5. **Book recommendations**: Add a feature to recommend books based on what you've already read. 6. **Import/Export**: Allow users to import or export their book data. 7. **Search functionality**: Implement full-text search across books and notes. To further explore the capabilities of Gel and Drizzle, you can check out these resources: - `Gel Documentation `_ - `Drizzle ORM Documentation `_ - `Next.js Documentation `_ Remember, you can find the complete source code for this tutorial in our `Gel Examples repository `_. Happy coding! ================================================ FILE: docs/intro/tutorials/index.rst ================================================ ========= Tutorials ========= .. toctree:: :maxdepth: 2 ai_fastapi_searchbot gel_drizzle_booknotes ================================================ FILE: docs/redirects ================================================ /guides/cloud -> /cloud /cli/edgedb -> docs/cli/gel /cli/edgedb_analyze -> docs/cli/gel_analyze /cli/edgedb_branch/edgedb_branch_create -> docs/cli/gel_branch/gel_branch_create /cli/edgedb_branch/edgedb_branch_drop -> docs/cli/gel_branch/gel_branch_drop /cli/edgedb_branch/edgedb_branch_list -> docs/cli/gel_branch/gel_branch_list /cli/edgedb_branch/edgedb_branch_merge -> docs/cli/gel_branch/gel_branch_merge /cli/edgedb_branch/edgedb_branch_rebase -> docs/cli/gel_branch/gel_branch_rebase /cli/edgedb_branch/edgedb_branch_rename -> docs/cli/gel_branch/gel_branch_rename /cli/edgedb_branch/edgedb_branch_switch -> docs/cli/gel_branch/gel_branch_switch /cli/edgedb_branch/edgedb_branch_wipe -> docs/cli/gel_branch/gel_branch_wipe /cli/edgedb_branch/index -> docs/cli/gel_branch/index /cli/edgedb_cli_upgrade -> docs/cli/gel_cli_upgrade /cli/edgedb_cloud/edgedb_cloud_login -> docs/cli/gel_cloud/gel_cloud_login /cli/edgedb_cloud/edgedb_cloud_logout -> docs/cli/gel_cloud/gel_cloud_logout /cli/edgedb_cloud/edgedb_cloud_secretkey/edgedb_cloud_secretkey_create -> docs/cli/gel_cloud/gel_cloud_secretkey/edgedb_cloud_secretkey_create /cli/edgedb_cloud/edgedb_cloud_secretkey/edgedb_cloud_secretkey_list -> docs/cli/gel_cloud/gel_cloud_secretkey/edgedb_cloud_secretkey_list /cli/edgedb_cloud/edgedb_cloud_secretkey/edgedb_cloud_secretkey_revoke -> docs/cli/gel_cloud/gel_cloud_secretkey/edgedb_cloud_secretkey_revoke /cli/edgedb_cloud/edgedb_cloud_secretkey/index -> docs/cli/gel_cloud/gel_cloud_secretkey/index /cli/edgedb_cloud/index -> docs/cli/gel_cloud/index /cli/edgedb_configure -> docs/cli/gel_configure /cli/edgedb_connopts -> docs/cli/gel_connopts /cli/edgedb_database/edgedb_database_create -> docs/cli/gel_database/gel_database_create /cli/edgedb_database/edgedb_database_drop -> docs/cli/gel_database/gel_database_drop /cli/edgedb_database/edgedb_database_wipe -> docs/cli/gel_database/gel_database_wipe /cli/edgedb_database/index -> docs/cli/gel_database/index /cli/edgedb_describe/edgedb_describe_object -> docs/cli/gel_describe/gel_describe_object /cli/edgedb_describe/edgedb_describe_schema -> docs/cli/gel_describe/gel_describe_schema /cli/edgedb_describe/index -> docs/cli/gel_describe/index /cli/edgedb_dump -> docs/cli/gel_dump /cli/edgedb_info -> docs/cli/gel_info /cli/edgedb_instance/edgedb_instance_create -> docs/cli/gel_instance/gel_instance_create /cli/edgedb_instance/edgedb_instance_credentials -> docs/cli/gel_instance/gel_instance_credentials /cli/edgedb_instance/edgedb_instance_destroy -> docs/cli/gel_instance/gel_instance_destroy /cli/edgedb_instance/edgedb_instance_link -> docs/cli/gel_instance/gel_instance_link /cli/edgedb_instance/edgedb_instance_list -> docs/cli/gel_instance/gel_instance_list /cli/edgedb_instance/edgedb_instance_logs -> docs/cli/gel_instance/gel_instance_logs /cli/edgedb_instance/edgedb_instance_reset_password -> docs/cli/gel_instance/gel_instance_reset_password /cli/edgedb_instance/edgedb_instance_restart -> docs/cli/gel_instance/gel_instance_restart /cli/edgedb_instance/edgedb_instance_revert -> docs/cli/gel_instance/gel_instance_revert /cli/edgedb_instance/edgedb_instance_start -> docs/cli/gel_instance/gel_instance_start /cli/edgedb_instance/edgedb_instance_status -> docs/cli/gel_instance/gel_instance_status /cli/edgedb_instance/edgedb_instance_stop -> docs/cli/gel_instance/gel_instance_stop /cli/edgedb_instance/edgedb_instance_unlink -> docs/cli/gel_instance/gel_instance_unlink /cli/edgedb_instance/edgedb_instance_upgrade -> docs/cli/gel_instance/gel_instance_upgrade /cli/edgedb_instance/index -> docs/cli/gel_instance/index /cli/edgedb_list -> docs/cli/gel_list /cli/edgedb_migrate -> docs/cli/gel_migrate /cli/edgedb_migration/edgedb_migration_apply -> docs/cli/gel_migration/gel_migration_apply /cli/edgedb_migration/edgedb_migration_create -> docs/cli/gel_migration/gel_migration_create /cli/edgedb_migration/edgedb_migration_edit -> docs/cli/gel_migration/gel_migration_edit /cli/edgedb_migration/edgedb_migration_extract -> docs/cli/gel_migration/gel_migration_extract /cli/edgedb_migration/edgedb_migration_log -> docs/cli/gel_migration/gel_migration_log /cli/edgedb_migration/edgedb_migration_status -> docs/cli/gel_migration/gel_migration_status /cli/edgedb_migration/edgedb_migration_upgrade_check -> docs/cli/gel_migration/gel_migration_upgrade_check /cli/edgedb_migration/index -> docs/cli/gel_migration/index /cli/edgedb_project/edgedb_project_info -> docs/cli/gel_project/gel_project_info /cli/edgedb_project/edgedb_project_init -> docs/cli/gel_project/gel_project_init /cli/edgedb_project/edgedb_project_unlink -> docs/cli/gel_project/gel_project_unlink /cli/edgedb_project/edgedb_project_upgrade -> docs/cli/gel_project/gel_project_upgrade /cli/edgedb_project/index -> docs/cli/gel_project/index /cli/edgedb_query -> docs/cli/gel_query /cli/edgedb_restore -> docs/cli/gel_restore /cli/edgedb_server/edgedb_server_info -> docs/cli/gel_server/gel_server_info /cli/edgedb_server/edgedb_server_install -> docs/cli/gel_server/gel_server_install /cli/edgedb_server/edgedb_server_list_versions -> docs/cli/gel_server/gel_server_list_versions /cli/edgedb_server/edgedb_server_uninstall -> docs/cli/gel_server/gel_server_uninstall /cli/edgedb_server/index -> docs/cli/gel_server/index /cli/edgedb_ui -> docs/cli/gel_ui /cli/edgedb_watch -> docs/cli/gel_watch ================================================ FILE: docs/redirects.js ================================================ // See https://nextjs.org/docs/app/api-reference/config/next-config-js/redirects module.exports = [ { source: "/changelog/:path*", destination: "/resources/changelog/:path*", permanent: false, }, { source: "/guides/cheatsheet/:path*", destination: "/resources/cheatsheets/:path*", permanent: true, }, { source: "/guides/ai/:path*", destination: "/ai/:path*", permanent: true, }, { source: "/changelog/6_x", destination: "/resources/changelog/6_x", permanent: true, }, { source: "/database/:path*", destination: "/reference/:path*", permanent: false, }, { source: "/guides/:path*", destination: "/resources/guides/:path*", permanent: false, }, { source: "/reference/reference/bindings/datetime", destination: "/reference/using/datetime", permanent: false, }, { source: "/reference/reference/bindings", destination: "/reference/using", permanent: false, }, { source: "/reference/clients/go/:path*", destination: "https://pkg.go.dev/github.com/geldata/gel-go", permanent: false, }, { source: "/reference/clients/rust/:path*", destination: "https://docs.rs/gel-tokio", permanent: false, }, { source: "/reference/libraries/dotnet/:path*", destination: "https://github.com/geldata/gel-net", permanent: false, }, { source: "/reference/libraries/elixir/:path*", destination: "https://hexdocs.pm/gel", permanent: false, }, { source: "/reference/libraries/java/:path*", destination: "https://github.com/geldata/gel-java", permanent: false, }, // Use the further redirects to get to the correct point in the consolidated docs { source: "/reference/libraries/js/:path*", destination: "/reference/clients/js/:path*", permanent: false, }, { source: "/reference/libraries/python/:path*", destination: "/reference/clients/python/:path*", permanent: false, }, { source: "/reference/libraries/:path*", destination: "/reference/using", permanent: false, }, { source: "/reference/clients/js/delete#delete", destination: "/reference/using/js/querybuilder", permanent: false, }, { source: "/reference/clients/js/driver", destination: "/reference/using/js", permanent: false, }, { source: "/reference/clients/js/for", destination: "/reference/using/js/querybuilder#for", permanent: false, }, { source: "/reference/clients/js/funcops", destination: "/reference/using/js/querybuilder#functions-and-operators", permanent: false, }, { source: "/reference/clients/js/group", destination: "/reference/using/js/querybuilder#group", permanent: false, }, { source: "/reference/clients/js/insert", destination: "/reference/using/js/querybuilder#insert", permanent: false, }, { source: "/reference/clients/js/literals", destination: "/reference/using/js/querybuilder#types-and-literals", permanent: false, }, { source: "/reference/clients/js/objects", destination: "/reference/using/js/querybuilder#objects-and-paths", permanent: false, }, { source: "/reference/clients/js/parameters", destination: "/reference/using/js/querybuilder#parameters", permanent: false, }, { source: "/reference/clients/js/select", destination: "/reference/using/js/querybuilder#select", permanent: false, }, { source: "/reference/clients/js/types", destination: "/reference/using/js/querybuilder#types-and-literals", permanent: false, }, { source: "/reference/clients/js/update", destination: "/reference/using/js/querybuilder#update", permanent: false, }, { source: "/reference/clients/js/with", destination: "/reference/using/js/querybuilder#with-blocks", permanent: false, }, { source: "/reference/clients/js/reference", destination: "/reference/using/js/client#client-reference", permanent: false, }, { source: "/reference/reference/connection", destination: "/reference/using/connection", permanent: false, }, { source: "/reference/reference/dsn", destination: "/reference/using/connection#dsn", permanent: false, }, { source: "/reference/clients/python/api/asyncio_client", destination: "/reference/using/python/client#asyncio-client", permanent: false, }, { source: "/reference/clients/python/api/blocking_client", destination: "/reference/using/python/client#blocking-client", permanent: false, }, { source: "/reference/clients/python/installation", destination: "/reference/using/python#installation", permanent: false, }, { source: "/reference/clients/python/usage", destination: "/reference/using/python#basic-usage", permanent: false, }, { source: "/reference/clients/http/health-checks", destination: "/reference/running/http#health-checks", permanent: false, }, { source: "/reference/clients/http/protocol", destination: "/reference/using/http", permanent: false, }, { source: "/reference/clients/:path*", destination: "/reference/using/:path*", permanent: false, }, { source: "/reference/reference/configuration", destination: "/reference/running/configuration", permanent: false, }, { source: "/reference/reference/environment", destination: "/reference/running/configuration#environment-variables", permanent: false, }, { source: "/reference/reference/gel_toml", destination: "/reference/using/projects#gel-toml", permanent: false, }, { source: "/reference/reference/http", destination: "/reference/running/http", permanent: false, }, { source: "/reference/reference/projects", destination: "/reference/using/projects", permanent: false, }, { source: "/reference/reference/protocol/:path*", destination: "/resources/protocol/:path*", permanent: false, }, { source: "/reference/reference/admin/databases", destination: "/reference/datamodel/branches", permanent: false, }, { source: "/reference/reference/admin/:path*", destination: "/reference/running/admin/:path*", permanent: false, }, { source: "/reference/reference/backend-ha", destination: "/reference/running/backend-ha", permanent: false, }, { source: "/resources/guides/deployment/:path*", destination: "/reference/running/deployment/:path*", permanent: false, }, { source: "/reference/reference/postgis", destination: "/reference/stdlib/postgis", permanent: true, }, { source: "/reference/cli/:path*", destination: "/reference/using/cli/:path*", permanent: false, }, { source: "/cli/:path*", destination: "/reference/using/cli/:path*", permanent: false, }, ]; ================================================ FILE: docs/reference/ai/extai.rst ================================================ .. _ref_ai_extai_reference: ======= ext::ai ======= This reference documents the |Gel| ``ext::ai`` extension components, configuration options, and database APIs. Enabling the Extension ====================== The AI extension can be enabled using the :ref:`extension ` mechanism: .. code-block:: sdl using extension ai; Configuration ============= The AI extension can be configured using ``configure session`` or ``configure current branch``: .. code-block:: edgeql configure current branch set ext::ai::Config::indexer_naptime := 'PT30S'; Configuration Properties ------------------------ * ``indexer_naptime``: Duration Specifies minimum delay between deferred ``ext::ai::index`` indexer runs. View current configuration: .. code-block:: edgeql select cfg::Config.extensions[is ext::ai::Config]{*}; Reset configuration: .. code-block:: edgeql configure current branch reset ext::ai::Config::indexer_naptime; .. _ref_ai_extai_reference_ui: UI == The AI section of the UI can be accessed via the sidebar after the extension has been enabled in the schema. It provides ways to manage provider configurations and RAG prompts, as well as try out different settings in the playground. Playground tab -------------- Provides an interactive environment for testing and configuring the built-in RAG. .. image:: images/ui_playground.png :alt: Screenshot of the Playground tab of the UI depicting an empty message window and three input fields set with default values. :width: 100% Components: * Message window: Displays conversation history between the user and the LLM. * Model: Dropdown menu for selecting the text generation model. * Prompt: Dropdown menu for selecting the RAG prompt template. * Context Query: Input field for entering an EdgeQL expression returning a set of objects with AI indexes. Prompts tab ----------- Provides ways to manage system prompts used in the built-in RAG. .. image:: images/ui_prompts.png :alt: Screenshot of the Prompts tab of the UI depicting an expanded prompt configuration menu. :width: 100% Providers tab ------------- Enables management of API configurations for AI API providers. .. image:: images/ui_providers.png :alt: Screenshot of the Providers tab of the UI depicting an expanded provider configuration menu. :width: 100% .. _ref_ai_extai_reference_index: Index ===== The ``ext::ai::index`` creates a deferred semantic similarity index of an expression on a type. .. code-block:: sdl-diff module default { type Astronomy { content: str; + deferred index ext::ai::index(embedding_model := 'text-embedding-3-small') + on (.content); } }; Parameters: * ``embedding_model``- The name of the model to use for embedding generation as a string. * ``distance_function``- The function to use for determining semantic similarity. Default: ``ext::ai::DistanceFunction.Cosine`` * ``index_type``- The type of index to create. Currently the only option is the default: ``ext::ai::IndexType.HNSW``. * ``index_parameters``- A named tuple of additional index parameters: * ``m``- The maximum number of edges of each node in the graph. Increasing can increase the accuracy of searches at the cost of index size. Default: ``32`` * ``ef_construction``- Dictates the depth and width of the search when building the index. Higher values can lead to better connections and more accurate results at the cost of time and resource usage when building the index. Default: ``100`` * ``dimensions``: int64 (Optional) - Embedding dimensions * ``truncate_to_max``: bool (Default: False) Functions ========= .. list-table:: :class: funcoptable * - :eql:func:`ext::ai::to_context` - :eql:func-desc:`ext::ai::to_context` * - :eql:func:`ext::ai::search` - :eql:func-desc:`ext::ai::search` ------------ .. eql:function:: ext::ai::to_context(object: anyobject) -> str Returns the indexed expression value for an object with an ``ext::ai::index``. **Example**: Schema: .. code-block:: sdl module default { type Astronomy { topic: str; content: str; deferred index ext::ai::index(embedding_model := 'text-embedding-3-small') on (.topic ++ ' ' ++ .content); } }; Data: .. code-block:: edgeql-repl db> insert Astronomy { ... topic := 'Mars', ... content := 'Skies on Mars are red.' ... } db> insert Astronomy { ... topic := 'Earth', ... content := 'Skies on Earth are blue.' ... } Results of calling ``to_context``: .. code-block:: edgeql-repl db> select ext::ai::to_context(Astronomy); {'Mars Skies on Mars are red.', 'Earth Skies on Earth are blue.'} ------------ .. eql:function:: ext::ai::search( \ object: anyobject, \ query: array \ ) -> optional tuple ext::ai::search( \ object: anyobject, \ query: str \ ) -> optional tuple Searches objects using their :ref:`ai::index `. Returns tuples of (object, distance). .. versionadded:: 7.0 If the ``query`` is a ``str``, the ai extension will make an embedding request to the provider and use the result to compute distances. To prevent unwanted provider calls, this functionality may only be used by roles with the :eql:permission:`ext::ai::perm::provider_call` permission. .. code-block:: edgeql-repl db> with query := >$query ... select ext::ai::search(Knowledge, query); { ( object := default::Knowledge {id: 9af0d0e8-0880-11ef-9b6b-4335855251c4}, distance := 0.20410746335983276 ), ( object := default::Knowledge {id: eeacf638-07f6-11ef-b9e9-57078acfce39}, distance := 0.7843298847773637 ), ( object := default::Knowledge {id: f70863c6-07f6-11ef-b9e9-3708318e69ee}, distance := 0.8560434728860855 ), } Scalar and Object Types ======================= Provider Configuration Types ---------------------------- .. list-table:: :class: funcoptable * - :eql:type:`ext::ai::ProviderAPIStyle` - Enum defining supported API styles * - :eql:type:`ext::ai::ProviderConfig` - Abstract base configuration for AI providers. Provider configurations are required for AI indexes and RAG functionality. Example provider configuration: .. code-block:: edgeql configure current database insert ext::ai::OpenAIProviderConfig { secret := 'sk-....', }; .. note:: All provider types require the ``secret`` property be set with a string containing the secret provided by the AI vendor. .. note:: ``ext::ai::CustomProviderConfig requires an ``api_style`` property be set. --------- .. eql:type:: ext::ai::ProviderAPIStyle Enum defining supported API styles: * ``OpenAI`` * ``Anthropic`` --------- .. eql:type:: ext::ai::ProviderConfig Abstract base configuration for AI providers. Properties: * ``name``: str (Required) - Unique provider identifier * ``display_name``: str (Required) - Human-readable name * ``api_url``: str (Required) - Provider API endpoint * ``client_id``: str (Optional) - Provider-supplied client ID * ``secret``: str (Required) - Provider API secret * ``api_style``: ProviderAPIStyle (Required) - Provider's API style Provider-specific types: * ``ext::ai::OpenAIProviderConfig`` * ``ext::ai::MistralProviderConfig`` * ``ext::ai::AnthropicProviderConfig`` * ``ext::ai::CustomProviderConfig`` Each inherits from :eql:type:`ext::ai::ProviderConfig` with provider-specific defaults. Model Types ----------- .. list-table:: :class: funcoptable * - :eql:type:`ext::ai::Model` - Abstract base type for AI models. * - :eql:type:`ext::ai::EmbeddingModel` - Abstract type for embedding models. * - :eql:type:`ext::ai::TextGenerationModel` - Abstract type for text generation models. .. _ref_ai_extai_reference_embedding_models: Embedding models ^^^^^^^^^^^^^^^^ OpenAI (`documentation `__) * ``text-embedding-3-small`` * ``text-embedding-3-large`` * ``text-embedding-ada-002`` Mistral (`documentation `__) * ``mistral-embed`` Ollama (`documentation `__) * ``nomic-embed-text`` * ``bge-m3`` .. _ref_ai_extai_reference_text_generation_models: Text generation models ^^^^^^^^^^^^^^^^^^^^^^ OpenAI (`documentation `__) * ``gpt-3.5-turbo`` * ``gpt-4-turbo-preview`` Mistral (`documentation `__) * ``mistral-small-latest`` * ``mistral-medium-latest`` * ``mistral-large-latest`` Anthropic (`documentation `__) * ``claude-3-haiku-20240307`` * ``claude-3-sonnet-20240229`` * ``claude-3-opus-20240229`` Ollama (`documentation `__) * ``llama3.2`` * ``llama3.3`` When using RAG, It is possible to specify a text generation model using a URI, combining the provider name (in lower case), and the model name. - eg. ``"openai:gpt-5"`` - eg. ``"anthropic:claude-opus-4-20250514"`` Using this form allows text generation from models which are not explicitly instantiated as a :eql:type:`ext::ai::TextGenerationModel` --------- .. eql:type:: ext::ai::Model Abstract base type for AI models. Annotations: * ``model_name`` - Model identifier * ``model_provider`` - Provider identifier --------- .. eql:type:: ext::ai::EmbeddingModel Abstract type for embedding models. Annotations: * ``embedding_model_max_input_tokens`` - Maximum tokens per input * ``embedding_model_max_batch_tokens`` - Maximum tokens per batch. Default: ``'8191'``. * ``embedding_model_max_batch_size`` - Maximum inputs per batch. Optional. * ``embedding_model_max_output_dimensions`` - Maximum embedding dimensions * ``embedding_model_supports_shortening`` - Input shortening support flag --------- .. eql:type:: ext::ai::TextGenerationModel Abstract type for text generation models. Annotations: * ``text_gen_model_context_window`` - Model's context window size Indexing Types -------------- .. list-table:: :class: funcoptable * - :eql:type:`ext::ai::DistanceFunction` - Enum for similarity metrics. * - :eql:type:`ext::ai::IndexType` - Enum for index implementations. --------- .. eql:type:: ext::ai::DistanceFunction Enum for similarity metrics. * ``Cosine`` * ``InnerProduct`` * ``L2`` --------- .. eql:type:: ext::ai::IndexType Enum for index implementations. * ``HNSW`` Prompt Types ------------ .. list-table:: :class: funcoptable * - :eql:type:`ext::ai::ChatParticipantRole` - Enum for chat roles. * - :eql:type:`ext::ai::ChatPromptMessage` - Type for chat prompt messages. * - :eql:type:`ext::ai::ChatPrompt` - Type for chat prompt configuration. Example custom prompt configuration: .. code-block:: edgeql insert ext::ai::ChatPrompt { name := 'test-prompt', messages := ( insert ext::ai::ChatPromptMessage { participant_role := ext::ai::ChatParticipantRole.System, content := "Your message content" } ) }; --------- .. eql:type:: ext::ai::ChatParticipantRole Enum for chat roles. * ``System`` * ``User`` * ``Assistant`` * ``Tool`` --------- .. eql:type:: ext::ai::ChatPromptMessage Type for chat prompt messages. Properties: * ``participant_role``: ChatParticipantRole (Required) * ``participant_name``: str (Optional) * ``content``: str (Required) --------- .. eql:type:: ext::ai::ChatPrompt Type for chat prompt configuration. Properties: * ``name``: str (Required) * ``messages``: set of ChatPromptMessage (Required) Permissions =========== .. _ref_ai_extai_reference_permissions: .. versionadded:: 7.0 .. list-table:: :class: funcoptable * - :eql:permission:`ext::ai::perm::provider_call` * - :eql:permission:`ext::ai::perm::chat_prompt_read` * - :eql:permission:`ext::ai::perm::chat_prompt_write` --------- .. eql:permission:: ext::ai::perm::provider_call Gives permission to make ai provider calls. Required to call :eql:func:`ext::ai::search` using text directly instead of an already generated embedding. --------- .. eql:permission:: ext::ai::perm::chat_prompt_read Gives permission to read chat prompt configuration. --------- .. eql:permission:: ext::ai::perm::chat_prompt_write Gives permission to modify chat prompt configuration. ================================================ FILE: docs/reference/ai/extvectorstore.rst ================================================ :orphan: .. _ref_extvectorstore_reference: ================ ext::vectorstore ================ The ``ext::vectorstore`` extension package provides simplified vectorstore workflows for |Gel|, built on top of the pgvector integration. It includes predefined vector dimensions and a base schema for vector storage records. Enabling the extension ====================== The extension package can be installed using the :gelcmd:`extension` CLI command: .. code-block:: bash $ gel extension install vectorstore It can be enabled using the :ref:`extension ` mechanism: .. code-block:: sdl using extension vectorstore; The Vectorstore extension is designed to be used in combination with the :ref:`Vectostore Python binding ` or other integrations, rather than on its own. Types ===== Vector Types ------------ The extension provides two pre-defined vector types with different dimensions: - ``ext::vectorstore::vector_1024``: 1024-dimensional vector - ``ext::vectorstore::vector_1536``: 1536-dimensional vector All vector types extend ``ext::pgvector::vector`` with their respective dimensions. Record Types ------------ .. eql:type:: ext::vectorstore::BaseRecord Abstract type that defines the basic structure for vector storage records. Properties: * ``collection: str`` (required): Identifies the collection the record belongs to * ``text: str``: Associated text content * ``embedding: ext::pgvector::vector``: The vector embedding * ``external_id: str``: External identifier with unique constraint * ``metadata: json``: Additional metadata in JSON format .. eql:type:: ext::vectorstore::DefaultRecord Extends :eql:type:`ext::vectorstore::BaseRecord` with specific configurations. Properties: * Inherits all properties from :eql:type:`ext::vectorstore::BaseRecord` * Specializes ``embedding`` to use ``vector_1536`` type * Includes an HNSW cosine similarity index on the embedding with: * ``m = 16`` * ``ef_construction = 128`` ================================================ FILE: docs/reference/ai/http.rst ================================================ .. _ref_ai_http_reference: ======== HTTP API ======== .. note:: All |Gel| server HTTP endpoints require :ref:`authentication `, such as `HTTP Basic Authentication `_ with Gel username and password. Embeddings ========== ``POST``: ``https://:/branch//ai/embeddings`` Generates text embeddings using the specified embeddings model. Request headers --------------- * ``Content-Type: application/json`` (required) Request body ------------ * ``inputs`` (array of strings, or single string, required): The text items to use as the basis for embeddings generation. * ``model`` (string, required): The name of the embedding model to use. You may use any of the supported :ref:`embedding models `. * ``dimensions`` (number, optional): The number of dimensions to truncate to. * ``user`` (string, optional): A user identifier for the request. Example request --------------- .. code-block:: bash $ curl --user : --json '{\ "inputs": ["What color is the sky on Mars?"],\ "model": "text-embedding-3-small"\ }' http://localhost:10931/branch/main/ai/embeddings Response -------- * **HTTP status**: 200 OK * **Content-Type**: application/json * **Body**: .. code-block:: json { "object": "list", "data": [ { "object": "embedding", "index": 0, "embedding": [-0.009434271, 0.009137661] } ], "model": "text-embedding-3-small", "usage": { "prompt_tokens": 8, "total_tokens": 8 } } .. note:: The ``embedding`` property is shown here with only two values for brevity, but an actual response would contain many more values. Error response -------------- * **HTTP status**: 400 Bad Request * **Content-Type**: application/json * **Body**: .. code-block:: json { "message": "missing or empty required \"model\" value in request", "type": "BadRequestError" } RAG === ``POST``: ``https://:/branch//ai/rag`` Performs retrieval-augmented text generation using the specified model based on the provided text query and the database content selected using similarity search. Request headers --------------- * ``Content-Type: application/json`` (required) Request body ------------ * ``context`` (object, required): Settings that define the context of the query. * ``query`` (string, required): Specifies an expression to determine the relevant objects and index to serve as context for text generation. You may set this to any expression that produces a set of objects, even if it is not a standalone query. * ``variables`` (object, optional): A dictionary of variables for use in the context query. * ``globals`` (object, optional): A dictionary of globals for use in the context query. * ``max_object_count`` (number, optional): Maximum number of objects to retrieve; default is 5. * ``model`` (string, required): The name of the text generation model to use. It is possible to specify the model name as a URI, eg. ``openai:gpt-5``. See: :ref:`text generation models `. * ``query`` (string, required): The query string used as the basis for text generation. * ``stream`` (boolean, optional): Specifies whether the response should be streamed. Defaults to false. * ``prompt`` (object, optional): Settings that define a prompt. Omit to use the default prompt. * ``name`` (string, optional): Name of predefined prompt. * ``id`` (string, optional): ID of predefined prompt. * ``custom`` (array of objects, optional): Custom prompt messages, each containing a ``role`` and ``content``. If no ``name`` or ``id`` was provided, the custom messages provided here become the prompt. If one of those was provided, these messages will be added to that existing prompt. * ``role`` (string): "system", "user", "assistant", or "tool". * ``content`` (string | array): Content of the message. * For ``role: "system"``: Must be a string. * For ``role: "user"``: Must be an array of content blocks, e.g., ``[{"type": "text", "text": "..."}]``. * For ``role: "assistant"``: Must be a string (the assistant's text response). May optionally include ``tool_calls``. * For ``role: "tool"``: Must be a string (the result of the tool call). Requires ``tool_call_id``. * ``tool_call_id`` (string, optional): Identifier for the tool call whose result this message represents (required if ``role: "tool"``). * ``tool_calls`` (array, optional): Array of tool calls requested by the assistant (used if ``role: "assistant"``). Each object should follow the format: ``{"id": "...", "type": "function", "function": {"name": "...", "arguments": "..."}}``. Arguments should be a JSON string. * ``temperature`` (number, optional): Sampling temperature. * ``top_p`` (number, optional): Nucleus sampling parameter. * ``max_tokens`` (number, optional): Maximum tokens to generate. * ``seed`` (number, optional): Random seed. * ``safe_prompt`` (boolean, optional): Enable safety features. * ``top_k`` (number, optional): Top-k sampling parameter. * ``logit_bias`` (object, optional): Token biasing. * ``logprobs`` (number, optional): Return token log probabilities. * ``user`` (string, optional): User identifier. * ``tools`` (array, optional): A list of tools the model may call. Each tool has a ``type`` ("function") and a ``function`` object with ``name``, ``description`` (optional), and ``parameters`` (JSON schema). Example: ``[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}]`` Example request --------------- .. code-block:: bash $ curl --user : --json '{\ "query": "What color is the sky on Mars?",\ "model": "gpt-4-turbo-preview",\ "context": {"query":"Knowledge"}\ }' http://:/branch/main/ai/rag Response -------- * **HTTP status**: 200 OK * **Content-Type**: application/json * **Body**: A JSON object containing the RAG response details. .. code-block:: json { "id": "chatcmpl-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx", "model": "gpt-4-turbo-preview", "text": "The sky on Mars typically appears butterscotch or reddish due to the fine dust particles suspended in the atmosphere.", "finish_reason": "stop", "usage": { "prompt_tokens": 50, "completion_tokens": 30, "total_tokens": 80 }, "logprobs": null, "tool_calls": null } * ``id`` (string): Unique identifier for the chat completion. * ``model`` (string): The model used for the chat completion. * ``text`` (string | null): The main text content of the response message. * ``finish_reason`` (string | null): The reason the model stopped generating tokens (e.g., "stop", "length", "tool_calls"). * ``usage`` (object | null): Token usage statistics for the request. * ``logprobs`` (object | null): Log probability information for the generated tokens (if requested). * ``tool_calls`` (array | null): Any tool calls requested by the model. Each element contains ``id``, ``type`` ("function"), ``name``, and ``args`` (parsed JSON object). Error response -------------- * **HTTP status**: 400 Bad Request * **Content-Type**: application/json * **Body**: .. code-block:: json { "message": "missing required 'query' in request 'context' object", "type": "BadRequestError" } Streaming response (SSE) ------------------------ When the ``stream`` parameter is set to ``true``, the server uses `Server-Sent Events `__ (SSE) to stream responses. Here is a detailed breakdown of the typical sequence and structure of events in a streaming response: * **HTTP Status**: 200 OK * **Content-Type**: text/event-stream * **Cache-Control**: no-cache The stream consists of a sequence of five events, each encapsulating part of the response in a structured format: 1. **Message start** * Event type: ``message_start`` * Data: Starts a message, specifying identifiers, roles, and initial usage. .. code-block:: json { "type": "message_start", "message": { "id": "", "role": "assistant", "model": "", "usage": { "prompt_tokens": 10 } } } 2. **Content block start** * Event type: ``content_block_start`` * Data: Marks the beginning of a new content block (either text or a tool call). .. code-block:: json { "type": "content_block_start", "index": 0, "content_block": { "type": "text", "text": "" } } Or for a tool call: .. code-block:: json { "type": "content_block_start", "index": 0, "content_block": { "id": "", "type": "tool_use", "name": "", "args": "{..." } } 3. **Content block delta** * Event type: ``content_block_delta`` * Data: Incrementally updates the content, appending more text or tool arguments. Includes logprobs if requested. .. code-block:: json { "type": "content_block_delta", "index": 0, "delta": { "type": "text_delta", "text": "The" }, "logprobs": null } Or for tool arguments: .. code-block:: json { "type": "content_block_delta", "index": 0, "delta": { "type": "tool_call_delta", "args": "{\"location" } } Subsequent ``content_block_delta`` events add more text/arguments to the message. 4. **Content block stop** * Event type: ``content_block_stop`` * Data: Marks the end of a content block. .. code-block:: json { "type": "content_block_stop", "index": 0 } 5. **Message delta** * Event type: ``message_delta`` * Data: Provides final message-level updates like the stop reason and final usage statistics. .. code-block:: json { "type": "message_delta", "delta": { "stop_reason": "stop" }, "usage": { "prompt_tokens": 10 } } 6. **Message stop** * Event type: ``message_stop`` * Data: Marks the end of the message. .. code-block:: json {"type": "message_stop"} Each event is sent as a separate SSE message, formatted as shown above. The connection is closed after all events are sent, signaling the end of the stream. **Example SSE response** .. code-block:: text :class: collapsible event: message_start data: {"type": "message_start", "message": {"id": "chatcmpl-9MzuQiF0SxUjFLRjIdT3mTVaMWwiv", "role": "assistant", "model": "gpt-4-0125-preview", "usage": {"prompt_tokens": 10}}} event: content_block_start data: {"type": "content_block_start","index":0,"content_block":{"type":"text","text":""}} event: content_block_delta data: {"type": "content_block_delta","index":0,"delta":{"type": "text_delta", "text": "The"}, "logprobs": null} event: content_block_delta data: {"type": "content_block_delta","index":0,"delta":{"type": "text_delta", "text": " skies"}, "logprobs": null} event: content_block_delta data: {"type": "content_block_delta","index":0,"delta":{"type": "text_delta", "text": " on"}, "logprobs": null} event: content_block_delta data: {"type": "content_block_delta","index":0,"delta":{"type": "text_delta", "text": " Mars"}, "logprobs": null} event: content_block_delta data: {"type": "content_block_delta","index":0,"delta":{"type": "text_delta", "text": " are"}, "logprobs": null} event: content_block_delta data: {"type": "content_block_delta","index":0,"delta":{"type": "text_delta", "text": " red"}, "logprobs": null} event: content_block_delta data: {"type": "content_block_delta","index":0,"delta":{"type": "text_delta", "text": "."}, "logprobs": null} event: content_block_stop data: {"type": "content_block_stop","index":0} event: message_delta data: {"type": "message_delta", "delta": {"stop_reason": "stop"}, "usage": {"completion_tokens": 7, "total_tokens": 17}} event: message_stop data: {"type": "message_stop"} ================================================ FILE: docs/reference/ai/index.rst ================================================ .. _ref_ai_overview: == AI == :edb-alt-title: Using Gel AI .. toctree:: :hidden: :maxdepth: 3 extai http python javascript |Gel| AI is a set of tools designed to enable you to ship AI-enabled apps with practically no effort. 1. ``ext::ai``: this Gel extension automatically generates embeddings for your data. Works with OpenAI, Mistral AI, Anthropic, and any other provider with a compatible API. 2. Python library: ``gel.ai``. Access all Gel AI features straight from your Python application. 3. JavaScript library: ``@gel/ai``. Access all Gel AI features right from your JavaScript backend application. .. 2. ``ext::vectorstore``: this extension is designed to replicate workflows that might be familiar to you from vectorstore-style databases. Powered by ``pgvector``, it allows you to store and search for embedding vectors, and integrates with popular AI frameworks. ================================================ FILE: docs/reference/ai/javascript.rst ================================================ .. _ref_ai_javascript_reference: ============== JavaScript API ============== ``@gel/ai`` is a wrapper around the :ref:`AI extension ` in |Gel|. .. tabs:: .. code-tab:: bash :caption: npm $ npm install @gel/ai .. code-tab:: bash :caption: yarn $ yarn add @gel/ai .. code-tab:: bash :caption: pnpm $ pnpm add @gel/ai .. code-tab:: bash :caption: bun $ bun add @gel/ai Overview ======== The AI package is built on top of the regular |Gel| client objects. **Example**: .. code-block:: typescript import { createClient } from "gel"; import { createRAGClient } from "@gel/ai"; const client = createClient(); const gpt4Ai = createRAGClient(client, { model: "gpt-4-turbo-preview", }); const astronomyAi = gpt4Ai.withContext({ query: "Astronomy" }); console.log( await astronomyAi.queryRag("What color is the sky on Mars?") ); Factory functions ================= .. js:function:: createRAGClient( \ client: Client, \ options: Partial = {} \ ): RAGClient Creates an instance of ``RAGClient`` with the specified client and options. :param client: A |Gel| client instance. :param string options.model: Required. Specifies the AI model to use. This could be a version of GPT or any other model supported by |Gel| AI. It is possible to specify the model name as a URI, eg. ``openai:gpt-5``. See::ref:`text generation models`. :param options.prompt: Optional. Defines the input prompt for the AI model. The prompt can be a simple string, an ID referencing a stored prompt, or a custom prompt structure that includes roles and content for more complex interactions. The default is the built-in system prompt. Core classes ============ .. js:class:: RAGClient Instances of ``RAGClient`` offer methods for client configuration and utilizing RAG. :ivar client: An instance of |Gel| client. .. js:method:: withConfig(options: Partial): RAGClient Returns a new ``RAGClient`` instance with updated configuration options. :param string options.model: Required. Specifies the AI model to use. This could be a version of GPT or any other model supported by |Gel| AI. :param options.prompt: Optional. Defines the input prompt for the AI model. The prompt can be a simple string, an ID referencing a stored prompt, or a custom prompt structure that includes roles and content for more complex interactions. The default is the built-in system prompt. .. js:method:: withContext(context: Partial): RAGClient Returns a new ``RAGClient`` instance with an updated query context. :param string context.query: Required. Specifies an expression to determine the relevant objects and index to serve as context for text generation. You may set this to any expression that produces a set of objects, even if it is not a standalone query. :param string context.variables: Optional. Variable settings required for the context query. :param string context.globals: Optional. Variable settings required for the context query. :param number context.max_object_count: Optional. A maximum number of objects to return from the context query. .. js:method:: async queryRag( \ message: string, \ context: QueryContext = this.context \ ): Promise Sends a query with context to the configured AI model and returns the response as a string. :param string message: Required. The message to be sent to the text generation provider's API. :param string context.query: Required. Specifies an expression to determine the relevant objects and index to serve as context for text generation. You may set this to any expression that produces a set of objects, even if it is not a standalone query. :param string context.variables: Optional. Variable settings required for the context query. :param string context.globals: Optional. Variable settings required for the context query. :param number context.max_object_count: Optional. A maximum number of objects to return from the context query. .. js:method:: async streamRag( \ message: string, \ context: QueryContext = this.context \ ): AsyncIterable & PromiseLike Can be used in two ways: - as **an async iterator** - if you want to process streaming data in real-time as it arrives, ideal for handling long-running streams. - as **a Promise that resolves to a full Response object** - you have complete control over how you want to handle the stream, this might be useful when you want to manipulate the raw stream or parse it in a custom way. :param string message: Required. The message to be sent to the text generation provider's API. :param string context.query: Required. Specifies an expression to determine the relevant objects and index to serve as context for text generation. You may set this to any expression that produces a set of objects, even if it is not a standalone query. :param string context.variables: Optional. Variable settings required for the context query. :param string context.globals: Optional. Variable settings required for the context query. :param number context.max_object_count: Optional. A maximum number of objects to return from the context query. .. js:method:: async generateEmbeddings( \ inputs: string[], \ model: string \ ): Promise Generates embeddings for the array of strings. :param string[] inputs: Required. Strings array to generate embeddings for. :param string model: Required. Specifies the AI model to use. ================================================ FILE: docs/reference/ai/python.rst ================================================ .. _ref_ai_python_reference: ========== Python API ========== The ``gel.ai`` package is an optional binding of the :ref:`AI extension ` in |Gel|. .. code-block:: bash $ pip install 'gel[ai]' Blocking and async API ====================== The AI binding is built on top of the regular |Gel| client objects, providing both blocking and asynchronous versions of its API. **Blocking client example**: .. code-block:: python import gel import gel.ai client = gel.create_client() gpt4ai = gel.ai.create_rag_client( client, model="gpt-4-turbo-preview" ) astronomy_ai = gpt4ai.with_context( query="Astronomy" ) print( astronomy_ai.query_rag("What color is the sky on Mars?") ); for data in astronomy_ai.stream_rag("What color is the sky on Mars?"): print(data) **Async client example**: .. code-block:: python import gel import gel.ai import asyncio client = gel.create_async_client() async def main(): gpt4ai = await gel.ai.create_async_rag_client( client, model="gpt-4-turbo-preview" ) astronomy_ai = gpt4ai.with_context( query="Astronomy" ) query = "What color is the sky on Mars?" print( await astronomy_ai.query_rag(query) ); #or streamed async for data in blog_ai.stream_rag(query): print(data) asyncio.run(main()) Factory functions ================= .. py:function:: create_rag_client(client, **kwargs) -> RAGClient Creates an instance of ``RAGClient`` with the specified client and options. This function ensures that the client is connected before initializing the AI with the specified options. :param client: A |Gel| client instance. :param kwargs: Keyword arguments that are passed to the ``RAGOptions`` data class to configure AI-specific options. These options are: * ``model``: The name of the model to be used. (required) * ``prompt``: An optional prompt to guide the model's behavior. ``None`` will result in the client using the default prompt. (default: ``None``) .. py:function:: create_async_rag_client(client, **kwargs) -> AsyncRAGClient Creates an instance of ``AsyncRAGClient`` w/ the specified client & options. This function ensures that the client is connected asynchronously before initializing the AI with the specified options. :param client: An asynchronous |Gel| client instance. :param kwargs: Keyword arguments that are passed to the ``RAGOptions`` data class to configure AI-specific options. These options are: * ``model``: The name of the model to be used. It is possible to specify the model name as a URI, eg. ``openai:gpt-5``. See: :ref:`text generation models `. (required) * ``prompt``: An optional prompt to guide the model's behavior. (default: None) Core classes ============ .. py:class:: BaseRAGClient The base class for |Gel| AI clients. This class handles the initialization and configuration of AI clients and provides methods to modify their configuration and context dynamically. Both the blocking and async AI client classes inherit from this one, so these methods are available on an AI client of either type. :ivar options: An instance of :py:class:`RAGOptions`, storing the RAG options. :ivar context: An instance of :py:class:`QueryContext`, storing the context for AI queries. :ivar client_cls: A placeholder for the client class, should be implemented by subclasses. :param client: An instance of |Gel| client, which could be either a synchronous or asynchronous client. :param options: AI options to be used with the client. :param kwargs: Keyword arguments to initialize the query context. .. py:method:: with_config(**kwargs) Creates a new instance of the same class with modified configuration options. This method uses the current instance's configuration as a base and applies the changes specified in ``kwargs``. :param kwargs: Keyword arguments that specify the changes to the AI configuration. These changes are passed to the ``derive`` method of the current configuration options object. Possible keywords include: * ``model``: Specifies the AI model to be used. This must be a string. * ``prompt``: An optional prompt to guide the model's behavior. This is optional and defaults to None. .. py:method:: with_context(**kwargs) Creates a new instance of the same class with a modified context. This method preserves the current AI options and client settings, but uses the modified context specified by ``kwargs``. :param kwargs: Keyword arguments that specify the changes to the context. These changes are passed to the ``derive`` method of the current context object. Possible keywords include: * ``query``: The database query string. * ``variables``: A dictionary of variables used in the query. * ``globals``: A dictionary of global settings affecting the query. * ``max_object_count``: An optional integer to limit the number of objects returned by the query. .. py:class:: RAGClient A synchronous class for creating |Gel| AI clients. This class provides methods to send queries and receive responses using both blocking and streaming communication modes synchronously. :ivar client: An instance of ``httpx.AsyncClient`` used for making HTTP requests asynchronously. .. py:method:: query_rag(message, context=None) -> str Sends a request to the AI provider and returns the response as a string. This method uses a blocking HTTP POST request. It raises an HTTP exception if the request fails. :param message: The query string to be sent to the AI model. :param context: An optional ``QueryContext`` object to provide additional context for the query. If not provided, uses the default context of this AI client instance. .. py:method:: stream_rag(message, context=None) Opens a connection to the AI provider to stream query responses. This method yields data as it is received, utilizing Server-Sent Events (SSE) to handle streaming data. It raises an HTTP exception if the request fails. :param message: The query string to be sent to the AI model. :param context: An optional ``QueryContext`` object to provide additional context for the query. If not provided, uses the default context of this AI client instance. .. py:method:: generate_embeddings(*inputs: str, model: str) -> list[float] Generates embeddings for input texts. :param inputs: Input texts. :param model: The embedding model to use .. py:class:: AsyncRAGClient An asynchronous class for creating |Gel| AI clients. This class provides methods to send queries and receive responses using both blocking and streaming communication modes asynchronously. :ivar client: An instance of ``httpx.AsyncClient`` used for making HTTP requests asynchronously. .. py:method:: query_rag(message, context=None) -> str :noindex: Sends an async request to the AI provider, returns the response as a string. This method is asynchronous and should be awaited. It raises an HTTP exception if the request fails. :param message: The query string to be sent to the AI model. :param context: An optional ``QueryContext`` object to provide additional context for the query. If not provided, uses the default context of this AI client instance. .. py:method:: stream_rag(message, context=None) :noindex: Opens an async connection to the AI provider to stream query responses. This method yields data as it is received, using asynchronous Server-Sent Events (SSE) to handle streaming data. This is an asynchronous generator method and should be used in an async for loop. It raises an HTTP exception if the connection fails. :param message: The query string to be sent to the AI model. :param context: An optional ``QueryContext`` object to provide additional context for the query. If not provided, uses the default context of this AI client instance. .. py:method:: generate_embeddings(*inputs: str, model: str) -> list[float] :noindex: Generates embeddings for input texts. :param inputs: Input texts. :param model: The embedding model to use Configuration classes ===================== .. py:class:: ChatParticipantRole An enumeration of roles used when defining a custom text generation prompt. :cvar SYSTEM: Represents a system-level entity or process. :cvar USER: Represents a human user participating in the chat. :cvar ASSISTANT: Represents an AI assistant. :cvar TOOL: Represents a tool or utility used within the chat context. .. py:class:: Custom A single message in a custom text generation prompt. :ivar role: The role of the chat participant. Must be an instance of :py:class:`ChatParticipantRole`. :ivar content: The content associated with the role, expressed as a string. .. py:class:: Prompt The metadata and content of a text generation prompt. :ivar name: An optional name identifying the prompt. :ivar id: An optional unique identifier for the prompt. :ivar custom: An optional list of :py:class:`Custom` objects, each providing role-specific content within the prompt. .. py:class:: RAGOptions A data class for RAG options, specifying model and prompt settings. :ivar model: The name of the AI model. :ivar prompt: An optional :py:class:`Prompt` providing additional guiding information for the model. :method derive(kwargs): Creates a new instance of :py:class:`RAGOptions` by merging existing options with provided keyword arguments. Returns a new :py:class:`RAGOptions` instance with updated attributes. :param kwargs: Keyword arguments to update the current AI options. Possible keywords include: * ``model`` (str): Update the model name. * ``prompt`` (:py:class:`Prompt`): Update or set a new prompt object. .. py:class:: QueryContext A data class defining the context for a query to an AI model. :ivar query: The base query string. :ivar variables: An optional dictionary of variables used in the query. :ivar globals: An optional dictionary of global settings affecting the query. :ivar max_object_count: An optional integer specifying the maximum number of objects the query should return. :method derive(kwargs): Creates a new instance of :py:class:`QueryContext` by merging existing context with provided keyword arguments. Returns a new :py:class:`QueryContext` instance with updated attributes. :param kwargs: Keyword arguments to update the current query context. Possible keywords include: * ``query`` (str): Update the query string. * ``variables`` (dict): Update or set new variables for the query. * ``globals`` (dict): Update or set new global settings for the query. * ``max_object_count`` (int): Update the limit on the number of objects returned by the query. .. py:class:: RAGRequest A data class defining a request to a text generation model. :ivar model: The name of the AI model to query. :ivar prompt: An optional :py:class:`Prompt` associated with the request. :ivar context: The :py:class:`QueryContext` defining the query context. :ivar query: The specific query string to be sent to the model. :ivar stream: A boolean indicating whether the response should be streamed (True) or returned in a single response (False). :method to_httpx_request(): Converts the RAGRequest into a dictionary suitable for making an HTTP request using the httpx library. ================================================ FILE: docs/reference/ai/vectorstore_python.rst ================================================ :orphan: .. _ref_ai_vectorstore_python: ====================== Vectorstore Python API ====================== Core Classes ============ .. py:class:: GelVectorstore A framework-agnostic interface for interacting with |Gel's| ext::vectorstore. This class provides methods for storing, retrieving, and searching vector embeddings. It follows vector database conventions and supports different embedding models. Args: * ``embedding_model`` (:py:class:`BaseEmbeddingModel`): The embedding model used to generate vectors. * ``collection_name`` (str): The name of the collection. * ``record_type`` (str): The schema type (table name) for storing records. * ``client_config`` (dict | None): The config for the |Gel| client. .. py:method:: add_items(self, items: list[InsertItem]) Add multiple items to the vector store in a single transaction. Embeddings will be generated and stored for all items. Args: * ``items`` (list[:py:class:`InsertItem`]): List of items to add. Each contains: * ``text`` (str): The text content to be embedded * ``metadata`` (dict[str, Any]): Additional data to store Returns: * List of database record IDs for the inserted items. .. py:method:: add_vectors(self, records: list[InsertRecord]) Add pre-computed vector embeddings to the store. Use this method when you have already generated embeddings and want to store them directly without re-computing them. Args: * ``records`` (list[:py:class:`InsertRecord`]): List of records. Each contains: * ``embedding`` (list[float]): Pre-computed embeddings * ``text`` (Optional[str]): Original text content * ``metadata`` (dict[str, Any]): Additional data to store Returns: * List of database record IDs for the inserted items. .. py:method:: delete(self, ids: list[uuid.UUID]) Delete records from the vector store by their IDs. Args: * ``ids`` (list[uuid.UUID]): List of record IDs to delete. Returns: * List of deleted record IDs. .. py:method:: get_by_ids(self, ids: list[uuid.UUID]) -> list[Record] Retrieve specific records by their IDs. Args: * ``ids`` (list[uuid.UUID]): List of record IDs to retrieve. Returns: * List of retrieved records. Each result contains: * ``id`` (uuid.UUID): The record's unique identifier * ``text`` (Optional[str]): The original text content * ``embedding`` (Optional[list[float]]): The stored vector embedding * ``metadata`` (Optional[dict[str, Any]]): Any associated metadata .. py:method:: search_by_item(self, item: Any, filters: Optional[CompositeFilter] = None, limit: Optional[int] = 4) -> list[SearchResult] Search for similar items in the vector store. This method: 1. Generates an embedding for the input item 2. Finds records with similar embeddings 3. Optionally filters results based on metadata 4. Returns the most similar items up to the specified limit Args: * ``item`` (Any): The query item to find similar matches for. Must be compatible with the embedding model's target_type. * ``filters`` (Optional[:py:class:`CompositeFilter`]): Metadata-based filters to use. * ``limit`` (Optional[int]): Max number of results to return. Defaults to 4. Returns: * List of similar items, ordered by similarity. Each result contains: * ``id`` (uuid.UUID): The record's unique identifier * ``text`` (Optional[str]): The original text content * ``embedding`` (list[float]): The stored vector embedding * ``metadata`` (Optional[dict[str, Any]]): Any associated metadata * ``cosine_similarity`` (float): Similarity score (higher is more similar) .. py:method:: search_by_vector(self, vector: list[float], filter_expression: str = "", limit: Optional[int] = 4) -> list[SearchResult] Search using a pre-computed vector embedding. Useful when you have already computed the embedding or want to search with a modified/combined embedding vector. Args: * ``vector`` (list[float]): The query embedding to search with. Must match the dimensionality of stored embeddings. * ``filter_expression`` (str): Filter expression for metadata filtering. * ``limit`` (Optional[int]): Max number of results to return. Defaults to 4. Returns: * List of similar items, ordered by similarity. Each result contains: * ``id`` (uuid.UUID): The record's unique identifier * ``text`` (Optional[str]): The original text content * ``embedding`` (list[float]): The stored vector embedding * ``metadata`` (Optional[dict[str, Any]]): Any associated metadata * ``cosine_similarity`` (float): Similarity score (higher is more similar) .. py:method:: update_record(self, record: Record) -> Optional[uuid.UUID] Update an existing record in the vector store. Only specified fields will be updated. If text is provided but not embedding, a new embedding will be automatically generated. Args: * ``record`` (:py:class:`Record`): * ``id`` (uuid.UUID): The ID of the record to update * ``text`` (Optional[str]): New text content. If provided without embedding, a new embedding will be generated. * ``embedding`` (Optional[list[float]]): New vector embedding. * ``metadata`` (Optional[dict[str, Any]]): New metadata to store with the record. Completely replaces existing metadata. Returns: * The updated record's ID if found and updated, None if no record was found with the given ID. Raises: * ValueError: If no fields are specified for update. .. py:class:: BaseEmbeddingModel Abstract base class for embedding models. Any embedding model used with :py:class:`GelVectorstore` must implement this interface. The model is expected to convert input data (text, images, etc.) into a numerical vector representation. .. py:method:: __call__(self, item) -> list[float] Convert an input item into a list of floating-point values (vector embedding). Must be implemented in subclasses. Args: * ``item``: Input item to be converted to an embedding Returns: * list[float]: Vector embedding of the input item .. py:method:: dimensions(self) -> int Return the number of dimensions in the embedding vector. Must be implemented in subclasses. Returns: * int: Number of dimensions in the embedding vector .. py:method:: target_type(self) -> TypeVar Return the expected data type of the input (e.g., str for text, image for vision models). Must be implemented in subclasses. Returns: * TypeVar: Expected input data type Data Classes ============ .. py:class:: InsertItem An item whose embedding will be created and stored alongside the item in the vector store. Args: * ``text`` (str): The text content to be embedded * ``metadata`` (dict[str, Any]): Additional data to store. Defaults to empty dict. .. py:class:: InsertRecord A record to be added to the vector store with embedding pre-computed. Args: * ``embedding`` (list[float]): Pre-computed embeddings * ``text`` (str | None): Original text content. Defaults to None. * ``metadata`` (dict[str, Any]): Additional data to store. Defaults to empty dict. .. py:class:: Record A record retrieved from the vector store, or an update record. Custom ``__init__`` so we can detect which fields the user passed (even if they pass None or {}). Args: * ``id`` (uuid.UUID): The record's unique identifier * ``text`` (str | None): The text content. Defaults to None. * ``embedding`` (list[float] | None): The vector embedding. Defaults to None. * ``metadata`` (dict[str, Any]): Additional data stored with the record. Defaults to empty dict. .. py:class:: SearchResult A search result from the vector store. Inherits from :py:class:`Record` Args: * ``cosine_similarity`` (float): Similarity score for the search result. Defaults to 0.0. Metadata Filtering ================== .. py:class:: FilterOperator Enumeration of supported filter operators for metadata filtering. Values: * ``EQ``: Equal to (=) * ``NE``: Not equal to (!=) * ``GT``: Greater than (>) * ``LT``: Less than (<) * ``GTE``: Greater than or equal to (>=) * ``LTE``: Less than or equal to (<=) * ``IN``: Value in array * ``NOT_IN``: Value not in array * ``LIKE``: Pattern matching * ``ILIKE``: Case-insensitive pattern matching * ``ANY``: Any array element matches * ``ALL``: All array elements match * ``CONTAINS``: String contains value * ``EXISTS``: Field exists .. py:class:: FilterCondition Enumeration of conditions for combining multiple filters. Values: * ``AND``: All conditions must be true * ``OR``: Any condition must be true .. py:class:: MetadataFilter Represents a single metadata filter condition. Args: * ``key`` (str): The metadata field key to filter on * ``value`` (int | float | str): The value to compare against * ``operator`` (:py:class:`FilterOperator`): The comparison operator. Defaults to FilterOperator.EQ. .. py:class:: CompositeFilter Allows grouping multiple MetadataFilter instances using AND/OR conditions. Args: * ``filters`` (list[:py:class:`CompositeFilter` | :py:class:`MetadataFilter`]): List of filters to combine * ``condition`` (:py:class:`FilterCondition`): How to combine the filters. Defaults to FilterCondition.AND. .. py:function:: get_filter_clause(filters: CompositeFilter) -> str Get the filter clause for a given CompositeFilter. Args: * ``filters`` (:py:class:`CompositeFilter`): The composite filter to convert to a clause Returns: * str: The filter clause string for use in queries Raises: * ValueError: If an unknown operator or condition is encountered ================================================ FILE: docs/reference/auth/built_in_ui.rst ================================================ .. _ref_guide_auth_built_in_ui: =========== Built-in UI =========== :edb-alt-title: Integrating Gel Auth's built-in UI To use the built-in UI for Gel Auth, enable the built-in Auth UI by clicking the "Enable UI" button under "Login UI" in the configuration section of the |Gel| UI. Set these configuration values: - ``redirect_to``: Once the authentication flow is complete, Gel will redirect the user's browser back to this URL in your application's backend. - ``redirect_to_on_signup``: If this is a new user, Gel will redirect the user's browser back to this URL in your application's backend. - ``app_name``: Used in the built-in UI to show the user the application's name in a few important places. - ``logo_url``: If provided, will show in the built-in UI as part of the page design. - ``dark_logo_url``: If provided and the user's system has indicated that they prefer a dark UI, this will show instead of ``logo_url`` in the built-in UI as part of the page design. - ``brand_color``: If provided, used in the built-in UI as part of the page design. Example Implementation ====================== We will demonstrate the various steps below by building a NodeJS HTTP server in a single file that we will use to simulate a typical web application. .. note:: We are in the process of publishing helper libraries that you can use with popular languages and web frameworks. The details below show the inner workings of how data is exchanged with the Auth extension from a web app using HTTP. You can use this as a guide to integrate with your application written in any language that can send and receive HTTP requests. We secure authentication tokens and other sensitive data by using PKCE (Proof Key of Code Exchange). Start the PKCE flow ------------------- Your application server creates a 32-byte Base64 URL-encoded string (which will be 43 bytes after encoding), called the ``verifier``. You need to store this value for the duration of the flow. One way to accomplish this bit of state is to use an HttpOnly cookie when the browser makes a request to the server for this value, which you can then use to retrieve it from the cookie store at the end of the flow. Take this ``verifier`` string, hash it with SHA256, and then base64url encode the resulting string. This new string is called the ``challenge``. .. note:: Since ``=`` is not a URL-safe character, if your Base64-URL encoding function adds padding, you should remove the padding before hashing the ``verifier`` to derive the ``challenge`` or when providing the ``verifier`` or ``challenge`` in your requests. .. note:: If you are familiar with PKCE, you will notice some differences from how RFC 7636 defines PKCE. Our authentication flow is not an OAuth flow, but rather a strict server-to-server flow with Proof Key of Code Exchange added for additional security to avoid leaking the authentication token. Here are some differences between PKCE as defined in RFC 7636 and our implementation: - We do not support the ``plain`` value for ``code_challenge_method``, and therefore do not read that value if provided in requests. - Our parameters omit the ``code_`` prefix, however we do support ``code_challenge`` and ``code_verifier`` as aliases, preferring ``challenge`` and ``verifier`` if present. .. code-block:: javascript import http from "node:http"; import { URL } from "node:url"; import crypto from "node:crypto"; /** * You can get this value by running `gel instance credentials`. * Value should be: * `${protocol}://${host}:${port}/branch/${branch}/ext/auth/ */ const GEL_AUTH_BASE_URL = process.env.GEL_AUTH_BASE_URL; const SERVER_PORT = 3000; /** * Generate a random Base64 url-encoded string, and derive a "challenge" * string from that string to use as proof that the request for a token * later is made from the same user agent that made the original request * * @returns {Object} The verifier and challenge strings */ const generatePKCE = () => { const verifier = crypto.randomBytes(32).toString("base64url"); const challenge = crypto .createHash("sha256") .update(verifier) .digest("base64url"); return { verifier, challenge }; }; .. note:: For |EdgeDB| versions before 5.0, the value for :gelenv:`AUTH_BASE_URL` in the above snippet should have the form: ``${protocol}://${host}:${port}/db/${database}/ext/auth/`` Link to built-in UI ------------------- Next, provide a link to your web application to either the ``/auth/ui/signin`` or ``auth/ui/signup``. Those routes will generate the ``verifier`` and ``challenge`` strings, save the ``verifier`` in a cookie and redirect the user to the built-in UI with the ``challenge`` in the search parameters. .. lint-off .. code-block:: javascript /** * In Node, the `req.url` is only the `pathname` portion of a URL. In * order to generate a full URL, we need to build the protocol and host * from other parts of the request. * * One reason we like to use `URL` objects here is to easily parse the * `URLSearchParams` from the request, and rather than do more error * prone string manipulation, we build a `URL`. * * @param {Request} req * @returns {URL} */ const getRequestUrl = (req) => { const protocol = req.connection.encrypted ? "https" : "http"; return new URL(req.url, `${protocol}://${req.headers.host}`); }; const server = http.createServer(async (req, res) => { const requestUrl = getRequestUrl(req); switch (requestUrl.pathname) { case "/auth/ui/signin": { await handleUiSignIn(req, res); break; } case "/auth/ui/signup": { await handleUiSignUp(req, res); break; } case "/auth/callback": { await handleCallback(req, res); break; } default: { res.writeHead(404); res.end("Not found"); break; } } }); /** * Redirects browser requests to Gel Auth UI sign in page with the * PKCE challenge, and saves PKCE verifier in an HttpOnly cookie. * * @param {Request} req * @param {Response} res */ const handleUiSignIn = async (req, res) => { const { verifier, challenge } = generatePKCE(); const redirectUrl = new URL("ui/signin", GEL_AUTH_BASE_URL); redirectUrl.searchParams.set("challenge", challenge); res.writeHead(301, { "Set-Cookie": `gel-pkce-verifier=${verifier}; HttpOnly; Path=/; Secure; SameSite=Strict`, Location: redirectUrl.href, }); res.end(); }; /** * Redirects browser requests to Gel Auth UI sign up page with the * PKCE challenge, and saves PKCE verifier in an HttpOnly cookie. * * @param {Request} req * @param {Response} res */ const handleUiSignUp = async (req, res) => { const { verifier, challenge } = generatePKCE(); const redirectUrl = new URL("ui/signup", GEL_AUTH_BASE_URL); redirectUrl.searchParams.set("challenge", challenge); res.writeHead(301, { "Set-Cookie": `gel-pkce-verifier=${verifier}; HttpOnly; Path=/; Secure; SameSite=Strict`, Location: redirectUrl.href, }); res.end(); }; server.listen(SERVER_PORT, () => { console.log(`HTTP server listening on port ${SERVER_PORT}...`); }); .. lint-on Retrieve ``auth_token`` ----------------------- At the very end of the flow, the Gel server will redirect the user's browser to the ``redirect_to`` address with a single query parameter: ``code``. This route should be a server route that has access to the ``verifier``. You then take that ``code`` and look up the ``verifier`` in the ``gel-pkce-verifier`` cookie (``gel-pkce-verifier`` with |EdgeDB| <= 5), and make a request to the Gel Auth extension to exchange these two pieces of data for an ``auth_token``. .. lint-off .. code-block:: javascript /** * Handles the PKCE callback and exchanges the `code` and `verifier * for an auth_token, setting the auth_token as an HttpOnly cookie. * * @param {Request} req * @param {Response} res */ const handleCallback = async (req, res) => { const requestUrl = getRequestUrl(req); const code = requestUrl.searchParams.get("code"); if (!code) { const error = requestUrl.searchParams.get("error"); res.status = 400; res.end( `OAuth callback is missing 'code'. \ OAuth provider responded with error: ${error}`, ); return; } const cookies = req.headers.cookie?.split("; "); const verifier = cookies ?.find((cookie) => cookie.startsWith("gel-pkce-verifier=")) ?.split("=")[1]; if (!verifier) { res.status = 400; res.end( `Could not find 'verifier' in the cookie store. Is this the \ same user agent/browser that started the authorization flow?`, ); return; } const codeExchangeUrl = new URL("token", GEL_AUTH_BASE_URL); codeExchangeUrl.searchParams.set("code", code); codeExchangeUrl.searchParams.set("verifier", verifier); const codeExchangeResponse = await fetch(codeExchangeUrl.href, { method: "GET", }); if (!codeExchangeResponse.ok) { const text = await codeExchangeResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { auth_token } = await codeExchangeResponse.json(); res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); }; .. lint-on :ref:`Back to the Gel Auth guide ` ================================================ FILE: docs/reference/auth/email_password.rst ================================================ .. _ref_guide_auth_email_password: ================== Email and password ================== :edb-alt-title: Integrating Gel Auth's email and password provider Along with using the :ref:`built-in UI `, you can also create your own UI that calls to your own web application backend. UI considerations ================= Similar to how the built-in UI works, you can query the database configuration to discover which providers are configured and dynamically build the UI. .. code-block:: edgeql select cfg::Config.extensions[is ext::auth::AuthConfig].providers { name, [is ext::auth::OAuthProviderConfig].display_name, }; The ``name`` is a unique string that identifies the Identity Provider. OAuth providers also have a ``display_name`` that you can use as a label for links or buttons. In later steps, you'll be providing this ``name`` as the ``provider`` in various endpoints. Example implementation ====================== We will demonstrate the various steps below by building a NodeJS HTTP server in a single file that we will use to simulate a typical web application. For this example, we will require email verification to demonstrate the full flow, but you can configure your provider to not require verification by setting the ``require_verification`` setting to ``false``. .. note:: The details below show the inner workings of how data is exchanged with the Auth extension from a web app using HTTP. You can use this as a guide to integrate with your application written in any language that can send and receive HTTP requests. Start the PKCE flow ------------------- We secure authentication tokens and other sensitive data by using PKCE (Proof Key of Code Exchange). Your application server creates a 32-byte Base64 URL-encoded string (which will be 43 bytes after encoding), called the ``verifier``. You need to store this value for the duration of the flow. One way to accomplish this bit of state is to use an HttpOnly cookie when the browser makes a request to the server for this value, which you can then use to retrieve it from the cookie store at the end of the flow. Take this ``verifier`` string, hash it with SHA256, and then base64url encode the resulting string. This new string is called the ``challenge``. .. note:: Since ``=`` is not a URL-safe character, if your Base64-URL encoding function adds padding, you should remove the padding before hashing the ``verifier`` to derive the ``challenge`` or when providing the ``verifier`` or ``challenge`` in your requests. .. note:: If you are familiar with PKCE, you will notice some differences from how RFC 7636 defines PKCE. Our authentication flow is not an OAuth flow, but rather a strict server-to-server flow with Proof Key of Code Exchange added for additional security to avoid leaking the authentication token. Here are some differences between PKCE as defined in RFC 7636 and our implementation: - We do not support the ``plain`` value for ``code_challenge_method``, and therefore do not read that value if provided in requests. - Our parameters omit the ``code_`` prefix, however we do support ``code_challenge`` and ``code_verifier`` as aliases, preferring ``challenge`` and ``verifier`` if present. .. lint-off .. code-block:: javascript import http from "node:http"; import { URL } from "node:url"; import crypto from "node:crypto"; /** * You can get this value by running `gel instance credentials`. * Value should be: * `${protocol}://${host}:${port}/branch/${branch}/ext/auth/ */ const GEL_AUTH_BASE_URL = process.env.GEL_AUTH_BASE_URL; const SERVER_PORT = 3000; /** * Generate a random Base64 url-encoded string, and derive a "challenge" * string from that string to use as proof that the request for a token * later is made from the same user agent that made the original request * * @returns {Object} The verifier and challenge strings */ const generatePKCE = () => { const verifier = crypto.randomBytes(32).toString("base64url"); const challenge = crypto .createHash("sha256") .update(verifier) .digest("base64url"); return { verifier, challenge }; }; .. lint-on .. note:: For |EdgeDB| versions before 5.0, the value for :gelenv:`AUTH_BASE_URL` in the above snippet should have the form: ``${protocol}://${host}:${port}/db/${database}/ext/auth/`` Sign-in and sign-up ------------------- Next, we implement routes that handle registering a new user and authenticating an existing user. .. lint-off .. code-block:: javascript const server = http.createServer(async (req, res) => { const requestUrl = getRequestUrl(req); switch (requestUrl.pathname) { case "/auth/signup": { await handleSignUp(req, res); break; } case "/auth/signin": { await handleSignIn(req, res); break; } case "/auth/verify": { await handleVerify(req, res); break; } case "/auth/send-password-reset-email": { await handleSendPasswordResetEmail(req, res); break; } case "/auth/ui/reset-password": { await handleUiResetPassword(req, res); break; } case "/auth/reset-password": { await handleResetPassword(req, res); break; } default: { res.writeHead(404); res.end("Not found"); break; } } }); /** * Handles sign up with email and password. * * @param {Request} req * @param {Response} res */ const handleSignUp = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const pkce = generatePKCE(); const { email, password, provider } = JSON.parse(body); if (!email || !password || !provider) { res.status = 400; res.end( `Request body malformed. Expected JSON body with 'email', 'password', and 'provider' keys, but got: ${body}`, ); return; } const registerUrl = new URL("register", GEL_AUTH_BASE_URL); const registerResponse = await fetch(registerUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ challenge: pkce.challenge, email, password, provider, verify_url: `http://localhost:${SERVER_PORT}/auth/verify`, }), }); if (!registerResponse.ok) { const text = await registerResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const registerJson = await registerResponse.json(); if ("code" in registerJson) { // No verification required, we can immediately get an auth token const tokenUrl = new URL("token", GEL_AUTH_BASE_URL); tokenUrl.searchParams.set("code", registerJson.code); tokenUrl.searchParams.set("verifier", pkce.verifier); const tokenResponse = await fetch(tokenUrl.href, { method: "get", }); if (!tokenResponse.ok) { const text = await tokenResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { auth_token } = await tokenResponse.json(); res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); } else { // Verification required, we need to render a notice to the user // to check their email for a verification link res.writeHead(200, { "Content-Type": "text/html" }); res.end(`

Please check your email for a verification link.

`); } }); }; /** * Handles sign in with email and password. * * @param {Request} req * @param {Response} res */ const handleSignIn = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const pkce = generatePKCE(); const { email, password, provider } = JSON.parse(body); if (!email || !password || !provider) { res.status = 400; res.end( `Request body malformed. Expected JSON body with 'email', 'password', and 'provider' keys, but got: ${body}`, ); return; } const authenticateUrl = new URL("authenticate", GEL_AUTH_BASE_URL); const authenticateResponse = await fetch(authenticateUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ challenge: pkce.challenge, email, password, provider, }), }); if (!authenticateResponse.ok) { const text = await authenticateResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const authenticateJson = await authenticateResponse.json(); if ("code" in authenticateJson) { // User is verified, we can get an auth token const tokenUrl = new URL("token", GEL_AUTH_BASE_URL); tokenUrl.searchParams.set("code", authenticateJson.code); tokenUrl.searchParams.set("verifier", pkce.verifier); const tokenResponse = await fetch(tokenUrl.href, { method: "get", }); if (!tokenResponse.ok) { const text = await tokenResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { auth_token } = await tokenResponse.json(); res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); } else { // Verification required, we need to render a notice to the user // to check their email for a verification link res.writeHead(200, { "Content-Type": "text/html" }); res.end(`

Please check your email for a verification link.

`); } }); }; .. lint-on Email verification ------------------ When a new user signs up, by default we require them to verify their email address before allowing the application to get an authentication token. To handle the verification flow, we implement an endpoint: .. note:: If your Email/Password provider uses the **Code** verification method, the verification email contains a one-time code rather than a link. In that case, prompt the user for the code and call ``POST /verify`` with: - **provider**: ``builtin::local_emailpassword`` - **email** and **code** - optionally a **challenge** and **redirect_to** to receive a PKCE code or a redirect upon success The Link-based example below continues to work when the provider uses the Link method. .. note:: 💡 If you would like to allow users to still log in, but offer limited access to your application, you can check the associated ``ext::auth::EmailPasswordFactor`` for the ``ext::auth::Identity`` to see if the ``verified_at`` property is some time in the past. You'll need to set the ``require_verification`` setting in the provider configuration to ``false``. .. lint-off .. code-block:: javascript /** * Handles the link in the email verification flow. * * @param {Request} req * @param {Response} res */ const handleVerify = async (req, res) => { const requestUrl = getRequestUrl(req); const verification_token = requestUrl.searchParams.get("verification_token"); if (!verification_token) { res.status = 400; res.end( `Verify request is missing 'verification_token' search param. The verification email is malformed.`, ); return; } const verifyUrl = new URL("verify", GEL_AUTH_BASE_URL); const verifyResponse = await fetch(verifyUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ verification_token, provider: "builtin::local_emailpassword", }), }); if (!verifyResponse.ok) { const text = await verifyResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { code } = await verifyResponse.json(); const cookies = req.headers.cookie?.split("; "); const verifier = cookies ?.find((cookie) => cookie.startsWith("gel-pkce-verifier=")) ?.split("=")[1]; if (verifier) { // Email verification flow is continuing from the original // user agent/browser, so we can immediately get an auth token const tokenUrl = new URL("token", GEL_AUTH_BASE_URL); tokenUrl.searchParams.set("code", code); tokenUrl.searchParams.set("verifier", verifier); const tokenResponse = await fetch(tokenUrl.href, { method: "get", }); if (!tokenResponse.ok) { const text = await tokenResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { auth_token } = await tokenResponse.json(); res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); return; } // Email verification flow is continuing from a different user agent/browser, // so we need to render a notice to the user to sign in, which will either // complete the PKCE flow or start a new one res.status = 200; res.end( `

Email verified! Please sign in to continue.

`, ); }; .. lint-on Create a User object -------------------- For some applications, you may want to create a custom ``User`` type in the default module to attach application-specific information. You can tie this to an ``ext::auth::Identity`` by using the ``identity_id`` returned during the sign-up flow. .. note:: For this example, we'll assume you have a one-to-one relationship between ``User`` objects and ``ext::auth::Identity`` objects. In your own application, you may instead decide to have a one-to-many relationship. Given this ``User`` type: .. code-block:: sdl type User { email: str; name: str; required identity: ext::auth::Identity { constraint exclusive; }; } You can update the ``handleRegister`` function like this to create a new ``User`` object: .. lint-off .. code-block:: javascript-diff const registerJson = await registerResponse.json(); + if ("identity_id" in registerJson) { + await client.query(` + with + identity := $identity_id, + emailFactor := ( + select ext::auth::EmailFactor filter .identity = identity + ), + insert User { + email := emailFactor.email, + identity := identity + }; + `, { identity_id: registerJson.identity_id }); + } + if ("code" in registerJson) { .. lint-on Password reset -------------- To allow users to reset their password, we implement three endpoints. The first one sends the reset email. The second is the HTML form that is rendered when the user follows the link in their email. And, the final one is the endpoint that updates the password and logs in the user. .. note:: If your provider is configured for the **Code** method for password reset, the email will contain a one-time code instead of a reset link/token. In that case: - Call ``POST /reset-password`` with **email**, **code**, **password** and optionally **challenge**. - If you include a **challenge**, the response will include a PKCE ``code`` that you can exchange at ``POST /token`` to log the user in immediately. - If you omit **challenge**, the response will indicate success without a PKCE code and you should ask the user to sign in. .. lint-off .. code-block:: javascript /** * Request a password reset for an email. * * @param {Request} req * @param {Response} res */ const handleSendPasswordResetEmail = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const { email } = JSON.parse(body); const reset_url = `http://localhost:${SERVER_PORT}/auth/ui/reset-password`; const provider = "builtin::local_emailpassword"; const pkce = generatePKCE(); const sendResetUrl = new URL("send-reset-email", GEL_AUTH_BASE_URL); const sendResetResponse = await fetch(sendResetUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ email, provider, reset_url, challenge: pkce.challenge, }), }); if (!sendResetResponse.ok) { const text = await sendResetResponse.text(); res.status = 400; res.end(`Error from auth server: ${text}`); return; } const { email_sent } = await sendResetResponse.json(); res.writeHead(200, { "Set-Cookie": `gel-pkce-verifier=${pkce.verifier}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(`Reset email sent to '${email_sent}'`); }); }; /** * Render a simple reset password UI * * @param {Request} req * @param {Response} res */ const handleUiResetPassword = async (req, res) => { const url = new URL(req.url); const reset_token = url.searchParams.get("reset_token"); res.writeHead(200, { "Content-Type": "text/html" }); res.end(`
`); }; /** * Send new password with reset token to Gel Auth. * * @param {Request} req * @param {Response} res */ const handleResetPassword = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const { reset_token, password } = JSON.parse(body); if (!reset_token || !password) { res.status = 400; res.end( `Request body malformed. Expected JSON body with 'reset_token' and 'password' keys, but got: ${body}` ); return; } const provider = "builtin::local_emailpassword"; const cookies = req.headers.cookie.split("; "); const verifier = cookies .find((cookie) => cookie.startsWith("gel-pkce-verifier=")) .split("=")[1]; if (!verifier) { res.status = 400; res.end( `Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?` ); return; } const resetUrl = new URL("reset-password", GEL_AUTH_BASE_URL); const resetResponse = await fetch(resetUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ reset_token, provider, password, }), }); if (!resetResponse.ok) { const text = await resetResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { code } = await resetResponse.json(); const tokenUrl = new URL("token", GEL_AUTH_BASE_URL); tokenUrl.searchParams.set("code", code); tokenUrl.searchParams.set("verifier", verifier); const tokenResponse = await fetch(tokenUrl.href, { method: "get", }); if (!tokenResponse.ok) { const text = await tokenResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { auth_token } = await tokenResponse.json(); res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); }); }; .. lint-on :ref:`Back to the Gel Auth guide ` ================================================ FILE: docs/reference/auth/http.rst ================================================ .. _ref_auth_http: ======== HTTP API ======== Your application server will interact with the Gel extension primarily by sending HTTP requests to the Gel server. This page describes the HTTP API exposed by the Gel server. For more in-depth guidance about integrating Gel Auth into your application, see :ref:`ref_guide_auth` for a reference example. The following sections are organized by authentication type. Responses ========= Responses typically include a JSON object that include a ``code`` property that can be exchanged for an access token by providing the matching PKCE verifier associated with the ``code``. Some endpoints can be configured to return responses as redirects and include response data in the redirect location's query string. General ======= POST /token ----------- Exchanges a PKCE authorization code (obtained from a successful registration, authentication, or email verification flow that included a PKCE challenge) for a session token. **Request Parameters (Query String):** * ``code`` (string, required): The PKCE authorization code that was previously issued. * ``verifier`` (string, required, also accepts ``code_verifier``): The PKCE code verifier string (plaintext, typically 43-128 characters) that was originally used to generate the ``code_challenge``. **Response:** 1. **Successful Token Exchange:** * This occurs if the ``code`` is valid, and the provided ``verifier`` correctly matches the ``challenge`` associated with the ``code``. * The PKCE ``code`` is consumed and cannot be reused. * A 200 OK response is returned with a JSON body containing the session token and identity information: .. code-block:: json { "auth_token": "your_new_session_jwt", "identity_id": "the_users_identity_id", "provider_token": "optional_oauth_provider_access_token", "provider_refresh_token": "optional_oauth_provider_refresh_token", "provider_id_token": "optional_oauth_provider_id_token" } .. note:: ``provider_token``, ``provider_refresh_token``, and ``provider_id_token`` are only populated if the PKCE flow originated from an interaction with an external OAuth provider that returned these tokens. 2. **PKCE Verification Failed:** * The ``code`` was found, but the ``verifier`` did not match the stored challenge. * An HTTP error response 403 Forbidden with a JSON body indicating ``PKCEVerificationFailed``. 3. **Unknown Code:** * The provided ``code`` was not found. * An HTTP error response 403 Forbidden with a JSON body indicating "NoIdentityFound". 4. **Code found, but not associated with an Identity:** * The ``code`` was found, but it is not associated with a user identity. * An HTTP error response 400 Bad Request with a JSON body indicating "InvalidData". 5. **Invalid Verifier Length:** * The ``verifier`` string is shorter than 43 characters or longer than 128 characters. * An HTTP 400 Bad Request response with a JSON body detailing the length requirement. 6. **Missing Parameters:** * Either ``code`` or ``verifier`` (or ``code_verifier``) is missing from the query string. * An HTTP 400 Bad Request response with a JSON body indicating the missing parameter. Email and password ================== POST /register -------------- Register a new user with email and password. **Request Body (JSON):** * ``email`` (string, required): The user's email address. * ``password`` (string, required): The user's desired password. * ``provider`` (string, required): The name of the provider to use: ``builtin::local_emailpassword`` * ``challenge`` (string, optional): A PKCE code challenge. This is required if the provider is configured with ``require_verification: false`` since registering will also authenticate and authentication is protected by a PKCE code exchange. * ``redirect_to`` (string, optional): A URL to redirect to upon successful registration. * ``verify_url`` (string, optional): The base URL for the email verification link. If not provided, it defaults to ``/ui/verify``, the built-in UI endpoint for verifying email addresses. The verification token will be appended as a query parameter to this URL. * ``redirect_on_failure`` (string, optional): A URL to redirect to if registration fails. .. note:: The verification email sent after registration depends on your provider's verification method: - **Code**: users receive a one-time code and must call ``POST /verify`` with ``provider``, ``email`` and the ``code``. - **Link**: users receive a verification link that carries a ``verification_token`` and must call ``POST /verify`` with ``provider`` and the ``verification_token`` (often done by following the link). **Response:** The behavior of the response depends on the request parameters and server-side provider configuration (specifically, ``require_verification``). 1. **Successful Registration with Email Verification Required:** * This occurs if the provider has ``require_verification: true``. * If ``redirect_to`` is provided in the request: * A 302 redirect to the ``redirect_to`` URL occurs. * The redirect URL will include ``identity_id`` and ``verification_email_sent_at`` as query parameters. * If ``redirect_to`` is NOT provided: * A 201 Created response is returned with a JSON body: .. code-block:: json { "identity_id": "...", "verification_email_sent_at": "YYYY-MM-DDTHH:MM:SS.ffffffZ" } 2. **Successful Registration with Email Verification NOT Required (PKCE Flow):** * This occurs if the provider has ``require_verification: false``. The ``challenge`` parameter is mandatory in the request. * If ``redirect_to`` is provided in the request: * A 302 redirect to the ``redirect_to`` URL occurs. * The redirect URL will include ``code`` (the PKCE authorization code) and ``provider`` as query parameters. * If ``redirect_to`` is NOT provided: * A 201 Created response is returned with a JSON body: .. code-block:: json { "code": "...", "provider": "..." } 3. **Registration Failure:** * If ``redirect_on_failure`` is provided in the request and is an allowed URL: * A 302 redirect to the ``redirect_on_failure`` URL occurs. * The redirect URL will include ``error`` (a description of the error) and ``email`` (the submitted email) as query parameters. * Otherwise (no ``redirect_on_failure`` or it's not allowed): * An HTTP error response (e.g., 400 Bad Request, 500 Internal Server Error) is returned with a JSON body describing the error. For example: .. code-block:: json { "message": "Error description", "type": "ErrorType", "code": "ERROR_CODE" } **Common Error Scenarios:** * Missing ``provider`` in the request. * Missing ``challenge`` in the request when the provider has ``require_verification: false``. * Email already exists. * Invalid password (e.g., too short, if policies are enforced). POST /authenticate ------------------ Authenticate a user using email and password. **Request Body (JSON):** * ``email`` (string, required): The user's email address. * ``password`` (string, required): The user's password. * ``provider`` (string, required): The name of the provider to use: ``builtin::local_emailpassword`` * ``challenge`` (string, required): A PKCE code challenge. * ``redirect_to`` (string, optional): A URL to redirect to upon successful authentication. * ``redirect_on_failure`` (string, optional): A URL to redirect to if authentication fails. If not provided, but ``redirect_to`` is, ``redirect_to`` will be used as the fallback for failure redirection. **Response:** The behavior of the response depends on the request parameters and the outcome of the authentication attempt. 1. **Successful Authentication:** * A PKCE authorization code is generated and associated with the user's session. * If ``redirect_to`` is provided in the request: * A 302 redirect to the ``redirect_to`` URL occurs. * The redirect URL will include a ``code`` (the PKCE authorization code) as a query parameter. * If ``redirect_to`` is NOT provided: * A 200 OK response is returned with a JSON body: .. code-block:: json { "code": "..." } 2. **Authentication Failure (e.g., invalid credentials, user not found):** * If ``redirect_on_failure`` (or ``redirect_to`` as a fallback) is provided in the request and is an allowed URL: * A 302 redirect to this URL occurs. * The redirect URL will include ``error`` (a description of the error) and ``email`` (the submitted email) as query parameters. * Otherwise (no applicable redirect URL or it's not allowed): * An HTTP error response (e.g., 400, 401) is returned with a JSON body describing the error. For example: .. code-block:: json { "message": "Invalid credentials", "type": "InvalidCredentialsError", "code": "INVALID_CREDENTIALS" } 3. **Email Verification Required:** * This occurs if the provider is configured with ``require_verification: true`` and the user has not yet verified their email address. * The response follows the same logic as **Authentication Failure**: * If ``redirect_on_failure`` (or ``redirect_to``) is provided, a redirect occurs with an error like "VerificationRequired". * Otherwise, an HTTP error (often 403 Forbidden) is returned with a JSON body indicating that email verification is required. **Common Error Scenarios:** * Missing required fields in the request: ``email``, ``password``, ``provider``, or ``challenge``. * Invalid email or password. * User account does not exist. * User account exists but email is not verified (if ``require_verification: true`` for the provider). POST /send-reset-email ---------------------- Send a password reset email to a user. **Request Body (JSON):** * ``provider`` (string, required): The name of the provider: ``builtin::local_emailpassword``. * ``email`` (string, required): The email address of the user requesting the password reset. * ``reset_url`` (string, required): The base URL for the password reset page (used for the Link method). The ``reset_token`` will be appended as a query parameter. This URL must be an allowed redirect URI in the server configuration. * ``challenge`` (string, required): A PKCE code challenge. For the Link method it is embedded in the ``reset_token``; for the Code method it can be re-used later when completing the reset to obtain a PKCE code. * ``redirect_to`` (string, optional): A URL to redirect to after the reset email has been successfully queued for sending. * ``redirect_on_failure`` (string, optional): A URL to redirect to if there's an error during the process. If not provided, but ``redirect_to`` is, ``redirect_to`` will be used as the fallback for failure redirection. .. note:: The email sent depends on your provider's configuration: - **Link**: a reset link is sent containing a ``reset_token``; the user should then call ``POST /reset-password`` with this token. - **Code**: a one-time code is sent to the email address; the user should then call ``POST /reset-password`` with ``email`` and ``code`` (and optionally ``challenge`` to receive a PKCE code). **Response:** The endpoint always attempts to respond in a way that does not reveal whether an email address is registered or not. 1. **Reset Email Queued (or User Not Found):** * If the user exists, a password reset email is generated and sent. * If the user does not exist, the server simulates a successful send to prevent email enumeration attacks. * If ``redirect_to`` is provided in the request: * A 302 redirect to the ``redirect_to`` URL occurs. * The redirect URL will include ``email_sent`` (the email address provided in the request) as a query parameter. * If ``redirect_to`` is NOT provided: * A 200 OK response is returned with a JSON body: .. code-block:: json { "email_sent": "user@example.com" } 2. **Failure (e.g., ``reset_url`` not allowed, SMTP server error):** * This occurs for errors not related to whether the user exists, such as configuration issues or mail server problems. * If ``redirect_on_failure`` (or ``redirect_to`` as a fallback) is provided in the request and is an allowed URL: * A 302 redirect to this URL occurs. * The redirect URL will include ``error`` (a description of the error) and ``email`` (the submitted email) as query parameters. * Otherwise (no applicable redirect URL or it's not allowed): * An HTTP error response (e.g., 400 Bad Request, 500 Internal Server Error) is returned with a JSON body describing the error. **Common Error Scenarios (leading to the Failure response):** * Missing required fields in the request: ``provider``, ``email``, ``reset_url``, or ``challenge``. * The provided ``reset_url`` is not in the server's list of allowed redirect URIs. * Internal server error during email dispatch (e.g., SMTP configuration issues). POST /reset-password -------------------- Resets a user's password using a reset token and a new password. This endpoint completes the password reset flow initiated by ``POST /send-reset-email``. **Request Body (JSON):** * ``provider`` (string, required): The name of the provider: ``builtin::local_emailpassword``. * ``password`` (string, required): The new password for the user's account. Choose one of the following modes: - **Token mode (Link method)** * ``reset_token`` (string, required): The token that was emailed to the user. - **Code mode** * ``email`` (string, required): The user's email address. * ``code`` (string, required): The one-time code sent by email. * ``challenge`` (string, optional): If provided, a PKCE authorization code will be generated upon success. Optional for both modes: * ``redirect_to`` (string, optional): A URL to redirect to after the password has been successfully reset. If provided and a PKCE code is generated, it will be appended as a query parameter. * ``redirect_on_failure`` (string, optional): A URL to redirect to if the password reset process fails. If not provided, but ``redirect_to`` is, ``redirect_to`` will be used as the fallback. **Response:** - **Token mode (Link method)** * The ``reset_token`` is validated, and the user's password is updated. * A PKCE authorization ``code`` is generated using the challenge embedded in the token. * If ``redirect_to`` is provided, a 302 redirect occurs with ``code`` appended; otherwise, a 200 OK JSON response is returned with ``{"code": "..."}``. - **Code mode** * The ``email``/``code`` are validated, and the user's password is updated. * If a ``challenge`` is provided, a PKCE authorization ``code`` is generated. * If ``redirect_to`` is provided and a PKCE code was generated, a 302 redirect occurs with ``code`` appended; if ``challenge`` was not provided, a 200 OK JSON response is returned with ``{"status": "password_reset"}``. - **Failure (invalid inputs or server error)** * If ``redirect_on_failure`` (or ``redirect_to`` as a fallback) is provided and is an allowed URL, a 302 redirect occurs with an ``error`` parameter (and submitted ``reset_token``/``email`` where applicable). * Otherwise, an HTTP error response is returned with a JSON error body (e.g., 400, 403, 500). **Common Error Scenarios:** * Missing required fields in the request: ``provider``, ``reset_token``, or ``password``. * The ``reset_token`` is malformed, has an invalid signature, or is expired. * Internal server error during the password update process. Email verification ================== These endpoints apply to the Email and password provider, as well as the WebAuthn provider. Verification emails are sent even if you do not *require* verification. The difference between requiring verification and not is that if you require verification, the user must verify their email before they can authenticate. If you do not require verification, the user can authenticate without verifying their email. POST /verify ------------ Verify a user's email address. Supports both Link and Code methods. **Request Body (JSON):** * ``provider`` (string, required): The provider name, e.g., ``builtin::local_emailpassword`` or ``builtin::local_webauthn``. Choose exactly one verification mode: - **Link mode** * ``verification_token`` (string, required): The JWT sent to the user (typically via an email link) to verify their email. - **Code mode** * ``email`` (string, required): The user's email address to verify. * ``code`` (string, required): The one-time code sent via email. * ``challenge`` (string, optional, also accepts ``code_challenge``): If provided, a PKCE authorization code will be generated upon success. * ``redirect_to`` (string, optional): If provided, a redirect response will be sent upon success. This URL must be in the server's list of allowed redirect URIs. **Response:** - **Link mode** The primary action is to validate the ``verification_token`` and mark the associated email as verified. The exact response depends on the contents of the ``verification_token`` (it may include a PKCE challenge and/or a redirect URL specified during its creation): 1. With challenge and redirect URL in token * A PKCE authorization code is generated using the challenge from the token. * A 302 redirect to the URL specified in the token (``maybe_redirect_to``) occurs, with ``code`` appended as a query parameter. 2. With challenge only in token * A PKCE authorization code is generated using the challenge from the token. * A 200 OK response is returned with a JSON body: .. code-block:: json { "code": "generated_pkce_code" } 3. With redirect URL only in token * A 302 redirect to the URL specified in the token (``maybe_redirect_to``) occurs (no ``code`` is added). 4. No challenge or redirect URL in token * A 204 No Content response is returned. 5. Invalid or expired token * A 403 Forbidden response is returned with a JSON body (e.g., token expired). - **Code mode** After validating ``email`` and ``code`` and marking the email as verified, behavior depends on optional ``challenge`` and ``redirect_to``: 1. ``challenge`` and ``redirect_to`` provided * A PKCE authorization code is generated and a 302 redirect to ``redirect_to`` occurs with ``code`` appended as a query parameter. 2. Only ``challenge`` provided * A PKCE authorization code is generated and a 200 OK response is returned with a JSON body: .. code-block:: json { "code": "generated_pkce_code" } 3. Only ``redirect_to`` provided * A 302 redirect to ``redirect_to`` occurs (no PKCE code is generated). 4. Neither provided * A 204 No Content response is returned. **Common Error Scenarios:** * Missing ``provider`` or ``verification_token`` in the request (results in HTTP 400). * The ``verification_token`` is malformed, has an invalid signature, or is expired (results in HTTP 403). * An internal error occurs while trying to update the email verification status (results in HTTP 500). POST /resend-verification-email ------------------------------- Resend a verification email to a user. This can be useful if the original email was lost or the token expired. **Request Body (JSON):** The request must include ``provider`` and a way to identify the user's email factor. * ``provider`` (string, required): The provider name, e.g., ``builtin::local_emailpassword`` or ``builtin::local_webauthn``. Then, choose **one** of the following methods to specify the user: * **Method 1: Using an existing Verification Token** * ``verification_token`` (string): An old (even expired) verification token. The system will extract necessary details (like ``identity_id``, original ``verify_url``, ``challenge``, and ``redirect_to``) from this token to generate a new one. * **Method 2: Using Email Address (for Email/Password provider)** * ``email`` (string, required if ``provider`` is ``builtin::local_emailpassword`` and ``verification_token`` is not used): The user's email address. * ``verify_url`` (string, optional): The base URL for the new verification link. Defaults to the server's configured UI verify path (e.g., ``/ui/verify``). * ``challenge`` (string, optional, also accepts ``code_challenge``): A PKCE code challenge to be embedded in the new verification token. * ``redirect_to`` (string, optional): A URL to redirect to after successful verification using the new token. This URL must be in the server's list of allowed redirect URIs. * **Method 3: Using WebAuthn Credential ID (for WebAuthn provider)** * ``credential_id`` (string, required if ``provider`` is ``builtin::local_webauthn`` and ``verification_token`` is not used): The Base64 encoded WebAuthn credential ID. * ``verify_url`` (string, optional): As above. * ``challenge`` (string, optional, also accepts ``code_challenge``): As above. * ``redirect_to`` (string, optional): As above. This URL must be in the server's list of allowed redirect URIs. **Response:** The endpoint aims to prevent email enumeration by always returning a successful status code if the request format is valid, regardless of whether the user or email factor was found. 1. **Verification Email Queued (or User/Email Factor Not Found):** * If the user/email factor is found, a new verification email with a fresh token is generated and sent. * If the user/email factor is not found (based on the provided identifier), the server simulates a successful send. * A 200 OK response is returned. The response body is typically empty. 2. **Failure (Invalid Request or Server Error):** * If the request is malformed (e.g., unsupported ``provider``, ``redirect_to`` URL not allowed, missing required fields for the chosen identification method), an HTTP 400 Bad Request with a JSON error body is returned. * If an internal server error occurs (e.g., SMTP issues), an HTTP 500 Internal Server Error with a JSON error body is returned. **Common Error Scenarios:** * Unsupported ``provider`` name. * Missing ``verification_token`` when it's the chosen method, or missing ``email`` / ``credential_id`` for other methods. * Providing a ``redirect_to`` URL that is not in the allowed list. * Internal SMTP errors preventing email dispatch. .. note:: If the provider uses the **Code** verification method, the resend email will contain a one-time code instead of a link. In this case, ``verify_url``, ``challenge``, and ``redirect_to`` are not included in the email and are only relevant for the Link method. OAuth ===== POST /authorize --------------- Initiate an OAuth authorization flow. **Request Parameters (Query String):** * ``provider`` (string, required): The name of the OAuth provider to use (e.g., ``builtin::oauth::google``). * ``redirect_to`` (string, required): The URL to redirect to after a successful OAuth flow completes and a PKCE code is obtained. This URL must be in the server's list of allowed redirect URIs. * ``challenge`` (string, required, also accepts ``code_challenge``): A PKCE code challenge generated by your application. * ``redirect_to_on_signup`` (string, optional): An alternative URL to redirect to after a *new* user successfully completes the OAuth flow. If not provided, ``redirect_to`` will be used for both new and existing users. This URL must also be in the server's list of allowed redirect URIs. * ``callback_url`` (string, optional): The URL the OAuth provider should redirect back to after the user authorizes the application. If not provided, it defaults to ``/callback``. This URL must be in the server's list of allowed redirect URIs. **Response:** 1. **Successful Authorization Initiation:** * The server generates a PKCE challenge record and prepares for the OAuth flow. * A 302 Found redirect response is returned. * The ``Location`` header will contain the authorization URL provided by the external OAuth identity provider. The user's browser will be directed to this URL to begin the OAuth provider's authentication/authorization process. **Common Error Scenarios:** * Missing required fields in the query string: ``provider``, ``redirect_to``, or ``challenge``. * The provided ``redirect_to``, ``redirect_to_on_signup``, or ``callback_url`` is not in the server's list of allowed redirect URIs. * Configuration error on the server (e.g., the specified provider is not configured). POST /callback -------------- Handle the redirect from the OAuth provider. This endpoint is typically called by the OAuth provider after the user has completed the authentication and authorization process on the provider's site. It processes the response from the provider, exchanges the authorization code for Gel session information (and potentially provider tokens), and redirects the user back to the application. This endpoint accepts parameters either in the query string (for GET requests) or in the request body as ``application/x-www-form-urlencoded`` (for POST requests). **Request Parameters (Query String or Form Data):** * ``state`` (string, required): The state parameter originally sent in the ``POST /authorize`` request. This is a signed JWT containing information needed to complete the flow (like provider name, redirect URLs, and the PKCE challenge). * ``code`` (string, optional): The authorization code provided by the OAuth identity provider. This is present on successful authorization. * ``error`` (string, optional): An error code provided by the OAuth identity provider, if authorization failed. * ``error_description`` (string, optional): A human-readable description of the error provided by the OAuth identity provider. **Response:** 1. **Successful Callback and Token Exchange:** * This occurs when the OAuth provider returns a ``code``, and the ``state`` is valid. * The server exchanges the OAuth code for identity information and potentially provider access/refresh tokens. * The identity is linked to the PKCE challenge provided in the original ``state``. * A 302 Found redirect response is returned. * The ``Location`` header will contain the ``redirect_to`` (or ``redirect_to_on_signup`` if applicable) URL specified in the original ``state`` parameter. * The redirect URL will include the Gel PKCE authorization ``code`` and the ``provider`` name as query parameters (e.g., ``https://app.example.com/success?code=gel_pkce_code&provider=oauth_provider_name``). This PKCE code can then be exchanged for a session token via ``POST /token``. 2. **OAuth Provider Returned an Error:** * This occurs when the OAuth provider redirects back with an ``error`` parameter. * A 302 Found redirect response is returned. * The ``Location`` header will contain the ``redirect_to`` URL specified in the original ``state`` parameter. * The redirect URL will include the ``error`` and optionally ``error_description`` and the user's ``email`` (if available and relevant) as query parameters. **Common Error Scenarios (before redirect):** * Missing ``state`` parameter in the request. * Invalid or malformed ``state`` token. * The OAuth provider did not return either a ``code`` or an ``error``. * Errors during the server's exchange of the OAuth code with the provider (these typically result in an HTTP error response from this endpoint rather than a redirect with an error). WebAuthn ======== POST /webauthn/register ----------------------- Register a new WebAuthn credential for a user. This typically follows a call to ``GET /webauthn/register/options`` where the registration options were obtained. **Request Body (JSON):** * ``provider`` (string, required): The name of the WebAuthn provider to use: ``builtin::local_webauthn``. * ``challenge`` (string, required): A PKCE code challenge. This challenge will be linked to the identity upon successful registration if email verification is not required. * ``email`` (string, required): The user's email address associated with the WebAuthn credential. * ``credentials`` (string, required): The credential data obtained from the client-side WebAuthn API (``navigator.credentials.create()``). This should be a JSON string. * ``verify_url`` (string, required): The base URL for the email verification link that will be emailed to the user if email verification is required. * ``user_handle`` (string, optional): The Base64 URL encoded user handle generated during the options request. This can also be passed via a cookie named ``edgedb-webauthn-registration-user-handle``. **Request Cookies:** * ``edgedb-webauthn-registration-user-handle`` (string, optional): The Base64 URL encoded user handle generated during the options request. If present, this overrides the ``user_handle`` in the request body. **Response:** The response depends on whether the WebAuthn provider is configured to require email verification or not. 1. **Successful Registration with Email Verification Required:** * A 201 Created response is returned with a JSON body: .. code-block:: json { "identity_id": "...", "verification_email_sent_at": "YYYY-MM-DDTHH:MM:SS.ffffffZ" } * The ``edgedb-webauthn-registration-user-handle`` cookie is cleared. 2. **Successful Registration with Email Verification NOT Required (PKCE Flow):** * A 201 Created response is returned with a JSON body: .. code-block:: json { "code": "...", "provider": "builtin::local_webauthn" } * The ``edgedb-webauthn-registration-user-handle`` cookie is cleared. The returned ``code`` can be exchanged for a session token at the ``POST /token`` endpoint. **Common Error Scenarios:** * Missing required fields in the request body or user handle (either in body or cookie). * Invalid or malformed ``credentials`` or ``user_handle`` data. * The specified ``verify_url`` is not in the server's list of allowed redirect URIs. * Errors during the WebAuthn registration process on the server (e.g., credential already registered). * Configuration error on the server (e.g., WebAuthn provider not configured). POST /webauthn/authenticate --------------------------- Authenticate a user using an existing WebAuthn credential. This typically follows a call to ``GET /webauthn/authenticate/options`` where the authentication options were obtained. **Request Body (JSON):** * ``provider`` (string, required): The name of the WebAuthn provider to use: ``builtin::local_webauthn``. * ``challenge`` (string, required): A PKCE code challenge. This challenge will be linked to the authenticated identity upon successful authentication. * ``email`` (string, required): The user's email address associated with the WebAuthn credential they are attempting to use. * ``assertion`` (string, required): The assertion data obtained from the client-side WebAuthn API (``navigator.credentials.get()``). This should be a JSON string. **Response:** 1. **Successful Authentication:** * This occurs when the provided ``assertion`` successfully verifies the user's identity based on the provided ``email``. * If email verification is required for the provider, the user's email must also be verified. * A PKCE authorization ``code`` is generated and linked to the authenticated identity using the provided ``challenge``. * A 200 OK response is returned with a JSON body: .. code-block:: json { "code": "..." } * The returned ``code`` can be exchanged for a session token at the ``POST /token`` endpoint. 2. **Authentication Failure:** * This occurs if the provided ``assertion`` does not match the registered credential for the given email, the email is not found, or if email verification is required but the email is not verified. * An HTTP error response (e.g., 401 Unauthorized or 403 Forbidden) is returned with a JSON body describing the error (e.g., "Failed to authenticate WebAuthn", "VerificationRequired"). **Common Error Scenarios:** * Missing required fields in the request body: ``challenge``, ``email``, or ``assertion``. * Invalid or malformed ``assertion`` data. * No WebAuthn credential found for the provided email. * WebAuthn authentication failed (e.g., invalid signature). * Email verification is required for the provider, but the user's email is not verified. * Configuration error on the server (e.g., WebAuthn provider not configured). GET /webauthn/register/options ------------------------------ Get the necessary options from the server to initiate a WebAuthn registration ceremony on the client side (using ``navigator.credentials.create()``). **Request Parameters (Query String):** * ``email`` (string, required): The user's email address for whom registration options are being requested. **Response:** 1. **Successful Options Retrieval:** * A 200 OK response is returned. * The ``Content-Type`` header is ``application/json``. * The response body contains a JSON object with the WebAuthn registration options, compatible with the Web Authentication API (``PublicKeyCredentialCreationOptions``). * A cookie named ``edgedb-webauthn-registration-user-handle`` is set containing the Base64 URL encoded user handle generated by the server. This cookie is needed for the subsequent ``POST /webauthn/register`` request. **Common Error Scenarios:** * Missing required ``email`` query parameter. * Configuration error on the server (e.g., WebAuthn provider not configured). * Errors during the generation of registration options on the server. GET /webauthn/authenticate/options ---------------------------------- Get the necessary options from the server to initiate a WebAuthn authentication ceremony on the client side (using ``navigator.credentials.get()``). **Request Parameters (Query String):** * ``email`` (string, required): The user's email address for whom authentication options are being requested. The server will look up associated WebAuthn credentials based on this email. **Response:** 1. **Successful Options Retrieval:** * A 200 OK response is returned. * The ``Content-Type`` header is ``application/json``. * The response body contains a JSON object with the WebAuthn authentication options, compatible with the Web Authentication API (``PublicKeyCredentialRequestOptions``). These options will include information about the user's registered credentials to challenge the client. **Common Error Scenarios:** * Missing required ``email`` query parameter. * Configuration error on the server (e.g., WebAuthn provider not configured). * Errors during the generation of authentication options on the server (e.g., no credentials found for the email). Magic link ========== POST /magic-link/register ------------------------- Registers a new user with a magic link credential and sends a magic link email to their email address. **Request Body (JSON or application/x-www-form-urlencoded):** The required fields depend on the provider's verification method. - **Code method** * ``email`` (string, required): The user's email address. * ``redirect_to`` (string, optional): A URL to redirect to after the email has been queued. If omitted, the request must accept ``application/json``. - **Link method** * ``email`` (string, required): The user's email address. * ``challenge`` (string, required): A PKCE code challenge that will be embedded in the magic link token. * ``callback_url`` (string, required): The URL that the user will be redirected to after clicking the magic link in the email. A PKCE authorization ``code`` will be appended to this URL. This URL must be in the server's list of allowed redirect URIs. * ``redirect_on_failure`` (string, required): A URL to redirect to if there's an error during the registration or email sending process. Error details will be appended as query parameters. This URL must be in the server's list of allowed redirect URIs. * ``redirect_to`` (string, optional): A URL to redirect to *after* the server has successfully queued the email for sending (before the user clicks the link). If provided, a JSON response will not be returned, and parameters like ``email_sent`` (or ``code=true`` in Code method) will be appended as query parameters. This URL must be in the server's list of allowed redirect URIs. * ``link_url`` (string, optional): The base URL for the magic link itself (the endpoint the link in the email will point to). If not provided, it defaults to ``/magic-link/authenticate``. This URL must be in the server's list of allowed redirect URIs. **Response:** The endpoint attempts to prevent email enumeration by always returning a success status if the request format is valid. - **Code method** * If the request accepts ``application/json`` and ``redirect_to`` is not provided, a 200 OK JSON response is returned: .. code-block:: json { "code": "true", "signup": "true", "email": "user@example.com" } * If ``redirect_to`` is provided, a 302 Found redirect occurs to ``redirect_to`` with ``code=true``, ``signup=true`` and ``email`` as query parameters. - **Link method** * If the request accepts ``application/json`` and ``redirect_to`` is not provided, a 200 OK JSON response is returned: .. code-block:: json { "email_sent": "user@example.com" } * If ``redirect_to`` is provided, a 302 Found redirect occurs to ``redirect_to`` with ``email_sent`` as a query parameter. - **Failure** * If an error occurs before a redirect would occur and the request accepts JSON, an HTTP error response (e.g., 400 Bad Request) is returned with a JSON body. * Otherwise, if ``redirect_on_failure`` was provided (Link method), a 302 Found redirect occurs to that URL with ``error`` and ``email`` query parameters. **Common Error Scenarios (leading to failure responses):** * Missing required fields in the request body: ``provider``, ``email``, ``challenge``, ``callback_url``, or ``redirect_on_failure``. * The provided ``callback_url``, ``redirect_on_failure``, ``redirect_to``, or ``link_url`` is not in the server's list of allowed redirect URIs. * Unsupported ``provider`` name. * Internal server error during email dispatch (e.g., SMTP issues). POST /magic-link/email ---------------------- Sends a magic link email to a user with an *existing* magic link credential. This is similar to ``POST /magic-link/register`` but does not attempt to create a new identity if the email is not found (though it still simulates a successful send to prevent enumeration). **Request Body (JSON or application/x-www-form-urlencoded):** The required fields depend on the provider's verification method. - **Code method** * ``email`` (string, required): The user's email address. * ``redirect_to`` (string, optional): A URL to redirect to after the email has been queued. If omitted, the response will be JSON. - **Link method** * ``email`` (string, required): The user's email address. * ``challenge`` (string, required): A PKCE code challenge that will be embedded in the magic link token. * ``callback_url`` (string, required): The URL that the user will be redirected to after clicking the magic link in the email. A PKCE authorization ``code`` will be appended to this URL. This URL must be in the server's list of allowed redirect URIs. * ``redirect_on_failure`` (string, required): A URL to redirect to if there's an error during the email sending process. Error details will be appended as query parameters. This URL must be in the server's list of allowed redirect URIs. * ``redirect_to`` (string, optional): A URL to redirect to *after* the server has successfully queued the email for sending (before the user clicks the link). If provided, a JSON response will not be returned. * ``link_url`` (string, optional): The base URL for the magic link itself. If not provided, it defaults to ``/magic-link/authenticate``. This URL must be in the server's list of allowed redirect URIs. **Response:** The endpoint attempts to prevent email enumeration by always returning a success status if the request format is valid, even if the email address is not found. - **Code method** * If ``redirect_to`` is NOT provided, a 200 OK JSON response is returned: .. code-block:: json { "code": "true", "email": "user@example.com" } * If ``redirect_to`` is provided, a 302 Found redirect occurs to the ``redirect_to`` URL with ``code=true`` and ``email`` as query parameters. - **Link method** * If ``redirect_to`` is NOT provided, a 200 OK JSON response is returned: .. code-block:: json { "email_sent": "user@example.com" } * If ``redirect_to`` is provided, a 302 Found redirect occurs to the ``redirect_to`` URL with ``email_sent`` as a query parameter. - **Failure** * If an error happens and a ``redirect_on_failure`` URL was provided (Link method), a 302 Found redirect is returned to that URL with ``error`` and the submitted ``email`` as query parameters. Otherwise, an HTTP error response is returned with a JSON body. **Common Error Scenarios (leading to failure responses):** * Missing required fields in the request body: ``provider``, ``email``, ``challenge``, ``callback_url``, or ``redirect_on_failure``. * The provided ``callback_url``, ``redirect_on_failure``, ``redirect_to``, or ``link_url`` is not in the server's list of allowed redirect URIs. * Unsupported ``provider`` name. * Internal server error during email dispatch (e.g., SMTP issues). POST /magic-link/authenticate ----------------------------- Authenticates a user by validating a magic link token received from an email. This endpoint is typically the target of the magic link URL sent to the user. This endpoint supports both Link and Code methods. **Link method (Query String):** * ``token`` (string, required): The magic link token (a signed JWT) extracted from the magic link URL. This token contains the identity ID, the original PKCE challenge, and the callback URL. * ``redirect_on_failure`` (string, optional): A URL to redirect to if the authentication process fails (e.g., invalid or expired token). Error details will be appended as query parameters. If not provided, an HTTP error response will be returned on failure. **Code method (JSON body):** * ``email`` (string, required): The user's email address. * ``code`` (string, required): The one-time code sent via email. * ``callback_url`` (string, required): The URL to redirect to after successful authentication. Must be an allowed redirect URI. * ``challenge`` (string, required): A PKCE code challenge. A PKCE authorization ``code`` will be generated upon success. **Response:** - **Link method** * If the provided ``token`` is valid, the user's email factor is marked as verified and a PKCE authorization ``code`` is generated using the challenge embedded in the token. A 302 Found redirect is returned to the token's ``callback_url`` with ``code`` appended. * On failure, if ``redirect_on_failure`` is provided, a 302 redirect occurs to that URL with an ``error`` parameter; otherwise, an HTTP error response is returned with a JSON body. - **Code method** * On success, the one-time code is validated, the email factor is marked as verified, and a PKCE authorization ``code`` is generated using the provided ``challenge``. A 302 Found redirect occurs to ``callback_url`` with ``code`` appended. * On failure, if a ``redirect_on_failure`` query parameter is present, a 302 redirect occurs to that URL with an ``error`` parameter; otherwise, a 400 Bad Request JSON response is returned with an error body. **Common Error Scenarios (leading to failure responses):** * Missing required ``token`` query parameter. * The provided ``token`` is malformed, has an invalid signature, or is expired. * Internal server error during the authentication or email verification process. * The ``callback_url`` extracted from the token is not in the server's list of allowed redirect URIs (this should ideally be caught earlier, but could potentially manifest here). ================================================ FILE: docs/reference/auth/index.rst ================================================ .. _ref_guide_auth: ==== Auth ==== .. toctree:: :hidden: :maxdepth: 3 http built_in_ui email_password oauth magic_link webauthn webhooks :edb-alt-title: Using Gel Auth |Gel| Auth is a batteries-included authentication solution for your app built into the Gel server. Here's how you can integrate it with your app. Enable extension in your schema =============================== Auth is a Gel extension. To enable it, you will need to add the extension to your app's schema: .. code-block:: sdl using extension auth; Extension configuration ======================= The best and easiest way to configure the extension for your database is to use the built-in UI. To access it, run :gelcmd:`ui`. If you have the extension enabled in your schema as shown above and have migrated that schema change, you will see the "Auth Admin" icon in the left-hand toolbar. .. image:: images/ui-auth.png :alt: The Gel local development server UI highlighting the auth admin icon in the left-hand toolbar. The icon is two nested shield outlines, the inner being a light pink color and the outer being a light blue when selected. :width: 100% The auth admin UI exposes these values: app_name -------- The name of your application to be shown on the login screen when using the built-in UI. logo_url -------- A URL to an image of your logo. This is also used to customize the built-in UI. logo_url -------- A URL to an image of your logo for use with a dark theme. This is also used to customize the built-in UI. brand_color ----------- Your brand color as a hex string. This will be used as the accent color in the built-in UI. auth_signing_key ---------------- The extension uses JSON Web Tokens (JWTs) internally for many operations. ``auth_signing_key`` is the value that is used as a symmetric key for signing the JWTs. At the moment, the JWTs are not considered "public" API, so there is no need to save this value for your own application use. It is exposed mainly to allow rotation. To configure via query or script: .. lint-off .. code-block:: edgeql CONFIGURE CURRENT BRANCH SET ext::auth::AuthConfig::auth_signing_key := 'F2KHaJfHi9Dzd8+6DI7FB9IFIoJXnhz2rzG/UzCRE7jTtYxqgTHHydc8xnN6emDB3tlR99FvPsyJfcVLVcQ5odSQpceDXplBOP+N14+EBy2mV6rA/7W7azIEKebtr9TVKrpBTMTOLAXo08ZnA6lvjn0VMs95za6Pta7VW62hjcb8jy6yxulvvU5SWnwa0x2z401K0pLK7byDD5eNqgTl40YaeOGoQ0iCkSmGxvLxyQgCIz2IU0zUbBwC9bQsTDORvflunruJznHuMxwbfYo/czQIIGuawU0H+G3GJZ3hecZLQlvwYCyLF37PFQVrcNMtUuGyDy2OyYtYHru2GW5B7Q'; .. lint-on token_time_to_live ------------------ This value controls the expiration time on the authentication token's JSON Web Token. This is effectively the "session" time. To configure via query or script: .. code-block:: edgeql CONFIGURE CURRENT BRANCH SET ext::auth::AuthConfig::token_time_to_live := "336 hours"; allowed_redirect_urls --------------------- This value is a set of strings that we use to ensure we only redirect to domains that are under the control of the application using the Auth extension. We compare any ``redirect_to`` URLs against this list. A URL is considered a "match" if the URL is exactly the same as one on the list, or is a sub-path of a URL on the list. For example, if the set includes ``https://example.com/myapp``: .. list-table:: :header-rows: 1 * - URL - Match * - ``https://example.com/myapp`` - ✅ * - ``https://example.com/myapp/auth`` - ✅ * - ``https://example.com/myapp/auth/verify`` - ✅ * - ``https://example.com/myapp/somewhere/else`` - ✅ * - ``http://example.com/myapp`` - Does not match the protocol * - ``https://example.com:443/myapp`` - Does not match the port * - ``https://auth.example.com/myapp`` - Does not match the subdomain * - ``https://example.com/different/subpath`` - Does not match the pathname or extend it .. note:: 💡 We always allow redirects to the auth extension itself, so you do not need to add it explicitly if, for instance, you are always using the built-in UI. To configure via query or script: .. code-block:: edgeql CONFIGURE CURRENT BRANCH SET ext::auth::AuthConfig::allowed_redirect_urls := { 'https://example.com', 'https://example.com/auth', 'https://localhost:3000', 'https://localhost:3000/auth' }; Webhooks ======== The auth extension supports sending webhooks for a variety of auth events. You can use these webhooks to, for instance, send a fully customized email for email verification, or password reset instead of our built-in email verification and password reset emails. You could also use them to trigger analytics events, start an email drip campaign, create an audit log, or trigger other side effects in your application. See the :ref:`webhooks documentation ` for more details on how to configure and use webhooks. Configuring SMTP ================ For email-based factors, you can configure SMTP to allow the extension to send emails on your behalf. You should either configure SMTP, or webhooks for the relevant events. The easiest way to configure SMTP is to use the built-in UI. Here is an example of configuring SMTP for local development using EdgeQL directly, using something like `Mailpit `__. .. note:: Gel Cloud users, rejoice! If you are using Gel Cloud, you can use the built-in development SMTP provider without any configuration. This special provider is already configured for development usage and is ready to send emails while you are developing your application. This provider is tuned specifically for development: it is rate limited and the sender is hardcoded. Do not use it in production, it will not work for that purpose. .. code-block:: edgeql # Create a new SMTP provider: # configure current branch insert cfg::SMTPProviderConfig { # This name must be unique and is used to reference the provider name := 'local_mailpit', sender := '"Display Name" ', host := 'localhost', port := 1025, username := 'smtpuser', password := 'smtppassword', security := 'STARTTLSOrPlainText', validate_certs := false, timeout_per_email := '60 seconds', timeout_per_attempt := '15 seconds', }; # Set this provider as the current email provider by name: # configure current branch set current_email_provider_name := 'local_mailpit'; .. note:: The ``sender`` property follows the `RFC 5322 `_ specification, so you can include a display name in the email address or use a bare email address. Including a display name is recommended as it provides a more user-friendly experience. Email clients will show the display name (e.g., "Display Name" from the example above) instead of just the raw email address in the sender field, making your emails appear more professional and trustworthy to recipients. Enabling authentication providers ================================= In order to use the auth extension, you'll need to enable at least one of these authentication providers. Providers can be added from the "Providers" section of the admin auth UI by clicking "Add Provider." This will add a form to the UI allowing for selection of the provider and configuration of the values described below. You can also enable providers via query. We'll demonstrate how in each section below. .. _ref_guide_auth_overview_email_password: Email and password ------------------ - ``require_verification``: (Default: ``true``) If ``true``, your application will not be able to retrieve an authentication token until the user has verified their email. If ``false``, your application can retrieve an authentication token, but a verification email will still be sent. Regardless of this setting, you can always decide to limit access or specific features in your application by testing if ``ext::auth::EmailPasswordFactor.verified_at`` is set to a date in the past on the ``ext::auth::LocalIdentity``. To enable via query or script: .. code-block:: edgeql CONFIGURE CURRENT BRANCH INSERT ext::auth::EmailPasswordProviderConfig { require_verification := false, }; .. note:: ``require_verification`` defaults to ``true``. If you use the Email and Password provider, in addition to the ``require_verification`` configuration, you'll need to configure SMTP to allow |Gel| to send email verification and password reset emails on your behalf or set up webhooks for the relevant events: - ``ext::auth::WebhookEvent.EmailVerificationRequested`` - ``ext::auth::WebhookEvent.PasswordResetRequested`` Here is an example of setting a local SMTP server, in this case using a product called `Mailpit `__ which is great for testing in development: .. code-block:: edgeql CONFIGURE CURRENT BRANCH INSERT cfg::SMTPProviderConfig { sender := 'hello@example.com', host := 'localhost', port := 1025, security := 'STARTTLSOrPlainText', validate_certs := false, }; Here is an example of setting up webhooks for the email verification and password reset events: .. code-block:: edgeql CONFIGURE CURRENT BRANCH INSERT ext::auth::WebhookConfig { url := 'https://example.com/auth/webhook', events := { ext::auth::WebhookEvent.EmailVerificationRequested, ext::auth::WebhookEvent.PasswordResetRequested, } }; OAuth ----- We currently support six different OAuth providers: .. lint-off - `Apple `__ - `Azure (Microsoft) `__ - `GitHub `__ - `Google `__ - `Discord `__ - `Slack `__ .. lint-on The instructions for creating an app for each provider can be found on each provider's developer documentation website, which is linked above. The important things you'll need to find and make note of for your configuration are the **client ID** and **secret**. Once you select the OAuth provider in the configuration UI, you will need to provide those values and the ``additional_scope``: - ``client_id`` This is assigned to you by the Identity Provider when you create an app with them. - ``secret`` This is created by the Identity Provider when you create an app with them. - ``additional_scope`` We request certain scope from the Identity Provider to fulfill our minimal data needs. You can pass additional scope here in a space-separated string and we will request that additional scope when getting the authentication token from the Identity Provider. .. note:: We return this authentication token with this scope from the Identity Provider when we return our own authentication token. You'll also need to set a callback URL in each provider's interface. To build this callback URL, you will need the hostname, port, and branch name of your database. The branch name is |main| by default. The hostname and port can be found running this CLI command: .. code-block:: bash $ gel instance credentials This will output a table that includes the hostnames and ports of all your instances. Grab those from the row corresponding to the correct instance for use in your callback URL, which takes on this format: .. code-block:: http[s]://{gel_host}[:port]/db/{db_name}/ext/auth/callback To enable the Azure OAuth provider via query or script: .. code-block:: edgeql CONFIGURE CURRENT BRANCH INSERT ext::auth::AzureOAuthProvider { secret := 'cccccccccccccccccccccccccccccccc', client_id := '1597b3fc-b67d-4d2b-b38f-acc256341dbc', additional_scope := 'offline_access', }; To enable any of the others, change ``AzureOAuthProvider`` in the example above to one of the other providers: - ``AppleOAuthProvider`` - ``DiscordOAuthProvider`` - ``GitHubOAuthProvider`` - ``GoogleOAuthProvider`` - ``SlackOAuthProvider`` Generic OpenID Connect providers -------------------------------- .. versionadded:: 6.0 Generic OpenID Connect providers are now supported. In order to use them, you will need to insert an ``ext::auth::OpenIDConnectProvider`` configuration object with a few additional properties: - ``name``: A unique string identifying the provider. - ``display_name``: A human-readable name for the provider. - ``issuer_url``: The issuer URL of the provider. This must be the domain of the provider's authorization server and will be used to drive the OpenID Connect flow. - ``logo_url``: (optional) A URL to an image of the provider's logo. This is used in the built-in UI to display the correct logo. Inherited from ``ext::auth::OAuthProviderConfig``: - ``client_id``: The client ID of the provider. - ``secret``: The client secret of the provider. - ``additional_scope``: (optional) A space-separated string of additional scopes to request from the provider. Here is an example of enabling the Google OpenID Connect provider (note, for Google, you can simply use the existing Google provider, but this is for illustration purposes): .. lint-off .. code-block:: edgeql CONFIGURE CURRENT BRANCH INSERT ext::auth::OpenIDConnectProvider { name := 'google', display_name := 'Google', issuer_url := 'https://accounts.google.com', logo_url := 'https://www.google.com/images/branding/googlelogo/1x/googlelogo_color_272x92dp.png', client_id := '1234567890', secret := '1234567890', }; .. lint-on Magic link ---------- Magic link offers the following settings: - ``verification_method``: ``Link`` (default) or ``Code``. - ``Link``: users receive a link and are redirected back with a PKCE ``code``. - ``Code``: users receive a one-time code. Collect the code and call ``POST /magic-link/authenticate`` with ``email``, ``code``, ``callback_url``, and the PKCE ``challenge`` to receive a PKCE ``code``. - ``token_time_to_live``: determines how long a magic link (or one-time code) remains valid after sending. Since magic links rely on email, you must also configure SMTP or webhooks. For local testing, you can use the same method used for SMTP previously for :ref:`the email and password provider `. Here is an example of setting a local SMTP server, in this case using a product called `Mailpit `__ which is great for testing in development: .. code-block:: edgeql CONFIGURE CURRENT BRANCH INSERT cfg::SMTPProviderConfig { sender := 'hello@example.com', host := 'localhost', port := 1025, security := 'STARTTLSOrPlainText', validate_certs := false, }; Here is an example of setting up webhooks for the magic link events: .. code-block:: edgeql CONFIGURE CURRENT BRANCH INSERT ext::auth::WebhookConfig { url := 'https://example.com/auth/webhook', events := { ext::auth::WebhookEvent.MagicLinkRequested, } }; WebAuthn -------- - ``relying_party_origin``: This is the URL of the web application handling the WebAuthn request. If you're using the built-in UI, it's the origin of the Gel web server. - ``require_verification``: (Default: ``true``) If ``true``, your application will not be able to retrieve an authentication token until the user has verified their email. If ``false``, your application can retrieve an authentication token, but a verification email will still be sent. Regardless of this setting, you can always decide to limit access or specific features in your application by testing if ``ext::auth::WebAuthnFactor.verified_at`` is set to a date in the past on the ``ext::auth::LocalIdentity``. .. note:: You will need to configure SMTP or webhooks. For local testing, you can use Mailpit as described in :ref:`the email/password section `. .. note:: You will need to configure CORS to allow the client-side script to call the Gel Auth extension's endpoints from the web browser. You can do this by updating the ``cors_allow_origins`` configuration in the Gel server configuration. Here is an example of setting a local SMTP server, in this case using a product called `Mailpit `__ which is great for testing in development: .. code-block:: edgeql CONFIGURE CURRENT BRANCH INSERT cfg::SMTPProviderConfig { sender := 'hello@example.com', host := 'localhost', port := 1025, security := 'STARTTLSOrPlainText', validate_certs := false, }; Here is an example of setting up webhooks for the WebAuthn events: .. code-block:: edgeql CONFIGURE CURRENT BRANCH INSERT ext::auth::WebhookConfig { url := 'https://example.com/auth/webhook', events := { ext::auth::WebhookEvent.EmailVerificationRequested, } }; Integrating your application ============================ In the end, what we want to end up with is an authentication token created by Gel that we can set as a global in any authenticated queries executed from our application, which will set a computed global linked to an ``ext::auth::Identity``. .. note:: 💡 If you want your own ``User`` type that contains application specific information like name, preferences, etc, you can link to this ``ext::auth::Identity`` to do so. You can then use the ``ext::auth::Identity`` (or custom ``User`` type) to define access policies and make authenticated queries. Select your method for detailed configuration: .. toctree:: :maxdepth: 3 built_in_ui email_password oauth magic_link webauthn Example usage ============= Here's an example schema that we can use to show how you would use the ``auth_token`` you get back from Gel to make queries against a protected resource, in this case being able to insert a ``Post``. .. code-block:: sdl using extension auth; module default { global current_user := ( assert_single(( select User filter .identity = global ext::auth::ClientTokenIdentity )) ); type User { required name: str; required identity: ext::auth::Identity; } type Post { required text: str; required author: User; access policy author_has_full_access allow all using (.author ?= global current_user); access policy others_read_only allow select; } } Let's now insert a ``Post``. .. lint-off .. code-block:: tsx const client = createClient().withGlobals({ "ext::auth::client_token": auth_token, }); const inserted = await client.querySingle( ` insert Post { text := $text, author := global current_user, }`, { text: 'if your grave doesnt say "rest in peace" on it you are automatically drafted into the skeleton war' } ); .. lint-on I can even delete it, since I have access through the global: .. code-block:: tsx await client.query(`delete Post filter .id = $id`, { id: inserted.id }); ================================================ FILE: docs/reference/auth/magic_link.rst ================================================ .. _ref_guide_auth_magic_link: ================ Magic Link Auth ================ :edb-alt-title: Integrating Gel Auth's Magic Link provider Magic Link is a passwordless authentication method that allows users to log in via a unique, time-sensitive link sent to their email. This guide will walk you through integrating Magic Link authentication with your application using Gel Auth. Enable Magic Link provider ========================== Before you can use Magic Link authentication, you need to enable the Magic Link provider in your Gel Auth configuration. This can be done through the Gel UI under the "Providers" section. Magic Link flow =============== The Magic Link authentication flow involves three main steps: 1. **Sending a Magic Link Email**: Your application requests Gel Auth to send a magic link to the user's email. 2. **User Clicks Magic Link**: The user receives the email and clicks on the magic link. 3. **Authentication and Token Retrieval**: The magic link directs the user to your application, which then authenticates the user and retrieves an authentication token from Gel Auth. .. note:: Your Magic Link provider can be configured to send either a **Link** or a one-time **Code**: - **Code**: Instead of clicking a link, users receive a one-time code. Prompt for the code and call ``POST /magic-link/authenticate`` with ``email``, ``code``, ``callback_url`` and the PKCE ``challenge``. On success, you will be redirected to ``callback_url`` with a PKCE ``code`` you can exchange at ``POST /token``. - **Link**: Users click a link that includes a token. Your app handles the token-based callback as shown below. UI considerations ================= Similar to how the built-in UI works, you can query the database configuration to discover which providers are configured and dynamically build the UI. .. code-block:: edgeql select cfg::Config.extensions[is ext::auth::AuthConfig].providers { name, [is ext::auth::OAuthProviderConfig].display_name, }; The ``name`` is a unique string that identifies the Identity Provider. OAuth providers also have a ``display_name`` that you can use as a label for links or buttons. Example implementation ====================== We will demonstrate the various steps below by building a NodeJS HTTP server in a single file that we will use to simulate a typical web application. .. note:: The details below show the inner workings of how data is exchanged with the Auth extension from a web app using HTTP. You can use this as a guide to integrate with your application written in any language that can send and receive HTTP requests. Start the PKCE flow ------------------- We secure authentication tokens and other sensitive data by using PKCE (Proof Key of Code Exchange). Your application server creates a 32-byte Base64 URL-encoded string (which will be 43 bytes after encoding), called the ``verifier``. You need to store this value for the duration of the flow. One way to accomplish this bit of state is to use an HttpOnly cookie when the browser makes a request to the server for this value, which you can then use to retrieve it from the cookie store at the end of the flow. Take this ``verifier`` string, hash it with SHA256, and then base64url encode the resulting string. This new string is called the ``challenge``. .. lint-off .. code-block:: javascript import http from "node:http"; import { URL } from "node:url"; import crypto from "node:crypto"; /** * You can get this value by running `gel instance credentials`. * Value should be: * `${protocol}://${host}:${port}/branch/${branch}/ext/auth/ */ const GEL_AUTH_BASE_URL = process.env.GEL_AUTH_BASE_URL; const SERVER_PORT = 3000; /** * Generate a random Base64 url-encoded string, and derive a "challenge" * string from that string to use as proof that the request for a token * later is made from the same user agent that made the original request * * @returns {Object} The verifier and challenge strings */ const generatePKCE = () => { const verifier = crypto.randomBytes(32).toString("base64url"); const challenge = crypto .createHash("sha256") .update(verifier) .digest("base64url"); return { verifier, challenge }; }; .. lint-on Routing ------- Let's set up the routes we will use to handle the magic link authentication flow. We will then detail each route handler in the following sections. .. lint-off .. code-block:: javascript const server = http.createServer(async (req, res) => { const requestUrl = getRequestUrl(req); switch (requestUrl.pathname) { case "/auth/magic-link/callback": { await handleCallback(req, res); break; } case "/auth/magic-link/signup": { await handleSignUp(req, res); break; } case "/auth/magic-link/send": { await handleSendMagicLink(req, res); break; } default: { res.writeHead(404); res.end("Not found"); break; } } }); .. lint-on Sign up ------- .. lint-off .. code-block:: javascript /** * Send magic link to new user's email for sign up. * * @param {Request} req * @param {Response} res */ const handleSignUp = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const pkce = generatePKCE(); const { email, provider } = JSON.parse(body); if (!email || !provider) { res.status = 400; res.end( `Request body malformed. Expected JSON body with 'email' and 'provider' keys, but got: ${body}`, ); return; } const registerUrl = new URL("magic-link/register", GEL_AUTH_BASE_URL); const callbackUrl = new URL("auth/magic-link/callback", "http://localhost:${SERVER_PORT}"); const registerResponse = await fetch(registerUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ challenge: pkce.challenge, email, provider, callback_url: callbackUrl.href, // The following endpoint will be called if there is an error // processing the magic link, such as expiration or malformed token, // etc. redirect_on_failure: `http://localhost:${SERVER_PORT}/auth_error.html`, }), }); if (!registerResponse.ok) { const text = await registerResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } res.writeHead(204, { "Set-Cookie": `gel-pkce-verifier=${pkce.verifier}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); }); }; .. lint-on Sign in ------- Signing in with a magic link simply involves telling the Gel Auth server to send a magic link to the user's email. The user will then click on the link to authenticate. .. lint-off .. code-block:: javascript /** * Send magic link to existing user's email for sign in. * * @param {Request} req * @param {Response} res */ const handleSendMagicLink = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const pkce = generatePKCE(); const { email, provider } = JSON.parse(body); if (!email || !provider) { res.status = 400; res.end( `Request body malformed. Expected JSON body with 'email' and 'provider' keys, but got: ${body}`, ); return; } const emailUrl = new URL("magic-link/email", GEL_AUTH_BASE_URL); const callbackUrl = new URL("auth/magic-link/callback", "http://localhost:${SERVER_PORT}"); const authenticateResponse = await fetch(emailUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ challenge: pkce.challenge, email, provider, callback_url: callbackUrl.href, }), }); if (!authenticateResponse.ok) { const text = await authenticateResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } res.writeHead(204, { "Set-Cookie": `gel-pkce-verifier=${pkce.verifier}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); }); }; .. lint-on Callback -------- Once the user clicks on the magic link (Link method), they will be redirected back to your application with a ``code`` query parameter. Your application will then exchange this code for an authentication token. If the provider uses the Code method, you will instead collect the one-time code from the user and call ``POST /magic-link/authenticate`` with ``email``, ``code``, ``callback_url`` and the PKCE ``challenge``. On success you will be redirected to ``callback_url`` with a PKCE ``code`` to exchange at ``POST /token``. .. lint-off .. code-block:: javascript /** * Handles the PKCE callback and exchanges the `code` and `verifier` * for an auth_token, setting the auth_token as an HttpOnly cookie. * * @param {Request} req * @param {Response} res */ const handleCallback = async (req, res) => { const requestUrl = getRequestUrl(req); const code = requestUrl.searchParams.get("code"); if (!code) { const error = requestUrl.searchParams.get("error"); res.status = 400; res.end( `Magic link callback is missing 'code'. Provider responded with error: ${error}`, ); return; } const cookies = req.headers.cookie?.split("; "); const verifier = cookies ?.find((cookie) => cookie.startsWith("gel-pkce-verifier=")) ?.split("=")[1]; if (!verifier) { res.status = 400; res.end( `Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?`, ); return; } const codeExchangeUrl = new URL("token", GEL_AUTH_BASE_URL); codeExchangeUrl.searchParams.set("code", code); codeExchangeUrl.searchParams.set("verifier", verifier); const codeExchangeResponse = await fetch(codeExchangeUrl.href, { method: "GET", }); if (!codeExchangeResponse.ok) { const text = await codeExchangeResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { auth_token } = await codeExchangeResponse.json(); res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); }; .. lint-on Create a User object -------------------- For some applications, you may want to create a custom ``User`` type in the default module to attach application-specific information. You can tie this to an ``ext::auth::Identity`` by using the ``identity_id`` returned during the sign-up flow. .. note:: For this example, we'll assume you have a one-to-one relationship between ``User`` objects and ``ext::auth::Identity`` objects. In your own application, you may instead decide to have a one-to-many relationship. Given this ``User`` type: .. code-block:: sdl type User { email: str; name: str; required identity: ext::auth::Identity { constraint exclusive; }; } We need to update two parts of the sign-up flow. First, we need to signal to the callback that this particular callback is for a sign-up, which we do by setting the ``isSignUp`` query parameter to ``true``. Second, we need to create a new ``User`` object and attach it to the ``ext::auth::Identity`` object. .. tabs:: .. code-tab:: javascript-diff :caption: handleSignUp const handleSignUp = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const pkce = generatePKCE(); const { email, provider } = JSON.parse(body); if (!email || !provider) { res.status = 400; res.end( `Request body malformed. Expected JSON body with 'email' and 'provider' keys, but got: ${body}`, ); return; } const registerUrl = new URL("magic-link/register", GEL_AUTH_BASE_URL); const callbackUrl = new URL("auth/magic-link/callback", "http://localhost:${SERVER_PORT}"); + callbackUrl.searchParams.set("isSignUp", "true"); const registerResponse = await fetch(registerUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ challenge: pkce.challenge, email, provider, callback_url: callbackUrl.href, // The following endpoint will be called if there is an error // processing the magic link, such as expiration or malformed token, // etc. redirect_on_failure: `http://localhost:${SERVER_PORT}/auth_error.html`, }), }); if (!registerResponse.ok) { const text = await registerResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } res.writeHead(204, { "Set-Cookie": `gel-pkce-verifier=${pkce.verifier}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); }); }; .. code-tab:: javascript-diff :caption: handleCallback const handleCallback = async (req, res) => { const requestUrl = getRequestUrl(req); const code = requestUrl.searchParams.get("code"); if (!code) { const error = requestUrl.searchParams.get("error"); res.status = 400; res.end( `Magic link callback is missing 'code'. Provider responded with error: ${error}`, ); return; } const cookies = req.headers.cookie?.split("; "); const verifier = cookies ?.find((cookie) => cookie.startsWith("gel-pkce-verifier=")) ?.split("=")[1]; if (!verifier) { res.status = 400; res.end( `Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?`, ); return; } const codeExchangeUrl = new URL("token", GEL_AUTH_BASE_URL); codeExchangeUrl.searchParams.set("code", code); codeExchangeUrl.searchParams.set("verifier", verifier); const codeExchangeResponse = await fetch(codeExchangeUrl.href, { method: "GET", }); if (!codeExchangeResponse.ok) { const text = await codeExchangeResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } - const { auth_token } = await codeExchangeResponse.json(); + const { + auth_token, + identity_id + } = await codeExchangeResponse.json(); + if (requestUrl.searchParams.get("isSignUp") === "true") { + await client.query(` + with + identity := $identity_id, + emailFactor := ( + select ext::auth::EmailFactor filter .identity = identity + ), + insert User { + email := emailFactor.email, + identity := identity + }; + `, { identity_id }); + } + res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); }; :ref:`Back to the Gel Auth guide ` ================================================ FILE: docs/reference/auth/oauth.rst ================================================ .. _ref_guide_auth_oauth: ===== OAuth ===== :edb-alt-title: Integrating Gel Auth's OAuth provider Along with using the :ref:`built-in UI `, you can also create your own UI that calls to your own web application backend. UI considerations ================= Similar to how the built-in UI works, you can query the database configuration to discover which providers are configured and dynamically build the UI. .. code-block:: edgeql select cfg::Config.extensions[is ext::auth::AuthConfig].providers { name, [is ext::auth::OAuthProviderConfig].display_name, }; The ``name`` is a unique string that identifies the Identity Provider. OAuth providers also have a ``display_name`` that you can use as a label for links or buttons. In later steps, you'll be providing this ``name`` as the ``provider`` in various endpoints. Example implementation ====================== We will demonstrate the various steps below by building a NodeJS HTTP server in a single file that we will use to simulate a typical web application. .. note:: We are in the process of publishing helper libraries that you can use with popular languages and web frameworks. The details below show the inner workings of how data is exchanged with the Auth extension from a web app using HTTP. You can use this as a guide to integrate with your application written in any language that can send and receive HTTP requests. We secure authentication tokens and other sensitive data by using PKCE (Proof Key of Code Exchange). Start the PKCE flow ------------------- Your application server creates a 32-byte Base64 URL-encoded string (which will be 43 bytes after encoding), called the ``verifier``. You need to store this value for the duration of the flow. One way to accomplish this bit of state is to use an HttpOnly cookie when the browser makes a request to the server for this value, which you can then use to retrieve it from the cookie store at the end of the flow. Take this ``verifier`` string, hash it with SHA256, and then base64url encode the resulting string. This new string is called the ``challenge``. .. note:: Since ``=`` is not a URL-safe character, if your Base64-URL encoding function adds padding, you should remove the padding before hashing the ``verifier`` to derive the ``challenge`` or when providing the ``verifier`` or ``challenge`` in your requests. .. note:: If you are familiar with PKCE, you will notice some differences from how RFC 7636 defines PKCE. Our authentication flow is not an OAuth flow, but rather a strict server-to-server flow with Proof Key of Code Exchange added for additional security to avoid leaking the authentication token. Here are some differences between PKCE as defined in RFC 7636 and our implementation: - We do not support the ``plain`` value for ``code_challenge_method``, and therefore do not read that value if provided in requests. - Our parameters omit the ``code_`` prefix, however we do support ``code_challenge`` and ``code_verifier`` as aliases, preferring ``challenge`` and ``verifier`` if present. .. code-block:: javascript import http from "node:http"; import { URL } from "node:url"; import crypto from "node:crypto"; /** * You can get this value by running `gel instance credentials`. * Value should be: * `${protocol}://${host}:${port}/branch/${branch}/ext/auth/ */ const GEL_AUTH_BASE_URL = process.env.GEL_AUTH_BASE_URL; const SERVER_PORT = 3000; /** * Generate a random Base64 url-encoded string, and derive a "challenge" * string from that string to use as proof that the request for a token * later is made from the same user agent that made the original request * * @returns {Object} The verifier and challenge strings */ const generatePKCE = () => { const verifier = crypto.randomBytes(32).toString("base64url"); const challenge = crypto .createHash("sha256") .update(verifier) .digest("base64url"); return { verifier, challenge }; }; .. note:: For |EdgeDB| versions before 5.0, the value for :gelenv:`AUTH_BASE_URL` in the above snippet should have the form: ``${protocol}://${host}:${port}/db/${database}/ext/auth/`` Redirect users to Identity Provider ----------------------------------- Next, we implement a route at ``/auth/authorize`` that the application should link to when signing in with a particular Identity Provider. We will redirect the end user's browser to the Identity Provider with the proper setup. .. lint-off .. code-block:: javascript const server = http.createServer(async (req, res) => { const requestUrl = getRequestUrl(req); switch (requestUrl.pathname) { case "/auth/authorize": { await handleAuthorize(req, res); break; } case "/auth/callback": { await handleCallback(req, res); break; } default: { res.writeHead(404); res.end("Not found"); break; } } }); /** * Redirects OAuth requests to Gel Auth OAuth authorize redirect * with the PKCE challenge, and saves PKCE verifier in an HttpOnly * cookie for later retrieval. * * @param {Request} req * @param {Response} res */ const handleAuthorize = async (req, res) => { const requestUrl = getRequestUrl(req); const provider = requestUrl.searchParams.get("provider"); if (!provider) { res.status = 400; res.end("Must provider a 'provider' value in search parameters"); return; } const pkce = generatePKCE(); const redirectUrl = new URL("authorize", GEL_AUTH_BASE_URL); redirectUrl.searchParams.set("provider", provider); redirectUrl.searchParams.set("challenge", pkce.challenge); redirectUrl.searchParams.set( "redirect_to", `http://localhost:${SERVER_PORT}/auth/callback` ); redirectUrl.searchParams.set( "redirect_to_on_signup", `http://localhost:${SERVER_PORT}/auth/callback?isSignUp=true` ); res.writeHead(302, { "Set-Cookie": `gel-pkce-verifier=${pkce.verifier}; HttpOnly; Path=/; Secure; SameSite=Strict`, Location: redirectUrl.href, }); res.end(); }; .. lint-on Retrieve ``auth_token`` ----------------------- At the very end of the flow, the Gel server will redirect the user's browser to the ``redirect_to`` address with a single query parameter: ``code``. This route should be a server route that has access to the ``verifier``. You then take that ``code`` and look up the ``verifier`` in the ``gel-pkce-verifier`` cookie (``gel-pkce-verifier`` with |EdgeDB| <= 5), and make a request to the Gel Auth extension to exchange these two pieces of data for an ``auth_token``. .. lint-off .. code-block:: javascript /** * Handles the PKCE callback and exchanges the `code` and `verifier * for an auth_token, setting the auth_token as an HttpOnly cookie. * * @param {Request} req * @param {Response} res */ const handleCallback = async (req, res) => { const requestUrl = getRequestUrl(req); const code = requestUrl.searchParams.get("code"); if (!code) { const error = requestUrl.searchParams.get("error"); res.status = 400; res.end( `OAuth callback is missing 'code'. OAuth provider responded with error: ${error}` ); return; } const cookies = req.headers.cookie?.split("; "); const verifier = cookies ?.find((cookie) => cookie.startsWith("gel-pkce-verifier=")) ?.split("=")[1]; if (!verifier) { res.status = 400; res.end( `Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?` ); return; } const codeExchangeUrl = new URL("token", GEL_AUTH_BASE_URL); codeExchangeUrl.searchParams.set("code", code); codeExchangeUrl.searchParams.set("verifier", verifier); const codeExchangeResponse = await fetch(codeExchangeUrl.href, { method: "GET", }); if (!codeExchangeResponse.ok) { const text = await codeExchangeResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { auth_token } = await codeExchangeResponse.json(); res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); }; .. lint-on Creating a User object ---------------------- For some applications, you may want to create a custom ``User`` type in the default module to attach application-specific information. You can tie this to an ``ext::auth::Identity`` by using the ``auth_token`` in our ``ext::auth::client_token`` global and inserting your ``User`` object with a link to the ``Identity``. .. note:: For this example, we'll assume you have a one-to-one relationship between ``User`` objects and ``ext::auth::Identity`` objects. In your own application, you may instead decide to have a one-to-many relationship. Given this ``User`` type: .. code-block:: sdl type User { email: str; name: str; required identity: ext::auth::Identity { constraint exclusive; }; } You can update the callback function like this to create a new ``User`` object when the callback succeeds. Recall that in our ``handleAuthorize`` route handler, we added a separate callback route for when the extension adds a new Identity which sets a search parameter on the URL to ``isSignUp=true``: .. code-block:: javascript-diff const { auth_token } = await codeExchangeResponse.json(); + + const isSignUp = requestUrl.searchParams.get("isSignUp"); + if (isSignUp === "true") { + const authedClient = client.withGlobals({ + "ext::auth::client_token": auth_token, + }); + await authedClient.query(` + insert User { + identity := (global ext::auth::ClientTokenIdentity) + }; + `); + } + res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); Using an OpenID Connect id_token -------------------------------- For some providers that implement OpenID Connect, we also return an ``id_token`` in the response. This token will have been validated by the extension to ensure that it has been signed by the provider and that the token has not expired. You can use this token to get additional information about the user from the provider to enrich your ``User`` object. .. code-block:: javascript-diff - const { auth_token } = await codeExchangeResponse.json(); + const { auth_token, id_token } = await codeExchangeResponse.json(); const isSignUp = requestUrl.searchParams.get("isSignUp"); if (isSignUp === "true") { + const { email, name, locale } = id_token ? await decodeJwt(id_token) : { email: null, name: null, locale: null }; const authedClient = client.withGlobals({ "ext::auth::client_token": auth_token, }); await authedClient.query(` insert User { identity := (global ext::auth::ClientTokenIdentity) + email := email, + name := name, + locale := locale }; - `); + `, { email, name, locale }); } Making authenticated requests to the OAuth resource server ---------------------------------------------------------- Along with the ``auth_token`` which represents the authenticated user's identity within your system, for OAuth providers, we also return a ``provider_token`` (and optionally a ``provider_refresh_token``) that you can use to make requests to the OAuth provider's resource server on behalf of the user. Here is an example of getting the user's profile information from Google utilizing OpenID Connect and the ``provider_token``: .. code-block:: javascript /** * Get the user's profile information from Google */ async function getUserProfile(providerToken) { const response = await fetch( "https://accounts.google.com/.well-known/openid-configuration" ); const discoveryDocument = await response.json(); const response = await fetch(discoveryDocument.userinfo_endpoint, { headers: { Authorization: `Bearer ${providerToken}`, Accept: "application/json", }, }); return await response.json(); } Then in our callback handler, we can use the ``provider_token`` to get the user's profile information and save it into our ``User`` object when we create it: .. code-block:: javascript-diff - const { auth_token } = await codeExchangeResponse.json(); + const { auth_token, provider_token } = await codeExchangeResponse.json(); const isSignUp = requestUrl.searchParams.get("isSignUp"); if (isSignUp === "true") { + const profile = await getUserProfile(provider_token); const authedClient = client.withGlobals({ "ext::auth::client_token": auth_token, }); await authedClient.query( ` + with + email := $email, + name := $name, insert User { + email := email, + name := name, identity := (global ext::auth::ClientTokenIdentity) }; - `); + `, + { email: profile.email, name: profile.name } + ); } res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); :ref:`Back to the Gel Auth guide ` ================================================ FILE: docs/reference/auth/webauthn.rst ================================================ .. _ref_guide_auth_webauthn: ======== WebAuthn ======== :edb-alt-title: Integrating Gel Auth's WebAuthn provider WebAuthn, short for Web Authentication, is a web standard published by the World Wide Web Consortium (W3C) for secure and passwordless authentication on the web. It allows users to log in using biometrics, mobile devices, or FIDO2 security keys instead of traditional passwords. This guide will walk you through integrating WebAuthn authentication with your application using Gel Auth. Why choose WebAuthn? ==================== WebAuthn provides a more secure and user-friendly alternative to passwords and SMS-based OTPs. By leveraging public key cryptography, it significantly reduces the risk of phishing, man-in-the-middle, and replay attacks. For application developers, integrating WebAuthn can enhance security while improving the user experience with seamless, passwordless logins. What is a Passkey? ================== While WebAuthn focuses on authenticating users through cryptographic credentials, Passkeys extend this concept by enabling users to easily access their credentials across devices, including those they haven't used before, without the need for a password. Passkeys are built on the WebAuthn framework and aim to simplify the user experience further by leveraging cloud synchronization of credentials. Many operating systems and password managers have added support for Passkeys, making it easier for users to manage their credentials across devices. Gel Auth's WebAuthn provider supports Passkeys, allowing users to log in to your application using their Passkeys. Security considerations ======================= For maximum flexibility, Gel Auth's WebAuthn provider allows multiple WebAuthn credentials per email. This means that it's very important to verify the email before trusting a WebAuthn credential. This can be done by setting the ``require_verification`` option to ``true`` (which is the default) in your WebAuthn provider configuration. Or you can check the verification status of the factor directly. WebAuthn flow ============= The WebAuthn authentication flow is a sophisticated process that involves a coordinated effort between the server and the client-side script. Unlike the other authentication methods outlined elsewhere in this guide, WebAuthn is a coordinated flow that involves a client-side script access web browser APIs, the Web Authentication API specifically, to interact with the user's authenticator device or passkey. At a high level, the sign-up ceremony involves the following steps: 1. The user initiates the sign-up process by providing their email address. 2. The server generates a JSON object that is used to configure the WebAuthn registration ceremony. 3. The client takes that JSON object, and using the Web Authentication API, interacts with the user's authenticator device to create a new credential. 4. The client sends the credential back to the server. 5. The server verifies the credential and associates it with the user's email address. The sign-in ceremony is similar, but instead of creating a new credential, the client uses the Web Authentication API to authenticate the user with an existing credential. Example implementation ====================== We will demonstrate the various steps below by building a NodeJS HTTP server in a single file that we will use to simulate a typical web application. .. note:: The details below show the inner workings of how data is exchanged with the Auth extension from a web app using HTTP. You can use this as a guide to integrate with your application written in any language that can send and receive HTTP requests. Start the PKCE flow ------------------- We secure authentication tokens and other sensitive data by using PKCE (Proof Key of Code Exchange). Your application server creates a 32-byte Base64 URL-encoded string (which will be 43 bytes after encoding), called the ``verifier``. You need to store this value for the duration of the flow. One way to accomplish this bit of state is to use an HttpOnly cookie when the browser makes a request to the server for this value, which you can then use to retrieve it from the cookie store at the end of the flow. Take this ``verifier`` string, hash it with SHA256, and then base64url encode the resulting string. This new string is called the ``challenge``. .. lint-off .. code-block:: javascript import http from "node:http"; import { URL } from "node:url"; import crypto from "node:crypto"; /** * You can get this value by running `gel instance credentials`. * Value should be: * `${protocol}://${host}:${port}/branch/${branch}/ext/auth/ */ const GEL_AUTH_BASE_URL = process.env.GEL_AUTH_BASE_URL; const SERVER_PORT = 3000; /** * Generate a random Base64 url-encoded string, and derive a "challenge" * string from that string to use as proof that the request for a token * later is made from the same user agent that made the original request * * @returns {Object} The verifier and challenge strings */ const generatePKCE = () => { const verifier = crypto.randomBytes(32).toString("base64url"); const challenge = crypto .createHash("sha256") .update(verifier) .digest("base64url"); return { verifier, challenge }; }; .. lint-on Routing ------- Let's set up the routes we will use to handle the WebAuthn flow. We will then detail each route handler in the following sections. .. lint-off .. code-block:: javascript const server = http.createServer(async (req, res) => { const requestUrl = getRequestUrl(req); switch (requestUrl.pathname) { case "/auth/webauthn/register/options": { await handleRegisterOptions(req, res); break; } case "/auth/webauthn/register": { await handleRegister(req, res); break; } case "/auth/webauthn/authenticate/options": { await handleAuthenticateOptions(req, res); break; } case "/auth/webauthn/authenticate": { await handleAuthenticate(req, res); break; } case "/auth/webauthn/verify": { await handleVerify(req, res); break; } default: { res.writeHead(404); res.end("Not found"); break; } } }); .. lint-on Handle register and authenticate options ---------------------------------------- The first step in the WebAuthn flow is to get the options for registering a new credential or authenticating an existing credential. The server generates a JSON object that is used to configure the WebAuthn registration or authentication ceremony. The Gel Auth extension provides these endpoints directly, so you can either proxy the request to the Auth extension or redirect the user to the Auth extension's URL. We'll show the proxy option here. .. lint-off .. code-block:: javascript const handleRegisterOptions = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const { email } = JSON.parse(body); if (!email) { res.status = 400; res.end( `Request body malformed. Expected JSON body with 'email' key, but got: ${body}`, ); return; } const registerUrl = new URL( "webauthn/register/options", GEL_AUTH_BASE_URL ); registerUrl.searchParams.set("email", email); const registerResponse = await fetch(registerUrl.href); if (!registerResponse.ok) { const text = await registerResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const registerData = await registerResponse.json(); res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify(registerData)); }); }; const handleAuthenticateOptions = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const { email } = JSON.parse(body); if (!email) { res.status = 400; res.end( `Request body malformed. Expected JSON body with 'email' key, but got: ${body}`, ); return; } const authenticateUrl = new URL( "webauthn/authenticate/options", GEL_AUTH_BASE_URL ); authenticateUrl.searchParams.set("email", email); const authenticateResponse = await fetch(authenticateUrl.href); if (!authenticateResponse.ok) { const text = await authenticateResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const authenticateData = await authenticateResponse.json(); res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify(authenticateData)); }); }; .. lint-on Register a new credential ------------------------- The client script will call the Web Authentication API to create a new credential payload and send it to this endpoint. This endpoints job will be to forward the serialized credential payload to the Gel Auth extension for verification, and then associate the credential with the user's email address. .. lint-off .. code-block:: javascript const handleRegister = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const { challenge, verifier } = generatePKCE(); const { email, provider, credentials, verify_url, user_handle } = JSON.parse(body); if (!email || !provider || !credentials || !verify_url || !user_handle) { res.status = 400; res.end( `Request body malformed. Expected JSON body with 'email', 'provider', 'credentials', 'verify_url', and 'user_handle' keys, but got: ${body}`, ); return; } const registerUrl = new URL("webauthn/register", GEL_AUTH_BASE_URL); const registerResponse = await fetch(registerUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ provider, email, credentials, verify_url, user_handle, challenge, }), }); if (!registerResponse.ok) { const text = await registerResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const registerData = await registerResponse.json(); if ("code" in registerData) { const tokenUrl = new URL("token", GEL_AUTH_BASE_URL); tokenUrl.searchParams.set("code", registerData.code); tokenUrl.searchParams.set("verifier", verifier); const tokenResponse = await fetch(tokenUrl.href, { method: "get", }); if (!tokenResponse.ok) { const text = await authenticateResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { auth_token } = await tokenResponse.json(); res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); } else { res.writeHead(204, { "Set-Cookie": `gel-pkce-verifier=${pkce.verifier}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); } }); }; .. lint-on Authenticate with an existing credential ---------------------------------------- The client script will call the Web Authentication API to authenticate with an existing credential and send the assertion to this endpoint. This endpoint's job will be to forward the serialized assertion to the Gel Auth extension for verification. .. lint-off .. code-block:: javascript const handleAuthenticate = async (req, res) => { let body = ""; req.on("data", (chunk) => { body += chunk.toString(); }); req.on("end", async () => { const { challenge, verifier } = generatePKCE(); const { email, provider, assertion } = JSON.parse(body); if (!email || !provider || !assertion) { res.status = 400; res.end( `Request body malformed. Expected JSON body with 'email', 'provider', and 'assertion' keys, but got: ${body}`, ); return; } const authenticateUrl = new URL("webauthn/authenticate", GEL_AUTH_BASE_URL); const authenticateResponse = await fetch(authenticateUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ provider, email, assertion, challenge, }), }); if (!authenticateResponse.ok) { const text = await authenticateResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const authenticateData = await authenticateResponse.json(); if ("code" in authenticateData) { const tokenUrl = new URL("token", GEL_AUTH_BASE_URL); tokenUrl.searchParams.set("code", authenticateData.code); const tokenResponse = await fetch(tokenUrl.href, { method: "get", }); if (!tokenResponse.ok) { const text = await authenticateResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { auth_token } = await tokenResponse.json(); res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); } else { res.writeHead(400, { "Content-Type": "application/json" }); res.end(JSON.stringify({ error: "Email must be verified before being able to authenticate." })); } }); }; .. lint-on Handle email verification ------------------------- When a new user signs up, by default we require them to verify their email address before allowing the application to get an authentication token. To handle the verification flow, we implement an endpoint: .. note:: If your WebAuthn provider uses the **Code** verification method, the verification email contains a one-time code rather than a link. In that case, prompt the user for the code and call ``POST /verify`` with: - **provider**: ``builtin::local_webauthn`` - **email** and **code** - optionally a **challenge** and **redirect_to** to receive a PKCE code or a redirect upon success The Link-based example below continues to work when the provider uses the Link method. .. note:: 💡 If you would like to allow users to still log in, but offer limited access to your application, you can check the associated ``ext::auth::WebAuthnFactor`` for the ``ext::auth::Identity`` to see if the ``verified_at`` property is some time in the past. You'll need to set the ``require_verification`` setting in the provider configuration to ``false``. .. lint-off .. code-block:: javascript /** * Handles the link in the email verification flow. * * @param {Request} req * @param {Response} res */ const handleVerify = async (req, res) => { const requestUrl = getRequestUrl(req); const verification_token = requestUrl.searchParams.get("verification_token"); if (!verification_token) { res.status = 400; res.end( `Verify request is missing 'verification_token' search param. The verification email is malformed.`, ); return; } const cookies = req.headers.cookie?.split("; "); const verifier = cookies ?.find((cookie) => cookie.startsWith("gel-pkce-verifier=")) ?.split("=")[1]; if (!verifier) { res.status = 400; res.end( `Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?`, ); return; } const verifyUrl = new URL("verify", GEL_AUTH_BASE_URL); const verifyResponse = await fetch(verifyUrl.href, { method: "post", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ verification_token, verifier, provider: "builtin::local_webauthn", }), }); if (!verifyResponse.ok) { const text = await verifyResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { code } = await verifyResponse.json(); const tokenUrl = new URL("token", GEL_AUTH_BASE_URL); tokenUrl.searchParams.set("code", code); tokenUrl.searchParams.set("verifier", verifier); const tokenResponse = await fetch(tokenUrl.href, { method: "get", }); if (!tokenResponse.ok) { const text = await tokenResponse.text(); res.status = 400; res.end(`Error from the auth server: ${text}`); return; } const { auth_token } = await tokenResponse.json(); res.writeHead(204, { "Set-Cookie": `gel-auth-token=${auth_token}; HttpOnly; Path=/; Secure; SameSite=Strict`, }); res.end(); }; .. lint-on Client-side script ------------------ On the client-side, you will need to write a script that retrieves the options from the Gel Auth extension, calls the Web Authentication API, and sends the resulting credential or assertion to the server. Writing out the low-level handling of serialization and deserialization of the WebAuthn data is beyond the scope of this guide, but we publish a WebAuthn client library that you can use to simlify this process. The library is available on npm as part of our ``@gel/auth-core`` library. Here is an example of how you might set up a form with appropriate click handlers to perform the WebAuthn sign in and sign up ceremonies. .. lint-off .. code-block:: javascript import { WebAuthnClient } from "@gel/auth-core/webauthn"; const webAuthnClient = new WebAuthnClient({ signupOptionsUrl: "http://localhost:3000/auth/webauthn/register/options", signupUrl: "http://localhost:3000/auth/webauthn/register", signinOptionsUrl: "http://localhost:3000/auth/webauthn/authenticate/options", signinUrl: "http://localhost:3000/auth/webauthn/authenticate", verifyUrl: "http://localhost:3000/auth/webauthn/verify", }); document.addEventListener("DOMContentReady", () => { const signUpButton = document.querySelector("button#sign-up"); const signInButton = document.querySelector("button#sign-in"); const emailInput = document.querySelector("input#email"); if (signUpButton) { signUpButton.addEventListener("click", async (event) => { event.preventDefault(); const email = emailInput.value.trim(); if (!email) { throw new Error("No email provided"); } try { await webAuthnClient.signUp(email); window.location = "http://localhost:3000/signup-success"; } catch (err) { console.error(err); window.location = "http://localhost:3000/signup-error"; } }); } if (signInButton) { signInButton.addEventListener("click", async (event) => { event.preventDefault(); const email = emailInput.value.trim(); if (!email) { throw new Error("No email provided"); } try { await webAuthnClient.signIn(email); window.location = "http://localhost:3000"; } catch (err) { console.error(err); window.location = "http://localhost:3000/signup-error"; } }) } }); .. lint-on ================================================ FILE: docs/reference/auth/webhooks.rst ================================================ .. _ref_auth_webhooks: ======== Webhooks ======== The auth extension supports sending webhooks for a variety of auth events. You can use these webhooks to, for instance, send a fully customized email for email verification, or password reset instead of our built-in email verification and password reset emails. You could also use them to trigger analytics events, start an email drip campaign, create an audit log, or trigger other side effects in your application. If you are using Webhooks to send emails, be sure to not also configure an SMTP provider otherwise we will send the email via SMTP and also send the webhook which will trigger your custom email sending behavior. .. warning:: We send webhooks with no durability or reliability guarantees, so you should always provide a mechanism for retrying delivery of any critical events, such as email verification and password reset. We detail how to resend these events in the relevant sections on the various authentication flows. Configuration ============= You can configure webhooks with the UI or via query. The URLs you register as webhooks must be unique across all webhooks configured for each branch. If you want to send multiple events to the same URL, you can do so by adding multiple ``ext::auth::WebhookEvent`` values to the ``events`` set, like in this example. .. code-block:: edgeql configure current branch insert ext::auth::WebhookConfig { url := 'https://example.com/auth/webhook', events := { ext::auth::WebhookEvent.EmailVerificationRequested, ext::auth::WebhookEvent.PasswordResetRequested, }, # Optional, only needed if you want to verify the webhook request signing_secret_key := '1234567890', }; When you receive a webhook, you'll look at the ``event_type`` field to determine which event corresponds to this webhook request and handle it accordingly. Checking webhook signatures =========================== You can provide a signing key, which you will need to generate and save in a place that your application will have access to. The extension will then add a ``x-ext-auth-signature-sha256`` header to the request, which you can use to verify the request by comparing the signature to the SHA256 hash of the request body. Here is an example of how you might verify the signature in a Node.js application: .. code-block:: typescript /** * Assert that if the request contains a signature header, that the signature * is valid for the request body. Will return false if there is no signature * header. * * @param {Request} request - The request to verify. * @param {string} signingKey - The key to use to verify the signature. * @returns {boolean} - True if the signature is present and valid, false if * the signature is not present at all. * @throws {AssertionError} - If the signature is present but invalid. */ async function assertSignature( request: Request, signingKey: string, ): Promise { const signatureHeader = request.headers.get('x-ext-auth-signature-sha256'); if (!signatureHeader) { return false; } const requestBody = await request.text(); const encoder = new TextEncoder(); const data = encoder.encode(requestBody); const key = await crypto.subtle.importKey( 'raw', encoder.encode(signingKey), { name: 'HMAC', hash: 'SHA-256' }, false, ['sign'] ); const signature = await crypto.subtle.sign('HMAC', key, data); const signatureHex = Buffer.from(signature).toString('hex'); assert.strictEqual( signatureHeader, signatureHex, "Signature header is set, but the signature is invalid" ); return true; }; Troubleshooting webhooks ======================== If you are having trouble receiving webhooks, you might need to look for any responses from the requests that are being scheduled by the :ref:`std::net::http ` module. You can list all of the :eql:type:`net::http::ScheduledRequest` objects, and any returned responses with the following query: .. code-block:: edgeql select net::http::ScheduledRequest { **, response: { ** } } Events reference ================ Common fields for all events: * ``event_type``: (string) This will be a literal string containing the name of the event. You can use this to determine which event occurred. * ``event_id``: (string) A unique identifier to help disambiguate events of the same type. * ``timestamp``: (string) The ISO 8601 timestamp of when the event was triggered. Identity created ^^^^^^^^^^^^^^^^ When a new ``ext::auth::Identity`` object is created, like when a new user signs up, or an existing user adds a new factor, this event is triggered. **Example payload:** .. code-block:: text POST http://localhost:8000/auth/webhook Content-type: application/json x-ext-auth-signature-sha256: 1234567890 { "event_type": "IdentityCreated", "event_id": "1234567890", "timestamp": "2021-01-01T00:00:00Z", "identity_id": "identity123" } Identity authenticated ^^^^^^^^^^^^^^^^^^^^^^ When an ``ext::auth::Identity`` object is authenticated, like when a user logs in, this event is triggered. **Example payload:** .. code-block:: text POST http://localhost:8000/auth/webhook Content-type: application/json x-ext-auth-signature-sha256: 1234567890 { "event_type": "IdentityAuthenticated", "event_id": "1234567890", "timestamp": "2021-01-01T00:00:00Z", "identity_id": "identity123" } Email factor created ^^^^^^^^^^^^^^^^^^^^ When a new ``ext::auth::EmailFactor`` object is created, like when a user adds a new email factor, this event is triggered. **Example payload:** .. code-block:: text POST http://localhost:8000/auth/webhook Content-type: application/json x-ext-auth-signature-sha256: 1234567890 { "event_type": "EmailFactorCreated", "event_id": "1234567890", "timestamp": "2021-01-01T00:00:00Z", "identity_id": "identity123", "email_factor_id": "emailfactor123" } Email verified ^^^^^^^^^^^^^^ When a user verifies their email address, this event is triggered. **Example payload:** .. code-block:: text POST http://localhost:8000/auth/webhook Content-type: application/json x-ext-auth-signature-sha256: 1234567890 { "event_type": "EmailVerified", "event_id": "1234567890", "timestamp": "2021-01-01T00:00:00Z", "identity_id": "identity123", "email_factor_id": "emailfactor123" } Email verification requested ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ When a user requests to verify their email address, like when they first sign up, or requests to resend the verification email, this event is triggered. **Example payload:** .. code-block:: text POST http://localhost:8000/auth/webhook Content-type: application/json x-ext-auth-signature-sha256: 1234567890 { "event_type": "EmailVerificationRequested", "event_id": "1234567890", "timestamp": "2021-01-01T00:00:00Z", "identity_id": "identity123", "verification_token": "verificationtoken123" } Password reset requested ^^^^^^^^^^^^^^^^^^^^^^^^ When a user requests to reset their password, this event is triggered. **Example payload:** .. code-block:: text POST http://localhost:8000/auth/webhook Content-type: application/json x-ext-auth-signature-sha256: 1234567890 { "event_type": "PasswordResetRequested", "event_id": "1234567890", "timestamp": "2021-01-01T00:00:00Z", "identity_id": "identity123", "reset_token": "resettoken123" } Magic link requested ^^^^^^^^^^^^^^^^^^^^ When a user requests to send a magic link email, like for signing in, or signing up for the first time, this event is triggered. **Example payload:** .. code-block:: text POST http://localhost:8000/auth/webhook Content-type: application/json x-ext-auth-signature-sha256: 1234567890 { "event_type": "MagicLinkRequested", "event_id": "1234567890", "timestamp": "2021-01-01T00:00:00Z", "identity_id": "identity123", "email_factor_id": "emailfactor123", "magic_link_token": "magiclinktoken123", "magic_link_url": "http://localhost:8000/auth/magic-link?token=magiclinktoken123" } ================================================ FILE: docs/reference/datamodel/access_policies.rst ================================================ .. _ref_datamodel_access_policies: =============== Access Policies =============== .. index:: object-level security, row-level security, RLS Object types in |Gel| can contain security policies that restrict the set of objects that can be selected, inserted, updated, or deleted by a particular query. This is known as *object-level security* and is similar in function to SQL's row-level security. When no access policies are defined, object-level security is not activated: any properly authenticated client can carry out any operation on any object in the database. Access policies allow you to ensure that the database itself handles access control logic rather than having to implement it in every application or service that connects to your database. Access policies can greatly simplify your backend code, centralizing access control logic in a single place. They can also be extremely useful for implementing AI agentic flows, where you want to have guardrails around your data that agents can't break. We'll illustrate access policies in this document with this simple schema: .. code-block:: sdl type User { required email: str { constraint exclusive; } } type BlogPost { required title: str; required author: User; } .. warning:: Once a policy is added to a particular object type, **all operations** (``select``, ``insert``, ``delete``, ``update``, etc.) on any object of that type are now *disallowed by default* unless specifically allowed by an access policy! See :ref:`resolution order ` below for details. Global variables ================ Global variables are a convenient way to set up the context for your access policies. Gel's global variables are tightly integrated with the Gel's data model, client APIs, EdgeQL and SQL, and the tooling around them. Global variables in Gel are not pre-defined. Users are free to define as many globals in their schema as they want to represent the business logic of their application. A common scenario is storing a ``current_user`` global representing the user executing queries. We'd like to have a slightly more complex example showing that you can use more than one global variable. Let's do that: * We'll use one *global* ``uuid`` to represent the identity of the user executing the query. * We'll have the ``Country`` *enum* to represent the type of country that the user is currently in. The enum represents three types of countries: those where the service has not been rolled out, those with read-only access, and those with full access. * We'll use the ``current_country`` *global* to represent the user's current country. In our *example schema*, we want *country* to be context-specific: the same user who can access certain content in one country might not be able to in another country (let's imagine that's due to different country-specific legal frameworks). Here is an illustration: .. code-block:: sdl-diff + scalar type Country extending enum; + global current_user: uuid; + required global current_country: Country { + default := Country.None + } type User { required email: str { constraint exclusive; } } type BlogPost { required title: str; required author: User; } You can set and reset these globals in Gel client libraries, for example: .. tabs:: .. code-tab:: typescript import createClient from 'gel'; const client = createClient(); // 'authedClient' will share the network connection with 'client', // but will have the 'current_user' global set. const authedClient = client.withGlobals({ current_user: '2141a5b4-5634-4ccc-b835-437863534c51', }); const result = await authedClient.query( `select global current_user;`); console.log(result); .. code-tab:: python from gel import create_client client = create_client().with_globals({ 'current_user': '580cc652-8ab8-4a20-8db9-4c79a4b1fd81' }) result = client.query(""" select global current_user; """) print(result) .. code-tab:: go package main import ( "context" "fmt" "log" "github.com/geldata/gel-go" ) func main() { ctx := context.Background() client, err := gel.CreateClient(ctx, gel.Options{}) if err != nil { log.Fatal(err) } defer client.Close() id, err := gel.ParseUUID("2141a5b4-5634-4ccc-b835-437863534c51") if err != nil { log.Fatal(err) } var result gel.UUID err = client. WithGlobals(map[string]interface{}{"current_user": id}). QuerySingle(ctx, "SELECT global current_user;", &result) if err != nil { log.Fatal(err) } fmt.Println(result) } .. code-tab:: rust use gel_protocol::{ model::Uuid, value::EnumValue }; let client = gel_tokio::create_client() .await .expect("Client should init") .with_globals_fn(|c| { c.set( "current_user", Value::Uuid( Uuid::parse_str("2141a5b4-5634-4ccc-b835-437863534c51") .expect("Uuid should have parsed"), ), ); c.set( "current_country", Value::Enum(EnumValue::from("Full")) ); }); client .query_required_single::("select global current_user;", &()) .await .expect("Returning value"); Defining policies ================= A policy example for our simple blog schema might look like: .. code-block:: sdl-diff global current_user: uuid; required global current_country: Country { default := Country.None } scalar type Country extending enum; type User { required email: str { constraint exclusive; } } type BlogPost { required title: str; required author: User; + access policy author_has_full_access + allow all + using (global current_user ?= .author.id + and global current_country ?= Country.Full) { + errmessage := "User does not have full access"; + } + access policy author_has_read_access + allow select + using (global current_user ?= .author.id + and global current_country ?= Country.ReadOnly); } Explanation: - ``access policy `` introduces a new policy in an object type. - ``allow all`` grants ``select``, ``insert``, ``update``, and ``delete`` access if the condition passes. We also used a separate policy to allow only ``select`` in some cases. - ``using ()`` is a boolean filter restricting the set of objects to which the policy applies. (We used the coalescing operator ``?=`` to handle empty sets gracefully.) - ``errmessage`` is an optional custom message to display in case of a write violation. Let's run some experiments in the REPL: .. code-block:: edgeql-repl db> insert User { email := "test@example.com" }; {default::User {id: be44b326-03db-11ed-b346-7f1594474966}} db> set global current_user := ... "be44b326-03db-11ed-b346-7f1594474966"; OK: SET GLOBAL db> set global current_country := Country.Full; OK: SET GLOBAL db> insert BlogPost { ... title := "My post", ... author := (select User filter .id = global current_user) ... }; {default::BlogPost {id: e76afeae-03db-11ed-b346-fbb81f537ca6}} Because the user is in a "full access" country and the current user ID matches the author, the new blog post is permitted. When the same user sets ``global current_country := Country.ReadOnly;``: .. code-block:: edgeql-repl db> set global current_country := Country.ReadOnly; OK: SET GLOBAL db> select BlogPost; {default::BlogPost {id: e76afeae-03db-11ed-b346-fbb81f537ca6}} db> insert BlogPost { ... title := "My second post", ... author := (select User filter .id = global current_user) ... }; gel error: AccessPolicyError: access policy violation on insert of default::BlogPost (User does not have full access) Finally, let's unset ``current_user`` and see how many blog posts are returned when we count them. .. code-block:: edgeql-repl db> set global current_user := {}; OK: SET GLOBAL db> select BlogPost; {} db> select count(BlogPost); {0} ``select BlogPost`` returns zero results in this case as well. We can only ``select`` the *posts* written by the *user* specified by ``current_user``. When ``current_user`` has no value or has a different value from the ``.author.id`` of any existing ``BlogPost`` objects, we can't read any posts. But thanks to ``Country`` being set to ``Country.Full``, this user will be able to write a new blog post. **The bottom line:** access policies use global variables to define a "subgraph" of data that is visible to your queries. Policy types ============ .. api-index:: select, insert, delete, update read, update write, all The types of policy rules map to the statement type in EdgeQL: - ``select``: Controls which objects are visible to any query. - ``insert``: Post-insert check. If the inserted object violates the policy, the operation fails. - ``delete``: Controls which objects can be deleted. - ``update read``: Pre-update check on which objects can be updated at all. - ``update write``: Post-update check for how objects can be updated. - ``all``: Shorthand for granting or denying ``select, insert, update, delete``. Resolution order ================ If multiple policies apply (some are ``allow`` and some are ``deny``), the logic is: 1. If there are no policies, access is allowed. 2. All ``allow`` policies collectively form a *union* / *or* of allowed sets. 3. All ``deny`` policies *subtract* from that union, overriding allows! 4. The final set of objects is the intersection of the above logic for each operation: ``select, insert, update read, update write, delete``. By default, once you define any policy on an object type, you must explicitly allow the operations you need. This is a common **pitfall** when you are starting out with access policies (but you will develop an intuition for this quickly). Let's look at an example: .. code-block:: sdl global current_user_id: uuid; global current_user := ( select User filter .id = global current_user_id ); type User { required email: str { constraint exclusive; } required is_admin: bool { default := false }; access policy admin_only allow all using (global current_user.is_admin ?? false); } type BlogPost { required title: str; author: User; access policy author_has_full_access allow all using (global current_user ?= .author.id); } In the above schema only admins will see a non-empty ``author`` link when running ``select BlogPost { author }``. Why? Because only admins can see ``User`` objects at all: ``admin_only`` policy is the only one defined on the ``User`` type! This means that instead of making ``BlogPost`` visible to its author, all non-admin authors won't be able to see their own posts. The above issue can be remedied by making the current user able to see their own ``User`` record. Interaction between policies ============================ Policy expressions themselves do not take other policies into account (since |EdgeDB| 3). This makes it easier to reason about policies. Custom error messages ===================== When an ``insert`` or ``update write`` violates an access policy, Gel will raise a generic ``AccessPolicyError``: .. code-block:: gel error: AccessPolicyError: access policy violation on insert of .. note:: Restricted access is represented either as an error message or an empty set, depending on the filtering order of the operation. The operations ``select``, ``delete``, or ``update read`` filter up front, and thus you simply won't get the data that is being restricted. Other operations (``insert`` and ``update write``) will return an error message. If multiple policies are in effect, it can be helpful to define a distinct ``errmessage`` in your policy: .. code-block:: sdl-diff global current_user_id: uuid; global current_user := ( select User filter .id = global current_user_id ); type User { required email: str { constraint exclusive; }; required is_admin: bool { default := false }; access policy admin_only allow all + using (global current_user.is_admin ?? false) { + errmessage := 'Only admins may query Users' + }; } type BlogPost { required title: str; author: User; access policy author_has_full_access allow all + using (global current_user ?= .author) { + errmessage := 'BlogPosts may only be queried by their authors' + }; } Now if you attempt, for example, a ``User`` insert as a non-admin user, you will receive this error: .. code-block:: gel error: AccessPolicyError: access policy violation on insert of default::User (Only admins may query Users) Disabling policies ================== .. api-index:: apply_access_policies You may disable all access policies by setting the ``apply_access_policies`` :ref:`configuration parameter ` to ``false``. You may also temporarily disable access policies using the Gel UI configuration checkbox (or via :gelcmd:`ui`), which only applies to your UI session. More examples ============= Here are some additional patterns: 1. Publicly visible blog posts, only writable by the author: .. code-block:: sdl-diff global current_user: uuid; type User { required email: str { constraint exclusive; } } type BlogPost { required title: str; required author: User; + required published: bool { default := false }; access policy author_has_full_access allow all using (global current_user ?= .author.id); + access policy visible_if_published + allow select + using (.published); } 2. Visible to friends, only modifiable by the author: .. code-block:: sdl-diff global current_user: uuid; type User { required email: str { constraint exclusive; } + multi friends: User; } type BlogPost { required title: str; required author: User; access policy author_has_full_access allow all using (global current_user ?= .author.id); + access policy friends_can_read + allow select + using ((global current_user in .author.friends.id) ?? false); } 3. Publicly visible except to those blocked by the author: .. code-block:: sdl-diff type User { required email: str { constraint exclusive; } + multi blocked: User; } type BlogPost { required title: str; required author: User; access policy author_has_full_access allow all using (global current_user ?= .author.id); + access policy anyone_can_read + allow select; + access policy exclude_blocked + deny select + using ((global current_user in .author.blocked.id) ?? false); } 4. "Disappearing" posts that become invisible after 24 hours: .. code-block:: sdl-diff type User { required email: str { constraint exclusive; } } type BlogPost { required title: str; required author: User; + required created_at: datetime { + default := datetime_of_statement() # non-volatile + } access policy author_has_full_access allow all using (global current_user ?= .author.id); + access policy hide_after_24hrs + allow select + using ( + datetime_of_statement() - .created_at < '24 hours' + ); } Super constraints ================= Access policies can act like "super constraints." For instance, a policy on ``insert`` or ``update write`` can do a post-write validity check, rejecting the operation if a certain condition is not met. E.g. here's a policy that limits the number of blog posts a ``User`` can post: .. code-block:: sdl-diff type User { required email: str { constraint exclusive; } + multi posts := . 500); } .. _ref_eql_sdl_access_policies: .. _ref_eql_sdl_access_policies_syntax: Declaring access policies ========================= .. api-index:: access policy, when, allow, deny, all, select, insert, delete, update, update read, update write, using, errmessage This section describes the syntax to declare access policies in your schema. Syntax ------ .. sdl:synopsis:: access policy [ when () ] { allow | deny } [, ... ] [ using () ] [ "{" [ errmessage := value ; ] [ ] "}" ] ; # where is one of all select insert delete update [{ read | write }] Where: :eql:synopsis:`` The name of the access policy. :eql:synopsis:`when ()` Specifies which objects this policy applies to. The :eql:synopsis:`` has to be a :eql:type:`bool` expression. When omitted, it is assumed that this policy applies to all objects of a given type. :eql:synopsis:`allow` Indicates that qualifying objects should allow access under this policy. :eql:synopsis:`deny` Indicates that qualifying objects should *not* allow access under this policy. This flavor supersedes any :eql:synopsis:`allow` policy and can be used to selectively deny access to a subset of objects that otherwise explicitly allows accessing them. :eql:synopsis:`all` Apply the policy to all actions. It is exactly equivalent to listing :eql:synopsis:`select`, :eql:synopsis:`insert`, :eql:synopsis:`delete`, :eql:synopsis:`update` actions explicitly. :eql:synopsis:`select` Apply the policy to all selection queries. Note that any object that cannot be selected, cannot be modified either. This makes :eql:synopsis:`select` the most basic "visibility" policy. :eql:synopsis:`insert` Apply the policy to all inserted objects. If a newly inserted object would violate this policy, an error is produced instead. :eql:synopsis:`delete` Apply the policy to all objects about to be deleted. If an object does not allow access under this kind of policy, it is not going to be considered by any :eql:stmt:`delete` command. Note that any object that cannot be selected, cannot be modified either. :eql:synopsis:`update read` Apply the policy to all objects selected for an update. If an object does not allow access under this kind of policy, it is not visible cannot be updated. Note that any object that cannot be selected, cannot be modified either. :eql:synopsis:`update write` Apply the policy to all objects at the end of an update. If an updated object violates this policy, an error is produced instead. Note that any object that cannot be selected, cannot be modified either. :eql:synopsis:`update` This is just a shorthand for :eql:synopsis:`update read` and :eql:synopsis:`update write`. Note that any object that cannot be selected, cannot be modified either. :eql:synopsis:`using ` Specifies what the policy is with respect to a given eligible (based on :eql:synopsis:`when` clause) object. The :eql:synopsis:`` has to be a :eql:type:`bool` expression. The specific meaning of this value also depends on whether this policy flavor is :eql:synopsis:`allow` or :eql:synopsis:`deny`. The expression must be :ref:`Stable `. When omitted, it is assumed that this policy applies to all eligible objects of a given type. :eql:synopsis:`set errmessage := ` Set a custom error message of :eql:synopsis:`` that is displayed when this access policy prevents a write action. :sdl:synopsis:`` Set access policy :ref:`annotation ` to a given *value*. Any sub-type extending a type inherits all of its access policies. You can define additional access policies on sub-types. .. _ref_eql_ddl_access_policies: DDL commands ============ This section describes the low-level DDL commands for creating, altering, and dropping access policies. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create access policy -------------------- :eql-statement: Define a new object access policy on a type: .. eql:synopsis:: [ with [, ...] ] { create | alter } type "{" [ ... ] create access policy [ when () ; ] { allow | deny } action [, action ... ; ] [ using () ; ] [ "{" [ set errmessage := value ; ] [ create annotation := value ; ] "}" ] "}" # where is one of all select insert delete update [{ read | write }] See the meaning of each parameter in the `Declaring access policies`_ section. The following subcommands are allowed in the ``create access policy`` block: :eql:synopsis:`set errmessage := ` Set a custom error message of :eql:synopsis:`` that is displayed when this access policy prevents a write action. :eql:synopsis:`create annotation := ` Set access policy annotation :eql:synopsis:`` to :eql:synopsis:``. See :eql:stmt:`create annotation` for details. Alter access policy ------------------- :eql-statement: Modify an existing access policy: .. eql:synopsis:: [ with [, ...] ] alter type "{" [ ... ] alter access policy "{" [ when () ; ] [ reset when ; ] { allow | deny } [, ... ; ] [ using () ; ] [ set errmessage := value ; ] [ reset expression ; ] [ create annotation := ; ] [ alter annotation := ; ] [ drop annotation ; ] "}" "}" You can change the policy's condition, actions, or error message, or add/drop annotations. The parameters describing the action policy are identical to the parameters used by ``create action policy``. There are a handful of additional subcommands that are allowed in the ``alter access policy`` block: :eql:synopsis:`reset when` Clear the :eql:synopsis:`when ()` so that the policy applies to all objects of a given type. This is equivalent to ``when (true)``. :eql:synopsis:`reset expression` Clear the :eql:synopsis:`using ()` so that the policy always passes. This is equivalent to ``using (true)``. :eql:synopsis:`alter annotation ;` Alter access policy annotation :eql:synopsis:``. See :eql:stmt:`alter annotation` for details. :eql:synopsis:`drop annotation ;` Remove access policy annotation :eql:synopsis:``. See :eql:stmt:`drop annotation` for details. All the subcommands allowed in the ``create access policy`` block are also valid subcommands for ``alter access policy`` block. Drop access policy ------------------ :eql-statement: Remove an existing policy: .. eql:synopsis:: [ with [, ...] ] alter type "{" [ ... ] drop access policy ; "}" ================================================ FILE: docs/reference/datamodel/aliases.rst ================================================ .. _ref_datamodel_aliases: ======= Aliases ======= .. index:: alias, virtual type You can think of *aliases* as a way to give schema names to arbitrary EdgeQL expressions. You can later refer to aliases in queries and in other aliases. Aliases are functionally equivalent to expression aliases defined in EdgeQL statements in :ref:`with block `, but are available to all queries using the schema and can be introspected. Like computed properties, the aliased expression is evaluated on the fly whenever the alias is referenced. Scalar alias ============ .. code-block:: sdl # in your schema: alias digits := {0,1,2,3,4,5,6,7,8,9}; Later, in some query: .. code-block:: edgeql select count(digits); Object type alias ================= The name of a given object type (e.g. ``User``) is itself a pointer to the *set of all User objects*. After declaring the alias below, you can use ``User`` and ``UserAlias`` interchangeably: .. code-block:: sdl alias UserAlias := User; Object type alias with computeds ================================ Object type aliases can include a *shape* that declares additional computed properties or links: .. code-block:: sdl type Post { required title: str; } alias PostWithTrimmedTitle := Post { trimmed_title := str_trim(.title) } Later, in some query: .. code-block:: edgeql select PostWithTrimmedTitle { trimmed_title }; Arbitrary expressions ===================== Aliases can correspond to any arbitrary EdgeQL expression, including entire queries. .. code-block:: sdl # Tuple alias alias Color := ("Purple", 128, 0, 128); # Named tuple alias alias GameInfo := ( name := "Li Europan Lingues", country := "Iceland", date_published := 2023, creators := ( (name := "Bob Bobson", age := 20), (name := "Trina Trinadóttir", age := 25), ), ); type BlogPost { required title: str; required is_published: bool; } # Query alias alias PublishedPosts := ( select BlogPost filter .is_published = true ); .. note:: All aliases are reflected in the database's built-in :ref:`GraphQL schema `. .. _ref_eql_sdl_aliases: .. _ref_eql_sdl_aliases_syntax: Defining aliases ================ .. api-index:: alias Syntax ------ Define a new alias corresponding to the :ref:`more explicit DDL commands `. .. sdl:synopsis:: alias := ; alias "{" using ; [ ] "}" ; Where: :eql:synopsis:`` The name (optionally module-qualified) of an alias to be created. :eql:synopsis:`` The aliased expression. Must be a :ref:`Stable ` EdgeQL expression. The valid SDL sub-declarations are listed below: :sdl:synopsis:`` Set alias :ref:`annotation ` to a given *value*. .. _ref_eql_ddl_aliases: DDL commands ============ This section describes the low-level DDL commands for creating and dropping aliases. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create alias ------------ :eql-statement: :eql-haswith: Define a new alias in the schema. .. eql:synopsis:: [ with [, ...] ] create alias := ; [ with [, ...] ] create alias "{" using ; [ create annotation := ; ... ] "}" ; # where is: [ := ] module Parameters ^^^^^^^^^^ Most sub-commands and options of this command are identical to the :ref:`SDL alias declaration `, with some additional features listed below: :eql:synopsis:`[ := ] module ` An optional list of module alias declarations to be used in the alias definition. :eql:synopsis:`create annotation := ;` An optional list of annotation values for the alias. See :eql:stmt:`create annotation` for details. Example ^^^^^^^ Create a new alias: .. code-block:: edgeql create alias Superusers := ( select User filter User.groups.name = 'Superusers' ); Drop alias ---------- :eql-statement: :eql-haswith: Remove an alias from the schema. .. eql:synopsis:: [ with [, ...] ] drop alias ; Parameters ^^^^^^^^^^ *alias-name* The name (optionally qualified with a module name) of an existing expression alias. Example ^^^^^^^ Remove an alias: .. code-block:: edgeql drop alias SuperUsers; ================================================ FILE: docs/reference/datamodel/annotations.rst ================================================ .. _ref_datamodel_annotations: .. _ref_eql_sdl_annotations: =========== Annotations =========== *Annotations* are named values associated with schema items and are designed to hold arbitrary schema-level metadata represented as a :eql:type:`str` (unstructured text). Users can store JSON-encoded data in annotations if they need to store more complex metadata. Standard annotations ==================== .. api-index:: title, description, deprecated There are a number of annotations defined in the standard library. The following are the annotations which can be set on any schema item: - ``std::title`` - ``std::description`` - ``std::deprecated`` For example, consider the following declaration: .. code-block:: sdl type Status { annotation title := 'Activity status'; annotation description := 'All possible user activities'; required name: str { constraint exclusive } } And the ``std::deprecated`` annotation can be used to mark deprecated items (e.g., :eql:func:`str_rpad`) and to provide some information such as what should be used instead. User-defined annotations ======================== To declare a custom annotation type beyond the three built-ins, add an abstract annotation type to your schema. A custom annotation could be used to attach arbitrary JSON-encoded data to your schema—potentially useful for introspection and code generation. .. code-block:: sdl abstract annotation admin_note; type Status { annotation admin_note := 'system-critical'; } .. _ref_eql_sdl_annotations_syntax: Declaring annotations ===================== .. api-index:: abstract, inheritable, annotation This section describes the syntax to use annotations in your schema. Syntax ------ .. sdl:synopsis:: # Abstract annotation form: abstract [ inheritable ] annotation [ "{" ; [...] "}" ] ; # Concrete annotation (same as ) form: annotation := ; Description ^^^^^^^^^^^ There are two forms of annotation declarations: abstract and concrete. The *abstract annotation* form is used for declaring new kinds of annotation in a module. The *concrete annotation* declarations are used as sub-declarations for all other declarations in order to actually annotate them. The annotation declaration options are as follows: :eql:synopsis:`abstract` If specified, the annotation will be *abstract*. :eql:synopsis:`inheritable` If specified, the annotation will be *inheritable*. The annotations are non-inheritable by default. That is, if a schema item has an annotation defined on it, the descendants of that schema item will not automatically inherit the annotation. Normal inheritance behavior can be turned on by declaring the annotation with the ``inheritable`` qualifier. This is only valid for *abstract annotation*. :eql:synopsis:`` The name (optionally module-qualified) of the annotation. :eql:synopsis:`` Any string value that the specified annotation is intended to have for the given context. The only valid SDL sub-declarations are *concrete annotations*: :sdl:synopsis:`` Annotations can also have annotations. Set the *annotation* of the enclosing annotation to a specific value. .. _ref_eql_ddl_annotations: DDL commands ============ This section describes the low-level DDL commands for creating, altering, and dropping annotations and abstract annotations. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create abstract annotation -------------------------- :eql-statement: Define a new annotation. .. eql:synopsis:: [ with [, ...] ] create abstract [ inheritable ] annotation [ "{" create annotation := ; [...] "}" ] ; Description ^^^^^^^^^^^ The command ``create abstract annotation`` defines a new annotation for use in the current Gel database. If *name* is qualified with a module name, then the annotation is created in that module, otherwise it is created in the current module. The annotation name must be distinct from that of any existing schema item in the module. The annotations are non-inheritable by default. That is, if a schema item has an annotation defined on it, the descendants of that schema item will not automatically inherit the annotation. Normal inheritance behavior can be turned on by declaring the annotation with the ``inheritable`` qualifier. Most sub-commands and options of this command are identical to the :ref:`SDL annotation declaration `. There's only one subcommand that is allowed in the ``create annotation`` block: :eql:synopsis:`create annotation := ` Annotations can also have annotations. Set the :eql:synopsis:`` of the enclosing annotation to a specific :eql:synopsis:``. See :eql:stmt:`create annotation` for details. Example ^^^^^^^ Declare an annotation ``extrainfo``: .. code-block:: edgeql create abstract annotation extrainfo; Alter abstract annotation ------------------------- :eql-statement: Change the definition of an annotation. .. eql:synopsis:: alter abstract annotation [ "{" ] ; [...] [ "}" ]; # where is one of rename to create annotation := alter annotation := drop annotation Description ^^^^^^^^^^^ :eql:synopsis:`alter abstract annotation` changes the definition of an abstract annotation. Parameters ^^^^^^^^^^ :eql:synopsis:`` The name (optionally module-qualified) of the annotation to alter. The following subcommands are allowed in the ``alter abstract annotation`` block: :eql:synopsis:`rename to ` Change the name of the annotation to :eql:synopsis:``. :eql:synopsis:`alter annotation := ` Annotations can also have annotations. Change :eql:synopsis:`` to a specific :eql:synopsis:``. See :eql:stmt:`alter annotation` for details. :eql:synopsis:`drop annotation ` Annotations can also have annotations. Remove annotation :eql:synopsis:``. See :eql:stmt:`drop annotation` for details. All the subcommands allowed in the ``create abstract annotation`` block are also valid subcommands for ``alter annotation`` block. Example ^^^^^^^ Rename an annotation: .. code-block:: edgeql alter abstract annotation extrainfo rename to extra_info; Drop abstract annotation ------------------------ :eql-statement: Drop a schema annotation. .. eql:synopsis:: [ with [, ...] ] drop abstract annotation ; Description ^^^^^^^^^^^ The command ``drop abstract annotation`` removes an existing schema annotation from the database schema. Note that the ``inheritable`` qualifier is not necessary in this statement. Example ^^^^^^^ Drop the annotation ``extra_info``: .. code-block:: edgeql drop abstract annotation extra_info; Create annotation ----------------- :eql-statement: Define an annotation value for a given schema item. .. eql:synopsis:: create annotation := Description ^^^^^^^^^^^ The command ``create annotation`` defines an annotation for a schema item. :eql:synopsis:`` refers to the name of a defined annotation, and :eql:synopsis:`` must be a constant EdgeQL expression evaluating into a string. This statement can only be used as a subcommand in another DDL statement. Example ^^^^^^^ Create an object type ``User`` and set its ``title`` annotation to ``"User type"``. .. code-block:: edgeql create type User { create annotation title := "User type"; }; Alter annotation ---------------- :eql-statement: Alter an annotation value for a given schema item. .. eql:synopsis:: alter annotation := Description ^^^^^^^^^^^ The command ``alter annotation`` alters an annotation value on a schema item. :eql:synopsis:`` refers to the name of a defined annotation, and :eql:synopsis:`` must be a constant EdgeQL expression evaluating into a string. This statement can only be used as a subcommand in another DDL statement. Example ^^^^^^^ Alter an object type ``User`` and alter the value of its previously set ``title`` annotation to ``"User type"``. .. code-block:: edgeql alter type User { alter annotation title := "User type"; }; Drop annotation --------------- :eql-statement: Remove an annotation from a given schema item. .. eql:synopsis:: drop annotation ; Description ^^^^^^^^^^^ The command ``drop annotation`` removes an annotation value from a schema item. :eql:synopsis:`` refers to the name of a defined annotation. The annotation value does not have to exist on a schema item. This statement can only be used as a subcommand in another DDL statement. Example ^^^^^^^ Drop the ``title`` annotation from the ``User`` object type: .. code-block:: edgeql alter type User { drop annotation title; }; .. list-table:: :class: seealso * - **See also** * - :ref:`Cheatsheets > Annotations ` * - :ref:`Introspection > Object types ` ================================================ FILE: docs/reference/datamodel/branches.rst ================================================ .. _ref_datamodel_branches: .. _ref_datamodel_databases: .. versionadded:: 5.0 ======== Branches ======== Gel's |branches| are equivalent to PostgreSQL's *databases* and map to them directly. Gel comes with tooling to help manage branches and build a development workflow around them. E.g. when developing locally you can map your Gel branches to your Git branches, and when using Gel Cloud and GitHub you can have a branch per PR. CLI commands ============ Refer to the :ref:`gel branch ` command group for details on the CLI commands for managing branches. .. _ref_admin_branches: DDL commands ============ These are low-level commands that are used to create, alter, and drop branches. You can use them when experimenting in REPL, of if you want to create your own tools to manage Gel branches. Create empty branch ------------------- :eql-statement: Create a new branch without schema or data. .. eql:synopsis:: create empty branch ; Description ^^^^^^^^^^^ The command ``create empty branch`` creates a new Gel branch without schema or data, aside from standard schemas. Example ^^^^^^^ Create a new empty branch: .. code-block:: edgeql create empty branch newbranch; Create schema branch -------------------- :eql-statement: Create a new branch copying the schema (without data)of an existing branch. .. eql:synopsis:: create schema branch from ; Description ^^^^^^^^^^^ The command ``create schema branch`` creates a new Gel branch with schema copied from an already existing branch. Example ^^^^^^^ Create a new schema branch: .. code-block:: edgeql create schema branch feature from main; Create data branch ------------------ :eql-statement: Create a new branch copying the schema and data of an existing branch. .. eql:synopsis:: create data branch from ; Description ^^^^^^^^^^^ The command ``create data branch`` creates a new Gel branch with schema and data copied from an already existing branch. Example ^^^^^^^ Create a new data branch: .. code-block:: edgeql create data branch feature from main; Drop branch ----------- :eql-statement: Remove a branch. .. eql:synopsis:: drop branch ; Description ^^^^^^^^^^^ The command ``drop branch`` removes an existing branch. It cannot be executed while there are existing connections to the target branch. .. warning:: Executing ``drop branch`` removes data permanently and cannot be undone. Example ^^^^^^^ Remove a branch: .. code-block:: edgeql drop branch appdb; Alter branch ------------ :eql-statement: Rename a branch. .. eql:synopsis:: alter branch rename to ; Description ^^^^^^^^^^^ The command ``alter branch … rename`` changes the name of an existing branch. It cannot be executed while there are existing connections to the target branch. Example ^^^^^^^ Rename a branch: .. code-block:: edgeql alter branch featuer rename to feature; Database (deprecated) ===================== Versions of Gel prior to 5.0 used the term *database* to refer to branches. Create database --------------- :eql-statement: Create a new database. .. eql:synopsis:: create database ; Description ^^^^^^^^^^^ The command ``create database`` creates a new Gel database. The new database will be created with all standard schemas prepopulated. Examples ^^^^^^^^ Create a new database: .. code-block:: edgeql create database appdb; Drop database ------------- :eql-statement: Remove a database. .. eql:synopsis:: drop database ; Description ^^^^^^^^^^^ The command ``drop database`` removes an existing database. It cannot be executed while there are existing connections to the target database. .. warning:: Executing ``drop database`` removes data permanently and cannot be undone. Examples ^^^^^^^^ Remove a database: .. code-block:: edgeql drop database appdb; ================================================ FILE: docs/reference/datamodel/comparison.rst ================================================ .. _ref_datamodel_comparison: =============== vs SQL and ORMs =============== |Gel's| approach to schema modeling builds upon the foundation of SQL while taking cues from modern tools like ORM libraries. Let's see how it stacks up. .. _ref_datamodel_sql_comparison: Comparison to SQL ----------------- When using SQL databases, there's no convenient representation of the schema. Instead, the schema only exists as a series of ``{CREATE|ALTER|DELETE} {TABLE| COLUMN}`` commands, usually spread across several SQL migration scripts. There's no simple way to see the current state of your schema at a glance. Moreover, SQL stores data in a *relational* way. Connections between tables are represented with foreign key constraints and ``JOIN`` operations are required to query across tables. .. code-block:: CREATE TABLE people ( id uuid PRIMARY KEY, name text, ); CREATE TABLE movies ( id uuid PRIMARY KEY, title text, director_id uuid REFERENCES people(id) ); In |Gel|, connections between tables are represented with :ref:`Links `. .. code-block:: sdl type Movie { required title: str; required director: Person; } type Person { required name: str; } This approach makes it simple to write queries that traverse this link, no JOINs required. .. code-block:: edgeql select Movie { title, director: { name } } .. _ref_datamodel_orm_comparison: Comparison to ORMs ------------------ Object-relational mapping libraries are popular for a reason. They provide a way to model your schema and write queries in a way that feels natural in the context of modern, object-oriented programming languages. But ORMs have downsides too. - **Lock-in**. Your schema is strongly coupled to the ORM library you are using. More generally, this also locks you into using a particular programming language. - Most ORMs have more **limited querying capabilities** than the query languages they abstract. - Many ORMs produce **suboptimal queries** that can have serious performance implications. - **Migrations** can be difficult. Since most ORMs aim to be the single source of truth for your schema, they necessarily must provide some sort of migration tool. These migration tools are maintained by the contributors to the ORM library, not the maintainers of the database itself. Quality control and long-term maintenance is not always guaranteed. From the beginning, Gel was designed to incorporate the best aspects of ORMs — declarative modeling, object-oriented APIs, and intuitive querying — without the drawbacks. ================================================ FILE: docs/reference/datamodel/computeds.rst ================================================ .. _ref_datamodel_computed: ========= Computeds ========= :edb-alt-title: Computed properties and links .. api-index:: := .. important:: This section assumes a basic understanding of EdgeQL. If you aren't familiar with it, feel free to skip this page for now. Object types can contain *computed* properties and links. Computed properties and links are not persisted in the database. Instead, they are evaluated *on the fly* whenever that field is queried. Computed properties must be declared with the ``property`` keyword and computed links must be declared with the ``link`` keyword in |EdgeDB| versions prior to 4.0. .. code-block:: sdl type Person { name: str; all_caps_name := str_upper(__source__.name); } Computed fields are associated with an EdgeQL expression. This expression can be an *arbitrary* EdgeQL query. This expression is evaluated whenever the field is referenced in a query. .. note:: Computed fields don't need to be pre-defined in your schema; you can drop them into individual queries as well. They behave in exactly the same way. For more information, see the :ref:`EdgeQL > Select > Computeds `. .. warning:: :ref:`Volatile and modifying ` expressions are not allowed in computed properties defined in schema. This means that, for example, your schema-defined computed property cannot call :eql:func:`datetime_current`, but it *can* call :eql:func:`datetime_of_transaction` or :eql:func:`datetime_of_statement`. This does *not* apply to computed properties outside of schema. .. _ref_dot_notation: Leading dot notation -------------------- .. api-index:: __source__ The example above used the special keyword ``__source__`` to refer to the current object; it's analogous to ``this`` or ``self`` in many object-oriented languages. However, explicitly using ``__source__`` is optional here; inside the scope of an object type declaration, you can omit it entirely and use the ``.`` shorthand. .. code-block:: sdl type Person { first_name: str; last_name: str; full_name := .first_name ++ ' ' ++ .last_name; } Type and cardinality inference ------------------------------ The type and cardinality of a computed field is *inferred* from the expression. There's no need for the modifier keywords you use for non-computed fields (like ``multi`` and ``required``). However, it's common to specify them anyway; it makes the schema more readable and acts as a sanity check: if the provided EdgeQL expression disagrees with the modifiers, an error will be thrown the next time you try to :ref:`create a migration `. .. code-block:: sdl type Person { first_name: str; # this is invalid, because first_name is not a required property required first_name_upper := str_upper(.first_name); } Common use cases ---------------- Filtering ^^^^^^^^^ If you find yourself writing the same ``filter`` expression repeatedly in queries, consider defining a computed field that encapsulates the filter. .. code-block:: sdl type Club { multi members: Person; multi active_members := ( select .members filter .is_active = true ) } type Person { name: str; is_active: bool; } .. _ref_datamodel_links_backlinks: Backlinks ^^^^^^^^^ Backlinks are one of the most common use cases for computed links. In |Gel| links are *directional*; they have a source and a target. Often it's convenient to traverse a link in the *reverse* direction. .. code-block:: sdl type BlogPost { title: str; author: User; } type User { name: str; multi blog_posts := .`. .. list-table:: :class: seealso * - :ref:`SDL > Links ` * - :ref:`DDL > Links ` * - :ref:`SDL > Properties ` * - :ref:`DDL > Properties ` ================================================ FILE: docs/reference/datamodel/constraints.rst ================================================ .. _ref_datamodel_constraints: .. _ref_eql_sdl_constraints: =========== Constraints =========== .. index:: validation Constraints give users fine-grained control to ensure data consistency. They can be defined on :ref:`properties `, :ref:`links`, :ref:`object types `, and :ref:`custom scalars `. .. _ref_datamodel_constraints_builtin: Standard constraints ==================== .. api-index:: exclusive, expression on, one_of, max_value, max_ex_value, min_value, min_ex_value, max_len_value, min_len_value, regexp |Gel| includes a number of standard ready-to-use constraints: .. include:: ../stdlib/constraint_table.rst Constraints on properties ========================= Example: enforce all ``User`` objects to have a unique ``username`` no longer than 25 characters: .. code-block:: sdl type User { required username: str { # usernames must be unique constraint exclusive; # max length (built-in) constraint max_len_value(25); }; } .. _ref_datamodel_constraints_objects: Constraints on object types =========================== .. api-index:: __subject__ Constraints can be defined on object types. This is useful when the constraint logic must reference multiple links or properties. Example: enforce that the magnitude of ``ConstrainedVector`` objects is no more than 5 .. code-block:: sdl type ConstrainedVector { required x: float64; required y: float64; constraint expression on ( (.x ^ 2 + .y ^ 2) ^ 0.5 <= 5 # or, long form: `(__subject__.x + __subject__.y) ^ 0.5 <= 5` ); } The ``expression`` constraint is used here to define custom constraint logic. Inside constraints, the keyword ``__subject__`` can be used to reference the *value* being constrained. .. note:: Note that inside an object type declaration, you can omit ``__subject__`` and simply refer to properties with the :ref:`leading dot notation ` (e.g. ``.property``). .. note:: Also note that the constraint expression are fairly restricted. Due to how constraints are implemented, you can only reference ``single`` (non-multi) properties and links defined on the object type: .. code-block:: sdl # Not valid! type User { required username: str; multi friends: User; # ❌ constraints cannot contain paths with more than one hop constraint expression on ('bob' in .friends.username); } Abstract constraints ==================== .. api-index:: abstract constraint You can re-use constraints across multiple object types by declaring them as abstract constraints. Example: .. code-block:: sdl abstract constraint min_value(min: anytype) { errmessage := 'Minimum allowed value for {__subject__} is {min}.'; using (__subject__ >= min); } # use it like this: scalar type posint64 extending int64 { constraint min_value(0); } # or like this: type User { required age: int16 { constraint min_value(12); }; } Computed constraints ==================== Constraints can be defined on computed properties: .. code-block:: sdl type User { required username: str; required clean_username := str_trim(str_lower(.username)); constraint exclusive on (.clean_username); } Composite constraints ===================== .. api-index:: constraint exclusive on To define a composite constraint, create an ``exclusive`` constraint on a tuple of properties or links. .. code-block:: sdl type User { username: str; } type BlogPost { title: str; author: User; constraint exclusive on ((.title, .author)); } .. _ref_datamodel_constraints_partial: Partial constraints =================== .. api-index:: constraint exclusive on, except Constraints on object types can be made partial, so that they are not enforced when the specified ``except`` condition is met. .. code-block:: sdl type User { required username: str; deleted: bool; # Usernames must be unique unless marked deleted constraint exclusive on (.username) except (.deleted); } Constraints on links ==================== You can constrain links such that a given object can only be linked once by using :eql:constraint:`exclusive`: .. code-block:: sdl type User { required name: str; # Make sure none of the "owned" items belong # to any other user. multi owns: Item { constraint exclusive; } } Link property constraints ========================= You can also add constraints for :ref:`link properties `: .. code-block:: sdl type User { name: str; multi friends: User { strength: float64; constraint expression on ( @strength >= 0 ); } } Link's "@source" and "@target" ============================== .. api-index:: @source, @target You can create a composite exclusive constraint on the object linking/linked *and* a link property by using ``@source`` or ``@target`` respectively. Here's a schema for a library book management app that tracks books and who has checked them out: .. code-block:: sdl type Book { required title: str; } type User { name: str; multi checked_out: Book { date: cal::local_date; # Ensures a given Book can be checked out # only once on a given day. constraint exclusive on ((@target, @date)); } } Here, the constraint ensures that no book can be checked out to two ``User``\s on the same ``@date``. In this example demonstrating ``@source``, we've created a schema to track player picks in a color-based memory game: .. code-block:: sdl type Player { required name: str; multi picks: Color { order: int16; constraint exclusive on ((@source, @order)); } } type Color { required name: str; } This constraint ensures that a single ``Player`` cannot pick two ``Color``\s at the same ``@order``. Constraints on custom scalars ============================= Custom scalar types can be constrained. .. code-block:: sdl scalar type username extending str { constraint regexp(r'^[A-Za-z0-9_]{4,20}$'); } Note: you can't use :eql:constraint:`exclusive` constraints on custom scalar types, as the concept of exclusivity is only defined in the context of a given object type. Use :eql:constraint:`expression` constraints to declare custom constraints using arbitrary EdgeQL expressions. The example below uses the built-in :eql:func:`str_trim` function. .. code-block:: sdl scalar type title extending str { constraint expression on ( __subject__ = str_trim(__subject__) ); } Constraints and inheritance =========================== .. api-index:: delegated constraint If you define a constraint on a type and then extend that type, the constraint will *not* be applied individually to each extending type. Instead, it will apply globally across all the types that inherited the constraint. .. code-block:: sdl type User { required name: str { constraint exclusive; } } type Administrator extending User; type Moderator extending User; .. code-block:: edgeql-repl gel> insert Administrator { .... name := 'Jan' .... }; {default::Administrator {id: 7aeaa146-f5a5-11ed-a598-53ddff476532}} gel> insert Moderator { .... name := 'Jan' .... }; gel error: ConstraintViolationError: name violates exclusivity constraint Detail: value of property 'name' of object type 'default::Moderator' violates exclusivity constraint gel> insert User { .... name := 'Jan' .... }; gel error: ConstraintViolationError: name violates exclusivity constraint Detail: value of property 'name' of object type 'default::User' violates exclusivity constraint As this example demonstrates, if an object of one extending type has a value for a property that is exclusive, an object of a *different* extending type cannot have the same value. If that's not what you want, you can instead delegate the constraint to the inheriting types by prepending the ``delegated`` keyword to the constraint. The constraint would then be applied just as if it were declared individually on each of the inheriting types. .. code-block:: sdl type User { required name: str { delegated constraint exclusive; } } type Administrator extending User; type Moderator extending User; .. code-block:: edgeql-repl gel> insert Administrator { .... name := 'Jan' .... }; {default::Administrator {id: 7aeaa146-f5a5-11ed-a598-53ddff476532}} gel> insert User { .... name := 'Jan' .... }; {default::User {id: a6e3fdaf-c44b-4080-b39f-6a07496de66b}} gel> insert Moderator { .... name := 'Jan' .... }; {default::Moderator {id: d3012a3f-0f16-40a8-8884-7203f393b63d}} gel> insert Moderator { .... name := 'Jan' .... }; gel error: ConstraintViolationError: name violates exclusivity constraint Detail: value of property 'name' of object type 'default::Moderator' violates exclusivity constraint With the addition of ``delegated`` to the constraints, the inserts were successful for each of the types. We did not hit a constraint violation until we tried to insert a second ``Moderator`` object with the same name as the existing one. .. _ref_eql_sdl_constraints_syntax: Declaring constraints ===================== This section describes the syntax to declare constraints in your schema. Syntax ------ .. sdl:synopsis:: [{abstract | delegated}] constraint [ ( [] [, ...] ) ] [ on ( ) ] [ except ( ) ] [ extending [, ...] ] "{" [ using ; ] [ errmessage := ; ] [ ] [ ... ] "}" ; # where is: [ : ] { | } Description ^^^^^^^^^^^ This declaration defines a new constraint with the following options: :eql:synopsis:`abstract` If specified, the constraint will be *abstract*. :eql:synopsis:`delegated` If specified, the constraint is defined as *delegated*, which means that it will not be enforced on the type it's declared on, and the enforcement will be delegated to the subtypes of this type. This is particularly useful for :eql:constraint:`exclusive` constraints in abstract types. This is only valid for *concrete constraints*. :eql:synopsis:`` The name (optionally module-qualified) of the new constraint. :eql:synopsis:`` An optional list of constraint arguments. For an *abstract constraint* :eql:synopsis:`` optionally specifies the argument name and :eql:synopsis:`` specifies the argument type. For a *concrete constraint* :eql:synopsis:`` optionally specifies the argument name and :eql:synopsis:`` specifies the argument value. The argument value specification must match the parameter declaration of the abstract constraint. :eql:synopsis:`on ( )` An optional expression defining the *subject* of the constraint. If not specified, the subject is the value of the schema item on which the concrete constraint is defined. The expression must refer to the original subject of the constraint as ``__subject__``. The expression must be :ref:`Immutable `, but may refer to ``__subject__`` and its properties and links. Note also that ```` itself has to be parenthesized. :eql:synopsis:`except ( )` An optional expression defining a condition to create exceptions to the constraint. If ```` evaluates to ``true``, the constraint is ignored for the current subject. If it evaluates to ``false`` or ``{}``, the constraint applies normally. ``except`` may only be declared on object constraints, and otherwise follows the same rules as ``on``. :eql:synopsis:`extending [, ...]` If specified, declares the *parent* constraints for this abstract constraint. The valid SDL sub-declarations are listed below: :eql:synopsis:`using ` A boolean expression that returns ``true`` for valid data and ``false`` for invalid data. The expression may refer to the subject of the constraint as ``__subject__``. This declaration is only valid for *abstract constraints*. :eql:synopsis:`errmessage := ` An optional string literal defining the error message template that is raised when the constraint is violated. The template is a formatted string that may refer to constraint context variables in curly braces. The template may refer to the following: - ``$argname`` -- the value of the specified constraint argument - ``__subject__`` -- the value of the ``title`` annotation of the scalar type, property or link on which the constraint is defined. If the content of curly braces does not match any variables, the curly braces are emitted as-is. They can also be escaped by using double curly braces. :sdl:synopsis:`` Set constraint :ref:`annotation ` to a given *value*. .. _ref_eql_ddl_constraints: DDL commands ============ This section describes the low-level DDL commands for creating and dropping constraints and abstract constraints. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create abstract constraint -------------------------- :eql-statement: :eql-haswith: Define a new abstract constraint. .. eql:synopsis:: [ with [ := ] module ] create abstract constraint [ ( [] [, ...] ) ] [ on ( ) ] [ extending [, ...] ] "{" ; [...] "}" ; # where is: [ : ] # where is one of using set errmessage := create annotation := Description ^^^^^^^^^^^ The command ``create abstract constraint`` defines a new abstract constraint. If *name* is qualified with a module name, then the constraint is created in that module, otherwise it is created in the current module. The constraint name must be distinct from that of any existing schema item in the module. Parameters ^^^^^^^^^^ Most sub-commands and options of this command are identical to the :ref:`SDL constraint declaration `, with some additional features listed below: :eql:synopsis:`[ := ] module ` An optional list of module alias declarations to be used in the migration definition. When *module-alias* is not specified, *module-name* becomes the effective current module and is used to resolve all unqualified names. :eql:synopsis:`set errmessage := ` An optional string literal defining the error message template that is raised when the constraint is violated. Other than a slight syntactical difference this is the same as the corresponding SDL declaration. :eql:synopsis:`create annotation := ;` Set constraint annotation ```` to ````. See :eql:stmt:`create annotation` for details. Example ^^^^^^^ Create an abstract constraint "uppercase" which checks if the subject is a string in upper case: .. code-block:: edgeql create abstract constraint uppercase { create annotation title := "Upper case constraint"; using (str_upper(__subject__) = __subject__); set errmessage := "{__subject__} is not in upper case"; }; Alter abstract constraint ------------------------- :eql-statement: :eql-haswith: Alter the definition of an abstract constraint. .. eql:synopsis:: [ with [ := ] module ] alter abstract constraint "{" ; [...] "}" ; # where is one of rename to using set errmessage := reset errmessage create annotation := alter annotation := drop annotation Description ^^^^^^^^^^^ The command ``alter abstract constraint`` changes the definition of an abstract constraint item. *name* must be a name of an existing abstract constraint, optionally qualified with a module name. Parameters ^^^^^^^^^^ :eql:synopsis:`[ := ] module ` An optional list of module alias declarations to be used in the migration definition. When *module-alias* is not specified, *module-name* becomes the effective current module and is used to resolve all unqualified names. :eql:synopsis:`` The name (optionally module-qualified) of the constraint to alter. Subcommands allowed in the ``alter abstract constraint`` block: :eql:synopsis:`rename to ` Change the name of the constraint to *newname*. All concrete constraints inheriting from this constraint are also renamed. :eql:synopsis:`alter annotation := ` Alter constraint annotation ````. See :eql:stmt:`alter annotation` for details. :eql:synopsis:`drop annotation ` Remove annotation ````. See :eql:stmt:`drop annotation` for details. :eql:synopsis:`reset errmessage` Remove the error message from this abstract constraint. The error message specified in the base abstract constraint will be used instead. All subcommands allowed in a ``create abstract constraint`` block are also valid here. Example ^^^^^^^ Rename the abstract constraint "uppercase" to "upper_case": .. code-block:: edgeql alter abstract constraint uppercase rename to upper_case; Drop abstract constraint ------------------------ :eql-statement: :eql-haswith: Remove an abstract constraint from the schema. .. eql:synopsis:: [ with [ := ] module ] drop abstract constraint ; Description ^^^^^^^^^^^ The command ``drop abstract constraint`` removes an existing abstract constraint item from the database schema. If any schema items depending on this constraint exist, the operation is refused. Parameters ^^^^^^^^^^ :eql:synopsis:`[ := ] module ` An optional list of module alias declarations to be used in the migration definition. :eql:synopsis:`` The name (optionally module-qualified) of the constraint to remove. Example ^^^^^^^ Drop abstract constraint ``upper_case``: .. code-block:: edgeql drop abstract constraint upper_case; Create constraint ----------------- :eql-statement: Define a concrete constraint on the specified schema item. .. eql:synopsis:: [ with [ := ] module ] create [ delegated ] constraint [ ( [] [, ...] ) ] [ on ( ) ] [ except ( ) ] "{" ; [...] "}" ; # where is: [ : ] # where is one of set errmessage := create annotation := Description ^^^^^^^^^^^ The command ``create constraint`` defines a new concrete constraint. It can only be used in the context of :eql:stmt:`create scalar`, :eql:stmt:`alter scalar`, :eql:stmt:`create property`, :eql:stmt:`alter property`, :eql:stmt:`create link`, or :eql:stmt:`alter link`. *name* must be a name (optionally module-qualified) of a previously defined abstract constraint. Parameters ^^^^^^^^^^ Most sub-commands and options of this command are identical to the :ref:`SDL constraint declaration `, with some additional features listed below: :eql:synopsis:`[ := ] module ` An optional list of module alias declarations to be used in the migration definition. :eql:synopsis:`set errmessage := ` An optional string literal defining the error message template that is raised when the constraint is violated. Other than a slight syntactical difference, this is the same as the corresponding SDL declaration. :eql:synopsis:`create annotation := ;` An optional list of annotations for the constraint. See :eql:stmt:`create annotation` for details. Example ^^^^^^^ Create a "score" property on the "User" type with a minimum value constraint: .. code-block:: edgeql alter type User create property score: int64 { create constraint min_value(0) }; Create a Vector with a maximum magnitude: .. code-block:: edgeql create type Vector { create required property x: float64; create required property y: float64; create constraint expression ON ( __subject__.x^2 + __subject__.y^2 < 25 ); } Alter constraint ---------------- :eql-statement: Alter the definition of a concrete constraint on the specified schema item. .. eql:synopsis:: [ with [ := ] module [, ...] ] alter constraint [ ( [] [, ...] ) ] [ on ( ) ] [ except ( ) ] "{" ; [ ... ] "}" ; # -- or -- [ with [ := ] module [, ...] ] alter constraint [ ( [] [, ...] ) ] [ on ( ) ] ; # where is one of: set delegated set not delegated set errmessage := reset errmessage create annotation := alter annotation drop annotation Description ^^^^^^^^^^^ The command ``alter constraint`` changes the definition of a concrete constraint. Both single- and multi-command forms are supported. Parameters ^^^^^^^^^^ :eql:synopsis:`[ := ] module ` An optional list of module alias declarations for the migration. :eql:synopsis:`` The name (optionally module-qualified) of the concrete constraint that is being altered. :eql:synopsis:`` A list of constraint arguments as specified at the time of ``create constraint``. :eql:synopsis:`on ( )` An expression defining the *subject* of the constraint as specified at the time of ``create constraint``. The following subcommands are allowed in the ``alter constraint`` block: :eql:synopsis:`set delegated` Mark the constraint as *delegated*, which means it will not be enforced on the type it's declared on, and enforcement is delegated to subtypes. Useful for :eql:constraint:`exclusive` constraints. :eql:synopsis:`set not delegated` Mark the constraint as *not delegated*, so it is enforced globally across the type and any extending types. :eql:synopsis:`rename to ` Change the name of the constraint to ````. :eql:synopsis:`alter annotation ` Alter a constraint annotation. :eql:synopsis:`drop annotation ` Remove a constraint annotation. :eql:synopsis:`reset errmessage` Remove the error message from this constraint, reverting to that of the abstract constraint, if any. All subcommands allowed in ``create constraint`` are also valid in ``alter constraint``. Example ^^^^^^^ Change the error message on the minimum value constraint on the property "score" of the "User" type: .. code-block:: edgeql alter type User alter property score alter constraint min_value(0) set errmessage := 'Score cannot be negative'; Drop constraint --------------- :eql-statement: :eql-haswith: Remove a concrete constraint from the specified schema item. .. eql:synopsis:: [ with [ := ] module [, ...] ] drop constraint [ ( [] [, ...] ) ] [ on ( ) ] [ except ( ) ] ; Description ^^^^^^^^^^^ The command ``drop constraint`` removes the specified constraint from its containing schema item. Parameters ^^^^^^^^^^ :eql:synopsis:`[ := ] module ` Optional module alias declarations for the migration definition. :eql:synopsis:`` The name (optionally module-qualified) of the concrete constraint to remove. :eql:synopsis:`` A list of constraint arguments as specified at the time of ``create constraint``. :eql:synopsis:`on ( )` Expression defining the *subject* of the constraint as specified at the time of ``create constraint``. Example ^^^^^^^ Remove constraint "min_value" from the property "score" of the "User" type: .. code-block:: edgeql alter type User alter property score drop constraint min_value(0); .. list-table:: :class: seealso * - **See also** * - :ref:`Introspection > Constraints ` * - :ref:`Standard Library > Constraints ` ================================================ FILE: docs/reference/datamodel/extensions.rst ================================================ .. _ref_datamodel_extensions: ========== Extensions ========== .. api-index:: using extension Extensions are the way |Gel| can be extended with more functionality. They can add new types, scalars, functions, etc., but, more importantly, they can add new ways of interacting with the database. Built-in extensions =================== .. api-index:: edgeql_http, graphql, auth, ai, pg_trgm, pg_unaccent, pgcrypto, pgvector There are a few built-in extensions available: - ``edgeql_http``: enables :ref:`EdgeQL over HTTP `, - ``graphql``: enables :ref:`GraphQL `, - ``auth``: enables :ref:`Gel Auth `, - ``ai``: enables :ref:`ext::ai module `, - ``pg_trgm``: enables ``ext::pg_trgm``, which re-exports `pgtrgm `__, - ``pg_unaccent``: enables ``ext::pg_unaccent``, which re-exports `unaccent `__, - ``pgcrypto``: enables ``ext::pgcrypto``, which re-exports `pgcrypto `__, - ``pgvector``: enables ``ext::pgvector``, which re-exports `pgvector `__, .. _ref_datamodel_using_extension: To enable these extensions, add a ``using`` statement at the top level of your schema: .. code-block:: sdl using extension auth; # or / and using extension ai; Standalone extensions ===================== .. api-index:: postgis Additionally, standalone extension packages can be installed on local project-managed instances via the CLI, with ``postgis`` being a notable example. List installed extensions: .. code-block:: bash $ gel extension list ┌─────────┬─────────┐ │ Name │ Version │ └─────────┴─────────┘ List available extensions: .. code-block:: bash $ gel extension list-available ┌─────────┬───────────────┐ │ Name │ Version │ │ postgis │ 3.4.3+6b82d77 │ └─────────┴───────────────┘ Install the ``postgis`` extension: .. code-block:: bash $ gel extension install postgis Found extension package: postgis version 3.4.3+6b82d77 00:00:03 [====================] 22.49 MiB/22.49 MiB Extension 'postgis' installed successfully. Check that extension is installed: .. code-block:: bash $ gel extension list ┌─────────┬───────────────┐ │ Name │ Version │ │ postgis │ 3.4.3+6b82d77 │ └─────────┴───────────────┘ After installing extensions, make sure to restart your instance: .. code-block:: bash $ gel instance restart Standalone extensions can now be declared in the schema, same as built-in extensions: .. code-block:: sdl using extension postgis; .. note:: To restore a dump that uses a standalone extension, that extension must be installed before the restore process. .. _ref_eql_sdl_extensions: Using extensions ================ Syntax ------ .. sdl:synopsis:: using extension ";" Extension declaration must be outside any :ref:`module block ` since extensions affect the entire database and not a specific module. .. _ref_eql_ddl_extensions: DDL commands ============ This section describes the low-level DDL commands for creating and dropping extensions. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. create extension ---------------- :eql-statement: Enable a particular extension for the current schema. .. eql:synopsis:: create extension ";" Description ^^^^^^^^^^^ The command ``create extension`` enables the specified extension for the current :versionreplace:`database;5.0:branch`. Examples ^^^^^^^^ Enable :ref:`GraphQL ` extension for the current schema: .. code-block:: edgeql create extension graphql; Enable :ref:`EdgeQL over HTTP ` extension for the current :versionreplace:`database;5.0:branch`: .. code-block:: edgeql create extension edgeql_http; drop extension -------------- :eql-statement: Disable an extension. .. eql:synopsis:: drop extension ";" The command ``drop extension`` disables a currently active extension for the current |branch|. Examples ^^^^^^^^ Disable :ref:`GraphQL ` extension for the current schema: .. code-block:: edgeql drop extension graphql; Disable :ref:`EdgeQL over HTTP ` extension for the current :versionreplace:`database;5.0:branch`: .. code-block:: edgeql drop extension edgeql_http; ================================================ FILE: docs/reference/datamodel/functions.rst ================================================ .. _ref_datamodel_functions: .. _ref_eql_sdl_functions: ========= Functions ========= .. note:: This page documents how to define custom functions, however |Gel| provides a large library of built-in functions and operators. These are documented in :ref:`Standard Library `. User-defined Functions ====================== Gel allows you to define custom functions. For example, consider a function that adds an exclamation mark ``'!'`` at the end of the string: .. code-block:: sdl function exclamation(word: str) -> str using (word ++ '!'); This function accepts a :eql:type:`str` as an argument and produces a :eql:type:`str` as output as well. .. code-block:: edgeql-repl test> select exclamation({'Hello', 'World'}); {'Hello!', 'World!'} .. _ref_datamodel_functions_modifying: Sets as arguments ================= Calling a user-defined function on a set will always apply it as :ref:`*element-wise* `. .. code-block:: sdl function magnitude(x: float64) -> float64 using ( math::sqrt(sum(x * x)) ); .. code-block:: edgeql-repl db> select magnitude({3, 4}); {3, 4} In order to pass in multiple arguments at once, arguments should be packed into arrays: .. code-block:: sdl function magnitude(xs: array) -> float64 using ( with x := array_unpack(xs) select math::sqrt(sum(x * x)) ); .. code-block:: edgeql-repl db> select magnitude([3, 4]); {5} Multiple packed arrays can be passed into such a function, which will then be applied element-wise. .. code-block:: edgeql-repl db> select magnitude({[3, 4], [5, 12]}); {5, 13} Modifying Functions =================== .. versionadded:: 6.0 User-defined functions can contain DML (i.e., :ref:`insert `, :ref:`update `, :ref:`delete `) to make changes to existing data. These functions have a :ref:`modifying ` volatility. .. code-block:: sdl function add_user(name: str) -> User using ( insert User { name := name, joined_at := std::datetime_current(), } ); .. code-block:: edgeql-repl db> select add_user('Jan') {name, joined_at}; {default::User {name: 'Jan', joined_at: '2024-12-11T11:49:47Z'}} Unlike other functions, the arguments of modifying functions **must** have a :ref:`cardinality ` of ``One``. .. code-block:: edgeql-repl db> select add_user({'Feb','Mar'}); gel error: QueryError: possibly more than one element passed into modifying function db> select add_user({}); gel error: QueryError: possibly an empty set passed as non-optional argument into modifying function Optional arguments can still accept empty sets. For example, if ``add_user`` was defined as: .. code-block:: sdl function add_user(name: str, joined_at: optional datetime) -> User using ( insert User { name := name, joined_at := joined_at ?? std::datetime_current(), } ); then the following queries are valid: .. code-block:: edgeql-repl db> select add_user('Apr', {}) {name, joined_at}; {default::User {name: 'Apr', joined_at: '2024-12-11T11:50:51Z'}} db> select add_user('May', '2024-12-11T12:00:00-07:00') {name, joined_at}; {default::User {name: 'May', joined_at: '2024-12-11T12:00:00Z'}} In order to insert or update a multi parameter, the desired arguments should be aggregated into an array as described above: .. code-block:: sdl function add_user(name: str, nicknames: array) -> User using ( insert User { name := name, nicknames := array_unpack(nicknames), } ); .. _ref_eql_sdl_functions_syntax: Declaring functions =================== .. api-index:: function, using, ->, variadic, named only, set of, optional, volatility, Immutable, Stable, Volatile, Modifying This section describes the syntax to declare a function in your schema. Syntax ------ .. sdl:synopsis:: function ([ ] [, ... ]) -> using ( ); function ([ ] [, ... ]) -> using ; function ([ ] [, ... ]) -> "{" [ ] [ volatility := {'Immutable' | 'Stable' | 'Volatile' | 'Modifying'} ] [ using ( ) ; ] [ using ; ] [ ... ] "}" ; # where is: [ ] : [ ] [ = ] # is: [ { variadic | named only } ] # is: [ { set of | optional } ] # and is: [ ] Description ^^^^^^^^^^^ This declaration defines a new **function** with the following options: :eql:synopsis:`` The name (optionally module-qualified) of the function to create. :eql:synopsis:`` The kind of an argument: ``variadic`` or ``named only``. If not specified, the argument is called *positional*. The ``variadic`` modifier indicates that the function takes an arbitrary number of arguments of the specified type. The passed arguments will be passed as an array of the argument type. Positional arguments cannot follow a ``variadic`` argument. ``variadic`` parameters cannot have a default value. The ``named only`` modifier indicates that the argument can only be passed using that specific name. Positional arguments cannot follow a ``named only`` argument. :eql:synopsis:`` The name of an argument. If ``named only`` modifier is used this argument *must* be passed using this name only. .. _ref_sdl_function_typequal: :eql:synopsis:`` The type qualifier: ``set of`` or ``optional``. The ``set of`` qualifier indicates that the function is taking the argument as a *whole set*, as opposed to being called on the input product element-by-element. User defined functions can not use ``set of`` arguments. The ``optional`` qualifier indicates that the function will be called if the argument is an empty set. The default behavior is to return an empty set if the argument is not marked as ``optional``. :eql:synopsis:`` The data type of the function's arguments (optionally module-qualified). :eql:synopsis:`` An expression to be used as default value if the parameter is not specified. The expression has to be of a type compatible with the type of the argument. .. _ref_sdl_function_rettype: :eql:synopsis:`` The return data type (optionally module-qualified). The ``set of`` modifier indicates that the function will return a non-singleton set. The ``optional`` qualifier indicates that the function may return an empty set. The valid SDL sub-declarations are listed below: :eql:synopsis:`volatility := {'Immutable' | 'Stable' | 'Volatile' | 'Modifying'}` Function volatility determines how aggressively the compiler can optimize its invocations. If not explicitly specified the function volatility is :ref:`inferred ` from the function body. * An ``Immutable`` function cannot modify the database and is guaranteed to return the same results given the same arguments *in all statements*. * A ``Stable`` function cannot modify the database and is guaranteed to return the same results given the same arguments *within a single statement*. * A ``Volatile`` function cannot modify the database and can return different results on successive calls with the same arguments. * A ``Modifying`` function can modify the database and can return different results on successive calls with the same arguments. :eql:synopsis:`using ( )` Specifies the body of the function. :eql:synopsis:`` is an arbitrary EdgeQL expression. :eql:synopsis:`using ` A verbose version of the :eql:synopsis:`using` clause that allows specifying the language of the function body. * :eql:synopsis:`` is the name of the language that the function is implemented in. Currently can only be ``edgeql``. * :eql:synopsis:`` is a string constant defining the function. It is often helpful to use :ref:`dollar quoting ` to write the function definition string. :sdl:synopsis:`` Set function :ref:`annotation ` to a given *value*. The function name must be distinct from that of any existing function with the same argument types in the same module. Functions of different argument types can share a name, in which case the functions are called *overloaded functions*. .. _ref_eql_ddl_functions: DDL commands ============ This section describes the low-level DDL commands for creating, altering, and dropping functions. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create function --------------- :eql-statement: :eql-haswith: Define a new function. .. eql:synopsis:: [ with [, ...] ] create function ([ ] [, ... ]) -> using ( ); [ with [, ...] ] create function ([ ] [, ... ]) -> using ; [ with [, ...] ] create function ([ ] [, ... ]) -> "{" [, ...] "}" ; # where is: [ ] : [ ] [ = ] # is: [ { variadic | named only } ] # is: [ { set of | optional } ] # and is: [ ] # and is one of set volatility := {'Immutable' | 'Stable' | 'Volatile' | 'Modifying'} ; create annotation := ; using ( ) ; using ; Description ^^^^^^^^^^^ The command ``create function`` defines a new function. If *name* is qualified with a module name, then the function is created in that module, otherwise it is created in the current module. The function name must be distinct from that of any existing function with the same argument types in the same module. Functions of different argument types can share a name, in which case the functions are called *overloaded functions*. Parameters ^^^^^^^^^^ Most sub-commands and options of this command are identical to the :ref:`SDL function declaration `, with some additional features listed below: :eql:synopsis:`set volatility := {'Immutable' | 'Stable' | 'Volatile' | 'Modifying'}` Function volatility determines how aggressively the compiler can optimize its invocations. Other than a slight syntactical difference this is the same as the corresponding SDL declaration. :eql:synopsis:`create annotation := ` Set the function's :eql:synopsis:`` to :eql:synopsis:``. See :eql:stmt:`create annotation` for details. Examples ^^^^^^^^ Define a function returning the sum of its arguments: .. code-block:: edgeql create function mysum(a: int64, b: int64) -> int64 using ( select a + b ); The same, but using a variadic argument and an explicit language: .. code-block:: edgeql create function mysum(variadic argv: int64) -> int64 using edgeql $$ select sum(array_unpack(argv)) $$; Define a function using the block syntax: .. code-block:: edgeql create function mysum(a: int64, b: int64) -> int64 { using ( select a + b ); create annotation title := "My sum function."; }; Alter function -------------- :eql-statement: :eql-haswith: Change the definition of a function. .. eql:synopsis:: [ with [, ...] ] alter function ([ ] [, ... ]) "{" [, ...] "}" # where is: [ ] : [ ] [ = ] # and is one of set volatility := {'Immutable' | 'Stable' | 'Volatile' | 'Modifying'} ; reset volatility ; rename to ; create annotation := ; alter annotation := ; drop annotation ; using ( ) ; using ; Description ^^^^^^^^^^^ The command ``alter function`` changes the definition of a function. The command allows changing annotations, the volatility level, and other attributes. Subcommands ^^^^^^^^^^^ The following subcommands are allowed in the ``alter function`` block in addition to the commands common to the ``create function``: :eql:synopsis:`reset volatility` Remove explicitly specified volatility in favor of the volatility inferred from the function body. :eql:synopsis:`rename to ` Change the name of the function to *newname*. :eql:synopsis:`alter annotation ;` Alter function :eql:synopsis:``. See :eql:stmt:`alter annotation` for details. :eql:synopsis:`drop annotation ;` Remove function :eql:synopsis:``. See :eql:stmt:`drop annotation` for details. :eql:synopsis:`reset errmessage;` Remove the error message from this abstract constraint. The error message specified in the base abstract constraint will be used instead. Example ^^^^^^^ .. code-block:: edgeql create function mysum(a: int64, b: int64) -> int64 { using ( select a + b ); create annotation title := "My sum function."; }; alter function mysum(a: int64, b: int64) { set volatility := 'Immutable'; drop annotation title; }; alter function mysum(a: int64, b: int64) { using ( select (a + b) * 100 ) }; Drop function ------------- :eql-statement: :eql-haswith: Remove a function. .. eql:synopsis:: [ with [, ...] ] drop function ([ ] [, ... ]); # where is: [ ] : [ ] [ = ] Description ^^^^^^^^^^^ The command ``drop function`` removes the definition of an existing function. The argument types to the function must be specified, since there can be different functions with the same name. Parameters ^^^^^^^^^^ :eql:synopsis:`` The name (optionally module-qualified) of an existing function. :eql:synopsis:`` The name of an argument used in the function definition. :eql:synopsis:`` The mode of an argument: ``set of`` or ``optional`` or ``variadic``. :eql:synopsis:`` The data type(s) of the function's arguments (optionally module-qualified), if any. Example ^^^^^^^ Remove the ``mysum`` function: .. code-block:: edgeql drop function mysum(a: int64, b: int64); .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > Function calls ` * - :ref:`Introspection > Functions ` * - :ref:`Cheatsheets > Functions ` ================================================ FILE: docs/reference/datamodel/future.rst ================================================ .. _ref_datamodel_future: =============== Future behavior =============== .. api-index:: using future This article explains what the ``using future ...;`` statement means in your schema. Our goal is to make |Gel| the best database system in the world, which requires us to keep evolving. Usually, we can add new functionality while preserving backward compatibility, but on rare occasions we must implement changes that require elaborate transitions. To handle these cases, we introduce *future* behavior, which lets you try out upcoming features before a major release. Sometimes enabling a future is necessary to fix current issues; other times it offers a safe and easy way to ensure your codebase remains compatible. This approach provides more time to adopt a new feature and identify any resulting bugs. Any time a behavior is available as a ``future,`` all new :ref:`projects ` enable it by default for empty databases. You can remove a ``future`` from your schema if absolutely necessary, but doing so is discouraged. Existing projects are unaffected by default, so you must manually add the ``future`` specification to gain early access. Flags ===== .. api-index:: simple_scoping, warn_old_scoping, nonrecursive_access_policies At the moment there are three ``future`` flags available: - ``simple_scoping`` Introduced in |Gel| 6.0, this flag simplifies the scoping rules for path expressions. Read more about it and in great detail in :ref:`ref_eql_path_resolution`. - ``warn_old_scoping`` Introduced in |Gel| 6.0, this flag will emit a warning when a query is detected to depend on the old scoping rules. This is an intermediate step towards enabling the ``simple_scoping`` flag in existing large codebases. Read more about this flag in :ref:`ref_warn_old_scoping`. .. _ref_datamodel_access_policies_nonrecursive: .. _nonrecursive: - ``nonrecursive_access_policies``: makes access policies non-recursive. This flag is no longer used becauae the behavior is enabled by default since |EdgeDB| 4. The flag was helpful to ease transition from EdgeDB 3.x to 4.x. Since |EdgeDB| 3.0, access policy restrictions do **not** apply to any access policy expression. This means that when reasoning about access policies it is no longer necessary to take other policies into account. Instead, all data is visible for the purpose of *defining* an access policy. This change was made to simplify reasoning about access policies and to allow certain patterns to be expressed efficiently. Since those who have access to modifying the schema can remove unwanted access policies, no additional security is provided by applying access policies to each other's expressions. .. _ref_eql_sdl_future: Declaring future flags ====================== Syntax ------ Declare that the current schema enables a particular future behavior. .. sdl:synopsis:: using future ";" Description ^^^^^^^^^^^ Future behavior declaration must be outside any :ref:`module block ` since this behavior affects the entire database and not a specific module. Example ^^^^^^^ .. code-block:: sdl-invalid using future simple_scoping; .. _ref_eql_ddl_future: DDL commands ============ This section describes the low-level DDL commands for creating and dropping future flags. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create future ------------- :eql-statement: Enable a particular future behavior for the current schema. .. eql:synopsis:: create future ";" The command ``create future`` enables the specified future behavior for the current branch. Example ^^^^^^^ .. code-block:: edgeql create future simple_scoping; Drop future ----------- :eql-statement: Disable a particular future behavior for the current schema. .. eql:synopsis:: drop future ";" Description ^^^^^^^^^^^ The command ``drop future`` disables a currently active future behavior for the current branch. However, this is only possible for versions of |Gel| when the behavior in question is not officially introduced. Once a particular behavior is introduced as the standard behavior in a |Gel| release, it cannot be disabled. Example ^^^^^^^ .. code-block:: edgeql drop future warn_old_scoping; ================================================ FILE: docs/reference/datamodel/globals.rst ================================================ .. _ref_datamodel_globals: ======= Globals ======= Schemas in Gel can contain typed *global variables*. These create a mechanism for specifying session-level context that can be referenced in queries, access policies, triggers, and elsewhere with the ``global`` keyword. Here's a very common example of a global variable representing the current user ID: .. code-block:: sdl global current_user_id: uuid; .. tabs:: .. code-tab:: edgeql select User { id, posts: { title, content } } filter .id = global current_user_id; .. code-tab:: python # In a non-trivial example, `global current_user_id` would # be used indirectly in an access policy or some other context. await client.with_globals({'user_id': user_id}).qeury(''' select User { id, posts: { title, content } } filter .id = global current_user_id; ''') .. code-tab:: typescript // In a non-trivial example, `global current_user_id` would // be used indirectly in an access policy or some other context. await client.withGlobals({user_id}).qeury(''' select User { id, posts: { title, content } } filter .id = global current_user_id; ''') Setting global variables ======================== Global variables are set at session level or when initializing a client. The exact API depends on which client library you're using, but the general behavior and principles are the same across all libraries. .. tabs:: .. code-tab:: typescript import createClient from 'gel'; const baseClient = createClient(); // returns a new Client instance, that shares the underlying // network connection with `baseClient` , but sends the configured // globals along with all queries run through it: const clientWithGlobals = baseClient.withGlobals({ current_user_id: '2141a5b4-5634-4ccc-b835-437863534c51', }); const result = await clientWithGlobals.query( `select global current_user_id;` ); .. code-tab:: python from gel import create_client base_client = create_client() # returns a new Client instance, that shares the underlying # network connection with `base_client` , but sends the configured # globals along with all queries run through it: client = base_client.with_globals({ 'current_user_id': '580cc652-8ab8-4a20-8db9-4c79a4b1fd81' }) result = client.query(""" select global current_user_id; """) .. code-tab:: go package main import ( "context" "fmt" "log" "github.com/geldata/gel-go" ) func main() { ctx := context.Background() client, err := gel.CreateClient(ctx, gel.Options{}) if err != nil { log.Fatal(err) } defer client.Close() id, err := gel.ParseUUID("2141a5b4-5634-4ccc-b835-437863534c51") if err != nil { log.Fatal(err) } var result gel.UUID err = client. WithGlobals(map[string]interface{}{"current_user": id}). QuerySingle(ctx, "SELECT global current_user;", &result) if err != nil { log.Fatal(err) } fmt.Println(result) } .. code-tab:: rust use uuid::Uuid; let client = gel_tokio::create_client().await.expect("Client init"); let client_with_globals = client.with_globals_fn(|c| { c.set( "current_user_id", Value::Uuid( Uuid::parse_str("2141a5b4-5634-4ccc-b835-437863534c51") .expect("Uuid should have parsed"), ), ) }); let val: Uuid = client_with_globals .query_required_single("select global current_user_id;", &()) .await .expect("Returning value"); println!("Result: {val}"); .. code-tab:: edgeql set global current_user_id := '2141a5b4-5634-4ccc-b835-437863534c51'; Cardinality =========== A global variable can be declared with one of two cardinalities: - ``single`` (the default): At most one value. - ``multi``: A set of values. Only valid for computed global variables. In addition, a global can be marked ``required`` or ``optional`` (the default). If marked ``required``, a default value must be provided. Computed globals ================ .. api-index:: global, := Global variables can also be computed. The value of computed globals is dynamically computed when they are referenced in queries. .. code-block:: sdl required global now := datetime_of_transaction(); The provided expression will be computed at the start of each query in which the global is referenced. There's no need to provide an explicit type; the type is inferred from the computed expression. Computed globals can also be object-typed and have ``multi`` cardinality. For example: .. code-block:: sdl global current_user_id: uuid; # object-typed global global current_user := ( select User filter .id = global current_user_id ); # multi global global current_user_friends := (global current_user).friends; Referencing globals =================== .. api-index:: global Unlike query parameters, globals can be referenced *inside your schema declarations*: .. code-block:: sdl type User { name: str; is_self := (.id = global current_user_id) }; This is particularly useful when declaring :ref:`access policies `: .. code-block:: sdl type Person { required name: str; access policy my_policy allow all using (.id = global current_user_id); } Refer to :ref:`Access Policies ` for complete documentation. .. _ref_eql_sdl_globals: .. _ref_eql_sdl_globals_syntax: Declaring globals ================= .. api-index:: required, optional, single, multi, global, :=, :, default This section describes the syntax to declare a global variable in your schema. Syntax ------ Define a new global variable in SDL, corresponding to the more explicit DDL commands described later: .. sdl:synopsis:: # Global variable declaration: [{required | optional}] [single] global : [ "{" [ default := ; ] [ ] ... "}" ] # Computed global variable declaration: [{required | optional}] [{single | multi}] global := ; Description ^^^^^^^^^^^ There are two different forms of ``global`` declarations, as shown in the syntax synopsis above: 1. A *settable* global (defined with ``: ``) which can be changed using a session-level :ref:`set ` command. 2. A *computed* global (defined with ``:= ``), which cannot be directly set but instead derives its value from the provided expression. The following options are available: :eql:synopsis:`required` If specified, the global variable is considered *required*. It is an error for this variable to have an empty value. If a global variable is declared *required*, it must also declare a *default* value. :eql:synopsis:`optional` The global variable is considered *optional*, i.e. it is possible for the variable to have an empty value. (This is the default.) :eql:synopsis:`multi` Specifies that the global variable may have a set of values. Only *computed* global variables can have this qualifier. :eql:synopsis:`single` Specifies that the global variable must have at most a *single* value. It is assumed that a global variable is ``single`` if neither ``multi`` nor ``single`` is specified. All non-computed global variables must be *single*. :eql:synopsis:`` The name of the global variable. It can be fully-qualified with the module name, or it is assumed to belong to the module in which it appears. :eql:synopsis:`` The type must be a valid :ref:`type expression ` denoting a non-abstract scalar or a container type. :eql:synopsis:` := ` Defines a *computed* global variable. The provided expression must be a :ref:`Stable ` EdgeQL expression. It can refer to other global variables. The type of a *computed* global variable is not limited to scalar and container types; it can also be an object type. The valid SDL sub-declarations are: :eql:synopsis:`default := ` Specifies the default value for the global variable as an EdgeQL expression. The default value is used in a session if the value was not explicitly specified by the client, or was reset with the :ref:`reset ` command. :sdl:synopsis:`` Set global variable :ref:`annotation ` to a given *value*. Examples -------- Declare a new global variable: .. code-block:: sdl global current_user_id: uuid; global current_user := ( select User filter .id = global current_user_id ); Set the global variable to a specific value using :ref:`session-level commands `: .. code-block:: edgeql set global current_user_id := '00ea8eaa-02f9-11ed-a676-6bd11cc6c557'; Use the computed global variable that is based on the value that was just set: .. code-block:: edgeql select global current_user { name }; :ref:`Reset ` the global variable to its default value: .. code-block:: edgeql reset global user_id; .. _ref_eql_ddl_globals: DDL commands ============ This section describes the low-level DDL commands for creating, altering, and dropping globals. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create global ------------- :eql-statement: :eql-haswith: Declare a new global variable using DDL. .. eql:synopsis:: [ with [, ...] ] create [{required | optional}] [single] global : [ "{" ; [...] "}" ] ; # Computed global variable form: [ with [, ...] ] create [{required | optional}] [{single | multi}] global := ; # where is one of set default := create annotation := Description ^^^^^^^^^^^ As with SDL, there are two different forms of ``global`` declaration: - A global variable that can be :ref:`set ` in a session. - A *computed* global that is derived from an expression (and so cannot be directly set in a session). The subcommands mirror those in SDL: :eql:synopsis:`set default := ` Specifies the default value for the global variable as an EdgeQL expression. The default value is used by the session if the value was not explicitly specified or was reset with the :ref:`reset ` command. :eql:synopsis:`create annotation := ` Assign an annotation to the global variable. See :eql:stmt:`create annotation` for details. Examples ^^^^^^^^ Define a new global property ``current_user_id``: .. code-block:: edgeql create global current_user_id: uuid; Define a new *computed* global property ``current_user`` based on the previously defined ``current_user_id``: .. code-block:: edgeql create global current_user := ( select User filter .id = global current_user_id ); Alter global ------------ :eql-statement: :eql-haswith: Change the definition of a global variable. .. eql:synopsis:: [ with [, ...] ] alter global [ "{" ; [...] "}" ] ; # where is one of set default := reset default rename to set required set optional reset optionalily set single set multi reset cardinality set type reset to default using () create annotation := alter annotation := drop annotation Description ^^^^^^^^^^^ The command :eql:synopsis:`alter global` changes the definition of a global variable. It can modify default values, rename the global, or change other attributes like optionality, cardinality, computed expressions, etc. Examples ^^^^^^^^ Set the ``description`` annotation of global variable ``current_user``: .. code-block:: edgeql alter global current_user create annotation description := 'Current User as specified by the global ID'; Make the ``current_user_id`` global variable ``required``: .. code-block:: edgeql alter global current_user_id { set required; # A required global variable MUST have a default value. set default := '00ea8eaa-02f9-11ed-a676-6bd11cc6c557'; } Drop global ----------- :eql-statement: :eql-haswith: Remove a global variable from the schema. .. eql:synopsis:: [ with [, ...] ] drop global ; Description ^^^^^^^^^^^ The command :eql:synopsis:`drop global` removes the specified global variable from the schema. Example ^^^^^^^ Remove the ``current_user`` global variable: .. code-block:: edgeql drop global current_user; ================================================ FILE: docs/reference/datamodel/index.rst ================================================ .. versioned-section:: .. _ref_datamodel_index: ====== Schema ====== .. toctree:: :maxdepth: 3 :hidden: objects properties links computeds primitives indexes constraints inheritance aliases globals access_policies permissions functions triggers mutation_rewrites linkprops modules migrations branches extensions annotations future comparison introspection/index |Gel| schema is a high-level description of your application's data model. In the schema, you define your types, links, access policies, functions, triggers, constraints, indexes, and more. Gel schema is strictly typed and is high-level enough to be mapped directly to mainstream programming languages and back. .. _ref_eql_sdl: Schema Definition Language ========================== Migrations are sequences of *data definition language* (DDL) commands. DDL is a low-level language that tells the database exactly how to change the schema. You typically won't need to write any DDL by hand; the Gel server will generate it for you. For a full guide on migrations, refer to the :ref:`Creating and applying migrations ` guide or the :ref:`migrations reference ` section. Example: .. code-block:: sdl # dbschema/default.gel type Movie { required title: str; required director: Person; } type Person { required name: str; } .. important:: Syntax highlighter packages/extensions for |.gel| files are available for `Visual Studio Code `_, `Sublime Text `_, `Atom `_, and `Vim `_. Migrations and DDL ================== Gel's baked-in migration system lets you painlessly evolve your schema over time. Just update the contents of your |.gel| file(s) and use the |Gel| CLI to *create* and *apply* migrations. .. code-block:: bash $ gel migration create Created dbschema/migrations/00001.edgeql $ gel migrate Applied dbschema/migrations/00001.edgeql Migrations are sequences of *data definition language* (DDL) commands. DDL is a low level language that tells the database how exactly to change the schema. Don't worry, you won't need to write any DDL directly, the Gel server will generate it for you. For a full guide on migrations, refer to the :ref:`Creating and applying migrations ` guide or the :ref:`migrations reference ` section. .. _ref_datamodel_terminology: .. _ref_datamodel_instances: Instances, branches, and modules ================================ Gel is like a stack of containers: * The *instance* is the running Gel process. Every instance has one or more |branches|. Instances can be created, started, stopped, and destroyed locally with :ref:`gel project ` or low-level :ref:`gel instance ` commands. * A *branch* is where your schema and data live. Branches map to PostgreSQL databases. Like instances, branches can be conveniently created, removed, and switched with the :ref:`gel branch ` commands. Read more about branches in the :ref:`branches reference `. * A *module* is a collection of types, functions, and other definitions. The default module is called ``default``. Modules are used to organize your schema logically. Read more about modules in the :ref:`modules reference `. ================================================ FILE: docs/reference/datamodel/indexes.rst ================================================ .. _ref_datamodel_indexes: ======= Indexes ======= .. index:: performance, postgres query planner An index is a data structure used internally to speed up filtering, ordering, and grouping operations in |Gel|. Indexes help accomplish this in two key ways: - They are pre-sorted, which saves time on costly sort operations on rows. - They can be used by the query planner to filter out irrelevant rows. .. note:: The Postgres query planner decides when to use indexes for a query. In some cases—e.g. when tables are small—it may be faster to scan the whole table rather than use an index. In such scenarios, the index might be ignored. For more information on how the planner decides this, see `the Postgres query planner documentation `_. Tradeoffs ========= While improving query performance, indexes also increase disk and memory usage and can slow down insertions and updates. Creating too many indexes may be detrimental; only index properties you often filter, order, or group by. .. important:: **Foreign and primary keys** In SQL databases, indexes are commonly used to index *primary keys* and *foreign keys*. Gel's analog to a SQL primary key is the ``id`` field automatically created for each object, while a link in Gel is the analog to a SQL foreign key. Both of these are automatically indexed. Moreover, any property with an :eql:constraint:`exclusive` constraint is also automatically indexed. Index on a property =================== Most commonly, indexes are declared within object type declarations and reference a particular property. The index can be used to speed up queries that reference that property in a filter, order by, or group by clause: .. code-block:: sdl type User { required name: str; index on (.name); } By indexing on ``User.name``, the query planner will have access to that index when planning queries using the ``name`` property. This may result in better performance as the database can look up a name in the index instead of scanning through all ``User`` objects sequentially—though ultimately it's up to the Postgres query planner whether to use the index. To see if an index helps, compare query plans by adding :ref:`analyze ` to your queries. .. note:: Even if your database is small now, you may benefit from an index as it grows. Index on an expression ====================== Indexes may be defined using an arbitrary *singleton* expression that references multiple properties of the enclosing object type. .. important:: A singleton expression is an expression that's guaranteed to return *at most one* element. As such, you can't index on a ``multi`` property. Example: .. code-block:: sdl type User { required first_name: str; required last_name: str; index on (str_lower(.first_name + ' ' + .last_name)); } Index on multiple properties ============================ A *composite index* references multiple properties. This can speed up queries that filter, order, or group on multiple properties at once. .. note:: An index on multiple properties may also be used in queries where only a single property in the index is referenced. In many traditional database systems, placing the most frequently used columns first in the composite index can improve the likelihood of its use. Read `the Postgres documentation on multicolumn indexes `_ to learn more about how the query planner uses these indexes. In |Gel|, a composite index is created by indexing on a ``tuple`` of properties: .. code-block:: sdl type User { required name: str; required email: str; index on ((.name, .email)); } Index on a link property ======================== Link properties can also be indexed. The special placeholder ``__subject__`` refers to the source object in a link property expression: .. code-block:: sdl abstract link friendship { strength: float64; index on (__subject__@strength); } type User { multi friends: User { extending friendship; }; } Exclude objects from an index ============================= When specifying an index, you can provide an optional ``except`` clause to exclude objects from the index. This is known as creating a *partial index*. Partial indexes are particularly useful in scenarios where you frequently query a subset of data that meets certain criteria, while consistently excluding other data. For example, if you often filter on a property but always exclude objects with a specific value for another property, a partial index can optimize these queries by indexing only the relevant subset of data, thus improving query performance and reducing index size. .. code-block:: sdl type User { required name: str; required email: str; archived_at: datetime; index on (.name) except (exists .archived_at); } Specify a Postgres index type ============================= .. api-index:: pg::hash, pg::btree, pg::gin, pg::gist, pg::spgist, pg::brin .. versionadded:: 3.0 Gel exposes Postgres index types that can be used directly in schemas via the ``pg`` module: - ``pg::hash`` : Index based on a 32-bit hash of the value - ``pg::btree`` : B-tree index (can help with sorted data retrieval) - ``pg::gin`` : Inverted index for multi-element data (arrays, JSON) - ``pg::gist`` : Generalized Search Tree for range and geometric searches - ``pg::spgist`` : Space-partitioned GiST - ``pg::brin`` : Block Range INdex Example: .. code-block:: sdl type User { required name: str; index pg::spgist on (.name); } .. _ref_datamodel_indexes_concurrent: Concurrent index building ========================= When creating an index, the object type will be locked for writes. This means that until the index is created, all ``insert``, ``update`` and ``delete`` queries will be put on hold. On types containing many objects, this can span minutes or even hours. Instead, index building can be deferred from migration application to a later time. To do this, set ``build_concurrently`` index property to ``true``: .. code-block:: sdl type User { name: str; index on (.name) { build_concurrently := true; }; } When this schema in applied to an instance, the index will be created, but it will not yet be active. The migration will not attempt to read any objects to build the index. As the last step of :gelcmd:`migration apply` (and :gelcmd:`migrate`), index will actually be built. During this time, the object type will not be locked for reads or writes. This means that migration will lock for significantly less time and allow index the be created while new writes are applied to the database. To apply migrations, but not build indexes at all, use :gelcmd:`migration apply --no-index-build` flag. This allows index building to be triggered at a later time, by using :gelcmd:`migration apply` again. Until the index is created, it will not be used to speed up queries. For tradeoffs of concurrent index building, refer to `PostgreSQL documentation `_. Annotate an index ================= Indexes can include annotations: .. code-block:: sdl type User { name: str; index on (.name) { annotation description := 'Indexing all users by name.'; }; } .. _ref_eql_sdl_indexes: Declaring indexes ================= .. api-index:: index on, except This section describes the syntax to use indexes in your schema. Syntax ------ .. sdl:synopsis:: index on ( ) [ except ( ) ] [ "{" "}" ] ; .. rubric:: Description - :sdl:synopsis:`on ( )` The expression to index. It must be :ref:`Immutable ` but may refer to the indexed object's properties/links. The expression itself must be parenthesized. - :eql:synopsis:`except ( )` An optional condition. If ```` evaluates to ``true``, the object is omitted from the index; if ``false`` or empty, it is included. - :sdl:synopsis:`` Allows setting index :ref:`annotation ` to a given value. - :sdl:synopsis:`build_concurrently := ` Allows index to be built :ref:`after migration is applied ` to the instance. .. _ref_eql_ddl_indexes: DDL commands ============ This section describes the low-level DDL commands for creating, altering, and dropping indexes. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create index ------------ :eql-statement: .. eql:synopsis:: create index on ( ) [ except ( ) ] [ "{" ; [...] "}" ] ; # where is one of create annotation := Creates a new index for a given object type or link using *index-expr*. - Most parameters/options match those in :ref:`Declaring indexes `. - Allowed subcommand: :eql:synopsis:`create annotation := ` Assign an annotation to this index. See :eql:stmt:`create annotation` for details. - :eql:synopsis:`set build_concurrently := ` Allows index to be built :ref:`after migration is applied ` to the instance. Example: .. code-block:: edgeql create type User { create property name: str { set default := ''; }; create index on (.name); }; Alter index ----------- :eql-statement: Alter the definition of an index. .. eql:synopsis:: alter index on ( ) [ except ( ) ] [ "{" ; [...] "}" ] ; # where is one of create annotation := alter annotation := drop annotation The command ``alter index`` is used to change the :ref:`annotations ` of an index. The *index-expr* is used to identify the index to be altered. :sdl:synopsis:`on ( )` The specific expression for which the index is made. Note also that ```` itself has to be parenthesized. The following subcommands are allowed in the ``alter index`` block: :eql:synopsis:`create annotation := ` Set index :eql:synopsis:`` to :eql:synopsis:``. See :eql:stmt:`create annotation` for details. :eql:synopsis:`alter annotation ;` Alter index :eql:synopsis:``. See :eql:stmt:`alter annotation` for details. :eql:synopsis:`drop annotation ;` Remove constraint :eql:synopsis:``. See :eql:stmt:`drop annotation` for details. Example: .. code-block:: edgeql alter type User { alter index on (.name) { create annotation title := 'User name index'; }; }; Drop index ---------- :eql-statement: Remove an index from a given schema item. .. eql:synopsis:: drop index on ( ) [ except ( ) ] ; Removes an index from a schema item. - :sdl:synopsis:`on ( )` identifies the indexed expression. This statement can only be used as a subdefinition in another DDL statement. Example: .. code-block:: edgeql alter type User { drop index on (.name); }; .. list-table:: :class: seealso * - **See also** - :ref:`Introspection > Indexes ` ================================================ FILE: docs/reference/datamodel/inheritance.rst ================================================ .. _ref_datamodel_inheritance: =========== Inheritance =========== .. index:: extending, extends, subtype, supertype, parent type, child type Inheritance is a crucial aspect of schema modeling in Gel. Schema items can *extend* one or more parent types. When extending, the child (subclass) inherits the definition of its parents (superclass). You can declare ``abstract`` object types, properties, links, constraints, and annotations. - :ref:`Objects ` - :ref:`Properties ` - :ref:`Links ` - :ref:`Constraints ` - :ref:`Annotations ` .. _ref_datamodel_inheritance_objects: Object types ------------ .. api-index:: abstract type, extending Object types can *extend* other object types. The extending type (AKA the *subtype*) inherits all links, properties, indexes, constraints, etc. from its *supertypes*. .. code-block:: sdl abstract type Animal { species: str; } type Dog extending Animal { breed: str; } Both abstract and concrete object types can be extended. Whether to make a type abstract or concrete is a fairly simple decision: if you need to be able to insert objects of the type, make it a concrete type. If objects of the type should never be inserted and it exists only to be extended, make it an abstract one. In the schema below the ``Animal`` type is now concrete and can be inserted, which was not the case in the example above. The new ``CanBark`` type however is abstract and thus the database will not have any individual ``CanBark`` objects. .. code-block:: sdl abstract type CanBark { required bark_sound: str; } type Animal { species: str; } type Dog extending Animal, CanBark { breed: str; } For details on querying polymorphic data, see :ref:`EdgeQL > Select > Polymorphic queries `. When using the SQL adapter, see :ref:`SQL adapter ` for information about using ``ONLY`` to query parent tables without including child objects. .. _ref_datamodel_inheritance_multiple: Multiple Inheritance ^^^^^^^^^^^^^^^^^^^^ Object types can :ref:`extend more than one type ` — that's called *multiple inheritance*. This mechanism allows building complex object types out of combinations of more basic types. .. code-block:: sdl abstract type HasName { first_name: str; last_name: str; } abstract type HasEmail { email: str; } type Person extending HasName, HasEmail { profession: str; } .. _ref_datamodel_overloading: Overloading ^^^^^^^^^^^ .. api-index:: overloaded An object type can overload an inherited property or link. All overloaded declarations must be prefixed with the ``overloaded`` prefix to avoid unintentional overloads. .. code-block:: sdl abstract type Person { name: str; multi friends: Person; } type Student extending Person { overloaded name: str { constraint exclusive; } overloaded multi friends: Student; } Overloaded fields cannot *generalize* the associated type; it can only make it *more specific* by setting the type to a subtype of the original or adding additional constraints. .. _ref_datamodel_inheritance_props: Properties ---------- .. api-index:: abstract property, readonly Properties can be *concrete* (the default) or *abstract*. Abstract properties are declared independent of a source or target, can contain :ref:`annotations `, and can be marked as ``readonly``. .. code-block:: sdl abstract property title_prop { annotation title := 'A title.'; readonly := false; } .. _ref_datamodel_inheritance_links: Links ----- .. api-index:: abstract link It's possible to define ``abstract`` links that aren't tied to a particular *source* or *target*. Abstract links can be marked as readonly and contain annotations, property declarations, constraints, and indexes. .. code-block:: sdl abstract link link_with_strength { strength: float64; index on (__subject__@strength); } type Person { multi friends: Person { extending link_with_strength; }; } .. _ref_datamodel_inheritance_constraints: Constraints ----------- .. api-index:: abstract constraint, using, errmessage Use ``abstract`` to declare reusable, user-defined constraint types. .. code-block:: sdl abstract constraint in_range(min: anyreal, max: anyreal) { errmessage := 'Value must be in range [{min}, {max}].'; using (min <= __subject__ and __subject__ < max); } type Player { points: int64 { constraint in_range(0, 100); } } .. _ref_datamodel_inheritance_annotations: Annotations ----------- .. api-index:: abstract annotation, inheritable EdgeQL supports three annotation types by default: ``title``, ``description``, and ``deprecated``. Use ``abstract annotation`` to declare custom user-defined annotation types. .. code-block:: sdl abstract annotation admin_note; type Status { annotation admin_note := 'system-critical'; # more properties } By default, annotations defined on abstract types, properties, and links will not be inherited by their subtypes. To override this behavior, use the ``inheritable`` modifier. .. code-block:: sdl abstract inheritable annotation admin_note; ================================================ FILE: docs/reference/datamodel/introspection/casts.rst ================================================ .. _ref_datamodel_introspection_casts: ===== Casts ===== This section describes introspection of Gel :eql:op:`type casts `. Features like whether the casts are implicit can be discovered by introspecting ``schema::Cast``. Introspection of the ``schema::Cast``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::Cast'; { Object { name: 'schema::Cast', links: { Object { name: '__type__' }, Object { name: 'from_type' }, Object { name: 'to_type' } }, properties: { Object { name: 'allow_assignment' }, Object { name: 'allow_implicit' }, Object { name: 'id' }, Object { name: 'name' } } } } Introspection of the possible casts from ``std::int64`` to other types: .. code-block:: edgeql-repl db> with module schema ... select Cast { ... allow_assignment, ... allow_implicit, ... to_type: { name }, ... } ... filter .from_type.name = 'std::int64' ... order by .to_type.name; { Object { allow_assignment: false, allow_implicit: true, to_type: Object { name: 'std::bigint' } }, Object { allow_assignment: false, allow_implicit: true, to_type: Object { name: 'std::decimal' } }, Object { allow_assignment: true, allow_implicit: false, to_type: Object { name: 'std::float32' } }, Object { allow_assignment: false, allow_implicit: true, to_type: Object { name: 'std::float64' } }, Object { allow_assignment: true, allow_implicit: false, to_type: Object { name: 'std::int16' } }, Object { allow_assignment: true, allow_implicit: false, to_type: Object { name: 'std::int32' } }, Object { allow_assignment: false, allow_implicit: false, to_type: Object { name: 'std::json' } }, Object { allow_assignment: false, allow_implicit: false, to_type: Object { name: 'std::str' } } } The ``allow_implicit`` property tells whether this is an *implicit cast* in all contexts (such as when determining the type of a set of mixed literals or resolving the argument types of functions or operators if there's no exact match). For example, a literal ``1`` is an :eql:type:`int64` and it is implicitly cast into a :eql:type:`bigint` or :eql:type:`float64` if it is added to a set containing either one of those types: .. code-block:: edgeql-repl db> select {1, 2n}; {1n, 2n} db> select {1, 2.0}; {1.0, 2.0} What happens if there's no implicit cast between a couple of scalars in this type of example? Gel checks whether there's a scalar type such that all of the set elements can be implicitly cast into that: .. code-block:: edgeql-repl db> select introspect (typeof {1, 2}).name; {'std::float64'} The scalar types :eql:type:`int64` and :eql:type:`float32` cannot be implicitly cast into each other, but they both can be implicitly cast into :eql:type:`float64`. The ``allow_assignment`` property tells whether this is an implicit cast during assignment if a more general *implicit cast* is not allowed. For example, consider the following type: .. code-block:: sdl type Example { property p_int16: int16; property p_float32: float32; property p_json: json; } .. code-block:: edgeql-repl db> insert Example { ... p_int16 := 1, ... p_float32 := 2 ... }; {Object { id: '...' }} db> insert Example { ... p_json := 3 # assignment cast to json not allowed ... }; InvalidPropertyTargetError: invalid target for property 'p_json' of object type 'default::Example': 'std::int64' (expecting 'std::json') ================================================ FILE: docs/reference/datamodel/introspection/colltypes.rst ================================================ .. _ref_datamodel_introspection_collection_types: ================ Collection types ================ This section describes introspection of :ref:`collection types `. Array ----- Introspection of the ``schema::Array``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::Array'; { Object { name: 'schema::Array', links: { Object { name: '__type__' }, Object { name: 'element_type' } }, properties: { Object { name: 'id' }, Object { name: 'name' } } } } For a type with an :eql:type:`array` property, consider the following: .. code-block:: sdl type User { required property name: str; property favorites: array; } Introspection of the ``User`` with emphasis on properties: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... properties: { ... name, ... target: { ... name, ... [is Array].element_type: { name }, ... }, ... }, ... } ... filter .name = 'default::User'; { Object { name: 'default::User', properties: { Object { name: 'favorites', target: Object { name: 'array', element_type: Object { name: 'std::str' } } }, ... } } } Tuple ----- Introspection of the ``schema::Tuple``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::Tuple'; { Object { name: 'schema::Tuple', links: { Object { name: '__type__' }, Object { name: 'element_types' } }, properties: { Object { name: 'id' }, Object { name: 'name' } } } } For example, below is an introspection of the return type of the :eql:func:`sys::get_version` function: .. code-block:: edgeql-repl db> with module schema ... select `Function` { ... return_type[is Tuple]: { ... element_types: { ... name, ... type: { name } ... } order by .num ... } ... } ... filter .name = 'sys::get_version'; { Object { return_type: Object { element_types: { Object { name: 'major', type: Object { name: 'std::int64' } }, Object { name: 'minor', type: Object { name: 'std::int64' } }, Object { name: 'stage', type: Object { name: 'sys::VersionStage' } }, Object { name: 'stage_no', type: Object { name: 'std::int64' } }, Object { name: 'local', type: Object { name: 'array' } } } } } } ================================================ FILE: docs/reference/datamodel/introspection/constraints.rst ================================================ .. _ref_datamodel_introspection_constraints: =========== Constraints =========== This section describes introspection of :ref:`constraints `. Introspection of the ``schema::Constraint``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::Constraint'; { Object { name: 'schema::Constraint', links: { Object { name: '__type__' }, Object { name: 'args' }, Object { name: 'annotations' }, Object { name: 'bases' }, Object { name: 'ancestors' }, Object { name: 'params' }, Object { name: 'return_type' }, Object { name: 'subject' } }, properties: { Object { name: 'errmessage' }, Object { name: 'expr' }, Object { name: 'finalexpr' }, Object { name: 'id' }, Object { name: 'abstract' }, Object { name: 'name' }, Object { name: 'return_typemod' }, Object { name: 'subjectexpr' } } } } Consider the following schema: .. code-block:: sdl scalar type maxex_100 extending int64 { constraint max_ex_value(100); } Introspection of the scalar ``maxex_100`` with focus on the constraint: .. code-block:: edgeql-repl db> with module schema ... select ScalarType { ... name, ... constraints: { ... name, ... expr, ... annotations: { name, @value }, ... subject: { name }, ... params: { name, @value, type: { name } }, ... return_typemod, ... return_type: { name }, ... errmessage, ... }, ... } ... filter .name = 'default::maxex_100'; { Object { name: 'default::maxex_100', constraints: { Object { name: 'std::max_ex_value', expr: '(__subject__ <= max)', annotations: {}, subject: Object { name: 'default::maxex_100' }, params: { Object { name: 'max', type: Object { name: 'anytype' }, @value: '100' } }, return_typemod: 'SingletonType', return_type: Object { name: 'std::bool' } errmessage: '{__subject__} must be less ...', } } } } .. list-table:: :class: seealso * - **See also** * - :ref:`Schema > Constraints ` * - :ref:`SDL > Constraints ` * - :ref:`DDL > Constraints ` * - :ref:`Standard Library > Constraints ` ================================================ FILE: docs/reference/datamodel/introspection/functions.rst ================================================ .. _ref_datamodel_introspection_functions: ========= Functions ========= This section describes introspection of :ref:`functions `. Introspection of the ``schema::Function``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::Function'; { Object { name: 'schema::Function', links: { Object { name: '__type__' }, Object { name: 'annotations' }, Object { name: 'params' }, Object { name: 'return_type' } }, properties: { Object { name: 'id' }, Object { name: 'name' }, Object { name: 'return_typemod' } } } } Since ``params`` are quite important to functions, here's their structure: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::Parameter'; { Object { name: 'schema::Parameter', links: { Object { name: '__type__' }, Object { name: 'type' } }, properties: { Object { name: 'default' }, Object { name: 'id' }, Object { name: 'kind' }, Object { name: 'name' }, Object { name: 'num' }, Object { name: 'typemod' } } } } Introspection of the built-in :eql:func:`count`: .. code-block:: edgeql-repl db> with module schema ... select `Function` { ... name, ... annotations: { name, @value }, ... params: { ... kind, ... name, ... num, ... typemod, ... type: { name }, ... default, ... }, ... return_typemod, ... return_type: { name }, ... } ... filter .name = 'std::count'; { Object { name: 'std::count', annotations: {}, params: { Object { kind: 'PositionalParam', name: 's', num: 0, typemod: 'SetOfType', type: Object { name: 'anytype' }, default: {} } }, return_typemod: 'SingletonType', return_type: Object { name: 'std::int64' } } } .. list-table:: :class: seealso * - **See also** * - :ref:`Schema > Functions ` * - :ref:`SDL > Functions ` * - :ref:`DDL > Functions ` * - :ref:`Reference > Function calls ` * - :ref:`Cheatsheets > Functions ` ================================================ FILE: docs/reference/datamodel/introspection/index.rst ================================================ .. _ref_datamodel_introspection: Introspection ============= .. index:: schema module .. api-index:: describe, introspect, typeof All of the schema information in Gel is stored in the ``schema`` :ref:`module ` and is accessible via *introspection queries*. All the introspection types are themselves extending :eql:type:`BaseObject`, so they are also subject to introspection :ref:`as object types `. The following query will give a list of all the types used in introspection: .. code-block:: edgeql select name := schema::ObjectType.name filter name like 'schema::%'; There's also a couple of ways of getting the introspection type of a particular expression. Any :eql:type:`Object` has a ``__type__`` link to the ``schema::ObjectType``. For scalars there's the :eql:op:`introspect` and :eql:op:`typeof` operators that can be used to get the type of an expression. Finally, the command :eql:stmt:`describe` can be used to get information about Gel types in a variety of human-readable formats. .. toctree:: :maxdepth: 3 :hidden: objects scalars colltypes functions triggers mutation_rewrites indexes constraints operators casts ================================================ FILE: docs/reference/datamodel/introspection/indexes.rst ================================================ .. _ref_datamodel_introspection_indexes: ======= Indexes ======= This section describes introspection of :ref:`indexes `. Introspection of the ``schema::Index``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::Index'; { Object { name: 'schema::Index', links: {Object { name: '__type__' }}, properties: { Object { name: 'expr' }, Object { name: 'id' }, Object { name: 'name' } } } } Consider the following schema: .. code-block:: sdl abstract type Addressable { property address: str; } type User extending Addressable { # define some properties and a link required property name: str; multi link friends: User; # define an index for User based on name index on (.name); } Introspection of ``User.name`` index: .. code-block:: edgeql-repl db> with module schema ... select Index { ... expr, ... } ... filter .expr like '%.name'; { Object { expr: '.name' } } For introspection of the index within the context of its host type see :ref:`object type introspection `. .. list-table:: :class: seealso * - **See also** * - :ref:`Schema > Indexes ` * - :ref:`SDL > Indexes ` * - :ref:`DDL > Indexes ` ================================================ FILE: docs/reference/datamodel/introspection/mutation_rewrites.rst ================================================ .. _ref_datamodel_introspection_mutation_rewrites: ================= Mutation rewrites ================= This section describes introspection of :ref:`mutation rewrites `. Introspection of the ``schema::Rewrite``: .. code-block:: edgeql-repl db> select schema::ObjectType { ... name, ... links: { ... name ... }, ... properties: { ... name ... } ... } filter .name = 'schema::Rewrite'; { schema::ObjectType { name: 'schema::Rewrite', links: { schema::Link {name: 'subject'}, schema::Link {name: '__type__'}, schema::Link {name: 'ancestors'}, schema::Link {name: 'bases'}, schema::Link {name: 'annotations'} }, properties: { schema::Property {name: 'inherited_fields'}, schema::Property {name: 'computed_fields'}, schema::Property {name: 'builtin'}, schema::Property {name: 'internal'}, schema::Property {name: 'name'}, schema::Property {name: 'id'}, schema::Property {name: 'abstract'}, schema::Property {name: 'is_abstract'}, schema::Property {name: 'final'}, schema::Property {name: 'is_final'}, schema::Property {name: 'kind'}, schema::Property {name: 'expr'}, }, }, } Introspection of all properties in the ``default`` schema with a mutation rewrite: .. code-block:: edgeql-repl db> select schema::ObjectType { ... name, ... properties := ( ... select .properties { ... name, ... rewrites: { ... kind ... } ... } filter exists .rewrites ... ) ... } filter .name ilike 'default::%' ... and exists .properties.rewrites; { schema::ObjectType { name: 'default::Post', properties: { schema::Property { name: 'created', rewrites: { schema::Rewrite { kind: Insert } } }, schema::Property { name: 'modified', rewrites: { schema::Rewrite { kind: Insert }, schema::Rewrite { kind: Update } } }, }, }, } Introspection of all rewrites, including the type of query (``kind``), rewrite expression, and the object and property they are on: .. code-block:: edgeql-repl db> select schema::Rewrite { ... subject := ( ... select .subject { ... name, ... source: { ... name ... } ... } ... ), ... kind, ... expr ... }; { schema::Rewrite { subject: schema::Property { name: 'created', source: schema::ObjectType { name: 'default::Post' } }, kind: Insert, expr: 'std::datetime_of_statement()' }, schema::Rewrite { subject: schema::Property { name: 'modified', source: schema::ObjectType { name: 'default::Post' } }, kind: Insert, expr: 'std::datetime_of_statement()' }, schema::Rewrite { subject: schema::Property { name: 'modified', source: schema::ObjectType { name: 'default::Post' } }, kind: Update, expr: 'std::datetime_of_statement()' }, } Introspection of all rewrites on a ``default::Post`` property named ``modified``: .. code-block:: edgeql-repl db> select schema::Rewrite {kind, expr} ... filter .subject.source.name = 'default::Post' ... and .subject.name = 'modified'; { schema::Rewrite { kind: Insert, expr: 'std::datetime_of_statement()' }, schema::Rewrite { kind: Update, expr: 'std::datetime_of_statement()' } } .. list-table:: :class: seealso * - **See also** * - :ref:`Schema > Mutation rewrites ` * - :ref:`SDL > Mutation rewrites ` * - :ref:`DDL > Mutation rewrites ` ================================================ FILE: docs/reference/datamodel/introspection/objects.rst ================================================ .. _ref_datamodel_introspection_object_types: ============ Object types ============ This section describes introspection of :ref:`object types `. Introspection of the ``schema::ObjectType``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::ObjectType'; { Object { name: 'schema::ObjectType', links: { Object { name: '__type__' }, Object { name: 'annotations' }, Object { name: 'bases' }, Object { name: 'constraints' }, Object { name: 'indexes' }, Object { name: 'links' }, Object { name: 'ancestors' }, Object { name: 'pointers' }, Object { name: 'properties' } }, properties: { Object { name: 'id' }, Object { name: 'abstract' }, Object { name: 'name' } } } } Consider the following schema: .. code-block:: sdl abstract type Addressable { address: str; } type User extending Addressable { # define some properties and a link required name: str; multi friends: User; # define an index for User based on name index on (.name); } Introspection of ``User``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... abstract, ... bases: { name }, ... ancestors: { name }, ... annotations: { name, @value }, ... links: { ... name, ... cardinality, ... required, ... target: { name }, ... }, ... properties: { ... name, ... cardinality, ... required, ... target: { name }, ... }, ... constraints: { name }, ... indexes: { expr }, ... } ... filter .name = 'default::User'; { Object { name: 'default::User', abstract: false, bases: {Object { name: 'default::Addressable' }}, ancestors: { Object { name: 'std::BaseObject' }, Object { name: 'std::Object' }, Object { name: 'default::Addressable' } }, annotations: {}, links: { Object { name: '__type__', cardinality: 'One', required: {}, target: Object { name: 'schema::Type' } }, Object { name: 'friends', cardinality: 'Many', required: false, target: Object { name: 'default::User' } } }, properties: { Object { name: 'address', cardinality: 'One', required: false, target: Object { name: 'std::str' } }, Object { name: 'id', cardinality: 'One', required: true, target: Object { name: 'std::uuid' } }, Object { name: 'name', cardinality: 'One', required: true, target: Object { name: 'std::str' } } }, constraints: {}, indexes: { Object { expr: '.name' } } } } .. list-table:: :class: seealso * - **See also** * - :ref:`Schema > Object types ` * - :ref:`SDL > Object types ` * - :ref:`DDL > Object types ` * - :ref:`Cheatsheets > Object types ` ================================================ FILE: docs/reference/datamodel/introspection/operators.rst ================================================ .. _ref_datamodel_introspection_operators: ========= Operators ========= This section describes introspection of Gel operators. Much like functions, operators have parameters and return types as well as a few other features. Introspection of the ``schema::Operator``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::Operator'; { Object { name: 'schema::Operator', links: { Object { name: '__type__' }, Object { name: 'annotations' }, Object { name: 'params' }, Object { name: 'return_type' } }, properties: { Object { name: 'id' }, Object { name: 'name' }, Object { name: 'operator_kind' }, Object { name: 'return_typemod' } } } } Since ``params`` are quite important to operators, here's their structure: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::Parameter'; { Object { name: 'schema::Parameter', links: { Object { name: '__type__' }, Object { name: 'type' } }, properties: { Object { name: 'default' }, Object { name: 'id' }, Object { name: 'kind' }, Object { name: 'name' }, Object { name: 'num' }, Object { name: 'typemod' } } } } Introspection of the :eql:op:`and` operator: .. code-block:: edgeql-repl db> with module schema ... select Operator { ... name, ... operator_kind, ... annotations: { name, @value }, ... params: { ... kind, ... name, ... num, ... typemod, ... type: { name }, ... default, ... }, ... return_typemod, ... return_type: { name }, ... } ... filter .name = 'std::AND'; { Object { name: 'std::AND', operator_kind: 'Infix', annotations: {}, params: { Object { kind: 'PositionalParam', name: 'a', num: 0, typemod: 'SingletonType', type: Object { name: 'std::bool' }, default: {} }, Object { kind: 'PositionalParam', name: 'b', num: 1, typemod: 'SingletonType', type: Object { name: 'std::bool' }, default: {} } }, return_typemod: 'SingletonType', return_type: Object { name: 'std::bool' } } } ================================================ FILE: docs/reference/datamodel/introspection/scalars.rst ================================================ .. _ref_datamodel_introspection_scalar_types: ============ Scalar types ============ This section describes introspection of :ref:`scalar types `. Introspection of the ``schema::ScalarType``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } ... filter .name = 'schema::ScalarType'; { Object { name: 'schema::ScalarType', links: { Object { name: '__type__' }, Object { name: 'annotations' }, Object { name: 'bases' }, Object { name: 'constraints' }, Object { name: 'ancestors' } }, properties: { Object { name: 'default' }, Object { name: 'enum_values' }, Object { name: 'id' }, Object { name: 'abstract' }, Object { name: 'name' } } } } Introspection of the built-in scalar :eql:type:`str`: .. code-block:: edgeql-repl db> with module schema ... select ScalarType { ... name, ... default, ... enum_values, ... abstract, ... bases: { name }, ... ancestors: { name }, ... annotations: { name, @value }, ... constraints: { name }, ... } ... filter .name = 'std::str'; { Object { name: 'std::str', default: {}, enum_values: {}, abstract: {}, bases: {Object { name: 'std::anyscalar' }}, ancestors: {Object { name: 'std::anyscalar' }}, annotations: {}, constraints: {} } } For an :ref:`enumerated scalar type `, consider the following: .. code-block:: sdl scalar type Color extending enum; Introspection of the enum scalar ``Color``: .. code-block:: edgeql-repl db> with module schema ... select ScalarType { ... name, ... default, ... enum_values, ... abstract, ... bases: { name }, ... ancestors: { name }, ... annotations: { name, @value }, ... constraints: { name }, ... } ... filter .name = 'default::Color'; { Object { name: 'default::Color', default: {}, enum_values: ['Red', 'Green', 'Blue'], abstract: {}, bases: {Object { name: 'std::anyenum' }}, ancestors: { Object { name: 'std::anyscalar' }, Object { name: 'std::anyenum' } }, annotations: {}, constraints: {} } } ================================================ FILE: docs/reference/datamodel/introspection/triggers.rst ================================================ .. _ref_datamodel_introspection_triggers: ========= Triggers ========= This section describes introspection of :ref:`triggers `. Introspection of ``schema::Trigger``: .. code-block:: edgeql-repl db> with module schema ... select ObjectType { ... name, ... links: { ... name, ... }, ... properties: { ... name, ... } ... } filter .name = 'schema::Trigger'; { schema::ObjectType { name: 'schema::Trigger', links: { schema::Link {name: 'subject'}, schema::Link {name: '__type__'}, schema::Link {name: 'ancestors'}, schema::Link {name: 'bases'}, schema::Link {name: 'annotations'} }, properties: { schema::Property {name: 'inherited_fields'}, schema::Property {name: 'computed_fields'}, schema::Property {name: 'builtin'}, schema::Property {name: 'internal'}, schema::Property {name: 'name'}, schema::Property {name: 'id'}, schema::Property {name: 'abstract'}, schema::Property {name: 'is_abstract'}, schema::Property {name: 'final'}, schema::Property {name: 'is_final'}, schema::Property {name: 'timing'}, schema::Property {name: 'kinds'}, schema::Property {name: 'scope'}, schema::Property {name: 'expr'}, }, }, } Introspection of a trigger named ``log_insert`` on the ``User`` type: .. lint-off .. code-block:: edgeql-repl db> with module schema ... select Trigger { ... name, ... kinds, ... timing, ... scope, ... expr, ... subject: { ... name ... } ... } filter .name = 'log_insert'; { schema::Trigger { name: 'log_insert', kinds: {Insert}, timing: After, scope: Each, expr: 'insert default::Log { action := \'insert\', target_name := __new__.name }', subject: schema::ObjectType {name: 'default::User'}, }, } .. lint-on .. list-table:: :class: seealso * - **See also** * - :ref:`Schema > Triggers ` * - :ref:`SDL > Triggers ` * - :ref:`DDL > Triggers ` ================================================ FILE: docs/reference/datamodel/linkprops.rst ================================================ .. _ref_datamodel_linkprops: =============== Link properties =============== .. index:: link property, linkprops, link table, relations .. api-index:: @ Links, like objects, can also contain **properties**. These are used to store metadata about the link. Due to how they're persisted under the hood, link properties have the additional constraint of always being ``single``. Link properties require non-trivial syntax to use them, so they are considered to be an advanced feature. In many cases, regular properties should be used instead. To paraphrase a famous quote: "Link properties are like a parachute, you don't need them very often, but when you do, they can be clutch." .. note:: In practice, link properties are best used with many-to-many relationships (``multi`` links without any exclusive constraints). For one-to-one, one-to-many, and many-to-one relationships the same data should be stored in object properties instead. .. versionchanged:: 7.0 Link properties can now be made required. Declaration =========== Let's a create a ``Person.friends`` link with a ``strength`` property corresponding to the strength of the friendship. .. code-block:: sdl type Person { required name: str { constraint exclusive }; multi friends: Person { strength: float64; } } Constraints =========== Now let's ensure that the ``@strength`` property is always non-negative: .. code-block:: sdl type Person { required name: str { constraint exclusive }; multi friends: Person { strength: float64; constraint expression on ( __subject__@strength >= 0 ); } } Indexes ======= To add an index on a link property, we have to refactor our code and define an abstract link ``friendship`` that will contain the ``strength`` property with an index on it: .. code-block:: sdl abstract link friendship { required strength: float64; index on (__subject__@strength); } type Person { required name: str { constraint exclusive }; multi friends: Person { extending friendship; }; } Conceptualizing link properties =============================== A way to conceptualize the difference between a regular property and a link property is that regular properties are used to construct an object, while link properties are used to construct the link between objects. For example, here the ``name`` and ``email`` properties are used to construct a ``Person`` object: .. code-block:: edgeql insert Person { name := "Jane", email := "jane@jane.com" } Now let's insert a ``Person`` object linking it to another ``Person`` object setting the ``@strength`` property to the link between them: .. code-block:: edgeql insert Person { name := "Bob", email := "bob@bob.com", friends := ( insert Person { name := "Jane", email := "jane@jane.com", @strength := 3.14 } ) } So we're not using ``@strength`` to construct a particular ``Person`` object, but to quantify a link between two ``Person`` objects. Inserting ========= What if we want to insert a ``Person`` object while linking it to another ``Person`` that's already in the database? The ``@strength`` property then will be specified in the *shape* of a ``select`` subquery: .. code-block:: edgeql insert Person { name := "Bob", friends := ( select detached Person { @strength := 3.14 } filter .name = "Alice" ) } .. note:: We are using the :eql:op:`detached` operator to unbind the ``Person`` reference from the scope of the ``insert`` query. When doing a nested insert, link properties can be directly included in the inner ``insert`` subquery: .. code-block:: edgeql insert Person { name := "Bob", friends := ( insert Person { name := "Jane", @strength := 3.14 } ) } Similarly, ``with`` can be used to capture an expression returning an object type, after which a link property can be added when linking it to another object type: .. code-block:: edgeql with alice := ( insert Person { name := "Alice" } unless conflict on .name else ( select Person filter .name = "Alice" limit 1 ) ) insert Person { name := "Bob", friends := alice { @strength := 3.14 } }; Updating ======== .. code-block:: edgeql update Person filter .name = "Bob" set { friends += ( select .friends { @strength := 3.7 } filter .name = "Alice" ) }; The example updates the ``@strength`` property of Bob's friends link to Alice to 3.7. In the context of multi links the ``+=`` operator works like an an insert/update operator. To update one or more links in a multi link, you can select from the current linked objects, as the example does. Use a ``detached`` selection if you want to insert/update a wider selection of linked objects instead. Selecting ========= To select a link property, you can use the ``@<>name`` syntax inside the select *shape*. Keep in mind, that you're not selecting a property on an object with this syntax, but rather on the link, in this case ``friends``: .. code-block:: edgeql-repl gel> select Person { .... name, .... friends: { .... name, .... @strength .... } .... }; { default::Person {name: 'Alice', friends: {}}, default::Person { name: 'Bob', friends: { default::Person {name: 'Alice', @strength: 3.7} } }, } Unions ====== A link property cannot be referenced in a set union *except* in the case of a :ref:`for loop `. That means this will *not* work: .. code-block:: edgeql # 🚫 Does not work insert Movie { title := 'The Incredible Hulk', actors := {( select Person { @character_name := 'The Hulk' } filter .name = 'Mark Ruffalo' ), ( select Person { @character_name := 'Iron Man' } filter .name = 'Robert Downey Jr.' )} }; That query will produce an error: ``QueryError: invalid reference to link property in top level shape`` You can use this workaround instead: .. code-block:: edgeql # ✅ Works! insert Movie { title := 'The Incredible Hulk', actors := assert_distinct(( with characters := { ('The Hulk', 'Mark Ruffalo'), ('Iron Man', 'Robert Downey Jr.') } for character in characters union ( select Person { @character_name := character.0 } filter .name = character.1 ) )) }; Note that we are also required to wrap the ``actors`` query with :eql:func:`assert_distinct` here to assure the compiler that the result set is distinct. With computed backlinks ======================= Specifying link properties of a computed backlink in your shape is also supported. If you have this schema: .. code-block:: sdl type Person { required name: str; multi follows: Person { followed: datetime { default := datetime_of_statement(); }; }; multi link followers := .` * - :ref:`Properties in schema ` * - :ref:`Properties with DDL ` ================================================ FILE: docs/reference/datamodel/links.rst ================================================ .. _ref_datamodel_links: ===== Links ===== Links define a relationship between two :ref:`object types ` in Gel. Links in |Gel| are incredibly powerful and flexible. They can be used to model relationships of any cardinality, can be traversed in both directions, can be polymorphic, can have constraints, and many other things. Links are directional ===================== Links are *directional*: they have a **source** (the type on which they are declared) and a **target** (the type they point to). E.g. the following schema defines a link from ``Person`` to ``Person`` and a link from ``Company`` to ``Person``: .. code-block:: sdl type Person { link best_friend: Person; } type Company { multi link employees: Person; } The ``employees`` link's source is ``Company`` and its target is ``Person``. The ``link`` keyword is optional, and can be omitted. Link cardinality ================ .. api-index:: single, multi All links have a cardinality: either ``single`` or ``multi``. The default is ``single`` (a "to-one" link). Use the ``multi`` keyword to declare a "to-many" link: .. code-block:: sdl type Person { multi friends: Person; } Required links ============== .. index:: not null .. api-index:: required, optional All links are either ``optional`` or ``required``; the default is ``optional``. Use the ``required`` keyword to declare a required link. A required link must point to *at least one* target instance, and if the cardinality of the required link is ``single``, it must point to *exactly one* target instance. In this scenario, every ``Person`` must have *exactly one* ``best_friend``: .. code-block:: sdl type Person { required best_friend: Person; } Links with cardinality ``multi`` can also be ``required``; ``required multi`` links must point to *at least one* target object: .. code-block:: sdl type Person { name: str; } type GroupChat { required multi members: Person; } Attempting to create a ``GroupChat`` with no members would fail. Exclusive constraints ===================== .. api-index:: constraint exclusive You can add an ``exclusive`` constraint to a link to guarantee that no other instances can link to the same target(s): .. code-block:: sdl type Person { name: str; } type GroupChat { required multi members: Person { constraint exclusive; } } With ``exclusive`` on ``GroupChat.members``, two ``GroupChat`` objects cannot link to the same ``Person``; put differently, no ``Person`` can be a ``member`` of multiple ``GroupChat`` objects. Backlinks ========= .. api-index:: .< In Gel you can traverse links in reverse to find objects that link to the object. You can do that directly in your query. E.g. for this example schema: .. code-block:: sdl type Author { name: str; } type Article { title: str; multi authors: Author; } You can find all articles by "John Doe" by traversing the ``authors`` link in reverse: .. code-block:: edgeql select Author { articles := .` to learn more. .. _ref_guide_one_to_many: One-to-many ----------- Conceptually, one-to-many and many-to-one relationships are identical; the "directionality" is a matter of perspective. Here, the same "shirt owner" relationship is represented with a ``multi`` link: .. code-block:: sdl type Person { required name: str; multi shirts: Shirt { # ensures a one-to-many relationship constraint exclusive; } } type Shirt { required color: str; } .. note:: Don't forget the ``exclusive`` constraint! Without it, the relationship becomes many-to-many. Under the hood, a ``multi`` link is stored in an intermediate `association table `_, whereas a ``single`` link is stored as a column in the object type where it is declared. .. note:: Choosing a link direction can be tricky. Should you model this relationship as one-to-many (with a ``multi`` link) or as many-to-one (with a ``single`` link and a backlink)? A general rule of thumb: - Use a ``multi`` link if the relationship is relatively stable and not updated frequently, and the set of related objects is typically small. For example, a list of postal addresses in a user profile. - Otherwise, prefer a single link from one object type and a computed backlink on the other. This can be more efficient and is generally recommended for 1:N relations: .. code-block:: sdl type Post { required author: User; } type User { multi posts := (.; }; } type Movie { required title: str; } Alternatively, you might introduce a dedicated type: .. code-block:: sdl type User { required name: str; multi watch_history := .`. Inserting and updating link properties -------------------------------------- To add a link with a link property, include the property name (prefixed by ``@``) in the shape: .. code-block:: edgeql insert Person { name := "Bob", family_members := ( select detached Person { @relationship := "sister" } filter .name = "Alice" ) }; Updating a link's property on an **existing** link is similar. You can select the link from within the object being updated: .. code-block:: edgeql update Person filter .name = "Bob" set { family_members := ( select .family_members { @relationship := "step-sister" } filter .name = "Alice" ) }; .. warning:: A link property cannot be referenced in a set union *except* in the case of a :ref:`for loop `. For instance: .. code-block:: edgeql # 🚫 Does not work insert Movie { title := 'The Incredible Hulk', characters := { ( select Person { @character_name := 'The Hulk' } filter .name = 'Mark Ruffalo' ), ( select Person { @character_name := 'Abomination' } filter .name = 'Tim Roth' ) } }; will produce an error ``QueryError: invalid reference to link property in top level shape``. One workaround is to insert them via a ``for`` loop, combined with :eql:func:`assert_distinct`: .. code-block:: edgeql # ✅ Works! insert Movie { title := 'The Incredible Hulk', characters := assert_distinct(( with actors := { ('The Hulk', 'Mark Ruffalo'), ('Abomination', 'Tim Roth') }, for actor in actors union ( select Person { @character_name := actor.0 } filter .name = actor.1 ) )) }; Querying link properties ------------------------ To query a link property, add the link property's name (prefixed with ``@``) in the shape: .. code-block:: edgeql-repl db> select Person { ... name, ... family_members: { ... name, ... @relationship ... } ... }; .. note:: In the results above, Bob has a *step-sister* property on the link to Alice, but Alice does not automatically have a property describing Bob. Changes to link properties are not mirrored on the "backlink" side unless explicitly updated, because link properties cannot be required. .. note:: For a full guide on modeling, inserting, updating, and querying link properties, see the :ref:`Using Link Properties ` guide. .. _ref_datamodel_link_deletion: Deletion policies ================= .. api-index:: on target delete, on source delete, restrict, delete source, allow, deferred restrict, delete target, if orphan Links can declare their own **deletion policy** for when the **target** or **source** is deleted. Target deletion --------------- The clause ``on target delete`` determines the action when the target object is deleted: - ``restrict`` (default) — raises an exception if the target is deleted. - ``delete source`` — deletes the source when the target is deleted (a cascade). - ``allow`` — removes the target from the link if the target is deleted. - ``deferred restrict`` — like ``restrict`` but defers the error until the end of the transaction if the object remains linked. .. code-block:: sdl type MessageThread { title: str; } type Message { content: str; chat: MessageThread { on target delete delete source; } } .. _ref_datamodel_links_source_deletion: Source deletion --------------- The clause ``on source delete`` determines the action when the **source** is deleted: - ``allow`` — deletes the source, removing the link to the target. - ``delete target`` — unconditionally deletes the target as well. - ``delete target if orphan`` — deletes the target if and only if it's no longer linked by any other object *via the same link*. .. code-block:: sdl type MessageThread { title: str; multi messages: Message { on source delete delete target; } } type Message { content: str; } You can add ``if orphan`` if you'd like to avoid deleting a target that remains linked elsewhere via the **same** link name. .. code-block:: sdl-diff type MessageThread { title: str; multi messages: Message { - on source delete delete target; + on source delete delete target if orphan; } } .. note:: The ``if orphan`` qualifier **does not** apply globally across all links in the database or even all links from the same type. If another link *by a different name* or *with a different on-target-delete* policy points at the same object, it *doesn't* prevent the object from being considered "orphaned" for the link that includes ``if orphan``. .. _ref_datamodel_link_polymorphic: Polymorphic links ================= Links can be **polymorphic**, i.e., have an ``abstract`` target. In the example below, we have an abstract type ``Person`` with concrete subtypes ``Hero`` and ``Villain``: .. code-block:: sdl abstract type Person { name: str; } type Hero extending Person { # additional fields } type Villain extending Person { # additional fields } A polymorphic link can target any non-abstract subtype: .. code-block:: sdl type Movie { title: str; multi characters: Person; } When querying a polymorphic link, you can filter by a specific subtype, cast the link to a subtype, etc. See :ref:`Polymorphic Queries ` for details. Abstract links ============== .. api-index:: abstract link It's possible to define ``abstract`` links that aren't tied to a particular source or target, and then extend them in concrete object types. This can help eliminate repetitive declarations: .. code-block:: sdl abstract link link_with_strength { strength: float64; index on (__subject__@strength); } type Person { multi friends: Person { extending link_with_strength; }; } .. _ref_eql_sdl_links_overloading: Overloading =========== .. api-index:: overloaded When an inherited link is modified (by adding more constraints or changing its target type, etc.), the ``overloaded`` keyword is required. This prevents unintentional overloading due to name clashes: .. code-block:: sdl abstract type Friendly { # this type can have "friends" multi friends: Friendly; } type User extending Friendly { # overload the link target to to be specifically User overloaded multi friends: User; # ... other links and properties } .. _ref_eql_sdl_links: .. _ref_eql_sdl_links_syntax: Declaring links =============== This section describes the syntax to use links in your schema. Syntax ------ .. sdl:synopsis:: # Concrete link form used inside type declaration: [ overloaded ] [{required | optional}] [{single | multi}] [ link ] : [ "{" [ extending [, ...] ; ] [ default := ; ] [ readonly := {true | false} ; ] [ on target delete ; ] [ on source delete ; ] [ ] [ ] [ ] ... "}" ] # Computed link form used inside type declaration: [{required | optional}] [{single | multi}] [ link ] := ; # Computed link form used inside type declaration (extended): [ overloaded ] [{required | optional}] [{single | multi}] link [: ] [ "{" using () ; [ extending [, ...] ; ] [ ] [ ] ... "}" ] # Abstract link form: abstract link [ "{" [ extending [, ...] ; ] [ readonly := {true | false} ; ] [ ] [ ] [ ] [ ] ... "}" ] There are several forms of link declaration, as shown in the syntax synopsis above: - the first form is the canonical definition form; - the second form is used for defining a :ref:`computed link `; - and the last form is used to define an abstract link. The following options are available: :eql:synopsis:`overloaded` If specified, indicates that the link is inherited and that some feature of it may be altered in the current object type. It is an error to declare a link as *overloaded* if it is not inherited. :eql:synopsis:`required` If specified, the link is considered *required* for the parent object type. It is an error for an object to have a required link resolve to an empty value. Child links **always** inherit the *required* attribute, i.e it is not possible to make a required link non-required by extending it. :eql:synopsis:`optional` This is the default qualifier assumed when no qualifier is specified, but it can also be specified explicitly. The link is considered *optional* for the parent object type, i.e. it is possible for the link to resolve to an empty value. :eql:synopsis:`multi` Specifies that there may be more than one instance of this link in an object, in other words, ``Object.link`` may resolve to a set of a size greater than one. :eql:synopsis:`single` Specifies that there may be at most *one* instance of this link in an object, in other words, ``Object.link`` may resolve to a set of a size not greater than one. ``single`` is assumed if nether ``multi`` nor ``single`` qualifier is specified. :eql:synopsis:`extending [, ...]` Optional clause specifying the *parents* of the new link item. Use of ``extending`` creates a persistent schema relationship between the new link and its parents. Schema modifications to the parent(s) propagate to the child. If the same *property* name exists in more than one parent, or is explicitly defined in the new link and at least one parent, then the data types of the property targets must be *compatible*. If there is no conflict, the link properties are merged to form a single property in the new link item. :eql:synopsis:`` The type must be a valid :ref:`type expression ` denoting an object type. The valid SDL sub-declarations are listed below: :eql:synopsis:`default := ` Specifies the default value for the link as an EdgeQL expression. The default value is used in an ``insert`` statement if an explicit value for this link is not specified. The expression must be :ref:`Stable `. :eql:synopsis:`readonly := {true | false}` If ``true``, the link is considered *read-only*. Modifications of this link using ``update`` are prohibited once an object is created. Any :ref:`overloaded links ` **must** preserve the original *read-only* value. Changes to this link **will** occur if a link is deleted and the appropriate :ref:`deletion policy ` allows it. :sdl:synopsis:`` Set link :ref:`annotation ` to a given *value*. :sdl:synopsis:`` Define a concrete :ref:`property ` on the link. :sdl:synopsis:`` Define a concrete :ref:`constraint ` on the link. :sdl:synopsis:`` Define an :ref:`index ` for this abstract link. Note that this index can only refer to link properties. .. _ref_eql_ddl_links: DDL commands ============ This section describes the low-level DDL commands for creating, altering, and dropping links. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create link ----------- :eql-statement: :eql-haswith: Define a new link. .. eql:synopsis:: [ with [, ...] ] {create|alter} type "{" [ ... ] create [{required | optional}] [{single | multi}] link [ extending [, ...] ]: [ "{" ; [...] "}" ] ; [ ... ] "}" # Computed link form: [ with [, ...] ] {create|alter} type "{" [ ... ] create [{required | optional}] [{single | multi}] link := ; [ ... ] "}" # Abstract link form: [ with [, ...] ] create abstract link [::] [extending [, ...]] [ "{" ; [...] "}" ] # where is one of set default := set readonly := {true | false} create annotation := create property ... create constraint ... on target delete on source delete reset on target delete create index on Description ^^^^^^^^^^^ The combinations of ``create type ... create link`` and ``alter type ... create link`` define a new concrete link for a given object type, in DDL form. There are three forms of ``create link``: 1. The canonical definition form (specifying a target type). 2. The computed link form (declaring a link via an expression). 3. The abstract link form (declaring a module-level link). Parameters ^^^^^^^^^^^ Most sub-commands and options mirror those found in the :ref:`SDL link declaration `. In DDL form: - ``set default := `` specifies a default value. - ``set readonly := {true | false}`` makes the link read-only or not. - ``create annotation := `` adds an annotation. - ``create property ...`` defines a property on the link. - ``create constraint ...`` defines a constraint on the link. - ``on target delete `` and ``on source delete `` specify deletion policies. - ``reset on target delete`` resets the target deletion policy to default or inherited. - ``create index on `` creates an index on the link. Examples ^^^^^^^^ .. code-block:: edgeql alter type User { create multi link friends: User }; .. code-block:: edgeql alter type User { create link special_group := ( select __source__.friends filter .town = __source__.town ) }; .. code-block:: edgeql create abstract link orderable { create property weight: std::int64 }; alter type User { create multi link interests extending orderable: Interest }; Alter link ---------- :eql-statement: :eql-haswith: Changes the definition of a link. .. eql:synopsis:: [ with [, ...] ] {create|alter} type "{" [ ... ] alter link [ "{" ] ; [...] [ "}" ]; [ ... ] "}" [ with [, ...] ] alter abstract link [::] [ "{" ] ; [...] [ "}" ]; # where is one of set default := reset default set readonly := {true | false} reset readonly rename to extending ... set required set optional reset optionality set single set multi reset cardinality set type [using ()] reset type using () create annotation := alter annotation := drop annotation create property ... alter property ... drop property ... create constraint ... alter constraint ... drop constraint ... on target delete on source delete create index on drop index on Description ^^^^^^^^^^^ This command modifies an existing link on a type. It can also be used on an abstract link at the module level. Parameters ^^^^^^^^^^ - ``rename to `` changes the link's name. - ``extending ...`` changes or adds link parents. - ``set required`` / ``set optional`` changes the link optionality. - ``reset optionality`` reverts optionality to default or inherited value. - ``set single`` / ``set multi`` changes cardinality. - ``reset cardinality`` reverts cardinality to default or inherited value. - ``set type [using ()]`` changes the link's target type. - ``reset type`` reverts the link's type to inherited. - ``using ()`` changes the expression of a computed link. - ``create annotation``, ``alter annotation``, ``drop annotation`` manage annotations. - ``create property``, ``alter property``, ``drop property`` manage link properties. - ``create constraint``, ``alter constraint``, ``drop constraint`` manage link constraints. - ``on target delete `` and ``on source delete `` manage deletion policies. - ``reset on target delete`` reverts the target deletion policy. - ``create index on `` / ``drop index on `` manage indexes on link properties. Examples ^^^^^^^^ .. code-block:: edgeql alter type User { alter link friends create annotation title := "Friends"; }; .. code-block:: edgeql alter abstract link orderable rename to sorted; .. code-block:: edgeql alter type User { alter link special_group using ( # at least one of the friend's interests # must match the user's select __source__.friends filter .interests IN __source__.interests ); }; Drop link --------- :eql-statement: :eql-haswith: Removes the specified link from the schema. .. eql:synopsis:: [ with [, ...] ] alter type "{" [ ... ] drop link [ ... ] "}" [ with [, ...] ] drop abstract link []:: Description ^^^^^^^^^^^ - ``alter type ... drop link `` removes the link from an object type. - ``drop abstract link `` removes an abstract link from the schema. Examples ^^^^^^^^ .. code-block:: edgeql alter type User drop link friends; .. code-block:: edgeql drop abstract link orderable; .. list-table:: :class: seealso * - **See also** - :ref:`Introspection > Object types ` ================================================ FILE: docs/reference/datamodel/migrations.rst ================================================ .. _ref_datamodel_migrations: ========== Migrations ========== |Gel's| baked-in migration system lets you painlessly evolve your schema over time. Just update the contents of your |.gel| file(s) and use the |Gel| CLI to *create* and *apply* migrations. .. code-block:: bash $ gel migration create Created dbschema/migrations/00001.edgeql $ gel migrate Applied dbschema/migrations/00001.edgeql Refer to the :ref:`creating and applying migrations ` guide for more information on how to use the migration system. This document describes how migrations are implemented. The migrations flow =================== The migration flow is as follows: 1. The user edits the |.gel| files in the ``dbschema`` directory. This makes the schema described in the |.gel| files **different** from the actual schema in the database. 2. The user runs the :gelcmd:`migration create` command to create a new migration (a sequence of low-level DDL commands). * The CLI reads the |.gel| files and sends them to the |Gel| server, to analyze the changes. * The |Gel| server generates a migration plan and sends it back to the CLI. * The migration plan might require clarification from the user. If so, the CLI and the |Gel| server will go back and forth presenting the user with a sequence of questions, until the migration plan is clear and approved by the user. 3. The CLI writes the migration plan to a new file in the ``dbschema/migrations`` directory. 4. The user runs the :gelcmd:`migrate` command to apply the migration to the database. 5. The user checks in the updated |.gel| files and the new ``dbschema/migrations`` migration file (created by :gelcmd:`migration create`) into version control. Command line tools ================== The two most important commands are: * :gelcmd:`migration create` * :gelcmd:`migrate` Automatic migrations ==================== Sometimes when you're prototyping something new you don't want to spend time worrying about migrations. There's no data to lose and not much code that depends on the schema just yet. For this use case you can use the :gelcmd:`watch --migrate` command, which will monitor your |.gel| files and automatically create and apply migrations for you in the background. .. _ref_eql_ddl: Data definition language (DDL) ============================== The migration plan is a sequence of DDL commands. DDL commands are low-level instructions that describe the changes to the schema. SDL and your |.gel| files are like a 3D printer: you design the final shape, and the system puts a database together for you. Using DDL is like building a house the traditional way: to add a window, you first need a frame; to have a frame, you need a wall; and so on. If your schema looks like this: .. code-block:: sdl type User { required name: str; } then the corresponding DDL might look like this: .. code-block:: edgeql create type User { create required property name: str; } There are some circumstances where users might want to use DDL directly. But in most cases you just need to learn how to read them to understand the migration plan. Luckily, the DDL and SDL syntaxes were designed in tandem and are very similar. Most documentation pages on Gel's schema have a section about DDL commands, e.g. :ref:`object types DDL `. .. _ref_eql_ddl_migrations: Migration DDL commands ====================== Migrations themselves are a sequence of special DDL commands. Like all DDL commands, ``start migration`` and other migration commands are considered low-level. Users are encouraged to use the built-in :ref:`migration tools ` instead. However, if you want to implement your own migration tools, this section will give you a good understanding of how Gel migrations work under the hood. Start migration --------------- :eql-statement: Start a migration block. .. eql:synopsis:: start migration to "{" ; [ ... ] "}" ; Parameters ^^^^^^^^^^ :eql:synopsis:`` Complete schema text (content of all |.gel| files) defined with the declarative :ref:`Gel schema definition language `. Description ^^^^^^^^^^^ The command ``start migration`` defines a migration of the schema to a new state. The target schema state is described using :ref:`SDL ` and describes the entire schema. This is important to remember when creating a migration to add a few more things to an existing schema as all the existing schema objects and the new ones must be included in the ``start migration`` command. Objects that aren't included in the command will be removed from the new schema (which may result in data loss). This command also starts a transaction block if not inside a transaction already. While inside a migration block, all issued EdgeQL statements are not executed immediately and are instead recorded to be part of the migration script. Aside from normal EdgeQL commands the following special migration commands are available: * :eql:stmt:`describe current migration` -- return a list of statements currently recorded as part of the migration; * :eql:stmt:`populate migration` -- auto-populate the migration with system-generated DDL statements to achieve the target schema state; * :eql:stmt:`abort migration` -- abort the migration block and discard the migration; * :eql:stmt:`commit migration` -- commit the migration by executing the migration script statements and recording the migration into the system migration log. Example ^^^^^^^ Create a new migration to a target schema specified by the Gel Schema syntax: .. code-block:: edgeql start migration to { module default { type User { property username: str; }; }; }; .. _ref_eql_ddl_migrations_create: create migration ---------------- :eql-statement: Create a new migration using an explicit EdgeQL script. .. eql:synopsis:: create migration "{" ; [ ... ] "}" ; Parameters ^^^^^^^^^^ :eql:synopsis:`` Any valid EdgeQL statement, except ``database``, ``branch``, ``role``, ``configure``, ``migration``, or ``transaction`` statements. Description ^^^^^^^^^^^ The command ``create migration`` executes all the nested EdgeQL commands and records the migration into the system migration log. Example ^^^^^^^ Create a new migration to a target schema specified by the Gel Schema syntax: .. code-block:: edgeql create migration { create type default::User { create property username: str; } }; Abort migration --------------- :eql-statement: Abort the current migration block and discard the migration. .. eql:synopsis:: abort migration ; Description ^^^^^^^^^^^ The command ``abort migration`` is used to abort a migration block started by :eql:stmt:`start migration`. Issuing ``abort migration`` outside of a migration block is an error. Example ^^^^^^^ Start a migration block and then abort it: .. code-block:: edgeql start migration to { module default { type User; }; }; abort migration; Populate migration ------------------ :eql-statement: Populate the current migration with system-generated statements. .. eql:synopsis:: populate migration ; Description ^^^^^^^^^^^ The command ``populate migration`` is used within a migration block started by :eql:stmt:`start migration` to automatically fill the migration with system-generated statements to achieve the desired target schema state. If the system is unable to automatically find a satisfactory sequence of statements to perform the migration, an error is returned. Issuing ``populate migration`` outside of a migration block is also an error. .. warning:: The statements generated by ``populate migration`` may drop schema objects, which may result in data loss. Make sure to inspect the generated migration using :eql:stmt:`describe current migration` before running :eql:stmt:`commit migration`! Example ^^^^^^^ Start a migration block and populate it with auto-generated statements. .. code-block:: edgeql start migration to { module default { type User; }; }; populate migration; Describe current migration -------------------------- :eql-statement: Describe the migration in the current migration block. .. eql:synopsis:: describe current migration [ as {ddl | json} ]; Description ^^^^^^^^^^^ The command ``describe current migration`` generates a description of the migration in the current migration block in the specified output format: :eql:synopsis:`as ddl` Show a sequence of statements currently recorded as part of the migration using valid :ref:`DDL ` syntax. The output will indicate if the current migration is fully defined, i.e. the recorded statements bring the schema to the state specified by :eql:stmt:`start migration`. :eql:synopsis:`as json` Provide a machine-readable description of the migration using the following JSON format: .. code-block:: { // Name of the parent migration "parent": "", // Whether the confirmed DDL makes the migration complete, // i.e. there are no more statements to issue. "complete": {true|false}, // List of confirmed migration statements "confirmed": [ "", ... ], // The variants of the next statement // suggested by the system to advance // the migration script. "proposed": { "statements": [{ "text": "" }], "required-user-input": [ { "placeholder": "", "prompt": "" }, ... ], "confidence": (0..1), // confidence coefficient "prompt": "", "prompt_id": "", // Whether the operation is considered to be non-destructive. "data_safe": {true|false} } } Where: :eql:synopsis:`` Regular statement text. :eql:synopsis:`` Statement text template with interpolation points using the ``\(name)`` syntax. :eql:synopsis:`` The name of an interpolation variable in the statement text template for which the user prompt is given. :eql:synopsis:`` The text of a user prompt for an interpolation variable. :eql:synopsis:`` Prompt for the proposed migration step. :eql:synopsis:`` An opaque string identifier for a particular operation prompt. The client should not repeat prompts with the same prompt id. Commit migration ---------------- :eql-statement: Commit the current migration to the database. .. eql:synopsis:: commit migration ; Description ^^^^^^^^^^^ The command ``commit migration`` executes all the commands defined by the current migration and records the migration as the most recent migration in the database. Issuing ``commit migration`` outside of a migration block initiated by :eql:stmt:`start migration` is an error. Example ^^^^^^^ Create and execute the current migration: .. code-block:: edgeql commit migration; Reset schema to initial ----------------------- :eql-statement: Reset the database schema to its initial state. .. eql:synopsis:: reset schema to initial ; .. warning:: This command will drop all entities and, as a consequence, all data. You won't want to use this statement on a production instance unless you want to lose all that instance's data. Migration rewrites DDL commands =============================== Migration rewrites allow you to change the migration history as long as your final schema matches the current database schema. Start migration rewrite ----------------------- Start a migration rewrite. .. eql:synopsis:: start migration rewrite ; Once the migration rewrite is started, you can run any arbitrary DDL until you are ready to :ref:`commit ` your new migration history. The most useful DDL in this context will be :ref:`create migration ` statements, which will allow you to create a sequence of migrations that will become your new migration history. Declare savepoint ----------------- Establish a new savepoint within the current migration rewrite. .. eql:synopsis:: declare savepoint ; Parameters ^^^^^^^^^^ :eql:synopsis:`` The name which will be used to identify the new savepoint if you need to later release it or roll back to it. Release savepoint ----------------- Destroys a savepoint previously defined in the current migration rewrite. .. eql:synopsis:: release savepoint ; Parameters ^^^^^^^^^^ :eql:synopsis:`` The name of the savepoint to be released. Rollback to savepoint --------------------- Rollback to the named savepoint. .. eql:synopsis:: rollback to savepoint ; All changes made after the savepoint are discarded. The savepoint remains valid and can be rolled back to again later, if needed. Parameters ^^^^^^^^^^ :eql:synopsis:`` The name of the savepoint to roll back to. Rollback -------- Rollback the entire migration rewrite. .. eql:synopsis:: rollback ; All updates made within the transaction are discarded. .. _ref_eql_ddl_migrations_rewrites_commit: Commit migration rewrite ------------------------ Commit a migration rewrite. .. eql:synopsis:: commit migration rewrite ; ================================================ FILE: docs/reference/datamodel/modules.rst ================================================ .. _ref_datamodel_modules: .. _ref_eql_sdl_modules: ======= Modules ======= Each |branch| has a schema consisting of several **modules**, each with a unique name. Modules can be used to organize large schemas into logical units. In practice, though, most users put their entire schema inside a single module called ``default``. .. code-block:: sdl module default { # declare types here } .. _ref_name_resolution: Name resolution =============== When you define a module that references schema objects from another module, you must use a *fully-qualified* name in the form ``other_module_name::object_name``: .. code-block:: sdl module A { type User extending B::AbstractUser; } module B { abstract type AbstractUser { required name: str; } } Reserved module names ===================== The following module names are reserved by |Gel| and contain pre-defined types, utility functions, and operators: * ``std``: standard types, functions, and operators in the :ref:`standard library ` * ``math``: algebraic and statistical :ref:`functions ` * ``cal``: local (non-timezone-aware) and relative date/time :ref:`types and functions ` * ``schema``: types describing the :ref:`introspection ` schema * ``sys``: system-wide entities, such as user roles and :ref:`databases ` * ``cfg``: configuration and settings Modules are containers ====================== They can contain types, functions, and other modules. Here's an example of an empty module: .. code-block:: sdl module my_module {} And here's an example of a module with a type: .. code-block:: sdl module my_module { type User { required name: str; } } Nested modules ============== .. code-block:: sdl module dracula { type Person { required name: str; multi places_visited: City; strength: int16; } module combat { function fight( one: dracula::Person, two: dracula::Person ) -> str using ( (one.name ?? 'Fighter 1') ++ ' wins!' IF (one.strength ?? 0) > (two.strength ?? 0) ELSE (two.name ?? 'Fighter 2') ++ ' wins!' ); } } You can chain together module names in a fully-qualified name to traverse a tree of nested modules. For example, to call the ``fight`` function in the nested module example above, you would use ``dracula::combat::fight()``. Declaring modules ================= This section describes the syntax to declare a module in your schema. Syntax ------ .. sdl:synopsis:: module "{" [ ] ... "}" Define a nested module: .. sdl:synopsis:: module "{" [ ] module "{" [ ] "}" ... "}" Description ^^^^^^^^^^^ The module block declaration defines a new module similar to the :eql:stmt:`create module` command, but it also allows putting the module content as nested declarations: :sdl:synopsis:`` Define various schema items that belong to this module. Unlike :eql:stmt:`create module`, a module block with the same name can appear multiple times in an SDL document. In that case all blocks with the same name are merged into a single module under that name. For example: .. code-block:: sdl module my_module { abstract type Named { required name: str; } } module my_module { type User extending Named; } The above is equivalent to: .. code-block:: sdl module my_module { abstract type Named { required name: str; } type User extending Named; } Typically, in the documentation examples of SDL the *module block* is omitted and instead its contents are described without assuming which specific module they belong to. It's also possible to declare modules implicitly. In this style, SDL declaration uses a :ref:`fully-qualified name ` for the item that is being declared. The *module* part of the *fully-qualified* name implies that a module by that name will be automatically created in the schema. The following declaration is equivalent to the previous examples, but it declares module ``my_module`` implicitly: .. code-block:: sdl abstract type my_module::Named { required name: str; } type my_module::User extending my_module::Named; A module block can be nested inside another module block to create a nested module. If you want to reference an entity in a nested module by its fully-qualified name, you will need to include all of the containing modules' names: ``::::`` .. _ref_eql_ddl_modules: DDL commands ============ This section describes the low-level DDL commands for creating and dropping modules. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create module ------------- :eql-statement: Create a new module. .. eql:synopsis:: create module [ :: ] [ if not exists ]; There's a :ref:`corresponding SDL declaration ` for a module, although in SDL a module declaration is likely to also include that module's content. Description ^^^^^^^^^^^ The command ``create module`` defines a new module for the current :versionreplace:`database;5.0:branch`. The name of the new module must be distinct from any existing module in the current :versionreplace:`database;5.0:branch`. Unlike :ref:`SDL module declaration ` the ``create module`` command does not have sub-commands; module contents are created separately. Parameters ^^^^^^^^^^ :eql:synopsis:`if not exists` Normally, creating a module that already exists is an error, but with this flag the command will succeed. It is useful for scripts that add something to a module or, if the module is missing, the module is created as well. Examples ^^^^^^^^ Create a new module: .. code-block:: edgeql create module payments; Create a new nested module: .. code-block:: edgeql create module payments::currencies; Drop module ----------- :eql-statement: Remove a module. .. eql:synopsis:: drop module ; Description ^^^^^^^^^^^ The command ``drop module`` removes an existing empty module from the current :versionreplace:`database;5.0:branch`. If the module contains any schema items, this command will fail. Examples ^^^^^^^^ Remove a module: .. code-block:: edgeql drop module payments; ================================================ FILE: docs/reference/datamodel/mutation_rewrites.rst ================================================ .. _ref_datamodel_mutation_rewrites: ================= Mutation rewrites ================= .. index:: modify, modification Mutation rewrites allow you to intercept database mutations (i.e., :ref:`inserts ` and/or :ref:`updates `) and set the value of a property or link to the result of an expression you define. They can be defined in your schema. Mutation rewrites are complementary to :ref:`triggers `. While triggers are unable to modify the triggering object, mutation rewrites are built for that purpose. Example: last modified ====================== Here's an example of a mutation rewrite that updates a property of a ``Post`` type to reflect the time of the most recent modification: .. code-block:: sdl type Post { required title: str; required body: str; modified: datetime { rewrite insert, update using (datetime_of_statement()) } } Every time a ``Post`` is updated, the mutation rewrite will be triggered, updating the ``modified`` property: .. code-block:: edgeql-repl db> insert Post { ... title := 'One wierd trick to fix all your spelling errors' ... }; {default::Post {id: 19e024dc-d3b5-11ed-968c-37f5d0159e5f}} db> select Post {title, modified}; { default::Post { title: 'One wierd trick to fix all your spelling errors', modified: '2023-04-05T13:23:49.488335Z', }, } db> update Post ... filter .id = '19e024dc-d3b5-11ed-968c-37f5d0159e5f' ... set {title := 'One weird trick to fix all your spelling errors'}; {default::Post {id: 19e024dc-d3b5-11ed-968c-37f5d0159e5f}} db> select Post {title, modified}; { default::Post { title: 'One weird trick to fix all your spelling errors', modified: '2023-04-05T13:25:04.119641Z', }, } In some cases, you will want different rewrites depending on the type of query. Here, we will add an ``insert`` rewrite and an ``update`` rewrite: .. code-block:: sdl type Post { required title: str; required body: str; created: datetime { rewrite insert using (datetime_of_statement()) } modified: datetime { rewrite update using (datetime_of_statement()) } } With this schema, inserts will set the ``Post`` object's ``created`` property while updates will set the ``modified`` property: .. code-block:: edgeql-repl db> insert Post { ... title := 'One wierd trick to fix all your spelling errors' ... }; {default::Post {id: 19e024dc-d3b5-11ed-968c-37f5d0159e5f}} db> select Post {title, created, modified}; { default::Post { title: 'One wierd trick to fix all your spelling errors', created: '2023-04-05T13:23:49.488335Z', modified: {}, }, } db> update Post ... filter .id = '19e024dc-d3b5-11ed-968c-37f5d0159e5f' ... set {title := 'One weird trick to fix all your spelling errors'}; {default::Post {id: 19e024dc-d3b5-11ed-968c-37f5d0159e5f}} db> select Post {title, created, modified}; { default::Post { title: 'One weird trick to fix all your spelling errors', created: '2023-04-05T13:23:49.488335Z', modified: '2023-04-05T13:25:04.119641Z', }, } .. note:: Each property may have a single ``insert`` and a single ``update`` mutation rewrite rule, or they may have a single rule that covers both. Mutation context ================ .. api-index:: rewrite, __subject__, __specified__, __old__ Inside the rewrite rule's expression, you have access to a few special values: * ``__subject__`` refers to the object type with the new property and link values. * ``__specified__`` is a named tuple with a key for each property or link in the type and a boolean value indicating whether this value was explicitly set in the mutation. * ``__old__`` refers to the object type with the previous property and link values (available for update-only mutation rewrites). Here are some examples of the special values in use. Maybe your blog hosts articles about particularly controversial topics. You could use ``__subject__`` to enforce a "cooling off" period before publishing a blog post: .. code-block:: sdl type Post { required title: str; required body: str; publish_time: datetime { rewrite insert, update using ( __subject__.publish_time ?? datetime_of_statement() + cal::to_relative_duration(days := 10) ) } } Here we take the post's ``publish_time`` if set or the time the statement is executed and add 10 days to it. That should give our authors time to consider if they want to make any changes before a post goes live. You can omit ``__subject__`` in many cases and achieve the same thing: .. code-block:: sdl-diff type Post { required title: str; required body: str; publish_time: datetime { rewrite insert, update using ( - __subject__.publish_time ?? datetime_of_statement() + + .publish_time ?? datetime_of_statement() + cal::to_relative_duration(days := 10) ) } } but only if the path prefix has not changed. In the following schema, for example, the ``__subject__`` in the rewrite rule is required, because in the context of the nested ``select`` query, the leading dot resolves from the ``User`` path: .. code-block:: sdl type Post { required title: str; required body: str; author_email: str; author_name: str { rewrite insert, update using ( (select User {name} filter .email = __subject__.author_email).name ) } } type User { name: str; email: str; } .. note:: Learn more about how this works in our documentation on :ref:`path resolution `. Using ``__specified__``, we can determine which fields were specified in the mutation. This would allow us to track when a single property was last modified as in the ``title_modified`` property in this schema: .. code-block:: sdl type Post { required title: str; required body: str; title_modified: datetime { rewrite update using ( datetime_of_statement() if __specified__.title else __old__.title_modified ) } } ``__specified__.title`` will be ``true`` if that value was set as part of the update, and this rewrite mutation rule will update ``title_modified`` to ``datetime_of_statement()`` in that case. Another way you might use this is to set a default value but allow overriding: .. code-block:: sdl type Post { required title: str; required body: str; modified: datetime { rewrite update using ( datetime_of_statement() if not __specified__.modified else .modified ) } } Here, we rewrite ``modified`` on updates to ``datetime_of_statement()`` unless ``modified`` was set in the update. In that case, we allow the specified value to be set. This is different from a :ref:`default ` value because the rewrite happens on each update whereas a default value is applied only on insert of a new object. One shortcoming in using ``__specified__`` to decide whether to update the ``modified`` property is that we still don't know whether the value changed — only that it was specified in the query. It's possible the value specified was the same as the existing value. You'd need to check the value itself to decide if it has changed. This is easy enough for a single value, but what if you want a global ``modified`` property that is updated only if any of the properties or links were changed? That could get cumbersome quickly for an object of any complexity. Instead, you might try casting ``__subject__`` and ``__old__`` to ``json`` and comparing them: .. code-block:: sdl type Post { required title: str; required body: str; modified: datetime { rewrite update using ( datetime_of_statement() if __subject__ {**} != __old__ {**} else __old__.modified ) } } Lastly, if we want to add an ``author`` property that can be set for each write and keep a history of all the authors, we can do this with the help of ``__old__``: .. code-block:: sdl type Post { required title: str; required body: str; author: str; all_authors: array { default := >[]; rewrite update using ( __old__.all_authors ++ [__subject__.author] ); } } On insert, our ``all_authors`` property will get initialized to an empty array of strings. We will rewrite updates to concatenate that array with an array containing the new author value. Cached computed =============== Mutation rewrites can be used to effectively create a cached computed value as demonstrated with the ``byline`` property in this schema: .. code-block:: sdl type Post { required title: str; required body: str; author: str; created: datetime { rewrite insert using (datetime_of_statement()) } byline: str { rewrite insert, update using ( 'by ' ++ __subject__.author ++ ' on ' ++ to_str(__subject__.created, 'Mon DD, YYYY') ) } } The ``byline`` property will be updated on each insert or update, but the value will not need to be calculated at read time like a proper :ref:`computed property `. .. _ref_eql_sdl_mutation_rewrites: .. _ref_eql_sdl_mutation_rewrites_syntax: Declaring mutation rewrites =========================== .. api-index:: rewrite insert, rewrite update, using This section describes the syntax to declare mutation rewrites in your schema. Syntax ------ Define a new mutation rewrite corresponding to the :ref:`more explicit DDL commands `. .. sdl:synopsis:: rewrite {insert | update} [, ...] using Mutation rewrites must be defined inside a property or link block. Description ^^^^^^^^^^^ This declaration defines a new trigger with the following options: :eql:synopsis:`insert | update [, ...]` The query type (or types) the rewrite runs on. Separate multiple values with commas to invoke the same rewrite for multiple types of queries. :eql:synopsis:`` The expression to be evaluated to produce the new value of the property. .. _ref_eql_ddl_mutation_rewrites: DDL commands ============ This section describes the low-level DDL commands for creatin and dropping mutation rewrites. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create rewrite -------------- :eql-statement: Define a new mutation rewrite. When creating a new property or link: .. eql:synopsis:: {create | alter} type "{" create { property | link } : "{" create rewrite {insert | update} [, ...] using "}" ; "}" ; When altering an existing property or link: .. eql:synopsis:: {create | alter} type "{" alter { property | link } "{" create rewrite {insert | update} [, ...] using "}" ; "}" ; Description ^^^^^^^^^^^ The command ``create rewrite`` nested under ``create type`` or ``alter type`` and then under ``create property/link`` or ``alter property/link`` defines a new mutation rewrite for the given property or link on the given object. Parameters ^^^^^^^^^^ :eql:synopsis:`` The name (optionally module-qualified) of the type containing the rewrite. :eql:synopsis:`` The name (optionally module-qualified) of the property or link being rewritten. :eql:synopsis:`insert | update [, ...]` The query type (or types) that are rewritten. Separate multiple values with commas to invoke the same rewrite for multiple types of queries. Examples ^^^^^^^^ Declare two mutation rewrites on new properties: one that sets a ``created`` property when a new object is inserted and one that sets a ``modified`` property on each update: .. code-block:: edgeql alter type User { create property created: datetime { create rewrite insert using (datetime_of_statement()); }; create property modified: datetime { create rewrite update using (datetime_of_statement()); }; }; Drop rewrite ------------ :eql-statement: Drop a mutation rewrite. .. eql:synopsis:: alter type "{" alter property "{" drop rewrite {insert | update} ; "}" ; "}" ; Description ^^^^^^^^^^^ The command ``drop rewrite`` inside an ``alter type`` block and further inside an ``alter property`` block removes the definition of an existing mutation rewrite on the specified property or link of the specified type. Parameters ^^^^^^^^^^ :eql:synopsis:`` The name (optionally module-qualified) of the type containing the rewrite. :eql:synopsis:`` The name (optionally module-qualified) of the property or link being rewritten. :eql:synopsis:`insert | update [, ...]` The query type (or types) that are rewritten. Separate multiple values with commas to invoke the same rewrite for multiple types of queries. Example ^^^^^^^ Remove the ``insert`` rewrite of the ``created`` property on the ``User`` type: .. code-block:: edgeql alter type User { alter property created { drop rewrite insert; }; }; .. list-table:: :class: seealso * - **See also** * - :ref:`Introspection > Mutation rewrites ` ================================================ FILE: docs/reference/datamodel/objects.rst ================================================ .. _ref_datamodel_object_types: ============ Object Types ============ .. index:: tables, models *Object types* are the primary components of a Gel schema. They are analogous to SQL *tables* or ORM *models*, and consist of :ref:`properties ` and :ref:`links `. Properties ========== Properties are used to attach primitive/scalar data to an object type. For the full documentation on properties, see :ref:`ref_datamodel_props`. .. code-block:: sdl type Person { email: str; } Using in a query: .. code-block:: edgeql select Person { email }; Links ===== Links are used to define relationships between object types. For the full documentation on links, see :ref:`ref_datamodel_links`. .. code-block:: sdl type Person { email: str; best_friend: Person; } Using in a query: .. code-block:: edgeql select Person { email, best_friend: { email } }; ID == .. index:: uuid, primary key There's no need to manually declare a primary key on your object types. All object types automatically contain a property ``id`` of type ``UUID`` that's *required*, *globally unique*, *readonly*, and has an index on it. The ``id`` is assigned upon creation and cannot be changed. Using in a query: .. code-block:: edgeql select Person { id }; select Person { email } filter .id = '123e4567-e89b-...'; Abstract types ============== .. index:: abstract, inheritance Object types can either be *abstract* or *non-abstract*. By default all object types are non-abstract. You can't create or store instances of abstract types (a.k.a. mixins), but they're a useful way to share functionality and structure among other object types. .. code-block:: sdl abstract type HasName { first_name: str; last_name: str; } .. _ref_datamodel_objects_inheritance: .. _ref_eql_sdl_object_types_inheritance: Inheritance =========== .. index:: extending, extends, subtypes, supertypes Object types can *extend* other object types. The extending type (AKA the *subtype*) inherits all links, properties, indexes, constraints, etc. from its *supertypes*. .. code-block:: sdl abstract type HasName { first_name: str; last_name: str; } type Person extending HasName { email: str; best_friend: Person; } Using in a query: .. code-block:: edgeql select Person { first_name, email, best_friend: { last_name } }; .. _ref_datamodel_objects_multiple_inheritance: Multiple Inheritance ==================== Object types can extend more than one type — that's called *multiple inheritance*. This mechanism allows building complex object types out of combinations of more basic types. .. note:: Gel's multiple inheritance should not be confused with the multiple inheritance of C++ or Python, where the complexity usually arises from fine-grained mixing of logic. Gel's multiple inheritance is structural and allows for natural composition. .. code-block:: sdl-diff abstract type HasName { first_name: str; last_name: str; } + abstract type HasEmail { + email: str; + } - type Person extending HasName { + type Person extending HasName, HasEmail { - email: str; best_friend: Person; } If multiple supertypes share links or properties, those properties must be of the same type and cardinality. .. _ref_eql_sdl_object_types: .. _ref_eql_sdl_object_types_syntax: Defining object types ===================== .. api-index:: abstract, type, extending This section describes the syntax to declare object types in your schema. Syntax ------ .. sdl:synopsis:: [abstract] type [extending [, ...] ] [ "{" [ ] [ ] [ ] [ ] [ ] ... "}" ] Description ^^^^^^^^^^^ This declaration defines a new object type with the following options: :eql:synopsis:`abstract` If specified, the created type will be *abstract*. :eql:synopsis:`` The name (optionally module-qualified) of the new type. :eql:synopsis:`extending [, ...]` Optional clause specifying the *supertypes* of the new type. Use of ``extending`` creates a persistent type relationship between the new subtype and its supertype(s). Schema modifications to the supertype(s) propagate to the subtype. References to supertypes in queries will also include objects of the subtype. If the same *link* name exists in more than one supertype, or is explicitly defined in the subtype and at least one supertype, then the data types of the link targets must be *compatible*. If there is no conflict, the links are merged to form a single link in the new type. These sub-declarations are allowed in the ``Type`` block: :sdl:synopsis:`` Set object type :ref:`annotation ` to a given *value*. :sdl:synopsis:`` Define a concrete :ref:`property ` for this object type. :sdl:synopsis:`` Define a concrete :ref:`link ` for this object type. :sdl:synopsis:`` Define a concrete :ref:`constraint ` for this object type. :sdl:synopsis:`` Define an :ref:`index ` for this object type. .. _ref_eql_ddl_object_types: DDL commands ============ This section describes the low-level DDL commands for creating, altering, and dropping object types. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create type ----------- :eql-statement: :eql-haswith: Define a new object type. .. eql:synopsis:: [ with [, ...] ] create [abstract] type [ extending [, ...] ] [ "{" ; [...] "}" ] ; # where is one of create annotation := create link ... create property ... create constraint ... create index on Description ^^^^^^^^^^^ The command ``create type`` defines a new object type for use in the current |branch|. If *name* is qualified with a module name, then the type is created in that module, otherwise it is created in the current module. The type name must be distinct from that of any existing schema item in the module. Parameters ^^^^^^^^^^ Most sub-commands and options of this command are identical to the :ref:`SDL object type declaration `, with some additional features listed below: :eql:synopsis:`with [, ...]` Alias declarations. The ``with`` clause allows specifying module aliases that can be referenced by the command. See :ref:`ref_eql_statements_with` for more information. The following subcommands are allowed in the ``create type`` block: :eql:synopsis:`create annotation := ` Set object type :eql:synopsis:`` to :eql:synopsis:``. See :eql:stmt:`create annotation` for details. :eql:synopsis:`create link ...` Define a new link for this object type. See :eql:stmt:`create link` for details. :eql:synopsis:`create property ...` Define a new property for this object type. See :eql:stmt:`create property` for details. :eql:synopsis:`create constraint ...` Define a concrete constraint for this object type. See :eql:stmt:`create constraint` for details. :eql:synopsis:`create index on ` Define a new :ref:`index ` using *index-expr* for this object type. See :eql:stmt:`create index` for details. Example ^^^^^^^ Create an object type ``User``: .. code-block:: edgeql create type User { create property name: str; }; Alter type ---------- :eql-statement: :eql-haswith: Change the definition of an object type. .. eql:synopsis:: [ with [, ...] ] alter type [ "{" ; [...] "}" ] ; [ with [, ...] ] alter type ; # where is one of rename to extending [, ...] create annotation := alter annotation := drop annotation create link ... alter link ... drop link ... create property ... alter property ... drop property ... create constraint ... alter constraint ... drop constraint ... create index on drop index on Description ^^^^^^^^^^^ The command ``alter type`` changes the definition of an object type. *name* must be a name of an existing object type, optionally qualified with a module name. Parameters ^^^^^^^^^^ :eql:synopsis:`with [, ...]` Alias declarations. The ``with`` clause allows specifying module aliases that can be referenced by the command. See :ref:`ref_eql_statements_with` for more information. :eql:synopsis:`` The name (optionally module-qualified) of the type being altered. :eql:synopsis:`extending [, ...]` Alter the supertype list. The full syntax of this subcommand is: .. eql:synopsis:: extending [, ...] [ first | last | before | after ] This subcommand makes the type a subtype of the specified list of supertypes. The requirements for the parent-child relationship are the same as when creating an object type. It is possible to specify the position in the parent list using the following optional keywords: * ``first`` -- insert parent(s) at the beginning of the parent list, * ``last`` -- insert parent(s) at the end of the parent list, * ``before `` -- insert parent(s) before an existing *parent*, * ``after `` -- insert parent(s) after an existing *parent*. :eql:synopsis:`alter annotation ;` Alter object type annotation :eql:synopsis:``. See :eql:stmt:`alter annotation` for details. :eql:synopsis:`drop annotation ` Remove object type :eql:synopsis:``. See :eql:stmt:`drop annotation` for details. :eql:synopsis:`alter link ...` Alter the definition of a link for this object type. See :eql:stmt:`alter link` for details. :eql:synopsis:`drop link ` Remove a link item from this object type. See :eql:stmt:`drop link` for details. :eql:synopsis:`alter property ...` Alter the definition of a property item for this object type. See :eql:stmt:`alter property` for details. :eql:synopsis:`drop property ` Remove a property item from this object type. See :eql:stmt:`drop property` for details. :eql:synopsis:`alter constraint ...` Alter the definition of a constraint for this object type. See :eql:stmt:`alter constraint` for details. :eql:synopsis:`drop constraint ;` Remove a constraint from this object type. See :eql:stmt:`drop constraint` for details. :eql:synopsis:`drop index on ` Remove an :ref:`index ` defined as *index-expr* from this object type. See :eql:stmt:`drop index` for details. All the subcommands allowed in the ``create type`` block are also valid subcommands for the ``alter type`` block. Example ^^^^^^^ Alter the ``User`` object type to make ``name`` required: .. code-block:: edgeql alter type User { alter property name { set required; } }; Drop type --------- :eql-statement: :eql-haswith: Remove the specified object type from the schema. .. eql:synopsis:: drop type ; Description ^^^^^^^^^^^ The command ``drop type`` removes the specified object type from the schema. All subordinate schema items defined on this type, such as links and indexes, are removed as well. Example ^^^^^^^ Remove the ``User`` object type: .. code-block:: edgeql drop type User; .. list-table:: :class: seealso * - **See also** * - :ref:`Introspection > Object types ` * - :ref:`Cheatsheets > Object types ` ================================================ FILE: docs/reference/datamodel/permissions.rst ================================================ .. _ref_datamodel_permissions: .. versionadded:: 7.0 =========== Permissions =========== .. index:: RBAC, role based access control, capability *Permissions* are the mechanism for limiting access to the database based on provided connection credentials. Each :ref:`role ` has as set of granted permissions. .. code-block:: edgeql create role alice { set password := 'wonderland'; set permissions := { sys::perm::data_modifiction, default::can_see_secrets }; }; Permissions are either :ref:`built-in ` or :ref:`defined in schema `. Some language features or functions require current role to have certain permissions. For example, to use ``insert``, ``update`` or ``delete``, current role is required to have ``sys::perm::data_modification``. Additionally, permissions of current role can be accessed via :ref:`global variables` of the same name: .. code-block:: edgeql select global sys::perm::data_modification; Note that roles are instance-wide object, which means that they exist independent of branches and their schemas. This means that role's permissions apply to all branches. Roles that are qualified as *superuser* are implicitly granted :ref:`all permissions`. Built-in permissions ==================== .. _ref_datamodel_permissions_built_in: :eql:synopsis:`sys::perm::data_modification` Required for using ``insert``, ``update`` or ``delete`` statements. :eql:synopsis:`sys::perm::ddl` Required for modification of schema. This includes applying migrations, and issuing bare DDL commands (e.g. ``create type Post;``). It does not include global instance commands, such as ``create branch`` or ``create role``. These are only allowed to *superuser* roles. :eql:synopsis:`sys::perm::branch_config` Required for issuing ``configure current branch``. :eql:synopsis:`sys::perm::sql_session_config` Required for issuing ``SET`` and ``RESET`` SQL commands. :eql:synopsis:`sys::perm::analyze` Required for issuing ``analyze ...`` queries. :eql:synopsis:`sys::perm::query_stats_read` Required for reading ``sys::QueryStats``. :eql:synopsis:`sys::perm::approximate_count` Required for accessing ``sys::approximate_count()``. :eql:synopsis:`cfg::perm::configure_timeout` Required for setting various timeouts, for example ``session_idle_transaction_timeout`` and ``query_execution_timeout``. :eql:synopsis:`cfg::perm::configure_apply_access_policies` Required for disabling access policies. :eql:synopsis:`cfg::perm::configure_allow_user_specified_id` Required for setting ``allow_user_specified_id``. :eql:synopsis:`std::net::perm::http_write` Required for issuing HTTP requests. :eql:synopsis:`std::net::perm::http_read` Required for reading status of issued HTTP requests and responses. Permissions for :ref:`auth ` extension: :eql:synopsis:`ext::auth::perm::auth_read` :eql:synopsis:`ext::auth::perm::auth_write` :eql:synopsis:`ext::auth::perm::auth_read_user` Permissions for ``ai`` extension are described in :ref:`AI extension reference `. Custom permissions ================== .. _ref_datamodel_permissions_custom: Custom permissions can be defined in schema, to fit the security model of each application. .. code-block:: sql module default { permission data_export; } These permissions can be assigned to roles, similar to built-in permissions: .. code-block:: edgeql alter role warehouse { set permissions := {default::data_export}; }; .. note:: Role permissions are instance-wide. If an unrelated branch defines ``default::data_export``, the ``warehouse`` role will receive it as well. This happens even if the unrelated branch adds the permission after ``alter role``. Additionally, a role may be given permissions which do not yet exist in any schema. This is useful for creating roles before any schemas are applied. To check if the current database connection's role has a permission, use :ref:`global variable` with the same name as the permission. This global is a boolean and cannot be manually set. .. code-block:: edgeql select global default::data_export; In combination with access policies, permissions can be used to limit read or write access of any type: .. code-block:: sdl type AuditLog { property event: str; access policy only_export_can_read allow select using (global data_export); access policy anyone_can_insert allow insert; } In this example, we have type ``AuditLog`` into which all roles are allowed to insert new log entries. But reading is allowed only to roles that posses ``data_export`` permission (or are qualified as a *superuser*). Common patterns =============== Public readonly database ------------------------ Gel server can be exposed to public internet, with clients connecting directy from browsers. Let's assume that only want to grant read access to the public browser client. In such scenarios, it is recommended to create a separate role that will be used by the JavaScript client (e.g. ``webapp``) and not grant it any permissions. This way, it will not be able to issue ``DROP TYPE`` or ``DELETE`` commands, but will be able to read all data in the database. More importantly, it will not be able to configure ``apply_access_policies`` to ``false`` to bypass our restrictions. If we want to limit that access further, for example limit read access to type ``Secrets``, we can use such schema: .. code-block:: sdl permission server_access; type Secret { access policy all_access allow select, insert, update, delete using (global server_access); }; Because ``webapp`` role will not possess permission ``server_access`` it will not be able to read (or modify) ``Secret``. For other, trusted clients, which should be able to access ``Secrets``, we have use *superuser* role, or some other role with ``server_access`` permission: .. code-block:: edgeql create role api_server { set password := 'strong_password'; set permissions := {sys::perm::dml, default::server_access}; }; Public partially writable database ---------------------------------- A similar example to the previous one is a public database, with a JavaScript client that needs write access to some, but not all, object types. In such scenarios, it is recommended to create a separate role for it (e.g. ``webapp``) and assign it ``sys::perm::ddl`` permission. Such role will be able to connect to the database, read all data and modify all types. For obvious reasons, this is undesirable, since client credentials could be extracted and used to delete all data in the database. To further limit access, the access policies must be used on every object: .. code-block:: sdl permission server_access; type Posts { # read-only access policy everyone_can_read allow select using (true); access policy server_can_do_everything allow select, insert, update, delete using (global server_access); } type Events { # insert-only access policy everyone_can_insert allow insert using (true); access policy server_can_do_everything allow select, insert, update, delete using (global server_access); } type Secrets { # no access access policy server_can_do_everything allow select, insert, update, delete using (global server_access); }; Again, we can then use superuser role for server to fully access the database, or setup a separate role with ``server_access`` permission. Restricting branches -------------------- To control access by branches instead of by object type, we can use ``Role.branches`` setting. For example, let's assume we have an instance with ``staging`` and ``prod`` branches. We want the role ``dev`` to have full access to ``staging``, but not ``prod``. .. code-block:: edgeql create role dev { set password := 'strong_password'; set branches := {'staging'}; }; For more about this, see :ref:`Roles `. Superuser permissions ===================== .. _ref_datamodel_permissions_superuser: Roles with *superuser* status are exempt from permission checks and have full access over the instance. This includes some commands that are not covered by any permission and are thus allowed *only* to *superuser* roles. These commands include: * :eql:synopsis:`ROLE` commands * :eql:synopsis:`BRANCH` commands * :eql:synopsis:`EXTENSION PACKAGE` commands * :eql:synopsis:`CONFIGURE INSTANCE` command * :eql:synopsis:`DESCRIBE` command * :eql:synopsis:`ADMINISTER` command .. list-table:: :class: seealso * - **See also** * - :ref:`Schema > Access policies ` * - :ref:`Running Gel > Administration > Roles ` ================================================ FILE: docs/reference/datamodel/primitives.rst ================================================ .. _ref_datamodel_primitives: ========== Primitives ========== |Gel| has a robust type system consisting of primitive and object types. types. Primitive types are used to declare *properties* on object types, as query and function arguments, as as well as in other contexts. .. _ref_datamodel_scalars: Built-in scalar types ===================== Gel comes with a range of built-in scalar types, such as: * String: :eql:type:`str` * Boolean: :eql:type:`bool` * Various numeric types: :eql:type:`int16`, :eql:type:`int32`, :eql:type:`int64`, :eql:type:`float32`, :eql:type:`float64`, :eql:type:`bigint`, :eql:type:`decimal` * JSON: :eql:type:`json`, * UUID: :eql:type:`uuid`, * Date/time: :eql:type:`datetime`, :eql:type:`duration` :eql:type:`cal::local_datetime`, :eql:type:`cal::local_date`, :eql:type:`cal::local_time`, :eql:type:`cal::relative_duration`, :eql:type:`cal::date_duration` * Miscellaneous: :eql:type:`sequence`, :eql:type:`bytes`, etc. Custom scalars ============== You can extend built-in scalars with additional constraints or annotations. Here's an example of a non-negative custom ``int64`` variant: .. code-block:: sdl scalar type posint64 extending int64 { constraint min_value(0); } .. _ref_datamodel_enums: Enums ===== Enum types are created by extending the abstract :eql:type:`enum` type, e.g.: .. code-block:: sdl scalar type Color extending enum; type Shirt { color: Color; } which can be queries with: .. code-block:: edgeql select Shirt filter .color = Color.Red; For a full reference on enum types, see the :ref:`Enum docs `. .. _ref_datamodel_arrays: Arrays ====== Arrays store zero or more primitive values of the same type in an ordered list. Arrays cannot contain object types or other arrays, but can contain virtually any other type. .. code-block:: sdl type Person { str_array: array; json_array: array; tuple_array: array>; # INVALID: arrays of object types not allowed: # friends: array # INVALID: arrays cannot be nested: # nested_array: array> # VALID: arrays can contain tuples with arrays in them nested_array_via_tuple: array>> } Array syntax in EdgeQL is very intuitive (indexing starts at ``0``): .. code-block:: edgeql select [1, 2, 3]; select [1, 2, 3][1] = 2; # true For a full reference on array types, see the :ref:`Array docs `. .. _ref_datamodel_tuples: Tuples ====== Like arrays, tuples are ordered sequences of primitive data. Unlike arrays, each element of a tuple can have a distinct type. Tuple elements can be *any type*, including primitives, objects, arrays, and other tuples. .. code-block:: sdl type Person { unnamed_tuple: tuple; nested_tuple: tuple>>; tuple_of_arrays: tuple, array>; } Optionally, you can assign a *key* to each element of the tuple. Tuples containing explicit keys are known as *named tuples*. You must assign keys to all elements (or none of them). .. code-block:: sdl type BlogPost { metadata: tuple; } Named and unnamed tuples are the same data structure under the hood. You can add, remove, and change keys in a tuple type after it's been declared. For details, see :ref:`Tuples `. .. note:: When you query an *unnamed* tuple using one of EdgeQL's :ref:`client libraries `, its value is converted to a list/array. When you fetch a named tuple, it is converted into an object/dictionary/hashmap depending on the language. .. _ref_datamodel_ranges: Ranges ====== Ranges represent some interval of values. The intervals can be bound or unbound on either end. They can also be empty, containing no values. Only some scalar types have corresponding range types: - Numeric ranges: ``range``, ``range``, ``range``, ``range``, ``range`` - Date/time ranges: ``range``, ``range``, ``range`` Example: .. code-block:: sdl type DieRoll { values: range; } For a full reference on ranges, functions and operators see the :ref:`Range docs `. Sequences ========= To represent an auto-incrementing integer property, declare a custom scalar that extends the abstract ``sequence`` type. Creating a sequence type initializes a global ``int64`` counter that auto-increments whenever a new object is created. All properties that point to the same sequence type will share the counter. .. code-block:: sdl scalar type ticket_number extending sequence; type Ticket { number: ticket_number; rendered_number := 'TICKET-\(.number)'; } For a full reference on sequences, see the :ref:`Sequence docs `. .. _ref_eql_sdl_scalars: .. _ref_eql_sdl_scalars_syntax: Declaring scalars ================= This section describes the syntax to declare a custom scalar type in your schema. Syntax ------ .. sdl:synopsis:: [abstract] scalar type [extending [, ...] ] [ "{" [ ] [ ] ... "}" ] Description ^^^^^^^^^^^ This declaration defines a new object type with the following options: :eql:synopsis:`abstract` If specified, the created scalar type will be *abstract*. :eql:synopsis:`` The name (optionally module-qualified) of the new scalar type. :eql:synopsis:`extending ` Optional clause specifying the *supertype* of the new type. If :eql:synopsis:`` is an :eql:type:`enumerated type ` declaration then an enumerated scalar type is defined. Use of ``extending`` creates a persistent type relationship between the new subtype and its supertype(s). Schema modifications to the supertype(s) propagate to the subtype. The valid SDL sub-declarations are listed below: :sdl:synopsis:`` Set scalar type :ref:`annotation ` to a given *value*. :sdl:synopsis:`` Define a concrete :ref:`constraint ` for this scalar type. .. _ref_eql_ddl_scalars: DDL commands ============ This section describes the low-level DDL commands for creating, altering, and dropping scalar types. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create scalar ------------- :eql-statement: :eql-haswith: Define a new scalar type. .. eql:synopsis:: [ with [, ...] ] create [abstract] scalar type [ extending ] [ "{" ; [...] "}" ] ; # where is one of create annotation := create constraint ... Description ^^^^^^^^^^^ The command ``create scalar type`` defines a new scalar type for use in the current |branch|. If *name* is qualified with a module name, then the type is created in that module, otherwise it is created in the current module. The type name must be distinct from that of any existing schema item in the module. If the ``abstract`` keyword is specified, the created type will be *abstract*. All non-abstract scalar types must have an underlying core implementation. For user-defined scalar types this means that ``create scalar type`` must have another non-abstract scalar type as its *supertype*. The most common use of ``create scalar type`` is to define a scalar subtype with constraints. Most sub-commands and options of this command are identical to the :ref:`SDL scalar type declaration `. The following subcommands are allowed in the ``create scalar type`` block: :eql:synopsis:`create annotation := ;` Set scalar type's :eql:synopsis:`` to :eql:synopsis:``. See :eql:stmt:`create annotation` for details. :eql:synopsis:`create constraint ...` Define a new constraint for this scalar type. See :eql:stmt:`create constraint` for details. Examples ^^^^^^^^ Create a new non-negative integer type: .. code-block:: edgeql create scalar type posint64 extending int64 { create constraint min_value(0); }; Create a new enumerated type: .. code-block:: edgeql create scalar type Color extending enum; Alter scalar ------------ :eql-statement: :eql-haswith: Alter the definition of a scalar type. .. eql:synopsis:: [ with [, ...] ] alter scalar type "{" ; [...] "}" ; # where is one of rename to extending ... create annotation := alter annotation := drop annotation create constraint ... alter constraint ... drop constraint ... Description ^^^^^^^^^^^ The command ``alter scalar type`` changes the definition of a scalar type. *name* must be a name of an existing scalar type, optionally qualified with a module name. The following subcommands are allowed in the ``alter scalar type`` block: :eql:synopsis:`rename to ;` Change the name of the scalar type to *newname*. :eql:synopsis:`extending ...` Alter the supertype list. It works the same way as in :eql:stmt:`alter type`. :eql:synopsis:`alter annotation ;` Alter scalar type :eql:synopsis:``. See :eql:stmt:`alter annotation` for details. :eql:synopsis:`drop annotation ` Remove scalar type's :eql:synopsis:`` from :eql:synopsis:``. See :eql:stmt:`drop annotation` for details. :eql:synopsis:`alter constraint ...` Alter the definition of a constraint for this scalar type. See :eql:stmt:`alter constraint` for details. :eql:synopsis:`drop constraint ` Remove a constraint from this scalar type. See :eql:stmt:`drop constraint` for details. All the subcommands allowed in the ``create scalar type`` block are also valid subcommands for ``alter scalar type`` block. Examples ^^^^^^^^ Define a new constraint on a scalar type: .. code-block:: edgeql alter scalar type posint64 { create constraint max_value(100); }; Add one more label to an enumerated type: .. code-block:: edgeql alter scalar type Color extending enum; Drop scalar ----------- :eql-statement: :eql-haswith: Remove a scalar type. .. eql:synopsis:: [ with [, ...] ] drop scalar type ; Description ^^^^^^^^^^^ The command ``drop scalar type`` removes a scalar type. Parameters ^^^^^^^^^^ *name* The name (optionally qualified with a module name) of an existing scalar type. Example ^^^^^^^ Remove a scalar type: .. code-block:: edgeql drop scalar type posint64; ================================================ FILE: docs/reference/datamodel/properties.rst ================================================ .. _ref_datamodel_props: ========== Properties ========== .. index:: property, primitive types, fields, columns Properties are used to associate primitive data with an :ref:`object type ` or :ref:`link `. .. code-block:: sdl type Player { property email: str; points: int64; is_online: bool; } Properties are associated with a *name* (e.g. ``email``) and a primitive type (e.g. ``str``). The term *primitive type* is an umbrella term that encompasses :ref:`scalar types ` like ``str``, :ref:`arrays ` and :ref:`tuples `, :ref:`and more `. Properties can be declared using the ``property`` keyword if that improves readability, or it can be ommitted. Required properties =================== .. index:: not null .. api-index:: required, optional Properties can be either ``optional`` (the default) or ``required``. E.g. here we have a ``User`` type that's guaranteed to have an ``email``, but ``name`` is optional and can be empty: .. code-block:: sdl type User { required email: str; optional name: str; } Since ``optional`` keyword is the default, we can omit it: .. code-block:: sdl type User { required email: str; name: str; } .. _ref_datamodel_props_cardinality: Cardinality =========== .. api-index:: single, multi Properties have a **cardinality**: * ``prop: type``, short for ``single prop: type``, can either hold zero or one value (that's the default). * ``multi prop: type`` can hold an *unordered set* of values, which can be zero, one, or more values of type ``type``. For example: .. code-block:: sdl type User { # "single" keyword isn't necessary here: # properties are single by default single name: str; # an unordered set of strings multi nicknames: str; # an unordered set of string arrays multi set_of_arrays: array; } multi vs. arrays ================ ``multi`` properties are stored differently than arrays under the hood. Essentially they are stored in a separate table ``(owner_id, value)``. .. rubric:: Pros of multi properties vs. arrays * ``multi`` properties allow efficient search and mutation of large sets. Arrays are much slower for those operations. * ``multi`` properties can have indexes and constraints appied to individual elements; arrays, in general, cannot. * It's easier to aggregate sets and operate on them than on arrays. In many cases arrays would require :ref:`unpacking them into a set ` first. .. rubric:: Cons of multi properties vs. arrays * On small sets, arrays are faster to retrieve. * It's easier to retain the original order in arrays. Arrays are ordered, but sets are not. .. _ref_datamodel_props_default_values: Default values ============== .. api-index:: default Properties can have a default value. This default can be a static value or an arbitrary EdgeQL expression, which will be evaluated upon insertion. .. code-block:: sdl type Player { required points: int64 { default := 0; } required latitude: float64 { default := (360 * random() - 180); } } Readonly properties =================== .. index:: immutable .. api-index:: readonly Properties can be marked as ``readonly``. In the example below, the ``User.external_id`` property can be set at the time of creation but not modified thereafter. .. code-block:: sdl type User { required external_id: uuid { readonly := true; } } Constraints =========== .. api-index:: constraint Properties can be augmented wth constraints. The example below showcases a subset of Gel's built-in constraints. .. code-block:: sdl type BlogPost { title: str { constraint exclusive; # all post titles must be unique constraint min_len_value(8); constraint max_len_value(30); constraint regexp(r'^[A-Za-z0-9 ]+$'); } status: str { constraint one_of('Draft', 'InReview', 'Published'); } upvotes: int64 { constraint min_value(0); constraint max_value(9999); } } You can constrain properties with arbitrary :ref:`EdgeQL ` expressions returning ``bool``. To reference the value of the property, use the special scope keyword ``__subject__``. .. code-block:: sdl type BlogPost { title: str { constraint expression on ( __subject__ = str_trim(__subject__) ); } } The constraint above guarantees that ``BlogPost.title`` doesn't contain any leading or trailing whitespace by checking that the raw string is equal to the trimmed version. It uses the built-in :eql:func:`str_trim` function. For a full reference of built-in constraints, see the :ref:`Constraints reference `. Annotations =========== .. index:: metadata Properties can contain annotations, small human-readable notes. The built-in annotations are ``title``, ``description``, and ``deprecated``. You may also declare :ref:`custom annotation types `. .. code-block:: sdl type User { email: str { annotation title := 'Email address'; } } Abstract properties =================== .. api-index:: abstract property Properties can be *concrete* (the default) or *abstract*. Abstract properties are declared independent of a source or target, can contain :ref:`annotations `, constraints, indexes, and can be marked as ``readonly``. .. code-block:: sdl abstract property email_prop { annotation title := 'An email address'; readonly := true; } type Student { # inherits annotations and "readonly := true" email: str { extending email_prop; }; } Overloading properties ====================== Any time we want to amend an inherited property (e.g. to add a constraint), the ``overloaded`` keyword must be used. This is to prevent unintentional overloading due to a name clash: .. code-block:: sdl abstract type Named { optional name: str; } type User extending Named { # make "name" required overloaded required name: str; } .. _ref_eql_sdl_props: .. _ref_eql_sdl_props_syntax: Declaring properties ==================== Syntax ------ This section describes the syntax to declare properties in your schema. .. sdl:synopsis:: # Concrete property form used inside type declaration: [ overloaded ] [{required | optional}] [{single | multi}] [ property ] : [ "{" [ extending [, ...] ; ] [ default := ; ] [ readonly := {true | false} ; ] [ ] [ ] ... "}" ] # Computed property form used inside type declaration: [{required | optional}] [{single | multi}] [ property ] := ; # Computed property form used inside type declaration (extended): [ overloaded ] [{required | optional}] [{single | multi}] property [: ] [ "{" using () ; [ extending [, ...] ; ] [ ] [ ] ... "}" ] # Abstract property form: abstract property [::] [ "{" [extending [, ...] ; ] [ readonly := {true | false} ; ] [ ] ... "}" ] Description ^^^^^^^^^^^ There are several forms of ``property`` declaration, as shown in the syntax synopsis above. The first form is the canonical definition form, the second and third forms are used for defining a :ref:`computed property `, and the last one is a form to define an ``abstract property``. The abstract form allows declaring the property directly inside a :ref:`module `. Concrete property forms are always used as sub-declarations for an :ref:`object type ` or a :ref:`link `. The following options are available: :eql:synopsis:`overloaded` If specified, indicates that the property is inherited and that some feature of it may be altered in the current object type. It is an error to declare a property as *overloaded* if it is not inherited. :eql:synopsis:`required` If specified, the property is considered *required* for the parent object type. It is an error for an object to have a required property resolve to an empty value. Child properties **always** inherit the *required* attribute, i.e it is not possible to make a required property non-required by extending it. :eql:synopsis:`optional` This is the default qualifier assumed when no qualifier is specified, but it can also be specified explicitly. The property is considered *optional* for the parent object type, i.e. it is possible for the property to resolve to an empty value. :eql:synopsis:`multi` Specifies that there may be more than one instance of this property in an object, in other words, ``Object.property`` may resolve to a set of a size greater than one. :eql:synopsis:`single` Specifies that there may be at most *one* instance of this property in an object, in other words, ``Object.property`` may resolve to a set of a size not greater than one. ``single`` is assumed if nether ``multi`` nor ``single`` qualifier is specified. :eql:synopsis:`extending [, ...]` Optional clause specifying the *parents* of the new property item. Use of ``extending`` creates a persistent schema relationship between the new property and its parents. Schema modifications to the parent(s) propagate to the child. :eql:synopsis:`` The type must be a valid :ref:`type expression ` denoting a non-abstract scalar or a container type. The valid SDL sub-declarations are listed below: :eql:synopsis:`default := ` Specifies the default value for the property as an EdgeQL expression. The default value is used in an ``insert`` statement if an explicit value for this property is not specified. The expression must be :ref:`Stable `. :eql:synopsis:`readonly := {true | false}` If ``true``, the property is considered *read-only*. Modifications of this property are prohibited once an object is created. All of the derived properties **must** preserve the original *read-only* value. :sdl:synopsis:`` Set property :ref:`annotation ` to a given *value*. :sdl:synopsis:`` Define a concrete :ref:`constraint ` on the property. .. _ref_eql_ddl_props: DDL commands ============ This section describes the low-level DDL commands for creating, altering, and dropping properties. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. .. _ref_eql_ddl_props_syntax: Create property --------------- :eql-statement: :eql-haswith: Define a new property. .. eql:synopsis:: [ with [, ...] ] {create|alter} {type|link} "{" [ ... ] create [{required | optional}] [{single | multi}] property [ extending [, ...] ] : [ "{" ; [...] "}" ] ; [ ... ] "}" # Computed property form: [ with [, ...] ] {create|alter} {type|link} "{" [ ... ] create [{required | optional}] [{single | multi}] property := ; [ ... ] "}" # Abstract property form: [ with [, ...] ] create abstract property [::] [extending [, ...]] [ "{" ; [...] "}" ] # where is one of set default := set readonly := {true | false} create annotation := create constraint ... Parameters ^^^^^^^^^^ Most sub-commands and options of this command are identical to the :ref:`SDL property declaration `. The following subcommands are allowed in the ``create property`` block: :eql:synopsis:`set default := ` Specifies the default value for the property as an EdgeQL expression. Other than a slight syntactical difference this is the same as the corresponding SDL declaration. :eql:synopsis:`set readonly := {true | false}` Specifies whether the property is considered *read-only*. Other than a slight syntactical difference this is the same as the corresponding SDL declaration. :eql:synopsis:`create annotation := ` Set property :eql:synopsis:`` to :eql:synopsis:``. See :eql:stmt:`create annotation` for details. :eql:synopsis:`create constraint` Define a concrete constraint on the property. See :eql:stmt:`create constraint` for details. Examples ^^^^^^^^ Define a new link ``address`` on the ``User`` object type: .. code-block:: edgeql alter type User { create property address: str }; Define a new :ref:`computed property ` ``number_of_connections`` on the ``User`` object type counting the number of interests: .. code-block:: edgeql alter type User { create property number_of_connections := count(.interests) }; Define a new abstract link ``orderable`` with ``weight`` property: .. code-block:: edgeql create abstract link orderable { create property weight: std::int64 }; Alter property -------------- :eql-statement: :eql-haswith: Change the definition of a property. .. eql:synopsis:: [ with [, ...] ] {create | alter} {type | link} "{" [ ... ] alter property [ "{" ] ; [...] [ "}" ]; [ ... ] "}" [ with [, ...] ] alter abstract property [::] [ "{" ] ; [...] [ "}" ]; # where is one of set default := reset default set readonly := {true | false} reset readonly rename to extending ... set required [using ( [using () create annotation := alter annotation := drop annotation create constraint ... alter constraint ... drop constraint ... Parameters ^^^^^^^^^^ :eql:synopsis:`` The name of an object type or link on which the property is defined. May be optionally qualified with module. :eql:synopsis:`` The unqualified name of the property to modify. :eql:synopsis:`` Optional name of the module to create or alter the abstract property in. If not specified, the current module is used. The following subcommands are allowed in the ``alter link`` block: :eql:synopsis:`rename to ` Change the name of the property to :eql:synopsis:``. All concrete properties inheriting from this property are also renamed. :eql:synopsis:`extending ...` Alter the property parent list. The full syntax of this subcommand is: .. eql:synopsis:: extending [, ...] [ first | last | before | after ] This subcommand makes the property a child of the specified list of parent property items. The requirements for the parent-child relationship are the same as when creating a property. It is possible to specify the position in the parent list using the following optional keywords: * ``first`` -- insert parent(s) at the beginning of the parent list, * ``last`` -- insert parent(s) at the end of the parent list, * ``before `` -- insert parent(s) before an existing *parent*, * ``after `` -- insert parent(s) after an existing *parent*. :eql:synopsis:`set required [using ( [using (`. The optional ``using`` clause specifies a conversion expression that computes the new property value from the old. The conversion expression must return a singleton set and is evaluated on each element of ``multi`` properties. A ``using`` clause must be provided if there is no implicit or assignment cast from old to new type. :eql:synopsis:`reset type` Reset the type of the property to the type inherited from properties of the same name in supertypes. It is an error to ``reset type`` on a property that is not inherited. :eql:synopsis:`using ()` Change the expression of a :ref:`computed property `. Only valid for concrete properties. :eql:synopsis:`alter annotation ;` Alter property annotation :eql:synopsis:``. See :eql:stmt:`alter annotation` for details. :eql:synopsis:`drop annotation ;` Remove property annotation :eql:synopsis:``. See :eql:stmt:`drop annotation` for details. :eql:synopsis:`alter constraint ...` Alter the definition of a constraint for this property. See :eql:stmt:`alter constraint` for details. :eql:synopsis:`drop constraint ;` Remove a constraint from this property. See :eql:stmt:`drop constraint` for details. :eql:synopsis:`reset default` Remove the default value from this property, or reset it to the value inherited from a supertype, if the property is inherited. :eql:synopsis:`reset readonly` Set property writability to the default value (writable), or, if the property is inherited, to the value inherited from properties in supertypes. All the subcommands allowed in the ``create property`` block are also valid subcommands for ``alter property`` block. Examples ^^^^^^^^ Set the ``title`` annotation of property ``address`` of object type ``User`` to ``"Home address"``: .. code-block:: edgeql alter type User { alter property address create annotation title := "Home address"; }; Add a maximum-length constraint to property ``address`` of object type ``User``: .. code-block:: edgeql alter type User { alter property address { create constraint max_len_value(500); }; }; Rename the property ``weight`` of link ``orderable`` to ``sort_by``: .. code-block:: edgeql alter abstract link orderable { alter property weight rename to sort_by; }; Redefine the :ref:`computed property ` ``number_of_connections`` to be the number of friends: .. code-block:: edgeql alter type User { alter property number_of_connections using ( count(.friends) ) }; Drop property ------------- :eql-statement: :eql-haswith: Remove a property from the schema. .. eql:synopsis:: [ with [, ...] ] {create|alter} type "{" [ ... ] drop link [ ... ] "}" [ with [, ...] ] drop abstract property ; Example ^^^^^^^ Remove property ``address`` from type ``User``: .. code-block:: edgeql alter type User { drop property address; }; ================================================ FILE: docs/reference/datamodel/triggers.rst ================================================ .. _ref_datamodel_triggers: .. _ref_eql_sdl_triggers: ======== Triggers ======== Triggers allow you to define an expression to be executed whenever a given query type is run on an object type. The original query will *trigger* your pre-defined expression to run in a transaction along with the original query. These can be defined in your schema. Important notes =============== Triggers are an advanced feature and have some caveats that you should be aware of. Consider using mutation rewrites -------------------------------- Triggers cannot be used to *modify* the object that set off the trigger, although they can be used with :eql:func:`assert` to do *validation* on that object. If you need to modify the object, you can use :ref:`mutation rewrites `. Unified trigger query execution ------------------------------- All queries within triggers, along with the initial triggering query, are compiled into a single combined SQL query under the hood. Keep this in mind when designing triggers that modify existing records. If multiple ``update`` queries within your triggers target the same object, only one of these queries will ultimately be executed. To ensure all desired updates on an object are applied, consolidate them into a single ``update`` query within one trigger, instead of distributing them across multiple updates. Multi-stage trigger execution ----------------------------- In some cases, a trigger can cause another trigger to fire. When this happens, Gel completes all the triggers fired by the initial query before kicking off a new "stage" of triggers. In the second stage, any triggers fired by the initial stage of triggers will fire. Gel will continue adding trigger stages until all triggers are complete. The exception to this is when triggers would cause a loop or would cause the same trigger to be run in two different stages. These triggers will generate an error. Data visibility --------------- Any query in your trigger will return the state of the database *after* the triggering query. If this query's results include the object that flipped the trigger, the results will contain that object in the same state as ``__new__``. Example: audit log ================== Here's an example that creates a simple **audit log** type so that we can keep track of what's happening to our users in a database. First, we will create a ``Log`` type: .. code-block:: sdl type Log { action: str; timestamp: datetime { default := datetime_current(); } target_name: str; change: str; } With the ``Log`` type in place, we can write some triggers that will automatically create ``Log`` objects for any insert, update, or delete queries on the ``Person`` type: .. code-block:: sdl type Person { required name: str; trigger log_insert after insert for each do ( insert Log { action := 'insert', target_name := __new__.name } ); trigger log_update after update for each do ( insert Log { action := 'update', target_name := __new__.name, change := __old__.name ++ '->' ++ __new__.name } ); trigger log_delete after delete for each do ( insert Log { action := 'delete', target_name := __old__.name } ); } In a trigger's expression, we have access to the ``__old__`` and/or ``__new__`` variables which capture the object before and after the query. Triggers on ``update`` can use both variables. Triggers on ``delete`` can use ``__old__``. Triggers on ``insert`` can use ``__new__``. Now, whenever we run a query, we get a log entry as well: .. code-block:: edgeql-repl db> insert Person {name := 'Jonathan Harker'}; {default::Person {id: b4d4e7e6-bd19-11ed-8363-1737d8d4c3c3}} db> select Log {action, timestamp, target_name, change}; { default::Log { action: 'insert', timestamp: '2023-03-07T18:56:02.403817Z', target_name: 'Jonathan Harker', change: {} } } db> update Person filter .name = 'Jonathan Harker' ... set {name := 'Mina Murray'}; {default::Person {id: b4d4e7e6-bd19-11ed-8363-1737d8d4c3c3}} db> select Log {action, timestamp, target_name, change}; { default::Log { action: 'insert', timestamp: '2023-03-07T18:56:02.403817Z', target_name: 'Jonathan Harker', change: {} }, default::Log { action: 'update', timestamp: '2023-03-07T18:56:39.520889Z', target_name: 'Mina Murray', change: 'Jonathan Harker->Mina Murray' }, } db> delete Person filter .name = 'Mina Murray'; {default::Person {id: b4d4e7e6-bd19-11ed-8363-1737d8d4c3c3}} db> select Log {action, timestamp, target_name, change}; { default::Log { action: 'insert', timestamp: '2023-03-07T18:56:02.403817Z', target_name: 'Jonathan Harker', change: {} }, default::Log { action: 'update', timestamp: '2023-03-07T18:56:39.520889Z', target_name: 'Mina Murray', change: 'Jonathan Harker->Mina Murray' }, default::Log { action: 'delete', timestamp: '2023-03-07T19:00:52.636084Z', target_name: 'Mina Murray', change: {} }, } Our audit logging works, but the update logs have a major shortcoming: they log an update even when nothing changes. Any time an ``update`` query runs, we get a log, even if the values are the same. We can prevent that by using the trigger's ``when`` to run the trigger conditionally. Here's a rework of our ``update`` logging query: .. code-block:: sdl-invalid trigger log_update after update for each when (__old__.name != __new__.name) do ( insert Log { action := 'update', target_name := __new__.name, change := __old__.name ++ '->' ++ __new__.name } ); If this object were more complicated and we had many properties to compare, we could use a ``json`` cast to compare them all in one shot: .. code-block:: sdl-invalid trigger log_update after update for each when (__old__ {**} != __new__ {**}) do ( insert Log { action := 'update', target_name := __new__.name, change := __old__.name ++ '->' ++ __new__.name } ); You might find that one log entry per row is too granular or too noisy for your use case. In that case, a ``for all`` trigger may be a better fit. Here's a schema that changes the ``Log`` type so that each object can log multiple writes by making ``target_name`` and ``change`` :ref:`multi properties ` and switches to ``for all`` triggers: .. code-block:: sdl-diff type Log { action: str; timestamp: datetime { default := datetime_current(); } - target_name: str; - change: str; + multi target_name: str; + multi change: str; } type Person { required name: str; - trigger log_insert after insert for each do ( + trigger log_insert after insert for all do ( insert Log { action := 'insert', target_name := __new__.name } ); - trigger log_update after update for each do ( + trigger log_update after update for all do ( insert Log { action := 'update', target_name := __new__.name, change := __old__.name ++ '->' ++ __new__.name } ); - trigger log_delete after delete for each do ( + trigger log_delete after delete for all do ( insert Log { action := 'delete', target_name := __old__.name } ); } Under this new schema, each query matching the trigger gets a single ``Log`` object instead of one ``Log`` object per row: .. code-block:: edgeql-repl db> for name in {'Jonathan Harker', 'Mina Murray', 'Dracula'} ... union ( ... insert Person {name := name} ... ); { default::Person {id: 3836f9c8-d393-11ed-9638-3793d3a39133}, default::Person {id: 38370a8a-d393-11ed-9638-d3e9b92ca408}, default::Person {id: 38370abc-d393-11ed-9638-5390f3cbd375}, } db> select Log {action, timestamp, target_name, change}; { default::Log { action: 'insert', timestamp: '2023-03-07T19:12:21.113521Z', target_name: {'Jonathan Harker', 'Mina Murray', 'Dracula'}, change: {}, }, } db> for change in { ... (old_name := 'Jonathan Harker', new_name := 'Jonathan'), ... (old_name := 'Mina Murray', new_name := 'Mina') ... } ... union ( ... update Person filter .name = change.old_name set { ... name := change.new_name ... } ... ); { default::Person {id: 3836f9c8-d393-11ed-9638-3793d3a39133}, default::Person {id: 38370a8a-d393-11ed-9638-d3e9b92ca408}, } db> select Log {action, timestamp, target_name, change}; { default::Log { action: 'insert', timestamp: '2023-04-05T09:21:17.514089Z', target_name: {'Jonathan Harker', 'Mina Murray', 'Dracula'}, change: {}, }, default::Log { action: 'update', timestamp: '2023-04-05T09:35:30.389571Z', target_name: {'Jonathan', 'Mina'}, change: {'Jonathan Harker->Jonathan', 'Mina Murray->Mina'}, }, } Example: validation =================== .. index:: trigger, validate, assert Triggers may also be used for validation by calling :eql:func:`assert` inside the trigger. In this example, the ``Person`` type has two multi links to other ``Person`` objects named ``friends`` and ``enemies``. These two links should be mutually exclusive, so we have written a trigger to make sure there are no common objects linked in both. .. code-block:: sdl type Person { required name: str; multi friends: Person; multi enemies: Person; trigger prohibit_frenemies after insert, update for each do ( assert( not exists (__new__.friends intersect __new__.enemies), message := "Invalid frenemies", ) ) } With this trigger in place, it is impossible to link the same ``Person`` as both a friend and an enemy of any other person. .. code-block:: edgeql-repl db> insert Person {name := 'Quincey Morris'}; {default::Person {id: e4a55480-d2de-11ed-93bd-9f4224fc73af}} db> insert Person {name := 'Dracula'}; {default::Person {id: e7f2cff0-d2de-11ed-93bd-279780478afb}} db> update Person ... filter .name = 'Quincey Morris' ... set { ... enemies := ( ... select detached Person filter .name = 'Dracula' ... ) ... }; {default::Person {id: e4a55480-d2de-11ed-93bd-9f4224fc73af}} db> update Person ... filter .name = 'Quincey Morris' ... set { ... friends := ( ... select detached Person filter .name = 'Dracula' ... ) ... }; gel error: GelError: Invalid frenemies Example: logging ================ Declare a trigger that inserts a ``Log`` object for each new ``User`` object: .. code-block:: sdl type User { required name: str; trigger log_insert after insert for each do ( insert Log { action := 'insert', target_name := __new__.name } ); } Declare a trigger that inserts a ``Log`` object conditionally when an update query makes a change to a ``User`` object: .. code-block:: sdl type User { required name: str; trigger log_update after update for each when (__old__ {**} != __new__ {**}) do ( insert Log { action := 'update', target_name := __new__.name, change := __old__.name ++ '->' ++ __new__.name } ); } .. _ref_eql_sdl_triggers_syntax: Declaring triggers ================== .. api-index:: trigger, after insert, after update, after delete, for each, for all, when, do, __new__, __old__ This section describes the syntax to declare a trigger in your schema. Syntax ------ .. sdl:synopsis:: type "{" trigger after {insert | update | delete} [, ...] for {each | all} [ when () ] do "}" Description ----------- This declaration defines a new trigger with the following options: :eql:synopsis:`` The name (optionally module-qualified) of the type to be triggered on. :eql:synopsis:`` The name of the trigger. :eql:synopsis:`insert | update | delete [, ...]` The query type (or types) to trigger on. Separate multiple values with commas to invoke the same trigger for multiple types of queries. :eql:synopsis:`each` The expression will be evaluated once per modified object. ``__new__`` and ``__old__`` in this context within the expression will refer to a single object. :eql:synopsis:`all` The expression will be evaluted once for the entire query, even if multiple objects were modified. ``__new__`` and ``__old__`` in this context within the expression refer to sets of the modified objects. .. versionadded:: 4.0 :eql:synopsis:`when ()` Optionally provide a condition for the trigger. If the condition is met, the trigger will run. If not, the trigger is skipped. :eql:synopsis:`` The expression to be evaluated when the trigger is invoked. The trigger name must be distinct from that of any existing trigger on the same type. .. _ref_eql_ddl_triggers: DDL commands ============ This section describes the low-level DDL commands for creating and dropping triggers. You typically don't need to use these commands directly, but knowing about them is useful for reviewing migrations. Create trigger -------------- :eql-statement: :ref:`Define ` a new trigger. .. eql:synopsis:: {create | alter} type "{" create trigger after {insert | update | delete} [, ...] for {each | all} [ when () ] do "}" Description ^^^^^^^^^^^ The command ``create trigger`` nested under ``create type`` or ``alter type`` defines a new trigger for a given object type. The trigger name must be distinct from that of any existing trigger on the same type. Parameters ^^^^^^^^^^ The options of this command are identical to the :ref:`SDL trigger declaration `. Example ^^^^^^^ Declare a trigger that inserts a ``Log`` object for each new ``User`` object: .. code-block:: edgeql alter type User { create trigger log_insert after insert for each do ( insert Log { action := 'insert', target_name := __new__.name } ); }; .. versionadded:: 4.0 Declare a trigger that inserts a ``Log`` object conditionally when an update query makes a change to a ``User`` object: .. code-block:: edgeql alter type User { create trigger log_update after update for each when (__old__ {**} != __new__ {**}) do ( insert Log { action := 'update', target_name := __new__.name, change := __old__.name ++ '->' ++ __new__.name } ); } Drop trigger ------------ :eql-statement: Remove a trigger. .. eql:synopsis:: alter type "{" drop trigger ; "}" Description ^^^^^^^^^^^ The command ``drop trigger`` inside an ``alter type`` block removes the definition of an existing trigger on the specified type. Parameters ^^^^^^^^^^ :eql:synopsis:`` The name (optionally module-qualified) of the type being triggered on. :eql:synopsis:`` The name of the trigger. Example ^^^^^^^ Remove the ``log_insert`` trigger on the ``User`` type: .. code-block:: edgeql alter type User { drop trigger log_insert; }; .. list-table:: :class: seealso * - **See also** * - :ref:`Introspection > Triggers ` ================================================ FILE: docs/reference/edgeql/analyze.rst ================================================ .. _ref_eql_analyze: Analyze ======= .. index:: explain, performance, postgres query planner .. api-index:: analyze Prefix an EdgeQL query with ``analyze`` to run a performance analysis of that query. .. code-block:: edgeql-repl db> analyze select Hero { ... name, ... secret_identity, ... villains: { ... name, ... nemesis: { ... name ... } ... } ... }; ──────────────────────────────────────── Query ──────────────────────────────────────── analyze select ➊ Hero {name, secret_identity, ➋ villains: {name, ➌ nemesis: {name}}}; ──────────────────────── Coarse-grained Query Plan ──────────────────────── │ Time Cost Loops Rows Width │ Relations ➊ root │ 0.0 69709.48 1.0 0.0 32 │ Hero ╰──➋ .villains │ 0.0 92.9 0.0 0.0 32 │ Villain, Hero.villains ╰──➌ .nemesis │ 0.0 8.18 0.0 0.0 32 │ Hero .. note:: In addition to using the ``analyze`` statement in the CLI or UI's REPL, you may also run performance analysis via our CLI's :ref:`analyze command ` and the UI's query builder (accessible by running :ref:`ref_cli_gel_ui` to invoke your instance's UI) by prepending your query with ``analyze``. This method offers helpful visualizations to to make it easy to understand your query's performance. After analyzing a query, you may run the ``\expand`` command in the REPL to see more fine-grained performance metrics on the previously analyzed query. .. list-table:: :class: seealso * - **See also** * - :ref:`CLI > gel analyze ` * - :ref:`Reference > EdgeQL > analyze ` ================================================ FILE: docs/reference/edgeql/delete.rst ================================================ .. _ref_eql_delete: Delete ====== .. api-index:: delete The ``delete`` command is used to delete objects from the database. .. code-block:: edgeql delete Hero filter .name = 'Iron Man'; Clauses ------- Deletion statements support ``filter``, ``order by``, ``offset``, and ``limit`` clauses. See :ref:`EdgeQL > Select ` for full documentation on these clauses. .. code-block:: edgeql delete Hero filter .name ilike 'the %' order by .name offset 10 limit 5; Link deletion ------------- .. api-index:: ConstraintViolationError Every link is associated with a *link deletion policy*. By default, it isn't possible to delete an object linked to by another. .. code-block:: edgeql-repl db> delete Hero filter .name = "Yelena Belova"; ConstraintViolationError: deletion of default::Hero (af7076e0-3e98-11ec-abb3-b3435bbe7c7e) is prohibited by link target policy {} This deletion failed because Yelena is still in the ``characters`` list of the Black Widow movie. We must destroy this link before Yelena can be deleted. .. code-block:: edgeql-repl db> update Movie ... filter .title = "Black Widow" ... set { ... characters -= (select Hero filter .name = "Yelena Belova") ... }; {default::Movie {id: af706c7c-3e98-11ec-abb3-4bbf3f18a61a}} db> delete Hero filter .name = "Yelena Belova"; {default::Hero {id: af7076e0-3e98-11ec-abb3-b3435bbe7c7e}} To avoid this behavior, we could update the ``Movie.characters`` link to use the ``allow`` deletion policy. .. code-block:: sdl-diff type Movie { required title: str { constraint exclusive }; required release_year: int64; - multi characters: Person; + multi characters: Person { + on target delete allow; + }; } Cascading deletes ^^^^^^^^^^^^^^^^^ .. index:: deletion policy .. api-index:: delete source, delete target If a link uses the ``delete source`` policy, then deleting a *target* of the link will also delete the object that links to it (the *source*). This behavior can be used to implement cascading deletes; be careful with this power! The full list of deletion policies is documented at :ref:`Schema > Links `. Return value ------------ .. index:: returning A ``delete`` statement returns the set of deleted objects. You can pass this set into ``select`` to fetch properties and links of the (now-deleted) objects. This is the last moment this data will be available before being permanently deleted. .. code-block:: edgeql-repl db> with movie := (delete Movie filter .title = "Untitled") ... select movie {id, title}; {default::Movie { id: b11303c6-40ac-11ec-a77d-d393cdedde83, title: 'Untitled', }} .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > Commands > Delete ` * - :ref:`Cheatsheets > Deleting data ` ================================================ FILE: docs/reference/edgeql/for.rst ================================================ .. _ref_eql_for: For === .. api-index:: for in, union EdgeQL supports a top-level ``for`` statement. These "for loops" iterate over each element of some input set, execute some expression with it, and merge the results into a single output set. .. code-block:: edgeql-repl db> for number in {0, 1, 2, 3} ... union ( ... select { number, number + 0.5 } ... ); {0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5} This statement iterates through each number in the set. Inside the loop, the ``number`` variable is bound to a singleton set. The inner expression is executed for every element of the input set, and the results of each execution are merged into a single output set. .. note:: The ``union`` keyword is required prior to |EdgeDB| 5.0 and is intended to indicate explicitly that the results of each loop execution are ultimately merged. .. versionadded: 5.0 If the body of ``for`` is a statement — ``select``, ``insert``, ``update``, ``delete``, ``group``, or ``with`` — ``union`` and the parentheses surrounding the statement are no longer required: .. code-block:: edgeql-repl db> for number in {0, 1, 2, 3} ... select { number, number + 0.5 } {0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5} Bulk inserts ------------ The ``for`` statement is commonly used for bulk inserts. .. code-block:: edgeql-repl db> for hero_name in {'Cersi', 'Ikaris', 'Thena'} ... union ( ... insert Hero { name := hero_name } ... ); { default::Hero {id: d7d7e0f6-40ae-11ec-87b1-3f06bed494b9}, default::Hero {id: d7d7f870-40ae-11ec-87b1-f712a4efc3a5}, default::Hero {id: d7d7f8c0-40ae-11ec-87b1-6b8685d56610} } This statement iterates through each name in the list of names. Inside the loop, ``hero_name`` is bound to a ``str`` singleton, so it can be assigned to ``Hero.name``. Instead of literal sets, it's common to use a :ref:`json ` parameter for bulk inserts. This value is then "unpacked" into a set of ``json`` elements and used inside the ``for`` loop: .. code-block:: edgeql-repl db> with ... raw_data := $data, ... for item in json_array_unpack(raw_data) union ( ... insert Hero { name := item['name'] } ... ); Parameter $data: [{"name":"Sersi"},{"name":"Ikaris"},{"name":"Thena"}] { default::Hero {id: d7d7e0f6-40ae-11ec-87b1-3f06bed494b9}, default::Hero {id: d7d7f870-40ae-11ec-87b1-f712a4efc3a5}, default::Hero {id: d7d7f8c0-40ae-11ec-87b1-6b8685d56610} } A similar approach can be used for bulk updates. .. _ref_eql_for_conditional_dml: Conditional DML --------------- .. api-index:: for, if else, unless conflict .. versionadded:: 4.0 DML is now supported in ``if..else``. DML (i.e., :ref:`insert `, :ref:`update `, :ref:`delete `) is not supported in :eql:op:`if..else`. If you need to do one of these conditionally, you can use a ``for`` loop as a workaround. For example, you might want to write this conditional: .. code-block:: # 🚫 Does not work with admin := (select User filter .role = 'admin') select admin if exists admin else (insert User {role := 'admin'}); Because of the lack of support for DML in a conditional, this query will fail. Here's how you can accomplish the same thing using the workaround: .. code-block:: edgeql # ✅ Works! with admin := (select User filter .role = 'admin'), new := (for _ in (select () filter not exists admin) union ( insert User {role := 'admin'} )), select {admin, new}; The ``admin`` alias represents the condition we want to test for. In this case, "do we have a ``User`` object with a value of ``admin`` for the ``role`` property?" In the ``new`` alias, we write a ``for`` loop with a ``select`` query that will produce a set with a single value if that object we queried for does *not* exist. (You can use ``exists`` instead of ``not exists`` in the nested ``select`` inside the ``for`` loop if you don't want to invert the condition.) A set with a single value results in a single iteration of the ``for`` loop. Inside that loop, we run our conditional DML — in this case to insert an admin user. Then we ``select`` both aliases to execute both of their queries. The query will return the ``User`` object. This in effect gives us a query that will insert a ``User`` object with a ``role`` of ``admin`` if none exists or return that object if it *does* exist. .. note:: If you're trying to conditionally run DML in response to a violation of an exclusivity constraint, you don't need this workaround. You should use :ref:`unless conflict ` instead. .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > Commands > For ` ================================================ FILE: docs/reference/edgeql/group.rst ================================================ .. _ref_eql_group: Group ===== .. index:: analytics, aggregate .. api-index:: group by, group using by, key, grouping, elements, rollup, cube EdgeQL supports a top-level ``group`` statement. This is used to partition sets into subsets based on some parameters. These subsets then can be additionally aggregated to provide some analytics. The most basic format is just using the bare :eql:stmt:`group` to group a set of objects by some property: .. code-block:: edgeql-repl db> group Movie by .release_year; { { key: {release_year: 2016}, grouping: {'release_year'}, elements: { default::Movie {title: 'Captain America: Civil War'}, default::Movie {title: 'Doctor Strange'}, }, }, { key: {release_year: 2017}, grouping: {'release_year'}, elements: { default::Movie {title: 'Spider-Man: Homecoming'}, default::Movie {title: 'Thor: Ragnarok'}, }, }, { key: {release_year: 2018}, grouping: {'release_year'}, elements: {default::Movie {title: 'Ant-Man and the Wasp'}}, }, { key: {release_year: 2019}, grouping: {'release_year'}, elements: {default::Movie {title: 'Spider-Man: No Way Home'}}, }, { key: {release_year: 2021}, grouping: {'release_year'}, elements: {default::Movie {title: 'Black Widow'}}, }, ... } Notice that the result of ``group`` is a set of :ref:`free objects ` with three fields: * ``key``: another free object containing the specific value of the grouping parameter for a given subset. * ``grouping``: set of names of grouping parameters, i.e. the specific names that also appear in the ``key`` free object. * ``elements``: the actual subset of values that match the ``key``. In the ``group`` statement, referring to the property in the ``by`` clause **must** be done by using the leading dot shothand ``.release_year``. The property name then shows up in ``grouping`` and ``key`` to indicate the defining characteristics of the particular result. Alternatively, we can give it an alias in an optional ``using`` clause and then that alias can be used in the ``by`` clause and will appear in the results: .. code-block:: edgeql-repl db> group Movie {title} ... using year := .release_year by year; { { key: {year: 2016}, grouping: {'year'}, elements: { default::Movie {title: 'Captain America: Civil War'}, default::Movie {title: 'Doctor Strange'}, }, }, { key: {year: 2017}, grouping: {'year'}, elements: { default::Movie {title: 'Spider-Man: Homecoming'}, default::Movie {title: 'Thor: Ragnarok'}, }, }, { key: {year: 2018}, grouping: {'year'}, elements: {default::Movie {title: 'Ant-Man and the Wasp'}}, }, { key: {year: 2019}, grouping: {'year'}, elements: {default::Movie {title: 'Spider-Man: No Way Home'}}, }, { key: {year: 2021}, grouping: {'year'}, elements: {default::Movie {title: 'Black Widow'}}, }, ... } The ``using`` clause is perfect for defining a more complex expression to group things by. For example, instead of grouping by the ``release_year`` we can group by the release decade: .. code-block:: edgeql-repl db> group Movie {title} ... using decade := .release_year // 10 ... by decade; { { { key: {decade: 200}, grouping: {'decade'}, elements: { default::Movie {title: 'Spider-Man'}, default::Movie {title: 'Spider-Man 2'}, default::Movie {title: 'Spider-Man 3'}, default::Movie {title: 'Iron Man'}, default::Movie {title: 'The Incredible Hulk'}, }, }, { key: {decade: 201}, grouping: {'decade'}, elements: { default::Movie {title: 'Iron Man 2'}, default::Movie {title: 'Thor'}, default::Movie {title: 'Captain America: The First Avenger'}, default::Movie {title: 'The Avengers'}, default::Movie {title: 'Iron Man 3'}, default::Movie {title: 'Thor: The Dark World'}, default::Movie {title: 'Captain America: The Winter Soldier'}, default::Movie {title: 'Ant-Man'}, default::Movie {title: 'Captain America: Civil War'}, default::Movie {title: 'Doctor Strange'}, default::Movie {title: 'Spider-Man: Homecoming'}, default::Movie {title: 'Thor: Ragnarok'}, default::Movie {title: 'Ant-Man and the Wasp'}, default::Movie {title: 'Spider-Man: No Way Home'}, }, }, { key: {decade: 202}, grouping: {'decade'}, elements: {default::Movie {title: 'Black Widow'}}, }, } It's also possible to group by more than one parameter, so we can group by whether the movie ``title`` contains a colon *and* the decade it was released. Additionally, let's only consider more recent movies, say, released after 2015, so that we're not overwhelmed by all the combination of results: .. code-block:: edgeql-repl db> with ... # Apply the group query only to more recent movies ... M := (select Movie filter .release_year > 2015) ... group M {title} ... using ... decade := .release_year // 10, ... has_colon := .title like '%:%' ... by decade, has_colon; { { key: {decade: 201, has_colon: false}, grouping: {'decade', 'has_colon'}, elements: { default::Movie {title: 'Ant-Man and the Wasp'}, default::Movie {title: 'Doctor Strange'}, }, }, { key: {decade: 201, has_colon: true}, grouping: {'decade', 'has_colon'}, elements: { default::Movie {title: 'Captain America: Civil War'}, default::Movie {title: 'Spider-Man: No Way Home'}, default::Movie {title: 'Thor: Ragnarok'}, default::Movie {title: 'Spider-Man: Homecoming'}, }, }, { key: {decade: 202, has_colon: false}, grouping: {'decade', 'has_colon'}, elements: {default::Movie {title: 'Black Widow'}}, }, } Once we break a set into partitions, we can also use :ref:`aggregate ` functions to provide some analytics about the data. For example, for the above partitioning (by decade and presence of ``:`` in the ``title``) we can calculate how many movies are in each subset as well as the average number of words in the movie titles: .. code-block:: edgeql-repl db> with ... # Apply the group query only to more recent movies ... M := (select Movie filter .release_year > 2015), ... groups := ( ... group M {title} ... using ... decade := .release_year // 10 - 200, ... has_colon := .title like '%:%' ... by decade, has_colon ... ) ... select groups { ... key := .key {decade, has_colon}, ... count := count(.elements), ... avg_words := math::mean( ... len(str_split(.elements.title, ' '))) ... }; { {key: {decade: 1, has_colon: false}, count: 2, avg_words: 3}, {key: {decade: 1, has_colon: true}, count: 4, avg_words: 3}, {key: {decade: 2, has_colon: false}, count: 1, avg_words: 2}, } .. note:: It is possible to produce results that are grouped in multiple different ways using :ref:`grouping sets `. This may be useful in more sophisticated analytics. .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > Commands > Group ` ================================================ FILE: docs/reference/edgeql/index.rst ================================================ .. versioned-section:: .. _ref_edgeql: ====== EdgeQL ====== .. toctree:: :maxdepth: 3 :hidden: literals sets paths types parameters select insert update delete for group with analyze path_resolution transactions EdgeQL is a next-generation query language designed to match SQL in power and surpass it in terms of clarity, brevity, and intuitiveness. It's used to query the database, insert/update/delete data, modify/introspect the schema, manage transactions, and more. Design goals ------------ EdgeQL is a spiritual successor to SQL designed with a few core principles in mind. **Compatible with modern languages**. A jaw-dropping amount of effort has been spent attempting to `bridge the gap `_ between the *relational* paradigm of SQL and the *object-oriented* nature of modern programming languages. Gel sidesteps this problem by modeling data in an *object-relational* way. **Strongly typed**. EdgeQL is *inextricably tied* to Gel's rigorous object-oriented type system. The type of all expressions is statically inferred by Gel. **Designed for programmers**. EdgeQL prioritizes syntax over keywords; It uses ``{ curly braces }`` to define scopes/structures and the *assignment operator* ``:=`` to set values. The result is a query language that looks more like code and less like word soup. .. All told, EdgeQL syntax contains roughly 180 .. reserved keywords; by comparison Postgres-flavored SQL contains `469 .. `_. .. **Compiles to SQL**. All EdgeQL queries, no matter how complex, compile to a .. single PostgreSQL query under the hood. With the exception of ``group by``, .. EdgeQL is equivalent to SQL in terms of power and expressivity. **Easy deep querying**. Gel's object-relational nature makes it painless to write deep, performant queries that traverse links, no ``JOINs`` required. **Composable**. `Unlike SQL `_, EdgeQL's syntax is readily composable; queries can be cleanly nested without worrying about Cartesian explosion. ================================================ FILE: docs/reference/edgeql/insert.rst ================================================ .. _ref_eql_insert: Insert ====== .. api-index:: insert, := The ``insert`` command is used to create instances of object types. The code samples on this page assume the following schema: .. code-block:: sdl module default { abstract type Person { required name: str { constraint exclusive }; } type Hero extending Person { secret_identity: str; multi villains := . insert Hero { ... name := "Spider-Man", ... secret_identity := "Peter Parker" ... }; {default::Hero {id: b0fbe9de-3e90-11ec-8c12-ffa2d5f0176a}} Similar to :ref:`selecting fields ` in ``select``, ``insert`` statements include a *shape* specified with ``curly braces``; the values of properties/links are assigned with the ``:=`` operator. Optional links or properties can be omitted entirely, as well as those with a ``default`` value (like ``id``). .. code-block:: edgeql-repl db> insert Hero { ... name := "Spider-Man" ... # secret_identity is omitted ... }; {default::Hero {id: b0fbe9de-3e90-11ec-8c12-ffa2d5f0176a}} You can only ``insert`` instances of concrete (non-abstract) object types. .. code-block:: edgeql-repl db> insert Person { ... name := "The Man With No Name" ... }; error: QueryError: cannot insert into abstract object type 'default::Person' By default, ``insert`` returns only the inserted object's ``id`` as seen in the examples above. If you want to get additional data back, you may wrap your ``insert`` with a ``select`` and apply a shape specifying any properties and links you want returned: .. code-block:: edgeql-repl db> select (insert Hero { ... name := "Spider-Man" ... # secret_identity is omitted ... }) {id, name}; { default::Hero { id: b0fbe9de-3e90-11ec-8c12-ffa2d5f0176a, name: "Spider-Man" } } You can use :ref:`ref_eql_with` to tidy this up if you prefer: .. code-block:: edgeql-repl db> with NewHero := (insert Hero { ... name := "Spider-Man" ... # secret_identity is omitted ... }) ... select NewHero { ... id, ... name, ... } { default::Hero { id: b0fbe9de-3e90-11ec-8c12-ffa2d5f0176a, name: "Spider-Man" } } .. _ref_eql_insert_links: Inserting links --------------- EdgeQL's composable syntax makes link insertion painless. Below, we insert "Spider-Man: No Way Home" and include all known heroes and villains as ``characters`` (which is basically true). .. code-block:: edgeql-repl db> insert Movie { ... title := "Spider-Man: No Way Home", ... release_year := 2021, ... characters := ( ... select Person ... filter .name in { ... 'Spider-Man', ... 'Doctor Strange', ... 'Doc Ock', ... 'Green Goblin' ... } ... ) ... }; {default::Movie {id: 9b1cf9e6-3e95-11ec-95a2-138eeb32759c}} To assign to the ``Movie.characters`` link, we're using a *subquery*. This subquery is executed and resolves to a set of type ``Person``, which is assignable to ``characters``. Note that the inner ``select Person`` statement is wrapped in parentheses; this is required for all subqueries in EdgeQL. Now let's assign to a *single link*. .. code-block:: edgeql-repl db> insert Villain { ... name := "Doc Ock", ... nemesis := (select Hero filter .name = "Spider-Man") ... }; This query is valid because the inner subquery is guaranteed to return at most one ``Hero`` object, due to the uniqueness constraint on ``Hero.name``. If you are filtering on a non-exclusive property, use ``assert_single`` to guarantee that the subquery will return zero or one results. If more than one result is returned, this query will fail at runtime. .. code-block:: edgeql-repl db> insert Villain { ... name := "Doc Ock", ... nemesis := assert_single(( ... select Hero ... filter .secret_identity = "Peter B. Parker" ... )) ... }; .. _ref_eql_insert_nested: Nested inserts -------------- Just as we used subqueries to populate links with existing objects, we can also execute *nested inserts*. .. code-block:: edgeql-repl db> insert Villain { ... name := "The Mandarin", ... nemesis := (insert Hero { ... name := "Shang-Chi", ... secret_identity := "Shaun" ... }) ... }; {default::Villain {id: d47888a0-3e7b-11ec-af13-fb68c8777851}} Now let's write a nested insert for a ``multi`` link. .. code-block:: edgeql-repl db> insert Movie { ... title := "Black Widow", ... release_year := 2021, ... characters := { ... (select Hero filter .name = "Black Widow"), ... (insert Hero { name := "Yelena Belova"}), ... (insert Villain { ... name := "Dreykov", ... nemesis := (select Hero filter .name = "Black Widow") ... }) ... } ... }; {default::Movie {id: af706c7c-3e98-11ec-abb3-4bbf3f18a61a}} We are using :ref:`set literal syntax ` to construct a set literal containing several ``select`` and ``insert`` subqueries. This set contains a mix of ``Hero`` and ``Villain`` objects; since these are both subtypes of ``Person`` (the expected type of ``Movie.characters``), this is valid. You also can't *assign* to a computed property or link; these fields don't actually exist in the database. .. code-block:: edgeql-repl db> insert Hero { ... name := "Ant-Man", ... villains := (select Villain) ... }; error: QueryError: modification of computed link 'villains' of object type 'default::Hero' is prohibited .. _ref_eql_insert_with: With block ---------- .. api-index:: with In the previous query, we selected Black Widow twice: once in the ``characters`` set and again as the ``nemesis`` of Dreykov. In circumstances like this, pulling a subquery into a ``with`` block lets you avoid duplication. .. code-block:: edgeql-repl db> with black_widow := (select Hero filter .name = "Black Widow") ... insert Movie { ... title := "Black Widow", ... release_year := 2021, ... characters := { ... black_widow, ... (insert Hero { name := "Yelena Belova"}), ... (insert Villain { ... name := "Dreykov", ... nemesis := black_widow ... }) ... } ... }; {default::Movie {id: af706c7c-3e98-11ec-abb3-4bbf3f18a61a}} The ``with`` block can contain an arbitrary number of clauses; later clauses can reference earlier ones. .. code-block:: edgeql-repl db> with ... black_widow := (select Hero filter .name = "Black Widow"), ... yelena := (insert Hero { name := "Yelena Belova"}), ... dreykov := (insert Villain {name := "Dreykov", nemesis := black_widow}) ... insert Movie { ... title := "Black Widow", ... release_year := 2021, ... characters := { black_widow, yelena, dreykov } ... }; {default::Movie {id: af706c7c-3e98-11ec-abb3-4bbf3f18a61a}} .. _ref_eql_insert_conflicts: Conflicts --------- .. api-index:: unless conflict on, else |Gel| provides a general-purpose mechanism for gracefully handling possible exclusivity constraint violations. Consider a scenario where we are trying to ``insert`` Eternals (the ``Movie``), but we can't remember if it already exists in the database. .. code-block:: edgeql-repl db> insert Movie { ... title := "Eternals", ... release_year := 2021 ... } ... unless conflict on .title ... else (select Movie); {default::Movie {id: af706c7c-3e98-11ec-abb3-4bbf3f18a61a}} This query attempts to ``insert`` Eternals. If it already exists in the database, it will violate the uniqueness constraint on ``Movie.title``, causing a *conflict* on the ``title`` field. The ``else`` clause is then executed and returned instead. In essence, ``unless conflict`` lets us "catch" exclusivity conflicts and provide a fallback expression. .. note:: Note that the ``else`` clause is simply ``select Movie``. There's no need to apply additional filters on ``Movie``; in the context of the ``else`` clause, ``Movie`` is bound to the conflicting object. .. note:: Using ``unless conflict`` on :ref:`multi properties ` is only supported in 2.10 and later. .. _ref_eql_upsert: Upserts ^^^^^^^ There are no limitations on what the ``else`` clause can contain; it can be any EdgeQL expression, including an :ref:`update ` statement. This lets you express *upsert* logic in a single EdgeQL query. .. code-block:: edgeql-repl db> with ... title := "Eternals", ... release_year := 2021 ... insert Movie { ... title := title, ... release_year := release_year ... } ... unless conflict on .title ... else ( ... update Movie set { release_year := release_year } ... ); {default::Movie {id: f1bf5ac0-3e9d-11ec-b78d-c7dfb363362c}} When a conflict occurs during the initial ``insert``, the statement falls back to the ``update`` statement in the ``else`` clause. This updates the ``release_year`` of the conflicting object. .. note:: It can be useful to know the outcome of an upsert. Here's an example showing how you can return that: .. code-block:: edgeql-repl db> with ... title := "Eternals", ... release_year := 2021, ... movie := ( ... insert Movie { ... title := title, ... release_year := release_year ... } ... unless conflict on .title ... else ( ... update Movie set { release_year := release_year } ... ) ... ) ... select movie { ... is_new := (movie not in Movie) ... }; {default::Movie {is_new: true}} This technique exploits the fact that a ``select`` will not return an object inserted in the same query. We know that, if the record exists, we updated it. If it does not, we inserted it. By wrapping your upsert in a ``select`` and putting a shape on it that queries for the object and returns whether or not it exists (as ``is_new``, in this example), you can easily see whether the object was inserted or updated. If you want to also return some of the ``Movie`` object's data, drop additional property names into the shape alongside ``is_new``. If you're on 3.0+, you can add ``Movie.*`` to the shape alongside ``is_new`` to get back all of the ``Movie`` object's properties. You could even silo the data off, keeping it separate from the ``is_new`` computed value like this: .. code-block:: edgeql-repl db> with ... title := "Eternals", ... release_year := 2021, ... movie := ( ... insert Movie { ... title := title, ... release_year := release_year ... } ... unless conflict on .title ... else ( ... update Movie set { release_year := release_year } ... ) ... ) ... select { ... data := (select movie {*}), ... is_new := (movie not in Movie) ... }; { { data: { default::Movie { id: 6880d0ba-62ca-11ee-9608-635818746433, release_year: 2021, title: 'Eternals' } }, is_new: false } } Suppressing failures ^^^^^^^^^^^^^^^^^^^^ .. api-index:: unless conflict The ``else`` clause is optional; when omitted, the ``insert`` statement will return an *empty set* if a conflict occurs. This is a common way to prevent ``insert`` queries from failing on constraint violations. .. code-block:: edgeql-repl db> insert Hero { name := "The Wasp" } # initial insert ... unless conflict; {default::Hero {id: 35b97a92-3e9b-11ec-8e39-6b9695d671ba}} db> insert Hero { name := "The Wasp" } # The Wasp now exists ... unless conflict; {} .. _ref_eql_insert_bulk: Bulk inserts ------------ Bulk inserts are performed by passing in a JSON array as a :ref:`query parameter `, :eql:func:`unpacking ` it, and using a :ref:`for loop ` to insert the objects. .. code-block:: edgeql-repl db> with ... raw_data := $data, ... for item in json_array_unpack(raw_data) union ( ... insert Hero { name := item['name'] } ... ); Parameter $data: [{"name":"Sersi"},{"name":"Ikaris"},{"name":"Thena"}] { default::Hero {id: 35b97a92-3e9b-11ec-8e39-6b9695d671ba}, default::Hero {id: 35b97a92-3e9b-11ec-8e39-6b9695d671ba}, default::Hero {id: 35b97a92-3e9b-11ec-8e39-6b9695d671ba}, ... } .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > Commands > Insert ` * - :ref:`Cheatsheets > Inserting data ` ================================================ FILE: docs/reference/edgeql/literals.rst ================================================ .. _ref_eql_literals: Literals ======== .. index:: primitive types EdgeQL is *inextricably tied* to Gel's rigorous type system. Below is an overview of how to declare a literal value of each *primitive type*. Click a link in the left column to jump to the associated section. .. list-table:: * - :ref:`String ` - ``str`` * - :ref:`Boolean ` - ``bool`` * - :ref:`Numbers ` - ``int16`` ``int32`` ``int64`` ``float32`` ``float64`` ``bigint`` ``decimal`` * - :ref:`UUID ` - ``uuid`` * - :ref:`Enums ` - ``enum`` * - :ref:`Dates and times ` - ``datetime`` ``duration`` ``cal::local_datetime`` ``cal::local_date`` ``cal::local_time`` ``cal::relative_duration`` * - :ref:`Durations ` - ``duration`` ``cal::relative_duration`` ``cal::date_duration`` * - :ref:`Ranges ` - ``range`` * - :ref:`Bytes ` - ``bytes`` * - :ref:`Arrays ` - ``array`` * - :ref:`Tuples ` - ``tuple`` or ``tuple`` * - :ref:`JSON ` - ``json`` .. _ref_eql_literal_strings: Strings ------- .. index:: unicode, quotes, raw strings, escape character .. api-index:: str, r'', r"", $$, $§label§$, \\§char§ The :eql:type:`str` type is a variable-length string of Unicode characters. A string can be declared with either single or double quotes. .. code-block:: edgeql-repl db> select 'I ❤️ EdgeQL'; {'I ❤️ EdgeQL'} db> select "hello there!"; {'hello there!'} db> select 'hello\nthere!'; {'hello there!'} db> select 'hello ... there!'; {'hello there!'} db> select r'hello ... there!'; # multiline {'hello there!'} There is a special syntax for declaring "raw strings". Raw strings treat the backslash ``\`` as a literal character instead of an escape character. .. code-block:: edgeql-repl db> select r'hello\nthere'; # raw string {r'hello\\nthere'} db> select $$one ... two ... three$$; # multiline raw string {'one two three'} db> select $label$You can add an interstitial label ... if you need to use "$$" in your string.$label$; { 'You can add an interstital label if you need to use "$$" in your string.', } EdgeQL contains a set of built-in functions and operators for searching, comparing, and manipulating strings. .. code-block:: edgeql-repl db> select 'hellothere'[5:10]; {'there'} db> select 'hello' ++ 'there'; {'hellothere'} db> select len('hellothere'); {10} db> select str_trim(' hello there '); {'hello there'} db> select str_split('hello there', ' '); {['hello', 'there']} For a complete reference on strings, see :ref:`Standard Library > String ` or click an item below. .. list-table:: * - Indexing and slicing - :eql:op:`str[i] ` :eql:op:`str[from:to] ` * - Concatenation - :eql:op:`str ++ str ` * - Utilities - :eql:func:`len` * - Transformation functions - :eql:func:`str_split` :eql:func:`str_lower` :eql:func:`str_upper` :eql:func:`str_title` :eql:func:`str_pad_start` :eql:func:`str_pad_end` :eql:func:`str_trim` :eql:func:`str_trim_start` :eql:func:`str_trim_end` :eql:func:`str_repeat` * - Comparison operators - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` * - Search - :eql:func:`contains` :eql:func:`find` * - Pattern matching and regexes - :eql:op:`str like pattern ` :eql:op:`str ilike pattern ` :eql:func:`re_match` :eql:func:`re_match_all` :eql:func:`re_replace` :eql:func:`re_test` .. _ref_eql_literal_boolean: Booleans -------- .. api-index:: bool The :eql:type:`bool` type represents a true/false value. .. code-block:: edgeql-repl db> select true; {true} db> select false; {false} |Gel| provides a set of operators that operate on boolean values. .. list-table:: * - Comparison operators - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` * - Logical operators - :eql:op:`or` :eql:op:`and` :eql:op:`not` * - Aggregation - :eql:func:`all` :eql:func:`any` .. _ref_eql_literal_numbers: Numbers ------- There are several numerical types in Gel's type system. .. list-table:: * - :eql:type:`int16` - 16-bit integer * - :eql:type:`int32` - 32-bit integer * - :eql:type:`int64` - 64-bit integer * - :eql:type:`float32` - 32-bit floating point number * - :eql:type:`float64` - 64-bit floating point number * - :eql:type:`bigint` - Arbitrary precision integer. * - :eql:type:`decimal` - Arbitrary precision number. Number literals that *do not* contain a decimal are interpreted as ``int64``. Numbers containing decimals are interpreted as ``float64``. The ``n`` suffix designates a number with *arbitrary precision*: either ``bigint`` or ``decimal``. ====================================== ============================= Syntax Inferred type ====================================== ============================= :eql:code:`select 3;` :eql:type:`int64` :eql:code:`select 3.14;` :eql:type:`float64` :eql:code:`select 314e-2;` :eql:type:`float64` :eql:code:`select 42n;` :eql:type:`bigint` :eql:code:`select 42.0n;` :eql:type:`decimal` :eql:code:`select 42e+100n;` :eql:type:`decimal` ====================================== ============================= To declare an ``int16``, ``int32``, or ``float32``, you must provide an explicit type cast. For details on type casting, see :ref:`Casting `. ====================================== ============================= Syntax Type ====================================== ============================= :eql:code:`select 1234;` :eql:type:`int16` :eql:code:`select 123456;` :eql:type:`int32` :eql:code:`select 123.456;` :eql:type:`float32` ====================================== ============================= EdgeQL includes a full set of arithmetic and comparison operators. Parentheses can be used to indicate the order-of-operations or visually group subexpressions; this is true across all EdgeQL queries. .. code-block:: edgeql-repl db> select 5 > 2; {true} db> select 2 + 2; {4} db> select 2 ^ 10; {1024} db> select (1 + 1) * 2 / (3 + 8); {0.36363636363636365} EdgeQL provides a comprehensive set of built-in functions and operators on numerical data. .. list-table:: * - Comparison operators - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` * - Arithmetic - :eql:op:`+ ` :eql:op:`- ` :eql:op:`- ` :eql:op:`* ` :eql:op:`/
` :eql:op:`// ` :eql:op:`% ` :eql:op:`^ ` * - Statistics - :eql:func:`sum` :eql:func:`min` :eql:func:`max` :eql:func:`math::mean` :eql:func:`math::stddev` :eql:func:`math::stddev_pop` :eql:func:`math::var` :eql:func:`math::var_pop` * - Math - :eql:func:`round` :eql:func:`math::abs` :eql:func:`math::ceil` :eql:func:`math::floor` :eql:func:`math::ln` :eql:func:`math::lg` :eql:func:`math::log` * - Random number - :eql:func:`random` .. _ref_eql_literal_uuid: UUID ---- The :eql:type:`uuid` type is commonly used to represent object identifiers. UUID literal must be explicitly cast from a string value matching the UUID specification. .. code-block:: edgeql-repl db> select 'a5ea6360-75bd-4c20-b69c-8f317b0d2857'; {a5ea6360-75bd-4c20-b69c-8f317b0d2857} Generate a random UUID. .. code-blocK:: edgeql-repl db> select uuid_generate_v1mc(); {b4d94e6c-3845-11ec-b0f4-93e867a589e7} .. _ref_eql_literal_enum: Enums ----- .. api-index:: scalar type, extending enum Enum types must be :ref:`declared in your schema `. .. code-block:: sdl scalar type Color extending enum; Once declared, an enum literal can be declared with dot notation, or by casting an appropriate string literal: .. code-block:: edgeql-repl db> select Color.Red; {Red} db> select "Red"; {Red} .. _ref_eql_literal_dates: Dates and times --------------- |Gel's| typesystem contains several temporal types. .. list-table:: * - :eql:type:`datetime` - Timezone-aware point in time * - :eql:type:`cal::local_datetime` - Date and time w/o timezone * - :eql:type:`cal::local_date` - Date type * - :eql:type:`cal::local_time` - Time type All temporal literals are declared by casting an appropriately formatted string. .. code-block:: edgeql-repl db> select '1999-03-31T15:17:00Z'; {'1999-03-31T15:17:00Z'} db> select '1999-03-31T17:17:00+02'; {'1999-03-31T15:17:00Z'} db> select '1999-03-31T15:17:00'; {'1999-03-31T15:17:00'} db> select '1999-03-31'; {'1999-03-31'} db> select '15:17:00'; {'15:17:00'} EdgeQL supports a set of functions and operators on datetime types. .. list-table:: * - Comparison operators - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` * - Arithmetic - :eql:op:`dt + dt ` :eql:op:`dt - dt ` * - String parsing - :eql:func:`to_datetime` :eql:func:`cal::to_local_datetime` :eql:func:`cal::to_local_date` :eql:func:`cal::to_local_time` * - Component extraction - :eql:func:`datetime_get` :eql:func:`cal::time_get` :eql:func:`cal::date_get` * - Truncation - :eql:func:`datetime_truncate` * - System timestamps - :eql:func:`datetime_current` :eql:func:`datetime_of_transaction` :eql:func:`datetime_of_statement` .. _ref_eql_literal_durations: Durations --------- |Gel's| type system contains three duration types. .. list-table:: * - :eql:type:`duration` - Exact duration * - :eql:type:`cal::relative_duration` - Duration in relative units * - :eql:type:`cal::date_duration` - Duration in months and days only Exact durations ^^^^^^^^^^^^^^^ The :eql:type:`duration` type represents *exact* durations that can be represented by some fixed number of microseconds. It can be negative and it supports units of ``microseconds``, ``milliseconds``, ``seconds``, ``minutes``, and ``hours``. .. code-block:: edgeql-repl db> select '45.6 seconds'; {'0:00:45.6'} db> select '-15 microseconds'; {'-0:00:00.000015'} db> select '5 hours 4 minutes 3 seconds'; {'5:04:03'} db> select '8760 hours'; # about a year {'8760:00:00'} All temporal units beyond ``hour`` no longer correspond to a fixed duration of time; the length of a day/month/year/etc changes based on daylight savings time, the month in question, leap years, etc. Relative durations ^^^^^^^^^^^^^^^^^^ By contrast, the :eql:type:`cal::relative_duration` type represents a "calendar" duration, like ``1 month``. Because months have different number of days, ``1 month`` doesn't correspond to a fixed number of milliseconds, but it's often a useful quantity to represent recurring events, postponements, etc. .. note:: The ``cal::relative_duration`` type supports the same units as ``duration``, plus ``days``, ``weeks``, ``months``, ``years``, ``decades``, ``centuries``, and ``millennia``. To declare relative duration literals: .. code-block:: edgeql-repl db> select '15 milliseconds'; {'PT.015S'} db> select '2 months 3 weeks 45 minutes'; {'P2M21DT45M'} db> select '-7 millennia'; {'P-7000Y'} Date durations ^^^^^^^^^^^^^^ The :eql:type:`cal::date_duration` represents spans consisting of some number of *months* and *days*. This type is primarily intended to simplify logic involving :eql:type:`cal::local_date` values. .. code-block:: edgeql-repl db> select '5 days'; {'P5D'} db> select '2022-06-25' + '5 days'; {'2022-06-30'} db> select '2022-06-30' - '2022-06-25'; {'P5D'} EdgeQL supports a set of functions and operators on duration types. .. list-table:: * - Comparison operators - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` * - Arithmetic - :eql:op:`dt + dt ` :eql:op:`dt - dt ` * - Duration string parsing - :eql:func:`to_duration` :eql:func:`cal::to_relative_duration` :eql:func:`cal::to_date_duration` * - Component extraction - :eql:func:`duration_get` * - Conversion - :eql:func:`duration_truncate` :eql:func:`cal::duration_normalize_hours` :eql:func:`cal::duration_normalize_days` .. _ref_eql_ranges: Ranges ------ .. api-index:: range, inc_lower, inc_upper, empty Ranges represent a range of orderable scalar values. A range comprises a lower bound, upper bound, and two boolean flags indicating whether each bound is inclusive. Create a range literal with the ``range`` constructor function. .. code-block:: edgeql-repl db> select range(1, 10); {range(1, 10, inc_lower := true, inc_upper := false)} db> select range(2.2, 3.3); {range(2.2, 3.3, inc_lower := true, inc_upper := false)} Ranges can be *empty*, when the upper and lower bounds are equal. .. code-block:: edgeql-repl db> select range(1, 1); {range({}, empty := true)} Ranges can be *unbounded*. An empty set is used to indicate the lack of a particular upper or lower bound. .. code-block:: edgeql-repl db> select range(4, {}); {range(4, {})} db> select range({}, 4); {range({}, 4)} db> select range({}, {}); {range({}, {})} To compute the set of concrete values defined by a range literal, use ``range_unpack``. An empty range will unpack to the empty set. Unbounded ranges cannot be unpacked. .. code-block:: edgeql-repl db> select range_unpack(range(0, 10)); {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} db> select range_unpack(range(1, 1)); {} db> select range_unpack(range(0, {})); gel error: InvalidValueError: cannot unpack an unbounded range .. _ref_eql_literal_bytes: Bytes ----- .. index:: binary, raw byte strings .. api-index:: b'', b"", rb'', br'', rb"", br"" The ``bytes`` type represents raw binary data. .. code-block:: edgeql-repl db> select b'bina\\x01ry'; {b'bina\\x01ry'} There is a special syntax for declaring "raw byte strings". Raw byte strings treat the backslash ``\`` as a literal character instead of an escape character. .. code-block:: edgeql-repl db> select rb'hello\nthere'; {b'hello\\nthere'} db> select br'\'; {b'\\'} .. _ref_eql_literal_array: Arrays ------ .. index:: collection, lists An array is an *ordered* collection of values of the *same type*. For example: .. code-block:: edgeql-repl db> select [1, 2, 3]; {[1, 2, 3]} db> select ['hello', 'world']; {['hello', 'world']} db> select [(1, 2), (100, 200)]; {[(1, 2), (100, 200)]} EdgeQL provides a set of functions and operators on arrays. .. list-table:: * - Indexing and slicing - :eql:op:`array[i] ` :eql:op:`array[from:to] ` :eql:func:`array_get` * - Concatenation - :eql:op:`array ++ array ` * - Comparison operators - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` * - Utilities - :eql:func:`len` :eql:func:`array_join` * - Search - :eql:func:`contains` :eql:func:`find` * - Conversion to/from sets - :eql:func:`array_agg` :eql:func:`array_unpack` See :ref:`Standard Library > Array ` for a complete reference on array data types. .. _ref_eql_literal_tuple: Tuples ------ A tuple is *fixed-length*, *ordered* collection of values, each of which may have a *different type*. The elements of a tuple can be of any type, including scalars, arrays, other tuples, and object types. .. code-block:: edgeql-repl db> select ('Apple', 7, true); {('Apple', 7, true)} Optionally, you can assign a key to each element of a tuple. These are known as *named tuples*. You must assign keys to all or none of the elements; you can't mix-and-match. .. code-block:: edgeql-repl db> select (fruit := 'Apple', quantity := 3.14, fresh := true); {(fruit := 'Apple', quantity := 3.14, fresh := true)} Indexing tuples ^^^^^^^^^^^^^^^ Tuple elements can be accessed with dot notation. Under the hood, there's no difference between named and unnamed tuples. Named tuples support key-based and numerical indexing. .. code-block:: edgeql-repl db> select (1, 3.14, 'red').0; {1} db> select (1, 3.14, 'red').2; {'red'} db> select (name := 'george', age := 12).name; {('george')} db> select (name := 'george', age := 12).0; {('george')} .. important:: When you query an *unnamed* tuple using one of EdgeQL's :ref:`client libraries `, its value is converted to a list/array. When you fetch a *named tuple*, it is converted to an object/dictionary/hashmap. For a full reference on tuples, see :ref:`Standard Library > Tuple `. .. _ref_eql_literal_json: JSON ---- The :eql:type:`json` scalar type is a stringified representation of structured data. JSON literals are declared by explicitly casting other values or passing a properly formatted JSON string into :eql:func:`to_json`. Any type can be converted into JSON except :eql:type:`bytes`. .. code-block:: edgeql-repl db> select 5; {'5'} db> select "a string"; {'"a string"'} db> select ["this", "is", "an", "array"]; {'["this", "is", "an", "array"]'} db> select ("unnamed tuple", 2); {'["unnamed tuple", 2]'} db> select (name := "named tuple", count := 2); {'{ "name": "named tuple", "count": 2 }'} db> select to_json('{"a": 2, "b": 5}'); {'{"a": 2, "b": 5}'} JSON values support indexing operators. The resulting value is also of type ``json``. .. code-block:: edgeql-repl db> select to_json('{"a": 2, "b": 5}')['a']; {2} db> select to_json('["a", "b", "c"]')[2]; {'"c"'} EdgeQL supports a set of functions and operators on ``json`` values. Refer to the :ref:`Standard Library > JSON ` or click an item below for detailed documentation. .. list-table:: * - Indexing - :eql:op:`json[i] ` :eql:op:`json[from:to] ` :eql:op:`json[name] ` :eql:func:`json_get` * - Merging - :eql:op:`json ++ json ` * - Comparison operators - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` * - Conversion to/from strings - :eql:func:`to_json` :eql:func:`to_str` * - Conversion to/from sets - :eql:func:`json_array_unpack` :eql:func:`json_object_unpack` * - Introspection - :eql:func:`json_typeof` ================================================ FILE: docs/reference/edgeql/parameters.rst ================================================ .. _ref_eql_params: Parameters ========== .. index:: query params, query arguments, query args, input .. api-index:: $, <§type§>$ :edb-alt-title: Query Parameters EdgeQL queries can reference parameters with ``$`` notation. The value of these parameters are supplied externally. .. code-block:: edgeql select $var; select $a + $b; select BlogPost filter .id = $blog_id; Note that we provided an explicit type cast before the parameter. This is required, as it enables Gel to enforce the provided types at runtime. Parameters can be named or unnamed tuples. .. code-block:: edgeql select >$var; select >$var; select >$var; select >$var; Usage with clients ------------------ REPL ^^^^ When you include a parameter reference in a Gel REPL, you'll be prompted interactively to provide a value or values. .. code-block:: edgeql-repl db> select 'I ❤️ ' ++ $var ++ '!'; Parameter $var: Gel {'I ❤️ Gel!'} Python ^^^^^^ .. code-block:: python await client.query( "select 'I ❤️ ' ++ $var ++ '!';", var="lamp") await client.query( "select $date;", date=datetime.today()) JavaScript ^^^^^^^^^^ .. code-block:: javascript await client.query("select 'I ❤️ ' ++ $name ++ '!';", { name: "rock and roll" }); await client.query("select $date;", { date: new Date() }); Go ^^ .. code-block:: go var result string err = db.QuerySingle(ctx, `select 'I ❤️ ' ++ $var ++ '!';"`, &result, "Golang") var date time.Time err = db.QuerySingle(ctx, `select $date;`, &date, time.Now()) Refer to the Datatypes page of your preferred :ref:`client library ` to learn more about mapping between Gel types and language-native types. .. _ref_eql_params_types: Parameter types and JSON ------------------------ In Gel, parameters can also be tuples. If you need to pass complex structures as parameters, use Gel's built-in :ref:`JSON ` functionality. .. code-block:: edgeql-repl db> with data := $data ... insert Movie { ... title := data['title'], ... release_year := data['release_year'], ... }; Parameter $data: {"title": "The Marvels", "release_year": 2023} {default::Movie {id: 8d286cfe-3c0a-11ec-aa68-3f3076ebd97f}} Arrays can be "unpacked" into sets and assigned to ``multi`` links or properties. .. code-block:: edgeql with friends := ( select User filter .id in array_unpack(>$friend_ids) ) insert User { name := $name, friends := friends, }; .. _ref_eql_params_optional: Optional parameters ------------------- .. api-index:: $ By default, query parameters are ``required``; the query will fail if the parameter value is an empty set. You can use an ``optional`` modifier inside the type cast if the parameter is optional. .. code-block:: edgeql-repl db> select $name; Parameter $name (Ctrl+D for empty set `{}`): {} .. note:: The ```` type cast is also valid (though redundant) syntax. .. code-block:: edgeql select $name; Default parameter values ------------------------ .. api-index:: ?? When using optional parameters, you may want to provide a default value to use in case the parameter is not passed. You can do this by using the :eql:op:`?? (coalesce) ` operator. .. code-block:: edgeql-repl db> select 'Hello ' ++ $name ?? 'there'; Parameter $name (Ctrl+D for empty set `{}`): Gel {'Hello Gel'} db> select 'Hello ' ++ $name ?? 'there'; Parameter $name (Ctrl+D for empty set `{}`): {'Hello there'} What can be parameterized? -------------------------- Any data manipulation language (DML) statement can be parameterized: ``select``, ``insert``, ``update``, and ``delete``. Since parameters can only be scalars, arrays of scalars, and tuples of scalars, only parts of the query that would be one of those types can be parameterized. This excludes parts of the query like the type being queried and the property to order by. .. note:: You can parameterize ``order by`` for a limited number of options by using :eql:op:`if..else`: .. code-block:: edgeql select Movie {*} order by (.title if $order_by = 'title' else {}) then (.release_year if $order_by = 'release_year' else {}); If a user running this query enters ``title`` as the parameter value, ``Movie`` objects will be sorted by their ``title`` property. If they enter ``release_year``, they will be sorted by the ``release_year`` property. Since the ``if`` and ``else`` result clauses need to be of compatible types, your ``else`` expressions should be an empty set of the same type as the property. Schema definition language (SDL) and :ref:`configure ` statements **cannot** be parameterized. Data definition language (DDL) has limited support for parameters, but it's not a recommended pattern. Some of the limitations might be lifted in future versions. ================================================ FILE: docs/reference/edgeql/path_resolution.rst ================================================ .. _ref_eql_path_resolution: ============ Path scoping ============ .. index:: using future simple_scoping, using future warn_old_scoping Beginning with Gel 6.0, we are phasing out our historical (and somewhat notorious) :ref:`"path scoping" algorithm ` in favor of a much simpler algorithm that nevertheless behaves identically on *most* idiomatic EdgeQL queries. Gel 6.0 will contain features to support migration to and testing of the new semantics. We expect the migration to be relatively painless for most users. Discussion of rationale for this change is available in `the RFC `_. New path scoping ---------------- .. versionadded:: 6.0 When applying a shape to a path (or to a path that has shapes applied to it already), the path will be be bound inside computed pointers in that shape: .. code-block:: edgeql-repl db> select User { ... name := User.first_name ++ ' ' ++ User.last_name ... } {User {name: 'Peter Parker'}, User {name: 'Tony Stark'}} When doing ``SELECT``, ``UPDATE``, or ``DELETE``, if the subject is a path, optionally with shapes applied to it, the path will be bound in ``FILTER`` and ``ORDER BY`` clauses: .. code-block:: edgeql-repl db> select User { ... name := User.first_name ++ ' ' ++ User.last_name ... } ... filter User.first_name = 'Peter' {User {name: 'Peter Parker'}} However, when a path is used multiple times in "sibling" contexts, a cross-product will be computed: .. code-block:: edgeql-repl db> select User.first_name ++ ' ' ++ User.last_name; {'Peter Parker', 'Peter Stark', 'Tony Parker', 'Tony Stark'} If you want to produce one value per ``User``, you can rewrite the query with a ``FOR`` to make the intention explicit: .. code-block:: edgeql-repl db> for u in User ... select u.first_name ++ ' ' ++ u.last_name; {'Peter Parker', 'Tony Stark'} The most idiomatic way to fetch such data in EdgeQL, however, remains: .. code-block:: edgeql-repl db> select User { name := .first_name ++ ' ' ++ .last_name } {User {name: 'Peter Parker'}, User {name: 'Tony Stark'}} (And, of course, you probably `shouldn't have first_name and last_name properties anyway `_) Path scoping configuration -------------------------- .. versionadded:: 6.0 Gel 6.0 introduces a new :ref:`future feature ` named ``simple_scoping`` alongside a configuration setting also named ``simple_scoping``. The future feature presence will determine which behavior is used inside expressions within the schema, as well as serve as the default value if the configuration value is not set. The configuration setting will allow overriding the presence or absence of the feature. For concreteness, here are all of the posible combinations of whether ``using future simple_scoping`` is set and the value of the configuration value ``simple_scoping``: .. list-table:: :widths: 25 25 25 25 :header-rows: 1 * - Future exists? - Config value - Query is simply scoped - Schema is simply scoped * - No - ``{}`` - No - No * - No - ``true`` - Yes - No * - No - ``false`` - No - No * - Yes - ``{}`` - Yes - Yes * - Yes - ``true`` - Yes - Yes * - Yes - ``false`` - No - Yes .. _ref_warn_old_scoping: Warning on old scoping ---------------------- .. versionadded:: 6.0 To make the migration process safer, we have also introduced a ``warn_old_scoping`` :ref:`future feature ` and config setting. When active, the server will emit a warning to the client when a query is detected to depend on the old scoping behavior. The behavior of warnings can be configured in client bindings, but by default they are logged. The check is known to sometimes produce false positives, on queries that will not actually have changed behavior, but is intended to not have false negatives. Recommended upgrade plan ------------------------ .. versionadded:: 6.0 The safest approach is to first get your entire schema and application working with ``warn_old_scoping`` without producing any warnings. Once that is done, it should be safe to switch to ``simple_scoping`` without changes in behavior. If you are very confident in your test coverage, though, you can try skipping dealing with ``warn_old_scoping`` and go straight to ``simple_scoping``. There are many different potential migration strategies. One that should work well: 1. Run ``CONFIGURE CURRENT DATABASE SET warn_old_scoping := true`` 2. Try running all of your queries against the database. 3. Fix any that produce warnings. 4. Adjust your schema until setting ``using future warn_old_scoping`` works without producing warnings. If you wish to proceed incrementally with steps 2 and 3, you can configure ``warn_old_scoping`` in your clients, having it enabled for queries that you have verified work with it and disabled for queries that have not yet been verified or updated. .. _ref_eql_old_path_resolution: Legacy path scoping ------------------- This section describes the path scoping algorithm used exclusively until |EdgeDB| 5.0 and by default in |Gel| 6.0. It will be removed in Gel 7.0. Element-wise operations with multiple arguments in Gel are generally applied to the :ref:`cartesian product ` of all the input sets. .. code-block:: edgeql-repl db> select {'aaa', 'bbb'} ++ {'ccc', 'ddd'}; {'aaaccc', 'aaaddd', 'bbbccc', 'bbbddd'} However, in cases where multiple element-wise arguments share a common path (``User.`` in this example), Gel factors out the common path rather than using cartesian multiplication. .. code-block:: edgeql-repl db> select User.first_name ++ ' ' ++ User.last_name; {'Mina Murray', 'Jonathan Harker', 'Lucy Westenra', 'John Seward'} We assume this is what you want, but if your goal is to get the cartesian product, you can accomplish it one of three ways. You could use :eql:op:`detached`. .. code-block:: edgeql-repl gel> select User.first_name ++ ' ' ++ detached User.last_name; { 'Mina Murray', 'Mina Harker', 'Mina Westenra', 'Mina Seward', 'Jonathan Murray', 'Jonathan Harker', 'Jonathan Westenra', 'Jonathan Seward', 'Lucy Murray', 'Lucy Harker', 'Lucy Westenra', 'Lucy Seward', 'John Murray', 'John Harker', 'John Westenra', 'John Seward', } You could use :ref:`with ` to attach a different symbol to your set of ``User`` objects. .. code-block:: edgeql-repl gel> with U := User .... select U.first_name ++ ' ' ++ User.last_name; { 'Mina Murray', 'Mina Harker', 'Mina Westenra', 'Mina Seward', 'Jonathan Murray', 'Jonathan Harker', 'Jonathan Westenra', 'Jonathan Seward', 'Lucy Murray', 'Lucy Harker', 'Lucy Westenra', 'Lucy Seward', 'John Murray', 'John Harker', 'John Westenra', 'John Seward', } Or you could leverage the effect scopes have on path resolution. More on that :ref:`in the Scopes section `. The reason ``with`` works here even though the alias ``U`` refers to the exact same set is that we only assume you want the path factored in this way when you use the same *symbol* to refer to a set. This means operations with ``User.first_name`` and ``User.last_name`` *do* get the common path factored while ``U.first_name`` and ``User.last_name`` *do not* and are resolved with cartesian multiplication. That may leave you still wondering why ``U`` and ``User`` did not get a common path factored. ``U`` is just an alias of ``select User`` and ``User`` is the same symbol that we use in our name query. That's true, but |Gel| doesn't factor in this case because of the queries' scopes. .. _ref_eql_path_resolution_scopes: Scopes ------ Scopes change the way path resolution works. Two sibling select queries — that is, queries at the same level — do not have their paths factored even when they use a common symbol. .. code-block:: edgeql-repl gel> select ((select User.first_name), (select User.last_name)); { ('Mina', 'Murray'), ('Mina', 'Harker'), ('Mina', 'Westenra'), ('Mina', 'Seward'), ('Jonathan', 'Murray'), ('Jonathan', 'Harker'), ('Jonathan', 'Westenra'), ('Jonathan', 'Seward'), ('Lucy', 'Murray'), ('Lucy', 'Harker'), ('Lucy', 'Westenra'), ('Lucy', 'Seward'), ('John', 'Murray'), ('John', 'Harker'), ('John', 'Westenra'), ('John', 'Seward'), } Common symbols in nested scopes *are* factored when they use the same symbol. In this example, the nested queries both use the same ``User`` symbol as the top-level query. As a result, the ``User`` in those queries refers to a single object because it has been factored. .. code-block:: edgeql-repl gel> select User { .... name:= (select User.first_name) ++ ' ' ++ (select User.last_name) .... }; { default::User {name: 'Mina Murray'}, default::User {name: 'Jonathan Harker'}, default::User {name: 'Lucy Westenra'}, default::User {name: 'John Seward'}, } If you have two common scopes and only *one* of them is in a nested scope, the paths are still factored. .. code-block:: edgeql-repl gel> select (Person.name, count(Person.friends)); {('Fran', 3), ('Bam', 2), ('Emma', 3), ('Geoff', 1), ('Tyra', 1)} In this example, ``count``, like all aggregate function, creates a nested scope, but this doesn't prevent the paths from being factored as you can see from the results. If the paths were *not* factored, the friend count would be the same for all the result tuples and it would reflect the total number of ``Person`` objects that are in *all* ``friends`` links rather than the number of ``Person`` objects that are in the named ``Person`` object's ``friends`` link. If you have two aggregate functions creating *sibling* nested scopes, the paths are *not* factored. .. code-block:: edgeql-repl gel> select (array_agg(distinct Person.name), count(Person.friends)); {(['Fran', 'Bam', 'Emma', 'Geoff'], 3)} This query selects a tuple containing two nested scopes. Here, |Gel| assumes you want an array of all unique names and a count of the total number of people who are anyone's friend. Clauses & Nesting ^^^^^^^^^^^^^^^^^ Most clauses are nested and are subjected to the same rules described above: common symbols are factored and assumed to refer to the same object as the outer query. This is because clauses like :ref:`filter ` and :ref:`order by ` need to be applied to each value in the result. The :ref:`offset ` and :ref:`limit ` clauses are not nested in the scope because they need to be applied globally to the entire result set of your query. .. _rfc: https://github.com/geldata/rfcs/blob/master/text/1027-no-factoring.rst ================================================ FILE: docs/reference/edgeql/paths.rst ================================================ .. _ref_eql_paths: ===== Paths ===== .. index:: links, relations A *path expression* (or simply a *path*) represents a set of values that are reachable by traversing a given sequence of links or properties from some source set of objects. Consider the following schema: .. code-block:: sdl type User { required email: str; multi friends: User; } type BlogPost { required title: str; required author: User; } type Comment { required text: str; required author: User; } A few simple inserts will allow some experimentation with paths. Start with a first user: .. code-block:: edgeql-repl db> insert User { ... email := "user1@me.com", ... }; Along comes another user who adds the first user as a friend: .. code-block:: edgeql-repl db> insert User { ... email := "user2@me.com", ... friends := (select detached User filter .email = "user1@me.com") ... }; The first user reciprocates, adding the new user as a friend: .. code-block:: edgeql-repl db> update User filter .email = "user1@me.com" ... set { ... friends += (select detached User filter .email = "user2@me.com") ... }; The second user writes a blog post about how nice Gel is: .. code-block:: edgeql-repl db> insert BlogPost { ... title := "Gel is awesome", ... author := assert_single((select User filter .email = "user2@me.com")) ... }; And the first user follows it up with a comment below the post: .. code-block:: edgeql-repl db> insert Comment { ... text := "Nice post, user2!", ... author := assert_single((select User filter .email = "user1@me.com")) ... }; The simplest path is simply ``User``. This is a :ref:`set reference ` that refers to all ``User`` objects in the database. .. code-block:: edgeql select User; Paths can traverse links. The path below refers to *all Users who are the friend of another User*. .. code-block:: edgeql select User.friends; Paths can traverse to an arbitrary depth in a series of nested links. Both ``select`` queries below end up showing the author of the ``BlogPost``. The second query returns the friends of the friends of the author of the ``BlogPost``, which in this case is just the author. .. code-block:: edgeql select BlogPost.author; # The author select BlogPost.author.friends.friends; # The author again Paths can terminate with a property reference. .. code-block:: edgeql select BlogPost.title; # all blog post titles select BlogPost.author.email; # all author emails select User.friends.email; # all friends' emails .. _ref_eql_paths_backlinks: Backlinks --------- .. api-index:: .< All examples thus far have traversed links in the *forward direction*, however it's also possible to traverse links *backwards* with ``.<`` notation. These are called **backlinks**. Starting from each user, the path below traverses all *incoming* links labeled ``author`` and returns the union of their sources. .. code-block:: edgeql select User.` operator: ``[is Foo]``: .. code-block:: edgeql # BlogPost objects that link to the user via a link named author select User.` with ``@`` notation. To demonstrate this, let's add a property to the ``User. friends`` link: .. code-block:: sdl-diff type User { required email: str; - multi friends: User; + multi friends: User { + since: cal::local_date; + } } The following represents a set of all dates on which friendships were formed. .. code-block:: edgeql select User.friends@since; Path roots ---------- For simplicity, all examples above use set references like ``User`` as the root of the path; however, the root can be *any expression* returning object types. Below, the root of the path is a *subquery*. .. code-block:: edgeql-repl db> with gel_lovers := ( ... select BlogPost filter .title ilike "Gel is awesome" ... ) ... select gel_lovers.author; This expression returns a set of all ``Users`` who have written a blog post titled "Gel is awesome". For a full syntax definition, see the :ref:`Reference > Paths `. ================================================ FILE: docs/reference/edgeql/select.rst ================================================ .. _ref_eql_select: Select ====== .. api-index:: select The ``select`` command retrieves or computes a set of values. We've already seen simple queries that select primitive values. .. code-block:: edgeql-repl db> select 'hello world'; {'hello world'} db> select [1, 2, 3]; {[1, 2, 3]} db> select {1, 2, 3}; {1, 2, 3} With the help of a ``with`` block, we can add filters, ordering, and pagination clauses. .. code-block:: edgeql-repl db> with x := {1, 2, 3, 4, 5} ... select x ... filter x >= 3; {3, 4, 5} db> with x := {1, 2, 3, 4, 5} ... select x ... order by x desc; {5, 4, 3, 2, 1} db> with x := {1, 2, 3, 4, 5} ... select x ... offset 1 limit 3; {2, 3, 4} These queries can also be rewritten to use inline aliases, like so: .. code-block:: edgeql-repl db> select x := {1, 2, 3, 4, 5} ... filter x >= 3; .. _ref_eql_select_objects: Selecting objects ----------------- However most queries are selecting *objects* that live in the database. For demonstration purposes, the queries below assume the following schema: .. code-block:: sdl module default { abstract type Person { required name: str { constraint exclusive }; } type Hero extending Person { secret_identity: str; multi villains := . insert Hero { ... name := "Spider-Man", ... secret_identity := "Peter Parker" ... }; {default::Hero {id: 6be1c9c6...}} db> insert Hero { ... name := "Iron Man", ... secret_identity := "Tony Stark" ... }; {default::Hero {id: 6bf7115a... }} db> for n in { "Sandman", "Electro", "Green Goblin", "Doc Ock" } ... union ( ... insert Villain { ... name := n, ... nemesis := (select Hero filter .name = "Spider-Man") ... }); { default::Villain {id: 6c22bdf0...}, default::Villain {id: 6c22c3d6...}, default::Villain {id: 6c22c46c...}, default::Villain {id: 6c22c502...}, } db> insert Villain { ... name := "Obadiah Stane", ... nemesis := (select Hero filter .name = "Iron Man") ... }; {default::Villain {id: 6c42c4ec...}} db> insert Movie { ... title := "Spider-Man: No Way Home", ... release_year := 2021, ... characters := (select Person filter .name in ... { "Spider-Man", "Sandman", "Electro", "Green Goblin", "Doc Ock" }) ... }; {default::Movie {id: 6c60c28a...}} db> insert Movie { ... title := "Iron Man", ... release_year := 2008, ... characters := (select Person filter .name in ... { "Iron Man", "Obadiah Stane" }) ... }; {default::Movie {id: 6d1f430e...}} Let's start by selecting all ``Villain`` objects in the database. In this example, there are only five. Remember, ``Villain`` is a :ref:`reference ` to the set of all Villain objects. .. code-block:: edgeql-repl db> select Villain; { default::Villain {id: 6c22bdf0...}, default::Villain {id: 6c22c3d6...}, default::Villain {id: 6c22c46c...}, default::Villain {id: 6c22c502...}, default::Villain {id: 6c42c4ec...}, } .. note:: For the sake of readability, the ``id`` values have been truncated. By default, this only returns the ``id`` of each object. If serialized to JSON, this result would look like this: .. code-block:: [ {"id": "6c22bdf0-5c03-11ee-99ff-dfaea4d947ce"}, {"id": "6c22c3d6-5c03-11ee-99ff-734255881e5d"}, {"id": "6c22c46c-5c03-11ee-99ff-c79f24cf638b"}, {"id": "6c22c502-5c03-11ee-99ff-cbacc3918129"}, {"id": "6c42c4ec-5c03-11ee-99ff-872c9906a467"} ] .. _ref_eql_shapes: Shapes ------ .. api-index:: select, { } To specify which properties to select, we attach a **shape** to ``Villain``. A shape can be attached to any object type expression in EdgeQL. .. code-block:: edgeql-repl db> select Villain { id, name }; { default::Villain {id: 6c22bdf0..., name: 'Sandman'}, default::Villain {id: 6c22c3d6..., name: 'Electro'}, default::Villain {id: 6c22c46c..., name: 'Green Goblin'}, default::Villain {id: 6c22c502..., name: 'Doc Ock'}, default::Villain {id: 6c42c4ec..., name: 'Obadiah Stane'}, } Nested shapes ^^^^^^^^^^^^^ Nested shapes can be used to fetch linked objects and their properties. Here we fetch all ``Villain`` objects and their nemeses. .. code-block:: edgeql-repl db> select Villain { ... name, ... nemesis: { name } ... }; { default::Villain { name: 'Sandman', nemesis: default::Hero {name: 'Spider-Man'}, }, ... } In the context of EdgeQL, computed links like ``Hero.villains`` are treated identically to concrete/non-computed links like ``Villain.nemesis``. .. code-block:: edgeql-repl db> select Hero { ... name, ... villains: { name } ... }; { default::Hero { name: 'Spider-Man', villains: { default::Villain {name: 'Sandman'}, default::Villain {name: 'Electro'}, default::Villain {name: 'Green Goblin'}, default::Villain {name: 'Doc Ock'}, }, }, ... } .. _ref_eql_select_splats: Splats ^^^^^^ .. index:: select *, select all .. api-index:: *, **, §type§.*, §type§.**, [is §type§].*, [is §type§].** Splats allow you to select all properties of a type using the asterisk (``*``) or all properties of the type and a single level of linked types with a double asterisk (``**``). .. edb:youtube-embed:: 9-I1qjIp3KI Splats will help you more easily select all properties when using the REPL. You can select all of an object's properties using the single splat: .. code-block:: edgeql-repl db> select Movie {*}; { default::Movie { id: 6c60c28a-5c03-11ee-99ff-dfa425012a05, release_year: 2021, title: 'Spider-Man: No Way Home', }, default::Movie { id: 6d1f430e-5c03-11ee-99ff-e731e8da06d9, release_year: 2008, title: 'Iron Man' }, } or you can select all of an object's properties and the properties of a single level of nested objects with the double splat: .. code-block:: edgeql-repl db> select Movie {**}; { default::Movie { id: 6c60c28a-5c03-11ee-99ff-dfa425012a05, release_year: 2021, title: 'Spider-Man: No Way Home', characters: { default::Hero { id: 6be1c9c6-5c03-11ee-99ff-63b1127d75f2, name: 'Spider-Man' }, default::Villain { id: 6c22bdf0-5c03-11ee-99ff-dfaea4d947ce, name: 'Sandman' }, default::Villain { id: 6c22c3d6-5c03-11ee-99ff-734255881e5d, name: 'Electro' }, default::Villain { id: 6c22c46c-5c03-11ee-99ff-c79f24cf638b,, name: 'Green Goblin' }, default::Villain { id: 6c22c502-5c03-11ee-99ff-cbacc3918129, name: 'Doc Ock' }, }, }, default::Movie { id: 6d1f430e-5c03-11ee-99ff-e731e8da06d9, release_year: 2008, title: 'Iron Man', characters: { default::Hero { id: 6bf7115a-5c03-11ee-99ff-c79c07f0e2db, name: 'Iron Man' }, default::Villain { id: 6c42c4ec-5c03-11ee-99ff-872c9906a467, name: 'Obadiah Stane' }, }, }, } .. note:: Splats are not yet supported in function bodies. The splat expands all properties defined on the type as well as inherited properties: .. code-block:: edgeql-repl db> select Hero {*}; { default::Hero { id: 6be1c9c6-5c03-11ee-99ff-63b1127d75f2, name: 'Spider-Man', secret_identity: 'Peter Parker' }, default::Hero { id: 6bf7115a-5c03-11ee-99ff-c79c07f0e2db, name: 'Iron Man', secret_identity: 'Tony Stark' }, } The splat here expands the heroes' names even though the ``name`` property is not defined on the ``Hero`` type but on the ``Person`` type it extends. If we want to select heroes but get only properties defined on the ``Person`` type, we can do this instead: .. code-block:: edgeql-repl db> select Hero {Person.*}; { default::Hero { id: 6be1c9c6-5c03-11ee-99ff-63b1127d75f2, name: 'Spider-Man' }, default::Hero { id: 6bf7115a-5c03-11ee-99ff-c79c07f0e2db, name: 'Iron Man' }, } If there are links on our ``Person`` type, we can use ``Person.**`` in a similar fashion to get all properties and one level of linked object properties, but only for links and properties that are defined on the ``Person`` type. You can use the splat to expand properties using a :ref:`type intersection `. Maybe we want to select all ``Person`` objects with their names but also get any properties defined on the ``Hero`` for those ``Person`` objects which are also ``Hero`` objects: .. code-block:: edgeql-repl db> select Person { ... name, ... [is Hero].* ... }; { default::Hero { name: 'Spider-Man', id: 6be1c9c6-5c03-11ee-99ff-63b1127d75f2, secret_identity: 'Peter Parker' }, default::Hero { name: 'Iron Man' id: 6bf7115a-5c03-11ee-99ff-c79c07f0e2db, secret_identity: 'Tony Stark' }, default::Villain { name: 'Sandman', id: 6c22bdf0-5c03-11ee-99ff-dfaea4d947ce, secret_identity: {} }, default::Villain { name: 'Electro', id: 6c22c3d6-5c03-11ee-99ff-734255881e5d, secret_identity: {} }, default::Villain { name: 'Green Goblin', id: 6c22c46c-5c03-11ee-99ff-c79f24cf638b, secret_identity: {} }, default::Villain { name: 'Doc Ock', id: 6c22c502-5c03-11ee-99ff-cbacc3918129, secret_identity: {} }, default::Villain { name: 'Obadiah Stane', id: 6c42c4ec-5c03-11ee-99ff-872c9906a467, secret_identity: {} }, } The double splat also works with type intersection expansion to expand both properties and links on the specified type. .. code-block:: edgeql-repl db> select Person { ... name, ... [is Hero].** ... }; { default::Villain { name: 'Sandman', id: 6c22bdf0-5c03-11ee-99ff-dfaea4d947ce, secret_identity: {}, villains: {} }, default::Villain { name: 'Electro', id: 6c22c3d6-5c03-11ee-99ff-734255881e5d, secret_identity: {}, villains: {} }, default::Villain { name: 'Green Goblin', id: 6c22c46c-5c03-11ee-99ff-c79f24cf638b, secret_identity: {}, villains: {} }, default::Villain { name: 'Doc Ock', id: 6c22c502-5c03-11ee-99ff-cbacc3918129, secret_identity: {}, villains: {} }, default::Villain { name: 'Obadiah Stane', id: 6c42c4ec-5c03-11ee-99ff-872c9906a467, secret_identity: {}, villains: {} }, default::Hero { name: 'Spider-Man', id: 6be1c9c6-5c03-11ee-99ff-63b1127d75f2, secret_identity: 'Peter Parker', villains: { default::Villain { name: 'Electro', id: 6c22c3d6-5c03-11ee-99ff-734255881e5d }, default::Villain { name: 'Sandman', id: 6c22bdf0-5c03-11ee-99ff-dfaea4d947ce }, default::Villain { name: 'Doc Ock', id: 6c22c502-5c03-11ee-99ff-cbacc3918129 }, default::Villain { name: 'Green Goblin', id: 6c22c46c-5c03-11ee-99ff-c79f24cf638b }, }, }, } With this query, we get ``name`` for each ``Person`` and all the properties and one level of links on the ``Hero`` objects. We don't get ``Villain`` objects' nemeses because that link is not covered by our double splat which only expands ``Hero`` links. If the ``Villain`` type had properties defined on it, we wouldn't get those with this query either. .. _ref_eql_select_filter: Filtering --------- .. index:: where .. api-index:: filter To filter the set of selected objects, use a ``filter `` clause. The ```` that follows the ``filter`` keyword can be *any boolean expression*. To reference the ``name`` property of the ``Villain`` objects being selected, we use ``Villain.name``. .. code-block:: edgeql-repl db> select Villain {id, name} ... filter Villain.name = "Doc Ock"; {default::Villain {id: 6c22c502..., name: 'Doc Ock'}} .. note:: This query contains two occurrences of ``Villain``. The first (outer) is passed as the argument to ``select`` and refers to the set of all ``Villain`` objects. However the *inner* occurrence is inside the *scope* of the ``select`` statement and refers to the *object being selected*. However, this looks a little clunky, so EdgeQL provides a shorthand: just drop ``Villain`` entirely and simply use ``.name``. Since we are selecting a set of Villains, it's clear from context that ``.name`` must refer to a link/property of the ``Villain`` type. In other words, we are in the **scope** of the ``Villain`` type. .. code-block:: edgeql-repl db> select Villain {name} ... filter .name = "Doc Ock"; {default::Villain {name: 'Doc Ock'}} .. warning:: When using comparison operators like ``=`` or ``!=``, or boolean operators ``and``, ``or``, and ``not``, keep in mind that these operators will produce an empty set if an operand is an empty set. Check out :ref:`our boolean cheatsheet ` for more info and help on how to mitigate this if you know your operands may be an empty set. Filtering by ID ^^^^^^^^^^^^^^^ To filter by ``id``, remember to cast the desired ID to :ref:`uuid `: .. code-block:: edgeql-repl db> select Villain {id, name} ... filter .id = "6c22c502-5c03-11ee-99ff-cbacc3918129"; { default::Villain { id: '6c22c502-5c03-11ee-99ff-cbacc3918129', name: 'Doc Ock' } } Nested filters ^^^^^^^^^^^^^^ Filters can be added at every level of shape nesting. The query below applies a filter to both the selected ``Hero`` objects and their linked ``villains``. .. code-block:: edgeql-repl db> select Hero { ... name, ... villains: { ... name ... } filter .name like "%O%" ... } filter .name ilike "%man"; { default::Hero { name: 'Spider-Man', villains: { default::Villain { name: 'Doc Ock' } } }, default::Hero { name: 'Iron Man', villains: { default::Villain { name: 'Obadiah Stane' } } }, } Note that the *scope* changes inside nested shapes. When we use ``.name`` in the outer ``filter``, it refers to the name of the hero. But when we use ``.name`` in the nested ``villains`` shape, the scope has changed to ``Villain``. Filtering on a known backlink ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Another handy use for backlinks is using them to filter and find items when doing a ``select`` (or an ``update`` or other operation, of course). This can work as a nice shortcut when you have the ID of one object that links to a second object without a link back to the first. Spider-Man's villains always have a grudging respect for him, and their names can be displayed to reflect that if we know the ID of a movie that they starred in. Note the ability to :ref:`cast from a uuid ` to an object type. .. code-block:: edgeql-repl db> select Villain filter .'6c60c28a-5c03-11ee-99ff-dfa425012a05' { ... name := .name ++ ', who got to see Spider-Man!' ... }; { 'Obadiah Stane', 'Sandman, who got to see Spider-Man!', 'Electro, who got to see Spider-Man!', 'Green Goblin, who got to see Spider-Man!', 'Doc Ock, who got to see Spider-Man!', } In other words, "select every ``Villain`` object that the ``Movie`` object of this ID links to via a link called ``characters``". A backlink is naturally not required, however. The same operation without traversing a backlink would look like this: .. code-block:: edgeql-repl db> with movie := ... '6c60c28a-5c03-11ee-99ff-dfa425012a05', ... select movie.characters[is Villain] { ... name := .name ++ ', who got to see Spider-Man!' ... }; .. _ref_eql_select_order: Filtering, ordering, and limiting of links ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Clauses like ``filter``, ``order by``, and ``limit`` can be used on links. If no properties of a link are selected, you can place the clauses directly inside the shape: .. code-block:: edgeql select User { likes order by .title desc limit 10 }; If properties are selected, place the clauses after the link's shape: .. code-block:: edgeql select User { likes: { id, title } order by .title desc limit 10 }; Ordering -------- .. index:: sorting .. api-index:: order by, asc, desc, then, empty first, empty last Order the result of a query with an ``order by`` clause. .. code-block:: edgeql-repl db> select Villain { name } ... order by .name; { default::Villain {name: 'Doc Ock'}, default::Villain {name: 'Electro'}, default::Villain {name: 'Green Goblin'}, default::Villain {name: 'Obadiah Stane'}, default::Villain {name: 'Sandman'}, } The expression provided to ``order by`` may be *any* singleton expression, primitive or otherwise. .. note:: In Gel all values are orderable. Objects are compared using their ``id``; tuples and arrays are compared element-by-element from left to right. By extension, the generic comparison operators :eql:op:`= `, :eql:op:`\< `, :eql:op:`\> `, etc. can be used with any two expressions of the same type. You can also order by multiple expressions and specify the *direction* with an ``asc`` (default) or ``desc`` modifier. .. note:: When ordering by multiple expressions, arrays, or tuples, the leftmost expression/element is compared. If these elements are the same, the next element is used to "break the tie", and so on. If all elements are the same, the order is not well defined. .. code-block:: edgeql-repl db> select Movie { title, release_year } ... order by ... .release_year desc then ... str_trim(.title) desc; { default::Movie {title: 'Spider-Man: No Way Home', release_year: 2021}, ... default::Movie {title: 'Iron Man', release_year: 2008}, } When ordering by multiple expressions, each expression is separated with the ``then`` keyword. For a full reference on ordering, including how empty values are handled, see :ref:`Reference > Commands > Select `. .. _ref_eql_select_pagination: Pagination ---------- .. api-index:: limit, offset |Gel| supports ``limit`` and ``offset`` clauses. These are typically used in conjunction with ``order by`` to maintain a consistent ordering across pagination queries. .. code-block:: edgeql-repl db> select Villain { name } ... order by .name ... offset 2 ... limit 2; { default::Villain {name: 'Obadiah Stane'}, default::Villain {name: 'Sandman'}, } The expressions passed to ``limit`` and ``offset`` can be any singleton ``int64`` expression. This query fetches all Villains except the last (sorted by name). .. code-block:: edgeql-repl db> select Villain {name} ... order by .name ... limit count(Villain) - 1; { default::Villain {name: 'Doc Ock'}, default::Villain {name: 'Electro'}, default::Villain {name: 'Green Goblin'}, default::Villain {name: 'Obadiah Stane'}, # no Sandman } You may pass the empty set to ``limit`` or ``offset``. Passing the empty set is effectively the same as excluding ``limit`` or ``offset`` from your query (i.e., no limit or no offset). This is useful if you need to parameterize ``limit`` and/or ``offset`` but may still need to execute your query without providing one or the other. .. code-block:: edgeql-repl db> select Villain {name} ... order by .name ... offset $offset ... limit $limit; Parameter $offset (Ctrl+D for empty set `{}`): Parameter $limit (Ctrl+D for empty set `{}`): { default::Villain {name: 'Doc Ock'}, default::Villain {name: 'Electro'}, ... } .. note:: If you parameterize ``limit`` and ``offset`` and want to reserve the option to pass the empty set, make sure those parameters are ``optional`` as shown in the example above. .. _ref_eql_select_computeds: Computed fields --------------- .. api-index:: := Shapes can contain *computed fields*. These are EdgeQL expressions that are computed on the fly during the execution of the query. As with other clauses, we can use :ref:`leading dot notation ` (e.g. ``.name``) to refer to the properties and links of the object type currently *in scope*. .. code-block:: edgeql-repl db> select Villain { ... name, ... name_upper := str_upper(.name) ... }; { default::Villain { id: 6c22bdf0..., name: 'Sandman', name_upper: 'SANDMAN', }, ... } As with nested filters, the *current scope* changes inside nested shapes. .. code-block:: edgeql-repl db> select Villain { ... id, ... name, ... name_upper := str_upper(.name), ... nemesis: { ... secret_identity, ... real_name_upper := str_upper(.secret_identity) ... } ... }; { default::Villain { id: 6c22bdf0..., name: 'Sandman', name_upper: 'SANDMAN', nemesis: default::Hero { secret_identity: 'Peter Parker', real_name_upper: 'PETER PARKER', }, }, ... } .. _ref_eql_select_backlinks: Backlinks --------- .. api-index:: .< Fetching backlinks is a common use case for computed fields. To demonstrate this, let's fetch a list of all movies starring a particular Hero. .. code-block:: edgeql-repl db> select Hero { ... name, ... movies := . Paths `. Instead of re-declaring backlinks inside every query where they're needed, it's common to add them directly into your schema as computed links. .. code-block:: sdl-diff abstract type Person { required name: str { constraint exclusive; }; + multi movies := . select Villain { ... name, ... nemesis_name := .nemesis.name, ... movies_with_nemesis := ( ... select Movie { title } ... filter Villain.nemesis in .characters ... ) ... }; { default::Villain { name: 'Sandman', nemesis_name: 'Spider-Man', movies_with_nemesis: { default::Movie {title: 'Spider-Man: No Way Home'} } }, ... } .. _ref_eql_select_polymorphic: Polymorphic queries ------------------- .. index:: polymorphism All queries thus far have referenced concrete object types: ``Hero`` and ``Villain``. However, both of these types extend the abstract type ``Person``, from which they inherit the ``name`` property. Polymorphic sets ^^^^^^^^^^^^^^^^ It's possible to directly query all ``Person`` objects; the resulting set will be a mix of ``Hero`` and ``Villain`` objects (and possibly other subtypes of ``Person``, should they be declared). .. code-block:: edgeql-repl db> select Person { name }; { default::Hero {name: 'Spider-Man'}, default::Hero {name: 'Iron Man'}, default::Villain {name: 'Doc Ock'}, default::Villain {name: 'Obadiah Stane'}, ... } You may also encounter such "mixed sets" when querying a link that points to an abstract type (such as ``Movie.characters``) or a :eql:op:`union type `. .. code-block:: edgeql-repl db> select Movie { ... title, ... characters: { ... name ... } ... } ... filter .title = "Iron Man 2"; { default::Movie { title: 'Iron Man', characters: { default::Villain {name: 'Obadiah Stane'}, default::Hero {name: 'Iron Man'} } } } Polymorphic fields ^^^^^^^^^^^^^^^^^^ .. api-index:: [is §type§]. We can fetch different properties *conditional* on the subtype of each object by prefixing property/link references with ``[is ]``. This is known as a **polymorphic query**. .. code-block:: edgeql-repl db> select Person { ... name, ... secret_identity := [is Hero].secret_identity, ... number_of_villains := count([is Hero].villains), ... nemesis := [is Villain].nemesis { ... name ... } ... }; { ... default::Villain { name: 'Obadiah Stane', secret_identity: {}, number_of_villains: 0, nemesis: default::Hero { name: 'Iron Man' } }, default::Hero { name: 'Spider-Man', secret_identity: 'Peter Parker', number_of_villains: 4, nemesis: {} }, ... } This syntax might look familiar; it's the :ref:`type intersection ` again. In effect, this operator conditionally returns the value of the referenced field only if the object matches a particular type. If the match fails, an empty set is returned. The line ``secret_identity := [is Hero].secret_identity`` is a bit redundant, since the computed property has the same name as the polymorphic field. In these cases, EdgeQL supports a shorthand. .. code-block:: edgeql-repl db> select Person { ... name, ... [is Hero].secret_identity, ... [is Villain].nemesis: { ... name ... } ... }; { ... default::Villain { name: 'Obadiah Stane', secret_identity: {}, nemesis: default::Hero {name: 'Iron Man'} }, default::Hero { name: 'Spider-Man', secret_identity: 'Peter Parker', nemesis: {} }, ... } Filtering polymorphic links ^^^^^^^^^^^^^^^^^^^^^^^^^^^ Relatedly, it's possible to filter polymorphic links by subtype. Below, we exclusively fetch the ``Movie.characters`` of type ``Hero``. .. code-block:: edgeql-repl db> select Movie { ... title, ... characters[is Hero]: { ... secret_identity ... }, ... }; { default::Movie { title: 'Spider-Man: No Way Home', characters: {default::Hero {secret_identity: 'Peter Parker'}}, }, default::Movie { title: 'Iron Man', characters: {default::Hero {secret_identity: 'Tony Stark'}}, }, ... } Accessing types in polymorphic queries ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ While the type of an object is displayed alongside the results of polymorphic queries run in the REPL, this is simply a convenience of the REPL and not a property that can be accessed. This is particularly noticeable if you cast an object to ``json``, making it impossible to determine the type if the query is polymorphic. First, the result of a query as the REPL presents it with type annotations displayed: .. code-block:: edgeql-repl db> select Person limit 1; {default::Villain {id: 6c22bdf0-5c03-11ee-99ff-dfaea4d947ce}} Note the type ``default::Villain``, which is displayed for the user's convenience but is not actually part of the data returned. This is the same query with the result cast as ``json`` to show only the data returned: .. code-block:: edgeql-repl db> select Person limit 1; {Json("{\"id\": \"6c22bdf0-5c03-11ee-99ff-dfaea4d947ce\"}")} .. note:: We will continue to cast subesequent examples in this section to ``json``, not because this is required for any of the functionality being demonstrated, but to remove the convenience type annotations provided by the REPL and make it easier to see what data is actually being returned by the query. The type of an object is found inside ``__type__`` which is a link that carries various information about the object's type, including its ``name``. .. code-block:: edgeql-repl db> select Person { ... __type__: { ... name ... } ... } limit 1; {Json("{\"__type__\": {\"name\": \"default::Villain\"}}")} This information can be pulled into the top level by assigning a name to the ``name`` property inside ``__type__``: .. code-block:: edgeql-repl db> select Person { type := .__type__.name } limit 1; {Json("{\"type\": \"default::Villain\"}")} There is nothing magical about ``__type__``; it is a simple link to an object of the type ``ObjectType`` which contains all of the possible information to know about the type of the current object. The splat operator can be used to see this object's makeup, while the double splat operator produces too much output to show on this page. Playing around with the splat and double splat operator inside ``__type__`` is a quick way to get some insight into the internals of Gel. .. code-block:: edgeql-repl db> select Person.__type__ {*} limit 1; { schema::ObjectType { id: 48be3a94-5bf3-11ee-bd60-0b44b607e31d, name: 'default::Hero', internal: false, builtin: false, computed_fields: [], final: false, is_final: false, abstract: false, is_abstract: false, inherited_fields: [], from_alias: false, is_from_alias: false, expr: {}, compound_type: false, is_compound_type: false, }, } .. _ref_eql_select_free_objects: Free objects ------------ .. index:: ad hoc type To select several values simultaneously, you can "bundle" them into a "free object". Free objects are a set of key-value pairs that can contain any expression. Here, the term "free" is used to indicate that the object in question is not an instance of a particular *object type*; instead, it's constructed ad hoc inside the query. .. code-block:: edgeql-repl db> select { ... my_string := "This is a string", ... my_number := 42, ... several_numbers := {1, 2, 3}, ... all_heroes := Hero { name } ... }; { { my_string: 'This is a string', my_number: 42, several_numbers: {1, 2, 3}, all_heroes: { default::Hero {name: 'Spider-Man'}, default::Hero {name: 'Iron Man'}, }, }, } Note that the result is a *singleton* but each key corresponds to a set of values, which may have any cardinality. .. _ref_eql_select_with: With block ---------- All top-level EdgeQL statements (``select``, ``insert``, ``update``, and ``delete``) can be prefixed with a ``with`` block. These blocks let you declare standalone expressions that can be used in your query. .. code-block:: edgeql-repl db> with hero_name := "Iron Man" ... select Hero { secret_identity } ... filter .name = hero_name; {default::Hero {secret_identity: 'Tony Stark'}} For full documentation on ``with``, see :ref:`EdgeQL > With `. .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > Commands > Select ` * - :ref:`Cheatsheets > Selecting data ` ================================================ FILE: docs/reference/edgeql/sets.rst ================================================ .. _ref_eql_sets: Sets ==== .. _ref_eql_everything_is_a_set: Everything is a set ------------------- .. index:: multiset, cardinality, empty set, singleton All values in EdgeQL are actually **sets**: a collection of values of a given **type**. All elements of a set must have the same type. The number of items in a set is known as its **cardinality**. A set with a cardinality of zero is referred to as an **empty set**. A set with a cardinality of one is known as a **singleton**. .. _ref_eql_set_constructor: Constructing sets ----------------- .. api-index:: {§expr [\, ...]§}, union Set literals are declared with *set constructor* syntax: a comma-separated list of values inside a set of ``{curly braces}``. .. code-block:: edgeql-repl db> select {"set", "of", "strings"}; {"set", "of", "strings"} db> select {1, 2, 3}; {1, 2, 3} In actuality, curly braces are a syntactic sugar for the :eql:op:`union` operator. The previous examples are perfectly equivalent to the following: .. code-block:: edgeql-repl db> select "set" union "of" union "strings"; {"set", "of", "strings"} db> select 1 union 2 union 3; {1, 2, 3} A consequence of this is that nested sets are *flattened*. .. code-block:: edgeql-repl db> select {1, {2, {3, 4}}}; {1, 2, 3, 4} db> select 1 union (2 union (3 union 4)); {1, 2, 3, 4} All values in a set must have the same type. For convenience, Gel will *implicitly cast* values to other types, as long as there is no loss of information (e.g. converting a ``int16`` to an ``int64``). For a full reference, see the casting table in :ref:`Standard Library > Casts `. .. code-block:: edgeql-repl db> select {1, 1.5}; {1.0, 1.5} db> select {1, 1234.5678n}; {1.0n, 1234.5678n} Attempting to declare a set containing elements of *incompatible* types is not permitted. .. code-block:: edgeql-repl db> select {"apple", 3.14}; error: QueryError: set constructor has arguments of incompatible types 'std::str' and 'std::float64' .. note:: Types are considered *compatible* if one can be implicitly cast into the other. For reference on implicit castability, see :ref:`Standard Library > Casts `. .. _ref_eql_set_literals_are_singletons: Literals are singletons ----------------------- Literal syntax like ``6`` or ``"hello world"`` is just a shorthand for declaring a *singleton* of a given type. This is why the literals we created in the previous section were printed inside braces: to indicate that these values are *actually sets*. .. code-block:: edgeql-repl db> select 6; {6} db> select "hello world"; {"hello world"} Wrapping a literal in curly braces does not change the meaning of the expression. For instance, ``"hello world"`` is *exactly equivalent* to ``{"hello world"}``. .. code-block:: edgeql-repl db> select {"hello world"}; {"hello world"} db> select "hello world" = {"hello world"}; {true} You can retrieve the cardinality of a set with the :eql:func:`count` function. .. code-block:: edgeql-repl db> select count('aaa'); {1} db> select count({'aaa', 'bbb'}); {2} .. _ref_eql_empty_sets: Empty sets ---------- .. index:: null, exists The reason EdgeQL introduced the concept of *sets* is to eliminate the concept of ``null``. In SQL databases ``null`` is a special value denoting the absence of data; in Gel the absence of data is just an empty set. .. note:: Why is the existence of NULL a problem? Put simply, it's an edge case that permeates all of SQL and is often handled inconsistently in different circumstances. A number of specific inconsistencies are documented in detail in the `We Can Do Better Than SQL `_ post on the Gel blog. For broader context, see Tony Hoare's talk `"The Billion Dollar Mistake" `_. Declaring empty sets isn't as simple as ``{}``; in EdgeQL, all expressions are *strongly typed*, including empty sets. With nonempty sets (like ``{1, 2, 3}``) , the type is inferred from the set's contents (``int64``). But with empty sets this isn't possible, so an *explicit cast* is required. .. code-block:: edgeql-repl db> select {}; error: QueryError: expression returns value of indeterminate type ┌─ query:1:8 │ 1 │ select {}; │ ^^ Consider using an explicit type cast. db> select {}; {} db> select {}; {} db> select count({}); {0} You can check whether or not a set is *empty* with the :eql:op:`exists` operator. .. code-block:: edgeql-repl db> select exists {}; {false} db> select exists {'not', 'empty'}; {true} .. _ref_eql_set_references: Set references -------------- .. index:: pointer, alias, with A set reference is a *pointer* to a set of values. Most commonly, this is the name of an :ref:`object type ` you've declared in your schema. .. code-block:: edgeql-repl db> select User; { default::User {id: 9d2ce01c-35e8-11ec-acc3-83b1377efea0}, default::User {id: b0e0dd0c-35e8-11ec-acc3-abf1752973be}, } db> select count(User); {2} It may also be an *alias*, which can be defined in a :ref:`with block ` or as an :ref:`alias declaration ` in your schema. .. note:: In the example above, the ``User`` object type was declared inside the ``default`` module. If it was in a non-``default`` module (say, ``my_module``, we would need to use its *fully-qualified* name. .. code-block:: edgeql-repl db> select my_module::User; .. _ref_eql_set_distinct: Multisets --------- .. api-index:: distinct Technically sets in Gel are actually *multisets*, because they can contain duplicates of the same element. To eliminate duplicates, use the :eql:op:`distinct` set operator. .. code-block:: edgeql-repl db> select {'aaa', 'aaa', 'aaa'}; {'aaa', 'aaa', 'aaa'} db> select distinct {'aaa', 'aaa', 'aaa'}; {'aaa'} .. _ref_eql_set_in: Checking membership ------------------- .. api-index:: §element§ in §set§ Use the :eql:op:`in` operator to check whether a set contains a particular element. .. code-block:: edgeql-repl db> select 'aaa' in {'aaa', 'bbb', 'ccc'}; {true} db> select 'ddd' in {'aaa', 'bbb', 'ccc'}; {false} .. _ref_eql_set_union: Merging sets ------------ .. api-index:: union Use the :eql:op:`union` operator to merge two sets. .. code-block:: edgeql-repl db> select 'aaa' union 'bbb' union 'ccc'; {'aaa', 'bbb', 'ccc'} db> select {1, 2} union {3.1, 4.4}; {1.0, 2.0, 3.1, 4.4} Finding common members ---------------------- .. api-index:: intersect Use the :eql:op:`intersect` operator to find common members between two sets. .. code-block:: edgeql-repl db> select {1, 2, 3, 4, 5} intersect {3, 4, 5, 6, 7}; {3, 5, 4} db> select {'a', 'b', 'c', 'd', 'e'} intersect {'c', 'd', 'e', 'f', 'g'}; {'e', 'd', 'c'} If set members are repeated in both sets, they will be repeated in the set produced by :eql:op:`intersect` the same number of times they are repeated in both of the operand sets. .. code-block:: edgeql-repl db> select {0, 1, 1, 1, 2, 3, 3} intersect {1, 3, 3, 3, 3, 3}; {1, 3, 3} In this example, ``1`` appears three times in the first set but only once in the second, so it appears only once in the result. ``3`` appears twice in the first set and five times in the second. Both ``3`` appearances in the first set are overlapped by ``3`` appearances in the second, so they both end up in the resulting set. Removing common members ----------------------- .. api-index:: except Use the :eql:op:`except` operator to leave only the members in the first set that do not appear in the second set. .. code-block:: edgeql-repl db> select {1, 2, 3, 4, 5} except {3, 4, 5, 6, 7}; {1, 2} db> select {'a', 'b', 'c', 'd', 'e'} except {'c', 'd', 'e', 'f', 'g'}; {'b', 'a'} When :eql:op:`except` eliminates a common member that is repeated, it never eliminates more than the number of instances of that member appearing in the second set. .. code-block:: edgeql-repl db> select {0, 1, 1, 1, 2, 3, 3} except {1, 3, 3, 3, 3, 3}; {0, 1, 1, 2} In this example, both sets share the member ``1``. The first set contains three of them while the second contains only one. The result retains two ``1`` members from the first set since the sets shared only a single ``1`` in common. The second set has five ``3`` members to the first set's two, so both of the first set's ``3`` members are eliminated from the resulting set. .. _ref_eql_set_coalesce: Coalescing ---------- .. index:: empty set, default values, optional .. api-index:: ?? Occasionally in queries, you need to handle the case where a set is empty. This can be achieved with a coalescing operator :eql:op:`?? `. This is commonly used to provide default values for optional :ref:`query parameters `. .. code-block:: edgeql-repl db> select 'value' ?? 'default'; {'value'} db> select {} ?? 'default'; {'default'} .. note:: Coalescing is an example of a function/operator with :ref:`optional inputs `. By default, passing an empty set into a function/operator will "short circuit" the operation and return an empty set. However it's possible to mark inputs as *optional*, in which case the operation will be defined over empty sets. Another example is :eql:func:`count`, which returns ``{0}`` when an empty set is passed as input. .. _ref_eql_set_type_filter: Inheritance ----------- .. index:: type intersection, backlinks .. api-index:: §expr§[is §type§] |Gel| schemas support :ref:`inheritance `; types (usually object types) can extend one or more other types. For instance you may declare an abstract object type ``Media`` that is extended by ``Movie`` and ``TVShow``. .. code-block:: sdl abstract type Media { required title: str; } type Movie extending Media { release_year: int64; } type TVShow extending Media { num_seasons: int64; } A set of type ``Media`` may contain both ``Movie`` and ``TVShow`` objects. .. code-block:: edgeql-repl db> select Media; { default::Movie {id: 9d2ce01c-35e8-11ec-acc3-83b1377efea0}, default::Movie {id: 3bfe4900-3743-11ec-90ee-cb73d2740820}, default::TVShow {id: b0e0dd0c-35e8-11ec-acc3-abf1752973be}, } We can use the *type intersection* operator ``[is ]`` to restrict the elements of a set by subtype. .. code-block:: edgeql-repl db> select Media[is Movie]; { default::Movie {id: 9d2ce01c-35e8-11ec-acc3-83b1377efea0}, default::Movie {id: 3bfe4900-3743-11ec-90ee-cb73d2740820}, } db> select Media[is TVShow]; { default::TVShow {id: b0e0dd0c-35e8-11ec-acc3-abf1752973be} } Type filters are commonly used in conjunction with :ref:`backlinks `. .. _ref_eql_set_aggregate: Aggregate vs element-wise operations ------------------------------------ .. index:: cartesian product EdgeQL provides a large library of built-in functions and operators for handling data structures. It's useful to consider functions/operators as either *aggregate* or *element-wise*. .. note:: This is an over-simplification, but it's a useful mental model when just starting out with Gel. For a more complete guide, see :ref:`Reference > Cardinality `. *Aggregate* operations are applied to the set *as a whole*; they accept a set with arbitrary cardinality and return a *singleton* (or perhaps an empty set if the input was also empty). .. code-block:: edgeql-repl db> select count({'aaa', 'bbb'}); {2} db> select sum({1, 2, 3}); {6} db> select min({1, 2, 3}); {1} Element-wise operations are applied on *each element* of a set. .. code-block:: edgeql-repl db> select str_upper({'aaa', 'bbb'}); {'AAA', 'BBB'} db> select {1, 2, 3} ^ 2; {1, 4, 9} db> select str_split({"hello world", "hi again"}, " "); {["hello", "world"], ["hi", "again"]} When an *element-wise* operation accepts two or more inputs, the operation is applied to all possible combinations of inputs; in other words, the operation is applied to the *Cartesian product* of the inputs. .. code-block:: edgeql-repl db> select {'aaa', 'bbb'} ++ {'ccc', 'ddd'}; {'aaaccc', 'aaaddd', 'bbbccc', 'bbbddd'} Accordingly, operations involving an empty set typically return an empty set. In constrast, aggregate operations like :eql:func:`count` are able to operate on empty sets. .. code-block:: edgeql-repl db> select {} ++ 'ccc'; {} db> select count({}); {0} For a more complete discussion of cardinality, see :ref:`Reference > Cardinality `. .. _ref_eql_set_array_conversion: Conversion to/from arrays ------------------------- .. api-index:: array_unpack, array_agg Both arrays and sets are collections of values that share a type. EdgeQL provides ways to convert one into the other. .. note:: Remember that *all values* in EdgeQL are sets; an array literal is just a singleton set of arrays. So here, "converting" a set into an array means converting a set of type ``x`` into another set with cardinality ``1`` (a singleton) and type ``array``. .. code-block:: edgeql-repl db> select array_unpack([1,2,3]); {1, 2, 3} db> select array_agg({1,2,3}); {[1, 2, 3]} Arrays are an *ordered collection*, whereas sets are generally unordered (unless explicitly sorted with an ``order by`` clause in a :ref:`select ` statement). Element-wise scalar operations in the standard library cannot be applied to arrays, so sets of scalars are typically easier to manipulate, search, and transform than arrays. .. code-block:: edgeql-repl db> select str_trim({' hello', 'world '}); {'hello', 'world'} db> select str_trim([' hello', 'world ']); error: QueryError: function "str_trim(arg0: array)" does not exist Some :ref:`aggregate ` operations have analogs that operate on arrays. For instance, the set function :eql:func:`count` is analogous to the array function :eql:func:`len`. Reference --------- .. list-table:: * - Set operators - :eql:op:`distinct` :eql:op:`in` :eql:op:`union` :eql:op:`exists` :eql:op:`if..else` :eql:op:`?? ` :eql:op:`detached` :eql:op:`[is type] ` * - Utility functions - :eql:func:`count` :eql:func:`enumerate` * - Cardinality assertion - :eql:func:`assert_distinct` :eql:func:`assert_single` :eql:func:`assert_exists` ================================================ FILE: docs/reference/edgeql/transactions.rst ================================================ .. _ref_eql_transactions: Transactions ============ .. api-index:: start transaction, declare savepoint, release savepoint, rollback to savepoint, rollback, commit EdgeQL supports atomic transactions. The transaction API consists of several commands: :eql:stmt:`start transaction` Start a transaction, specifying the isolation level, access mode (``read only`` vs ``read write``), and deferrability. :eql:stmt:`declare savepoint` Establish a new savepoint within the current transaction. A savepoint is a intermediate point in a transaction flow that provides the ability to partially rollback a transaction. :eql:stmt:`release savepoint` Destroys a savepoint previously defined in the current transaction. :eql:stmt:`rollback to savepoint` Rollback to the named savepoint. All changes made after the savepoint are discarded. The savepoint remains valid and can be rolled back to again later, if needed. :eql:stmt:`rollback` Rollback the entire transaction. All updates made within the transaction are discarded. :eql:stmt:`commit` Commit the transaction. All changes made by the transaction become visible to others and will persist if a crash occurs. Client libraries ---------------- There is rarely a reason to use these commands directly. All Gel client libraries provide dedicated transaction APIs that handle transaction creation under the hood. Examples below show a transaction that sends 10 cents from the account of a ``BankCustomer`` called ``'Customer1'`` to ``BankCustomer`` called ``'Customer2'``. The equivalent Gel schema and queries are: .. code-block:: module default { type BankCustomer { required name: str; required balance: int64; } } update BankCustomer filter .name = 'Customer1' set { bank_balance := .bank_balance -10 }; update BankCustomer filter .name = 'Customer2' set { bank_balance := .bank_balance +10 } TypeScript/JS ^^^^^^^^^^^^^ Using an EdgeQL query string: .. code-block:: typescript client.transaction(async tx => { await tx.execute(`update BankCustomer filter .name = 'Customer1' set { bank_balance := .bank_balance -10 }`); await tx.execute(`update BankCustomer filter .name = 'Customer2' set { bank_balance := .bank_balance +10 }`); }); Using the querybuilder: .. code-block:: typescript const query1 = e.update(e.BankCustomer, () => ({ filter_single: { name: "Customer1" }, set: { bank_balance: { "-=": 10 } }, })); const query2 = e.update(e.BankCustomer, () => ({ filter_single: { name: "Customer2" }, set: { bank_balance: { "+=": 10 } }, })); client.transaction(async (tx) => { await query1.run(tx); await query2.run(tx); }); Full documentation at :ref:`Client Libraries > TypeScript/JS `; Python ^^^^^^ .. code-block:: python async for tx in client.transaction(): async with tx: await tx.execute("""update BankCustomer filter .name = 'Customer1' set { bank_balance := .bank_balance -10 };""") await tx.execute("""update BankCustomer filter .name = 'Customer2' set { bank_balance := .bank_balance +10 };""") Full documentation at :ref:`Client Libraries > Python `; Golang ^^^^^^ .. code-block:: go err = client.Tx(ctx, func(ctx context.Context, tx *gel.Tx) error { query1 := `update BankCustomer filter .name = 'Customer1' set { bank_balance := .bank_balance -10 };` if e := tx.Execute(ctx, query1); e != nil { return e } query2 := `update BankCustomer filter .name = 'Customer2' set { bank_balance := .bank_balance +10 };` if e := tx.Execute(ctx, query2); e != nil { return e } return nil }) if err != nil { log.Fatal(err) } Full documentation at `Client Libraries > Go `_. Rust ^^^^ .. code-block:: rust let balance_change_query = "update BankCustomer filter .name = $0 set { bank_balance := .bank_balance + $1 }"; client .transaction(|mut conn| async move { conn.execute(balance_change_query, &("Customer1", -10)) .await .expect("Execute should have worked"); conn.execute(balance_change_query, &("Customer2", 10)) .await .expect("Execute should have worked"); Ok(()) }) .await .expect("Transaction should have worked"); .. XXX: Add Rust docs .. Full documentation at :ref:`Client Libraries > Rust `. .. _gel-go: https://pkg.go.dev/github.com/geldata/gel-go ================================================ FILE: docs/reference/edgeql/types.rst ================================================ .. _ref_eql_types: ===== Types ===== The foundation of EdgeQL is Gel's rigorous type system. There is a set of EdgeQL operators and functions for changing, introspecting, and filtering by types. .. _ref_eql_types_names: Type expressions ---------------- .. api-index:: array<§type§>, tuple<§type [\, ...]§> Type expressions are exactly what they sound like: EdgeQL expressions that refer to a type. Most commonly, these are simply the *names* of established types: ``str``, ``int64``, ``BlogPost``, etc. Arrays and tuples have a dedicated type syntax. .. list-table:: * - **Type** - **Syntax** * - Array - ``array`` * - Tuple (unnamed) - ``tuple`` * - Tuple (named) - ``tuple`` For additional details on type syntax, see :ref:`Schema > Primitive Types `. .. _ref_eql_types_typecast: Type casting ------------ .. index:: casts, find object by id .. api-index:: <§type§>§expr§ Type casting is used to convert primitive values into another type. Casts are indicated with angle brackets containing a type expression. .. code-block:: edgeql-repl db> select 10; {"10"} db> select 10; {10n} db> select >[1, 2, 3]; {['1', '2', '3']} db> select >(1, 2, 3); {('1', 2, 3n)} Type casts are useful for declaring literals for types like ``datetime``, ``uuid``, and ``int16`` that don't have a dedicated syntax. .. code-block:: edgeql-repl db> select '1999-03-31T15:17:00Z'; {'1999-03-31T15:17:00Z'} db> select 42; {42} db> select '89381587-705d-458f-b837-860822e1b219'; {89381587-705d-458f-b837-860822e1b219} There are limits to what values can be cast to a certain type. In some cases two types are entirely incompatible, like ``bool`` and ``int64``; in other cases, the source data must be in a particular format, like casting ``str`` to ``datetime``. For a comprehensive table of castability, see :ref:`Standard Library > Casts `. Type casts can only be used on primitive expressions, not object type expressions. Every object stored in the database is strongly and immutably typed; you can't simply convert an object to an object of a different type. .. code-block:: edgeql-repl db> select 10; QueryError: cannot cast 'std::int64' to 'default::BlogPost' db> select 'asdf'; InvalidValueError: invalid input syntax for type std::int64: "asdf" db> select 100000000000000n; NumericOutOfRangeError: std::int16 out of range .. lint-off You can cast a UUID into an object: .. code-block:: edgeql-repl db> select '01d9cc22-b776-11ed-8bef-73f84c7e91e7'; {default::Hero {id: 01d9cc22-b776-11ed-8bef-73f84c7e91e7}} If you try to cast a UUID that no object of the type has as its ``id`` property, you'll get an error: .. code-block:: edgeql-repl db> select 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'; gel error: CardinalityViolationError: 'default::Hero' with id 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa' does not exist .. lint-on .. _ref_eql_types_intersection: Type intersections ------------------ .. api-index:: [is §type§] All elements of a given set have the same type; however, in the context of *sets of objects*, this type might be ``abstract`` and contain elements of multiple concrete subtypes. For instance, a set of ``Media`` objects may contain both ``Movie`` and ``TVShow`` objects. .. code-block:: edgeql-repl db> select Media; { default::Movie {id: 9d2ce01c-35e8-11ec-acc3-83b1377efea0}, default::Movie {id: 3bfe4900-3743-11ec-90ee-cb73d2740820}, default::TVShow {id: b0e0dd0c-35e8-11ec-acc3-abf1752973be}, } We can use the *type intersection* operator to restrict the elements of a set by subtype. .. code-block:: edgeql-repl db> select Media[is Movie]; { default::Movie {id: 9d2ce01c-35e8-11ec-acc3-83b1377efea0}, default::Movie {id: 3bfe4900-3743-11ec-90ee-cb73d2740820}, } Logically, this computes the intersection of the ``Media`` and ``Movie`` sets; since only ``Movie`` objects occur in both sets, this can be conceptualized as a "filter" that removes all elements that aren't of type ``Movie``. .. Type unions .. ----------- .. You can create a type union with the pipe operator: :eql:op:`type | type .. `. This is mostly commonly used for object types. .. .. code-block:: edgeql-repl .. db> select 5 is int32 | int64; .. {true} .. db> select Media is Movie | TVShow; .. {true, true, true} Type checking ------------- .. api-index:: §expr§ is §type§, §expr§ is not §type§ The ``[is foo]`` "type intersection" syntax should not be confused with the *type checking* operator :eql:op:`is`. .. code-block:: edgeql-repl db> select 5 is int64; {true} db> select {3.14, 2.718} is not int64; {true, true} db> select Media is Movie; {true, true, false} The ``typeof`` operator ----------------------- .. api-index:: typeof §expr§ The type of any expression can be extracted with the :eql:op:`typeof` operator. This can be used in any expression that expects a type. .. code-block:: edgeql-repl db> select '100'; {100} db> select "tuna" is typeof "trout"; {true} Introspection ------------- The entire type system of Gel is *stored inside Gel*. All types are introspectable as instances of the ``schema::Type`` type. For a set of introspection examples, see :ref:`Guides > Introspection `. ================================================ FILE: docs/reference/edgeql/update.rst ================================================ .. _ref_eql_update: Update ====== .. api-index:: update, filter, set The ``update`` command is used to update existing objects. .. code-block:: edgeql-repl db> update Hero ... filter .name = "Hawkeye" ... set { name := "Ronin" }; {default::Hero {id: d476b12e-3e7b-11ec-af13-2717f3dc1d8a}} If you omit the ``filter`` clause, all objects will be updated. This is useful for updating values across all objects of a given type. The example below cleans up all ``Hero.name`` values by trimming whitespace and converting them to title case. .. code-block:: edgeql-repl db> update Hero ... set { name := str_trim(str_title(.name)) }; {default::Hero {id: d476b12e-3e7b-11ec-af13-2717f3dc1d8a}} Syntax ^^^^^^ The structure of the ``update`` statement (``update...filter...set``) is an intentional inversion of SQL's ``UPDATE...SET...WHERE`` syntax. Curiously, in SQL, the ``where`` clauses typically occur *last* despite being applied before the ``set`` statement. EdgeQL is structured to reflect this; first, a target set is specified, then filters are applied, then the data is updated. Updating properties ------------------- To explicitly unset a property that is not required, set it to an empty set. .. code-block:: edgeql update Person filter .id = $id set { middle_name := {} }; Updating links -------------- .. api-index:: :=, +=, -= When updating links, the ``:=`` operator will *replace* the set of linked values. .. code-block:: edgeql-repl db> update movie ... filter .title = "Black Widow" ... set { ... characters := ( ... select Person ... filter .name in { "Black Widow", "Yelena", "Dreykov" } ... ) ... }; {default::Title {id: af706c7c-3e98-11ec-abb3-4bbf3f18a61a}} db> select Movie { num_characters := count(.characters) } ... filter .title = "Black Widow"; {default::Movie {num_characters: 3}} To add additional linked items, use the ``+=`` operator. .. code-block:: edgeql-repl db> update Movie ... filter .title = "Black Widow" ... set { ... characters += (insert Villain {name := "Taskmaster"}) ... }; {default::Title {id: af706c7c-3e98-11ec-abb3-4bbf3f18a61a}} db> select Movie { num_characters := count(.characters) } ... filter .title = "Black Widow"; {default::Movie {num_characters: 4}} To remove items, use ``-=``. .. code-block:: edgeql-repl db> update Movie ... filter .title = "Black Widow" ... set { ... characters -= Villain # remove all villains ... }; {default::Title {id: af706c7c-3e98-11ec-abb3-4bbf3f18a61a}} db> select Movie { num_characters := count(.characters) } ... filter .title = "Black Widow"; {default::Movie {num_characters: 2}} Returning data on update ------------------------ By default, ``update`` returns only the inserted object's ``id`` as seen in the examples above. If you want to get additional data back, you may wrap your ``update`` with a ``select`` and apply a shape specifying any properties and links you want returned: .. code-block:: edgeql-repl db> select (update Hero ... filter .name = "Hawkeye" ... set { name := "Ronin" } ... ) {id, name}; { default::Hero { id: d476b12e-3e7b-11ec-af13-2717f3dc1d8a, name: "Ronin" } } With blocks ----------- All top-level EdgeQL statements (``select``, ``insert``, ``update``, and ``delete``) can be prefixed with a ``with`` block. This is useful for updating the results of a complex query. .. code-block:: edgeql-repl db> with people := ( ... select Person ... order by .name ... offset 3 ... limit 3 ... ) ... update people ... set { name := str_trim(.name) }; { default::Hero {id: d4764c66-3e7b-11ec-af13-df1ba5b91187}, default::Hero {id: d7d7e0f6-40ae-11ec-87b1-3f06bed494b9}, default::Villain {id: d477a836-3e7b-11ec-af13-4fea611d1c31}, } .. note:: You can pass any object-type expression into ``update``, including polymorphic ones (as above). You can also use ``with`` to make returning additional data from an update more readable: .. code-block:: edgeql-repl db> with UpdatedHero := (update Hero ... filter .name = "Hawkeye" ... set { name := "Ronin" } ... ) ... select UpdatedHero { ... id, ... name ... }; { default::Hero { id: d476b12e-3e7b-11ec-af13-2717f3dc1d8a, name: "Ronin" } } See also -------- For documentation on performing *upsert* operations, see :ref:`EdgeQL > Insert > Upserts `. .. list-table:: * - :ref:`Reference > Commands > Update ` * - :ref:`Cheatsheets > Updating data ` ================================================ FILE: docs/reference/edgeql/with.rst ================================================ .. _ref_eql_with: With ==== .. index:: composition, composing queries, composable, CTE, common table expressions, subquery, subqueries .. api-index:: with All top-level EdgeQL statements (``select``, ``insert``, ``update``, and ``delete``) can be prefixed by a ``with`` block. These blocks contain declarations of standalone expressions that can be used in your query. .. code-block:: edgeql-repl db> with my_str := "hello world" ... select str_title(my_str); {'Hello World'} The ``with`` clause can contain more than one variable. Earlier variables can be referenced by later ones. Taken together, it becomes possible to write "script-like" queries that execute several statements in sequence. .. code-block:: edgeql-repl db> with a := 5, ... b := 2, ... c := a ^ b ... select c; {25} Subqueries ^^^^^^^^^^ There's no limit to the complexity of computed expressions. EdgeQL is fully composable; queries can simply be embedded inside each other. The following query fetches a list of all movies featuring at least one of the original six Avengers. .. code-block:: edgeql-repl db> with avengers := (select Hero filter .name in { ... 'Iron Man', ... 'Black Widow', ... 'Captain America', ... 'Thor', ... 'Hawkeye', ... 'The Hulk' ... }) ... select Movie {title} ... filter avengers in .characters; { default::Movie {title: 'Iron Man'}, default::Movie {title: 'The Incredible Hulk'}, default::Movie {title: 'Iron Man 2'}, default::Movie {title: 'Thor'}, default::Movie {title: 'Captain America: The First Avenger'}, ... } .. _ref_eql_with_params: Query parameters ^^^^^^^^^^^^^^^^ A common use case for ``with`` clauses is the initialization of :ref:`query parameters `. .. code-block:: edgeql with user_id := $user_id select User { name } filter .id = user_id; For a full reference on using query parameters, see :ref:`EdgeQL > Parameters `. Module alias ^^^^^^^^^^^^ .. api-index:: with, as module Another use of ``with`` is to provide aliases for modules. This can be useful for long queries which reuse many objects or functions from the same module. .. code-block:: edgeql with http as module std::net::http select http::ScheduledRequest filter .method = http::Method.POST; If the aliased module does not exist at the top level, but does exists as a part of the ``std`` module, that will be used automatically. .. code-block:: edgeql with http as module net::http # <- omitting std select http::ScheduledRequest filter .method = http::Method.POST; Module selection ^^^^^^^^^^^^^^^^ .. index:: fully-qualified names .. api-index:: with module By default, the *active module* is ``default``, so all schema objects inside this module can be referenced by their *short name*, e.g. ``User``, ``BlogPost``, etc. To reference objects in other modules, we must use fully-qualified names (``default::Hero``). However, ``with`` clauses also provide a mechanism for changing the *active module* on a per-query basis. .. code-block:: edgeql-repl db> with module schema ... select ObjectType; This ``with module`` clause changes the default module to schema, so we can refer to ``schema::ObjectType`` (a built-in Gel type) as simply ``ObjectType``. As with module aliases, if the active module does not exist at the top level, but does exist as part of the ``std`` module, that will be used automatically. .. code-block:: edgeql-repl db> with module math select abs(-1); {1} .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > Commands > With ` ================================================ FILE: docs/reference/index.rst ================================================ ========= Reference ========= .. toctree:: :maxdepth: 3 :hidden: using/index running/index datamodel/index edgeql/index stdlib/index ai/index auth/index reference/index Learn three components, and you know |Gel|: how to work with :ref:`schema `, how to write queries with :ref:`EdgeQL `, and what's available to you in our :ref:`standard library `. Start in those sections if you're new to |Gel|. Move over to our :ref:`reference ` when you're ready to dive deep into the internals, syntax, and other advanced topics. Schema ------ |Gel| schemas are declared using our schema definition language (SDL). .. code-block:: sdl module default { type Book { required title: str; release_year: int16; author: Person; } type Person { required name: str; } } The example schema above defines two types: Book and Person, each with a property or two. Book also contains a link to the author, which is a link to objects of the Person type. Learn more about how to define your schema using SDL in the :ref:`schema ` section. EdgeQL ------ EdgeQL is a next-generation query language designed to match SQL in power and surpass it in terms of clarity, brevity, and intuitiveness. .. code-block:: edgeql-repl db> select Book { ... title, ... release_year, ... author: { ... name ... } ... } order by .title; { default::Book { title: '1984', release_year: 1949, author: default::Person { name: 'George Orwell' } }, default::Book { title: 'Americanah', release_year: 2013, author: default::Person { name: 'Chimamanda Ngozi Adichie' } }, ... } You can use EdgeQL to easily return nested data structures just by putting a shape with a link on an object as shown above. Standard library ---------------- |Gel| comes with a rigorously defined type system consisting of scalar types, collection types (like arrays and tuples), and object types. It also includes a library of built-in functions and operators for working with each datatype, alongside some additional utilities and extensions. .. code-block:: edgeql-repl db> select count(Book); {16} db> select Book { ... title, ... title_length := len(.title) ... } order by .title_length; { default::Book { title: 'Sula', title_length: 4 }, default::Book { title: '1984', title_length: 4 }, default::Book { title: 'Beloved', title_length: 7 }, default::Book { title: 'The Fellowship of the Ring', title_length: 26 }, default::Book { title: 'One Hundred Years of Solitude', title_length: 29 }, } db> select math::stddev(len(Book.title)); {7.298401651503339} Gel comes with a rigorously defined type system consisting of scalar types, collection types (like arrays and tuples), and object types. It also includes a library of built-in functions and operators for working with each datatype, alongside some additional utilities and extensions. Cheatsheets ----------- Learn to do various common tasks using the many tools included with |Gel|. Querying ^^^^^^^^ * :ref:`Select ` * :ref:`Insert ` * :ref:`Update ` * :ref:`Delete ` * :ref:`via GraphQL ` Schema ^^^^^^ * :ref:`Booleans ` * :ref:`Object Types ` * :ref:`Functions ` * :ref:`Aliases ` * :ref:`Annotations ` * :ref:`Link Properties ` Admin ^^^^^ * :ref:`CLI ` * :ref:`REPL ` * :ref:`Admin ` ================================================ FILE: docs/reference/reference/edgeql/analyze.rst ================================================ .. _ref_eql_statements_analyze: Analyze ======= :eql-statement: ``analyze`` -- trigger performance analysis of the appended query .. eql:synopsis:: analyze ; # where is any EdgeQL query Description ----------- ``analyze`` returns a table with performance metrics broken down by node. You may prepend the ``analyze`` keyword in either of our REPLs (CLI or :ref:`UI `) or you may prepend in the UI's query builder for a helpful visualization of your query's performance. After any ``analyze`` in a REPL, run the ``\expand`` command to see fine-grained performance analysis of the previously analyzed query. Example ------- .. code-block:: edgeql-repl db> analyze select Hero { ... name, ... secret_identity, ... villains: { ... name, ... nemesis: { ... name ... } ... } ... }; ──────────────────────────────────────── Query ──────────────────────────────────────── analyze select ➊ Hero {name, secret_identity, ➋ villains: {name, ➌ nemesis: {name}}}; ──────────────────────── Coarse-grained Query Plan ──────────────────────── │ Time Cost Loops Rows Width │ Relations ➊ root │ 0.0 69709.48 1.0 0.0 32 │ Hero ╰──➋ .villains │ 0.0 92.9 0.0 0.0 32 │ Villain, Hero.villains ╰──➌ .nemesis │ 0.0 8.18 0.0 0.0 32 │ Hero .. list-table:: :class: seealso * - **See also** * - :ref:`CLI > gel analyze ` * - :ref:`EdgeQL > Analyze ` ================================================ FILE: docs/reference/reference/edgeql/cardinality.rst ================================================ .. _ref_reference_cardinality: Cardinality =========== The number of items in a set is known as its **cardinality**. A set with a cardinality of zero is referred to as an **empty set**. A set with a cardinality of one is known as a **singleton**. Terminology ----------- The term **cardinality** is used to refer to both the *exact* number of elements in a given set or a *range* of possible values. Internally, Gel tracks 5 different cardinality ranges: ``Empty`` (zero elements), ``One`` (a singleton set), ``AtMostOne`` (zero or one elements), ``AtLeastOne`` (one or more elements), and ``Many`` (any number of elements). |Gel| uses this information to statically check queries for validity. For instance, when assigning to a ``required multi`` link, the value being assigned in question *must* have a cardinality of ``One`` or ``AtLeastOne`` (as empty sets are not permitted). .. _ref_reference_cardinality_functions_operators: Functions and operators ----------------------- It's often useful to think of Gel functions/operators as either *element-wise* or *aggregate*. Element-wise operations are applied to *each item* in a set. Aggregate operations operate on sets *as a whole*. .. note:: This is a simplification, but it's a useful mental model when getting started with Gel. .. _ref_reference_cardinality_aggregate: Aggregate operations ^^^^^^^^^^^^^^^^^^^^ An example of an aggregate function is :eql:func:`count`. It returns the number of elements in a given set. Regardless of the size of the input set, the result is a singleton integer. .. code-block:: edgeql-repl db> select count('hello'); {1} db> select count({'this', 'is', 'a', 'set'}); {4} db> select count({}); {0} Another example is :eql:func:`array_agg`, which converts a *set* of elements into a singleton array. .. code-block:: edgeql-repl db> select array_agg({1,2,3}); {[1, 2, 3]} .. _ref_reference_cardinality_elementwise: Element-wise operations ^^^^^^^^^^^^^^^^^^^^^^^ By contrast, the :eql:func:`len` function is element-wise; it computes the length of each string inside a set of strings; as such, it converts a set of :eql:type:`str` into an equally-sized set of :eql:type:`int64`. .. code-block:: edgeql-repl db> select len('hello'); {5} db> select len({'hello', 'world'}); {5, 5} .. _ref_reference_cardinality_cartesian: Cartesian products ^^^^^^^^^^^^^^^^^^ In case of element-wise operations that accept multiple arguments, the operation is applied to a cartesian product of all the input sets. .. code-block:: edgeql-repl db> select {'aaa', 'bbb'} ++ {'ccc', 'ddd'}; {'aaaccc', 'aaaddd', 'bbbccc', 'bbbddd'} db> select {true, false} or {true, false}; {true, true, true, false} By extension, if any of the input sets are empty, the result of applying an element-wise function is also empty. In effect, when Gel detects an empty set, it "short-circuits" and returns an empty set without applying the operation. .. code-block:: edgeql-repl db> select {} ++ {'ccc', 'ddd'}; {} db> select {} or {true, false}; {} .. note:: Certain functions and operators avoid this "short-circuit" behavior by marking their inputs as :ref:`optional `. A notable example of an operator with optional inputs is the :eql:op:`?? ` operator. .. code-block:: edgeql-repl db> select {} ?? 'default'; {'default'} Per-input cardinality ^^^^^^^^^^^^^^^^^^^^^ Ultimately, the distinction between "aggregate vs element-wise" operations is a false one. Consider the :eql:op:`in` operation. .. code-block:: edgeql-repl db> select {1, 4} in {1, 2, 3}; {true, false} This operator takes two inputs. If it was "element-wise" we would expect the cardinality of the above operation to the cartesian product of the input cardinalities: ``2 x 3 = 6``. It it was aggregate, we'd expect a singleton output. Instead, the cardinality is ``2``. This operator is element-wise with respect to the first input and aggregate with respect to the second. The "element-wise vs aggregate" concept isn't determined on a per-function/per-operator basis; it determined on a per-input basis. Type qualifiers ^^^^^^^^^^^^^^^ When defining functions, all inputs are element-wise by default. The ``set of`` :ref:`type qualifier ` is used to designate an input as *aggregate*. Currently this modifier is not supported for user-defined functions, but it is used by certain standard library functions. Similarly the ``optional`` qualifier marks the input as optional; an operation will be executed is an optional input is empty, whereas passing an empty set for a "standard" (non-optional) element-wise input will always result in an empty set. Similarly, the *output* of a function :ref:`can be annotated ` with ``set of`` and ``optional`` qualifiers. Cardinality computation ^^^^^^^^^^^^^^^^^^^^^^^ To compute the number of times a function/operator will be invoked, take the cardinality of each input and apply the following transformations, based on the type qualifier (or lack thereof) for each: .. code-block:: element-wise: N -> N optional: N -> max(1, N) aggregate: N -> 1 The ultimate cardinality of the result is the union of the results of each invokation; as such, it depends on the *values returned* by each invokation. ================================================ FILE: docs/reference/reference/edgeql/casts.csv ================================================ from \ to,:eql:type:`json `,:eql:type:`str `,:eql:type:`float32 `,:eql:type:`float64 `,:eql:type:`int16 `,:eql:type:`int32 `,:eql:type:`int64 `,:eql:type:`bigint `,:eql:type:`decimal `,:eql:type:`bool `,:eql:type:`bytes `,:eql:type:`uuid `,:eql:type:`datetime `,:eql:type:`duration `,:eql:type:`local_date `,:eql:type:`local_datetime `,:eql:type:`local_time `,:eql:type:`relative_duration `,:eql:type:`date_duration `,:eql:type:`enum`,object :eql:type:`json `,,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``, :eql:type:`str `,``<>``,,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``:=``, :eql:type:`float32 `,``<>``,``<>``,,impl,``<>*``,``<>*``,``<>*``,``<>*``,``<>``,,,,,,,,,,,, :eql:type:`float64 `,``<>``,``<>``,``:=``,,``<>*``,``<>*``,``<>*``,``<>*``,``<>``,,,,,,,,,,,, :eql:type:`int16 `,``<>``,``<>``,impl,impl,,impl,impl,impl,impl,,,,,,,,,,,, :eql:type:`int32 `,``<>``,``<>``,,impl,``<>``,,impl,impl,impl,,,,,,,,,,,, :eql:type:`int64 `,``<>``,``<>``,``:=``,impl,``:=``,``:=``,,impl,impl,,,,,,,,,,,, :eql:type:`bigint `,,,,,,,,,impl,,,,,,,,,,,, :eql:type:`decimal `,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,``<>``,,,,,,,,,,,,, :eql:type:`bool `,``<>``,``<>``,,,,,,,,,,,,,,,,,,, :eql:type:`bytes `,``<>``,,,,,,,,,,,,,,,,,,,, :eql:type:`uuid `,``<>``,``<>``,,,,,,,,,,,,,,,,,,,``<>`` :eql:type:`datetime `,``<>``,``<>``,,,,,,,,,,,,,,,,,,, :eql:type:`duration `,``<>``,``<>``,,,,,,,,,,,,,,,,``<>``,,, :eql:type:`local_date `,``<>``,``<>``,,,,,,,,,,,,,,impl,,,,, :eql:type:`local_datetime `,``<>``,``<>``,,,,,,,,,,,,,``<>``,,``<>``,,,, :eql:type:`local_time `,``<>``,``<>``,,,,,,,,,,,,,,,,,,, :eql:type:`relative_duration `,``<>``,``<>``,,,,,,,,,,,,``<>``,,,,,``<>``,, :eql:type:`date_duration `,``<>``,``<>``,,,,,,,,,,,,,,,,impl,,, :eql:type:`enum`,``<>``,``<>``,,,,,,,,,,,,,,,,,,, object,``<>``,,,,,,,,,,,,,,,,,,,, ================================================ FILE: docs/reference/reference/edgeql/casts.rst ================================================ .. _ref_eql_casts: ===== Casts ===== There are different ways that casts appear in EdgeQL. Explicit Casts -------------- A type cast expression converts the specified value to another value of the specified type: .. eql:synopsis:: "<" ">" The :eql:synopsis:`` must be a valid :ref:`type expression ` denoting a non-abstract scalar or a container type. For example, the following expression casts an integer value into a string: .. code-block:: edgeql-repl db> select 10; {"10"} See the :eql:op:`type cast operator ` section for more information on type casting rules. .. _ref_uuid_casting: .. lint-off You can cast a UUID into an object: .. code-block:: edgeql-repl db> select '01d9cc22-b776-11ed-8bef-73f84c7e91e7'; {default::Hero {id: 01d9cc22-b776-11ed-8bef-73f84c7e91e7}} If you try to cast a UUID that no object of the type has as its ``id`` property, you'll get an error: .. code-block:: edgeql-repl db> select 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'; gel error: CardinalityViolationError: 'default::Hero' with id 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa' does not exist .. lint-on Assignment Casts ---------------- *Assignment casts* happen when inserting new objects. Numeric types will often be automatically cast into the specific type corresponding to the property they are assigned to. This is to avoid extra typing when dealing with numeric value using fewer bits: .. code-block:: edgeql # Automatically cast a literal 42 (which is int64 # by default) into an int16 value. insert MyObject { int16_val := 42 }; If *assignment* casting is supported for a given pair of types, *explicit* casting of those types is also supported. Implicit Casts -------------- *Implicit casts* happen automatically whenever the value type doesn't match the expected type in an expression. This is mostly supported for numeric casts that don't incur any potential information loss (in form of truncation), so typically from a less precise type, to a more precise one. The :eql:type:`int64` to :eql:type:`float64` is a notable exception, which can suffer from truncation of significant digits for very large integer values. There are a few scenarios when *implicit casts* can occur: 1) Passing arguments that don't match exactly the types in the function signature: .. code-block:: edgeql-repl db> with x := 12.34 ... select math::ceil(x); {13} The function :eql:func:`math::ceil` only takes :eql:type:`int64`, :eql:type:`float64`, :eql:type:`bigint`, or :eql:type:`decimal` as its argument. So the :eql:type:`float32` value will be *implicitly cast* into a :eql:type:`float64` in order to match a valid signature. 2) Using operands that don't match exactly the types in the operator signature (this works the same way as for functions): .. code-block:: edgeql-repl db> select 1 + 2.3; {3.3} The operator :eql:op:`+ ` is defined only for operands of the same type, so in the expression above the :eql:type:`int64` value ``1`` is *implicitly cast* into a :eql:type:`float64` in order to match the other operand and produce a valid signature. 3) Mixing different numeric types in a set: .. code-block:: edgeql-repl db> select {1, 2.3, 4.5} is float64; {true, true, true} All elements in a set have to be of the same type, so the values are cast into :eql:type:`float64` as that happens to be the common type to which all the set elements can be *implicitly cast*. This would work out the same way if :eql:op:`union` was used instead: .. code-block:: edgeql-repl db> select (1 union 2.3 union 4.5) is float64; {true, true, true} If *implicit* casting is supported for a given pair of types, *assignment* and *explicit* casting of those types is also supported. .. _ref_eql_casts_table: Casting Table ------------- .. note:: The UUID-to-object cast is only available since |EdgeDB| 3.0+. .. This file is automatically generated by `make casts`: .. csv-table:: :file: casts.csv :class: vertheadertable - ``<>`` - can be cast explicitly - ``:=`` - assignment cast is supported - ``impl`` - implicit cast is supported - ``*``- When casting a float type to an integer type, the fractional value naturally cannot be preserved after the cast. When executing this cast, we round to the nearest integer, rounding ties to the nearest even (e.g., 1.5 is rounded up to 2; 2.5 is also rounded to 2). ================================================ FILE: docs/reference/reference/edgeql/delete.rst ================================================ .. _ref_eql_statements_delete: Delete ====== :eql-statement: :eql-haswith: ``delete`` -- remove objects from a database. .. eql:synopsis:: [ with [, ...] ] delete [ filter ] [ order by [direction] [then ...] ] [ offset ] [ limit ] ; :eql:synopsis:`with` Alias declarations. The ``with`` clause allows specifying module aliases as well as expression aliases that can be referenced by the ``delete`` statement. See :ref:`ref_eql_statements_with` for more information. :eql:synopsis:`delete ...` The entire :eql:synopsis:`delete ...` statement is syntactic sugar for ``delete (select ...)``. Therefore, the base :eql:synopsis:`` and the following :eql:synopsis:`filter`, :eql:synopsis:`order by`, :eql:synopsis:`offset`, and :eql:synopsis:`limit` clauses shape the set to be deleted the same way an explicit :eql:stmt:`select` would. Output ~~~~~~ On successful completion, a ``delete`` statement returns the set of deleted objects. Examples ~~~~~~~~ Here's a simple example of deleting a specific user: .. code-block:: edgeql with module example delete User filter User.name = 'Alice Smith'; And here's the equivalent ``delete (select ...)`` statement: .. code-block:: edgeql with module example delete (select User filter User.name = 'Alice Smith'); .. list-table:: :class: seealso * - **See also** * - :ref:`EdgeQL > Delete ` * - :ref:`Cheatsheets > Deleting data ` ================================================ FILE: docs/reference/reference/edgeql/describe.rst ================================================ .. _ref_eql_statements_describe: Describe ======== :eql-statement: ``describe`` -- provide human-readable description of a schema or a schema object .. eql:synopsis:: describe schema [ as {ddl | sdl | test [ verbose ]} ]; describe [ as {ddl | sdl | text [ verbose ]} ]; # where is one of object annotation constraint function link module property scalar type type Description ----------- ``describe`` generates a human-readable description of a schema object. The output of a ``describe`` command is a :eql:type:`str` , although it cannot be used as an expression in queries. There are three output formats to choose from: :eql:synopsis:`as ddl` Provide a valid :ref:`DDL ` definition. The :ref:`DDL ` generated is a complete valid definition of the particular schema object assuming all the other referenced schema objects already exist. This is the default format. :eql:synopsis:`as sdl` Provide an :ref:`SDL ` definition. The :ref:`SDL ` generated is a complete valid definition of the particular schema object assuming all the other referenced schema objects already exist. :eql:synopsis:`as text [verbose]` Provide a human-oriented definition. The human-oriented definition generated is similar to :ref:`SDL `, but it includes all the details that are inherited (if any). The :eql:synopsis:`verbose` mode enables displaying additional details, such as :ref:`annotations ` and :ref:`constraints `, which are otherwise omitted. When the ``describe`` command is used with the :eql:synopsis:`schema` the result is a definition of the entire database schema. Only the :eql:synopsis:`as ddl` option is available for schema description. The ``describe`` command can specify the type of schema object that it should generate the description of: :eql:synopsis:`object ` Match any module level schema object with the specified *name*. This is the most general use of the ``describe`` command. It does not match :ref:`modules ` (and other globals that cannot be uniquely identified just by the name). :eql:synopsis:`annotation ` Match only :ref:`annotations ` with the specified *name*. :eql:synopsis:`constraint ` Match only :ref:`constraints ` with the specified *name*. :eql:synopsis:`function ` Match only :ref:`functions ` with the specified *name*. :eql:synopsis:`link ` Match only :ref:`links ` with the specified *name*. :eql:synopsis:`module ` Match only :ref:`modules ` with the specified *name*. :eql:synopsis:`property ` Match only :ref:`properties ` with the specified *name*. :eql:synopsis:`scalar type ` Match only :ref:`scalar types ` with the specified *name*. :eql:synopsis:`type ` Match only :ref:`object types ` with the specified *name*. Examples -------- Consider the following schema: .. code-block:: sdl abstract type Named { required name: str { delegated constraint exclusive; } } type User extending Named { required email: str { annotation title := 'Contact email'; } } Here are some examples of a ``describe`` command: .. code-block:: edgeql-repl db> describe object User; { "create type default::User extending default::Named { create required single property email -> std::str { create annotation std::title := 'Contact email'; }; };" } db> describe object User as sdl; { "type default::User extending default::Named { required single property email -> std::str { annotation std::title := 'Contact email'; }; };" } db> describe object User as text; { 'type default::User extending default::Named { required single link __type__ -> schema::Type { readonly := true; }; required single property email -> std::str; required single property id -> std::uuid { readonly := true; }; required single property name -> std::str; };' } db> describe object User as text verbose; { "type default::User extending default::Named { required single link __type__ -> schema::Type { readonly := true; }; required single property email -> std::str { annotation std::title := 'Contact email'; }; required single property id -> std::uuid { readonly := true; constraint std::exclusive; }; required single property name -> std::str { constraint std::exclusive; }; };" } db> describe schema; { "create module default if not exists; create abstract type default::Named { create required single property name -> std::str { create delegated constraint std::exclusive; }; }; create type default::User extending default::Named { create required single property email -> std::str { create annotation std::title := 'Contact email'; }; };" } The ``describe`` command also warns you if there are standard library matches that are masked by some user-defined object. Consider the following schema: .. code-block:: sdl module default { function len(v: tuple) -> float64 using ( select (v.0 ^ 2 + v.1 ^ 2) ^ 0.5 ); } So within the ``default`` module the user-defined function ``len`` (computing the length of a vector) masks the built-ins: .. code-block:: edgeql-repl db> describe function len as text; { 'function default::len(v: tuple) -> std::float64 using (select (((v.0 ^ 2) + (v.1 ^ 2)) ^ 0.5) ); # The following builtins are masked by the above: # function std::len(array: array) -> std::int64 { # volatility := \'Immutable\'; # annotation std::description := \'A polymorphic function to calculate a "length" of its first argument.\'; # using sql $$ # SELECT cardinality("array")::bigint # $$ # ;}; # function std::len(bytes: std::bytes) -> std::int64 { # volatility := \'Immutable\'; # annotation std::description := \'A polymorphic function to calculate a "length" of its first argument.\'; # using sql $$ # SELECT length("bytes")::bigint # $$ # ;}; # function std::len(str: std::str) -> std::int64 { # volatility := \'Immutable\'; # annotation std::description := \'A polymorphic function to calculate a "length" of its first argument.\'; # using sql $$ # SELECT char_length("str")::bigint # $$ # ;};', } ================================================ FILE: docs/reference/reference/edgeql/eval.rst ================================================ .. _ref_eql_fundamentals_queries: ==================== Evaluation algorithm ==================== EdgeQL is a functional language in the sense that every expression is a composition of one or more queries. Queries can be *explicit*, such as a :eql:stmt:`select` statement, or *implicit*, as dictated by the semantics of a function, operator or a statement clause. An implicit ``select`` subquery is assumed in the following situations: - expressions passed as an argument for an aggregate function parameter or operand; - the right side of the assignment operator (``:=``) in expression aliases and :ref:`shape element declarations `; - the majority of statement clauses. A nested query is called a *subquery*. Here, the phrase "*apearing directly in the query*" means "appearing directly in the query rather than in the subqueries". .. _ref_eql_fundamentals_eval_algo: A query is evaluated recursively using the following procedure: 1. Make a list of simple paths (i.e., paths that begin with a set reference) appearing directly the query. For every path in the list, find all paths which begin with the same set reference and treat their longest common prefix as an equivalent set reference. Example: .. code-block:: edgeql select ( User.firstname, User.friends.firstname, User.friends.lastname, Issue.priority.name, Issue.number, Status.name ); In the above query, the longest common prefixes are: ``User``, ``User.friends``, ``Issue``, and ``Status.name``. 2. Make a *query input list* of all unique set references which appear directly in the query (including the common path prefixes identified above). The set references and path prefixes in this list are called *input set references*, and the sets they represent are called *input sets*. Order this list such that any input references come before any other input set reference for which it is a prefix (sorting lexicographically works). 3. Compute a set of *input tuples*. - Begin with a set containing a single empty tuple. - For each input set reference, we compute a *dependent* Cartesian product of the input tuple set (``X``) so far and the input set ``Y`` being considered. In this dependent product, we pair each tuple ``x`` in the input tuple set ``X`` with each element of the subset of the input set ``Y`` corresponding to the tuple ``x``. (For example, in the above example, computing the dependent product of User and User.friends would pair each user with all of their friends.) (Mathematically, ``X' = {(x, y) | x ∈ X, y ∈ f(x)}``, if ``f(x)`` selects the appropriate subset.) The set produced becomes the new input tuple set and we continue down the list. - As a caveat to the above, if an input set appears exclusively as an :ref:`optional ` argument, it produces pairs with a placeholder value ``Missing`` instead of an empty Cartesian product in the above set. (Mathematically, this corresponds to having ``f(x) = {Missing}`` whenever it would otherwise produce an empty set.) 4. Iterate over the set of input tuples, and on every iteration: - in the query and its subqueries, replace each input set reference with the corresponding value from the input tuple or an empty set if the value is ``Missing``; - evaluate the query expression in the order of precedence using the following rules: * subqueries are evaluated recursively from step 1; * a function or an operator is evaluated in a loop over a Cartesian product of its non-aggregate arguments (empty ``optional`` arguments are excluded from the product); aggregate arguments are passed as a whole set; the results of the invocations are collected to form a single set. 5. Collect the results of all iterations to obtain the final result set. ================================================ FILE: docs/reference/reference/edgeql/for.rst ================================================ .. _ref_eql_statements_for: For === :eql-statement: :eql-haswith: ``for``--compute a union of subsets based on values of another set .. eql:synopsis:: [ with [, ...] ] for in union ; :eql:synopsis:`for in ` The ``for`` clause has this general form: .. TODO: rewrite this .. eql:synopsis:: for in where :eql:synopsis:`` is a :ref:`literal `, a :ref:`function call `, a :ref:`set constructor `, a :ref:`path `, or any parenthesized expression or statement. :eql:synopsis:`union ` The ``union`` clause of the ``for`` statement has this general form: .. TODO: rewrite this .. eql:synopsis:: union Here, :eql:synopsis:`` is an arbitrary expression that is evaluated for every element in a set produced by evaluating the ``for`` clause. The results of the evaluation are appended to the result set. .. _ref_eql_forstatement: Usage of ``for`` statement ++++++++++++++++++++++++++ ``for`` statement has some powerful features that deserve to be considered in detail separately. However, the common core is that ``for`` iterates over elements of some arbitrary expression. Then for each element of the iterator some set is computed and combined via a :eql:op:`union` with the other such computed sets. The simplest use case is when the iterator is given by a set expression and it follows the general form of ``for x in A ...``: .. code-block:: edgeql with module example # the iterator is an explicit set of tuples, so x is an # element of this set, i.e. a single tuple for x in { (name := 'Alice', theme := 'fire'), (name := 'Bob', theme := 'rain'), (name := 'Carol', theme := 'clouds'), (name := 'Dave', theme := 'forest') } # typically this is used with an INSERT, DELETE or UPDATE union ( insert User { name := x.name, theme := x.theme, } ); Since ``x`` is an element of a set it is guaranteed to be a non-empty singleton in all of the expressions used by the ``union`` and later clauses of ``for``. Another variation this usage of ``for`` is a bulk ``update``. There are cases when a bulk update involves a lot of external data that cannot be derived from the objects being updated. That is a good use-case when a ``for`` statement is appropriate. .. code-block:: edgeql # Here's an example of an update that is awkward to # express without the use of FOR statement with module example update User filter .name in {'Alice', 'Bob', 'Carol', 'Dave'} set { theme := 'red' if .name = 'Alice' else 'star' if .name = 'Bob' else 'dark' if .name = 'Carol' else 'strawberry' }; # Using a FOR statement, the above update becomes simpler to # express or review for a human. with module example for x in { (name := 'Alice', theme := 'red'), (name := 'Bob', theme := 'star'), (name := 'Carol', theme := 'dark'), (name := 'Dave', theme := 'strawberry') } union ( update User filter .name = x.name set { theme := x.theme } ); When updating data that mostly or completely depends on the objects being updated there's no need to use the ``for`` statement and it is not advised to use it for performance reasons. .. code-block:: edgeql with module example update User filter .name in {'Alice', 'Bob', 'Carol', 'Dave'} set { theme := 'halloween' }; # The above can be accomplished with a for statement, # but it is not recommended. with module example for x in {'Alice', 'Bob', 'Carol', 'Dave'} union ( update User filter .name = x set { theme := 'halloween' } ); Another example of using a ``for`` statement is working with link properties. Specifying the link properties either at creation time or in a later step with an update is often simpler with a ``for`` statement helping to associate the link target to the link property in an intuitive manner. .. code-block:: edgeql # Expressing this without for statement is fairly tedious. with module example, U2 := User for x in { ( name := 'Alice', friends := [('Bob', 'coffee buff'), ('Carol', 'dog person')] ), ( name := 'Bob', friends := [('Alice', 'movie buff'), ('Dave', 'cat person')] ) } union ( update User filter .name = x.name set { friends := assert_distinct( ( for f in array_unpack(x.friends) union ( select U2 {@nickname := f.1} filter U2.name = f.0 ) ) ) } ); .. list-table:: :class: seealso * - **See also** * - :ref:`EdgeQL > For ` ================================================ FILE: docs/reference/reference/edgeql/functions.rst ================================================ .. _ref_reference_function_call: Function calls ============== |Gel| provides a number of functions in the :ref:`standard library `. It is also possible for users to :ref:`define their own ` functions. The syntax for a function call is as follows: .. eql:synopsis:: "(" [ [, , ...]] ")" # where is: | := Here :eql:synopsis:`` is a possibly qualified name of a function, and :eql:synopsis:`` is an *expression* optionally prefixed with an argument name and the assignment operator (``:=``) for :ref:`named only ` arguments. For example, the following computes the length of a string ``'foo'``: .. code-block:: edgeql-repl db> select len('foo'); {3} And here's an example of using a *named only* argument to provide a default value: .. code-block:: edgeql-repl db> select array_get(['hello', 'world'], 10, default := 'n/a'); {'n/a'} .. list-table:: :class: seealso * - **See also** * - :ref:`Schema > Functions ` * - :ref:`SDL > Functions ` * - :ref:`DDL > Functions ` * - :ref:`Introspection > Functions ` * - :ref:`Cheatsheets > Functions ` ================================================ FILE: docs/reference/reference/edgeql/group.rst ================================================ .. _ref_eql_statements_group: Group ===== :eql-statement: :eql-haswith: ``group``--partition a set into subsets based on one or more keys .. eql:synopsis:: [ with [, ...] ] group [ := ] [ using := , [, ...] ] by , ... ; # where a is one of { , ... } ROLLUP( , ... ) CUBE( , ... ) # where a is one of () ( , ... ) # where a is one of . :eql:synopsis:`group ` The ``group`` clause sets up the input set that will be operated on. Much like in :eql:stmt:`select` it is possible to define an ad-hoc alias at this stage to make referring to the starting set concisely. :eql:synopsis:`using := ` The ``using`` clause defines one or more aliases which can then be used as part of the grouping key. If the :eql:synopsis:`by` clause only refers to :eql:synopsis:`.` the ``using`` clause is optional. :eql:synopsis:`by ` The ``by`` clause sepecifies which parameters will be used to partition the starting set. There are only two basic components for defining :eql:synopsis:``: references to :eql:synopsis:`` defined in the :eql:synopsis:`using` clause or by references to the short-path format of :eql:synopsis:`.`. The :eql:synopsis:`.` has to refer to properties or links immediately present on the type of starting set. The basic building blocks can also be combined by using parentheses ``( )`` to indicate that partitioning will happen based on several parameters at once. It is also possible to specify *grouping sets*, which are denoted using curly braces ``{ }``. The results will contain different partitioning based on each of the grouping set elements. When there are multiple top-level grouping-elements then the cartesian product of them is taken to determine the grouping set. Thus ``a, {b, c}`` is equivalent to ``{(a, b), (a, c)}`` grouping sets. :eql:synopsis:`ROLLUP` and :eql:synopsis:`CUBE` are a shorthand to specify particular grouping sets. :eql:synopsis:`ROLLUP` groups by all prefixes of a list of elements, so ``ROLLUP (a, b, c)`` is equivalent to ``{(), (a), (a, b), (a, b, c)}``. :eql:synopsis:`CUBE` groups by all elements of the power set, so ``CUBE (a, b)`` is equivalent to ``{(), (a), (b), (a, b)}``. Output ------ The ``group`` statement partitions a starting set into subsets based on some specified parameters. The output is organized into a set of :ref:`free objects ` of the following structure: .. eql:synopsis:: { "key": { := [, ...] }, "grouping": , "elements": , } :eql:synopsis:`"key"` The :eql:synopsis:`"key"` contains another :ref:`free object `, which contains all the aliases or field names used as the key together with the specific values these parameters take for this particular subset. :eql:synopsis:`"grouping"` The :eql:synopsis:`"grouping"` contains a :eql:type:`str` set of all the names of the parameters used as the key for this particular subset. This is especially useful when using grouping sets and the parameters used in the key are not the same for all partitionings. :eql:synopsis:`"elements"` The :eql:synopsis:`"elements"` contains the actual subset of values that match the :eql:synopsis:`"key"`. Examples -------- Here's a simple example without using any aggregation or any further processing: .. code-block:: edgeql-repl db> group Movie {title} by .release_year; { { key: {release_year: 2016}, grouping: {'release_year'}, elements: { default::Movie {title: 'Captain America: Civil War'}, default::Movie {title: 'Doctor Strange'}, }, }, { key: {release_year: 2017}, grouping: {'release_year'}, elements: { default::Movie {title: 'Spider-Man: Homecoming'}, default::Movie {title: 'Thor: Ragnarok'}, }, }, { key: {release_year: 2018}, grouping: {'release_year'}, elements: {default::Movie {title: 'Ant-Man and the Wasp'}}, }, { key: {release_year: 2019}, grouping: {'release_year'}, elements: {default::Movie {title: 'Spider-Man: No Way Home'}}, }, { key: {release_year: 2021}, grouping: {'release_year'}, elements: {default::Movie {title: 'Black Widow'}}, }, ... } Or we can group by an expression instead, such as whether the title starts with a vowel or not: .. code-block:: edgeql-repl db> with ... # Apply the group query only to more recent movies ... M := (select Movie filter .release_year > 2015) ... group M {title} ... using vowel := re_test('(?i)^[aeiou]', .title) ... by vowel; { { key: {vowel: false}, grouping: {'vowel'}, elements: { default::Movie {title: 'Thor: Ragnarok'}, default::Movie {title: 'Doctor Strange'}, default::Movie {title: 'Spider-Man: Homecoming'}, default::Movie {title: 'Captain America: Civil War'}, default::Movie {title: 'Black Widow'}, default::Movie {title: 'Spider-Man: No Way Home'}, }, }, { key: {vowel: true}, grouping: {'vowel'}, elements: {default::Movie {title: 'Ant-Man and the Wasp'}}, }, } It is also possible to group scalars instead of objects, in which case you need to define an ad-hoc alias to refer to the scalar set in order to specify how it will be grouped: .. code-block:: edgeql-repl db> with ... # Apply the group query only to more recent movies ... M := (select Movie filter .release_year > 2015) ... group T := M.title ... using vowel := re_test('(?i)^[aeiou]', T) ... by vowel; { { key: {vowel: false}, grouping: {'vowel'}, elements: { 'Captain America: Civil War', 'Doctor Strange', 'Spider-Man: Homecoming', 'Thor: Ragnarok', 'Spider-Man: No Way Home', 'Black Widow', }, }, { key: {vowel: true}, grouping: {'vowel'}, elements: {'Ant-Man and the Wasp'} }, } Often the results of ``group`` are immediately used in a :eql:stmt:`select` statement to provide some kind of analytical results: .. code-block:: edgeql-repl db> with ... # Apply the group query only to more recent movies ... M := (select Movie filter .release_year > 2015), ... groups := ( ... group M {title} ... using vowel := re_test('(?i)^[aeiou]', .title) ... by vowel ... ) ... select groups { ... starts_with_vowel := .key.vowel, ... count := count(.elements), ... mean_title_length := ... round(math::mean(len(.elements.title))) ... }; { {starts_with_vowel: false, count: 6, mean_title_length: 18}, {starts_with_vowel: true, count: 1, mean_title_length: 20}, } It's possible to group by more than one parameter. For example, we can add the release decade to whether the ``title`` starts with a vowel: .. code-block:: edgeql-repl db> with ... # Apply the group query only to more recent movies ... M := (select Movie filter .release_year > 2015), ... groups := ( ... group M {title} ... using ... vowel := re_test('(?i)^[aeiou]', .title), ... decade := .release_year // 10 ... by vowel, decade ... ) ... select groups { ... key := .key {vowel, decade}, ... count := count(.elements), ... mean_title_length := ... math::mean(len(.elements.title)) ... }; { { key: {vowel: false, decade: 201}, count: 5, mean_title_length: 19.8, }, { key: {vowel: false, decade: 202}, count: 1, mean_title_length: 11, }, { key: {vowel: true, decade: 201}, count: 1, mean_title_length: 20 }, } Having more than one grouping parameter opens up the possibility to using *grouping sets* to see the way grouping parameters interact with the analytics we're gathering: .. code-block:: edgeql-repl db> with ... # Apply the group query only to more recent movies ... M := (select Movie filter .release_year > 2015), ... groups := ( ... group M {title} ... using ... vowel := re_test('(?i)^[aeiou]', .title), ... decade := .release_year // 10 ... by CUBE(vowel, decade) ... ) ... select groups { ... key := .key {vowel, decade}, ... grouping, ... count := count(.elements), ... mean_title_length := ... (math::mean(len(.elements.title))) ... } order by array_agg(.grouping); { { key: {vowel: {}, decade: {}}, grouping: {}, count: 7, mean_title_length: 18.571428571428573, }, { key: {vowel: {}, decade: 202}, grouping: {'decade'}, count: 1, mean_title_length: 11, }, { key: {vowel: {}, decade: 201}, grouping: {'decade'}, count: 6, mean_title_length: 19.833333333333332, }, { key: {vowel: true, decade: {}}, grouping: {'vowel'}, count: 1, mean_title_length: 20, }, { key: {vowel: false, decade: {}}, grouping: {'vowel'}, count: 6, mean_title_length: 18.333333333333332, }, { key: {vowel: false, decade: 201}, grouping: {'vowel', 'decade'}, count: 5, mean_title_length: 19.8, }, { key: {vowel: true, decade: 201}, grouping: {'vowel', 'decade'}, count: 1, mean_title_length: 20, }, { key: {vowel: false, decade: 202}, grouping: {'vowel', 'decade'}, count: 1, mean_title_length: 11, }, } .. list-table:: :class: seealso * - **See also** * - :ref:`EdgeQL > Group ` ================================================ FILE: docs/reference/reference/edgeql/index.rst ================================================ .. _ref_eql_statements: EdgeQL ======== Statements in EdgeQL are a kind of an *expression* that has one or more ``clauses`` and is used to retrieve or modify data in a database. Query statements: * :eql:stmt:`select` Retrieve data from a database and compute arbitrary expressions. * :eql:stmt:`for` Compute an expression for every element of an input set and concatenate the results. * :eql:stmt:`group` Group data into subsets by keys. Data modification statements: * :eql:stmt:`insert` Create new object in a database. * :eql:stmt:`update` Update objects in a database. * :eql:stmt:`delete` Remove objects from a database. Transaction control statements: * :eql:stmt:`start transaction` Start a transaction. * :eql:stmt:`commit` Commit the current transaction. * :eql:stmt:`rollback` Abort the current transaction. * :eql:stmt:`declare savepoint` Declare a savepoint within the current transaction. * :eql:stmt:`rollback to savepoint` Rollback to a savepoint within the current transaction. * :eql:stmt:`release savepoint` Release a previously declared savepoint. Session state control statements: * :eql:stmt:`set` and :eql:stmt:`reset`. Introspection command: * :eql:stmt:`describe`. Performance analysis statement: * :eql:stmt:`analyze`. .. toctree:: :maxdepth: 3 :hidden: lexical eval shapes paths casts functions cardinality volatility select insert update delete for group with analyze tx_start tx_commit tx_rollback tx_sp_declare tx_sp_release tx_sp_rollback sess_set_alias sess_reset_alias describe ================================================ FILE: docs/reference/reference/edgeql/insert.rst ================================================ .. _ref_eql_statements_insert: Insert ====== :eql-statement: :eql-haswith: ``insert`` -- create a new object in a database .. eql:synopsis:: [ with [ , ... ] ] insert [ ] [ unless conflict [ on [ else ] ] ] ; Description ----------- ``insert`` inserts a new object into a database. When evaluating an ``insert`` statement, *expression* is used solely to determine the *type* of the inserted object and is not evaluated in any other way. If a value for a *required* link is evaluated to an empty set, an error is raised. It is possible to insert multiple objects by putting the ``insert`` into a :eql:stmt:`for` statement. See :ref:`ref_eql_forstatement` for more details. :eql:synopsis:`with` Alias declarations. The ``with`` clause allows specifying module aliases as well as expression aliases that can be referenced by the :eql:stmt:`update` statement. See :ref:`ref_eql_statements_with` for more information. :eql:synopsis:`` An arbitrary expression returning a set of objects to be updated. .. eql:synopsis:: insert [ "{" := [, ...] "}" ] .. _ref_eql_statements_conflict: :eql:synopsis:`unless conflict [ on ]` :index: unless conflict Handler of conflicts. This clause allows to handle specific conflicts arising during execution of ``insert`` without producing an error. If the conflict arises due to exclusive constraints on the properties specified by *property-expr*, then instead of failing with an error the ``insert`` statement produces an empty set (or an alternative result). The exclusive constraint on ```` cannot be defined on a parent type. The specified *property-expr* may be either a reference to a property (or link) or a tuple of references to properties (or links). A caveat, however, is that ``unless conflict`` will not prevent conflicts caused between multiple DML operations in the same query; inserting two conflicting objects (through use of ``for`` or simply with two ``insert`` statements) will cause a constraint error. Example: .. code-block:: edgeql insert User { email := 'user@example.org' } unless conflict on .email .. code-block:: edgeql insert User { first := 'Jason', last := 'Momoa' } unless conflict on (.first, .last) :eql:synopsis:`else ` Alternative result in case of conflict. This clause can only appear after ``unless conflict`` clause. Any valid expression can be specified as the *alternative*. When a conflict arises, the result of the ``insert`` becomes the *alternative* expression (instead of the default ``{}``). In order to refer to the conflicting object in the *alternative* expression, the name used in the ``insert`` must be used (see :ref:`example below `). Outputs ------- The result of an ``insert`` statement used as an *expression* is a singleton set containing the inserted object. Examples -------- Here's a simple example of an ``insert`` statement creating a new user: .. code-block:: edgeql with module example insert User { name := 'Bob Johnson' }; ``insert`` is not only a statement, but also an expression and as such is has a value of the set of objects that has been created. .. code-block:: edgeql with module example insert Issue { number := '100', body := 'Fix errors in insert', owner := ( select User filter User.name = 'Bob Johnson' ) }; It is possible to create nested objects in a single ``insert`` statement as an atomic operation. .. code-block:: edgeql with module example insert Issue { number := '101', body := 'Nested insert', owner := ( insert User { name := 'Nested User' } ) }; The above statement will create a new ``Issue`` as well as a new ``User`` as the owner of the ``Issue``. It will also return the new ``Issue`` linked to the new ``User`` if the statement is used as an expression. It is also possible to create new objects based on some existing data either provided as an explicit list (possibly automatically generated by some tool) or a query. A ``for`` statement is the basis for this use-case and ``insert`` is simply the expression in the ``union`` clause. .. code-block:: edgeql # example of a bulk insert of users based on explicitly provided # data with module example for x in {'Alice', 'Bob', 'Carol', 'Dave'} union (insert User { name := x }); # example of a bulk insert of issues based on a query with module example, Elvis := (select User filter .name = 'Elvis'), Open := (select Status filter .name = 'Open') for Q in (select User filter .name ilike 'A%') union (insert Issue { name := Q.name + ' access problem', body := 'This user was affected by recent system glitch', owner := Elvis, status := Open }); .. _ref_eql_statements_insert_unless: There's an important use-case where it is necessary to either insert a new object or update an existing one identified with some key. This is what the ``unless conflict`` clause allows: .. code-block:: edgeql with module people select ( insert Person { name := "Łukasz Langa", is_admin := true } unless conflict on .name else ( update Person set { is_admin := true } ) ) { name, is_admin }; .. note:: Statements in EdgeQL represent an atomic interaction with the database. From the point of view of a statement all side-effects (such as database updates) happen after the statement is executed. So as far as each statement is concerned, it is some purely functional expression evaluated on some specific input (database state). .. list-table:: :class: seealso * - **See also** * - :ref:`EdgeQL > Insert ` * - :ref:`Cheatsheets > Inserting data ` ================================================ FILE: docs/reference/reference/edgeql/lexical.rst ================================================ .. _ref_eql_lexical: Lexical structure ================= Every EdgeQL command is composed of a sequence of *tokens*, terminated by a semicolon (``;``). The types of valid tokens as well as their order is determined by the syntax of the particular command. EdgeQL is case sensistive except for *keywords* (in the examples the keywords are written in upper case as a matter of convention). There are several kinds of tokens: *keywords*, *identifiers*, *literals* (constants) and *symbols* (operators and punctuation). Tokens are normally separated by whitespace (space, tab, newline) or comments. Identifiers ----------- There are two ways of writing identifiers in EdgeQL: plain and quoted. The plain identifiers are similar to many other languages, they are alphanumeric with underscores and cannot start with a digit. The quoted identifiers start and end with a *backtick* ```quoted.identifier``` and can contain any characters inside with a few exceptions. They must not start with an ampersand (``@``) or contain a double colon (``::``). If there's a need to include a backtick character as part of the identifier name a double-backtick sequence (``````) should be used: ```quoted``identifier``` will result in the actual identifier being ``quoted`identifier``. .. productionlist:: edgeql identifier: `plain_ident` | `quoted_ident` plain_ident: `ident_first` `ident_rest`* ident_first: ident_rest: quoted_ident: "`" `qident_first` `qident_rest`* "`" qident_first: qident_rest: Quoted identifiers are usually needed to represent module names that contain a dot (``.``) or to distinguish *names* from *reserved keywords* (for instance to allow referring to a link named "order" as ```order```). .. _ref_eql_lexical_names: Names and keywords ------------------ .. TODO:: This section needs a significant update. There are a number of *reserved* and *unreserved* keywords in EdgeQL. Every identifier that is not a *reserved* keyword is a valid *name*. *Names* are used to refer to concepts, links, link properties, etc. .. TODO: update this for "branch" .. productionlist:: edgeql short_name: `not_keyword_ident` | `quoted_ident` not_keyword_ident: keyword: `reserved_keyword` | `unreserved_keyword` reserved_keyword: case insensitive sequence matching any : of the following : "AGGREGATE" | "ALTER" | "AND" | : "ANY" | "COMMIT" | "CREATE" | : "DELETE" | "DETACHED" | "DISTINCT" | : "DROP" | "ELSE" | "EMPTY" | "EXISTS" | : "FALSE" | "FILTER" | "FUNCTION" | : "GET" | "GROUP" | "IF" | "ILIKE" | : "IN" | "INSERT" | "IS" | "LIKE" | : "LIMIT" | "MODULE" | "NOT" | "OFFSET" | : "OR" | "ORDER" | "OVER" | : "PARTITION" | "ROLLBACK" | "SELECT" | : "SET" | "SINGLETON" | "START" | "TRUE" | : "UPDATE" | "UNION" | "WITH" unreserved_keyword: case insensitive sequence matching any : of the following : "ABSTRACT" | "ACTION" | "AFTER" | : "ARRAY" | "AS" | "ASC" | "ATOM" | : "ANNOTATION" | "BEFORE" | "BY" | : "CONCEPT" | "CONSTRAINT" | : "DATABASE" | "DESC" | "EVENT" | : "EXTENDING" | "FINAL" | "FIRST" | : "FOR" | "FROM" | "INDEX" | : "INITIAL" | "LAST" | "LINK" | : "MAP" | "MIGRATION" | "OF" | "ON" | : "POLICY" | "PROPERTY" | : "REQUIRED" | "RENAME" | "TARGET" | : "THEN" | "TO" | "TRANSACTION" | : "TUPLE" | "VALUE" | "VIEW" Fully-qualified names consist of a module, ``::``, and a short name. They can be used in most places where a short name can appear (such as paths and shapes). .. productionlist:: edgeql name: `short_name` | `fq_name` fq_name: `short_name` "::" `short_name` | : `short_name` "::" `unreserved_keyword` .. _ref_eql_lexical_const: Constants --------- A number of scalar types have literal constant expressions. .. _ref_eql_lexical_str: Strings ^^^^^^^ Production rules for :eql:type:`str` literals: .. productionlist:: edgeql string: `str` | `raw_str` str: "'" `str_content`* "'" | '"' `str_content`* '"' raw_str: "r'" `raw_content`* "'" | : 'r"' `raw_content`* '"' | : `dollar_quote` `raw_content`* `dollar_quote` raw_content: dollar_quote: "$" `q_char0`? `q_char`* "$" q_char0: "A"..."Z" | "a"..."z" | "_" q_char: "A"..."Z" | "a"..."z" | "_" | "0"..."9" str_content: | `unicode` | `str_escapes` unicode: str_escapes: The inclusion of "high ASCII" character in :token:`edgeql:q_char` in practice reflects the ability to use some of the letters with diacritics like ``ò`` or ``ü`` in the dollar-quote delimiter. Here's a list of valid :token:`edgeql:str_escapes`: .. _ref_eql_lexical_str_escapes: +--------------------+---------------------------------------------+ | Escape Sequence | Meaning | +====================+=============================================+ | ``\[newline]`` | Backslash and all whitespace up to next | | | non-whitespace character is ignored | +--------------------+---------------------------------------------+ | ``\\`` | Backslash (\\) | +--------------------+---------------------------------------------+ | ``\'`` | Single quote (') | +--------------------+---------------------------------------------+ | ``\"`` | Double quote (") | +--------------------+---------------------------------------------+ | ``\b`` | ASCII backspace (``\x08``) | +--------------------+---------------------------------------------+ | ``\f`` | ASCII form feed (``\x0C``) | +--------------------+---------------------------------------------+ | ``\n`` | ASCII newline (``\x0A``) | +--------------------+---------------------------------------------+ | ``\r`` | ASCII carriage return (``\x0D``) | +--------------------+---------------------------------------------+ | ``\t`` | ASCII tabulation (``\x09``) | +--------------------+---------------------------------------------+ | ``\xhh`` | Character with hex value hh | +--------------------+---------------------------------------------+ | ``\uhhhh`` | Character with 16-bit hex value hhhh | +--------------------+---------------------------------------------+ | ``\Uhhhhhhhh`` | Character with 32-bit hex value hhhhhhhh | +--------------------+---------------------------------------------+ Here's some examples of regular strings using escape sequences .. code-block:: edgeql-repl db> select 'hello ... world'; {'hello world'} db> select "hello\nworld"; {'hello world'} db> select 'hello \ ... world'; {'hello world'} db> select 'https://geldata.com/\ ... docs/edgeql/lexical\ ... #constants'; {'https://geldata.com/docs/edgeql/lexical#constants'} db> select 'hello \\ world'; {'hello \ world'} db> select 'hello \'world\''; {"hello 'world'"} db> select 'hello \x77orld'; {'hello world'} db> select 'hello \u0077orld'; {'hello world'} .. _ref_eql_lexical_raw: Raw strings don't have any specially interpreted symbols; they contain all the symbols between the quotes exactly as typed. .. code-block:: edgeql-repl db> select r'hello \\ world'; {'hello \\ world'} db> select r'hello \ ... world'; {'hello \ world'} db> select r'hello ... world'; {'hello world'} .. _ref_eql_lexical_dollar_quoting: Dollar-quoted String Constants ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ A special case of raw strings are *dollar-quoted* strings. They allow using either kind of quote symbols ``'`` or ``"`` as part of the string content without the quotes terminating the string. In fact, because the *dollar-quote* delimiter sequences can have arbitrary alphanumeric additional fillers, it is always possible to surround any content with *dollar-quotes* in an unambiguous manner: .. code-block:: edgeql-repl db> select $$hello ... world$$; {'hello world'} db> select $$hello\nworld$$; {'hello\nworld'} db> select $$"hello" 'world'$$; {"\"hello\" 'world'"} db> select $a$hello$$world$$$a$; {'hello$$world$$'} More specifically, a delimiter: * Must start with an ASCII letter or underscore * Has following characters that can be digits 0-9, underscores or ASCII letters .. _ref_eql_lexical_bytes: Bytes ^^^^^ Production rules for :eql:type:`bytes` literals: .. productionlist:: edgeql bytes: "b'" `bytes_content`* "'" | 'b"' `bytes_content`* '"' bytes_content: | `ascii` | `bytes_escapes` ascii: bytes_escapes: Here's a list of valid :token:`edgeql:bytes_escapes`: .. _ref_eql_lexical_bytes_escapes: +--------------------+---------------------------------------------+ | Escape Sequence | Meaning | +====================+=============================================+ | ``\\`` | Backslash (\\) | +--------------------+---------------------------------------------+ | ``\'`` | Single quote (') | +--------------------+---------------------------------------------+ | ``\"`` | Double quote (") | +--------------------+---------------------------------------------+ | ``\b`` | ASCII backspace (``\x08``) | +--------------------+---------------------------------------------+ | ``\f`` | ASCII form feed (``\x0C``) | +--------------------+---------------------------------------------+ | ``\n`` | ASCII newline (``\x0A``) | +--------------------+---------------------------------------------+ | ``\r`` | ASCII carriage return (``\x0D``) | +--------------------+---------------------------------------------+ | ``\t`` | ASCII tabulation (``\x09``) | +--------------------+---------------------------------------------+ | ``\xhh`` | Character with hex value hh | +--------------------+---------------------------------------------+ Integers ^^^^^^^^ There are two kinds of integer constants: limited size (:eql:type:`int64`) and unlimited size (:eql:type:`bigint`). Unlimited size integer :eql:type:`bigint` literals are similar to a regular integer literals with an ``n`` suffix. The production rules are as follows: .. productionlist:: edgeql bigint: `integer` "n" integer: "0" | `non_zero` `digit`* non_zero: "1"..."9" digit: "0"..."9" By default all integer literals are interpreted as :eql:type:`int64`, while an explicit cast can be used to convert them to :eql:type:`int16` or :eql:type:`int32`: .. code-block:: edgeql-repl db> select 0; {0} db> select 123; {123} db> select 456; {456} db> select 789; {789} Examples of :eql:type:`bigint` literals: .. code-block:: edgeql-repl db> select 123n; {123n} db> select 12345678901234567890n; {12345678901234567890n} Real Numbers ^^^^^^^^^^^^ Just as for integers, there are two kinds of real number constants: limited precision (:eql:type:`float64`) and unlimited precision (:eql:type:`decimal`). The :eql:type:`decimal` constants have the same lexical structure as :eql:type:`float64`, but with an ``n`` suffix: .. productionlist:: edgeql decimal: `float` "n" float: `float_wo_dec` | `float_w_dec` float_wo_dec: `integer_part` `exp` float_w_dec: `integer_part` "." `decimal_part`? `exp`? integer_part: "0" | `non_zero` `digit`* decimal_part: `digit`+ exp: "e" ("+" | "-")? `digit`+ By default all float literals are interpreted as :eql:type:`float64`, while an explicit cast can be used to convert them to :eql:type:`float32`: .. code-block:: edgeql-repl db> select 0.1; {0.1} db> select 12.3; {12.3} db> select 1e3; {1000.0} db> select 1.2e-3; {0.0012} db> select 12.3; {12.3} Examples of :eql:type:`decimal` literals: .. code-block:: edgeql-repl db> select 12.3n; {12.3n} db> select 12345678901234567890.12345678901234567890n; {12345678901234567890.12345678901234567890n} db> select 12345678901234567890.12345678901234567890e-3n; {12345678901234567.89012345678901234567890n} Punctuation ----------- EdgeQL uses ``;`` as a statement separator. It is idempotent, so multiple repetitions of ``;`` don't have any additional effect. Comments -------- Comments start with a ``#`` character that is not otherwise part of a string literal and end at the end of line. Semantically, a comment is equivalent to whitespace. .. productionlist:: edgeql comment: "#" Operators --------- EdgeQL operators listed in order of precedence from lowest to highest: .. list-table:: :widths: auto :header-rows: 1 * - operator * - :eql:op:`union` * - :eql:op:`if..else` * - :eql:op:`or` * - :eql:op:`and` * - :eql:op:`not` * - :eql:op:`=`, :eql:op:`\!=`, :eql:op:`?=`, :eql:op:`?\!=` * - :eql:op:`\<`, :eql:op:`>`, :eql:op:`\<=`, :eql:op:`>=` * - :eql:op:`like`, :eql:op:`ilike` * - :eql:op:`in`, :eql:op:`not in ` * - :eql:op:`is`, :eql:op:`is not ` * - :eql:op:`+`, :eql:op:`-`, :eql:op:`++` * - :eql:op:`*`, :eql:op:`/
`, :eql:op:`//`, :eql:op:`%` * - :eql:op:`?? ` * - :eql:op:`distinct`, unary :eql:op:`-` * - :eql:op:`^` * - :eql:op:`type cast ` * - :eql:op:`array[] `, :eql:op:`str[] `, :eql:op:`json[] `, :eql:op:`bytes[] ` * - :eql:kw:`detached` ================================================ FILE: docs/reference/reference/edgeql/paths.rst ================================================ .. _ref_reference_paths: ===== Paths ===== A *path expression* (or simply a *path*) represents a set of values that are reachable when traversing a given sequence of links or properties from some source set. The result of a path expression depends on whether it terminates with a link or property reference. a) if a path *does not* end with a property reference, then it represents a unique set of objects reachable from the set at the root of the path; b) if a path *does* end with a property reference, then it represents a list of property values for every element in the unique set of objects reachable from the set at the root of the path. The syntactic form of a path is: .. eql:synopsis:: [ ... ] # where is: The individual path components are: :eql:synopsis:`` Any valid expression. :eql:synopsis:`` It can be one of the following: - ``.`` for an outgoing link reference - ``.<`` for an incoming or :ref:`backlink ` reference - ``@`` for a link property reference :eql:synopsis:`` This must be a valid link or link property name. ================================================ FILE: docs/reference/reference/edgeql/select.rst ================================================ .. _ref_eql_statements_select: Select ====== :eql-statement: :eql-haswith: ``select``--retrieve or compute a set of values. .. eql:synopsis:: [ with [, ...] ] select [ filter ] [ order by [direction] [then ...] ] [ offset ] [ limit ] ; :eql:synopsis:`filter ` The optional ``filter`` clause, where :eql:synopsis:`` is any expression that has a result of type :eql:type:`bool`. The condition is evaluated for every element in the set produced by the ``select`` clause. The result of the evaluation of the ``filter`` clause is a set of boolean values. If at least one value in this set is ``true``, the input element is included, otherwise it is eliminated from the output. .. _ref_reference_select_order: :eql:synopsis:`order by [direction] [then ...]` The optional ``order by`` clause has this general form: .. eql:synopsis:: order by [ asc | desc ] [ empty { first | last } ] [ then ... ] The ``order by`` clause produces a result set sorted according to the specified expression or expressions, which are evaluated for every element of the input set. If two elements are equal according to the leftmost *expression*, they are compared according to the next expression and so on. If two elements are equal according to all expressions, the resulting order is undefined. Each *expression* can be an arbitrary expression that results in a value of an *orderable type*. Primitive types are orderable, object types are not. Additionally, the result of each expression must be an empty set or a singleton. Using an expression that may produce more elements is a compile-time error. An optional ``asc`` or ``desc`` keyword can be added after any *expression*. If not specified ``asc`` is assumed by default. If ``empty last`` is specified, then input values that produce an empty set when evaluating an *expression* are sorted *after* all other values; if ``empty first`` is specified, then they are sorted *before* all other values. If neither is specified, ``empty first`` is assumed when ``asc`` is specified or implied, and ``empty last`` when ``desc`` is specified. :eql:synopsis:`offset ` The optional ``offset`` clause, where :eql:synopsis:`` is a *singleton expression* of an integer type. This expression is evaluated once and its result is used to skip the first *element-count* elements of the input set while producing the output. If *element-count* evaluates to an empty set, it is equivalent to ``offset 0``, which is equivalent to omitting the ``offset`` clause. If *element-count* evaluates to a value that is larger then the cardinality of the input set, an empty set is produced as the result. :eql:synopsis:`limit ` The optional ``limit`` clause, where :eql:synopsis:`` is a *singleton expression* of an integer type. This expression is evaluated once and its result is used to include only the first *element-count* elements of the input set while producing the output. If *element-count* evaluates to an empty set, it is equivalent to specifying no ``limit`` clause. Description ----------- ``select`` retrieves or computes a set of values. The data flow of a ``select`` block can be conceptualized like this: .. eql:synopsis:: with module example # select clause select # compute a set of things # optional clause filter # filter the computed set # optional clause order by # define ordering of the filtered set # optional clause offset # slice the filtered/ordered set # optional clause limit # slice the filtered/ordered set Please note that the ``order by`` clause defines ordering that can only be relied upon if the resulting set is not used in any other operation. ``select``, ``offset`` and ``limit`` clauses are the only exception to that rule as they preserve the inherent ordering of the underlying set. The first clause is ``select``. It indicates that ``filter``, ``order by``, ``offset``, or ``limit`` clauses may follow an expression, i.e. it makes an expression into a ``select`` statement. Without any of the optional clauses a ``(select Expr)`` is completely equivalent to ``Expr`` for any expression ``Expr``. Consider an example using the ``filter`` optional clause: .. code-block:: edgeql with module example select User { name, owned := (select User. Select ` * - :ref:`Cheatsheets > Selecting data ` ================================================ FILE: docs/reference/reference/edgeql/sess_reset_alias.rst ================================================ .. _ref_eql_statements_session_reset_alias: Reset ===== :eql-statement: ``reset`` -- reset one or multiple session-level parameters .. eql:synopsis:: reset module ; reset alias ; reset alias * ; reset global ; Description ----------- This command allows resetting one or many configuration parameters of the current session. Variations ---------- :eql:synopsis:`reset module` Reset the default module name back to "default" for the current session. For example, if a module ``foo`` contains type ``FooType``, the following is how the ``set`` and ``reset`` commands can be used to alias it: .. code-block:: edgeql # Set the default module to "foo" for the current session. set module foo; # This query is now equivalent to "select foo::FooType". select FooType; # Reset the default module for the current session. reset module; # This query will now produce an error. select FooType; :eql:synopsis:`reset alias ` Reset :eql:synopsis:`` for the current session. For example: .. code-block:: edgeql # Alias the "std" module as "foo". set alias foo as module std; # Now "std::min()" can be called as "foo::min()" in # the current session. select foo::min({1}); # Reset the alias. reset alias foo; # Now this query will error out, as there is no # module "foo". select foo::min({1}); :eql:synopsis:`reset alias *` Reset all aliases defined in the current session. This command affects aliases set with :eql:stmt:`set alias ` and :eql:stmt:`set module `. The default module will be set to "default". Example: .. code-block:: edgeql # Reset all custom aliases for the current session. reset alias *; :eql:synopsis:`reset global ` Reset the global variable *name* to its default value or ``{}`` if the variable has no default value and is ``optional``. Examples -------- .. code-block:: edgeql reset module; reset alias foo; reset alias *; reset global current_user_id; .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > EdgeQL > Set ` ================================================ FILE: docs/reference/reference/edgeql/sess_set_alias.rst ================================================ .. _ref_eql_statements_session_set_alias: Set === :eql-statement: ``set`` -- set one or multiple session-level parameters .. eql:synopsis:: set module ; set alias as module ; set global := ; Description ----------- This command allows altering the configuration of the current session. Variations ---------- :eql:synopsis:`set module ` Set the default module for the current section to *module*. For example, if a module ``foo`` contains type ``FooType``, the following is how the type can be referred to: .. code-block:: edgeql # Use the fully-qualified name. select foo::FooType; # Use the WITH clause to define the default module # for the query. with module foo select foo::FooType; # Set the default module for the current session ... set module foo; # ... and use an unqualified name. select FooType; :eql:synopsis:`set alias as module ` Define :eql:synopsis:`` for the :eql:synopsis:``. For example: .. code-block:: edgeql # Use the fully-qualified name. select foo::FooType; # Use the WITH clause to define a custom alias # for the "foo" module. with bar as module foo select bar::FooType; # Define "bar" as an alias for the "foo" module for # the current session ... set alias bar as module foo; # ... and use "bar" instead of "foo". select bar::FooType; :eql:synopsis:`set global := ` Set the global variable *name* to the specified value. For example: .. code-block:: edgeql # Set the global variable "current_user_id". set global current_user_id := '00ea8eaa-02f9-11ed-a676-6bd11cc6c557'; # We can now use that value in a query. select User { name } filter .id = global current_user_id; Examples -------- .. code-block:: edgeql set module foo; set alias foo AS module std; set global current_user_id := '00ea8eaa-02f9-11ed-a676-6bd11cc6c557'; .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > EdgeQL > Reset ` ================================================ FILE: docs/reference/reference/edgeql/shapes.rst ================================================ .. _ref_reference_shapes: ====== Shapes ====== A *shape* is a powerful syntactic construct that can be used to describe type variants in queries, data in ``insert`` and ``update`` statements, and to specify the format of statement output. Shapes always follow an expression, and are a list of *shape elements* enclosed in curly braces: .. eql:synopsis:: "{" [, ...] "}" Shape element has the following syntax: .. eql:synopsis:: [ "[" is "]" ] If an optional :eql:synopsis:`` filter is used, :eql:synopsis:`` will only apply to those objects in the :eql:synopsis:`` set that are instances of :eql:synopsis:``. :eql:synopsis:`` is one of the following: - a name of an existing link or property of a type produced by :eql:synopsis:``; - a declaration of a computed link or property in the form .. eql:synopsis :: [@] := - a *subshape* in the form .. eql:synopsis :: : [ "[" is "]" ] "{" ... "}"` The :eql:synopsis:`` is the name of an existing link or property, and :eql:synopsis:`` is an optional object type that specifies the type of target objects selected or inserted, depending on the context. Shaping Query Results ===================== At the end of the day, EdgeQL has two jobs that are similar, yet distinct: 1) Express the values that we want computed. 2) Arrange the values into a particular shape that we want. Consider the task of getting "names of users and all of the friends' names associated with the given user" in a database defined by the following schema: .. code-block:: sdl type User { required name: str; multi friends: User; } If we only concern ourselves with getting the values, then a reasonable solution to this might be: .. code-block:: edgeql-repl db> select (User.name, User.friends.name ?? ''); { ('Alice', 'Cameron'), ('Alice', 'Dana'), ('Billie', 'Dana'), ('Cameron', ''), ('Dana', 'Alice'), ('Dana', 'Billie'), ('Dana', 'Cameron'), } This particular solution is very similar to what one might get using SQL. It's equivalent to a table with "user name" and "friend name" columns. It gets the job done, albeit with some redundant repeating of "user names". We can improve things a little and reduce the repetition by aggregating all the friend names into an array: .. code-block:: edgeql-repl db> select (User.name, array_agg(User.friends.name)); { ('Alice', ['Cameron', 'Dana']), ('Billie', ['Dana']), ('Cameron', []), ('Dana', ['Alice', 'Billie', 'Cameron']), } This achieves a couple of things: it's easier to see which friends belong to which user and we no longer need the placeholder ``''`` for those users who don't have friends. The recommended way to get this information in Gel, however, is to use *shapes*, because they mimic the structure of the data and the output: .. code-block:: edgeql-repl db> select User { ... name, ... friends: { ... name ... } ... }; { default::User { name: 'Alice', friends: { default::User {name: 'Cameron'}, default::User {name: 'Dana'}, }, }, default::User {name: 'Billie', friends: {default::User {name: 'Dana'}}}, default::User {name: 'Cameron', friends: {}}, default::User { name: 'Dana', friends: { default::User {name: 'Alice'}, default::User {name: 'Billie'}, default::User {name: 'Cameron'}, }, }, } So far the expression for the data that we wanted was also acceptable for structuring the output, but what if that's not the case? Let's add a condition and only show those users who have friends with either the letter "i" or "o" in their names: .. code-block:: edgeql-repl db> select User { ... name, ... friends: { ... name ... } ... } filter .friends.name ilike '%i%' or .friends.name ilike '%o%'; { default::User { name: 'Alice', friends: { default::User {name: 'Cameron'}, default::User {name: 'Dana'}, }, }, default::User { name: 'Dana', friends: { default::User {name: 'Alice'}, default::User {name: 'Billie'}, default::User {name: 'Cameron'}, }, }, } That ``filter`` is getting a bit bulky, so perhaps we can just factor these flags out as part of the shape's computed properties: .. code-block:: edgeql-repl db> select User { ... name, ... friends: { ... name ... }, ... has_i := .friends.name ilike '%i%', ... has_o := .friends.name ilike '%o%', ... } filter .has_i or .has_o; { default::User { name: 'Alice', friends: { default::User {name: 'Cameron'}, default::User {name: 'Dana'}, }, has_i: {false, false}, has_o: {true, false}, }, default::User { name: 'Dana', friends: { default::User {name: 'Alice'}, default::User {name: 'Billie'}, default::User {name: 'Cameron'}, }, has_i: {true, true, false}, has_o: {false, false, true}, }, } It looks like this refactoring came at the cost of putting extra things into the output. In this case we don't want our intermediate calculations to actually show up in the output, so what can we do? In |Gel| the output structure is determined *only* by the expression appearing in the top-level :eql:stmt:`select`. This means that we can move our intermediate calculations into the :eql:kw:`with` block: .. code-block:: edgeql-repl db> with U := ( ... select User { ... has_i := .friends.name ilike '%i%', ... has_o := .friends.name ilike '%o%', ... } ... ) ... select U { ... name, ... friends: { ... name ... }, ... } filter .has_i or .has_o; { default::User { name: 'Alice', friends: { default::User {name: 'Cameron'}, default::User {name: 'Dana'}, }, }, default::User { name: 'Dana', friends: { default::User {name: 'Alice'}, default::User {name: 'Billie'}, default::User {name: 'Cameron'}, }, }, } This way we can use ``has_i`` and ``has_o`` in our query without leaking them into the output. General Shaping Rules ===================== In Gel typically all shapes appearing in the top-level :eql:stmt:`select` should be reflected in the output. This also applies to shapes no matter where and how they are nested. Aside from other shapes, this includes nesting in arrays: .. code-block:: edgeql-repl db> select array_agg(User {name}); { [ default::User {name: 'Alice'}, default::User {name: 'Billie'}, default::User {name: 'Cameron'}, default::User {name: 'Dana'}, ], } ... or tuples: .. code-block:: edgeql-repl db> select enumerate(User {name}); { (0, default::User {name: 'Alice'}), (1, default::User {name: 'Billie'}), (2, default::User {name: 'Cameron'}), (3, default::User {name: 'Dana'}), } You can safely access a tuple element and expect the output shape to be intact: .. code-block:: edgeql-repl db> select enumerate(User{name}).1; { default::User {name: 'Alice'}, default::User {name: 'Billie'}, default::User {name: 'Cameron'}, default::User {name: 'Dana'}, } Accessing array elements or working with slices also preserves output shape and is analogous to using ``offset`` and ``limit`` when working with sets: .. code-block:: edgeql-repl db> select array_agg(User {name})[2]; {default::User {name: 'Cameron'}} Losing Shapes ============= There are some situations where shape information gets completely or partially discarded. Any such operation also prevents the altered shape from appearing in the output altogether. In order for the shape to be preserved, the original expression type must be preserved. This means that :eql:op:`union` can alter the shape, because the result of a :eql:op:`union` is a :eql:op:`union type `. So you can still refer to the common properties, but not to the properties that appeared in the shape. As mentioned above, since :eql:op:`union` potentially alters the expression shape it never preserves output shape, even when the underlying type wasn't altered: .. code-block:: edgeql-repl db> select User{name} union User{name}; { default::User {id: 7769045a-27bf-11ec-94ea-3f6c0ae59eb3}, default::User {id: 7b42ed20-27bf-11ec-94ea-7700ec77834e}, default::User {id: 7fcedbc4-27bf-11ec-94ea-73dcb6f297a4}, default::User {id: 82f52646-27bf-11ec-94ea-3718ffb8dd15}, default::User {id: 7769045a-27bf-11ec-94ea-3f6c0ae59eb3}, default::User {id: 7b42ed20-27bf-11ec-94ea-7700ec77834e}, default::User {id: 7fcedbc4-27bf-11ec-94ea-73dcb6f297a4}, default::User {id: 82f52646-27bf-11ec-94ea-3718ffb8dd15}, } Listing several items inside a set ``{ ... }`` functions identically to a :eql:op:`union` and so will also produce a union type and remove shape from output. Another subtle way for a type union to remove the shape from the output is by the :eql:op:`?? ` and the :eql:op:`if..else` operators. Both of them determine the result type as the union of the left and right operands: .. code-block:: edgeql-repl db> select {} ?? User {name}; { default::User {id: 7769045a-27bf-11ec-94ea-3f6c0ae59eb3}, default::User {id: 7b42ed20-27bf-11ec-94ea-7700ec77834e}, default::User {id: 7fcedbc4-27bf-11ec-94ea-73dcb6f297a4}, default::User {id: 82f52646-27bf-11ec-94ea-3718ffb8dd15}, } Shapes survive array creation (either via :eql:func:`array_agg` or by using ``[ ... ]``), but they follow the same rules as for :eql:op:`union` for array :eql:op:`concatenation `. Basically the element type of the resulting array must be a union type and thus all shape information is lost: .. code-block:: edgeql-repl db> select array_agg(User{name}) ++ array_agg(User{name}); { [ default::User {id: 7769045a-27bf-11ec-94ea-3f6c0ae59eb3}, default::User {id: 7b42ed20-27bf-11ec-94ea-7700ec77834e}, default::User {id: 7fcedbc4-27bf-11ec-94ea-73dcb6f297a4}, default::User {id: 82f52646-27bf-11ec-94ea-3718ffb8dd15}, default::User {id: 7769045a-27bf-11ec-94ea-3f6c0ae59eb3}, default::User {id: 7b42ed20-27bf-11ec-94ea-7700ec77834e}, default::User {id: 7fcedbc4-27bf-11ec-94ea-73dcb6f297a4}, default::User {id: 82f52646-27bf-11ec-94ea-3718ffb8dd15}, ], } .. note:: The :eql:stmt:`for` statement preserves the shape given inside the ``union`` clause, effectively applying the shape to its entire result. ================================================ FILE: docs/reference/reference/edgeql/tx_commit.rst ================================================ .. Portions Copyright (c) 2019 MagicStack Inc. and the Gel authors. Portions Copyright (c) 1996-2018, PostgreSQL Global Development Group Portions Copyright (c) 1994, The Regents of the University of California Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. .. _ref_eql_statements_commit_tx: Commit ====== :eql-statement: ``commit`` -- commit the current transaction .. eql:synopsis:: commit ; Example ------- Commit the current transaction: .. code-block:: edgeql commit; Description ----------- The ``commit`` command commits the current transaction. All changes made by the transaction become visible to others and are guaranteed to be durable if a crash occurs. .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > EdgeQL > Start transaction ` * - :ref:`Reference > EdgeQL > Rollabck ` * - :ref:`Reference > EdgeQL > Declare savepoint ` * - :ref:`Reference > EdgeQL > Rollback to savepoint ` * - :ref:`Reference > EdgeQL > Release savepoint ` ================================================ FILE: docs/reference/reference/edgeql/tx_rollback.rst ================================================ .. Portions Copyright (c) 2019 MagicStack Inc. and the Gel authors. Portions Copyright (c) 1996-2018, PostgreSQL Global Development Group Portions Copyright (c) 1994, The Regents of the University of California Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. .. _ref_eql_statements_rollback_tx: Rollback ======== :eql-statement: ``rollback`` -- abort the current transaction .. eql:synopsis:: rollback ; Example ------- Abort the current transaction: .. code-block:: edgeql rollback; Description ----------- The ``rollback`` command rolls back the current transaction and causes all updates made by the transaction to be discarded. .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > EdgeQL > Start transaction ` * - :ref:`Reference > EdgeQL > Commit ` * - :ref:`Reference > EdgeQL > Declare savepoint ` * - :ref:`Reference > EdgeQL > Rollback to savepoint ` * - :ref:`Reference > EdgeQL > Release savepoint ` ================================================ FILE: docs/reference/reference/edgeql/tx_sp_declare.rst ================================================ .. Portions Copyright (c) 2019 MagicStack Inc. and the Gel authors. Portions Copyright (c) 1996-2018, PostgreSQL Global Development Group Portions Copyright (c) 1994, The Regents of the University of California Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. .. _ref_eql_statements_declare_savepoint: Declare savepoint ================= :eql-statement: ``declare savepoint`` -- declare a savepoint within the current transaction .. eql:synopsis:: declare savepoint ; Description ----------- ``savepoint`` establishes a new savepoint within the current transaction. A savepoint is a special mark inside a transaction that allows all commands that are executed after it was established to be rolled back, restoring the transaction state to what it was at the time of the savepoint. It is an error to declare a savepoint outside of a transaction. Example ------- .. code-block:: edgeql # Will select no objects: select test::TestObject { name }; start transaction; insert test::TestObject { name := 'q1' }; insert test::TestObject { name := 'q2' }; # Will select two TestObjects with names 'q1' and 'q2' select test::TestObject { name }; declare savepoint f1; insert test::TestObject { name:='w1' }; # Will select three TestObjects with names # 'q1' 'q2', and 'w1' select test::TestObject { name }; rollback to savepoint f1; # Will select two TestObjects with names 'q1' and 'q2' select test::TestObject { name }; rollback; .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > EdgeQL > Start transaction ` * - :ref:`Reference > EdgeQL > Commit ` * - :ref:`Reference > EdgeQL > Rollabck ` * - :ref:`Reference > EdgeQL > Rollback to savepoint ` * - :ref:`Reference > EdgeQL > Release savepoint ` ================================================ FILE: docs/reference/reference/edgeql/tx_sp_release.rst ================================================ .. Portions Copyright (c) 2019 MagicStack Inc. and the Gel authors. Portions Copyright (c) 1996-2018, PostgreSQL Global Development Group Portions Copyright (c) 1994, The Regents of the University of California Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. .. _ref_eql_statements_release_savepoint: Release savepoint ================= :eql-statement: ``release savepoint`` -- release a previously declared savepoint .. eql:synopsis:: release savepoint ; Description ----------- ``release savepoint`` destroys a savepoint previously defined in the current transaction. Destroying a savepoint makes it unavailable as a rollback point, but it has no other user visible behavior. It does not undo the effects of commands executed after the savepoint was established. (To do that, see :eql:stmt:`rollback to savepoint`.) ``release savepoint`` also destroys all savepoints that were established after the named savepoint was established. Example ------- .. code-block:: edgeql start transaction; # ... declare savepoint f1; # ... release savepoint f1; # ... rollback; .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > EdgeQL > Start transaction ` * - :ref:`Reference > EdgeQL > Commit ` * - :ref:`Reference > EdgeQL > Rollabck ` * - :ref:`Reference > EdgeQL > Declare savepoint ` * - :ref:`Reference > EdgeQL > Rollback to savepoint ` ================================================ FILE: docs/reference/reference/edgeql/tx_sp_rollback.rst ================================================ .. Portions Copyright (c) 2019 MagicStack Inc. and the Gel authors. Portions Copyright (c) 1996-2018, PostgreSQL Global Development Group Portions Copyright (c) 1994, The Regents of the University of California Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. .. _ref_eql_statements_rollback_savepoint: Rollback to savepoint ===================== :eql-statement: ``rollback to savepoint`` -- rollback to a savepoint within the current transaction .. eql:synopsis:: rollback to savepoint ; Description ----------- Rollback all commands that were executed after the savepoint was established. The savepoint remains valid and can be rolled back to again later, if needed. ``rollback to savepoint`` implicitly destroys all savepoints that were established after the named savepoint. Example ------- .. code-block:: edgeql start transaction; # ... declare savepoint f1; # ... rollback to savepoint f1; # ... rollback; .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > EdgeQL > Start transaction ` * - :ref:`Reference > EdgeQL > Commit ` * - :ref:`Reference > EdgeQL > Rollabck ` * - :ref:`Reference > EdgeQL > Declare savepoint ` * - :ref:`Reference > EdgeQL > Release savepoint ` ================================================ FILE: docs/reference/reference/edgeql/tx_start.rst ================================================ .. Portions Copyright (c) 2019 MagicStack Inc. and the Gel authors. Portions Copyright (c) 1996-2018, PostgreSQL Global Development Group Portions Copyright (c) 1994, The Regents of the University of California Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. .. _ref_eql_statements_start_tx: Start transaction ================= :eql-statement: ``start transaction`` -- start a transaction .. eql:synopsis:: start transaction [ , ... ] ; # where is one of: isolation repeatable read isolation serializable read write | read only deferrable | not deferrable Description ----------- This command starts a new transaction block. Any Gel command outside of an explicit transaction block starts an implicit transaction block; the transaction is then automatically committed if the command was executed successfully, or automatically rollbacked if there was an error. This behavior is often called "autocommit". When isolation is not specified, it defaults to ``serializable``. Parameters ---------- The :eql:synopsis:`` can be one of the following: :eql:synopsis:`isolation serializable` All statements in the current transaction can only see data changes that were committed before the first query or data modification statement was executed within this transaction. If a pattern of reads and writes among concurrent serializable transactions creates a situation that could not have occurred in any serial (one-at-a-time) execution of those transactions, one of them will be rolled back with a serialization failure. This level is the default isolation level. Note that, compared to ``repeatable read``, serializable level has a significantly higher probability of resulting in serialization failures, requires the whole transaction to be retried. If acceptable, consider using ``repeatable read`` or :ref:`prefer repeatable read `. :eql:synopsis:`isolation repeatable read` All statements in the current transaction can only see data changes that were committed before the first query or data modification statement was executed within this transaction. Compared to ``serializable``, this level is less likely to result in serialization failures. It is however possible for this level to allow serialization anomalies. This constitutes a series of transactions that would not be allowed if they were executed serially instead of concurrently. For example, assume ``type X { is_selected: bool }`` and following query: .. code-block:: edgeql # transaction A: unselect all selected update X filter .is_selected set { is_selected := false }; # transaction B: select all unselected update X filter not .is_selected set { is_selected := true }; Running these two transactions serially would result in either all ``X`` being select or none being selected. But if executed concurrently, even with the ``repeatable read`` isolation level, we can end up with some ``X`` being selected and some not. To avoid this, we can use the ``serializable`` isolation level. :eql:synopsis:`read write` Sets the transaction access mode to read/write. This is the default. :eql:synopsis:`read only` Sets the transaction access mode to read-only. Any data modifications with :eql:stmt:`insert`, :eql:stmt:`update`, or :eql:stmt:`delete` are disallowed. Schema mutations via :ref:`DDL ` are also disallowed. :eql:synopsis:`deferrable` The transaction can be set to deferrable mode only when it is ``serializable`` and ``read only``. When all three of these properties are selected for a transaction, the transaction may block when first acquiring its snapshot, after which it is able to run without the normal overhead of a ``serializable`` transaction and without any risk of contributing to or being canceled by a serialization failure. This mode is well suited for long-running reports or backups. Examples -------- Start a new transaction and rollback it: .. code-block:: edgeql start transaction; select 'Hello World!'; rollback; Start a serializable deferrable transaction: .. code-block:: edgeql start transaction isolation serializable, read only, deferrable; .. _prefer_repeatable_read: Prefer repeatable read ---------------------- In addition to the isolation levels above, some client libraries also support ``PreferRepeatableRead`` as a transaction isolation level. In this mode, the server will analyze the query and use ``repeatable read`` isolation level if it can. When it cannot, it will use ``serializable`` isolation level. Client libraries that currently support this mode: * TypeScript/JS * Python * Go .. list-table:: :class: seealso * - **See also** * - :ref:`Reference > EdgeQL > Commit ` * - :ref:`Reference > EdgeQL > Rollback ` * - :ref:`Reference > EdgeQL > Declare savepoint ` * - :ref:`Reference > EdgeQL > Rollback to savepoint ` * - :ref:`Reference > EdgeQL > Release savepoint ` ================================================ FILE: docs/reference/reference/edgeql/update.rst ================================================ .. _ref_eql_statements_update: Update ====== :eql-statement: :eql-haswith: ``update`` -- update objects in a database .. eql:synopsis:: [ with [, ...] ] update [ filter ] set ; ``update`` changes the values of the specified links in all objects selected by *update-selector-expr* and, optionally, filtered by *filter-expr*. :eql:synopsis:`with` Alias declarations. The ``with`` clause allows specifying module aliases as well as expression aliases that can be referenced by the ``update`` statement. See :ref:`ref_eql_statements_with` for more information. :eql:synopsis:`update ` An arbitrary expression returning a set of objects to be updated. :eql:synopsis:`filter ` An expression of type :eql:type:`bool` used to filter the set of updated objects. :eql:synopsis:`` is an expression that has a result of type :eql:type:`bool`. Only objects that satisfy the filter expression will be updated. See the description of the ``filter`` clause of the :eql:stmt:`select` statement for more information. :eql:synopsis:`set ` A shape expression with the new values for the links of the updated object. There are three possible assignment operations permitted within the ``set`` shape: .. eql:synopsis:: set { := [, ...] } set { += [, ...] } set { -= [, ...] } The most basic assignment is the ``:=``, which just sets the :eql:synopsis:`` to the specified :eql:synopsis:``. The ``+=`` and ``-=`` either add or remove the set of values specified by the :eql:synopsis:`` from the *current* value of the :eql:synopsis:``. Output ~~~~~~ On successful completion, an ``update`` statement returns the set of updated objects. Examples ~~~~~~~~ Here are a couple of examples of the ``update`` statement with simple assignments using ``:=``: .. code-block:: edgeql # update the user with the name 'Alice Smith' with module example update User filter .name = 'Alice Smith' set { name := 'Alice J. Smith' }; # update all users whose name is 'Bob' with module example update User filter .name like 'Bob%' set { name := User.name ++ '*' }; For usage of ``+=`` and ``-=`` consider the following ``Post`` type: .. code-block:: sdl # ... Assume some User type is already defined type Post { required title: str; required body: str; # A "tags" property containing a set of strings multi tags: str; author: User; } The following queries add or remove tags from some user's posts: .. code-block:: edgeql with module example update Post filter .author.name = 'Alice Smith' set { # add tags tags += {'example', 'edgeql'} }; with module example update Post filter .author.name = 'Alice Smith' set { # remove a tag, if it exist tags -= 'todo' }; The statement ``for in `` allows to express certain bulk updates more clearly. See :ref:`ref_eql_forstatement` for more details. .. list-table:: :class: seealso * - **See also** * - :ref:`EdgeQL > Update ` * - :ref:`Cheatsheets > Updating data ` ================================================ FILE: docs/reference/reference/edgeql/volatility.rst ================================================ .. _ref_reference_volatility: Volatility ========== The **volatility** of an expression refers to how its value may change across successive evaluations. Expressions may have one of the following volatilities, in order of increasing volatility: * ``Immutable``: The expression cannot modify the database and is guaranteed to have the same value *in all statements*. * ``Stable``: The expression cannot modify the database and is guaranteed to have the same value *within a single statement*. * ``Volatile``: The expression cannot modify the database and can have different values on successive evaluations. * ``Modifying``: The expression can modify the database and can have different values on successive evaluations. Expressions ----------- All :ref:`primitives `, :ref:`ranges `, and :ref:`multiranges ` are ``Immutable``. :ref:`Arrays `, :ref:`tuples `, and :ref:`sets ` have the volatility of their most volatile component. :ref:`Globals ` are always ``Stable``, even computed globals with an immutable expression. Objects and shapes ^^^^^^^^^^^^^^^^^^ :ref:`Objects ` are generally ``Stable`` except: * Objects with a :ref:`shape ` containing a more volatile computed pointer will have the volatility of its most volatile component. * :ref:`Free objects ` have the volatility of their most volatile component. They may be ``Immutable``. An object's non-computed pointers are ``Stable``. Its computed pointers have the volatility of their expressions. Any DML (i.e., :ref:`insert `, :ref:`update `, :ref:`delete `) is ``Modifying``. Functions and operators ^^^^^^^^^^^^^^^^^^^^^^^ Unless explicitly specified, a :ref:`function's ` volatility will be inferred from its body expression. A function call's volatility is highest of its body expression and its call arguments. Given: .. code-block:: sdl # Immutable function plus_primitive(x: float64) -> float64 using (x + 1); # Stable global one := 1; function plus_global(x: float64) -> float64 using (x + one); # Volatile function plus_random(x: float64) -> float64 using (x + random()); # Modifying type One { val := 1; }; function plus_insert(x: float64) -> float64 using (x + (insert One).val); Some example operator and function calls: .. code-block:: 1 + 1: Immutable 1 + global one: Stable global one + random(): Volatile (insert One).val: Modifying plus_primitive(1): Immutable plus_stable(1): Stable plus_random(global one): Volatile plus_insert(random()): Immutable Restrictions ------------ Some features restrict the volatility of expressions. A lower volatility can be used. :ref:`Indexes ` expressions must be ``Immutable``. Within the index, pointers to the indexed object are treated as immutable :ref:`constraints ` expressions must be ``Immutable``. Within the constraint, the ``__subject__`` and its pointers are treated as immutable. :ref:`Access policies ` must be ``Stable``. :ref:`Aliases `, :ref:`globals `, and :ref:`computed pointers ` in the schema must be ``Stable``. The :ref:`cartesian product ` of a ``Volatile`` or ``Modifying`` expression is not allowed. .. code-block:: edgeql-repl db> SELECT {1, 2} + random() QueryError: can not take cross product of volatile operation ``Modifying`` expressions are not allowed in a non-scalar argument to a function, except for :ref:`standard set functions `. The non-optional parameters of ``Modifying`` :ref:`functions ` must have a :ref:`cardinality ` of ``One``. Optional parameters must have a cardinality of ``AtMostOne``. ================================================ FILE: docs/reference/reference/edgeql/with.rst ================================================ .. _ref_eql_statements_with: With block ========== .. eql:keyword:: with The ``with`` block in EdgeQL is used to define aliases. The expression aliases are evaluated in the lexical scope they appear in, not the scope where their alias is used. This means that refactoring queries using aliases must be done with care so as not to alter the query semantics. Specifying a module +++++++++++++++++++ .. eql:keyword:: module Used inside a ``with`` block to specify module names. One of the more basic and common uses of the ``with`` block is to specify the default module that is used in a query. ``with module `` construct indicates that whenever an identifier is used without any module specified explicitly, the module will default to ```` and then fall back to built-ins from ``std`` module. The following queries are exactly equivalent: .. code-block:: edgeql with module example select User { name, owned := (select User. With ` ================================================ FILE: docs/reference/reference/index.rst ================================================ .. versioned-section:: .. _ref_reference_index: ========= Reference ========= .. toctree:: :maxdepth: 3 :hidden: edgeql/index This section contains comprehensive reference documentation on the internals of |Gel|, the binary protocol, the formal syntax of EdgeQL, and more. ================================================ FILE: docs/reference/running/admin/configure.rst ================================================ .. _ref_eql_statements_configure: Configure ========= :eql-statement: ``configure`` -- change a server configuration parameter .. eql:synopsis:: configure {session | current branch | instance} set := ; configure instance insert ; configure {session | current branch | instance} reset ; configure {current branch | instance} reset [ filter ] ; .. note:: Prior to |Gel| and |EdgeDB| 5.0 *branches* were called *databases*. ``configure current branch`` is used to be called ``configure current database``, which is still supported for backwards compatibility. Description ----------- This command allows altering the server configuration. The effects of :eql:synopsis:`configure session` last until the end of the current session. Some configuration parameters cannot be modified by :eql:synopsis:`configure session` and can only be set by :eql:synopsis:`configure instance`. :eql:synopsis:`configure current branch` is used to configure an individual Gel branch within a server instance with the changes persisted across server restarts. :eql:synopsis:`configure instance` is used to configure the entire Gel instance with the changes persisted across server restarts. This variant acts directly on the file system and cannot be rolled back, so it cannot be used in a transaction block. The :eql:synopsis:`configure instance insert` variant is used for composite configuration parameters, such as ``Auth``. Parameters ---------- :eql:synopsis:`` The name of a primitive configuration parameter. Available configuration parameters are described in the :ref:`ref_std_cfg` section. :eql:synopsis:`` The name of a composite configuration value class. Available configuration classes are described in the :ref:`ref_std_cfg` section. :eql:synopsis:`` An expression that returns a value of type :eql:type:`std::bool`. Only configuration objects matching this condition will be affected. Examples -------- Set the ``listen_addresses`` parameter: .. code-block:: edgeql configure instance set listen_addresses := {'127.0.0.1', '::1'}; Set the ``query_work_mem`` parameter: .. code-block:: edgeql configure instance set query_work_mem := '4MiB'; Add a Trust authentication method for "my_user": .. code-block:: edgeql configure instance insert Auth { priority := 1, method := (insert Trust), user := 'my_user' }; Remove all Trust authentication methods: .. code-block:: edgeql configure instance reset Auth filter Auth.method is Trust; ================================================ FILE: docs/reference/running/admin/index.rst ================================================ .. _ref_admin: Administration ============== Administrative commands for managing Gel: * :ref:`configure ` Configure server behavior. * :ref:`role ` Create, remove, or alter a role. .. versionadded:: 5.0 New administrative commands were added in |EdgeDB| 5 release: * :ref:`branch ` Create, remove, or alter a branch. * :ref:`administer statistics_update() ` Update internal statistics about data. * :ref:`administer vacuum() ` Reclaim storage space. .. toctree:: :maxdepth: 3 :hidden: configure roles statistics_update vacuum ================================================ FILE: docs/reference/running/admin/roles.rst ================================================ .. _ref_admin_roles: ===== Roles ===== :edb-alt-title: Roles This section describes the administrative commands pertaining to *roles*. Create role =========== :eql-statement: Create a role. .. eql:synopsis:: create superuser role [ extending [, ...] ] "{" ; [...] "}" ; # where is one of set password := Description ----------- The command ``create role`` defines a new database role. :eql:synopsis:`superuser` If specified, the created role will have the *superuser* status, and will be exempt from :ref:`all permission checks`. Prior to version 7.0, ``superuser`` qualifier was mandatory, i.e. it was not possible to create non-superuser roles. :eql:synopsis:`` The name of the role to create. :eql:synopsis:`extending [, ...]` If specified, declares the parent roles for this role. The role inherits all the privileges of the parents. The following subcommands are allowed in the ``create role`` block: :eql:synopsis:`set password := ` Set the password for the role. .. versionadded:: 7.0 :eql:synopsis:`set permissions := ` Set :ref:`permissions ` for the role. Value is a set of identifiers of either built-in permissions or permissions defined in schema. Roles also gain the permissions of their base Roles. Roles that are *superusers* are implicitly granted all permissions, so setting this does not have any effect. Note that permission names are not validated and it is possible to reference a permission that does not yet exist in any schema. :eql:synopsis:`set branches := ` Configure a set of branches that this role is allowed to access. When connecting to instance branch that is not in this set, connection will be refused. If set to ``'*'``, this branch can connect to all branches of the instance. Defaults to ``'*'``. Examples -------- Create a new role: .. code-block:: edgeql create role alice { set password := 'wonderland'; set permissions := { sys::perm::data_modifiction, sys::perm::query_stats, cfg::perm::configure_timeouts, cfg::perm::configure_apply_access_policies, ext::auth::perm::auth_read, ext::auth::perm::auth_write, }; set branches := {'main', 'staging'}; }; Alter role ========== :eql-statement: Alter an existing role. .. eql:synopsis:: alter role "{" ; [...] "}" ; # where is one of rename to set password := extending ... Description ----------- The command ``alter role`` changes the settings of an existing role. :eql:synopsis:`` The name of the role to alter. The following subcommands are allowed in the ``alter role`` block: :eql:synopsis:`rename to ` Change the name of the role to *newname*. :eql:synopsis:`extending ...` Alter the role parent list. The full syntax of this subcommand is: .. eql:synopsis:: extending [, ...] [ first | last | before | after ] This subcommand makes the role a child of the specified list of parent roles. The role inherits all the privileges of the parents. It is possible to specify the position in the parent list using the following optional keywords: * ``first`` -- insert parent(s) at the beginning of the parent list, * ``last`` -- insert parent(s) at the end of the parent list, * ``before `` -- insert parent(s) before an existing *parent*, * ``after `` -- insert parent(s) after an existing *parent*. .. versionadded:: 7.0 :eql:synopsis:`set permissions := ` Set :ref:`permissions ` for the role. Value is a set of identifiers of either built-in permissions or permissions defined in schema. Roles that are *superusers* are implicitly granted all permissions, so setting this does not have any effect. Note that permission names are not validated and it is possible to reference a permission that does not yet exist in the schema. :eql:synopsis:`set branches := ` Configure a set of branches that this role is allowed to access. When connecting to instance branch that is not in this set, connection will be refused. If set to ``'*'``, this branch can connect to all branches of the instance. Defaults to ``'*'``. Examples -------- Alter a role: .. code-block:: edgeql alter role alice { set password := 'new password'; set branches := {'*'}; }; Drop role ========= :eql-statement: Remove a role. .. eql:synopsis:: drop role ; Description ----------- The command ``drop role`` removes an existing role. Examples -------- Remove a role: .. code-block:: edgeql drop role alice; ================================================ FILE: docs/reference/running/admin/statistics_update.rst ================================================ .. versionadded:: 6.0 .. _ref_admin_statistics_update: ============================== administer statistics_update() ============================== :eql-statement: Update internal statistics about data. .. eql:synopsis:: administer statistics_update "(" [ [, ...]] ")" Description ----------- Updates statistics about the contents of data in the current branch. Subsequently, the query planner uses these statistics to help determine the most efficient execution plans for queries. :eql:synopsis:`` If a type name or a path to a link or property are specified, that data will be targeted for statistics update. If omitted, all user-accessible data will be analyzed. Examples -------- Update the statistics on type ``SomeType``: .. code-block:: edgeql administer statistics_update(SomeType); Update statistics of type ``SomeType`` and the link ``OtherType.ptr``. .. code-block:: edgeql administer statistics_update(SomeType, OtherType.ptr); Update statistics on everything that is user-accessible in the database: .. code-block:: edgeql administer statistics_update(); ================================================ FILE: docs/reference/running/admin/vacuum.rst ================================================ .. versionadded:: 5.0 .. _ref_admin_vacuum: ====== Vacuum ====== :eql-statement: Reclaim storage space. .. eql:synopsis:: administer vacuum "(" [ [, ...]] [, full := {true | false}] [, statistics_update := {true | false}] ")" Description ----------- Cleans and reclaims storage by removing obsolete data. :eql:synopsis:`` If a type name or a path to a link or property are specified, that data will be targeted for the vacuum operation. If omitted, all user-accessible data will be targeted. :eql:synopsis:`full := {true | false}` If set to ``true``, an exclusive lock is obtained and reclaimed space is returned to the operating system. If set to ``false`` or if not set, the command can operate alongside normal reading and writing of the database and reclaimed space is kept available for reuse in the database, reducing the rate of growth of the database. :eql:synopsis:`statistics_update := {true | false}` If set to ``true``, updates statistics used by the planner to determine the most efficient way to execute queries on specified data. See also :ref:`administer statistics_update() `. Examples -------- Vacuum the type ``SomeType``: .. code-block:: edgeql administer vacuum(SomeType); Vacuum the type ``SomeType`` and the link ``OtherType.ptr`` and return reclaimed space to the operating system: .. code-block:: edgeql administer vacuum(SomeType, OtherType.ptr, full := true); Vacuum everything that is user-accessible in the database: .. code-block:: edgeql administer vacuum(); ================================================ FILE: docs/reference/running/backend_ha.rst ================================================ .. _ref_backend_ha: Backend high-availability ========================= High availability is a sophisticated and systematic challenge, especially for databases. To address the problem, Gel server now supports selected highly-available backend Postgres clusters, namely in 2 categories: * API-based HA * Adaptive HA without API When the backend HA feature is enabled in Gel, Gel server will try its best to detect and react to backend failovers, whether a proper API is available or not. During backend failover, no frontend connections will be closed; instead, all incoming queries will fail with a retryable error until failover has completed successfully. If the query originates from a client that supports retrying transactions, these queries may be retried by the client until the backend connection is restored and the query can be properly resolved. API-based HA ------------ |Gel| server accepts different types of backends by looking into the protocol of the ``--backend-dsn`` command-line parameter. Gel supports the following DSN protocols currently: * ``stolon+consul+http://`` * ``stolon+consul+https://`` When using these protocols, Gel builds the actual DSN of the cluster's leader node by calling the corresponding API using credentials in the ``--backend-dsn`` and subscribes to that API for failover events. Once failover is detected, Gel drops all backend connections and routes all new backend connections to the new leader node. `Stolon `_ is an open-source cloud native PostgreSQL manager for PostgreSQL high availability. Currently, Gel supports using a Stolon cluster as the backend in a Consul-based setup, where Gel acts as a Stolon proxy. This way, you only need to manage Stolon sentinels and keepers, plus a Consul deployment. To use a Stolon cluster, run Gel server with a DSN, like so: .. code-block:: bash $ gel-server \ --backend-dsn stolon+consul+http://localhost:8500/my-cluster |Gel| will connect to the Consul HTTP service at ``localhost:8500``, and subscribe to the updates of the cluster named ``my-cluster``. Using a regular ``postgres://`` DSN disables API-based HA. Adaptive HA ----------- |Gel| also supports DNS-based generic HA backends. This may be a cloud database with multi-AZ failover or some custom HA Postgres cluster that keeps a DNS name always resolved to the leader node. Adaptive HA can be enabled with a switch in addition to a regular backend DSN: .. code-block:: bash $ gel-server \ --backend-dsn postgres://xxx.rds.amazonaws.com \ --enable-backend-adaptive-ha Once enabled, Gel server will keep track of unusual backend events like unexpected disconnects or Postgres shutdown notifications. When a threshold is reached, Gel considers the backend to be in the "failover" state. It then drops all current backend connections and try to re-establish new connections with the same backend DSN. Because Gel doesn't cache resolved DNS values, the new connections will be established with the new leader node. Under the hood of adaptive HA, Gel maintains a state machine to avoid endless switch-overs in an unstable network. State changes only happen when certain conditions are met. **Set of possible states:** * ``Healthy`` - all is good * ``Unhealthy`` - a staging state before failover * ``Failover`` - backend failover is in process **Rules of state switches:** ``Unhealthy`` -> ``Healthy`` * Successfully connected to a non-hot-standby backend. ``Unhealthy`` -> ``Failover`` * More than 60% (configurable with environment variable :gelenv:`SERVER_BACKEND_ADAPTIVE_HA_DISCONNECT_PERCENT`) of existing pgcons are "unexpectedly disconnected" (number of existing pgcons is captured at the moment we change to ``Unhealthy`` state, and maintained on "expected disconnects" too). * (and) In ``Unhealthy`` state for more than 30 seconds (:gelenv:`SERVER_BACKEND_ADAPTIVE_HA_UNHEALTHY_MIN_TIME`). * (and) sys_pgcon is down. * (or) Postgres shutdown/hot-standby notification received. ``Healthy`` -> ``Unhealthy`` * Any unexpected disconnect. ``Healthy`` -> ``Failover`` * Postgres shutdown/hot-standby notification received. ``Failover`` -> ``Healthy`` * Successfully connected to a non-hot-standby backend. * (and) sys_pgcon is healthy. ("pgcon" is a code name for backend connections, and "sys_pgcon" is a special backend connection which Gel uses to talk to the "Gel system database".) ================================================ FILE: docs/reference/running/configuration.rst ================================================ .. _ref_admin_config: ============= Configuration ============= The behavior of the Gel server is configurable with sensible defaults. Some configuration can be set on the running instance using configuration parameters, while other configuration is set at startup using environment variables or command line arguments to the |gel-server| binary. Configuration parameters ======================== |Gel| exposes a number of configuration parameters that affect its behavior. In this section we review the ways to change the server configuration, as well as detail each available configuration parameter. EdgeQL ------ The :eql:stmt:`configure` command can be used to set the configuration parameters using EdgeQL. For example, you can use the CLI REPL to set the ``listen_addresses`` parameter: .. code-block:: edgeql-repl gel> configure instance set listen_addresses := {'127.0.0.1', '::1'}; CONFIGURE: OK CLI --- The :ref:`ref_cli_gel_configure` command allows modifying the system configuration from a terminal or a script: .. code-block:: bash $ gel configure set listen_addresses 127.0.0.1 ::1 Configuration parameters ======================== :edb-alt-title: Available Configuration Parameters .. _ref_admin_config_connection: Connection settings ------------------- .. api-index:: listen_addresses, listen_port, cors_allow_origins :eql:synopsis:`listen_addresses: multi str` Specifies the TCP/IP address(es) on which the server is to listen for connections from client applications. If the list is empty, the server does not listen on any IP interface at all. :eql:synopsis:`listen_port: int16` The TCP port the server listens on; ``5656`` by default. Note that the same port number is used for all IP addresses the server listens on. :eql:synopsis:`cors_allow_origins: multi str` Origins that will be calling the server that need Cross-Origin Resource Sharing (CORS) support. Can use ``*`` to allow any origin. When HTTP clients make a preflight request to the server, the origins allowed here will be added to the ``Access-Control-Allow-Origin`` header in the response. Resource usage -------------- .. api-index:: effective_io_concurrency, query_work_mem, shared_buffers :eql:synopsis:`effective_io_concurrency: int64` Sets the number of concurrent disk I/O operations that can be executed simultaneously. Corresponds to the PostgreSQL configuration parameter of the same name. :eql:synopsis:`query_work_mem: cfg::memory` The amount of memory used by internal query operations such as sorting. Corresponds to the PostgreSQL ``work_mem`` configuration parameter. :eql:synopsis:`shared_buffers: cfg::memory` The amount of memory the database uses for shared memory buffers. Corresponds to the PostgreSQL configuration parameter of the same name. Changing this value requires server restart. Query planning -------------- .. api-index:: default_statistics_target, effective_cache_size :eql:synopsis:`default_statistics_target: int64` Sets the default data statistics target for the planner. Corresponds to the PostgreSQL configuration parameter of the same name. :eql:synopsis:`effective_cache_size: cfg::memory` Sets the planner's assumption about the effective size of the disk cache that is available to a single query. Corresponds to the PostgreSQL configuration parameter of the same name. Query cache ----------- .. versionadded:: 5.0 .. api-index:: auto_rebuild_query_cache, query_cache_mode, cfg::QueryCacheMode :eql:synopsis:`auto_rebuild_query_cache: bool` Determines whether to recompile the existing query cache to SQL any time DDL is executed. :eql:synopsis:`query_cache_mode: cfg::QueryCacheMode` Allows the developer to set where the query cache is stored. Possible values: * ``cfg::QueryCacheMode.InMemory``- All query cache is lost on server restart. This mirrors pre-5.0 |EdgeDB| behavior. * ``cfg::QueryCacheMode.RegInline``- The in-memory query cache is also stored in the database as-is so it can be restored on restart. * ``cfg::QueryCacheMode.Default``- Allow the server to select the best caching option. Currently, it will select ``InMemory`` for arm64 Linux and ``RegInline`` for everything else. * ``cfg::QueryCacheMode.PgFunc``- Wraps queries into stored functions in Postgres and reduces backend request size and preparation time. Query behavior -------------- .. api-index:: allow_bare_ddl, cfg::AllowBareDDL, apply_access_policies, apply_access_policies_pg, force_database_error :eql:synopsis:`allow_bare_ddl: cfg::AllowBareDDL` Allows for running bare DDL outside a migration. Possible values are ``cfg::AllowBareDDL.AlwaysAllow`` and ``cfg::AllowBareDDL.NeverAllow``. When you create an instance, this is set to ``cfg::AllowBareDDL.AlwaysAllow`` until you run a migration. At that point it is set to ``cfg::AllowBareDDL.NeverAllow`` because it's generally a bad idea to mix migrations with bare DDL. .. _ref_std_cfg_apply_access_policies: :eql:synopsis:`apply_access_policies: bool` Determines whether access policies should be applied when running queries. Setting this to ``false`` effectively puts you into super-user mode, ignoring any access policies that might otherwise limit you on the instance. .. note:: This setting can also be conveniently accessed via the "Config" dropdown menu at the top of the Gel UI (accessible by running the CLI command :gelcmd:`ui` from within a project). The setting will apply only to your UI session, so you won't have to remember to re-enable it when you're done. :eql:synopsis:`apply_access_policies_pg -> bool` Determines whether access policies should be applied when running queries over SQL adapter. Defaults to ``false``. :eql:synopsis:`force_database_error -> str` A hook to force all queries to produce an error. Defaults to 'false'. .. note:: This parameter takes a ``str`` instead of a ``bool`` to allow more verbose messages when all queries are forced to fail. The database will attempt to deserialize this ``str`` into a JSON object that must include a ``type`` (which must be a Gel :ref:`error type ` name), and may also include ``message``, ``hint``, and ``details`` which can be set ad-hoc by the user. For example, the following is valid input: ``'{ "type": "QueryError", "message": "Did not work", "hint": "Try doing something else", "details": "Indeed, something went really wrong" }'`` As is this: ``'{ "type": "UnknownParameterError" }'`` Transaction behavior -------------------- .. api-index:: default_transaction_isolation, default_transaction_access_mode, default_transaction_deferrable .. versionadded:: 6.0 These settings will affect both explicit transactions as well as the implicit transactions that each query runs in. :eql:synopsis:`default_transaction_isolation -> sys::TransactionIsolation` Controls the default isolation level of each new transaction, including implicit transactions. Defaults to ``sys::TransactionIsolation.Serializable``. * ``sys::TransactionIsolation.RepeatableRead`` * ``sys::TransactionIsolation.Serializable`` (default) :eql:synopsis:`default_transaction_access_mode -> sys::TransactionAccessMode` Controls the default read-only status of each new transaction, including implicit transactions. Defaults to ``sys::TransactionAccessMode.ReadWrite``. * ``sys::TransactionAccessMode.ReadOnly`` * ``sys::TransactionAccessMode.ReadWrite`` (default) :eql:synopsis:`default_transaction_deferrable -> sys::TransactionDeferrability` Controls the default deferrable status of each new transaction. It currently has no effect on read-write transactions. Defaults to ``sys::TransactionDeferrability.NotDeferrable``. * ``sys::TransactionDeferrability.Deferrable`` * ``sys::TransactionDeferrability.NotDeferrable`` (default) .. _ref_std_cfg_client_connections: Client connections ------------------ .. api-index:: allow_user_specified_id, session_idle_timeout, session_idle_transaction_timeout, query_execution_timeout :eql:synopsis:`allow_user_specified_id: bool` Makes it possible to set the ``.id`` property when inserting new objects. .. warning:: Enabling this feature introduces some security vulnerabilities: 1. An unprivileged user can discover ids that already exist in the database by trying to insert new values and noting when there is a constraint violation on ``.id`` even if the user doesn't have access to the relevant table. 2. It allows re-using object ids for a different object type, which the application might not expect. Additionally, enabling can have serious performance implications as, on an ``insert``, every object type must be checked for collisions. As a result, we don't recommend enabling this. If you need to preserve UUIDs from an external source on your objects, it's best to create a new property to store these UUIDs. If you will need to filter on this external UUID property, you may add an :ref:`index ` or exclusive constraint on it. :eql:synopsis:`session_idle_timeout -> std::duration` Sets the timeout for how long client connections can stay inactive before being forcefully closed by the server. Time spent on waiting for query results doesn't count as idling. E.g. if the session idle timeout is set to 1 minute it would be OK to run a query that takes 2 minutes to compute; to limit the query execution time use the ``query_execution_timeout`` setting. The default is 60 seconds. Setting it to ``'0'`` disables the mechanism. Setting the timeout to less than ``2`` seconds is not recommended. Note that the actual time an idle connection can live can be up to two times longer than the specified timeout. This is a system-level config setting. :eql:synopsis:`session_idle_transaction_timeout -> std::duration` Sets the timeout for how long client connections can stay inactive while in a transaction. The default is 10 seconds. Setting it to ``'0'`` disables the mechanism. .. note:: For ``session_idle_transaction_timeout`` and ``query_execution_timeout``, values under 1ms are rounded down to zero, which will disable the timeout. In order to set a timeout, please set a duration of 1ms or greater. ``session_idle_timeout`` can take values below 1ms. :eql:synopsis:`query_execution_timeout -> std::duration` Sets a time limit on how long a query can be run. Setting it to ``'0'`` disables the mechanism. The timeout isn't enabled by default. .. note:: For ``session_idle_transaction_timeout`` and ``query_execution_timeout``, values under 1ms are rounded down to zero, which will disable the timeout. In order to set a timeout, please set a duration of 1ms or greater. ``session_idle_timeout`` can take values below 1ms. .. _ref_reference_environment: .. _ref_reference_envvar_variants: Environment variables ===================== Certain behaviors of the Gel server are configured at startup. This configuration can be set with environment variables. The variables documented on this page are supported when using the |gel-server| binary or the official :ref:`Docker image `. Some environment variables (noted below) support ``_FILE`` and ``_ENV`` variants. - The ``_FILE`` variant expects its value to be a file name. The file's contents will be read and used as the value. - The ``_ENV`` variant expects its value to be the name of another environment variable. The value of the other environment variable is then used as the final value. This is convenient in deployment scenarios where relevant values are auto populated into fixed environment variables. .. note:: For |Gel| versions before 6.0 the prefix for all environment variables is ``EDGEDB_`` instead of ``GEL_``. GEL_DEBUG_HTTP_INJECT_CORS -------------------------- Set to ``1`` to have Gel send appropriate CORS headers with HTTP responses. .. note:: This is set to ``1`` by default for Gel Cloud instances. .. _ref_reference_envvar_admin_ui: GEL_SERVER_ADMIN_UI ------------------- Set to ``enabled`` to enable the web-based admininstrative UI for the instance. Maps directly to the |gel-server| flag ``--admin-ui``. GEL_SERVER_ALLOW_INSECURE_BINARY_CLIENTS ---------------------------------------- .. warning:: Deprecated Use :gelenv:`SERVER_BINARY_ENDPOINT_SECURITY` instead. Specifies the security mode of the server's binary endpoint. When set to ``1``, non-TLS connections are allowed. Not set by default. .. warning:: Disabling TLS is not recommended in production. GEL_SERVER_ALLOW_INSECURE_HTTP_CLIENTS -------------------------------------- .. warning:: Deprecated Use :gelenv:`SERVER_HTTP_ENDPOINT_SECURITY` instead. Specifies the security mode of the server's HTTP endpoint. When set to ``1``, non-TLS connections are allowed. Not set by default. .. warning:: Disabling TLS is not recommended in production. .. _ref_reference_docker_gel_server_backend_dsn: GEL_SERVER_BACKEND_DSN / _FILE / _ENV ------------------------------------- Specifies a PostgreSQL connection string in the `URI format`_. If set, the PostgreSQL cluster specified by the URI is used instead of the builtin PostgreSQL server. Cannot be specified alongside :gelenv:`SERVER_DATADIR`. Maps directly to the |gel-server| flag ``--backend-dsn``. The ``_FILE`` and ``_ENV`` variants are also supported. .. _URI format: https://www.postgresql.org/docs/13/libpq-connect.html#id-1.7.3.8.3.6 GEL_SERVER_MAX_BACKEND_CONNECTIONS ---------------------------------- The maximum NUM of connections this Gel instance could make to the backend PostgreSQL cluster. If not set, Gel will detect and calculate the NUM: RAM/100MiB for local Postgres, or pg_settings.max_connections for remote Postgres minus the NUM of ``--reserved-pg-connections``. GEL_SERVER_BINARY_ENDPOINT_SECURITY ----------------------------------- Specifies the security mode of the server's binary endpoint. When set to ``optional``, non-TLS connections are allowed. Default is ``tls``. .. warning:: Disabling TLS is not recommended in production. GEL_SERVER_BIND_ADDRESS / _FILE / _ENV -------------------------------------- Specifies the network interface on which Gel will listen. Maps directly to the |gel-server| flag ``--bind-address``. The ``_FILE`` and ``_ENV`` variants are also supported. GEL_SERVER_BOOTSTRAP_COMMAND ---------------------------- Useful to fine-tune initial user creation and other initial setup. Maps directly to the |gel-server| flag ``--bootstrap-command``. The ``_FILE`` and ``_ENV`` variants are also supported. .. note:: A create branch statement (i.e., :eql:stmt:`create empty branch`, :eql:stmt:`create schema branch`, or :eql:stmt:`create data branch`) cannot be combined in a block with any other statements. Since all statements in :gelenv:`SERVER_BOOTSTRAP_COMMAND` run in a single block, it cannot be used to create a branch and, for example, create a user on that branch. For Docker deployments, you can instead write :ref:`custom scripts to run before migrations `. These are placed in ``/gel-bootstrap.d/``. By writing your ``create branch`` statements in one ``.edgeql`` file each placed in ``/gel-bootstrap.d/`` and other statements in their own file, you can create branches and still run other EdgeQL statements to bootstrap your instance. Note that for |EdgeDB| versions prior to 5.0, paths contain "edgedb" instead of "gel", so ``/gel-bootstrap.d/`` becomes ``/edgedb-bootstrap.d/``. GEL_SERVER_BOOTSTRAP_ONLY ------------------------- When set, bootstrap the database cluster and exit. Not set by default. .. _ref_reference_docker_gel_server_datadir: GEL_SERVER_DATADIR ------------------ Specifies a path where the database files are located. Default is ``/var/lib/gel/data``. Cannot be specified alongside :gelenv:`SERVER_BACKEND_DSN`. Maps directly to the |gel-server| flag ``--data-dir``. GEL_SERVER_DEFAULT_AUTH_METHOD / _FILE / _ENV --------------------------------------------- Optionally specifies the authentication method used by the server instance. Supported values are ``SCRAM`` (the default) and ``Trust``. When set to ``Trust``, the database will allow complete unauthenticated access for all who have access to the database port. This is often useful when setting an admin password on an instance that lacks one. Use at your own risk and only for development and testing. The ``_FILE`` and ``_ENV`` variants are also supported. GEL_SERVER_HTTP_ENDPOINT_SECURITY --------------------------------- Specifies the security mode of the server's HTTP endpoint. When set to ``optional``, non-TLS connections are allowed. Default is ``tls``. .. warning:: Disabling TLS is not recommended in production. GEL_SERVER_INSTANCE_NAME ------------------------ Specify the server instance name. GEL_SERVER_JWS_KEY_FILE ----------------------- Specifies a path to a file containing a public key in PEM format used to verify JWT signatures. The file could also contain a private key to sign JWT for local testing. GEL_SERVER_LOG_LEVEL -------------------- Set the logging level. Default is ``info``. Other possible values are ``debug``, ``warn``, ``error``, and ``silent``. GEL_SERVER_PORT / _FILE / _ENV ------------------------------ Specifies the network port on which Gel will listen. Default is ``5656``. Maps directly to the |gel-server| flag ``--port``. The ``_FILE`` and ``_ENV`` variants are also supported. GEL_SERVER_RUNSTATE_DIR ----------------------- Specifies a path where Gel will place its Unix socket and other transient files. Maps directly to the |gel-server| flag ``--runstate-dir``. GEL_SERVER_SECURITY ------------------- When set to ``insecure_dev_mode``, sets :gelenv:`SERVER_DEFAULT_AUTH_METHOD` to ``Trust``, and :gelenv:`SERVER_TLS_CERT_MODE` to ``generate_self_signed`` (unless an explicit TLS certificate is specified). Finally, if this option is set, the server will accept plaintext HTTP connections. Maps directly to the |gel-server| flag ``--security``. .. warning:: Disabling TLS is not recommended in production. GEL_SERVER_TLS_CERT_FILE ------------------------ The TLS certificate file, exclusive with :gelenv:`SERVER_TLS_CERT_MODE=generate_self_signed`. Maps directly to the |gel-server| flag ``--tls-cert-file``. GEL_SERVER_TLS_KEY_FILE ----------------------- The TLS private key file, exclusive with :gelenv:`SERVER_TLS_CERT_MODE=generate_self_signed`. Maps directly to the |gel-server| flag ``--tls-key-file``. GEL_SERVER_TLS_CERT_MODE / _FILE / _ENV --------------------------------------- Specifies what to do when the TLS certificate and key are either not specified or are missing. - When set to ``require_file``, the TLS certificate and key must be specified in the :gelenv:`SERVER_TLS_CERT` and :gelenv:`SERVER_TLS_KEY` variables and both must exist. - When set to ``generate_self_signed`` a new self-signed certificate and private key will be generated and placed in the path specified by :gelenv:`SERVER_TLS_CERT` and :gelenv:`SERVER_TLS_KEY`, if those are set. Otherwise, the generated certificate and key are stored as ``edbtlscert.pem`` and ``edbprivkey.pem`` in :gelenv:`SERVER_DATADIR`, or, if :gelenv:`SERVER_DATADIR` is not set, they will be placed in ``/etc/ssl/gel``. Default is ``generate_self_signed`` when :gelenv:`SERVER_SECURITY=insecure_dev_mode`. Otherwise, the default is ``require_file``. Maps directly to the |gel-server| flag ``--tls-cert-mode``. The ``_FILE`` and ``_ENV`` variants are also supported. Docker image specific variables =============================== These variables are only used by the Docker image. Setting these variables outside that context will have no effect. GEL_DOCKER_ABORT_CODE --------------------- If the process fails, the arguments are logged to stderr and the script is terminated with this exit code. Default is ``1``. GEL_DOCKER_APPLY_MIGRATIONS --------------------------- The container will attempt to apply migrations in ``dbschema/migrations`` unless this variable is set to ``never``. **Values**: ``always`` (default), ``never`` GEL_DOCKER_BOOTSTRAP_TIMEOUT_SEC -------------------------------- Sets the number of seconds to wait for instance bootstrapping to complete before timing out. Default is ``300``. GEL_DOCKER_LOG_LEVEL -------------------- Change the logging level for the docker container. **Values**: ``trace``, ``debug``, ``info`` (default), ``warn``, ``error`` GEL_DOCKER_SHOW_GENERATED_CERT ------------------------------ Shows the generated TLS certificate in console output. **Values**: ``always`` (default), ``never`` GEL_SERVER_BINARY ----------------- Sets the Gel server binary to run. Default is |gel-server|. GEL_SERVER_BOOTSTRAP_COMMAND_FILE --------------------------------- Run the script when initializing the database. The script is run by the default user within the default |branch|. May be used with or without :gelenv:`SERVER_BOOTSTRAP_ONLY`. GEL_SERVER_COMPILER_POOL_MODE ----------------------------- Choose a mode for the compiler pool to scale. ``fixed`` means the pool will not scale and sticks to :gelenv:`SERVER_COMPILER_POOL_SIZE`, while ``on_demand`` means the pool will maintain at least 1 worker and automatically scale up (to :gelenv:`SERVER_COMPILER_POOL_SIZE` workers ) and down to the demand. **Values**: ``fixed``, ``on_demand`` Default is ``fixed`` in production mode and ``on_demand`` in development mode. GEL_SERVER_COMPILER_POOL_SIZE ----------------------------- When :gelenv:`SERVER_COMPILER_POOL_MODE` is ``fixed``, this setting is the exact size of the compiler pool. When :gelenv:`SERVER_COMPILER_POOL_MODE` is ``on_demand``, this will serve as the maximum size of the compiler pool. GEL_SERVER_EMIT_SERVER_STATUS ----------------------------- Instruct the server to emit changes in status to *DEST*, where *DEST* is a URI specifying a file (``file://``), or a file descriptor (``fd://``). If the URI scheme is not specified, ``file://`` is assumed. GEL_SERVER_EXTRA_ARGS --------------------- Additional arguments to pass when starting the Gel server. GEL_SERVER_PASSWORD / _FILE / _ENV ---------------------------------- The password for the default superuser account (or the user specified in :gelenv:`SERVER_USER`) will be set to this value. If no value is provided, a password will not be set, unless set via :gelenv:`SERVER_BOOTSTRAP_COMMAND`. (If a value for :gelenv:`SERVER_BOOTSTRAP_COMMAND` is provided, this variable will be ignored.) The ``_FILE`` and ``_ENV`` variants are also supported. GEL_SERVER_PASSWORD_HASH / _FILE / _ENV --------------------------------------- A variant of :gelenv:`SERVER_PASSWORD`, where the specified value is a hashed password verifier instead of plain text. If :gelenv:`SERVER_BOOTSTRAP_COMMAND` is set, this variable will be ignored. The ``_FILE`` and ``_ENV`` variants are also supported. GEL_SERVER_TENANT_ID -------------------- Specifies the tenant ID of this server. When using multiple Gel instances with one Postgres cluster each Gel instance must have a unique tenant ID. Must be an alphanumeric ASCII string, maximum 10 characters long. Defaults to "E" if not set. GEL_SERVER_UID -------------- Specifies the ID of the user which should run the server binary. Default is ``1``. GEL_SERVER_USER --------------- If set to anything other than the default username |admin|, the username specified will be created. The user defined here will be the one assigned the password set in :gelenv:`SERVER_PASSWORD` or the hash set in :gelenv:`SERVER_PASSWORD_HASH`. ================================================ FILE: docs/reference/running/deployment/aws_aurora_ecs.rst ================================================ .. _ref_guide_deployment_aws_aurora_ecs: === AWS === :edb-alt-title: Deploying Gel to AWS .. note:: We recomend using our `helm chart `_ to deploy gel on AWS EKS. The CloudFormation guide below does not configure TLS certificates correctly. .. _helm-chart: https://github.com/geldata/helm-charts/blob/main /charts/gel-server/README.md In this guide we show how to deploy Gel on AWS using Amazon Aurora and Elastic Container Service. .. include:: ./note_cloud_reset_password.rst Prerequisites ============= * AWS account with billing enabled (or a `free trial `_) * (optional) ``aws`` CLI (`install `_) .. _aws-trial: https://aws.amazon.com/free .. _awscli-install: https://docs.aws.amazon.com /cli/latest/userguide/getting-started-install.html Quick Install with CloudFormation ================================= We maintain a `CloudFormation template `_ for easy automated deployment of Gel in your AWS account. The template deploys Gel to a new ECS service and connects it to a newly provisioned Aurora PostgreSQL cluster. The created instance has a public IP address with TLS configured and is protected by a password you provide. CloudFormation Web Portal ------------------------- Click `here `_ to start the deployment process using CloudFormation portal and follow the prompts. You'll be prompted to provide a value for the following parameters: - ``DockerImage``: defaults to the latest version (``geldata/gel``), or you can specify a particular tag from the ones published to `Docker Hub `_. - ``InstanceName``: ⚠️ Due to limitations with AWS, this must be 22 characters or less! - ``SuperUserPassword``: this will be used as the password for the new Gel instance. Keep track of the value you provide. Once the deployment is complete, follow these steps to find the host name that has been assigned to your Gel instance: .. lint-off 1. Open the AWS Console and navigate to CloudFormation > Stacks. Click on the newly created Stack. 2. Wait for the status to read ``CREATE_COMPLETE``—it can take 15 minutes or more. 3. Once deployment is complete, click the ``Outputs`` tab. The value of ``PublicHostname`` is the hostname at which your Gel instance is publicly available. 4. Copy the hostname and run the following command to open a REPL to your instance. .. code-block:: bash $ gel --dsn gel://admin:@ --tls-security insecure Gel x.x Type \help for help, \quit to quit. gel> .. lint-on To make changes to your Gel deployment like upgrading the Gel version or enabling the UI you can follow the CloudFormation `Updating a stack `_ instructions. Search for ``ContainerDefinitions`` in the template and you will find where Gel's :ref:`environment variables ` are defined. To upgrade the Gel version specify a `docker image tag `_ with the image name ``geldata/gel`` in the second step of the update workflow. CloudFormation CLI ------------------ Alternatively, if you prefer to use AWS CLI, run the following command in your terminal: .. code-block:: bash $ aws cloudformation create-stack \ --stack-name Gel \ --template-url \ https://gel-deployment.s3.us-east-2.amazonaws.com/gel-aurora.yml \ --capabilities CAPABILITY_NAMED_IAM \ --parameters ParameterKey=SuperUserPassword,ParameterValue= .. _cf-template: https://github.com/geldata/gel-deploy/tree/dev/aws-cf .. _cf-deploy: https://console.aws.amazon.com /cloudformation/home#/stacks/new?stackName=Gel&templateURL= https%3A%2F%2Fgel-deployment.s3.us-east-2.amazonaws.com%2Fgel-aurora.yml .. _aws_console: https://console.aws.amazon.com /ec2/v2/home#NIC:search=ec2-security-group .. _stack-update: https://docs.aws.amazon.com /AWSCloudFormation/latest/UserGuide/cfn-whatis-howdoesitwork.html .. _docker-tags: https://hub.docker.com/r/geldata/gel/tags Connecting your application =========================== To connect your application to the Gel instance, you'll need to provide connection parameters. Gel client libraries can be configured using either a DSN (connection string) or individual environment variables. Obtaining connection parameters ------------------------------- Your connection requires the following components: - **Host**: The ``PublicHostname`` value from the CloudFormation Stack's ``Outputs`` tab. - **Port**: ``5656`` (the default Gel port) - **Username**: |admin| (the default superuser) - **Password**: The ``SuperUserPassword`` you specified during deployment - **Branch**: |main| (the default branch) Construct the DSN using these values: .. code-block:: bash $ GEL_DSN="gel://admin:@:5656" Obtaining the TLS certificate ----------------------------- .. warning:: The CloudFormation template does not configure TLS certificates correctly. We recommend using ``--tls-security insecure`` for testing, but for production you should use our `helm chart `_ or configure TLS manually. To connect securely, your application needs the server's TLS certificate. For self-signed certificates, you can retrieve the certificate by connecting to the instance and extracting it: .. code-block:: bash $ gel --dsn $GEL_DSN --tls-security insecure \ query "SELECT sys::get_tls_certificate()" Store this certificate and provide it to your application via the :gelenv:`TLS_CA` or :gelenv:`TLS_CA_FILE` environment variable. Using in your application ------------------------- Set these environment variables where you deploy your application: .. code-block:: bash GEL_DSN="gel://admin:@:5656" # For self-signed certificates: GEL_CLIENT_TLS_SECURITY=insecure # Or with a proper TLS certificate: GEL_TLS_CA="" Gel's client libraries will automatically read these environment variables. Local development with the CLI ------------------------------ To make your remote instance easier to work with during local development, create an alias using :gelcmd:`instance link`. .. note:: The command groups :gelcmd:`instance` and :gelcmd:`project` are not intended to manage production instances. .. code-block:: bash $ gel instance link \ --dsn $GEL_DSN \ --non-interactive \ --trust-tls-cert \ my_aws_instance You can now refer to the remote instance using the alias ``my_aws_instance``. Use this alias wherever an instance name is expected: .. code-block:: bash $ gel -I my_aws_instance Gel x.x Type \help for help, \quit to quit. gel> Or apply migrations: .. code-block:: bash $ gel -I my_aws_instance migrate Health Checks ============= Using an HTTP client, you can perform health checks to monitor the status of your Gel instance. Learn how to use them with our :ref:`health checks guide `. ================================================ FILE: docs/reference/running/deployment/azure_flexibleserver.rst ================================================ .. _ref_guide_deployment_azure_flexibleserver: ===== Azure ===== :edb-alt-title: Deploying Gel to Azure In this guide we show how to deploy Gel using Azure's `Postgres Flexible Server `_ as the backend. .. include:: ./note_cloud_reset_password.rst Prerequisites ============= * Valid Azure Subscription with billing enabled or credits (`free trial `_). * Azure CLI (`install `_). .. _azure-trial: https://azure.microsoft.com/en-us/free/ .. _azure-install: https://docs.microsoft.com/en-us/cli/azure/install-azure-cli Provision a Gel instance ========================= Login to your Microsoft Azure account. .. code-block:: bash $ az login Create a new resource group. .. code-block:: bash $ GROUP=my-group-name $ az group create --name $GROUP --location westus Provision a PostgreSQL server. .. note:: If you already have a database provisioned you can skip this step. For convenience, assign a value to the ``PG_SERVER_NAME`` environment variable; we'll use this variable in multiple later commands. .. code-block:: bash $ PG_SERVER_NAME=postgres-for-gel Use the ``read`` command to securely assign a value to the ``PASSWORD`` environment variable. .. code-block:: bash $ echo -n "> " && read -s PASSWORD Then create a Postgres Flexible server. .. code-block:: bash $ az postgres flexible-server create \ --resource-group $GROUP \ --name $PG_SERVER_NAME \ --location westus \ --admin-user gel_admin \ --admin-password $PASSWORD \ --sku-name Standard_D2s_v3 \ --version 14 \ --yes .. note:: If you get an error saying ``"Specified server name is already used."`` change the value of ``PG_SERVER_NAME`` and rerun the command. Allow other Azure services access to the Postgres instance. .. code-block:: bash $ az postgres flexible-server firewall-rule create \ --resource-group $GROUP \ --name $PG_SERVER_NAME \ --rule-name allow-azure-internal \ --start-ip-address 0.0.0.0 \ --end-ip-address 0.0.0.0 |Gel| requires Postgres' ``uuid-ossp`` extension which needs to be enabled. .. code-block:: bash $ az postgres flexible-server parameter set \ --resource-group $GROUP \ --server-name $PG_SERVER_NAME \ --name azure.extensions \ --value uuid-ossp Azure is not able to reliably pull docker images `because of rate limits `_, so you will need to provide docker hub login credentials to create a container. If you don't already have a docker hub account you can create one `here `_. .. _azure-cli-issue: https://github.com/Azure/azure-cli/issues/29300 .. code-block:: bash $ echo -n "docker user> " && read -s DOCKER_USER $ echo -n "docker password> " && read -s DOCKER_PASSWORD Start a Gel container. .. code-block:: bash $ PG_HOST=$( az postgres flexible-server list \ --resource-group $GROUP \ --query "[?name=='$PG_SERVER_NAME'].fullyQualifiedDomainName | [0]" \ --output tsv ) $ DSN="postgresql://gel_admin:$PASSWORD@$PG_HOST/postgres?sslmode=require" $ az container create \ --registry-username $DOCKER_USER \ --registry-password $DOCKER_PASSWORD \ --registry-login-server index.docker.io \ --os-type Linux \ --cpu 1 \ --memory 1 \ --resource-group $GROUP \ --name gel-container-group \ --image geldata/gel \ --dns-name-label geldb \ --ports 5656 \ --secure-environment-variables \ "GEL_SERVER_PASSWORD=$PASSWORD" \ "GEL_SERVER_BACKEND_DSN=$DSN" \ --environment-variables \ GEL_SERVER_TLS_CERT_MODE=generate_self_signed Persist the SSL certificate. We have configured Gel to generate a self signed SSL certificate when it starts. However, if the container is restarted a new certificate would be generated. To preserve the certificate across failures or reboots copy the certificate files and use their contents in the :gelenv:`SERVER_TLS_KEY` and :gelenv:`SERVER_TLS_CERT` environment variables. .. code-block:: bash $ key="$( az container exec \ --resource-group $GROUP \ --name gel-container-group \ --exec-command "cat /tmp/gel/edbprivkey.pem" \ | tr -d "\r" )" $ cert="$( az container exec \ --resource-group $GROUP \ --name gel-container-group \ --exec-command "cat /tmp/gel/edbtlscert.pem" \ | tr -d "\r" )" $ az container delete \ --resource-group $GROUP \ --name gel-container-group \ --yes $ az container create \ --registry-username $DOCKER_USER \ --registry-password $DOCKER_PASSWORD \ --registry-login-server index.docker.io \ --os-type Linux \ --cpu 1 \ --memory 1 \ --resource-group $GROUP \ --name gel-container-group \ --image geldata/gel \ --dns-name-label geldb \ --ports 5656 \ --secure-environment-variables \ "GEL_SERVER_PASSWORD=$PASSWORD" \ "GEL_SERVER_BACKEND_DSN=$DSN" \ "GEL_SERVER_TLS_KEY=$key" \ --environment-variables \ "GEL_SERVER_TLS_CERT=$cert" Connecting your application =========================== To connect your application to the Gel instance, you'll need to provide connection parameters. Gel client libraries can be configured using either a DSN (connection string) or individual environment variables. Obtaining connection parameters ------------------------------- Your connection requires the following components: - **Host**: The FQDN of your Azure container instance. Retrieve it with: .. code-block:: bash $ az container list \ --resource-group $GROUP \ --query "[?name=='gel-container-group'].ipAddress.fqdn | [0]" \ --output tsv - **Port**: ``5656`` (the default Gel port) - **Username**: |admin| (the default superuser) - **Password**: The password you set in the ``$PASSWORD`` variable - **Branch**: |main| (the default branch) Construct the DSN using these values: .. code-block:: bash $ GEL_HOST=$(az container list \ --resource-group $GROUP \ --query "[?name=='gel-container-group'].ipAddress.fqdn | [0]" \ --output tsv) $ GEL_DSN="gel://admin:$PASSWORD@$GEL_HOST:5656" Obtaining the TLS certificate ----------------------------- Since we configured Gel with a self-signed TLS certificate, your application needs the certificate to connect securely. Retrieve it from the container: .. code-block:: bash $ az container exec \ --resource-group $GROUP \ --name gel-container-group \ --exec-command "cat /tmp/gel/edbtlscert.pem" \ | tr -d "\r" > gel-tls-cert.pem Alternatively, you can retrieve it using the Gel CLI: .. code-block:: bash $ gel --dsn $GEL_DSN --tls-security insecure \ query "SELECT sys::get_tls_certificate()" > gel-tls-cert.pem Using in your application ------------------------- Set these environment variables where you deploy your application: .. code-block:: bash GEL_DSN="gel://admin:@:5656" # For self-signed certificates, either trust the cert: GEL_TLS_CA_FILE="/path/to/gel-tls-cert.pem" # Or (for development only) disable TLS verification: GEL_CLIENT_TLS_SECURITY=insecure Gel's client libraries will automatically read these environment variables. Local development with the CLI ------------------------------ To make your remote instance easier to work with during local development, create an alias using :gelcmd:`instance link`. .. note:: The command groups :gelcmd:`instance` and :gelcmd:`project` are not intended to manage production instances. .. code-block:: bash $ printf $PASSWORD | gel instance link \ --dsn $GEL_DSN \ --password-from-stdin \ --non-interactive \ --trust-tls-cert \ my_azure_instance You can now refer to the remote instance using the alias ``my_azure_instance``. Use this alias wherever an instance name is expected: .. code-block:: bash $ gel -I my_azure_instance Gel x.x Type \help for help, \quit to quit. gel> Or apply migrations: .. code-block:: bash $ gel -I my_azure_instance migrate Health Checks ============= Using an HTTP client, you can perform health checks to monitor the status of your Gel instance. Learn how to use them with our :ref:`health checks guide `. ================================================ FILE: docs/reference/running/deployment/bare_metal.rst ================================================ .. _ref_guide_deployment_bare_metal: ========== Bare Metal ========== :edb-alt-title: Deploying Gel to a Bare Metal Server In this guide we show how to deploy Gel to bare metal using your system's package manager and systemd. .. include:: ./note_cloud_reset_password.rst Install the Gel Package ======================= The steps for installing the Gel package will be slightly different depending on your Linux distribution. Once you have the package installed you can jump to :ref:`ref_guide_deployment_bare_metal_enable_unit`. Debian/Ubuntu LTS ----------------- Import the Gel packaging key. .. code-block:: bash $ sudo mkdir -p /usr/local/share/keyrings && \ sudo curl --proto '=https' --tlsv1.2 -sSf \ -o /usr/local/share/keyrings/gel-keyring.gpg \ https://packages.geldata.com/keys/gel-keyring.gpg Add the Gel package repository. .. code-block:: bash $ echo deb '[signed-by=/usr/local/share/keyrings/gel-keyring.gpg]' \ https://packages.geldata.com/apt \ $(grep "VERSION_CODENAME=" /etc/os-release | cut -d= -f2) main \ | sudo tee /etc/apt/sources.list.d/gel.list .. note:: For non-LTS releases of Debian/Ubuntu (e.g. Ubuntu Oracular), one can install package for latest LTS release, because they are usually forward compatible. To do this, replace the ``$(grep ...)`` with the name of latest LTS release (e.g. ``noble``). Install the Gel package. .. code-block:: bash $ sudo apt-get update && sudo apt-get install gel-6 CentOS/RHEL 7/8 --------------- Add the Gel package repository. .. code-block:: bash $ sudo curl --proto '=https' --tlsv1.2 -sSfL \ https://packages.geldata.com/rpm/gel-rhel.repo \ > /etc/yum.repos.d/gel.repo Install the Gel package. .. code-block:: bash $ sudo yum install gel-6 Disable SELinux. .. code-block:: bash $ sed -i 's/SELINUX=enforcing/SELINUX=disabled/' /etc/selinux/config $ reboot .. _ref_guide_deployment_bare_metal_enable_unit: Enable a systemd unit ===================== The Gel package comes bundled with a systemd unit that is disabled by default. You can start the server by enabling the unit. .. code-block:: bash $ sudo systemctl enable --now gel-server-6 This will start the server on port 5656, and the data directory will be ``/var/lib/gel/6/data``. .. warning:: |gel-server| cannot be run as root. Set environment variables ========================= To set environment variables when running Gel with ``systemctl``, .. code-block:: bash $ systemctl edit --full gel-server-6 This opens a ``systemd`` unit file. Set the desired environment variables under the ``[Service]`` section. View the supported environment variables at :ref:`Reference > Environment Variables `. .. code-block:: toml [Service] Environment="GEL_SERVER_TLS_CERT_MODE=generate_self_signed" Environment="GEL_SERVER_ADMIN_UI=enabled" Save the file and exit, then restart the service. .. code-block:: bash $ systemctl restart gel-server-6 Set a password ============== There is no default password. To set one, you will first need to get the Unix socket directory. You can find this by looking at your system.d unit file. .. code-block:: bash $ sudo systemctl cat gel-server-6 Set a password by connecting from localhost. .. code-block:: bash $ echo -n "> " && read -s PASSWORD $ RUNSTATE_DIR=$(systemctl show gel-server-6 -P ExecStart | \ grep -o -m 1 -- "--runstate-dir=[^ ]\+" | \ awk -F "=" '{print $2}') $ sudo gel --port 5656 --tls-security insecure --admin \ --unix-path $RUNSTATE_DIR \ query "ALTER ROLE admin SET password := '$PASSWORD'" The server listens on localhost by default. Changing this looks like this. .. code-block:: bash $ gel --port 5656 --tls-security insecure --password query \ "CONFIGURE INSTANCE SET listen_addresses := {'0.0.0.0'};" The listen port can be changed from the default ``5656`` if your deployment scenario requires a different value. .. code-block:: bash $ gel --port 5656 --tls-security insecure --password query \ "CONFIGURE INSTANCE SET listen_port := 1234;" You may need to restart the server after changing the listen port or addresses. .. code-block:: bash $ sudo systemctl restart gel-server-6 Connecting your application =========================== To connect your application to the Gel instance, you'll need to provide connection parameters. Gel client libraries can be configured using either a DSN (connection string) or individual environment variables. Obtaining connection parameters ------------------------------- Your connection requires the following components: - **Host**: The IP address or hostname of your server (e.g., ``localhost``, ``192.168.1.100``, or ``gel.example.com``) - **Port**: ``5656`` by default, or the custom port if you changed it with ``CONFIGURE INSTANCE SET listen_port`` - **Username**: |admin| (the default superuser) - **Password**: The password you set with ``ALTER ROLE admin SET password`` - **Branch**: |main| (the default branch) Construct the DSN using these values: .. code-block:: bash $ GEL_DSN="gel://admin:@:5656" Obtaining the TLS certificate ----------------------------- If you configured Gel with ``GEL_SERVER_TLS_CERT_MODE=generate_self_signed``, your application needs the certificate to connect securely. The generated certificate is stored in the data directory. You can find it at: .. code-block:: bash $ cat /var/lib/gel/6/data/edbtlscert.pem Alternatively, retrieve it using the Gel CLI: .. code-block:: bash $ gel --dsn $GEL_DSN --tls-security insecure \ query "SELECT sys::get_tls_certificate()" Using in your application ------------------------- Set these environment variables where you deploy your application: .. code-block:: bash GEL_DSN="gel://admin:@:5656" # For self-signed certificates, provide the CA cert: GEL_TLS_CA_FILE="/path/to/edbtlscert.pem" # Or embed the certificate content directly: GEL_TLS_CA="" Gel's client libraries will automatically read these environment variables. Local development with the CLI ------------------------------ To make your instance easier to work with during local development, create an alias using :gelcmd:`instance link`. .. note:: The command groups :gelcmd:`instance` and :gelcmd:`project` are not intended to manage production instances. .. code-block:: bash $ gel instance link \ --dsn $GEL_DSN \ --non-interactive \ --trust-tls-cert \ my_bare_metal_instance You can now refer to the instance using the alias ``my_bare_metal_instance``. Use this alias wherever an instance name is expected: .. code-block:: bash $ gel -I my_bare_metal_instance Gel x.x Type \help for help, \quit to quit. gel> Or apply migrations: .. code-block:: bash $ gel -I my_bare_metal_instance migrate Upgrading Gel ============= When you want to upgrade to the newest point release upgrade the package and restart the ``gel-server-6`` unit. Debian/Ubuntu LTS ----------------- .. code-block:: bash $ sudo apt-get update && sudo apt-get install --only-upgrade gel-6 $ sudo systemctl restart gel-server-6 CentOS/RHEL 7/8 --------------- .. code-block:: bash $ sudo yum update gel-6 $ sudo systemctl restart gel-server-6 Health Checks ============= Using an HTTP client, you can perform health checks to monitor the status of your Gel instance. Learn how to use them with our :ref:`health checks guide `. ================================================ FILE: docs/reference/running/deployment/digitalocean.rst ================================================ .. _ref_guide_deployment_digitalocean: ============ DigitalOcean ============ :edb-alt-title: Deploying Gel to DigitalOcean Create a droplet and use the :ref:`ref_guide_deployment_bare_metal` guide to install gel server. ================================================ FILE: docs/reference/running/deployment/docker.rst ================================================ .. _ref_guide_deployment_docker: ====== Docker ====== :edb-alt-title: Deploying Gel with Docker .. include:: ./note_cloud_reset_password.rst When to use the "geldata/gel" Docker image ========================================== .. _geldata/gel: https://hub.docker.com/r/geldata/gel This image is primarily intended to be used directly when there is a requirement to use Docker containers, such as in production, or in a development setup that involves multiple containers orchestrated by Docker Compose or a similar tool. Otherwise, using the :ref:`ref_cli_gel_server` CLI on the host system is the recommended way to install and run Gel servers. How to use this image ===================== The simplest way to run the image (without data persistence) is this: .. code-block:: bash $ docker run --name gel -d \ -e GEL_SERVER_SECURITY=insecure_dev_mode \ geldata/gel See the :ref:`ref_guides_deployment_docker_customization` section below for the meaning of the :gelenv:`SERVER_SECURITY` variable and other options. Then, to authenticate to the Gel instance and store the credentials in a Docker volume, run: .. code-block:: bash $ docker run -it --rm --link=gel \ -e GEL_SERVER_PASSWORD=secret \ -v gel-cli-config:/.config/edgedb geldata/gel-cli \ -H gel instance link my_instance \ --tls-security insecure \ --non-interactive Now, to open an interactive shell to the database instance run this: .. code-block:: bash $ docker run -it --rm --link=gel \ -v gel-cli-config:/.config/edgedb geldata/gel-cli \ -I my_instance Data Persistence ================ If you want the contents of the database to survive container restarts, you must mount a persistent volume at the path specified by :gelenv:`SERVER_DATADIR` (``/var/lib/gel/data`` by default). For example: .. code-block:: bash $ docker run \ --name gel \ -e GEL_SERVER_PASSWORD=secret \ -e GEL_SERVER_TLS_CERT_MODE=generate_self_signed \ -v /my/data/directory:/var/lib/gel/data \ -d geldata/gel Note that on Windows you must use a Docker volume instead: .. code-block:: bash $ docker volume create --name=gel-data $ docker run \ --name gel \ -e GEL_SERVER_PASSWORD=secret \ -e GEL_SERVER_TLS_CERT_MODE=generate_self_signed \ -v gel-data:/var/lib/gel/data \ -d geldata/gel It is also possible to run a ``gel`` container on a remote PostgreSQL cluster specified by :gelenv:`SERVER_BACKEND_DSN`. See below for details. Schema Migrations ================= A derived image may include application schema and migrations in ``/dbschema``, in which case the container will attempt to apply the schema migrations found in ``/dbschema/migrations``, unless the :gelenv:`DOCKER_APPLY_MIGRATIONS` environment variable is set to ``never``. Docker Compose ============== A simple ``docker-compose`` configuration might look like this. With a ``docker-compose.yaml`` containing: .. code-block:: yaml services: gel: image: geldata/gel environment: GEL_SERVER_SECURITY: insecure_dev_mode volumes: - "./dbschema:/dbschema" ports: - "5656:5656" Once there is a :ref:`schema ` in ``dbschema/`` a migration can be created with: .. code-block:: bash $ gel --tls-security=insecure -P 5656 migration create Alternatively, if you don't have the Gel CLI installed on your host machine, you can use the CLI bundled with the server container: .. code-block:: bash $ docker compose exec gel \ gel --tls-security=insecure -P 5656 migration create .. _ref_guides_deployment_docker_customization: Configuration ============= The Docker image supports the same set of enviroment variables as the Gel server process, which are documented under :ref:`Reference > Environment Variables `. |Gel| containers can be additionally configured using initialization scripts and some Docker-specific environment variables, documented below. .. note:: Some variables support ``_ENV`` and ``_FILE`` :ref:`variants ` to support more advanced configurations. .. _ref_guides_deployment_docker_initial_setup: Initial configuration --------------------- When a Gel container starts on the specified data directory or remote Postgres cluster for the first time, initial instance setup is performed. This is called the *bootstrap phase*. The following environment variables affect the bootstrap only and have no effect on subsequent container runs. .. note:: For |EdgeDB| versions before 6.0 (Gel) the prefix for all environment variables is ``EDGEDB_`` instead of ``GEL_``. GEL_SERVER_BOOTSTRAP_COMMAND ............................ Useful to fine-tune initial user and branch creation, and other initial setup. If neither the :gelenv:`SERVER_BOOTSTRAP_COMMAND` variable or the :gelenv:`SERVER_BOOTSTRAP_SCRIPT_FILE` are explicitly specified, the container will look for the presence of ``/gel-bootstrap.edgeql`` in the container (which can be placed in a derived image). Maps directly to the |gel-server| flag ``--bootstrap-command``. The ``*_FILE`` and ``*_ENV`` variants are also supported. GEL_SERVER_BOOTSTRAP_SCRIPT_FILE ................................ Deprecated in image version 2.8: use :gelenv:`SERVER_BOOTSTRAP_COMMAND_FILE` instead. Run the script when initializing the database. The script is run by default user within default branch. GEL_SERVER_PASSWORD ................... The password for the default superuser account will be set to this value. If no value is provided a password will not be set, unless set via :gelenv:`SERVER_BOOTSTRAP_COMMAND`. (If a value for :gelenv:`SERVER_BOOTSTRAP_COMMAND` is provided, this variable will be ignored.) The ``*_FILE`` and ``*_ENV`` variants are also supported. GEL_SERVER_PASSWORD_HASH ........................ A variant of :gelenv:`SERVER_PASSWORD`, where the specified value is a hashed password verifier instead of plain text. If :gelenv:`SERVER_BOOTSTRAP_COMMAND` is set, this variable will be ignored. The ``*_FILE`` and ``*_ENV`` variants are also supported. GEL_SERVER_GENERATE_SELF_SIGNED_CERT .................................... .. warning:: Deprecated: use :gelenv:`SERVER_TLS_CERT_MODE=generate_self_signed` instead. Set this option to ``1`` to tell the server to automatically generate a self-signed certificate with key file in the :gelenv:`SERVER_DATADIR` (if present, see below), and echo the certificate content in the logs. If the certificate file exists, the server will use it instead of generating a new one. Self-signed certificates are usually used in development and testing, you should likely provide your own certificate and key file with the variables below. GEL_SERVER_TLS_CERT/GEL_SERVER_TLS_KEY ...................................... The TLS certificate and private key data, exclusive with :gelenv:`SERVER_TLS_CERT_MODE=generate_self_signed`. The ``*_FILE`` and ``*_ENV`` variants are also supported. Custom scripts in "/docker-entrypoint.d/" ......................................... To perform additional initialization, a derived image may include one or more executable files in ``/docker-entrypoint.d/``, which will get executed by the container entrypoint *before* any other processing takes place. Runtime configuration --------------------- GEL_DOCKER_LOG_LEVEL .................... Determines the log verbosity level in the entrypoint script. Valid levels are ``trace``, ``debug``, ``info``, ``warning``, and ``error``. The default is ``info``. .. _ref_guide_deployment_docker_custom_bootstrap_scripts: Custom scripts in "/gel-bootstrap.d/" and "/gel-bootstrap-late.d" ................................................................. To perform additional initialization, a derived image may include one or more ``*.edgeql`` or ``*.sh`` scripts, which are executed in addition to and *after* the initialization specified by the environment variables above or the ``/gel-bootstrap.edgeql`` script. Parts in ``/gel-bootstrap.d`` are executed *before* any schema migrations are applied, and parts in ``/gel-bootstrap-late.d`` are executed *after* the schema migration have been applied. .. note:: Best practice for naming your script files when you will have multiple script files to run on bootstrap is to prepend the filenames with ``01-``, ``02-``, and so on to indicate your desired order of execution. Connecting your application =========================== To connect your application to the Gel instance, you'll need to provide connection parameters. Gel client libraries can be configured using either a DSN (connection string) or individual environment variables. Obtaining connection parameters ------------------------------- Your connection requires the following components: - **Host**: The container hostname or IP address. In Docker Compose, this is the service name (e.g., ``gel``). For standalone containers, use ``localhost`` if on the same host, or the container's IP/hostname. - **Port**: ``5656`` (the default Gel port, unless remapped with ``-p``) - **Username**: |admin| (the default superuser) - **Password**: The value of :gelenv:`SERVER_PASSWORD` you set when starting the container - **Branch**: |main| (the default branch) Construct the DSN using these values: .. code-block:: bash $ GEL_DSN="gel://admin:@:5656" For a Docker Compose setup with the service named ``gel``: .. code-block:: bash $ GEL_DSN="gel://admin:secret@gel:5656" Obtaining the TLS certificate ----------------------------- If you configured Gel with ``GEL_SERVER_TLS_CERT_MODE=generate_self_signed``, your application needs the certificate to connect securely. Retrieve the certificate from the running container: .. code-block:: bash $ docker exec cat /var/lib/gel/data/edbtlscert.pem Or using the Gel utility script: .. code-block:: bash $ docker exec \ gel-show-secrets.sh --format=raw GEL_SERVER_TLS_CERT Alternatively, retrieve it using the Gel CLI: .. code-block:: bash $ gel --dsn $GEL_DSN --tls-security insecure \ query "SELECT sys::get_tls_certificate()" If you mounted a persistent volume at :gelenv:`SERVER_DATADIR`, the certificate is also available at ``/edbtlscert.pem``. Using in your application ------------------------- Set these environment variables in your application container: .. code-block:: yaml # docker-compose.yaml example services: app: image: your-app environment: GEL_DSN: "gel://admin:secret@gel:5656" # For self-signed certificates: GEL_CLIENT_TLS_SECURITY: "insecure" # Or provide the CA certificate: # GEL_TLS_CA: "" For production, we recommend providing the TLS certificate rather than disabling TLS verification: .. code-block:: yaml services: app: image: your-app environment: GEL_DSN: "gel://admin:${GEL_PASSWORD}@gel:5656" GEL_TLS_CA_FILE: "/certs/gel-ca.pem" volumes: - ./certs:/certs:ro Gel's client libraries will automatically read these environment variables. Local development with the CLI ------------------------------ To make your Gel container easier to work with during local development, create an alias using :gelcmd:`instance link`. .. note:: The command groups :gelcmd:`instance` and :gelcmd:`project` are not intended to manage production instances. From your host machine, link to the container: .. code-block:: bash $ gel instance link \ --dsn gel://admin:secret@localhost:5656 \ --non-interactive \ --trust-tls-cert \ my_docker_instance You can now refer to the instance using the alias ``my_docker_instance``. Use this alias wherever an instance name is expected: .. code-block:: bash $ gel -I my_docker_instance Gel x.x Type \help for help, \quit to quit. gel> Or apply migrations: .. code-block:: bash $ gel -I my_docker_instance migrate Health Checks ============= Using an HTTP client, you can perform health checks to monitor the status of your Gel instance. Learn how to use them with our :ref:`health checks guide `. ================================================ FILE: docs/reference/running/deployment/fly_io.rst ================================================ .. _ref_guide_deployment_fly_io: ====== Fly.io ====== :edb-alt-title: Deploying Gel to Fly.io In this guide we show how to deploy Gel using a `Fly.io `_ PostgreSQL cluster as the backend. The deployment consists of two apps: one running Postgres and the other running Gel. .. include:: ./note_cloud_reset_password.rst Prerequisites ============= * Fly.io account * ``flyctl`` CLI (`install `_) .. _flyctl-install: https://fly.io/docs/getting-started/installing-flyctl/ Provision a Fly.io app for Gel ============================== Every Fly.io app must have a globally unique name, including service VMs like Postgres and Gel. Pick a name and assign it to a local environment variable called ``EDB_APP``. In the command below, replace ``myorg-gel`` with a name of your choosing. .. code-block:: bash $ EDB_APP=myorg-gel $ flyctl apps create --name $EDB_APP New app created: myorg-gel Now let's use the ``read`` command to securely assign a value to the ``PASSWORD`` environment variable. .. code-block:: bash $ echo -n "> " && read -s PASSWORD Now let's assign this password to a Fly `secret `_, plus a few other secrets that we'll need. There are a couple more environment variables we need to set: .. code-block:: bash $ flyctl secrets set \ GEL_SERVER_PASSWORD="$PASSWORD" \ GEL_SERVER_BACKEND_DSN_ENV=DATABASE_URL \ GEL_SERVER_TLS_CERT_MODE=generate_self_signed \ GEL_SERVER_PORT=8080 \ --app $EDB_APP Secrets are staged for the first deployment Let's discuss what's going on with all these secrets. - The :gelenv:`SERVER_BACKEND_DSN_ENV` tells the Gel container where to look for the PostgreSQL connection string (more on that below) - The :gelenv:`SERVER_TLS_CERT_MODE` tells Gel to auto-generate a self-signed TLS certificate. You may instead choose to provision a custom TLS certificate. In this case, you should instead create two other secrets: assign your certificate to :gelenv:`SERVER_TLS_CERT` and your private key to :gelenv:`SERVER_TLS_KEY`. - Lastly, :gelenv:`SERVER_PORT` tells Gel to listen on port 8080 instead of the default 5656, because Fly.io prefers ``8080`` for its default health checks. Finally, let's configure the VM size as Gel requires a little bit more than the default Fly.io VM side provides. Put this in a file called ``fly.toml`` in your current directory.: .. code-block:: yaml [build] image = "geldata/gel" [[vm]] memory = "512mb" cpus = 1 cpu-kind = "shared" Create a PostgreSQL cluster =========================== Now we need to provision a PostgreSQL cluster and attach it to the Gel app. .. note:: If you have an existing PostgreSQL cluster in your Fly.io organization, you can skip to the attachment step. Then create a new PostgreSQL cluster. This may take a few minutes to complete. .. code-block:: bash $ PG_APP=myorg-postgres $ flyctl pg create --name $PG_APP --vm-size shared-cpu-1x ? Select region: sea (Seattle, Washington (US)) ? Specify the initial cluster size: 1 ? Volume size (GB): 10 Creating postgres cluster myorg-postgres in organization personal Postgres cluster myorg-postgres created Username: postgres Password: Hostname: myorg-postgres.internal Proxy Port: 5432 PG Port: 5433 Save your credentials in a secure place, you won't be able to see them again! Monitoring Deployment ... --> v0 deployed successfully In the output, you'll notice a line that says ``Machine is created``. The ID in that line is the ID of the virtual machine created for your Postgres cluster. We now need to use that ID to scale the cluster since the ``shared-cpu-1x`` VM doesn't have enough memory by default. Scale it with this command: .. code-block:: bash $ flyctl machine update --memory 1024 --app $PG_APP -y Searching for image 'flyio/postgres:14.6' remotely... image found: img_0lq747j0ym646x35 Image: registry-1.docker.io/flyio/postgres:14.6 Image size: 361 MB Updating machine Waiting for to become healthy (started, 3/3) Machine updated successfully! ==> Monitoring health checks Waiting for to become healthy (started, 3/3) ... With the VM scaled sufficiently, we can now attach the PostgreSQL cluster to the Gel app: .. code-block:: bash $ PG_ROLE=myorg_gel $ flyctl pg attach "$PG_APP" \ --database-user "$PG_ROLE" \ --app $EDB_APP Postgres cluster myorg-postgres is now attached to myorg-gel The following secret was added to myorg-gel: DATABASE_URL=postgres://... Lastly, Gel needs the ability to create Postgres databases and roles, so let's adjust the permissions on the role that Gel will use to connect to Postgres: .. code-block:: bash $ echo "alter role \"$PG_ROLE\" createrole createdb; \quit" \ | flyctl pg connect --app $PG_APP ... ALTER ROLE .. _ref_guide_deployment_fly_io_start_gel: Start Gel ========= Everything is set! Time to start Gel. .. code-block:: bash $ flyctl deploy --remote-only --app $EDB_APP ... Finished launching new machines ------- ✔ Machine e286630dce9638 [app] was created ------- That's it! You can now start using the Gel instance located at :geluri:`myorg-gel.internal` in your Fly.io apps. If deploy did not succeed: 1. make sure you've created the ``fly.toml`` file. 2. re-run the ``deploy`` command 3. check the logs for more information: ``flyctl logs --app $EDB_APP`` Persist the generated TLS certificate ===================================== Now we need to persist the auto-generated TLS certificate to make sure it survives Gel app restarts. (If you've provided your own certificate, skip this step). .. code-block:: bash $ EDB_SECRETS="GEL_SERVER_TLS_KEY GEL_SERVER_TLS_CERT" $ flyctl ssh console --app $EDB_APP -C \ "gel-show-secrets.sh --format=toml $EDB_SECRETS" \ | tr -d '\r' | flyctl secrets import --app $EDB_APP Connecting your application =========================== To connect your application to the Gel instance, you'll need to provide connection parameters. Gel client libraries can be configured using either a DSN (connection string) or individual environment variables. Obtaining connection parameters ------------------------------- Your connection requires the following components: - **Host (internal)**: ``$EDB_APP.internal`` — Fly uses this synthetic TLD for inter-app communication (e.g., ``myorg-gel.internal``) - **Host (external)**: ``$EDB_APP.fly.dev`` — for connections from outside Fly.io (requires exposing the port, see below) - **Port**: ``8080``, which we configured earlier with :gelenv:`SERVER_PORT` - **Username**: |admin| (the default superuser) - **Password**: The value you assigned to ``$PASSWORD`` - **Branch**: |main| (the default branch) Construct the DSN for internal Fly.io connections: .. code-block:: bash $ GEL_DSN=gel://admin:$PASSWORD@$EDB_APP.internal:8080 Consider writing it to a file to ensure the DSN looks correct. Remember to delete the file after you're done. (Printing this value to the terminal with ``echo`` is insecure and can leak your password into shell logs.) .. code-block:: bash $ echo $GEL_DSN > dsn.txt $ open dsn.txt $ rm dsn.txt Obtaining the TLS certificate ----------------------------- If you need secure TLS connections (required for external access), retrieve the server's TLS certificate: .. code-block:: bash $ flyctl ssh console -a $EDB_APP \ -C "gel-show-secrets.sh --format=raw GEL_SERVER_TLS_CERT" Save this to a file or set it as a secret in your application. From a Fly.io app ----------------- To connect to this instance from another Fly app (say, an app that runs your API server) set the value of the :gelenv:`DSN` secret inside that app. .. code-block:: bash $ flyctl secrets set \ GEL_DSN=$DSN \ --app my-other-fly-app We'll also set another variable that will disable Gel's TLS checks. Inter-application communication is secured by Fly so TLS isn't vital in this case; configuring TLS certificates is also beyond the scope of this guide. .. code-block:: bash $ flyctl secrets set GEL_CLIENT_TLS_SECURITY=insecure \ --app my-other-fly-app You can also set these values as environment variables inside your ``fly.toml`` file, but using Fly's built-in `secrets `_ functionality is recommended. From external application ------------------------- If you need to access Gel from outside the Fly.io network, you'll need to configure the Fly.io proxy to let external connections in. Let's make sure the ``[[services]]`` section in our ``fly.toml`` looks something like this: .. code-block:: toml [[services]] http_checks = [] internal_port = 8080 processes = ["app"] protocol = "tcp" script_checks = [] [services.concurrency] hard_limit = 25 soft_limit = 20 type = "connections" [[services.ports]] port = 5656 [[services.tcp_checks]] grace_period = "1s" interval = "15s" restart_limit = 0 timeout = "2s" In the same directory, :ref:`redeploy the Gel app `. This makes the Gel port available to the outside world. You can now access the instance from any host via the following public DSN: :geluri:`admin:$PASSWORD@$EDB_APP.fly.dev`. To secure communication between the server and the client, you will also need to set the :gelenv:`TLS_CA` environment secret in your application. You can securely obtain the certificate content by running: .. code-block:: bash $ flyctl ssh console -a $EDB_APP \ -C "gel-show-secrets.sh --format=raw GEL_SERVER_TLS_CERT" Local development with the CLI ------------------------------ To access the Gel instance from your local development machine, install the Wireguard `VPN `_ and create a tunnel, as described on Fly's `Private Networking `_ docs. Once it's up and running, use :gelcmd:`instance link` to create a local alias to the remote instance. .. note:: The command groups :gelcmd:`instance` and :gelcmd:`project` are not intended to manage production instances. .. code-block:: bash $ gel instance link \ --dsn $GEL_DSN \ --non-interactive \ --trust-tls-cert \ my_fly_instance Authenticating to gel://admin@myorg-gel.internal:8080/main Successfully linked to remote instance. To connect run: gel -I my_fly_instance You can now refer to the remote instance using the alias ``my_fly_instance``. Use this alias wherever an instance name is expected: .. code-block:: bash $ gel -I my_fly_instance Gel x.x Type \help for help, \quit to quit. gel> Or apply migrations: .. code-block:: bash $ gel -I my_fly_instance migrate .. _vpn: https://fly.io/docs/reference/private-networking/#private-network-vpn Health Checks ============= Using an HTTP client, you can perform health checks to monitor the status of your Gel instance. Learn how to use them with our :ref:`health checks guide `. ================================================ FILE: docs/reference/running/deployment/gcp.rst ================================================ .. _ref_guide_deployment_gcp: ============ Google Cloud ============ :edb-alt-title: Deploying Gel to Google Cloud In this guide we show how to deploy Gel on GCP using Cloud SQL and Kubernetes. .. include:: ./note_cloud_reset_password.rst Prerequisites ============= * Google Cloud account with billing enabled (or a `free trial `_) * ``gcloud`` CLI (`install `_) * ``kubectl`` CLI (`install `_) .. _gcp-trial: https://cloud.google.com/free/ .. _gcloud-intsll: https://cloud.google.com/sdk/ .. _kubectl-install: https://kubernetes.io/docs/tasks/tools/install-kubectl/ Make sure you are logged into Google Cloud. .. code-block:: bash $ gcloud init Create a project ================ Set the ``PROJECT`` environment variable to the project name you'd like to use. Google Cloud only allow letters, numbers, and hyphens. .. code-block:: bash $ PROJECT=gel Then create a project with this name. Skip this step if your project already exists. .. code-block:: bash $ gcloud projects create $PROJECT Then enable the requisite APIs. .. code-block:: bash $ gcloud services enable \ container.googleapis.com \ sqladmin.googleapis.com \ iam.googleapis.com \ --project=$PROJECT Provision a Postgres instance ============================= Use the ``read`` command to securely assign a value to the ``PASSWORD`` environment variable. .. code-block:: bash $ echo -n "> " && read -s PASSWORD Then create a Cloud SQL instance and set the password. .. code-block:: bash $ gcloud sql instances create ${PROJECT}-postgres \ --database-version=POSTGRES_17 \ --edition=enterprise \ --cpu=1 \ --memory=3840MiB \ --region=us-west2 \ --project=$PROJECT $ gcloud sql users set-password postgres \ --instance=${PROJECT}-postgres \ --password=$PASSWORD \ --project=$PROJECT Create a Kubernetes cluster =========================== Create an empty Kubernetes cluster inside your project. .. code-block:: bash $ gcloud container clusters create ${PROJECT}-k8s \ --zone=us-west2-a \ --num-nodes=1 \ --project=$PROJECT Configure service account ========================= Create a new service account, configure its permissions, and generate a ``credentials.json`` file. .. code-block:: bash $ gcloud iam service-accounts create ${PROJECT}-account \ --project=$PROJECT $ MEMBER="${PROJECT}-account@${PROJECT}.iam.gserviceaccount.com" $ gcloud projects add-iam-policy-binding $PROJECT \ --member=serviceAccount:${MEMBER} \ --role=roles/cloudsql.admin \ --project=$PROJECT $ gcloud iam service-accounts keys create credentials.json \ --iam-account=${MEMBER} Then use this ``credentials.json`` to authenticate the Kubernetes CLI tool ``kubectl``. .. code-block:: bash $ gcloud components install gke-gcloud-auth-plugin $ kubectl create secret generic cloudsql-instance-credentials \ --from-file=credentials.json=credentials.json $ INSTANCE_CONNECTION_NAME=$( gcloud sql instances describe ${PROJECT}-postgres \ --format="value(connectionName)" \ --project=$PROJECT ) $ DSN="postgresql://postgres:${PASSWORD}@127.0.0.1:5432" $ kubectl create secret generic cloudsql-db-credentials \ --from-literal=dsn=$DSN \ --from-literal=password=$PASSWORD \ --from-literal=instance=${INSTANCE_CONNECTION_NAME}=tcp:5432 Deploy Gel ========== Download the starter Gel Kubernetes configuration file. This file specifies a persistent volume, a container running a `Cloud SQL authorization proxy `_, and a container to run `Gel itself `_. It relies on the secrets we declared in the previous step. .. code-block:: bash $ wget "https://raw.githubusercontent.com\ /geldata/gel-deploy/dev/gcp/deployment.yaml" $ kubectl apply -f deployment.yaml Ensure the pods are running. .. code-block:: bash $ kubectl get pods NAME READY STATUS RESTARTS AGE gel-977b8fdf6-jswlw 0/2 ContainerCreating 0 16s The ``READY 0/2`` tells us neither of the two pods have finished booting. Re-run the command until ``2/2`` pods are ``READY``. If there were errors you can check Gel's logs with: .. code-block:: bash $ kubectl logs deployment/gel --container gel Persist TLS Certificate ======================= Now that our Gel instance is up and running, we need to download a local copy of its self-signed TLS certificate (which it generated on startup) and pass it as a secret into Kubernetes. Then we'll redeploy the pods. .. code-block:: bash $ kubectl create secret generic cloudsql-tls-credentials \ --from-literal=tlskey="$( kubectl exec deploy/gel -c=gel -- \ gel-show-secrets.sh --format=raw GEL_SERVER_TLS_KEY )" \ --from-literal=tlscert="$( kubectl exec deploy/gel -c=gel -- \ gel-show-secrets.sh --format=raw GEL_SERVER_TLS_CERT )" $ kubectl delete -f deployment.yaml $ kubectl apply -f deployment.yaml Expose Gel ========== .. code-block:: bash $ kubectl expose deploy/gel --type LoadBalancer Connecting your application =========================== To connect your application to the Gel instance, you'll need to provide connection parameters. Gel client libraries can be configured using either a DSN (connection string) or individual environment variables. Obtaining connection parameters ------------------------------- Your connection requires the following components: - **Host**: The ``EXTERNAL-IP`` of the LoadBalancer service - **Port**: ``5656`` (the default Gel port) - **Username**: |admin| (the default superuser) - **Password**: The value you assigned to ``$PASSWORD`` - **Branch**: |main| (the default branch) Get the public-facing IP address of your database: .. code-block:: bash $ kubectl get service NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) gel LoadBalancer 5656:30841/TCP Copy the ``EXTERNAL-IP`` associated with the ``gel`` service and construct your instance's :ref:`DSN `: .. code-block:: bash $ GEL_IP= $ GEL_DSN="gel://admin:${PASSWORD}@${GEL_IP}" To print the final DSN, you can ``echo`` it. Note that you should only run this command on a computer you trust, like a personal laptop or sandboxed environment. .. code-block:: bash $ echo $GEL_DSN Obtaining the TLS certificate ----------------------------- Since we configured Gel with a self-signed TLS certificate, your application needs the certificate to connect securely. Retrieve it from the running pod: .. code-block:: bash $ kubectl exec deploy/gel -c=gel -- \ gel-show-secrets.sh --format=raw GEL_SERVER_TLS_CERT \ > gel-tls-cert.pem Alternatively, retrieve it using the Gel CLI: .. code-block:: bash $ gel --dsn $GEL_DSN --tls-security insecure \ query "SELECT sys::get_tls_certificate()" > gel-tls-cert.pem Test your connection by opening a REPL: .. code-block:: bash $ gel --dsn $GEL_DSN --tls-security insecure Gel x.x (repl x.x) Type \help for help, \quit to quit. gel> select "hello world!"; Local development with the CLI ------------------------------ To make your remote instance easier to work with during local development, create an alias using :gelcmd:`instance link`. .. note:: The command groups :gelcmd:`instance` and :gelcmd:`project` are not intended to manage production instances. .. code-block:: bash $ echo $PASSWORD | gel instance link \ --dsn $GEL_DSN \ --password-from-stdin \ --non-interactive \ --trust-tls-cert \ my_gcp_instance You can now refer to the remote instance using the alias ``my_gcp_instance``. Use this alias wherever an instance name is expected: .. code-block:: bash $ gel -I my_gcp_instance Gel x.x Type \help for help, \quit to quit. gel> Or apply migrations: .. code-block:: bash $ gel -I my_gcp_instance migrate Using in your application ------------------------- Set these environment variables where you deploy your application: .. code-block:: bash GEL_DSN="gel://admin:@:5656" # For self-signed certificates, provide the CA cert: GEL_TLS_CA_FILE="/path/to/gel-tls-cert.pem" # Or embed the certificate content directly: GEL_TLS_CA="" # Or (for development only) disable TLS verification: # GEL_CLIENT_TLS_SECURITY=insecure Gel's client libraries will automatically read these environment variables. Health Checks ============= Using an HTTP client, you can perform health checks to monitor the status of your Gel instance. Learn how to use them with our :ref:`health checks guide `. ================================================ FILE: docs/reference/running/deployment/index.rst ================================================ .. _ref_guide_deployment: ========== Deployment ========== |Gel| can be hosted on all major cloud hosting platforms. The guides below demonstrate how to spin up both a managed PostgreSQL instance and a container running Gel `in Docker `_. .. note:: Minimum requirements As a rule of thumb, the Gel Docker container requires 1GB RAM! Images with insufficient RAM may experience unexpected issues during startup. When using an external PostgreSQL instance Gel must connect with the PostgreSQL superuser. .. toctree:: :maxdepth: 1 aws_aurora_ecs azure_flexibleserver digitalocean fly_io gcp docker bare_metal ================================================ FILE: docs/reference/running/deployment/note_cloud_reset_password.rst ================================================ .. note:: Gel Cloud: Reset the default password for the admin role If you want to dump an existing Gel Cloud instance and restore it to a new self-managed instance, you need to change the automatically generated password for the default admin role - ``edgedb`` or ``admin``. The administrator role name and its password used in the dump/restore process must be the same in both the instance dumped from and the instance restored to for the Gel tooling to continue functioning properly. To change the default password in the Cloud instance, execute the following query in the instance: .. code-block:: edgeql ALTER ROLE admin { set password := 'new_password' }; ================================================ FILE: docs/reference/running/http.rst ================================================ .. _ref_reference_http_api: ======== HTTP API ======== Gel provides HTTP endpoints that allow you to monitor the health and performance of your instance. You can use these endpoints to check if your instance is alive and ready to receive queries, as well as to collect metrics about its operation. Your branch's URL takes the form of ``http://:``. Here's how to determine your local Gel instance's HTTP server URL: - The ``hostname`` will be ``localhost`` - Find the ``port`` by running :gelcmd:`instance list`. This will print a table of all Gel instances on your machine, including their associated port number. To determine the URL of a remote instance you have linked with the CLI, you can get both the hostname and port of the instance from the "Port" column of the :gelcmd:`instance list` table (formatted as ``:``). .. _ref_edgeql_http_health_checks: .. _ref_reference_health_checks: .. _ref_guide_deployment_health_checks: Health Checks ============= |Gel| exposes endpoints to check for aliveness and readiness of your database instance. Aliveness --------- Check if your instance is alive. .. code-block:: http://:/server/status/alive If your instance is alive, it will respond with a ``200`` status code and ``"OK"`` as the payload. Otherwise, it will respond with a ``50x`` or a network error. Readiness --------- Check if your instance is ready to receive queries. .. code-block:: http://:/server/status/ready If your instance is ready, it will respond with a ``200`` status code and ``"OK"`` as the payload. Otherwise, it will respond with a ``50x`` or a network error. .. _ref_observability: Observability ============= Retrieve instance metrics. .. code-block:: http://:/metrics All Gel instances expose a Prometheus-compatible endpoint available via GET request. The following metrics are made available. System ------ ``compiler_process_spawns_total`` **Counter.** Total number of compiler processes spawned. ``compiler_processes_current`` **Gauge.** Current number of active compiler processes. ``branches_current`` **Gauge.** Current number of branches. Backend connections and performance ----------------------------------- ``backend_connections_total`` **Counter.** Total number of backend connections established. ``backend_connections_current`` **Gauge.** Current number of active backend connections. ``backend_connection_establishment_errors_total`` **Counter.** Number of times the server could not establish a backend connection. ``backend_connection_establishment_latency`` **Histogram.** Time it takes to establish a backend connection, in seconds. ``backend_query_duration`` **Histogram.** Time it takes to run a query on a backend connection, in seconds. Client connections ------------------ ``client_connections_total`` **Counter.** Total number of clients. ``client_connections_current`` **Gauge.** Current number of active clients. ``client_connections_idle_total`` **Counter.** Total number of forcefully closed idle client connections. ``client_connection_duration`` **Histogram.** Time a client connection is open. Queries and compilation ----------------------- ``edgeql_query_compilations_total`` **Counter.** Number of compiled/cached queries or scripts since instance startup. A query is compiled and then cached on first use, increasing the ``path="compiler"`` parameter. Subsequent uses of the same query only use the cache, thus only increasing the ``path="cache"`` parameter. ``edgeql_query_compilation_duration`` Deprecated in favor of ``query_compilation_duration[interface="edgeql"]``. **Histogram.** Time it takes to compile an EdgeQL query or script, in seconds. ``graphql_query_compilations_total`` **Counter.** Number of compiled/cached GraphQL queries since instance startup. A query is compiled and then cached on first use, increasing the ``path="compiler"`` parameter. Subsequent uses of the same query only use the cache, thus only increasing the ``path="cache"`` parameter. ``sql_queries_total`` **Counter.** Number of SQL queries since instance startup. ``sql_compilations_total`` **Counter.** Number of SQL compilations since instance startup. ``query_compilation_duration`` **Histogram.** Time it takes to compile a query or script, in seconds. ``queries_per_connection`` **Histogram.** Number of queries per connection. ``query_size`` **Histogram.** Number of bytes in a query, where the label ``interface=edgeql`` means the size of an EdgeQL query, ``=graphql`` for a GraphQL query, ``=sql`` for a readonly SQL query from the user, and ``=compiled`` for a backend SQL query compiled and issued by the server. Auth Extension -------------- ``auth_api_calls_total`` **Counter.** Number of API calls to the Auth extension. ``auth_ui_renders_total`` **Counter.** Number of UI pages rendered by the Auth extension. ``auth_providers`` **Histogram.** Number of Auth providers configured. ``auth_successful_logins_total`` **Counter.** Number of successful logins in the Auth extension. Errors ------ ``background_errors_total`` **Counter.** Number of unhandled errors in background server routines. ``transaction_serialization_errors_total`` **Counter.** Number of transaction serialization errors. ``connection_errors_total`` **Counter.** Number of network connection errors. ================================================ FILE: docs/reference/running/index.rst ================================================ .. _ref_running_index: =========== Running Gel =========== .. toctree:: :maxdepth: 2 :hidden: local deployment/index configuration http backend_ha admin/index This section provides comprehensive guidance for deploying and managing Gel database instances. Get your instance running ========================= While running local project instances with the CLI is low maintenance and easy to get started, Gel also makes it easy to configure your production deployment through: * Environment variables * Configuration files * Server CLI arguments These configuration mechanisms allow you to tailor your Gel deployment to your specific needs. Operations teams can control everything from memory usage, connection limits, to TLS and network settings, ensuring your Gel deployment aligns with organizational policies and infrastructure constraints while maintaining optimal performance. And keep it humming =================== Running Gel in production requires effective management and monitoring capabilities. Gel provides a comprehensive set of tools for these purposes: - **HTTP endpoints** for health checks and metrics collection that integrate with your monitoring infrastructure - **Administrative commands** for: - Role management (creating, altering, and dropping roles) - Branch operations (creating, dropping, and managing branches) - Database maintenance operations like vacuum for reclaiming space - Updating internal statistics used by the query planner These tools enable you to maintain optimal performance, manage access control, and troubleshoot issues in your Gel deployments. The reference documentation for these features is organized in the sections that follow. ================================================ FILE: docs/reference/running/local.rst ================================================ .. _ref_running_local: ================= Local development ================= See the :ref:`ref_guide_using_projects` page for information on how to use projects to manage local development. ================================================ FILE: docs/reference/stdlib/abstract.rst ================================================ .. _ref_std_abstract_types: ============== Abstract Types ============== Abstract types are used to describe polymorphic functions, otherwise known as "generic functions," which can be called on a broad range of types. ---------- .. eql:type:: anytype :index: any anytype A generic type. It is a placeholder used in cases where no specific type requirements are needed, such as defining polymorphic parameters in functions and operators. ---------- .. eql:type:: std::anyobject :index: any anytype object A generic object. Similarly to :eql:type:`anytype`, this type is used to denote a generic object. This is useful when defining polymorphic parameters in functions and operators as it conforms to whatever type is actually passed. This is different friom :eql:type:`BaseObject` which although is the parent type of any object also only has an ``id`` property, making access to other properties and links harder. ---------- .. eql:type:: std::anyscalar :index: any anytype scalar An abstract base scalar type. All scalar types are derived from this type. ---------- .. eql:type:: std::anyenum :index: any anytype enum An abstract base enumerated type. All :eql:type:`enum` types are derived from this type. ---------- .. eql:type:: anytuple :index: any anytype anytuple A generic tuple. Similarly to :eql:type:`anytype`, this type is used to denote a generic tuple without detailing its component types. This is useful when defining polymorphic parameters in functions and operators. Abstract Numeric Types ====================== These abstract numeric types extend :eql:type:`anyscalar`. .. eql:type:: std::anyint :index: any anytype int An abstract base scalar type for :eql:type:`int16`, :eql:type:`int32`, and :eql:type:`int64`. ---------- .. eql:type:: std::anyfloat :index: any anytype float An abstract base scalar type for :eql:type:`float32` and :eql:type:`float64`. ---------- .. eql:type:: std::anyreal :index: any anytype An abstract base scalar type for :eql:type:`anyint`, :eql:type:`anyfloat`, and :eql:type:`decimal`. Abstract Range Types ==================== There are some types that can be used to construct :ref:`ranges `. These scalar types are distinguished by the following abstract types: .. eql:type:: std::anypoint :index: any anypoint anyrange Abstract base type for all valid ranges. Abstract base scalar type for :eql:type:`int32`, :eql:type:`int64`, :eql:type:`float32`, :eql:type:`float64`, :eql:type:`decimal`, :eql:type:`datetime`, :eql:type:`cal::local_datetime`, and :eql:type:`cal::local_date`. ---------- .. eql:type:: std::anydiscrete :index: any anydiscrete anyrange discrete An abstract base type for all valid *discrete* ranges. This is an abstract base scalar type for :eql:type:`int32`, :eql:type:`int64`, and :eql:type:`cal::local_date`. ---------- .. eql:type:: std::anycontiguous :index: any anycontiguous anyrange An abstract base type for all valid *contiguous* ranges. This is an abstract base scalar type for :eql:type:`float32`, :eql:type:`float64`, :eql:type:`decimal`, :eql:type:`datetime`, and :eql:type:`cal::local_datetime`. ================================================ FILE: docs/reference/stdlib/array.rst ================================================ .. _ref_std_array: ====== Arrays ====== :edb-alt-title: Array Functions and Operators .. list-table:: :class: funcoptable * - :eql:op:`array[i] ` - :eql:op-desc:`arrayidx` * - :eql:op:`array[from:to] ` - :eql:op-desc:`arrayslice` * - :eql:op:`array ++ array ` - :eql:op-desc:`arrayplus` * - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` - Comparison operators * - :eql:func:`len` - Returns the number of elements in the array. * - :eql:func:`contains` - Checks if an element is in the array. * - :eql:func:`find` - Finds the index of an element in the array. * - :eql:func:`array_join` - Renders an array to a string or byte-string. * - :eql:func:`array_fill` - :eql:func-desc:`array_fill` * - :eql:func:`array_replace` - :eql:func-desc:`array_replace` * - :eql:func:`array_set` - :eql:func-desc:`array_set` * - :eql:func:`array_insert` - :eql:func-desc:`array_insert` * - :eql:func:`array_agg` - :eql:func-desc:`array_agg` * - :eql:func:`array_get` - :eql:func-desc:`array_get` * - :eql:func:`array_unpack` - :eql:func-desc:`array_unpack` Arrays store expressions of the *same type* in an ordered list. .. _ref_std_array_constructor: Constructing arrays ^^^^^^^^^^^^^^^^^^^ An array constructor is an expression that consists of a sequence of comma-separated expressions *of the same type* enclosed in square brackets. It produces an array value: .. eql:synopsis:: "[" [, ...] "]" For example: .. code-block:: edgeql-repl db> select [1, 2, 3]; {[1, 2, 3]} db> select [('a', 1), ('b', 2), ('c', 3)]; {[('a', 1), ('b', 2), ('c', 3)]} Empty arrays ^^^^^^^^^^^^ You can also create an empty array, but it must be done by providing the type information using type casting. Gel cannot infer the type of an empty array created otherwise. For example: .. code-block:: edgeql-repl db> select []; QueryError: expression returns value of indeterminate type Hint: Consider using an explicit type cast. ### select []; ### ^ db> select >[]; {[]} Reference ^^^^^^^^^ .. eql:type:: std::array An ordered list of values of the same type. Array indexing starts at zero. An array can contain any type except another array. In Gel, arrays are always one-dimensional. An array type is created implicitly when an :ref:`array constructor ` is used: .. code-block:: edgeql-repl db> select [1, 2]; {[1, 2]} The array types themselves are denoted by ``array`` followed by their sub-type in angle brackets. These may appear in cast operations: .. code-block:: edgeql-repl db> select >[1, 4, 7]; {['1', '4', '7']} db> select >[1, 4, 7]; {[1n, 4n, 7n]} Array types may also appear in schema declarations: .. code-block:: sdl type Person { str_array: array; json_array: array; } See also the list of standard :ref:`array functions `, as well as :ref:`generic functions ` such as :eql:func:`len`. ---------- .. eql:operator:: arrayidx: array [ int64 ] -> anytype .. api-index:: §array§[§int§] Accesses the array element at a given index. Example: .. code-block:: edgeql-repl db> select [1, 2, 3][0]; {1} db> select [(x := 1, y := 1), (x := 2, y := 3.3)][1]; {(x := 2, y := 3.3)} This operator also allows accessing elements from the end of the array using a negative index: .. code-block:: edgeql-repl db> select [1, 2, 3][-1]; {3} Referencing a non-existent array element will result in an error: .. code-block:: edgeql-repl db> select [1, 2, 3][4]; InvalidValueError: array index 4 is out of bounds ---------- .. eql:operator:: arrayslice: array [ int64 : int64 ] -> anytype .. api-index:: §array§[§int§:§int§] Produces a sub-array from an existing array. Omitting the lower bound of an array slice will default to a lower bound of zero. Omitting the upper bound will default the upper bound to the length of the array. The lower bound of an array slice is inclusive while the upper bound is not. Examples: .. code-block:: edgeql-repl db> select [1, 2, 3][0:2]; {[1, 2]} db> select [1, 2, 3][2:]; {[3]} db> select [1, 2, 3][:1]; {[1]} db> select [1, 2, 3][:-2]; {[1]} Referencing an array slice beyond the array boundaries will result in an empty array (unlike a direct reference to a specific index). Slicing with a lower bound less than the minimum index or a upper bound greater than the maximum index are functionally equivalent to not specifying those bounds for your slice: .. code-block:: edgeql-repl db> select [1, 2, 3][1:20]; {[2, 3]} db> select [1, 2, 3][10:20]; {[]} --------- .. eql:operator:: arrayplus: array ++ array -> array .. index:: concatenate, join, add .. api-index:: §array §++§ array§ Concatenates two arrays of the same type into one. .. code-block:: edgeql-repl db> select [1, 2, 3] ++ [99, 98]; {[1, 2, 3, 99, 98]} ---------- .. eql:function:: std::array_agg(s: set of anytype) -> array .. index:: aggregate Returns an array made from all of the input set elements. The ordering of the input set will be preserved if specified: .. code-block:: edgeql-repl db> select array_agg({2, 3, 5}); {[2, 3, 5]} db> select array_agg(User.name order by User.name); {['Alice', 'Bob', 'Joe', 'Sam']} ---------- .. eql:function:: std::array_get(array: array, \ index: int64, \ named only default: anytype = {} \ ) -> optional anytype Returns the element of a given *array* at the specified *index*. If the index is out of the array's bounds, the *default* argument or ``{}`` (empty set) will be returned. This works the same as the :eql:op:`array indexing operator `, except that if the index is out of bounds, an empty set of the array element's type is returned instead of raising an exception: .. code-block:: edgeql-repl db> select array_get([2, 3, 5], 1); {3} db> select array_get([2, 3, 5], 100); {} db> select array_get([2, 3, 5], 100, default := 42); {42} ---------- .. eql:function:: std::array_unpack(array: array) -> set of anytype Returns the elements of an array as a set. .. note:: The ordering of the returned set is not guaranteed. However, if it is wrapped in a call to :eql:func:`enumerate`, the assigned indexes are guaranteed to match the array. .. code-block:: edgeql-repl db> select array_unpack([2, 3, 5]); {3, 2, 5} db> select enumerate(array_unpack([2, 3, 5])); {(1, 3), (0, 2), (2, 5)} ---------- .. eql:function:: std::array_join(array: array, delimiter: str) -> str std::array_join(array: array, \ delimiter: bytes) -> bytes .. index:: array_to_string, implode Renders an array to a string or byte-string. Join a string array into a single string using a specified *delimiter*: .. code-block:: edgeql-repl db> select array_join(['one', 'two', 'three'], ', '); {'one, two, three'} Similarly, an array of :eql:type:`bytes` can be joined as a single value using a specified *delimiter*: .. code-block:: edgeql-repl db> select array_join([b'\x01', b'\x02', b'\x03'], b'\xff'); {b'\x01\xff\x02\xff\x03'} ---------- .. eql:function:: std::array_fill(val: anytype, n: int64) -> array Returns an array of the specified size, filled with the provided value. Create an array of size *n* where every element has the value *val*. .. code-block:: edgeql-repl db> select array_fill(0, 5); {[0, 0, 0, 0, 0]} db> select array_fill('n/a', 3); {['n/a', 'n/a', 'n/a']} ---------- .. eql:function:: std::array_replace(array: array, \ old: anytype, \ new: anytype) \ -> array Returns an array with all occurrences of one value replaced by another. Return an array where every *old* value is replaced with *new*. .. code-block:: edgeql-repl db> select array_replace([1, 1, 2, 3, 5], 1, 99); {[99, 99, 2, 3, 5]} db> select array_replace(['h', 'e', 'l', 'l', 'o'], 'l', 'L'); {['h', 'e', 'L', 'L', 'o']} ---------- .. eql:function:: std::array_set(array: array, \ idx: int64, \ val: anytype) \ -> array .. versionadded:: 6.0 Returns an array with an value at a specific index replaced by another. Return an array where the value at the index indicated by *idx* is replaced with *val*. .. code-block:: edgeql-repl db> select array_set(['hello', 'world'], 0, 'goodbye'); {['goodbye', 'world']} db> select array_set([1, 1, 2, 3], 1, 99); {[1, 99, 2, 3]} ---------- .. eql:function:: std::array_insert(array: array, \ idx: int64, \ val: anytype) \ -> array .. versionadded:: 6.0 Returns an array with an value inserted at a specific. Return an array where the value *val* is inserted at the index indicated by *idx*. .. code-block:: edgeql-repl db> select array_insert(['the', 'brown', 'fox'], 1, 'quick'); {['the', 'quick', 'brown', 'fox']} db> select array_insert([1, 1, 2, 3], 1, 99); {[1, 99, 1, 2, 3]} ================================================ FILE: docs/reference/stdlib/bool.rst ================================================ .. _ref_std_logical: ======== Booleans ======== :edb-alt-title: Boolean Functions and Operators .. list-table:: :class: funcoptable * - :eql:type:`bool` - Boolean type * - :eql:op:`bool or bool ` - :eql:op-desc:`or` * - :eql:op:`bool and bool ` - :eql:op-desc:`and` * - :eql:op:`not bool ` - :eql:op-desc:`not` * - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` - Comparison operators * - :eql:func:`all` - :eql:func-desc:`all` * - :eql:func:`any` - :eql:func-desc:`any` * - :eql:func:`assert` - :eql:func-desc:`assert` ---------- .. eql:type:: std::bool A boolean type of either ``true`` or ``false``. EdgeQL has case-insensitive keywords and that includes the boolean literals: .. code-block:: edgeql-repl db> select (True, true, TRUE); {(true, true, true)} db> select (False, false, FALSE); {(false, false, false)} These basic operators will always result in a boolean type value (although, for some of them, that value may be the empty set if an operand is the empty set): - :eql:op:`= ` - :eql:op:`\!= ` - :eql:op:`?= ` - :eql:op:`?!= ` - :eql:op:`in` - :eql:op:`not in ` - :eql:op:`\< ` - :eql:op:`\> ` - :eql:op:`\<= ` - :eql:op:`\>= ` - :eql:op:`like` - :eql:op:`ilike` These operators will result in a boolean type value even if the right operand is the empty set: - :eql:op:`in` - :eql:op:`not in ` These operators will always result in a boolean ``true`` or ``false`` value, even if either operand is the empty set: - :eql:op:`?= ` - :eql:op:`?!= ` These operators will produce the empty set if either operand is the empty set: - :eql:op:`= ` - :eql:op:`\!= ` - :eql:op:`\< ` - :eql:op:`\> ` - :eql:op:`\<= ` - :eql:op:`\>= ` - :eql:op:`like` - :eql:op:`ilike` If you need to use these operators and it's possible one or both operands will be the empty set, you can ensure a ``bool`` product by :eql:op:`coalescing `. With ``=`` and ``!=``, you can use their respective dedicated coalescing operators, ``?=`` and ``?!=``. See each individual operator for an example. Some boolean operator examples: .. code-block:: edgeql-repl db> select true and 2 < 3; {true} db> select '!' IN {'hello', 'world'}; {false} It's possible to get a boolean by casting a :eql:type:`str` or :eql:type:`json` value into it: .. code-block:: edgeql-repl db> select ('true'); {true} db> select to_json('false'); {false} :ref:`Filter clauses ` must always evaluate to a boolean: .. code-block:: edgeql select User filter .name ilike 'alice'; ---------- .. eql:operator:: or: bool or bool -> bool .. api-index:: or Evaluates ``true`` if either boolean is ``true``. .. code-block:: edgeql-repl db> select false or true; {true} .. warning:: When either operand in an ``or`` is an empty set, the result will not be a ``bool`` but instead an empty set. .. code-block:: edgeql-repl db> select true or {}; {} If one of the operands in an ``or`` operation could be an empty set, you may want to use the :eql:op:`coalesce` operator (``??``) on that side to ensure you will still get a ``bool`` result. .. code-block:: edgeql-repl db> select true or ({} ?? false); {true} ---------- .. eql:operator:: and: bool and bool -> bool .. api-index:: and Evaluates ``true`` if both booleans are ``true``. .. code-block:: edgeql-repl db> select false and true; {false} .. warning:: When either operand in an ``and`` is an empty set, the result will not be a ``bool`` but instead an empty set. .. code-block:: edgeql-repl db> select true and {}; {} If one of the operands in an ``and`` operation could be an empty set, you may want to use the :eql:op:`coalesce` operator (``??``) on that side to ensure you will still get a ``bool`` result. .. code-block:: edgeql-repl db> select true and ({} ?? false); {false} ---------- .. eql:operator:: not: not bool -> bool .. api-index:: not Logically negates a given boolean value. .. code-block:: edgeql-repl db> select not false; {true} .. warning:: When the operand in a ``not`` is an empty set, the result will not be a ``bool`` but instead an empty set. .. code-block:: edgeql-repl db> select not {}; {} If the operand in a ``not`` operation could be an empty set, you may want to use the :eql:op:`coalesce` operator (``??``) on that side to ensure you will still get a ``bool`` result. .. code-block:: edgeql-repl db> select not ({} ?? false); {true} ---------- The ``and`` and ``or`` operators are commutative. The truth tables are as follows: +-------+-------+---------------+--------------+--------------+ | a | b | a ``and`` b | a ``or`` b | ``not`` a | +=======+=======+===============+==============+==============+ | true | true | true | true | false | +-------+-------+---------------+--------------+--------------+ | true | false | false | true | false | +-------+-------+---------------+--------------+--------------+ | false | true | false | true | true | +-------+-------+---------------+--------------+--------------+ | false | false | false | false | true | +-------+-------+---------------+--------------+--------------+ ---------- The operators ``and``/``or`` and the functions :eql:func:`all`/:eql:func:`any` differ in the way they handle an empty set (``{}``). Both ``and`` and ``or`` operators apply to the cross-product of their operands. If either operand is an empty set, the result will also be an empty set. For example: .. code-block:: edgeql-repl db> select {true, false} and {}; {} db> select true and {}; {} Operating on an empty set with :eql:func:`all`/:eql:func:`any` does *not* return an empty set: .. code-block:: edgeql-repl db> select all({}); {true} db> select any({}); {false} :eql:func:`all` returns ``true`` because the empty set contains no false values. :eql:func:`any` returns ``false`` because the empty set contains no true values. The :eql:func:`all` and :eql:func:`any` functions are generalized to apply to sets of values, including ``{}``. Thus they have the following truth table: +-------+-------+-----------------+-----------------+ | a | b | ``all({a, b})`` | ``any({a, b})`` | +=======+=======+=================+=================+ | true | true | true | true | +-------+-------+-----------------+-----------------+ | true | false | false | true | +-------+-------+-----------------+-----------------+ | {} | true | true | true | +-------+-------+-----------------+-----------------+ | {} | false | false | false | +-------+-------+-----------------+-----------------+ | false | true | false | true | +-------+-------+-----------------+-----------------+ | false | false | false | false | +-------+-------+-----------------+-----------------+ | true | {} | true | true | +-------+-------+-----------------+-----------------+ | false | {} | false | false | +-------+-------+-----------------+-----------------+ | {} | {} | true | false | +-------+-------+-----------------+-----------------+ Since :eql:func:`all` and :eql:func:`any` apply to sets as a whole, missing values (represented by ``{}``) are just that - missing. They don't affect the overall result. To understand the last line in the above truth table it's useful to remember that ``all({a, b}) = all(a) and all(b)`` and ``any({a, b}) = any(a) or any(b)``. For more customized handling of ``{}``, use the :eql:op:`?? ` operator. ---------- .. eql:function:: std::assert( \ input: bool, \ named only message: optional str = {} \ ) -> bool Checks that the input bool is ``true``. If the input bool is ``false``, ``assert`` raises a ``QueryAssertionError``. Otherwise, this function returns ``true``. .. code-block:: edgeql-repl db> select assert(true); {true} db> select assert(false); gel error: QueryAssertionError: assertion failed db> select assert(false, message := 'value is not true'); gel error: QueryAssertionError: value is not true ``assert`` can be used in triggers to create more powerful constraints. In this schema, the ``Person`` type has both ``friends`` and ``enemies`` links. You may not want a ``Person`` to be both a friend and an enemy of the same ``Person``. ``assert`` can be used inside a trigger to easily prohibit this. .. code-block:: sdl type Person { required name: str; multi friends: Person; multi enemies: Person; trigger prohibit_frenemies after insert, update for each do ( assert( not exists (__new__.friends intersect __new__.enemies), message := "Invalid frenemies", ) ) } With this trigger in place, it is impossible to link the same ``Person`` as both a friend and an enemy of any other person. .. code-block:: edgeql-repl db> insert Person {name := 'Quincey Morris'}; {default::Person {id: e4a55480-d2de-11ed-93bd-9f4224fc73af}} db> insert Person {name := 'Dracula'}; {default::Person {id: e7f2cff0-d2de-11ed-93bd-279780478afb}} db> update Person ... filter .name = 'Quincey Morris' ... set { ... enemies := ( ... select detached Person filter .name = 'Dracula' ... ) ... }; {default::Person {id: e4a55480-d2de-11ed-93bd-9f4224fc73af}} db> update Person ... filter .name = 'Quincey Morris' ... set { ... friends := ( ... select detached Person filter .name = 'Dracula' ... ) ... }; gel error: GelError: Invalid frenemies In the following examples, the ``size`` properties of the ``File`` objects are ``1024``, ``1024``, and ``131,072``. .. code-block:: edgeql-repl db> for obj in (select File) ... union (assert(obj.size <= 128*1024, message := 'file too big')); {true, true, true} db> for obj in (select File) ... union (assert(obj.size <= 64*1024, message := 'file too big')); gel error: QueryAssertionError: file too big You may call ``assert`` in the ``order by`` clause of your ``select`` statement. This will ensure it is called only on objects that pass your filter. .. code-block:: edgeql-repl db> select File { name, size } ... order by assert(.size <= 128*1024, message := "file too big"); { default::File {name: 'File 2', size: 1024}, default::File {name: 'Asdf 3', size: 1024}, default::File {name: 'File 1', size: 131072}, } db> select File { name, size } ... order by assert(.size <= 64*1024, message := "file too big"); gel error: QueryAssertionError: file too big db> select File { name, size } ... filter .size <= 64*1024 ... order by assert(.size <= 64*1024, message := "file too big"); { default::File {name: 'File 2', size: 1024}, default::File {name: 'Asdf 3', size: 1024} } ================================================ FILE: docs/reference/stdlib/bytes.rst ================================================ .. _ref_std_bytes: ===== Bytes ===== :edb-alt-title: Bytes Functions and Operators .. list-table:: :class: funcoptable * - :eql:type:`bytes` - Byte sequence * - :eql:type:`Endian` - An enum for indicating integer value encoding. * - :eql:op:`bytes[i] ` - :eql:op-desc:`bytesidx` * - :eql:op:`bytes[from:to] ` - :eql:op-desc:`bytesslice` * - :eql:op:`bytes ++ bytes ` - :eql:op-desc:`bytesplus` * - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` - Comparison operators * - :eql:func:`len` - Returns the number of bytes. * - :eql:func:`contains` - Checks if the byte sequence contains a given subsequence. * - :eql:func:`find` - Finds the index of the first occurrence of a subsequence. * - :eql:func:`to_bytes` - :eql:func-desc:`to_bytes` * - :eql:func:`to_str` - :eql:func-desc:`to_str` * - :eql:func:`to_int16` - :eql:func-desc:`to_int16` * - :eql:func:`to_int32` - :eql:func-desc:`to_int32` * - :eql:func:`to_int64` - :eql:func-desc:`to_int64` * - :eql:func:`to_uuid` - :eql:func-desc:`to_uuid` * - :eql:func:`bytes_get_bit` - :eql:func-desc:`bytes_get_bit` * - :eql:func:`bit_count` - :eql:func-desc:`bit_count` * - :eql:func:`enc::base64_encode` - :eql:func-desc:`enc::base64_encode` * - :eql:func:`enc::base64_decode` - :eql:func-desc:`enc::base64_decode` ---------- .. eql:type:: std::bytes A sequence of bytes representing raw data. Bytes can be represented as a literal using this syntax: ``b''``. .. code-block:: edgeql-repl db> select b'Hello, world'; {b'Hello, world'} db> select b'Hello,\x20world\x01'; {b'Hello, world\x01'} There are also some :ref:`generic ` functions that can operate on bytes: .. code-block:: edgeql-repl db> select contains(b'qwerty', b'42'); {false} Bytes are rendered as base64-encoded strings in JSON. When you cast a ``bytes`` value into JSON, that's what you'll get. In order to :eql:op:`cast ` a :eql:type:`json` value into bytes, it must be a base64-encoded string. .. code-block:: edgeql-repl db> select b'Hello Gel!'; {"\"SGVsbG8gRWRnZURCIQ==\""} db> select to_json("\"SGVsbG8gRWRnZURCIQ==\""); {b'Hello Gel!'} ---------- .. eql:type:: std::Endian .. versionadded:: 5.0 An enum for indicating integer value encoding. This enum is used by the :eql:func:`to_int16`, :eql:func:`to_int32`, :eql:func:`to_int64` and the :eql:func:`to_bytes` converters working with :eql:type:`bytes` and integers. ``Endian.Big`` stands for big-endian encoding going from most significant byte to least. ``Endian.Little`` stands for little-endian encoding going from least to most significant byte. .. code-block:: edgeql-repl db> select to_bytes(16908295, Endian.Big); {b'\x01\x02\x00\x07'} db> select to_int32(b'\x01\x02\x00\x07', Endian.Big); {16908295} db> select to_bytes(16908295, Endian.Little); {b'\x07\x00\x02\x01'} db> select to_int32(b'\x07\x00\x02\x01', Endian.Little); {16908295} ---------- .. eql:operator:: bytesidx: bytes [ int64 ] -> bytes .. api-index:: §bytes§[§int§] Accesses a byte at a given index. Examples: .. code-block:: edgeql-repl db> select b'binary \x01\x02\x03\x04 ftw!'[2]; {b'n'} db> select b'binary \x01\x02\x03\x04 ftw!'[8]; {b'\x02'} ---------- .. eql:operator:: bytesslice: bytes [ int64 : int64 ] -> bytes .. api-index:: §bytes§[§int§:§int§] Produces a bytes sub-sequence from an existing bytes value. Examples: .. code-block:: edgeql-repl db> select b'\x01\x02\x03\x04 ftw!'[2:-1]; {b'\x03\x04 ftw'} db> select b'some bytes'[2:-3]; {b'me by'} --------- .. eql:operator:: bytesplus: bytes ++ bytes -> .. index:: join, add .. api-index:: §bytes §++§ bytes§ Concatenates two bytes values into one. .. code-block:: edgeql-repl db> select b'\x01\x02' ++ b'\x03\x04'; {b'\x01\x02\x03\x04'} --------- .. TODO: Function signatures except the first need to be revealed only for v5+ .. eql:function:: std::to_bytes(s: str) -> bytes std::to_bytes(j: json) -> bytes std::to_bytes(val: int16, endian: Endian) -> bytes std::to_bytes(val: int32, endian: Endian) -> bytes std::to_bytes(val: int64, endian: Endian) -> bytes std::to_bytes(val: uuid) -> bytes .. versionadded:: 4.0 .. index:: encode, stringencoder Converts a given value into binary representation as :eql:type:`bytes`. The strings get converted using UTF-8 encoding: .. code-block:: edgeql-repl db> select to_bytes('テキスト'); {b'\xe3\x83\x86\xe3\x82\xad\xe3\x82\xb9\xe3\x83\x88'} The json values get converted as strings using UTF-8 encoding: .. code-block:: edgeql-repl db> select to_bytes(to_json('{"a": 1}')); {b'{"a": 1}'} The integer values can be encoded as big-endian (most significant bit comes first) byte strings: .. code-block:: edgeql-repl db> select to_bytes(31, Endian.Big); {b'\x00\x1f'} db> select to_bytes(31, Endian.Big); {b'\x00\x00\x00\x1f'} db> select to_bytes(123456789123456789, Endian.Big); {b'\x01\xb6\x9bK\xac\xd0_\x15'} .. note:: Due to underlying implementation details using big-endian encoding results in slightly faster performance of ``to_bytes`` when converting integers. The UUID values are converted to the underlying string of 16 bytes: .. code-block:: edgeql-repl db> select to_bytes('1d70c86e-cc92-11ee-b4c7-a7aa0a34e2ae'); {b'\x1dp\xc8n\xcc\x92\x11\xee\xb4\xc7\xa7\xaa\n4\xe2\xae'} To perform the reverse conversion there are corresponding functions: :eql:func:`to_str`, :eql:func:`to_int16`, :eql:func:`to_int32`, :eql:func:`to_int64`, :eql:func:`to_uuid`. --------- .. eql:function:: std::bytes_get_bit(bytes: bytes, nth: int64) -> int64 Returns the specified bit of the :eql:type:`bytes` value. When looking for the *nth* bit, this function will enumerate bits from least to most significant in each byte. .. code-block:: edgeql-repl db> for n in {0, 1, 2, 3, 4, 5, 6, 7, ... 8, 9, 10, 11, 12, 13 ,14, 15} ... union bytes_get_bit(b'ab', n); {1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0} --------- .. eql:function:: enc::base64_encode(b: bytes) -> str .. versionadded:: 4.0 Returns a Base64-encoded :eql:type:`str` of the :eql:type:`bytes` value. .. code-block:: edgeql-repl db> select enc::base64_encode(b'hello'); {'aGVsbG8='} --------- .. eql:function:: enc::base64_decode(s: str) -> bytes .. versionadded:: 4.0 Returns the :eql:type:`bytes` of a Base64-encoded :eql:type:`str`. Returns an InvalidValueError if input is not valid Base64. .. code-block:: edgeql-repl db> select enc::base64_decode('aGVsbG8='); {b'hello'} db> select enc::base64_decode('aGVsbG8'); gel error: InvalidValueError: invalid base64 end sequence ================================================ FILE: docs/reference/stdlib/cfg.rst ================================================ .. _ref_std_cfg: ====== Config ====== The ``cfg`` module contains a set of types and scalars used for configuring |Gel|. .. list-table:: :class: funcoptable * - **Type** - **Description** * - :eql:type:`cfg::AbstractConfig` - The abstract base type for all configuration objects. The properties of this type define the set of configuruation settings supported by Gel. * - :eql:type:`cfg::Config` - The main configuration object. The properties of this object reflect the overall configuration setting from instance level all the way to session level. * - :eql:type:`cfg::DatabaseConfig` - The database configuration object. It reflects all the applicable configuration at the Gel database level. * - :eql:type:`cfg::BranchConfig` - The database branch configuration object. It reflects all the applicable configuration at the Gel branch level. * - :eql:type:`cfg::InstanceConfig` - The instance configuration object. * - :eql:type:`cfg::ExtensionConfig` - The abstract base type for all extension configuration objects. Each extension can define the necessary configuration settings by extending this type and adding the extension-specific properties. * - :eql:type:`cfg::Auth` - An object type representing an authentication profile. * - :eql:type:`cfg::ConnectionTransport` - An enum type representing the different protocols that Gel speaks. * - :eql:type:`cfg::AuthMethod` - An abstract object type representing a method of authentication * - :eql:type:`cfg::Trust` - A subclass of ``AuthMethod`` indicating an "always trust" policy (no authentication). * - :eql:type:`cfg::SCRAM` - A subclass of ``AuthMethod`` indicating password-based authentication. * - :eql:type:`cfg::Password` - A subclass of ``AuthMethod`` indicating basic password-based authentication. * - :eql:type:`cfg::JWT` - A subclass of ``AuthMethod`` indicating token-based authentication. * - :eql:type:`cfg::memory` - A scalar type for storing a quantity of memory storage. .. eql:type:: cfg::AbstractConfig An abstract type representing the configuration of an instance or database. The properties of this object type represent the set of configuration options supported by Gel (listed above). ---------- .. eql:type:: cfg::Config The main configuration object type. This type will have only one object instance. The ``cfg::Config`` object represents the sum total of the current Gel configuration. It reflects the result of applying instance, branch, and session level configuration. Examining this object is the recommended way of determining the current configuration. Here's an example of checking and disabling :ref:`access policies `: .. code-block:: edgeql-repl db> select cfg::Config.apply_access_policies; {true} db> configure session set apply_access_policies := false; OK: CONFIGURE SESSION db> select cfg::Config.apply_access_policies; {false} ---------- .. eql:type:: cfg::BranchConfig .. versionadded:: 5.0 The branch-level configuration object type. This type will have only one object instance. The ``cfg::BranchConfig`` object represents the state of the branch and instance-level Gel configuration. For overall configuration state please refer to the :eql:type:`cfg::Config` instead. ---------- .. eql:type:: cfg::InstanceConfig The instance-level configuration object type. This type will have only one object instance. The ``cfg::InstanceConfig`` object represents the state of only instance-level Gel configuration. For overall configuraiton state please refer to the :eql:type:`cfg::Config` instead. ---------- .. eql:type:: cfg::ExtensionConfig .. versionadded:: 5.0 An abstract type representing extension configuration. Every extension is expected to define its own extension-specific config object type extending ``cfg::ExtensionConfig``. Any necessary extension configuration setting should be represented as properties of this concrete config type. Up to three instances of the extension-specific config type will be created, each of them with a ``required single link cfg`` to the :eql:type:`cfg::Config`, :eql:type:`cfg::DatabaseConfig`, or :eql:type:`cfg::InstanceConfig` object depending on the configuration level. The :eql:type:`cfg::AbstractConfig` exposes a corresponding computed multi-backlink called ``extensions``. For example, :ref:`ext::pgvector ` extension exposes ``probes`` as a configurable parameter via ``ext::pgvector::Config`` object: .. code-block:: edgeql-repl db> configure session ... set ext::pgvector::Config::probes := 5; OK: CONFIGURE SESSION db> select cfg::Config.extensions[is ext::pgvector::Config]{*}; { ext::pgvector::Config { id: 12b5c70f-0bb8-508a-845f-ca3d41103b6f, probes: 5, ef_search: 40, }, } ---------- .. eql:type:: cfg::Auth An object type designed to specify a client authentication profile. .. code-block:: edgeql-repl db> configure instance insert ... Auth {priority := 0, method := (insert Trust)}; OK: CONFIGURE INSTANCE Below are the properties of the ``Auth`` class. :eql:synopsis:`priority -> int64` The priority of the authentication rule. The lower this number, the higher the priority. :eql:synopsis:`user -> multi str` The name(s) of the database role(s) this rule applies to. If set to ``'*'``, then it applies to all roles. :eql:synopsis:`method -> cfg::AuthMethod` The name of the authentication method type. Expects an instance of :eql:type:`cfg::AuthMethod`; Valid values are: ``Trust`` for no authentication and ``SCRAM`` for SCRAM-SHA-256 password authentication. :eql:synopsis:`comment -> optional str` An optional comment for the authentication rule. --------- .. eql:type:: cfg::ConnectionTransport An enum listing the various protocols that Gel can speak. Possible values are: .. list-table:: :class: funcoptable * - **Value** - **Description** * - ``cfg::ConnectionTransport.TCP`` - Gel binary protocol * - ``cfg::ConnectionTransport.TCP_PG`` - Postgres protocol for the :ref:`SQL query mode ` * - ``cfg::ConnectionTransport.HTTP`` - Gel binary protocol :ref:`tunneled over HTTP ` * - ``cfg::ConnectionTransport.SIMPLE_HTTP`` - :ref:`EdgeQL over HTTP ` and :ref:`GraphQL ` endpoints --------- .. eql:type:: cfg::AuthMethod An abstract object class that represents an authentication method. It currently has four concrete subclasses, each of which represent an available authentication method: :eql:type:`cfg::SCRAM`, :eql:type:`cfg::JWT`, :eql:type:`cfg::Password`, and :eql:type:`cfg::Trust`. :eql:synopsis:`transports -> multi cfg::ConnectionTransport` Which connection transports this method applies to. The subclasses have their own defaults for this. ------- .. eql:type:: cfg::Trust The ``cfg::Trust`` indicates an "always-trust" policy. When active, it disables password-based authentication. .. code-block:: edgeql-repl db> configure instance insert ... Auth {priority := 0, method := (insert Trust)}; OK: CONFIGURE INSTANCE ------- .. eql:type:: cfg::SCRAM ``cfg::SCRAM`` indicates password-based authentication. It uses a challenge-response scheme to avoid transmitting the password directly. This policy is implemented via ``SCRAM-SHA-256`` It is available for the ``TCP``, ``TCP_PG``, and ``HTTP`` transports and is the default for ``TCP`` and ``TCP_PG``. .. code-block:: edgeql-repl db> configure instance insert ... Auth {priority := 0, method := (insert SCRAM)}; OK: CONFIGURE INSTANCE ------- .. eql:type:: cfg::JWT ``cfg::JWT`` uses a JWT signed by the server to authenticate. It is available for the ``TCP``, ``HTTP``, and ``HTTP_SIMPLE`` transports and is the default for ``HTTP``. ------- .. eql:type:: cfg::Password ``cfg::Password`` indicates simple password-based authentication. Unlike :eql:type:`cfg::SCRAM`, this policy transmits the password over the (encrypted) channel. It is implemened using HTTP Basic Authentication over TLS. This policy is available only for the ``SIMPLE_HTTP`` transport, where it is the default. ------- .. eql:type:: cfg::memory A scalar type representing a quantity of memory storage. As with ``uuid``, ``datetime``, and several other types, ``cfg::memory`` values are declared by casting from an appropriately formatted string. .. code-block:: edgeql-repl db> select '1B'; # 1 byte {'1B'} db> select '5KiB'; # 5 kibibytes {'5KiB'} db> select '128MiB'; # 128 mebibytes {'128MiB'} The numerical component of the value must be a non-negative integer; the units must be one of ``B|KiB|MiB|GiB|TiB|PiB``. We're using the explicit ``KiB`` unit notation (1024 bytes) instead of ``kB`` (which is ambiguous, and may mean 1000 or 1024 bytes). ================================================ FILE: docs/reference/stdlib/constraint_table.rst ================================================ .. list-table:: * - :eql:constraint:`exclusive` - Enforce uniqueness (disallow duplicate values) * - :eql:constraint:`expression` - Custom constraint expression (followed by keyword ``on``) * - :eql:constraint:`one_of` - A list of allowable values * - :eql:constraint:`max_value` - Maximum value numerically/lexicographically * - :eql:constraint:`max_ex_value` - Maximum value numerically/lexicographically (exclusive range) * - :eql:constraint:`min_value` - Minimum value numerically/lexicographically * - :eql:constraint:`min_ex_value` - Minimum value numerically/lexicographically (exclusive range) * - :eql:constraint:`max_len_value` - Maximum length (``str`` only) * - :eql:constraint:`min_len_value` - Minimum length (``str`` only) * - :eql:constraint:`regexp` - Regex constraint (``str`` only) ================================================ FILE: docs/reference/stdlib/constraints.rst ================================================ .. _ref_std_constraints: =========== Constraints =========== .. include:: constraint_table.rst .. eql:constraint:: std::expression on (expr) A constraint based on an arbitrary expression returning a boolean. The ``expression`` constraint may be used as in this example to create a custom scalar type: .. code-block:: sdl scalar type StartsWithA extending str { constraint expression on (__subject__[0] = 'A'); } Example of using an ``expression`` constraint based on two object properties to restrict maximum magnitude for a vector: .. code-block:: sdl type Vector { required x: float64; required y: float64; constraint expression on ( __subject__.x^2 + __subject__.y^2 < 25 ); } .. eql:constraint:: std::one_of(variadic members: anytype) Specifies a list of allowed values. Example: .. code-block:: sdl scalar type Status extending str { constraint one_of ('Open', 'Closed', 'Merged'); } .. eql:constraint:: std::max_value(max: anytype) Specifies the maximum allowed value. Example: .. code-block:: sdl scalar type Max100 extending int64 { constraint max_value(100); } .. eql:constraint:: std::max_ex_value(max: anytype) Specifies a non-inclusive upper bound for the value. Example: .. code-block:: sdl scalar type Under100 extending int64 { constraint max_ex_value(100); } In this example, in contrast to the ``max_value`` constraint, a value of the ``Under100`` type cannot be ``100`` since the valid range of ``max_ex_value`` does not include the value specified in the constraint. .. eql:constraint:: std::max_len_value(max: int64) Specifies the maximum allowed length of a value. Example: .. code-block:: sdl scalar type Username extending str { constraint max_len_value(30); } .. eql:constraint:: std::min_value(min: anytype) Specifies the minimum allowed value. Example: .. code-block:: sdl scalar type NonNegative extending int64 { constraint min_value(0); } .. eql:constraint:: std::min_ex_value(min: anytype) Specifies a non-inclusive lower bound for the value. Example: .. code-block:: sdl scalar type PositiveFloat extending float64 { constraint min_ex_value(0); } In this example, in contrast to the ``min_value`` constraint, a value of the ``PositiveFloat`` type cannot be ``0`` since the valid range of ``mix_ex_value`` does not include the value specified in the constraint. .. eql:constraint:: std::min_len_value(min: int64) Specifies the minimum allowed length of a value. Example: .. code-block:: sdl scalar type EmailAddress extending str { constraint min_len_value(3); } .. eql:constraint:: std::regexp(pattern: str) Limits to string values matching a regular expression. Example: .. code-block:: sdl scalar type LettersOnly extending str { constraint regexp(r'[A-Za-z]*'); } See our documentation on :ref:`regular expression patterns ` for more information on those. .. eql:constraint:: std::exclusive Specifies that the link or property value must be exclusive (unique). When applied to a ``multi`` link or property, the exclusivity constraint guarantees that for every object, the set of values held by a link or property does not intersect with any other such set in any other object of this type. This constraint is only valid for concrete links and properties. Scalar type definitions cannot include this constraint. This constraint has an additional effect of creating an implicit :ref:`index ` on a property. This means that there's no need to add explicit indexes for properties with this constraint. Example: .. code-block:: sdl type User { # Make sure user names are unique. required name: str { constraint exclusive; } # Already indexed, don't need to do this: # index on (.name) # Make sure none of the "owned" items belong # to any other user. multi owns: Item { constraint exclusive; } } Sometimes it may be necessary to create a type where each *combination* of properties is unique. This can be achieved by defining an ``exclusive`` constraint for the combination, rather than on each property: .. code-block:: sdl type UniqueCoordinates { required x: int64; required y: int64; # Each combination of x and y must be unique. constraint exclusive on ( (.x, .y) ); } Any possible expression can appear in the ``on ()`` clause of the ``exclusive`` constraint as long as it adheres to the following: * The expression can only contain references to the immediate properties or links of the type. * No :ref:`backlinks ` or long paths are allowed. * Only ``Immutable`` functions are allowed in the constraint expression. .. list-table:: :class: seealso * - **See also** * - :ref:`Schema > Constraints ` * - :ref:`SDL > Constraints ` * - :ref:`DDL > Constraints ` * - :ref:`Introspection > Constraints ` ================================================ FILE: docs/reference/stdlib/datetime.rst ================================================ .. _ref_std_datetime: =============== Dates and Times =============== :edb-alt-title: Types, Functions, and Operators for Dates and Times .. list-table:: :class: funcoptable * - :eql:type:`datetime` - Timezone-aware point in time * - :eql:type:`duration` - Absolute time span * - :eql:type:`cal::local_datetime` - Date and time w/o timezone * - :eql:type:`cal::local_date` - Date type * - :eql:type:`cal::local_time` - Time type * - :eql:type:`cal::relative_duration` - Relative time span * - :eql:type:`cal::date_duration` - Relative time span in days * - :eql:op:`dt + dt ` - :eql:op-desc:`dtplus` * - :eql:op:`dt - dt ` - :eql:op-desc:`dtminus` * - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` - Comparison operators * - :eql:func:`to_str` - Render a date/time value to a string. * - :eql:func:`to_datetime` - :eql:func-desc:`to_datetime` * - :eql:func:`cal::to_local_datetime` - :eql:func-desc:`cal::to_local_datetime` * - :eql:func:`cal::to_local_date` - :eql:func-desc:`cal::to_local_date` * - :eql:func:`cal::to_local_time` - :eql:func-desc:`cal::to_local_time` * - :eql:func:`to_duration` - :eql:func-desc:`to_duration` * - :eql:func:`cal::to_relative_duration` - :eql:func-desc:`cal::to_relative_duration` * - :eql:func:`cal::to_date_duration` - :eql:func-desc:`cal::to_date_duration` * - :eql:func:`datetime_get` - :eql:func-desc:`datetime_get` * - :eql:func:`cal::time_get` - :eql:func-desc:`cal::time_get` * - :eql:func:`cal::date_get` - :eql:func-desc:`cal::date_get` * - :eql:func:`duration_get` - :eql:func-desc:`duration_get` * - :eql:func:`datetime_truncate` - :eql:func-desc:`datetime_truncate` * - :eql:func:`duration_truncate` - :eql:func-desc:`duration_truncate` * - :eql:func:`datetime_current` - :eql:func-desc:`datetime_current` * - :eql:func:`datetime_of_transaction` - :eql:func-desc:`datetime_of_transaction` * - :eql:func:`datetime_of_statement` - :eql:func-desc:`datetime_of_statement` * - :eql:func:`cal::duration_normalize_hours` - :eql:func-desc:`cal::duration_normalize_hours` * - :eql:func:`cal::duration_normalize_days` - :eql:func-desc:`cal::duration_normalize_days` .. _ref_std_datetime_intro: |Gel| offers two ways of representing date/time values: * a timezone-aware :eql:type:`std::datetime` type; * a set of "local" date/time types, not attached to any particular timezone: :eql:type:`cal::local_datetime`, :eql:type:`cal::local_date`, and :eql:type:`cal::local_time`. There are also two different ways of measuring duration: * :eql:type:`duration` for using absolute and unambiguous units; * :eql:type:`cal::relative_duration` for using fuzzy units like years, months and days in addition to the absolute units. All related operators, functions, and type casts are designed to maintain a strict separation between timezone-aware and "local" date/time values. |Gel| stores and outputs timezone-aware values in UTC format. .. note:: All date/time types are restricted to years between 1 and 9999, including the years 1 and 9999. Although many systems support ISO 8601 date/time formatting in theory, in practice the formatting before year 1 and after 9999 tends to be inconsistent. As such, dates outside this range are not reliably portable. .. _ref_std_datetime_timezones: Timezones --------- For timezone string literals, you may specify timezones in one of two ways: * IANA (Olson) timezone database name (e.g. ``America/New_York``) * A time zone abbreviation (e.g. ``EDT`` for Eastern Daylight Time) See the `relevant section from the PostgreSQL documentation `_ for more detail about how time zones affect the behavior of date/time functionality. .. note:: The IANA timezone database is maintained by Paul Eggert for the IANA. You can find a `GitHub repository with the latest timezone data here `_, and the `list of timezone names here `_. ---------- .. eql:type:: std::datetime Represents a timezone-aware moment in time. All dates must correspond to dates that exist in the proleptic Gregorian calendar. :eql:op:`Casting ` is a simple way to obtain a :eql:type:`datetime` value in an expression: .. code-block:: edgeql select '2018-05-07T15:01:22.306916+00'; select '2018-05-07T15:01:22+00'; When casting ``datetime`` from strings, the string must follow the ISO 8601 format with a timezone included. .. code-block:: edgeql-repl db> select 'January 01 2019 UTC'; InvalidValueError: invalid input syntax for type std::datetime: 'January 01 2019 UTC' Hint: Please use ISO8601 format. Alternatively "to_datetime" function provides custom formatting options. db> select '2019-01-01T15:01:22'; InvalidValueError: invalid input syntax for type std::datetime: '2019-01-01T15:01:22' Hint: Please use ISO8601 format. Alternatively "to_datetime" function provides custom formatting options. All ``datetime`` values are restricted to the range from year 1 to 9999. For more information regarding interacting with this type, see :eql:func:`datetime_get`, :eql:func:`to_datetime`, and :eql:func:`to_str`. ---------- .. eql:type:: cal::local_datetime A type for representing a date and time without a timezone. :eql:op:`Casting ` is a simple way to obtain a :eql:type:`cal::local_datetime` value in an expression: .. code-block:: edgeql select '2018-05-07T15:01:22.306916'; select '2018-05-07T15:01:22'; When casting ``cal::local_datetime`` from strings, the string must follow the ISO 8601 format without timezone: .. code-block:: edgeql-repl db> select '2019-01-01T15:01:22+00'; InvalidValueError: invalid input syntax for type cal::local_datetime: '2019-01-01T15:01:22+00' Hint: Please use ISO8601 format. Alternatively "cal::to_local_datetime" function provides custom formatting options. db> select 'January 01 2019'; InvalidValueError: invalid input syntax for type cal::local_datetime: 'January 01 2019' Hint: Please use ISO8601 format. Alternatively "cal::to_local_datetime" function provides custom formatting options. All ``datetime`` values are restricted to the range from year 1 to 9999. For more information regarding interacting with this type, see :eql:func:`datetime_get`, :eql:func:`cal::to_local_datetime`, and :eql:func:`to_str`. ---------- .. eql:type:: cal::local_date A type for representing a date without a timezone. :eql:op:`Casting ` is a simple way to obtain a :eql:type:`cal::local_date` value in an expression: .. code-block:: edgeql select '2018-05-07'; When casting ``cal::local_date`` from strings, the string must follow the ISO 8601 date format. For more information regarding interacting with this type, see :eql:func:`cal::date_get`, :eql:func:`cal::to_local_date`, and :eql:func:`to_str`. ---------- .. eql:type:: cal::local_time A type for representing a time without a timezone. :eql:op:`Casting ` is a simple way to obtain a :eql:type:`cal::local_time` value in an expression: .. code-block:: edgeql select '15:01:22.306916'; select '15:01:22'; When casting ``cal::local_time`` from strings, the string must follow the ISO 8601 time format. For more information regarding interacting with this type, see :eql:func:`cal::time_get`, :eql:func:`cal::to_local_time`, and :eql:func:`to_str`. ---------- .. _ref_datetime_duration: .. eql:type:: std::duration A type for representing a span of time. A :eql:type:`duration` is a fixed number of seconds and microseconds and isn't adjusted by timezone, length of month, or anything else in datetime calculations. When converting from a string, only units of ``'microseconds'``, ``'milliseconds'``, ``'seconds'``, ``'minutes'``, and ``'hours'`` are valid: .. code-block:: edgeql-repl db> select '45.6 seconds'; {'0:00:45.6'} db> select '15 milliseconds'; {'0:00:00.015'} db> select '48 hours 45 minutes'; {'48:45:00'} db> select '11 months'; gel error: InvalidValueError: invalid input syntax for type std::duration: '11 months' Hint: Units bigger than hours cannot be used for std::duration. All date/time types support the ``+`` and ``-`` arithmetic operations with durations: .. code-block:: edgeql-repl db> select '2019-01-01T00:00:00Z' - '24 hours'; {'2018-12-31T00:00:00+00:00'} db> select '22:00' + '1 hour'; {'23:00:00'} For more information regarding interacting with this type, see :eql:func:`to_duration`, and :eql:func:`to_str` and date/time :eql:op:`operators `. ---------- .. eql:type:: cal::relative_duration A type for representing a relative span of time. Unlike :eql:type:`std::duration`, ``cal::relative_duration`` is an imprecise form of measurement. When months and days are used, the same relative duration could have a different absolute duration depending on the date you're measuring from. For example 2020 was a leap year and had 366 days. Notice how the number of hours in each year below is different: .. code-block:: edgeql-repl db> with ... first_day_of_2020 := '2020-01-01T00:00:00Z', ... one_year := '1 year', ... first_day_of_next_year := first_day_of_2020 + one_year ... select first_day_of_next_year - first_day_of_2020; {'8784:00:00'} db> with ... first_day_of_2019 := '2019-01-01T00:00:00Z', ... one_year := '1 year', ... first_day_of_next_year := first_day_of_2019 + one_year ... select first_day_of_next_year - first_day_of_2019; {'8760:00:00'} When converting from a string, only the following units are valid: - ``'microseconds'`` - ``'milliseconds'`` - ``'seconds'`` - ``'minutes'`` - ``'hours'`` - ``'days'`` - ``'weeks'`` - ``'months'`` - ``'years'`` - ``'decades'`` - ``'centuries'`` - ``'millennia'`` Examples of units usage: .. code-block:: edgeql select '45.6 seconds'; select '15 milliseconds'; select '3 weeks 45 minutes'; select '-7 millennia'; All date/time types support the ``+`` and ``-`` arithmetic operations with ``relative_duration``: .. code-block:: edgeql-repl db> select '2019-01-01T00:00:00Z' - ... '3 years'; {'2016-01-01T00:00:00+00:00'} db> select '22:00' + ... '1 hour'; {'23:00:00'} If an arithmetic operation results in a day that doesn't exist in the given month, the last day of the month will be used instead: .. code-block:: edgeql-repl db> select "2021-01-31T15:00:00" + ... "1 month"; {'2021-02-28T15:00:00'} For arithmetic operations involving a ``cal::relative_duration`` consisting of multiple components (units), higher-order components are applied first followed by lower-order components. .. code-block:: edgeql-repl db> select "2021-04-30T15:00:00" + ... "1 month 1 day"; {'2021-05-31T15:00:00'} If you add the same components split into separate durations, adding the higher-order units first followed by the lower-order units, the calculation produces the same result as in the previous example: .. code-block:: edgeql-repl db> select "2021-04-30T15:00:00" + ... "1 month" + ... "1 day"; {'2021-05-31T15:00:00'} When the order of operations is reversed, the result may be different for some corner cases: .. code-block:: edgeql-repl db> select "2021-04-30T15:00:00" + ... "1 day" + ... "1 month"; {'2021-06-01T15:00:00'} .. rubric:: Gotchas Due to the implementation of ``relative_duration`` logic, arithmetic operations may behave counterintuitively. **Non-associative** .. code-block:: edgeql-repl db> select '2021-01-31T00:00:00' + ... '1 month' + ... '1 month'; {'2021-03-28T00:00:00'} db> select '2021-01-31T00:00:00' + ... ('1 month' + ... '1 month'); {'2021-03-31T00:00:00'} **Lossy** .. code-block:: edgeql-repl db> with m := '1 month' ... select '2021-01-31' + m ... = ... '2021-01-30' + m; {true} **Asymmetric** .. code-block:: edgeql-repl db> with m := '1 month' ... select '2021-01-31' + m - m; {'2021-01-28'} **Non-monotonic** .. code-block:: edgeql-repl db> with m := '1 month' ... select '2021-01-31T01:00:00' + m ... < ... '2021-01-30T23:00:00' + m; {true} db> with m := '2 month' ... select '2021-01-31T01:00:00' + m ... < ... '2021-01-30T23:00:00' + m; {false} For more information regarding interacting with this type, see :eql:func:`cal::to_relative_duration`, and :eql:func:`to_str` and date/time :eql:op:`operators `. ---------- .. eql:type:: cal::date_duration A type for representing a span of time in days. This type is similar to :eql:type:`cal::relative_duration`, except it only uses 2 units: months and days. It is the result of subtracting one :eql:type:`cal::local_date` from another. The purpose of this type is to allow performing ``+`` and ``-`` operations on a :eql:type:`cal::local_date` and to produce a :eql:type:`cal::local_date` as the result: .. code-block:: edgeql-repl db> select '2022-06-30' - ... '2022-06-25'; {'P5D'} db> select '2022-06-25' + ... '5 days'; {'2022-06-30'} db> select '2022-06-25' - ... '5 days'; {'2022-06-20'} When converting from a string, only the following units are valid: - ``'days'``, - ``'weeks'``, - ``'months'``, - ``'years'``, - ``'decades'``, - ``'centuries'``, - ``'millennia'``. .. code-block:: edgeql select '45 days'; select '3 weeks 5 days'; select '-7 millennia'; In most cases, ``date_duration`` is fully compatible with :eql:type:`cal::relative_duration` and shares the same general behavior and caveats. Gel will apply type coercion in the event it expects a :eql:type:`cal::relative_duration` and finds a ``cal::date_duration`` instead. For more information regarding interacting with this type, see :eql:func:`cal::to_date_duration` and date/time :eql:op:`operators `. ---------- .. eql:operator:: dtplus: datetime + duration -> datetime datetime + cal::relative_duration \ -> cal::relative_duration duration + duration -> duration duration + cal::relative_duration \ -> cal::relative_duration cal::relative_duration + cal::relative_duration \ -> cal::relative_duration cal::local_datetime + cal::relative_duration \ -> cal::relative_duration cal::local_datetime + duration \ -> cal::local_datetime cal::local_time + cal::relative_duration \ -> cal::relative_duration cal::local_time + duration -> cal::local_time cal::local_date + cal::date_duration \ -> cal::local_date cal::date_duration + cal::date_duration \ -> cal::date_duration cal::local_date + cal::relative_duration \ -> cal::local_datetime cal::local_date + duration -> cal::local_datetime .. api-index:: §datetime | duration §+§ datetime | duration§ Adds a duration and any other datetime value. This operator is commutative. .. code-block:: edgeql-repl db> select '22:00' + '1 hour'; {'23:00:00'} db> select '1 hour' + '22:00'; {'23:00:00'} db> select '1 hour' + '2 hours'; {10800s} ---------- .. eql:operator:: dtminus: duration - duration -> duration datetime - datetime -> duration datetime - duration -> datetime datetime - cal::relative_duration -> datetime cal::relative_duration - cal::relative_duration \ -> cal::relative_duration cal::local_datetime - cal::local_datetime \ -> cal::relative_duration cal::local_datetime - cal::relative_duration \ -> cal::local_datetime cal::local_datetime - duration \ -> cal::local_datetime cal::local_time - cal::local_time \ -> cal::relative_duration cal::local_time - cal::relative_duration \ -> cal::local_time cal::local_time - duration -> cal::local_time cal::date_duration - cal::date_duration \ -> cal::date_duration cal::local_date - cal::local_date \ -> cal::date_duration cal::local_date - cal::date_duration \ -> cal::local_date cal::local_date - cal::relative_duration \ -> cal::local_datetime cal::local_date - duration -> cal::local_datetime duration - cal::relative_duration \ -> cal::relative_duration cal::relative_duration - duration\ -> cal::relative_duration .. api-index:: §datetime | duration §-§ datetime | duration§ Subtracts two compatible datetime or duration values. .. code-block:: edgeql-repl db> select '2019-01-01T01:02:03+00' - ... '24 hours'; {'2018-12-31T01:02:03Z'} db> select '2019-01-01T01:02:03+00' - ... '2019-02-01T01:02:03+00'; {-2678400s} db> select '1 hour' - ... '2 hours'; {-3600s} When subtracting a :eql:type:`cal::local_date` type from another, the result is given as a whole number of days using the :eql:type:`cal::date_duration` type: .. code-block:: edgeql-repl db> select '2022-06-25' - ... '2019-02-01'; {'P1240D'} .. note:: Subtraction doesn't make sense for some type combinations. You couldn't subtract a point in time from a duration, so neither can Gel (although the inverse — subtracting a duration from a point in time — is perfectly fine). You also couldn't subtract a timezone-aware datetime from a local one or vice versa. If you attempt any of these, Gel will raise an exception as shown in these examples. When subtracting a date/time object from a time interval, an exception will be raised: .. code-block:: edgeql-repl db> select '1 day' - ... '2019-01-01T01:02:03+00'; QueryError: operator '-' cannot be applied to operands ... An exception will also be raised when trying to subtract a timezone-aware :eql:type:`std::datetime` type from :eql:type:`cal::local_datetime` or vice versa: .. code-block:: edgeql-repl db> select '2019-01-01T01:02:03+00' - ... '2019-02-01T01:02:03'; QueryError: operator '-' cannot be applied to operands... db> select '2019-02-01T01:02:03' - ... '2019-01-01T01:02:03+00'; QueryError: operator '-' cannot be applied to operands... ---------- .. eql:function:: std::datetime_current() -> datetime .. index:: now Returns the server's current date and time. .. code-block:: edgeql-repl db> select datetime_current(); {'2018-05-14T20:07:11.755827Z'} This function is volatile since it always returns the current time when it is called. As a result, it cannot be used in :ref:`computed properties defined in schema `. This does *not* apply to computed properties outside of schema. ---------- .. eql:function:: std::datetime_of_transaction() -> datetime .. index:: now Returns the date and time of the start of the current transaction. This function is non-volatile since it returns the current time when the transaction is started, not when the function is called. As a result, it can be used in :ref:`computed properties ` defined in schema. ---------- .. eql:function:: std::datetime_of_statement() -> datetime .. index:: now Returns the date and time of the start of the current statement. This function is non-volatile since it returns the current time when the statement is started, not when the function is called. As a result, it can be used in :ref:`computed properties ` defined in schema. ---------- .. eql:function:: std::datetime_get(dt: datetime, el: str) -> float64 std::datetime_get(dt: cal::local_datetime, \ el: str) -> float64 Returns the element of a date/time given a unit name. You may pass any of these unit names for *el*: - ``'epochseconds'`` - the number of seconds since 1970-01-01 00:00:00 UTC (Unix epoch) for :eql:type:`datetime` or local time for :eql:type:`cal::local_datetime`. It can be negative. - ``'century'`` - the century according to the Gregorian calendar - ``'day'`` - the day of the month (1-31) - ``'decade'`` - the decade (year divided by 10 and rounded down) - ``'dow'`` - the day of the week from Sunday (0) to Saturday (6) - ``'doy'`` - the day of the year (1-366) - ``'hour'`` - the hour (0-23) - ``'isodow'`` - the ISO day of the week from Monday (1) to Sunday (7) - ``'isoyear'`` - the ISO 8601 week-numbering year that the date falls in. See the ``'week'`` element for more details. - ``'microseconds'`` - the seconds including fractional value expressed as microseconds - ``'millennium'`` - the millennium. The third millennium started on Jan 1, 2001. - ``'milliseconds'`` - the seconds including fractional value expressed as milliseconds - ``'minutes'`` - the minutes (0-59) - ``'month'`` - the month of the year (1-12) - ``'quarter'`` - the quarter of the year (1-4) - ``'seconds'`` - the seconds, including fractional value from 0 up to and not including 60 - ``'week'`` - the number of the ISO 8601 week-numbering week of the year. ISO weeks are defined to start on Mondays and the first week of a year must contain Jan 4 of that year. - ``'year'`` - the year .. code-block:: edgeql-repl db> select datetime_get( ... '2018-05-07T15:01:22.306916+00', ... 'epochseconds'); {1525705282.306916} db> select datetime_get( ... '2018-05-07T15:01:22.306916+00', ... 'year'); {2018} db> select datetime_get( ... '2018-05-07T15:01:22.306916+00', ... 'quarter'); {2} db> select datetime_get( ... '2018-05-07T15:01:22.306916+00', ... 'doy'); {127} db> select datetime_get( ... '2018-05-07T15:01:22.306916+00', ... 'hour'); {15} ---------- .. eql:function:: cal::time_get(dt: cal::local_time, el: str) -> float64 Returns the element of a time value given a unit name. You may pass any of these unit names for *el*: - ``'midnightseconds'`` - ``'hour'`` - ``'microseconds'`` - ``'milliseconds'`` - ``'minutes'`` - ``'seconds'`` For full description of what these elements extract see :eql:func:`datetime_get`. .. code-block:: edgeql-repl db> select cal::time_get( ... '15:01:22.306916', 'minutes'); {1} db> select cal::time_get( ... '15:01:22.306916', 'milliseconds'); {22306.916} ---------- .. eql:function:: cal::date_get(dt: local_date, el: str) -> float64 Returns the element of a date given a unit name. The :eql:type:`cal::local_date` scalar has the following elements available for extraction: - ``'century'`` - the century according to the Gregorian calendar - ``'day'`` - the day of the month (1-31) - ``'decade'`` - the decade (year divided by 10 and rounded down) - ``'dow'`` - the day of the week from Sunday (0) to Saturday (6) - ``'doy'`` - the day of the year (1-366) - ``'isodow'`` - the ISO day of the week from Monday (1) to Sunday (7) - ``'isoyear'`` - the ISO 8601 week-numbering year that the date falls in. See the ``'week'`` element for more details. - ``'millennium'`` - the millennium. The third millennium started on Jan 1, 2001. - ``'month'`` - the month of the year (1-12) - ``'quarter'`` - the quarter of the year (1-4) not including 60 - ``'week'`` - the number of the ISO 8601 week-numbering week of the year. ISO weeks are defined to start on Mondays and the first week of a year must contain Jan 4 of that year. - ``'year'`` - the year .. code-block:: edgeql-repl db> select cal::date_get( ... '2018-05-07', 'century'); {21} db> select cal::date_get( ... '2018-05-07', 'year'); {2018} db> select cal::date_get( ... '2018-05-07', 'month'); {5} db> select cal::date_get( ... '2018-05-07', 'doy'); {127} ---------- .. eql:function:: std::duration_get(dt: duration, el: str) -> float64 std::duration_get(dt: cal::relative_duration, \ el: str) -> float64 std::duration_get(dt: cal::date_duration, \ el: str) -> float64 Returns the element of a duration given a unit name. You may pass any of these unit names as ``el``: - ``'millennium'`` - number of 1000-year chunks rounded down - ``'century'`` - number of centuries rounded down - ``'decade'`` - number of decades rounded down - ``'year'`` - number of years rounded down - ``'quarter'``- remaining quarters after whole years are accounted for - ``'month'`` - number of months left over after whole years are accounted for - ``'day'`` - number of days recorded in the duration - ``'hour'`` - number of hours - ``'minutes'`` - remaining minutes after whole hours are accounted for - ``'seconds'`` - remaining seconds, including fractional value after whole minutes are accounted for - ``'milliseconds'`` - remaining seconds including fractional value expressed as milliseconds - ``'microseconds'`` - remaining seconds including fractional value expressed as microseconds .. note :: Only for units ``'month'`` or larger or for units ``'hour'`` or smaller will you receive a total across multiple units expressed in the original duration. See *Gotchas* below for details. Additionally, it's possible to convert a given duration into seconds: - ``'totalseconds'`` - the number of seconds represented by the duration. It will be approximate for :eql:type:`cal::relative_duration` and :eql:type:`cal::date_duration` for units ``'month'`` or larger because a month is assumed to be 30 days exactly. The :eql:type:`duration` scalar has only ``'hour'`` and smaller units available for extraction. The :eql:type:`cal::relative_duration` scalar has all of the units available for extraction. The :eql:type:`cal::date_duration` scalar only has ``'date'`` and larger units available for extraction. .. code-block:: edgeql-repl db> select duration_get( ... '400 months', 'year'); {33} db> select duration_get( ... '400 months', 'month'); {4} db> select duration_get( ... '1 month 20 days 30 hours', ... 'day'); {20} db> select duration_get( ... '30 hours', 'hour'); {30} db> select duration_get( ... '1 month 20 days 30 hours', ... 'hour'); {30} db> select duration_get('30 hours', 'hour'); {30} db> select duration_get( ... '1 month 20 days 30 hours', ... 'totalseconds'); {4428000} db> select duration_get( ... '30 hours', 'totalseconds'); {108000} .. rubric:: Gotchas This function will provide you with a calculated total for the unit passed as ``el``, but only within the given "size class" of the unit. These size classes exist because they are logical breakpoints that we can't reliably convert values across. A month might be 30 days long, or it might be 28 or 29 or 31. A day is generally 24 hours, but with daylight savings, it might be longer or shorter. As a result, it's impossible to convert across these lines in a way that works in every situation. For some use cases, assuming a 30 day month works fine. For others, it might not. The size classes are as follows: - ``'month'`` and larger - ``'day'`` - ``'hour'`` and smaller For example, if you specify ``'day'`` as your ``el`` argument, the function will return only the number of days expressed as ``N days`` in your duration. It will not add another day to the returned count for every 24 hours (defined as ``24 hours``) in the duration, nor will it consider the months' constituent day counts in the returned value. Specifying ``'decade'`` for ``el`` will total up all decades represented in units ``'month'`` and larger, but it will not add a decade's worth of days to the returned value as an additional decade. In this example, the duration represents more than a day's time, but since ``'day'`` and ``'hour'`` are in different size classes, the extra day stemming from the duration's hours is not added. .. code-block:: edgeql-repl db> select duration_get( ... '1 day 36 hours', 'day'); {1} In this counter example, both the decades and months are pooled together since they are in the same size class. The return value is 5: the 2 ``'decades'`` and the 3 decades in ``'400 months'``. .. code-block:: edgeql-repl db> select duration_get( ... '2 decades 400 months', 'decade'); {5} If a unit from a smaller size class would contribute to your desired unit's total, it is not added. .. code-block:: edgeql-repl db> select duration_get( ... '1 year 400 days', 'year'); {1} When you request a unit in the smallest size class, it will be pooled with other durations in the same size class. .. code-block:: edgeql-repl db> select duration_get( ... '20 hours 3600 seconds', 'hour'); {21} Seconds and smaller units always return remaining time in that unit after accounting for the next larger unit. .. code-block:: edgeql-repl db> select duration_get( ... '20 hours 3600 seconds', 'seconds'); {0} db> select duration_get( ... '20 hours 3630 seconds', 'seconds'); {30} Normalization and truncation may help you deal with this. If your use case allows for making assumptions about the duration of a month or a day, you can make those conversions for yourself using the :eql:func:`cal::duration_normalize_hours` or :eql:func:`cal::duration_normalize_days` functions. If you got back a duration as a result of a datetime calculation and don't need the level of granularity you have, you can truncate the value with :eql:func:`duration_truncate`. ---------- .. eql:function:: std::datetime_truncate(dt: datetime, unit: str) -> datetime Truncates the input datetime to a particular precision. The valid units in order or decreasing precision are: - ``'microseconds'`` - ``'milliseconds'`` - ``'seconds'`` - ``'minutes'`` - ``'hours'`` - ``'days'`` - ``'weeks'`` - ``'months'`` - ``'quarters'`` - ``'years'`` - ``'decades'`` - ``'centuries'`` .. code-block:: edgeql-repl db> select datetime_truncate( ... '2018-05-07T15:01:22.306916+00', 'years'); {'2018-01-01T00:00:00Z'} db> select datetime_truncate( ... '2018-05-07T15:01:22.306916+00', 'quarters'); {'2018-04-01T00:00:00Z'} db> select datetime_truncate( ... '2018-05-07T15:01:22.306916+00', 'days'); {'2018-05-07T00:00:00Z'} db> select datetime_truncate( ... '2018-05-07T15:01:22.306916+00', 'hours'); {'2018-05-07T15:00:00Z'} ---------- .. eql:function:: std::duration_truncate(dt: duration, unit: str) -> duration std::duration_truncate(dt: cal::relative_duration, \ unit: str) -> cal::relative_duration Truncates the input duration to a particular precision. The valid units for :eql:type:`duration` are: - ``'microseconds'`` - ``'milliseconds'`` - ``'seconds'`` - ``'minutes'`` - ``'hours'`` In addition to the above the following are also valid for :eql:type:`cal::relative_duration`: - ``'days'`` - ``'weeks'`` - ``'months'`` - ``'years'`` - ``'decades'`` - ``'centuries'`` .. code-block:: edgeql-repl db> select duration_truncate( ... '15:01:22', 'hours'); {'15:00:00'} db> select duration_truncate( ... '15:01:22.306916', 'minutes'); {'15:01:00'} db> select duration_truncate( ... '400 months', 'years'); {'P33Y'} db> select duration_truncate( ... '400 months', 'decades'); {'P30Y'} ---------- .. eql:function:: std::to_datetime(s: str, fmt: optional str={}) -> datetime std::to_datetime(local: cal::local_datetime, zone: str) \ -> datetime std::to_datetime(year: int64, month: int64, day: int64, \ hour: int64, min: int64, sec: float64, zone: str) \ -> datetime std::to_datetime(epochseconds: decimal) -> datetime std::to_datetime(epochseconds: float64) -> datetime std::to_datetime(epochseconds: int64) -> datetime .. index:: parse datetime Create a :eql:type:`datetime` value. The :eql:type:`datetime` value can be parsed from the input :eql:type:`str` *s*. By default, the input is expected to conform to ISO 8601 format. However, the optional argument *fmt* can be used to override the :ref:`input format ` to other forms. .. code-block:: edgeql-repl db> select to_datetime('2018-05-07T15:01:22.306916+00'); {'2018-05-07T15:01:22.306916Z'} db> select to_datetime('2018-05-07T15:01:22+00'); {'2018-05-07T15:01:22Z'} db> select to_datetime('May 7th, 2018 15:01:22 +00', ... 'Mon DDth, YYYY HH24:MI:SS TZH'); {'2018-05-07T15:01:22Z'} Alternatively, the :eql:type:`datetime` value can be constructed from a :eql:type:`cal::local_datetime` value: .. code-block:: edgeql-repl db> select to_datetime( ... '2019-01-01T01:02:03', 'HKT'); {'2018-12-31T17:02:03Z'} Another way to construct a the :eql:type:`datetime` value is to specify it in terms of its component parts: year, month, day, hour, min, sec, and :ref:`zone `. .. code-block:: edgeql-repl db> select to_datetime( ... 2018, 5, 7, 15, 1, 22.306916, 'UTC'); {'2018-05-07T15:01:22.306916000Z'} Finally, it is also possible to convert a Unix timestamp to a :eql:type:`datetime` .. code-block:: edgeql-repl db> select to_datetime(1590595184.584); {'2020-05-27T15:59:44.584000000Z'} ------------ .. eql:function:: cal::to_local_datetime(s: str, fmt: optional str={}) \ -> local_datetime cal::to_local_datetime(dt: datetime, zone: str) \ -> local_datetime cal::to_local_datetime(year: int64, month: int64, \ day: int64, hour: int64, min: int64, sec: float64) \ -> local_datetime .. index:: parse local_datetime Create a :eql:type:`cal::local_datetime` value. Similar to :eql:func:`to_datetime`, the :eql:type:`cal::local_datetime` value can be parsed from the input :eql:type:`str` *s* with an optional *fmt* argument or it can be given in terms of its component parts: *year*, *month*, *day*, *hour*, *min*, *sec*. For more details on formatting see :ref:`here `. .. code-block:: edgeql-repl db> select cal::to_local_datetime('2018-05-07T15:01:22.306916'); {'2018-05-07T15:01:22.306916'} db> select cal::to_local_datetime('May 7th, 2018 15:01:22', ... 'Mon DDth, YYYY HH24:MI:SS'); {'2018-05-07T15:01:22'} db> select cal::to_local_datetime( ... 2018, 5, 7, 15, 1, 22.306916); {'2018-05-07T15:01:22.306916'} A timezone-aware :eql:type:`datetime` type can be converted to local datetime in the specified :ref:`timezone `: .. code-block:: edgeql-repl db> select cal::to_local_datetime( ... '2018-12-31T22:00:00+08', ... 'America/Chicago'); {'2018-12-31T08:00:00'} db> select cal::to_local_datetime( ... '2018-12-31T22:00:00+08', ... 'CST'); {'2018-12-31T08:00:00'} ------------ .. eql:function:: cal::to_local_date(s: str, fmt: optional str={}) \ -> cal::local_date cal::to_local_date(dt: datetime, zone: str) \ -> cal::local_date cal::to_local_date(year: int64, month: int64, \ day: int64) -> cal::local_date .. index:: parse local_date Create a :eql:type:`cal::local_date` value. Similar to :eql:func:`to_datetime`, the :eql:type:`cal::local_date` value can be parsed from the input :eql:type:`str` *s* with an optional *fmt* argument or it can be given in terms of its component parts: *year*, *month*, *day*. For more details on formatting see :ref:`here `. .. code-block:: edgeql-repl db> select cal::to_local_date('2018-05-07'); {'2018-05-07'} db> select cal::to_local_date('May 7th, 2018', 'Mon DDth, YYYY'); {'2018-05-07'} db> select cal::to_local_date(2018, 5, 7); {'2018-05-07'} A timezone-aware :eql:type:`datetime` type can be converted to local date in the specified :ref:`timezone `: .. code-block:: edgeql-repl db> select cal::to_local_date( ... '2018-12-31T22:00:00+08', ... 'America/Chicago'); {'2019-01-01'} ------------ .. eql:function:: cal::to_local_time(s: str, fmt: optional str={}) \ -> local_time cal::to_local_time(dt: datetime, zone: str) \ -> local_time cal::to_local_time(hour: int64, min: int64, sec: float64) \ -> local_time .. index:: parse local_time Create a :eql:type:`cal::local_time` value. Similar to :eql:func:`to_datetime`, the :eql:type:`cal::local_time` value can be parsed from the input :eql:type:`str` *s* with an optional *fmt* argument or it can be given in terms of its component parts: *hour*, *min*, *sec*. For more details on formatting see :ref:`here `. .. code-block:: edgeql-repl db> select cal::to_local_time('15:01:22.306916'); {'15:01:22.306916'} db> select cal::to_local_time('03:01:22pm', 'HH:MI:SSam'); {'15:01:22'} db> select cal::to_local_time(15, 1, 22.306916); {'15:01:22.306916'} A timezone-aware :eql:type:`datetime` type can be converted to local date in the specified :ref:`timezone `: .. code-block:: edgeql-repl db> select cal::to_local_time( ... '2018-12-31T22:00:00+08', ... 'America/Los_Angeles'); {'06:00:00'} ------------ .. eql:function:: std::to_duration( \ named only hours: int64=0, \ named only minutes: int64=0, \ named only seconds: float64=0, \ named only microseconds: int64=0 \ ) -> duration .. index:: parse duration Create a :eql:type:`duration` value. This function uses ``named only`` arguments to create a :eql:type:`duration` value. The available duration fields are: *hours*, *minutes*, *seconds*, *microseconds*. .. code-block:: edgeql-repl db> select to_duration(hours := 1, ... minutes := 20, ... seconds := 45); {4845s} db> select to_duration(seconds := 4845); {4845s} .. eql:function:: std::duration_to_seconds(cur: duration) -> decimal Return duration as total number of seconds in interval. .. code-block:: edgeql-repl db> select duration_to_seconds('1 hour'); {3600.000000n} db> select duration_to_seconds('10 second 123 ms'); {10.123000n} ------------ .. eql:function:: cal::to_relative_duration( \ named only years: int64=0, \ named only months: int64=0, \ named only days: int64=0, \ named only hours: int64=0, \ named only minutes: int64=0, \ named only seconds: float64=0, \ named only microseconds: int64=0 \ ) -> cal::relative_duration .. index:: parse relative_duration Create a :eql:type:`cal::relative_duration` value. This function uses ``named only`` arguments to create a :eql:type:`cal::relative_duration` value. The available duration fields are: *years*, *months*, *days*, *hours*, *minutes*, *seconds*, *microseconds*. .. code-block:: edgeql-repl db> select cal::to_relative_duration(years := 5, minutes := 1); {'P5YT1S'} db> select cal::to_relative_duration(months := 3, days := 27); {'P3M27D'} ------------ .. eql:function:: cal::to_date_duration( \ named only years: int64=0, \ named only months: int64=0, \ named only days: int64=0 \ ) -> cal::date_duration .. index:: parse date_duration Create a :eql:type:`cal::date_duration` value. This function uses ``named only`` arguments to create a :eql:type:`cal::date_duration` value. The available duration fields are: *years*, *months*, *days*. .. code-block:: edgeql-repl db> select cal::to_date_duration(years := 1, days := 3); {'P1Y3D'} db> select cal::to_date_duration(days := 12); {'P12D'} ------------ .. eql:function:: cal::duration_normalize_hours( \ dur: cal::relative_duration \ ) -> cal::relative_duration .. index:: justify_hours Convert 24-hour chunks into days. This function converts all 24-hour chunks into day units. The resulting :eql:type:`cal::relative_duration` is guaranteed to have less than 24 hours in total in the units smaler than days. .. code-block:: edgeql-repl db> select cal::duration_normalize_hours( ... '1312 hours'); {'P54DT16H'} This is a lossless operation because 24 hours are always equal to 1 day in :eql:type:`cal::relative_duration` units. This is sometimes used together with :eql:func:`cal::duration_normalize_days`. ------------ .. eql:function:: cal::duration_normalize_days( \ dur: cal::relative_duration \ ) -> cal::relative_duration cal::duration_normalize_days( \ dur: cal::date_duration \ ) -> cal::date_duration .. index:: justify_days Convert 30-day chunks into months. This function converts all 30-day chunks into month units. The resulting :eql:type:`cal::relative_duration` or :eql:type:`cal::date_duration` is guaranteed to have less than 30 day units. .. code-block:: edgeql-repl db> select cal::duration_normalize_days( ... '1312 days'); {'P3Y7M22D'} db> select cal::duration_normalize_days( ... '1312 days'); {'P3Y7M22D'} This function is a form of approximation and does not preserve the exact duration. This is often used together with :eql:func:`cal::duration_normalize_hours`. ================================================ FILE: docs/reference/stdlib/deprecated.rst ================================================ .. _ref_std_deprecated: ========== Deprecated ========== :edb-alt-title: Deprecated Functions .. list-table:: :class: funcoptable * - :eql:func:`str_lpad` - :eql:func-desc:`str_lpad` * - :eql:func:`str_rpad` - :eql:func-desc:`str_rpad` * - :eql:func:`str_ltrim` - :eql:func-desc:`str_ltrim` * - :eql:func:`str_rtrim` - :eql:func-desc:`str_rtrim` ---------- .. eql:function:: std::str_lpad(string: str, n: int64, fill: str = ' ') -> str Return the input *string* left-padded to the length *n*. .. warning:: This function is deprecated. Use :eql:func:`std::str_pad_start` instead. If the *string* is longer than *n*, then it is truncated to the first *n* characters. Otherwise, the *string* is padded on the left up to the total length *n* using *fill* characters (space by default). .. code-block:: edgeql-repl db> select str_lpad('short', 10); {' short'} db> select str_lpad('much too long', 10); {'much too l'} db> select str_lpad('short', 10, '.:'); {'.:.:.short'} ---------- .. eql:function:: std::str_rpad(string: str, n: int64, fill: str = ' ') -> str Return the input *string* right-padded to the length *n*. .. warning:: This function is deprecated. Use :eql:func:`std::str_pad_end` instead. If the *string* is longer than *n*, then it is truncated to the first *n* characters. Otherwise, the *string* is padded on the right up to the total length *n* using *fill* characters (space by default). .. code-block:: edgeql-repl db> select str_rpad('short', 10); {'short '} db> select str_rpad('much too long', 10); {'much too l'} db> select str_rpad('short', 10, '.:'); {'short.:.:.'} ---------- .. eql:function:: std::str_ltrim(string: str, trim: str = ' ') -> str Return the input string with all leftmost *trim* characters removed. .. warning:: This function is deprecated. Use :eql:func:`std::str_trim_start` instead. If the *trim* specifies more than one character they will be removed from the beginning of the *string* regardless of the order in which they appear. .. code-block:: edgeql-repl db> select str_ltrim(' data'); {'data'} db> select str_ltrim('.....data', '.:'); {'data'} db> select str_ltrim(':::::data', '.:'); {'data'} db> select str_ltrim(':...:data', '.:'); {'data'} db> select str_ltrim('.:.:.data', '.:'); {'data'} ---------- .. eql:function:: std::str_rtrim(string: str, trim: str = ' ') -> str Return the input string with all rightmost *trim* characters removed. .. warning:: This function is deprecated. Use :eql:func:`std::str_trim_end` instead. If the *trim* specifies more than one character they will be removed from the end of the *string* regardless of the order in which they appear. .. code-block:: edgeql-repl db> select str_rtrim('data '); {'data'} db> select str_rtrim('data.....', '.:'); {'data'} db> select str_rtrim('data:::::', '.:'); {'data'} db> select str_rtrim('data:...:', '.:'); {'data'} db> select str_rtrim('data.:.:.', '.:'); {'data'} ---------- .. eql:type:: cfg::DatabaseConfig The branch-level configuration object type. As of |EdgeDB| 5.0, this config object represents database *branch* and instance-level configuration. **Use the identical** :eql:type:`cfg::BranchConfig` instead. ================================================ FILE: docs/reference/stdlib/enum.rst ================================================ .. _ref_std_enum: ===== Enums ===== :edb-alt-title: Enum Type .. eql:type:: std::enum An enumerated type is a data type consisting of an ordered list of values. An enum type can be declared in a schema by using the following syntax: .. code-block:: sdl scalar type Color extending enum; Enum values can then be accessed directly: .. code-block:: edgeql-repl db> select Color.Red is Color; {true} :eql:op:`Casting ` can be used to obtain an enum value in an expression: .. code-block:: edgeql-repl db> select 'Red' is Color; {false} db> select 'Red' is Color; {true} db> select 'Red' = Color.Red; {true} .. note:: The enum values in EdgeQL are string-like in the fact that they can contain any characters that the strings can. This is different from some other languages where enum values are identifier-like and thus cannot contain some characters. For example, when working with GraphQL enum values that contain characters that aren't allowed in identifiers cannot be properly reflected. To address this, consider using only identifier-like enum values in cases where such compatibility is needed. Currently, enum values cannot be longer than 63 characters. ================================================ FILE: docs/reference/stdlib/fts.rst ================================================ .. _ref_std_fts: .. versionadded:: 4.0 ================ Full-text Search ================ The ``fts`` built-in module contains various tools that enable full-text search functionality in Gel. .. note:: Since full-text search is a natural language search, it may not be ideal for your use case, particularly if you want to find partial matches. In that case, you may want to look instead at :ref:`ref_ext_pgtrgm`. .. list-table:: :class: funcoptable * - :eql:type:`fts::Language` - Common languages :eql:type:`enum` * - :eql:type:`fts::PGLanguage` - Postgres languages :eql:type:`enum` * - :eql:type:`fts::Weight` - Weight category :eql:type:`enum` * - :eql:type:`fts::document` - Opaque document type * - :eql:func:`fts::search` - :eql:func-desc:`fts::search` * - :eql:func:`fts::with_options` - :eql:func-desc:`fts::with_options` When considering FTS functionality our goal was to come up with an interface that could support different backend FTS providers. To achieve that we've identified the following components to the FTS functionality: 1) Valid FTS targets must be indexed. 2) The expected language should be specified at the time of creating an index. 3) It should be possible to mark document parts as having different relevance. 4) It should be possible to assign custom weights at runtime so as to make searching more flexible. 5) The search query should be close to what people are already used to. To address (1) we introduce a special ``fts::index``. The presence of this index in a type declaration indicates that the type in question can be subject to full-text search. This is an unusual index as it actually affects the results of :eql:func:`fts::search` function. This is unlike most indexes which only affect the performance and not the actual results. Another special feature of ``fts::index`` is that at most one such index can be declared per type. If a type inherits this index from a parent and also declares its own, only the new index applies and fully overrides the ``fts::index`` inherited from the parent type. This means that when dealing with a hierarchy of full-text-searchable types, each type can customize what gets searched as needed. The language (2) is defined as part of the ``fts::index on`` expression. A special function :eql:func:`fts::with_options` is used for that purpose: .. code-block:: sdl type Item { required available: bool { default := false; }; required name: str; description: str; index fts::index on ( fts::with_options( .name, language := fts::Language.eng ) ); } The above declaration specifies that ``Item`` is full-text-searchable, specifically by examining the ``name`` property (and ignoring ``description``) and assuming that the contents of that property are in English. Marking different parts of the document as having different relevance (3) can also be done by the :eql:func:`fts::with_options` function: .. code-block:: sdl type Item { required available: bool { default := false; }; required name: str; description: str; index fts::index on (( fts::with_options( .name, language := fts::Language.eng, weight_category := fts::Weight.A, ), fts::with_options( .description, language := fts::Language.eng, weight_category := fts::Weight.B, ) )); } The schema now indicates that both ``name`` and ``description`` properties of ``Item`` are full-text-searchable. Additionally, the ``name`` and ``description`` have potentially different relevance. By default :eql:func:`fts::search` assumes that the weight categories ``A``, ``B``, ``C``, and ``D`` have the following weights: ``[1, 0.5, 0.25, 0.125]``. This makes each successive category relevance score halved. Consider the following: .. code-block:: edgeql-repl gel> select Item{name, description}; { default::Item {name: 'Canned corn', description: {}}, default::Item { name: 'Candy corn', description: 'A great Halloween treat', }, default::Item { name: 'Sweet', description: 'Treat made with corn sugar', }, } gel> with res := ( .... select fts::search(Item, 'corn treat', language := 'eng') .... ) .... select res.object {name, description, score := res.score} .... order by res.score desc; { default::Item { name: 'Candy corn', description: 'A great Halloween treat', score: 0.4559453, }, default::Item { name: 'Canned corn', description: {}, score: 0.30396354, }, default::Item { name: 'Sweet', description: 'Treat made with corn sugar', score: 0.30396354, }, } As you can see, the highest scoring match came from an ``Item`` that had the search terms appear in both ``name`` and ``description``. It is also apparent that matching a single term from the search query in the ``name`` property scores the same as matching two terms in ``description`` as we would expect based on their weight categories. We can, however, customize the weights (4) to change this trend: .. code-block:: edgeql-repl gel> with res := ( .... select fts::search( .... Item, 'corn treat', .... language := 'eng', .... weights := [0.2, 1], .... ) .... ) .... select res.object {name, description, score := res.score} .... order by res.score desc; { default::Item { name: 'Sweet', description: 'Treat made with corn sugar', score: 0.6079271, }, default::Item { name: 'Candy corn', description: 'A great Halloween treat', score: 0.36475626, }, default::Item { name: 'Canned corn', description: {}, score: 0.06079271, }, } We can even use custom weights to completely ignore one of the properties (e.g. ``name``) in our search, although we also need to add a filter based on the score to make this work properly: .. code-block:: edgeql-repl gel> with res := ( .... select fts::search( .... Item, 'corn treat', .... language := 'eng', .... weights := [0, 1], .... ) .... ) .... select res.object {name, description, score := res.score} .... filter res.score > 0 .... order by res.score desc; { default::Item { name: 'Sweet', description: 'Treat made with corn sugar', score: 0.6079271, }, default::Item { name: 'Candy corn', description: 'A great Halloween treat', score: 0.30396354, }, } Finally, the search query supports features for fine-tuning (5). By default, all search terms are desirable, but ultimately optional. You can enclose a term or even a phrase in ``"..."`` to indicate that this specific term is of increased importance and should appear in all matches: .. code-block:: edgeql-repl gel> with res := ( .... select fts::search( .... Item, '"corn sugar"', .... language := 'eng', .... ) .... ) .... select res.object {name, description, score := res.score} .... order by res.score desc; { default::Item { name: 'Sweet', description: 'Treat made with corn sugar', score: 0.4955161, }, } Only one ``Item`` contains the phrase "corn sugar" and incomplete matches are omitted. The search query can also use ``AND`` (using upper-case to indicate that it is a query modifier and not part of the query) to indicate whether terms are required or optional: .. code-block:: edgeql-repl gel> with res := ( .... select fts::search( .... Item, 'sweet AND treat', .... language := 'eng', .... ) .... ) .... select res.object {name, description, score := res.score} .... order by res.score desc; { default::Item { name: 'Sweet', description: 'Treat made with corn sugar', score: 0.70076555, }, } Adding a ``!`` in front of a search term marks it as something that the matching object *must not* contain: .. code-block:: edgeql-repl gel> with res := ( .... select fts::search( .... Item, '!treat', .... language := 'eng', .... ) .... ) .... select res.object {name, description, score := res.score} .... order by res.score desc; { default::Item { name: 'Canned corn', description: {}, score: 0, }, } ---------- .. eql:type:: fts::Language An :eql:type:`enum` for representing commonly supported languages. When indexing an object for full-text search it is important to specify the expected language by :eql:func:`fts::with_options` function. This particular :eql:type:`enum` represents languages that are common across several possible [future] backend implementations and thus are "safe" even if the backend implementation switches from one of the options to another. This generic enum is the recommended way of specifying the language. The following `ISO 639-3 `_ language identifiers are available: ``ara``, ``hye``, ``eus``, ``cat``, ``dan``, ``nld``, ``eng``, ``fin``, ``fra``, ``deu``, ``ell``, ``hin``, ``hun``, ``ind``, ``gle``, ``ita``, ``nor``, ``por``, ``ron``, ``rus``, ``spa``, ``swe``, ``tur``. ---------- .. eql:type:: fts::PGLanguage An :eql:type:`enum` for representing languages supported by PostgreSQL. When indexing an object for full-text search it is important to specify the expected language by :eql:func:`fts::with_options` function. This particular :eql:type:`enum` represents languages that are available in PostgreSQL implementation of full-text search. The following `ISO 639-3 `_ language identifiers are available: ``ara``, ``hye``, ``eus``, ``cat``, ``dan``, ``nld``, ``eng``, ``fin``, ``fra``, ``deu``, ``ell``, ``hin``, ``hun``, ``ind``, ``gle``, ``ita``, ``lit``, ``npi``, ``nor``, ``por``, ``ron``, ``rus``, ``srp``, ``spa``, ``swe``, ``tam``, ``tur``, ``yid``. Additionally, the ``xxx_simple`` identifier is also available to represent the ``pg_catalog.simple`` language setting. Unless you need some particular language setting that is not available in the :eql:type:`fts::Language`, it is recommended that you use the more general lanuguage enum instead. ---------- .. eql:type:: fts::Weight An :eql:type:`enum` for representing weight categories. When indexing an object for full-text search different properties of this object may have different significance. To account for that, they can be assigned different weight categories by using :eql:func:`fts::with_options` function. There are four available weight categories: ``A``, ``B``, ``C``, or ``D``. ---------- .. eql:type:: fts::document An opaque transient type used in ``fts::index``. This type is technically what the ``fts::index`` expects as a valid ``on`` expression. It cannot be directly instantiated and can only be produced as the result of applying the special :eql:func:`fts::with_options` function. Thus this type only appears in full-text search index definitions and cannot appear as either a property type or anywhere in regular queries. ------------ .. eql:function:: fts::search( \ object: anyobject, \ query: str, \ named only language: str = fts::Language.eng, \ named only weights: optional array = {}, \ ) -> optional tuple Perform full-text search on a target object. This function applies the search ``query`` to the specified object. If a match is found, the result will consist of a tuple with the matched ``object`` and the corresponding ``score``. A higher ``score`` indicates a better match. In case no match is found, the function will return an empty set ``{}``. Likewise, ``{}`` is returned if the ``object`` has no ``fts::index`` defined for it. The ``language`` parameter can be specified in order to match the expected indexed language. In case of mismatch there is a big chance that the query will not produce the expected results. The optional ``weights`` parameter can be passed in order to provide custom weights to the different weight groups. By default, the weights are ``[1, 0.5, 0.25, 0.125]`` representing groups of diminishing significance. ------------ .. eql:function:: fts::with_options( \ text: str, \ NAMED ONLY language: anyenum, \ NAMED ONLY weight_category: optional fts::Weight = \ fts::Weight.A, \ ) -> fts::document Assign language and weight category to a document portion. This is a special function that can only appear inside ``fts::index`` expressions. The ``text`` expression specifies the portion of the document that will be indexed and available for full-text search. The ``language`` parameter specifies the expected language of the ``text`` expression. This affects how the index accounts for grammatical variants of a given word (e.g. how plural and singular forms are determined, etc.). The ``weight_category`` parameter assigns one of four available weight categories to the ``text`` expression: ``A``, ``B``, ``C``, or ``D``. By themselves, the categories simply group together portions of the document so that these groups can be ascribed different significance by the :eql:func:`fts::search` function. By default it is assumed that each successive category is half as significant as the previous, starting with ``A`` as the most significant. However, these default weights can be overridden when making a call to :eql:func:`fts::search`. .. _iso639: https://iso639-3.sil.org/code_tables/639/data ================================================ FILE: docs/reference/stdlib/generic.rst ================================================ .. _ref_std_generic: ======= Generic ======= :edb-alt-title: Generic Functions and Operators .. list-table:: :class: funcoptable * - :eql:op:`anytype = anytype ` - :eql:op-desc:`eq` * - :eql:op:`anytype != anytype ` - :eql:op-desc:`neq` * - :eql:op:`anytype ?= anytype ` - :eql:op-desc:`coaleq` * - :eql:op:`anytype ?!= anytype ` - :eql:op-desc:`coalneq` * - :eql:op:`anytype \< anytype ` - :eql:op-desc:`lt` * - :eql:op:`anytype \> anytype ` - :eql:op-desc:`gt` * - :eql:op:`anytype \<= anytype ` - :eql:op-desc:`lteq` * - :eql:op:`anytype \>= anytype ` - :eql:op-desc:`gteq` * - :eql:func:`len` - :eql:func-desc:`len` * - :eql:func:`contains` - :eql:func-desc:`contains` * - :eql:func:`find` - :eql:func-desc:`find` .. note:: In EdgeQL, any value can be compared to another as long as their types are compatible. ----------- .. eql:operator:: eq: anytype = anytype -> bool .. index:: comparison .. api-index:: = Compares two values for equality. .. code-block:: edgeql-repl db> select 3 = 3.0; {true} db> select 3 = 3.14; {false} db> select [1, 2] = [1, 2]; {true} db> select (1, 2) = (x := 1, y := 2); {true} db> select (x := 1, y := 2) = (a := 1, b := 2); {true} db> select 'hello' = 'world'; {false} .. warning:: When either operand in an equality comparison is an empty set, the result will not be a ``bool`` but instead an empty set. .. code-block:: edgeql-repl db> select true = {}; {} If one of the operands in an equality comparison could be an empty set, you may want to use the :eql:op:`coalescing equality ` operator (``?=``) instead. ---------- .. eql:operator:: neq: anytype != anytype -> bool .. index:: not equal, comparison .. api-index:: != Compares two values for inequality. .. code-block:: edgeql-repl db> select 3 != 3.0; {false} db> select 3 != 3.14; {true} db> select [1, 2] != [2, 1]; {false} db> select (1, 2) != (x := 1, y := 2); {false} db> select (x := 1, y := 2) != (a := 1, b := 2); {false} db> select 'hello' != 'world'; {true} .. warning:: When either operand in an inequality comparison is an empty set, the result will not be a ``bool`` but instead an empty set. .. code-block:: edgeql-repl db> select true != {}; {} If one of the operands in an inequality comparison could be an empty set, you may want to use the :eql:op:`coalescing inequality ` operator (``?!=``) instead. ---------- .. eql:operator:: coaleq: optional anytype ?= optional anytype -> bool .. index:: coalesce equal, comparison, empty set .. api-index:: ?= Compares two (potentially empty) values for equality. This works the same as a regular :eql:op:`=` operator, but also allows comparing an empty ``{}`` set. Two empty sets are considered equal. .. code-block:: edgeql-repl db> select {1} ?= {1.0}; {true} db> select {1} ?= {}; {false} db> select {} ?= {}; {true} ---------- .. eql:operator:: coalneq: optional anytype ?!= optional anytype -> bool .. index:: coalesce not equal, comparison .. api-index:: ?!= Compares two (potentially empty) values for inequality. This works the same as a regular :eql:op:`=` operator, but also allows comparing an empty ``{}`` set. Two empty sets are considered equal. .. code-block:: edgeql-repl db> select {2} ?!= {2}; {false} db> select {1} ?!= {}; {true} db> select {} ?!= {}; {false} ---------- .. eql:operator:: lt: anytype < anytype -> bool .. index:: comparison .. api-index:: < Less than operator. The operator returns ``true`` if the value of the left expression is less than the value of the right expression: .. code-block:: edgeql-repl db> select 1 < 2; {true} db> select 2 < 2; {false} db> select 'hello' < 'world'; {true} db> select (1, 'hello') < (1, 'world'); {true} .. warning:: When either operand in a comparison is an empty set, the result will not be a ``bool`` but instead an empty set. .. code-block:: edgeql-repl db> select 1 < {}; {} If one of the operands in a comparison could be an empty set, you may want to coalesce the result of the comparison with ``false`` to ensure your result is boolean. .. code-block:: edgeql-repl db> select (1 < {}) ?? false; {false} ---------- .. eql:operator:: gt: anytype > anytype -> bool .. index:: comparison .. api-index:: > Greater than operator. The operator returns ``true`` if the value of the left expression is greater than the value of the right expression: .. code-block:: edgeql-repl db> select 1 > 2; {false} db> select 3 > 2; {true} db> select 'hello' > 'world'; {false} db> select (1, 'hello') > (1, 'world'); {false} .. warning:: When either operand in a comparison is an empty set, the result will not be a ``bool`` but instead an empty set. .. code-block:: edgeql-repl db> select 1 > {}; {} If one of the operands in a comparison could be an empty set, you may want to coalesce the result of the comparison with ``false`` to ensure your result is boolean. .. code-block:: edgeql-repl db> select (1 > {}) ?? false; {false} ---------- .. eql:operator:: lteq: anytype <= anytype -> bool .. index:: comparison .. api-index:: <= Less or equal operator. The operator returns ``true`` if the value of the left expression is less than or equal to the value of the right expression: .. code-block:: edgeql-repl db> select 1 <= 2; {true} db> select 2 <= 2; {true} db> select 3 <= 2; {false} db> select 'hello' <= 'world'; {true} db> select (1, 'hello') <= (1, 'world'); {true} .. warning:: When either operand in a comparison is an empty set, the result will not be a ``bool`` but instead an empty set. .. code-block:: edgeql-repl db> select 1 <= {}; {} If one of the operands in a comparison could be an empty set, you may want to coalesce the result of the comparison with ``false`` to ensure your result is boolean. .. code-block:: edgeql-repl db> select (1 <= {}) ?? false; {false} ---------- .. eql:operator:: gteq: anytype >= anytype -> bool .. index:: comparison .. api-index:: >= Greater or equal operator. The operator returns ``true`` if the value of the left expression is greater than or equal to the value of the right expression: .. code-block:: edgeql-repl db> select 1 >= 2; {false} db> select 2 >= 2; {true} db> select 3 >= 2; {true} db> select 'hello' >= 'world'; {false} db> select (1, 'hello') >= (1, 'world'); {false} .. warning:: When either operand in a comparison is an empty set, the result will not be a ``bool`` but instead an empty set. .. code-block:: edgeql-repl db> select 1 >= {}; {} If one of the operands in a comparison could be an empty set, you may want to coalesce the result of the comparison with ``false`` to ensure your result is boolean. .. code-block:: edgeql-repl db> select (1 >= {}) ?? false; {false} ---------- .. eql:function:: std::len(value: str) -> int64 std::len(value: bytes) -> int64 std::len(value: array) -> int64 .. index:: length, count Returns the number of elements of a given value. This function works with the :eql:type:`str`, :eql:type:`bytes` and :eql:type:`array` types: .. code-block:: edgeql-repl db> select len('foo'); {3} db> select len(b'bar'); {3} db> select len([2, 5, 7]); {3} ---------- .. eql:function:: std::contains(haystack: str, needle: str) -> bool std::contains(haystack: bytes, needle: bytes) -> bool std::contains(haystack: array, needle: anytype) \ -> bool std::contains(haystack: range, \ needle: range) \ -> std::bool std::contains(haystack: range, \ needle: anypoint) \ -> std::bool std::contains(haystack: multirange, \ needle: multirange) \ -> std::bool std::contains(haystack: multirange, \ needle: range) \ -> std::bool std::contains(haystack: multirange, \ needle: anypoint) \ -> std::bool .. index:: find, strpos, includes Returns true if the given sub-value exists within the given value. When *haystack* is a :eql:type:`str` or a :eql:type:`bytes` value, this function will return ``true`` if it contains *needle* as a subsequence within it or ``false`` otherwise: .. code-block:: edgeql-repl db> select contains('qwerty', 'we'); {true} db> select contains(b'qwerty', b'42'); {false} When *haystack* is an :eql:type:`array`, the function will return ``true`` if the array contains the element specified as *needle* or ``false`` otherwise: .. code-block:: edgeql-repl db> select contains([2, 5, 7, 2, 100], 2); {true} When *haystack* is a :ref:`range `, the function will return ``true`` if it contains either the specified sub-range or element. The function will return ``false`` otherwise. .. code-block:: edgeql-repl db> select contains(range(1, 10), range(2, 5)); {true} db> select contains(range(1, 10), range(2, 15)); {false} db> select contains(range(1, 10), 2); {true} db> select contains(range(1, 10), 10); {false} When *haystack* is a :ref:`multirange `, the function will return ``true`` if it contains either the specified multirange, sub-range or element. The function will return ``false`` otherwise. .. code-block:: edgeql-repl db> select contains( ... multirange([ ... range(1, 4), range(7), ... ]), ... multirange([ ... range(1, 2), range(8, 10), ... ]), ... ); {true} db> select contains( ... multirange([ ... range(1, 4), range(8, 10), ... ]), ... range(8), ... ); {false} db> select contains( ... multirange([ ... range(1, 4), range(8, 10), ... ]), ... 3, ... ); {true} When *haystack* is :ref:`JSON `, the function will return ``true`` if the json data contains the element specified as *needle* or ``false`` otherwise: .. code-block:: edgeql-repl db> with haystack := to_json('{ ... "city": "Baerlon", ... "city": "Caemlyn" ... }'), ... needle := to_json('{ ... "city": "Caemlyn" ... }'), ... select contains(haystack, needle); {true} ---------- .. eql:function:: std::find(haystack: str, needle: str) -> int64 std::find(haystack: bytes, needle: bytes) -> int64 std::find(haystack: array, needle: anytype, \ from_pos: int64=0) -> int64 .. index:: find, strpos Returns the index of a given sub-value in a given value. When *haystack* is a :eql:type:`str` or a :eql:type:`bytes` value, the function will return the index of the first occurrence of *needle* in it. When *haystack* is an :eql:type:`array`, this will return the index of the the first occurrence of the element passed as *needle*. For :eql:type:`array` inputs it is also possible to provide an optional *from_pos* argument to specify the position from which to start the search. If the *needle* is not found, return ``-1``. .. code-block:: edgeql-repl db> select find('qwerty', 'we'); {1} db> select find(b'qwerty', b'42'); {-1} db> select find([2, 5, 7, 2, 100], 2); {0} db> select find([2, 5, 7, 2, 100], 2, 1); {3} ================================================ FILE: docs/reference/stdlib/index.rst ================================================ .. versioned-section:: .. _ref_std: ================ Standard Library ================ .. toctree:: :maxdepth: 3 :hidden: generic set type math string bool numbers json uuid enum datetime array tuple range bytes sequence objects abstract constraints net fts sys cfg pgcrypto pg_trgm pg_unaccent pgvector postgis deprecated |Gel| comes with a rigorously defined type system consisting of **scalar types**, **collection types** (like arrays and tuples), and **object types**. There is also a library of built-in functions and operators for working with each datatype. .. _ref_datamodel_typesystem: Scalar Types ------------ .. _ref_datamodel_scalar_types: *Scalar types* store primitive data. - :ref:`Strings ` - :ref:`Numbers ` - :ref:`Booleans ` - :ref:`Dates and times ` - :ref:`Enums ` - :ref:`JSON ` - :ref:`UUID ` - :ref:`Bytes ` - :ref:`Sequences ` - :ref:`Abstract types `: these are the types that undergird the scalar hierarchy. .. _ref_datamodel_collection_types: Collection Types ---------------- *Collection types* are special generic types used to group homogeneous or heterogeneous data. - :ref:`Arrays ` - :ref:`Tuples ` Range Types ----------- - :ref:`Range ` - :ref:`Multirange ` Object Types ------------ - :ref:`Object Types ` Types and Sets -------------- - :ref:`Sets ` - :ref:`Types ` - :ref:`Casting ` Utilities --------- - :ref:`Math ` - :ref:`Comparison ` - :ref:`Constraints ` - :ref:`Full-text Search ` - :ref:`System ` Extensions ---------- - :ref:`ext::pgvector ` ================================================ FILE: docs/reference/stdlib/json.rst ================================================ .. _ref_std_json: ==== JSON ==== :edb-alt-title: JSON Functions and Operators .. list-table:: :class: funcoptable * - :eql:type:`json` - JSON scalar type * - :eql:op:`json[i] ` - :eql:op-desc:`jsonidx` * - :eql:op:`json[from:to] ` - :eql:op-desc:`jsonslice` * - :eql:op:`json ++ json ` - :eql:op-desc:`jsonplus` * - :eql:op:`json[name] ` - :eql:op-desc:`jsonobjdest` * - :eql:op:`= ` :eql:op:`\!= ` :eql:op:`?= ` :eql:op:`?!= ` :eql:op:`\< ` :eql:op:`\> ` :eql:op:`\<= ` :eql:op:`\>= ` - Comparison operators * - :eql:func:`to_json` - :eql:func-desc:`to_json` * - :eql:func:`to_str` - Render JSON value to a string. * - :eql:func:`json_get` - :eql:func-desc:`json_get` * - :eql:func:`json_set` - :eql:func-desc:`json_set` * - :eql:func:`json_array_unpack` - :eql:func-desc:`json_array_unpack` * - :eql:func:`json_object_pack` - :eql:func-desc:`json_object_pack` * - :eql:func:`json_object_unpack` - :eql:func-desc:`json_object_unpack` * - :eql:func:`json_typeof` - :eql:func-desc:`json_typeof` .. _ref_std_json_construction: Constructing JSON Values ------------------------ JSON in Gel is a :ref:`scalar type `. This type doesn't have its own literal, and instead can be obtained by either casting a value to the :eql:type:`json` type, or by using the :eql:func:`to_json` function: .. code-block:: edgeql-repl db> select to_json('{"hello": "world"}'); {Json("{\"hello\": \"world\"}")} db> select 'hello world'; {Json("\"hello world\"")} Any value in Gel can be cast to a :eql:type:`json` type as well: .. code-block:: edgeql-repl db> select 2019; {Json("2019")} db> select cal::to_local_date(datetime_current(), 'UTC'); {Json("\"2022-11-21\"")} The :eql:func:`json_object_pack` function provides one more way to construct JSON. It constructs a JSON object from an array of key/value tuples: .. code-block:: edgeql-repl db> select json_object_pack({("hello", "world")}); {Json("{\"hello\": \"world\"}")} Additionally, any :eql:type:`Object` in Gel can be cast as a :eql:type:`json` type. This produces the same JSON value as the JSON-serialized result of that said object. Furthermore, this result will be the same as the output of a :eql:stmt:`select expression `/ :eql:stmt:`offset ``` which renders as this: :eql:stmt:`the select statement { setPrompt(e.target.value); }} >
{question && (

{question}

)} {(isLoading && ) || (error &&

{error}

) || (answer &&

{answer}

)}
); } function ReturnIcon({ className }: { className?: string }) { return ( ); } function LoadingDots() { return (
); } .. lint-on We have created an input field where the user can enter a question. The text the user types in the input field is captured as ``prompt``. ``question`` is the submitted prompt that we show under the input when user submits their question. We clear the input and delete the prompt when user submits it, but keep the ``question`` value so the user can reference it. Let's look at the fleshed-out form submission handler function that we stubbed in earlier: .. code-block:: typescript :caption: app/page.tsx const handleSubmit = ( e: KeyboardEvent | React.MouseEvent ) => { e.preventDefault(); setIsLoading(true); setQuestion(prompt); setAnswer(""); setPrompt(""); generateAnswer(prompt); }; When the user submits a question, we set the ``isLoading`` state to ``true`` and show the loading indicator. We clear the prompt state and set the question state. We also clear the answer state because the answer may hold an answer to a previous question, but we want to start with an empty answer. At this point we want to create a server-sent event and send a request to our ``api/generate-answer`` route. We will do this inside the ``generateAnswer`` function. The browser-native SSE API doesn't allow the client to send a payload to the server; the client is only able to open a connection to the server to begin receiving events from it via a GET request. In order for the client to be able to send a payload via a POST request to open the SSE connection, we will use the `sse.js `_ package, so let's install it. .. code-block:: bash $ npm install sse.js This package doesn't have a corresponding types package, so we need to add them manually. Let's create a new folder named ``types`` in the project root and an ``sse.d.ts`` file inside it. .. code-block:: bash $ mkdir types && touch types/sse.d.ts Open ``sse.d.ts`` and add this code: .. code-block:: typescript :caption: types/sse.d.ts type SSEOptions = EventSourceInit & { payload?: string; }; declare module "sse.js" { class SSE extends EventSource { constructor(url: string | URL, sseOptions?: SSEOptions); stream(): void; } } This extends the native ``EventStream`` by adding a payload to the constructor. We also added the ``stream`` function to it which is used to activate the stream in the sse.js library. This allows us to import ``SSE`` in ``page.tsx`` and use it to open a connection to our handler route while also sending the user's query. .. code-block:: typescript-diff "use client"; - import { useState } from "react"; + import { useState, useRef } from "react"; + import { SSE } from "sse.js"; import { errors } from "./constants"; export default function Home() { + const eventSourceRef = useRef(); + const [prompt, setPrompt] = useState(""); const [question, setQuestion] = useState(""); const [answer, setAnswer] = useState(""); const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(undefined); const handleSubmit = () => {}; + + const generateAnswer = async (query: string) => { + if (eventSourceRef.current) eventSourceRef.current.close(); + + const eventSource = new SSE(`api/generate-answer`, { + payload: JSON.stringify({ query }), + }); + eventSourceRef.current = eventSource; + + eventSource.onerror = handleError; + eventSource.onmessage = handleMessage; + eventSource.stream(); + }; + + handleError() { /* … */ } + handleMessage() { /* … */ } // … Note that we save a reference to the ``eventSource`` object. We need this in case a user submits a new question while answer to the previous one is still assembling on the client. If we don't close the existing connection to the server before opening the new one, this could cause problems since two connections will be open and trying to receive data. We opened a connection to the server, and we are now ready to receive events from it. We just need to write handlers for those events so the UI knows what to do with them. We will get the answer as part of a message event, and if an error is returned, the server will send an error event to the client. Let's break down these handlers. .. code-block:: typescript :caption: app/page.tsx // … function handleError(err: any) { setIsLoading(false); const errMessage = err.data === errors.flagged ? errors.flagged : errors.default; setError(errMessage); } function handleMessage(e: MessageEvent) { try { setIsLoading(false); if (e.data === "[DONE]") return; const chunkResponse = JSON.parse(e.data); const chunk = chunkResponse.choices[0].delta?.content || ""; setAnswer((answer) => answer + chunk); } catch (err) { handleError(err); } } When we get the message event, we extract the data from it and add it to the ``answer`` state until we receive all chunks. This is indicated when the data is equal to ``[DONE]``, meaning the whole answer has been received and the connection to the server will be closed. There is no data to be parsed in this case, so we return instead of trying to parse it. (An error will be thrown if we try to parse it in this case.) The completed UI ---------------- Put all that together, and you have this (which can be copy/pasted to ``app/page.tsx``): .. lint-off .. code-block:: typescript :caption: app/page.tsx "use client"; import { useState, useRef } from "react"; import { SSE } from "sse.js"; import { errors } from "./constants"; export default function Home() { const eventSourceRef = useRef(); const [prompt, setPrompt] = useState(""); const [question, setQuestion] = useState(""); const [answer, setAnswer] = useState(""); const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(undefined); const handleSubmit = ( e: KeyboardEvent | React.MouseEvent ) => { e.preventDefault(); setIsLoading(true); setQuestion(prompt); setAnswer(""); setPrompt(""); generateAnswer(prompt); }; const generateAnswer = async (query: string) => { if (eventSourceRef.current) eventSourceRef.current.close(); const eventSource = new SSE(`api/generate-answer`, { payload: JSON.stringify({ query }), }); eventSourceRef.current = eventSource; eventSource.onerror = handleError; eventSource.onmessage = handleMessage; eventSource.stream(); }; function handleError(err: any) { setIsLoading(false); const errMessage = err.data === errors.flagged ? errors.flagged : errors.default; setError(errMessage); } function handleMessage(e: MessageEvent) { try { setIsLoading(false); if (e.data === "[DONE]") return; const chunkResponse = JSON.parse(e.data); const chunk = chunkResponse.choices[0].delta?.content || ""; setAnswer((answer) => answer + chunk); } catch (err) { handleError(err); } } return (
{ setPrompt(e.target.value); }} >
{question && (

{question}

)} {(isLoading && ) || (error &&

{error}

) || (answer &&

{answer}

)}
); } function ReturnIcon({ className }: { className?: string }) { return ( ); } function LoadingDots() { return (
); } .. lint-on With that, the UI can now get answers from the Next.js route. The build is complete, and it's time to try it out! Testing it out ============== You should now be able to run the project to test it. .. code-block:: bash $ npm run dev If you used our example documentation, the chatbot will know a few things about EdgeQL along with whatever it was trained on. Some questions you might try: - "What is EdgeQL?" - "Who is EdgeQL for?" - "How should I get started with EdgeQL?" If you don't like the responses you're getting, here are a few things you might try tweaking: - ``systemMessage`` in the ``createFullPrompt`` function in ``app/api/generate-answer/route.ts`` - ``temperature`` in the ``getOpenAiAnswer`` in ``app/api/generate-answer/route.ts`` - the ``matchThreshold`` value passed to the query from the ``getContext`` function in ``app/api/generate-answer/route.ts`` You can see the finished source code for this build in `our examples repo on GitHub `_. You might also find our actual implementation interesting. You'll find it in `our website repo `_. Pay close attention to the contents of `buildTools/gpt `_, where the embedding generation happens and `components/gpt `_, which contains most of the UI for our chatbot. If you have trouble with the build or just want to hang out with other Gel users, please join `our awesome community on Discord `_! ================================================ FILE: docs/resources/guides/tutorials/cloudflare_workers.rst ================================================ .. _ref_guide_cloudflare_workers: ================== Cloudflare Workers ================== :edb-alt-title: Using Gel in Cloudflare Workers This guide demonstrates how to integrate Gel with Cloudflare Workers to build serverless applications that can interact with Gel. It covers the following: - Setting up a new Cloudflare Worker project - Configuring Gel - Using Gel in a Cloudflare Worker - Deploying the Worker to Cloudflare You can use this project as a reference: `Gel Cloudflare Workers Example`_. Prerequisites ------------- `Sign up for a Cloudflare account`_ to later deploy your worker. Ensure you have the following installed: - `Node.js`_ - :ref:`Gel CLI ` or juse ``$ npx gel`` .. _Sign up for a Cloudflare account: https://dash.cloudflare.com/sign-up .. _Node.js: https://nodejs.org/en/ Setup and configuration ----------------------- Initialize a New Cloudflare Worker Project =========================================== Use the `create-cloudflare`_ package to create a new Cloudflare Worker project. .. _create-cloudflare: https://www.npmjs.com/package/create-cloudflare .. code-block:: bash $ npm create cloudflare@latest # or pnpm, yarn, bun # or $ npx create-cloudflare@latest Answer the prompts to create a new project. Pick the *"Hello World" Worker* template to get started. You'll be asked if you want to put your project on Cloudflare. If you say yes, you'll need to sign in (if you haven't already). If you don't want to deploy right away, switch to the project folder you just made to start writing your code. When you're ready to deploy your project on Cloudflare, you can run ``npx wrangler deploy`` to push it. .. note:: Using Wrangler CLI If you prefer using `Wrangler`_ to set up your worker, you can use the :code:`wrangler generate` command to create a new project. .. _Wrangler: https://developers.cloudflare.com/workers/cli-wrangler Configure Gel ============= You can use `Gel Cloud`_ for a managed service or run Gel locally. .. _`Gel Cloud`: https://www.geldata.com/cloud **Local Gel Setup (Optional for Gel Cloud Users)** If you're running Gel locally, you can use the following command to create a new instance: .. code-block:: bash $ gel project init It creates an |gel.toml| config file and a schema file :code:`dbschema/default.gel`. It also spins up a Gel instance and associates it with the current directory. As long as you're inside the project directory, all CLI commands will be executed against this instance. You can run |gelcmd| in your terminal to open an interactive REPL to your instance. .. code-block:: bash $ gel # or $ npx gel **Install the Gel npm package** .. code-block:: bash $ npm install gel # or pnpm, yarn, bun **Extend The Default Schema (Optional)** You can extend the default schema, :code:`dbschema/default.gel`, to define your data model, and then try it out in the Cloudflare Worker code. Add new types to the schema file: .. code-block:: sdl module default { type Movie { required title: str { constraint exclusive; }; multi actors: Person; } type Person { required name: str; } } Then apply the schema schema to your Gel instance: .. code-block:: bash $ gel migration create $ gel migrate Using Gel in a Cloudflare Worker ================================ Open the :code:`index.ts` file from the :code:`src` directory in your project, and remove the default code. To interact with your **local Gel instance**, use the following code: .. code-block:: typescript import * as gel from "gel"; export default { async fetch( _request: Request, env: Env, ctx: ExecutionContext, ): Promise { const client = gel.createHttpClient({ tlsSecurity: "insecure", dsn: "", }); const movies = await client.query(`select Movie { title }`); return new Response(JSON.stringify(movies, null, 2), { headers: { "content-type": "application/json;charset=UTF-8", }, }); }, } satisfies ExportedHandler; .. note:: Gel DSN Replace :code:`` with your Gel DSN. You can obtain your Gel DSN from the command line by running: .. code-block:: bash $ gel instance credentials --insecure-dsn .. note:: tlsSecurity The :code:`tlsSecurity` option is set to :code:`insecure` to allow connections to a local Gel instance. This lets you test your Cloudflare Worker locally. **Don't use this option in production.** **Client Setup with Gel Cloud** If you're using Gel Cloud, you can instead use the following code to set up the client: .. code-block:: typescript const client = gel.createHttpClient({ instanceName: env.GEL_INSTANCE, secretKey: env.GEL_SECRET_KEY, }); .. note:: Environment variables You can obtain :gelenv:`INSTANCE` and :gelenv:`SECRET_KEY` values from the Gel Cloud dashboard. You will need to set the :gelenv:`INSTANCE` and :gelenv:`SECRET_KEY` environment variables in your Cloudflare Worker project. Add the following to your :code:`wrangler.toml` file: .. code-block:: toml [vars] GEL_INSTANCE = "your-gel-instance" GEL_SECRET_KEY = "your-gel-secret-key" Next, you can run :code:`wrangler types` to generate the types for your environment variables. **Running the Worker** .. note:: Adding polyfills for Node.js The :code:`gel` package currently uses Node.js built-in modules that are not available in the Cloudflare Worker environment. You have to add the following line to your :code:`wrangler.toml` file to include the polyfills: .. code-block:: toml node_compat = true To run the worker locally, use the following command: .. code-block:: bash $ npm run dev # or pnpm, yarn, bun This will start a local server at :code:`http://localhost:8787`. Run :code:`curl http://localhost:8787` to see the response. **Deploying the Worker to Cloudflare** To deploy the worker to Cloudflare, use the following command: .. code-block:: bash $ npm run deploy # or pnpm, yarn, bun This will deploy the worker to Cloudflare and provide you with a URL to access your worker. Wrapping up =========== Congratulations! You have successfully integrated Gel with Cloudflare Workers. Here's a minimal starter project that you can use as a reference: `Gel Cloudflare Workers Example`_. Check out the `Cloudflare Workers documentation`_ for more information and to learn about the various features and capabilities of Cloudflare Workers. .. _`Gel Cloudflare Workers Example`: https://github.com/geldata/gel-examples/tree/main/cloudflare-workers .. _`Cloudflare Workers documentation`: https://developers.cloudflare.com/workers ================================================ FILE: docs/resources/guides/tutorials/graphql_apis_with_strawberry.rst ================================================ ========== Strawberry ========== :edb-alt-title: Building a GraphQL API with Gel and Strawberry |Gel| allows you to query your database with GraphQL via the built-in GraphQL extension. It enables you to expose GraphQL-driven CRUD APIs for all object types, their properties, links, and aliases. This opens up the scope for creating backend-less applications where the users will directly communicate with the database. You can learn more about that in the :ref:`GraphQL ` section of the docs. However, as of now, Gel is not ready to be used as a standalone backend. You shouldn't expose your Gel instance directly to the application's frontend; this is insecure and will give all users full read/write access to your database. So, in this tutorial, we'll see how you can quickly create a simple GraphQL API without using the built-in extension, which will give the users restricted access to the database schema. Also, we'll implement HTTP basic authentication and demonstrate how you can write your own GraphQL validators and resolvers. This tutorial assumes you're already familiar with GraphQL terms like schema, query, mutation, resolver, validator, etc, and have used GraphQL with some other technology before. We'll build the same movie organization system that we used in the Flask :ref:`tutorial ` and expose the objects and relationships as a GraphQL API. Using the GraphQL interface, you'll be able to fetch, create, update, and delete movie and actor objects in the database. `Strawberry `_ is a Python library that takes a code-first approach where you'll write your object schema as Python classes. This allows us to focus more on how you can integrate Gel into your workflow and less on the idiosyncrasies of GraphQL itself. We'll also use the Gel client to communicate with the database, `FastAPI `_ to build the authentication layer, and Uvicorn as the webserver. Prerequisites ============= Before we start, make sure you have :ref:`installed ` the |gelcmd| command-line tool. Here, we'll use Python 3.10 and a few of its latest features while building the APIs. A working version of this tutorial can be found `on Github `_. Install the dependencies ^^^^^^^^^^^^^^^^^^^^^^^^ To follow along, clone the repository and head over to the ``strawberry-gql`` directory. .. code-block:: bash $ git clone git@github.com:geldata/gel-examples.git $ cd gel-examples/strawberry-gql Create a Python 3.10 virtual environment, activate it, and install the dependencies with this command: .. code-block:: bash $ python3.10 -m venv .venv $ source .venv/bin/activate $ pip install gel fastapi strawberry-graphql uvicorn[standard] Initialize the database ^^^^^^^^^^^^^^^^^^^^^^^ Now, let's initialize a Gel project. From the project's root directory: .. code-block:: bash $ gel project init Initializing project... Specify the name of Gel instance to use with this project [default: strawberry_crud]: > strawberry_crud Do you want to start instance automatically on login? [y/n] > y Checking Gel versions... Once you've answered the prompts, a new Gel instance called ``strawberry_crud`` will be created and started. Connect to the database ^^^^^^^^^^^^^^^^^^^^^^^ Let's test that we can connect to the newly started instance. To do so, run: .. code-block:: bash $ gel You should be connected to the database instance and able to see a prompt similar to this: :: Gel x.x (repl x.x) Type \help for help, \quit to quit. gel> You can start writing queries here. However, the database is currently empty. Let's start designing the data model. Schema design ============= The movie organization system will have two object types—**movies** and **actors**. Each *movie* can have links to multiple *actors*. The goal is to create a GraphQL API suite that'll allow us to fetch, create, update, and delete the objects while maintaining their relationships. |Gel| allows us to declaratively define the structure of the objects. The schema lives inside |.gel| file in the ``dbschema`` directory. It's common to declare the entire schema in a single file :dotgel:`dbschema/default`. This is how our datatypes look: .. code-block:: sdl # dbschema/default.gel module default { abstract type Auditable { property created_at -> datetime { readonly := true; default := datetime_current(); } } type Actor extending Auditable { required property name -> str { constraint max_len_value(50); } property age -> int16 { constraint min_value(0); constraint max_value(100); } property height -> int16 { constraint min_value(0); constraint max_value(300); } } type Movie extending Auditable { required property name -> str { constraint max_len_value(50); } property year -> int16{ constraint min_value(1850); }; multi link actors -> Actor; } } Here, we've defined an ``abstract`` type called ``Auditable`` to take advantage of Gel's schema mixin system. This allows us to add a ``created_at`` property to multiple types without repeating ourselves. The ``Actor`` type extends ``Auditable`` and inherits the ``created_at`` property as a result. This property is auto-filled via the ``datetime_current`` function. Along with the inherited type, the actor type also defines a few additional properties like called ``name``, ``age``, and ``height``. The constraints on the properties make sure that actor names can't be longer than 50 characters, age must be between 0 to 100 years, and finally, height must be between 0 to 300 centimeters. We also define a ``Movie`` type that extends the ``Auditable`` abstract type. It also contains some additional concrete properties and links: ``name``, ``year``, and an optional multi-link called ``actors`` which refers to the ``Actor`` objects. Build the GraphQL API ===================== The API endpoints are defined in the ``app`` directory. The directory structure looks as follows: :: app ├── __init__.py ├── main.py └── schemas.py The ``schemas.py`` module contains the code that defines the GraphQL schema and builds the queries and mutations for ``Actor`` and ``Movie`` objects. The ``main.py`` module then registers the GraphQL schema, adds the authentication layer, and exposes the API to the webserver. Write the GraphQL schema ^^^^^^^^^^^^^^^^^^^^^^^^^^^ Along with the database schema, to expose Gel's object relational model as a GraphQL API, you'll also have to define a GraphQL schema that mirrors the object structure in the database. Strawberry allows us to express this schema via type annotated Python classes. We define the Strawberry schema in the ``schema.py`` file as follows: .. code-block:: python # strawberry-gql/app/schema.py from __future__ import annotations import json # will be used later for serialization import gel import strawberry client = gel.create_async_client() @strawberry.type class Actor: name: str | None age: int | None = None height: int | None = None @strawberry.type class Movie: name: str | None year: int | None = None actors: list[Actor] | None = None Here, the GraphQL schema mimics our database schema. Similar to the ``Actor`` and ``Movie`` types in the Gel schema, here, both the ``Actor`` and ``Movie`` models have three attributes. Likewise, the ``actors`` attribute in the ``Movie`` model represents the link between movies and actors. Query actors ^^^^^^^^^^^^ In this section, we'll write the resolver to create the queries that'll allow us to fetch the actor objects from the database. You'll need to write the query resolvers as methods in a class decorated with the ``@strawberry.type`` decorator. Each method will also need to be decorated with the ``@strawberry.field`` decorator to mark them as resolvers. Resolvers can be either sync or async. In this particular case, we'll write asynchronous resolvers that'll act in a non-blocking manner. The query to fetch the actors is built in the ``schema.py`` file as follows: .. code-block:: python # strawberry-gql/app/schema.py ... @strawberry.type class Query: @strawberry.field async def get_actors( self, filter_name: str | None = None ) -> list[Actor]: if filter_name: actors_json = await client.query_json( """ select Actor {name, age, height} filter .name=$filter_name """, filter_name=filter_name, ) else: actors_json = await client.query_json( """ select Actor {name, age, height} """ ) actors = json.loads(actors_json) return [ Actor(name, age, height) for (name, age, height) in ( d.values() for d in actors ) ] # Register the Query. schema = strawberry.Schema(query=Query) Here, the ``get_actors`` resolver method accepts an optional ``filter_name`` parameter and returns a list of ``Actor`` type objects. The optional ``filter_name`` parameter allows us to build the capability of filtering the actor objects by name. Inside the method, we use the Gel client to asynchronously query the data. The ``client.query_json`` method returns JSON serialized data which we use to create the ``Actor`` instances. Finally, we return the list of actor instances and the rest of the work is done by Strawberry. Then in the last line of the above snippet, we register the ``Query`` class to build the ``Schema`` instance. Afterward, in the ``main.py`` module, we use FastAPI to expose the ``/graphql`` endpoint. Also, we add a basic HTTP authentication layer to demonstrate how you can easily protect your GraphQL endpoint by leveraging FastAPI's dependency injection system. Here's how the content of the ``main.py`` looks: .. code-block:: python # strawberry-gql/app/main.py from __future__ import annotations import secrets from typing import Literal from fastapi import ( Depends, FastAPI, HTTPException, Request, Response, status ) from fastapi.security import HTTPBasic, HTTPBasicCredentials from strawberry.fastapi import GraphQLRouter from app.schema import schema app = FastAPI() router = GraphQLRouter(schema) security = HTTPBasic() def auth( credentials: HTTPBasicCredentials = Depends(security) ) -> Literal[True]: """Simple HTTP Basic Auth.""" correct_username = secrets.compare_digest( credentials.username, "ubuntu" ) correct_password = secrets.compare_digest( credentials.password, "debian" ) if not (correct_username and correct_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", headers={"WWW-Authenticate": "Basic"}, ) return True @router.api_route("/", methods=["GET", "POST"]) async def graphql(request: Request) -> Response: return await router.handle_graphql(request=request) app.include_router( router, prefix="/graphql", dependencies=[Depends(auth)] ) First, we initialize the ``FastAPI`` app instance which will communicate with the Uvicorn webserver. Then we attach the initialized ``schema`` instance to the ``GraphQLRouter``. The ``HTTPBasic`` class provides the machinery required to add the authentication layer. The ``auth`` function houses the implementation details of how we're comparing the incoming and expected username and passwords as well as how the webserver is going to handle unauthorized requests. The ``graphql`` handler function is the one that handles the incoming HTTP requests. Finally, the router instance and the security handler are registered to the app instance via the ``app.include_router`` method. We can now start querying the ``/graphql`` endpoint. We'll use the built-in GraphiQL interface to perform the queries. Before that, let's start the Uvicorn webserver first. Run: .. code-block:: bash $ uvicorn app.main:app --port 5000 --reload This exposes the webserver in port 5000. Now, in your browser, go to `http://localhost:5000/graphql `_. Here, you'll find that the HTTP basic auth requires us to provide the username and password. .. image:: /docs/tutorials/strawberry/http_basic.png :alt: HTTP basic auth prompt :width: 100% Currently, the allowed username and password is ``ubuntu`` and ``debian`` respectively. Provide the credentials and you'll be taken to a page that looks like this: .. image:: /docs/tutorials/strawberry/graphiql.png :alt: GraphiQL interface :width: 100% You can write your GraphQL queries here. Let's write a query that'll fetch all the actors in the database and show all three of their attributes. The following query does that: .. code-block:: graphql query ActorQuery { getActors { age height name } } The following response will appear on the right panel of the GraphiQL explorer: .. image:: /docs/tutorials/strawberry/query_actors.png :alt: Query actors :width: 100% Since as of now, the database doesn't have any data, the payload is returning an empty list. Let's write a mutation and create some actors. Mutate actors ^^^^^^^^^^^^^^^ Mutations are also written in the ``schema.py`` file. To write a mutation, you'll have to create a separate class where you'll write the mutation resolvers. The resolver methods will need to be decorated with the ``@strawberry.mutation`` decorator. You can write the mutation that'll create an actor object in the database as follows: .. code-block:: python # strawberry-gql/app/schema.py ... @strawberry.type class Mutation: @strawberry.mutation async def create_actor( self, name: str, age: int | None = None, height: int | None = None ) -> ResponseActor: actor_json = await client.query_single_json( """ with new_actor := ( insert Actor { name := $name, age := $age, height := $height } ) select new_actor {name, age, height} """, name=name, age=age, height=height, ) actor = json.loads(actor_json) return Actor( actor.get("name"), actor.get("age"), actor.get("height") ) # Mutation class needs to be registered here. schema = strawberry.Schema(query=Query, mutation=Mutation) Creating a mutation also includes data validation. By type annotating the ``Mutation`` class, we're implicitly asking Strawberry to perform data validation on the incoming request payload. Strawberry will raise an HTTP 400 error if the validation fails. Let's create an actor. Submit the following GraphQL query in the GraphiQL interface: .. code-block:: graphql mutation ActorMutation { __typename createActor( name: "Robert Downey Jr.", age: 57, height: 173 ) { age height name } } In the above mutation, ``name`` is a required field and the remaining two are optional fields. This mutation will create an actor named ``Robert Downey Jr.`` and show all three attributes— ``name``, ``age``, and ``height`` of the created actor in the response payload. Here's the response: .. image:: /docs/tutorials/strawberry/create_actor.png :alt: Create an actor :width: 100% Now that we've created an actor object, we can run the previously created query to fetch the actors. Running the ``ActorQuery`` will give you the following response: .. image:: /docs/tutorials/strawberry/query_actors_2.png :alt: Query actors :width: 100% You can also filter actors by their names. To do so, you'd leverage the ``filterName`` parameter of the ``getActors`` resolver: .. code-block:: graphql query ActorQuery { __typename getActors(filterName: "Robert Downey Jr.") { age height name } } This will only display the filtered results. Similarly, as shown above, you can write the mutations to update and delete actors. Their implementations can be found in the ``schema.py`` file. Check out ``update_actors`` and ``delete_resolvers`` to learn more about their implementation details. You can update one or more attributes of an actor with the following mutation: .. code-block:: graphql mutation ActorMutation { __typename updateActors(filterName: "Robert Downey Jr.", age: 60) { name age height } } Running this mutation will update the ``age`` of ``Robert Downey Jr.``. First, we filter the objects that we want to mutate via the ``filterName`` parameter and then we update the relevant attributes; in this case, we updated the ``age`` of the object. Finally, we show all the fields in the return payload. Use the GraphiQL explorer to interactively play with the full API suite. Query movies ^^^^^^^^^^^^ In the ``schema.py`` file, the query to fetch movies is constructed as follows: .. code-block:: python # strawberry-gql/app/schema.py ... @strawberry.type class Query: ... @strawberry.field async def get_movies( self, filter_name: str | None = None, ) -> list[Movie]: if filter_name: movies_json = await client.query_json( """ select Movie {name, year, actors: {name, age, height}} filter .name=$filter_name """, filter_name=filter_name, ) else: movies_json = await client.query_json( """ select Movie {name, year, actors: {name, age, height}} """ ) # Deserialize. movies = json.loads(movies_json) for idx, movie in enumerate(movies): actors = [ Actor(name) for d in movie.get("actors", []) for name in d.values() ] movies[idx] = Movie( movie.get("name"), movie.get("year"), actors ) return movies Similar to the actor query, this also allows you to either fetch all or filter movies by the movie names. Execute the following query to see the movies in the database: .. code-block:: graphql query MovieQuery { __typename getMovies { actors { age height name } name year } } This will return an empty list since the database doesn't contain any movies. In the next section, we'll create a mutation to create the movies and query them afterward. Mutate movies ^^^^^^^^^^^^^ Before running any query to fetch the movies, let's see how you'd construct a mutation that allows you to create movies. You can build the mutation similar to how we've constructed the create actor mutation. It looks like this: .. code-block:: python # strawberry-gql/app/schema.py ... @strawberry.type class Mutation: ... @strawberry.mutation async def create_movie( self, name: str, year: int | None = None, actor_names: list[str] | None = None, ) -> Movie: movie_json = await client.query_single_json( """ with name := $name, year := $year, actor_names := >$actor_names, new_movie := ( insert Movie { name := name, year := year, actors := ( select detached Actor filter .name in array_unpack(actor_names) ) } ) select new_movie { name, year, actors: {name, age, height} } """, name=name, year=year, actor_names=actor_names, ) movie = json.loads(movie_json) actors = [ Actor(name) for d in movie.get("actors", []) for name in d.values()] return Movie( movie.get("name"), movie.get("year"), actors ) You can submit a request to this mutation to create a movie. While creating a movie, you must provide the name of the movie as it's a required field. Also, you can optionally provide the ``year`` the movie was released and an array containing the names of the actors. If the values of the ``actor_names`` field match any existing actor in the database, the above snippet makes sure that the movie will be linked with the corresponding actors. In the GraphiQL explorer, run the following mutation to create a movie named ``Avengers`` and link the actor ``Robert Downey Jr.`` with the movie: .. code-block:: graphql mutation MovieMutation { __typename createMovie( name: "Avengers", actorNames: ["Robert Downey Jr."], year: 2012 ) { actors { name } } } It'll return: .. image:: /docs/tutorials/strawberry/create_movie.png :alt: Create a movie :width: 100% Now you can fetch the movies with a simple query like this one: .. code-block:: graphql query MovieQuery { __typename getMovies { name year actors { name } } } You'll then see an output similar to this: .. image:: /docs/tutorials/strawberry/query_movies.png :alt: Query movies :width: 100% Take a look at the ``update_movies`` and ``delete_movies`` resolvers to gain more insights into the implementation details of those mutations. Conclusion ========== In this tutorial, you've seen how can use Strawberry and Gel together to quickly build a fully-featured GraphQL API. Also, you have seen how FastAPI allows you add an authentication layer and serve the API in a secure manner. One thing to keep in mind here is, ideally, you'd only use GraphQL if you're interfacing with something that already expects a GraphQL API. Otherwise, EdgeQL is always going to be more powerful and expressive than GraphQL's query syntax. ================================================ FILE: docs/resources/guides/tutorials/index.rst ================================================ .. _ref_guide_tutorials: ================= Using Gel with... ================= .. toctree:: :maxdepth: 1 nextjs_app_router nextjs_pages_router rest_apis_with_fastapi rest_apis_with_flask jupyter_notebook graphql_apis_with_strawberry chatgpt_bot cloudflare_workers trpc Bun ================================================ FILE: docs/resources/guides/tutorials/jupyter_notebook.rst ================================================ .. _ref_guide_jupyter_notebook: ================ Jupyter Notebook ================ :edb-alt-title: Using Gel with Jupyter Notebook 1. `Install Jupyter Notebook `__ 2. Install the Gel Python library with ``pip install gel`` 3. Set the appropriate :ref:`connection environment variables ` required for your Gel instance **For Gel Cloud instances** - :gelenv:`INSTANCE`- your instance name (``/``) - :gelenv:`SECRET_KEY`- a secret key with permissions for the selected instance. .. note:: You may create a secret key with the CLI by running :gelcmd:`cloud secretkey create` or in the `Gel Cloud UI `__. **For other remote instances** - :gelenv:`DSN`- the DSN of your remote instance .. note:: DSNs take the following format: :geluri:`:@:/`. Omit any segment, and Gel will fall back to a default value listed in :ref:`our DSN specification ` **For local Gel instances** - :gelenv:`INSTANCE`- your instance name - :gelenv:`USER` & :gelenv:`PASSWORD` .. note :: Usernames and passwords Gel creates an |admin| user by default, but the password is randomized. You may set the password for this role by running ``alter role admin { set password := ''; };`` or you may create a new role using ``create superuser role { set password := ''; };``. 4. Start your notebook by running ``jupyter notebook``. Make sure this process runs in the same environment that contains the variables you set in step 3. 5. Create a new notebook. 6. In one of your notebook's blocks, import the Gel library and run a query. .. code-block:: python import gel client = gel.create_client() def main(): query = "SELECT 1 + 1;" # Swap in any query you want result = client.query(query) print(result[0]) main() client.close() ================================================ FILE: docs/resources/guides/tutorials/nextjs_app_router.rst ================================================ .. _ref_guide_nextjs_app_router: ==================== Next.js (App Router) ==================== :edb-alt-title: Building a simple blog application with Gel and Next.js (App Router) We're going to build a simple blog application with `Next.js `_ and Gel. Let's start by scaffolding our app with Next.js's ``create-next-app`` tool. You'll be prompted to provide a name (we'll use ``nextjs-blog``) for your app and choose project options. For this tutorial, we'll go with the recommended settings including TypeScript, App Router, and **opt-ing out** of the ``src/`` directory. .. code-block:: bash $ npx create-next-app@latest ✔ Would you like to use TypeScript? Yes ✔ Would you like to use ESLint? Yes ✔ Would you like to use Tailwind CSS? Yes ✔ Would you like to use src/ directory? No ✔ Would you like to use App Router? (recommended) Yes ✔ Would you like to customize the default import alias (@/*) Yes The scaffolding tool will create a simple Next.js app and install its dependencies. Once it's done, you can navigate to the app's directory and start the development server. .. code-block:: bash $ cd nextjs-blog $ npm dev # or yarn dev or pnpm dev or bun run dev When the dev server starts, it will log out a local URL. Visit that URL to see the default Next.js homepage. At this point the app's file structure looks like this: .. code-block:: README.md tsconfig.json package.json next.config.js next-env.d.ts postcss.config.js tailwind.config.js app ├── page.tsx ├── layout.tsx ├── globals.css └── favicon.ico public ├── next.tsx └── vercel.svg There's an async function ``Home`` defined in ``app/page.tsx`` that renders the homepage. It's a `Server Component `_ which lets you integrate server-side logic directly into your React components. Server Components are executed on the server and can fetch data from a database or an API. We'll use this feature to load blog posts from a Gel database. Updating the homepage --------------------- Let's start by implementing a simple homepage for our blog application using static data. Replace the contents of ``app/page.tsx`` with the following. .. code-block:: tsx :caption: app/page.tsx import Link from 'next/link' type Post = { id: string title: string content: string } export default async function Home() { const posts: Post[] = [ { id: 'post1', title: 'This one weird trick makes using databases fun', content: 'Use Gel', }, { id: 'post2', title: 'How to build a blog with Gel and Next.js', content: "Let's start by scaffolding our app with `create-next-app`.", }, ] return (

Posts

    {posts.map((post) => (
  • {post.title}
  • ))}
) } After saving, you can refresh the page to see the blog posts. Clicking on a post title will take you to a page that doesn't exist yet. We'll create that page later in the tutorial. Initializing Gel ---------------- Now let's spin up a database for the app. You have two options to initialize a Gel project: using ``$ npx gel`` without installing the CLI, or installing the gel CLI directly. In this tutorial, we'll use the first option. If you prefer to install the CLI, see the :ref:`Gel CLI guide ` for more information. From the application's root directory, run the following command: .. code-block:: bash $ npx gel project init No `gel.toml` found in `~/nextjs-blog` or above Do you want to initialize a new project? [Y/n] > Y Specify the name of Gel instance to use with this project [default: nextjs_blog]: > nextjs_blog Checking Gel versions... Specify the version of Gel to use with this project [default: x.x]: > ┌─────────────────────┬──────────────────────────────────────────────┐ │ Project directory │ ~/nextjs-blog │ │ Project config │ ~/nextjs-blog/gel.toml │ │ Schema dir (empty) │ ~/nextjs-blog/dbschema │ │ Installation method │ portable package │ │ Start configuration │ manual │ │ Version │ x.x │ │ Instance name │ nextjs_blog │ └─────────────────────┴──────────────────────────────────────────────┘ Initializing Gel instance... Applying migrations... Everything is up to date. Revision initial. Project initialized. This process has spun up a Gel instance called ``nextjs_blog`` and associated it with your current directory. As long as you're inside that directory, CLI commands and client libraries will be able to connect to the linked instance automatically, without additional configuration. To test this, run the |gelcmd| command to open a REPL to the linked instance. .. code-block:: bash $ gel Gel x.x (repl x.x) Type \help for help, \quit to quit. gel> select 2 + 2; {4} > From inside this REPL, we can execute EdgeQL queries against our database. But there's not much we can do currently, since our database is schemaless. Let's change that. The project initialization process also created a new subdirectory in our project called ``dbschema``. This is folder that contains everything pertaining to Gel. Currently it looks like this: .. code-block:: dbschema ├── default.gel └── migrations The :dotgel:`default` file will contain our schema. The ``migrations`` directory is currently empty, but will contain our migration files. Let's update the contents of :dotgel:`default` with the following simple blog schema. .. code-block:: sdl :caption: dbschema/default.gel module default { type BlogPost { required title: str; required content: str { default := "" } } } .. note:: Gel lets you split up your schema into different ``modules`` but it's common to keep your entire schema in the ``default`` module. Save the file, then let's create our first migration. .. code-block:: bash $ npx gel migration create did you create object type 'default::BlogPost'? [y,n,l,c,b,s,q,?] > y Created ./dbschema/migrations/00001.edgeql The ``dbschema/migrations`` directory now contains a migration file called ``00001.edgeql``. Currently though, we haven't applied this migration against our database. Let's do that. .. code-block:: bash $ npx gel migrate Applied m1fee6oypqpjrreleos5hmivgfqg6zfkgbrowx7sw5jvnicm73hqdq (00001.edgeql) Our database now has a schema consisting of the ``BlogPost`` type. We can create some sample data from the REPL. Run the |gelcmd| command to re-open the REPL. .. code-block:: bash $ gel Gel x.x (repl x.x) Type \help for help, \quit to quit. gel> Then execute the following ``insert`` statements. .. code-block:: edgeql-repl gel> insert BlogPost { .... title := "This one weird trick makes using databases fun", .... content := "Use Gel" .... }; {default::BlogPost {id: 7f301d02-c780-11ec-8a1a-a34776e884a0}} gel> insert BlogPost { .... title := "How to build a blog with Gel and Next.js", .... content := "Let's start by scaffolding our app..." .... }; {default::BlogPost {id: 88c800e6-c780-11ec-8a1a-b3a3020189dd}} Loading posts with React Server Components ------------------------------------------ Now that we have a couple posts in the database, let's load them into our Next.js app. To do that, we'll need the ``gel`` client library. Let's install that from NPM: .. code-block:: bash $ npm install gel # or 'yarn add gel' or 'pnpm add gel' or 'bun add gel' Then go to the ``app/page.tsx`` file to replace the static data with the blogposts fetched from the database. To fetch these from the homepage, we'll create a Gel client and use the ``.query()`` method to fetch all the posts in the database with a ``select`` statement. .. code-block:: tsx-diff :caption: app/page.tsx import Link from 'next/link' + import { createClient } from 'gel'; type Post = { id: string title: string content: string } + const client = createClient(); export default async function Home() { - const posts: Post[] = [ - { - id: 'post1', - title: 'This one weird trick makes using databases fun', - content: 'Use Gel', - }, - { - id: 'post2', - title: 'How to build a blog with Gel and Next.js', - content: "Start by scaffolding our app with `create-next-app`.", - }, - ] + const posts = await client.query(`\ + select BlogPost { + id, + title, + content + };`) return (

Posts

    {posts.map((post) => (
  • {post.title}
  • ))}
) } When you refresh the page, you should see the blog posts. Generating the query builder ---------------------------- Since we're using TypeScript, it makes sense to use Gel's powerful query builder. This provides a schema-aware client API that makes writing strongly typed EdgeQL queries easy and painless. The result type of our queries will be automatically inferred, so we won't need to manually type something like ``type Post = { id: string; ... }``. First, install the generator to your project. .. code-block:: bash $ npm install --save-dev @gel/generate $ # or yarn add --dev @gel/generate $ # or pnpm add --dev @gel/generate $ # or bun add --dev @gel/generate Then generate the query builder with the following command. .. code-block:: bash $ npx @gel/generate edgeql-js Generating query builder... Detected tsconfig.json, generating TypeScript files. To override this, use the --target flag. Run `npx @gel/generate --help` for full options. Introspecting database schema... Writing files to ./dbschema/edgeql-js Generation complete! 🤘 Checking the generated query builder into version control is not recommended. Would you like to update .gitignore to ignore the query builder directory? The following line will be added: dbschema/edgeql-js [y/n] (leave blank for "y") > y This command introspected the schema of our database and generated some code in the ``dbschema/edgeql-js`` directory. It also asked us if we wanted to add the generated code to our ``.gitignore``; typically it's not good practice to include generated files in version control. Back in ``app/page.tsx``, let's update our code to use the query builder instead. .. code-block:: typescript-diff :caption: app/page.tsx import Link from 'next/link' import { createClient } from 'gel'; + import e from '@/dbschema/edgeql-js'; - type Post = { - id: string - title: string - content: string - } const client = createClient(); export default async function Home() { - const posts = await client.query(`\ - select BlogPost { - id, - title, - content - };`) + const selectPosts = e.select(e.BlogPost, () => ({ + id: true, + title: true, + content: true, + })); + const posts = await selectPosts.run(client); return (

Posts

    {posts.map((post) => (
  • {post.title}
  • ))}
) } Instead of writing our query as a plain string, we're now using the query builder to declare our query in a code-first way. As you can see, we import the query builder as a single default import ``e`` from the ``dbschema/edgeql-js`` directory. Now, when we update our ``selectPosts`` query, the type of our dynamically loaded ``posts`` variable will update automatically — no need to keep our type definitions in sync with our API logic! Rendering blog posts -------------------- Our homepage renders a list of links to each of our blog posts, but we haven't implemented the page that actually displays the posts. Let's create a new page at ``app/post/[id]/page.tsx``. This is a `dynamic route `_ that includes an ``id`` URL parameter. We'll use this parameter to fetch the appropriate post from the database. Add the following code in ``app/post/[id]/page.tsx``: .. code-block:: tsx :caption: app/post/[id]/page.tsx import { createClient } from 'gel' import e from '@/dbschema/edgeql-js' import Link from 'next/link' const client = createClient() export default async function Post({ params }: { params: { id: string } }) { const post = await e .select(e.BlogPost, (post) => ({ id: true, title: true, content: true, filter_single: e.op(post.id, '=', e.uuid(params.id)), })) .run(client) if (!post) { return
Post not found
} return (

{post.title}

{post.content}

) } We are again using a Server Component to fetch the post from the database. This time, we're using the ``filter_single`` method to filter the ``BlogPost`` type by its ``id``. We're also using the ``uuid`` function from the query builder to convert the ``id`` parameter to a UUID. Now, click on one of the blog post links on the homepage. This should bring you to ``/post/``. Deploying to Vercel ------------------- You can deploy a Gel instance on the Gel Cloud or on your preferred cloud provider. We'll cover both options here. With Gel Cloud ============== **#1 Deploy Gel** First, sign up for an account at `cloud.geldata.com `_ and create a new instance. Create and make note of a secret key for your Gel Cloud instance. You can create a new secret key from the "Secret Keys" tab in the Gel Cloud console. We'll need this later to connect to the database from Vercel. Run the following command to migrate the project to the Gel Cloud: .. code-block:: bash $ npx gel migrate -I / .. note:: Alternatively, if you want to restore your data from a local instance to the cloud, you can use the :gelcmd:`dump` and :gelcmd:`restore` commands. .. code-block:: bash $ npx gel dump $ npx gel restore -I / The migrations and schema will be automatically applied to the cloud instance. **#2 Set up a `prebuild` script** Add the following ``prebuild`` script to your ``package.json``. When Vercel initializes the build, it will trigger this script which will generate the query builder. The ``npx @gel/generate edgeql-js`` command will read the value of the :gelenv:`SECRET_KEY` and :gelenv:`INSTANCE` variables, connect to the database, and generate the query builder before Vercel starts building the project. .. code-block:: javascript-diff // package.json "scripts": { "dev": "next dev", "build": "next build", "start": "next start", "lint": "next lint", + "prebuild": "npx @gel/generate edgeql-js" }, **#3 Deploy to Vercel** Push your project to GitHub or some other Git remote repository. Then deploy this app to Vercel with the button below. .. XXX -- update URL .. lint-off .. image:: https://vercel.com/button :width: 150px :target: https://vercel.com/new/git/external?repository-url=https://github.com/geldata/gel-examples/tree/main/nextjs-blog&project-name=nextjs-edgedb-blog&repository-name=nextjs-edgedb-blog&env=EDGEDB_DSN,EDGEDB_CLIENT_TLS_SECURITY .. lint-on In "Configure Project," expand "Environment Variables" to add two variables: - :gelenv:`INSTANCE` containing your Gel Cloud instance name (in ``/`` format) - :gelenv:`SECRET_KEY` containing the secret key you created and noted previously. **#4 View the application** Once deployment has completed, view the application at the deployment URL supplied by Vercel. With other cloud providers =========================== **#1 Deploy Gel** Check out the following guides for deploying Gel to your preferred cloud provider: - :ref:`AWS ` - :ref:`Google Cloud ` - :ref:`Azure ` - :ref:`DigitalOcean ` - :ref:`Fly.io ` - :ref:`Docker ` (cloud-agnostic) **#2 Find your instance's DSN** The DSN is also known as a connection string. It will have the format :geluri:`username:password@hostname:port`. The exact instructions for this depend on which cloud you are deploying to. **#3 Apply migrations** Use the DSN to apply migrations against your remote instance. .. code-block:: bash $ npx gel migrate --dsn --tls-security insecure .. note:: You have to disable TLS checks with ``--tls-security insecure``. All Gel instances use TLS by default, but configuring it is out of scope of this project. Once you've applied the migrations, consider creating some sample data in your database. Open a REPL and ``insert`` some blog posts: .. code-block:: bash $ npx gel --dsn --tls-security insecure Gel x.x (repl x.x) Type \help for help, \quit to quit. gel> insert BlogPost { title := "Test post" }; {default::BlogPost {id: c00f2c9a-cbf5-11ec-8ecb-4f8e702e5789}} **#4 Set up a `prebuild` script** Add the following ``prebuild`` script to your ``package.json``. When Vercel initializes the build, it will trigger this script which will generate the query builder. The ``npx @gel/generate edgeql-js`` command will read the value of the :gelenv:`DSN` variable, connect to the database, and generate the query builder before Vercel starts building the project. .. code-block:: javascript-diff // package.json "scripts": { "dev": "next dev", "build": "next build", "start": "next start", "lint": "next lint", + "prebuild": "npx @gel/generate edgeql-js" }, **#5 Deploy to Vercel** Deploy this app to Vercel with the button below. .. lint-off .. image:: https://vercel.com/button :width: 150px :target: https://vercel.com/new/git/external?repository-url=https://github.com/geldata/gel-examples/tree/main/nextjs-blog&project-name=nextjs-edgedb-blog&repository-name=nextjs-edgedb-blog&env=EDGEDB_DSN,EDGEDB_CLIENT_TLS_SECURITY .. lint-on When prompted: - Set :gelenv:`DSN` to your database's DSN - Set :gelenv:`CLIENT_TLS_SECURITY` to ``insecure``. This will disable Gel's default TLS checks; configuring TLS is beyond the scope of this tutorial. .. XXX -- update URL .. image:: https://www.geldata.com/docs/tutorials/nextjs/env.png :alt: Setting environment variables in Vercel :width: 100% **#6 View the application** Once deployment has completed, view the application at the deployment URL supplied by Vercel. Wrapping up ----------- This tutorial demonstrates how to work with Gel in a Next.js app, using the App Router. We've created a simple blog application that loads posts from a database and displays them on the homepage. We've also created a dynamic route that fetches a single post from the database and displays it on a separate page. The next step is to add a ``/newpost`` page with a form for writing new blog posts and saving them into Gel. That's left as an exercise for the reader. To see the final code for this tutorial, refer to `github.com/geldata/gel-examples/tree/main/nextjs-blog `_. ================================================ FILE: docs/resources/guides/tutorials/nextjs_pages_router.rst ================================================ .. _ref_guide_nextjs_pages_router: ====================== Next.js (Pages Router) ====================== :edb-alt-title: Building a simple blog application with Gel and Next.js (Pages Router) We're going to build a simple blog application with `Next.js `_ and Gel. Let's start by scaffolding our app with Next.js's ``create-next-app`` tool. We'll be using TypeScript for this tutorial. .. code-block:: bash $ npx create-next-app --typescript nextjs-blog This will take a minute to run. The scaffolding tool is creating a simple Next.js app and installing all our dependencies for us. Once it's complete, let's navigate into the directory and start the dev server. .. code-block:: bash $ cd nextjs-blog $ yarn dev Open `localhost:3000 `_ to see the default Next.js homepage. At this point the app's file structure looks like this: .. code-block:: README.md tsconfig.json package.json next.config.js next-env.d.ts pages ├── _app.tsx ├── api │ └── hello.ts └── index.tsx public ├── favicon.ico └── vercel.svg styles ├── Home.module.css └── globals.css There's a custom App component defined in ``pages/_app.tsx`` that loads some global CSS, plus the homepage at ``pages/index.tsx`` and a single API route at ``pages/api/hello.ts``. The ``styles`` and ``public`` directories contain some other assets. Updating the homepage --------------------- Let's start by implementing a simple homepage for our blog application using static data. Replace the contents of ``pages/index.tsx`` with the following. .. code-block:: tsx // pages/index.tsx import type {NextPage} from 'next'; import Head from 'next/head'; import styles from '../styles/Home.module.css'; type Post = { id: string; title: string; content: string; }; const HomePage: NextPage = () => { const posts: Post[] = [ { id: 'post1', title: 'This one weird trick makes using databases fun', content: 'Use Gel', }, { id: 'post2', title: 'How to build a blog with Gel and Next.js', content: "Let's start by scaffolding our app with `create-next-app`.", }, ]; return (
My Blog

Blog

{posts.map((post) => { return (

{post.title}

); })}
); }; export default HomePage; After saving, Next.js should hot-reload, and the homepage should look something like this. .. image:: /docs/tutorials/nextjs/basic_home.png :alt: Basic blog homepage with static content :width: 100% Initializing Gel ---------------- Now let's spin up a database for the app. You have two options to initialize a Gel project: using ``$ npx gel`` without installing the CLI, or installing the gel CLI directly. In this tutorial, we'll use the first option. If you prefer to install the CLI, see the :ref:`Gel CLI guide ` for more information. From the application's root directory, run the following command: .. code-block:: bash $ npx gel project init No `gel.toml` found in `~/nextjs-blog` or above Do you want to initialize a new project? [Y/n] > Y Specify the name of Gel instance to use with this project [default: nextjs_blog]: > nextjs_blog Checking Gel versions... Specify the version of Gel to use with this project [default: x.x]: > ┌─────────────────────┬──────────────────────────────────────────────┐ │ Project directory │ ~/nextjs-blog │ │ Project config │ ~/nextjs-blog/gel.toml │ │ Schema dir (empty) │ ~/nextjs-blog/dbschema │ │ Installation method │ portable package │ │ Start configuration │ manual │ │ Version │ x.x │ │ Instance name │ nextjs_blog │ └─────────────────────┴──────────────────────────────────────────────┘ Initializing Gel instance... Applying migrations... Everything is up to date. Revision initial. Project initialized. This process has spun up a Gel instance called ``nextjs-blog`` and "linked" it with your current directory. As long as you're inside that directory, CLI commands and client libraries will be able to connect to the linked instance automatically, without additional configuration. To test this, run the |gelcmd| command to open a REPL to the linked instance. .. code-block:: bash $ gel Gel x.x (repl x.x) Type \help for help, \quit to quit. gel> select 2 + 2; {4} > From inside this REPL, we can execute EdgeQL queries against our database. But there's not much we can do currently, since our database is schemaless. Let's change that. The project initialization process also created a new subdirectory in our project called ``dbschema``. This is folder that contains everything pertaining to Gel. Currently it looks like this: .. code-block:: dbschema ├── default.gel └── migrations The :dotgel:`default` file will contain our schema. The ``migrations`` directory is currently empty, but will contain our migration files. Let's update the contents of :dotgel:`default` with the following simple blog schema. .. code-block:: sdl # dbschema/default.gel module default { type BlogPost { required property title -> str; required property content -> str { default := "" }; } } .. note:: Gel lets you split up your schema into different ``modules`` but it's common to keep your entire schema in the ``default`` module. Save the file, then let's create our first migration. .. code-block:: bash $ npx gel migration create did you create object type 'default::BlogPost'? [y,n,l,c,b,s,q,?] > y Created ./dbschema/migrations/00001.edgeql The ``dbschema/migrations`` directory now contains a migration file called ``00001.edgeql``. Currently though, we haven't applied this migration against our database. Let's do that. .. code-block:: bash $ npx gel migrate Applied m1fee6oypqpjrreleos5hmivgfqg6zfkgbrowx7sw5jvnicm73hqdq (00001.edgeql) Our database now has a schema consisting of the ``BlogPost`` type. We can create some sample data from the REPL. Run the |gelcmd| command to re-open the REPL. .. code-block:: bash $ gel Gel x.x (repl x.x) Type \help for help, \quit to quit. gel> Then execute the following ``insert`` statements. .. code-block:: edgeql-repl gel> insert BlogPost { .... title := "This one weird trick makes using databases fun", .... content := "Use Gel" .... }; {default::BlogPost {id: 7f301d02-c780-11ec-8a1a-a34776e884a0}} gel> insert BlogPost { .... title := "How to build a blog with Gel and Next.js", .... content := "Let's start by scaffolding our app..." .... }; {default::BlogPost {id: 88c800e6-c780-11ec-8a1a-b3a3020189dd}} Loading posts with an API route ------------------------------- Now that we have a couple posts in the database, let's load them dynamically with a Next.js `API route `_. To do that, we'll need the ``gel`` client library. Let's install that from NPM: .. code-block:: bash $ npm install gel Then create a new file at ``pages/api/post.ts`` and copy in the following code. .. code-block:: typescript // pages/api/post.ts import type {NextApiRequest, NextApiResponse} from 'next'; import {createClient} from 'gel'; export const client = createClient(); export default async function handler( req: NextApiRequest, res: NextApiResponse ) { const posts = await client.query(`select BlogPost { id, title, content };`); res.status(200).json(posts); } This file initializes a Gel client, which manages a pool of connections to the database and provides an API for executing queries. We're using the ``.query()`` method to fetch all the posts in the database with a simple ``select`` statement. If you visit `localhost:3000/api/post `_ in your browser, you should see a plaintext JSON representation of the blog posts we inserted earlier. To fetch these from the homepage, we'll use ``useState``, ``useEffect``, and the built-in ``fetch`` API. At the top of the ``HomePage`` component in ``pages/index.tsx``, replace the static data and add the missing imports. .. code-block:: tsx-diff // pages/index.tsx + import {useState, useEffect} from 'react'; type Post = { id: string; title: string; content: string; }; const HomePage: NextPage = () => { - const posts: Post[] = [ - { - id: 'post1', - title: 'This one weird trick makes using databases fun', - content: 'Use Gel', - }, - { - id: 'post2', - title: 'How to build a blog with Gel and Next.js', - content: "Let's start by scaffolding our app...", - }, - ]; + const [posts, setPosts] = useState(null); + useEffect(() => { + fetch(`/api/post`) + .then((result) => result.json()) + .then(setPosts); + }, []); + if (!posts) return

Loading...

; return
...
; } When you refresh the page, you should briefly see a ``Loading...`` indicator before the homepage renders the (dynamically loaded!) blog posts. Generating the query builder ---------------------------- Since we're using TypeScript, it makes sense to use Gel's powerful query builder. This provides a schema-aware client API that makes writing strongly typed EdgeQL queries easy and painless. The result type of our queries will be automatically inferred, so we won't need to manually type something like ``type Post = { id: string; ... }``. First, install the generator to your project. .. code-block:: bash $ yarn add --dev @gel/generate Then generate the query builder with the following command. .. code-block:: bash $ npx @gel/generate edgeql-js Generating query builder... Detected tsconfig.json, generating TypeScript files. To override this, use the --target flag. Run `npx @gel/generate --help` for full options. Introspecting database schema... Writing files to ./dbschema/edgeql-js Generation complete! 🤘 Checking the generated query builder into version control is not recommended. Would you like to update .gitignore to ignore the query builder directory? The following line will be added: dbschema/edgeql-js [y/n] (leave blank for "y") > y This command introspected the schema of our database and generated some code in the ``dbschema/edgeql-js`` directory. It also asked us if we wanted to add the generated code to our ``.gitignore``; typically it's not good practice to include generated files in version control. Back in ``pages/api/post.ts``, let's update our code to use the query builder instead. .. code-block:: typescript-diff // pages/api/post.ts import type {NextApiRequest, NextApiResponse} from 'next'; import {createClient} from 'gel'; + import e, {$infer} from '../../dbschema/edgeql-js'; export const client = createClient(); + const selectPosts = e.select(e.BlogPost, () => ({ + id: true, + title: true, + content: true, + })); + export type Posts = $infer; export default async function handler( req: NextApiRequest, res: NextApiResponse ) { - const posts = await client.query(`select BlogPost { - id, - title, - content - };`); + const posts = await selectPosts.run(client); res.status(200).json(posts); } Instead of writing our query as a plain string, we're now using the query builder to declare our query in a code-first way. As you can see we import the query builder as a single default import ``e`` from the ``dbschema/edgeql-js`` directory. We're also using a utility called ``$infer`` to extract the inferred type of this query. In VSCode you can hover over ``Posts`` to see what this type is. .. image:: /docs/tutorials/nextjs/inference.png :alt: Inferred type of posts query :width: 100% Back in ``pages/index.tsx``, let's update our code to use the inferred ``Posts`` type instead of our manual type declaration. .. code-block:: typescript-diff // pages/index.tsx import type {NextPage} from 'next'; import Head from 'next/head'; import {useEffect, useState} from 'react'; import styles from '../styles/Home.module.css'; + import {Posts} from "./api/post"; - type Post = { - id: string; - title: string; - content: string; - }; const Home: NextPage = () => { + const [posts, setPosts] = useState(null); // ... } Now, when we update our ``selectPosts`` query, the type of our dynamically loaded ``posts`` variable will update automatically—no need to keep our type definitions in sync with our API logic! Rendering blog posts -------------------- Our homepage renders a list of links to each of our blog posts, but we haven't implemented the page that actually displays the posts. Let's create a new page at ``pages/post/[id].tsx``. This is a `dynamic route `_ that includes an ``id`` URL parameter. We'll use this parameter to fetch the appropriate post from the database. Create ``pages/post/[id].tsx`` and add the following code. We're using ``getServerSideProps`` to load the blog post data server-side, to avoid loading spinners and ensure the page loads fast. .. code-block:: tsx import React from 'react'; import {GetServerSidePropsContext, InferGetServerSidePropsType} from 'next'; import {client} from '../api/post'; import e from '../../dbschema/edgeql-js'; export const getServerSideProps = async ( context?: GetServerSidePropsContext ) => { const post = await e .select(e.BlogPost, (post) => ({ id: true, title: true, content: true, filter_single: e.op( post.id, '=', e.uuid(context!.params!.id as string) ), })) .run(client); return {props: {post: post!}}; }; export type GetPost = InferGetServerSidePropsType; const Post: React.FC = (props) => { return (

{props.post.title}

{props.post.content}

); }; export default Post; Inside ``getServerSideProps`` we're extracting the ``id`` parameter from ``context.params`` and using it in our EdgeQL query. The query is a ``select`` query that fetches the ``id``, ``title``, and ``content`` of the post with a matching ``id``. We're using Next's ``InferGetServerSidePropsType`` utility to extract the inferred type of our query and pass it into ``React.FC``. Now, if we update our query, the type of the component props will automatically update too. In fact, this entire application is end-to-end typesafe. Now, click on one of the blog post links on the homepage. This should bring you to ``/post/``, which should display something like this: .. image:: /docs/tutorials/nextjs/post.png :alt: Basic blog homepage with static content :width: 100% Deploying to Vercel ------------------- **#1 Deploy Gel** First deploy a Gel instance on your preferred cloud provider: - :ref:`AWS ` - :ref:`Azure ` - :ref:`DigitalOcean ` - :ref:`Fly.io ` - :ref:`Google Cloud ` or use a cloud-agnostic deployment method: - :ref:`Docker ` - :ref:`Bare metal ` **#2. Find your instance's DSN** The DSN is also known as a connection string. It will have the format :geluri:`username:password@hostname:port`. The exact instructions for this depend on which cloud you are deploying to. **#3 Apply migrations** Use the DSN to apply migrations against your remote instance. .. code-block:: bash $ npx gel migrate --dsn --tls-security insecure .. note:: You have to disable TLS checks with ``--tls-security insecure``. All Gel instances use TLS by default, but configuring it is out of scope of this project. Once you've applied the migrations, consider creating some sample data in your database. Open a REPL and ``insert`` some blog posts: .. code-block:: bash $ npx gel --dsn --tls-security insecure Gel x.x (repl x.x) Type \help for help, \quit to quit. gel> insert BlogPost { title := "Test post" }; {default::BlogPost {id: c00f2c9a-cbf5-11ec-8ecb-4f8e702e5789}} **#4 Set up a `prebuild` script** Add the following ``prebuild`` script to your ``package.json``. When Vercel initializes the build, it will trigger this script which will generate the query builder. The ``npx @gel/generate edgeql-js`` command will read the value of the :gelenv:`DSN` variable, connect to the database, and generate the query builder before Vercel starts building the project. .. code-block:: javascript-diff // package.json "scripts": { "dev": "next dev", "build": "next build", "start": "next start", "lint": "next lint", + "prebuild": "npx @gel/generate edgeql-js" }, **#5 Deploy to Vercel** Deploy this app to Vercel with the button below. .. XXX -- update URL .. lint-off .. image:: https://vercel.com/button :width: 150px :target: https://vercel.com/new/git/external?repository-url=https://github.com/geldata/gel-examples/tree/main/nextjs-blog&project-name=nextjs-edgedb-blog&repository-name=nextjs-edgedb-blog&env=GEL_DSN,GEL_CLIENT_TLS_SECURITY .. lint-on When prompted: - Set :gelenv:`DSN` to your database's DSN - Set :gelenv:`CLIENT_TLS_SECURITY` to ``insecure``. This will disable Gel's default TLS checks; configuring TLS is beyond the scope of this tutorial. .. image:: /docs/tutorials/nextjs/env.png :alt: Setting environment variables in Vercel :width: 100% **#6 View the application** Once deployment has completed, view the application at the deployment URL supplied by Vercel. Wrapping up ----------- Admittedly this isn't the prettiest blog of all time, or the most feature-complete. But this tutorial demonstrates how to work with Gel in a Next.js app, including data fetching with API routes and ``getServerSideProps``. The next step is to add a ``/newpost`` page with a form for writing new blog posts and saving them into Gel. That's left as an exercise for the reader. To see the final code for this tutorial, refer to `github.com/geldata/gel-examples/tree/main/nextjs-blog `_. ================================================ FILE: docs/resources/guides/tutorials/rest_apis_with_fastapi.rst ================================================ .. _ref_guide_rest_apis_with_fastapi: ======= FastAPI ======= :edb-alt-title: Building a REST API with Gel and FastAPI Because FastAPI encourages and facilitates strong typing, it's a natural pairing with Gel. Our Python code generation generates not only typed query functions but result types you can use to annotate your endpoint handler functions. |Gel| can help you quickly build REST APIs in Python without getting into the rigmarole of using ORM libraries to handle your data effectively. Here, we'll be using `FastAPI `_ to expose the API endpoints and Gel to store the content. We'll build a simple event management system where you'll be able to fetch, create, update, and delete *events* and *event hosts* via RESTful API endpoints. Prerequisites ============= Before we start, make sure you've :ref:`installed ` the |gelcmd| command line tool. For this tutorial, we'll use Python 3.10 to take advantage of the asynchronous I/O paradigm to communicate with the database more efficiently. You can use newer versions of Python if you prefer, but you may need to adjust the code accordingly. If you want to skip ahead, the completed source code for this API can be found `in our examples repo `_. If you want to check out an example with Gel Auth, you can find that in the same repo in the `fastapi-crud-auth folder `_. Create a project directory ^^^^^^^^^^^^^^^^^^^^^^^^^^ To get started, create a directory for your project and change into it. .. code-block:: bash $ mkdir fastapi-crud $ cd fastapi-crud Install the dependencies ^^^^^^^^^^^^^^^^^^^^^^^^ Create a Python virtual environment, activate it, and install the dependencies with this command (in Linux/macOS; see the following note for help with Windows): .. code-block:: bash $ python -m venv myvenv $ source myvenv/bin/activate $ pip install gel fastapi 'httpx[cli]' uvicorn .. note:: Make sure you run ``source myvenv/bin/activate`` any time you want to come back to this project to activate its virtual environment. If not, you may start working under your system's default Python environment which could be the incorrect version or not have the dependencies installed. If you want to confirm you're using the right environment, run ``which python``. You should see that the current ``python`` is inside your venv directory. .. note:: The commands will differ for Windows/Powershell users; `this guide `_ provides instructions for working with virtual environments across a range of OSes, including Windows. Initialize the database ^^^^^^^^^^^^^^^^^^^^^^^ Now, let's initialize a Gel project. From the project's root directory: .. code-block:: bash $ gel project init No `gel.toml` found in `` or above Do you want to initialize a new project? [Y/n] > Y Specify the name of Gel instance to use with this project [default: fastapi_crud]: > fastapi_crud Checking Gel versions... Specify the version of Gel to use with this project [default: 2.7]: > 2.7 Once you've answered the prompts, a new Gel instance called ``fastapi_crud`` will be created and started. If you see ``Project initialized``, you're ready. Connect to the database ^^^^^^^^^^^^^^^^^^^^^^^ Let's test that we can connect to the newly started instance. To do so, run: .. code-block:: bash $ gel You should see this prompt indicating you are now connected to your new database instance: :: Gel x.x (repl x.x) Type \help for help, \quit to quit. gel> You can start writing queries here. Since this database is empty, that won't get you very far, so let's start designing our data model instead. Schema design ============= The event management system will have two entities: **events** and **users**. Each *event* can have an optional link to a *user* who is that event's host. The goal is to create API endpoints that'll allow us to fetch, create, update, and delete the entities while maintaining their relationships. |Gel| allows us to declaratively define the structure of the entities. If you've worked with SQLAlchemy or Django ORM, you might refer to these declarative schema definitions as *models*. In Gel we call them "object types". The schema lives inside |.gel| files in the ``dbschema`` directory. It's common to declare the entire schema in a single file :dotgel:`dbschema/default`. This file is created for you when you run :gelcmd:`project init`, but you'll need to fill it with your schema. This is what our datatypes look like: .. code-block:: sdl :caption: dbschema/default.gel module default { abstract type Auditable { required created_at: datetime { readonly := true; default := datetime_current(); } } type User extending Auditable { required name: str { constraint exclusive; constraint max_len_value(50); }; } type Event extending Auditable { required name: str { constraint exclusive; constraint max_len_value(50); } address: str; schedule: datetime; link host: User; } } Here, we've defined an ``abstract`` type called ``Auditable`` to take advantage of Gel's schema mixin system. This allows us to add a ``created_at`` property to multiple types without repeating ourselves. Abstract types don't have any concrete footprints in the database, as they don't hold any actual data. Their only job is to propagate properties, links, and constraints to the types that extend them. The ``User`` type extends ``Auditable`` and inherits the ``created_at`` property as a result. Since ``created_at`` has a ``default`` value, it's auto-filled with the return value of the ``datetime_current`` function. Along with the property conveyed to it by the extended type, the ``User`` type defines its own concrete required property called ``name``. We impose two constraints on this property: names should be unique (``constraint exclusive``) and shorter than 50 characters (``constraint max_len_value(50)``). We also define an ``Event`` type that extends the ``Auditable`` abstract type. It contains its own concrete properties and links: ``address``, ``schedule``, and an optional link called ``host`` that corresponds to a ``User``. Run a migration =============== With the schema created, it's time to lock it in. The first step is to create a migration. .. code-block:: bash $ gel migration create When this step is successful, you'll see ``Created dbschema/migrations/00001.edgeql``. Now run the migration we just created. .. code-block:: bash $ gel migrate Once this is done, you'll see ``Applied`` along with the migration's ID. I like to go one step further in verifying success and see the schema applied to my database. To do that, first fire up the Gel console: .. code-block:: bash $ gel In the console, type ``\ds`` (for "describe schema"). If everything worked, we should output very close to the schema we added in the :dotgel:`default` file: :: module default { abstract type Auditable { required property created_at: std::datetime { default := (std::datetime_current()); readonly := true; }; }; type Event extending default::Auditable { link host: default::User; property address: std::str; required property name: std::str { constraint std::exclusive; constraint std::max_len_value(50); }; property schedule: std::datetime; }; type User extending default::Auditable { required property name: std::str { constraint std::exclusive; constraint std::max_len_value(50); }; }; }; Build the API endpoints ======================= With the schema established, we're ready to start building out the app. Let's start by creating an ``app`` directory inside our project: .. code-block:: bash $ mkdir app Within this ``app`` directory, we're going to create three modules: ``events.py`` and ``users.py`` which represent the events and users APIs respectively, and ``main.py`` that registers all the endpoints and exposes them to the `uvicorn `_ webserver. We also need an ``__init__.py`` to mark this directory as a package so we can easily import from it. Go ahead and create that file now in your editor or via the command line like this (from the project root): .. code-block:: bash $ touch app/__init__.py We'll work on the users API first since it's the simpler of the two. Users API ^^^^^^^^^ We want this app to be type safe, end to end. To achieve this, instead of hard-coding string queries into the app, we'll use code generation to generate typesafe functions from queries we write in ``.edgeql`` files. These files are simple text files containing the queries we want our app to be able to run. The code generator will search through our project for all files with the ``.edgeql`` extension and generate those functions for us as individual Python modules. When you installed the Gel client (via ``pip install gel``), the code generator was installed alongside it, so you're already ready to go. We just need to write those queries! We'll write queries for one endpoint at a time to start so you can see how the pieces fit together. To keep things organized, create a new directory inside ``app`` called ``queries``. Create a new file in ``app/queries`` named ``get_users.edgeql`` and open it in your editor. Write the query into this file. It's the same one we would have written inline in our Python code as shown in the code block above: .. code-block:: edgeql :caption: app/queries/get_users.edgeql select User {name, created_at}; We need one more query to finish off this endpoint. Create another file inside ``app/queries`` named ``get_user_by_name.edgeql`` and open it in your editor. Add this query: .. code-block:: edgeql select User {name, created_at} filter User.name = $name Save that file and get ready to kick off the magic that is code generation! 🪄 .. code-block:: bash $ gel-py Found Gel project: Processing /app/queries/get_user_by_name.edgeql Processing /app/queries/get_users.edgeql Generating /app/queries/get_user_by_name.py Generating /app/queries/get_users.py The code generator creates one module per query file by default and places them at the same path as the query files. With code generated, we're ready to write an endpoint. Let's create the ``GET /users`` endpoint so that we can request the ``User`` objects saved in the database. Create a new file ``app/users.py``, open it in your editor, and add the following code: .. lint-off .. code-block:: python :caption: app/users.py from __future__ import annotations import datetime from http import HTTPStatus from typing import List import gel from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel from .queries import get_user_by_name_async_edgeql as get_user_by_name_qry from .queries import get_users_async_edgeql as get_users_qry router = APIRouter() client = gel.create_async_client() class RequestData(BaseModel): name: str @router.get("/users") async def get_users( name: str = Query(None, max_length=50) ) -> List[get_users_qry.GetUsersResult] | get_user_by_name_qry.GetUserByNameResult: if not name: users = await get_users_qry.get_users(client) return users else: user = await get_user_by_name_qry.get_user_by_name(client, name=name) return user .. lint-on We've imported the generated code and aliased it (using ``as ``) to make the module names we use in our code a bit neater. The ``APIRouter`` instance does the actual work of exposing the API. We also create an async Gel client instance to communicate with the database. By default, this API will return a list of all users, but you can also filter the user objects by name. We have the ``RequestData`` class to handle the data an API consumer will need to send in case they want to get only a single user. The types we're using in the return annotation have been generated by the |Gel| code generation based on the queries we wrote and our database's schema. Note that we're also calling the appropriate generated function based on whether or not the API consumer passes an argument for ``name``. This nearly gets us there but not quite. We have one potential outcome not accounted for: a query for a user by name that returns no results. In that case, we'll want to return a 404 (not found). To fix it, we'll check in the else case whether we got anything back from the single user query. If not, we'll go ahead and raise an exception. This will send the 404 (not found) response to the user. .. lint-off .. code-block:: python :caption: app/users.py ... if not name: users = await get_users_qry.get_users(client) return users else: user = await get_user_by_name_qry.get_user_by_name(client, name=name) if not user: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail={"error": f"Username '{name}' does not exist."}, ) return user ... .. lint-on To summarize, in the ``get_users`` function, we use our generated code to perform asynchronous queries via the ``gel`` client. Then we return the query results. Afterward, the JSON serialization part is taken care of by FastAPI. Before we can use this endpoint, we need to expose it to the server. We'll do that in the ``main.py`` module. Create ``app/main.py`` and open it in your editor. Here's the content of the module: .. code-block:: python :caption: app/main.py from __future__ import annotations from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware from app import users fast_api = FastAPI() # Set all CORS enabled origins. fast_api.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) fast_api.include_router(users.router) Here, we import everything we need, including our own ``users`` module containing the router and endpoint logic for the users API. We instantiate the API, give it a permissive CORS configuration, and give it the users router. To test the endpoint, go to the project root and run: .. code-block:: bash $ uvicorn app.main:fast_api --port 5001 --reload This will start a ``uvicorn`` server and you'll be able to start making requests against it. Earlier, we installed the `HTTPx `_ client library to make HTTP requests programmatically. It also comes with a neat command-line tool that we'll use to test our API. While the ``uvicorn`` server is running, bring up a new console. Activate your virtual environment by running ``source myenv/bin/activate`` and run: .. code-block:: bash $ httpx -m GET http://localhost:5001/users You'll see the following output on the console: :: HTTP/1.1 200 OK date: Sat, 16 Apr 2022 22:58:11 GMT server: uvicorn content-length: 2 content-type: application/json [] .. note:: If you find yourself with a result you don't expect when making a request to your API, switch over to the uvicorn server console. You should find a traceback that will point you to the problem area in your code. If you see this result, that means the API is working! It's not especially useful though. Our request yields an empty list because the database is currently empty. Let's create the ``POST /users`` endpoint in ``app/users.py`` to start saving users in the database. Before we do that though, let's go ahead and create the new query we'll need. Create and open ``app/queries/create_user.edgeql`` and fill it with this query: .. code-block:: edgeql :caption: app/queries/create_user.edgeql select (insert User { name := $name }) { name, created_at }; .. note:: We're running our ``insert`` inside a ``select`` here so that we can return the ``name`` and ``created_at`` properties. If we just ran the ``insert`` bare, it would return only the ``id``. Save the file and run ``gel-py`` to generate the new function. Now, we're ready to open ``app/users.py`` again and add the POST endpoint. First, import the generated function for the new query: .. code-block:: python :caption: app/users.py ... from .queries import create_user_async_edgeql as create_user_qry from .queries import get_user_by_name_async_edgeql as get_user_by_name_qry from .queries import get_users_async_edgeql as get_users_qry ... Then write the endpoint to call that function: .. lint-off .. code-block:: python :caption: app/users.py ... @router.post("/users", status_code=HTTPStatus.CREATED) async def post_user(user: RequestData) -> create_user_qry.CreateUserResult: try: created_user = await create_user_qry.create_user(client, name=user.name) except gel.errors.ConstraintViolationError: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail={"error": f"Username '{user.name}' already exists."}, ) return created_user .. lint-on In the above snippet, we ingest data with the shape dictated by the ``RequestData`` model and return a payload of the query results. The ``try...except`` block gracefully handles the situation where the API consumer might try to create multiple users with the same name. A successful request will yield the status code HTTP 201 (created) along with the new user's ``id``, ``name``, and ``created_at`` as JSON. To test it out, make a request as follows: .. code-block:: bash $ httpx -m POST http://localhost:5001/users \ --json '{"name" : "Jonathan Harker"}' The output should look similar to this: :: HTTP/1.1 201 Created ... { "id": "53771f56-6f57-11ed-8729-572f5fba7ddc", "name": "Jonathan Harker", "created_at": "2022-04-16T23:09:30.929664+00:00" } .. note:: Since IDs are generated, your ``id`` values probably won't match the values in this guide. This is not a problem. If you try to make the same request again, it'll throw an HTTP 400 (bad request) error: :: HTTP/1.1 400 Bad Request ... { "detail": { "error": "Username 'Jonathan Harker' already exists." } } Before we move on to the next step, create 2 more users called ``Count Dracula`` and ``Mina Murray``. Once you've done that, we can move on to the next step of building the ``PUT /users`` endpoint to update existing user data. We'll start again with the query. Create a new file in ``app/queries`` named ``update_user.edgeql``. Open it in your editor and enter this query: .. code-block:: edgeql :caption: app/queries/update_user.edgeql select ( update User filter .name = $current_name set {name := $new_name} ) {name, created_at}; Save the file and generate again using ``gel-py``. Now, we'll import that and add the endpoint over in ``app/users.py``. .. lint-off .. code-block:: python :caption: app/users.py ... from .queries import create_user_async_edgeql as create_user_qry from .queries import get_user_by_name_async_edgeql as get_user_by_name_qry from .queries import get_users_async_edgeql as get_users_qry from .queries import update_user_async_edgeql as update_user_qry ... @router.put("/users") async def put_user( user: RequestData, current_name: str ) -> update_user_qry.UpdateUserResult: try: updated_user = await update_user_qry.update_user( client, new_name=user.name, current_name=current_name, ) except gel.errors.ConstraintViolationError: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail={"error": f"Username '{user.name}' already exists."}, ) if not updated_user: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail={"error": f"User '{current_name}' was not found."}, ) return updated_user .. lint-on Not much new happening here. We wrote our query with a ``current_name`` parameter for finding the user to be updated. The ``user`` argument will give us the changes to make to that user, which in this case can only be the ``name`` since that's the only property a user has. We pull the name out of ``user`` and pass it as our ``new_name`` argument to the generated function. The endpoint calls the generated function passing the client and those two values, and the user is updated. We've accounted for the possibility of a user trying to change a user's name to a new name that conflicts with a different user. That will return a 400 (bad request) error. We've also accounted for the possibility of a user trying to update a user that doesn't exist, which will return a 404 (not found). Let's save everything and test this out. .. code-block:: bash $ httpx -m PUT http://localhost:5001/users \ -p 'current_name' 'Jonathan Harker' \ --json '{"name" : "Dr. Van Helsing"}' This will return: :: HTTP/1.1 200 OK ... [ { "id": "53771f56-6f57-11ed-8729-572f5fba7ddc", "name": "Dr. Van Helsing", "created_at": "2022-04-16T23:09:30.929664+00:00" } ] If you try to change the name of a user to match that of an existing user, the endpoint will throw an HTTP 400 (bad request) error: .. code-block:: bash $ httpx -m PUT http://localhost:5001/users \ -p 'current_name' 'Count Dracula' \ --json '{"name" : "Dr. Van Helsing"}' This returns: :: HTTP/1.1 400 Bad Request ... { "detail": { "error": "Username 'Dr. Van Helsing' already exists." } } Since we've verified that endpoint is working, let's move on to the ``DELETE /users`` endpoint. It'll allow us to query the name of the targeted object to delete it. Start by creating ``app/queries/delete_user.edgeql`` and filling it with this query: .. code-block:: edgeql :caption: app/queries/delete_user.edgeql select ( delete User filter .name = $name ) {name, created_at}; Generate the new function by again running ``gel-py``. Then re-open ``app/users.py``. This endpoint's code will look similar to the endpoints we've already written: .. lint-off .. code-block:: python :caption: app/users.py ... from .queries import create_user_async_edgeql as create_user_qry from .queries import delete_user_async_edgeql as delete_user_qry from .queries import get_user_by_name_async_edgeql as get_user_by_name_qry from .queries import get_users_async_edgeql as get_users_qry from .queries import update_user_async_edgeql as update_user_qry ... @router.delete("/users") async def delete_user(name: str) -> delete_user_qry.DeleteUserResult: try: deleted_user = await delete_user_qry.delete_user( client, name=name, ) except gel.errors.ConstraintViolationError: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail={"error": "User attached to an event. Cannot delete."}, ) if not deleted_user: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail={"error": f"User '{name}' was not found."}, ) return deleted_user .. lint-on This endpoint will simply delete the requested user if the user isn't attached to any event. If the targeted object *is* attached to an event, the API will throw an HTTP 400 (bad request) error and refuse to delete the object. To test it out by deleting ``Count Dracula``, on your console, run: .. code-block:: bash $ httpx -m DELETE http://localhost:5001/users \ -p 'name' 'Count Dracula' If it worked, you should see this result: :: HTTP/1.1 200 OK ... [ { "id": "e6837562-6f55-11ed-8744-ff1b295ed864", "name": "Count Dracula", "created_at": "2022-04-16T23:23:56.630101+00:00" } ] With that, you've written the entire users API! Now, we move onto the events API which is slightly more complex. (Nothing you can't handle though. 😁) Events API ^^^^^^^^^^ Let's start with the ``POST /events`` endpoint, and then we'll fetch the objects created via POST using the ``GET /events`` endpoint. First, we need a query. Create a file ``app/queries/create_event.edgeql`` and drop this query into it: .. code-block:: edgeql :caption: app/queries/create_event.edgeql with name := $name, address := $address, schedule := $schedule, host_name := $host_name select ( insert Event { name := name, address := address, schedule := schedule, host := assert_single( (select detached User filter .name = host_name) ) } ) {name, address, schedule, host: {name}}; Run ``gel-py`` to generate a function from that query. Create a file in ``app`` named ``events.py`` and open it in your editor. It's time to code up the endpoint to use that freshly generated query. .. lint-off .. code-block:: python :caption: app/events.py from __future__ import annotations from http import HTTPStatus from typing import List import gel from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel from .queries import create_event_async_edgeql as create_event_qry router = APIRouter() client = gel.create_async_client() class RequestData(BaseModel): name: str address: str schedule: str host_name: str @router.post("/events", status_code=HTTPStatus.CREATED) async def post_event(event: RequestData) -> create_event_qry.CreateEventResult: try: created_event = await create_event_qry.create_event( client, name=event.name, address=event.address, schedule=event.schedule, host_name=event.host_name, ) except gel.errors.InvalidValueError: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail={ "error": "Invalid datetime format. " "Datetime string must look like this: " "'2010-12-27T23:59:59-07:00'", }, ) except gel.errors.ConstraintViolationError: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail=f"Event name '{event.name}' already exists,", ) return created_event .. lint-on Like the ``POST /users`` endpoint, the incoming and outgoing shape of the ``POST /events`` endpoint's data are defined by the ``RequestData`` model and the generated ``CreateEventResult`` model respectively. The ``post_events`` function asynchronously inserts the data into the database and returns the fields defined in the ``select`` query we wrote earlier, along with the new event's ``id``. The exception handling logic validates the shape of the incoming data. For example, just as before in the users API, the events API will complain if you try to create multiple events with the same name. Also, the field ``schedule`` accepts data as an `ISO 8601 `_ timestamp string. Values not adhering to that will incur an HTTP 400 (bad request) error. It's almost time to test, but before we can do that, we need to expose this new API in ``app/main.py``. Open that file, and update the import on line 6 to also import ``events``: .. code-block:: python :caption: app/main.py ... from app import users, events ... Drop down to the bottom of ``main.py`` and include the events router: .. code-block:: python :caption: app/main.py ... fast_api.include_router(events.router) Let's try it out. Here's how you'd create an event: .. code-block:: bash $ httpx -m POST http://localhost:5001/events \ --json '{ "name":"Resuscitation", "address":"Britain", "schedule":"1889-07-27T23:59:59-07:00", "host_name":"Mina Murray" }' If everything worked, you'll see output like this: :: HTTP/1.1 200 OK ... { "id": "0b1847f4-6f3d-11ed-9f27-6fcdf20ffe22", "name": "Resuscitation", "address": "Britain", "schedule": "1889-07-28T06:59:59+00:00", "host": { "name": "Mina Murray" } } To speed this up a bit, we'll go ahead and write all the remaining queries in one shot. Then we can flip back to ``app/events.py`` and code up all the endpoints. Start by creating a file in ``app/queries`` named ``get_events.edgeql``. This one is really straightforward: .. code-block:: edgeql :caption: app/queries/get_events.edgeql select Event {name, address, schedule, host : {name}}; Save that one and create ``app/queries/get_event_by_name.edgeql`` with this query: .. code-block:: edgeql :caption: app/queries/get_event_by_name.edgeql select Event { name, address, schedule, host : {name} } filter .name = $name; Those two will handle queries for ``GET /events``. Next, create ``app/queries/update_event.edgeql`` with this query: .. code-block:: edgeql :caption: app/queries/update_event.edgeql with current_name := $current_name, new_name := $name, address := $address, schedule := $schedule, host_name := $host_name select ( update Event filter .name = current_name set { name := new_name, address := address, schedule := schedule, host := (select User filter .name = host_name) } ) {name, address, schedule, host: {name}}; That query will handle PUT requests. The last method left is DELETE. Create ``app/queries/delete_event.edgeql`` and put this query in it: .. code-block:: edgeql :caption: app/queries/delete_event.edgeql select ( delete Event filter .name = $name ) {name, address, schedule, host : {name}}; Run ``gel-py`` to generate the new functions. Open ``app/events.py`` so we can start getting these functions implemented in the API! We'll start by coding GET. Import the newly generated queries and write the GET endpoint in ``events.py``: .. lint-off .. code-block:: python :caption: app/events.py ... from .queries import create_event_async_edgeql as create_event_qry from .queries import delete_event_async_edgeql as delete_event_qry from .queries import get_event_by_name_async_edgeql as get_event_by_name_qry from .queries import get_events_async_edgeql as get_events_qry from .queries import update_event_async_edgeql as update_event_qry ... @router.get("/events") async def get_events( name: str = Query(None, max_length=50) ) -> List[get_events_qry.GetEventsResult] | get_event_by_name_qry.GetEventByNameResult: if not name: events = await get_events_qry.get_events(client) return events else: event = await get_event_by_name_qry.get_event_by_name(client, name=name) if not event: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail={"error": f"Event '{name}' does not exist."}, ) return event .. lint-on Save that file and test it like this: .. code-block:: bash $ httpx -m GET http://localhost:5001/events We should get back an array containing all our events (which, at the moment, is just the one): :: HTTP/1.1 200 OK ... [ { "id": "0b1847f4-6f3d-11ed-9f27-6fcdf20ffe22", "name": "Resuscitation", "address": "Britain", "schedule": "1889-07-28T06:59:59+00:00", "host": { "name": "Mina Murray" } } ] You can also use the ``GET /events`` endpoint to return a single event object by name. To locate the ``Resuscitation`` event, you'd use the ``name`` parameter with the GET API as follows: .. code-block:: bash $ httpx -m GET http://localhost:5001/events \ -p 'name' 'Resuscitation' That'll return a result that looks like the response we just got without the ``name`` parameter, except that it's a single object instead of an array. :: HTTP/1.1 200 OK ... { "id": "0b1847f4-6f3d-11ed-9f27-6fcdf20ffe22", "name": "Resuscitation", "address": "Britain", "schedule": "1889-07-28T06:59:59+00:00", "host": { "name": "Mina Murray" } } If we'd had multiple events, the response to our first test would have given us all of them. Let's finish off the events API with the PUT and DELETE endpoints. Open ``app/events.py`` and add this code: .. lint-off .. code-block:: python :caption: app/events.py ... @router.put("/events") async def put_event( event: RequestData, current_name: str ) -> update_event_qry.UpdateEventResult: try: updated_event = await update_event_qry.update_event( client, current_name=current_name, name=event.name, address=event.address, schedule=event.schedule, host_name=event.host_name, ) except gel.errors.InvalidValueError: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail={ "error": "Invalid datetime format. " "Datetime string must look like this: '2010-12-27T23:59:59-07:00'", }, ) except gel.errors.ConstraintViolationError: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail={"error": f"Event name '{event.name}' already exists."}, ) if not updated_event: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail={"error": f"Update event '{event.name}' failed."}, ) return updated_event @router.delete("/events") async def delete_event(name: str) -> delete_event_qry.DeleteEventResult: deleted_event = await delete_event_qry.delete_event(client, name=name) if not deleted_event: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail={"error": f"Delete event '{name}' failed."}, ) return deleted_event .. lint-on The events API is now ready to handle updates and deletion. Let's try out a cool alternative way to test these new endpoints. Browse the endpoints using the native OpenAPI doc ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ FastAPI automatically generates OpenAPI schema from the API endpoints and uses those to build the API docs. While the ``uvicorn`` server is running, go to your browser and head over to `http://localhost:5001/docs `_. You should see an API navigator like this: .. image:: /docs/tutorials/fastapi/openapi.png :alt: FastAPI docs navigator :width: 100% This documentation allows you to play with the APIs interactively. Let's try to make a request to the ``PUT /events``. Click on the API that you want to try and then click on the **Try it out** button. You can do it in the UI as follows: .. image:: /docs/tutorials/fastapi/put.png :alt: FastAPI docs PUT events API :width: 100% Clicking the **execute** button will make the request and return the following payload: .. image:: /docs/tutorials/fastapi/put_result.png :alt: FastAPI docs PUT events API result :width: 100% You can do the same to test ``DELETE /events``, just make sure you give it whatever name you set for the event in your previous test of the PUT method. Integrating Gel Auth ==================== |Gel| Auth provides a built-in authentication solution that is deeply integrated with the Gel server. This section outlines how to enable and configure Gel Auth in your application schema, manage authentication providers, and set key configuration parameters. Setting up Gel Auth ^^^^^^^^^^^^^^^^^^^ To start using Gel Auth, you must first enable it in your schema. Add the following to your schema definition: .. code-block:: sdl using extension auth; Once added, make sure to apply the schema changes by migrating your database schema. .. code-block:: bash $ gel migration create $ gel migrate Configuring Gel Auth -------------------- The configuration of Gel Auth involves setting various parameters to secure and tailor authentication to your needs. For now, we'll focus on the essential parameters to get started. You can configure these settings through a Python script, which is recommended for scalability, or you can use the Gel UI for a more user-friendly approach. **Auth Signing Key** This key is used to sign the JWTs for internal operations. Although it's not necessary for your application's functionality, it's essential for secure token handling. To generate a secure key, you can use OpenSSL or Python with the following commands: Using OpenSSL: .. code-block:: bash $ openssl rand -base64 32 Using Python: .. code-block:: python import secrets print(secrets.token_urlsafe(32)) Once you have generated your key, configure it in Gel like this: .. code-block:: edgeql CONFIGURE CURRENT BRANCH SET ext::auth::AuthConfig::auth_signing_key := ''; **Allowed redirect URLs** This configuration ensures that redirections are limited to domains under your control. The ``allowed_redirect_urls`` setting specifies URLs that the Auth extension can safely redirect to after authentication. A URL must exactly match or be a sub-path of a URL in the list to be considered valid. To configure this in your application: .. code-block:: edgeql CONFIGURE CURRENT BRANCH SET ext::auth::AuthConfig::allowed_redirect_urls := { 'http://localhost:8000', 'http://localhost:8000/auth' }; Enabling authentication providers --------------------------------- You need to configure at least one authentication provider to use Gel Auth. This can be done via the Gel UI or directly through queries. In this example, we'll configure a email and password provider. You can add it with the following query: .. code-block:: edgeql CONFIGURE CURRENT BRANCH INSERT ext::auth::EmailPasswordProviderConfig { require_verification := false, }; .. note:: ``require_verification`` defaults to ``true``. In this example, we're setting it to ``false`` to simplify the setup. In a production environment, you should set it to ``true`` to ensure that users verify their email addresses before they can log in. If you use the Email and Password provider, in addition to the ``require_verification`` configuration, you'll need to configure SMTP to allow |Gel| to send email verification and password reset emails on your behalf. Here is an example of setting a local SMTP server, in this case using a product called `Mailpit `__ which is great for testing in development: .. code-block:: edgeql CONFIGURE CURRENT BRANCH SET ext::auth::SMTPConfig::sender := 'hello@example.com'; CONFIGURE CURRENT BRANCH SET ext::auth::SMTPConfig::host := 'localhost'; CONFIGURE CURRENT BRANCH SET ext::auth::SMTPConfig::port := 1025; CONFIGURE CURRENT BRANCH SET ext::auth::SMTPConfig::security := 'STARTTLSOrPlainText'; CONFIGURE CURRENT BRANCH SET ext::auth::SMTPConfig::validate_certs := false; You can query the database configuration to discover which providers are configured with the following query: .. code-block:: edgeql select cfg::Config.extensions[is ext::auth::AuthConfig].providers { name, [is ext::auth::OAuthProviderConfig].display_name, }; Implementing authentication with FastAPI ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Below, we provide a detailed guide to setting up authentication using FastAPI, including both sign-in and sign-up functionalities. PKCE flow for enhanced security ------------------------------- The PKCE (Proof Key for Code Exchange) flow enhances security in server-to-server authentication by generating a unique verifier and its corresponding challenge. First, your server creates a 32-byte Base64 URL-encoded verifier, stores it in an HttpOnly cookie, hashes it with SHA256, and then encodes it to form the challenge. This implementation ensures enhanced security by preventing token leakage and is tailored specifically for server-to-server interactions. Add the following code to your FastAPI application to generate the PKCE: .. code-block:: python :caption: app/auth.py import secrets import hashlib import base64 def generate_pkce(): verifier = secrets.token_urlsafe(32) challenge = hashlib.sha256(verifier.encode()).digest() challenge_base64 = base64.urlsafe_b64encode(challenge).decode('utf-8').rstrip('=') return verifier, challenge_base64 User registration and authentication ------------------------------------ Next, we're going to create endpoints in FastAPI to handle user registration (sign-up) and user login (sign-in): **Sign-up endpoint** .. code-block:: python :caption: app/auth.py from fastapi import APIRouter, HTTPException, Request from fastapi.responses import JSONResponse import httpx router = APIRouter() # Value should be: # {protocol}://${host}:${port}/branch/${branch}/ext/auth/ GEL_AUTH_BASE_URL = os.getenv('GEL_AUTH_BASE_URL') @router.post("/auth/signup") async def handle_signup(request: Request): body = await request.json() email = body.get("email") password = body.get("password") if not email or not password: raise HTTPException(status_code=400, detail="Missing email or password") verifier, challenge = generate_pkce() register_url = f"{GEL_AUTH_BASE_URL}/register" register_response = httpx.post(register_url, json={ "challenge": challenge, "email": email, "password": password, "provider": "builtin::local_emailpassword", "verify_url": "http://localhost:8000/auth/verify", }) if register_response.status_code != 200 and register_response.status_code != 201: return JSONResponse(status_code=400, content={"message": "Registration failed"}) code = register_response.json().get("code") token_url = f"{GEL_AUTH_BASE_URL}/token" token_response = httpx.get(token_url, params={"code": code, "verifier": verifier}) if token_response.status_code != 200: return JSONResponse(status_code=400, content={"message": "Token exchange failed"}) auth_token = token_response.json().get("auth_token") response = JSONResponse(content={"message": "User registered"}) response.set_cookie(key="gel-auth-token", value=auth_token, httponly=True, secure=True, samesite='strict') return response The sign-up endpoint sends a POST request to the Gel Auth server to register a new user. It also sets the auth token as an HttpOnly cookie in the response. **Sign-in endpoint** .. code-block:: python :caption: app/auth.py @router.post("/auth/signin") async def handle_signin(request: Request): body = await request.json() email = body.get("email") password = body.get("password") provider = body.get("provider") if not email or not password or not provider: raise HTTPException(status_code=400, detail="Missing email, password, or provider.") verifier, challenge = generate_pkce() authenticate_url = f"{GEL_AUTH_BASE_URL}/authenticate" response = httpx.post(authenticate_url, json={ "challenge": challenge, "email": email, "password": password, "provider": provider, }) if response.status_code != 200: return JSONResponse(status_code=400, content={"message": "Authentication failed"}) code = response.json().get("code") token_url = f"{GEL_AUTH_BASE_URL}/token" token_response = httpx.get(token_url, params={"code": code, "verifier": verifier}) if token_response.status_code != 200: return JSONResponse(status_code=400, content={"message": "Token exchange failed"}) auth_token = token_response.json().get("auth_token") response = JSONResponse(content={"message": "Authentication successful"}) response.set_cookie(key="gel-auth-token", value=auth_token, httponly=True, secure=True, samesite='strict') return response The sign-in endpoint sends a POST request to the Gel Auth server to authenticate a user. It then retrieves the code from the response and exchanges it for an auth token. The token is set as an HttpOnly cookie in the response. **Add the auth endpoints to the FastAPI application** Finally, add the auth endpoints to the FastAPI application: .. code-block:: python-diff :caption: app/main.py + fast_api.include_router(events.router) Creating a new user in the sign-up endpoint ------------------------------------------- Now, let's automatically create a new user in the database when a user signs up. We'll use the ``create_user_async_edgeql`` query we generated earlier to achieve this, but we'll need to modify it slightly to link it to the |Gel| Auth identity. First, let's update the Gel schema to include a new field in the User type to store the Gel Auth identity and a new ``current_user`` type. .. code-block:: sdl-diff :caption: dbschema/default.gel + global current_user := assert_single( + (( + select User + filter .identity = global ext::auth::ClientTokenIdentity + )) + ); type User extending Auditable { + required identity: ext::auth::Identity; required name: str { constraint exclusive; constraint max_len_value(50); }; } After updating the schema, run the following command to apply the changes: .. code-block:: bash $ gel migration create $ gel migrate Next, update the ``create_user_async_edgeql`` query to include the identity: .. code-block:: edgeql-diff :caption: app/queries/create_user.edgeql select ( insert User { name := $name, + identity := $identity_id, }) { name, created_at, }; Run ``gel-py`` to generate the new function. Now, let's update the sign-up endpoint to create a new user in the database. We need to do a few things: 1. Import ``gel``. 2. Create a Gel client. 3. Get the identity ID from the Gel Auth server response. 4. Create a new user in the database using the ``create_user_async_edgeql`` query. .. code-block:: python-diff + import gel + client = gel.create_async_client() @router.post("/auth/signup") async def handle_signup(request: Request): body = await request.json() email = body.get("email") + name = body.get("name") password = body.get("password") - if not email or not password: + if not email or not password or not name: - raise HTTPException(status_code=400, detail="Missing email or password.") + raise HTTPException(status_code=400, detail="Missing email, password, or name.") verifier, challenge = generate_pkce() register_url = f"{GEL_AUTH_BASE_URL}/register" register_response = httpx.post(register_url, json={ "challenge": challenge, "email": email, "password": password, "provider": "builtin::local_emailpassword", "verify_url": "http://localhost:8000/auth/verify", }) if register_response.status_code != 200 and register_response.status_code != 201: return JSONResponse(status_code=400, content={"message": "Registration failed"}) code = register_response.json().get("code") token_url = f"{GEL_AUTH_BASE_URL}/token" token_response = httpx.get(token_url, params={"code": code, "verifier": verifier}) if token_response.status_code != 200: return JSONResponse(status_code=400, content={"message": "Token exchange failed"}) auth_token = token_response.json().get("auth_token") + identity_id = token_response.json().get("identity_id") + try: + created_user = await create_user_qry.create_user(client, name=name, identity_id=identity_id) + except gel.errors.ConstraintViolationError: + raise HTTPException( + status_code=400, + detail={"error": f"User with email '{email}' already exists."}, + ) response = JSONResponse(content={"message": "User registered"}) response.set_cookie(key="gel-auth-token", value=auth_token, httponly=True, secure=True, samesite='strict') return response You can now test the sign-up endpoint by sending a POST request to ``http://localhost:8000/auth/signup`` with the following payload: .. code-block:: json { "email": "jonathan@example.com", "name": "Jonathan Harker", "password": "password" } If the request is successful, you should see a response with the message ``User registered``. Wrapping up =========== Now you have a fully functioning events API in FastAPI backed by Gel. If you want to see all the source code for the completed project, you'll find it in `our examples repo `_. We also have a separate example that demonstrates how to integrate Gel Auth with FastAPI in the same repo. Check it out `here `_. If you're stuck or if you just want to show off what you've built, come talk to us `on Discord `_. It's a great community of helpful folks, all passionate about being part of the next generation of databases. ================================================ FILE: docs/resources/guides/tutorials/rest_apis_with_flask.rst ================================================ .. _ref_guide_rest_apis_with_flask: ===== Flask ===== :edb-alt-title: Building a REST API with Gel and Flask The Gel Python client makes it easy to integrate Gel into your preferred web development stack. In this tutorial, we'll see how you can quickly start building RESTful APIs with `Flask `_ and |Gel|. We'll build a simple movie organization system where you'll be able to fetch, create, update, and delete *movies* and *movie actors* via RESTful API endpoints. Prerequisites ============= Before we start, make sure you've :ref:`installed ` the |gelcmd| command-line tool. Here, we'll use Python 3.10 and a few of its latest features while building the APIs. A working version of this tutorial can be found `on Github `_. Install the dependencies ^^^^^^^^^^^^^^^^^^^^^^^^ To follow along, clone the repository and head over to the ``flask-crud`` directory. .. code-block:: bash $ git clone git@github.com:geldata/gel-examples.git $ cd gel-examples/flask-crud Create a Python 3.10 virtual environment, activate it, and install the dependencies with this command: .. code-block:: bash $ python -m venv myvenv $ source myvenv/bin/activate $ pip install gel flask 'httpx[cli]' Initialize the database ^^^^^^^^^^^^^^^^^^^^^^^ Now, let's initialize a Gel project. From the project's root directory: .. code-block:: bash $ gel project init Initializing project... Specify the name of Gel instance to use with this project [default: flask_crud]: > flask_crud Do you want to start instance automatically on login? [y/n] > y Checking Gel versions... Once you've answered the prompts, a new Gel instance called ``flask_crud`` will be created and started. Connect to the database ^^^^^^^^^^^^^^^^^^^^^^^ Let's test that we can connect to the newly started instance. To do so, run: .. code-block:: bash $ gel You should be connected to the database instance and able to see a prompt similar to this: :: Gel x.x (repl x.x) Type \help for help, \quit to quit. gel> You can start writing queries here. However, the database is currently empty. Let's start designing the data model. Schema design ============= The movie organization system will have two object types—**movies** and **actors**. Each *movie* can have links to multiple *actors*. The goal is to create API endpoints that'll allow us to fetch, create, update, and delete the objects while maintaining their relationships. |Gel| allows us to declaratively define the structure of the objects. The schema lives inside |.gel| file in the ``dbschema`` directory. It's common to declare the entire schema in a single file :dotgel:`dbschema/default`. This is how our datatypes look: .. code-block:: sdl # dbschema/default.gel module default { abstract type Auditable { property created_at -> datetime { readonly := true; default := datetime_current(); } } type Actor extending Auditable { required property name -> str { constraint max_len_value(50); } property age -> int16 { constraint min_value(0); constraint max_value(100); } property height -> int16 { constraint min_value(0); constraint max_value(300); } } type Movie extending Auditable { required property name -> str { constraint max_len_value(50); } property year -> int16{ constraint min_value(1850); }; multi link actors -> Actor; } } Here, we've defined an ``abstract`` type called ``Auditable`` to take advantage of Gel's schema mixin system. This allows us to add a ``created_at`` property to multiple types without repeating ourselves. The ``Actor`` type extends ``Auditable`` and inherits the ``created_at`` property as a result. This property is auto-filled via the ``datetime_current`` function. Along with the inherited type, the actor type also defines a few additional properties like called ``name``, ``age``, and ``height``. The constraints on the properties make sure that actor names can't be longer than 50 characters, age must be between 0 to 100 years, and finally, height must be between 0 to 300 centimeters. We also define a ``Movie`` type that extends the ``Auditable`` abstract type. It also contains some additional concrete properties and links: ``name``, ``year``, and an optional multi-link called ``actors`` which refers to the ``Actor`` objects. Build the API endpoints ======================= The API endpoints are defined in the ``app`` directory. The directory structure looks as follows: :: app ├── __init__.py ├── actors.py ├── main.py └── movies.py The ``actors.py`` and ``movies.py`` modules contain the code to build the ``Actor`` and ``Movie`` APIs respectively. The ``main.py`` module then registers all the endpoints and exposes them to the webserver. Fetch actors ^^^^^^^^^^^^ Since the ``Actor`` type is simpler, we'll start with that. Let's create a ``GET /actors`` endpoint so that we can see the ``Actor`` objects saved in the database. You can create the API in Flask like this: .. code-block:: python # flask-crud/app/actors.py from __future__ import annotations import json from http import HTTPStatus import gel from flask import Blueprint, request actor = Blueprint("actor", __name__) client = gel.create_client() @actor.route("/actors", methods=["GET"]) def get_actors() -> tuple[dict, int]: filter_name = request.args.get("filter_name") if not filter_name: actors = client.query_json( """ select Actor { name, age, height } """ ) else: actors = client.query_json( """ select Actor { name, age, height } filter .name = $filter_name """, filter_name=filter_name, ) response_payload = {"result": json.loads(actors)} return response_payload, HTTPStatus.OK The ``Blueprint`` instance does the actual work of exposing the API. We also create a blocking Gel client instance to communicate with the database. By default, this API will return a list of actors, but you can also filter the objects by name. In the ``get_actors`` function, we perform the database query via the ``gel`` client. Here, the ``client.query_json`` method conveniently returns ``JSON`` serialized objects. We deserialize the returned data in the ``response_payload`` dictionary and then return it. Afterward, the final JSON serialization part is taken care of by Flask. This endpoint is exposed to the server in the ``main.py`` module. Here's the content of the module: .. code-block:: python # flask-crud/app/main.py from __future__ import annotations from flask import Flask from app.actors import actor from app.movies import movie app = Flask(__name__) app.register_blueprint(actor) app.register_blueprint(movie) To test the endpoint, go to the ``flask-crud`` directory and run: .. code-block:: bash $ export FLASK_APP=app.main:app && flask run --reload This will start the development server and make it accessible via port 5000. Earlier, we installed the `HTTPx `_ client library to make HTTP requests programmatically. It also comes with a neat command-line tool that we'll use to test our API. While the development server is running, on a new console, run: .. code-block:: bash $ httpx -m GET http://localhost:5000/actors You'll see the following output on the console: :: HTTP/1.1 200 OK Server: Werkzeug/2.1.1 Python/3.10.4 Date: Wed, 27 Apr 2022 18:58:38 GMT Content-Type: application/json Content-Length: 2 { "result": [] } Our request yielded an empty list because the database is currently empty. Let's create the ``POST /actors`` endpoint to start saving actors in the database. Create actor ^^^^^^^^^^^^ The POST endpoint can be built similarly: .. code-block:: python # flask-crud/app/actors.py ... @actor.route("/actors", methods=["POST"]) def post_actor() -> tuple[dict, int]: incoming_payload = request.json # Data validation. if not incoming_payload: return { "error": "Bad request" }, HTTPStatus.BAD_REQUEST if not (name := incoming_payload.get("name")): return { "error": "Field 'name' is required." }, HTTPStatus.BAD_REQUEST if len(name) > 50: return { "error": "Field 'name' cannot be longer than 50 " "characters." }, HTTPStatus.BAD_REQUEST if age := incoming_payload.get("age"): if 0 <= age <= 100: return { "error": "Field 'age' must be between 0 " "and 100." }, HTTPStatus.BAD_REQUEST if height := incoming_payload.get("height"): if not 0 <= height <= 300: return { "error": "Field 'height' must between 0 and " "300 cm." }, HTTPStatus.BAD_REQUEST # Create object. actor = client.query_single_json( """ with name := $name, age := $age, height := $height select ( insert Actor { name := name, age := age, height := height } ){ name, age, height }; """, name=name, age=age, height=height, ) response_payload = {"result": json.loads(actor)} return response_payload, HTTPStatus.CREATED In the above snippet, we perform data validation in the conditional blocks and then make the query to create the object in the database. For now, we'll only allow creating a single object per request. The ``client.query_single_json`` ensures that we're creating and returning only one object. Inside the query string, notice, how we're using ```` to deal with the optional fields. If the user doesn't provide the value of an optional field like ``age`` or ``height``, it'll be defaulted to ``null``. To test it out, make a request as follows: .. code-block:: bash $ httpx -m POST http://localhost:5000/actors \ -j '{"name" : "Robert Downey Jr."}' The output should look similar to this: :: HTTP/1.1 201 CREATED ... { "result": { "age": null, "height": null, "name": "Robert Downey Jr." } } Before we move on to the next step, create 2 more actors called ``Chris Evans`` and ``Natalie Portman``. Now that we have some data in the database, let's make a ``GET`` request to see the objects: .. code-block:: bash $ httpx -m GET http://localhost:5000/actors The response looks as follows: :: HTTP/1.1 200 OK ... { "result": [ { "age": null, "height": null, "name": "Robert Downey Jr." }, { "age": null, "height": null, "name": "Chris Evans" }, { "age": null, "height": null, "name": "Natalie Portman" } ] } You can filter the output of the ``GET /actors`` by ``name``. To do so, use the ``filter_name`` query parameter like this: .. code-block:: bash $ httpx -m GET http://localhost:5000/actors \ -p filter_name "Robert Downey Jr." Doing this will only display the data of a single object: :: HTTP/1.1 200 OK { "result": [ { "age": null, "height": null, "name": "Robert Downey Jr." } ] } Once you've done that, we can move on to the next step of building the ``PUT /actors`` endpoint to update the actor data. Update actor ^^^^^^^^^^^^ It can be built like this: .. code-block:: python # flask-crud/app/actors.py # ... @actor.route("/actors", methods=["PUT"]) def put_actors() -> tuple[dict, int]: incoming_payload = request.json filter_name = request.args.get("filter_name") # Data validation. if not incoming_payload: return { "error": "Bad request" }, HTTPStatus.BAD_REQUEST if not filter_name: return { "error": "Query parameter 'filter_name' must " "be provided", }, HTTPStatus.BAD_REQUEST if (name:=incoming_payload.get("name")) and len(name) > 50: return { "error": "Field 'name' cannot be longer than " "50 characters." }, HTTPStatus.BAD_REQUEST if age := incoming_payload.get("age"): if age <= 0: return { "error": "Field 'age' cannot be less than " "or equal to 0." }, HTTPStatus.BAD_REQUEST if height := incoming_payload.get("height"): if not 0 <= height <= 300: return { "error": "Field 'height' must between 0 " "and 300 cm." }, HTTPStatus.BAD_REQUEST # Update object. actors = client.query_json( """ with filter_name := $filter_name, name := $name, age := $age, height := $height select ( update Actor filter .name = filter_name set { name := name ?? .name, age := age ?? .age, height := height ?? .height } ){ name, age, height };""", filter_name=filter_name, name=name, age=age, height=height, ) response_payload = {"result": json.loads(actors)} return response_payload, HTTPStatus.OK Here, we'll isolate the intended object that we want to update by filtering the actors with the ``filter_name`` parameter. For example, if you wanted to update the properties of ``Robert Downey Jr.``, the value of the ``filter_name`` query parameter would be ``Robert Downey Jr.``. The coalesce operator ``??`` in the query string makes sure that the API user can selectively update the properties of the target object and the other properties keep their existing values. The following command updates the ``age`` and ``height`` of ``Robert Downey Jr.``. .. code-block:: bash $ httpx -m PUT http://localhost:5000/actors \ -p filter_name "Robert Downey Jr." \ -j '{"age": 57, "height": 173}' This will return: :: HTTP/1.1 200 OK ... { "result": [ { "age": 57, "height": 173, "name": "Robert Downey Jr." } ] } Delete actor ^^^^^^^^^^^^ Another API that we'll need to cover is the ``DELETE /actors`` endpoint. It'll allow us to query the name of the targeted object and delete that. The code looks similar to the ones you've already seen: .. code-block:: python # flask-crud/app/actors.py ... @actor.route("/actors", methods=["DELETE"]) def delete_actors() -> tuple[dict, int]: if not (filter_name := request.args.get("filter_name")): return { "error": "Query parameter 'filter_name' must " "be provided", }, HTTPStatus.BAD_REQUEST try: actors = client.query_json( """select ( delete Actor filter .name = $filter_name ) {name} """, filter_name=filter_name, ) except gel.errors.ConstraintViolationError: return ( { "error": f"Cannot delete '{filter_name}. " "Actor is associated with at least one movie." }, HTTPStatus.BAD_REQUEST, ) response_payload = {"result": json.loads(actors)} return response_payload, HTTPStatus.OK This endpoint will simply delete the requested actor if the actor isn't attached to any movie. If the targeted object is attached to a movie, then API will throw an HTTP 400 (bad request) error and refuse to delete the object. To delete ``Natalie Portman``, on your console, run: .. code-block:: bash $ httpx -m DELETE http://localhost:5000/actors \ -p filter_name "Natalie Portman" That'll return: :: HTTP/1.1 200 OK ... { "result": [ { "name": "Natalie Portman" } ] } Now let's move on to building the ``Movie`` API. Create movie ^^^^^^^^^^^^ Here's how we'll implement the ``POST /movie`` endpoint: .. code-block:: python # flask-crud/app/movies.py from __future__ import annotations import json from http import HTTPStatus import gel from flask import Blueprint, request movie = Blueprint("movie", __name__) client = gel.create_client() @movie.route("/movies", methods=["POST"]) def post_movie() -> tuple[dict, int]: incoming_payload = request.json # Data validation. if not incoming_payload: return { "error": "Bad request" }, HTTPStatus.BAD_REQUEST if not (name := incoming_payload.get("name")): return { "error": "Field 'name' is required." }, HTTPStatus.BAD_REQUEST if len(name) > 50: return { "error": "Field 'name' cannot be longer than " "50 characters." }, HTTPStatus.BAD_REQUEST if year := incoming_payload.get("year"): if year < 1850: return { "error": "Field 'year' cannot be less " "than 1850." }, HTTPStatus.BAD_REQUEST actor_names = incoming_payload.get("actor_names") # Create object. movie = client.query_single_json( """ with name := $name, year := $year, actor_names := >$actor_names select ( insert Movie { name := name, year := year, actors := ( select Actor filter .name in array_unpack(actor_names) ) } ){ name, year, actors: {name, age, height} }; """, name=name, year=year, actor_names=actor_names, ) response_payload = {"result": json.loads(movie)} return response_payload, HTTPStatus.CREATED Like the ``POST /actors`` API, conditional blocks validate the shape of the incoming data and the ``client.query_json`` method creates the object in the database. EdgeQL allows us to perform insertion and selection of data fields at the same time in a single query. One thing that's different here is that the ``POST /movies`` API also accepts an optional field called ``actor_names`` where the user can provide an array of actor names. The backend will associate the actors with the movie object if those actors exist in the database. Here's how you'd create a movie: .. lint-off .. code-block:: bash $ httpx -m POST http://localhost:5000/movies \ -j '{ "name": "The Avengers", "year": 2012, "actor_names": [ "Robert Downey Jr.", "Chris Evans" ] }' .. lint-on That'll return: :: HTTP/1.1 201 CREATED ... { "result": { "actors": [ { "age": null, "height": null, "name": "Chris Evans" }, { "age": 57, "height": 173, "name": "Robert Downey Jr." } ], "name": "The Avengers", "year": 2012 } } Additional movie endpoints ^^^^^^^^^^^^^^^^^^^^^^^^^^ The implementation of the ``GET /movie``, ``PATCH /movie`` and ``DELETE /movie`` endpoints are provided in the sample codebase in ``app/movies.py``. But try to write them on your own using the Actor endpoints as a starting point! Once you're done, you should be able to fetch a movie by its title from your database with the ``filter_name`` parameter and the GET API as follows: .. code-block:: bash $ httpx -m GET http://localhost:5000/movies \ -p 'filter_name' 'The Avengers' That'll return: :: HTTP/1.1 200 OK ... { "result": [ { "actors": [ { "age": null, "name": "Chris Evans" }, { "age": 57, "name": "Robert Downey Jr." } ], "name": "The Avengers", "year": 2012 } ] } Conclusion ========== While building REST APIs, the Gel client allows you to leverage Gel with any microframework of your choice. Whether it's `FastAPI `_, `Flask `_, `AIOHTTP `_, `Starlette `_, or `Tornado `_, the core workflow is quite similar to the one demonstrated above; you'll query and serialize data with the client and then return the payload for your framework to process. ================================================ FILE: docs/resources/guides/tutorials/trpc.rst ================================================ .. _ref_guide_trpc: ==== tRPC ==== :edb-alt-title: Integrating Gel with tRPC This guide explains how to integrate **Gel** with **tRPC** for a modern, type-safe API. We'll cover setting up database interactions, API routing, and implementing authentication, all while ensuring type safety across the client and server. You can reference the following repositories for more context: - `create-t3-turbo-gel `_ - A monorepo template using the `T3 stack `_, `Turborepo `_, and Gel. - `LookFeel Project `_ - A real-world example using **Gel** and **tRPC**. Step 1: Gel setup ================= |Gel| will serve as the database layer for your application. Install and initialize Gel -------------------------- To initialize **Gel**, run the following command using your preferred package manager: .. code-block:: bash $ pnpm dlx gel project init # or `npx gel project init` This will create a Gel project and set up a schema to start with. Define the Gel Schema --------------------- The previous command generated a schema file in the ``dbschema`` directory. Here's an example schema that defines a ``User`` model: .. code-block:: sdl :caption: dbschema/default.gel module default { type User { required name: str; required email: str; } } Apply schema migrations ----------------------- Once schema changes are made, apply migrations with: .. code-block:: bash $ pnpm dlx gel migration create # or npx gel migration create $ pnpm dlx gel migration apply # or npx gel migration apply Step 2: Configure Gel Client ============================ To interact with **Gel** from your application, you need to configure the client. Install Gel Client ------------------ First, install the **Gel** client using your package manager: .. code-block:: bash $ pnpm add gel $ # or yarn add gel $ # or npm install gel $ # or bun add gel Then, create a client instance in a ``gel.ts`` file: .. code-block:: typescript :caption: src/gel.ts import { createClient } from 'gel'; const gelClient = createClient(); export default gelClient; This client will be used to interact with the database and execute queries. Step 3: tRPC setup ================== **tRPC** enables type-safe communication between the frontend and backend. Install tRPC dependencies ------------------------- Install the required tRPC dependencies: .. code-block:: bash $ pnpm add @trpc/server @trpc/client $ # or yarn add @trpc/server @trpc/client $ # or npm install @trpc/server @trpc/client $ # or bun add @trpc/server @trpc/client If you're using React and would like to use React Query with tRPC, also install a wrapper around the `@tanstack/react-query `_. .. code-block:: bash $ pnpm add @trpc/react-query $ # or yarn add @trpc/react-query $ # or npm install @trpc/react-query $ # or bun add @trpc/react-query Define the tRPC Router ----------------------- Here's how to define a simple tRPC query that interacts with **Gel**: .. code-block:: typescript :caption: server/routers/_app.ts import { initTRPC } from '@trpc/server'; import gelClient from './gel'; const t = initTRPC.create(); export const appRouter = t.router({ getUsers: t.procedure.query(async () => { const users = await gelClient.query('SELECT User { name, email }'); return users; }), }); export type AppRouter = typeof appRouter; This example defines a query that fetches user data from Gel, ensuring type safety in both the query and response. Step 4: Use tRPC Client ======================== Now that the server is set up, you can use the tRPC client to interact with the API from the frontend. We will demonstrate how to integrate tRPC with **Next.js** and **Express**. With Next.js ------------ If you're working with **Next.js**, here's how to integrate **tRPC**: Create a tRPC API Handler ~~~~~~~~~~~~~~~~~~~~~~~~~ Inside ``api/trpc/[trpc].ts``, create the following handler to connect **tRPC** with Next.js: .. code-block:: typescript :caption: pages/api/trpc/[trpc].ts import { createNextApiHandler } from '@trpc/server/adapters/next'; import { appRouter } from '../../../server/routers/_app'; export default createNextApiHandler({ router: appRouter, }); Create a tRPC Client ~~~~~~~~~~~~~~~~~~~~ Next, create a **tRPC** client to interact with the API: .. code-block:: typescript :caption: utils/trpc.ts import { createTRPCReact } from "@trpc/react-query"; import { AppRouter } from './routers/_app'; export const api = createTRPCReact(); Client-Side Usage in Next.js ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ You can then use **tRPC** hooks to query the API from the client: .. code-block:: typescript :caption: components/UsersComponent.tsx import { trpc } from '../utils/trpc'; const UsersComponent = () => { const { data, isLoading } = trpc.getUsers.useQuery(); if (isLoading) return
Loading...
; return (
{data?.map(user => (

{user.name}

))}
); }; export default UsersComponent; Alternative Path: Use tRPC with Express --------------------------------------- If you're not using **Next.js**, here's how you can integrate **tRPC** with **Express**. Set up Express server with tRPC ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Here's how you can create an Express server and integrate **tRPC**: .. code-block:: typescript import express from 'express'; import { appRouter } from './routers/_app'; import * as trpcExpress from '@trpc/server/adapters/express'; const app = express(); app.use( '/trpc', trpcExpress.createExpressMiddleware({ router: appRouter, }) ); app.listen(4000, () => { console.log('Server is running on port 4000'); }); Client-side usage ----------------- In non-Next.js apps, use the tRPC client to interact with the server: .. code-block:: typescript import { createTRPCClient, httpBatchLink } from '@trpc/client'; import { AppRouter } from './routers/_app'; const trpc = createTRPCClient({ links: [ httpBatchLink({ url: 'http://localhost:4000/trpc', }), ], }); async function fetchUsers() { const users = await trpc.getUsers.query(); console.log(users); } Step 5: Set up authentication with Gel Auth =========================================== In this section, we will cover how to integrate **Gel Auth** with **tRPC** and context in both **Next.js** and **Express** environments. This will ensure that user authentication is handled securely and that both server-side and client-side tRPC calls can access the user's session. Gel Auth with tRPC and tRPC context in Next.js ---------------------------------------------- In **Next.js**, integrating **Gel Auth** with **tRPC** involves creating a context that provides the user session and Gel client to the tRPC API. 1. **Initialize Gel Client and Auth** First, initialize the **Gel** client and **Gel Auth**: .. code-block:: typescript import { createClient } from "gel"; import createAuth from "@gel/auth-nextjs/app"; // Initialize Gel client export const gelClient = createClient(); // Initialize Gel Auth export const auth = createAuth(gelClient, { baseUrl: process.env.VERCEL_ENV === "production" ? "https://production.yourapp.com" : "http://localhost:3000", }); 2. **Create tRPC Context** The **tRPC** context provides the Gel Auth session to the tRPC procedures: .. code-block:: typescript :caption: src/trpc.ts import { initTRPC } from '@trpc/server'; import { headers } from "next/headers"; import { auth } from "src/gel.ts"; // Create tRPC context with session and Gel client export const createTRPCContext = async () => { const session = await auth.getSession(); // Retrieve session from Gel Auth return { session, // Pass the session to the context }; }; // Initialize tRPC with context const t = initTRPC.context().create({}); 3. **Use tRPC Context in API Handler** In **Next.js**, set up an API handler to connect your **tRPC router** with the context: .. code-block:: typescript :caption: pages/api/trpc/[trpc].ts import { createNextApiHandler } from '@trpc/server/adapters/next'; import { createTRPCContext } from 'src/trpc.ts'; import { appRouter } from 'src/routers/_app'; export default createNextApiHandler({ router: appRouter, // Your tRPC router createContext: createTRPCContext, }); 4. **Example tRPC Procedure** You can now write procedures in your tRPC router, making use of the **Gel Auth** session and the **Gel** client: .. code-block:: typescript export const appRouter = t.router({ getUserData: t.procedure.query(async ({ ctx }) => { if (!(await ctx.session.isSignedIn())) { throw new Error("Not authenticated"); } // Fetch data from Gel using the authenticated client const userData = await ctx.session.client.query(` select User { name, email } `); return userData; }), }); Gel Auth with tRPC and Context in Express ----------------------------------------- In **Express**, the process involves setting up middleware to manage the authentication and context for tRPC procedures. 1. **Initialize Gel Client and Auth for Express** Just like in **Next.js**, you first initialize the **Gel** client and **Gel Auth**: .. code-block:: typescript import { createClient } from "gel"; import createExpressAuth from "@gel/auth-express"; // Initialize Gel client const gelClient = createClient(); // Initialize Gel Auth for Express export const auth = createExpressAuth(gelClient, { baseUrl: `http://localhost:${process.env.PORT || 3000}`, }); 2. **Create tRPC Context Middleware for Express** In **Express**, create middleware to pass the authenticated session and Gel client to the tRPC context: .. code-block:: typescript import { type AuthRequest, type Response, type NextFunction } from "express"; // Middleware to set up tRPC context in Express export const createTRPCContextMiddleware = async ( req: AuthRequest, res: Response, next: NextFunction ) => { const session = req.auth?.session(); // Get authenticated session req.context = { session, // Add session to context gelClient, // Add Gel client to context }; next(); }; 3. **Set up tRPC Router in Express** Use the **tRPC router** in **Express** by including the context middleware and **Gel Auth** middleware: .. code-block:: typescript import express from "express"; import { appRouter } from "./path-to-router"; import { auth } from "./path-to-auth"; import { createTRPCContextMiddleware } from "./path-to-context"; import { createExpressMiddleware } from "@trpc/server/adapters/express"; const app = express(); // Gel Auth middleware to handle sessions app.use(auth.middleware); // Custom middleware to pass tRPC context app.use(createTRPCContextMiddleware); // tRPC route setup app.use( "/trpc", createExpressMiddleware({ router: appRouter, createContext: (req) => req.context, // Use context from middleware }) ); app.listen(4000, () => { console.log('Server running on port 4000'); }); 4. **Example tRPC Procedure in Express** Once the context is set, you can define tRPC procedures that use both the session and Gel client: .. code-block:: typescript export const appRouter = t.router({ getUserData: t.procedure.query(async ({ ctx }) => { if (!(await ctx.session.isSignedIn())) { throw new Error("Not authenticated"); } // Fetch data from Gel using the authenticated client const userData = await ctx.session.client.query(` select User { name, email } `); return userData; }), }); Conclusion ---------- By integrating **Gel Auth** into the tRPC context, you ensure that authenticated sessions are securely passed to API procedures, enabling user authentication and protecting routes. You can also reference these projects for further examples: - `create-t3-turbo-gel `_ - A monorepo template using the `T3 stack `_, `Turborepo `_, and Gel. - `LookFeel Project `_ - A real-world example using **Gel** and **tRPC**. ================================================ FILE: docs/resources/index.rst ================================================ ========= Resources ========= .. toctree:: :maxdepth: 3 :hidden: upgrading guides/index protocol/index cheatsheets/index changelog/index ================================================ FILE: docs/resources/protocol/dataformats.rst ================================================ .. _ref_proto_dataformats: ================= Data wire formats ================= This section describes the data wire format of standard Gel types. .. _ref_protocol_fmt_array: Sets and array<> ================ Set and array values are represented as the following structure: .. code-block:: c struct SetOrArrayValue { // Number of dimensions, currently must // always be 0 or 1. 0 indicates an empty set or array. int32 ndims; // Reserved. int32 reserved0; // Reserved. int32 reserved1; // Dimension data. Dimension dimensions[ndims]; // Element data, the number of elements // in this array is the sum of dimension sizes: // sum((d.upper - d.lower + 1) for d in dimensions) Element elements[]; }; struct Dimension { // Upper dimension bound, inclusive, // number of elements in the dimension // relative to the lower bound. int32 upper; // Lower dimension bound, always 1. int32 lower; }; struct Element { // Encoded element data length in bytes. int32 length; // Element data. uint8 data[length]; }; Note: zero-length arrays (and sets) are represented as a 12-byte value where ``dims`` equal to zero regardless of the shape in type descriptor. Sets of arrays are a special case. Every array within a set is wrapped in an Envelope. The full structure follows: .. code-block:: c struct SetOfArrayValue { // Number of dimensions, currently must // always be 0 or 1. 0 indicates an empty set. int32 ndims; // Reserved. int32 reserved0; // Reserved. int32 reserved1; // Dimension data. Same layout as above. Dimension dimensions[ndims]; // Envelope data, the number of elements // in this array is the sum of dimension sizes: // sum((d.upper - d.lower + 1) for d in dimensions) Envelope elements[]; }; struct Envelope { // Encoded envelope element length in bytes. int32 length; // Number of elements, currently must // always be 1. int32 nelems; // Reserved. int32 reserved // Element data. Same layout as above. Element element[nelems]; }; .. _ref_protocol_fmt_tuple: tuple<>, namedtuple<>, and object<> ==================================== Tuple, namedtuple and object values are represented as the following structure: .. code-block:: c struct TupleOrNamedTupleOrObjectValue { // Number of elements int32 nelems; // Element data. Element elements[nelems]; }; struct Element { // Reserved. int32 reserved; // Encoded element data length in bytes. int32 length; // Element data. uint8 data[length]; }; Note that for objects, ``Element.length`` can be set to ``-1``, which means an empty set. .. _ref_protocol_fmt_sparse_obj: Sparse Objects ============== Sparse object values are represented as the following structure: .. code-block:: c struct SparseObjectValue { // Number of elements int32 nelems; // Element data. Element elements[nelems]; }; struct Element { // Index of the element in the input shape. int32 index; // Encoded element data length in bytes. int32 length; // Element data. uint8 data[length]; }; .. _ref_protocol_fmt_range: Ranges ====== Range values are represented as the following structure: .. code-block:: c struct Range { // A bit mask of range definition. uint8 flags; // Lower boundary data. Boundary lower; // Upper boundary data. Boundary upper; }; struct Boundary { // Encoded boundary data length in bytes. int32 length; // Boundary data. uint8 data[length]; }; enum RangeFlag { // Empty range. EMPTY = 0x0001; // Included lower boundary. LB_INC = 0x0002; // Included upper boundary. UB_INC = 0x0004; // Inifinity (excluded) lower boundary. LB_INF = 0x0008; // Infinity (excluded) upper boundary. UB_INF = 0x0010; }; .. _ref_protocol_fmt_uuid: std::uuid ========= :eql:type:`std::uuid` values are represented as a sequence of 16 unsigned byte values. For example, the UUID value ``b9545c35-1fe7-485f-a6ea-f8ead251abd3`` is represented as: .. code-block:: c 0xb9 0x54 0x5c 0x35 0x1f 0xe7 0x48 0x5f 0xa6 0xea 0xf8 0xea 0xd2 0x51 0xab 0xd3 .. _ref_protocol_fmt_str: std::str ======== :eql:type:`std::str` values are represented as a UTF-8 encoded byte string. For example, the ``str`` value ``'Hello! 🙂'`` is encoded as: .. code-block:: c 0x48 0x65 0x6c 0x6c 0x6f 0x21 0x20 0xf0 0x9f 0x99 0x82 .. _ref_protocol_fmt_bytes: std::bytes ========== :eql:type:`std::bytes` values are represented as is. .. _ref_protocol_fmt_int16: std::int16 ========== :eql:type:`std::int16` values are represented as two bytes, most significant byte first. For example, the ``int16`` value ``6556`` is represented as: .. code-block:: c 0x19 0x9c .. _ref_protocol_fmt_int32: std::int32 ========== :eql:type:`std::int32` values are represented as four bytes, most significant byte first. For example, the ``int32`` value ``655665`` is represented as: .. code-block:: c 0x00 0x0a 0x01 0x31 .. _ref_protocol_fmt_int64: std::int64 ========== :eql:type:`std::int64` values are represented as eight bytes, most significant byte first. For example, the ``int64`` value ``123456789987654321`` is represented as: .. code-block:: c 0x01 0xb6 0x9b 0x4b 0xe0 0x52 0xfa 0xb1 .. _ref_protocol_fmt_float32: std::float32 ============ :eql:type:`std::float32` values are represented as an IEEE 754-2008 binary 32-bit value, most significant byte first. For example, the ``float32`` value ``-15.625`` is represented as: .. code-block:: c 0xc1 0x7a 0x00 0x00 .. _ref_protocol_fmt_float64: std::float64 ============ :eql:type:`std::float64` values are represented as an IEEE 754-2008 binary 64-bit value, most significant byte first. For example, the ``float64`` value ``-15.625`` is represented as: .. code-block:: c 0xc0 0x2f 0x40 0x00 0x00 0x00 0x00 0x00 .. _ref_protocol_fmt_decimal: std::decimal ============ :eql:type:`std::decimal` values are represented as the following structure: .. code-block:: c struct Decimal { // Number of digits in digits[], can be 0. uint16 ndigits; // Weight of first digit. int16 weight; // Sign of the value uint16 sign; // Value display scale. uint16 dscale; // base-10000 digits. uint16 digits[ndigits]; }; enum DecimalSign { // Positive value. POS = 0x0000; // Negative value. NEG = 0x4000; }; Decimal values are represented as a sequence of base-10000 *digits*. The first digit is assumed to be multiplied by *weight* * 10000, i.e. there might be up to weight + 1 digits before the decimal point. Trailing zeros may be absent. It is possible to have negative weight. *dscale*, or display scale, is the nominal precision expressed as number of base-10 digits after the decimal point. It is always non-negative. *dscale* may be more than the number of physically present fractional digits, implying significant trailing zeroes. The actual number of digits physically present in the *digits* array contains trailing zeros to the next 4-byte increment (meaning that integer and fractional part are always distinct base-10000 digits). For example, the decimal value ``-15000.6250000`` is represented as: .. code-block:: c // ndigits 0x00 0x04 // weight 0x00 0x01 // sign 0x40 0x00 // dscale 0x00 0x07 // digits 0x00 0x01 0x13 0x88 0x18 0x6a 0x00 0x00 .. _ref_protocol_fmt_bool: std::bool ========= :eql:type:`std::bool` values are represented as an int8 with only two valid values: ``0x01`` for ``true`` and ``0x00`` for ``false``. .. _ref_protocol_fmt_datetime: std::datetime ============= :eql:type:`std::datetime` values are represented as a 64-bit integer, most sigificant byte first. The value is the number of *microseconds* between the encoded datetime and January 1st 2000, 00:00 UTC. A Unix timestamp can be converted into a Gel ``datetime`` value using this formula: .. code-block:: c edb_datetime = (unix_ts + 946684800) * 1000000 For example, the ``datetime`` value ``'2019-05-06T12:00+00:00'`` is encoded as: .. code-block:: c 0x00 0x02 0x2b 0x35 0x9b 0xc4 0x10 0x00 See the :ref:`client libraries ` section for more info about how to handle different precision when encoding data. .. _ref_protocol_fmt_local_datetime: cal::local_datetime =================== :eql:type:`cal::local_datetime` values are represented as a 64-bit integer, most sigificant byte first. The value is the number of *microseconds* between the encoded datetime and January 1st 2000, 00:00. For example, the ``local_datetime`` value ``'2019-05-06T12:00'`` is encoded as: .. code-block:: c 0x00 0x02 0x2b 0x35 0x9b 0xc4 0x10 0x00 See the :ref:`client libraries ` section for more info about how to handle different precision when encoding data. .. _ref_protocol_fmt_local_date: cal::local_date =============== :eql:type:`cal::local_date` values are represented as a 32-bit integer, most sigificant byte first. The value is the number of *days* between the encoded date and January 1st 2000. For example, the ``local_date`` value ``'2019-05-06'`` is encoded as: .. code-block:: c 0x00 0x00 0x1b 0x99 .. _ref_protocol_fmt_local_time: cal::local_time =============== :eql:type:`cal::local_time` values are represented as a 64-bit integer, most sigificant byte first. The value is the number of *microseconds* since midnight. For example, the ``local_time`` value ``'12:10'`` is encoded as: .. code-block:: c 0x00 0x00 0x00 0x0a 0x32 0xae 0xf6 0x00 See the :ref:`client libraries ` section for more info about how to handle different precision when encoding data. .. _ref_protocol_fmt_duration: std::duration ============= The :eql:type:`std::duration` values are represented as the following structure: .. code-block:: c struct Duration { int64 microseconds; // deprecated, is always 0 int32 days; // deprecated, is always 0 int32 months; }; For example, the ``duration`` value ``'48 hours 45 minutes 7.6 seconds'`` is encoded as: .. code-block:: c // microseconds 0x00 0x00 0x00 0x28 0xdd 0x11 0x72 0x80 // days 0x00 0x00 0x00 0x00 // months 0x00 0x00 0x00 0x00 See the :ref:`client libraries ` section for more info about how to handle different precision when encoding data. .. _ref_protocol_fmt_relative_duration: cal::relative_duration ====================== The :eql:type:`cal::relative_duration` values are represented as the following structure: .. code-block:: c struct Duration { int64 microseconds; int32 days; int32 months; }; For example, the ``cal::relative_duration`` value ``'2 years 7 months 16 days 48 hours 45 minutes 7.6 seconds'`` is encoded as: .. code-block:: c // microseconds 0x00 0x00 0x00 0x28 0xdd 0x11 0x72 0x80 // days 0x00 0x00 0x00 0x10 // months 0x00 0x00 0x00 0x1f See the :ref:`client libraries ` section for more info about how to handle different precision when encoding data. .. _ref_protocol_fmt_date_duration: cal::date_duration ================== :eql:type:`cal::date_duration` values are represented as the following structure: .. code-block:: c struct DateDuration { int64 reserved; int32 days; int32 months; }; For example, the ``cal::date_duration`` value ``'1 years 2 days'`` is encoded as: .. code-block:: c // reserved 0x00 0x00 0x00 0x00 0x00 0x00 0x00 0x00 // days 0x00 0x00 0x00 0x02 // months 0x00 0x00 0x00 0x0c .. _ref_protocol_fmt_json: std::json ========= :eql:type:`std::json` values are represented as the following structure: .. code-block:: c struct JSON { uint8 format; uint8 jsondata[]; }; *format* is currently always ``1``, and *jsondata* is a UTF-8 encoded JSON string. .. _ref_protocol_fmt_bigint: std::bigint =========== :eql:type:`std::bigint` values are represented as the following structure: .. code-block:: c struct BigInt { // Number of digits in digits[], can be 0. uint16 ndigits; // Weight of first digit. int16 weight; // Sign of the value uint16 sign; // Reserved value, must be zero uint16 reserved; // base-10000 digits. uint16 digits[ndigits]; }; enum BigIntSign { // Positive value. POS = 0x0000; // Negative value. NEG = 0x4000; }; Decimal values are represented as a sequence of base-10000 *digits*. The first digit is assumed to be multiplied by *weight* * 10000, i.e. there might be up to weight + 1 digits. Trailing zeros may be absent. For example, the bigint value ``-15000`` is represented as: .. code-block:: c // ndigits 0x00 0x02 // weight 0x00 0x01 // sign 0x40 0x00 // reserved 0x00 0x00 // digits 0x00 0x01 0x13 0x88 .. _ref_protocol_fmt_memory: cfg::memory =========== :eql:type:`cfg::memory` values are represented as a number of *bytes* encoded as a 64-bit integer, most sigificant byte first. For example, the ``cfg::memory`` value ``123MiB`` is represented as: .. code-block:: c 0x00 0x00 0x00 0x00 0x07 0xb0 0x00 0x00 ================================================ FILE: docs/resources/protocol/dump_format.rst ================================================ Dump file format ================ This description uses the same :ref:`conventions ` as the protocol description. General Structure ----------------- Dump file is structure as follows: 1. Dump file format marker ``\xFF\xD8\x00\x00\xD8EDGEDB\x00DUMP\x00`` (17 bytes) 2. Format version number ``\x00\x00\x00\x00\x00\x00\x00\x01`` (8 bytes) 3. Header block 4. Any number of data blocks General Dump Block ------------------ Both header and data blocks are formatted as follows: .. code-block:: c struct DumpHeader { int8 mtype; // SHA1 hash sum of block data byte sha1sum[20]; // Length of message contents in bytes, // including self. int32 message_length; // Block data. Should be treated in opaque way by a client. byte data[message_length]; } Upon receiving a protocol dump data message, the dump client should: * Replace packet type: * ``@`` (0x40) → ``H`` (0x48) * ``=`` (0x3d) → ``D`` (0x44) * Prepend SHA1 checksum to the block * Append the entire dump protocol message disregarding the first byte (the message type). Header Block ------------ Format: .. code-block:: c struct DumpHeader { // Message type ('H') int8 mtype = 0x48; // SHA1 hash sum of block data byte sha1sum[20]; // Length of message contents in bytes, // including self. int32 message_length; // A set of message headers. Headers headers; // Protocol version of the dump int16 major_ver; int16 minor_ver; // Schema data string schema_ddl; // Type identifiers int32 num_types; TypeInfo types[num_types]; // Object descriptors int32 num_descriptors; ObjectDesc descriptors[num_descriptors] }; struct TypeInfo { string type_name; string type_class; byte type_id[16]; } struct ObjectDesc { byte object_id[16]; bytes description; int16 num_dependencies; byte dependency_id[num_dependencies][16]; } Known headers: * 101 ``BLOCK_TYPE`` -- block type, always "I" * 102 ``SERVER_TIME`` -- server time when dump is started as a floating point unix timestamp stringified * 103 ``SERVER_VERSION`` -- full version of server as string * 105 ``SERVER_CATALOG_VERSION`` -- the catalog version of the server, as a 64-bit integer. The catalog version is an identifier that is incremented whenever a change is made to the database layout or standard library. Data Block ---------- Format: .. code-block:: c struct DumpBlock { // Message type ('=') int8 mtype = 0x3d; // Length of message contents in bytes, // including self. int32 message_length; // A set of message headers. Headers headers; } Known headers: * 101 ``BLOCK_TYPE`` -- block type, always "D" * 110 ``BLOCK_ID`` -- block identifier (16 bytes of UUID) * 111 ``BLOCK_NUM`` -- integer block index stringified * 112 ``BLOCK_DATA`` -- the actual block data ================================================ FILE: docs/resources/protocol/errors.rst ================================================ .. _ref_protocol_errors: ====== Errors ====== Errors inheritance ================== Each error in Gel consists of a code, a name, and optionally tags. Errors in Gel can inherit from other errors. This is denoted by matching code prefixes. For example, ``TransactionConflictError`` (``0x_05_03_01_00``) is the parent error for ``TransactionSerializationError`` (``0x_05_03_01_01``) and ``TransactionDeadlockError`` (``0x_05_03_01_02``). The matching prefix here is ``0x_05_03_01``. When the Gel client expects a more general error and Gel returns a more specific error that inherits from the general error, the check in the client must take this into account. This can be expressed by the ``binary and`` operation or ``&`` opeator in most programming languages: .. code-block:: (expected_error_code & server_error_code) == expected_error_code Note that although it is not explicitly stated in the ``edb/api/errors.txt`` file, each inherited error must contain all tags of the parent error. Given that, ``TransactionSerializationError`` and ``TransactionDeadlockError``, for example, must contain the ``SHOULD_RETRY`` tag that is defined for ``TransactionConflictError``. .. _ref_protocol_error_codes: Error codes =========== Error codes and names as specified in ``edb/api/errors.txt``: .. raw:: text :file: errors.txt ================================================ FILE: docs/resources/protocol/index.rst ================================================ .. _ref_protocol_overview: =============== Binary protocol =============== |Gel| uses a message-based binary protocol for communication between clients and servers. The protocol is supported over TCP/IP. .. toctree:: :maxdepth: 3 :hidden: messages errors typedesc dataformats dump_format .. _ref_protocol_connecting: Connecting to Gel ================= The Gel binary protocol has two modes of operation: sockets and HTTP tunnelling. When connecting to Gel, the client can specify an accepted `ALPN Protocol`_ to use. If the client does not specify an ALPN protocol, HTTP tunnelling is assumed. Sockets ------- When using the ``edgedb-binary`` ALPN protocol, the client and server communicate over a raw TCP/IP socket, following the :ref:`message format ` and :ref:`message flow ` described below. .. _ref_http_tunnelling: HTTP Tunnelling --------------- HTTP tunnelling differs in a few ways: * Authentication is handled at ``/auth/token``. * Query execution is handled at ``/branch/{BRANCH}``. .. note:: Prior to |Gel| and |EdgeDB| 5.0 *branches* were called *databases*. If you're making a request against an older version of |EdgeDB| you should change ``/branch/`` options to ``/db/``. * Transactions are not supported. The :ref:`authentication ` phase is handled by sending ``GET`` requests to ``/auth/token`` with the ``Authorization`` header containing the authorization payload with the format: .. code-block:: Authorization: {AUTH METHOD} data={PAYLOAD} The client then reads the ``www-authenticate`` response header with the following format: .. code-block:: www-authenticate: {AUTH METHOD} {AUTH PAYLOAD} The auth payload's format is described by the auth method, usually ``SCRAM-SHA-256``. If the auth method differs from the requested method, the client should abort the authentication attempt. Once the :ref:`authentication ` phase is complete, the final response's body will contain an authorization token used to authenticate the HTTP connection. The client then sends any following message to ``/branch/{BRANCH}`` (or ``/db/{DATABASE}`` if you're using a version of |EdgeDB| < 5) with the following headers: * ``X-EdgeDB-User``: The username specified in the :ref:`connection parameters `. * ``Authorization``: The authorization token received from the :ref:`authentication ` phase, prefixed by ``Bearer``. * ``Content-Type``: Always ``application/x.edgedb.v_1_0.binary``. The response should be checked to match the content type, and the body should be parsed as the :ref:`message format ` described below; multiple message can be included in the response body, and should be parsed in order. .. _ALPN Protocol: https://github.com/geldata/rfcs/blob/master/text/ 1008-tls-and-alpn.rst#alpn-and-protocol-changes .. _ref_protocol_conventions: Conventions and data Types ========================== The message format descriptions in this section use a C-like struct definitions to describe their layout. The structs are *packed*, i.e. there are never any alignment gaps. The following data types are used in the descriptions: .. list-table:: :class: funcoptable * - ``int8`` - 8-bit integer * - ``int16`` - 16-bit integer, most significant byte first * - ``int32`` - 32-bit integer, most significant byte first * - ``int64`` - 64-bit integer, most significant byte first * - ``uint8`` - 8-bit unsigned integer * - ``uint16`` - 16-bit unsigned integer, most significant byte first * - ``uint32`` - 32-bit unsigned integer, most significant byte first * - ``uint64`` - 64-bit unsigned integer, most significant byte first * - ``int8`` or ``uint8`` - an 8-bit signed or unsigned integer enumeration, where *T* denotes the name of the enumeration * - ``string`` - a UTF-8 encoded text string prefixed with its byte length as ``uint32`` * - ``bytes`` - a byte string prefixed with its length as ``uint32`` * - ``KeyValue`` - .. eql:struct:: edb.protocol.KeyValue * - ``Annotation`` - .. eql:struct:: edb.protocol.Annotation * - ``uuid`` - an array of 16 bytes with no length prefix, equivalent to ``byte[16]`` .. _ref_message_format: Message Format ============== All messages in the Gel wire protocol have the following format: .. code-block:: c struct { uint8 message_type; int32 payload_length; uint8 payload[payload_length - 4]; }; The server and the client *MUST* not fragment messages. I.e the complete message must be sent before starting a new message. It's advised that whole message should be buffered before initiating a network call (but this requirement is neither observable nor enforceable at the other side). It's also common to buffer the whole message on the receiver side before starting to process it. Errors ====== At any point the server may send an :ref:`ref_protocol_msg_error` indicating an error condition. This is implied in the message flow documentation, and only successful paths are explicitly documented. The handling of the ``ErrorResponse`` message depends on the connection phase, as well as the severity of the error. If the server is not able to recover from an error, the connection is closed immediately after an ``ErrorResponse`` message is sent. Logs ==== Similarly to ``ErrorResponse`` the server may send a :ref:`ref_protocol_msg_log` message. The client should handle the message and continue as before. .. _ref_message_flow: Message Flow ============ There are two main phases in the lifetime of a Gel connection: the connection phase, and the command phase. The connection phase is responsible for negotiating the protocol and connection parameters, including authentication. The command phase is the regular operation phase where the server is processing queries sent by the client. Connection Phase ---------------- To begin a session, a client opens a connection to the server, and sends the :ref:`ref_protocol_msg_client_handshake`. The server responds in one of three ways: 1. One of the authentication messages (see :ref:`below `); 2. :ref:`ref_protocol_msg_server_handshake` followed by one of the authentication messages; 3. :ref:`ref_protocol_msg_error` which indicates an invalid client handshake message. :ref:`ref_protocol_msg_server_handshake` is only sent if the requested connection parameters cannot be fully satisfied; the server responds to offer the protocol parameters it is willing to support. Client may proceed by noting lower protocol version and/or absent extensions. Client *MUST* close the connection if protocol version is unsupported. Server *MUST* send subset of the extensions received in :ref:`ref_protocol_msg_client_handshake` (i.e. it never adds extra ones). While it's not required by the protocol specification itself, Gel server currently requires setting the following params in :ref:`ref_protocol_msg_client_handshake`: * ``user`` -- username for authentication * ``branch`` -- branch to connect to .. _ref_authentication: Authentication -------------- The server then initiates the authentication cycle by sending an authentication request message, to which the client must respond with an appropriate authentication response message. The following messages are sent by the server in the authentication cycle: :ref:`ref_protocol_msg_auth_ok` Authentication is successful. :ref:`ref_protocol_msg_auth_sasl` The client must now initiate a SASL negotiation, using one of the SASL mechanisms listed in the message. The client will send an :ref:`ref_protocol_msg_auth_sasl_initial_response` with the name of the selected mechanism, and the first part of the SASL data stream in response to this. If further messages are needed, the server will respond with :ref:`ref_protocol_msg_auth_sasl_continue`. :ref:`ref_protocol_msg_auth_sasl_continue` This message contains challenge data from the previous step of SASL negotiation (:ref:`ref_protocol_msg_auth_sasl`, or a previous :ref:`ref_protocol_msg_auth_sasl_continue`). The client must respond with an :ref:`ref_protocol_msg_auth_sasl_response` message. :ref:`ref_protocol_msg_auth_sasl_final` SASL authentication has completed with additional mechanism-specific data for the client. The server will next send :ref:`ref_protocol_msg_auth_ok` to indicate successful authentication, or an :ref:`ref_protocol_msg_error` to indicate failure. This message is sent only if the SASL mechanism specifies additional data to be sent from server to client at completion. If the frontend does not support the authentication method requested by the server, then it should immediately close the connection. Once the server has confirmed successful authentication with :ref:`ref_protocol_msg_auth_ok`, it then sends one or more of the following messages: :ref:`ref_protocol_msg_server_key_data` This message provides per-connection secret-key data that the client must save if it wants to be able to issue certain requests later. The client should not respond to this message. :ref:`ref_protocol_msg_server_parameter_status` This message informs the frontend about the setting of certain server parameters. The client can ignore this message, or record the settings for its future use. The client should not respond to this message. The connection phase ends when the server sends the first :ref:`ref_protocol_msg_ready_for_command` message, indicating the start of a command cycle. Command Phase ------------- In the command phase, the server expects the client to send one of the following messages: :ref:`ref_protocol_msg_parse` Instructs the server to parse the provided command or commands for execution. The server responds with a :ref:`ref_protocol_msg_command_data_description` containing the :ref:`type descriptor ` data necessary to perform data I/O for this command. :ref:`ref_protocol_msg_execute` Execute the provided command or commands. This message expects the client to declare a correct :ref:`type descriptor ` identifier for command arguments. If the declared input type descriptor does not match the expected value, a :ref:`ref_protocol_msg_command_data_description` message is returned followed by a ``ParameterTypeMismatchError`` in an ``ErrorResponse`` message. If the declared output type descriptor does not match, the server will send a :ref:`ref_protocol_msg_command_data_description` prior to sending any :ref:`ref_protocol_msg_data` messages. The client could attach state data in both messages. When doing so, the client must also set a correct :ref:`type descriptor ` identifier for the state data. If the declared state type descriptor does not match the expected value, a :ref:`ref_protocol_msg_state_data_description` message is returned followed by a ``StateMismatchError`` in an ``ErrorResponse`` message. However, the special type id of zero ``00000000-0000-0000-0000-000000000000`` for empty/default state is always a match. Each of the messages could contain one or more EdgeQL commands separated by a semicolon (``;``). If more than one EdgeQL command is found in a single message, the server will treat the commands as an EdgeQL script. EdgeQL scripts are always atomic, they will be executed in an implicit transaction block if no explicit transaction is currently active. Therefore, EdgeQL scripts have limitations on the kinds of EdgeQL commands they can contain: * Transaction control commands are not allowed, like ``start transaction``, ``commit``, ``declare savepoint``, or ``rollback to savepoint``. * Non-transactional commands, like ``create branch``, ``configure instance``, or ``create database`` are not allowed. In the command phase, the server can be in one of the three main states: * *idle*: server is waiting for a command; * *busy*: server is executing a command; * *error*: server encountered an error and is discarding incoming messages. Whenever a server switches to the *idle* state, it sends a :ref:`ref_protocol_msg_ready_for_command` message. Whenever a server encounters an error, it sends an :ref:`ref_protocol_msg_error` message and switches into the *error* state. To switch a server from the *error* state into the *idle* state, a :ref:`ref_protocol_msg_sync` message must be sent by the client. .. _ref_protocol_dump_flow: Dump Flow --------- Backup flow goes as following: 1. Client sends :ref:`ref_protocol_msg_dump` message 2. Server sends :ref:`ref_protocol_msg_dump_header` message 3. Server sends one or more :ref:`ref_protocol_msg_dump_block` messages 4. Server sends :ref:`ref_protocol_msg_command_complete` message Usually client should send :ref:`ref_protocol_msg_sync` after ``Dump`` message to finish implicit transaction. .. _ref_protocol_restore_flow: Restore Flow ------------ Restore procedure fills up the |branch| the client is connected to with the schema and data from the dump file. Flow is the following: 1. Client sends :ref:`ref_protocol_msg_restore` message with the dump header block 2. Server sends :ref:`ref_protocol_msg_restore_ready` message as a confirmation that it has accepted the header, restored schema and ready to receive data blocks 3. Clients sends one or more :ref:`ref_protocol_msg_restore_block` messages 4. Client sends :ref:`ref_protocol_msg_restore_eof` message 5. Server sends :ref:`ref_protocol_msg_command_complete` message Note: :ref:`ref_protocol_msg_error` may be sent from the server at any time. In case of error, :ref:`ref_protocol_msg_sync` must be sent and all subsequent messages ignored until :ref:`ref_protocol_msg_ready_for_command` is received. Restore protocol doesn't require a :ref:`ref_protocol_msg_sync` message except for error cases. Termination =========== The normal termination procedure is that the client sends a :ref:`ref_protocol_msg_terminate` message and immediately closes the connection. On receipt of this message, the server cleans up the connection resources and closes the connection. In some cases the server might disconnect without a client request to do so. In such cases the server will attempt to send an :ref:`ref_protocol_msg_error` or a :ref:`ref_protocol_msg_log` message to indicate the reason for the disconnection. ================================================ FILE: docs/resources/protocol/messages.rst ================================================ ======== Messages ======== .. list-table:: :class: funcoptable * - **Server Messages** - * - :ref:`ref_protocol_msg_auth_ok` - Authentication is successful. * - :ref:`ref_protocol_msg_auth_sasl` - SASL authentication is required. * - :ref:`ref_protocol_msg_auth_sasl_continue` - SASL authentication challenge. * - :ref:`ref_protocol_msg_auth_sasl_final` - SASL authentication final message. * - :ref:`ref_protocol_msg_command_complete` - Successful completion of a command. * - :ref:`ref_protocol_msg_command_data_description` - Description of command data input and output. * - :ref:`ref_protocol_msg_state_data_description` - Description of state data. * - :ref:`ref_protocol_msg_data` - Command result data element. * - :ref:`ref_protocol_msg_dump_header` - Initial message of the database backup protocol * - :ref:`ref_protocol_msg_dump_block` - Single chunk of database backup data * - :ref:`ref_protocol_msg_error` - Server error. * - :ref:`ref_protocol_msg_log` - Server log message. * - :ref:`ref_protocol_msg_server_parameter_status` - Server parameter value. * - :ref:`ref_protocol_msg_ready_for_command` - Server is ready for a command. * - :ref:`ref_protocol_msg_restore_ready` - Successful response to the :ref:`ref_protocol_msg_restore` message * - :ref:`ref_protocol_msg_server_handshake` - Initial server connection handshake. * - :ref:`ref_protocol_msg_server_key_data` - Opaque token identifying the server connection. * - **Client Messages** - * - :ref:`ref_protocol_msg_auth_sasl_initial_response` - SASL authentication initial response. * - :ref:`ref_protocol_msg_auth_sasl_response` - SASL authentication response. * - :ref:`ref_protocol_msg_client_handshake` - Initial client connection handshake. * - :ref:`ref_protocol_msg_dump` - Initiate database backup * - :ref:`ref_protocol_msg_parse` - Parse EdgeQL command(s). * - :ref:`ref_protocol_msg_execute` - Parse and/or execute a query. * - :ref:`ref_protocol_msg_restore` - Initiate database restore * - :ref:`ref_protocol_msg_restore_block` - Next block of database dump * - :ref:`ref_protocol_msg_restore_eof` - End of database dump * - :ref:`ref_protocol_msg_sync` - Provide an explicit synchronization point. * - :ref:`ref_protocol_msg_terminate` - Terminate the connection. .. _ref_protocol_msg_error: ErrorResponse ============= Sent by: server. Format: .. eql:struct:: edb.protocol.ErrorResponse .. eql:struct:: edb.protocol.ErrorSeverity See the :ref:`list of error codes ` for all possible error codes. Known attributes: * 0x0001 ``HINT``: ``str`` -- error hint. * 0x0002 ``DETAILS``: ``str`` -- error details. * 0x0101 ``SERVER_TRACEBACK``: ``str`` -- error traceback from server (is only sent in dev mode). * 0xFFF1 ``POSITION_START`` -- byte offset of the start of the error span. * 0xFFF2 ``POSITION_END`` -- byte offset of the end of the error span. * 0xFFF3 ``LINE_START`` -- one-based line number of the start of the error span. * 0xFFF4 ``COLUMN_START`` -- one-based column number of the start of the error span. * 0xFFF5 ``UTF16_COLUMN_START`` -- zero-based column number in UTF-16 encoding of the start of the error span. * 0xFFF6 ``LINE_END`` -- one-based line number of the start of the error span. * 0xFFF7 ``COLUMN_END`` -- one-based column number of the start of the error span. * 0xFFF8 ``UTF16_COLUMN_END`` -- zero-based column number in UTF-16 encoding of the end of the error span. * 0xFFF9 ``CHARACTER_START`` -- zero-based offset of the error span in terms of Unicode code points. * 0xFFFA ``CHARACTER_END`` -- zero-based offset of the end of the error span. Notes: 1. Error span is the range of characters (or equivalent bytes) of the original query that compiler points to as the source of the error. 2. ``COLUMN_*`` is defined in terms of width of characters defined by Unicode Standard Annex #11, in other words, the column number in the text if rendered with monospace font, e.g. in a terminal. 3. ``UTF16_COLUMN_*`` is defined as number of UTF-16 code units (i.e. two byte-pairs) that precede target character on the same line. 4. ``*_END`` points to a next character after the last character of the error span. .. _ref_protocol_msg_log: LogMessage ========== Sent by: server. Format: .. eql:struct:: edb.protocol.LogMessage .. eql:struct:: edb.protocol.MessageSeverity See the :ref:`list of error codes ` for all possible log message codes. .. _ref_protocol_msg_ready_for_command: ReadyForCommand =============== Sent by: server. Format: .. eql:struct:: edb.protocol.ReadyForCommand .. eql:struct:: edb.protocol.TransactionState .. eql:struct:: edb.protocol.Annotation .. _ref_protocol_msg_restore_ready: RestoreReady ============ Sent by: server. Initial :ref:`ref_protocol_msg_restore` message accepted, ready to receive data. See :ref:`ref_protocol_restore_flow`. Format: .. eql:struct:: edb.protocol.RestoreReady .. eql:struct:: edb.protocol.Annotation .. _ref_protocol_msg_command_complete: CommandComplete =============== Sent by: server. Format: .. eql:struct:: edb.protocol.CommandComplete .. eql:struct:: edb.protocol.Annotation .. _ref_protocol_msg_dump: Dump ==== Sent by: client. Initiates a database backup. See :ref:`ref_protocol_dump_flow`. Format: .. eql:struct:: edb.protocol.Dump .. eql:struct:: edb.protocol.Annotation .. eql:struct:: edb.protocol.DumpFlag Use: * ``DUMP_SECRETS`` to include secrets in the backup. By default, secrets are not included. .. _ref_protocol_msg_command_data_description: CommandDataDescription ====================== Sent by: server. Format: .. eql:struct:: edb.protocol.CommandDataDescription .. eql:struct:: edb.protocol.enums.Cardinality .. eql:struct:: edb.protocol.Annotation The format of the *input_typedesc* and *output_typedesc* fields is described in the :ref:`ref_proto_typedesc` section. .. _ref_protocol_msg_state_data_description: StateDataDescription ==================== Sent by: server. Format: .. eql:struct:: edb.protocol.StateDataDescription The format of the *typedesc* fields is described in the :ref:`ref_proto_typedesc` section. .. _ref_protocol_msg_sync: Sync ==== Sent by: client. Format: .. eql:struct:: edb.protocol.Sync .. _ref_protocol_msg_restore: Restore ======= Sent by: client. Initiate restore to the current |branch|. See :ref:`ref_protocol_restore_flow`. Format: .. eql:struct:: edb.protocol.Restore .. _ref_protocol_msg_restore_block: RestoreBlock ============ Sent by: client. Send dump file data block. See :ref:`ref_protocol_restore_flow`. Format: .. eql:struct:: edb.protocol.RestoreBlock .. _ref_protocol_msg_restore_eof: RestoreEof ========== Sent by: client. Notify server that dump is fully uploaded. See :ref:`ref_protocol_restore_flow`. Format: .. eql:struct:: edb.protocol.RestoreEof .. _ref_protocol_msg_execute: Execute ======= Sent by: client. Format: .. eql:struct:: edb.protocol.Execute .. eql:struct:: edb.protocol.OutputFormat .. eql:struct:: edb.protocol.Annotation Use: * ``BINARY`` to return data encoded in binary. * ``JSON`` to return data as single row and single field that contains the resultset as a single JSON array". * ``JSON_ELEMENTS`` to return a single JSON string per top-level set element. This can be used to iterate over a large result set efficiently. * ``NONE`` to prevent the server from returning data, even if the EdgeQL command does. The data in *arguments* must be encoded as a :ref:`tuple value ` described by a type descriptor identified by *input_typedesc_id*. .. eql:struct:: edb.protocol.enums.Cardinality .. _ref_protocol_msg_parse: Parse ===== Sent by: client. .. eql:struct:: edb.protocol.Parse .. eql:struct:: edb.protocol.Capability .. eql:struct:: edb.protocol.Annotation See RFC1004_ for more information on capability flags. .. eql:struct:: edb.protocol.CompilationFlag Use: * ``0x0000_0000_0000_0001`` (``INJECT_OUTPUT_TYPE_IDS``) -- if set, all returned objects have a ``__tid__`` property set to their type ID (equivalent to having an implicit ``__tid__ := .__type__.id`` computed property.) * ``0x0000_0000_0000_0002`` (``INJECT_OUTPUT_TYPE_NAMES``) -- if set all returned objects have a ``__tname__`` property set to their type name (equivalent to having an implicit ``__tname__ := .__type__.name`` computed property.) Note that specifying this flag might slow down queries. * ``0x0000_0000_0000_0004`` (``INJECT_OUTPUT_OBJECT_IDS``) -- if set all returned objects have an ``id`` property set to their identifier, even if not specified explicitly in the output shape. .. eql:struct:: edb.protocol.OutputFormat Use: * ``BINARY`` to return data encoded in binary. * ``JSON`` to return data as single row and single field that contains the resultset as a single JSON array". * ``JSON_ELEMENTS`` to return a single JSON string per top-level set element. This can be used to iterate over a large result set efficiently. * ``NONE`` to prevent the server from returning data, even if the EdgeQL statement does. .. eql:struct:: edb.protocol.enums.Cardinality .. _ref_protocol_msg_data: Data ==== Sent by: server. Format: .. eql:struct:: edb.protocol.Data .. eql:struct:: edb.protocol.DataElement The exact encoding of ``DataElement.data`` is defined by the query output :ref:`type descriptor `. Wire formats for the standard scalar types and collections are documented in :ref:`ref_proto_dataformats`. .. _ref_protocol_msg_dump_header: Dump Header =========== Sent by: server. Initial message of database backup protocol. See :ref:`ref_protocol_dump_flow`. Format: .. eql:struct:: edb.protocol.DumpHeader .. eql:struct:: edb.protocol.DumpTypeInfo .. eql:struct:: edb.protocol.DumpObjectDesc Known attributes: * 101 ``BLOCK_TYPE`` -- block type, always "I" * 102 ``SERVER_TIME`` -- server time when dump is started as a floating point unix timestamp stringified * 103 ``SERVER_VERSION`` -- full version of server as string * 105 ``SERVER_CATALOG_VERSION`` -- the catalog version of the server, as a 64-bit integer. The catalog version is an identifier that is incremented whenever a change is made to the database layout or standard library. .. _ref_protocol_msg_dump_block: Dump Block ========== Sent by: server. The actual protocol data in the backup protocol. See :ref:`ref_protocol_dump_flow`. Format: .. eql:struct:: edb.protocol.DumpBlock Known attributes: * 101 ``BLOCK_TYPE`` -- block type, always "D" * 110 ``BLOCK_ID`` -- block identifier (16 bytes of UUID) * 111 ``BLOCK_NUM`` -- integer block index stringified * 112 ``BLOCK_DATA`` -- the actual block data .. _ref_protocol_msg_server_key_data: ServerKeyData ============= Sent by: server. Format: .. eql:struct:: edb.protocol.ServerKeyData .. _ref_protocol_msg_server_parameter_status: ParameterStatus =============== Sent by: server. Format: .. eql:struct:: edb.protocol.ParameterStatus Known statuses: * ``suggested_pool_concurrency`` -- suggested default size for clients connection pools. Serialized as UTF-8 encoded string. * ``system_config`` -- a set of instance-level configuration settings exposed to clients on connection. Serialized as: .. eql:struct:: edb.protocol.ParameterStatus_SystemConfig Where ``DataElement`` is defined in the same way as for the :ref:`Data ` message: .. eql:struct:: edb.protocol.DataElement .. _ref_protocol_msg_client_handshake: ClientHandshake =============== Sent by: client. Format: .. eql:struct:: edb.protocol.ClientHandshake .. eql:struct:: edb.protocol.ConnectionParam .. eql:struct:: edb.protocol.ProtocolExtension The ``ClientHandshake`` message is the first message sent by the client upon connecting to the server. It is the first phase of protocol negotiation, where the client sends the requested protocol version and extensions. Currently, the only defined ``major_ver`` is ``1``, and ``minor_ver`` is ``0``. No protocol extensions are currently defined. The server always responds with the :ref:`ref_protocol_msg_server_handshake`. .. _ref_protocol_msg_server_handshake: ServerHandshake =============== Sent by: server. Format: .. eql:struct:: edb.protocol.ServerHandshake .. eql:struct:: edb.protocol.ProtocolExtension The ``ServerHandshake`` message is a direct response to the :ref:`ref_protocol_msg_client_handshake` message and is sent by the server in the case where the server does not support the protocol version or protocol extensions requested by the client. It contains the maximum protocol version supported by the server, considering the version requested by the client. It also contains the intersection of the client-requested and server-supported protocol extensions. Any requested extensions not listed in the ``Server Handshake`` message are considered unsupported. .. _ref_protocol_msg_auth_ok: AuthenticationOK ================ Sent by: server. Format: .. eql:struct:: edb.protocol.AuthenticationOK The ``AuthenticationOK`` message is sent by the server once it considers the authentication to be successful. .. _ref_protocol_msg_auth_sasl: AuthenticationSASL ================== Sent by: server. Format: .. eql:struct:: edb.protocol.AuthenticationRequiredSASLMessage The ``AuthenticationSASL`` message is sent by the server if it determines that a SASL-based authentication method is required in order to connect using the connection parameters specified in the :ref:`ref_protocol_msg_client_handshake`. The message contains a list of *authentication methods* supported by the server in the order preferred by the server. .. note:: At the moment, the only SASL authentication method supported by Gel is ``SCRAM-SHA-256`` (`RFC 7677 `_). The client must select an appropriate authentication method from the list returned by the server and send an :ref:`ref_protocol_msg_auth_sasl_initial_response`. One or more server-challenge and client-response message follow. Each server-challenge is sent in an :ref:`ref_protocol_msg_auth_sasl_continue`, followed by a response from the client in an :ref:`ref_protocol_msg_auth_sasl_response` message. The particulars of the messages are mechanism specific. Finally, when the authentication exchange is completed successfully, the server sends an :ref:`ref_protocol_msg_auth_sasl_final`, followed immediately by an :ref:`ref_protocol_msg_auth_ok`. .. _ref_protocol_msg_auth_sasl_continue: AuthenticationSASLContinue ========================== Sent by: server. Format: .. eql:struct:: edb.protocol.AuthenticationSASLContinue .. _ref_protocol_msg_auth_sasl_final: AuthenticationSASLFinal ======================= Sent by: server. Format: .. eql:struct:: edb.protocol.AuthenticationSASLFinal .. _ref_protocol_msg_auth_sasl_initial_response: AuthenticationSASLInitialResponse ================================= Sent by: client. Format: .. eql:struct:: edb.protocol.AuthenticationSASLInitialResponse .. _ref_protocol_msg_auth_sasl_response: AuthenticationSASLResponse ========================== Sent by: client. Format: .. eql:struct:: edb.protocol.AuthenticationSASLResponse .. _ref_protocol_msg_terminate: Terminate ========= Sent by: client. Format: .. eql:struct:: edb.protocol.Terminate .. _RFC1004: https://github.com/geldata/rfcs/blob/master/text/1004-transactions-api.rst ================================================ FILE: docs/resources/protocol/typedesc.rst ================================================ .. _ref_proto_typedesc: ================ Type descriptors ================ This section describes how type information for query input and results is encoded. Specifically, this is needed to decode the server response to the :ref:`ref_protocol_msg_command_data_description` message. The type descriptor is essentially a list of type information *blocks*: * each *block* encodes one type; * *blocks* can reference other *blocks*. While parsing the *blocks*, a database driver can assemble an *encoder* or a *decoder* of the Gel binary data. An *encoder* is used to encode objects, native to the driver's runtime, to binary data that Gel can decode and work with. A *decoder* is used to decode data from Gel native format to data types native to the driver. There is one special type with *type id* of zero: ``00000000-0000-0000-0000-000000000000``. The describe result of this type contains zero *blocks*. It's used when a statement returns no meaningful results, e.g. the ``CREATE BRANCH example`` statement. It is also used to represent the input descriptor when a query does not receive any arguments, or the state descriptor for an empty/default state. .. versionadded:: 6.0 Added ``SQLRecordDescriptor``. Descriptor and type IDs ======================= The descriptor and type IDs in structures below are intended to be semi-stable unique identifiers of a type. Fundamental types have globally stable known IDs, and type IDs for schema-defined types (i.e. with ``schema_defined = true``) persist. Ephemeral type ids are derived from type structure and are not guaranteed to be stable, but are still useful as cache keys. Set Descriptor ============== .. code-block:: c struct SetDescriptor { // Indicates that this is a Set value descriptor. uint8 tag = 0; // Descriptor ID. uuid id; // Set element type descriptor index. uint16 type; }; Set values are encoded on the wire as :ref:`single-dimensional arrays `. Scalar Type Descriptor ====================== .. code-block:: c struct ScalarTypeDescriptor { // Indicates that this is a // Scalar Type descriptor. uint8 tag = 3; // Schema type ID. uuid id; // Schema type name. string name; // Whether the type is defined in the schema // or is ephemeral. bool schema_defined; // Number of ancestor scalar types. uint16 ancestors_count; // Indexes of ancestor scalar type descriptors // in ancestor resolution order (C3). uint16 ancestors[ancestors_count]; }; The descriptor IDs for fundamental scalar types are constant. The following table lists all Gel fundamental type descriptor IDs: .. list-table:: :header-rows: 1 * - ID - Type * - ``00000000-0000-0000-0000-000000000100`` - :ref:`std::uuid ` * - ``00000000-0000-0000-0000-000000000101`` - :ref:`std::str ` * - ``00000000-0000-0000-0000-000000000102`` - :ref:`std::bytes ` * - ``00000000-0000-0000-0000-000000000103`` - :ref:`std::int16 ` * - ``00000000-0000-0000-0000-000000000104`` - :ref:`std::int32 ` * - ``00000000-0000-0000-0000-000000000105`` - :ref:`std::int64 ` * - ``00000000-0000-0000-0000-000000000106`` - :ref:`std::float32 ` * - ``00000000-0000-0000-0000-000000000107`` - :ref:`std::float64 ` * - ``00000000-0000-0000-0000-000000000108`` - :ref:`std::decimal ` * - ``00000000-0000-0000-0000-000000000109`` - :ref:`std::bool ` * - ``00000000-0000-0000-0000-00000000010A`` - :ref:`std::datetime ` * - ``00000000-0000-0000-0000-00000000010E`` - :ref:`std::duration ` * - ``00000000-0000-0000-0000-00000000010F`` - :ref:`std::json ` * - ``00000000-0000-0000-0000-00000000010B`` - :ref:`cal::local_datetime ` * - ``00000000-0000-0000-0000-00000000010C`` - :ref:`cal::local_date ` * - ``00000000-0000-0000-0000-00000000010D`` - :ref:`cal::local_time ` * - ``00000000-0000-0000-0000-000000000110`` - :ref:`std::bigint ` * - ``00000000-0000-0000-0000-000000000111`` - :ref:`cal::relative_duration ` * - ``00000000-0000-0000-0000-000000000112`` - :ref:`cal::date_duration ` * - ``00000000-0000-0000-0000-000000000130`` - :ref:`cfg::memory ` Tuple Type Descriptor ===================== .. code-block:: c struct TupleTypeDescriptor { // Indicates that this is a // Tuple Type descriptor. uint8 tag = 4; // Schema type ID. uuid id; // Schema type name. string name; // Whether the type is defined in the schema // or is ephemeral. bool schema_defined; // Number of ancestor scalar types. uint16 ancestors_count; // Indexes of ancestor scalar type descriptors // in ancestor resolution order (C3). uint16 ancestors[ancestors_count]; // The number of elements in tuple. uint16 element_count; // Indexes of element type descriptors. uint16 element_types[element_count]; }; An empty tuple type descriptor has an ID of ``00000000-0000-0000-0000-0000000000FF``. Named Tuple Type Descriptor =========================== .. code-block:: c struct NamedTupleTypeDescriptor { // Indicates that this is a // Named Tuple Type descriptor. uint8 tag = 5; // Schema type ID. uuid id; // Schema type name. string name; // Whether the type is defined in the schema // or is ephemeral. bool schema_defined; // Number of ancestor scalar types. uint16 ancestors_count; // Indexes of ancestor scalar type descriptors // in ancestor resolution order (C3). uint16 ancestors[ancestors_count]; // The number of elements in tuple. uint16 element_count; // Indexes of element descriptors. TupleElement elements[element_count]; }; struct TupleElement { // Field name. string name; // Field type descriptor index. int16 type; }; Array Type Descriptor ===================== .. code-block:: c struct ArrayTypeDescriptor { // Indicates that this is an // Array Type descriptor. uint8 tag = 6; // Schema type ID. uuid id; // Schema type name. string name; // Whether the type is defined in the schema // or is ephemeral. bool schema_defined; // Number of ancestor scalar types. uint16 ancestors_count; // Indexes of ancestor scalar type descriptors // in ancestor resolution order (C3). uint16 ancestors[ancestors_count]; // Array element type. uint16 type; // The number of array dimensions, at least 1. uint16 dimension_count; // Sizes of array dimensions, -1 indicates // unbound dimension. int32 dimensions[dimension_count]; }; Enumeration Type Descriptor =========================== .. code-block:: c struct EnumerationTypeDescriptor { // Indicates that this is an // Enumeration Type descriptor. uint8 tag = 7; // Schema type ID. uuid id; // Schema type name. string name; // Whether the type is defined in the schema // or is ephemeral. bool schema_defined; // Number of ancestor scalar types. uint16 ancestors_count; // Indexes of ancestor scalar type descriptors // in ancestor resolution order (C3). uint16 ancestors[ancestors_count]; // The number of enumeration members. uint16 member_count; // Names of enumeration members. string members[member_count]; }; Range Type Descriptor ===================== .. code-block:: c struct RangeTypeDescriptor { // Indicates that this is a // Range Type descriptor. uint8 tag = 9; // Schema type ID. uuid id; // Schema type name. string name; // Whether the type is defined in the schema // or is ephemeral. bool schema_defined; // Number of ancestor scalar types. uint16 ancestors_count; // Indexes of ancestor scalar type descriptors // in ancestor resolution order (C3). uint16 ancestors[ancestors_count]; // Range type descriptor index. uint16 type; }; Ranges are encoded on the wire as :ref:`ranges `. Object Type Descriptor ====================== .. code-block:: c struct ObjectTypeDescriptor { // Indicates that this is an // object type descriptor. uint8 tag = 10; // Schema type ID. uuid id; // Schema type name (can be empty for ephemeral free object types). string name; // Whether the type is defined in the schema // or is ephemeral. bool schema_defined; }; Compound Type Descriptor ======================== .. code-block:: c struct CompoundTypeDescriptor { // Indicates that this is a // compound type descriptor. uint8 tag = 11; // Schema type ID. uuid id; // Schema type name. string name; // Whether the type is defined in the schema // or is ephemeral. bool schema_defined; // Compound type operation, see TypeOperation below. uint8 op; // Number of compound type components. uint16 component_count; // Compound type component type descriptor indexes. uint16 components[component_count]; }; enum TypeOperation { // Foo | Bar UNION = 1; // Foo & Bar INTERSECTION = 2; }; Object Output Shape Descriptor ============================== .. code-block:: c struct ObjectShapeDescriptor { // Indicates that this is an // Object Shape descriptor. uint8 tag = 1; // Descriptor ID. uuid id; // Whether is is an ephemeral free shape, // if true, then `type` would always be 0 // and should not be interpreted. bool ephemeral_free_shape; // Object type descriptor index. uint16 type; // Number of elements in shape. uint16 element_count; // Array of shape elements. ShapeElement elements[element_count]; }; struct ShapeElement { // Field flags: // 1 << 0: the field is implicit // 1 << 1: the field is a link property // 1 << 2: the field is a link uint32 flags; // The cardinality of the shape element. uint8 cardinality; // Element name. string name; // Element type descriptor index. uint16 type; // Source schema type descriptor index // (useful for polymorphic queries). uint16 source_type; }; .. eql:struct:: edb.protocol.enums.Cardinality Objects are encoded on the wire as :ref:`tuples `. Input Shape Descriptor ====================== .. code-block:: c struct InputShapeDescriptor { // Indicates that this is an // Object Shape descriptor. uint8 tag = 8; // Descriptor ID. uuid id; // Number of elements in shape. uint16 element_count; // Shape elements. InputShapeElement elements[element_count]; }; struct InputShapeElement { // Field flags, currently always zero. uint32 flags; // The cardinality of the shape element. uint8 cardinality; // Element name. string name; // Element type descriptor index. uint16 type; }; Input objects are encoded on the wire as :ref:`sparse objects `. Type Annotation Text Descriptor =============================== .. code-block:: c struct TypeAnnotationDescriptor { // Indicates that this is an // Type Annotation descriptor. uint8 tag = 127; // Index of the descriptor the // annotation is for. uint16 descriptor; // Annotation key. string key; // Annotation value. string value; }; SQL Record Descriptor ===================== .. code-block:: c struct SQLRecordDescriptor { // Indicates that this is a // SQL Record descriptor. uint8 tag = 13; // Descriptor ID. uuid id; // Number of elements in record. uint16 element_count; // Array of shape elements. SQLRecordElement elements[element_count]; }; struct SQLRecordElement { // Element name. string name; // Element type descriptor index. uint16 type; }; ================================================ FILE: docs/resources/upgrading.rst ================================================ .. _ref_upgrading: ================= Upgrading from v5 ================= With the release of Gel v6, we have introduced a number of changes that affect your workflow. The most obvious change is that the CLI and client libraries are now named after Gel, not EdgeDB. But, there are a number of other smaller changes and enhancements that are worth understanding as you bring your EdgeDB database up-to-date with the latest release. CLI === .. lint-off For a few versions we've been shipping an alias to the ``edgedb`` CLI named ``gel`` as we've been working on the rename. For the most part, you can now just use ``gel`` instead of ``edgedb`` as the CLI name. Make sure you are using the latest version of the CLI by running :gelcmd:`cli upgrade`. If you see a note about not being able to upgrade, you can try running ``edgedb cli upgrade`` and then after that :gelcmd:`cli upgrade`. Don't forget to update any scripts that use the ``edgedb`` CLI to use ``gel`` instead. .. lint-on Project Configuration File ========================== .. lint-off Gel CLI and client libraries use a configuration file configure various things about your project, such as the location of the schema directory, and the target version of the Gel server. Previously, this file was named ``edgedb.toml``, but it is now named |gel.toml|. .. lint-on In addition to the name change, we have also renamed the TOML table for configuring the server version from ``[edgedb]`` to ``[instance]``. .. tabs:: .. code-tab:: toml :caption: (Before) edgedb.toml [edgedb] server-version = "5.7" .. code-tab:: toml-diff :caption: (After) gel.toml - [edgedb] + [instance] server-version = "5.7" We continue to support the old file and table name, but we recommend updating to the new name as soon as possible. There are also a number of useful new workflow features in the CLI that are configured in this file that are worth exploring as well. See the `announcement blog post `_ for more details. Client Libraries ================ We've started publishing our various client libraries under Gel-flavored names, and will only be publishing to these new packages going forward. .. list-table:: :header-rows: 1 * - Language - New Package * - Python - `gel on PyPI `_ * - TypeScript - `gel on npm `_ * - Go - `gel-go on GitHub `_ * - Rust - `gel-rust on GitHub `_ If you're using the TypeScript client library, you can use our codemod to automatically update your codebase to point at the new packages: .. code-block:: bash $ npx @gel/codemod@latest Code generation =============== Some of the languages we support include code generation tools that can generate code from your schema. Here is a table of how those tools have been renamed: .. list-table:: :header-rows: 1 * - Language - Previous - Current * - Python - ``edgedb-py`` - ``gel-py`` * - TypeScript - ``@edgedb/generate`` - ``@gel/generate`` Check your project task runners and update them accordingly. Upgrading instances =================== To take advantage of the new features in Gel v6, you'll need to upgrade your instances to the latest version. Cloud instances --------------- If you're using a hosted instance on Gel Cloud, you can upgrade your instance by clicking on the "Upgrade" button in the Gel Cloud console, or with the CLI. .. code-block:: bash $ gel instance upgrade --to-latest Local instances --------------- If you have local instances that you've intialized with the CLI using :gelcmd:`project init`, you can upgrade them easily with the CLI. .. code-block:: bash gel project upgrade --to-latest This will upgrade the project instance to the latest version of Gel and also update the |gel.toml| server-version value to the latest version. Remote instances ---------------- To upgrade a remote instance, we recommend the following dump-and-restore process: 1. Gel v6.0 supports PostgreSQL 14 or above. Verify your PostgreSQL version before upgrading Gel. If you're using Postgres 13 or below, upgrade Postgres first. 2. Spin up an empty 6.0 instance. You can use one of our :ref:`deployment guides `. For Debian/Ubuntu, when adding the Gel package repository, use this command: .. code-block:: bash $ echo deb [signed-by=/usr/local/share/keyrings/gel-keyring.gpg] \ https://packages.geldata.com/apt \ $(grep "VERSION_CODENAME=" /etc/os-release | cut -d= -f2) main \ | sudo tee /etc/apt/sources.list.d/gel.list $ sudo apt-get update && sudo apt-get install gel-6 For CentOS/RHEL, use this installation command: .. code-block:: bash $ sudo yum install gel-6 In any required ``systemctl`` commands, replace ``edgedb-server-5`` with ``gel-server-6``. For Docker setups, use the ``6`` or other appropriate tag. .. note:: The new instance will have a different DSN, including a different port number. Take note of the full DSN of the new instance as you'll need it to restore your database, and update your application to use the new DSN in further steps. 3. Take your application offline, then dump your v5.x database with the CLI: .. code-block:: bash $ gel dump --dsn --all --format dir my_database.dump/ This will dump the schema and contents of your current database to a directory on your local disk called ``my_database.dump``. The directory name isn't important. 4. Restore to the new, empty v6 instance from the dump: .. code-block:: bash $ gel restore --all my_database.dump/ --dsn Once the restore is complete, update your application to connect to the new instance. This process will involve some downtime, specifically during steps 2 and 3. GitHub Action ============= We publish a GitHub action for accessing a Gel instance in your GitHub Actions workflows. This action has been updated to work with Gel v6. If you're using the action in your workflow, update it to use the latest version. .. code-block:: yaml-diff - - uses: edgedb/setup-edgedb@v1 + - uses: geldata/setup-gel@v1 - - run: edgedb query 'select sys::get_version_as_str()' + - run: gel query 'select sys::get_version_as_str()' ================================================ FILE: edb/.gitignore ================================================ # Ignore Cython debug files *.html # Grammar spec file encoded as BitCode *.bc # Grammar spec file encoded as EBNF *.ebnf ================================================ FILE: edb/README.md ================================================ ## Directory overview Here is a list of most of the important directories in EdgeDB, along with some of the key files and subdirectories in them. This list is *partial*, focused on the compiler, and ordered conceptually. - `schema/` - Representation of the schema, and implementation of schema modifications (both SDL migrations and DDL statements). The schema is an immutable object (implemented as a bunch of immutable maps), and making changes to it produces a new schema object. Objects stored in the schema are represented in the compiler as proxy objects, and fetching attributes from them requires passing in a schema. - `edgeql/` - EdgeQL frontend tools - AST, parser, first stage compiler, etc - `edgeql/ast.py` - EdgeQL AST - `edgeql/parser/` - Parser. Uses https://github.com/MagicStack/parsing - `edgeql/compiler/` - Compiler from EdgeQL to our IR - `edgeql/tracer.py` and `edgeql/declarative.py` - Analysis to convert SDL schema descriptions to DDL that will create them. - `ir/` - Intermediate Representation (IR) and tools for operating on it - `edgeql/ir/pathid.py` - Definition of "path ids", which are used to identify sets - `edgeql/ir/ast.py` - Primary AST of intermediate representation. The IR contains no direct references to schema objects; information from the schema that is needed in the IR needs to be explicitly placed there. There are `TypeRef` and `PointerRef` objects that do this for types and pointers. - `edgeql/ir/scopetree.py` - Representation of "scope tree", which computes and tracks where sets are "bound". The IR output of the compiler consists of both an IR AST and a scope tree, needed to interpret it. - `pgsql/` - PostgreSQL backend tools - AST, codegen, second stage compiler, etc - `pgsql/ast.py` - SQL AST. The AST contains both information for the actual SQL AST, along with a large collection of metadata that is used during the compilation process. - `pgsql/codegen.py` - SQL codegen. Converts AST to SQL. - `pgsql/compiler/` - IR to SQL compiler. - `pgsql/delta.py` - Generates SQL DDL from delta trees. - `lib/` - Definition of EdgeDB's standard library - `lib/schema.edgeql` - Definition of the parts of the schema that are exposed publically. - `graphql/` - GraphQL to EdgeQL compiler - `server/` - Implementation of the EdgeDB server and protocol handling - `server/compiler` - The interface between the server and the compiler ================================================ FILE: edb/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # DO NOT ADD ANYTHING TO THIS FILE: # we might want to make "edb" a namespace # package at some point. ================================================ FILE: edb/_edgeql_parser.pyi ================================================ import typing class SyntaxError(Exception): ... class ParserResult: out: typing.Optional[CSTNode | list[OpaqueToken]] errors: list[ tuple[ str, tuple[int, typing.Optional[int]], typing.Optional[str], typing.Optional[str], ] ] def pack(self) -> bytes: ... class Hasher: @staticmethod def start_migration(parent_id: str) -> Hasher: ... def add_source(self, data: str) -> None: ... def make_migration_id(self) -> str: ... unreserved_keywords: frozenset[str] partial_reserved_keywords: frozenset[str] future_reserved_keywords: frozenset[str] current_reserved_keywords: frozenset[str] class Entry: key: bytes tokens: list[OpaqueToken] extra_blobs: list[bytes] first_extra: typing.Optional[int] extra_counts: list[int] def get_variables(self) -> dict[str, typing.Any]: ... def pack(self) -> bytes: ... def normalize(text: str) -> Entry: ... def parse( start_token_name: str, tokens: list[OpaqueToken] ) -> tuple[ ParserResult, list[tuple[type, typing.Callable]] ]: ... def suggest_next_keywords( start_token_name: str, tokens: list[OpaqueToken] ) -> tuple[list[str], bool]: ... def preload_spec(spec_filepath: str) -> None: ... def save_spec(spec_json: str, dst: str) -> None: ... class CSTNode: production: typing.Optional[Production] terminal: typing.Optional[Terminal] class Production: id: int args: list[CSTNode] start: int | None end: int | None class Terminal: text: str value: typing.Any start: int end: int class SourcePoint: line: int zero_based_line: int column: int utf16column: int offset: int char_offset: int @staticmethod def from_offsets( data: bytes, offsets: list[int] ) -> list[SourcePoint]: ... @staticmethod def from_lines_cols( data: bytes, lines_cols: list[tuple[int, int]] ) -> list[SourcePoint]: ... def offset_of_line(text: str, target: int) -> int: ... class OpaqueToken: def span_start(self) -> int: ... def span_end(self) -> int: ... def is_ident(self) -> bool: ... def tokenize(s: str) -> ParserResult: ... def unpickle_token(bytes: bytes) -> OpaqueToken: ... def unpack(serialized: bytes) -> Entry | list[OpaqueToken]: ... ================================================ FILE: edb/api/errors.txt ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # #### 0x_01_00_00_00 InternalServerError #### 0x_02_00_00_00 UnsupportedFeatureError #### 0x_03_00_00_00 ProtocolError 0x_03_01_00_00 BinaryProtocolError 0x_03_01_00_01 UnsupportedProtocolVersionError 0x_03_01_00_02 TypeSpecNotFoundError 0x_03_01_00_03 UnexpectedMessageError 0x_03_02_00_00 InputDataError 0x_03_02_01_00 ParameterTypeMismatchError 0x_03_02_02_00 StateMismatchError #SHOULD_RETRY 0x_03_03_00_00 ResultCardinalityMismatchError 0x_03_04_00_00 CapabilityError 0x_03_04_01_00 UnsupportedCapabilityError 0x_03_04_02_00 DisabledCapabilityError 0x_03_04_03_00 UnsafeIsolationLevelError #### 0x_04_00_00_00 QueryError 0x_04_01_00_00 InvalidSyntaxError 0x_04_01_01_00 EdgeQLSyntaxError 0x_04_01_02_00 SchemaSyntaxError 0x_04_01_03_00 GraphQLSyntaxError 0x_04_02_00_00 InvalidTypeError 0x_04_02_01_00 InvalidTargetError 0x_04_02_01_01 InvalidLinkTargetError 0x_04_02_01_02 InvalidPropertyTargetError 0x_04_03_00_00 InvalidReferenceError 0x_04_03_00_01 UnknownModuleError 0x_04_03_00_02 UnknownLinkError 0x_04_03_00_03 UnknownPropertyError 0x_04_03_00_04 UnknownUserError 0x_04_03_00_05 UnknownDatabaseError 0x_04_03_00_06 UnknownParameterError 0x_04_03_00_07 DeprecatedScopingError 0x_04_04_00_00 SchemaError 0x_04_05_00_00 SchemaDefinitionError 0x_04_05_01_00 InvalidDefinitionError 0x_04_05_01_01 InvalidModuleDefinitionError 0x_04_05_01_02 InvalidLinkDefinitionError 0x_04_05_01_03 InvalidPropertyDefinitionError 0x_04_05_01_04 InvalidUserDefinitionError 0x_04_05_01_05 InvalidDatabaseDefinitionError 0x_04_05_01_06 InvalidOperatorDefinitionError 0x_04_05_01_07 InvalidAliasDefinitionError 0x_04_05_01_08 InvalidFunctionDefinitionError 0x_04_05_01_09 InvalidConstraintDefinitionError 0x_04_05_01_0A InvalidCastDefinitionError 0x_04_05_02_00 DuplicateDefinitionError 0x_04_05_02_01 DuplicateModuleDefinitionError 0x_04_05_02_02 DuplicateLinkDefinitionError 0x_04_05_02_03 DuplicatePropertyDefinitionError 0x_04_05_02_04 DuplicateUserDefinitionError 0x_04_05_02_05 DuplicateDatabaseDefinitionError 0x_04_05_02_06 DuplicateOperatorDefinitionError 0x_04_05_02_07 DuplicateViewDefinitionError 0x_04_05_02_08 DuplicateFunctionDefinitionError 0x_04_05_02_09 DuplicateConstraintDefinitionError 0x_04_05_02_0A DuplicateCastDefinitionError 0x_04_05_02_0B DuplicateMigrationError #### 0x_04_06_00_00 SessionTimeoutError 0x_04_06_01_00 IdleSessionTimeoutError #SHOULD_RETRY 0x_04_06_02_00 QueryTimeoutError 0x_04_06_0A_00 TransactionTimeoutError 0x_04_06_0A_01 IdleTransactionTimeoutError #### 0x_05_00_00_00 ExecutionError 0x_05_01_00_00 InvalidValueError 0x_05_01_00_01 DivisionByZeroError 0x_05_01_00_02 NumericOutOfRangeError 0x_05_01_00_03 AccessPolicyError 0x_05_01_00_04 QueryAssertionError 0x_05_02_00_00 IntegrityError 0x_05_02_00_01 ConstraintViolationError 0x_05_02_00_02 CardinalityViolationError 0x_05_02_00_03 MissingRequiredError 0x_05_03_00_00 TransactionError 0x_05_03_01_00 TransactionConflictError #SHOULD_RETRY 0x_05_03_01_01 TransactionSerializationError 0x_05_03_01_02 TransactionDeadlockError 0x_05_03_01_03 QueryCacheInvalidationError 0x_05_04_00_00 WatchError #### 0x_06_00_00_00 ConfigurationError #### 0x_07_00_00_00 AccessError 0x_07_01_00_00 AuthenticationError #### 0x_08_00_00_00 AvailabilityError 0x_08_00_00_01 BackendUnavailableError #SHOULD_RETRY 0x_08_00_00_02 ServerOfflineError #SHOULD_RECONNECT #SHOULD_RETRY 0x_08_00_00_03 UnknownTenantError #SHOULD_RECONNECT #SHOULD_RETRY 0x_08_00_00_04 ServerBlockedError #### 0x_09_00_00_00 BackendError 0x_09_00_01_00 UnsupportedBackendFeatureError #### 0x_F0_00_00_00 LogMessage 0x_F0_01_00_00 WarningMessage 0x_F0_02_00_00 StatusMessage 0x_F0_02_00_01 MigrationStatusMessage #### Suggested errors for Gel clients 0x_FF_00_00_00 ClientError 0x_FF_01_00_00 ClientConnectionError 0x_FF_01_01_00 ClientConnectionFailedError 0x_FF_01_01_01 ClientConnectionFailedTemporarilyError #SHOULD_RECONNECT #SHOULD_RETRY 0x_FF_01_02_00 ClientConnectionTimeoutError #SHOULD_RECONNECT #SHOULD_RETRY 0x_FF_01_03_00 ClientConnectionClosedError #SHOULD_RECONNECT #SHOULD_RETRY 0x_FF_02_00_00 InterfaceError 0x_FF_02_01_00 QueryArgumentError 0x_FF_02_01_01 MissingArgumentError 0x_FF_02_01_02 UnknownArgumentError 0x_FF_02_01_03 InvalidArgumentError 0x_FF_03_00_00 NoDataError 0x_FF_04_00_00 InternalClientError ================================================ FILE: edb/api/types.txt ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Use `edb gen-types` to regenerate `edb/schema/_types.py` based on this file. # Base Scalar Types 00000000-0000-0000-0000-000000000001 anytype 00000000-0000-0000-0000-000000000002 anytuple 00000000-0000-0000-0000-000000000003 anyobject 00000000-0000-0000-0000-0000000000F0 std 00000000-0000-0000-0000-0000000000FF empty-tuple 00000000-0000-0000-0000-000000000100 std::uuid 00000000-0000-0000-0000-000000000101 std::str 00000000-0000-0000-0000-000000000102 std::bytes 00000000-0000-0000-0000-000000000103 std::int16 00000000-0000-0000-0000-000000000104 std::int32 00000000-0000-0000-0000-000000000105 std::int64 00000000-0000-0000-0000-000000000106 std::float32 00000000-0000-0000-0000-000000000107 std::float64 00000000-0000-0000-0000-000000000108 std::decimal 00000000-0000-0000-0000-000000000109 std::bool 00000000-0000-0000-0000-00000000010A std::datetime 00000000-0000-0000-0000-00000000010E std::duration 00000000-0000-0000-0000-00000000010F std::json 00000000-0000-0000-0000-000000000110 std::bigint 00000000-0000-0000-0000-00000000010B std::cal::local_datetime 00000000-0000-0000-0000-00000000010C std::cal::local_date 00000000-0000-0000-0000-00000000010D std::cal::local_time 00000000-0000-0000-0000-000000000111 std::cal::relative_duration 00000000-0000-0000-0000-000000000112 std::cal::date_duration 00000000-0000-0000-0000-000000000130 cfg::memory 00000000-0000-0000-0000-000001000001 std::pg::json 00000000-0000-0000-0000-000001000002 std::pg::timestamptz 00000000-0000-0000-0000-000001000003 std::pg::timestamp 00000000-0000-0000-0000-000001000004 std::pg::date 00000000-0000-0000-0000-000001000005 std::pg::interval ================================================ FILE: edb/buildmeta.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, Mapping, Sequence, NamedTuple, TypedDict, cast, ) # DO NOT put any imports here other than from stdlib # or modules from edb.common that themselves have only stdlib imports. import base64 import datetime import hashlib import importlib.util import json import logging import os import pathlib import pickle import platform import re import subprocess import sys import tempfile from edb.common import debug from edb.common import devmode from edb.common import verutils # Increment this whenever the database layout or stdlib changes. # # WARNING: DO NOT INCREMENT THIS WHEN BACKPORTING CHANGES TO A RELEASE BRANCH. # The merge conflict there is a nice reminder that you probably need # to write a patch in edb/pgsql/patches.py, and then you should preserve # the old value. EDGEDB_CATALOG_VERSION = 2025_11_03_00_00 EDGEDB_MAJOR_VERSION = 8 class MetadataError(Exception): pass class BackendVersion(NamedTuple): major: int minor: int micro: int releaselevel: str serial: int string: str class VersionMetadata(TypedDict): build_date: datetime.datetime | None build_hash: str | None scm_revision: str | None source_date: datetime.datetime | None target: str | None def get_build_metadata_value(prop: str) -> str: env_val = os.environ.get(f'_GEL_BUILDMETA_{prop}') if env_val: return env_val env_val = os.environ.get(f'_EDGEDB_BUILDMETA_{prop}') if env_val: return env_val try: from . import _buildmeta # type: ignore return getattr(_buildmeta, prop) except (ImportError, AttributeError): raise MetadataError( f'could not find {prop} in Gel distribution metadata') from None def _get_devmode_pg_config_path() -> pathlib.Path: root = pathlib.Path(__file__).parent.parent.resolve() pg_config = root / 'build' / 'postgres' / 'install' / 'bin' / 'pg_config' if not pg_config.is_file(): try: pg_config = pathlib.Path( get_build_metadata_value('PG_CONFIG_PATH')) except MetadataError: pass if not pg_config.is_file(): raise MetadataError('DEV mode: Could not find PostgreSQL build, ' 'run `pip install -e .`') return pg_config def get_pg_config_path() -> pathlib.Path: if devmode.is_in_dev_mode(): pg_config = _get_devmode_pg_config_path() else: try: pg_config = pathlib.Path( get_build_metadata_value('PG_CONFIG_PATH')) except MetadataError: pg_config = _get_devmode_pg_config_path() else: if not pg_config.is_file(): raise MetadataError( f'invalid pg_config path: {pg_config!r}: file does not ' f'exist or is not a regular file') return pg_config _pg_version_regex = re.compile( r"(Postgre[^\s]*)?\s*" r"(?P[0-9]+)\.?" r"((?P[0-9]+)\.?)?" r"(?P[0-9]+)?" r"(?P[a-z]+)?" r"(?P[0-9]+)?" ) def parse_pg_version(version_string: str) -> BackendVersion: version_match = _pg_version_regex.search(version_string) if version_match is None: raise ValueError( f"malformed Postgres version string: {version_string!r}") version = version_match.groupdict() return BackendVersion( major=int(version["major"]), minor=0, micro=int(version.get("minor") or 0), releaselevel=version.get("releaselevel") or "final", serial=int(version.get("serial") or 0), string=version_string, ) _bundled_pg_version: Optional[BackendVersion] = None def get_pg_version() -> BackendVersion: global _bundled_pg_version if _bundled_pg_version is not None: return _bundled_pg_version pg_config = subprocess.run( [get_pg_config_path()], capture_output=True, text=True, check=True, ) for line in pg_config.stdout.splitlines(): k, eq, v = line.partition('=') if eq and k.strip().lower() == 'version': v = v.strip() parsed_ver = parse_pg_version(v) _bundled_pg_version = BackendVersion( major=parsed_ver.major, minor=parsed_ver.minor, micro=parsed_ver.micro, releaselevel=parsed_ver.releaselevel, serial=parsed_ver.serial, string=v, ) return _bundled_pg_version else: raise MetadataError( "could not find version information in pg_config output") def get_runstate_path(data_dir: pathlib.Path) -> pathlib.Path: if devmode.is_in_dev_mode(): return data_dir else: runstate_dir = get_build_metadata_value('RUNSTATE_DIR') if runstate_dir is not None: return pathlib.Path(runstate_dir) else: return data_dir def get_shared_data_dir_path() -> pathlib.Path: if devmode.is_in_dev_mode(): return devmode.get_dev_mode_cache_dir() # type: ignore[return-value] else: return pathlib.Path(get_build_metadata_value('SHARED_DATA_DIR')) def get_extension_dir_path() -> pathlib.Path: # TODO: Do we want a special metadata value?? return get_shared_data_dir_path() / "extensions" def hash_dirs( dirs: Sequence[tuple[str, str]], *, extra_files: Optional[Sequence[str | pathlib.Path]]=None, extra_data: Optional[bytes] = None, ) -> bytes: def hash_dir(dirname, ext, paths): with os.scandir(dirname) as it: for entry in it: if entry.is_file() and entry.name.endswith(ext): paths.append(entry.path) elif entry.is_dir(): hash_dir(entry.path, ext, paths) paths: list[str] = [] for dirname, ext in dirs: hash_dir(dirname, ext, paths) if extra_files: for extra_file in extra_files: if isinstance(extra_file, pathlib.Path): extra_file = str(extra_file.resolve()) paths.append(extra_file) h = hashlib.sha1() # sha1 is the fastest one. for path in sorted(paths): with open(path, 'rb') as f: h.update(f.read()) h.update(str(sys.version_info[:2]).encode()) if extra_data is not None: h.update(extra_data) return h.digest() def read_data_cache( cache_key: bytes, path: str, *, pickled: bool=True, source_dir: Optional[pathlib.Path] = None, ) -> Any: if source_dir is None: source_dir = get_shared_data_dir_path() full_path = source_dir / path if full_path.exists(): with open(full_path, 'rb') as f: src_hash = f.read(len(cache_key)) if src_hash == cache_key or debug.flags.bootstrap_cache_yolo: if pickled: data = f.read() try: return pickle.loads(data) except Exception: logging.exception(f'could not unpickle {path}') else: return f.read() def write_data_cache( obj: Any, cache_key: bytes, path: str, *, pickled: bool = True, target_dir: Optional[pathlib.Path] = None, ): if target_dir is None: target_dir = get_shared_data_dir_path() full_path = target_dir / path try: with tempfile.NamedTemporaryFile( mode='wb', dir=full_path.parent, delete=False) as f: f.write(cache_key) if pickled: pickle.dump(obj, file=f, protocol=pickle.HIGHEST_PROTOCOL) else: f.write(obj) except Exception: try: os.unlink(f.name) except OSError: pass finally: raise else: os.rename(f.name, full_path) def get_version() -> verutils.Version: if devmode.is_in_dev_mode(): root = pathlib.Path(__file__).parent.parent.resolve() version = verutils.parse_version(get_version_from_scm(root)) else: vertuple: list[Any] = list(get_build_metadata_value('VERSION')) vertuple[2] = verutils.VersionStage(vertuple[2]) version = verutils.Version(*vertuple) return version _version_dict: Optional[Mapping[str, Any]] = None def get_version_build_id( v: verutils.Version, short: bool = True, ) -> tuple[str, ...]: parts = [] if v.local: if short: build_hash = None build_kind = None for segment in v.local: if segment[0] == "s": build_hash = segment[1:] elif segment[0] == "b": build_kind = segment[1:] if build_kind == "official": if build_hash: parts.append(build_hash) elif build_kind: parts.append(build_kind) else: parts.extend(v.local) return tuple(parts) def get_version_dict() -> Mapping[str, Any]: global _version_dict if _version_dict is None: ver = get_version() _version_dict = { 'major': ver.major, 'minor': ver.minor, 'stage': ver.stage.name.lower(), 'stage_no': ver.stage_no, 'local': get_version_build_id(ver), } return _version_dict _version_json: Optional[str] = None def get_version_json() -> str: global _version_json if _version_json is None: _version_json = json.dumps(get_version_dict()) return _version_json def get_version_string(short: bool = True) -> str: v = get_version() string = f'{v.major}.{v.minor}' if v.stage is not verutils.VersionStage.FINAL: string += f'-{v.stage.name.lower()}.{v.stage_no}' build_id = get_version_build_id(v, short=short) if build_id: string += "+" + ".".join(build_id) return string def get_version_metadata() -> VersionMetadata: v = get_version() pfx_map = { "b": "build_type", "r": "build_date", "s": "build_hash", "g": "scm_revision", "d": "source_date", "t": "target", } result = {} for segment in v.local: key = pfx_map.get(segment[0]) if key: raw_val = segment[1:] val: str | datetime.datetime if key == "target": val = _decode_build_target(raw_val) elif key in {"build_date", "source_date"}: val = _decode_build_date(raw_val) else: val = raw_val result[key] = val return cast(VersionMetadata, result) def _decode_build_target(val: str) -> str: return ( base64.b32decode(val + "=" * (-len(val) % 8), casefold=True).decode() ) def _decode_build_date(val: str) -> datetime.datetime: return datetime.datetime.strptime(val, r"%Y%m%d%H%M").replace( tzinfo=datetime.timezone.utc) def get_version_from_scm(root: pathlib.Path) -> str: pretend = os.environ.get('SETUPTOOLS_SCM_PRETEND_VERSION') if pretend: return pretend posint = r'(0|[1-9]\d*)' pep440_version_re = re.compile( rf""" ^ (?P{posint}) \. (?P{posint}) ( \. (?P{posint}) )? ( (?Pa|b|rc) (?P{posint}) )? $ """, re.X, ) proc = subprocess.run( ['git', 'tag', '--list', 'v*'], stdout=subprocess.PIPE, universal_newlines=True, check=True, cwd=root, ) all_tags = { v[1:] for v in proc.stdout.strip().split('\n') if pep440_version_re.match(v[1:]) } proc = subprocess.run( ['git', 'tag', '--points-at', 'HEAD'], stdout=subprocess.PIPE, universal_newlines=True, check=True, cwd=root, ) head_tags = { v[1:] for v in proc.stdout.strip().split('\n') if pep440_version_re.match(v[1:]) } if all_tags & head_tags: tag = max(head_tags) else: tag = max(all_tags) m = pep440_version_re.match(tag) assert m is not None major = EDGEDB_MAJOR_VERSION minor = m.group('minor') micro = m.group('micro') or '' microkind = '.' if micro else '' prekind = m.group('prekind') or '' preval = m.group('preval') or '' if os.environ.get("EDGEDB_BUILD_IS_RELEASE"): # Release build. ver = f'{major}.{minor}{microkind}{micro}{prekind}{preval}' else: # Dev/nightly build. microkind = '' micro = '' minor = '0' incremented_ver = f'{major}.{minor}{microkind}{micro}' proc = subprocess.run( ['git', 'rev-list', '--count', 'HEAD'], stdout=subprocess.PIPE, universal_newlines=True, check=True, cwd=root, ) commits_on_branch = proc.stdout.strip() ver = f'{incremented_ver}.dev{commits_on_branch}' proc = subprocess.run( ['git', 'rev-parse', '--verify', '--quiet', 'HEAD^{commit}'], stdout=subprocess.PIPE, universal_newlines=True, check=True, cwd=root, ) commitish = proc.stdout.strip() env = dict(os.environ) env['TZ'] = 'UTC' proc = subprocess.run( ['git', 'show', '-s', '--format=%cd', '--date=format-local:%Y%m%d%H', commitish], stdout=subprocess.PIPE, universal_newlines=True, check=True, cwd=root, env=env, ) rev_date = proc.stdout.strip() catver = EDGEDB_CATALOG_VERSION full_version = f'{ver}+d{rev_date}.g{commitish[:9]}.cv{catver}' build_target = os.environ.get("EDGEDB_BUILD_TARGET") if build_target: # Check that build target is encoded correctly _decode_build_target(build_target) else: plat = sys.platform if plat == "win32": plat = "windows" ident = [ platform.machine(), "pc" if plat == "windows" else "apple" if plat == "darwin" else "unknown", plat, ] if hasattr(platform, "libc_ver"): libc, _ = platform.libc_ver() if libc == "glibc": ident.append("gnu") elif libc == "musl": ident.append("musl") build_target = base64.b32encode( "-".join(ident).encode()).decode().rstrip("=").lower() build_date = os.environ.get("EDGEDB_BUILD_DATE") if build_date: # Validate _decode_build_date(build_date) else: now = datetime.datetime.now(tz=datetime.timezone.utc) build_date = now.strftime(r"%Y%m%d%H%M") version_line = f'{full_version}.r{build_date}.t{build_target}' if not os.environ.get("EDGEDB_BUILD_OFFICIAL"): build_type = "local" else: build_type = "official" version_line += f'.b{build_type}' version_hash = hashlib.sha256(version_line.encode("utf-8")).hexdigest() full_version = f"{version_line}.s{version_hash[:7]}" return full_version def get_cache_src_dirs(): find_spec = importlib.util.find_spec edgeql = pathlib.Path(find_spec('edb.edgeql').origin).parent return ( (pathlib.Path(find_spec('edb.schema').origin).parent, '.py'), (edgeql / 'compiler', '.py'), (edgeql / 'parser', '.py'), (pathlib.Path(find_spec('edb.lib').origin).parent, '.edgeql'), (pathlib.Path(find_spec('edb.pgsql.metaschema').origin).parent, '.py'), ) def get_default_tenant_id() -> str: return 'E' def get_version_line() -> str: ver_meta = get_version_metadata() extras = [] source = "" if build_date := ver_meta["build_date"]: nice_date = build_date.strftime("%Y-%m-%dT%H:%MZ") source += f" on {nice_date}" if ver_meta["scm_revision"]: source += f" from revision {ver_meta['scm_revision']}" if source_date := ver_meta["source_date"]: nice_date = source_date.strftime("%Y-%m-%dT%H:%MZ") source += f" ({nice_date})" if source: extras.append(f", built{source}") if ver_meta["target"]: extras.append(f"for {ver_meta['target']}") return get_version_string() + " ".join(extras) ================================================ FILE: edb/cli/.gitignore ================================================ /edgedb /gel ================================================ FILE: edb/cli/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, NoReturn import os import sys def rustcli(*, args: Optional[list[str]]=None) -> NoReturn: if args is None: args = [*sys.argv] os.execvpe('gel', args, os.environ) ================================================ FILE: edb/cli/__main__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Stub to allow invoking builtin `edgedb` CLI as `python -m edb.cli`.""" from __future__ import annotations import sys from edb import cli if __name__ == '__main__': sys.exit(cli.rustcli()) ================================================ FILE: edb/common/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # DO NOT ADD ANYTHING TO THIS FILE. # Importing packages like `edb.common.debug` must not # cause any side-effects. ================================================ FILE: edb/common/_typing_inspect.py ================================================ # The MIT License (MIT) # # Portions Copyright (c) 2017-2019 Ivan Levkivskyi # Portions Copyright (c) 2021 MagicStack Inc. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. """ This is a micro-implementation of a subset of `typing-inspect` API that Gel relies on that only works on Python 3.9+. """ from __future__ import annotations import collections from types import GenericAlias, UnionType # type: ignore from typing import _GenericAlias # type: ignore from typing import Any, ClassVar, Generic, Optional, TypeVar, Union __all__ = [ "is_classvar", "is_typevar", "is_generic_type", "is_union_type", "is_tuple_type", "get_args", "get_generic_bases", "get_parameters", "get_origin", ] def is_classvar(t) -> bool: return t is ClassVar or _is_genericalias(t) and t.__origin__ is ClassVar def is_typevar(t) -> bool: return type(t) is TypeVar def is_generic_type(t) -> bool: return ( isinstance(t, type) and issubclass(t, Generic) # type: ignore or _is_genericalias(t) and t.__origin__ not in (Union, tuple, ClassVar, collections.abc.Callable) ) def is_union_type(t) -> bool: return ( t is Union or (_is_genericalias(t) and t.__origin__ is Union) or isinstance(t, UnionType) ) def is_tuple_type(t) -> bool: return ( t is tuple or _is_genericalias(t) and t.__origin__ is tuple or isinstance(t, type) and issubclass(t, Generic) # type: ignore and issubclass(t, tuple) ) def get_args(t, evaluate: bool = True) -> Any: if evaluate is not None and not evaluate: raise ValueError("evaluate can only be True in Python >= 3.7") if _is_genericalias(t) or isinstance(t, UnionType): res = t.__args__ if ( get_origin(t) is collections.abc.Callable and res[0] is not Ellipsis ): res = (list(res[:-1]), res[-1]) return res return () def get_generic_bases(t) -> tuple[type, ...]: return getattr(t, "__orig_bases__", ()) def get_parameters(t) -> tuple[TypeVar, ...]: if ( _is_genericalias(t) or isinstance(t, type) and issubclass(t, Generic) # type: ignore and t is not Generic ): return t.__parameters__ else: return () def get_origin(t) -> Optional[type]: if _is_genericalias(t): return t.__origin__ if t.__origin__ is not ClassVar else None if t is Generic: return Generic # type: ignore return None def _is_genericalias(t) -> bool: return isinstance(t, (GenericAlias, _GenericAlias)) ================================================ FILE: edb/common/adapter.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional class AdapterError(Exception): pass _adapters: dict[Any, dict[type, Adapter]] = {} class Adapter(type): __edb_adaptee__: Optional[type] def __new__[Adapter_T: Adapter]( mcls: type[Adapter_T], name: str, bases: tuple[type, ...], clsdict: dict[str, Any], *, adapts: Optional[type] = None, **kwargs: Any, ) -> Adapter_T: if adapts is not None: bases = bases + (adapts,) clsdict['__edb_adaptee__'] = adapts result = super().__new__(mcls, name, bases, clsdict, **kwargs) if adapts is not None: assert issubclass(mcls, Adapter) and mcls is not Adapter try: adapters = _adapters[mcls] except KeyError: adapters = _adapters[mcls] = {} assert adapts not in adapters adapters[adapts] = result return result def __init__( cls, name: str, bases: tuple[type, ...], clsdict: dict[str, Any], *, adapts: Optional[type] = None, **kwargs: Any, ): super().__init__(name, bases, clsdict, **kwargs) @classmethod def _match_adapter( mcls, obj: type, adaptee: type, adapter: Adapter, ) -> Optional[Adapter]: if issubclass(obj, adapter) and obj is not adapter: # mypy bug below return obj # type: ignore elif issubclass(obj, adaptee): return adapter else: return None @classmethod def _get_adapter( mcls, reversed_mro: tuple[type, ...], ) -> Optional[Adapter]: adapters = _adapters.get(mcls) if adapters is None: return None result = None seen: set[Adapter] = set() for base in reversed_mro: for adaptee, adapter in adapters.items(): found = mcls._match_adapter(base, adaptee, adapter) if found and found not in seen: result = found seen.add(found) return result @classmethod def get_adapter(mcls, obj: Any) -> Optional[Adapter]: mro = obj.__mro__ reversed_mro = tuple(reversed(mro)) result = mcls._get_adapter(reversed_mro) if result is not None: return result for mc in mcls.__subclasses__(mcls): result = mc._get_adapter(reversed_mro) if result is not None: return result return None @classmethod def adapt[T](mcls, obj: T) -> T: adapter = mcls.get_adapter(obj.__class__) if adapter is None: raise AdapterError( 'could not find {}.{} adapter for {}'.format( mcls.__module__, mcls.__name__, obj.__class__.__name__ ) ) elif adapter is not obj.__class__: # type: ignore return adapter.adapt(obj) else: return obj def get_adaptee(cls) -> type: adaptee = cls.__edb_adaptee__ if adaptee is None: raise LookupError(f'adapter {cls} has no adaptee type') return adaptee def has_adaptee(cls) -> bool: return cls.__edb_adaptee__ is not None ================================================ FILE: edb/common/assert_data_shape.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import datetime import decimal import math import pprint import uuid import unittest import edgedb class bag(list): """Wrapper for list that tells assert_query_result to ignore order""" def __repr__(self): return f'bag({list.__repr__(self)})' def sort_results(results, sort): if sort is True: sort = lambda x: x # don't bother sorting empty things if results: # sort can be either a key function or a dict if isinstance(sort, dict): # the keys in the dict indicate the fields that # actually must be sorted for key, val in sort.items(): # '.' is a special key referring to the base object if key == '.': sort_results(results, val) else: if isinstance(results, list): for r in results: sort_results(r[key], val) else: sort_results(results[key], val) else: results.sort(key=sort) def assert_data_shape( data, shape, fail, message=None, from_sql=False, rel_tol=None, abs_tol=None, ): try: import asyncpg from asyncpg import types as pgtypes except ImportError: if from_sql: raise unittest.SkipTest( 'SQL tests skipped: asyncpg not installed') base_fail = fail rel_tol = 1e-04 if rel_tol is None else rel_tol abs_tol = 1e-15 if abs_tol is None else abs_tol def fail(msg): base_fail(f'{msg}\nshape: {shape!r}\ndata: {data!r}') _void = object() def _format_path(path): if path: return 'PATH: ' + ''.join(str(p) for p in path) else: return 'PATH: ' def _assert_type_shape(path, data, shape): if shape in (int, float): if not isinstance(data, shape): fail( f'{message}: expected {shape}, got {data!r} ' f'{_format_path(path)}') else: try: shape(data) except (ValueError, TypeError): fail( f'{message}: expected {shape}, got {data!r} ' f'{_format_path(path)}') def _assert_dict_shape(path, data, shape): if not isinstance(data, dict): fail( f'{message}: expected dict ' f'{_format_path(path)}') # TODO: should we also check that there aren't *extra* keys # (other than id, __tname__?) for sk, sv in shape.items(): if not data or sk not in data: fail( f'{message}: key {sk!r} ' f'is missing\n{pprint.pformat(data)} ' f'{_format_path(path)}') _assert_generic_shape(path + (f'["{sk}"]',), data[sk], sv) def _list_shape_iter(shape): last_shape = _void for item in shape: if item is Ellipsis: if last_shape is _void: raise ValueError( 'invalid shape spec: Ellipsis cannot be the' 'first element') while True: yield last_shape last_shape = item yield item def _assert_list_shape(path, data, shape): if not isinstance(data, (list, tuple)): fail( f'{message}: expected list got {type(data)} ' f'{_format_path(path)}') if not data and shape: fail( f'{message}: expected non-empty list got {type(data)} ' f'{_format_path(path)}') shape_iter = _list_shape_iter(shape) _data_count = 0 for _data_count, el in enumerate(data): try: el_shape = next(shape_iter) except StopIteration: fail( f'{message}: unexpected trailing elements in list ' f'{_format_path(path)}') _assert_generic_shape( path + (f'[{_data_count}]',), el, el_shape) if len(shape) > _data_count + 1: if shape[_data_count + 1] is not Ellipsis: fail( f'{message}: expecting more elements in list ' f'{_format_path(path)}') def _assert_set_shape(path, data, shape): if not isinstance(data, (list, set)): fail( f'{message}: expected list or set ' f'{_format_path(path)}') if not data and shape: fail( f'{message}: expected non-empty set ' f'{_format_path(path)}') shape_iter = _list_shape_iter(sorted(shape)) _data_count = 0 for _data_count, el in enumerate(sorted(data)): try: el_shape = next(shape_iter) except StopIteration: fail( f'{message}: unexpected trailing elements in set ' f'[path {_format_path(path)}]') _assert_generic_shape( path + (f'{{{_data_count}}}',), el, el_shape) if len(shape) > _data_count + 1: if Ellipsis not in shape: fail( f'{message}: expecting more elements in set ' f'{_format_path(path)}') def _assert_bag_shape(path, data, shape): # A bag is treated like a set except that we want it to work # on objects, which can't be hashed or sorted. if not isinstance(data, (list, set)): fail( f'{message}: expected list or set ' f'{_format_path(path)}') if Ellipsis in shape: raise ValueError( f"{message}: can't use ellipsis in set/bag shape") data = list(data) if len(data) > len(shape): fail( f'{message}: too many elements in list ' f'{_format_path(path)}') # this is all very O(n^2) but n should be small for el_shape in shape: for data_count, el in enumerate(data): try: _assert_generic_shape( path + (f'[{data_count}]',), el, el_shape) except AssertionError: # oh well pass else: data.pop(data_count) break else: fail( f'{message}: missing elements in list ' f'{_format_path(path)}: {el_shape!r}') def _assert_generic_shape(path, data, shape): if from_sql: if isinstance(shape, bag): return _assert_bag_shape(path, data, shape) elif isinstance(shape, list): # NULL is acceptable substitute for the empty set, so we'll # assume that in our tests None satisfies the [] expected # result. if data is not None or len(shape) > 0: return _assert_list_shape(path, data, shape) elif isinstance(shape, tuple): assert isinstance(data, asyncpg.Record) return _assert_list_shape( path, [d for d in data.values()], shape) elif isinstance(shape, set): return _assert_set_shape(path, data, shape) elif isinstance(shape, dict): assert isinstance(data, asyncpg.Record) # If the record has "target" pop the "id" from the expected # results as we expect it to be a "target" duplicate. rec = {k: v for k, v in data.items()} if 'target' in rec: if 'id' in shape and shape['id'] == shape.get('target'): shape.pop('id') return _assert_dict_shape(path, rec, shape) elif isinstance(shape, type): return _assert_type_shape(path, data, shape) elif isinstance(shape, float): if not math.isclose(data, shape, rel_tol=rel_tol, abs_tol=abs_tol): fail( f'{message}: not isclose({data}, {shape}) ' f'{_format_path(path)}') elif isinstance(shape, uuid.UUID): # If data comes from SQL, we expect UUID. if data != shape: fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, (str, int, bytes, datetime.timedelta, decimal.Decimal)): if data != shape: fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, edgedb.RelativeDuration): if data != datetime.timedelta( days=shape.months * 30 + shape.days, microseconds=shape.microseconds, ): fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, edgedb.DateDuration): if data != datetime.timedelta( days=shape.months * 30 + shape.days, ): fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, edgedb.Range): if data != pgtypes.Range( lower=shape.lower, upper=shape.upper, lower_inc=shape.inc_lower, upper_inc=shape.inc_upper, empty=shape.is_empty(), ): fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, edgedb.EnumValue): if data != str(shape): fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif shape is None: if data is not None: fail( f'{message}: {data!r} is expected to be None ' f'{_format_path(path)}') else: if data != shape: fail( f'{message}: ({type(data)}) {data!r} != ' f'({type(shape)}) {shape!r} ' f'{_format_path(path)}') else: if isinstance(shape, bag): return _assert_bag_shape(path, data, shape) elif isinstance(shape, (list, tuple)): return _assert_list_shape(path, data, shape) elif isinstance(shape, set): return _assert_set_shape(path, data, shape) elif isinstance(shape, dict): return _assert_dict_shape(path, data, shape) elif isinstance(shape, type): return _assert_type_shape(path, data, shape) elif isinstance(shape, float): if math.isnan(shape): if not math.isnan(shape): fail( f'NaN mismatch {_format_path(path)}' ) elif not math.isclose(data, shape, rel_tol=rel_tol, abs_tol=abs_tol): fail( f'{message}: not isclose({data}, {shape}) ' f'{_format_path(path)}') elif isinstance(shape, uuid.UUID): # We expect a str from JSON. if data != str(shape): fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, (str, int, bytes, datetime.timedelta, decimal.Decimal)): if data != shape: fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, edgedb.RelativeDuration): if data != shape: fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif isinstance(shape, edgedb.DateDuration): if data != shape: fail( f'{message}: {data!r} != {shape!r} ' f'{_format_path(path)}') elif shape is None: if data is not None: fail( f'{message}: {data!r} is expected to be None ' f'{_format_path(path)}') else: raise ValueError(f'unsupported shape type {shape}') message = message or 'data shape differs' return _assert_generic_shape((), data, shape) ================================================ FILE: edb/common/ast/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from .base import * # NOQA from .visitor import * # NOQA from .transformer import * # NOQA from .codegen import * # NOQA ================================================ FILE: edb/common/ast/base.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import copy import collections.abc import functools import re import sys from typing import ( Any, Callable, cast, get_type_hints, TYPE_CHECKING, AbstractSet # NoQA ) from edb.common import debug from edb.common import markup from edb.common import typing_inspect class ASTError(Exception): pass class _Field: def __init__( self, name, type_, default, factory, field_hidden=False, field_meta=False, ): self.name = name self.type = type_ self.default = default self.factory = factory self.hidden = field_hidden self.meta = field_meta class _FieldSpec: def __init__(self, factory): self.factory = factory def field[T](*, factory: Callable[[], T]) -> T: return cast(T, _FieldSpec(factory=factory)) def _check_type_passthrough(type_, value, raise_error): pass def _check_type_real(type_, value, raise_error): if type_ is None: return if typing_inspect.is_union_type(type_): for t in typing_inspect.get_args(type_, evaluate=True): try: _check_type(t, value, raise_error) except TypeError: pass else: break else: raise_error(str(type_), value) elif typing_inspect.is_tuple_type(type_): _check_tuple_type(type_, value, raise_error, tuple) elif typing_inspect.is_generic_type(type_): ot = typing_inspect.get_origin(type_) if ot in (list, list, collections.abc.Sequence): _check_container_type(type_, value, raise_error, list) elif ot in (set, set): _check_container_type(type_, value, raise_error, set) elif ot in (frozenset, frozenset): _check_container_type(type_, value, raise_error, frozenset) elif ot in (dict, dict): _check_mapping_type(type_, value, raise_error, dict) elif ot is not None: raise TypeError(f'unsupported typing type: {type_!r}') elif type_ is not Any: if value is not None and not isinstance(value, type_): raise_error(type_.__name__, value) if debug.flags.typecheck: _check_type = _check_type_real else: _check_type = _check_type_passthrough class AST: # These use type comments because type annotations are interpreted # by the AST system and so annotating them would interfere! __ast_frozen_fields__ = frozenset() # type: AbstractSet[str] # Class setup stuff: @classmethod def _collect_direct_fields(cls): dct = cls.__dict__ cls.__abstract_node__ = bool(dct.get('__abstract_node__')) cls.__rust_ignore__ = bool(dct.get('__rust_ignore__')) if '__annotations__' not in dct: cls._direct_fields = [] return cls globalns = sys.modules[cls.__module__].__dict__.copy() globalns[cls.__name__] = cls try: while True: try: annos = get_type_hints(cls, globalns) except NameError as e: # Forward type declaration. Generally, we try # to avoid these as much as possible, but when # there's a cycle it's better to have correct # static type analysis even though the runtime # validation infrastructure does not support # cyclic references. # XXX: This is a horrible hack, need to find # a better way. m = re.match(r"name '(\w+)' is not defined", e.args[0]) if not m: raise globalns[m.group(1)] = AST else: break except Exception: raise RuntimeError( f'unable to resolve type annotations for ' f'{cls.__module__}.{cls.__qualname__}') if annos: annos = {k: v for k, v in annos.items() if k in dct['__annotations__']} hidden = () if '__ast_hidden__' in dct: hidden = set(dct['__ast_hidden__']) meta = () if '__ast_meta__' in dct: meta = set(dct['__ast_meta__']) fields = [] for f_name, f_type in annos.items(): if f_type is object: f_type = None factory = None if f_name in dct: f_default = dct[f_name] if isinstance(f_default, _FieldSpec): factory = f_default.factory f_default = None delattr(cls, f_name) else: f_default = None f_hidden = f_name in hidden f_meta = f_name in meta fields.append(_Field( f_name, f_type, f_default, factory, f_hidden, f_meta )) cls._direct_fields = fields return cls def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._collect_direct_fields() fields = collections.OrderedDict() for parent in reversed(cls.__mro__): lst = getattr(parent, '_direct_fields', []) for field in lst: fields[field.name] = field cls._fields = fields cls._field_factories = tuple( (k, v.factory) for k, v in fields.items() if v.factory and not isinstance(getattr(cls, k, None), property) ) # Push the default values down in the MRO for k, v in cls._fields.items(): if ( not v.factory and not isinstance(getattr(cls, k, None), property) and k not in cls.__dict__ ): setattr(cls, k, v.default) @classmethod def get_field(cls, name): return cls._fields.get(name) # Actual object level code def __init__(self, **kwargs): if type(self).__abstract_node__: raise ASTError( f'cannot instantiate abstract AST node ' f'{self.__class__.__name__!r}') # Make kwargs directly into our __dict__ for field_name, factory in self._field_factories: if field_name not in kwargs: kwargs[field_name] = factory() should_check_types = __debug__ and _check_type is _check_type_real if should_check_types: for k, v in kwargs.items(): self.check_field_type(self._fields[k], v) self.__dict__ = kwargs def __copy__(self): copied = self._init_copy() for field, value in iter_fields(self, include_meta=True): try: object.__setattr__(copied, field, value) except AttributeError: # don't mind not setting getter_only attrs. continue return copied def __deepcopy__(self, memo): copied = self._init_copy() for field, value in iter_fields(self, include_meta=True): object.__setattr__(copied, field, copy.deepcopy(value, memo)) return copied def _init_copy(self): return self.__class__() def replace[T](self: T, **changes) -> T: copied = copy.copy(self) for field, value in changes.items(): object.__setattr__(copied, field, value) return copied def _checked_setattr(self, name, value): super().__setattr__(name, value) field = self._fields.get(name) if field: self.check_field_type(field, value) if name in self.__ast_frozen_fields__: raise TypeError(f'cannot set immutable {name} on {self!r}') if __debug__ and _check_type is _check_type_real: __setattr__ = _checked_setattr def check_field_type(self, field, value): def raise_error(field_type_name, value): raise TypeError( '%s.%s.%s: expected %s but got %s' % ( self.__class__.__module__, self.__class__.__name__, field.name, field_type_name, value.__class__.__name__)) _check_type(field.type, value, raise_error) def dump(self, *, meta=True): markup.dump(self, _ast_include_meta=meta) class ImmutableASTMixin: __frozen = False # This uses type comments because type annotations are interpreted # by the AST system and so annotating them would interfere! __ast_mutable_fields__ = frozenset() # type: AbstractSet[str] def __init__(self, **kwargs): super().__init__(**kwargs) self.__frozen = True # mypy gets mad about this if there isn't a __setattr__ in AST. # I don't know why. if not TYPE_CHECKING: def __setattr__(self, name, value): if self.__frozen and name not in self.__ast_mutable_fields__: raise TypeError(f'cannot set {name} on immutable {self!r}') else: super().__setattr__(name, value) @markup.serializer.serializer.register(AST) def serialize_to_markup(ast, *, ctx): node = markup.elements.lang.TreeNode(id=id(ast), name=type(ast).__name__) include_meta = ctx.kwargs.get('_ast_include_meta', True) exclude_unset = ctx.kwargs.get('_ast_exclude_unset', True) if debug.flags.ast_span: s = getattr(ast, 'span', None) if s: node.add_child(label='span', node=markup.serialize(str(s), ctx=ctx)) fields = iter_fields( ast, include_meta=include_meta, exclude_unset=exclude_unset) for fieldname, field in fields: if ast._fields[fieldname].hidden: continue if field is None: if ast._fields[fieldname].meta: continue node.add_child(label=fieldname, node=markup.serialize(field, ctx=ctx)) return node @functools.lru_cache(1024) def _is_ast_node_type(cls): return issubclass(cls, AST) def is_ast_node(obj): return _is_ast_node_type(obj.__class__) _marker = object() def iter_fields(node, *, include_meta=True, exclude_unset=False): exclude_meta = not include_meta for field_name, field in node._fields.items(): if exclude_meta and field.meta: continue field_val = getattr(node, field_name, _marker) if field_val is _marker: continue if exclude_unset: if field.factory: default = field.factory() else: default = field.default if field_val == default: continue yield field_name, field_val def _is_optional(type_): return (typing_inspect.is_union_type(type_) and type(None) in typing_inspect.get_args(type_, evaluate=True)) def _check_container_type(type_, value, raise_error, instance_type): if not isinstance(value, instance_type): raise_error(str(type_), value) type_args = typing_inspect.get_args(type_, evaluate=True) eltype = type_args[0] for el in value: _check_type(eltype, el, raise_error) def _check_tuple_type(type_, value, raise_error, instance_type): if not isinstance(value, instance_type): raise_error(str(type_), value) eltype = None ellipsis = False type_args = typing_inspect.get_args(type_, evaluate=True) for i, el in enumerate(value): if not ellipsis: new_eltype = type_args[i] if new_eltype is Ellipsis: ellipsis = True else: eltype = new_eltype if eltype is not None: _check_type(eltype, el, raise_error) def _check_mapping_type(type_, value, raise_error, instance_type): if not isinstance(value, instance_type): raise_error(str(type_), value) type_args = typing_inspect.get_args(type_, evaluate=True) ktype = type_args[0] vtype = type_args[1] for k, v in value.items(): _check_type(ktype, k, raise_error) if not k and not _is_optional(ktype): raise RuntimeError('empty key in map') _check_type(vtype, v, raise_error) ================================================ FILE: edb/common/ast/codegen.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Iterable, Sequence from dataclasses import dataclass import itertools import textwrap from . import base from .visitor import NodeVisitor @dataclass(kw_only=True, eq=False, match_args=False, slots=True, frozen=True) class Options: indent_with: str = ' ' * 4 add_line_information: bool = False pretty: bool = True class SourceGenerator(NodeVisitor): """Generate source code from an AST tree.""" result: list[str] def __init__( self, indent_with: str = ' ' * 4, add_line_information: bool = False, pretty: bool = True ) -> None: self.result = [] self.indent_with = indent_with self.add_line_information = add_line_information self.indentation = 0 self.char_indentation = 0 self.new_lines = 0 self.current_line = 1 self.pretty = pretty def node_visit(self, node: base.AST) -> None: method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) return visitor(node) def visit_indented( self, node: base.AST, indent: bool = True, nest: bool = False ) -> None: if nest: self.write("(") if indent: self.new_lines = 1 self.char_indentation += 1 res = self.visit(node) if indent: self.char_indentation -= 1 if nest: self.write(")") self.new_lines = 1 return res def write(self, *x: str, delimiter: Optional[str] = None) -> None: if not x: return if self.new_lines: if self.result and self.pretty: self.current_line += self.new_lines self.result.append('\n' * self.new_lines) if self.pretty: self.result.append(self.indent_with * self.indentation) self.result.append(' ' * self.char_indentation) else: self.result.append(' ') self.new_lines = 0 if delimiter: self.result.append(x[0]) chain = itertools.chain.from_iterable chunks: Iterable[str] = chain((delimiter, v) for v in x[1:]) else: chunks = x for chunk in chunks: if not isinstance(chunk, str): raise ValueError( 'invalid text chunk in codegen: {!r}'.format(chunk)) self.result.append(chunk) def visit_list( self, items: Sequence[base.AST], *, separator: str = ',', terminator: Optional[str] = None, newlines: bool = True, **kwargs: Any ) -> None: # terminator overrides separator setting # separator = terminator if terminator is not None else separator size = len(items) for i, item in enumerate(items): self.visit(item, **kwargs) # type: ignore if i < size - 1 or terminator is not None: self.write(separator) if newlines: self.new_lines = 1 else: self.write(' ') def newline(self, node=None, extra=0): self.new_lines = max(self.new_lines, 1 + extra) if node is not None and self.add_line_information: self.write('# line: %s' % node.lineno) self.new_lines = 1 def finish(self) -> str: return ''.join(self.result) @classmethod def to_source( cls, node: base.AST | Sequence[base.AST], indent_with: str = ' ' * 4, add_line_information: bool = False, pretty: bool = True, **kwargs: Any ) -> str: generator = cls(indent_with, add_line_information, # type: ignore pretty=pretty, **kwargs) generator.visit(node) return generator.finish() def indent_text(self, text: str) -> str: return textwrap.indent(text, self.indent_with * self.indentation) ================================================ FILE: edb/common/ast/transformer.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb.common import typeutils from . import base from . import visitor class NodeTransformer(visitor.NodeVisitor): """Walks the abstract syntax tree and allows modification of nodes. The `NodeTransformer` will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is ``None``, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place. Here is an example transformer that rewrites all occurrences of name lookups (``foo``) to ``data['foo']``:: class RewriteName(NodeTransformer): def visit_Name(self, node): return copy_location(Subscript( value=Name(id='data', ctx=Load()), slice=Index(value=Str(s=node.id)), ctx=node.ctx ), node) Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:`generic_visit` method for the node first. For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node. Usually you use the transformer like this:: node = YourTransformer().visit(node) """ def generic_visit(self, node): if isinstance(node, base.ImmutableASTMixin): changes = {} for field, old_value in base.iter_fields(node, include_meta=False): field_spec = node._fields[field] if self.skip_hidden and field_spec.hidden: continue if field in self.extra_skips: continue old_value = getattr(node, field, None) if typeutils.is_container(old_value): new_values = old_value.__class__(self.visit(old_value)) changes[field] = old_value.__class__(new_values) elif isinstance(old_value, base.AST): new_node = self.visit(old_value) if new_node is not old_value: changes[field] = new_node node = node.replace(**changes) else: for field, old_value in base.iter_fields(node, include_meta=False): field_spec = node._fields[field] if self.skip_hidden and field_spec.hidden: continue if field in self.extra_skips: continue old_value = getattr(node, field, None) if typeutils.is_container(old_value): new_values = old_value.__class__(self.visit(old_value)) setattr(node, field, old_value.__class__(new_values)) elif isinstance(old_value, base.AST): new_node = self.visit(old_value) if new_node is not old_value: setattr(node, field, new_node) return node ================================================ FILE: edb/common/ast/visitor.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( AbstractSet, Any, Callable, Collection, Optional, Iterable ) from edb.common import typeutils from . import base class SkipNode(Exception): pass def find_children[_T]( node: base.AST | Collection[base.AST], type: type[_T], test_func: Optional[Callable[[_T], bool]] = None, terminate_early=False, extra_skips: AbstractSet[str] = frozenset(), extra_skip_types: tuple[type, ...] = (), ) -> list[_T]: visited = set() result = [] def _find_children(node): if isinstance(node, extra_skip_types): return False elif isinstance(node, (tuple, list, set, frozenset)): for n in node: if _find_children(n): return True return False elif isinstance(node, dict): for n in node.values(): if _find_children(n): return True return False elif not base.is_ast_node(node): return False if node in visited: return False else: visited.add(node) try: if isinstance(node, type) and (not test_func or test_func(node)): result.append(node) if terminate_early: return True except SkipNode: return False for field, value in base.iter_fields(node, include_meta=False): field_spec = node._fields[field] if field_spec.hidden or field_spec.name in extra_skips: continue if _find_children(value): return True return False _find_children(node) return result class NodeVisitor: """Walk the AST and call a visitor function for every node found. This class is meant to be subclassed, with the subclass adding visitor methods. Per default the visitor functions for the nodes are ``'visit_'`` + class name of the node. So a `TryFinally` node visit function would be `visit_TryFinally`. This behavior can be changed by overriding the `visit` method. If no visitor function exists for a node (return value `None`) the `generic_visit` visitor is used instead. Don't use the `NodeVisitor` if you want to apply changes to nodes during traversing. For this a special visitor exists (`NodeTransformer`) that allows modifications. """ skip_hidden = False extra_skips: AbstractSet[str] = frozenset() def __init__(self, *, context=None, memo=None): if memo is not None: self._memo = memo else: self._memo = {} self._context = context @property def memo(self): return self._memo @classmethod def run(cls, node, **kwargs): visitor = cls(**kwargs) return visitor.visit(node) def container_visit(self, node) -> dict[Any, Any] | Iterable[Any]: def _visit_element(elem): if base.is_ast_node(elem) or typeutils.is_container(elem): return self.visit(elem) else: return elem result: dict[Any, Any] | Iterable[Any] if isinstance(node, dict): result = {} for key, value in node.items(): result[key] = _visit_element(value) elif isinstance(node, tuple): result = () for elem in node: result += (_visit_element(elem),) else: result = [] for elem in node: result.append(_visit_element(elem)) return result def repeated_node_visit(self, node): result = self.memo[node] if result is None: return node else: return result def node_visit(self, node): if node in self.memo: return self.repeated_node_visit(node) else: self.memo[node] = None for cls in node.__class__.__mro__: method = 'visit_' + cls.__name__ visitor = getattr(self, method, None) if visitor is not None: break else: visitor = self.generic_visit result = visitor(node) self.memo[node] = result return result def visit(self, node): if typeutils.is_container(node): return self.container_visit(node) elif base.is_ast_node(node): return self.node_visit(node) def generic_visit(self, node, *, combine_results=None): field_results = [] for field, value in base.iter_fields(node, include_meta=False): field_spec = node._fields[field] if self.skip_hidden and field_spec.hidden: continue if field in self.extra_skips: continue res = self.visit(value) if res is not None: field_results.append(res) if combine_results is not None: return combine_results(field_results) else: return self.combine_field_results(field_results) def combine_field_results(self, results): return results def nodes_equal(n1, n2): if type(n1) is not type(n2): return False for field, _value in base.iter_fields(n1, include_meta=False): if not n1._fields[field].hidden: n1v = getattr(n1, field) n2v = getattr(n2, field) if typeutils.is_container(n1v): n1v = list(n1v) if typeutils.is_container(n2v): n2v = list(n2v) else: return False if len(n1v) != len(n2v): return False for i, item1 in enumerate(n1v): try: item2 = n2v[i] except IndexError: return False if base.is_ast_node(item1): if not nodes_equal(item1, item2): return False else: if item1 != item2: return False elif base.is_ast_node(n1v): if not nodes_equal(n1v, n2v): return False else: if n1v != n2v: return False return True ================================================ FILE: edb/common/asyncutil.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Awaitable, Callable, cast, overload, Self, TypeVar, ) import asyncio import inspect import warnings async def deferred_shield[T](arg: Awaitable[T]) -> T: '''Wait for a future, deferring cancellation until it is complete. If you do await deferred_shield(something()) it is approximately equivalent to await something() except that if the coroutine containing it is cancelled, something() is protected from cancellation, and *additionally* CancelledError is not raised in the caller until something() completes. This can be useful if something() contains something that shouldn't be interrupted but also can't be safely left running asynchronously. ''' task = asyncio.ensure_future(arg) ex = None while not task.done(): try: await asyncio.shield(task) except asyncio.CancelledError as cex: if ex is not None: cex.__context__ = ex ex = cex except Exception: if ex: raise ex from None raise if ex: raise ex return task.result() async def debounce[T]( input: Callable[[], Awaitable[T]], output: Callable[[list[T]], Awaitable[None]], *, max_wait: float, delay_amt: float, max_batch_size: int, ) -> None: '''Debounce and batch async events. Loops forever unless an operation fails, so should probably be run from a task. The basic algorithm is that if an event comes in less than `delay_amt` since the previous one, then instead of sending it immediately, we wait an additional `delay_amt` from then. If we are already waiting, any message also extends the wait, up to `max_wait`. Also, cap the maximum batch size to `max_batch_size`. ''' # I think the algorithm reads more clearly with the params # capitalized as constants, though we don't want them like that in # the argument list, so reassign them. MAX_WAIT, DELAY_AMT, MAX_BATCH_SIZE = max_wait, delay_amt, max_batch_size loop = asyncio.get_running_loop() batch = [] last_signal = -MAX_WAIT target_time = None while True: try: if target_time is None: v = await input() else: async with asyncio.timeout_at(target_time): v = await input() except TimeoutError: t = loop.time() else: batch.append(v) t = loop.time() # If we aren't current waiting, and we got a # notification recently, arrange to wait some before # sending it. if ( target_time is None and t - last_signal < DELAY_AMT ): target_time = t + DELAY_AMT # If we were already waiting, wait a little longer, though # not longer than MAX_WAIT. elif ( target_time is not None ): target_time = min( max(t + DELAY_AMT, target_time), last_signal + MAX_WAIT, ) # Skip sending the event if we need to wait longer. if ( target_time is not None and t < target_time and len(batch) < MAX_BATCH_SIZE ): continue await output(batch) batch = [] last_signal = t target_time = None _Owner = TypeVar("_Owner") HandlerFunction = Callable[[], Awaitable[None]] HandlerMethod = Callable[[Any], Awaitable[None]] class ExclusiveTask: """Manages to run a repeatable task once at a time.""" _handler: HandlerFunction _task: asyncio.Task | None _scheduled: bool _stop_requested: bool def __init__(self, handler: HandlerFunction) -> None: self._handler = handler self._task = None self._scheduled = False self._stop_requested = False @property def scheduled(self) -> bool: return self._scheduled async def _run(self) -> None: if self._scheduled and not self._stop_requested: self._scheduled = False else: return try: await self._handler() finally: if self._scheduled and not self._stop_requested: self._task = asyncio.create_task(self._run()) else: self._task = None def schedule(self) -> None: """Schedule to run the task as soon as possible. If already scheduled, nothing happens; it won't queue up. If the task is already running, it will be scheduled to run again as soon as the running task is done. """ if not self._stop_requested: self._scheduled = True if self._task is None: self._task = asyncio.create_task(self._run()) async def stop(self) -> None: """Cancel scheduled task and wait for the running one to finish. After an ExclusiveTask is stopped, no more new schedules are allowed. Note: "cancel scheduled task" only means setting self._scheduled to False; if an asyncio task is scheduled, stop() will still wait for it. """ self._scheduled = False self._stop_requested = True if self._task is not None: await self._task class ExclusiveTaskProperty: _method: HandlerMethod _name: str | None def __init__( self, method: HandlerMethod, *, slot: str | None = None ) -> None: self._method = method self._name = slot def __set_name__(self, owner: type[_Owner], name: str) -> None: if (slots := getattr(owner, "__slots__", None)) is not None: if self._name is None: raise TypeError("missing slot in @exclusive_task()") if self._name not in slots: raise TypeError( f"slot {self._name!r} must be defined in __slots__" ) if self._name is None: self._name = name @overload def __get__(self, instance: None, owner: type[_Owner]) -> Self: ... @overload def __get__( self, instance: _Owner, owner: type[_Owner] ) -> ExclusiveTask: ... def __get__( self, instance: _Owner | None, owner: type[_Owner] ) -> ExclusiveTask | Self: # getattr on the class if instance is None: return self assert self._name is not None # getattr on an object with __dict__ if (d := getattr(instance, "__dict__", None)) is not None: if rv := d.get(self._name, None): return rv rv = ExclusiveTask(self._method.__get__(instance, owner)) d[self._name] = rv return rv # getattr on an object with __slots__ else: if rv := getattr(instance, self._name, None): return rv rv = ExclusiveTask(self._method.__get__(instance, owner)) setattr(instance, self._name, rv) return rv ExclusiveTaskDecorator = Callable[ [HandlerFunction | HandlerMethod], ExclusiveTask | ExclusiveTaskProperty ] def _exclusive_task( handler: HandlerFunction | HandlerMethod, *, slot: str | None ) -> ExclusiveTask | ExclusiveTaskProperty: sig = inspect.signature(handler) params = list(sig.parameters.values()) if len(params) == 0: handler = cast(HandlerFunction, handler) if slot is not None: warnings.warn( "slot is specified but unused in @exclusive_task()", stacklevel=2, ) return ExclusiveTask(handler) elif len(params) == 1 and params[0].kind in ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ): handler = cast(HandlerMethod, handler) return ExclusiveTaskProperty(handler, slot=slot) else: raise TypeError("bad signature") @overload def exclusive_task(handler: HandlerFunction) -> ExclusiveTask: ... @overload def exclusive_task( handler: HandlerMethod, *, slot: str | None = None ) -> ExclusiveTaskProperty: ... @overload def exclusive_task(*, slot: str | None = None) -> ExclusiveTaskDecorator: ... def exclusive_task( handler: HandlerFunction | HandlerMethod | None = None, *, slot: str | None = None, ) -> ExclusiveTask | ExclusiveTaskProperty | ExclusiveTaskDecorator: """Convert an async function into an ExclusiveTask. This decorator can be applied to either top-level functions or methods in a class. In the latter case, the exclusiveness is bound to each object of the owning class. If the owning class defines __slots__, you must also define an extra slot to store the exclusive state and tell exclusive_task() by providing the `slot` argument. """ if handler is None: def decorator( handler: HandlerFunction | HandlerMethod, ) -> ExclusiveTask | ExclusiveTaskProperty: return _exclusive_task(handler, slot=slot) return decorator return _exclusive_task(handler, slot=slot) ================================================ FILE: edb/common/asyncwatcher.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright EdgeDB Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional import asyncio import logging from . import retryloop logger = logging.getLogger("edb.server.asyncwatcher") class AsyncWatcherProtocol(asyncio.Protocol): def __init__( self, watcher: AsyncWatcher, ) -> None: self._transport: Optional[asyncio.Transport] = None self._watcher = watcher def connection_made(self, transport: asyncio.BaseTransport) -> None: self._transport = transport # type: ignore [assignment] self.request() def connection_lost(self, exc: Optional[Exception]) -> None: self._watcher.incr_metrics_counter("watch-disconnect") self._watcher.on_connection_lost() def request(self) -> None: raise NotImplementedError def close(self) -> None: raise NotImplementedError class AsyncWatcher: def __init__(self) -> None: self._waiter: Optional[asyncio.Future] = None self._stop_waiter: Optional[asyncio.Future] = None self._protocol: Optional[AsyncWatcherProtocol] = None self._watching = False self._retry_attempt = 0 self._backoff = retryloop.exp_backoff() async def start_watching(self) -> bool: if not self._watching: self._watching = True try: self._protocol = await self._start_watching() return True except BaseException: self.incr_metrics_counter("watch-start-err") self._watching = False raise return False async def retry_watching(self) -> None: self._retry_attempt += 1 delay = self._backoff(self._retry_attempt) await asyncio.sleep(delay) try: await self.start_watching() except Exception: logger.warning( "%s failed to start watching, will retry.", type(self).__name__, exc_info=True, ) asyncio.create_task(self.retry_watching()) def stop_watching(self) -> None: self._watching = False protocol, self._protocol = self._protocol, None if protocol is not None: self._stop_waiter = asyncio.get_running_loop().create_future() protocol.close() async def wait_stopped_watching(self) -> None: if self._stop_waiter is not None: await self._stop_waiter def on_connection_lost(self) -> None: self._protocol = None if self._watching: self.stop_watching() asyncio.create_task(self.retry_watching()) else: waiter, self._stop_waiter = self._stop_waiter, None if waiter is not None: waiter.set_result(None) def on_update(self, data: bytes) -> None: self._retry_attempt = 0 self._on_update(data) def _on_update(self, data: bytes) -> None: raise NotImplementedError async def _start_watching(self) -> AsyncWatcherProtocol: raise NotImplementedError def consume_tokens(self, tokens: int) -> float: # For rate limit - tries to consume the given number of tokens, returns # non-zero values as seconds to wait if unsuccessful return 0 def incr_metrics_counter(self, event: str, value: float = 1.0) -> None: pass ================================================ FILE: edb/common/binwrapper.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import io import struct class BinWrapper: """A utility binary-reader wrapper over any io.BytesIO object.""" i64 = struct.Struct('!q') i32 = struct.Struct('!l') i16 = struct.Struct('!h') i8 = struct.Struct('!b') ui64 = struct.Struct('!Q') ui32 = struct.Struct('!L') ui16 = struct.Struct('!H') ui8 = struct.Struct('!B') def __init__(self, buf: io.BytesIO) -> None: self.buf = buf def write_ui64(self, val: int) -> None: self.buf.write(self.ui64.pack(val)) def write_ui32(self, val: int) -> None: self.buf.write(self.ui32.pack(val)) def write_ui16(self, val: int) -> None: self.buf.write(self.ui16.pack(val)) def write_ui8(self, val: int) -> None: self.buf.write(self.ui8.pack(val)) def write_i64(self, val: int) -> None: self.buf.write(self.i64.pack(val)) def write_i32(self, val: int) -> None: self.buf.write(self.i32.pack(val)) def write_i16(self, val: int) -> None: self.buf.write(self.i16.pack(val)) def write_i8(self, val: int) -> None: self.buf.write(self.i8.pack(val)) def write_len32_prefixed_bytes(self, val: bytes) -> None: self.write_ui32(len(val)) self.buf.write(val) def write_bytes(self, val: bytes) -> None: self.buf.write(val) def read_ui64(self) -> int: data = self.buf.read(8) return self.ui64.unpack(data)[0] def read_ui32(self) -> int: data = self.buf.read(4) return self.ui32.unpack(data)[0] def read_ui16(self) -> int: data = self.buf.read(2) return self.ui16.unpack(data)[0] def read_ui8(self) -> int: data = self.buf.read(1) return self.ui8.unpack(data)[0] def read_i64(self) -> int: data = self.buf.read(8) return self.i64.unpack(data)[0] def read_i32(self) -> int: data = self.buf.read(4) return self.i32.unpack(data)[0] def read_i16(self) -> int: data = self.buf.read(2) return self.i16.unpack(data)[0] def read_i8(self) -> int: data = self.buf.read(1) return self.i8.unpack(data)[0] def read_bytes(self, size: int) -> bytes: data = self.buf.read(size) if len(data) != size: raise BufferError(f'cannot read bytes with len={size}') return data def read_len32_prefixed_bytes(self) -> bytes: size = self.read_ui32() return self.read_bytes(size) def read_nullable_len32_prefixed_bytes(self) -> bytes | None: size = self.read_i32() if size == -1: return None else: return self.read_bytes(size) def tell(self) -> int: return self.buf.tell() ================================================ FILE: edb/common/checked.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, ClassVar, Optional, AbstractSet, Iterable, Iterator, MutableMapping, MutableSequence, MutableSet, Sequence, cast, overload, ) import collections.abc import itertools import types from edb.common import debug from edb.common import parametric __all__ = [ "CheckedList", "CheckedDict", "CheckedSet", "FrozenCheckedList", "FrozenCheckedSet", ] class ParametricContainer: types: ClassVar[Optional[tuple[type, ...]]] = None def __reduce__(self) -> tuple[Any, ...]: assert self.types is not None, f'missing parameters in {type(self)}' cls: type[ParametricContainer] = self.__class__ container = getattr(self, "_container", ()) if cls.__name__.endswith("]"): # Parametrized type. cls = cls.__bases__[0] else: # A subclass of a parametrized type. return cls, (container,) args = self.types[0] if len(self.types) == 1 else self.types return cls.__restore__, (args, container) @classmethod def __restore__( cls, params: tuple[type, ...], data: Iterable[Any] ) -> ParametricContainer: return cls[params](data) # type: ignore class AbstractCheckedList[T]: type: type _container: list[T] @classmethod def _check_type(cls, value: Any) -> T: """Ensure `value` is of type T and return it.""" if not isinstance(value, cls.type): raise ValueError( f"{cls!r} accepts only values of type {cls.type!r}, " f"got {type(value)!r}" ) return cast(T, value) def __init__(self, iterable: Iterable[T] = ()) -> None: pass def __lt__(self, other: list[T]) -> bool: return self._container < self._cast(other) def __le__(self, other: list[T]) -> bool: return self._container <= self._cast(other) def __gt__(self, other: list[T]) -> bool: return self._container > self._cast(other) def __ge__(self, other: list[T]) -> bool: return self._container >= self._cast(other) def _cast(self, other: list[T]) -> list[T]: if isinstance(other, (CheckedList, FrozenCheckedList)): return other._container return other def __eq__(self, other: object) -> bool: if isinstance(other, (CheckedList, FrozenCheckedList)): other = other._container return self._container == other def __str__(self) -> str: return repr(self._container) def __repr__(self) -> str: return f"{_type_repr(type(self))}({repr(self._container)})" class FrozenCheckedList[T]( ParametricContainer, parametric.SingleParametricType[T], AbstractCheckedList[T], Sequence[T], ): def __init__(self, iterable: Iterable[T] = ()) -> None: super().__init__() self._container = [self._check_type(element) for element in iterable] self._hash_cache = -1 def __hash__(self) -> int: if self._hash_cache == -1: self._hash_cache = hash(tuple(self._container)) return self._hash_cache # # Sequence # @overload def __getitem__(self, index: int) -> T: ... @overload def __getitem__(self, index: slice) -> FrozenCheckedList[T]: ... def __getitem__(self, index: int | slice) -> Any: if isinstance(index, slice): return self.__class__(self._container[index]) return self._container[index] def __len__(self) -> int: return len(self._container) # # List-specific # def __add__(self, other: Iterable[T]) -> FrozenCheckedList[T]: return self.__class__(itertools.chain(self, other)) def __radd__(self, other: Iterable[T]) -> FrozenCheckedList[T]: return self.__class__(itertools.chain(other, self)) def __mul__(self, n: int) -> FrozenCheckedList[T]: return self.__class__(self._container * n) __rmul__ = __mul__ class CheckedList[T]( ParametricContainer, parametric.SingleParametricType[T], AbstractCheckedList[T], MutableSequence[T], ): def __init__(self, iterable: Iterable[T] = ()) -> None: super().__init__() self._container = [self._check_type(element) for element in iterable] # # Sequence # @overload def __getitem__(self, index: int) -> T: ... @overload def __getitem__(self, index: slice) -> CheckedList[T]: ... def __getitem__(self, index: int | slice) -> Any: if isinstance(index, slice): return self.__class__(self._container[index]) return self._container[index] # # MutableSequence # @overload def __setitem__(self, index: int, value: T) -> None: ... @overload def __setitem__(self, index: slice, value: Iterable[T]) -> None: ... def __setitem__(self, index: int | slice, value: Any) -> None: if isinstance(index, int): self._container[index] = self._check_type(value) return _slice = index self._container[_slice] = filter(self._check_type, value) @overload def __delitem__(self, index: int) -> None: ... @overload def __delitem__(self, index: slice) -> None: ... def __delitem__(self, index: int | slice) -> None: del self._container[index] def insert(self, index: int, value: T) -> None: self._container.insert(index, self._check_type(value)) def __len__(self) -> int: return len(self._container) # # List-specific # def __add__(self, other: Iterable[T]) -> CheckedList[T]: return self.__class__(itertools.chain(self, other)) def __radd__(self, other: Iterable[T]) -> CheckedList[T]: return self.__class__(itertools.chain(other, self)) def __iadd__(self, other: Iterable[T]) -> CheckedList[T]: self._container.extend(filter(self._check_type, other)) return self def __mul__(self, n: int) -> CheckedList[T]: return self.__class__(self._container * n) __rmul__ = __mul__ def __imul__(self, n: int) -> CheckedList[T]: self._container *= n return self def sort(self, *, key: Any = None, reverse: bool = False) -> None: self._container.sort(key=key, reverse=reverse) class AbstractCheckedSet[T](AbstractSet[T]): type: type _container: AbstractSet[T] def __init__(self, iterable: Iterable[T] = ()) -> None: pass @classmethod def _check_type(cls, value: Any) -> T: """Ensure `value` is of type T and return it.""" if not isinstance(value, cls.type): raise ValueError( f"{cls!r} accepts only values of type {cls.type!r}, " f"got {type(value)!r}" ) return cast(T, value) def _cast(self, other: Any) -> AbstractSet[T]: if isinstance(other, (FrozenCheckedSet, CheckedSet)): return other._container if isinstance(other, collections.abc.Set): return other return set(other) def __eq__(self, other: object) -> bool: if isinstance(other, (CheckedSet, FrozenCheckedSet)): other = other._container return self._container == other def __str__(self) -> str: return repr(self._container) def __repr__(self) -> str: return f"{_type_repr(type(self))}({repr(self._container)})" # # collections.abc.Set aka typing.AbstractSet # def __contains__(self, value: Any) -> bool: return value in self._container def __iter__(self) -> Iterator[T]: return iter(self._container) def __len__(self) -> int: return len(self._container) # # Specific to set() and frozenset() # def issubset(self, other: AbstractSet[Any]) -> bool: return self.__le__(other) def issuperset(self, other: AbstractSet[Any]) -> bool: return self.__ge__(other) class FrozenCheckedSet[T]( ParametricContainer, parametric.SingleParametricType[T], AbstractCheckedSet[T], ): def __init__(self, iterable: Iterable[T] = ()) -> None: super().__init__() self._container = {self._check_type(element) for element in iterable} self._hash_cache = -1 def __hash__(self) -> int: if self._hash_cache == -1: self._hash_cache = hash(frozenset(self._container)) return self._hash_cache # # Replaced mixins of collections.abc.Set # # NOTE: The type ignores on function signatures below are because we are # deliberately breaking the Liskov Substitute Principle: we want the type # checker to warn the user if a checked set of a type is __or__'d, or # __and__'d with a set of an incompatible type. If the user wanted this, # they should convert the checked set into a regular set or a differently # typed checked set first. def __and__(self, other: AbstractSet[T]) -> FrozenCheckedSet[T]: other_set = self._cast(other) for elem in other_set: # We need the explicit type check to reject nonsensical # & operations that must always result in an empty new set. self._check_type(elem) return self.__class__(other_set & self._container) __rand__ = __and__ def __or__( # type: ignore self, other: AbstractSet[T] ) -> FrozenCheckedSet[T]: other_set = self._cast(other) return self.__class__(other_set | self._container) __ror__ = __or__ def __sub__(self, other: AbstractSet[T]) -> FrozenCheckedSet[T]: other_set = self._cast(other) for elem in other_set: # We need the explicit type check to reject nonsensical # - operations that always return the original checked set. self._check_type(elem) return self.__class__(self._container - other_set) def __rsub__(self, other: AbstractSet[T]) -> FrozenCheckedSet[T]: other_set = self._cast(other) return self.__class__(other_set - self._container) def __xor__( # type: ignore self, other: AbstractSet[T] ) -> FrozenCheckedSet[T]: other_set = self._cast(other) return self.__class__(self._container ^ other_set) __rxor__ = __xor__ # # Specific to set() and frozenset() # union = __and__ intersection = __or__ difference = __sub__ symmetric_difference = __xor__ class CheckedSet[T]( ParametricContainer, parametric.SingleParametricType[T], AbstractCheckedSet[T], MutableSet[T], ): _container: set[T] def __init__(self, iterable: Iterable[T] = ()) -> None: super().__init__() self._container = {self._check_type(element) for element in iterable} # # Replaced mixins of collections.abc.Set # # NOTE: The type ignores on function signatures below are because we are # deliberately breaking the Liskov Substitute Principle: we want the type # checker to warn the user if a checked set of a type is __or__'d, or # __and__'d with a set of an incompatible type. If the user wanted this, # they should convert the checked set into a regular set or a differently # typed checked set first. def __and__(self, other: AbstractSet[T]) -> CheckedSet[T]: other_set = self._cast(other) for elem in other_set: # We need the explicit type check to reject nonsensical # & operations that must always result in an empty new set. self._check_type(elem) return self.__class__(other_set & self._container) __rand__ = __and__ def __or__(self, other: AbstractSet[T]) -> CheckedSet[T]: # type: ignore other_set = self._cast(other) return self.__class__(other_set | self._container) __ror__ = __or__ def __sub__(self, other: AbstractSet[T]) -> CheckedSet[T]: other_set = self._cast(other) for elem in other_set: # We need the explicit type check to reject nonsensical # - operations that always return the original checked set. self._check_type(elem) return self.__class__(self._container - other_set) def __rsub__(self, other: AbstractSet[T]) -> CheckedSet[T]: other_set = self._cast(other) return self.__class__(other_set - self._container) def __xor__(self, other: AbstractSet[T]) -> CheckedSet[T]: # type: ignore other_set = self._cast(other) return self.__class__(self._container ^ other_set) __rxor__ = __xor__ # # MutableSet # def add(self, value: T) -> None: self._container.add(self._check_type(value)) def discard(self, value: T) -> None: self._container.discard(self._check_type(value)) # # Replaced mixins of collections.abc.MutableSet # def __ior__(self, other: AbstractSet[T]) -> CheckedSet[T]: # type: ignore self._container |= set(filter(self._check_type, other)) return self def __iand__(self, other: AbstractSet[T]) -> CheckedSet[T]: # We do the type check here to reject nonsensical # & operations that always clear the checked set. self._container &= set(filter(self._check_type, other)) return self def __ixor__(self, other: AbstractSet[T]) -> CheckedSet[T]: # type: ignore self._container ^= set(filter(self._check_type, other)) return self def __isub__(self, other: AbstractSet[T]) -> CheckedSet[T]: # We do the type check here to reject nonsensical # - operations that could never affect the checked set. self._container -= set(filter(self._check_type, other)) return self # # Specific to set() and frozenset() # union = __and__ intersection = __or__ difference = __sub__ symmetric_difference = __xor__ # # Specific to set() # update = __ior__ intersection_update = __iand__ difference_update = __isub__ symmetric_difference_update = __ixor__ def _type_repr(obj: Any) -> str: if isinstance(obj, type): if obj.__module__ == "builtins": return obj.__qualname__ return f"{obj.__module__}.{obj.__qualname__}" if isinstance(obj, types.FunctionType): return obj.__name__ return repr(obj) class AbstractCheckedDict[K, V]: keytype: type valuetype: type _container: dict[K, V] @classmethod def _check_key_type(cls, key: Any) -> K: """Ensure `key` is of type K and return it.""" if not isinstance(key, cls.keytype): raise KeyError( f"{cls!r} accepts only keys of type {cls.keytype!r}, " f"got {type(key)!r}" ) return cast(K, key) @classmethod def _check_value_type(cls, value: Any) -> V: """Ensure `value` is of type V and return it.""" if not isinstance(value, cls.valuetype): raise ValueError( f"{cls!r} accepts only values of type " "{cls.valuetype!r}, got {type(value)!r}" ) return cast(V, value) def __eq__(self, other: object) -> bool: if isinstance(other, CheckedDict): other = other._container return self._container == other def __str__(self) -> str: return repr(self._container) def __repr__(self) -> str: return f"{_type_repr(type(self))}({repr(self._container)})" class CheckedDict[K, V]( ParametricContainer, parametric.KeyValueParametricType[K, V], AbstractCheckedDict[K, V], MutableMapping[K, V], ): def __init__(self, *args: Any, **kwargs: V) -> None: super().__init__() self._container = {} if len(args) == 1: self.update(args[0]) if len(args) > 1: raise ValueError( f"{type(self)!r} expected at most 1 argument, got {len(args)}" ) if len(kwargs): # Mypy is right below that the type of kwargs is Dict[str, V] # but we are deliberately letting this through for it to blow up # on runtime type checking if K is not a string. self.update(kwargs) # type: ignore # # collections.abc.Mapping # def __getitem__(self, key: K) -> V: return self._container[key] def __iter__(self) -> Iterator[K]: return iter(self._container) def __len__(self) -> int: return len(self._container) # # collections.abc.MutableMapping # def __setitem__(self, key: K, value: V) -> None: self._check_key_type(key) self._container[key] = self._check_value_type(value) def __delitem__(self, key: K) -> None: del self._container[key] # # Dict-specific # @classmethod def fromkeys( cls, iterable: Iterable[K], value: Optional[V] = None ) -> CheckedDict[K, V]: new: CheckedDict[K, V] = cls() for key in iterable: new[cls._check_key_type(key)] = cls._check_value_type(value) return new def _identity[T](cls: type, value: T) -> T: return value _type_checking = { CheckedList: ["_check_type"], CheckedDict: ["_check_key_type", "_check_value_type"], CheckedSet: ["_check_type"], FrozenCheckedList: ["_check_type"], FrozenCheckedSet: ["_check_type"], } def disable_typechecks() -> None: for type_, methods in _type_checking.items(): for method in methods: setattr(type_, method, _identity) def enable_typechecks() -> None: for type_, methods in _type_checking.items(): for method in methods: try: delattr(type_, method) except AttributeError: continue if not debug.flags.typecheck: disable_typechecks() ================================================ FILE: edb/common/colorsys.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """An extension to standard library module :mod:`colorsys`. Contains additional functions, with the most notable - :func:`rgb_distance`. """ from __future__ import annotations from math import sqrt as _sqrt from colorsys import ( rgb_to_yiq, yiq_to_rgb, rgb_to_hls, hls_to_rgb, rgb_to_hsv, hsv_to_rgb ) __all__ = 'rgb_to_yiq', 'yiq_to_rgb', 'rgb_to_hls', 'hls_to_rgb', \ 'rgb_to_hsv', 'hsv_to_rgb', 'rgb_to_xyz', 'xyz_to_lab', \ 'rgb_distance', 'Color' class Color: colors = { 'aliceblue': '#f0f8ff', 'antiquewhite': '#faebd7', 'aqua': '#00ffff', 'aquamarine': '#7fffd4', 'azure': '#f0ffff', 'beige': '#f5f5dc', 'bisque': '#ffe4c4', 'black': '#000000', 'blanchedalmond': '#ffebcd', 'blue': '#0000ff', 'blueviolet': '#8a2be2', 'brown': '#a52a2a', 'burlywood': '#deb887', 'cadetblue': '#5f9ea0', 'chartreuse': '#7fff00', 'chocolate': '#d2691e', 'coral': '#ff7f50', 'cornflowerblue': '#6495ed', 'cornsilk': '#fff8dc', 'crimson': '#dc143c', 'cyan': '#00ffff', 'darkblue': '#00008b', 'darkcyan': '#008b8b', 'darkgoldenrod': '#b8860b', 'darkgray': '#a9a9a9', 'darkgreen': '#006400', 'darkkhaki': '#bdb76b', 'darkmagenta': '#8b008b', 'darkolivegreen': '#556b2f', 'darkorange': '#ff8c00', 'darkorchid': '#9932cc', 'darkred': '#8b0000', 'darksalmon': '#e9967a', 'darkseagreen': '#8fbc8f', 'darkslateblue': '#483d8b', 'darkslategray': '#2f4f4f', 'darkturquoise': '#00ced1', 'darkviolet': '#9400d3', 'deeppink': '#ff1493', 'deepskyblue': '#00bfff', 'dimgray': '#696969', 'dodgerblue': '#1e90ff', 'firebrick': '#b22222', 'floralwhite': '#fffaf0', 'forestgreen': '#228b22', 'fuchsia': '#ff00ff', 'gainsboro': '#dcdcdc', 'ghostwhite': '#f8f8ff', 'gold': '#ffd700', 'goldenrod': '#daa520', 'gray': '#808080', 'green': '#008000', 'greenyellow': '#adff2f', 'honeydew': '#f0fff0', 'hotpink': '#ff69b4', 'indianred': '#cd5c5c', 'indigo': '#4b0082', 'ivory': '#fffff0', 'khaki': '#f0e68c', 'lavender': '#e6e6fa', 'lavenderblush': '#fff0f5', 'lawngreen': '#7cfc00', 'lemonchiffon': '#fffacd', 'lightblue': '#add8e6', 'lightcoral': '#f08080', 'lightcyan': '#e0ffff', 'lightgoldenrodyellow': '#fafad2', 'lightgreen': '#90ee90', 'lightgrey': '#d3d3d3', 'lightpink': '#ffb6c1', 'lightsalmon': '#ffa07a', 'lightseagreen': '#20b2aa', 'lightskyblue': '#87cefa', 'lightslategray': '#778899', 'lightsteelblue': '#b0c4de', 'lightyellow': '#ffffe0', 'lime': '#00ff00', 'limegreen': '#32cd32', 'linen': '#faf0e6', 'magenta': '#ff00ff', 'maroon': '#800000', 'mediumaquamarine': '#66cdaa', 'mediumblue': '#0000cd', 'mediumorchid': '#ba55d3', 'mediumpurple': '#9370db', 'mediumseagreen': '#3cb371', 'mediumslateblue': '#7b68ee', 'mediumspringgreen': '#00fa9a', 'mediumturquoise': '#48d1cc', 'mediumvioletred': '#c71585', 'midnightblue': '#191970', 'mintcream': '#f5fffa', 'mistyrose': '#ffe4e1', 'moccasin': '#ffe4b5', 'navajowhite': '#ffdead', 'navy': '#000080', 'oldlace': '#fdf5e6', 'olive': '#808000', 'olivedrab': '#6b8e23', 'orange': '#ffa500', 'orangered': '#ff4500', 'orchid': '#da70d6', 'palegoldenrod': '#eee8aa', 'palegreen': '#98fb98', 'paleturquoise': '#afeeee', 'palevioletred': '#db7093', 'papayawhip': '#ffefd5', 'peachpuff': '#ffdab9', 'peru': '#cd853f', 'pink': '#ffc0cb', 'plum': '#dda0dd', 'powderblue': '#b0e0e6', 'purple': '#800080', 'red': '#ff0000', 'rosybrown': '#bc8f8f', 'royalblue': '#4169e1', 'saddlebrown': '#8b4513', 'salmon': '#fa8072', 'sandybrown': '#f4a460', 'seagreen': '#2e8b57', 'seashell': '#fff5ee', 'sienna': '#a0522d', 'silver': '#c0c0c0', 'skyblue': '#87ceeb', 'slateblue': '#6a5acd', 'slategray': '#708090', 'snow': '#fffafa', 'springgreen': '#00ff7f', 'steelblue': '#4682b4', 'tan': '#d2b48c', 'teal': '#008080', 'thistle': '#d8bfd8', 'tomato': '#ff6347', 'turquoise': '#40e0d0', 'violet': '#ee82ee', 'wheat': '#f5deb3', 'white': '#ffffff', 'whitesmoke': '#f5f5f5', 'yellow': '#ffff00', 'yellowgreen': '#9acd32' } def __init__(self, r, g, b, a=1.0): r = int(r) g = int(g) b = int(b) a = float(a) if r > 255 or r < 0 or g < 0 or g > 255 or b < 0 or b > 255: raise ValueError('color component should belong to [0, 255]') if a < 0 or a > 1: raise ValueError('alpha component should belong to [0, 1]') self.r, self.g, self.b, self.a = r, g, b, a @classmethod def from_color(cls, color): return cls(color.r, color.g, color.b, color.a) @classmethod def from_string(cls, value, alpha=1.0): if not value.startswith('#'): if value == 'transparent': return cls(0, 0, 0, 0) else: try: value = cls.colors[str(value)] except KeyError: raise ValueError('Unknown color name') value = value[1:] try: if len(value) == 3: r, g, b = [int(x * 2, 16) for x in value] elif len(value) == 6: r, g, b = [int(value[i:i + 2], 16) for i in range(0, 6, 2)] else: raise ValueError except ValueError: raise ValueError('Invalid color value') return cls(r, g, b, a=alpha) @classmethod def from_hls(cls, h, l, s, alpha=1.0): # NoQA: E741 return cls(*(int(c * 255) for c in hls_to_rgb(h, l, s)), a=alpha) def rgb_channels(self, *, as_floats=False): if as_floats: return (self.r / 255.0, self.g / 255.0, self.b / 255.0) else: return (self.r, self.g, self.b) def rgba_channels(self, *, as_floats=False): if as_floats: return (self.r / 255.0, self.g / 255.0, self.b / 255.0, self.a) else: return (self.r, self.g, self.b, self.a) def hls_channels(self): return rgb_to_hls(*(c / 255 for c in self.rgb_channels())) # Relative to RGB max white XYZ_MAX_X = 95.047 XYZ_MAX_Y = 100.0 XYZ_MAX_Z = 108.883 def rgb_to_xyz(r, g, b): """Converts RGB color to XYZ :param float r: Red value in ``0..1`` range :param float g: Green value in ``0..1`` range :param float b: Blue value in ``0..1`` range :returns: ``(x, y, z)``, all values normalized to the ``(0..1, 0..1, 0..1)`` range """ # Formulae from http://www.easyrgb.com/index.php?X=MATH if r > 0.04045: r = ((r + 0.055) / 1.055) ** 2.4 else: r /= 12.92 if g > 0.04045: g = ((g + 0.055) / 1.055) ** 2.4 else: g /= 12.92 if b > 0.04045: b = ((b + 0.055) / 1.055) ** 2.4 else: b /= 12.92 r *= 100.0 g *= 100.0 b *= 100.0 x = min((r * 0.4124 + g * 0.3576 + b * 0.1805) / XYZ_MAX_X, 1.0) y = min((r * 0.2126 + g * 0.7152 + b * 0.0722) / XYZ_MAX_Y, 1.0) z = min((r * 0.0193 + g * 0.1192 + b * 0.9505) / XYZ_MAX_Z, 1.0) return (x, y, z) _1_3 = 1.0 / 3.0 _16_116 = 16.0 / 116.0 def xyz_to_lab(x, y, z): """Converts XYZ color to LAB :param float x: Value from ``0..1`` :param float y: Value from ``0..1`` :param float z: Value from ``0..1`` :returns: ``(L, a, b)``, values in range ``(0..100, -127..128, -127..128)`` """ # Formulae from http://www.easyrgb.com/index.php?X=MATH if x > 0.008856: x **= _1_3 else: x = (7.787 * x) + _16_116 if y > 0.008856: y **= _1_3 else: y = (7.787 * y) + _16_116 if z > 0.008856: z **= _1_3 else: z = (7.787 * z) + _16_116 lum = 116.0 * y - 16.0 a = 500 * (x - y) b = 200 * (y - z) return (lum, a, b) def rgb_distance(r1, g1, b1, r2, g2, b2): """Calculates numerical distance between two colors in RGB color space. The distance is calculated by CIE94 formula. :params: Two colors with ``r, g, b`` values in ``0..1`` range :returns: A number in ``0..100`` range. The lesser - the closer colors are. """ # Formulae from wikipedia article re CIE94 L1, A1, B1 = xyz_to_lab(*rgb_to_xyz(r1, b1, g1)) L2, A2, B2 = xyz_to_lab(*rgb_to_xyz(r2, b2, g2)) dL = L1 - L2 C1 = _sqrt(A1 * A1 + B1 * B1) C2 = _sqrt(A2 * A2 + B2 * B2) dCab = C1 - C2 dA = A1 - A2 dB = B1 - B2 dEab = _sqrt(dL ** 2 + dA ** 2 + dB ** 2) dHab = _sqrt(max(dEab ** 2 - dL ** 2 - dCab ** 2, 0.0)) dE = _sqrt((dL ** 2) + ((dCab / (1 + 0.045 * C1)) ** 2) + ( dHab / (1 + 0.015 * C1)) ** 2) return dE ================================================ FILE: edb/common/compiler.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, ContextManager, Self, ) import collections import re class ContextLevel: _stack: CompilerContext[Self] def __init__(self, prevlevel: Optional[Self], mode: Any) -> None: pass def on_pop( self: Self, prevlevel: Optional[Self], ) -> None: pass def new( self: Self, mode: Any=None, ) -> CompilerContextManager[Self]: return self._stack.new(mode, self) def reenter( self: Self, ) -> CompilerReentryContextManager[Self]: return CompilerReentryContextManager(self._stack, self) class CompilerContextManager[ContextLevel_T: ContextLevel]( ContextManager[ContextLevel_T] ): def __init__( self, context: CompilerContext[ContextLevel_T], mode: Any, prevlevel: Optional[ContextLevel_T], ) -> None: self.context = context self.mode = mode self.prevlevel = prevlevel def __enter__(self) -> ContextLevel_T: return self.context.push(self.mode, self.prevlevel) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.context.pop() class CompilerReentryContextManager[ContextLevel_T: ContextLevel]( ContextManager[ContextLevel_T] ): def __init__( self, context: CompilerContext[ContextLevel_T], level: ContextLevel_T, ) -> None: self.context = context self.level = level def __enter__(self) -> ContextLevel_T: return self.context._push(None, initial=self.level) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.context.pop() class CompilerContext[ContextLevel_T: ContextLevel]: stack: list[ContextLevel_T] ContextLevelClass: type[ContextLevel_T] default_mode: Any def __init__(self, initial: ContextLevel_T) -> None: self.stack = [] self._push(None, initial=initial) def push( self, mode: Any, prevlevel: Optional[ContextLevel_T] = None, ) -> ContextLevel_T: return self._push(mode, prevlevel) def _push( self, mode: Any, prevlevel: Optional[ContextLevel_T] = None, *, initial: Optional[ContextLevel_T] = None, ) -> ContextLevel_T: if initial is not None: level = initial else: if prevlevel is None: prevlevel = self.current elif prevlevel is not self.current: # In the past, we always used self.current as the # previous level and simply ignored the prevlevel # parameter. Actually using prevlevel makes more sense # and has fewer gotchas, but enough code had grown to # depend on the old behavior that changing it required # asserting that they were the same. We can consider # dropping the assertion if it proves tedious. raise AssertionError( 'Calling new() on a context other than the current one') level = self.ContextLevelClass(prevlevel, mode) level._stack = self self.stack.append(level) return level def pop(self) -> None: level = self.stack.pop() level.on_pop(self.stack[-1] if self.stack else None) def new( self, mode: Any = None, prevlevel: Optional[ContextLevel_T] = None, ) -> CompilerContextManager[ContextLevel_T]: if mode is None: mode = self.default_mode return CompilerContextManager(self, mode, prevlevel) @property def current(self) -> ContextLevel_T: return self.stack[-1] class SimpleCounter: counts: collections.defaultdict[str, int] def __init__(self) -> None: self.counts = collections.defaultdict(int) def nextval(self, name: str = 'default') -> int: self.counts[name] += 1 return self.counts[name] class AliasGenerator(SimpleCounter): def get(self, hint: str = '') -> str: if not hint: hint = 'v' m = re.search(r'~\d+$', hint) if m: hint = hint[:m.start()] idx = self.nextval(hint) alias = f'{hint}~{idx}' return alias ================================================ FILE: edb/common/debug.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Debug flags and output facilities. An example code using this module: if debug.flags.some_sql_flag: debug.header('SQL') debug.dump(sql_ast) Use `debug.header()`, `debug.print()`, `debug.dump()` and `debug.dump_code()` functions as opposed to using 'print' built-in directly. This gives us flexibility to redirect debug output if needed. """ from __future__ import annotations import builtins import contextlib import os import sys import time import warnings # Don't import anything from "edb.*" as it will wreck coverage. __all__ = () # Don't. class FlagsMeta(type): def __new__(mcls, name, bases, dct): flags = {} for flagname, flag in dct.items(): if not isinstance(flag, Flag): continue flag.name = flagname flags[flagname] = flag dct[flagname] = flag.default dct['_items'] = flags return super().__new__(mcls, name, bases, dct) def __iter__(cls): return iter(cls._items.values()) class Flag: def __init__(self, *, doc: str, default: bool=False): self.name = None self.doc = doc self.default = default class flags(metaclass=FlagsMeta): pgsql_parser = Flag( doc="Debug SQL parser.") bootstrap = Flag( doc="Debug server catalog bootstrap.") bootstrap_cache_yolo = Flag( doc="Disable bootstrap cache consistency check.") edgeql_parser = Flag( doc="Debug EdgeQL parser (rebuild grammar verbosly).") edgeql_compile = Flag( doc="Dump EdgeQL/IR/SQL ASTs.") edgeql_compile_edgeql_text = Flag( doc="Dump EdgeQL Text (subset of `edgeql_compile').") edgeql_compile_edgeql_ast = Flag( doc="Dump EdgeQL AST (subset of `edgeql_compile').") edgeql_compile_scope = Flag( doc="Dump EdgeQL scope tree (subset of `edgeql_compile').") edgeql_compile_ir = Flag( doc="Dump EdgeQL IR (subset of `edgeql_compile').") edgeql_compile_sql_ast = Flag( doc="Dump generated SQL AST (subset of `edgeql_compile').") edgeql_compile_sql_ast_meta = Flag( doc="Whether to include the metadata fields when dumping the SQL AST.") edgeql_compile_sql_text = Flag( doc="Dump generated SQL text (subset of `edgeql_compile').") edgeql_compile_sql_reordered_text = Flag( doc="Dump generated SQL-like text that might better reflect scoping.") edgeql_explain = Flag( doc="Dump extra debug info when doing EXPLAIN") edgeql_disable_normalization = Flag( doc="Disable EdgeQL normalization (constant extraction etc)") graphql_compile = Flag( doc="Debug GraphQL compiler.") sdl_loading = Flag( doc="Print applied DDL when loading SDL.") delta_plan = Flag( doc="Print expanded delta command tree prior to processing.") delta_pgsql_plan = Flag( doc="Print delta command tree annortated with DB ops.") delta_execute = Flag( doc="Output SQL commands as executed during migration.") delta_execute_ddl = Flag( doc="Output just the DDL commands as executed during migration.") delta_validate_reflection = Flag( doc="Whether to do expensive validation of reflection correctness.") server = Flag( doc="Print server errors.") server_proto = Flag( doc="Print server protocol querying messages.") server_clobber_pg_conns = Flag( doc="Discard Postgres connections when releasing them to the pool.") edgeql_text_in_sql = Flag( doc="Include the EdgeQL query text in the SQL sent to Postgres.") print_locals = Flag( doc="Include values of local variables in tracebacks.") disable_qcache = Flag( doc="Disable server query cache. Parse/Execute will always recompile.") typecheck = Flag( doc="Perform runtime type checking.") pgserver = Flag( doc="Show PostgreSQL server logs and log all statements.") log_metrics = Flag( doc="Log verbose statistics on connections and compiler behavior.") disable_docs_edgeql_validation = Flag( doc="Disable validation of edgeql in docs (for site build)") pydebug_listen = Flag( doc="Enable listening for Debug Adapter Protocol connections. " "Requires pydebug to be installed." ) sql_input = Flag( doc="Enable logging of SQL incoming requests (pg compiler input)." ) sql_output = Flag( doc="Enable logging of SQL requests, compiled to the internal SQL" "(pg compiler output)." ) sql_text_in_sql = Flag( doc="Include the original SQL query text in the SQL sent to Postgres." ) zombodb = Flag(doc="Enabled zombodb and disables postgres FTS") ast_span = Flag(doc="Enables spans in markup of ASTs") @contextlib.contextmanager def timeit(title='block'): st = time.monotonic() try: yield finally: print(f'{title} took {time.monotonic() - st:.4f}s') def header(*args): print('=' * 80) print(*args) print('=' * 80) def dump(*args, **kwargs): from . import markup as _markup _markup.dump(*args, **kwargs) def dumps(*args, **kwargs): from . import markup as _markup return _markup.dumps(*args, **kwargs) def dump_code(*args, **kwargs): from . import markup as _markup _markup.dump_code(*args, **kwargs) def dump_sql(sql, *args, **kwargs): import edb.pgsql.codegen dump_code( edb.pgsql.codegen.generate_source(sql, *args, **kwargs), lexer='SQL' ) def dump_edgeql(eql, *args, **kwargs): import edb.edgeql.codegen dump_code(edb.edgeql.codegen.generate_source(eql, *args, **kwargs)) def set_trace(**kwargs): """Debugger hook that works inside worker processes. Set PYTHONBREAKPOINT=edb.common.debug.set_trace, and this will be triggered by `breakpoint()`. Unfortunately readline doesn't work when not using stdin itself, so try running the server wrapped with `rlwrap.` """ from pdb import Pdb new_stdin = open("/dev/tty", "r") Pdb(stdin=new_stdin, stdout=sys.stdout).set_trace( sys._getframe().f_back, **kwargs) def print(*args): builtins.print(*args) def init_debug_flags(): prefix = 'EDGEDB_DEBUG_' for env_name, env_val in os.environ.items(): if not env_name.startswith(prefix): continue name = env_name[len(prefix):].lower() if not hasattr(flags, name): warnings.warn(f'Unknown debug flag: {env_name!r}', stacklevel=2) continue value = env_val.strip() not in {'', '0'} setattr(flags, name, value) init_debug_flags() ================================================ FILE: edb/common/devmode.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, NamedTuple import contextlib import json import logging import os import pathlib logger = logging.getLogger('edb.devmode.cache') class CoverageConfig(NamedTuple): config: str datadir: str paths: list[str] def to_json(self) -> str: return json.dumps(self._asdict()) @classmethod def from_json(cls, js: str): dct = json.loads(js) return cls(**dct) def save_to_environ(self): os.environ.update({ 'EDGEDB_TEST_COVERAGE': self.to_json() }) @classmethod def from_environ(cls) -> Optional['CoverageConfig']: config = os.environ.get('EDGEDB_TEST_COVERAGE') if config is None: return None else: return cls.from_json(config) @classmethod def new_custom_coverage_object(cls, **conf): import coverage cov = coverage.Coverage(**conf) cov._warn_no_data = False cov._warn_unimported_source = False cov._warn_preimported_source = False return cov def new_coverage_object(self): return self.new_custom_coverage_object( config_file=self.config, source=self.paths, data_file=os.path.join(self.datadir, f'cov-{os.getpid()}'), ) @classmethod def start_coverage_if_requested(cls): cov_config = cls.from_environ() if cov_config is not None: cov = cov_config.new_coverage_object() cov.start() return cov else: return None @classmethod @contextlib.contextmanager def enable_coverage_if_requested(cls): cov_config = cls.from_environ() if cov_config is None: yield else: cov = cov_config.new_coverage_object() cov.start() try: yield finally: cov.stop() cov.save() def enable_dev_mode(enabled: bool = True): os.environ['__EDGEDB_DEVMODE'] = '1' if enabled else '' def is_in_dev_mode() -> bool: devmode = os.environ.get('__EDGEDB_DEVMODE', '0') return devmode.lower() not in ('0', '', 'false') def get_dev_mode_cache_dir() -> pathlib.Path: if is_in_dev_mode(): root = pathlib.Path(__file__).parent.parent.parent cache_dir = (root / 'build' / 'cache') cache_dir.mkdir(exist_ok=True) return cache_dir else: raise RuntimeError('server is not running in dev mode') def get_dev_mode_data_dir() -> pathlib.Path: data_dir_env = os.environ.get("EDGEDB_SERVER_DEV_DIR") if data_dir_env: data_dir = pathlib.Path(data_dir_env) else: root = pathlib.Path(__file__).parent.parent.parent data_dir = root / "tmp" / "devdatadir" return data_dir ================================================ FILE: edb/common/english.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2009-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations def add_a(word): article = 'an' if word[0] in 'aeiou' else 'a' return f'{article} {word}' ================================================ FILE: edb/common/enum.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import enum import functools class StrEnum(str, enum.Enum): """A version of string enum with reasonable __str__.""" def __str__(self): return self._value_ @functools.total_ordering class OrderedEnumMixin(): @classmethod @functools.lru_cache(None) def _index_of(cls, value): return list(cls).index(value) def __lt__(self, other): if self.__class__ is other.__class__: return self._index_of(self) < self._index_of(other) return NotImplemented ================================================ FILE: edb/common/exceptions.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2010-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import sys def _get_contexts(ex, *, auto_init=False): try: return ex.__sx_error_contexts__ except AttributeError: if auto_init: cs = ex.__sx_error_contexts__ = {} return cs else: return {} def add_context(ex, context): assert isinstance(context, ExceptionContext) contexts = _get_contexts(ex, auto_init=True) cls = context.__class__ if cls in contexts: raise ValueError( 'context {}.{} is already present in ' 'exception'.format(cls.__module__, cls.__name__)) contexts[cls] = context def replace_context(ex, context): contexts = _get_contexts(ex, auto_init=True) contexts[context.__class__] = context def get_context(ex, context_class): contexts = _get_contexts(ex) try: return contexts[context_class] except KeyError as ex: raise LookupError( '{} context class is not ' 'found'.format(context_class)) from ex def iter_contexts(ex, ctx_class=None): contexts = _get_contexts(ex) if ctx_class is None: return iter(contexts.values()) else: assert issubclass(ctx_class, ExceptionContext) return ( context for context in contexts.values() if isinstance(context, ctx_class)) class ExceptionContext: title = 'Exception Context' class DefaultExceptionContext(ExceptionContext): title = 'Details' def __init__(self, hint=None, details=None): super().__init__() self.details = details self.hint = hint _old_excepthook = sys.excepthook def _is_internal_error(exc): if isinstance(exc, ExceptionGroup): return any(_is_internal_error(e) for e in exc.exceptions) # This is pretty cheesy but avoids needing to import our edgedb # exceptions or do anything elaborate with contexts. return type(exc).__name__ == 'InternalServerError' def excepthook(exctype, exc, tb): try: from edb.common import markup markup.dump(exc, file=sys.stderr) if _is_internal_error(exc): # TODO(rename): change URL once we can print( f'This is most likely a bug in Gel. ' f'Please consider opening an issue ticket ' f'at https://github.com/edgedb/edgedb/issues/new' f'?template=bug_report.md' ) except Exception as ex: print('!!! exception in edb.excepthook !!!', file=sys.stderr) # Attach the original exception as a context to top of the new chain, # but only if it's not already there. Take some care to avoid looping # forever. visited = set() parent = ex while parent.__cause__ or ( not parent.__suppress_context__ and parent.__context__): if (parent in visited or parent.__context__ is exc or parent.__cause__ is exc): break visited.add(parent) parent = parent.__cause__ or parent.__context__ parent.__context__ = exc parent.__cause__ = None _old_excepthook(type(ex), ex, ex.__traceback__) def install_excepthook(): sys.excepthook = excepthook def uninstall_excepthook(): sys.excepthook = _old_excepthook ================================================ FILE: edb/common/levenshtein.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations def distance(s: str, t: str) -> int: """Calculates Levenshtein distance between s and t.""" m, n = len(s), len(t) if m > n: s, t = t, s m, n = n, m ri = list(range(m + 1)) for i in range(1, n + 1): ri_1, ri = ri, [i] + [0] * m for j in range(1, m + 1): ri[j] = min(ri_1[j] + 1, ri[j - 1] + 1, ri_1[j - 1] + int(s[j - 1] != t[i - 1])) return ri[m] ================================================ FILE: edb/common/log.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations # DON'T IMPORT asyncio or any package that creates their own logger here, # or the "tenant" value cannot be injected. import contextvars import logging current_tenant = contextvars.ContextVar("current_tenant", default="-") class EdgeDBLogger(logging.Logger): def makeRecord( self, name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None, ): # Unlike the standard Logger class, we allow overwriting # all attributes of the log record with stuff from *extra*. factory = logging.getLogRecordFactory() rv = factory(name, level, fn, lno, msg, args, exc_info, func, sinfo) rv.__dict__["tenant"] = current_tenant.get() if extra is not None: rv.__dict__.update(extra) return rv def early_setup(): logging.setLoggerClass(EdgeDBLogger) ================================================ FILE: edb/common/lru.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import collections.abc import functools from typing import Callable, Optional from types import MethodType class LRUMapping(collections.abc.MutableMapping): # We use an OrderedDict for LRU implementation. Operations: # # * We use a simple `__setitem__` to push a new entry: # `entries[key] = new_entry` # That will push `new_entry` to the *end* of the entries dict. # # * When we have a cache hit, we call # `entries.move_to_end(key, last=True)` # to move the entry to the *end* of the entries dict. # # * When we need to remove entries to maintain `max_size`, we call # `entries.popitem(last=False)` # to remove an entry from the *beginning* of the entries dict. # # So new entries and hits are always promoted to the end of the # entries dict, whereas the unused one will group in the # beginning of it. def __init__(self, *, maxsize): if maxsize <= 0: raise ValueError( f'maxsize is expected to be greater than 0, got {maxsize}' ) self._dict = collections.OrderedDict() self._maxsize = maxsize def __getitem__(self, key): o = self._dict[key] self._dict.move_to_end(key, last=True) return o def __setitem__(self, key, o): if key in self._dict: self._dict[key] = o self._dict.move_to_end(key, last=True) else: self._dict[key] = o if len(self._dict) > self._maxsize: self._dict.popitem(last=False) def __delitem__(self, key): del self._dict[key] def __contains__(self, key): return key in self._dict def __len__(self): return len(self._dict) def __iter__(self): return iter(self._dict) class _NoPickle: def __init__(self, obj): self.obj = obj def __bool__(self): return bool(self.obj) def __getstate__(self): return () def __setstate__(self, _d): self.obj = None def lru_method_cache[Tf: Callable]( maxsize: int | None = 128, ) -> Callable[[Tf], Tf]: """A version of lru_cache for methods that shouldn't leak memory. Basically the idea is that we generate a per-object lru-cached partially applied method. Since pickling an lru_cache of a lambda or a functools.partial doesn't work, we wrap it in a _NoPickle object that doesn't pickle its contents. """ def transformer(f: Tf) -> Tf: key = f'__{f.__name__}_cached' @functools.wraps(f) def func(self, *args, **kwargs): _m = getattr(self, key, None) if not _m: _m = _NoPickle( functools.lru_cache(maxsize)(functools.partial(f, self)) ) setattr(self, key, _m) return _m.obj(*args, **kwargs) return func # type: ignore return transformer def method_cache[Tf: Callable](f: Tf) -> Tf: return lru_method_cache(None)(f) def clear_method_cache[Tf](method: Tf) -> None: assert isinstance(method, MethodType) key = f'__{method.__func__.__name__}_cached' _m: Optional[_NoPickle] = getattr(method.__self__, key, None) if _m is not None: _m.obj.cache_clear() _LRU_CACHES: list[functools._lru_cache_wrapper] = [] def per_job_lru_cache[Tf: Callable]( maxsize: int | None = 128, ) -> Callable[[Tf], Tf]: """A version of lru_cache that can be cleared en masse. All the caches will be tracked and calling clear_lru_caches() will clear them all. """ def transformer(f: Tf) -> Tf: wrapped = functools.lru_cache(maxsize)(f) _LRU_CACHES.append(wrapped) return wrapped # type: ignore return transformer def clear_lru_caches(): for cache in _LRU_CACHES: cache.cache_clear() ================================================ FILE: edb/common/markup/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import abc import sys from edb.common import exceptions from . import elements, serializer, renderers from .serializer import serialize from .serializer import base as _base_serializer from .serializer.base import Context # noqa from .elements.base import Markup # noqa class MarkupCapableMixin: def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if hasattr(cls, 'as_markup'): serializer.serializer.register(cls)(cls.as_markup) class MarkupExceptionContext( exceptions.ExceptionContext, MarkupCapableMixin, ): @abc.abstractclassmethod # type: ignore def as_markup(cls, *, ctx): pass def _serialize(obj, trim=True, kwargs=None): ctx = _base_serializer.Context(trim=trim, kwargs=kwargs) try: return serialize(obj, ctx=ctx) finally: ctx.reset() def dumps(obj, header=None, trim=True): markup = _serialize(obj, trim=trim) if header is not None: markup = elements.doc.Section(title=header, body=[markup]) return renderers.terminal.renders(markup) def _dump(markup, header, file): if header is not None: markup = elements.doc.Section(title=header, body=[markup]) renderers.terminal.render(markup, file=file) def dump(*objs, file=None, trim=True, marker=None, **kwargs): for obj in objs: if marker: markup = elements.doc.Marker(text=marker) renderers.terminal.render(markup, file=file, ensure_newline=False) markup = _serialize(obj, trim=trim, kwargs=kwargs) _dump(markup, None, file) def dump_code(code: str, *, lexer='python', header=None, file=None): markup = serializer.serialize_code(code, lexer=lexer) _dump(markup, header, file) def dump_callstack(f=None, *, limit=None, header=None, file=None, trim=True): if f is None: try: raise ZeroDivisionError except ZeroDivisionError: f = sys.exc_info()[2].tb_frame.f_back if limit is None: limit = getattr(sys, 'tracebacklimit', None) result = [] i = 0 start_frame = f ctx = _base_serializer.Context(trim=trim) while f is not None and (limit is None or i < limit): result.append(_base_serializer.serialize_callstack_point(f, ctx=ctx)) f = f.f_back i += 1 result.reverse() markup = elements.lang.Traceback(items=result, id=id(start_frame)) _dump(markup, header, file) ================================================ FILE: edb/common/markup/elements/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from . import base, lang, doc, code # NOQA ================================================ FILE: edb/common/markup/elements/base.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb.common.struct import RTStruct, StructMeta, Field from edb.common import checked class MarkupMeta(StructMeta): def __new__(mcls, name, bases, dct, ns=None, **kwargs): cls = super().__new__(mcls, name, bases, dct, **kwargs) cls._markup_ns = ns ns_name = [name] for base in cls.__mro__: try: base_ns = base._markup_ns except AttributeError: pass else: if base_ns is not None: ns_name.append(base_ns) cls._markup_name = '.'.join(reversed(ns_name)) cls._markup_name_safe = '_'.join(reversed(ns_name)) return cls def __init__(cls, name, bases, dct, ns=None, **kwargs): super().__init__(name, bases, dct, **kwargs) def __instancecheck__(cls, inst): # We make OverflowBarier and SerializationError be instanceof # and subclassof any Markup class. This avoids errors when # they are being added to various CheckedList & CheckedDict # collections. parent_check = type(RTStruct).__instancecheck__ if parent_check(cls, inst): return True return type(inst) in (OverflowBarier, SerializationError) def __subclasscheck__(cls, subcls): parent_check = type(RTStruct).__subclasscheck__ if parent_check(cls, subcls): return True return subcls in (OverflowBarier, SerializationError) class Markup(RTStruct, metaclass=MarkupMeta, use_slots=True): """Base class for all markup elements.""" MarkupList = checked.CheckedList[Markup] MarkupMapping = checked.CheckedDict[str, Markup] class OverflowBarier(Markup): """Represents that the nesting level of objects was too big.""" class SerializationError(Markup): """An error during object serialization occurred.""" text = Field(str) cls = Field(str) ================================================ FILE: edb/common/markup/elements/code.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb.common import checked from edb.common import struct from . import base class BaseCode(base.Markup, ns='code'): pass class Token(BaseCode): val = struct.Field(str) class Code(BaseCode): tokens = struct.Field( checked.CheckedList[Token], default=None, coerce=True ) class Whitespace(Token): pass class Comment(Token): pass class Keyword(Token): pass class Type(Token): pass class Operator(Token): pass class Name(Token): pass class Constant(Name): pass class BuiltinName(Name): pass class FunctionName(Name): pass class ClassName(Name): pass class Decorator(Token): pass class Attribute(Token): pass class Tag(Token): pass class Literal(Token): pass class String(Literal): pass class Number(Literal): pass class Punctuation(Token): """Characters ',', ':', '[', etc.""" class Error(Token): pass ================================================ FILE: edb/common/markup/elements/doc.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import difflib from edb.common import checked from edb.common.struct import Field from . import base class DocMarkup(base.Markup, ns='doc'): pass class Marker(DocMarkup): text = Field(str) class Section(DocMarkup): title = Field(str, coerce=True, default=None) body = Field(base.MarkupList, coerce=True) collapsed = Field(bool, coerce=True, default=False) class SubNode(DocMarkup): body = Field(base.Markup) class Text(DocMarkup): text = Field(str) class SourceCode(DocMarkup): text = Field(str) class Diff(DocMarkup): lines = Field(checked.CheckedList[str], coerce=True) @classmethod def get_diff( cls, a, b, fromfile='', tofile='', fromfiledate='', tofiledate='', n=10 ): lines = difflib.unified_diff( a, b, fromfile, tofile, fromfiledate, tofiledate, n) lines = [line.rstrip() for line in lines] if lines: return cls(lines=lines) else: return Text(text='No differences') class ValueDiff(DocMarkup): before = Field(str) after = Field(str) comment = Field(str, default=None) ================================================ FILE: edb/common/markup/elements/lang.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import linecache from edb.common.struct import Field from edb.common import checked from . import base class LangMarkup(base.Markup, ns='lang'): pass class Number(LangMarkup): num = Field(str, default=None, coerce=True) class String(LangMarkup): str = Field(str, default=None, coerce=True) class MultilineString(LangMarkup): str = Field(str, default=None, coerce=True) class Ref(LangMarkup): ref = Field(int, coerce=True) refname = Field(str, default=None) def __repr__(self): return '<{} {} {}>'.format('Ref', self.refname, self.ref) class BaseObject(LangMarkup): """Base language object with ``id``, but without ``attributes``.""" id = Field(int, default=None, coerce=True) class Object(BaseObject): class_module = Field(str) classname = Field(str) repr = Field(str, default=None) attributes = Field(base.MarkupMapping, default=None, coerce=True) class List(BaseObject): items = Field( # type: ignore[assignment] base.MarkupList, default=base.MarkupList, coerce=True) trimmed = Field(bool, default=False) brackets = Field(str, default="[]") class Dict(BaseObject): items = Field( # type: ignore[assignment] base.MarkupMapping, default=base.MarkupMapping, coerce=True) trimmed = Field(bool, default=False) class TreeNodeChild(BaseObject): label = Field(str, default=None) node = Field(base.Markup) TreeNodeChildrenList = checked.CheckedList[TreeNodeChild] class TreeNode(BaseObject): name = Field(str) children = Field( TreeNodeChildrenList, default=TreeNodeChildrenList, coerce=True) def add_child(self, *, label=None, node): self.children.append(TreeNodeChild(label=label, node=node)) class NoneConstantType(LangMarkup): pass class TrueConstantType(LangMarkup): pass class FalseConstantType(LangMarkup): pass class Constants: none = NoneConstantType() true = TrueConstantType() false = FalseConstantType() class TracebackPoint(BaseObject): name = Field(str, default=None) filename = Field(str, default=None) lineno = Field(int, default=None) colno = Field(int, default=None) end_colno = Field(int, default=None) address = Field(str, default=None) context = Field(bool, default=False) lines = Field(checked.CheckedList[str], default=None, coerce=True) line_numbers = Field(checked.CheckedList[int], default=None, coerce=True) locals = Field(Dict, default=None) def load_source(self, window=3, lines=None): self.lines = self.line_numbers = None if (self.lineno and ((self.filename and not self.filename.startswith('<') and not self.filename.endswith('>')) or lines)): lineno = self.lineno if not lines: linecache.checkcache(self.filename) sourcelines = linecache.getlines(self.filename, globals()) else: sourcelines = lines lines = [] line_numbers = [] start = max(1, lineno - window) end = min(len(sourcelines), lineno + window) + 1 for i in range(start, end): lines.append(sourcelines[i - 1].rstrip()) line_numbers.append(i) if lines: self.lines = checked.CheckedList[str](lines) self.line_numbers = checked.CheckedList[int](line_numbers) TracebackPointList = checked.CheckedList[TracebackPoint] class Traceback(BaseObject): items = Field( # type: ignore[assignment] TracebackPointList, default=TracebackPointList, coerce=True) class ExceptionContext(BaseObject): title = Field(str, default='Context') body = Field(base.MarkupList, coerce=True) ExceptionContextList = checked.CheckedList[ExceptionContext] class _Exception(Object): pass class Exception(_Exception): msg = Field(str) # NB: Traceback is just an exception context # contexts = Field(ExceptionContextList, default=None, coerce=True) context = Field(_Exception, None) cause = Field(_Exception, None) ================================================ FILE: edb/common/markup/format.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2012-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations def xrepr(obj, *, max_len=None): """Extended ``builtins.repr`` function. Examples: .. code-block:: pycon >>> xrepr('1234567890', max_len=7) '12'... :param int max_len: When defined limits maximum length of the result string representation. :returns str: """ result = str(repr(obj)) if max_len is not None and len(result) > max_len: ext = '...' if result[0] in ('"', "'"): ext = result[0] + ext elif result[0] == '<': ext = '>' + ext result = result[:(max_len - len(ext))] + ext return result ================================================ FILE: edb/common/markup/renderers/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from . import terminal # NOQA ================================================ FILE: edb/common/markup/renderers/styles.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb.common.term import Style16, Style256 class StylesTable: def __getattr__(self, key): # If we're querying some non-existing style, pretend it's empty # return Style16() class Dark16(StylesTable): Style = Style16 id = Style() bracket = Style() header1 = Style(color='white') header2 = Style(color='black', bold=True) exc_title = Style(color='red', bold=True) marker = Style(color='red', bold=True) tb_name = Style(color='yellow') tb_filename = Style() tb_lineno = Style() tb_current_line = Style() tb_code = Style() tb_pos_caret = Style(color='white', bold=True) attribute = Style(color='black', bold=True) key = Style(color='yellow') tree_node = Style(color='red', bold=True) constant = Style(color='cyan') literal = Style(color='green') ref = Style(color='red') unknown_object = Style(color='blue', bold=True) serialization_error = unknown_markup = overflow = Style( color='white', bgcolor='red') diff_anno = Style(color='white') diff_after = Style(color='green') diff_before = Style(color='red') code = Style(color='white') code_decorator = Style(color='black', bold=True) code_comment = attribute code_string = literal code_number = Style(color='green') code_constant = constant code_punctuation = bracket code_keyword = constant code_decl_name = tree_node code_tag = code_keyword code_attribute = Style(color=attribute.color) class Dark256(StylesTable): Style = Style256 id = Style(color='#3d6559') bracket = Style(color='#a7a963') header1 = Style(color='#656565') header2 = Style(color='#474747') exc_title = Style(color='#d84903', bold=True) marker = Style(color='#bbb', bgcolor='#582a70', bold=True) tb_name = Style(color='#5f5f87', bold=True) tb_filename = Style() tb_lineno = Style() tb_current_line = Style(color='#fff', bold=True) tb_code = Style() tb_pos_caret = Style(color='white', bold=True) attribute = Style(color='#565656', bold=True) key = Style(color='#5f875f', bold=True) tree_node = Style(color='#bc74d7', bold=True) constant = Style(color='#1dbdd0') literal = Style(color='#4aa336') ref = Style(color='#586c9e') unknown_object = Style(color='#707070') unknown_markup = overflow = Style(color='white', bgcolor='#84345a') serialization_error = Style(color='white', bgcolor='#900') diff_anno = Style(color='#777') diff_after = Style(color='#4aa336') diff_before = Style(color='#A00') code = Style(color='#aaa') code_decorator = Style(color='#af5f00') code_comment = attribute code_string = literal code_number = Style(color='#af5f5f') code_constant = constant code_punctuation = bracket code_keyword = constant code_decl_name = tree_node code_tag = code_keyword code_attribute = Style(color=attribute.color) ================================================ FILE: edb/common/markup/renderers/terminal.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import sys import contextlib from edb.common import term from edb.common.markup.format import xrepr from .. import elements from . import styles as styles_module LINE_BREAK = 1 FOLDABLE_LINES_START = 2 FOLDABLE_LINES_END = 3 NON_FOLDED_SPACE = 4 FOLDED_SPACE = 5 INDENT = 10 INDENT_NO_NL = 11 DEDENT = 20 DEDENT_NO_NL = 21 NEW_LINE = 30 DATA = 100 HEADER = 101 class Buffer: def __init__( self, *, max_width=None, styled=False, indentation=0, indent_with=' ' * 4, ): self.data = [] self.indentation = 0 self.indent_with = indent_with self.max_width = max_width self.styled = styled def new_line(self, lines=1): for _ in range(lines): self.data.append((NEW_LINE, )) @contextlib.contextmanager def indent(self, auto_new_line=True): if auto_new_line: self.data.append((INDENT, )) yield self.data.append((DEDENT, )) else: self.data.append((INDENT_NO_NL, )) yield self.data.append((DEDENT_NO_NL, )) def non_folded_space(self, space=' '): self.data.append((NON_FOLDED_SPACE, space)) def folded_space(self, space=' '): self.data.append((FOLDED_SPACE, space)) @contextlib.contextmanager def foldable_lines(self): self.data.append((FOLDABLE_LINES_START, )) yield self.data.append((FOLDABLE_LINES_END, )) @contextlib.contextmanager def non_foldable_lines(self): self.data.append((FOLDABLE_LINES_END, )) yield self.data.append((FOLDABLE_LINES_START, )) def mark_line_break(self): self.data.append((LINE_BREAK, )) def write(self, s, style=None): st = None if self.styled and style is not None and not style.empty: st = style self.data.append((DATA, str(s), st)) def header(self, s, style=None, level=1): st = None if self.styled and style is not None and not style.empty: st = style self.data.append((HEADER, str(s), st, level)) def flush(self): data = self.data self.data = None indentation = self.indentation indent_with = self.indent_with indent_with_len = len(indent_with) max_width = self.max_width result = [] folded_mode = 0 offset = 0 def check_folded_fit(pos, data, width): _len = 0 smlines = 0 smlines_max = 0 for item in data[pos:]: code = item[0] if code == FOLDABLE_LINES_START: smlines += 1 smlines_max += 1 elif code == FOLDABLE_LINES_END: smlines -= 1 if not smlines: break elif code == DATA or code == HEADER: _len += len(item[1]) elif code == LINE_BREAK: _len += 1 elif code == FOLDED_SPACE: _len += 1 if _len > width: return 0 if _len < width: return smlines_max else: return 0 for pos, item in enumerate(data): el = item[0] if el == INDENT: indentation += 1 if not folded_mode: result.append('\n' + indent_with * indentation) offset = indent_with_len * indentation elif el == DEDENT: indentation -= 1 if not folded_mode: result.append('\n' + indent_with * indentation) offset = indent_with_len * indentation elif el == INDENT_NO_NL: indentation += 1 elif el == DEDENT_NO_NL: indentation -= 1 elif el == NEW_LINE: if not folded_mode: result.append('\n' + indent_with * indentation) offset = indent_with_len * indentation elif el == FOLDABLE_LINES_START: if (not folded_mode) and (max_width is not None) and ( max_width - offset > 20): folded_mode = check_folded_fit( pos, data, max_width - offset) elif el == FOLDABLE_LINES_END: if folded_mode: folded_mode -= 1 elif el == LINE_BREAK: if folded_mode: result.append(' ') offset += 1 else: result.append('\n' + indent_with * indentation) offset = indent_with_len * indentation elif el == NON_FOLDED_SPACE: if not folded_mode: result.append(item[1]) elif el == FOLDED_SPACE: if folded_mode: result.append(item[1]) elif el == DATA or el == HEADER: # ``item[1]`` -- text to output, ``item[2]`` -- its style # ``item[3]`` -- its level, for headers # text, style = item[1], item[2] if el == HEADER: text = ' {} '.format(text) strlevel = '=' if item[3] == 0 else '-' if self.max_width: width = self.max_width - offset text = '{{str:{strlevel}^{width:d}s}}'.format( strlevel=strlevel, width=width).format(str=text) else: text = strlevel * 4 + text + strlevel * 4 if style is None: result.append(text) else: # If there's a style object - let's apply it # result.append(style.apply(text)) offset += len(text) elif el == HEADER: # ``item[1]`` -- text to output, ``item[2]`` -- its style, # _, text, style, _level = el if item[2] is None: result.append(item[1]) else: # If there's a style object - let's apply it # result.append(item[2].apply(item[1])) offset += len(item[1]) else: raise AssertionError(f"Unexpected element: {el}") return ''.join(result) class BaseRenderer: def __init__(self, *, indent_with=' ' * 4, max_width=None, styles=None): self.renderers_cache = {} self.buffer = Buffer( max_width=max_width, styled=styles, indent_with=indent_with) self.max_width = max_width self.styles = styles or styles_module.StylesTable() def _render(self, markup): cls = markup.__class__ renderer = None if not issubclass(cls, elements.base.Markup): return self._render_unknown(markup) try: renderer = self.renderers_cache[cls] except KeyError: cls_name = markup.__class__._markup_name_safe try: renderer = getattr(self, '_render_{}'.format(cls_name)) except AttributeError: for base in markup.__class__.__mro__: if issubclass(base, elements.base.Markup): try: renderer = getattr( self, '_render_{}'.format(base._markup_name_safe)) except AttributeError: pass else: self.renderers_cache[cls] = renderer break else: self.renderers_cache[cls] = renderer if renderer is None: raise Exception('no renderer found for {!r}'.format(markup)) return renderer(markup) def _render_header(self, str, style=None, level=1): self.buffer.header(str, style=style, level=level) def _render_unknown(self, element): self.buffer.write( xrepr(element, max_len=120), style=self.styles.unknown_markup) def _render_Markup(self, element): self.buffer.write( xrepr(element, max_len=120), style=self.styles.unknown_markup) def _render_OverflowBarier(self, element): self.buffer.write('<...>', style=self.styles.overflow) def _render_SerializationError(self, element): self.buffer.write( 'Exception during serialization to markup: <{}: {}>'.format( element.cls, element.text), style=self.styles.serialization_error) @classmethod def renders(cls, markup, styles=None, max_width=None): renderer = cls(max_width=max_width, styles=styles) renderer._render(markup) return renderer.buffer.flush() class DocRenderer(BaseRenderer): def _render_doc_Text(self, element): self.buffer.write(element.text) def _render_doc_SourceCode(self, element): self.buffer.write(element.text) def _render_doc_Marker(self, element): self.buffer.write(element.text, style=self.styles.marker) self.buffer.write(' ') def _render_doc_SubNode(self, element): with self.buffer.indent(): self._render(element.body) def _render_doc_Section(self, element): if element.title: self._render_header(element.title, style=self.styles.header1) self.buffer.new_line(2) for el in element.body: self._render(el) self.buffer.new_line() def _render_doc_ValueDiff(self, element): self.buffer.write(element.before, style=self.styles.diff_before) self.buffer.write(' | ') self.buffer.write(element.after, style=self.styles.diff_after) if element.comment: self.buffer.write( f' # {element.comment}', style=self.styles.code_comment) def _render_doc_Diff(self, element): total_lines = len(element.lines) for linenum, line in enumerate(element.lines): style = None if line.startswith('+'): style = self.styles.diff_after elif line.startswith('-'): style = self.styles.diff_before elif line.startswith('@@'): style = self.styles.diff_anno self.buffer.write(line, style=style) if linenum < total_lines - 1: self.buffer.new_line() class LangRenderer(BaseRenderer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.ex_depth = 0 def _render_lang_TreeNode(self, element): with self.buffer.foldable_lines(): self.buffer.write(element.name, style=self.styles.tree_node) self.buffer.non_folded_space() if element.id: self.buffer.write( '<0x{:x}>'.format(int(element.id)), style=self.styles.id) self.buffer.non_folded_space() self.buffer.write('(', style=self.styles.id) child_count = len(element.children) if child_count: key = lambda child: (len(child.label) if child.label else 0) longest_lbl = max(element.children, key=key).label padding = min(len(longest_lbl) if longest_lbl else 0, 20) with self.buffer.indent(): for idx, child in enumerate(element.children): if child.label: self.buffer.write( child.label, style=self.styles.attribute) self.buffer.non_folded_space( ' ' * (max(0, padding - len(child.label)) + 1)) self.buffer.write('=') self.buffer.non_folded_space() self._render(child.node) if idx < (child_count - 1): self.buffer.write(',') self.buffer.mark_line_break() self.buffer.write(')', style=self.styles.id) def _render_lang_Ref(self, element): self.buffer.write( ''.format(element.refname), style=self.styles.ref) def _render_lang_List(self, element): with self.buffer.foldable_lines(): self.buffer.write(element.brackets[0], style=self.styles.bracket) item_count = len(element.items) if item_count: with self.buffer.indent(): for idx, item in enumerate(element.items): self._render(item) if idx < (item_count - 1): self.buffer.write(',') self.buffer.mark_line_break() if element.trimmed: self.buffer.write('...') self.buffer.write(element.brackets[1], style=self.styles.bracket) def _render_mapping_(self, mapping, trimmed=False): self.buffer.write('{', style=self.styles.bracket) item_count = len(mapping) if item_count: with self.buffer.indent(): for idx, (key, value) in enumerate(mapping.items()): self.buffer.write(key, style=self.styles.key) self.buffer.write(': ') self._render(value) if idx < (item_count - 1): self.buffer.write(',') self.buffer.mark_line_break() if trimmed: self.buffer.write('...') self.buffer.write('}', style=self.styles.bracket) def _render_lang_Dict(self, element): with self.buffer.foldable_lines(): self._render_mapping_(element.items, trimmed=element.trimmed) def _render_lang_Object(self, element): if element.attributes or element.repr is None: self.buffer.write( '<{}.{} at 0x{:x}'.format( element.class_module, element.classname, element.id), style=self.styles.unknown_object) if element.attributes: self.buffer.write(' ') self._render_mapping_(element.attributes) self.buffer.write('>', style=self.styles.unknown_object) else: self.buffer.write(element.repr, style=self.styles.unknown_object) def _render_lang_String(self, element): self.buffer.write( xrepr(element.str, max_len=120), style=self.styles.literal) def _render_lang_MultilineString(self, element): with self.buffer.non_foldable_lines(): for line in element.str.splitlines(): self.buffer.new_line() self.buffer.write( line, style=self.styles.literal ) self.buffer.data.append((DEDENT_NO_NL, )) self.buffer.new_line() self.buffer.data.append((INDENT_NO_NL, )) def _render_lang_Number(self, element): self.buffer.write(element.num, style=self.styles.literal) def _render_lang_NoneConstantType(self, element): self.buffer.write('None', self.styles.constant) def _render_lang_TrueConstantType(self, element): self.buffer.write('True', self.styles.constant) def _render_lang_FalseConstantType(self, element): self.buffer.write('False', self.styles.constant) def _render_lang_TracebackPoint(self, element): with self.buffer.indent(False): self.buffer.new_line() self.buffer.write(element.filename, style=self.styles.tb_filename) if element.lineno: self.buffer.write(', line ') self.buffer.write(element.lineno, style=self.styles.tb_lineno) if element.address: self.buffer.write(', at ') self.buffer.write(element.address, style=self.styles.tb_lineno) self.buffer.write(', in ') self.buffer.write(element.name, style=self.styles.tb_name) with self.buffer.indent(False): self.buffer.new_line() if element.lines and element.line_numbers: for lineno, line in zip( element.line_numbers, element.lines): if lineno == element.lineno: if element.context: stripped_spaces = 0 stripped_line = line else: stripped_spaces = len(line) - len( line.lstrip()) stripped_line = line.strip() self.buffer.write( '> ', style=self.styles.tb_current_line) self.buffer.write( stripped_line or '???', style=self.styles.tb_code) if element.colno: # Render column caret _caret_indent = ' ' * ( element.colno - stripped_spaces) self.buffer.new_line() self.buffer.write(' ', style=self.styles.code) self.buffer.write( _caret_indent + '^', style=self.styles.tb_pos_caret) if element.end_colno is not None: cnt = element.end_colno - element.colno - 1 self.buffer.write( '^' * cnt, style=self.styles.tb_pos_caret) self.buffer.new_line() if not element.context: break elif element.context: self.buffer.write('| ', style=self.styles.code) self.buffer.write( line.rstrip(), style=self.styles.code) self.buffer.new_line() else: if not element.context: self.buffer.write('???', style=self.styles.tb_code) if element.locals: self.buffer.new_line(2) self.buffer.write('Locals: ', style=self.styles.attribute) self._render(element.locals) self.buffer.new_line() def _render_lang_Traceback(self, element): for item in element.items: self._render(item) def _render_lang_ExceptionContext(self, element): self.buffer.new_line(2) self._render_header(element.title, level=2, style=self.styles.header2) self.buffer.new_line() if element.body: for el in element.body: self._render(el) def _render_lang_Exception(self, element): self.ex_depth += 1 try: if self.ex_depth == 1: msg = 'Exception occurred' if element.msg: msg = '{}: {}'.format(msg, element.msg) self._render_header(msg, style=self.styles.header1) self.buffer.new_line(2) if (element.cause or element.context) is not None: if element.cause is None: self._render(element.context) msg = ('During handling of the above exception, ' 'another exception occurred') else: self._render(element.cause) msg = ('The above exception was the direct cause ' 'of the following exception') self.buffer.new_line(2) self._render_header(msg, style=self.styles.header1) self.buffer.new_line(2) if element.class_module == 'builtins': excclass = element.classname else: excclass = '{}.{}'.format( element.class_module, element.classname) base_excline = '{}: {}'.format(excclass, element.msg) self.buffer.write( '{}. {}'.format(self.ex_depth, base_excline), style=self.styles.exc_title) if element.contexts: for context in element.contexts: self._render(context) self.buffer.new_line() self.buffer.new_line() self.buffer.write(base_excline, style=self.styles.exc_title) finally: self.ex_depth -= 1 class CodeRenderer(BaseRenderer): def _write_code_token(self, val, style): parts = val.split('\n') for chunk in parts[:-1]: self.buffer.write(chunk, style=style) self.buffer.new_line() self.buffer.write(parts[-1], style=style) def _render_code_Token(self, element): self._write_code_token(element.val, style=self.styles.code) def _render_code_Comment(self, element): self._write_code_token(element.val, style=self.styles.code_comment) def _render_code_Decorator(self, element): self._write_code_token(element.val, style=self.styles.code_decorator) def _render_code_String(self, element): self._write_code_token(element.val, style=self.styles.code_string) def _render_code_Number(self, element): self._write_code_token(element.val, style=self.styles.code_number) def _render_code_ClassName(self, element): self._write_code_token(element.val, style=self.styles.code_decl_name) def _render_code_FunctionName(self, element): self._write_code_token(element.val, style=self.styles.code_decl_name) def _render_code_Constant(self, element): self._write_code_token(element.val, style=self.styles.code_constant) def _render_code_Keyword(self, element): self._write_code_token(element.val, style=self.styles.code_keyword) def _render_code_Punctuation(self, element): self._write_code_token(element.val, style=self.styles.code_punctuation) def _render_code_Tag(self, element): self._write_code_token(element.val, style=self.styles.code_tag) def _render_code_Attribute(self, element): self._write_code_token(element.val, style=self.styles.code_attribute) def _render_code_Code(self, element): if len(element.tokens) > 20: with self.buffer.indent(): for token in element.tokens: self._render(token) else: for token in element.tokens: self._render(token) class Renderer(DocRenderer, LangRenderer, CodeRenderer): pass renders = Renderer.renders def render(markup, *, ensure_newline=True, file=None, renderer=Renderer): if file is None: file = sys.stdout try: fileno = file.fileno() except OSError: # This is a hack to try to get nice colorized dump output over # a remote-pdb connection. If the output is redirected to # something without fileno, use what ought to be stdout's fileno # to decide on color, etc. fileno = 1 max_width = term.size(fileno)[1] style_table = None if term.use_colors(fileno): max_colors = term.max_colors() if max_colors > 255: style_table = styles_module.Dark256 elif max_colors > 6: style_table = styles_module.Dark16 rendered = renderer.renders( markup, styles=style_table, max_width=max_width) if ensure_newline and not rendered.endswith('\n'): rendered += '\n' print(rendered, file=file, end='') ================================================ FILE: edb/common/markup/serializer/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations class settings: censor_sensitive_vars = True censor_list = ['secret', 'password'] from .base import serialize, serializer, serialize_traceback_point # NOQA from .base import Context # NOQA from .base import no_ref_detect # NOQA from .code import serialize_code # NOQA from . import logging # NOQA ================================================ FILE: edb/common/markup/serializer/base.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import collections import collections.abc import decimal import functools import types import weakref from .. import elements from edb.common import exceptions from edb.common.markup.format import xrepr from edb.common import debug from . import settings #: Maximum level of nested structures that we can serialize. #: If we reach it - we'll just stop traversing the objects #: tree at that point and yield 'elements.base.OverflowBarier' #: OVERFLOW_BARIER = 100 # XXX Configurable? #: Maximum number of total 'serialize' calls. #: If we reach it - we'll just stop traversing the objects #: tree at that point and yield 'elements.base.OverflowBarier' #: RUN_OVERFLOW_BARIER = 5000 # XXX Configurable? __all__ = 'serialize', def no_ref_detect[T](func: T) -> T: """Serializer decorated with ``no_ref_detect`` will be executed without prior checking the memo if object was already serialized""" func.no_ref_detect = True # type: ignore return func @functools.singledispatch def serializer(obj, *, ctx): """Markup serializers dispatcher""" raise NotImplementedError class Context: """Markup serialization context. Holds the ``memo`` set, which is used to avoid serializing objects that already have been serialized, and ``depth`` - recursion depth""" def __init__(self, trim=True, kwargs=None): self.reset() self.trim = trim self.kwargs = kwargs or {} if settings.censor_sensitive_vars: self.censor_set = set(settings.censor_list) else: self.censor_set = set() def censored(self, key): return key in self.censor_set def reset(self): self.memo = set() self.keep_alive = [] self.level = 0 self.run_cnt = 0 def serialize(obj, *, ctx): """Serialize arbitrary python object to Markup elements""" tobj = type(obj) sr = serializer.dispatch(tobj) if sr is serializer: raise LookupError(f'unable to find serializer for object {obj!r}') if (sr is serialize_unknown_object and hasattr(tobj, '__dataclass_fields__')): sr = serialize_dataclass ctx.level += 1 ctx.run_cnt += 1 try: if ctx.level >= OVERFLOW_BARIER or ctx.run_cnt >= RUN_OVERFLOW_BARIER: return elements.base.OverflowBarier() ref_detect = True try: # Was the serializer decorated with ``@no_ref_detect``? # ref_detect = not sr.no_ref_detect except AttributeError: pass if ref_detect: # OK, so if we've already serialized obj, don't do that again, just # return ``markup.Ref`` element. # obj_id = id(obj) if obj_id in ctx.memo: return elements.lang.Ref(ref=obj_id, refname=repr(obj)) else: ctx.memo.add(obj_id) ctx.keep_alive.append(obj) try: return sr(obj, ctx=ctx) except Exception as ex: return elements.base.SerializationError( text=str(ex), cls='{}.{}'.format( ex.__class__.__module__, ex.__class__.__name__)) finally: ctx.level -= 1 @no_ref_detect def _serialize_traceback_point( obj, frame, lineno, *, ctx, include_source=True, source_window_size=2, include_locals=False, point_cls=elements.lang.TracebackPoint, ): assert source_window_size >= 0 name = frame.f_code.co_name filename = frame.f_code.co_filename locals = None if include_locals or debug.flags.print_locals: locals = serialize(dict(frame.f_locals), ctx=ctx) if filename.startswith('.'): frame_fn = frame.f_globals.get('__file__') if frame_fn and frame_fn.endswith(filename[2:]): filename = frame_fn point = point_cls( name=name, lineno=lineno, filename=filename, locals=locals, id=id(obj)) if include_source: point.load_source(window=source_window_size) return point @no_ref_detect def serialize_traceback_point( obj, *, ctx, include_source=True, source_window_size=2, include_locals=False, point_cls=elements.lang.TracebackPoint, ): assert isinstance(obj, types.TracebackType) return _serialize_traceback_point( obj, obj.tb_frame, obj.tb_lineno, ctx=ctx, include_source=include_source, source_window_size=source_window_size, include_locals=include_locals, point_cls=point_cls) @no_ref_detect def serialize_callstack_point( obj, *, ctx, include_source=True, source_window_size=2, include_locals=False, point_cls=elements.lang.TracebackPoint, ): assert isinstance(obj, types.FrameType) return _serialize_traceback_point( obj, obj, obj.f_lineno, ctx=ctx, include_source=include_source, source_window_size=source_window_size, include_locals=include_locals, point_cls=point_cls) @serializer.register(types.TracebackType) def serialize_traceback(obj, *, ctx): result = [] current = obj while current is not None: result.append(serialize_traceback_point(current, ctx=ctx)) current = current.tb_next return elements.lang.Traceback(items=result, id=id(obj)) @serializer.register(BaseException) def serialize_exception(obj, *, ctx): cause = context = None if obj.__cause__ is not None and obj.__cause__ is not obj: cause = serialize(obj.__cause__, ctx=ctx) elif ( not obj.__suppress_context__ and obj.__context__ is not None and obj.__context__ is not obj): context = serialize(obj.__context__, ctx=ctx) details_context = None contexts = [] for ex_context in exceptions.iter_contexts(obj): if isinstance(ex_context, exceptions.DefaultExceptionContext): details_context = ex_context else: contexts.append(serialize(ex_context, ctx=ctx)) obj_traceback = obj.__traceback__ if obj_traceback: traceback = elements.lang.ExceptionContext( title='Traceback', body=[serialize(obj_traceback, ctx=ctx)]) if isinstance(obj, SyntaxError): point = elements.lang.TracebackPoint( name='', lineno=obj.lineno, colno=obj.offset, filename=obj.filename or '') point.load_source() traceback.body[0].items.append(point) contexts.append(traceback) if details_context is not None: contexts.append(serialize(details_context, ctx=ctx)) markup = elements.lang.Exception( class_module=obj.__class__.__module__, classname=obj.__class__.__name__, msg=str(obj), contexts=contexts, cause=cause, context=context, id=id(obj)) if isinstance(obj, BaseExceptionGroup): markup = elements.doc.Section( body=[ markup, elements.doc.Section( title='Grouped exceptions', body=[ elements.doc.SubNode(body=serializer(sub, ctx=ctx)) for sub in obj.exceptions ] ) ], ) return markup @serializer.register(exceptions.ExceptionContext) def serialize_generic_exception_context(obj, *, ctx): msg = 'No markup serializer for {!r} context'.format(obj) return elements.lang.ExceptionContext( title=obj.title, body=[elements.doc.Text(text=msg)]) @serializer.register(exceptions.DefaultExceptionContext) def serialize_default_exception_context(obj, *, ctx): body = [] if obj.details: txt = elements.doc.Text(text='Details: {}'.format(obj.details)) body.append(elements.doc.Section(body=[txt])) if obj.hint: txt = elements.doc.Text(text='Hint: {}'.format(obj.hint)) body.append(elements.doc.Section(body=[txt])) return elements.lang.ExceptionContext(title=obj.title, body=body) @serializer.register(type(None)) @no_ref_detect def serialize_none(obj, *, ctx): return elements.lang.Constants.none @serializer.register(bool) @no_ref_detect def serialize_bool(obj, *, ctx): if obj: return elements.lang.Constants.true else: return elements.lang.Constants.false @serializer.register(int) @serializer.register(float) @serializer.register(decimal.Decimal) @no_ref_detect def serialize_number(obj, *, ctx): return elements.lang.Number(num=obj) @serializer.register(str) @no_ref_detect def serialize_str(obj, *, ctx): return elements.lang.String(str=obj) @serializer.register(collections.UserList) @serializer.register(list) @serializer.register(tuple) @serializer.register(collections.abc.Set) @serializer.register(weakref.WeakSet) @serializer.register(set) @serializer.register(frozenset) @no_ref_detect def serialize_sequence(obj, *, ctx, trim_at=100): els = [] cnt = 0 trim = ctx.trim if isinstance(obj, tuple): brackets = "()" elif isinstance(obj, (collections.abc.Set, weakref.WeakSet, set, frozenset)): brackets = "{}" else: brackets = "[]" for cnt, item in enumerate(obj): els.append(serialize(item, ctx=ctx)) if trim and cnt >= trim_at: break return elements.lang.List( items=els, id=id(obj), brackets=brackets, trimmed=(trim and cnt >= trim_at)) @serializer.register(dict) @serializer.register(collections.abc.Mapping) @no_ref_detect def serialize_mapping(obj, *, ctx, trim_at=100): map = collections.OrderedDict() cnt = 0 trim = ctx.trim for cnt, (key, value) in enumerate(obj.items()): if not isinstance(key, str): key = repr(key) if ctx.censored(key) and value is not None: value = '********' map[key] = serialize(value, ctx=ctx) if trim and cnt >= trim_at: break return elements.lang.Dict( items=map, id=id(obj), trimmed=(trim and cnt >= trim_at)) def serialize_dataclass(obj, *, ctx): fields = type(obj).__dataclass_fields__ node = elements.lang.TreeNode( id=id(obj), name=f'{type(obj).__name__}') for fieldname, field in fields.items(): try: val = getattr(obj, fieldname) except AttributeError: continue if not field.repr: continue node.add_child( label=fieldname, node=serialize(val, ctx=ctx)) return node @serializer.register(object) @no_ref_detect def serialize_unknown_object(obj, *, ctx): return elements.lang.Object( id=id(obj), class_module=type(obj).__module__, classname=type(obj).__name__, repr=xrepr(obj, max_len=200)) def _serialize_known_object(obj, attrs, *, ctx): map = collections.OrderedDict() for attr in attrs: map[attr] = serialize(getattr(obj, attr, None), ctx=ctx) return elements.lang.Object( id=id(obj), class_module=obj.__class__.__module__, classname=obj.__class__.__name__, attributes=map) ================================================ FILE: edb/common/markup/serializer/code.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb.common.markup.elements import code as code_el try: from pygments import token, lexers except ImportError: # No pygments def serialize_code(code, lexer='does not matter'): return code_el.Code(tokens=[code_el.Token(val=code)]) else: import functools _TOKEN_MAP = { token.Token: code_el.Token, token.Whitespace: code_el.Whitespace, token.Comment: code_el.Comment, token.Keyword: code_el.Keyword, token.Keyword.Type: code_el.Type, token.Keyword.Constant: code_el.Constant, token.Operator: code_el.Operator, token.Operator.Word: code_el.Keyword, token.Name: code_el.Name, token.Name.Builtin: code_el.BuiltinName, token.Name.Function: code_el.FunctionName, token.Name.Class: code_el.ClassName, token.Name.Constant: code_el.Constant, token.Name.Decorator: code_el.Decorator, token.Name.Attribute: code_el.Attribute, token.Name.Tag: code_el.Tag, token.Name.Builtin.Pseudo: code_el.Constant, token.Punctuation: code_el.Punctuation, token.String: code_el.String, token.Number: code_el.Number, token.Error: code_el.Error } @functools.lru_cache(100) def get_code_class(token_type): cls = _TOKEN_MAP.get(token_type) while cls is None: token_type = token_type[:-1] cls = _TOKEN_MAP.get(token_type) if cls is None: cls = code_el.Token return cls class MarkupFormatter: def format(self, tokens): result = [] for token_type, value in tokens: cls = get_code_class(token_type) result.append(cls(val=value)) return code_el.Code(tokens=result) def serialize_code(code, lexer='python'): lexer = lexers.get_lexer_by_name(lexer, stripall=True) return MarkupFormatter().format(lexer.get_tokens(code)) ================================================ FILE: edb/common/markup/serializer/logging.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import logging from . import base @base.serializer.register(logging.LogRecord) # type: ignore def serialize_logging_record(obj, *, ctx): return base._serialize_known_object( obj, (attr for attr in dir(obj) if not attr.startswith('_')), ctx=ctx) ================================================ FILE: edb/common/ordered.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, Hashable, Iterable, Iterator, MutableSet, ) import collections import collections.abc class OrderedSet[K: Hashable](MutableSet[K]): map: dict[K, None] def __init__(self, iterable: Optional[Iterable[K]] = None) -> None: if iterable is not None: self.map = {v: None for v in iterable} else: self.map = {} def add(self, item: K) -> None: self.map[item] = None def discard(self, item: K) -> None: self.map.pop(item, None) def update(self, iterable: Iterable[K]) -> None: for item in iterable: self.map[item] = None def replace(self, existing: K, new: K) -> None: if existing not in self.map: raise LookupError(f'{existing!r} is not in set') self.map[existing] = None difference_update = collections.abc.MutableSet.__isub__ symmetric_difference_update = collections.abc.MutableSet.__ixor__ intersection_update = collections.abc.MutableSet.__iand__ def __len__(self) -> int: return len(self.map) def __contains__(self, item: Any) -> bool: return item in self.map def __iter__(self) -> Iterator[K]: return iter(self.map) def __reversed__(self) -> Iterator[K]: return reversed(self.map.keys()) def __repr__(self) -> str: if not self: return '%s()' % (self.__class__.__name__, ) return '%s(%r)' % (self.__class__.__name__, list(self)) def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): return len(self) == len(other) and self.map == other.map elif other is None: return False else: return not self.isdisjoint(other) def copy(self) -> OrderedSet[K]: return self.__class__(self) def clear(self) -> None: self.map.clear() ================================================ FILE: edb/common/ordered.pyi ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This stub file is needed so that __and__, __or__, __sub__, __xor__, and so on properly return the instance of the *current class*, not the abstract versions. """ from __future__ import annotations from typing import ( AbstractSet, Any, Hashable, Iterable, Iterator, MutableSet, Optional, ) class OrderedSet[_H: Hashable](MutableSet[_H]): def __init__(self, iterable: Optional[Iterable[_H]] = None) -> None: ... def __and__(self, s: AbstractSet[Any]) -> OrderedSet[_H]: ... def __or__[_T](self, s: AbstractSet[_T]) -> OrderedSet[_H | _T]: ... def __sub__(self, s: AbstractSet[Any]) -> OrderedSet[_H]: ... def __xor__[_T](self, s: AbstractSet[_T]) -> OrderedSet[_H | _T]: ... def __ior__[_S](self, s: AbstractSet[_S]) -> OrderedSet[_H | _S]: ... def __iand__[_T](self, s: AbstractSet[Any]) -> OrderedSet[_T]: ... def __ixor__[_S](self, s: AbstractSet[_S]) -> OrderedSet[_H | _S]: ... def __isub__[_T](self, s: AbstractSet[Any]) -> OrderedSet[_H]: ... difference_update = MutableSet.__isub__ symmetric_difference_update = MutableSet.__ixor__ intersection_update = MutableSet.__iand__ def add(self, item: _H) -> None: ... def discard(self, item: _H) -> None: ... def update(self, s: Iterable[_H]) -> None: ... def replace(self, existing: _H, new: _H) -> None: ... def __len__(self) -> int: ... def __contains__(self, item: Any) -> bool: ... def __iter__(self) -> Iterator[_H]: ... def __reversed__(self) -> Iterator[_H]: ... def copy(self) -> OrderedSet[_H]: ... def clear(self) -> None: ... ================================================ FILE: edb/common/parametric.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, ClassVar, Generic, Optional, TypeVar, get_type_hints, ) import functools import types import sys from edb.common import typing_inspect __all__ = [ "ParametricType", "SingleParametricType", "KeyValueParametricType", ] T = TypeVar("T") V = TypeVar("V") try: from types import GenericAlias except ImportError: from typing import _GenericAlias as GenericAlias # type: ignore class ParametricType: types: ClassVar[Optional[tuple[type, ...]]] = None orig_args: ClassVar[Optional[tuple[type, ...]]] = None _forward_refs: ClassVar[dict[str, tuple[int, str]]] = {} _type_param_map: ClassVar[dict[Any, str]] = {} _non_type_params: ClassVar[dict[int, type]] = {} def __init_subclass__(cls) -> None: super().__init_subclass__() if cls.types is not None: return elif ParametricType in cls.__bases__: cls._init_parametric_base() elif any(issubclass(b, ParametricType) for b in cls.__bases__): cls._init_parametric_user() @classmethod def _init_parametric_base(cls) -> None: """Initialize a direct subclass of ParametricType""" # Direct subclasses of ParametricType must declare # ClassVar attributes corresponding to the Generic type vars. # For example: # class P(ParametricType, Generic[T, V]): # t: ClassVar[Type[T]] # v: ClassVar[Type[V]] params = getattr(cls, '__parameters__', None) if not params: raise TypeError( f'{cls} must be declared as Generic' ) mod = sys.modules[cls.__module__] annos = get_type_hints(cls, mod.__dict__) param_map = {} for attr, t in annos.items(): if not typing_inspect.is_classvar(t): continue args = typing_inspect.get_args(t) # ClassVar constructor should have the check, but be extra safe. assert len(args) == 1 arg = args[0] if typing_inspect.get_origin(arg) is not type: continue arg_args = typing_inspect.get_args(arg) # Likewise, rely on Type checking its stuff in the constructor assert len(arg_args) == 1 if not typing_inspect.is_typevar(arg_args[0]): continue if arg_args[0] in params: param_map[arg_args[0]] = attr for param in params: if param not in param_map: raise TypeError( f'{cls.__name__}: missing ClassVar for' f' generic parameter {param}' ) cls._type_param_map = param_map @classmethod def _init_parametric_user(cls) -> None: """Initialize an indirect descendant of ParametricType.""" # For ParametricType grandchildren we have to deal with possible # TypeVar remapping and generally check for type sanity. ob = getattr(cls, '__orig_bases__', ()) generic_params: list[type] = [] for b in ob: if ( isinstance(b, type) and not isinstance(b, GenericAlias) and issubclass(b, ParametricType) and b is not ParametricType ): raise TypeError( f'{cls.__name__}: missing one or more type arguments for' f' base {b.__name__!r}' ) if not typing_inspect.is_generic_type(b): continue org = typing_inspect.get_origin(b) if not isinstance(org, type): continue if not issubclass(org, ParametricType): generic_params.extend(getattr(b, '__parameters__', ())) continue base_params = getattr(org, '__parameters__', ()) base_non_type_params = getattr(org, '_non_type_params', {}) args = typing_inspect.get_args(b) expected = len(base_params) if len(args) != expected: raise TypeError( f'{b.__name__} expects {expected} type arguments' f' got {len(args)}' ) base_map = dict(cls._type_param_map) subclass_map = {} for i, arg in enumerate(args): if i in base_non_type_params: continue if not typing_inspect.is_typevar(arg): raise TypeError( f'{b.__name__} expects all arguments to be' f' TypeVars' ) base_typevar = base_params[i] attr = base_map.get(base_typevar) if attr is not None: subclass_map[arg] = attr if len(subclass_map) != len(base_map): raise TypeError( f'{cls.__name__}: missing one or more type arguments for' f' base {org.__name__!r}' ) cls._type_param_map = subclass_map cls._non_type_params = { i: p for i, p in enumerate(generic_params) if p not in cls._type_param_map } def __init__(self) -> None: if self._forward_refs: raise TypeError( f"{type(self)!r} unresolved type parameters" ) if self.types is None: raise TypeError( f"{type(self)!r} must be parametrized to instantiate" ) super().__init__() @classmethod @functools.lru_cache() def __class_getitem__( cls, params: type | str | tuple[type | str, ...] ) -> type[ParametricType]: """Return a dynamic subclass parametrized with `params`. We cannot use `_GenericAlias` provided by `Generic[T]` because the default `__class_getitem__` on `_GenericAlias` is not a real type and so it doesn't retain information on generics on the class. Even on the object, it adds the relevant `__orig_class__` link too late, after `__init__()` is called. That means we wouldn't be able to type-check in the initializer using built-in `Generic[T]`. """ if cls.types is not None: raise TypeError(f"{cls!r} is already parametrized") if not isinstance(params, tuple): params = (params,) all_params = params type_params = [] for i, param in enumerate(all_params): if i not in cls._non_type_params: type_params.append(param) params_str = ", ".join(_type_repr(a) for a in all_params) name = f"{cls.__name__}[{params_str}]" bases = (cls,) type_dict: dict[str, Any] = { "types": tuple(type_params), "orig_args": all_params, "__module__": cls.__module__, } forward_refs: dict[str, tuple[int, str]] = {} tuple_to_attr: dict[int, str] = {} if cls._type_param_map: gen_params = getattr(cls, '__parameters__', ()) for i, gen_param in enumerate(gen_params): attr = cls._type_param_map.get(gen_param) if attr: tuple_to_attr[i] = attr expected = len(gen_params) actual = len(params) if expected != actual: raise TypeError( f"type {cls.__name__!r} expects {expected} type" f" parameter{'s' if expected != 1 else ''}," f" got {actual}" ) for i, attr in tuple_to_attr.items(): type_dict[attr] = all_params[i] if not all(isinstance(param, type) for param in type_params): if all( type(param) is TypeVar # type: ignore[comparison-overlap] for param in type_params ): # All parameters are type variables: return the regular generic # alias to allow proper subclassing. generic = super(ParametricType, cls) return generic.__class_getitem__(all_params) # type: ignore else: forward_refs = { param: (i, tuple_to_attr[i]) for i, param in enumerate(type_params) if isinstance(param, str) } if not forward_refs: raise TypeError( f"{cls!r} expects types as type parameters") result = type(name, bases, type_dict) assert issubclass(result, ParametricType) result._forward_refs = forward_refs return result @classmethod def is_fully_resolved(cls) -> bool: return not cls._forward_refs @classmethod def resolve_types(cls, globalns: dict[str, Any]) -> None: if cls.types is None: raise TypeError( f"{cls!r} is not parametrized" ) if not cls._forward_refs: return types = list(cls.types) for ut, (idx, attr) in cls._forward_refs.items(): t = eval(ut, globalns, {}) if isinstance(t, type) and not isinstance(t, GenericAlias): types[idx] = t setattr(cls, attr, t) else: raise TypeError( f"{cls!r} expects types as type parameters, got {t!r:.100}" ) cls.types = tuple(types) cls._forward_refs = {} @classmethod def is_anon_parametrized(cls) -> bool: return cls.__name__.endswith(']') def __reduce__(self) -> tuple[Any, ...]: raise NotImplementedError( 'must implement explicit __reduce__ for ParametricType subclass' ) class SingleParametricType(ParametricType, Generic[T]): # noqa: UP046 # We ignore UP046 (Generic[T]) because Python 3.12.2 typing.get_type_hints # has problems with resolving `T` when it is defined using the new syntax. type: ClassVar[type[T]] # type: ignore class KeyValueParametricType(ParametricType, Generic[T, V]): # noqa: UP046 # We ignore UP046 (Generic[T]) because Python 3.12.2 typing.get_type_hints # has problems with resolving `T` when it is defined using the new syntax. keytype: ClassVar[type[T]] # type: ignore valuetype: ClassVar[type[V]] # type: ignore def _type_repr(obj: Any) -> str: if isinstance(obj, type): if obj.__module__ == "builtins": return obj.__qualname__ return f"{obj.__module__}.{obj.__qualname__}" if isinstance(obj, types.FunctionType): return obj.__name__ return repr(obj) ================================================ FILE: edb/common/parsing.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2010-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, Optional, cast ) import json import logging import os import sys import types import parsing from edb.common import debug, span Span = span.Span logger = logging.getLogger('edb.common.parsing') class ParserSpecIncompatibleError(Exception): pass class Token(parsing.Token): token_map: dict[Any, Any] = {} _token: str = "" def __init_subclass__( cls, *, token=None, lextoken=None, is_internal=False, **kwargs ): super().__init_subclass__(**kwargs) if is_internal: return if token is None: if not cls.__name__.startswith('T_'): raise Exception( 'Token class names must either start with T_ or have ' 'a token parameter') token = cls.__name__[2:] if lextoken is None: lextoken = token cls._token = token Token.token_map[lextoken] = cls if not cls.__doc__: doc = '%%token %s' % token prec = Precedence.for_token(token) if prec: doc += ' [%s]' % prec.__name__ cls.__doc__ = doc def __init__(self, val, clean_value, span=None): super().__init__() self.val = val self.clean_value = clean_value self.span = span def __repr__(self): return '' % (self.__class__._token, self.val) def inline(argument_index: int): """ When added to grammar productions, it makes the method equivalent to: self.val = kids[argument_index].val """ def decorator(func: Any): func.inline_index = argument_index return func return decorator class Nonterm(parsing.Nonterm): span: Span def __init_subclass__(cls, *, is_internal=False, **kwargs): """Add docstrings to class and reduce functions If no class docstring is present, set it to '%nonterm'. If any reduce function (ie. of the form `reduce(_\\w+)+` does not have a docstring, a new one is generated based on the function name. See https://github.com/MagicStack/parsing for more information. Keyword arguments: is_internal -- internal classes do not need docstrings and processing can be skipped """ super().__init_subclass__(**kwargs) if is_internal: return if not cls.__doc__: cls.__doc__ = '%nonterm' for name, attr in cls.__dict__.items(): if (name.startswith('reduce_') and isinstance(attr, types.FunctionType)): if attr.__doc__ is None: tokens = name.split('_') if name == 'reduce_empty': tokens = ['reduce', ''] doc = r'%reduce {}'.format(' '.join(tokens[1:])) prec = getattr(attr, '__parsing_precedence__', None) if prec is not None: doc += ' [{}]'.format(prec) inline_index = getattr(attr, 'inline_index', None) attr = lambda self, *args, meth=attr: meth(self, *args) attr.__doc__ = doc attr.inline_index = inline_index setattr(cls, name, attr) class ListNonterm(Nonterm, is_internal=True): def __init_subclass__( cls, *, element, separator=None, is_internal=False, allow_trailing_separator=False, **kwargs, ): """Create reductions for list classes. If trailing separator is not allowed, the class can handle all reductions directly. L := E L := L S E If trailing separator is allowed, create an inner class to handle all non-trailing reductions. Then the class handles the trailing separator. I := E I := I S E L := I L := I S The inner class is added to the same module as the class. """ if not is_internal: if not allow_trailing_separator: # directly handle the list ListNonterm.add_list_reductions( cls, element=element, separator=separator ) else: # create inner list class and add to same module mod = sys.modules[cls.__module__] def inner_cls_exec(ns): ns['__module__'] = mod.__name__ return ns inner_cls_name = cls.__name__ + 'Inner' inner_cls_kwds = dict(element=element, separator=separator) inner_cls = types.new_class(inner_cls_name, (ListNonterm,), inner_cls_kwds, inner_cls_exec) setattr(mod, inner_cls_name, inner_cls) # create reduce_inner function separator_name = ListNonterm.component_name(separator) setattr(cls, 'reduce_{}'.format(inner_cls_name), lambda self, inner: ( ListNonterm._reduce_inner(self, inner) )) setattr(cls, 'reduce_{}_{}'.format(inner_cls_name, separator_name), lambda self, inner, sep: ( ListNonterm._reduce_inner(self, inner) )) # reduce functions must be present before calling superclass super().__init_subclass__(is_internal=is_internal, **kwargs) def __iter__(self): return iter(self.val) def __len__(self): return len(self.val) @staticmethod def add_list_reductions(cls, *, element, separator=None): element_name = ListNonterm.component_name(element) separator_name = ListNonterm.component_name(separator) if separator_name: tail_prod = lambda self, lst, sep, el: ( ListNonterm._reduce_list(self, lst, el) ) tail_prod_name = 'reduce_{}_{}_{}'.format( cls.__name__, separator_name, element_name) else: tail_prod = lambda self, lst, el: ( ListNonterm._reduce_list(self, lst, el) ) tail_prod_name = 'reduce_{}_{}'.format( cls.__name__, element_name) setattr(cls, tail_prod_name, tail_prod) setattr(cls, 'reduce_' + element_name, lambda self, el: ListNonterm._reduce_el(self, el)) @staticmethod def component_name(component: type) -> Optional[str]: if component is None: return None elif issubclass(component, Token): return component._token elif issubclass(component, Nonterm): return component.__name__ else: raise Exception( 'List component must be a Token or Nonterm') @staticmethod def _reduce_list(self, lst, el): if el.val is None: tail = [] else: tail = [el.val] self.val = lst.val + tail @staticmethod def _reduce_el(self, el): if el.val is None: tail = [] else: tail = [el.val] self.val = tail @staticmethod def _reduce_inner(self, inner): self.val = inner.val def precedence(precedence): """Decorator to set production precedence.""" def decorator(func): func.__parsing_precedence__ = precedence.__name__ return func return decorator class Precedence(parsing.Precedence): token_prec_map: dict[Any, Any] = {} last: dict[Any, Any] = {} def __init_subclass__( cls, *, assoc, tokens=None, prec_group=None, rel_to_last='>', is_internal=False, **kwargs, ): super().__init_subclass__(**kwargs) if is_internal: return if not cls.__doc__: doc = '%%%s' % assoc last = Precedence.last.get(prec_group) if last: doc += ' %s%s' % (rel_to_last, last.__name__) cls.__doc__ = doc if tokens: for token in tokens: existing = None try: existing = Precedence.token_prec_map[token] except KeyError: Precedence.token_prec_map[token] = cls else: raise Exception( 'token {} has already been set precedence {}'.format( token, existing)) Precedence.last[prec_group] = cls @classmethod def for_token(cls, token_name): return Precedence.token_prec_map.get(token_name) def load_parser_spec(mod: types.ModuleType) -> parsing.Spec: return parsing.Spec( mod, skinny=not debug.flags.edgeql_parser, logFile=_localpath(mod, "log"), verbose=bool(debug.flags.edgeql_parser), ) def _localpath(mod, type): return os.path.join( os.path.dirname(mod.__file__), mod.__name__.rpartition('.')[2] + '.' + type) def load_spec_productions( production_names: list[tuple[str, str]], mod: types.ModuleType ) -> list[tuple[type, Callable]]: productions: list[tuple[Any, Callable]] = [] for class_name, method_name in production_names: cls = mod.__dict__.get(class_name, None) if not cls: # for NontermStart productions.append((parsing.Nonterm(), lambda *args: None)) continue method = cls.__dict__[method_name] productions.append((cls, method)) return productions def spec_to_json(spec: parsing.Spec) -> str: # Converts a ParserSpec into JSON. Called from edgeql-parser Rust crate. assert spec.pureLR token_map: dict[str, str] = { v._token: c for c, v in Token.token_map.items() } # productions productions_all: set[Any] = set() for st_actions in spec.actions(): for _, acts in st_actions.items(): act = cast(Any, acts[0]) if 'ReduceAction' in str(type(act)): prod = act.production productions_all.add(prod) productions, production_id = sort_productions(productions_all) # actions actions = [] for st_actions in spec.actions(): out_st_actions = [] for tok, acts in st_actions.items(): act = cast(Any, acts[0]) str_tok = token_map.get(str(tok), str(tok)) if 'ShiftAction' in str(type(act)): action_obj: Any = { 'Shift': int(act.nextState) } else: prod = act.production action_obj = { 'Reduce': { 'production_id': production_id[prod], 'non_term': str(prod.lhs), 'cnt': len(prod.rhs), } } out_st_actions.append((str_tok, action_obj)) out_st_actions.sort(key=lambda item: item[0]) actions.append(out_st_actions) # goto goto = [] for st_goto in spec.goto(): out_goto = [] for nterm, action in st_goto.items(): out_goto.append((str(nterm), action)) goto.append(out_goto) # inlines inlines = [] for prod in productions: id = production_id[prod] inline = getattr(prod.method, 'inline_index', None) if inline is not None: assert isinstance(inline, int) inlines.append((id, inline)) res = { 'actions': actions, 'goto': goto, 'start': str(spec.start_sym()), 'inlines': inlines, 'production_names': list(map(production_name, productions)), } return json.dumps(res) def sort_productions( productions_all: set[Any], ) -> tuple[list[Any], dict[Any, int]]: productions = list(productions_all) productions.sort(key=production_name) productions_id = {prod: id for id, prod in enumerate(productions)} return (productions, productions_id) def production_name(prod: Any) -> tuple[str, ...]: return tuple(prod.qualified.split('.')[-2:]) ================================================ FILE: edb/common/prometheus.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """An implementation of prometheus metrics protocol. Key differences from the official "prometheus_client" package: 1. Not thread-safe. We're an async application and don't use threads so there's no need for thread-safety. 2. The code is as simple as possible; no complicated polymorphism that can slow down metrics collection in runtime. 3. It's more than 4x faster, likely because of (1) and (2). 4. No global state. All metrics are explicitly contained in an explicitly created "registry" instance. 5. This code can be potentially cythonized (or mypyc-ified) for extra performance. 6. The tests (tests/common/test_prometheus.py) ensure that the output is exactly equal to what prometheus_client generates. It's a bug otherwise. See more for details: * Open Metrics standard: https://github.com/OpenObservability/OpenMetrics/blob/main/specification/OpenMetrics.md * Prometheus documentation: https://prometheus.io/docs/practices/naming/ * Prometheus official Python client: https://github.com/prometheus/client_python """ from __future__ import annotations import bisect import enum import functools import math import time import typing __all__ = ('Registry', 'Unit', 'calc_buckets') def calc_buckets( start: float, upper_bound: float, /, *, increment_ratio: float = 1.20 ) -> tuple[float, ...]: """Calculate histogram buckets on a logarithmic scale.""" # See https://amplitude.com/blog/2014/08/06/optimal-streaming-histograms # for more details. # (Says a long standing comment, but this isn't what that post recommends!) result: list[float] = [] while start <= upper_bound: result.append(start) start *= increment_ratio return tuple(result) def per_order_buckets( start: float, end: float, *, base: float=10.0, entries_per_order=4, ) -> tuple[float, ...]: # See https://amplitude.com/blog/2014/08/06/optimal-streaming-histograms # for more details. # (Actually, for this one.) result: list[float] = [start] next = start * base while next <= end: for i in range(1, entries_per_order): val = next / entries_per_order * i if val > result[-1]: result.append(val) result.append(next) next *= base return tuple(result) class Unit(enum.Enum): # https://prometheus.io/docs/practices/naming/#base-units SECONDS = 'seconds' CELSIUS = 'celsius' METERS = 'meters' BYTES = 'bytes' RATIO = 'ratio' VOLTS = 'volts' AMPERES = 'amperes' JOULES = 'joules' GRAMS = 'grams' class Registry: _metrics: list[BaseMetric] _metrics_names: set[str] _prefix: str | None def __init__(self, *, prefix: str | None = None): self._metrics = [] self._metrics_names = set() self._prefix = prefix def _add_metric(self, metric: BaseMetric) -> None: name = metric.get_name() if name in self._metrics_names: raise ValueError( f'a metric with a name {name!r} has already been registered') self._metrics.append(metric) self._metrics_names.add(name) def now(self) -> float: return time.time() def set_info(self, name: str, desc: str, /, **kwargs: str) -> None: self._add_metric(Info(self, name, desc, **kwargs)) def new_counter( self, name: str, desc: str, /, *, unit: Unit | None = None, ) -> Counter: counter = Counter(self, name, desc, unit) self._add_metric(counter) return counter def new_labeled_counter( self, name: str, desc: str, /, *, labels: tuple[str, ...], unit: Unit | None = None, ) -> LabeledCounter: counter = LabeledCounter(self, name, desc, unit, labels=labels) self._add_metric(counter) return counter def new_gauge( self, name: str, desc: str, /, *, unit: Unit | None = None, ) -> Gauge: gauge = Gauge(self, name, desc, unit) self._add_metric(gauge) return gauge def new_labeled_gauge( self, name: str, desc: str, /, *, unit: Unit | None = None, labels: tuple[str, ...], ) -> LabeledGauge: gauge = LabeledGauge(self, name, desc, unit, labels=labels) self._add_metric(gauge) return gauge def new_histogram( self, name: str, desc: str, /, *, unit: Unit | None = None, buckets: typing.Sequence[float] | None = None, ) -> Histogram: hist = Histogram(self, name, desc, unit, buckets=buckets) self._add_metric(hist) return hist def new_labeled_histogram( self, name: str, desc: str, /, *, unit: Unit | None = None, buckets: typing.Sequence[float] | None = None, labels: tuple[str, ...], ) -> LabeledHistogram: hist = LabeledHistogram( self, name, desc, unit, buckets=buckets, labels=labels ) self._add_metric(hist) return hist def generate(self, **label_filters: str) -> str: buffer: list[str] = [] for metric in self._metrics: metric._generate(buffer, **label_filters) buffer.append('') return '\n'.join(buffer) class BaseMetric: _type: str _name: str _desc: str _unit: Unit | None _created: float _registry: Registry PROHIBITED_SUFFIXES = ( '_count', '_created', '_total', '_sum', '_bucket', '_gcount', '_gsum', '_info', ) PROHIBITED_PREFIXES = ( '_', 'python_', 'prometheus_', ) PROHIBITED_LABELS = ( 'quantile', 'le' ) def __init__( self, registry: Registry, name: str, desc: str, unit: Unit | None = None, /, ) -> None: self._registry = registry name = self._augment_metric_name(name) self._validate_name(name) if unit is not None: name += '_' + unit.value self._name = name self._desc = desc self._unit = unit self._created = registry.now() def _augment_metric_name(self, name: str) -> str: if self._registry._prefix is not None: name = f'{self._registry._prefix}_{name}' return name def get_name(self) -> str: return self._name def _validate_name(self, name: str) -> None: if (name.startswith(self.PROHIBITED_PREFIXES) or name.endswith(self.PROHIBITED_SUFFIXES)): raise ValueError(f'invalid metrics name: {name!r}') def _validate_label_names(self, labels: tuple[str, ...]) -> None: for label in labels: if label.startswith('_') or label in self.PROHIBITED_LABELS: raise ValueError(f'invalid label name: {label!r}') def _validate_label_values( self, labels: tuple[str, ...], values: tuple[str, ...] ) -> None: if len(values) != len(labels): raise ValueError( f'missing values for labels: {labels[len(values):]!r}') for name, val in zip(labels, values): if not val: raise ValueError(f'empty value for label {name!r}') def _make_label_filter( self, labels: tuple[str, ...], label_filters: dict[str, str], ) -> typing.Callable[[tuple[str, ...]], bool]: if not label_filters: return lambda _: True try: label_by_idx = [ (labels.index(label), label_val) for label, label_val in label_filters.items() ] except ValueError: return lambda _: False def label_filter(label_values: tuple[str, ...]) -> bool: for idx, label_val in label_by_idx: if label_values[idx] != label_val: return False return True return label_filter def _generate(self, buffer: list[str], **label_filters: str) -> None: raise NotImplementedError class Info(BaseMetric): _type = 'info' _name: str _desc: str _registry: Registry _labels: dict[str, str] def __init__(self, *args: typing.Any, **labels: str) -> None: super().__init__(*args) self._validate_label_names(tuple(labels.keys())) self._labels = labels def _generate(self, buffer: list[str], **label_filters: str) -> None: if label_filters: return desc = _format_desc(self._desc) buffer.append(f'# HELP {self._name}_info {desc}') buffer.append(f'# TYPE {self._name}_info gauge') fmt_label = ','.join( f'{label}="{_format_label_val(value)}"' for label, value in self._labels.items() ) buffer.append(f'{self._name}_info{{{fmt_label}}} 1.0') class BaseCounter(BaseMetric): _type = 'counter' _suffix = '_total' _render_created = True _value: float def __init__(self, *args: typing.Any) -> None: super().__init__(*args) self._value = 0 def inc(self, value: float = 1.0) -> None: if value < 0: raise ValueError( 'counter cannot be incremented with a negative value') self._value += value def _generate(self, buffer: list[str], **label_filters: str) -> None: if label_filters: return desc = _format_desc(self._desc) buffer.append(f'# HELP {self._name}{self._suffix} {desc}') buffer.append(f'# TYPE {self._name}{self._suffix} {self._type}') buffer.append(f'{self._name}{self._suffix} {float(self._value)}') if self._render_created: buffer.append(f'# HELP {self._name}_created {desc}') buffer.append(f'# TYPE {self._name}_created gauge') buffer.append(f'{self._name}_created {float(self._created)}') class BaseLabeledCounter(BaseMetric): _type = 'counter' _suffix = '_total' _render_created = True _labels: tuple[str, ...] _metric_values: dict[tuple[str, ...], float] _metric_created: dict[tuple[str, ...], float] def __init__(self, *args: typing.Any, labels: tuple[str, ...]) -> None: super().__init__(*args) self._validate_label_names(labels) self._labels = labels self._metric_values = {} self._metric_created = {} def inc(self, value: float = 1.0, *labels: str) -> None: self._validate_label_values(self._labels, labels) if value < 0: raise ValueError( 'counter cannot be incremented with a negative value') try: self._metric_values[labels] += value except KeyError: self._metric_values[labels] = value self._metric_created[labels] = self._registry.now() def _generate(self, buffer: list[str], **label_filters: str) -> None: desc = _format_desc(self._desc) buffer.append(f'# HELP {self._name}{self._suffix} {desc}') buffer.append(f'# TYPE {self._name}{self._suffix} {self._type}') filter_func = self._make_label_filter(self._labels, label_filters) for labels, value in self._metric_values.items(): if not filter_func(labels): continue fmt_label = ','.join( f'{label}="{_format_label_val(label_val)}"' for label, label_val in zip(self._labels, labels) ) buffer.append( f'{self._name}{self._suffix}{{{fmt_label}}} {float(value)}' ) if self._render_created and self._metric_values: buffer.append(f'# HELP {self._name}_created {desc}') buffer.append(f'# TYPE {self._name}_created gauge') for labels, value in self._metric_created.items(): if not filter_func(labels): continue fmt_label = ','.join( f'{label}="{_format_label_val(label_val)}"' for label, label_val in zip(self._labels, labels) ) buffer.append( f'{self._name}_created{{{fmt_label}}} {float(value)}' ) def clear(self, label_filter: typing.Callable[..., bool]) -> None: for label in list(self._metric_values): if label_filter(*label): self._metric_values.pop(label) self._metric_created.pop(label, None) class _TotalMixin(BaseMetric): def _augment_metric_name(self, name: str) -> str: name = super()._augment_metric_name(name) if not name.endswith('_total'): raise TypeError('counter metric name require the "_total" suffix') name = name[:-len('_total')] return name class Counter(_TotalMixin, BaseCounter): pass class LabeledCounter(_TotalMixin, BaseLabeledCounter): pass class Gauge(BaseCounter): _type = 'gauge' _render_created = False _suffix = '' def inc(self, value: float = 1.0) -> None: self._value += value def dec(self, value: float = 1.0) -> None: self._value -= value def set(self, value: float) -> None: self._value = value class LabeledGauge(BaseLabeledCounter): _type = 'gauge' _render_created = False _suffix = '' def inc(self, value: float = 1.0, *labels: str) -> None: self._validate_label_values(self._labels, labels) try: self._metric_values[labels] += value except KeyError: self._metric_values[labels] = value self._metric_created[labels] = self._registry.now() def dec(self, value: float = 1.0, *labels: str) -> None: self.inc(-value, *labels) def set(self, value: float = 1.0, *labels: str) -> None: self._validate_label_values(self._labels, labels) self._metric_values[labels] = value try: self._metric_created[labels] except KeyError: self._metric_created[labels] = self._registry.now() class BaseHistogram(BaseMetric): _type = 'histogram' _buckets: list[float] # Default buckets that many standard prometheus client libraries use. DEFAULT_BUCKETS = [ 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, ] def __init__( self, *args: typing.Any, buckets: typing.Sequence[float] | None = None ) -> None: if buckets is None: buckets = self.DEFAULT_BUCKETS else: buckets = list(buckets) # copy, just in case if buckets != sorted(buckets): raise ValueError('*buckets* must be sorted') if len(buckets) < 2: raise ValueError('*buckets* must have at least 2 numbers') if not math.isinf(buckets[-1]): buckets += [float('+inf')] super().__init__(*args) self._buckets = buckets class Histogram(BaseHistogram): _values: list[float] _sum: float def __init__( self, *args: typing.Any, buckets: typing.Sequence[float] | None = None ) -> None: super().__init__(*args, buckets=buckets) self._sum = 0.0 self._values = [0.0] * len(self._buckets) def observe(self, value: float) -> None: idx = bisect.bisect_left(self._buckets, value) self._values[idx] += 1.0 self._sum += value def _generate(self, buffer: list[str], **label_filters: str) -> None: if label_filters: return desc = _format_desc(self._desc) buffer.append(f'# HELP {self._name} {desc}') buffer.append(f'# TYPE {self._name} histogram') accum = 0.0 for buck, val in zip(self._buckets, self._values): accum += val if math.isinf(buck): if buck > 0: buckf = '+Inf' else: buckf = '-Inf' else: buckf = str(buck) buffer.append(f'{self._name}_bucket{{le="{buckf}"}} {accum}') buffer.append(f'{self._name}_count {accum}') buffer.append(f'{self._name}_sum {self._sum}') buffer.append(f'# HELP {self._name}_created {desc}') buffer.append(f'# TYPE {self._name}_created gauge') buffer.append(f'{self._name}_created {float(self._created)}') class LabeledHistogram(BaseHistogram): _labels: tuple[str, ...] _metric_values: dict[tuple[str, ...], list[float | list[float]]] _metric_created: dict[tuple[str, ...], float] def __init__( self, *args: typing.Any, buckets: typing.Sequence[float] | None = None, labels: tuple[str, ...], ) -> None: super().__init__(*args, buckets=buckets) self._labels = labels self._metric_values = {} self._metric_created = {} def observe(self, value: float, *labels: str) -> None: self._validate_label_values(self._labels, labels) try: metric = self._metric_values[labels] except KeyError: metric = [0.0, [0.0] * len(self._buckets)] self._metric_values[labels] = metric self._metric_created[labels] = self._registry.now() idx = bisect.bisect_left(self._buckets, value) metric[1][idx] += 1.0 # type: ignore metric[0] += value # type: ignore def _generate(self, buffer: list[str], **label_filters: str) -> None: desc = _format_desc(self._desc) buffer.append(f'# HELP {self._name} {desc}') buffer.append(f'# TYPE {self._name} histogram') filter_func = self._make_label_filter(self._labels, label_filters) for labels, values in self._metric_values.items(): if not filter_func(labels): continue fmt_label = ','.join( f'{label}="{_format_label_val(label_val)}"' for label, label_val in zip(self._labels, labels) ) accum = 0.0 for buck, val in zip(self._buckets, values[1]): # type: ignore accum += val if math.isinf(buck): if buck > 0: buckf = '+Inf' else: buckf = '-Inf' else: buckf = str(buck) buffer.append( f'{self._name}_bucket{{le="{buckf}",{fmt_label}}} {accum}' ) buffer.append(f'{self._name}_count{{{fmt_label}}} {accum}') buffer.append(f'{self._name}_sum{{{fmt_label}}} {values[0]}') if self._metric_values: buffer.append(f'# HELP {self._name}_created {desc}') buffer.append(f'# TYPE {self._name}_created gauge') for labels, value in self._metric_created.items(): if not filter_func(labels): continue fmt_label = ','.join( f'{label}="{_format_label_val(label_val)}"' for label, label_val in zip(self._labels, labels) ) buffer.append( f'{self._name}_created{{{fmt_label}}} {float(value)}' ) @functools.lru_cache(maxsize=1024) def _format_desc(desc: str) -> str: return desc.replace('\\', r'\\').replace('\n', r'\n') @functools.lru_cache(maxsize=1024) def _format_label_val(desc: str) -> str: return ( desc.replace('\\', r'\\').replace('\n', r'\n').replace('"', r'\"') ) ================================================ FILE: edb/common/retryloop.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Callable, Optional, ) import asyncio import random import re import time import types def const_backoff(delay: float) -> Callable[[int], float]: return lambda _: delay def exp_backoff( *, factor: float = 0.1, jitter_scale: float = 0.001, ) -> Callable[[int], float]: def _f(i: int) -> float: delay: int = 2 ** i return delay * factor + random.randrange(100) * jitter_scale return _f class RetryLoop: def __init__( self, *, backoff: Callable[[int], float] = const_backoff(0.5), timeout: float, ignore: type[Exception] | tuple[type[Exception], ...] | None = None, ignore_regexp: str | None = None, wait_for: type[Exception] | tuple[type[Exception], ...] | None = None, wait_for_regexp: str | None = None, retry_cb: Callable[[Optional[BaseException]], None] | None = None, ) -> None: self._iteration = 0 self._backoff = backoff self._timeout = timeout self._ignore = ignore if ignore_regexp is None: self._ignore_regexp = None else: self._ignore_regexp = re.compile(ignore_regexp) self._wait_for = wait_for if wait_for_regexp is None: self._wait_for_regexp = None else: self._wait_for_regexp = re.compile(wait_for_regexp) self._started_at = 0.0 self._stop_request = False self._retry_cb = retry_cb def __aiter__(self) -> RetryLoop: return self async def __anext__(self) -> RetryIteration: if self._stop_request: raise StopAsyncIteration if self._started_at == 0: # First run self._started_at = time.monotonic() else: # Second or greater run -- delay before yielding delay = self._backoff(self._iteration) await asyncio.sleep(delay) self._iteration += 1 return RetryIteration(self) class RetryIteration: def __init__(self, loop: RetryLoop) -> None: self._loop = loop async def __aenter__(self) -> RetryIteration: return self async def __aexit__( self, et: type[BaseException], e: BaseException, _tb: types.TracebackType, ) -> bool: elapsed = time.monotonic() - self._loop._started_at if ( self._loop._ignore is not None or self._loop._ignore_regexp is not None ): # Mode 1: Try until we don't get errors matching `ignore` if et is None: self._loop._stop_request = True return False # Propagate if it's not the error we expected. if self._loop._ignore is not None: if not isinstance(e, self._loop._ignore): return False if self._loop._ignore_regexp is not None: if not self._loop._ignore_regexp.search(str(e)): return False if elapsed > self._loop._timeout: # Propagate -- we've run it enough times. return False if self._loop._retry_cb is not None: self._loop._retry_cb(e) # Ignore the exception until next run. return True else: # Mode 2: Try until we fail with an error matching `wait_for` assert ( self._loop._wait_for is not None or self._loop._wait_for_regexp is not None ) if et is not None: if ( self._loop._wait_for is None or isinstance(e, self._loop._wait_for) ) and ( self._loop._wait_for_regexp is None or self._loop._wait_for_regexp.search(str(e)) ): # We're done, we've got what we waited for. self._loop._stop_request = True return True else: # Propagate, it's not the error we expected. return False if elapsed > self._loop._timeout: raise TimeoutError( f'exception matching {self._loop._wait_for!r} ' f'has not happen in {self._loop._timeout} seconds') # Ignore the exception until next run. return True ================================================ FILE: edb/common/secretkey.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright EdgeDB Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Iterable import pathlib import uuid from datetime import datetime, timedelta def generate_tls_cert( tls_cert_file: pathlib.Path, tls_key_file: pathlib.Path, listen_hosts: Iterable[str] ) -> None: from cryptography import x509 from cryptography.hazmat import backends from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509 import oid backend = backends.default_backend() private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=backend ) subject = x509.Name( [x509.NameAttribute(oid.NameOID.COMMON_NAME, "Gel Server")] ) certificate = ( x509.CertificateBuilder() .subject_name(subject) .public_key(private_key.public_key()) .serial_number(int(uuid.uuid4())) .issuer_name(subject) .not_valid_before( datetime.today() - timedelta(days=1) ) .not_valid_after( datetime.today() + timedelta(weeks=1000) ) .add_extension( x509.SubjectAlternativeName( [ x509.DNSName(name) for name in listen_hosts if name not in {'0.0.0.0', '::'} ] ), critical=False, ) .sign( private_key=private_key, algorithm=hashes.SHA256(), backend=backend, ) ) with tls_cert_file.open("wb") as f: f.write(certificate.public_bytes(encoding=serialization.Encoding.PEM)) tls_cert_file.chmod(0o644) with tls_key_file.open("wb") as f: f.write( private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), ) ) tls_key_file.chmod(0o600) ================================================ FILE: edb/common/signalctl.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import asyncio import functools import signal as mod_signal import warnings def _release_waiter(waiter, *args): if not waiter.done(): waiter.set_result(None) class SignalError(Exception): def __init__(self, signo): self.signo = signo def __str__(self): if isinstance(self.signo, mod_signal.Signals): return self.signo._name_ else: return str(self.signo) class SignalHandler: def __init__(self, callback, signals, controller): self._cancelled = False self._callback = callback self._signals = signals self._controller = controller for signal in signals: controller._register_waiter(signal, self) def done(self): return self._cancelled def cancelled(self): return self._cancelled def set_result(self, result): asyncio.get_running_loop().call_soon(self._callback, result) def cancel(self): self._cancelled = True for signal in self._signals: self._controller._discard_waiter(signal, self) class SignalController: _registry: dict[ asyncio.AbstractEventLoop, dict[int, set[SignalController]], ] = {} _waiters: dict[int, set[asyncio.Future]] def __init__(self, *signals): self._signals = signals self._loop = asyncio.get_running_loop() self._waiters = {} def __enter__(self): registry = self._registry.setdefault(self._loop, {}) for signal in self._signals: controllers = registry.setdefault(signal, set()) if not controllers: self._loop.add_signal_handler( signal, self._signal_callback, signal ) controllers.add(self) return self def __exit__(self, exc_type, exc_val, exc_tb): handlers = [ waiter for waiters in self._waiters.values() for waiter in waiters if isinstance(waiter, SignalHandler) ] for handler in handlers: handler.cancel() if self._waiters: warnings.warn( "SignalController exited before wait_for() completed.", stacklevel=1, ) registry = self._registry[self._loop] for signal in self._signals: controllers = registry[signal] controllers.discard(self) if not controllers: del registry[signal] self._loop.remove_signal_handler(signal) if not registry: del self._registry[self._loop] def _on_signal(self, signal): for waiter in self._waiters.get(signal, []): if not waiter.done(): waiter.set_result(signal) def _register_waiter(self, signal, waiter): self._waiters.setdefault(signal, set()).add(waiter) def _discard_waiter(self, signal, waiter): waiters = self._waiters.get(signal) if waiters: waiters.discard(waiter) if not waiters: del self._waiters[signal] async def wait_for(self, fut, *, cancel_on=None): fut = asyncio.ensure_future(fut) # early check: if for any reason fut is already done, just return if fut.done(): return fut.result() # by default, capture all signals configured in this controller if cancel_on is None: cancel_on = self._signals cancelled_by = None outer_cancelled_at_last = False # The design here: we'll wait on a separate Future "waiter" for clean # cancellation. The waiter might be woken up by 3 different events: # 1. The given "fut" is done # 2. A signal is captured # 3. The "waiter" is cancelled by outer code. # For 2, we'll cancel the given "fut" and record the signal in # cancelled_by as a __context__ chain to raise in the next step; for 3, # we cancel the given "fut" and propagate the CancelledError later. # # The complexity of this design is: because our cancellation might be # intercepted in the "fut" code - e.g. a finally block or except block # that traps (and hopefully re-raises) the CancelledError or # BaseException, we need a loop here to ensure all the nested blocks # are exhaustively executed until the "fut" is done, meanwhile the # signals may keep hitting the "fut" code blocks, and "wait_for" is # ready to handle them properly, and return all the SignalError objects # in a __context__ chain preserving the order as they happen. while not fut.done(): waiter = self._loop.create_future() cb = functools.partial(_release_waiter, waiter) fut.add_done_callback(cb) for signal in cancel_on: self._register_waiter(signal, waiter) try: try: signal = await waiter except asyncio.CancelledError: # Event 3: cancelled by outer code. if not fut.done(): fut.cancel() outer_cancelled_at_last = True else: # Event 2: "fut" is still running, which means that # "waiter" was woken up by a signal. if not fut.done(): assert signal is not None fut.cancel() err = SignalError(signal) err.__context__ = cancelled_by cancelled_by = err outer_cancelled_at_last = False # Event 1: "fut" is done - exit the loop naturally. finally: fut.remove_done_callback(cb) # In any case, the "waiter" is done at this time, it needs to # be removed from the signal callback chain, even if we still # need to wait for the signal in the next loop, with a new # "waiter" object. for signal in cancel_on: self._discard_waiter(signal, waiter) # Now that the "fut" is done, let's check its result. It may end up in # 3 different scenarios, listed below inline: try: # 1. "fut" finished happily without raising errors (event 1), just # return the result. Even if we've previously recorded signals # (event 2) or cancellations (event 3), it's now handled by the # user, and we shall simply dispose the recorded cancelled_by. return fut.result() except asyncio.CancelledError as ex: # 2. "fut" is cancelled - this usually means we caught a signal, # but it could also be other reasons, see below. if cancelled_by is not None: # Event 2 happened at least once if outer_cancelled_at_last: # If event 3 is the last event, the outer code is probably # expecting a CancelledError, e.g. asyncio.wait_for(). # Therefore, we just raise it with signal errors attached. ex.__context__ = cancelled_by raise else: # If event 2 is the last event, simply raise the grouped # signal errors, attaching the CancelledError to reveal # where the signals hit the user code. We cannot raise # directly here because cancelled_by.__context__ may have # previously-captured signal errors. cancelled_by.__cause__ = ex else: # Neither event 2 nor 3 happened, the user code cancelled # itself, simply propagate the same error. raise except Exception as e: # 3. For any other errors, we just raise it with the signal errors # attached as __context__ if event 2 happened. if cancelled_by is not None: e.__context__ = cancelled_by raise assert cancelled_by is not None raise cancelled_by def add_handler(self, callback, signals=None) -> SignalHandler: if signals is None: signals = self._signals return SignalHandler(callback, signals, self) @classmethod def _signal_callback(cls, signal): registry = cls._registry.get(asyncio.get_running_loop()) if not registry: return controllers = registry.get(signal) if not controllers: return for controller in controllers: controller._on_signal(signal) ================================================ FILE: edb/common/span.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """This module contains tools for maintaining parser context. Maintaining parser context explicitly is significant overhead and can be difficult in the face of changing AST structures. Certain parser productions require various nesting or unnesting of previously parsed nodes. Both of these operations can result in the parser context not being correctly updated. The tools in this module attempt to automatically maintain parser context based on the context information found in lexer tokens. The general approach is to infer context information by propagating known contexts through the AST structure. """ from __future__ import annotations from typing import Iterable, Any, Optional import re import bisect import edb._edgeql_parser as ql_parser from edb.common import ast from edb.common import markup from edb.common import typeutils NEW_LINE = re.compile(rb'\r\n?|\n') class Span(markup.MarkupExceptionContext): ''' Parser Source Context ''' def __init__( self, filename: Optional[str], buffer: str, start: int, end: int, *, context_lines=1, ): assert start is not None assert end is not None self.filename = filename self.buffer = buffer self.start = start self.end = end self.context_lines = context_lines self._points = None @classmethod def empty(cls) -> Span: return Span( filename=None, buffer='', start=0, end=0, ) def __str__(self): if self.filename: return f'{self.filename}:{self.start}..{self.end}' return f'{self.start}..{self.end}' def __getstate__(self): dic = self.__dict__.copy() dic['_points'] = None return dic def _calc_points(self): # HACK: If we don't have an actual buffer (probably because we # are recompiling after a schema change), just fake something # long enough. Line numbers will be wrong but positions will # still be right... buffer = self.buffer.encode('utf-8') if self.buffer else b' ' * self.end self._points = ql_parser.SourcePoint.from_offsets( buffer, [self.start, self.end] ) @property def start_point(self): if self._points is None: self._calc_points() return self._points[0] @property def end_point(self): if self._points is None: self._calc_points() return self._points[1] @classmethod @markup.serializer.no_ref_detect def as_markup(cls, self, *, ctx): me = markup.elements start = self.start_point # TODO: do more with end? end = self.end_point buf_bytes = self.buffer.encode('utf-8') offset = 0 buf_lines = [] line_offsets = [0] for match in NEW_LINE.finditer(buf_bytes): buf_lines.append(buf_bytes[offset : match.start()].decode('utf-8')) offset = match.end() line_offsets.append(offset) line_no = bisect.bisect_right(line_offsets, start.offset) - 1 context_start = max(0, line_no - self.context_lines) context_end = min(line_no + self.context_lines + 1, len(buf_lines)) endcol = end.column if start.line == end.line else None tbp = me.lang.TracebackPoint( name=self.filename, filename=self.filename, lineno=start.line, colno=start.column, end_colno=endcol, lines=buf_lines[context_start:context_end], # Line numbers are 1 indexed here line_numbers=list(range(context_start + 1, context_end + 1)), context=True, ) return me.lang.ExceptionContext(title=self.title, body=[tbp]) def _get_span(items, *, reverse=False) -> Optional[Span]: ctx = None items = reversed(items) if reverse else items # find non-empty start and end # for item in items: if isinstance(item, (list, tuple)): ctx = _get_span(item, reverse=reverse) if ctx: return ctx else: ctx = getattr(item, 'span', None) if ctx: return ctx return None def get_span(*kids: list[ast.AST]): start_ctx = _get_span(kids) end_ctx = _get_span(kids, reverse=True) if not start_ctx or not end_ctx: return None return Span( filename=start_ctx.filename, buffer=start_ctx.buffer, start=start_ctx.start, end=end_ctx.end, ) def merge_spans(spans: Iterable[Span]) -> Span | None: span_list = list(spans) if not span_list: return None span_list.sort(key=lambda x: (x.start, x.end)) # assume same name and buffer apply to all # return Span( filename=span_list[0].filename, buffer=span_list[0].buffer, start=span_list[0].start, end=span_list[-1].end, ) class SpanPropagator(ast.NodeVisitor): """Propagate span from children to root. It is assumed that if a node has a span, all of its children also have correct span. For a node that has no span, its span is derived as a superset of all of the spans of its descendants. If full_pass is True, nodes with span will still recurse into children and their new span will also be superset of the existing span. """ def __init__(self, default=None, full_pass=False): super().__init__() self._default = default self._full_pass = full_pass def repeated_node_visit(self, node): return self.memo[node] def container_visit(self, node) -> list[Span | None]: span_list: list[Span | None] = [] for el in node: if isinstance(el, ast.AST) or typeutils.is_container(el): span = self.visit(el) if not span: pass elif isinstance(span, (list, tuple)): span_list.extend(span) elif isinstance(span, dict): span_list.extend(span.values()) else: span_list.append(span) return span_list def generic_visit(self, node): # base case: we already have span if not self._full_pass and getattr(node, 'span', None) is not None: return node.span # recurse into children fields span_list = self.container_visit(v for _, v in ast.iter_fields(node)) # also include own span (this can only happen in full_pass) if existing := getattr(node, 'span', None): span_list.append(existing) # merge spans into one node.span = merge_spans(s for s in span_list if s) or self._default return node.span class SpanValidator(ast.NodeVisitor): def generic_visit(self, node): if getattr(node, 'span', None) is None: from edb.edgeql import ast as qlast # some nodes are allowed to not have span, because they are not # always produced by the parser (i.e. ShapeOperation is created as # a default value in the ast node) if not isinstance(node, (qlast.ShapeOperation, qlast.Options)): raise RuntimeError('node {} has no span'.format(node)) super().generic_visit(node) # Finds the node in AST by position within the source. # It returns the first node whose span contains the target offset in a # post-order traversal of the AST. # To be exact, it returns a path from that node to the tree root. # Or None if not found. Path is never empty. def find_by_source_position[T: ast.AST]( node: T, target_offset: int ) -> list[T] | None: finder = SpanFinder(target_offset) finder.visit(node) return finder.found_path class SpanFinder(ast.NodeVisitor): target_offset: int found_path: list[Any] | None def __init__(self, target_offset: int): super().__init__() self.target_offset = target_offset self.found_path = None def generic_visit(self, node, *, combine_results=None) -> Any: if self.found_path is not None: return has_span = False if node_span := getattr(node, 'span', None): has_span = True if not span_contains(node_span, self.target_offset): return super().generic_visit(node) if self.found_path is None: if has_span: self.found_path = [node] else: self.found_path.append(node) def span_contains(span: Span, target_offset: int) -> bool: return span.start <= target_offset and target_offset <= span.end ================================================ FILE: edb/common/struct.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2009-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, cast, Final, Iterable, Iterator, Mapping, Optional, Self, ) import collections import enum from . import checked class ProtoField: __slots__ = () class NoDefaultT(enum.Enum): NoDefault = 0 NoDefault: Final = NoDefaultT.NoDefault class Field[T](ProtoField): """``Field`` objects: attributes of :class:`Struct`.""" __slots__ = ('name', 'type', 'default', 'coerce', 'formatters', 'frozen') name: str def __init__( self, type_: type[T], default: T | NoDefaultT = NoDefault, *, coerce: bool = False, str_formatter: Callable[[T], str] = str, repr_formatter: Callable[[T], str] = repr, frozen: bool = False, ) -> None: """ :param type: The type of the value in the field. :param default: Default field value. If not specified, the field would be considered required and a failure to specify its value when initializing a ``Struct`` will raise :exc:`TypeError`. `default` can be a callable taking no arguments. :param bool coerce: If set to ``True`` - coerce field's value to its type. """ self.type = type_ self.default = default self.coerce = coerce self.frozen = frozen self.formatters = {'str': str_formatter, 'repr': repr_formatter} def copy(self) -> Field[T]: return self.__class__( self.type, self.default, coerce=self.coerce, str_formatter=self.formatters['str'], repr_formatter=self.formatters['repr']) def adapt(self, value: Any) -> T: # cast() below due to https://github.com/python/mypy/issues/7920 ctype = cast(type, self.type) if not isinstance(value, ctype): value = ctype(value) # Type ignore below because with ctype we lost information that # it is indeed a Type[T]. return value # type: ignore @property def required(self) -> bool: return self.default is NoDefault class StructMeta(type): _fields: dict[str, Field[Any]] _sorted_fields: dict[str, Field[Any]] def __new__[StructMeta_T: StructMeta]( mcls: type[StructMeta_T], name: str, bases: tuple[type, ...], clsdict: dict[str, Any], *, use_slots: bool = True, **kwargs: Any, ) -> StructMeta_T: fields = {} myfields = {} for k, v in clsdict.items(): if not isinstance(v, ProtoField): continue if not isinstance(v, Field): raise TypeError( f'cannot create {name} class: struct.Field expected, ' f'got {type(v)}') v.name = k myfields[k] = v if '__slots__' not in clsdict: if use_slots is None: for base in bases: sa = '{}.{}_slots'.format(base.__module__, base.__name__) if isinstance(base, StructMeta) and hasattr(base, sa): use_slots = True break if use_slots: clsdict['__slots__'] = tuple(myfields.keys()) for key in myfields.keys(): del clsdict[key] cls = super().__new__(mcls, name, bases, clsdict, **kwargs) if use_slots: sa = '{}.{}_slots'.format(cls.__module__, cls.__name__) setattr(cls, sa, True) for parent in reversed(cls.__mro__): if parent is cls: fields.update(myfields) elif isinstance(parent, StructMeta): fields.update(parent.get_ownfields()) for field in fields.values(): if field.coerce and not issubclass(cls, RTStruct): raise TypeError( f'{cls.__name__}.{field.name} cannot be declared ' f'with coerce=True: {cls.__name__} is not an RTStruct', ) if field.frozen and not issubclass(cls, RTStruct): raise TypeError( f'{cls.__name__}.{field.name} cannot be declared ' f'with frozen=True: {cls.__name__} is not an RTStruct', ) cls._fields = fields cls._sorted_fields = collections.OrderedDict( sorted(fields.items(), key=lambda e: e[0])) fa = '{}.{}_fields'.format(cls.__module__, cls.__name__) setattr(cls, fa, myfields) return cls def get_field(cls, name: str) -> Optional[Field[Any]]: return cls._fields.get(name) def get_fields(cls, sorted: bool = False) -> dict[str, Field[Any]]: return cls._sorted_fields if sorted else cls._fields def get_ownfields(cls) -> dict[str, Field[Any]]: return getattr( # type: ignore cls, '{}.{}_fields'.format(cls.__module__, cls.__name__)) class Struct(metaclass=StructMeta): """A base class allowing implementation of attribute objects protocols. Each struct has a collection of ``Field`` objects, which should be defined as class attributes of the ``Struct`` subclass. Unlike ``collections.namedtuple``, ``Struct`` is much easier to mix in and define. Furthermore, fields are strictly typed and can be declared as required. By default, Struct will reject attributes, which have not been declared as fields. A ``MixedStruct`` subclass does have this restriction. .. code-block:: pycon >>> from edb.common.struct import Struct, Field >>> class MyStruct(Struct): ... name = Field(type=str) ... description = Field(type=str, default=None) ... >>> MyStruct(name='Spam') >>> MyStruct(name='Ham', description='Good Ham') If ``use_slots`` is set to ``True`` in a class signature, ``__slots__`` will be used to create dictless instances, with reduced memory footprint: .. code-block:: pycon >>> class S1(Struct, use_slots=True): ... foo = Field(str, None) >>> class S2(S1): ... bar = Field(str, None) >>> S2().foo = '1' >>> S2().bar = '2' >>> S2().spam = '2' AttributeError: 'S2' object has no attribute 'spam' """ def __init__(self, **kwargs: Any) -> None: """ :raises: TypeError if invalid field value was provided or a value was not provided for a field without a default value. """ self._check_init_argnames(kwargs) self._init_fields(kwargs) def __setstate__(self, state: Mapping[str, Any]) -> None: if isinstance(state, tuple) and len(state) == 2: state, slotstate = state else: slotstate = None if state: self.update(**state) if slotstate: self.update(**slotstate) def update(self, *args: Any, **kwargs: Any) -> None: """Update the field values.""" values: dict[str, Any] = {} values.update(*args, **kwargs) self._check_init_argnames(values) for k, v in values.items(): setattr(self, k, v) def setdefaults(self) -> list[str]: """Initialize unset fields with default values.""" fields_set = [] for field_name, field in self.__class__._fields.items(): value = getattr(self, field_name) if value is None and field.default is not None: value = self._getdefault(field_name, field) self.set_default_value(field_name, value) fields_set.append(field_name) return fields_set def set_default_value(self, field_name: str, value: Any) -> None: setattr(self, field_name, value) def formatfields( self, formatter: str = 'str', ) -> Iterator[tuple[str, str]]: """Return an iterator over fields formatted using `formatter`.""" for name, field in self.__class__._fields.items(): formatter_obj = field.formatters.get(formatter) if formatter_obj: yield (name, formatter_obj(getattr(self, name))) def _copy_and_replace[Struct_T: Struct]( self, cls: type[Struct_T], **replacements: Any, ) -> Struct_T: args = {f: getattr(self, f) for f in cls._fields.keys()} if replacements: args.update(replacements) return cls(**args) def copy_with_class[Struct_T: Struct]( self, cls: type[Struct_T] ) -> Struct_T: return self._copy_and_replace(cls) def copy(self: Self) -> Self: return self.copy_with_class(type(self)) def replace(self: Self, **replacements: Any) -> Self: return self._copy_and_replace(type(self), **replacements) def items(self) -> Iterator[tuple[str, Any]]: for field in self.__class__._fields: yield field, getattr(self, field, None) def as_tuple(self) -> tuple[Any, ...]: result = [] for field in self.__class__._fields: result.append(getattr(self, field, None)) return tuple(result) __copy__ = copy def __iter__(self) -> Iterator[str]: return iter(self.__class__._fields) def __str__(self) -> str: fields = ', '.join( f'{name}={value}' for name, value in self.formatfields('str') ) if fields: fields = f' {fields}' return f'<{self.__class__.__name__}{fields} at {id(self):#x}>' def __repr__(self) -> str: fields = ', '.join( f'{name}={value}' for name, value in self.formatfields('repr') ) if fields: fields = f' {fields}' return f'<{self.__class__.__name__}{fields} at {id(self):#x}>' def _init_fields( self, values: Mapping[str, Any], ) -> None: for field_name, field in self.__class__._fields.items(): value = values.get(field_name) if value is None and field.default is not None: value = self._getdefault(field_name, field) setattr(self, field_name, value) def _check_init_argnames(self, args: Iterable[str]) -> None: extra = set(args) - set(self.__class__._fields) - {'_in_init_'} if extra: fmt = '{} {} invalid argument{} for struct {}.{}' plural = len(extra) > 1 msg = fmt.format( ', '.join(extra), 'are' if plural else 'is an', 's' if plural else '', self.__class__.__module__, self.__class__.__name__) raise TypeError(msg) def _getdefault[T]( self, field_name: str, field: Field[T], ) -> T: ftype = field.type if field.default == ftype: value = field.default() # type: ignore elif field.default is NoDefault: raise TypeError( '%s.%s.%s is required' % ( self.__class__.__module__, self.__class__.__name__, field_name)) else: value = field.default return value # type: ignore def get_field_value(self, field_name: str) -> Any: try: return self.__dict__[field_name] except KeyError as e: field = self.__class__.get_field(field_name) if field is None: raise TypeError( f'{field_name} is not a valid field in this struct') try: return self._getdefault(field_name, field) except TypeError: raise e class RTStruct(Struct): """A variant of Struct with runtime type validation""" __slots__ = ('_in_init_',) def __init__(self, **kwargs: Any) -> None: """ :raises: TypeError if invalid field value was provided or a value was not provided for a field without a default value. """ self._check_init_argnames(kwargs) self._in_init_ = True try: self._init_fields(kwargs) finally: self._in_init_ = False def __setstate__(self, state: Mapping[str, Any]) -> None: self._in_init_ = True try: super().__setstate__(state) finally: self._in_init_ = False def __setattr__(self, name: str, value: Any) -> None: field = type(self)._fields.get(name) if field is not None: value = self._check_field_type(field, name, value) if field.frozen and not self._in_init_: raise ValueError(f'cannot assign to frozen field {name!r}') super().__setattr__(name, value) def _check_field_type[T](self, field: Field[T], name: str, value: Any) -> T: if (field.type and value is not None and not isinstance(value, field.type)): if field.coerce: ftype = field.type if issubclass(ftype, (checked.AbstractCheckedList, checked.AbstractCheckedSet)): val_list = [] for v in value: if v is not None and not isinstance(v, ftype.type): v = ftype.type(v) val_list.append(v) value = val_list elif issubclass(ftype, checked.CheckedDict): val_dict = {} for k, v in value.items(): if k is not None and not isinstance(k, ftype.keytype): k = ftype.keytype(k) if (v is not None and not isinstance(v, ftype.valuetype)): v = ftype.valuetype(v) val_dict[k] = v value = val_dict try: return ftype(value) # type: ignore except Exception as ex: raise TypeError( 'cannot coerce {!r} value {!r} ' 'to {}'.format(name, value, ftype)) from ex raise TypeError( '{}.{}.{}: expected {} but got {!r}'.format( self.__class__.__module__, self.__class__.__name__, name, field.type.__name__, value)) return value # type: ignore class MixedStructMeta(StructMeta): def __new__( mcls, name: str, bases: tuple[type, ...], clsdict: dict[str, Any], *, use_slots: bool = False, **kwargs: Any, ) -> MixedStructMeta: return super().__new__( mcls, name, bases, clsdict, use_slots=use_slots, **kwargs, ) class MixedStruct(Struct, metaclass=MixedStructMeta): def _check_init_argnames(self, args: Iterable[Any]) -> None: pass class MixedRTStruct(RTStruct, metaclass=MixedStructMeta): def _check_init_argnames(self, args: Iterable[Any]) -> None: pass ================================================ FILE: edb/common/supervisor.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional import asyncio import itertools class Supervisor: def __init__(self, *, _name, _loop, _private): if _name is None: self._name = f'sup#{_name_counter()}' else: self._name = str(_name) self._loop = _loop self._unfinished_tasks = 0 self._cancelled = False self._tasks = set() self._errors = [] self._base_error = None self._on_completed_fut = None @classmethod async def create(cls, *, name: Optional[str] = None): loop = asyncio.get_running_loop() return cls(_loop=loop, _name=name, _private=True) def __repr__(self): msg = f'= 0 if self._on_completed_fut is not None and not self._unfinished_tasks: if not self._on_completed_fut.done(): self._on_completed_fut.set_result(True) if task.cancelled(): return exc = task.exception() if exc is None: return self._errors.append(exc) if self._is_base_error(exc) and self._base_error is None: self._base_error = exc self._cancel() def _cancel(self): self._cancelled = True for t in self._tasks: if not t.done(): t.cancel() def _is_base_error(self, exc): assert isinstance(exc, BaseException) return not isinstance(exc, Exception) _name_counter = itertools.count(1).__next__ ================================================ FILE: edb/common/term.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """A collection of functions and classes to simplify output to terminal.""" from __future__ import annotations from typing import Optional import os import sys import fcntl import termios import struct import functools from edb.common.colorsys import rgb_distance as color_distance from edb.common.colorsys import Color def isatty(fileno): return os.isatty(fileno) _COLORS: Optional[int] = None _colorize = 'auto' def set_colorization_option(option): global _colorize _colorize = option def max_colors(): """Max colors current terminal supports. :returns: Integer. For instance, for 'xterm' it is usually 256 .. note:: Uses :mod:`curses` """ global _COLORS if _COLORS is None: try: import curses try: curses.setupterm() _COLORS = curses.tigetnum('colors') except (OSError, curses.error): pass except ImportError: pass if _COLORS is None: _COLORS = 1 return _COLORS def supports_colors(fileno): """Check if ``fileno`` file-descriptor supports colored output. :params int fileno: file-descriptor :returns: bool """ return ( isatty(fileno) and os.getenv('TERM') != 'dumb' and os.getenv('ANSI_COLORS_DISABLED') is None) def size(fileno): """Current terminal height and width (lines and columns). :params int fileno: file-descriptor :returns: Tuple of two integers - lines and columns respectively. ``(None, None)`` if ``fileno`` is not a terminal """ if not isatty(fileno): return None, None try: size = struct.unpack( '2h', fcntl.ioctl(fileno, termios.TIOCGWINSZ, ' ')) except Exception: size = (os.getenv('LINES', 25), os.getenv('COLUMNS', 80)) return size def use_colors(fileno=None): """Check on whether use colored output or not. Checks ``shell.MainCommand.colorize`` config setting and ``fileno`` for being capable of displaying colors. :param int fileno: File-descriptor. If ``None``, checks on ``sys.stdout`` :returns bool: Whether you can or can not use color terminal output """ if _colorize == 'on': return True if _colorize == 'off': return False assert _colorize == 'auto' if fileno is None: try: fileno = sys.stdout.fileno() except OSError: return False return supports_colors(fileno) # XTerm 256 colors table. # _MAP256 = { 16: '#000000', 17: '#00005f', 18: '#000087', 19: '#0000af', 20: '#0000d7', 21: '#0000ff', 22: '#005f00', 23: '#005f5f', 24: '#005f87', 25: '#005faf', 26: '#005fd7', 27: '#005fff', 28: '#008700', 29: '#00875f', 30: '#008787', 31: '#0087af', 32: '#0087d7', 33: '#0087ff', 34: '#00af00', 35: '#00af5f', 36: '#00af87', 37: '#00afaf', 38: '#00afd7', 39: '#00afff', 40: '#00d700', 41: '#00d75f', 42: '#00d787', 43: '#00d7af', 44: '#00d7d7', 45: '#00d7ff', 46: '#00ff00', 47: '#00ff5f', 48: '#00ff87', 49: '#00ffaf', 50: '#00ffd7', 51: '#00ffff', 52: '#5f0000', 53: '#5f005f', 54: '#5f0087', 55: '#5f00af', 56: '#5f00d7', 57: '#5f00ff', 58: '#5f5f00', 59: '#5f5f5f', 60: '#5f5f87', 61: '#5f5faf', 62: '#5f5fd7', 63: '#5f5fff', 64: '#5f8700', 65: '#5f875f', 66: '#5f8787', 67: '#5f87af', 68: '#5f87d7', 69: '#5f87ff', 70: '#5faf00', 71: '#5faf5f', 72: '#5faf87', 73: '#5fafaf', 74: '#5fafd7', 75: '#5fafff', 76: '#5fd700', 77: '#5fd75f', 78: '#5fd787', 79: '#5fd7af', 80: '#5fd7d7', 81: '#5fd7ff', 82: '#5fff00', 83: '#5fff5f', 84: '#5fff87', 85: '#5fffaf', 86: '#5fffd7', 87: '#5fffff', 88: '#870000', 89: '#87005f', 90: '#870087', 91: '#8700af', 92: '#8700d7', 93: '#8700ff', 94: '#875f00', 95: '#875f5f', 96: '#875f87', 97: '#875faf', 98: '#875fd7', 99: '#875fff', 100: '#878700', 101: '#87875f', 102: '#878787', 103: '#8787af', 104: '#8787d7', 105: '#8787ff', 106: '#87af00', 107: '#87af5f', 108: '#87af87', 109: '#87afaf', 110: '#87afd7', 111: '#87afff', 112: '#87d700', 113: '#87d75f', 114: '#87d787', 115: '#87d7af', 116: '#87d7d7', 117: '#87d7ff', 118: '#87ff00', 119: '#87ff5f', 120: '#87ff87', 121: '#87ffaf', 122: '#87ffd7', 123: '#87ffff', 124: '#af0000', 125: '#af005f', 126: '#af0087', 127: '#af00af', 128: '#af00d7', 129: '#af00ff', 130: '#af5f00', 131: '#af5f5f', 132: '#af5f87', 133: '#af5faf', 134: '#af5fd7', 135: '#af5fff', 136: '#af8700', 137: '#af875f', 138: '#af8787', 139: '#af87af', 140: '#af87d7', 141: '#af87ff', 142: '#afaf00', 143: '#afaf5f', 144: '#afaf87', 145: '#afafaf', 146: '#afafd7', 147: '#afafff', 148: '#afd700', 149: '#afd75f', 150: '#afd787', 151: '#afd7af', 152: '#afd7d7', 153: '#afd7ff', 154: '#afff00', 155: '#afff5f', 156: '#afff87', 157: '#afffaf', 158: '#afffd7', 159: '#afffff', 160: '#d70000', 161: '#d7005f', 162: '#d70087', 163: '#d700af', 164: '#d700d7', 165: '#d700ff', 166: '#d75f00', 167: '#d75f5f', 168: '#d75f87', 169: '#d75faf', 170: '#d75fd7', 171: '#d75fff', 172: '#d78700', 173: '#d7875f', 174: '#d78787', 175: '#d787af', 176: '#d787d7', 177: '#d787ff', 178: '#d7af00', 179: '#d7af5f', 180: '#d7af87', 181: '#d7afaf', 182: '#d7afd7', 183: '#d7afff', 184: '#d7d700', 185: '#d7d75f', 186: '#d7d787', 187: '#d7d7af', 188: '#d7d7d7', 189: '#d7d7ff', 190: '#d7ff00', 191: '#d7ff5f', 192: '#d7ff87', 193: '#d7ffaf', 194: '#d7ffd7', 195: '#d7ffff', 196: '#ff0000', 197: '#ff005f', 198: '#ff0087', 199: '#ff00af', 200: '#ff00d7', 201: '#ff00ff', 202: '#ff5f00', 203: '#ff5f5f', 204: '#ff5f87', 205: '#ff5faf', 206: '#ff5fd7', 207: '#ff5fff', 208: '#ff8700', 209: '#ff875f', 210: '#ff8787', 211: '#ff87af', 212: '#ff87d7', 213: '#ff87ff', 214: '#ffaf00', 215: '#ffaf5f', 216: '#ffaf87', 217: '#ffafaf', 218: '#ffafd7', 219: '#ffafff', 220: '#ffd700', 221: '#ffd75f', 222: '#ffd787', 223: '#ffd7af', 224: '#ffd7d7', 225: '#ffd7ff', 226: '#ffff00', 227: '#ffff5f', 228: '#ffff87', 229: '#ffffaf', 230: '#ffffd7', 231: '#ffffff', 232: '#080808', 233: '#121212', 234: '#1c1c1c', 235: '#262626', 236: '#303030', 237: '#3a3a3a', 238: '#444444', 239: '#4e4e4e', 240: '#585858', 241: '#606060', 242: '#666666', 243: '#767676', 244: '#808080', 245: '#8a8a8a', 246: '#949494', 247: '#9e9e9e', 248: '#a8a8a8', 249: '#b2b2b2', 250: '#bcbcbc', 251: '#c6c6c6', 252: '#d0d0d0', 253: '#dadada', 254: '#e4e4e4', 255: '#eeeeee' } def _is_opt_getter(name: str): return lambda self: self._is_opt(name) def _set_opt_setter(name: str): return lambda self, value: self._set_opt(name, value) class AbstractStyle: """Encapsulates information about text-style. For instance, what color should text be, should it be underlined or bold etc. Use instances of :class:`Style16` or :class:`Style256`, this class is abstract. """ __slots__ = ( '_opts', '_color', '_bgcolor', '_term_prefix', '_term_postfix') _opts_table = { 'bold': '1', 'faint': '2', 'italic': '3', 'underline': '4', 'blink': '5', 'overline': '6', 'reverse': '7' } _ropts_table = {v: k for k, v in _opts_table.items()} def __init__( self, *, color=None, bgcolor=None, bold=False, faint=False, italic=False, underline=False, overline=False, reverse=False, ): self._opts = set() self._color = None self._bgcolor = None self.color = color self.bgcolor = bgcolor self.bold = bold self.faint = faint self.italic = italic self.underline = underline self.overline = overline self.reverse = reverse def _filter_color(self, color): raise NotImplementedError def _get_color(self): return self._rcolor_table[self._color] def _set_color(self, color): self._color = self._filter_color(color) self._recalc() color = property(_get_color, _set_color) def _get_bgcolor(self): return self._rcolor_table[self._bgcolor] def _set_bgcolor(self, color): self._bgcolor = self._filter_color(color) self._recalc() bgcolor = property(_get_bgcolor, _set_bgcolor) @property def empty(self): return not bool(self._term_prefix) def _is_opt(self, name: str) -> bool: assert name in self._opts_table return self._opts_table[name] in self._opts def _set_opt(self, name, value): try: tr_name = self._opts_table[name] except KeyError: raise ValueError('unknown style option {!r}'.format(name)) if value: self._opts.add(tr_name) else: if tr_name in self._opts: self._opts.discard(tr_name) self._recalc() bold = property(_is_opt_getter('bold'), _set_opt_setter('bold')) faint = property(_is_opt_getter('faint'), _set_opt_setter('faint')) italic = property(_is_opt_getter('italic'), _set_opt_setter('italic')) underline = property( _is_opt_getter('underline'), _set_opt_setter('underline')) blink = property(_is_opt_getter('blink'), _set_opt_setter('blink')) overline = property( _is_opt_getter('overline'), _set_opt_setter('overline')) reverse = property(_is_opt_getter('reverse'), _set_opt_setter('reverse')) def _recalc(self): cmd = [] if self._color is not None: if self._color > 15: cmd.append('38;5;{}'.format(self._color)) else: cmd.append('3{}'.format(self._color)) if self._bgcolor is not None: if self._bgcolor > 15: cmd.append('48;5;{}'.format(self._bgcolor)) else: cmd.append('4{}'.format(self._bgcolor)) cmd.extend(self._opts) if cmd: self._term_prefix = '\x1B[{}m'.format(';'.join(cmd)) self._term_postfix = '\x1B[0m' else: self._term_prefix = '' self._term_postfix = '' def apply(self, str): """Apply ANSI escape sequences to :param:str. If the result can be printed to a terminal that supports styling. """ return self._term_prefix + str + self._term_postfix class Style16(AbstractStyle): """16-color style.""" _color_table = { 'black': 0, 'red': 1, 'green': 2, 'yellow': 3, 'blue': 4, 'magenta': 5, 'cyan': 6, 'white': 7 } _rcolor_table = {v: k for k, v in _color_table.items()} def _filter_color(self, color): if color is None: return None try: return self._color_table[color] except KeyError as ex: raise ValueError('unknown color {!r}'.format(color)) from ex class Style256(AbstractStyle): """256-color style. Accepts any rgb color in hex format, for instance: .. code-block:: pycon >>> Style256(color='#abcdef') Or by css name: .. code-block:: pycon >>> Style256(color='chocolate') In case of a color being outside of standard xterm 256 color palette, it'll try to locate the closest color in it. """ _color_table = {v: k for k, v in _MAP256.items()} _rcolor_table = _MAP256 _rgb_color_table = { Color.from_string(v).rgb_channels(as_floats=True): k for k, v in _MAP256.items() } @staticmethod @functools.lru_cache(500) def _filter_color(color): if color is None: return None try: return Style256._color_table[color] except KeyError: pass c = Color.from_string(color).rgb_channels(as_floats=True) return min( Style256._rgb_color_table.items(), key=lambda item: color_distance(item[0][0], item[0][1], item[0][2], *c))[1] class StylesTable: """Base class for simple style tables.""" def __getattr__(self, key): # If we're querying some non-existing style, pretend it's empty # return Style16() def dump(self): for name, style in self.__class__.__dict__.items(): if isinstance(style, AbstractStyle): print(style.apply(name)) ================================================ FILE: edb/common/token_bucket.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import time class TokenBucket: _capacity: float _token_per_sec: float _tokens: float _last_fill_time: float def __init__(self, capacity: float, token_per_sec: float): self._capacity = capacity self._token_per_sec = token_per_sec self._tokens = capacity self._last_fill_time = time.monotonic() def consume(self, tokens: int) -> float: if tokens <= 0: return True now = time.monotonic() tokens_to_add = (now - self._last_fill_time) * self._token_per_sec self._tokens = min(self._capacity, self._tokens + tokens_to_add) self._last_fill_time = now left = self._tokens - tokens if left >= 0: self._tokens -= tokens return 0 else: return -left / (tokens * self._token_per_sec) ================================================ FILE: edb/common/topological.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, Protocol, Iterable, Iterator, Mapping, MutableSet, TYPE_CHECKING, ) from collections import defaultdict from edb.common.ordered import OrderedSet class UnresolvedReferenceError(Exception): pass class CycleError(Exception): def __init__( self, msg: str, item: Any, path: tuple[Any, ...] = (), ) -> None: super().__init__(msg) self.item = item self.path = path class DepGraphEntry[K, V, T]: #: The graph node item: V #: An optional set of dependencies for the graph node as lookup keys. deps: MutableSet[K] #: An optional set of *weak* dependencies for the graph node as #: lookup keys. The difference from regular deps is that weak deps #: that cause cycles are ignored. Essentially, weak deps dictate #: a _preference_ in order rather than a requirement. weak_deps: MutableSet[K] merge: Optional[MutableSet[K]] loop_control: MutableSet[K] extra: Optional[T] def __init__( self, item: V, deps: Optional[MutableSet[K]] = None, merge: Optional[MutableSet[K]] = None, loop_control: Optional[MutableSet[K]] = None, extra: Optional[T] = None, weak_deps: Optional[MutableSet[K]] = None, ) -> None: self.item = item if deps is None: deps = set() self.deps = deps self.merge = merge if loop_control is None: loop_control = set() self.loop_control = loop_control self.extra = extra if weak_deps is None: weak_deps = set() self.weak_deps = weak_deps def sort_ex[K, V, T]( graph: Mapping[K, DepGraphEntry[K, V, T]], *, allow_unresolved: bool = False, ) -> Iterator[tuple[K, DepGraphEntry[K, V, T]]]: adj: dict[K, OrderedSet[K]] = defaultdict(OrderedSet) weak_adj: dict[K, OrderedSet[K]] = defaultdict(OrderedSet) loop_control: dict[K, OrderedSet[K]] = defaultdict(OrderedSet) for item_name, item in graph.items(): if item.weak_deps: for dep in item.weak_deps: if dep in graph: weak_adj[item_name].add(dep) elif not allow_unresolved: raise UnresolvedReferenceError( 'reference to an undefined item {} in {}'.format( dep, item_name)) if item.merge is not None: for merge in item.merge: if merge in graph: adj[item_name].add(merge) elif not allow_unresolved: raise UnresolvedReferenceError( 'reference to an undefined item {} in {}'.format( merge, item_name)) if item.deps: for dep in item.deps: if dep in graph: adj[item_name].add(dep) elif not allow_unresolved: raise UnresolvedReferenceError( 'reference to an undefined item {} in {}'.format( dep, item_name)) if item.loop_control: for ctrl in item.loop_control: if ctrl in graph: loop_control[item_name].add(ctrl) elif not allow_unresolved: raise UnresolvedReferenceError( 'reference to an undefined item {} in {}'.format( ctrl, item_name)) visiting: OrderedSet[K] = OrderedSet() visiting_weak: MutableSet[K] = set() visited = set() order = [] def visit( item: K, for_control: bool = False, weak_link: bool = False, ) -> None: if item in visiting: # Separate the matching item from the rest of the visiting # set for error reporting. vis_list = tuple(visiting - {item}) cycle_item = item if len(vis_list) == 0 else vis_list[-1] raise CycleError( f"dependency cycle between {cycle_item!r} " f"and {item!r}", path=vis_list, item=item, ) if item not in visited: visiting.add(item) if weak_link: visiting_weak.add(item) try: for n in weak_adj[item]: try: visit(n, weak_link=True) except CycleError: if len(visiting_weak) == 0: pass else: raise for n in adj[item]: visit(n, weak_link=weak_link) for n in loop_control[item]: visit(n, weak_link=weak_link, for_control=True) if not for_control: order.append(item) visited.add(item) except CycleError: if len(visiting_weak) == 1: pass else: raise finally: visiting.remove(item) if weak_link: visiting_weak.remove(item) for key in graph: visit(key) return ((key, graph[key]) for key in order) def sort[K, V, T]( graph: Mapping[K, DepGraphEntry[K, V, T]], *, allow_unresolved: bool = False, ) -> tuple[V, ...]: items = sort_ex(graph, allow_unresolved=allow_unresolved) return tuple(i[1].item for i in items) if TYPE_CHECKING: class MergeFunction[V](Protocol): def __call__( self, item: V, parent: V, **kwargs: Any, ) -> V: ... def normalize[K, V, T]( graph: Mapping[K, DepGraphEntry[K, V, T]], merger: MergeFunction[V], **merger_kwargs: Any, ) -> Iterable[V]: merged: dict[K, V] = {} for name, item in sort_ex(graph): merge = item.merge if merge: for m in merge: merger(item.item, merged[m], **merger_kwargs) merged.setdefault(name, item.item) return merged.values() ================================================ FILE: edb/common/traceback.py ================================================ # mypy: disable-error-code="attr-defined" # Portions copyright 2019-present MagicStack Inc. and the EdgeDB authors. # Portions copyright 2001-2019 Python Software Foundation. # License: PSFL. """ Provides stack trace formatting that prints `{filename}:{line}`, instead of `"{filename}", line {line}`. Stolen from Python's traceback module. """ import traceback import typing from contextlib import suppress StackSummaryLike = ( traceback.StackSummary | list[tuple[str, typing.Any, str, typing.Any]] ) def format_exception(e: BaseException) -> str: exctype = type(e) value = e tb = e.__traceback__ tb_e = traceback.TracebackException( exctype, value, tb, compact=True ) tb_e.stack = StandardStackSummary(tb_e.stack) return '\n'.join(tb_e.format()) def format_stack_summary(stack: StackSummaryLike) -> list[str]: return _format_stack_summary(_into_list_of_frames(stack)) class StandardStackSummary(traceback.StackSummary): def format(self) -> list[str]: return format_stack_summary(self) def _into_list_of_frames(a_list: StackSummaryLike): """ Create a StackSummary object from a supplied list of FrameSummary objects or old-style list of tuples. """ # While doing a fast-path check for isinstance(a_list, StackSummary) is # appealing, idlelib.run.cleanup_traceback and other similar code may # break this by making arbitrary frames plain tuples, so we need to # check on a frame by frame basis. result = [] for frame in a_list: if isinstance(frame, traceback.FrameSummary): result.append(frame) else: filename, lineno, name, line = frame result.append( traceback.FrameSummary(filename, lineno, name, line=line) ) return result def _format_stack_summary(stack: list[traceback.FrameSummary]): """Format the stack ready for printing. Returns a list of strings ready for printing. Each string in the resulting list corresponds to a single frame from the stack. Each string ends in a newline; the strings may contain internal newlines as well, for those items with source text lines. For long sequences of the same frame and line, the first few repetitions are shown, followed by a summary line stating the exact number of further repetitions. """ result = [] last_file = None last_line = None last_name = None count = 0 for frame_summary in stack: formatted_frame = _format_frame_summary(frame_summary) if formatted_frame is None: continue if ( last_file is None or last_file != frame_summary.filename or last_line is None or last_line != frame_summary.lineno or last_name is None or last_name != frame_summary.name ): if count > traceback._RECURSIVE_CUTOFF: count -= traceback._RECURSIVE_CUTOFF result.append( f' [Previous line repeated {count} more ' f'time{"s" if count > 1 else ""}]\n' ) last_file = frame_summary.filename last_line = frame_summary.lineno last_name = frame_summary.name count = 0 count += 1 if count > traceback._RECURSIVE_CUTOFF: continue result.append(formatted_frame) if count > traceback._RECURSIVE_CUTOFF: count -= traceback._RECURSIVE_CUTOFF result.append( f' [Previous line repeated {count} more ' f'time{"s" if count > 1 else ""}]\n' ) return result def _format_frame_summary(frame: traceback.FrameSummary): """Format the lines for a single FrameSummary. Returns a string representing one frame involved in the stack. This gets called for every frame to be printed in the stack summary. """ row = [f' {frame.filename}:{frame.lineno}, in {frame.name}\n'] if frame.line: stripped_line = frame.line.strip() row.append(' {}\n'.format(stripped_line)) orig_line_len = len(frame._original_line) frame_line_len = len(frame.line.lstrip()) stripped_characters = orig_line_len - frame_line_len if frame.colno is not None and frame.end_colno is not None: start_offset = ( traceback._byte_offset_to_character_offset( frame._original_line, frame.colno ) + 1 ) end_offset = ( traceback._byte_offset_to_character_offset( frame._original_line, frame.end_colno ) + 1 ) anchors = None if frame.lineno == frame.end_lineno: with suppress(Exception): anchors = ( traceback._extract_caret_anchors_from_line_segment( frame._original_line[ start_offset - 1 : end_offset - 1 ] ) ) else: end_offset = stripped_characters + len(stripped_line) # show indicators if primary char doesn't span the frame line if end_offset - start_offset < len(stripped_line) or ( anchors and anchors.right_start_offset - anchors.left_end_offset > 0 ): row.append(' ') row.append(' ' * (start_offset - stripped_characters)) if anchors: row.append(anchors.primary_char * (anchors.left_end_offset)) row.append( anchors.secondary_char * (anchors.right_start_offset - anchors.left_end_offset) ) row.append( anchors.primary_char * ( end_offset - start_offset - anchors.right_start_offset ) ) else: row.append('^' * (end_offset - start_offset)) row.append('\n') if frame.locals: for name, value in sorted(frame.locals.items()): row.append(' {name} = {value}\n'.format(name=name, value=value)) return ''.join(row) ================================================ FILE: edb/common/turbo_uuid.pyi ================================================ from __future__ import annotations import uuid class UUID(uuid.UUID): def __init__(self, inp: bytes | str) -> None: ... ================================================ FILE: edb/common/typeutils.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Callable, Optional, Sequence import collections.abc import functools def chain_decorators[TC: Callable]( funcs: Sequence[Callable[[TC], TC]] ) -> Callable[[TC], TC]: def f(func: TC) -> TC: for dec in reversed(funcs): func = dec(func) return func return f def downcast[T](typ: type[T], x: Any) -> T: assert isinstance(x, typ) return x def not_none[T](x: Optional[T]) -> T: assert x is not None return x @functools.lru_cache(1024) def _is_container_type(cls): return ( issubclass(cls, (collections.abc.Container)) and not issubclass(cls, (str, bytes, bytearray, memoryview)) # not namedtuple, either and not (issubclass(cls, tuple) and hasattr(cls, '_fields')) ) @functools.lru_cache(1024) def _is_iterable_type(cls): return ( issubclass(cls, collections.abc.Iterable) ) def is_container(obj): cls = obj.__class__ return _is_container_type(cls) and _is_iterable_type(cls) def is_container_type(type_): return isinstance(type_, type) and _is_container_type(type_) ================================================ FILE: edb/common/typing_inspect.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # Portions copyright 2017-2020 Ivan Levkivskyi # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # try: # this will fail on Python 3.8 because it doesn't have `types.GenericAlias` from edb.common._typing_inspect import * # NoQA except ImportError: from typing_inspect import * # NoQA ================================================ FILE: edb/common/uuidgen.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import hashlib import os import re import uuid from . import turbo_uuid UUID = turbo_uuid.UUID UUID_RE_S = r'[a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12}' UUID_RE = re.compile(UUID_RE_S, re.I) def uuid1mc() -> uuid.UUID: """Generate a v1 UUID using a pseudo-random multicast node address.""" # Note: cannot use pgproto.UUID since it's UUID v1 node = int.from_bytes(os.urandom(6), byteorder='little') | (1 << 40) return UUID(uuid.uuid1(node=node).bytes) # type-ignores below because the first argument to uuid.UUID is a string # called `hex` which is not something that pgproto.UUID supports. def uuid4() -> uuid.UUID: """Generate a random UUID.""" return UUID(uuid.uuid4().bytes) def uuid5_bytes(namespace: uuid.UUID, name: bytes | bytearray) -> uuid.UUID: """Generate a UUID from the SHA-1 hash of a namespace UUID and a name.""" # Do the hashing ourselves because the stdlib version only supports str hasher = hashlib.sha1(namespace.bytes) hasher.update(name) return UUID(uuid.UUID(bytes=hasher.digest()[:16], version=5).bytes) def uuid5(namespace: uuid.UUID, name: str) -> uuid.UUID: """Generate a UUID from the SHA-1 hash of a namespace UUID and a name.""" return uuid5_bytes(namespace, name.encode("utf-8")) def from_bytes(data: bytes) -> uuid.UUID: return UUID(data) ================================================ FILE: edb/common/value_dispatch.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Callable, Protocol, Iterable import functools import inspect import types class _ValueDispatchCallable[_T](Protocol): registry: types.MappingProxyType[Any, Callable[..., _T]] def register( self, val: Any, ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: ... def register_for_all( self, val: Iterable[Any], ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: ... def __call__(__self, *args: Any, **kwargs: Any) -> _T: ... def value_dispatch[_T](func: Callable[..., _T]) -> _ValueDispatchCallable[_T]: """Like singledispatch() but dispatches by value of the first arg. Example: @value_dispatch def eat(fruit): return f"I don't want a {fruit}..." @eat.register('apple') def _eat_apple(fruit): return "I love apples!" @eat.register('eggplant') @eat.register('squash') def _eat_what(fruit): return f"I didn't know {fruit} is a fruit!" An alternative to applying multuple `register` decorators is to use the `register_for_all` helper: @eat.register_for_all({'eggplant', 'squash'}) def _eat_what(fruit): return f"I didn't know {fruit} is a fruit!" """ registry: dict[Any, Callable[..., _T]] = {} @functools.wraps(func) def wrapper(arg0: Any, *args: Any, **kwargs: Any) -> _T: try: delegate = registry[arg0] except KeyError: pass else: return delegate(arg0, *args, **kwargs) return func(arg0, *args, **kwargs) def register( value: Any, ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: if inspect.isfunction(value): raise TypeError( "value_dispatch.register() decorator requires a value") def wrap(func: Callable[..., _T]) -> Callable[..., _T]: if value in registry: raise ValueError( f'@value_dispatch: there is already a handler ' f'registered for {value!r}' ) registry[value] = func return func return wrap def register_for_all( values: Iterable[Any], ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: def wrap(func: Callable[..., _T]) -> Callable[..., _T]: for value in values: if value in registry: raise ValueError( f'@value_dispatch: there is already a handler ' f'registered for {value!r}' ) registry[value] = func return func return wrap wrapper.register = register # type: ignore [attr-defined] wrapper.register_for_all = register_for_all # type: ignore [attr-defined] return wrapper # type: ignore [return-value] ================================================ FILE: edb/common/verutils.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, NamedTuple import enum import re VERSION_PATTERN = re.compile(r""" ^ (?P[0-9]+(?:\.[0-9]+)*) (?P
        [-\.]?
        (?P(a|b|c|rc|alpha|beta|dev))
        [\.]?
        (?P[0-9]+)?
    )?
    (?:\+(?P[a-z0-9]+(?:[\.][a-z0-9]+)*))?
    $
""", re.X)


class VersionStage(enum.IntEnum):
    DEV = 0
    ALPHA = 10
    BETA = 20
    RC = 30
    FINAL = 40


class Version(NamedTuple):
    major: int
    minor: int
    stage: VersionStage
    stage_no: int
    local: tuple[str, ...]

    def __str__(self):
        ver = f'{self.major}.{self.minor}'
        if self.stage is not VersionStage.FINAL:
            ver += f'-{self.stage.name.lower()}.{self.stage_no}'
        if self.local:
            ver += f'{("+" + ".".join(self.local)) if self.local else ""}'

        return ver


def parse_version(ver: str) -> Version:
    v = VERSION_PATTERN.match(ver)
    if v is None:
        raise ValueError(f'cannot parse version: {ver}')
    local: list[str] = []
    if v.group('pre'):
        pre_l = v.group('pre_l')
        if pre_l in {'a', 'alpha'}:
            stage = VersionStage.ALPHA
        elif pre_l in {'b', 'beta'}:
            stage = VersionStage.BETA
        elif pre_l in {'c', 'rc'}:
            stage = VersionStage.RC
        elif pre_l in {'dev'}:
            stage = VersionStage.DEV
        else:
            raise ValueError(f'cannot determine release stage from {ver}')

        stage_no = int(v.group('pre_n'))
    else:
        stage = VersionStage.FINAL
        stage_no = 0
    if v.group('local'):
        local.extend(v.group('local').split('.'))

    release = [int(r) for r in v.group('release').split('.')]

    return Version(
        major=release[0],
        minor=release[1],
        stage=stage,
        stage_no=stage_no,
        local=tuple(local),
    )


def from_json(data: dict[str, Any]) -> Version:
    return Version(
        data['major'],
        data['minor'],
        VersionStage[data['stage'].upper()],
        data['stage_no'],
        tuple(data['local']),
    )


================================================
FILE: edb/common/view_patterns.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Hacky implementation of "view patterns" with Python match

A "view pattern" is one that does some transformation on the data
being matched before attempting to match it. This can be super useful,
as it allows writing "helper functions" for pattern matching.

We provide a class, ViewPattern, that can be subclassed with custom
`match` methods that performs a transformation on the scrutinee,
returning a transformed value or raising NoMatch if a match is not
possible.

For example, you could write:
  @dataclasses.dataclass
  class IntPair:
      lhs: int
      rhs: int

  class sum_view(ViewPattern[int], targets=(IntPair,)):
      @staticmethod
      def match(obj: object) -> int:
          match obj:
              case IntPair(lhs, rhs):
                  return lhs + rhs
          raise view_patterns.NoMatch

and then write code like:

  match IntPair(lhs=10, rhs=15):
      case sum_view(10):
          print("NO!")
      case sum_view(25):
          print("YES!")

----

To understand how this is implemented, we first discuss how pattern
matching a value `v` against a pattern like `C()` is performed:
 1. isinstance(v, C) is called. If it is False, the match fails
 2. C.__match_args__ is fetched; it should contain a tuple of
    attribute names to be used for positional matching.
 3. In our case, there should be only one attribute in it, `attr`,
    and v.attr is fetched. If fetching v.attr raises AttributeError,
    the match fails.

Our implementation strategy, then, is:
 a. Overload C's isinstance check by implementing `__instancecheck__`
    in a metaclass. Return True if the instance is an instance of
    one of the target classes.
 b. Make C's __match_args__ `('_view_result_',)`
 c. Arrange for `_view_result_` on the matched object to
    call match and return that value. If match raises NoMatch, transform
    it into AttributeError, so that the match fails.

Calling match from the *getter* lets us avoid the need to save the
value somewhere between steps a and c, but requires us to install one
method per view in the scrutinee's class.

Hopefully Python will add __match__ and we can delete all this code!
"""


class NoMatch(Exception):
    pass


class ViewPatternMeta(type):
    def __new__(mcls, name, bases, clsdict, *, targets=(), **kwargs):
        cls = super().__new__(mcls, name, bases, clsdict, **kwargs)

        @property  # type: ignore
        def _view_result_getter(self):
            try:
                return cls.match(self)
            except NoMatch:
                raise AttributeError

        fname = f'_view_result_{cls.__module__}.{cls.__qualname__}'
        mangled = fname.replace("___", "___3_").replace(".", "___")

        cls.__match_args__ = (mangled,)  # type: ignore
        cls._view_result_getter = _view_result_getter
        cls._targets = targets

        # Install the getter onto all target classes
        for target in targets:
            setattr(target, mangled, _view_result_getter)

        return cls

    def __instancecheck__(self, instance):
        return isinstance(instance, self._targets)


class ViewPattern[_T](metaclass=ViewPatternMeta):
    __match_args__ = ('result',)
    result: _T

    @classmethod
    def match(cls, obj: object) -> _T:
        raise NoMatch


================================================
FILE: edb/common/windowedsum.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2020-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

import collections
import time


class WindowedSum:
    """Keeps the sum of incremented values from the last minute.

    The sum is kept with second precision.

    >>> s = WindowedSum()
    >>> s += 1
    >>> s += 1
    >>> time.sleep(30)
    >>> s += 1
    >>> s += 1
    >>> int(s)
    4
    >>> time.sleep(30)
    >>> int(s)
    2
    """

    def __init__(self) -> None:
        self._maxlen = 60
        init: float = 0
        self._buckets = collections.deque([init], maxlen=self._maxlen)
        self._last_shift_at = 0.0

    def __iadd__(self, val: float) -> WindowedSum:
        self.shift()
        self._buckets[-1] += val
        return self

    def __int__(self) -> int:
        self.shift()
        return int(sum(self._buckets))

    def __float__(self) -> float:
        self.shift()
        return float(sum(self._buckets))

    def shift(self) -> None:
        now = time.monotonic()
        shift_by = int(min(now - self._last_shift_at, self._maxlen))
        if shift_by:
            self._buckets.extend(shift_by * [0])
            self._last_shift_at = now


================================================
FILE: edb/common/xdedent.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2011-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Library for building nicely indented output using f-strings.

textwrap.dedent allows removing extra indentation, but it performs
poorly when strings get interpolated in before dedenting, especially
if those strings were produced at a different level of indentation.

The `escape` function escapes a string for interpolation. Notionally,
the an interpolated escaped string has all of its leading indentation
stripped, and when it is interpolated in, lines after the first are
indented at the level the interpolated string appears in the output.

Interpolating an escaped `LINE_BLANK` deletes a newline that appears
directly before it. This can be useful when a branch might produce
nothing, but it is interpolated nonconditionally.

The `xdedent` function takes a string with interpolated escaped
strings and properly formats it.


The system uses escape delimeters for maintaining a nesting structure
in strings that the user produces. The `xdedent` function then parses apart
the nesting structure and interprets it.
Obviously, as with all schemes for
in-band signalling, all hell can break loose if the signals appear in
the input data unescaped.

Our signal sequences contain a null byte and both kinds of quote
character, so you should be fine unless the untrusted data:
 * has null bytes and
 * does not have any kind of quote character escaped in it somehow

"""

from __future__ import annotations


import textwrap
from typing import Any

_LEFT_ESCAPE = "\0'\"<<{<[<[{{}}]>]>}>>"
_ESCAPE_LEN = len(_LEFT_ESCAPE)
assert len(_RIGHT_ESCAPE) == _ESCAPE_LEN

LINE_BLANK = _LEFT_ESCAPE[:-1] + "||||||" + _RIGHT_ESCAPE[1:]


def escape(s: str) -> str:
    return _LEFT_ESCAPE + s.strip('\n') + _RIGHT_ESCAPE


Rep = list[str | list[Any]]


def _parse(s: str, start: int) -> tuple[Rep, int]:
    frags: Rep = []
    while start < len(s):
        nleft = s.find(_LEFT_ESCAPE, start)
        nright = s.find(_RIGHT_ESCAPE, start)
        if nleft == nright == -1:
            frags.append(s[start:])
            start = len(s)
        elif nleft != -1 and nleft < nright:
            if nleft > start:
                frags.append(s[start:nleft])
            subfrag, start = _parse(s, nleft + _ESCAPE_LEN)
            # If it is the special magic line blanking fragment,
            # delete up through the last newline. Otherwise collect it.
            if subfrag == [LINE_BLANK] and frags and isinstance(frags[-1], str):
                frags[-1] = frags[-1].rsplit('\n', 1)[0]
            else:
                frags.append(subfrag)
        else:
            assert nright >= 0
            frags.append(s[start:nright])
            start = nright + _ESCAPE_LEN
            break

    return frags, start


def _format_rep(rep: Rep) -> str:
    # cpython does some really dubious things to make appending in place
    # to a string efficient, and we depend on them here
    out_str = ""

    # TODO: I think there ought to be a more complicated algorithm
    # that builds a list of lines + indentation metadata and then
    # fixes it all up in one go?

    for frag in rep:
        if isinstance(frag, str):
            out_str += frag
        else:
            fixed_frag = _format_rep(frag)
            # If there is a newline in the final result, we need to indent
            # it to our current position on the current line.
            if '\n' in fixed_frag:
                last_nl = out_str.rfind('\n')
                indent = (
                    len(out_str) if last_nl < 0
                    else len(out_str) - last_nl - 1
                )
                # Indent all the lines but the first (since that goes
                # onto our current line)
                fixed_frag = textwrap.indent(fixed_frag, ' ' * indent)[indent:]

            out_str += fixed_frag

    return textwrap.dedent(out_str).removesuffix('\n')


def xdedent(s: str) -> str:
    # unlike regular dedent, xdedent trims a leading newline
    s = s.removeprefix('\n')
    parsed, _ = _parse(s, 0)
    res = _format_rep(parsed)
    assert _LEFT_ESCAPE not in res
    return res


================================================
FILE: edb/edgeql/__init__.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

from . import ast  # NOQA
from .tokenizer import Source, NormalizedSource  # NOQA
from .codegen import generate_source  # NOQA
from .parser import parse_fragment, parse_block, parse_query  # NOQA
from .parser.grammar import keywords  # NOQA
from .quote import quote_literal, quote_ident  # NOQA


================================================
FILE: edb/edgeql/ast.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

# Do not import "from typing *"; this module contains
# AST classes that name-clash with classes from the typing module.

import typing

from edb.common import enum as s_enum
from edb.common import ast, span

from . import qltypes

Span = span.Span

DDLCommand_T = typing.TypeVar(
    'DDLCommand_T',
    bound='DDLCommand',
    covariant=True,
)

ObjectDDL_T = typing.TypeVar(
    'ObjectDDL_T',
    bound='ObjectDDL',
    covariant=True,
)


Base_T = typing.TypeVar(
    'Base_T',
    bound='Base',
)


class SortOrder(s_enum.StrEnum):
    Asc = 'ASC'
    Desc = 'DESC'


SortAsc = SortOrder.Asc
SortDesc = SortOrder.Desc
SortDefault = SortAsc


class NonesOrder(s_enum.StrEnum):
    First = 'first'
    Last = 'last'


NonesFirst = NonesOrder.First
NonesLast = NonesOrder.Last


class CardinalityModifier(s_enum.StrEnum):
    Optional = 'OPTIONAL'
    Required = 'REQUIRED'


class DescribeGlobal(s_enum.StrEnum):
    Schema = 'SCHEMA'
    DatabaseConfig = 'DATABASE CONFIG'
    InstanceConfig = 'INSTANCE CONFIG'
    Roles = 'ROLES'

    def to_edgeql(self) -> str:
        return self.value


class Base(ast.AST):
    __abstract_node__ = True
    __ast_hidden__ = {'span', 'system_comment'}

    span: typing.Optional[Span] = None

    # System-generated comment.
    system_comment: typing.Optional[str] = None

    def dump_edgeql(self) -> None:
        from edb.common.debug import dump_edgeql

        dump_edgeql(self)


class GrammarEntryPoint(Base):
    """Mixin denoting nodes that are entry points for EdgeQL grammar"""
    __abstract_node__ = True


class OptionValue(Base):
    """An option value resulting from a syntax."""
    __abstract_node__ = True

    name: str


class OptionFlag(OptionValue):

    val: bool


class Options(Base):

    options: dict[str, OptionValue] = ast.field(factory=dict)

    def get_flag(self, k: str) -> OptionFlag:
        try:
            flag = self[k]
        except KeyError:
            return OptionFlag(name=k, val=False)
        else:
            assert isinstance(flag, OptionFlag)
            return flag

    def __getitem__(self, k: str) -> OptionValue:
        return self.options[k]

    def __iter__(self) -> typing.Iterator[str]:
        return iter(self.options)

    def __len__(self) -> int:
        return len(self.options)


class Expr(GrammarEntryPoint, Base):
    """Abstract parent for all query expressions."""

    __abstract_node__ = True


class Placeholder(Expr):
    """An interpolation placeholder used in expression templates."""

    name: str


class SortExpr(Base):
    path: Expr
    direction: typing.Optional[SortOrder] = None
    nones_order: typing.Optional[NonesOrder] = None


class Alias(Base):
    __abstract_node__ = True


class AliasedExpr(Alias):
    alias: str
    expr: Expr


class ModuleAliasDecl(Alias):
    module: str
    alias: typing.Optional[str]


class GroupingAtom(Base):
    __abstract_node__ = True


class BaseObjectRef(Base):
    __abstract_node__ = True


class ObjectRef(BaseObjectRef, GroupingAtom):
    name: str
    module: typing.Optional[str] = None
    itemclass: typing.Optional[qltypes.SchemaObjectClass] = None


class PseudoObjectRef(BaseObjectRef):
    '''anytype, anytuple or anyobject'''
    name: str


class Anchor(Expr):
    '''Identifier that resolves to some pre-compiled expression.
       For example in shapes, the anchor __subject__ refers to object that the
       shape is defined on.
    '''
    __abstract_node__ = True
    name: str


class IRAnchor(Anchor):
    has_dml: bool = False
    # Whether, when the anchor is referenced, to move the entire
    # referenced scope tree of the anchor to wherever it is referenced.
    #
    # This is important when the anchor is being used to substitute an
    # expression in, being used only once, and we want it to behave
    # like it was written at the point it is being substituted.
    #
    # (Sometimes we have anchors that get used repeatedly and which we
    # *want* to have be bound above, basically. I'd like to get rid of
    # all of those uses, though.)
    # (And also the scope tree.)
    move_scope: bool = False


class SpecialAnchor(Anchor):
    pass


class Cursor(Expr):
    '''A special node that halts compilation and returns all names visible in
       the current scope. Used for LSP completions.
    '''


class DetachedExpr(Expr):  # DETACHED Expr
    expr: Expr
    preserve_path_prefix: bool = False


class GlobalExpr(Expr):  # GLOBAL Name
    name: ObjectRef


class Index(Base):
    index: Expr


class Slice(Base):
    start: typing.Optional[Expr]
    stop: typing.Optional[Expr]


class Indirection(Expr):
    arg: Expr
    indirection: list[Index | Slice]


class BinOp(Expr):
    left: Expr
    op: str
    right: Expr

    rebalanced: bool = False
    set_constructor: bool = False


class WindowSpec(Base):
    orderby: list[SortExpr]
    partition: list[Expr]


class FunctionCall(Expr):
    func: tuple[str, str] | str
    args: list[Expr] = ast.field(factory=list)
    kwargs: dict[str, Expr] = ast.field(factory=dict)
    window: typing.Optional[WindowSpec] = None


class StrInterpFragment(Base):
    expr: Expr
    suffix: str


class StrInterp(Expr):
    prefix: str
    interpolations: list[StrInterpFragment]


class BaseConstant(Expr):
    """Constant (a literal value)."""
    __abstract_node__ = True


class Constant(BaseConstant):
    """Constant whose value we can store in a string."""
    kind: ConstantKind
    value: str

    @classmethod
    def string(cls, value: str, span: Span | None = None) -> Constant:
        return Constant(kind=ConstantKind.STRING, value=value, span=span)

    @classmethod
    def boolean(cls, b: bool, span: Span | None = None) -> Constant:
        return Constant(
            kind=ConstantKind.BOOLEAN, value=str(b).lower(), span=span
        )

    @classmethod
    def integer(cls, i: int) -> Constant:
        return Constant(kind=ConstantKind.INTEGER, value=str(i))

    @classmethod
    def float(cls, f: float) -> Constant:
        return Constant(kind=ConstantKind.FLOAT, value=str(f))

    @classmethod
    def make(cls, n: object) -> Constant:
        if isinstance(n, str):
            return cls.string(n)
        elif isinstance(n, bool):
            return cls.boolean(n)
        elif isinstance(n, int):
            return cls.integer(n)
        elif isinstance(n, float):
            return cls.float(n)
        else:
            raise AssertionError('unsupported constant type')


class ConstantKind(s_enum.StrEnum):
    STRING = 'STRING'
    BOOLEAN = 'BOOLEAN'
    INTEGER = 'INTEGER'
    FLOAT = 'FLOAT'
    BIGINT = 'BIGINT'
    DECIMAL = 'DECIMAL'


class BytesConstant(BaseConstant):
    value: bytes

    @classmethod
    def from_python(cls, s: bytes) -> BytesConstant:
        return BytesConstant(value=s)


class QueryParameter(Expr):
    name: str


class FunctionParameter(Expr):
    name: str


class UnaryOp(Expr):
    op: str
    operand: Expr


class TypeExpr(Base):
    __abstract_node__ = True

    name: typing.Optional[str] = None  # name is used for types in named tuples


class TypeOf(TypeExpr):
    expr: Expr


class TypeExprLiteral(TypeExpr):
    # Literal type exprs are used in enum declarations.
    val: Constant


class TypeName(TypeExpr):
    maintype: BaseObjectRef
    subtypes: typing.Optional[list[TypeExpr]] = None
    dimensions: typing.Optional[list[int]] = None


class TypeOpName(s_enum.StrEnum):
    OR = '|'
    AND = '&'


class TypeOp(TypeExpr):
    __rust_box__ = {'left', 'right'}

    left: TypeExpr
    op: TypeOpName
    right: TypeExpr


class FuncParamDecl(Base):
    name: str
    type: TypeExpr
    typemod: qltypes.TypeModifier = qltypes.TypeModifier.SingletonType
    kind: qltypes.ParameterKind
    default: typing.Optional[Expr] = None


class IsOp(Expr):
    left: Expr
    op: str
    right: TypeExpr


class TypeIntersection(Base):
    type: TypeExpr


class Ptr(Base):
    name: str
    direction: typing.Optional[str] = None
    # @ptr has type 'property'
    # .?>ptr has type 'optional'
    type: typing.Optional[typing.Literal['optional', 'property']] = None


class Splat(Base):
    """Represents a splat operation (expansion to all props/links) in shapes"""

    #: Expansion depth
    depth: int
    #: Source type expression, e.g in Type.**
    type: typing.Optional[TypeExpr] = None
    #: Type intersection on the source which would result
    #: in polymorphic expansion, e.g. [is Type].**
    intersection: typing.Optional[TypeIntersection] = None


PathElement = Expr | Ptr | TypeIntersection | ObjectRef | Splat


class Path(Expr, GroupingAtom):
    steps: list[PathElement]
    partial: bool = False
    allow_factoring: bool = False


class TypeCast(Expr):
    expr: Expr
    type: TypeExpr
    cardinality_mod: typing.Optional[CardinalityModifier] = None


class Introspect(Expr):
    type: TypeExpr


class IfElse(Expr):
    condition: Expr
    if_expr: Expr
    else_expr: Expr
    # Just affects pretty-printing
    python_style: bool = False


class TupleElement(Base):
    # This stores the name in another node instead of as a str just so
    # that the name can have a separate source context.
    name: Ptr
    val: Expr


class NamedTuple(Expr):
    elements: list[TupleElement]


class Tuple(Expr):
    elements: list[Expr]


class Array(Expr):
    elements: list[Expr]


class Set(Expr):
    elements: list[Expr]


# Statements
#

class Command(Base):
    """
    A top-level node that is evaluated by our server and
    cannot be a part of a sub expression.
    """

    __abstract_node__ = True
    aliases: typing.Optional[list[Alias]] = None


class Commands(GrammarEntryPoint, Base):
    commands: list[Command]


class SessionSetAliasDecl(Command):
    decl: ModuleAliasDecl


class SessionResetAliasDecl(Command):
    alias: str


class SessionResetModule(Command):
    pass


class SessionResetAllAliases(Command):
    pass


SessionCommand = (
    SessionSetAliasDecl
    | SessionResetAliasDecl
    | SessionResetModule
    | SessionResetAllAliases
)


class ShapeOp(s_enum.StrEnum):
    APPEND = 'APPEND'
    SUBTRACT = 'SUBTRACT'
    ASSIGN = 'ASSIGN'
    MATERIALIZE = 'MATERIALIZE'  # This is an internal implementation artifact


# Need indirection over ShapeOp to preserve the source context.
class ShapeOperation(Base):
    op: ShapeOp


class ShapeOrigin(s_enum.StrEnum):
    EXPLICIT = 'EXPLICIT'
    DEFAULT = 'DEFAULT'
    SPLAT_EXPANSION = 'SPLAT_EXPANSION'
    MATERIALIZATION = 'MATERIALIZATION'


class ShapeElement(Expr):
    expr: Path
    elements: typing.Optional[list[ShapeElement]] = None
    compexpr: typing.Optional[Expr] = None
    cardinality: typing.Optional[qltypes.SchemaCardinality] = None
    required: typing.Optional[bool] = None
    operation: ShapeOperation = ShapeOperation(op=ShapeOp.ASSIGN)
    origin: ShapeOrigin = ShapeOrigin.EXPLICIT

    where: typing.Optional[Expr] = None

    orderby: typing.Optional[list[SortExpr]] = None

    offset: typing.Optional[Expr] = None
    limit: typing.Optional[Expr] = None


class Shape(Expr):
    expr: typing.Optional[Expr]
    elements: list[ShapeElement]
    allow_factoring: bool = False


class Query(Expr, GrammarEntryPoint, Command):
    __abstract_node__ = True

    aliases: typing.Optional[list[Alias]] = None


class SelectQuery(Query):
    result_alias: typing.Optional[str] = None
    result: Expr

    where: typing.Optional[Expr] = None

    orderby: typing.Optional[list[SortExpr]] = None

    offset: typing.Optional[Expr] = None
    limit: typing.Optional[Expr] = None

    # This is a hack, indicating that rptr should be forwarded through
    # this select. Used when we generate implicit selects that need to
    # not interfere with linkprops.
    rptr_passthrough: bool = False

    implicit: bool = False


class GroupingIdentList(GroupingAtom, Base):
    elements: tuple[GroupingAtom, ...]


class GroupingElement(Base):
    __abstract_node__ = True


class GroupingSimple(GroupingElement):
    element: GroupingAtom


class GroupingSets(GroupingElement):
    sets: list[GroupingElement]


class GroupingOperation(GroupingElement):
    oper: str
    elements: list[GroupingAtom]


class GroupQuery(Query):
    subject_alias: typing.Optional[str] = None
    using: typing.Optional[list[AliasedExpr]]
    by: list[GroupingElement]

    subject: Expr


class InternalGroupQuery(Query):
    subject_alias: typing.Optional[str] = None
    using: typing.Optional[list[AliasedExpr]]
    by: list[GroupingElement]

    subject: Expr

    group_alias: str
    grouping_alias: typing.Optional[str]
    from_desugaring: bool = False

    result_alias: typing.Optional[str] = None
    result: Expr

    where: typing.Optional[Expr] = None

    orderby: typing.Optional[list[SortExpr]] = None


class InsertQuery(Query):
    subject: ObjectRef
    shape: list[ShapeElement]
    unless_conflict: typing.Optional[
        tuple[typing.Optional[Expr], typing.Optional[Expr]]
    ] = None


class UpdateQuery(Query):
    shape: list[ShapeElement]

    subject: Expr

    where: typing.Optional[Expr] = None


class DeleteQuery(Query):
    subject: Expr

    where: typing.Optional[Expr] = None

    orderby: typing.Optional[list[SortExpr]] = None

    offset: typing.Optional[Expr] = None
    limit: typing.Optional[Expr] = None


class ForQuery(Query):
    from_desugaring: bool = False
    has_union: bool = True  # whether UNION was used in the syntax

    optional: bool = False
    iterator: Expr
    iterator_alias: str

    result_alias: typing.Optional[str] = None
    result: Expr


# Transactions
#


class Transaction(Base):
    '''Abstract parent for all transaction operations.'''

    __abstract_node__ = True


class StartTransaction(Transaction):
    isolation: typing.Optional[qltypes.TransactionIsolationLevel] = None
    access: typing.Optional[qltypes.TransactionAccessMode] = None
    deferrable: typing.Optional[qltypes.TransactionDeferMode] = None


class CommitTransaction(Transaction):
    pass


class RollbackTransaction(Transaction):
    pass


class DeclareSavepoint(Transaction):

    name: str


class RollbackToSavepoint(Transaction):

    name: str


class ReleaseSavepoint(Transaction):

    name: str


# DDL
#


class DDL(Base):
    '''A mixin denoting DDL nodes.'''
    __abstract_node__ = True


class Position(DDL):
    ref: typing.Optional[ObjectRef] = None
    position: str


class DDLOperation(DDL):
    '''A change to schema'''

    __abstract_node__ = True
    commands: list[DDLOperation] = ast.field(factory=list)


class DDLCommand(DDLOperation, Command):
    __abstract_node__ = True


class DDLQuery(DDLCommand):
    '''A query wrapped in DDL. Appears in migrations.'''
    query: Query


class NonTransactionalDDLCommand(DDLCommand):
    __abstract_node__ = True


class AlterAddInherit(DDLOperation):
    position: typing.Optional[Position] = None
    bases: list[TypeName]


class AlterDropInherit(DDLOperation):
    bases: list[TypeName]


class OnTargetDelete(DDLOperation):
    cascade: typing.Optional[qltypes.LinkTargetDeleteAction]


class OnSourceDelete(DDLOperation):
    cascade: typing.Optional[qltypes.LinkSourceDeleteAction]


class SetField(DDLOperation):
    name: str
    value: Expr | TypeExpr | None
    #: Indicates that this AST originated from a special DDL syntax
    #: rather than from a generic `SET field := value` statement, and
    #: so must not be subject to the "allow_ddl_set" constraint.
    #: This attribute is also considered by the codegen to emit appropriate
    #: syntax.
    special_syntax: bool = False


class SetPointerType(SetField):
    name: str = 'target'
    special_syntax: bool = True
    value: typing.Optional[TypeExpr]
    cast_expr: typing.Optional[Expr] = None


class SetPointerCardinality(SetField):
    name: str = 'cardinality'
    special_syntax: bool = True
    conv_expr: typing.Optional[Expr] = None


class SetPointerOptionality(SetField):
    name: str = 'required'
    special_syntax: bool = True
    fill_expr: typing.Optional[Expr] = None


class ObjectDDL(DDLCommand):
    __abstract_node__ = True

    name: ObjectRef


class CreateObject(ObjectDDL):
    __abstract_node__ = True

    abstract: bool = False
    sdl_alter_if_exists: bool = False
    create_if_not_exists: bool = False


class AlterObject(ObjectDDL):
    __abstract_node__ = True


class DropObject(ObjectDDL):
    __abstract_node__ = True


class CreateExtendingObject(CreateObject):
    __abstract_node__ = True

    # final is not currently implemented, and the syntax is not
    # supported except in old dumps. We track it only to allow us to
    # error on it.
    final: bool = False
    bases: list[TypeName]


class Rename(ObjectDDL):
    new_name: ObjectRef

    @property
    def name(self) -> ObjectRef:  # type: ignore[override]  # mypy bug?
        return self.new_name


class NestedQLBlock(DDL):

    commands: list[DDLOperation]
    text: typing.Optional[str] = None


class MigrationCommand(DDLCommand):

    __abstract_node__ = True


class CreateMigration(CreateObject, MigrationCommand, GrammarEntryPoint):

    body: NestedQLBlock
    parent: typing.Optional[ObjectRef] = None
    metadata_only: bool = False

    # Sometimes the target SDL of a migration can be known in advance.
    # eg. when doing `start migration to`
    target_sdl: typing.Optional[str] = None


class CommittedSchema(DDL):
    pass


class StartMigration(MigrationCommand):

    target: Schema | CommittedSchema


class AbortMigration(MigrationCommand):
    pass


class PopulateMigration(MigrationCommand):
    pass


class AlterCurrentMigrationRejectProposed(MigrationCommand):
    pass


class DescribeCurrentMigration(MigrationCommand):

    language: qltypes.DescribeLanguage


class CommitMigration(MigrationCommand):
    pass


class AlterMigration(AlterObject, MigrationCommand):
    pass


class DropMigration(DropObject, MigrationCommand):
    pass


class ResetSchema(MigrationCommand):

    target: ObjectRef


class StartMigrationRewrite(MigrationCommand):
    pass


class AbortMigrationRewrite(MigrationCommand):
    pass


class CommitMigrationRewrite(MigrationCommand):
    pass


class UnqualifiedObjectCommand(ObjectDDL):

    __abstract_node__ = True


class GlobalObjectCommand(UnqualifiedObjectCommand):

    __abstract_node__ = True


class BranchType(s_enum.StrEnum):
    EMPTY = 'EMPTY'
    SCHEMA = 'SCHEMA'
    DATA = 'DATA'
    TEMPLATE = 'TEMPLATE'


class DatabaseCommand(GlobalObjectCommand, NonTransactionalDDLCommand):

    __abstract_node__ = True
    flavor: qltypes.SchemaObjectClass = qltypes.SchemaObjectClass.BRANCH


class CreateDatabase(CreateObject, DatabaseCommand):

    template: typing.Optional[ObjectRef] = None
    branch_type: BranchType


class AlterDatabase(AlterObject, DatabaseCommand):
    force: bool = False


class DropDatabase(DropObject, DatabaseCommand):
    force: bool = False


class ExtensionPackageCommand(GlobalObjectCommand):

    __abstract_node__ = True
    version: Constant


class CreateExtensionPackage(CreateObject, ExtensionPackageCommand):

    body: NestedQLBlock


class DropExtensionPackage(DropObject, ExtensionPackageCommand):
    pass


class ExtensionPackageMigrationCommand(GlobalObjectCommand):
    __abstract_node__ = True


class CreateExtensionPackageMigration(
    CreateObject, ExtensionPackageMigrationCommand
):
    from_version: Constant
    to_version: Constant
    body: NestedQLBlock


class DropExtensionPackageMigration(
    DropObject, ExtensionPackageMigrationCommand
):
    from_version: Constant
    to_version: Constant


class ExtensionCommand(UnqualifiedObjectCommand):
    __abstract_node__ = True


class CreateExtension(CreateObject, ExtensionCommand):
    version: typing.Optional[Constant] = None


class AlterExtension(DropObject, ExtensionCommand):
    version: typing.Optional[Constant] = None
    to_version: Constant


class DropExtension(DropObject, ExtensionCommand):
    version: typing.Optional[Constant] = None


class FutureCommand(UnqualifiedObjectCommand):

    __abstract_node__ = True


class CreateFuture(CreateObject, FutureCommand):
    pass


class DropFuture(DropObject, FutureCommand):
    pass


class ModuleCommand(UnqualifiedObjectCommand):

    __abstract_node__ = True


class CreateModule(ModuleCommand, CreateObject):
    pass


class AlterModule(ModuleCommand, AlterObject):
    pass


class DropModule(ModuleCommand, DropObject):
    pass


class RoleCommand(GlobalObjectCommand):
    __abstract_node__ = True


class CreateRole(CreateObject, RoleCommand):
    superuser: bool = False
    bases: list[TypeName]


class AlterRole(AlterObject, RoleCommand):
    pass


class DropRole(DropObject, RoleCommand):
    pass


class AnnotationCommand(ObjectDDL):

    __abstract_node__ = True


class CreateAnnotation(CreateExtendingObject, AnnotationCommand):
    type: typing.Optional[TypeExpr]
    inheritable: bool


class AlterAnnotation(AlterObject, AnnotationCommand):
    pass


class DropAnnotation(DropObject, AnnotationCommand):
    pass


class PseudoTypeCommand(ObjectDDL):

    __abstract_node__ = True


class CreatePseudoType(CreateObject, PseudoTypeCommand):
    pass


class ScalarTypeCommand(ObjectDDL):

    __abstract_node__ = True


class CreateScalarType(CreateExtendingObject, ScalarTypeCommand):
    pass


class AlterScalarType(AlterObject, ScalarTypeCommand):
    pass


class DropScalarType(DropObject, ScalarTypeCommand):
    pass


class PropertyCommand(ObjectDDL):

    __abstract_node__ = True


class CreateProperty(CreateExtendingObject, PropertyCommand):
    pass


class AlterProperty(AlterObject, PropertyCommand):
    pass


class DropProperty(DropObject, PropertyCommand):
    pass


class CreateConcretePointer(CreateObject):
    __abstract_node__ = True

    is_required: typing.Optional[bool] = None
    declared_overloaded: bool = False
    target: typing.Optional[Expr | TypeExpr]
    cardinality: qltypes.SchemaCardinality
    bases: list[TypeName]


class CreateConcreteUnknownPointer(CreateConcretePointer):
    pass


class AlterConcreteUnknownPointer(AlterObject, PropertyCommand):
    pass


class CreateConcreteProperty(CreateConcretePointer, PropertyCommand):
    pass


class AlterConcreteProperty(AlterObject, PropertyCommand):
    pass


class DropConcreteProperty(DropObject, PropertyCommand):
    pass


class ObjectTypeCommand(ObjectDDL):

    __abstract_node__ = True


class CreateObjectType(CreateExtendingObject, ObjectTypeCommand):
    pass


class AlterObjectType(AlterObject, ObjectTypeCommand):
    pass


class DropObjectType(DropObject, ObjectTypeCommand):
    pass


class AliasCommand(ObjectDDL):

    __abstract_node__ = True


class CreateAlias(CreateObject, AliasCommand):
    pass


class AlterAlias(AlterObject, AliasCommand):
    pass


class DropAlias(DropObject, AliasCommand):
    pass


class GlobalCommand(ObjectDDL):

    __abstract_node__ = True


class CreateGlobal(CreateObject, GlobalCommand):
    is_required: typing.Optional[bool] = None
    target: typing.Optional[Expr | TypeExpr]
    cardinality: typing.Optional[qltypes.SchemaCardinality]


class AlterGlobal(AlterObject, GlobalCommand):
    pass


class DropGlobal(DropObject, GlobalCommand):
    pass


class SetGlobalType(SetField):
    name: str = 'target'
    special_syntax: bool = True
    value: typing.Optional[TypeExpr]
    cast_expr: typing.Optional[Expr] = None
    reset_value: bool = False


class PermissionCommand(ObjectDDL):

    __abstract_node__ = True


class CreatePermission(CreateObject, PermissionCommand):
    pass


class AlterPermission(AlterObject, PermissionCommand):
    pass


class DropPermission(DropObject, PermissionCommand):
    pass


class LinkCommand(ObjectDDL):

    __abstract_node__ = True


class CreateLink(CreateExtendingObject, LinkCommand):
    pass


class AlterLink(AlterObject, LinkCommand):
    pass


class DropLink(DropObject, LinkCommand):
    pass


class CreateConcreteLink(
    CreateExtendingObject,
    CreateConcretePointer,
    LinkCommand,
):
    pass


class AlterConcreteLink(AlterObject, LinkCommand):
    pass


class DropConcreteLink(DropObject, LinkCommand):
    pass


class ConstraintCommand(ObjectDDL):

    __abstract_node__ = True


class CreateConstraint(
    CreateExtendingObject,
    ConstraintCommand,
):
    subjectexpr: typing.Optional[Expr]
    abstract: bool = True
    params: list[FuncParamDecl] = ast.field(factory=list)


class AlterConstraint(AlterObject, ConstraintCommand):
    pass


class DropConstraint(DropObject, ConstraintCommand):
    pass


class ConcreteConstraintOp(ConstraintCommand):

    __abstract_node__ = True
    args: list[Expr]
    subjectexpr: typing.Optional[Expr]
    except_expr: typing.Optional[Expr] = None


class CreateConcreteConstraint(ConcreteConstraintOp, CreateObject):
    delegated: bool = False


class AlterConcreteConstraint(ConcreteConstraintOp, AlterObject):
    pass


class DropConcreteConstraint(ConcreteConstraintOp, DropObject):
    pass


class IndexType(DDL):
    name: ObjectRef
    args: list[Expr] = ast.field(factory=list)
    kwargs: dict[str, Expr] = ast.field(factory=dict)


class IndexCommand(ObjectDDL):

    __abstract_node__ = True


class IndexCode(DDL):
    language: Language
    code: str


class CreateIndex(
    CreateExtendingObject,
    IndexCommand,
):
    kwargs: dict[str, Expr] = ast.field(factory=dict)
    index_types: list[IndexType]
    code: typing.Optional[IndexCode] = None
    params: list[FuncParamDecl] = ast.field(factory=list)


class AlterIndex(AlterObject, IndexCommand):
    pass


class DropIndex(DropObject, IndexCommand):
    pass


class IndexMatchCommand(ObjectDDL):

    __abstract_node__ = True
    valid_type: TypeName


class CreateIndexMatch(CreateObject, IndexMatchCommand):
    pass
    # XXX: we might want to have a code field to potentially customize the
    # default index code (to account for operator classes and similar custom
    # type-based syntax)


class DropIndexMatch(DropObject, IndexMatchCommand):
    pass


class ConcreteIndexCommand(IndexCommand):

    __abstract_node__ = True
    kwargs: dict[str, Expr] = ast.field(factory=dict)
    expr: Expr
    except_expr: typing.Optional[Expr] = None
    deferred: bool = False


class CreateConcreteIndex(ConcreteIndexCommand, CreateObject):
    pass


class AlterConcreteIndex(ConcreteIndexCommand, AlterObject):
    pass


class DropConcreteIndex(ConcreteIndexCommand, DropObject):
    pass


class CreateAnnotationValue(AnnotationCommand, CreateObject):
    value: Expr


class AlterAnnotationValue(AnnotationCommand, AlterObject):
    value: typing.Optional[Expr]


class DropAnnotationValue(AnnotationCommand, DropObject):
    pass


class AccessPolicyCommand(ObjectDDL):

    __abstract_node__ = True


class CreateAccessPolicy(CreateObject, AccessPolicyCommand):
    condition: typing.Optional[Expr]
    action: qltypes.AccessPolicyAction
    access_kinds: list[qltypes.AccessKind]
    expr: typing.Optional[Expr]


class SetAccessPerms(DDLOperation):
    access_kinds: list[qltypes.AccessKind]
    action: qltypes.AccessPolicyAction


class AlterAccessPolicy(AlterObject, AccessPolicyCommand):
    pass


class DropAccessPolicy(DropObject, AccessPolicyCommand):
    pass


class TriggerCommand(ObjectDDL):

    __abstract_node__ = True


class CreateTrigger(CreateObject, TriggerCommand):
    timing: qltypes.TriggerTiming
    kinds: list[qltypes.TriggerKind]
    scope: qltypes.TriggerScope
    expr: Expr
    condition: typing.Optional[Expr]


class AlterTrigger(AlterObject, TriggerCommand):
    pass


class DropTrigger(DropObject, TriggerCommand):
    pass


class RewriteCommand(ObjectDDL):
    """
    Mutation rewrite command.

    Note that kinds are basically identifiers of the command, so they need to
    be present for all commands.

    List of kinds is converted into multiple commands when creating delta
    commands in `_cmd_tree_from_ast`.
    """

    __abstract_node__ = True

    kinds: list[qltypes.RewriteKind]


class CreateRewrite(CreateObject, RewriteCommand):
    expr: Expr


class AlterRewrite(AlterObject, RewriteCommand):
    pass


class DropRewrite(DropObject, RewriteCommand):
    pass


class Language(s_enum.StrEnum):
    SQL = 'SQL'
    EdgeQL = 'EdgeQL'


class FunctionCode(DDL):
    language: Language = Language.EdgeQL
    code: typing.Optional[str] = None
    nativecode: typing.Optional[Expr] = None
    from_function: typing.Optional[str] = None
    from_expr: bool = False


class FunctionCommand(DDLCommand):

    __abstract_node__ = True
    params: list[FuncParamDecl] = ast.field(factory=list)


class CreateFunction(CreateObject, FunctionCommand):

    returning: TypeExpr
    code: FunctionCode
    nativecode: typing.Optional[Expr]
    returning_typemod: qltypes.TypeModifier = qltypes.TypeModifier.SingletonType


class AlterFunction(AlterObject, FunctionCommand):

    code: FunctionCode = FunctionCode  # type: ignore
    nativecode: typing.Optional[Expr]


class DropFunction(DropObject, FunctionCommand):
    pass


class OperatorCode(DDL):
    language: Language
    from_operator: typing.Optional[tuple[str, ...]]
    from_function: typing.Optional[tuple[str, ...]]
    from_expr: bool
    code: typing.Optional[str]


class OperatorCommand(DDLCommand):

    __abstract_node__ = True
    kind: qltypes.OperatorKind
    params: list[FuncParamDecl] = ast.field(factory=list)


class CreateOperator(CreateObject, OperatorCommand):
    returning: TypeExpr
    returning_typemod: qltypes.TypeModifier = qltypes.TypeModifier.SingletonType
    code: OperatorCode


class AlterOperator(AlterObject, OperatorCommand):
    pass


class DropOperator(DropObject, OperatorCommand):
    pass


class CastCode(DDL):
    language: Language
    from_function: str
    from_expr: bool
    from_cast: bool
    code: str


class CastCommand(ObjectDDL):

    __abstract_node__ = True
    from_type: TypeName
    to_type: TypeName


class CreateCast(CreateObject, CastCommand):
    code: CastCode
    allow_implicit: bool
    allow_assignment: bool


class AlterCast(AlterObject, CastCommand):
    pass


class DropCast(DropObject, CastCommand):
    pass


class OptionalExpr(Expr):
    """Internally used in ELSE clause of IF statement."""

    expr: Expr


#
# Config
#


class ConfigOp(Base):
    __abstract_node__ = True
    name: ObjectRef
    scope: qltypes.ConfigScope


class ConfigSet(ConfigOp):

    expr: Expr


class ConfigInsert(ConfigOp):

    shape: list[ShapeElement]


class ConfigReset(ConfigOp):
    where: typing.Optional[Expr] = None


#
# Describe
#


class DescribeStmt(Command):

    language: qltypes.DescribeLanguage
    object: ObjectRef | DescribeGlobal
    options: Options


#
# Explain
#


class ExplainStmt(Command):

    args: typing.Optional[NamedTuple]
    query: Query


#
# Administer
#


class AdministerStmt(Command):

    expr: FunctionCall


#
# SDL
#


class SDL(Base):
    '''A mixin denoting SDL nodes.'''

    __abstract_node__ = True


class ModuleDeclaration(SDL):
    # The 'name' is treated same as in CreateModule, for consistency,
    # since this declaration also implies creating a module.
    name: ObjectRef
    declarations: list[ObjectDDL | ModuleDeclaration]


class Schema(SDL, GrammarEntryPoint, Base):
    declarations: list[ObjectDDL | ModuleDeclaration]


#
# These utility functions work on EdgeQL AST nodes
#


def get_ddl_field_command(
    ddlcmd: DDLOperation,
    name: str,
) -> typing.Optional[SetField]:
    for cmd in ddlcmd.commands:
        if isinstance(cmd, SetField) and cmd.name == name:
            return cmd

    return None


def get_ddl_field_value(
    ddlcmd: DDLOperation,
    name: str,
) -> Expr | TypeExpr | None:
    cmd = get_ddl_field_command(ddlcmd, name)
    return cmd.value if cmd is not None else None


def get_ddl_subcommand(
    ddlcmd: DDLOperation,
    cmdtype: type[DDLOperation],
) -> typing.Optional[DDLOperation]:
    for cmd in ddlcmd.commands:
        if isinstance(cmd, cmdtype):
            return cmd
    else:
        return None


def has_ddl_subcommand(
    ddlcmd: DDLOperation,
    cmdtype: type[DDLOperation],
) -> bool:
    return bool(get_ddl_subcommand(ddlcmd, cmdtype))


ReturningQuery = SelectQuery | ForQuery | InternalGroupQuery


FilteringQuery = (
    SelectQuery | DeleteQuery | ShapeElement | UpdateQuery | ConfigReset
)


SubjectQuery = DeleteQuery | UpdateQuery | GroupQuery


OffsetLimitQuery = SelectQuery | DeleteQuery | ShapeElement


BasedOn = (
    AlterAddInherit
    | AlterDropInherit
    | CreateExtendingObject
    | CreateRole
    | CreateConcretePointer
)

CallableObjectCommand = (
    CreateConstraint | CreateIndex | FunctionCommand | OperatorCommand
)

# A node that can have a WITH block
Statement = Query | Command


================================================
FILE: edb/edgeql/codegen.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations
from typing import (
    Any,
    Callable,
    Optional,
    TypeVar,
    AbstractSet,
    Sequence,
    Match,
    TYPE_CHECKING,
)

import itertools
import re

from edb import errors
from edb.common.ast import codegen, base
from edb.common import typeutils

from . import ast as qlast
from . import quote as edgeql_quote
from . import qltypes


_BYTES_ESCAPE_RE = re.compile(b'[\\\'\x00-\x1f\x7e-\xff]')
_NON_PRINTABLE_RE = re.compile(
    r'[\u0000-\u0008\u000B\u000C\u000E-\u001F\u007F\u0080-\u009F\n]')
_ESCAPES = {
    b'\\': b'\\\\',
    b'\'': b'\\\'',
    b'\t': b'\\t',
    b'\n': b'\\n',
}


if TYPE_CHECKING:
    import enum
    Enum_T = TypeVar('Enum_T', bound=enum.Enum)


def _bytes_escape(match: Match[bytes]) -> bytes:
    char = match.group(0)
    try:
        return _ESCAPES[char]
    except KeyError:
        return b'\\x%02x' % char[0]


def param_to_str(ident: str) -> str:
    return '$' + edgeql_quote.quote_ident(
        ident, allow_reserved=True, allow_num=True)


def ident_to_str(ident: str, allow_num: bool=False) -> str:
    return '::'.join([
        edgeql_quote.quote_ident(part, allow_num=allow_num)
        for part in ident.split('::')
    ])


class EdgeQLSourceGeneratorError(errors.InternalServerError):
    pass


class EdgeSchemaSourceGeneratorError(errors.InternalServerError):
    pass


class EdgeQLSourceGenerator(codegen.SourceGenerator):

    def __init__(
        self,
        *args: Any,
        sdlmode: bool = False,
        descmode: bool = False,
        # Uppercase keywords for backwards compatibility with older migrations.
        uppercase: bool = False,
        unsorted: bool = False,
        limit_ref_classes:
            Optional[AbstractSet[qltypes.SchemaObjectClass]] = None,
        **kwargs: Any
    ) -> None:
        super().__init__(*args, **kwargs)
        self.sdlmode = sdlmode
        self.descmode = descmode
        self.uppercase = uppercase
        self.unsorted = unsorted
        self.limit_ref_classes = limit_ref_classes

    def visit(
        self, node: qlast.Base | list[qlast.Base], **kwargs: Any
    ) -> None:
        if isinstance(node, list):
            self.visit_list(node, terminator=';')
        else:
            method = 'visit_' + node.__class__.__name__
            visitor = getattr(self, method, self.generic_visit)
            visitor(node, **kwargs)

    def _kw_case(self, *kws: str) -> str:
        kwstring = ' '.join(kws)
        if self.uppercase:
            kwstring = kwstring.upper()
        else:
            kwstring = kwstring.lower()
        return kwstring

    def _write_keywords(self, *kws: str) -> None:
        self.write(self._kw_case(*kws))

    def _needs_parentheses(self, node: Any) -> bool:
        # The "parent" attribute is set by calling `_fix_parent_links`
        # before traversing the AST.  Since it's not an attribute that
        # can be inferred by static typing we ignore typing for this
        # function.
        parent: Optional[qlast.Base] = node._parent
        return (
            parent is not None
            and not isinstance(parent, (qlast.Commands, qlast.DDL))
            # Non-union FOR bodies can't have parens
            and not (
                isinstance(parent, qlast.ForQuery)
                and not parent.has_union
                and parent.result is node
            )
        )

    def generic_visit(
        self, node: qlast.Base, *args: Any, **kwargs: Any
    ) -> None:
        if isinstance(node, qlast.SDL):
            raise EdgeQLSourceGeneratorError(
                f'No method to generate code for {node.__class__.__name__}')
        else:
            raise EdgeQLSourceGeneratorError(
                f'No method to generate code for {node.__class__.__name__}'
            )

    def _block_ws(self, change: int, newlines: bool = True) -> None:
        """Block whitespace"""
        if newlines:
            self.indentation += change
            self.new_lines = 1
        else:
            self.write(' ')

    def _visit_aliases(self, node: qlast.Statement) -> None:
        if node.aliases:
            self._write_keywords('WITH')
            self._block_ws(1)
            if node.aliases:
                self.visit_list(node.aliases)
            self._block_ws(-1)

    def _visit_filter(
        self, node: qlast.FilteringQuery, newlines: bool = True
    ) -> None:
        if node.where:
            self._write_keywords('FILTER')
            self._block_ws(1, newlines)
            self.visit(node.where)
            self._block_ws(-1, newlines)

    def _visit_order(
        self,
        node: qlast.SelectQuery | qlast.DeleteQuery,
        newlines: bool = True,
    ) -> None:
        if node.orderby:
            self._write_keywords('ORDER BY')
            self._block_ws(1, newlines)
            self.visit_list(
                node.orderby,
                separator=self._kw_case(' THEN'), newlines=newlines
            )
            self._block_ws(-1, newlines)

    def _visit_offset_limit(
        self, node: qlast.OffsetLimitQuery, newlines: bool = True
    ) -> None:
        if node.offset is not None:
            self._write_keywords('OFFSET')
            self._block_ws(1, newlines)
            self.visit(node.offset)
            self._block_ws(-1, newlines)
        if node.limit is not None:
            self._write_keywords('LIMIT')
            self._block_ws(1, newlines)
            self.visit(node.limit)
            self._block_ws(-1, newlines)

    def visit_Commands(self, node: qlast.Commands) -> None:
        self.visit_list(node.commands, separator=';', terminator=';')

    def visit_AliasedExpr(self, node: qlast.AliasedExpr) -> None:
        self.write(ident_to_str(node.alias))
        self.write(' := ')
        self._block_ws(1)

        self.visit(node.expr)

        self._block_ws(-1)

    def visit_InsertQuery(self, node: qlast.InsertQuery) -> None:
        # need to parenthesise when INSERT appears as an expression
        parenthesise = self._needs_parentheses(node)

        if parenthesise:
            self.write('(')
        self._visit_aliases(node)
        self._write_keywords('INSERT')
        self._block_ws(1)
        self.visit(node.subject)
        self._block_ws(-1)

        if node.shape:
            self.indentation += 1
            self._visit_shape(node.shape)
            self.indentation -= 1

        if node.unless_conflict:
            on_expr, else_expr = node.unless_conflict
            self._write_keywords('UNLESS CONFLICT')

            if on_expr:
                self._write_keywords(' ON ')
                self.visit(on_expr)

                if else_expr:
                    self._write_keywords(' ELSE ')
                    self.visit(else_expr)

        if parenthesise:
            self.write(')')

    def visit_UpdateQuery(self, node: qlast.UpdateQuery) -> None:
        # need to parenthesise when UPDATE appears as an expression
        parenthesise = self._needs_parentheses(node)

        if parenthesise:
            self.write('(')
        self._visit_aliases(node)
        self._write_keywords('UPDATE')
        self._block_ws(1)
        self.visit(node.subject)
        self._block_ws(-1)

        self._visit_filter(node)

        self.new_lines = 1
        self._write_keywords('SET ')
        self._visit_shape(node.shape, always_emit_braces=True)

        if parenthesise:
            self.write(')')

    def visit_DeleteQuery(self, node: qlast.DeleteQuery) -> None:
        # need to parenthesise when DELETE appears as an expression
        parenthesise = self._needs_parentheses(node)

        if parenthesise:
            self.write('(')

        self._visit_aliases(node)

        self._write_keywords('DELETE')
        self._block_ws(1)
        self.visit(node.subject)
        self._block_ws(-1)
        self._visit_filter(node)
        self._visit_order(node)
        self._visit_offset_limit(node)
        if parenthesise:
            self.write(')')

    def visit_SelectQuery(self, node: qlast.SelectQuery) -> None:
        # XXX: need to parenthesise when SELECT appears as an expression,
        # the actual passed value is ignored.
        parenthesise = self._needs_parentheses(node)
        if node.implicit:
            parenthesise = parenthesise and bool(node.aliases)

        if parenthesise:
            self.write('(')

        if not node.implicit or node.aliases:
            self._visit_aliases(node)
            self._write_keywords('SELECT')
            self._block_ws(1)

        if node.result_alias:
            self.write(node.result_alias, ' := ')
        self.visit(node.result)
        if not node.implicit or node.aliases:
            self._block_ws(-1)
        else:
            self.write(' ')
        self._visit_filter(node)
        self._visit_order(node)
        self._visit_offset_limit(node)
        if parenthesise:
            self.write(')')

    def visit_ForQuery(self, node: qlast.ForQuery) -> None:
        # need to parenthesize when FOR appears as an expression
        parenthesise = self._needs_parentheses(node)

        if parenthesise:
            self.write('(')

        self._visit_aliases(node)

        self._write_keywords('FOR ')
        self.write(ident_to_str(node.iterator_alias))
        self._write_keywords(' IN ')
        self.visit(node.iterator)
        # guarantee an newline here
        self.new_lines = 1
        if node.has_union:
            self._write_keywords('UNION ')
            self._block_ws(1)
            self.visit(node.result)
            self.indentation -= 1
        else:
            self.visit(node.result)

        if parenthesise:
            self.write(')')

    def visit_GroupingIdentList(self, atom: qlast.GroupingIdentList) -> None:
        self.write('(')
        self.visit_list(atom.elements, newlines=False)
        self.write(')')

    def visit_GroupingSimple(self, node: qlast.GroupingSimple) -> None:
        self.visit(node.element)

    def visit_GroupingSets(self, node: qlast.GroupingSets) -> None:
        self.write('{')
        self.visit_list(node.sets, newlines=False)
        self.write('}')

    def visit_GroupingOperation(self, node: qlast.GroupingOperation) -> None:
        self._write_keywords(node.oper)
        self.write(' (')
        self.visit_list(node.elements, newlines=False)
        self.write(')')

    def visit_GroupQuery(
        self,
        node: qlast.GroupQuery | qlast.InternalGroupQuery,
        no_paren: bool = False
    ) -> None:
        # need to parenthesise when GROUP appears as an expression
        parenthesise = self._needs_parentheses(node) and not no_paren

        if parenthesise:
            self.write('(')

        self._visit_aliases(node)

        if isinstance(node, qlast.InternalGroupQuery):
            self._write_keywords('FOR ')
        self._write_keywords('GROUP')
        self._block_ws(1)
        if node.subject_alias:
            self.write(ident_to_str(node.subject_alias), ' := ')
        self.visit(node.subject)
        self._block_ws(-1)
        if node.using:
            self._write_keywords('USING')
            self._block_ws(1)
            self.visit_list(node.using, newlines=False)
            self._block_ws(-1)
        self._write_keywords('BY ')
        self.visit_list(node.by)

        if parenthesise:
            self.write(')')

    def visit_InternalGroupQuery(self, node: qlast.InternalGroupQuery) -> None:
        parenthesise = self._needs_parentheses(node)
        if parenthesise:
            self.write('(')

        self.visit_GroupQuery(node, no_paren=True)
        self._block_ws(0)
        self._write_keywords('IN ')
        self.write(ident_to_str(node.group_alias))
        if node.grouping_alias:
            self.write(', ')
            self.write(ident_to_str(node.grouping_alias))
        self.write(' ')
        self._block_ws(0)
        self._write_keywords('UNION ')
        self.visit(node.result)

        if node.where:
            self._write_keywords(' FILTER ')
            self.visit(node.where)

        if node.orderby:
            self._write_keywords(' ORDER BY ')
            self.visit_list(
                node.orderby,
                separator=self._kw_case(' THEN'), newlines=False
            )

        if parenthesise:
            self.write(')')

    def visit_ModuleAliasDecl(self, node: qlast.ModuleAliasDecl) -> None:
        if node.alias:
            self.write(ident_to_str(node.alias))
            self._write_keywords(' AS ')
        self._write_keywords('MODULE ')
        self.write(ident_to_str(node.module))

    def visit_SortExpr(self, node: qlast.SortExpr) -> None:
        self.visit(node.path)
        if node.direction:
            self.write(' ')
            self.write(node.direction)
        if node.nones_order:
            self._write_keywords(' EMPTY ')
            self.write(node.nones_order.upper())

    def visit_DetachedExpr(self, node: qlast.DetachedExpr) -> None:
        self._write_keywords('DETACHED ')
        self.visit(node.expr)

    def visit_GlobalExpr(self, node: qlast.GlobalExpr) -> None:
        self._write_keywords('GLOBAL ')
        self.visit(node.name)

    def visit_StrInterp(self, node: qlast.StrInterp) -> None:
        self.write("'")
        self.write(edgeql_quote.escape_string(node.prefix))
        for fragment in node.interpolations:
            self.write("\\(")
            self.visit(fragment.expr)
            self.write(")")
            self.write(edgeql_quote.escape_string(fragment.suffix))
        self.write("'")

    def visit_UnaryOp(self, node: qlast.UnaryOp) -> None:
        op = str(node.op).upper()
        self.write(op)
        if op.isalnum():
            self.write(' (')
        self.visit(node.operand)
        if op.isalnum():
            self.write(')')

    def visit_BinOp(self, node: qlast.BinOp) -> None:
        self.write('(')
        self.visit(node.left)
        self.write(' ' + str(node.op).upper() + ' ')
        self.visit(node.right)
        self.write(')')

    def visit_IsOp(self, node: qlast.IsOp) -> None:
        self.write('(')
        self.visit(node.left)
        self.write(' ' + str(node.op).upper() + ' ')
        self.visit(node.right)
        self.write(')')

    def visit_TypeOp(self, node: qlast.TypeOp) -> None:
        self.write('(')
        self.visit(node.left)
        self.write(' ' + str(node.op).upper() + ' ')
        self.visit(node.right)
        self.write(')')

    def visit_IfElse(self, node: qlast.IfElse) -> None:
        parent = node._parent  # type: ignore
        parenthesize = not (
            isinstance(parent, qlast.SelectQuery)
            and parent.implicit
            and isinstance(parent._parent, qlast.Commands)  # type: ignore
        )
        if parenthesize:
            self.write('(')
        if node.python_style:
            self.visit(node.if_expr)
            self._write_keywords(' IF ')
            self.visit(node.condition)
            self._write_keywords(' ELSE ')
            self.visit(node.else_expr)
        else:
            self._write_keywords('IF ')
            self.visit(node.condition)
            self._write_keywords(' THEN ')
            self.visit(node.if_expr)
            self._write_keywords(' ELSE ')
            self.visit(node.else_expr)
        if parenthesize:
            self.write(')')

    def visit_Tuple(self, node: qlast.Tuple) -> None:
        self.write('(')
        count = len(node.elements)
        self.visit_list(node.elements, newlines=False)
        if count == 1:
            self.write(',')

        self.write(')')

    def visit_Set(self, node: qlast.Set) -> None:
        self.write('{')
        self.visit_list(node.elements, newlines=False)
        self.write('}')

    def visit_Array(self, node: qlast.Array) -> None:
        self.write('[')
        self.visit_list(node.elements, newlines=False)
        self.write(']')

    def visit_NamedTuple(self, node: qlast.NamedTuple) -> None:
        self.write('(')
        self._block_ws(1)
        self.visit_list(node.elements, newlines=True, separator=',')
        self._block_ws(-1)
        self.write(')')

    def visit_TupleElement(self, node: qlast.TupleElement) -> None:
        self.visit(node.name)
        self.write(' := ')
        self.visit(node.val)

    def visit_Path(self, node: qlast.Path) -> None:
        for i, e in enumerate(node.steps):
            if i > 0 or node.partial:
                if (getattr(e, 'type', None) != 'property'
                        and not isinstance(e, qlast.TypeIntersection)):
                    self.write('.')

            if i == 0:
                if isinstance(
                    e,
                    (
                        qlast.ObjectRef,
                        qlast.Anchor,
                        qlast.Splat,
                        qlast.Ptr,
                        qlast.Set,
                        qlast.Tuple,
                        qlast.NamedTuple,
                        qlast.TypeIntersection,
                        qlast.QueryParameter,
                        qlast.FunctionParameter,
                    ),
                ):
                    self.visit(e)
                else:
                    self.write('(')
                    self.visit(e)
                    self.write(')')
            else:
                self.visit(e)

    def visit_Shape(self, node: qlast.Shape) -> None:
        if node.expr is not None:
            self.visit(node.expr)
            self.write(' ')
        self._visit_shape(node.elements)

    def _visit_shape(
        self,
        shape: Sequence[qlast.ShapeElement],
        always_emit_braces: bool=False,
    ) -> None:
        if shape or always_emit_braces:
            self.write('{')
            self._block_ws(1)
            self.visit_list(shape)
            self._block_ws(-1)
            self.write('}')

    def visit_Ptr(self, node: qlast.Ptr, *, quote: bool = True) -> None:
        if node.type == 'property':
            self.write('@')
        elif node.type == 'optional':
            self.write('?>')
        elif node.direction and node.direction != '>':
            self.write(node.direction)

        self.write(ident_to_str(node.name, allow_num=True))

    def visit_Splat(self, node: qlast.Splat) -> None:
        if node.type is not None:
            self.visit(node.type)
        if node.intersection is not None:
            self.visit(node.intersection)
        if node.type is not None or node.intersection is not None:
            self.write('.')
        if node.depth == 1:
            self.write('*')
        elif node.depth == 2:
            self.write('**')
        else:
            raise AssertionError(f"unexpected splat depth: {node.depth}")

    def visit_TypeIntersection(self, node: qlast.TypeIntersection) -> None:
        self._write_keywords('[IS ')
        self.visit(node.type)
        self.write(']')

    def visit_ShapeElement(self, node: qlast.ShapeElement) -> None:
        # PathSpec can only contain LinkExpr or LinkPropExpr,
        # and must not be quoted.

        quals = []
        if node.required is not None:
            if node.required:
                quals.append('required')
            else:
                quals.append('optional')

        if node.cardinality:
            quals.append(node.cardinality.as_ptr_qual())

        if quals:
            self.write(*quals, delimiter=' ')
            self.write(' ')

        if len(node.expr.steps) == 1:
            self.visit(node.expr)
        else:
            self.visit(node.expr.steps[0])
            if not isinstance(node.expr.steps[1], qlast.TypeIntersection):
                self.write('.')
            self.visit(node.expr.steps[1])
            if len(node.expr.steps) == 3:
                self.visit(node.expr.steps[2])

        if not node.compexpr and node.elements:
            self.write(': ')
            self._visit_shape(node.elements)

        if node.where:
            self._write_keywords(' FILTER ')
            self.visit(node.where)

        if node.orderby:
            self._write_keywords(' ORDER BY ')
            self.visit_list(
                node.orderby,
                separator=self._kw_case(' THEN'), newlines=False
            )

        if node.offset:
            self._write_keywords(' OFFSET ')
            self.visit(node.offset)

        if node.limit:
            self._write_keywords(' LIMIT ')
            self.visit(node.limit)

        if node.compexpr:
            if node.operation is None:
                raise AssertionError(
                    f'ShapeElement.operation is unexpectedly None'
                )

            if node.operation.op is qlast.ShapeOp.ASSIGN:
                self.write(' := ')
            elif node.operation.op is qlast.ShapeOp.APPEND:
                self.write(' += ')
            elif node.operation.op is qlast.ShapeOp.SUBTRACT:
                self.write(' -= ')
            else:
                raise NotImplementedError(
                    f'unexpected shape operation: {node.operation.op!r}'
                )
            self.visit(node.compexpr)

    def visit_QueryParameter(self, node: qlast.QueryParameter) -> None:
        self.write(param_to_str(node.name))

    def visit_FunctionParameter(self, node: qlast.FunctionParameter) -> None:
        self.write(param_to_str(node.name))

    def visit_Placeholder(self, node: qlast.Placeholder) -> None:
        self.write('\\(')
        self.write(node.name)
        self.write(')')

    def visit_Constant(self, node: qlast.Constant) -> None:
        if node.kind == qlast.ConstantKind.STRING:
            if not _NON_PRINTABLE_RE.search(node.value):
                for d in ("'", '"', '$$'):
                    if d not in node.value:
                        if '\\' in node.value and d != '$$':
                            self.write('r', d, node.value, d)
                        else:
                            self.write(d, node.value, d)
                        return
                self.write(edgeql_quote.dollar_quote_literal(node.value))
                return
            self.write(repr(node.value))
        else:
            self.write(node.value)

    def visit_BytesConstant(self, node: qlast.BytesConstant) -> None:
        val = _BYTES_ESCAPE_RE.sub(_bytes_escape, node.value)
        self.write("b'", val.decode('utf-8', 'backslashreplace'), "'")

    def visit_FunctionCall(self, node: qlast.FunctionCall) -> None:
        if isinstance(node.func, tuple):
            self.write(
                f'{ident_to_str(node.func[0])}::{ident_to_str(node.func[1])}')
        else:
            self.write(ident_to_str(node.func))

        self.write('(')

        for i, arg in enumerate(node.args):
            if i > 0:
                self.write(', ')
            self.visit(arg)

        if node.kwargs:
            if node.args:
                self.write(', ')

            for i, (name, arg) in enumerate(node.kwargs.items()):
                if i > 0:
                    self.write(', ')
                self.write(f'{edgeql_quote.quote_ident(name)} := ')
                self.visit(arg)

        self.write(')')

        if node.window:
            self._write_keywords(' OVER (')
            self._block_ws(1)

            if node.window.partition:
                self._write_keywords('PARTITION BY ')
                self.visit_list(node.window.partition, newlines=False)
                self.new_lines = 1

            if node.window.orderby:
                self._write_keywords('ORDER BY ')
                self.visit_list(
                    node.window.orderby, separator=self._kw_case(' THEN'))

            self._block_ws(-1)
            self.write(')')

    def visit_PseudoObjectRef(self, node: qlast.PseudoObjectRef) -> None:
        self.write(node.name)

    def visit_TypeCast(self, node: qlast.TypeCast) -> None:
        self.write('<')
        if node.cardinality_mod is qlast.CardinalityModifier.Optional:
            self.write('optional ')
        self.visit(node.type)
        self.write('>')
        self.visit(node.expr)

    def visit_Indirection(self, node: qlast.Indirection) -> None:
        self.write('(')
        self.visit(node.arg)
        self.write(')')
        for indirection in node.indirection:
            self.visit(indirection)

    def visit_Slice(self, node: qlast.Slice) -> None:
        self.write('[')
        if node.start:
            self.visit(node.start)
        self.write(':')
        if node.stop:
            self.visit(node.stop)
        self.write(']')

    def visit_Index(self, node: qlast.Index) -> None:
        self.write('[')
        self.visit(node.index)
        self.write(']')

    def visit_ObjectRef(self, node: qlast.ObjectRef) -> None:
        if node.itemclass:
            self.write(node.itemclass)
            self.write(' ')
        if node.module:
            self.write(ident_to_str(node.module))
            self.write('::')
        self.write(ident_to_str(node.name))

    def visit_SpecialAnchor(self, node: qlast.Anchor) -> None:
        self.write(node.name)

    def visit_IRAnchor(self, node: qlast.Anchor) -> None:
        self.write(node.name)

    def visit_TypeExprLiteral(self, node: qlast.TypeExprLiteral) -> None:
        self.visit(node.val)

    def visit_TypeName(self, node: qlast.TypeName) -> None:
        parenthesize = (
            isinstance(
                node._parent,  # type: ignore
                (qlast.IsOp, qlast.TypeOp, qlast.Introspect),
            )
            and node.subtypes is not None
        )
        if parenthesize:
            self.write('(')
        if node.name is not None:
            self.write(ident_to_str(node.name), ': ')

        self.visit(node.maintype)
        if node.subtypes is not None:
            self.write('<')
            self.visit_list(node.subtypes, newlines=False)
            if node.dimensions is not None:
                for dim in node.dimensions:
                    if dim is None:
                        self.write('[]')
                    else:
                        self.write('[', str(dim), ']')
            self.write('>')
        if parenthesize:
            self.write(')')

    def visit_Introspect(self, node: qlast.Introspect) -> None:
        self.write('INTROSPECT ')
        self.visit(node.type)

    def visit_TypeOf(self, node: qlast.TypeOf) -> None:
        self.write('TYPEOF ')
        self.visit(node.expr)

    # DDL nodes

    def visit_DDLQuery(self, node: qlast.DDLQuery) -> None:
        self.visit(node.query)

    def visit_Position(self, node: qlast.Position) -> None:
        self.write(node.position)
        if node.ref:
            self.write(' ')
            self.visit(node.ref)

    def _ddl_visit_bases(self, node: qlast.BasedOn) -> None:
        if node.bases:
            self._write_keywords(' EXTENDING ')
            self.visit_list(node.bases, newlines=False)

    PointerNode = TypeVar(
        'PointerNode',
        qlast.CreateConcretePointer,
        qlast.CreateLink,
        qlast.CreateProperty
    )

    def _ddl_add_pointer_bases(
        self,
        node: PointerNode,
    ) -> PointerNode:
        # We very carefully strained EXTENDING clauses out of subcommands
        # when parsing, but now that we're printing, we want to print it
        # back in the commands block. Do that by "faking" an extending
        # node in the subcommands of a scoped copy of this node.
        if node.bases:
            return node.replace(commands=(
                [qlast.AlterAddInherit(bases=node.bases)] + node.commands
            ))
        else:
            return node

    def _ddl_clean_up_commands(
        self,
        commands: Sequence[qlast.Base],
    ) -> Sequence[qlast.Base]:
        # Always omit orig_expr fields from output since we are
        # using the original expression in TEXT output
        # already.
        return [
            c for c in commands
            if (
                not isinstance(c, qlast.SetField)
                or not c.name.startswith('orig_')
            )
        ]

    def _ddl_visit_body(
        self,
        commands: Sequence[qlast.Base],
        group_by_system_comment: bool = False,
        *,
        allow_short: bool = False
    ) -> None:
        if self.limit_ref_classes:
            commands = [
                c for c in commands
                if (
                    not isinstance(c, qlast.ObjectDDL)
                    or c.name.itemclass in self.limit_ref_classes
                )
            ]

        commands = self._ddl_clean_up_commands(commands)
        if len(commands) == 1 and allow_short and not (
            isinstance(commands[0], qlast.ObjectDDL)
            and not isinstance(commands[0], qlast.Rename)
        ):
            self.write(' ')
            self.visit(commands[0])
        elif len(commands) > 0:
            self.write(' {')
            self._block_ws(1)

            if group_by_system_comment:
                sort_key = lambda c: (
                    c.system_comment or '',
                    c.name.name if isinstance(c.name, qlast.ObjectRef)
                    else c.name
                )
                group_key = lambda c: c.system_comment or ''
                if not self.unsorted:
                    commands = sorted(commands, key=sort_key)
                groups = itertools.groupby(commands, group_key)
                for i, (comment, items) in enumerate(groups):
                    if i > 0:
                        self.new_lines = 2
                    if comment:
                        self.write('#')
                        self.new_lines = 1
                        self.write(f'# {comment}')
                        self.new_lines = 1
                        self.write('#')
                        self.new_lines = 1
                    self.visit_list(list(items), terminator=';')
            elif self.descmode or self.sdlmode:
                def sort_desc_or_sdl(
                    c: qlast.Base,
                ) -> tuple[str, ...]:
                    # The sort key is a tuple of parts of the command which will
                    # be rendered to text.
                    #
                    # Commands will be ordered generally as:
                    # 1. General DDL Operations
                    # 2. Set Field Operations
                    # 3. Object DDL Operations
                    #
                    # Empty strings are used to achieve this general ordering.
                    #
                    # General DDL Operations are sorted by command class name.
                    # This works because these commands can each appear once per
                    # body.
                    # eg. ('', '', '', 'AlterAddInherit')
                    #
                    # Set Field Operations are sorted by field name.
                    # eg. ('', '', 'readonly')
                    #
                    # Object DDL Operations are sorted first by itemclass then
                    # name.
                    # eg. ('TYPE', 'Foo')
                    #
                    # For constraints and indexes, the expression and except
                    # expression are included.
                    # eg. ('CONSTRAINT', 'exclusive', '.a', '.b')

                    if isinstance(c, qlast.ObjectDDL):
                        if isinstance(c, qlast.ConcreteConstraintOp):
                            subject_expr = (
                                self.generate_isolated_text(c.subjectexpr)
                                if c.subjectexpr is not None else
                                ''
                            )
                            except_expr = (
                                self.generate_isolated_text(c.except_expr)
                                if c.except_expr is not None else
                                ''
                            )
                            return (
                                typeutils.not_none(c.name.itemclass),
                                c.name.name,
                                subject_expr,
                                except_expr,
                            )

                        if isinstance(c, qlast.ConcreteIndexCommand):
                            expr = (
                                self.generate_isolated_text(c.expr)
                                if c.expr is not None else
                                ''
                            )
                            except_expr = (
                                self.generate_isolated_text(c.except_expr)
                                if c.except_expr is not None else
                                ''
                            )
                            return (
                                typeutils.not_none(c.name.itemclass),
                                c.name.name,
                                expr,
                                except_expr,
                            )

                        return (c.name.itemclass or '', c.name.name)

                    if isinstance(c, qlast.SetField):
                        return ('', '', c.name)

                    return ('', '', '', c.__class__.__name__)

                if not self.unsorted:
                    commands = sorted(commands, key=sort_desc_or_sdl)

                self.visit_list(list(commands), terminator=';')

            else:
                self.visit_list(list(commands), terminator=';')

            self._block_ws(-1)
            self.write('}')

    def _visit_CreateObject(
        self,
        node: qlast.CreateObject,
        *object_keywords: str,
        after_name: Optional[Callable[[], None]] = None,
        render_commands: bool = True,
        unqualified: bool = False,
        named: bool = True,
        group_by_system_comment: bool = False,
    ) -> None:
        self._visit_aliases(node)
        if self.sdlmode:
            self.write(*[kw.lower() for kw in object_keywords], delimiter=' ')
        else:
            self._write_keywords('CREATE', *object_keywords)
        if named:
            self.write(' ')
            if unqualified or not node.name.module:
                self.write(ident_to_str(node.name.name))
            else:
                self.write(ident_to_str(node.name.module), '::',
                           ident_to_str(node.name.name))
        if after_name:
            after_name()
        if node.create_if_not_exists and not self.sdlmode:
            self._write_keywords(' IF NOT EXISTS')

        commands = node.commands
        if commands and render_commands:
            self._ddl_visit_body(
                commands,
                group_by_system_comment=group_by_system_comment,
            )

    def _visit_AlterObject(
        self,
        node: qlast.AlterObject,
        *object_keywords: str,
        allow_short: bool = True,
        after_name: Optional[Callable[[], None]] = None,
        unqualified: bool = False,
        named: bool = True,
        ignored_cmds: Optional[AbstractSet[qlast.DDLOperation]] = None,
        group_by_system_comment: bool = False,
    ) -> None:
        self._visit_aliases(node)
        if self.sdlmode:
            self.write(*[kw.lower() for kw in object_keywords], delimiter=' ')
        else:
            self._write_keywords('ALTER', *object_keywords)
        if named:
            self.write(' ')
            if unqualified or not node.name.module:
                self.write(ident_to_str(node.name.name))
            else:
                self.write(ident_to_str(node.name.module), '::',
                           ident_to_str(node.name.name))
        if after_name:
            after_name()

        commands = node.commands
        if ignored_cmds:
            commands = [cmd for cmd in commands
                        if cmd not in ignored_cmds]

        if commands:
            self._ddl_visit_body(
                commands,
                group_by_system_comment=group_by_system_comment,
                allow_short=allow_short,
            )

    def _visit_DropObject(
        self,
        node: qlast.DropObject,
        *object_keywords: str,
        unqualified: bool = False,
        after_name: Optional[Callable[[], None]] = None,
        named: bool = True,
    ) -> None:
        self._visit_aliases(node)
        self._write_keywords('DROP', *object_keywords)
        if named:
            self.write(' ')
            if unqualified or not node.name.module:
                self.write(ident_to_str(node.name.name))
            else:
                self.write(ident_to_str(node.name.module), '::',
                           ident_to_str(node.name.name))
        if after_name:
            after_name()
        if node.commands:
            self.write(' {')
            self._block_ws(1)
            self.visit_list(node.commands, terminator=';')
            self.indentation -= 1
            self.write('}')

    def visit_Rename(self, node: qlast.Rename) -> None:
        self._write_keywords('RENAME TO ')
        self.visit(node.new_name)

    def visit_AlterAddInherit(self, node: qlast.AlterAddInherit) -> None:
        if node.bases:
            self._write_keywords('EXTENDING ')
            self.visit_list(node.bases)
            if node.position is not None:
                self.write(' ')
                self.visit(node.position)

    def visit_AlterDropInherit(self, node: qlast.AlterDropInherit) -> None:
        if node.bases:
            self._write_keywords('DROP EXTENDING ')
            self.visit_list(node.bases)

    def visit_CreateDatabase(self, node: qlast.CreateDatabase) -> None:
        if node.flavor == qltypes.SchemaObjectClass.BRANCH:
            if node.branch_type == qlast.BranchType.EMPTY:
                self._visit_CreateObject(node, 'EMPTY BRANCH')
            else:

                def after_name() -> None:
                    self._write_keywords(' FROM ')
                    assert node.template
                    self.visit(node.template)
                self._visit_CreateObject(
                    node, f'{node.branch_type} BRANCH', after_name=after_name)
        elif node.flavor == qltypes.SchemaObjectClass.DATABASE:
            self._visit_CreateObject(node, 'DATABASE')
        else:
            raise EdgeQLSourceGeneratorError(
                f'unknown branch command flavor: {node.flavor!r}'
            )

    def visit_AlterDatabase(self, node: qlast.AlterDatabase) -> None:
        self._visit_AlterObject(node, node.flavor)

    def visit_DropDatabase(self, node: qlast.DropDatabase) -> None:
        self._visit_DropObject(node, node.flavor)

    def visit_CreateRole(self, node: qlast.CreateRole) -> None:
        after_name = lambda: self._ddl_visit_bases(node)
        keywords = []
        if node.superuser:
            keywords.append('SUPERUSER')
        keywords.append('ROLE')
        self._visit_CreateObject(node, *keywords, after_name=after_name)

    def visit_AlterRole(self, node: qlast.AlterRole) -> None:
        self._visit_AlterObject(node, 'ROLE')

    def visit_DropRole(self, node: qlast.DropRole) -> None:
        self._visit_DropObject(node, 'ROLE')

    def visit_CreateExtensionPackage(
        self,
        node: qlast.CreateExtensionPackage,
    ) -> None:
        self._write_keywords('CREATE EXTENSION PACKAGE')
        self.write(' ')
        self.write(ident_to_str(node.name.name))
        self._write_keywords(' VERSION ')
        self.visit(node.version)
        if node.body.text:
            self.write(' {')
            self._block_ws(1)
            self.write(self.indent_text(node.body.text))
            self._block_ws(-1)
            self.write('}')
        elif node.body.commands:
            self._ddl_visit_body(node.body.commands)

    def visit_DropExtensionPackage(
        self,
        node: qlast.DropExtensionPackage,
    ) -> None:
        def after_name() -> None:
            self._write_keywords(' VERSION ')
            self.visit(node.version)

        self._visit_DropObject(node, 'EXTENSION PACKAGE', after_name=after_name)

    def visit_CreateExtensionPackageMigration(
        self,
        node: qlast.CreateExtensionPackageMigration,
    ) -> None:
        self._write_keywords('CREATE EXTENSION PACKAGE')
        self.write(' ')
        self.write(ident_to_str(node.name.name))
        self._write_keywords(' MIGRATION FROM ')
        self._write_keywords(' VERSION ')
        self.visit(node.from_version)
        self._write_keywords(' TO ')
        self.visit(node.to_version)

        if node.body.text:
            self.write(' {')
            self._block_ws(1)
            self.write(self.indent_text(node.body.text))
            self._block_ws(-1)
            self.write('}')
        elif node.body.commands:
            self._ddl_visit_body(node.body.commands)

    def visit_DropExtensionPackageMigration(
        self,
        node: qlast.DropExtensionPackageMigration,
    ) -> None:
        self._write_keywords('DROP EXTENSION PACKAGE')
        self.write(' ')
        self.write(ident_to_str(node.name.name))
        self._write_keywords(' MIGRATION FROM ')
        self._write_keywords(' VERSION ')
        self.visit(node.from_version)
        self._write_keywords(' TO ')
        self.visit(node.to_version)

    def visit_CreateExtension(
        self,
        node: qlast.CreateExtension,
    ) -> None:
        if self.sdlmode or self.descmode:
            self._write_keywords('using extension')
        else:
            self._write_keywords('CREATE EXTENSION')
        self.write(' ')
        self.write(ident_to_str(node.name.name))
        if node.version is not None:
            self._write_keywords(' version ')
            self.visit(node.version)
        if node.commands:
            self._ddl_visit_body(node.commands)

    def visit_AlterExtension(self, node: qlast.AlterExtension) -> None:
        self._write_keywords('ALTER EXTENSION')
        self.write(' ')
        self.write(ident_to_str(node.name.name))
        self._write_keywords(' TO VERSION ')
        self.visit(node.to_version)

    def visit_DropExtension(
        self,
        node: qlast.DropExtension,
    ) -> None:
        self._visit_DropObject(node, 'EXTENSION')

    def visit_CreateFuture(
        self,
        node: qlast.CreateFuture,
    ) -> None:
        if self.sdlmode or self.descmode:
            self._write_keywords('using future')
        else:
            self._write_keywords('CREATE FUTURE')
        self.write(' ')
        self.write(ident_to_str(node.name.name))

    def visit_DropFuture(
        self,
        node: qlast.DropFuture,
    ) -> None:
        self._visit_DropObject(node, 'FUTURE')

    def visit_CreateMigration(self, node: qlast.CreateMigration) -> None:
        self._write_keywords('CREATE')
        if node.metadata_only:
            self._write_keywords(' APPLIED')
        self._write_keywords(' MIGRATION')
        if node.name is not None:
            self.write(' ')
            self.write(ident_to_str(node.name.name))
            self._write_keywords(' ONTO ')
            if node.parent is not None:
                self.write(ident_to_str(node.parent.name))
            else:
                self._write_keywords('initial')
        if node.body.text:
            self.write(' {')
            self._block_ws(1)
            self.write(self.indent_text(node.body.text))
            self._block_ws(-1)
            self.write('}')
        elif node.commands or node.body.commands:
            commands = [*node.commands, *node.body.commands]
            self._ddl_visit_body(commands)

    def visit_StartMigration(self, node: qlast.StartMigration) -> None:
        if isinstance(node.target, qlast.CommittedSchema):
            self._write_keywords('START MIGRATION TO COMMITTED SCHEMA')
        else:
            self._write_keywords('START MIGRATION TO {')
            self.new_lines = 1
            self.indentation += 1
            self.visit(node.target)
            self.indentation -= 1
            self.new_lines = 1
            self.write('}')

    def visit_CommitMigration(self, node: qlast.CommitMigration) -> None:
        self._write_keywords('COMMIT MIGRATION')

    def visit_AbortMigration(self, node: qlast.AbortMigration) -> None:
        self._write_keywords('ABORT MIGRATION')

    def visit_PopulateMigration(self, node: qlast.PopulateMigration) -> None:
        self._write_keywords('POPULATE MIGRATION')

    def visit_StartMigrationRewrite(
        self, node: qlast.StartMigrationRewrite
    ) -> None:
        self._write_keywords('START MIGRATION REWRITE')

    def visit_CommitMigrationRewrite(
        self, node: qlast.CommitMigrationRewrite
    ) -> None:
        self._write_keywords('COMMIT MIGRATION REWRITE')

    def visit_AbortMigrationRewrite(
        self, node: qlast.AbortMigrationRewrite
    ) -> None:
        self._write_keywords('ABORT MIGRATION REWRITE')

    def visit_DescribeCurrentMigration(
        self,
        node: qlast.DescribeCurrentMigration,
    ) -> None:
        self._write_keywords('DESCRIBE CURRENT MIGRATION AS ')
        self.write(node.language.upper())

    def visit_AlterCurrentMigrationRejectProposed(
        self,
        node: qlast.AlterCurrentMigrationRejectProposed,
    ) -> None:
        self._write_keywords('ALTER CURRENT MIGRATION REJECT PROPOSED')

    def visit_AlterMigration(self, node: qlast.AlterMigration) -> None:
        self._visit_AlterObject(node, 'MIGRATION')

    def visit_DropMigration(self, node: qlast.DropMigration) -> None:
        self._visit_DropObject(node, 'MIGRATION')

    def visit_ResetSchema(self, node: qlast.ResetSchema) -> None:
        self._write_keywords(f'RESET SCHEMA TO {node.target}')

    def visit_CreateModule(self, node: qlast.CreateModule) -> None:
        self._visit_CreateObject(node, 'MODULE')
        # Hack to handle the SDL version of this with an empty block.
        if self.sdlmode and not node.commands:
            self.write('{}')

    def visit_AlterModule(self, node: qlast.AlterModule) -> None:
        self._visit_AlterObject(node, 'MODULE')

    def visit_DropModule(self, node: qlast.DropModule) -> None:
        self._visit_DropObject(node, 'MODULE')

    def visit_CreateAlias(self, node: qlast.CreateAlias) -> None:
        if (
            len(node.commands) == 1
            and isinstance(node.commands[0], qlast.SetField)
            and node.commands[0].name == 'expr'
        ):

            self._visit_CreateObject(node, 'ALIAS', render_commands=False)
            self.write(' := (')
            self.new_lines = 1
            self.indentation += 1
            expr = node.commands[0].value
            assert expr is not None
            self.visit(expr)
            self.indentation -= 1
            self.new_lines = 1
            self.write(')')
        else:
            self._visit_CreateObject(node, 'ALIAS')

    def visit_AlterAlias(self, node: qlast.AlterAlias) -> None:
        self._visit_AlterObject(node, 'ALIAS')

    def visit_DropAlias(self, node: qlast.DropAlias) -> None:
        self._visit_DropObject(node, 'ALIAS')

    def visit_SetField(self, node: qlast.SetField) -> None:
        if node.special_syntax:
            if node.name == 'expr':
                if node.value is None:
                    self._write_keywords('RESET', 'EXPRESSION')
                else:
                    self._write_keywords('USING')
                    self.write(' (')
                    self.visit(node.value)
                    self.write(')')
            elif node.name == 'condition':
                if node.value is None:
                    self._write_keywords('RESET', 'WHEN')
                else:
                    self._write_keywords('WHEN')
                    self.write(' (')
                    self.visit(node.value)
                    self.write(')')
            elif node.name == 'target':
                if node.value is None:
                    self._write_keywords('RESET', 'TYPE')
                else:
                    self._write_keywords('SET', 'TYPE ')
                    self.visit(node.value)
            else:
                keywords = self._process_special_set(node)
                self.write(*keywords, delimiter=' ')
        elif node.value:
            if not self.sdlmode:
                self._write_keywords('SET ')
            self.write(f'{node.name} := ')
            if not isinstance(node.value, (qlast.BaseConstant, qlast.Set)):
                self.write('(')
            self.visit(node.value)
            if not isinstance(node.value, (qlast.BaseConstant, qlast.Set)):
                self.write(')')
        elif not self.sdlmode:
            self._write_keywords('RESET ')
            self.write(node.name)

    def _eval_bool_expr(
        self,
        expr: qlast.Expr | qlast.TypeExpr,
    ) -> bool:
        if (not isinstance(expr, qlast.Constant)
            or expr.kind != qlast.ConstantKind.BOOLEAN
        ):
            raise AssertionError(f'expected BooleanConstant, got {expr!r}')
        return expr.value == 'true'

    def _eval_enum_expr(
        self,
        expr: qlast.Expr | qlast.TypeExpr,
        enum_type: type[Enum_T],
    ) -> Enum_T:
        if (
            not isinstance(expr, qlast.Constant)
            or expr.kind != qlast.ConstantKind.STRING
        ):
            raise AssertionError(f'expected StringConstant, got {expr!r}')
        return enum_type(expr.value)

    def _process_special_set(self, node: qlast.SetField) -> list[str]:

        keywords: list[str] = []
        fname = node.name

        if fname == 'required':
            if node.value is None:
                keywords.extend(('RESET', 'OPTIONALITY'))
            elif self._eval_bool_expr(node.value):
                keywords.extend(('SET', 'REQUIRED'))
            else:
                keywords.extend(('SET', 'OPTIONAL'))
        elif fname == 'abstract':
            if node.value is None:
                keywords.extend(('RESET', 'ABSTRACT'))
            elif self._eval_bool_expr(node.value):
                keywords.extend(('SET', 'ABSTRACT'))
            else:
                keywords.extend(('SET', 'NOT', 'ABSTRACT'))
        elif fname == 'delegated':
            if node.value is None:
                keywords.extend(('RESET', 'DELEGATED'))
            elif self._eval_bool_expr(node.value):
                keywords.extend(('SET', 'DELEGATED'))
            else:
                keywords.extend(('SET', 'NOT', 'DELEGATED'))
        elif fname == 'cardinality':
            if node.value is None:
                keywords.extend(('RESET', 'CARDINALITY'))
            elif node.value:
                value = self._eval_enum_expr(
                    node.value, qltypes.SchemaCardinality)
                keywords.extend(('SET', value.to_edgeql()))
        elif fname == 'owned':
            if node.value is None:
                keywords.extend(('DROP', 'OWNED'))
            elif self._eval_bool_expr(node.value):
                keywords.extend(('SET', 'OWNED'))
            else:
                keywords.extend(('DROP', 'OWNED'))
        elif fname == 'deferred':
            if node.value is None:
                keywords.extend(('RESET', 'DEFERRED'))
            elif self._eval_bool_expr(node.value):
                keywords.extend(('SET', 'DEFERRED'))
            else:
                keywords.extend(('DROP', 'DEFERRED'))
        else:
            raise EdgeQLSourceGeneratorError(
                'unknown special field: {!r}'.format(fname))

        return keywords

    def visit_CreateAnnotation(self, node: qlast.CreateAnnotation) -> None:
        after_name = lambda: self._ddl_visit_bases(node)
        if node.inheritable:
            tag = 'ABSTRACT INHERITABLE ANNOTATION'
        else:
            tag = 'ABSTRACT ANNOTATION'
        self._visit_CreateObject(node, tag, after_name=after_name)

    def visit_AlterAnnotation(self, node: qlast.AlterAnnotation) -> None:
        self._visit_AlterObject(node, 'ABSTRACT ANNOTATION')

    def visit_DropAnnotation(self, node: qlast.DropAnnotation) -> None:
        self._visit_DropObject(node, 'ABSTRACT ANNOTATION')

    def visit_CreateAnnotationValue(
        self, node: qlast.CreateAnnotationValue
    ) -> None:
        if self.sdlmode:
            self._write_keywords('annotation ')
        else:
            self._write_keywords('CREATE ANNOTATION ')
        self.visit(node.name)
        self.write(' := ')
        self.visit(node.value)

    def visit_AlterAnnotationValue(
        self, node: qlast.AlterAnnotationValue
    ) -> None:
        self._write_keywords('ALTER ANNOTATION ')
        self.visit(node.name)
        self.write(' ')
        if node.value:
            self.write(':= ')
            self.visit(node.value)
        else:
            # The command should be a DROP OWNED
            assert len(node.commands) == 1
            self.visit(node.commands[0])

    def visit_DropAnnotationValue(
        self, node: qlast.DropAnnotationValue
    ) -> None:
        self._write_keywords('DROP ANNOTATION ')
        self.visit(node.name)

    def visit_CreateConstraint(self, node: qlast.CreateConstraint) -> None:
        def after_name() -> None:
            if node.params:
                self.write('(')
                self.visit_list(node.params, newlines=False)
                self.write(')')
            if node.subjectexpr:
                self._write_keywords(' ON ')
                self.write('(')
                self.visit(node.subjectexpr)
                self.write(')')

            self._ddl_visit_bases(node)

        self._visit_CreateObject(
            node, 'ABSTRACT CONSTRAINT', after_name=after_name
        )

    def visit_AlterConstraint(self, node: qlast.AlterConstraint) -> None:
        self._visit_AlterObject(node, 'ABSTRACT CONSTRAINT')

    def visit_DropConstraint(self, node: qlast.DropConstraint) -> None:
        self._visit_DropObject(node, 'ABSTRACT CONSTRAINT')

    def _after_constraint(self, node: qlast.ConcreteConstraintOp) -> None:
        if node.args:
            self.write('(')
            self.visit_list(node.args, newlines=False)
            self.write(')')
        if node.subjectexpr:
            self._write_keywords(' ON ')
            self.write('(')
            self.visit(node.subjectexpr)
            self.write(')')
        if node.except_expr:
            self._write_keywords(' EXCEPT ')
            self.write('(')
            self.visit(node.except_expr)
            self.write(')')

    def visit_CreateConcreteConstraint(
        self, node: qlast.CreateConcreteConstraint
    ) -> None:
        keywords = []
        if node.delegated:
            keywords.append('DELEGATED')
        keywords.append('CONSTRAINT')
        self._visit_CreateObject(
            node, *keywords, after_name=lambda: self._after_constraint(node)
        )

    def visit_AlterConcreteConstraint(
        self, node: qlast.AlterConcreteConstraint
    ) -> None:
        self._visit_AlterObject(
            node,
            'CONSTRAINT',
            allow_short=False,
            after_name=lambda: self._after_constraint(node),
        )

    def visit_DropConcreteConstraint(
        self, node: qlast.DropConcreteConstraint
    ) -> None:
        self._visit_DropObject(
            node, 'CONSTRAINT', after_name=lambda: self._after_constraint(node)
        )

    def _format_access_kinds(self, kinds: list[qltypes.AccessKind]) -> str:
        # Canonicalize the order, since the schema loses track
        kinds = [k for k in list(qltypes.AccessKind) if k in kinds]
        if kinds == list(qltypes.AccessKind):
            return 'all'
        skinds = ', '.join(str(kind).lower() for kind in kinds)
        skinds = skinds.replace("update", "update ")
        skinds = skinds.replace("update read, update write", "update")
        return skinds

    def visit_CreateAccessPolicy(self, node: qlast.CreateAccessPolicy) -> None:
        def after_name() -> None:
            if node.condition:
                self._block_ws(1)
                self._write_keywords('WHEN ')
                self.write('(')
                self.visit(node.condition)
                self.write(')')
                self._block_ws(-1)
            self._block_ws(1)
            self._write_keywords(str(node.action) + ' ')
            if node.access_kinds:
                self._write_keywords(
                    self._format_access_kinds(node.access_kinds) + ' ')
            if node.expr:
                self._write_keywords('USING ')
                self.write('(')
                self.visit(node.expr)
                self.write(')')

        keywords = []
        keywords.extend(['ACCESS', 'POLICY'])
        self._visit_CreateObject(
            node, *keywords, after_name=after_name, unqualified=True)
        # This is left hanging from after_name, so that subcommands
        # get double indented
        self.indentation -= 1

    def visit_SetAccessPerms(self, node: qlast.SetAccessPerms) -> None:
        self._write_keywords(str(node.action) + ' ')
        self._write_keywords(self._format_access_kinds(node.access_kinds))

    def visit_AlterAccessPolicy(self, node: qlast.AlterAccessPolicy) -> None:
        self._visit_AlterObject(node, 'ACCESS POLICY', unqualified=True)

    def visit_DropAccessPolicy(self, node: qlast.DropAccessPolicy) -> None:
        self._visit_DropObject(node, 'ACCESS POLICY', unqualified=True)

    def _format_trigger_kinds(self, kinds: list[qltypes.TriggerKind]) -> str:
        # Canonicalize the order, since the schema loses track
        kinds = [k for k in list(qltypes.TriggerKind) if k in kinds]
        skinds = ', '.join(str(kind).lower() for kind in kinds)
        return skinds

    def visit_CreateTrigger(self, node: qlast.CreateTrigger) -> None:
        def after_name() -> None:
            self._block_ws(1)
            self._write_keywords(str(node.timing) + ' ')
            self._write_keywords(
                self._format_trigger_kinds(node.kinds) + ' ')

            self._block_ws(0)
            self._write_keywords('FOR ' + str(node.scope) + ' ')

            if node.condition:
                self._block_ws(1)
                self._write_keywords('WHEN ')
                self.write('(')
                self.visit(node.condition)
                self.write(')')
                self._block_ws(-1)

            self._write_keywords('DO ')
            self.write('(')
            self.visit(node.expr)
            self.write(')')

        keywords = []
        keywords.extend(['TRIGGER'])
        self._visit_CreateObject(
            node, *keywords, after_name=after_name, unqualified=True)
        # This is left hanging from after_name, so that subcommands
        # get double indented
        self.indentation -= 1

    def visit_AlterTrigger(self, node: qlast.AlterTrigger) -> None:
        self._visit_AlterObject(node, 'TRIGGER', unqualified=True)

    def visit_DropTrigger(self, node: qlast.DropTrigger) -> None:
        self._visit_DropObject(node, 'TRIGGER', unqualified=True)

    def _format_rewrite_kinds(self, kinds: list[qltypes.RewriteKind]) -> str:
        # Canonicalize the order, since the schema loses track
        kinds = [k for k in list(qltypes.RewriteKind) if k in kinds]
        skinds = ', '.join(str(kind).lower() for kind in kinds)
        return skinds

    def visit_CreateRewrite(self, node: qlast.CreateRewrite) -> None:
        def an() -> None:
            self._block_ws(1)
            self._write_keywords(self._format_rewrite_kinds(node.kinds) + ' ')

            self._block_ws(0)

            self._write_keywords('USING ')
            self.write('(')
            self.visit(node.expr)
            self.write(')')

        keywords = []
        keywords.extend(['REWRITE'])
        self._visit_CreateObject(
            node, *keywords, after_name=an, unqualified=True, named=False
        )
        # This is left hanging from after_name, so that subcommands
        # get double indented
        self.indentation -= 1

    def visit_AlterRewrite(self, node: qlast.AlterRewrite) -> None:
        def an() -> None:
            self._block_ws(1)
            self._write_keywords(self._format_rewrite_kinds(node.kinds) + ' ')

        self._visit_AlterObject(
            node, 'REWRITE', after_name=an, unqualified=True, named=False
        )

    def visit_DropRewrite(self, node: qlast.DropRewrite) -> None:
        def an() -> None:
            self._block_ws(1)
            self._write_keywords(self._format_rewrite_kinds(node.kinds) + ' ')

        self._visit_DropObject(
            node, 'REWRITE', after_name=an, unqualified=True, named=False
        )

    def visit_CreateScalarType(self, node: qlast.CreateScalarType) -> None:
        keywords = []
        if node.abstract:
            keywords.append('ABSTRACT')
        keywords.append('SCALAR')
        keywords.append('TYPE')

        after_name = lambda: self._ddl_visit_bases(node)
        self._visit_CreateObject(node, *keywords, after_name=after_name)

    def visit_AlterScalarType(self, node: qlast.AlterScalarType) -> None:
        self._visit_AlterObject(node, 'SCALAR TYPE')

    def visit_DropScalarType(self, node: qlast.DropScalarType) -> None:
        self._visit_DropObject(node, 'SCALAR TYPE')

    def visit_CreatePseudoType(self, node: qlast.CreatePseudoType) -> None:
        keywords = []
        keywords.append('PSEUDO')
        keywords.append('TYPE')
        self._visit_CreateObject(node, *keywords)

    def visit_CreateProperty(self, node: qlast.CreateProperty) -> None:
        node = self._ddl_add_pointer_bases(node)
        self._visit_CreateObject(node, 'ABSTRACT PROPERTY')

    def visit_AlterProperty(self, node: qlast.AlterProperty) -> None:
        self._visit_AlterObject(node, 'ABSTRACT PROPERTY')

    def visit_DropProperty(self, node: qlast.DropProperty) -> None:
        self._visit_DropObject(node, 'ABSTRACT PROPERTY')

    def visit_CreateConcreteProperty(
        self, node: qlast.CreateConcreteProperty
    ) -> None:
        self.visit_CreateConcretePointer(node, kind='PROPERTY')

    def _process_AlterConcretePointer_for_SDL(
        self,
        node: qlast.AlterObject,
    ) -> tuple[list[str], frozenset[qlast.DDLOperation]]:
        keywords = []
        specials = set()

        for command in node.commands:
            if isinstance(command, qlast.SetField) and command.special_syntax:
                kw = self._process_special_set(command)
                specials.add(command)
                if kw[0] == 'SET':
                    keywords.append(kw[1])

        order = ['OPTIONAL', 'REQUIRED', 'SINGLE', 'MULTI']
        keywords.sort(key=lambda i: order.index(i))

        return keywords, frozenset(specials)

    def visit_AlterConcreteProperty(
        self, node: qlast.AlterConcreteProperty
    ) -> None:
        self.visit_AlterConcretePointer(node, kind='PROPERTY')

    def visit_DropConcreteProperty(
        self, node: qlast.DropConcreteProperty
    ) -> None:
        self._visit_DropObject(node, 'PROPERTY', unqualified=True)

    def visit_CreateLink(self, node: qlast.CreateLink) -> None:
        node = self._ddl_add_pointer_bases(node)
        self._visit_CreateObject(node, 'ABSTRACT LINK')

    def visit_AlterLink(self, node: qlast.AlterLink) -> None:
        self._visit_AlterObject(node, 'ABSTRACT LINK')

    def visit_DropLink(self, node: qlast.DropLink) -> None:
        self._visit_DropObject(node, 'ABSTRACT LINK')

    def visit_CreateConcretePointer(
        self,
        node: qlast.CreateConcretePointer,
        kind: Optional[str],
    ) -> None:
        keywords = []

        if self.sdlmode and node.declared_overloaded:
            keywords.append('OVERLOADED')
            if node.is_required:
                keywords.append('REQUIRED')
        else:
            if node.is_required is True:
                keywords.append("REQUIRED")
            elif node.is_required is False:
                keywords.append("OPTIONAL")
            # else: node.is_required is None
        if node.cardinality:
            keywords.append(node.cardinality.as_ptr_qual().upper())
        if kind:
            keywords.append(kind)

        def after_name() -> None:
            if node.target is not None:
                if isinstance(node.target, qlast.TypeExpr):
                    self.write(': ')
                    self.visit(node.target)
                elif pure_computable:
                    # computable
                    self.write(' := (')
                    self.visit(node.target)
                    self.write(')')

        node = self._ddl_add_pointer_bases(node)

        pure_computable = (
            len(node.commands) == 0
            or (
                len(node.commands) == 1
                and isinstance(node.commands[0], qlast.SetField)
                and node.commands[0].name == 'expr'
                and not isinstance(node.target, qlast.TypeExpr)
            )
        )

        self._visit_CreateObject(
            node,
            *keywords,
            after_name=after_name,
            unqualified=True,
            render_commands=not pure_computable,
        )

    def visit_CreateConcreteUnknownPointer(
        self, node: qlast.CreateConcreteLink
    ) -> None:
        self.visit_CreateConcretePointer(node, kind=None)

    def visit_AlterConcreteUnknownPointer(
        self, node: qlast.AlterConcreteLink
    ) -> None:
        self.visit_AlterConcretePointer(node, kind=None)

    def visit_CreateConcreteLink(self, node: qlast.CreateConcreteLink) -> None:
        self.visit_CreateConcretePointer(node, kind='LINK')

    def visit_AlterConcretePointer(
        self,
        node: qlast.AlterObject,
        kind: Optional[str],
    ) -> None:
        keywords = []
        ignored_cmds: set[qlast.DDLOperation] = set()

        after_name: Optional[Callable[[], None]]

        if self.sdlmode:
            if (not self.descmode
                    or not node.system_comment
                    or 'inherited from' not in node.system_comment):
                keywords.append('OVERLOADED')
            quals, ignored_cmds_r = self._process_AlterConcretePointer_for_SDL(
                node)
            keywords.extend(quals)
            ignored_cmds.update(ignored_cmds_r)

            type_cmd = None
            inherit_cmd = None
            for cmd in node.commands:
                if isinstance(cmd, qlast.SetPointerType):
                    ignored_cmds.add(cmd)
                    type_cmd = cmd
                elif isinstance(cmd, qlast.AlterAddInherit):
                    ignored_cmds.add(cmd)
                    inherit_cmd = cmd

            def after_name() -> None:
                if inherit_cmd:
                    self._ddl_visit_bases(inherit_cmd)
                if type_cmd is not None:
                    self.write(' -> ')
                    assert type_cmd.value
                    self.visit(type_cmd.value)
        else:
            after_name = None

        if kind:
            keywords.append(kind)
        self._visit_AlterObject(
            node,
            *keywords,
            ignored_cmds=ignored_cmds,
            allow_short=False,
            unqualified=True,
            after_name=after_name,
        )

    def visit_AlterConcreteLink(self, node: qlast.AlterConcreteLink) -> None:
        self.visit_AlterConcretePointer(node, kind='LINK')

    def visit_DropConcreteLink(self, node: qlast.DropConcreteLink) -> None:
        self._visit_DropObject(node, 'LINK', unqualified=True)

    def visit_SetPointerType(self, node: qlast.SetPointerType) -> None:
        if node.value is None:
            self._write_keywords('RESET TYPE')
        else:
            self._write_keywords('SET TYPE ')
            self.visit(node.value)
            if node.cast_expr is not None:
                self._write_keywords(' USING (')
                self.visit(node.cast_expr)
                self.write(')')

    def visit_SetPointerCardinality(
        self,
        node: qlast.SetPointerCardinality,
    ) -> None:
        if node.value is None:
            self._write_keywords('RESET CARDINALITY')
        else:
            value = self._eval_enum_expr(node.value, qltypes.SchemaCardinality)
            self._write_keywords('SET ')
            self.write(value.to_edgeql())
        if node.conv_expr is not None:
            self._write_keywords(' USING (')
            self.visit(node.conv_expr)
            self.write(')')

    def visit_SetPointerOptionality(
        self,
        node: qlast.SetPointerOptionality,
    ) -> None:
        if node.value is None:
            self._write_keywords('RESET OPTIONALITY')
        else:
            if self._eval_bool_expr(node.value):
                self._write_keywords('SET REQUIRED')
            else:
                self._write_keywords('SET OPTIONAL')
            if node.fill_expr is not None:
                self._write_keywords(' USING (')
                self.visit(node.fill_expr)
                self.write(')')

    def visit_OnTargetDelete(self, node: qlast.OnTargetDelete) -> None:
        if node.cascade is None:
            self._write_keywords('RESET ON TARGET DELETE')
        else:
            self._write_keywords('ON TARGET DELETE', node.cascade.to_edgeql())

    def visit_OnSourceDelete(self, node: qlast.OnSourceDelete) -> None:
        if node.cascade is None:
            self._write_keywords('RESET ON SOURCE DELETE')
        else:
            self._write_keywords('ON SOURCE DELETE', node.cascade.to_edgeql())

    def visit_CreateObjectType(self, node: qlast.CreateObjectType) -> None:
        keywords = []

        if node.abstract:
            keywords.append('ABSTRACT')
        keywords.append('TYPE')

        after_name = lambda: self._ddl_visit_bases(node)
        self._visit_CreateObject(node, *keywords, after_name=after_name)

    def visit_AlterObjectType(self, node: qlast.AlterObjectType) -> None:
        self._visit_AlterObject(node, 'TYPE')

    def visit_DropObjectType(self, node: qlast.DropObjectType) -> None:
        self._visit_DropObject(node, 'TYPE')

    def _after_index(self, node: qlast.ConcreteIndexCommand) -> None:
        if node.kwargs:
            self.write('(')
            for i, (name, arg) in enumerate(node.kwargs.items()):
                if i > 0:
                    self.write(', ')
                self.write(f'{edgeql_quote.quote_ident(name)} := ')
                self.visit(arg)
            self.write(')')

        self._write_keywords(' ON ')
        self.write('(')
        self.visit(node.expr)
        self.write(')')

        if node.except_expr:
            self._write_keywords(' EXCEPT ')
            self.write('(')
            self.visit(node.except_expr)
            self.write(')')

    def visit_IndexType(self, node: qlast.IndexType) -> None:
        self.visit(node.name)

        if node.kwargs:
            self.write('(')
            for i, (name, arg) in enumerate(node.kwargs.items()):
                if i > 0:
                    self.write(', ')
                self.write(f'{edgeql_quote.quote_ident(name)} := ')
                self.visit(arg)
            self.write(')')

    def visit_CreateIndex(self, node: qlast.CreateIndex) -> None:
        def after_name() -> None:
            if node.params:
                self.write('(')
                self.visit_list(node.params, newlines=False)
                self.write(')')

            if node.kwargs:
                self.write('(')
                for i, (name, arg) in enumerate(node.kwargs.items()):
                    if i > 0:
                        self.write(', ')
                    self.write(f'{edgeql_quote.quote_ident(name)} := ')
                    self.visit(arg)
                self.write(')')

            if node.index_types:
                self._write_keywords(' USING ')
                self.visit_list(node.index_types, newlines=False)

            self._ddl_visit_bases(node)

        self._visit_CreateObject(node, 'ABSTRACT INDEX', after_name=after_name)

    def visit_AlterIndex(self, node: qlast.AlterIndex) -> None:
        self._visit_AlterObject(node, 'ABSTRACT INDEX')

    def visit_DropIndex(self, node: qlast.DropIndex) -> None:
        self._visit_DropObject(node, 'ABSTRACT INDEX')

    def visit_IndexCode(self, node: qlast.IndexCode) -> None:
        self._write_keywords('USING', node.language)
        self.write(edgeql_quote.dollar_quote_literal(node.code))

    def visit_CreateConcreteIndex(
        self, node: qlast.CreateConcreteIndex
    ) -> None:
        keywords = ['DEFERRED', 'INDEX'] if node.deferred else ['INDEX']
        self._visit_CreateObject(
            node,
            *keywords,
            named=node.name.name != 'idx',
            after_name=lambda: self._after_index(node),
        )

    def visit_AlterConcreteIndex(self, node: qlast.AlterConcreteIndex) -> None:
        self._visit_AlterObject(
            node, 'INDEX', named=node.name.name != 'idx',
            after_name=lambda: self._after_index(node))

    def visit_DropConcreteIndex(self, node: qlast.DropConcreteIndex) -> None:
        self._visit_DropObject(
            node, 'INDEX', named=node.name.name != 'idx',
            after_name=lambda: self._after_index(node))

    def visit_CreateIndexMatch(self, node: qlast.CreateIndexMatch) -> None:
        def after_name() -> None:
            self.visit(node.valid_type)
            self._write_keywords(' using ')
            self.visit(node.name)

        self._visit_CreateObject(
            node, 'index match', 'for',
            named=False, after_name=after_name,
        )

    def visit_DropIndexMatch(self, node: qlast.DropIndexMatch) -> None:
        def after_name() -> None:
            self.visit(node.valid_type)
            self._write_keywords(' using ')
            self.visit(node.name)
        self._visit_DropObject(
            node, 'index match', 'for',
            named=False, after_name=after_name,
        )

    def visit_CreateOperator(self, node: qlast.CreateOperator) -> None:
        def after_name() -> None:
            self.write('(')
            self.visit_list(node.params, newlines=False)
            self.write(')')
            self.write(' -> ')
            self.write(node.returning_typemod.to_edgeql(), ' ')
            self.visit(node.returning)

            if node.abstract:
                return

            if node.commands:
                self.write(' {')
                self._block_ws(1)
                commands = self._ddl_clean_up_commands(node.commands)
                self.visit_list(commands, terminator=';')
                self.new_lines = 1
            else:
                self.write(' ')

            if node.code.from_operator:
                from_clause = f'USING {node.code.language} OPERATOR '
                self._write_keywords(from_clause)
                op, *types = node.code.from_operator
                op_str = op
                if types:
                    op_str += f'({",".join(types)})'
                self.write(f'{op_str!r}', ';')
            if node.code.from_function:
                from_clause = f'USING {node.code.language} OPERATOR '
                self._write_keywords(from_clause)
                op, *types = node.code.from_function
                op_str = op
                if types:
                    op_str += f'({",".join(types)})'
                self.write(f'{op_str!r}', ';')
            if node.code.from_expr:
                from_clause = f'USING {node.code.language} EXPRESSION'
                self._write_keywords(from_clause, ';')
            elif node.code.code:
                from_clause = f'USING {node.code.language} '
                self._write_keywords(from_clause)
                self.write(
                    edgeql_quote.dollar_quote_literal(
                        node.code.code),
                    ';'
                )

            self._block_ws(-1)
            if node.commands:
                self.write('}')

        op_type = []
        if node.abstract:
            op_type.append('ABSTRACT')
        if node.kind:
            op_type.append(node.kind.upper())
        op_type.append('OPERATOR')

        self._visit_CreateObject(node, *op_type, after_name=after_name,
                                 render_commands=False)

    def visit_AlterOperator(self, node: qlast.AlterOperator) -> None:
        def after_name() -> None:
            self.write('(')
            self.visit_list(node.params, newlines=False)
            self.write(')')

        op_type = []
        if node.kind:
            op_type.append(node.kind.upper())
        op_type.append('OPERATOR')
        self._visit_AlterObject(node, *op_type, after_name=after_name)

    def visit_DropOperator(self, node: qlast.DropOperator) -> None:
        def after_name() -> None:
            self.write('(')
            self.visit_list(node.params, newlines=False)
            self.write(')')

        op_type = []
        if node.kind:
            op_type.append(node.kind.upper())
        op_type.append('OPERATOR')
        self._visit_DropObject(node, *op_type, after_name=after_name)

    def _function_after_name(
        self, node: qlast.CreateFunction | qlast.AlterFunction
    ) -> None:
        self.write('(')
        self.visit_list(node.params, newlines=False)
        self.write(')')
        if isinstance(node, qlast.CreateFunction):
            self.write(' -> ')
            self._write_keywords(node.returning_typemod.to_edgeql(), '')
            self.visit(node.returning)

        if node.commands:
            self.write(' {')
            self._block_ws(1)
            commands = self._ddl_clean_up_commands(node.commands)
            self.visit_list(commands, terminator=';')
            self.new_lines = 1
        else:
            self.write(' ')

        had_using = True
        if node.code.from_function:
            from_clause = f'USING {node.code.language} FUNCTION '
            self._write_keywords(from_clause)
            self.write(f'{node.code.from_function!r}')
        elif node.code.language is qlast.Language.EdgeQL:
            if node.nativecode:
                self._write_keywords('USING')
                self.write(' (')
                self.visit(node.nativecode)
                self.write(')')
            elif node.code.code:
                self._write_keywords('USING')
                self.write(f' ({node.code.code})')
            else:
                had_using = False
        else:
            if node.code.from_expr:
                from_clause = f'USING {node.code.language} EXPRESSION'
                self._write_keywords(from_clause)
            elif node.code.code:
                from_clause = f'USING {node.code.language} '
                self._write_keywords(from_clause)
                self.write(
                    edgeql_quote.dollar_quote_literal(
                        node.code.code))
            else:
                from_clause = f'USING {node.code.language} '
                self._write_keywords(from_clause)

        if node.commands:
            self._block_ws(-1)
            if had_using:
                self.write(';')
            self.write('}')

    def visit_CreateFunction(self, node: qlast.CreateFunction) -> None:
        self._visit_CreateObject(
            node, 'FUNCTION',
            after_name=lambda: self._function_after_name(node),
            render_commands=False)

    def visit_AlterFunction(self, node: qlast.AlterFunction) -> None:
        def after_name() -> None:
            self.write('(')
            self.visit_list(node.params, newlines=False)
            self.write(')')

        self._visit_AlterObject(
            node, 'FUNCTION',
            after_name=lambda: self._function_after_name(node),
            ignored_cmds=set(node.commands))

    def visit_DropFunction(self, node: qlast.DropFunction) -> None:
        def after_name() -> None:
            self.write('(')
            self.visit_list(node.params, newlines=False)
            self.write(')')
        self._visit_DropObject(node, 'FUNCTION', after_name=after_name)

    def visit_FuncParamDecl(self, node: qlast.FuncParamDecl) -> None:
        kind = node.kind.to_edgeql()
        if kind:
            self._write_keywords(kind, '')

        if node.name is not None:
            self.write(ident_to_str(node.name), ': ')

        typemod = node.typemod.to_edgeql()
        if typemod:
            self._write_keywords(typemod, '')

        self.visit(node.type)

        if node.default:
            self.write(' = ')
            self.visit(node.default)

    def visit_CreateCast(self, node: qlast.CreateCast) -> None:
        def after_name() -> None:
            self.write(' ')
            self.visit(node.from_type)
            self._write_keywords(' to ')
            self.visit(node.to_type)

            self.write(' {')
            self._block_ws(1)

            if node.commands:
                commands = self._ddl_clean_up_commands(node.commands)
                self.visit_list(commands, terminator=';')
                self.new_lines = 1

            from_clause = f'USING {node.code.language} '
            code = ''

            if node.code.from_function:
                from_clause += 'FUNCTION'
                code = f'{node.code.from_function!r}'
            elif node.code.from_cast:
                from_clause += 'CAST'
            elif node.code.from_expr:
                from_clause += 'EXPRESSION'
            elif node.code.code:
                code = edgeql_quote.dollar_quote_literal(node.code.code)

            self._write_keywords(from_clause)
            if code:
                self.write(' ', code)
            self.write(';')
            self.new_lines = 1

            if node.allow_assignment:
                self._write_keywords('ALLOW ASSIGNMENT;')
                self.new_lines = 1
            if node.allow_implicit:
                self._write_keywords('ALLOW IMPLICIT;')
                self.new_lines = 1

            self._block_ws(-1)
            self.write('}')

        self._visit_CreateObject(
            node, 'CAST', 'FROM',
            named=False, after_name=after_name, render_commands=False
        )

    def visit_AlterCast(self, node: qlast.AlterCast) -> None:
        def after_name() -> None:
            self._write_keywords('FROM ')
            self.visit(node.from_type)
            self._write_keywords(' TO ')
            self.visit(node.to_type)
        self._visit_AlterObject(
            node,
            'CAST',
            named=False,
            after_name=after_name,
        )

    def visit_DropCast(self, node: qlast.DropCast) -> None:
        def after_name() -> None:
            self._write_keywords('FROM ')
            self.visit(node.from_type)
            self._write_keywords(' TO ')
            self.visit(node.to_type)
        self._visit_DropObject(
            node,
            'CAST',
            named=False,
            after_name=after_name,
        )

    def visit_SetGlobalType(self, node: qlast.SetGlobalType) -> None:
        if node.value is None:
            self._write_keywords('RESET TYPE')
        else:
            self._write_keywords('SET TYPE ')
            self.visit(node.value)
            if node.cast_expr is not None:
                self._write_keywords(' USING (')
                self.visit(node.cast_expr)
                self.write(')')
            elif node.reset_value:
                self._write_keywords(' RESET TO DEFAULT')

    def visit_CreateGlobal(self, node: qlast.CreateGlobal) -> None:
        keywords = []
        if node.is_required is True:
            keywords.append('REQUIRED')
        elif node.is_required is False:
            keywords.append('OPTIONAL')
        if node.cardinality:
            keywords.append(node.cardinality.as_ptr_qual().upper())
        keywords.append('GLOBAL')

        pure_computable = (
            len(node.commands) == 0
            or (
                len(node.commands) == 1
                and isinstance(node.commands[0], qlast.SetField)
                and node.commands[0].name == 'expr'
                and not isinstance(node.target, qlast.TypeExpr)
            )
        )

        def after_name() -> None:
            if node.target is not None:
                if isinstance(node.target, qlast.TypeExpr):
                    self.write(' -> ')
                    self.visit(node.target)
                elif pure_computable:
                    # computable
                    self.write(' := (')
                    self.visit(node.target)
                    self.write(')')

        self._visit_CreateObject(
            node, *keywords, after_name=after_name,
            render_commands=not pure_computable)

    def visit_AlterGlobal(self, node: qlast.AlterGlobal) -> None:
        self._visit_AlterObject(node, 'GLOBAL')

    def visit_DropGlobal(self, node: qlast.DropGlobal) -> None:
        self._visit_DropObject(node, 'GLOBAL')

    def visit_CreatePermission(self, node: qlast.CreatePermission) -> None:
        self._visit_CreateObject(node, 'PERMISSION')

    def visit_AlterPermission(self, node: qlast.AlterPermission) -> None:
        self._visit_AlterObject(node, 'PERMISSION')

    def visit_DropPermission(self, node: qlast.DropPermission) -> None:
        self._visit_DropObject(node, 'PERMISSION')

    def visit_ConfigSet(self, node: qlast.ConfigSet) -> None:
        if node.scope == qltypes.ConfigScope.GLOBAL:
            self._write_keywords('SET GLOBAL ')
        else:
            self._write_keywords('CONFIGURE ')
            self.write(node.scope.to_edgeql())
            self._write_keywords(' SET ')
        self.visit(node.name)
        self.write(' := ')
        self.visit(node.expr)

    def visit_ConfigInsert(self, node: qlast.ConfigInsert) -> None:
        self._write_keywords('CONFIGURE ')
        self.write(node.scope.to_edgeql())
        self._write_keywords(' INSERT ')
        self.visit(node.name)
        self.indentation += 1
        self._visit_shape(node.shape)
        self.indentation -= 1

    def visit_ConfigReset(self, node: qlast.ConfigReset) -> None:
        if node.scope == qltypes.ConfigScope.GLOBAL:
            self._write_keywords('RESET GLOBAL ')
        else:
            self._write_keywords('CONFIGURE ')
            self.write(node.scope.to_edgeql())
            self._write_keywords(' RESET ')
        self.visit(node.name)
        self._visit_filter(node)

    def visit_SessionSetAliasDecl(
        self, node: qlast.SessionSetAliasDecl
    ) -> None:
        self._write_keywords('SET ')
        if node.decl.alias:
            self._write_keywords('ALIAS ')
        self.visit_ModuleAliasDecl(node.decl)

    def visit_SessionResetAllAliases(
        self, node: qlast.SessionResetAllAliases
    ) -> None:
        self._write_keywords('RESET ALIAS *')

    def visit_SessionResetModule(self, node: qlast.SessionResetModule) -> None:
        self._write_keywords('RESET MODULE')

    def visit_SessionResetAliasDecl(
        self, node: qlast.SessionResetAliasDecl
    ) -> None:
        self._write_keywords('RESET ALIAS ')
        self.write(node.alias)

    def visit_StartTransaction(self, node: qlast.StartTransaction) -> None:
        self._write_keywords('START TRANSACTION')

        mods = []

        if node.isolation is not None:
            mods.append(f'ISOLATION {node.isolation.value}')

        if node.access is not None:
            mods.append(node.access.value)

        if node.deferrable is not None:
            mods.append(node.deferrable.value)

        if mods:
            self._write_keywords(' ' + ', '.join(mods))

    def visit_RollbackTransaction(
        self, node: qlast.RollbackTransaction
    ) -> None:
        self._write_keywords('ROLLBACK')

    def visit_CommitTransaction(self, node: qlast.CommitTransaction) -> None:
        self._write_keywords('COMMIT')

    def visit_DeclareSavepoint(self, node: qlast.DeclareSavepoint) -> None:
        self._write_keywords('DECLARE SAVEPOINT ')
        self.write(node.name)

    def visit_RollbackToSavepoint(
        self, node: qlast.RollbackToSavepoint
    ) -> None:
        self._write_keywords('ROLLBACK TO SAVEPOINT ')
        self.write(node.name)

    def visit_ReleaseSavepoint(self, node: qlast.ReleaseSavepoint) -> None:
        self._write_keywords('RELEASE SAVEPOINT ')
        self.write(node.name)

    def visit_DescribeStmt(self, node: qlast.DescribeStmt) -> None:
        self._write_keywords('DESCRIBE ')
        if isinstance(node.object, qlast.DescribeGlobal):
            self.write(node.object.to_edgeql())
        else:
            self.visit(node.object)
        if node.language:
            self._write_keywords(' AS ')
            self.write(node.language)
        if node.options:
            self.write(' ')
            self.visit(node.options)

    def visit_Options(self, node: qlast.Options) -> None:
        first = True
        for opt in node.options.values():
            if isinstance(opt, qlast.OptionFlag) and not opt.val:
                continue
            if first:
                self.write(' ')
            first = False

            self.write(opt.name)

    # SDL nodes
    def copy_generator(self) -> EdgeQLSourceGenerator:
        return self.__class__(
            indent_with=self.indent_with,
            add_line_information=self.add_line_information,
            pretty=self.pretty,
            unsorted=self.unsorted,
            sdlmode=True,
            descmode=self.descmode,
            limit_ref_classes=self.limit_ref_classes
        )

    def generate_isolated_text(self, node: qlast.Base) -> str:
        sdl_codegen = self.copy_generator()
        sdl_codegen.visit(node)
        return ''.join(sdl_codegen.result)

    def visit_Schema(self, node: qlast.Schema) -> None:
        sdl_codegen = self.copy_generator()
        sdl_codegen.indentation = self.indentation
        sdl_codegen.current_line = self.current_line
        sdl_codegen.visit_list(node.declarations, terminator=';')
        self.result.extend(sdl_codegen.result)

    def visit_ModuleDeclaration(self, node: qlast.ModuleDeclaration) -> None:
        self._write_keywords('module ')
        # the name is always unqualified here
        self.write(ident_to_str(node.name.name))
        self.write('{')
        self._block_ws(1)
        self.visit_list(node.declarations, terminator=';')
        self._block_ws(-1)
        self.write('}')

    @classmethod
    def to_source(  # type: ignore
        cls,
        node: qlast.Base | Sequence[qlast.Base],
        indent_with: str = ' ' * 4,
        add_line_information: bool = False,
        pretty: bool = True,
        sdlmode: bool = False,
        descmode: bool = False,
        # Uppercase keywords for backwards compatibility with older migrations.
        uppercase: bool = False,
        limit_ref_classes:
            Optional[AbstractSet[qltypes.SchemaObjectClass]] = None,
        unsorted: bool = False,
    ) -> str:
        if isinstance(node, (list, tuple)):
            for n in node:
                _fix_parent_links(n)
        else:
            assert isinstance(node, qlast.Base)
            _fix_parent_links(node)

        return super().to_source(
            node,
            indent_with,
            add_line_information,
            pretty,
            sdlmode=sdlmode,
            descmode=descmode,
            uppercase=uppercase,
            limit_ref_classes=limit_ref_classes,
            unsorted=unsorted,
        )


def _fix_parent_links(node: qlast.Base) -> qlast.Base:
    # NOTE: Do not use this legacy function in new code!
    # Using AST.parent is an anti-pattern. Instead write code
    # that uses singledispatch and maintains a proper context.

    node._parent = None  # type: ignore

    for _field, value in base.iter_fields(node):
        if isinstance(value, dict):
            for n in value.values():
                if base.is_ast_node(n):
                    _fix_parent_links(n)
                    n._parent = node

        elif typeutils.is_container(value):
            for n in value:
                if base.is_ast_node(n):
                    _fix_parent_links(n)
                    n._parent = node

        elif base.is_ast_node(value):
            _fix_parent_links(value)
            value._parent = node

    return node


generate_source = EdgeQLSourceGenerator.to_source


================================================
FILE: edb/edgeql/compiler/__init__.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL to IR compiler.

The purpose of this compilation phase is to produce a canonical, self-contained
representation of an EdgeQL expression, aka the IR.  The validity of the
expression and other schema-level checks and resolutions happen at this stage.
Once the IR generation is successful, the expression is considered valid.

The finalized IR consists of two tree structures: the expression tree and
the scope tree.

The *expression tree* is, essentially, another AST form that generally
resembles the overall shape of the original EdgeQL AST annotated with type
information and other metadata that is necessary to compile the IR into the
backend query language.  The *scope tree* tracks the visibility of variables
and determines how the aggregation functions are arranged in the expression.

Every EdgeQL expression is essentially a giant functional map-reduce
construct, or, Pythonically, a bunch of nested set comprehensions.
In those terms, the expression tree encodes expressions inside comprehensions,
and the scope tree determines how the comprehensions are nested, and at which
comprehension level the variables are defined.

The :mod:`ir.ast` and the :mod:`ir.scopetree` modules have more comments on
the organization of the IR expression and scope trees, correspondingly.

Operation
---------

The compiler has several entry points, are all in this file.  Each entry
point sets the compilation context and then calls the generic compilation
dispatch.  The compilation process is a straightforward EdgeQL AST traversal,
where most AST nodes have a dedicated handler function, and the routing is
done by singledispatch based on the AST node type.

Context
-------

The compilation context object is passed to the vast majority of the compiler
functions and contains the information necessary to correctly process an AST
node in a given situation.  It is organized as a stack that resembles a
ChainMap, albeit the elements are objects instead of dicts, and the chaining
logic is controlled by the context itself.  See context.py for details.

Organization
------------

The compiler code is organized into the following modules (in rough order
of control flow):

__init__.py
    This file, contains compiler entry points that initialize
    the compilation context and call into compilation dispatch.

stmt.py
    Handlers for statement expressions, like ``SELECT``, ``INSERT``.

expr.py
    Handlers for the majority of expressions that aren't statements or
    that are handled elsewhere.

func.py
    Handlers for function calls and operator expressions.

casts.py
    Handlers for type cast expressions.

clauses.py
    Handlers for common statement clauses like ``FILTER`` and ``ORDER BY``.

polyres.py
    Logic for function, operator, and cast lookup via multiple dispatch
    and generic type specialization.

config.py
    Handlers for ``CONFIGURE`` commands.

setgen.py
    Functions to generate ``ir.ast.Set`` nodes and process path expressions.

viewgen.py
    Functions that process shape expressions into view types.

typegen.py
    Helpers for type expressions.

context.py
    Compilation context definition.

stmtctx.py
    Functions to set up the overall compilation context as well as finalize
    the result IR.

pathctx.py
    PathId and scope helpers.

schemactx.py
    Helpers that interface with the schema, such as object lookup and
    derivation.

astutils.py
    Various helpers for EdgeQL AST analysis.

dispatch.py
    Compiler singledispatch decorator (separate module for ease of import).

"""


from __future__ import annotations
from typing import (
    Any,
    Callable,
    Optional,
    AbstractSet,
    Mapping,
    Sequence,
    cast,
    overload,
    TYPE_CHECKING,
)

# WARNING: this package is in a tight import loop with various modules
# in edb.schema, so no direct imports from either this package or
# edb.schema are allowed at the top-level.  If absolutely necessary,
# use the lazy-loading mechanism.

import functools

from edb.edgeql import ast as qlast
from edb.edgeql import codegen as qlcodegen
from edb.edgeql import qltypes
from edb.edgeql import parser as qlparser

from edb.common import debug

from .options import CompilerOptions as CompilerOptions  # "as" for reexport

if TYPE_CHECKING:
    from edb.schema import schema as s_schema

    from edb.ir import ast as irast
    from edb.ir import staeval as ireval

    from . import dispatch as dispatch_mod
    from . import inference as inference_mod
    from . import normalization as norm_mod
    from . import stmtctx as stmtctx_mod
else:
    # Modules will be loaded lazily in _load().
    dispatch_mod = None
    inference_mod = None
    irast = None
    ireval = None
    norm_mod = None
    stmtctx_mod = None


#: Compiler modules lazy-load guard.
_LOADED = False


def compiler_entrypoint[Tf: Callable[..., Any]](func: Tf) -> Tf:
    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        if not _LOADED:
            _load()
        return func(*args, **kwargs)

    return cast(Tf, wrapper)


@overload
def compile_ast_to_ir(
    tree: qlast.Expr | qlast.Command,
    schema: s_schema.Schema,
    *,
    script_info: Optional[irast.ScriptInfo] = None,
    options: Optional[CompilerOptions] = None,
) -> irast.Statement:
    pass


@overload
def compile_ast_to_ir(
    tree: qlast.ConfigOp,
    schema: s_schema.Schema,
    *,
    script_info: Optional[irast.ScriptInfo] = None,
    options: Optional[CompilerOptions] = None,
) -> irast.ConfigCommand:
    pass


@overload
def compile_ast_to_ir(
    tree: qlast.Base,
    schema: s_schema.Schema,
    *,
    script_info: Optional[irast.ScriptInfo] = None,
    options: Optional[CompilerOptions] = None,
) -> irast.Statement | irast.ConfigCommand:
    pass


@compiler_entrypoint
def compile_ast_to_ir(
    tree: qlast.Base,
    schema: s_schema.Schema,
    *,
    script_info: Optional[irast.ScriptInfo] = None,
    options: Optional[CompilerOptions] = None,
) -> irast.Statement | irast.ConfigCommand:
    """Compile given EdgeQL AST into Gel IR.

    This is the normal compiler entry point.  It assumes that *tree*
    represents a complete statement.

    Args:
        tree:
            EdgeQL AST.

        schema:
            Schema instance.  Must contain definitions for objects
            referenced by the AST *tree*.

        options:
            An optional :class:`edgeql.compiler.options.CompilerOptions`
            instance specifying compilation options.

        allow_writing_protected_ptrs:
            If ``True``, allows protected object properties or links to
            be overwritten in `INSERT` shapes.

    Returns:
        An instance of :class:`ir.ast.Command`.  Most frequently, this
        would be an instance of :class:`ir.ast.Statement`.
    """
    if options is None:
        options = CompilerOptions()

    if debug.flags.edgeql_compile or debug.flags.edgeql_compile_edgeql_text:
        debug.header('EdgeQL Text')
        debug.dump_code(qlcodegen.generate_source(tree, pretty=True))

    if debug.flags.edgeql_compile or debug.flags.edgeql_compile_edgeql_ast:
        debug.header('Compiler Options')
        debug.dump(options.__dict__)
        debug.header('EdgeQL AST')
        debug.dump(tree, schema=schema)

    ctx = stmtctx_mod.init_context(schema=schema, options=options)

    if isinstance(tree, qlast.Expr) and ctx.implicit_limit:
        tree = qlast.SelectQuery(result=tree, implicit=True)
        tree.limit = qlast.Constant.integer(ctx.implicit_limit)

    if not script_info:
        script_info = stmtctx_mod.preprocess_script([tree], ctx=ctx)

    ctx.env.script_params = script_info.params

    ir_set = dispatch_mod.compile(tree, ctx=ctx)
    ir_expr = stmtctx_mod.fini_expression(ir_set, ctx=ctx)

    if debug.flags.edgeql_compile or debug.flags.edgeql_compile_scope:
        debug.header('Scope Tree')
        print(ctx.path_scope.pdebugformat())

        # Also build and dump a mapping from scope ids to
        # paths that appear directly at them.
        scopes: dict[int, set[irast.PathId]] = {
            k: set() for k in
            sorted(node.unique_id
                   for node in ctx.path_scope.descendants
                   if node.unique_id)
        }
        for ir_set in ctx.env.set_types:
            if ir_set.path_scope_id and ir_set.path_scope_id in scopes:
                scopes[ir_set.path_scope_id].add(ir_set.path_id)
        debug.dump(scopes)

    if debug.flags.edgeql_compile or debug.flags.edgeql_compile_ir:
        debug.header('Gel IR')
        debug.dump(ir_expr, schema=getattr(ir_expr, 'schema', None))

    return ir_expr


@compiler_entrypoint
def compile_ast_fragment_to_ir(
    tree: qlast.Base,
    schema: s_schema.Schema,
    *,
    options: Optional[CompilerOptions] = None,
) -> irast.Statement:
    """Compile given EdgeQL AST fragment into Gel IR.

    Unlike :func:`~compile_ast_to_ir` above, this does not assume
    that the AST *tree* is a complete statement.  The expression
    doesn't even have to resolve to a specific type.

    Args:
        tree:
            EdgeQL AST fragment.

        schema:
            Schema instance.  Must contain definitions for objects
            referenced by the AST *tree*.

        options:
            An optional :class:`edgeql.compiler.options.CompilerOptions`
            instance specifying compilation options.

    Returns:
        An instance of :class:`ir.ast.Statement`.
    """
    if options is None:
        options = CompilerOptions()

    ctx = stmtctx_mod.init_context(schema=schema, options=options)
    ir_set = dispatch_mod.compile(tree, ctx=ctx)

    result_type = ctx.env.set_types[ir_set]

    return irast.Statement(
        expr=ir_set,
        schema=ctx.env.schema,
        stype=result_type,
        dml_exprs=ctx.env.dml_exprs,
        views={},
        params=[],
        globals=[],
        required_permissions=set(),
        server_param_conversions=[],
        server_param_conversion_params=[],
        # These values are nonsensical, but ideally the caller does not care
        cardinality=qltypes.Cardinality.UNKNOWN,
        multiplicity=qltypes.Multiplicity.EMPTY,
        volatility=qltypes.Volatility.Volatile,
        view_shapes={},
        view_shapes_metadata={},
        schema_refs=frozenset(),
        schema_ref_exprs=None,
        scope_tree=ctx.path_scope,
        type_rewrites={},
        singletons=[],
        triggers=(),
        warnings=tuple(ctx.env.warnings),
        unsafe_isolation_dangers=tuple(ctx.env.unsafe_isolation_dangers),
    )


@compiler_entrypoint
def preprocess_script(
    stmts: Sequence[qlast.Base],
    schema: s_schema.Schema,
    *,
    options: CompilerOptions,
) -> irast.ScriptInfo:
    ctx = stmtctx_mod.init_context(schema=schema, options=options)
    return stmtctx_mod.preprocess_script(stmts, ctx=ctx)


def evaluate_to_python_val(
    expr: str,
    schema: s_schema.Schema,
    *,
    modaliases: Optional[Mapping[Optional[str], str]] = None,
) -> Any:
    """Evaluate the given EdgeQL string as a constant expression.

    Args:
        expr:
            EdgeQL expression as a string.

        schema:
            Schema instance.  Must contain definitions for objects
            referenced by *expr*.

        modaliases:
            Module name resolution table.  Useful when this EdgeQL
            expression is part of some other construct, such as a
            DDL statement.

    Returns:
        The result of the evaluation as a Python value.

    Raises:
        If the expression is not constant, or is otherwise not supported by
        the const evaluator, the function will raise
        :exc:`ir.staeval.UnsupportedExpressionError`.
    """
    tree = qlparser.parse_fragment(expr)
    return evaluate_ast_to_python_val(tree, schema, modaliases=modaliases)


def evaluate_ir_statement_to_python_val(
    ir: irast.Statement,
) -> Any:
    """Evaluate the given EdgeQL IR AST as a constant expression.

    Args:
        ir:
            EdgeQL IR Statement AST.

    Returns:
        The result of the evaluation as a Python value and the associated IR.

    Raises:
        If the expression is not constant, or is otherwise not supported by
        the const evaluator, the function will raise
        :exc:`ir.staeval.UnsupportedExpressionError`.
    """
    return ireval.evaluate_to_python_val(ir.expr, schema=ir.schema)


def evaluate_ast_to_python_val_and_ir(
    tree: qlast.Base,
    schema: s_schema.Schema,
    *,
    modaliases: Optional[Mapping[Optional[str], str]] = None,
) -> tuple[Any, irast.Statement]:
    """Evaluate the given EdgeQL AST as a constant expression.

    Args:
        tree:
            EdgeQL AST.

        schema:
            Schema instance.  Must contain definitions for objects
            referenced by AST *tree*.

        modaliases:
            Module name resolution table.  Useful when this EdgeQL
            expression is part of some other construct, such as a
            DDL statement.

    Returns:
        The result of the evaluation as a Python value and the associated IR.

    Raises:
        If the expression is not constant, or is otherwise not supported by
        the const evaluator, the function will raise
        :exc:`ir.staeval.UnsupportedExpressionError`.
    """
    if modaliases is None:
        modaliases = {}
    ir = compile_ast_fragment_to_ir(
        tree,
        schema,
        options=CompilerOptions(
            modaliases=modaliases,
        ),
    )
    return ireval.evaluate_to_python_val(ir.expr, schema=ir.schema), ir


def evaluate_ast_to_python_val(
    tree: qlast.Base,
    schema: s_schema.Schema,
    *,
    modaliases: Optional[Mapping[Optional[str], str]] = None,
) -> Any:
    """Evaluate the given EdgeQL AST as a constant expression.

    Args:
        tree:
            EdgeQL AST.

        schema:
            Schema instance.  Must contain definitions for objects
            referenced by AST *tree*.

        modaliases:
            Module name resolution table.  Useful when this EdgeQL
            expression is part of some other construct, such as a
            DDL statement.

    Returns:
        The result of the evaluation as a Python value.

    Raises:
        If the expression is not constant, or is otherwise not supported by
        the const evaluator, the function will raise
        :exc:`ir.staeval.UnsupportedExpressionError`.
    """
    return evaluate_ast_to_python_val_and_ir(
        tree, schema, modaliases=modaliases
    )[0]


@compiler_entrypoint
def compile_constant_tree_to_ir(
    const: qlast.BaseConstant,
    schema: s_schema.Schema,
    *,
    styperef: Optional[irast.TypeRef] = None,
) -> irast.Expr:
    """Compile an EdgeQL constant into an IR ConstExpr.

    Args:
        const:
            An EdgeQL AST representing a constant.

        schema:
            A schema instance.  Must contain the definition of the
            constant type.

        styperef:
            Optionally overrides an IR type descriptor for the returned
            ConstExpr.  If not specified, the inferred type of the constant
            is used.

    Returns:
        An instance of :class:`ir.ast.ConstExpr` representing the
        constant.
    """
    ctx = stmtctx_mod.init_context(schema=schema, options=CompilerOptions())
    if not isinstance(const, qlast.BaseConstant):
        raise ValueError(f'unexpected input: {const!r} is not a constant')

    ir_set = dispatch_mod.compile(const, ctx=ctx)
    assert isinstance(ir_set, irast.Set)
    result = ir_set.expr
    assert isinstance(result, irast.BaseConstant)
    if styperef is not None and result.typeref.id != styperef.id:
        result = type(result)(value=result.value, typeref=styperef)

    return result


@compiler_entrypoint
def normalize(
    tree: qlast.Base,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> None:
    """Normalize the given AST *tree* by explicitly qualifying identifiers.

    This helper takes an arbitrary EdgeQL AST tree together with the current
    module alias mapping and produces an equivalent expression, in which
    all identifiers representing schema object references are properly
    qualified with the module name.

    NOTE: the tree is mutated *in-place*.
    """
    return norm_mod.normalize(
        tree,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )


@compiler_entrypoint
def renormalize_compat(
    tree: qlast.Base_T,
    orig_text: str,
    *,
    schema: s_schema.Schema,
    localnames: AbstractSet[str] = frozenset(),
) -> qlast.Base_T:
    """Renormalize an expression normalized with imprint_expr_context().

    This helper takes the original, unmangled expression, an EdgeQL AST
    tree of the same expression mangled with `imprint_expr_context()`
    (which injects extra WITH MODULE clauses), and produces a normalized
    expression with explicitly qualified identifiers instead.  Old dumps
    are the main user of this facility.
    """
    return norm_mod.renormalize_compat(
        tree,
        orig_text,
        schema=schema,
        localnames=localnames,
    )


def _load() -> None:
    """Load the compiler modules.  This is done once per process."""

    global _LOADED
    global dispatch_mod, inference_mod, irast, ireval, norm_mod, stmtctx_mod

    from edb.ir import ast as _irast
    from edb.ir import staeval as _ireval

    from . import expr as _expr_compiler  # NOQA
    from . import config as _config_compiler  # NOQA
    from . import stmt as _stmt_compiler  # NOQA

    from . import dispatch
    from . import inference
    from . import normalization
    from . import stmtctx

    dispatch_mod = dispatch
    inference_mod = inference
    irast = _irast
    ireval = _ireval
    norm_mod = normalization
    stmtctx_mod = stmtctx
    _LOADED = True


================================================
FILE: edb/edgeql/compiler/astutils.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL compiler helpers for AST classification and basic transforms."""


from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, TYPE_CHECKING

from edb.common import ast
from edb.common import view_patterns

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from edb.schema import name as sn
from edb.schema import functions as s_func

if TYPE_CHECKING:

    from edb.schema import functions as s_func

    from . import context


def extend_binop(
    binop: Optional[qlast.Expr],
    *exprs: qlast.Expr,
    op: str = 'AND',
) -> qlast.Expr:
    exprlist = list(exprs)

    if binop is None:
        result = exprlist.pop(0)
    else:
        result = binop

    for expr in exprlist:
        if expr is not None and expr is not result:
            result = qlast.BinOp(
                left=result,
                right=expr,
                op=op,
            )

    return result


def ensure_ql_query(expr: qlast.Expr) -> qlast.Query:

    # a sanity check added after refactoring AST
    assert isinstance(expr, qlast.Expr)

    if not isinstance(expr, qlast.Query):
        expr = qlast.SelectQuery(
            result=expr,
            implicit=True,
        )
    return expr


def ensure_ql_select(expr: qlast.Expr) -> qlast.SelectQuery:
    if not isinstance(expr, qlast.SelectQuery):
        expr = qlast.SelectQuery(
            result=expr,
            implicit=True,
        )
    return expr


def is_ql_empty_set(expr: qlast.Expr) -> bool:
    return isinstance(expr, qlast.Set) and len(expr.elements) == 0


def is_ql_empty_array(expr: qlast.Expr) -> bool:
    return isinstance(expr, qlast.Array) and len(expr.elements) == 0


def is_nontrivial_shape_element(shape_el: qlast.ShapeElement) -> bool:
    return bool(
        shape_el.where
        or shape_el.orderby
        or shape_el.offset
        or shape_el.limit
        or shape_el.compexpr
        or (
            shape_el.elements and
            any(is_nontrivial_shape_element(el) for el in shape_el.elements)
        )
    )


def extend_path(expr: qlast.Expr, field: str) -> qlast.Path:
    step = qlast.Ptr(name=field)

    if isinstance(expr, qlast.Path):
        return qlast.Path(
            steps=[*expr.steps, step],
            partial=expr.partial,
        )
    else:
        return qlast.Path(steps=[expr, step])


@dataclass
class Params:
    cast_params: list[
        tuple[qlast.TypeCast, dict[Optional[str], str]]
    ] = field(default_factory=list)
    shaped_params: list[
        tuple[qlast.QueryParameter, qlast.Shape]
    ] = field(default_factory=list)
    loose_params: list[qlast.QueryParameter] = field(default_factory=list)


class FindParams(ast.NodeVisitor):
    """Visitor to find all the parameters.

    The annoying bit is that we also need all the modaliases.
    """
    def __init__(self, modaliases: dict[Optional[str], str]) -> None:
        super().__init__()
        self.params: Params = Params()
        self.modaliases = modaliases

    def visit_Command(self, n: qlast.Command) -> None:
        self._visit_with_stmt(n)

    def visit_Query(self, n: qlast.Query) -> None:
        self._visit_with_stmt(n)

    def _visit_with_stmt(self, n: qlast.Statement) -> None:
        old = self.modaliases
        for with_entry in (n.aliases or ()):
            if isinstance(with_entry, qlast.ModuleAliasDecl):
                self.modaliases = self.modaliases.copy()
                self.modaliases[with_entry.alias] = with_entry.module
            else:
                self.visit(with_entry)

        # The memoization will prevent us from redoing the aliases
        self.generic_visit(n)
        self.modaliases = old

    def visit_TypeCast(self, n: qlast.TypeCast) -> None:
        if isinstance(n.expr, qlast.QueryParameter):
            self.params.cast_params.append((n, self.modaliases))
        elif isinstance(n.expr, qlast.Shape):
            if isinstance(n.expr.expr, qlast.QueryParameter):
                self.params.shaped_params.append((n.expr.expr, n.expr))
            else:
                self.generic_visit(n)
        else:
            self.generic_visit(n)

    def visit_QueryParameter(self, n: qlast.QueryParameter) -> None:
        self.params.loose_params.append(n)

    def visit_CreateFunction(self, n: qlast.CreateFunction) -> None:
        pass

    def visit_CreateConstraint(self, n: qlast.CreateFunction) -> None:
        pass


def find_parameters(
    ql: qlast.Base, modaliases: dict[Optional[str], str]
) -> Params:
    """Get all query parameters"""
    v = FindParams(modaliases)
    v.visit(ql)
    return v.params


class alias_view(
    view_patterns.ViewPattern[tuple[str, list[qlast.PathElement]]],
    targets=(qlast.Base,),
):
    @staticmethod
    def match(obj: object) -> tuple[str, list[qlast.PathElement]]:
        match obj:
            case qlast.Path(
                steps=[qlast.ObjectRef(module=None, name=alias), *rest],
                partial=False,
            ):
                return alias, rest
        raise view_patterns.NoMatch


def contains_dml(
    ql_expr: qlast.Base,
    *,
    ctx: context.ContextLevel
    ) -> bool:
    """Check whether a expression contains any DML in a subtree."""
    # If this ends up being a perf problem, we can use a visitor
    # directly and cache.
    dml_types = (qlast.InsertQuery, qlast.UpdateQuery, qlast.DeleteQuery)
    if isinstance(ql_expr, dml_types):
        return True

    res = ast.find_children(
        ql_expr, qlast.Base,
        lambda x: (
            isinstance(x, dml_types)
            or (isinstance(x, qlast.IRAnchor) and x.has_dml)
            or (
                isinstance(x, qlast.FunctionCall)
                and any(
                    (
                        func.get_volatility(ctx.env.schema)
                        == qltypes.Volatility.Modifying
                    )
                    for func in _get_functions_from_call(x, ctx=ctx)
                )
            )
        ),
        terminate_early=True,
    )

    return bool(res)


def _get_functions_from_call(
    expr: qlast.FunctionCall,
    *,
    ctx: context.ContextLevel,
) -> tuple[s_func.Function, ...]:
    funcname: sn.Name
    if isinstance(expr.func, str):
        funcname = sn.UnqualName(expr.func)
    else:
        funcname = sn.QualName(*expr.func)

    return s_func.lookup_functions(
        funcname,
        default=(),
        module_aliases=ctx.modaliases,
        schema=ctx.env.schema,
    )


================================================
FILE: edb/edgeql/compiler/casts.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL compiler routines for type casts."""


from __future__ import annotations

import json
from typing import (
    Optional,
    Iterable,
    Mapping,
    cast,
    TYPE_CHECKING,
)

from edb import errors

from edb.common import parsing

from edb.ir import ast as irast
from edb.ir import utils as irutils

from edb.schema import casts as s_casts
from edb.schema import constraints as s_constr
from edb.schema import functions as s_func
from edb.schema import indexes as s_indexes
from edb.schema import name as sn
from edb.schema import scalars as s_scalars
from edb.schema import types as s_types
from edb.schema import utils as s_utils
from edb.schema import name as s_name

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes as ft

from . import astutils
from . import context
from . import dispatch
from . import pathctx
from . import polyres
from . import setgen
from . import typegen
from . import viewgen

if TYPE_CHECKING:
    from edb.schema import schema as s_schema


def compile_cast(
    ir_expr: irast.Set | irast.Expr,
    new_stype: s_types.Type,
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
    cardinality_mod: Optional[qlast.CardinalityModifier] = None,
) -> irast.Set:

    if new_stype.is_polymorphic(ctx.env.schema) and span is not None:
        # If we have no span we don't know whether this is a direct cast
        # or some implicit cast being processed.
        raise errors.QueryError(
            f'cannot cast into generic type '
            f'{new_stype.get_displayname(ctx.env.schema)!r}',
            hint="Please ensure you don't use generic "
                 '"any" types or abstract scalars.',
            span=span)

    if (
        isinstance(ir_expr, irast.Set)
        and isinstance(ir_expr.expr, irast.EmptySet)
    ):
        # For the common case of casting an empty set, we simply
        # generate a new empty set node of the requested type.
        return setgen.new_empty_set(
            stype=new_stype,
            alias=ir_expr.path_id.target_name_hint.name,
            ctx=ctx,
            span=ir_expr.span)

    if isinstance(new_stype, s_types.Array) and (
        irutils.is_untyped_empty_array_expr(ir_expr)
        or (
            isinstance(ir_expr, irast.Set)
            and irutils.is_untyped_empty_array_expr(
                irutils.unwrap_set(ir_expr).expr)
        )
    ):
        # Ditto for empty arrays.
        new_typeref = typegen.type_to_typeref(new_stype, ctx.env)
        return setgen.ensure_set(
            irast.Array(elements=[], typeref=new_typeref), ctx=ctx)

    ir_set = setgen.ensure_set(ir_expr, ctx=ctx)
    orig_stype = setgen.get_set_type(ir_set, ctx=ctx)

    if new_stype.is_polymorphic(ctx.env.schema):
        raise errors.QueryError(
            f'expression returns value of indeterminate type',
            span=span)

    if (orig_stype == new_stype and
            cardinality_mod is not qlast.CardinalityModifier.Required):
        return ir_set
    if orig_stype.is_object_type() and new_stype.is_object_type():
        # Object types cannot be cast between themselves,
        # as cast is a _constructor_ operation, and the only
        # valid way to construct an object is to INSERT it.
        raise errors.QueryError(
            f'cannot cast object type '
            f'{orig_stype.get_displayname(ctx.env.schema)!r} '
            f'to {new_stype.get_displayname(ctx.env.schema)!r}, use '
            f'`...[IS {new_stype.get_displayname(ctx.env.schema)}]` instead',
            span=span)

    # The only valid object type cast other than  is from anytype,
    # and thus it must be an empty set.
    if (
        orig_stype.is_any(ctx.env.schema)
        and new_stype.is_object_type()
    ):
        return setgen.new_empty_set(
            stype=new_stype,
            ctx=ctx,
            span=ir_expr.span)

    uuid_t = ctx.env.get_schema_type_and_track(sn.QualName('std', 'uuid'))
    if (
        orig_stype.issubclass(ctx.env.schema, uuid_t)
        and new_stype.is_object_type()
    ):
        return _find_object_by_id(ir_expr, new_stype, ctx=ctx)

    json_t = ctx.env.get_schema_type_and_track(sn.QualName('std', 'json'))
    if (
        isinstance(ir_set.expr, irast.Array)
        and (
            isinstance(new_stype, s_types.Array)
            or new_stype.issubclass(ctx.env.schema, json_t)
        )
    ):
        cast_element = ('array', None)
        if ctx.collection_cast_info is not None:
            ctx.collection_cast_info.path_elements.append(cast_element)

        result = _cast_array_literal(
            ir_set, orig_stype, new_stype, span=span, ctx=ctx)

        if ctx.collection_cast_info is not None:
            ctx.collection_cast_info.path_elements.pop()

        return result

    if orig_stype.is_tuple(ctx.env.schema):
        return _cast_tuple(
            ir_set, orig_stype, new_stype, span=span, ctx=ctx)

    if isinstance(orig_stype, s_types.Array):
        if not s_types.is_type_compatible(
            orig_stype, new_stype, schema=ctx.env.schema
        ) and (
            not isinstance(new_stype, s_types.Array)
            and isinstance(
                (el_type := orig_stype.get_subtypes(ctx.env.schema)[0]),
                s_scalars.ScalarType,
            )
        ):
            # We're not casting to another array, so for purposes of matching
            # the right cast we want to reduce orig_stype to an array of the
            # built-in base type as that's what the cast will actually
            # expect.
            ir_set = _cast_to_base_array(
                ir_set, el_type, orig_stype, ctx=ctx)

        if isinstance(new_stype, s_types.Array):
            cast_element = ('array', None)
            if ctx.collection_cast_info is not None:
                ctx.collection_cast_info.path_elements.append(cast_element)

            result = _cast_array(
                ir_set, orig_stype, new_stype, span=span, ctx=ctx)

            if ctx.collection_cast_info is not None:
                ctx.collection_cast_info.path_elements.pop()

            return result

        else:
            return _cast_array(
                ir_set, orig_stype, new_stype, span=span, ctx=ctx)

    if isinstance(orig_stype, s_types.Range):
        if s_types.is_type_compatible(
            orig_stype, new_stype, schema=ctx.env.schema
        ):
            # Casting between compatible types is unnecessary. It is important
            # to catch things like RangeExprAlias and Range being of the same
            # type and not neding a cast.
            return ir_set
        else:
            if isinstance(new_stype, s_types.MultiRange):
                # For multirange target type we might need to first upcast the
                # range into corresponding multirange and then do a separate
                # cast for the subtype.
                if (
                    (ost := orig_stype.get_subtypes(schema=ctx.env.schema)) !=
                        new_stype.get_subtypes(schema=ctx.env.schema)
                ):
                    ctx.env.schema, mr_stype = \
                        s_types.MultiRange.from_subtypes(ctx.env.schema, ost)
                    ir_set = _inheritance_cast_to_ir(
                        ir_set, orig_stype, mr_stype,
                        cardinality_mod=cardinality_mod, ctx=ctx)
                    return _cast_multirange(
                        ir_set, mr_stype, new_stype, span=span, ctx=ctx)

                else:
                    # The subtypes match, so this is a direct upcast from
                    # range to multirange.
                    return _inheritance_cast_to_ir(
                        ir_set, orig_stype, new_stype,
                        cardinality_mod=cardinality_mod, ctx=ctx)

            return _cast_range(
                ir_set, orig_stype, new_stype, span=span, ctx=ctx)

    if orig_stype.is_multirange():
        if s_types.is_type_compatible(
            orig_stype, new_stype, schema=ctx.env.schema
        ):
            # Casting between compatible types is unnecessary. It is important
            # to catch things like MultiRangeExprAlias and MultiRange being of
            # the same type and not neding a cast.
            return ir_set
        else:
            return _cast_multirange(
                ir_set, orig_stype, new_stype, span=span, ctx=ctx)

    if orig_stype.issubclass(ctx.env.schema, new_stype):
        # The new type is a supertype of the old type,
        # and is always a wider domain, so we simply reassign
        # the stype.
        return _inheritance_cast_to_ir(
            ir_set, orig_stype, new_stype,
            cardinality_mod=cardinality_mod, ctx=ctx)

    if (
        new_stype.issubclass(ctx.env.schema, orig_stype)
        or _has_common_concrete_scalar(orig_stype, new_stype, ctx=ctx)
    ):
        # The new type is a subtype or a sibling type of a shared
        # ancestor, so may potentially have a more restrictive domain,
        # generate a cast call.
        return _inheritance_cast_to_ir(
            ir_set, orig_stype, new_stype,
            cardinality_mod=cardinality_mod, ctx=ctx)

    if (
        new_stype.issubclass(ctx.env.schema, json_t)
        and ir_set.path_id.is_objtype_path()
    ):
        # JSON casts of objects are special: we want the full shape
        # and not just an identity.
        viewgen.late_compile_view_shapes(ir_set, ctx=ctx)
    elif orig_stype.issubclass(ctx.env.schema, json_t):

        if base_stype := _get_concrete_scalar_base(new_stype, ctx):
            # Casts from json to custom scalars may have special handling.
            # So we turn the type cast json->x into json->base and base->x.
            base_ir = compile_cast(ir_expr, base_stype, span=span, ctx=ctx)

            return compile_cast(
                base_ir,
                new_stype,
                cardinality_mod=cardinality_mod,
                span=span,
                ctx=ctx,
            )

        elif isinstance(
            new_stype, s_types.Array
        ) and not new_stype.get_subtypes(ctx.env.schema)[0].issubclass(
            ctx.env.schema, json_t
        ):
            # Turn casts from json->array into json->array
            # and array->array.
            ctx.env.schema, json_array_typ = s_types.Array.from_subtypes(
                ctx.env.schema, [json_t]
            )
            json_array_ir = compile_cast(
                ir_expr,
                json_array_typ,
                cardinality_mod=cardinality_mod,
                span=span,
                ctx=ctx,
            )
            return compile_cast(
                json_array_ir, new_stype, span=span, ctx=ctx
            )

        elif isinstance(new_stype, s_types.Tuple):
            return _cast_json_to_tuple(
                ir_set,
                orig_stype,
                new_stype,
                cardinality_mod,
                span=span,
                ctx=ctx,
            )

        elif isinstance(new_stype, s_types.Range):
            return _cast_json_to_range(
                ir_set,
                orig_stype,
                new_stype,
                cardinality_mod,
                span=span,
                ctx=ctx,
            )

        elif isinstance(new_stype, s_types.MultiRange):
            return _cast_json_to_multirange(
                ir_set,
                orig_stype,
                new_stype,
                cardinality_mod,
                span=span,
                ctx=ctx,
            )

    # Constraints and indexes require an immutable expression, but pg cast is
    # only stable. In this specific case, we use cast wrapper function that
    # is declared to be immutable.
    if orig_stype.is_enum(ctx.env.schema) or new_stype.is_enum(ctx.env.schema):
        objctx = ctx.env.options.schema_object_context
        if objctx in (s_constr.Constraint, s_indexes.Index):

            str_typ = ctx.env.schema.get(
                sn.QualName("std", "str"),
                type=s_types.Type,
            )
            orig_str = orig_stype.issubclass(ctx.env.schema, str_typ)
            new_str = new_stype.issubclass(ctx.env.schema, str_typ)
            if orig_str or new_str:
                return _cast_enum_str_immutable(
                    ir_expr, orig_stype, new_stype, ctx=ctx
                )

    return _compile_cast(
        ir_expr,
        orig_stype,
        new_stype,
        cardinality_mod=cardinality_mod,
        span=span,
        ctx=ctx,
    )


def _has_common_concrete_scalar(
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    *,
    ctx: context.ContextLevel,
) -> bool:
    schema = ctx.env.schema
    return bool(
        isinstance(orig_stype, s_scalars.ScalarType)
        and isinstance(new_stype, s_scalars.ScalarType)
        and (orig_base := orig_stype.maybe_get_topmost_concrete_base(schema))
        and (new_base := new_stype.maybe_get_topmost_concrete_base(schema))
        and orig_base == new_base
    )


def _get_concrete_scalar_base(
    stype: s_types.Type, ctx: context.ContextLevel
) -> Optional[s_types.Type]:
    """Returns None if stype is not scalar or if it is already topmost"""

    if stype.is_enum(ctx.env.schema):
        return ctx.env.get_schema_type_and_track(sn.QualName('std', 'str'))

    if not isinstance(stype, s_scalars.ScalarType):
        return None
    if topmost := stype.maybe_get_topmost_concrete_base(ctx.env.schema):
        if topmost != stype:
            return topmost
    return None


def _compile_cast(
    ir_expr: irast.Set | irast.Expr,
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
    cardinality_mod: Optional[qlast.CardinalityModifier],
) -> irast.Set:

    ir_set = setgen.ensure_set(ir_expr, ctx=ctx)
    cast = _find_cast(orig_stype, new_stype, span=span, ctx=ctx)

    if cast is None:
        raise errors.QueryError(
            f'cannot cast '
            f'{orig_stype.get_displayname(ctx.env.schema)!r} to '
            f'{new_stype.get_displayname(ctx.env.schema)!r}',
            span=span or ir_set.span)

    return _cast_to_ir(ir_set, cast, orig_stype, new_stype,
                       cardinality_mod, ctx=ctx)


def _cast_to_ir(
    ir_set: irast.Set,
    cast: s_casts.Cast,
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    cardinality_mod: Optional[qlast.CardinalityModifier] = None,
    *,
    ctx: context.ContextLevel,
) -> irast.Set:

    orig_typeref = typegen.type_to_typeref(orig_stype, env=ctx.env)
    new_typeref = typegen.type_to_typeref(new_stype, env=ctx.env)
    cast_name = cast.get_name(ctx.env.schema)
    cast_ir = irast.TypeCast(
        expr=ir_set,
        from_type=orig_typeref,
        to_type=new_typeref,
        cardinality_mod=cardinality_mod,
        cast_name=cast_name,
        sql_function=cast.get_from_function(ctx.env.schema),
        sql_cast=cast.get_from_cast(ctx.env.schema),
        sql_expr=bool(cast.get_code(ctx.env.schema)),
        error_message_context=cast_message_context(ctx),
    )

    return setgen.ensure_set(cast_ir, ctx=ctx)


def _inheritance_cast_to_ir(
    ir_set: irast.Set,
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    *,
    cardinality_mod: Optional[qlast.CardinalityModifier],
    ctx: context.ContextLevel,
) -> irast.Set:

    orig_typeref = typegen.type_to_typeref(orig_stype, env=ctx.env)
    new_typeref = typegen.type_to_typeref(new_stype, env=ctx.env)
    cast_ir = irast.TypeCast(
        expr=ir_set,
        from_type=orig_typeref,
        to_type=new_typeref,
        cardinality_mod=cardinality_mod,
        cast_name=None,
        sql_function=None,
        sql_cast=True,
        sql_expr=False,
        error_message_context=cast_message_context(ctx),
    )

    return setgen.ensure_set(cast_ir, ctx=ctx)


class CastParamListWrapper(s_func.ParameterLikeList):

    def __init__(self, params: Iterable[s_func.ParameterDesc]) -> None:
        self._params = tuple(params)

    def get_by_name(
        self,
        schema: s_schema.Schema,
        name: str,
    ) -> s_func.ParameterDesc:
        raise NotImplementedError

    def as_str(self, schema: s_schema.Schema) -> str:
        raise NotImplementedError

    def find_named_only(
        self,
        schema: s_schema.Schema,
    ) -> Mapping[str, s_func.ParameterDesc]:
        return {}

    def find_variadic(
        self,
        schema: s_schema.Schema,
    ) -> Optional[s_func.ParameterDesc]:
        return None

    def has_polymorphic(
        self,
        schema: s_schema.Schema,
    ) -> bool:
        return False

    def has_objects(
        self,
        schema: s_schema.Schema,
    ) -> bool:
        return False

    def has_set_of(
        self,
        schema: s_schema.Schema,
    ) -> bool:
        return False

    def objects(
        self,
        schema: s_schema.Schema,
    ) -> tuple[s_func.ParameterDesc, ...]:
        return self._params

    def has_required_params(self, schema: s_schema.Schema) -> bool:
        return True

    def get_in_canonical_order(
        self,
        schema: s_schema.Schema,
    ) -> tuple[s_func.ParameterDesc, ...]:
        return self._params


class CastCallableWrapper(s_func.CallableLike):
    # A wrapper around a cast object to make it quack like a callable
    # for the purposes of polymorphic resolution.
    def __init__(self, cast: s_casts.Cast) -> None:
        self._cast = cast

    def has_inlined_defaults(self, schema: s_schema.Schema) -> bool:
        return False

    def get_params(
        self,
        schema: s_schema.Schema,
    ) -> s_func.ParameterLikeList:
        from_type_param = s_func.ParameterDesc(
            num=0,
            name=sn.UnqualName('val'),
            type=self._cast.get_from_type(schema).as_shell(schema),
            typemod=ft.TypeModifier.SingletonType,
            kind=ft.ParameterKind.PositionalParam,
            default=None,
        )

        to_type_param = s_func.ParameterDesc(
            num=0,
            name=sn.UnqualName('_'),
            type=self._cast.get_to_type(schema).as_shell(schema),
            typemod=ft.TypeModifier.SingletonType,
            kind=ft.ParameterKind.PositionalParam,
            default=None,
        )

        return CastParamListWrapper((from_type_param, to_type_param))

    def get_return_type(self, schema: s_schema.Schema) -> s_types.Type:
        return self._cast.get_to_type(schema)

    def get_return_typemod(self, schema: s_schema.Schema) -> ft.TypeModifier:
        return ft.TypeModifier.SingletonType

    def get_verbosename(self, schema: s_schema.Schema) -> str:
        return self._cast.get_verbosename(schema)

    def get_abstract(self, schema: s_schema.Schema) -> bool:
        return False


def _find_cast(
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
) -> Optional[s_casts.Cast]:

    # Don't try to pick up casts when there is a direct subtyping
    # relationship.
    if (orig_stype.issubclass(ctx.env.schema, new_stype)
            or new_stype.issubclass(ctx.env.schema, orig_stype)
            or _has_common_concrete_scalar(orig_stype, new_stype, ctx=ctx)):
        return None

    casts = ctx.env.schema.get_casts_to_type(new_stype)
    if not casts and isinstance(new_stype, s_types.InheritingType):
        ancestors = new_stype.get_ancestors(ctx.env.schema)
        for t in ancestors.objects(ctx.env.schema):
            casts = ctx.env.schema.get_casts_to_type(t)
            if casts:
                break
        else:
            return None

    dummy_set = irast.DUMMY_SET
    args = [
        (orig_stype, dummy_set),
        (new_stype, dummy_set),
    ]

    matched = polyres.find_callable(
        (CastCallableWrapper(c) for c in casts), args=args, kwargs={}, ctx=ctx)

    if len(matched) == 1:
        return cast(CastCallableWrapper, matched[0].func)._cast
    elif len(matched) > 1:
        raise errors.QueryError(
            f'cannot unambiguously cast '
            f'{orig_stype.get_displayname(ctx.env.schema)!r} '
            f'to {new_stype.get_displayname(ctx.env.schema)!r}',
            span=span)
    else:
        return None


def _cast_json_to_tuple(
    ir_set: irast.Set,
    orig_stype: s_types.Type,
    new_stype: s_types.Tuple,
    cardinality_mod: Optional[qlast.CardinalityModifier],
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
) -> irast.Set:

    with ctx.new() as subctx:
        subctx.allow_factoring()
        pathctx.register_set_in_scope(ir_set, ctx=subctx)

        subctx.anchors = subctx.anchors.copy()
        source_path = subctx.create_anchor(ir_set, 'a')

        # Top-level json->tuple casts should produce an empty set on
        # null inputs, but error on missing fields or null subelements
        allow_null = cardinality_mod != qlast.CardinalityModifier.Required

        # Only json arrays or objects can be cast to tuple.
        # If not in the top level cast, raise an exception here
        json_object_args: list[qlast.Expr] = [
            source_path,
            qlast.Constant.boolean(allow_null),
        ]
        if error_message_context := cast_message_context(subctx):
            json_object_args.append(qlast.Constant.string(
                json.dumps({
                    "error_message_context": error_message_context
                })
            ))

        # Don't validate NULLs. They are filtered out with the json nulls.
        json_objects = qlast.IfElse(
            condition=qlast.UnaryOp(
                op='EXISTS',
                operand=source_path,
            ),
            if_expr=qlast.FunctionCall(
                func=('__std__', '__tuple_validate_json'),
                args=json_object_args,
            ),
            else_expr=qlast.TypeCast(
                expr=qlast.Set(elements=[]),
                type=typegen.type_to_ql_typeref(orig_stype, ctx=ctx),
            ),
        )

        json_objects_ir = dispatch.compile(json_objects, ctx=subctx)

    with ctx.new() as subctx:
        pathctx.register_set_in_scope(json_objects_ir, ctx=subctx)
        subctx.anchors = subctx.anchors.copy()
        source_path = subctx.create_anchor(json_objects_ir, 'a')

        # Filter out json nulls and postgress NULLs.
        # Nulls at the top level cast can be ignored.
        filtered = qlast.SelectQuery(
            result=source_path,
            where=qlast.BinOp(
                left=qlast.FunctionCall(
                    func=('__std__', 'json_typeof'), args=[source_path]
                ),
                op='!=',
                right=qlast.Constant.string('null'),
            ),
        )
        filtered_ir = dispatch.compile(filtered, ctx=subctx)
        source_path = subctx.create_anchor(filtered_ir, 'a')

        # TODO: try using jsonb_to_record instead of a bunch of
        # json_get calls and see if that is faster.
        elements = []
        for new_el_name, new_st in new_stype.iter_subtypes(ctx.env.schema):
            cast_element = ('tuple', new_el_name)
            if subctx.collection_cast_info is not None:
                subctx.collection_cast_info.path_elements.append(cast_element)

            json_get_kwargs: dict[str, qlast.Expr] = {}
            if error_message_context := cast_message_context(subctx):
                json_get_kwargs['detail'] = qlast.Constant.string(
                    json.dumps({
                        "error_message_context": error_message_context
                    })
                )
            val_e = qlast.FunctionCall(
                func=('__std__', '__json_get_not_null'),
                args=[
                    source_path,
                    qlast.Constant.string(new_el_name),
                ],
                kwargs=json_get_kwargs
            )

            val = dispatch.compile(val_e, ctx=subctx)

            val = compile_cast(
                val, new_st,
                cardinality_mod=qlast.CardinalityModifier.Required,
                ctx=subctx, span=span)

            if subctx.collection_cast_info is not None:
                subctx.collection_cast_info.path_elements.pop()

            elements.append(irast.TupleElement(name=new_el_name, val=val))

        return setgen.new_tuple_set(
            elements,
            named=new_stype.is_named(ctx.env.schema),
            ctx=subctx,
        )


def _cast_tuple(
    ir_set: irast.Set,
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
) -> irast.Set:

    assert isinstance(orig_stype, s_types.Tuple)

    # Make sure the source tuple expression is pinned in the scope,
    # so that we don't generate a cross-product of it by evaluating
    # the tuple indirections.
    pathctx.register_set_in_scope(ir_set, ctx=ctx)

    direct_cast = _find_cast(orig_stype, new_stype, span=span, ctx=ctx)
    orig_subtypes = dict(orig_stype.iter_subtypes(ctx.env.schema))

    if direct_cast is not None:
        # Direct casting to non-tuple involves casting each tuple
        # element and also keeping the cast around the whole tuple.
        # This is to trigger the downstream logic of casting
        # objects (in elements of the tuple).
        elements = []
        for n in orig_subtypes:
            val = setgen.tuple_indirection_set(
                ir_set,
                source=orig_stype,
                ptr_name=n,
                ctx=ctx,
            )
            val_type = setgen.get_set_type(val, ctx=ctx)
            # Element cast
            cast_element = ('tuple', n)
            if ctx.collection_cast_info is not None:
                ctx.collection_cast_info.path_elements.append(cast_element)

            val = compile_cast(val, new_stype, ctx=ctx, span=span)

            if ctx.collection_cast_info is not None:
                ctx.collection_cast_info.path_elements.pop()

            elements.append(irast.TupleElement(name=n, val=val))

        new_tuple = setgen.new_tuple_set(
            elements,
            named=orig_stype.is_named(ctx.env.schema),
            ctx=ctx,
        )

        return _cast_to_ir(
            new_tuple, direct_cast, orig_stype, new_stype, ctx=ctx)

    if not new_stype.is_tuple(ctx.env.schema):
        raise errors.QueryError(
            f'cannot cast {orig_stype.get_displayname(ctx.env.schema)!r} '
            f'to {new_stype.get_displayname(ctx.env.schema)!r}',
            span=span)

    assert isinstance(new_stype, s_types.Tuple)
    new_subtypes = list(new_stype.iter_subtypes(ctx.env.schema))
    if len(orig_subtypes) != len(new_subtypes):
        raise errors.QueryError(
            f'cannot cast {orig_stype.get_displayname(ctx.env.schema)!r} '
            f'to {new_stype.get_displayname(ctx.env.schema)!r}: '
            f'the number of elements is not the same',
            span=span)

    # For tuple-to-tuple casts we generate a new tuple
    # to simplify things on sqlgen side.
    elements = []
    for i, n in enumerate(orig_subtypes):
        val = setgen.tuple_indirection_set(
            ir_set,
            source=orig_stype,
            ptr_name=n,
            ctx=ctx,
        )
        val_type = setgen.get_set_type(val, ctx=ctx)
        new_el_name, new_st = new_subtypes[i]
        if val_type != new_st:
            # Element cast
            cast_element = ('tuple', new_el_name)
            if ctx.collection_cast_info is not None:
                ctx.collection_cast_info.path_elements.append(cast_element)

            val = compile_cast(val, new_st, ctx=ctx, span=span)

            if ctx.collection_cast_info is not None:
                ctx.collection_cast_info.path_elements.pop()

        elements.append(irast.TupleElement(name=new_el_name, val=val))

    return setgen.new_tuple_set(
        elements,
        named=new_stype.is_named(ctx.env.schema),
        ctx=ctx,
    )


def _cast_range(
    ir_set: irast.Set,
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
) -> irast.Set:

    assert isinstance(orig_stype, s_types.Range)

    direct_cast = _find_cast(orig_stype, new_stype, span=span, ctx=ctx)
    if direct_cast is not None:
        return _cast_to_ir(
            ir_set, direct_cast, orig_stype, new_stype, ctx=ctx
        )

    if not new_stype.is_range():
        raise errors.QueryError(
            f'cannot cast {orig_stype.get_displayname(ctx.env.schema)!r} '
            f'to {new_stype.get_displayname(ctx.env.schema)!r}',
            span=span)
    assert isinstance(new_stype, s_types.Range)
    el_type = new_stype.get_subtypes(ctx.env.schema)[0]
    orig_el_type = orig_stype.get_subtypes(ctx.env.schema)[0]
    ql_el_type = typegen.type_to_ql_typeref(el_type, ctx=ctx)

    el_cast = _find_cast(orig_el_type, el_type, span=span, ctx=ctx)
    if el_cast is None:
        raise errors.QueryError(
            f'cannot cast {orig_stype.get_displayname(ctx.env.schema)!r} '
            f'to {new_stype.get_displayname(ctx.env.schema)!r}',
            span=span)

    with ctx.new() as subctx:
        subctx.allow_factoring()
        subctx.anchors = subctx.anchors.copy()
        source_path = subctx.create_anchor(ir_set, 'a')

        cast = qlast.FunctionCall(
            func=('__std__', 'range'),
            args=[
                qlast.TypeCast(
                    expr=qlast.FunctionCall(
                        func=('__std__', 'range_get_lower'),
                        args=[source_path],
                    ),
                    type=ql_el_type,
                ),
                qlast.TypeCast(
                    expr=qlast.FunctionCall(
                        func=('__std__', 'range_get_upper'),
                        args=[source_path],
                    ),
                    type=ql_el_type,
                ),
            ],
            kwargs={
                "inc_lower": qlast.FunctionCall(
                    func=('__std__', 'range_is_inclusive_lower'),
                    args=[source_path],
                ),
                "inc_upper": qlast.FunctionCall(
                    func=('__std__', 'range_is_inclusive_upper'),
                    args=[source_path],
                ),
                "empty": qlast.FunctionCall(
                    func=('__std__', 'range_is_empty'),
                    args=[source_path],
                ),
            }
        )

        if el_type.contains_json(subctx.env.schema):
            subctx.implicit_limit = 0

        return dispatch.compile(cast, ctx=subctx)


def _cast_multirange(
    ir_set: irast.Set,
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
) -> irast.Set:

    assert isinstance(orig_stype, s_types.MultiRange)

    direct_cast = _find_cast(orig_stype, new_stype, span=span, ctx=ctx)
    if direct_cast is not None:
        return _cast_to_ir(
            ir_set, direct_cast, orig_stype, new_stype, ctx=ctx
        )

    if not new_stype.is_multirange():
        raise errors.QueryError(
            f'cannot cast {orig_stype.get_displayname(ctx.env.schema)!r} '
            f'to {new_stype.get_displayname(ctx.env.schema)!r}',
            span=span)
    assert isinstance(new_stype, s_types.MultiRange)
    el_type = new_stype.get_subtypes(ctx.env.schema)[0]
    orig_el_type = orig_stype.get_subtypes(ctx.env.schema)[0]

    el_cast = _find_cast(orig_el_type, el_type, span=span, ctx=ctx)
    if el_cast is None:
        raise errors.QueryError(
            f'cannot cast {orig_stype.get_displayname(ctx.env.schema)!r} '
            f'to {new_stype.get_displayname(ctx.env.schema)!r}',
            span=span)

    ctx.env.schema, new_range_type = s_types.Range.from_subtypes(
        ctx.env.schema, [el_type])
    ql_range_type = typegen.type_to_ql_typeref(new_range_type, ctx=ctx)
    with ctx.new() as subctx:
        subctx.allow_factoring()
        subctx.anchors = subctx.anchors.copy()
        pathctx.register_set_in_scope(ir_set, ctx=subctx)
        source_path = subctx.create_anchor(ir_set, 'a')

        # multirange(
        #     array_agg(
        #         >multirange_unpack(orig)
        #     )
        # )
        cast = qlast.FunctionCall(
            func=('__std__', 'multirange'),
            args=[
                qlast.FunctionCall(
                    func=('__std__', 'array_agg'),
                    args=[
                        qlast.TypeCast(
                            expr=qlast.FunctionCall(
                                func=('__std__', 'multirange_unpack'),
                                args=[source_path],
                            ),
                            type=ql_range_type,
                        ),
                    ],
                ),
            ],
        )

        if el_type.contains_json(subctx.env.schema):
            subctx.implicit_limit = 0

        return dispatch.compile(cast, ctx=subctx)


def _cast_json_to_range(
    ir_set: irast.Set,
    orig_stype: s_types.Type,
    new_stype: s_types.Range,
    cardinality_mod: Optional[qlast.CardinalityModifier],
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
) -> irast.Set:

    with ctx.new() as subctx:
        subctx.anchors = subctx.anchors.copy()
        source_path = subctx.create_anchor(ir_set, 'a')

        check_args: list[qlast.Expr] = [source_path]
        if error_message_context := cast_message_context(subctx):
            check_args.append(qlast.Constant.string(
                json.dumps({
                    "error_message_context": error_message_context
                })
            ))
        check = qlast.FunctionCall(
            func=('__std__', '__range_validate_json'),
            args=check_args
        )

        check_ir = dispatch.compile(check, ctx=subctx)
        source_path = subctx.create_anchor(check_ir, 'b')

        range_el_t = new_stype.get_element_type(ctx.env.schema)
        ql_range_el_t = typegen.type_to_ql_typeref(range_el_t, ctx=subctx)
        bool_t = ctx.env.get_schema_type_and_track(sn.QualName('std', 'bool'))
        ql_bool_t = typegen.type_to_ql_typeref(bool_t, ctx=subctx)

        def compile_with_range_element(
            expr: qlast.Expr,
            element_name: str,
        ) -> irast.Set:
            cast_element = ('range', element_name)
            if subctx.collection_cast_info is not None:
                subctx.collection_cast_info.path_elements.append(cast_element)

            expr_ir = dispatch.compile(expr, ctx=subctx)

            if subctx.collection_cast_info is not None:
                subctx.collection_cast_info.path_elements.pop()

            return expr_ir

        lower: qlast.Expr = qlast.TypeCast(
            expr=qlast.FunctionCall(
                func=('__std__', 'json_get'),
                args=[
                    source_path,
                    qlast.Constant.string('lower'),
                ],
            ),
            type=ql_range_el_t,
        )
        lower_ir = compile_with_range_element(lower, 'lower')
        lower = subctx.create_anchor(lower_ir, 'lower')

        upper: qlast.Expr = qlast.TypeCast(
            expr=qlast.FunctionCall(
                func=('__std__', 'json_get'),
                args=[
                    source_path,
                    qlast.Constant.string('upper'),
                ],
            ),
            type=ql_range_el_t,
        )
        upper_ir = compile_with_range_element(upper, 'upper')
        upper = subctx.create_anchor(upper_ir, 'upper')

        inc_lower: qlast.Expr = qlast.TypeCast(
            expr=qlast.FunctionCall(
                func=('__std__', 'json_get'),
                args=[
                    source_path,
                    qlast.Constant.string('inc_lower'),
                ],
                kwargs={
                    'default': qlast.FunctionCall(
                        func=('__std__', 'to_json'),
                        args=[qlast.Constant.string("true")],
                    ),
                },
            ),
            type=ql_bool_t,
        )
        inc_lower_ir = compile_with_range_element(inc_lower, 'inc_lower')
        inc_lower = subctx.create_anchor(inc_lower_ir, 'inc_lower')

        inc_upper: qlast.Expr = qlast.TypeCast(
            expr=qlast.FunctionCall(
                func=('__std__', 'json_get'),
                args=[
                    source_path,
                    qlast.Constant.string('inc_upper'),
                ],
                kwargs={
                    'default': qlast.FunctionCall(
                        func=('__std__', 'to_json'),
                        args=[qlast.Constant.string("false")],
                    ),
                },
            ),
            type=ql_bool_t,
        )
        inc_upper_ir = compile_with_range_element(inc_upper, 'inc_upper')
        inc_upper = subctx.create_anchor(inc_upper_ir, 'inc_upper')

        empty: qlast.Expr = qlast.TypeCast(
            expr=qlast.FunctionCall(
                func=('__std__', 'json_get'),
                args=[
                    source_path,
                    qlast.Constant.string('empty'),
                ],
                kwargs={
                    'default': qlast.FunctionCall(
                        func=('__std__', 'to_json'),
                        args=[qlast.Constant.string("false")],
                    ),
                },
            ),
            type=ql_bool_t,
        )
        empty_ir = compile_with_range_element(empty, 'empty')
        empty = subctx.create_anchor(empty_ir, 'empty')

        cast = qlast.FunctionCall(
            func=('__std__', 'range'),
            args=[lower, upper],
            # inc_lower and inc_upper are required to be present for
            # non-empty casts from json, and this is checked in
            # __range_validate_json. We still need to provide default
            # arguments when fetching them, though, since if those
            # arguments to range are {} it will cause {"empty": true}
            # to evaluate to {}.
            kwargs={
                "inc_lower": inc_lower,
                "inc_upper": inc_upper,
                "empty": empty,
            }
        )

        return dispatch.compile(cast, ctx=subctx)


def _cast_json_to_multirange(
    ir_set: irast.Set,
    orig_stype: s_types.Type,
    new_stype: s_types.MultiRange,
    cardinality_mod: Optional[qlast.CardinalityModifier],
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
) -> irast.Set:

    ctx.env.schema, new_range_type = s_types.Range.from_subtypes(
        ctx.env.schema, new_stype.get_subtypes(ctx.env.schema))
    ctx.env.schema, new_array_type = s_types.Array.from_subtypes(
        ctx.env.schema, [new_range_type])
    ql_array_range_type = typegen.type_to_ql_typeref(new_array_type, ctx=ctx)
    with ctx.new() as subctx:
        # We effectively want to do the following:
        # multirange(>>a)
        subctx.anchors = subctx.anchors.copy()
        source_path = subctx.create_anchor(ir_set, 'a')

        cast = qlast.FunctionCall(
            func=('__std__', 'multirange'),
            args=[
                qlast.TypeCast(
                    expr=source_path,
                    type=ql_array_range_type,
                ),
            ],
        )

        return dispatch.compile(cast, ctx=subctx)


def _cast_to_base_array(
    ir_set: irast.Set,
    el_stype: s_scalars.ScalarType,
    orig_stype: s_types.Array,
    ctx: context.ContextLevel,
    cardinality_mod: Optional[qlast.CardinalityModifier]=None
) -> irast.Set:

    base_stype = el_stype.get_base_for_cast(ctx.env.schema)
    assert isinstance(base_stype, s_types.Type)
    ctx.env.schema, new_stype = s_types.Array.from_subtypes(
        ctx.env.schema, [base_stype])

    return _inheritance_cast_to_ir(
        ir_set, orig_stype, new_stype,
        cardinality_mod=cardinality_mod, ctx=ctx)


def _cast_array(
    ir_set: irast.Set,
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
) -> irast.Set:

    assert isinstance(orig_stype, s_types.Array)

    direct_cast = _find_cast(orig_stype, new_stype, span=span, ctx=ctx)

    if direct_cast is None:
        if not new_stype.is_array():
            raise errors.QueryError(
                f'cannot cast {orig_stype.get_displayname(ctx.env.schema)!r} '
                f'to {new_stype.get_displayname(ctx.env.schema)!r}',
                span=span)
        assert isinstance(new_stype, s_types.Array)
        el_type = new_stype.get_subtypes(ctx.env.schema)[0]
    elif new_stype.is_json(ctx.env.schema):
        el_type = new_stype
    else:
        # We're casting an array into something that's not an array (e.g. a
        # vector), so we don't need to match element types.
        return _cast_to_ir(
            ir_set, direct_cast, orig_stype, new_stype, ctx=ctx)

    orig_el_type = orig_stype.get_subtypes(ctx.env.schema)[0]

    el_cast = _find_cast(orig_el_type, el_type, span=span, ctx=ctx)

    if el_cast is not None and el_cast.get_from_cast(ctx.env.schema):
        # Simple cast
        return _cast_to_ir(
            ir_set, el_cast, orig_stype, new_stype, ctx=ctx)
    else:
        with ctx.new() as subctx:
            subctx.allow_factoring()

            subctx.anchors = subctx.anchors.copy()
            source_path = subctx.create_anchor(ir_set, 'a')

            unpacked = qlast.FunctionCall(
                func=('__std__', 'array_unpack'),
                args=[source_path],
            )

            enumerated = dispatch.compile(
                qlast.FunctionCall(
                    func=('__std__', 'enumerate'),
                    args=[unpacked],
                ),
                ctx=subctx,
            )

            enumerated_ref = subctx.create_anchor(enumerated, 'e')

            elements = qlast.FunctionCall(
                func=('__std__', 'array_agg'),
                args=[
                    qlast.SelectQuery(
                        result=qlast.TypeCast(
                            expr=astutils.extend_path(enumerated_ref, '1'),
                            type=typegen.type_to_ql_typeref(
                                el_type,
                                ctx=subctx,
                            ),
                            cardinality_mod=qlast.CardinalityModifier.Required,
                            span=span,
                        ),
                        orderby=[
                            qlast.SortExpr(
                                path=astutils.extend_path(enumerated_ref, '0'),
                                direction=qlast.SortOrder.Asc,
                            ),
                        ],
                    ),
                ],
            )

            # Force the elements to be correlated with whatever the
            # anchor was. (Doing it this way ensures a NULL check,
            # and just registering it in the scope would not.)
            correlated_elements = astutils.extend_path(
                qlast.Tuple(elements=[source_path, elements]), '1'
            )
            correlated_query = qlast.SelectQuery(result=correlated_elements)

            if el_type.contains_json(subctx.env.schema):
                subctx.implicit_limit = 0

            array_ir = dispatch.compile(correlated_query, ctx=subctx)
            assert isinstance(array_ir, irast.Set)

            if direct_cast is not None:
                ctx.env.schema, array_stype = s_types.Array.from_subtypes(
                    ctx.env.schema, [el_type])
                return _cast_to_ir(
                    array_ir, direct_cast, array_stype, new_stype, ctx=ctx
                )
            else:
                return array_ir


def _cast_array_literal(
    ir_set: irast.Set,
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    *,
    span: Optional[parsing.Span],
    ctx: context.ContextLevel,
) -> irast.Set:

    assert isinstance(ir_set.expr, irast.Array)

    orig_typeref = typegen.type_to_typeref(orig_stype, env=ctx.env)
    new_typeref = typegen.type_to_typeref(new_stype, env=ctx.env)
    direct_cast = _find_cast(orig_stype, new_stype, span=span, ctx=ctx)

    if direct_cast is None:
        if not new_stype.is_array():
            raise errors.QueryError(
                f'cannot cast {orig_stype.get_displayname(ctx.env.schema)!r} '
                f'to {new_stype.get_displayname(ctx.env.schema)!r}',
                span=span) from None
        assert isinstance(new_stype, s_types.Array)
        el_type = new_stype.get_subtypes(ctx.env.schema)[0]
        intermediate_stype = orig_stype

    else:
        el_type = new_stype
        ctx.env.schema, intermediate_stype = s_types.Array.from_subtypes(
            ctx.env.schema, [el_type])

    intermediate_typeref = typegen.type_to_typeref(
        intermediate_stype, env=ctx.env)
    casted_els = []
    for el in ir_set.expr.elements:
        el = compile_cast(el, el_type,
                          cardinality_mod=qlast.CardinalityModifier.Required,
                          ctx=ctx, span=span)
        casted_els.append(el)

    new_array = setgen.ensure_set(
        irast.Array(elements=casted_els, typeref=intermediate_typeref),
        ctx=ctx)

    if direct_cast is not None:
        return _cast_to_ir(
            new_array, direct_cast, intermediate_stype, new_stype, ctx=ctx)

    else:
        cast_ir = irast.TypeCast(
            expr=new_array,
            from_type=orig_typeref,
            to_type=new_typeref,
            sql_cast=True,
            sql_expr=False,
            span=span,
            error_message_context=cast_message_context(ctx),
        )

    return setgen.ensure_set(cast_ir, ctx=ctx)


def _cast_enum_str_immutable(
    ir_expr: irast.Set | irast.Expr,
    orig_stype: s_types.Type,
    new_stype: s_types.Type,
    *,
    ctx: context.ContextLevel,
) -> irast.Set:
    """
    Compiles cast between an enum and std::str
    under the assumption that this expression must be immutable.
    """

    if new_stype.is_enum(ctx.env.schema):
        enum_stype = new_stype
        suffix = "_from_str"
    else:
        enum_stype = orig_stype
        suffix = "_into_str"

    name: s_name.Name = enum_stype.get_name(ctx.env.schema)
    name = cast(s_name.QualName, name)
    cast_name = s_name.QualName(
        module=name.module, name=str(enum_stype.id) + suffix
    )

    orig_typeref = typegen.type_to_typeref(orig_stype, env=ctx.env)
    new_typeref = typegen.type_to_typeref(new_stype, env=ctx.env)

    cast_ir = irast.TypeCast(
        expr=setgen.ensure_set(ir_expr, ctx=ctx),
        from_type=orig_typeref,
        to_type=new_typeref,
        cardinality_mod=None,
        cast_name=cast_name,
        sql_function=None,
        sql_cast=False,
        sql_expr=True,
        error_message_context=cast_message_context(ctx),
    )

    return setgen.ensure_set(cast_ir, ctx=ctx)


def _find_object_by_id(
    ir_expr: irast.Set | irast.Expr,
    new_stype: s_types.Type,
    *,
    ctx: context.ContextLevel,
) -> irast.Set:
    with ctx.new() as subctx:
        subctx.anchors = subctx.anchors.copy()

        ir_set = setgen.ensure_set(ir_expr, ctx=subctx)
        uuid_anchor = subctx.create_anchor(ir_set, name='a')

        object_name = s_utils.name_to_ast_ref(
            new_stype.get_name(ctx.env.schema)
        )

        select_id = qlast.SelectQuery(
            result=qlast.DetachedExpr(expr=qlast.Path(steps=[object_name])),
            where=qlast.BinOp(
                left=qlast.Path(
                    steps=[qlast.Ptr(name='id', direction='>')],
                    partial=True,
                ),
                op='=',
                right=qlast.Path(steps=[qlast.ObjectRef(name='_id')]),
            ),
        )

        error_message = qlast.BinOp(
            left=qlast.Constant.string(
                value=(
                    repr(new_stype.get_displayname(ctx.env.schema))
                    + ' with id \''
                )
            ),
            op='++',
            right=qlast.BinOp(
                left=qlast.TypeCast(
                    expr=qlast.Path(steps=[qlast.ObjectRef(name='_id')]),
                    type=qlast.TypeName(maintype=qlast.ObjectRef(name='str')),
                ),
                op='++',
                right=qlast.Constant.string('\' does not exist'),
            ),
        )

        exists_ql = qlast.FunctionCall(
            func='assert_exists',
            args=[select_id],
            kwargs={'message': error_message},
        )

        for_query = qlast.ForQuery(
            iterator=uuid_anchor, iterator_alias='_id', result=exists_ql
        )

        return dispatch.compile(for_query, ctx=subctx)


def cast_message_context(ctx: context.ContextLevel) -> Optional[str]:
    if (
        ctx.collection_cast_info is not None
        and ctx.collection_cast_info.path_elements
    ):
        from_name = (
            ctx.collection_cast_info.from_type.get_displayname(ctx.env.schema)
        )
        to_name = (
            ctx.collection_cast_info.to_type.get_displayname(ctx.env.schema)
        )
        path_msg = ''.join(
            _collection_element_message_context(path_element)
            for path_element in ctx.collection_cast_info.path_elements
        )
        return (
            f"while casting '{from_name}' to '{to_name}', {path_msg}"
        )
    else:
        return None


def _collection_element_message_context(
    path_element: tuple[str, Optional[str]]
) -> str:
    if path_element[0] == 'tuple':
        return f"at tuple element '{path_element[1]}', "
    elif path_element[0] == 'array':
        return f'in array elements, '
    elif path_element[0] == 'range':
        return f"in range parameter '{path_element[1]}', "
    else:
        raise NotImplementedError


================================================
FILE: edb/edgeql/compiler/clauses.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL compiler functions to process shared clauses."""


from __future__ import annotations

from typing import Optional, Sequence

from edb.edgeql import ast as qlast
from edb.ir import ast as irast

from edb import errors
from edb.ir import utils as irutils
from edb.schema import name as sn
from edb.schema import operators as s_oper

from . import context
from . import dispatch
from . import polyres
from . import schemactx
from . import setgen
from . import typegen
from . import pathctx


def compile_where_clause(
    where: Optional[qlast.Base], *, ctx: context.ContextLevel
) -> Optional[irast.Set]:

    if where is None:
        return None

    if ctx.partial_path_prefix:
        pathctx.register_set_in_scope(ctx.partial_path_prefix, ctx=ctx)

    with ctx.newscope(fenced=True) as subctx:
        subctx.expr_exposed = context.Exposure.UNEXPOSED
        subctx.path_scope.unnest_fence = True
        subctx.disallow_dml = "in a FILTER clause"
        ir_expr = dispatch.compile(where, ctx=subctx)
        bool_t = ctx.env.get_schema_type_and_track(sn.QualName('std', 'bool'))
        ir_set = setgen.scoped_set(ir_expr, typehint=bool_t, ctx=subctx)

    return ir_set


def adjust_nones_order(
    ir_sortexpr: irast.Set,
    sort: qlast.SortExpr,
    *,
    ctx: context.ContextLevel,
) -> Optional[qlast.NonesOrder]:
    if sort.nones_order:
        return sort.nones_order

    # If we are doing an ORDER BY on a required property that has an
    # exclusive constraint and no nones_order specified, we want to
    # defualt to EMPTY LAST (or EMPTY FIRST for DESC).  Since the
    # property is required, this doesn't have a semantic impact, but
    # our exclusive constraints (sigh.) use a UNIQUE constraint,
    # which is always NULLS LAST.
    #
    # Postgres seems to *sometimes* be able to use the indexes without
    # this intervention, but not always?
    # See #8035.
    ir = irutils.unwrap_set(ir_sortexpr)
    expr = ir.expr
    if (
        isinstance(expr, irast.Pointer)
        and expr.source == ctx.partial_path_prefix
        and expr.dir_cardinality
        and not expr.dir_cardinality.can_be_zero()
        and isinstance(expr.ptrref, irast.PointerRef)
        and (ptr := typegen.ptrcls_from_ptrref(
            expr.ptrref, ctx=ctx,
        ))
        and bool(ptr.get_exclusive_constraints(ctx.env.schema))
    ):
        if sort.direction == qlast.SortOrder.Desc:
            return qlast.NonesOrder.First
        else:
            return qlast.NonesOrder.Last

    return None


def compile_orderby_clause(
    sortexprs: Optional[Sequence[qlast.SortExpr]], *, ctx: context.ContextLevel
) -> Optional[list[irast.SortExpr]]:

    if not sortexprs:
        return None

    result: list[irast.SortExpr] = []

    if ctx.partial_path_prefix:
        pathctx.register_set_in_scope(ctx.partial_path_prefix, ctx=ctx)

    with ctx.new() as subctx:
        subctx.expr_exposed = context.Exposure.UNEXPOSED
        subctx.disallow_dml = "in an ORDER BY clause"
        for sortexpr in sortexprs:
            with subctx.newscope(fenced=True) as exprctx:
                exprctx.path_scope.unnest_fence = True
                ir_sortexpr = dispatch.compile(sortexpr.path, ctx=exprctx)
                ir_sortexpr = setgen.scoped_set(
                    ir_sortexpr, force_reassign=True, ctx=exprctx)
                ir_sortexpr.span = sortexpr.span

                # Check that the sortexpr type is actually orderable
                # with either '>' or '<' based on the DESC or ASC sort
                # order.
                env = exprctx.env
                sort_type = setgen.get_set_type(ir_sortexpr, ctx=ctx)
                # Postgres by default treats ASC as using '<' and DESC
                # as using '>'. We should do the same.
                if sortexpr.direction == qlast.SortDesc:
                    op_name = '>'
                else:
                    op_name = '<'
                opers = s_oper.lookup_operators(
                    op_name,
                    module_aliases=exprctx.modaliases,
                    schema=env.schema
                )

                # Verify that a comparison operator is defined for 2
                # sort_type expressions.
                matched = polyres.find_callable(
                    opers,
                    args=[(sort_type, ir_sortexpr), (sort_type, ir_sortexpr)],
                    kwargs={},
                    ctx=exprctx)
                if len(matched) != 1:
                    sort_type_name = schemactx.get_material_type(
                        sort_type, ctx=ctx).get_displayname(env.schema)
                    if len(matched) == 0:
                        raise errors.QueryError(
                            f'type {sort_type_name!r} cannot be used in '
                            f'ORDER BY clause because ordering is not '
                            f'defined for it',
                            span=sortexpr.span)

                    elif len(matched) > 1:
                        raise errors.QueryError(
                            f'type {sort_type_name!r} cannot be used in '
                            f'ORDER BY clause because ordering is '
                            f'ambiguous for it',
                            span=sortexpr.span)

            result.append(
                irast.SortExpr(
                    expr=ir_sortexpr,
                    direction=sortexpr.direction,
                    nones_order=adjust_nones_order(
                        ir_sortexpr,
                        sortexpr,
                        ctx=ctx,
                    ),
                ))

    return result


def compile_limit_offset_clause(
    expr: Optional[qlast.Base], *, ctx: context.ContextLevel
) -> Optional[irast.Set]:
    if expr is None:
        ir_set = None
    else:
        with ctx.newscope(fenced=True) as subctx:
            subctx.expr_exposed = context.Exposure.UNEXPOSED
            # Clear out the partial_path_prefix, since we aren't in
            # the scope of the select subject
            subctx.partial_path_prefix = None

            ir_expr = dispatch.compile(expr, ctx=subctx)
            int_t = ctx.env.get_schema_type_and_track(
                sn.QualName('std', 'int64'))
            ir_set = setgen.scoped_set(
                ir_expr, force_reassign=True, typehint=int_t, ctx=subctx)
            ir_set.span = expr.span

    return ir_set


================================================
FILE: edb/edgeql/compiler/config.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""CONFIGURE statement compilation functions."""


from __future__ import annotations
from typing import Optional, NamedTuple

import json

from edb import errors

from edb.edgeql import qltypes

from edb.ir import ast as irast
from edb.ir import staeval as ireval
from edb.ir import statypes as statypes
from edb.ir import typeutils as irtyputils

from edb.schema import constraints as s_constr
from edb.schema import globals as s_globals
from edb.schema import links as s_links
from edb.schema import name as sn
from edb.schema import objtypes as s_objtypes
from edb.schema import pointers as s_pointers
from edb.schema import schema as s_schema
from edb.schema import types as s_types
from edb.schema import utils as s_utils
from edb.schema import expr as s_expr

from edb.edgeql import ast as qlast

from . import casts
from . import context
from . import dispatch
from . import setgen
from . import typegen


class SettingInfo(NamedTuple):
    param_name: str
    param_type: s_types.Type
    cardinality: qltypes.SchemaCardinality
    required: bool
    requires_restart: bool
    backend_setting: str | None
    affects_compilation: bool
    is_system_config: bool
    ptr: Optional[s_pointers.Pointer]


@dispatch.compile.register
def compile_ConfigSet(
    expr: qlast.ConfigSet,
    *,
    ctx: context.ContextLevel,
) -> irast.Set:

    info = _validate_op(expr, ctx=ctx)
    param_val = dispatch.compile(expr.expr, ctx=ctx)
    param_type = info.param_type
    val_type = setgen.get_set_type(param_val, ctx=ctx)
    compatible = s_types.is_type_compatible(
        val_type, param_type, schema=ctx.env.schema)
    if not compatible:
        if not val_type.assignment_castable_to(param_type, ctx.env.schema):
            raise errors.ConfigurationError(
                f'invalid setting value type for {info.param_name}: '
                f'{val_type.get_displayname(ctx.env.schema)!r} '
                f'(expecting {param_type.get_displayname(ctx.env.schema)!r})'
            )
        else:
            param_val = casts.compile_cast(
                param_val, param_type, span=None, ctx=ctx)

    try:
        if expr.scope != qltypes.ConfigScope.GLOBAL:
            val = ireval.evaluate_to_python_val(
                param_val, schema=ctx.env.schema)
        else:
            val = None
    except ireval.UnsupportedExpressionError as e:
        raise errors.QueryError(
            f'non-constant expression in CONFIGURE {expr.scope} SET',
            span=expr.expr.span
        ) from e
    else:
        if isinstance(val, statypes.ScalarType) and info.backend_setting:
            backend_expr = dispatch.compile(
                qlast.Constant.string(val.to_backend_str()),
                ctx=ctx,
            )
        else:
            backend_expr = None

    if info.ptr:
        _enforce_pointer_constraints(
            info.ptr, param_val, ctx=ctx, for_obj=False)

    config_set = irast.ConfigSet(
        name=info.param_name,
        cardinality=info.cardinality,
        required=info.required,
        scope=expr.scope,
        requires_restart=info.requires_restart,
        backend_setting=info.backend_setting,
        is_system_config=info.is_system_config,
        span=expr.span,
        expr=param_val,
        backend_expr=backend_expr,
    )
    return setgen.ensure_set(config_set, ctx=ctx)


@dispatch.compile.register
def compile_ConfigReset(
    expr: qlast.ConfigReset,
    *,
    ctx: context.ContextLevel,
) -> irast.Set:

    info = _validate_op(expr, ctx=ctx)
    filter_expr = expr.where
    select_ir = None

    if not info.param_type.is_object_type() and filter_expr is not None:
        raise errors.QueryError(
            'RESET of a primitive configuration parameter '
            'must not have a FILTER clause',
            span=expr.span,
        )

    elif isinstance(info.param_type, s_objtypes.ObjectType):
        param_type_name = info.param_type.get_name(ctx.env.schema)
        param_type_ref = qlast.ObjectRef(
            name=param_type_name.name,
            module=param_type_name.module,
        )
        body = qlast.Shape(
            expr=qlast.Path(steps=[param_type_ref]),
            elements=s_utils.get_config_type_shape(
                ctx.env.schema, info.param_type, path=[param_type_ref]),
        )
        # The body needs to have access to secrets, since they get put
        # into the shape and are necessary for compiling the deletion
        # code, so compile the body in a way that we allow it.
        # The filter should *not* be able to access secret pointers, though.
        with ctx.new() as sctx:
            sctx.current_schema_views += (info.param_type,)
            body_ir = dispatch.compile(body, ctx=sctx)

        with ctx.new() as sctx:
            sctx.anchors = sctx.anchors.copy()
            select = qlast.SelectQuery(
                result=sctx.create_anchor(body_ir, 'a'),
                where=filter_expr,
            )

            sctx.modaliases = ctx.modaliases.copy()
            sctx.modaliases[None] = 'cfg'
            select_ir = setgen.ensure_set(
                dispatch.compile(select, ctx=sctx), ctx=sctx)

    config_reset = irast.ConfigReset(
        name=info.param_name,
        cardinality=info.cardinality,
        scope=expr.scope,
        requires_restart=info.requires_restart,
        backend_setting=info.backend_setting,
        is_system_config=info.is_system_config,
        span=expr.span,
        selector=select_ir,
    )
    return setgen.ensure_set(config_reset, ctx=ctx)


@dispatch.compile.register
def compile_ConfigInsert(
    expr: qlast.ConfigInsert, *, ctx: context.ContextLevel
) -> irast.Set:

    info = _validate_op(expr, ctx=ctx)

    if expr.scope not in (
        qltypes.ConfigScope.INSTANCE, qltypes.ConfigScope.DATABASE
    ):
        raise errors.UnsupportedFeatureError(
            f'CONFIGURE {expr.scope} INSERT is not supported'
        )

    subject = info.param_type
    insert_stmt = qlast.InsertQuery(
        subject=s_utils.name_to_ast_ref(subject.get_name(ctx.env.schema)),
        shape=expr.shape,
    )

    _inject_tname(insert_stmt, ctx=ctx)

    with ctx.newscope(fenced=False) as subctx:
        subctx.expr_exposed = context.Exposure.EXPOSED
        subctx.modaliases = ctx.modaliases.copy()
        subctx.modaliases[None] = 'cfg'
        subctx.special_computables_in_mutation_shape |= {'_tname'}
        insert_ir = dispatch.compile(insert_stmt, ctx=subctx)
        insert_ir_set = setgen.ensure_set(insert_ir, ctx=subctx)
        assert isinstance(insert_ir_set.expr, irast.InsertStmt)
        insert_subject = insert_ir_set.expr.subject

        _validate_config_object(insert_subject, scope=expr.scope, ctx=subctx)

    return setgen.ensure_set(
        irast.ConfigInsert(
            name=info.param_name,
            cardinality=info.cardinality,
            scope=expr.scope,
            requires_restart=info.requires_restart,
            backend_setting=info.backend_setting,
            is_system_config=info.is_system_config,
            expr=insert_subject,
            span=expr.span,
        ),
        ctx=ctx,
    )


def _inject_tname(
    insert_stmt: qlast.InsertQuery, *, ctx: context.ContextLevel
) -> None:

    for el in insert_stmt.shape:
        if isinstance(el.compexpr, qlast.InsertQuery):
            _inject_tname(el.compexpr, ctx=ctx)

    assert isinstance(insert_stmt.subject, qlast.BaseObjectRef)
    insert_stmt.shape.append(
        qlast.ShapeElement(
            expr=qlast.Path(
                steps=[qlast.Ptr(name='_tname')],
            ),
            compexpr=qlast.Path(
                steps=[
                    qlast.Introspect(
                        type=qlast.TypeName(
                            maintype=insert_stmt.subject,
                        ),
                    ),
                    qlast.Ptr(name='name'),
                ],
            ),
        ),
    )


def _validate_config_object(
    expr: irast.Set, *, scope: str, ctx: context.ContextLevel
) -> None:

    for element, _ in expr.shape:
        assert isinstance(element.expr, irast.Pointer)
        if element.expr.ptrref.shortname.name == 'id':
            continue

        ptr = typegen.ptrcls_from_ptrref(
            element.expr.ptrref.real_material_ptr,
            ctx=ctx,
        )
        if isinstance(ptr, s_pointers.Pointer):
            _enforce_pointer_constraints(
                ptr, element, ctx=ctx, for_obj=True)

        if (irtyputils.is_object(element.typeref)
                and isinstance(element.expr, irast.InsertStmt)):
            _validate_config_object(element, scope=scope, ctx=ctx)


def _validate_global_op(
    expr: qlast.ConfigOp, *, ctx: context.ContextLevel
) -> SettingInfo:
    glob_name = s_utils.ast_ref_to_name(expr.name)
    glob = ctx.env.get_schema_object_and_track(
        glob_name, expr.name,
        modaliases=ctx.modaliases, type=s_globals.Global)
    assert isinstance(glob, s_globals.Global)

    fullname = glob.get_name(ctx.env.schema)
    if sn.UnqualName(fullname.module) in s_schema.STD_MODULES:
        raise errors.ConfigurationError(
            f"system global '{glob_name}' may not be explicitly specified",
            span=expr.name.span
        )

    if isinstance(expr, (qlast.ConfigSet, qlast.ConfigReset)):
        if glob.get_expr(ctx.env.schema):
            raise errors.ConfigurationError(
                f"global '{glob_name}' is computed from an expression and "
                f"cannot be modified",
                span=expr.name.span
            )

    param_type = glob.get_target(ctx.env.schema)

    return SettingInfo(param_name=str(glob.get_name(ctx.env.schema)),
                       param_type=param_type,
                       cardinality=glob.get_cardinality(ctx.env.schema),
                       required=glob.get_required(ctx.env.schema),
                       requires_restart=False,
                       backend_setting=None,
                       is_system_config=False,
                       affects_compilation=False,
                       ptr=None)


def _enforce_pointer_constraints(
    ptr: s_pointers.Pointer,
    expr: irast.Set,
    *,
    ctx: context.ContextLevel,
    for_obj: bool,
) -> None:
    constraints = ptr.get_constraints(ctx.env.schema)
    for constraint in constraints.objects(ctx.env.schema):
        if constraint.issubclass(
            ctx.env.schema,
            ctx.env.schema.get('std::exclusive', type=s_constr.Constraint),
        ):
            continue

        with ctx.detached() as sctx:
            sctx.partial_path_prefix = expr
            sctx.anchors = ctx.anchors.copy()
            sctx.anchors['__subject__'] = expr

            final_expr: Optional[s_expr.Expression] = (
                constraint.get_finalexpr(ctx.env.schema)
            )
            assert final_expr is not None and final_expr.parse() is not None
            ir = dispatch.compile(final_expr.parse(), ctx=sctx)

        result = ireval.evaluate(ir, schema=ctx.env.schema)
        assert isinstance(result, irast.BooleanConstant)
        if result.value != 'true':
            if for_obj:
                name = ptr.get_verbosename(ctx.env.schema, with_parent=True)
            else:
                name = repr(ptr.get_shortname(ctx.env.schema).name)
            raise errors.ConfigurationError(
                f'invalid setting value for {name}'
            )


def _validate_op(
    expr: qlast.ConfigOp, *, ctx: context.ContextLevel
) -> SettingInfo:

    if expr.scope == qltypes.ConfigScope.GLOBAL:
        return _validate_global_op(expr, ctx=ctx)

    cfg_host_type = None
    is_ext_config = False
    if expr.name.module:
        cfg_host_name = sn.name_from_string(expr.name.module)
        cfg_host_type = ctx.env.get_schema_type_and_track(
            cfg_host_name, default=None)
        is_ext_config = bool(cfg_host_type)

    abstract_config = ctx.env.get_schema_type_and_track(
        sn.QualName('cfg', 'AbstractConfig'))
    ext_config = ctx.env.get_schema_type_and_track(
        sn.QualName('cfg', 'ExtensionConfig'))

    if not cfg_host_type:
        cfg_host_type = abstract_config

    name = fullname = expr.name.name
    if is_ext_config:
        fullname = f'{cfg_host_type.get_name(ctx.env.schema)}::{name}'

    assert isinstance(cfg_host_type, s_objtypes.ObjectType)
    cfg_type = None
    ptr = None

    if isinstance(expr, (qlast.ConfigSet, qlast.ConfigReset)):
        if is_ext_config and expr.scope == qltypes.ConfigScope.INSTANCE:
            raise errors.ConfigurationError(
                'INSTANCE configuration of extension-defined config variables '
                'is not allowed'
            )

        # expr.name is the actual name of the property.
        ptr = cfg_host_type.maybe_get_ptr(ctx.env.schema, sn.UnqualName(name))
        if ptr is not None:
            cfg_type = ptr.get_target(ctx.env.schema)

    if cfg_type is None:
        if isinstance(expr, qlast.ConfigSet):
            raise errors.ConfigurationError(
                f'unrecognized configuration parameter {name!r}',
                span=expr.span
            )

        cfg_type = ctx.env.get_schema_type_and_track(
            s_utils.ast_ref_to_name(expr.name), default=None)
        if not cfg_type and not expr.name.module:
            # expr.name is the name of the configuration type
            cfg_type = ctx.env.get_schema_type_and_track(
                sn.QualName('cfg', name), default=None)
        if not cfg_type:
            raise errors.ConfigurationError(
                f'unrecognized configuration object {name!r}',
                span=expr.span
            )

        assert isinstance(cfg_type, s_objtypes.ObjectType)
        ptr_candidate: Optional[s_pointers.Pointer] = None

        mro = [cfg_type] + list(
            cfg_type.get_ancestors(ctx.env.schema).objects(ctx.env.schema))
        for ct in mro:
            ptrs = ctx.env.schema.get_referrers(
                ct, scls_type=s_links.Link, field_name='target')

            if ptrs:
                pointer_link = next(iter(ptrs))
                assert isinstance(pointer_link, s_links.Link)
                ptr_candidate = pointer_link
                break

        if (
            ptr_candidate is None
            or (ptr_source := ptr_candidate.get_source(ctx.env.schema)) is None
            or not ptr_source.issubclass(
                ctx.env.schema, (abstract_config, ext_config))
        ):
            raise errors.ConfigurationError(
                f'{name!r} cannot be configured directly'
            )

        ptr = ptr_candidate

        fullname = ptr.get_shortname(ctx.env.schema).name
        if ptr_source.issubclass(ctx.env.schema, ext_config):
            fullname = f'{ptr_source.get_name(ctx.env.schema)}::{fullname}'

    assert isinstance(ptr, s_pointers.Pointer)

    sys_attr = ptr.get_annotations(ctx.env.schema).get(
        ctx.env.schema, sn.QualName('cfg', 'system'), None)

    system = (
        sys_attr is not None
        and sys_attr.get_value(ctx.env.schema) == 'true'
    )

    cardinality = ptr.get_cardinality(ctx.env.schema)
    assert cardinality is not None

    restart_attr = ptr.get_annotations(ctx.env.schema).get(
        ctx.env.schema, sn.QualName('cfg', 'requires_restart'), None)

    requires_restart = (
        restart_attr is not None
        and restart_attr.get_value(ctx.env.schema) == 'true'
    )

    backend_attr = ptr.get_annotations(ctx.env.schema).get(
        ctx.env.schema, sn.QualName('cfg', 'backend_setting'), None)

    if backend_attr is not None:
        backend_setting = json.loads(backend_attr.get_value(ctx.env.schema))
    else:
        backend_setting = None

    system_attr = ptr.get_annotations(ctx.env.schema).get(
        ctx.env.schema, sn.QualName('cfg', 'system'), None)

    is_system_config = (
        system_attr is not None
        and system_attr.get_value(ctx.env.schema) == 'true'
    )

    compilation_attr = ptr.get_annotations(ctx.env.schema).get(
        ctx.env.schema, sn.QualName('cfg', 'affects_compilation'), None)

    if compilation_attr is not None:
        affects_compilation = (
            json.loads(compilation_attr.get_value(ctx.env.schema)))
    else:
        affects_compilation = False

    if system and expr.scope is not qltypes.ConfigScope.INSTANCE:
        raise errors.ConfigurationError(
            f'{name!r} is a system-level configuration parameter; '
            f'use "CONFIGURE INSTANCE"')

    return SettingInfo(param_name=fullname,
                       param_type=cfg_type,
                       cardinality=cardinality,
                       required=False,
                       requires_restart=requires_restart,
                       backend_setting=backend_setting,
                       is_system_config=is_system_config,
                       affects_compilation=affects_compilation,
                       ptr=ptr)


================================================
FILE: edb/edgeql/compiler/config_desc.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""Implementation of DESCRIBE ... CONFIG"""

from __future__ import annotations

import textwrap

from edb.edgeql import parser as qlparser
from edb.edgeql import qltypes
from edb.edgeql import quote as qlquote

from . import context
from . import dispatch

from edb.ir import ast as irast

from edb.schema import name as s_name
from edb.schema import objtypes as s_objtypes
from edb.schema import scalars as s_scalars
from edb.schema import schema as s_schema
from edb.schema import types as s_types

from edb.pgsql import common

ql = common.quote_literal


def compile_describe_config(
    scope: qltypes.ConfigScope, ctx: context.ContextLevel
) -> irast.Set:
    config_edgeql = _describe_config(
        ctx.env.schema, scope, ctx.env.options.testmode)
    config_ast = qlparser.parse_fragment(config_edgeql)

    with ctx.new() as subctx:
        subctx.allow_factoring()
        return dispatch.compile(config_ast, ctx=subctx)


def _describe_config(
    schema: s_schema.Schema,
    scope: qltypes.ConfigScope,
    testmode: bool,
) -> str:
    """Generate an EdgeQL query to render config as DDL."""

    if scope is qltypes.ConfigScope.INSTANCE:
        source = 'system override'
        config_object_name = 'cfg::InstanceConfig'
    elif scope is qltypes.ConfigScope.DATABASE:
        source = 'database'
        config_object_name = 'cfg::DatabaseConfig'
    else:
        raise AssertionError(f'unexpected configuration source: {scope!r}')

    cfg = schema.get(config_object_name, type=s_objtypes.ObjectType)
    items = []
    items.extend(_describe_config_inner(
        schema, scope, config_object_name, cfg, testmode
    ))
    ext = schema.get('cfg::ExtensionConfig', type=s_objtypes.ObjectType)
    for ext_cfg in sorted(
        ext.descendants(schema), key=lambda x: x.get_name(schema)
    ):
        items.extend(_describe_config_inner(
            schema, scope, config_object_name, ext_cfg, testmode
        ))

    testmode_check = (
        "json_get(cfg::get_config_json(),'__internal_testmode','value')"
        " ?? false"
    )
    query = (
        "assert_exists(assert_single(("
        + f"FOR conf IN {{cfg::get_config_json(sources := [{ql(source)}])}} "
        + "UNION (\n"
        + (f"FOR testmode IN {{{testmode_check}}} UNION (\n"
           if testmode else "")
        + "SELECT array_join([" + ', '.join(items) + "], '')"
        + (")" if testmode else "")
        + ")"
        + ")))"
    )
    return query


def _describe_config_inner(
    schema: s_schema.Schema,
    scope: qltypes.ConfigScope,
    config_object_name: str,
    cfg: s_objtypes.ObjectType,
    testmode: bool,
) -> list[str]:
    """Generate an EdgeQL query to render config as DDL."""

    actual_name = str(cfg.get_name(schema))
    cast = (
        f'.extensions[is {actual_name}]' if actual_name != config_object_name
        else ''
    )

    items = []
    for ptr_name, p in sorted(
        cfg.get_pointers(schema).items(schema),
        key=lambda x: x[0],
    ):
        pn = str(ptr_name)
        if (
            pn == 'id'
            or p.get_computable(schema)
            or p.get_protected(schema)
        ):
            continue

        is_internal = (
            p.get_annotation(
                schema,
                s_name.QualName('cfg', 'internal')
            ) == 'true'
        )
        if is_internal and not testmode:
            continue

        ptype = p.get_target(schema)
        assert ptype is not None

        # Skip backlinks to the base object. The will get plenty of
        # special treatment.
        if str(ptype.get_name(schema)) == 'cfg::AbstractConfig':
            continue

        ptr_card = p.get_cardinality(schema)
        mult = ptr_card.is_multi()
        psource = f'{config_object_name}{cast}.{qlquote.quote_ident(pn)}'
        if isinstance(ptype, s_objtypes.ObjectType):
            item = textwrap.indent(
                _render_config_object(
                    schema=schema,
                    valtype=ptype,
                    value_expr=psource,
                    scope=scope,
                    join_term='',
                    level=1,
                ),
                ' ' * 4,
            )
        else:
            fn = (
                pn if actual_name == config_object_name
                else f'{actual_name}::{pn}'
            )
            renderer = (
                _render_config_redacted if p.get_secret(schema)
                else _render_config_set if mult
                else _render_config_scalar
            )
            item = textwrap.indent(
                renderer(
                    schema=schema,
                    valtype=ptype,
                    value_expr=psource,
                    name=fn,
                    scope=scope,
                    level=1,
                ),
                ' ' * 4,
            )

        fpn = f'{actual_name}::{pn}' if cast else pn

        condition = f'EXISTS json_get(conf, {ql(fpn)})'
        if is_internal:
            condition = f'({condition}) AND testmode'
        # For INSTANCE, filter out configs that are set to the default.
        # This is because we currently implement the defaults by
        # setting them with CONFIGURE INSTANCE, so we can't detect
        # defaults by seeing what is unset.
        if (
            scope == qltypes.ConfigScope.INSTANCE
            and (default := p.get_default(schema))
        ):
            condition = f'({condition}) AND {psource} ?!= ({default.text})'

        items.append(f"(\n{item}\n    IF {condition} ELSE ''\n  )")

    return items


def _render_config_value(
    *,
    schema: s_schema.Schema,
    valtype: s_types.Type,
    value_expr: str,
) -> str:
    if valtype.issubclass(
        schema,
        schema.get('std::anyreal', type=s_scalars.ScalarType),
    ):
        val = f'{value_expr}'
    elif valtype.issubclass(
        schema,
        schema.get('std::bool', type=s_scalars.ScalarType),
    ):
        val = f'{value_expr}'
    elif valtype.issubclass(
        schema,
        schema.get('std::duration', type=s_scalars.ScalarType),
    ):
        val = f'"" ++ cfg::_quote({value_expr})'
    elif valtype.issubclass(
        schema,
        schema.get('cfg::memory', type=s_scalars.ScalarType),
    ):
        val = f'"" ++ cfg::_quote({value_expr})'
    elif valtype.issubclass(
        schema,
        schema.get('std::str', type=s_scalars.ScalarType),
    ):
        val = f'cfg::_quote({value_expr})'
    elif valtype.is_enum(schema):
        tn = valtype.get_name(schema)
        val = f'"<{str(tn)}>" ++ cfg::_quote({value_expr})'
    else:
        raise AssertionError(
            f'unexpected configuration value type: '
            f'{valtype.get_displayname(schema)}'
        )

    return val


def _render_config_redacted(
    *,
    schema: s_schema.Schema,
    valtype: s_types.Type,
    value_expr: str,
    scope: qltypes.ConfigScope,
    name: str,
    level: int,
) -> str:
    if level == 1:
        return (
            f"'CONFIGURE {scope.to_edgeql()} "
            f"SET {qlquote.quote_ident(name)} := {{}};  # REDACTED\\n'"
        )
    else:
        indent = ' ' * (4 * (level - 1))
        return f"'{indent}{qlquote.quote_ident(name)} := {{}},  # REDACTED'"


def _render_config_set(
    *,
    schema: s_schema.Schema,
    valtype: s_types.Type,
    value_expr: str,
    scope: qltypes.ConfigScope,
    name: str,
    level: int,
) -> str:
    assert isinstance(valtype, s_scalars.ScalarType)
    v = _render_config_value(
        schema=schema, valtype=valtype, value_expr=value_expr)
    if level == 1:
        return (
            f"'CONFIGURE {scope.to_edgeql()} "
            f"SET {qlquote.quote_ident(name)} := {{' ++ "
            f"array_join(array_agg((select _ := {v} order by _)), ', ') "
            f"++ '}};\\n'"
        )
    else:
        indent = ' ' * (4 * (level - 1))
        return (
            f"'{indent}{qlquote.quote_ident(name)} := {{' ++ "
            f"array_join(array_agg((SELECT _ := {v} ORDER BY _)), ', ') "
            f"++ '}},'"
        )


def _render_config_scalar(
    *,
    schema: s_schema.Schema,
    valtype: s_types.Type,
    value_expr: str,
    scope: qltypes.ConfigScope,
    name: str,
    level: int,
) -> str:
    assert isinstance(valtype, s_scalars.ScalarType)
    v = _render_config_value(
        schema=schema, valtype=valtype, value_expr=value_expr)
    if level == 1:
        return (
            f"'CONFIGURE {scope.to_edgeql()} "
            f"SET {qlquote.quote_ident(name)} := ' ++ {v} ++ ';\\n'"
        )
    else:
        indent = ' ' * (4 * (level - 1))
        return f"'{indent}{qlquote.quote_ident(name)} := ' ++ {v} ++ ','"


def _render_config_object(
    *,
    schema: s_schema.Schema,
    valtype: s_objtypes.ObjectType,
    value_expr: str,
    scope: qltypes.ConfigScope,
    join_term: str,
    level: int,
) -> str:
    # Generate a valid `CONFIGURE  INSERT ConfigObject`
    # shape for a given configuration object type or
    # `INSERT ConfigObject` for a nested configuration type.
    sub_layouts = _describe_config_object(
        schema=schema, valtype=valtype, level=level + 1, scope=scope)
    sub_layouts_items = []
    if level == 1:
        decor = [f'CONFIGURE {scope.to_edgeql()} INSERT ', ';\\n']
    else:
        decor = ['(INSERT ', ')']

    indent = ' ' * (4 * (level - 1))

    for type_name, type_layout in sub_layouts.items():
        if type_layout:
            sub_layout_item = (
                f"'{indent}{decor[0]}{type_name} {{\\n'\n++ "
                + "\n++ ".join(type_layout)
                + f" ++ '{indent}}}{decor[1]}'"
            )
        else:
            sub_layout_item = (
                f"'{indent}{decor[0]}{type_name}{decor[1]}'"
            )

        if len(sub_layouts) > 1:
            if type_layout:
                sub_layout_item = (
                    f'(WITH item := item[IS {type_name}]'
                    f' SELECT {sub_layout_item}) '
                    f'IF item.__type__.name = {ql(str(type_name))}'
                )
            else:
                sub_layout_item = (
                    f'{sub_layout_item} '
                    f'IF item.__type__.name = {ql(str(type_name))}'
                )

        sub_layouts_items.append(sub_layout_item)

    if len(sub_layouts_items) > 1:
        sli_render = '\nELSE '.join(sub_layouts_items) + "\nELSE ''"
    else:
        sli_render = sub_layouts_items[0]

    return '\n'.join((
        f"array_join(array_agg((SELECT _ := (",
        f"  FOR item IN {{ {value_expr} }}",
        f"  UNION (",
        f"{textwrap.indent(sli_render, ' ' * 4)}",
        f"  )",
        f") ORDER BY _)), {ql(join_term)})",
    ))


def _describe_config_object(
    *,
    schema: s_schema.Schema,
    valtype: s_objtypes.ObjectType,
    level: int,
    scope: qltypes.ConfigScope,
) -> dict[s_name.QualName, list[str]]:
    cfg_types = [valtype]
    cfg_types.extend(cfg_types[0].descendants(schema))
    layouts = {}
    for cfg in cfg_types:
        items = []
        for ptr_name, p in sorted(
            cfg.get_pointers(schema).items(schema),
            key=lambda x: x[0],
        ):
            pn = str(ptr_name)
            if (
                pn == 'id'
                or p.get_protected(schema)
                or p.get_annotation(
                    schema,
                    s_name.QualName('cfg', 'internal'),
                ) == 'true'
            ):
                continue

            ptype = p.get_target(schema)
            assert ptype is not None
            if str(ptype.get_name(schema)) == 'cfg::AbstractConfig':
                continue

            ptr_card = p.get_cardinality(schema)
            mult = ptr_card.is_multi()
            psource = f'item.{qlquote.quote_ident(pn)}'

            if isinstance(ptype, s_objtypes.ObjectType):
                rval = textwrap.indent(
                    _render_config_object(
                        schema=schema,
                        valtype=ptype,
                        value_expr=psource,
                        scope=scope,
                        join_term=' UNION ',
                        level=level + 1,
                    ),
                    ' ' * 2,
                ).strip()
                indent = ' ' * (4 * (level - 1))
                item = (
                    f"'{indent}{qlquote.quote_ident(pn)} "
                    f":= (\\n'\n++ {rval} ++ '\\n{indent}),\\n'"
                )
                condition = None
            else:
                render = (
                    _render_config_redacted if p.get_secret(schema)
                    else _render_config_set if mult
                    else _render_config_scalar
                )
                item = render(
                    schema=schema,
                    valtype=ptype,
                    value_expr=psource,
                    scope=scope,
                    name=pn,
                    level=level,
                )
                if p.get_secret(schema):
                    condition = 'true'
                else:
                    condition = f'EXISTS {psource}'

            if condition is not None:
                item = f"({item} ++ '\\n' IF {condition} ELSE '')"

            items.append(item)

        layouts[cfg.get_name(schema)] = items

    return layouts


================================================
FILE: edb/edgeql/compiler/conflicts.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""Compilation of DML exclusive constraint conflict handling."""


from __future__ import annotations
from typing import Optional, Iterable, Sequence

from edb import errors

from edb.ir import ast as irast
from edb.ir import utils as irutils
from edb.ir import typeutils

from edb.schema import constraints as s_constr
from edb.schema import name as s_name
from edb.schema import links as s_links
from edb.schema import objtypes as s_objtypes
from edb.schema import pointers as s_pointers
from edb.schema import utils as s_utils
from edb.schema import expr as s_expr

from edb.edgeql import ast as qlast
from edb.edgeql import utils as qlutils
from edb.edgeql import qltypes

from . import astutils
from . import context
from . import dispatch
from . import inference
from . import pathctx
from . import schemactx
from . import setgen
from . import typegen


def _get_needed_ptrs(
    subject_typ: s_objtypes.ObjectType,
    obj_constrs: Sequence[s_constr.Constraint],
    initial_ptrs: Iterable[str],
    rewrite_kind: Optional[qltypes.RewriteKind],
    ctx: context.ContextLevel,
) -> tuple[set[str], dict[str, qlast.Expr]]:
    """Find all the pointers needed by a list of constraints and pointers.

    This chases down computed pointer definitions, rewrites, and
    constraint expressions.
    """
    needed_ptrs = set(initial_ptrs)
    for constr in obj_constrs:
        subjexpr: Optional[s_expr.Expression] = (
            constr.get_subjectexpr(ctx.env.schema)
        )
        assert subjexpr
        needed_ptrs |= qlutils.find_subject_ptrs(subjexpr.parse())
        if except_expr := constr.get_except_expr(ctx.env.schema):
            assert isinstance(except_expr, s_expr.Expression)
            needed_ptrs |= qlutils.find_subject_ptrs(except_expr.parse())

    wl = list(needed_ptrs)
    ptr_anchors = {}
    while wl:
        p = wl.pop()
        ptr = subject_typ.getptr(ctx.env.schema, s_name.UnqualName(p))
        exprs = []
        if expr := ptr.get_expr(ctx.env.schema):
            exprs.append(expr)
        if rewrite_kind and (
            rewrite := ptr.get_rewrite(ctx.env.schema, rewrite_kind)
        ):
            exprs.append(rewrite.get_expr(ctx.env.schema))
        for expr in exprs:
            assert isinstance(expr.parse(), qlast.Expr)
            ptr_anchors[p] = expr.parse()
            for ref in qlutils.find_subject_ptrs(expr.parse()):
                if ref not in needed_ptrs:
                    wl.append(ref)
                    needed_ptrs.add(ref)

    return needed_ptrs, ptr_anchors


def _get_rewrite_kind(stmt: irast.MutatingStmt) -> qltypes.RewriteKind | None:
    return (
        qltypes.RewriteKind.Insert
        if isinstance(stmt, irast.InsertStmt)
        else qltypes.RewriteKind.Update
        if isinstance(stmt, irast.UpdateStmt)
        else None
    )


def _get_rewritten_ptrs(
    stmt: irast.MutatingStmt,
    subject_typ: s_objtypes.ObjectType,
    *,
    ctx: context.ContextLevel,
) -> set[str]:
    schema = ctx.env.schema
    rewrite_kind = _get_rewrite_kind(stmt)

    rewritten = set()
    for ptr in subject_typ.get_pointers(schema).objects(schema):
        if rewrite_kind:
            rewrite = ptr.get_rewrite(ctx.env.schema, rewrite_kind)
            if rewrite:
                rewritten.add(ptr.get_shortname(schema).name)

    return rewritten


def _compile_conflict_select_for_obj_type(
    stmt: irast.MutatingStmt,
    subject_typ: s_objtypes.ObjectType,
    *,
    for_inheritance: bool,
    fake_dml_set: Optional[irast.Set],
    obj_constrs: Sequence[s_constr.Constraint],
    constrs: dict[str, tuple[s_pointers.Pointer, list[s_constr.Constraint]]],
    span: Optional[irast.Span],
    ctx: context.ContextLevel,
) -> tuple[Optional[qlast.Expr], bool]:
    """Synthesize a select of conflicting objects

    ... for a single object type. This gets called once for each ancestor
    type that provides constraints to the type being inserted.

    `cnstrs` contains the constraints to consider.
    """
    # We have a fake_dml_set to represent the root exactly when we are
    # compiling this for inheritance checking reasons (and not for
    # real UNLESS CONFLICTs)
    assert for_inheritance == bool(fake_dml_set)

    rewrite_kind = _get_rewrite_kind(stmt)

    # Find which pointers we need to grab
    needed_ptrs, ptr_anchors = _get_needed_ptrs(
        subject_typ, obj_constrs, constrs.keys(), rewrite_kind, ctx=ctx
    )

    # Check that no pointers in constraints are rewritten
    if rewrite_kind and not fake_dml_set:
        for p in needed_ptrs:
            ptr = subject_typ.getptr(ctx.env.schema, s_name.UnqualName(p))
            rewrite = ptr.get_rewrite(ctx.env.schema, rewrite_kind)
            if rewrite:
                raise errors.UnsupportedFeatureError(
                    "INSERT UNLESS CONFLICT cannot be used on properties or "
                    "links that have a rewrite rule specified",
                    span=span,
                )

    ctx.anchors = ctx.anchors.copy()

    # If we are given a fake_dml_set to directly represent the result
    # of our DML, use that instead of populating the result.
    # TODO: XXX: always use fake_dml_set??
    # (would we need to still disallow MUTATING properties?)
    if fake_dml_set:
        for p in needed_ptrs | {'id', '__type__'}:
            ptr = subject_typ.getptr(ctx.env.schema, s_name.UnqualName(p))
            val = setgen.extend_path(fake_dml_set, ptr, ctx=ctx)

            ptr_anchors[p] = ctx.create_anchor(val, p)

    # Find the IR corresponding to the fields we care about and
    # produce anchors for them
    ptrs_in_shape = set()
    for elem, _ in stmt.subject.shape:
        rptr = elem.expr
        name = rptr.ptrref.shortname.name
        ptrs_in_shape.add(name)
        if name in needed_ptrs and name not in ptr_anchors:
            assert rptr.expr
            # We don't properly support hoisting volatile properties out of
            # UNLESS CONFLICT, so disallow it. We *do* support handling DML
            # there, since that gets hoisted into CTEs via its own mechanism.
            # See issue #1699.
            if inference.infer_volatility(
                rptr.expr, ctx.env, exclude_dml=True
            ).is_volatile():
                assert not for_inheritance
                raise errors.UnsupportedFeatureError(
                    'INSERT UNLESS CONFLICT ON does not support volatile '
                    'properties',
                    span=span,
                )

            # We want to use the same path_scope_id as the original
            elem_set = setgen.ensure_set(rptr.expr, ctx=ctx)
            elem_set.path_scope_id = elem.path_scope_id

            # FIXME: The wrong thing will definitely happen if there are
            # volatile entries here
            ptr_anchors[name] = ctx.create_anchor(elem_set, name)

    if for_inheritance and not ptrs_in_shape:
        return None, False

    # Fill in empty sets for pointers that are needed but not present
    present_ptrs = set(ptr_anchors)
    for p in (needed_ptrs - present_ptrs):
        ptr = subject_typ.getptr(ctx.env.schema, s_name.UnqualName(p))
        typ = ptr.get_target(ctx.env.schema)
        assert typ
        ptr_anchors[p] = qlast.TypeCast(
            expr=qlast.Set(elements=[]),
            type=typegen.type_to_ql_typeref(typ, ctx=ctx))

    if not ptr_anchors:
        raise errors.QueryError(
            'INSERT UNLESS CONFLICT property requires matching shape',
            span=span,
        )

    conds: list[qlast.Expr] = []
    for ptrname, (ptr, ptr_cnstrs) in constrs.items():
        if ptrname not in present_ptrs:
            continue
        anchor = qlutils.subject_paths_substitute(
            ptr_anchors[ptrname], ptr_anchors)
        ptr_val = qlast.Path(partial=True, steps=[
            qlast.Ptr(name=ptrname)
        ])
        ptr, ptr_cnstrs = constrs[ptrname]
        ptr_card = ptr.get_cardinality(ctx.env.schema)

        for cnstr in ptr_cnstrs:
            lhs: qlast.Expr = anchor
            rhs: qlast.Expr = ptr_val
            # If there is a subjectexpr, substitute our lhs and rhs in
            # for __subject__ in the subjectexpr and compare *that*
            if (subjectexpr := cnstr.get_subjectexpr(ctx.env.schema)):
                assert isinstance(subjectexpr, s_expr.Expression)
                assert isinstance(subjectexpr.parse(), qlast.Expr)
                lhs = qlutils.subject_substitute(subjectexpr.parse(), lhs)
                rhs = qlutils.subject_substitute(subjectexpr.parse(), rhs)

            conds.append(qlast.BinOp(
                op='=' if ptr_card.is_single() else 'IN',
                left=lhs, right=rhs,
            ))

    # If the type we are looking at is BaseObject, then this must a
    # conflict check we are synthesizing for an explicit .id. We need
    # to ignore access policies in that case, since there is no
    # trigger to back us up.
    # (We can't insert directly into the abstract BaseObject, so this
    # is a safe assumption.)
    ignore_rewrites = (
        str(subject_typ.get_name(ctx.env.schema)) == 'std::BaseObject')
    if ignore_rewrites:
        assert not obj_constrs
        assert len(constrs) == 1 and len(constrs['id'][1]) == 1
    insert_subject = ctx.create_anchor(setgen.class_set(
        subject_typ, ignore_rewrites=ignore_rewrites, ctx=ctx
    ))

    for constr in obj_constrs:
        subject_expr: Optional[s_expr.Expression] = (
            constr.get_subjectexpr(ctx.env.schema)
        )
        assert subject_expr and isinstance(subject_expr.parse(), qlast.Expr)
        lhs = qlutils.subject_paths_substitute(
            subject_expr.parse(), ptr_anchors
        )
        rhs = qlutils.subject_substitute(
            subject_expr.parse(), insert_subject
        )
        op = qlast.BinOp(op='=', left=lhs, right=rhs)

        # If there is an except expr, we need to add in those checks also
        if except_expr := constr.get_except_expr(ctx.env.schema):
            assert isinstance(except_expr, s_expr.Expression)

            e_lhs = qlutils.subject_paths_substitute(
                except_expr.parse(), ptr_anchors)
            e_rhs = qlutils.subject_substitute(
                except_expr.parse(), insert_subject)

            true_ast = qlast.Constant.boolean(True)
            on = qlast.BinOp(
                op='AND',
                left=qlast.BinOp(op='?!=', left=e_lhs, right=true_ast),
                right=qlast.BinOp(op='?!=', left=e_rhs, right=true_ast),
            )
            op = qlast.BinOp(op='AND', left=op, right=on)

        conds.append(op)

    if not conds:
        return None, False

    # We use `any` to compute the disjunction here because some might
    # be empty.
    if len(conds) == 1:
        cond = conds[0]
    else:
        cond = qlast.FunctionCall(
            func='any',
            args=[qlast.Set(elements=conds)],
        )

    # For the result filtering we ignore any objects from the same type.
    if fake_dml_set:
        anchor = qlutils.subject_paths_substitute(
            ptr_anchors['__type__'], ptr_anchors)
        anchor_val = qlast.Path(steps=[anchor, qlast.Ptr(name='id')])
        ptr_val = qlast.Path(
            partial=True,
            steps=[qlast.Ptr(name='__type__'), qlast.Ptr(name='id')],
        )
        cond = qlast.BinOp(
            op='AND',
            left=cond,
            right=qlast.BinOp(op='!=', left=anchor_val, right=ptr_val),
        )

    # Produce a query that finds the conflicting objects
    select_ast = qlast.DetachedExpr(
        expr=qlast.SelectQuery(result=insert_subject, where=cond)
    )

    # If one of the pointers we care about is multi, then we have to always
    # use a conflict CTE check instead of trying to use a constraint.
    has_multi = False
    for ptrname in needed_ptrs:
        ptr = subject_typ.getptr(ctx.env.schema, s_name.UnqualName(ptrname))
        if not ptr.get_cardinality(ctx.env.schema).is_single():
            has_multi = True

    return select_ast, has_multi


def _constr_matters(
    constr: s_constr.Constraint,
    *,
    only_local: bool = False,
    ctx: context.ContextLevel,
) -> bool:
    schema = ctx.env.schema
    return (
        not constr.is_non_concrete(schema)
        and not constr.get_delegated(schema)
        and (
            # In some use sites we always process ancestor constraints
            # too, so in those cases a constraint only matters if it
            # is the "top" constraint where it actually starts
            # applying.
            not only_local
            or constr.get_owned(schema)
            or all(
                anc.get_delegated(schema) or anc.is_non_concrete(schema)
                for anc in constr.get_ancestors(schema).objects(schema)
            )
        )
    )


PointerConstraintMap = dict[
    str,
    tuple[s_pointers.Pointer, list[s_constr.Constraint]],
]
ConstraintPair = tuple[PointerConstraintMap, list[s_constr.Constraint]]
ConflictTypeMap = dict[s_objtypes.ObjectType, ConstraintPair]


def _split_constraints(
    obj_constrs: Sequence[s_constr.Constraint],
    constrs: PointerConstraintMap,
    ctx: context.ContextLevel,
) -> ConflictTypeMap:
    schema = ctx.env.schema

    type_maps: ConflictTypeMap = {}

    # Split up pointer constraints by what object types they come from
    for name, (_, p_constrs) in constrs.items():
        for p_constr in p_constrs:
            ancs = (p_constr,) + p_constr.get_ancestors(schema).objects(schema)
            for anc in ancs:
                if not _constr_matters(anc, only_local=True, ctx=ctx):
                    continue
                p_ptr = anc.get_subject(schema)
                assert isinstance(p_ptr, s_pointers.Pointer)
                obj = p_ptr.get_source(schema)
                assert isinstance(obj, s_objtypes.ObjectType)
                map, _ = type_maps.setdefault(obj, ({}, []))
                _, entry = map.setdefault(name, (p_ptr, []))
                entry.append(anc)

    # Split up object constraints by what object types they come from
    for obj_constr in obj_constrs:
        ancs = (obj_constr,) + obj_constr.get_ancestors(schema).objects(schema)
        for anc in ancs:
            if not _constr_matters(anc, only_local=True, ctx=ctx):
                continue
            obj = anc.get_subject(schema)
            assert isinstance(obj, s_objtypes.ObjectType)
            _, o_constr_entry = type_maps.setdefault(obj, ({}, []))
            o_constr_entry.append(anc)

    return type_maps


def _compile_conflict_select(
    stmt: irast.MutatingStmt,
    subject_typ: s_objtypes.ObjectType,
    *,
    for_inheritance: bool=False,
    fake_dml_set: Optional[irast.Set]=None,
    obj_constrs: Sequence[s_constr.Constraint],
    constrs: PointerConstraintMap,
    span: Optional[irast.Span],
    ctx: context.ContextLevel,
) -> tuple[irast.Set, bool, bool]:
    """Synthesize a select of conflicting objects

    This teases apart the constraints we care about based on which
    type they originate from, generates a SELECT for each type, and
    unions them together.

    `cnstrs` contains the constraints to consider.
    """
    schema = ctx.env.schema

    if for_inheritance:
        type_maps = {subject_typ: (constrs, list(obj_constrs))}
    else:
        type_maps = _split_constraints(obj_constrs, constrs, ctx=ctx)

    always_check = False

    # Generate a separate query for each type
    from_parent = False
    frags = []
    for a_obj, (a_constrs, a_obj_constrs) in type_maps.items():
        frag, frag_always_check = _compile_conflict_select_for_obj_type(
            stmt, a_obj, obj_constrs=a_obj_constrs, constrs=a_constrs,
            for_inheritance=for_inheritance,
            fake_dml_set=fake_dml_set,
            span=span, ctx=ctx,
        )
        always_check |= frag_always_check
        if frag:
            if a_obj != subject_typ:
                from_parent = True
            frags.append(frag)

    always_check |= from_parent or any(
        not child.is_view(schema) for child in subject_typ.children(schema)
    )

    # Union them all together
    select_ast = qlast.Set(elements=frags)
    with ctx.new() as ectx:
        ectx.allow_factoring()

        ectx.implicit_limit = 0
        ectx.allow_endpoint_linkprops = True
        select_ir = dispatch.compile(select_ast, ctx=ectx)
        select_ir = setgen.scoped_set(
            select_ir, force_reassign=True, ctx=ectx)
        assert isinstance(select_ir, irast.Set)

    # If we have an empty set, remake it with the right type
    if isinstance(select_ir.expr, irast.EmptySet):
        select_ir = setgen.new_empty_set(stype=subject_typ, ctx=ctx)

    return select_ir, always_check, from_parent


def _get_exclusive_ptr_constraints(
    typ: s_objtypes.ObjectType,
    include_id: bool,
    *, ctx: context.ContextLevel,
) -> dict[str, tuple[s_pointers.Pointer, list[s_constr.Constraint]]]:
    schema = ctx.env.schema
    pointers = {}

    exclusive_constr = schema.get('std::exclusive', type=s_constr.Constraint)
    for ptr in typ.get_pointers(schema).objects(schema):
        ptr = ptr.get_nearest_non_derived_parent(schema)
        ex_cnstrs = [c for c in ptr.get_constraints(schema).objects(schema)
                     if c.issubclass(schema, exclusive_constr)]
        if ex_cnstrs:
            name = ptr.get_shortname(schema).name
            if name != 'id' or include_id:
                pointers[name] = ptr, ex_cnstrs

    return pointers


def compile_insert_unless_conflict(
    stmt: irast.InsertStmt,
    typ: s_objtypes.ObjectType,
    *, ctx: context.ContextLevel,
) -> irast.OnConflictClause:
    """Compile an UNLESS CONFLICT clause with no ON

    This requires synthesizing a conditional based on all the exclusive
    constraints on the object.
    """
    has_id_write = _has_explicit_id_write(stmt)
    pointers = _get_exclusive_ptr_constraints(
        typ, include_id=has_id_write, ctx=ctx)
    obj_constrs = typ.get_constraints(ctx.env.schema).objects(ctx.env.schema)

    select_ir, always_check, _ = _compile_conflict_select(
        stmt, typ,
        constrs=pointers,
        obj_constrs=obj_constrs,
        span=stmt.span, ctx=ctx)

    return irast.OnConflictClause(
        constraint=None, select_ir=select_ir, always_check=always_check,
        else_ir=None)


def compile_insert_unless_conflict_on(
    stmt: irast.InsertStmt,
    typ: s_objtypes.ObjectType,
    constraint_spec: qlast.Expr,
    else_branch: Optional[qlast.Expr],
    *, ctx: context.ContextLevel,
) -> irast.OnConflictClause:

    with ctx.new() as constraint_ctx:
        constraint_ctx.partial_path_prefix = setgen.class_set(typ, ctx=ctx)

        # We compile the name here so we can analyze it, but we don't do
        # anything else with it.
        cspec_res = dispatch.compile(constraint_spec, ctx=constraint_ctx)

    # We accept a property, link, or a list of them in the form of a
    # tuple.
    if isinstance(cspec_res.expr, irast.Tuple):
        cspec_args = [elem.val for elem in cspec_res.expr.elements]
    else:
        cspec_args = [cspec_res]

    cspec_args = [irutils.unwrap_set(arg) for arg in cspec_args]

    for cspec_arg in cspec_args:
        if not isinstance(cspec_arg.expr, irast.Pointer):
            raise errors.QueryError(
                'UNLESS CONFLICT argument must be a property, link, '
                'or tuple of properties and links',
                span=constraint_spec.span,
            )

        if cspec_arg.expr.source.path_id != stmt.subject.path_id:
            raise errors.QueryError(
                'UNLESS CONFLICT argument must be a property of the '
                'type being inserted',
                span=constraint_spec.span,
            )

    schema = ctx.env.schema

    ptrs = []
    exclusive_constr = schema.get('std::exclusive', type=s_constr.Constraint)
    for cspec_arg in cspec_args:
        assert isinstance(cspec_arg.expr, irast.Pointer)
        schema, ptr = (
            typeutils.ptrcls_from_ptrref(cspec_arg.expr.ptrref, schema=schema))
        if not isinstance(ptr, s_pointers.Pointer):
            raise errors.QueryError(
                'UNLESS CONFLICT argument must be a property, link, '
                'or tuple of properties and links',
                span=constraint_spec.span,
            )

        ptr = ptr.get_nearest_non_derived_parent(schema)
        ptrs.append(ptr)

    obj_constrs = inference.cardinality.get_object_exclusive_constraints(
        typ, set(ptrs), ctx.env)

    field_constrs = []
    if len(ptrs) == 1:
        field_constrs = [
            c for c in ptrs[0].get_constraints(schema).objects(schema)
            if c.issubclass(schema, exclusive_constr)]

    all_constrs = list(obj_constrs) + field_constrs
    if len(all_constrs) != 1:
        raise errors.QueryError(
            'UNLESS CONFLICT property must have a single exclusive constraint',
            span=constraint_spec.span,
        )

    ds = {ptr.get_shortname(schema).name: (ptr, field_constrs)
          for ptr in ptrs}
    select_ir, always_check, from_anc = _compile_conflict_select(
        stmt, typ, constrs=ds, obj_constrs=list(obj_constrs),
        span=stmt.span, ctx=ctx)

    # Compile an else branch
    else_ir = None
    if else_branch:
        # TODO: We should support this, but there is some semantic and
        # implementation trickiness.
        if from_anc:
            raise errors.UnsupportedFeatureError(
                'UNLESS CONFLICT can not use ELSE when constraint is from a '
                'parent type',
                details=(
                    f"The existing object can't be exposed in the ELSE clause "
                    f"because it may not have type {typ.get_name(schema)}"),
                span=constraint_spec.span,
            )

        with ctx.new() as ectx:
            # The ELSE needs to be able to reference the subject in an
            # UPDATE, even though that would normally be prohibited.
            ectx.iterator_path_ids |= {stmt.subject.path_id}

            pathctx.ban_inserting_path(
                stmt.subject.path_id, location='else', ctx=ectx)

            # Compile else
            else_ir = dispatch.compile(
                astutils.ensure_ql_query(else_branch), ctx=ectx
            )
        assert isinstance(else_ir, irast.Set)

    return irast.OnConflictClause(
        constraint=irast.ConstraintRef(id=all_constrs[0].id),
        select_ir=select_ir,
        always_check=always_check,
        else_ir=else_ir
    )


def _has_explicit_id_write(stmt: irast.MutatingStmt) -> bool:
    for elem, _ in stmt.subject.shape:
        if elem.expr.ptrref.shortname.name == 'id':
            # We want to make sure it isn't an implicit id (which
            # won't have an expr) or a default value (which won't have
            # a span).
            #
            # ... it is at least a little dodgy to check for default
            # value by span presence.
            return elem.span is not None and elem.expr.expr is not None
    return False


def _disallow_exclusive_linkprops(
    stmt: irast.MutatingStmt,
    typ: s_objtypes.ObjectType,
    *, ctx: context.ContextLevel,

) -> None:
    # TODO: It should be possible to support this, but we don't deal
    # with it yet, so disallow it for safety reasons.
    schema = ctx.env.schema
    exclusive_constr = schema.get('std::exclusive', type=s_constr.Constraint)
    for ptr in typ.get_pointers(schema).objects(schema):
        if not isinstance(ptr, s_links.Link):
            continue
        ptr = ptr.get_nearest_non_derived_parent(schema)
        for lprop in ptr.get_pointers(schema).objects(schema):
            ex_cnstrs = [
                c for c in lprop.get_constraints(schema).objects(schema)
                if c.issubclass(schema, exclusive_constr)]
            if ex_cnstrs:
                raise errors.UnsupportedFeatureError(
                    'INSERT/UPDATE do not support exclusive constraints on '
                    'link properties when another statement in '
                    'the same query modifies a related type',
                    span=stmt.span,
                )


def _get_type_conflict_constraint_entries(
    stmt: irast.MutatingStmt,
    typ: s_objtypes.ObjectType,
    *, ctx: context.ContextLevel,
) -> list[tuple[s_constr.Constraint, ConstraintPair]]:
    # TODO: why do we return this in such a hinky way?
    rewrite_kind = _get_rewrite_kind(stmt)

    has_id_write = _has_explicit_id_write(stmt)
    pointers = _get_exclusive_ptr_constraints(
        typ, include_id=has_id_write, ctx=ctx)
    exclusive = ctx.env.schema.get('std::exclusive', type=s_constr.Constraint)
    obj_constrs = [
        constr for constr in
        typ.get_constraints(ctx.env.schema).objects(ctx.env.schema)
        if constr.issubclass(ctx.env.schema, exclusive)
    ]

    shape_ptrs = set()
    for elem, op in stmt.subject.shape:
        if op != qlast.ShapeOp.MATERIALIZE:
            shape_ptrs.add(elem.expr.ptrref.shortname.name)
    shape_ptrs |= _get_rewritten_ptrs(stmt, typ, ctx=ctx)

    # This is a little silly, but for *this* we need to do one per
    # constraint (so that we can properly identify which constraint
    # failed in the error messages)
    entries: list[tuple[s_constr.Constraint, ConstraintPair]] = []
    for name, (ptr, ptr_constrs) in pointers.items():
        for ptr_constr in ptr_constrs:
            # For updates, we only need to emit the check if we actually
            # modify a pointer used by the constraint. For inserts, though
            # everything must be in play, since constraints can depend on
            # nonexistence also.
            if (
                _constr_matters(ptr_constr, ctx=ctx)
                and (
                    isinstance(stmt, irast.InsertStmt)
                    or (
                        _get_needed_ptrs(typ, (), [name], rewrite_kind, ctx)[0]
                        & shape_ptrs
                    )
                )
            ):
                entries.append((ptr_constr, ({name: (ptr, [ptr_constr])}, [])))
    for obj_constr in obj_constrs:
        # See note above about needed ptrs check
        if (
            _constr_matters(obj_constr, ctx=ctx)
            and (
                isinstance(stmt, irast.InsertStmt)
                or (_get_needed_ptrs(
                    typ, [obj_constr], (), rewrite_kind, ctx)[0] & shape_ptrs)
            )
        ):
            entries.append((obj_constr, ({}, [obj_constr])))

    return entries


def _compile_inheritance_conflict_selects(
    stmt: irast.MutatingStmt,
    conflict: irast.MutatingStmt,
    typ: s_objtypes.ObjectType,
    subject_type: s_objtypes.ObjectType,
    *, ctx: context.ContextLevel,
) -> list[irast.OnConflictClause]:
    """Compile the selects needed to resolve multiple DML to related types

    Generate a SELECT that finds all objects of type `typ` that conflict with
    the insert `stmt`. The backend will use this to explicitly check that
    no conflicts exist, and raise an error if they do.

    This is needed because we mostly use triggers to enforce these
    cross-type exclusive constraints, and they use a snapshot
    beginning at the start of the statement.
    """
    _disallow_exclusive_linkprops(stmt, typ, ctx=ctx)
    entries = _get_type_conflict_constraint_entries(stmt, typ, ctx=ctx)

    # We need to pull from the actual result overlay,
    # since the final row can depend on things not in the query
    # (on updates always, on inserts due to rewrites).
    fake_subject = qlast.DetachedExpr(expr=qlast.Path(steps=[
        s_utils.name_to_ast_ref(subject_type.get_name(ctx.env.schema))]))

    fake_dml_set = dispatch.compile(fake_subject, ctx=ctx)

    clauses = []
    for cnstr, (p, o) in entries:
        select_ir, _, _ = _compile_conflict_select(
            stmt, typ,
            for_inheritance=True,
            fake_dml_set=fake_dml_set,
            constrs=p,
            obj_constrs=o,
            span=stmt.span, ctx=ctx)
        if isinstance(select_ir.expr, irast.EmptySet):
            continue
        cnstr_ref = irast.ConstraintRef(id=cnstr.id)
        clauses.append(
            irast.OnConflictClause(
                constraint=cnstr_ref, select_ir=select_ir, always_check=False,
                else_ir=None, else_fail=conflict,
                check_anchor=fake_dml_set.path_id)
        )
    return clauses


def compile_inheritance_conflict_checks(
    stmt: irast.MutatingStmt,
    subject_stype: s_objtypes.ObjectType,
    *, ctx: context.ContextLevel,
) -> Optional[list[irast.OnConflictClause]]:

    has_id_write = _has_explicit_id_write(stmt)

    relevant_dml = [
        dml for dml in ctx.env.dml_stmts
        if not isinstance(dml, irast.DeleteStmt)
    ]
    # Updates can conflict with themselves
    if isinstance(stmt, irast.UpdateStmt):
        relevant_dml.append(stmt)

    if not relevant_dml and not has_id_write:
        return None

    assert isinstance(subject_stype, s_objtypes.ObjectType)
    modified_ancestors = set()
    base_object = ctx.env.schema.get(
        'std::BaseObject', type=s_objtypes.ObjectType)

    subject_stype = subject_stype.get_nearest_non_derived_parent(
        ctx.env.schema)
    subject_stype = schemactx.concretify(subject_stype, ctx=ctx)
    # For updates, we need to also consider all descendants, because
    # those could also have interesting constraints of their own.
    if isinstance(stmt, irast.UpdateStmt):
        subject_stypes = list(
            schemactx.get_all_concrete(subject_stype, ctx=ctx))
    else:
        subject_stypes = [subject_stype]

    for ir in relevant_dml:
        # N.B that for updates, the update itself will be in dml_stmts,
        # since an update can conflict with itself if there are subtypes.
        # If there aren't subtypes, though, skip it.
        if ir is stmt and len(subject_stypes) == 1:
            continue

        typ = setgen.get_set_type(ir.subject, ctx=ctx)
        assert isinstance(typ, s_objtypes.ObjectType)
        typ = schemactx.concretify(typ, ctx=ctx)

        # As mentioned above, need to consider descendants of updates
        if isinstance(ir, irast.UpdateStmt):
            typs = list(schemactx.get_all_concrete(typ, ctx=ctx))
        else:
            typs = [typ]

        for typ in typs:
            for subject_stype in subject_stypes:
                # If the earlier DML has a shared ancestor that isn't
                # BaseObject and isn't the same type, then we need to
                # see if we need a conflict select.
                #
                # Note that two DMLs on the same type *can* require a
                # conflict select if at least one of them is an UPDATE
                # and there are children, but that is accounted for by
                # the above loops over all descendants when ir is an
                # UPDATE.
                if subject_stype == typ:
                    continue

                ancs = s_utils.get_class_nearest_common_ancestors(
                    ctx.env.schema, [subject_stype, typ])
                for anc in ancs:
                    if anc != base_object:
                        modified_ancestors.add((subject_stype, anc, ir))

    # If `id` is explicitly written to, synthesize a check against
    # BaseObject to ensure that it doesn't conflict with anything,
    # since we disable the trigger for id's exclusive constraint for
    # performance reasons.
    if has_id_write:
        modified_ancestors.add((subject_stype, base_object, stmt))

    conflicters = []
    for subject_stype, anc_type, ir in modified_ancestors:

        # don't enforce any constraints for abstract object type
        if subject_stype.get_abstract(schema=ctx.env.schema):
            continue

        conflicters.extend(
            _compile_inheritance_conflict_selects(
                stmt, ir, anc_type, subject_stype, ctx=ctx
            )
        )

    return conflicters or None


def check_for_isolation_conflicts(
    stmt: irast.MutatingStmt,
    typ: s_objtypes.ObjectType,
    update_typ: Optional[s_objtypes.ObjectType] = None,
    *, ctx: context.ContextLevel,
) -> None:
    """Check for conflicts on a DML stmt that cause isolation dangers.

    Cross-table exclusive constraints are implemented by triggers that
    read the other tables looking for conflicting rows. This works
    fine in SERIALIZABLE mode, but in REPEATABLE READ mode, this can
    miss two concurrent transactions creating conflicting objects.

    Analyze the type involved in `stmt` to see if there are isolation
    dangers, and log them if so.  These will be reported to the client
    and will generate an error if the query is executed in REPEATABLE
    READ mode.

    This function is called for every subtype in an UPDATE.  In that
    case, `typ` is the subtype and `update_typ` is the base type being
    UDPATEd.
    """

    schema = ctx.env.schema

    entries = _get_type_conflict_constraint_entries(stmt, typ, ctx=ctx)
    constrs = [cnstr for cnstr, _ in entries]

    op = 'INSERT' if isinstance(stmt, irast.InsertStmt) else 'UPDATE'
    base_msg = f'{op} to {typ.get_verbosename(schema)} '

    for constr in constrs:
        subject = constr.get_subject(schema)
        assert subject

        # Find the origin type; if we are the only origin type and we
        # don't have children, then we are good.
        all_objs = []
        match constr.get_constraint_origins(schema):
            case []:
                continue
            case [root_constr, *_]:
                root_subject = root_constr.get_subject(schema)
                if isinstance(root_subject, s_pointers.Pointer):
                    root_subject_obj = root_subject.get_source(schema)
                else:
                    root_subject_obj = root_subject

                if isinstance(root_subject_obj, s_objtypes.ObjectType):
                    all_objs = list(
                        schemactx.get_all_concrete(root_subject_obj, ctx=ctx)
                    )
                    if root_subject_obj == typ and len(all_objs) == 1:
                        continue
                    if root_subject_obj == typ and constr.get_delegated(schema):
                        continue
                    # If this is an UPDATE and we are processing some
                    # subtype, and the actual type being updated is
                    # covered by this same constraint, don't report it for
                    # the child: it will be reported for an ancestor,
                    # which is less noisy.
                    if (
                        update_typ
                        and update_typ != typ
                        and update_typ in all_objs
                    ):
                        continue

        subj_vn = subject.get_verbosename(schema, with_parent=True)
        vn = f'{base_msg}affects an exclusive constraint on {subj_vn}'
        if expr := constr.get_subjectexpr(schema):
            vn += f" with expression '{expr.text}'"

        if not root_subject_obj:
            msg = (
                f"{vn} that is defined on "
                f"{root_subject.get_verbosename(schema)}"
            )
        elif root_subject_obj != typ:
            msg = (
                f"{vn} that is defined in ancestor "
                f"{root_subject_obj.get_verbosename(schema)}"
            )
        else:
            all_objs_s = ', '.join(sorted(
                f"'{o.get_displayname(schema)}'" for o in all_objs if o != typ
            ))
            msg = (
                f"{vn} that is shared with "
                f"descendant types: {all_objs_s}"
            )

        ctx.log_repeatable_read_danger(
            errors.UnsafeIsolationLevelError(msg, span=stmt.span)
        )


================================================
FILE: edb/edgeql/compiler/context.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL to IR compiler context."""

from __future__ import annotations
from typing import (
    Any,
    Callable,
    Literal,
    Optional,
    Mapping,
    MutableMapping,
    Sequence,
    ChainMap,
    NamedTuple,
    cast,
    overload,
    TYPE_CHECKING,
)

import collections
import dataclasses
import enum
import uuid
import weakref

from edb.common import compiler
from edb.common import ordered
from edb.common import parsing

from edb import errors

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from edb.ir import ast as irast
from edb.ir import utils as irutils
from edb.ir import typeutils as irtyputils

from edb.schema import expraliases as s_aliases
from edb.schema import name as s_name
from edb.schema import objects as s_obj
from edb.schema import permissions as s_permissions
from edb.schema import pointers as s_pointers
from edb.schema import schema as s_schema
from edb.schema import types as s_types

from .options import GlobalCompilerOptions

if TYPE_CHECKING:
    from edb.schema import objtypes as s_objtypes
    from edb.schema import sources as s_sources


class Exposure(enum.IntEnum):
    UNEXPOSED = 0
    BINDING = 1
    EXPOSED = 2

    def __bool__(self) -> bool:
        return self == Exposure.EXPOSED


class ContextSwitchMode(enum.Enum):
    NEW = enum.auto()
    SUBQUERY = enum.auto()
    NEWSCOPE = enum.auto()
    NEWFENCE = enum.auto()
    DETACHED = enum.auto()


@dataclasses.dataclass(kw_only=True)
class ViewRPtr:
    source: s_sources.Source
    ptrcls: Optional[s_pointers.Pointer]
    ptrcls_name: Optional[s_name.QualName] = None
    base_ptrcls: Optional[s_pointers.Pointer] = None
    ptrcls_is_linkprop: bool = False
    ptrcls_is_alias: bool = False
    exprtype: s_types.ExprType = s_types.ExprType.Select
    rptr_dir: Optional[s_pointers.PointerDirection] = None


@dataclasses.dataclass(kw_only=True, frozen=True)
class SecurityContext:
    # N.B: Whether we are compiling a trigger is not included here
    # since we clear cached rewrites when compiling them in the
    # *pgsql* compiler.
    suppress_policies: bool

    def toggle_policies(self) -> SecurityContext:
        return dataclasses.replace(
            self, suppress_policies=not self.suppress_policies
        )


@dataclasses.dataclass
class ScopeInfo:
    path_scope: irast.ScopeTreeNode
    binding_kind: Optional[irast.BindingKind]
    pinned_path_id_ns: Optional[frozenset[str]] = None


class PointerRefCache(dict[irtyputils.PtrRefCacheKey, irast.BasePointerRef]):

    _rcache: dict[irast.BasePointerRef, s_pointers.PointerLike]

    def __init__(self) -> None:
        super().__init__()
        self._rcache = {}

    def __setitem__(
        self,
        key: irtyputils.PtrRefCacheKey,
        val: irast.BasePointerRef,
    ) -> None:
        super().__setitem__(key, val)
        self._rcache[val] = key

    def get_ptrcls_for_ref(
        self,
        ref: irast.BasePointerRef,
    ) -> Optional[s_pointers.PointerLike]:
        return self._rcache.get(ref)


# Volatility inference computes two volatility results:
# A basic one, and one for consumption by materialization
InferredVolatility = (
    qltypes.Volatility
    | tuple[qltypes.Volatility, qltypes.Volatility]
)


@dataclasses.dataclass(frozen=True, kw_only=True)
class ServerParamConversion:
    path_id: irast.PathId
    ir_param: irast.Param
    additional_info: tuple[str, ...]

    volatility: qltypes.Volatility

    # If the parameter is a query parameter, track its script params index.
    script_param_index: Optional[int] = None

    # If the parameter is a constant value, pass to directly to the server.
    constant_value: Optional[Any] = None


class Environment:
    """Compilation environment."""

    schema: s_schema.Schema
    """A Schema instance to use for class resolution."""

    orig_schema: s_schema.Schema
    """A Schema as it was at the start of the compilation."""

    options: GlobalCompilerOptions
    """Compiler options."""

    path_scope: irast.ScopeTreeNode
    """Overrall expression path scope tree."""

    schema_view_cache: dict[
        tuple[s_types.Type, object],
        tuple[s_types.Type, irast.Set],
    ]
    """Type cache used by schema-level views."""

    query_parameters: dict[str, irast.Param]
    """A mapping of query parameters to their types.  Gets populated during
    the compilation."""

    query_globals: dict[s_name.QualName, irast.Global]
    """A mapping of query globals.  Gets populated during
    the compilation."""
    query_globals_types: dict[s_name.QualName, s_types.Type]
    """Injected dummy types for caching globals when the input
    encoding is JSON"""

    required_permissions: set[s_permissions.Permission]
    """Permissions *required* to run this query."""

    server_param_conversions: dict[
        str,
        dict[str, ServerParamConversion],
    ]
    """A mapping of query parameters and the server param conversions which are
    needed by the query.

    This indicates that the server will compute and provide an additional
    parameter based on a user provided parameter.

    Used by ext::ai:search to get embeddings from text before running a query.
    """

    server_param_conversion_calls: list[tuple[str, Optional[parsing.Span]]]
    """Used to generate errors related to server param conversions."""

    set_types: dict[irast.Set, s_types.Type]
    """A dictionary of all Set instances and their schema types."""

    type_origins: dict[s_types.Type, Optional[parsing.Span]]
    """A dictionary of notable types and their source origins.

    This is used to trace where a particular type instance originated in
    order to provide useful diagnostics for type errors.
    """

    inferred_volatility: dict[
        irast.Base,
        InferredVolatility]
    """A dictionary of expressions and their inferred volatility."""

    view_shapes: dict[
        s_types.Type | s_pointers.PointerLike,
        list[tuple[s_pointers.Pointer, qlast.ShapeOp]]
    ]
    """Object output or modification shapes."""

    pointer_derivation_map: dict[
        s_pointers.Pointer,
        list[s_pointers.Pointer],
    ]
    """A parent: children mapping of derived pointer classes."""

    pointer_specified_info: dict[
        s_pointers.Pointer,
        tuple[
            Optional[qltypes.SchemaCardinality],
            Optional[bool],
            Optional[parsing.Span],
        ],
    ]
    """Cardinality/source context for pointers with unclear cardinality."""

    view_shapes_metadata: dict[s_types.Type, irast.ViewShapeMetadata]

    schema_refs: set[s_obj.Object]
    """A set of all schema objects referenced by an expression."""

    schema_ref_exprs: Optional[dict[s_obj.Object, set[qlast.Base]]]
    """Map from all schema objects referenced to the ast referants.

    This is used for rewriting expressions in the schema after a rename. """

    # Caches for costly operations in edb.ir.typeutils
    ptr_ref_cache: PointerRefCache
    type_ref_cache: dict[irtyputils.TypeRefCacheKey, irast.TypeRef]

    dml_exprs: list[qlast.Base]
    """A list of DML expressions (statements and DML-containing
    functions) that appear in a function body.
    """

    dml_stmts: list[irast.MutatingStmt]
    """A list of DML statements in the query"""

    #: A list of bindings that should be assumed to be singletons.
    singletons: list[irast.PathId]

    scope_tree_nodes: MutableMapping[int, irast.ScopeTreeNode]
    """Map from unique_id to nodes."""

    materialized_sets: dict[
        s_types.Type | s_pointers.PointerLike,
        tuple[qlast.Statement, Sequence[irast.MaterializeReason]],
    ]
    """A mapping of computed sets that must be computed only once."""

    compiled_stmts: dict[qlast.Statement, irast.Stmt]
    """A mapping of from input edgeql to compiled IR"""

    alias_result_view_name: Optional[s_name.QualName]
    """The name of a view being defined as an alias."""

    script_params: dict[str, irast.Param]
    """All parameter definitions from an enclosing multi-statement script.

    Used to make sure the types are consistent."""

    source_map: dict[s_pointers.PointerLike, irast.ComputableInfo]
    """A mapping of computable pointers to QL source AST and context."""

    type_rewrites: dict[
        tuple[s_types.Type, bool], irast.Set | None | Literal[True]]
    """Access policy rewrites for schema-level types.

    None indicates no rewrite, True indicates a compound type
    that had rewrites in its components.
    """

    expr_view_cache: dict[tuple[qlast.Base, s_name.Name], irast.Set]
    """Type cache used by expression-level views."""

    shape_type_cache: dict[
        tuple[
            s_objtypes.ObjectType,
            s_types.ExprType,
            tuple[qlast.ShapeElement, ...],
        ],
        s_objtypes.ObjectType,
    ]
    """Type cache for shape expressions."""

    path_scope_map: dict[irast.Set, ScopeInfo]
    """A dictionary of scope info that are appropriate for a given view."""

    dml_rewrites: dict[irast.Set, irast.Rewrites]
    """Compiled rewrites that should be attached to InsertStmt or UpdateStmt"""

    warnings: list[errors.EdgeDBError]
    """List of warnings to emit"""

    unsafe_isolation_dangers: list[errors.UnsafeIsolationLevelError]
    """List of repeatable read DML dangers"""

    policy_use_count: int
    """A count of the number of times that policies have been referenced.

    Can be used to detect if a sub-compilation referenced a policy.
    """

    def __init__(
        self,
        *,
        schema: s_schema.Schema,
        path_scope: Optional[irast.ScopeTreeNode] = None,
        alias_result_view_name: Optional[s_name.QualName] = None,
        options: Optional[GlobalCompilerOptions] = None,
    ) -> None:
        if options is None:
            options = GlobalCompilerOptions()

        if path_scope is None:
            path_scope = irast.new_scope_tree()

        self.options = options
        self.schema = schema
        self.orig_schema = schema
        self.path_scope = path_scope
        self.schema_view_cache = {}
        self.query_parameters = {}
        self.query_globals = {}
        self.query_globals_types = {}
        self.required_permissions = set()
        self.server_param_conversions = {}
        self.server_param_conversion_calls = []
        self.set_types = {}
        self.type_origins = {}
        self.inferred_volatility = {}
        self.view_shapes = collections.defaultdict(list)
        self.view_shapes_metadata = collections.defaultdict(
            irast.ViewShapeMetadata)
        self.schema_refs = set()
        self.schema_ref_exprs = {} if options.track_schema_ref_exprs else None
        self.ptr_ref_cache = PointerRefCache()
        self.type_ref_cache = {}
        self.dml_exprs = []
        self.dml_stmts = []
        self.pointer_derivation_map = collections.defaultdict(list)
        self.pointer_specified_info = {}
        self.singletons = []
        self.scope_tree_nodes = weakref.WeakValueDictionary()
        self.materialized_sets = {}
        self.compiled_stmts = {}
        self.alias_result_view_name = alias_result_view_name
        self.script_params = {}
        self.source_map = {}
        self.type_rewrites = {}
        self.shape_type_cache = {}
        self.expr_view_cache = {}
        self.path_scope_map = {}
        self.dml_rewrites = {}
        self.warnings = []
        self.unsafe_isolation_dangers = []
        self.policy_use_count = 0

    def add_schema_ref(
        self, sobj: s_obj.Object, expr: Optional[qlast.Base]
    ) -> None:
        self.schema_refs.add(sobj)
        if self.schema_ref_exprs is not None and expr:
            self.schema_ref_exprs.setdefault(sobj, set()).add(expr)

    @overload
    def get_schema_object_and_track(
        self,
        name: s_name.Name,
        expr: Optional[qlast.Base],
        *,
        modaliases: Optional[Mapping[Optional[str], str]] = None,
        type: Optional[type[s_obj.Object]] = None,
        default: s_obj.Object | s_obj.NoDefaultT = s_obj.NoDefault,
        label: Optional[str] = None,
        condition: Optional[Callable[[s_obj.Object], bool]] = None,
    ) -> s_obj.Object:
        ...

    @overload
    def get_schema_object_and_track(
        self,
        name: s_name.Name,
        expr: Optional[qlast.Base],
        *,
        modaliases: Optional[Mapping[Optional[str], str]] = None,
        type: Optional[type[s_obj.Object]] = None,
        default: s_obj.Object | s_obj.NoDefaultT | None = s_obj.NoDefault,
        label: Optional[str] = None,
        condition: Optional[Callable[[s_obj.Object], bool]] = None,
    ) -> Optional[s_obj.Object]:
        ...

    def get_schema_object_and_track(
        self,
        name: s_name.Name,
        expr: Optional[qlast.Base],
        *,
        modaliases: Optional[Mapping[Optional[str], str]] = None,
        type: Optional[type[s_obj.Object]] = None,
        default: s_obj.Object | s_obj.NoDefaultT | None = s_obj.NoDefault,
        label: Optional[str] = None,
        condition: Optional[Callable[[s_obj.Object], bool]] = None,
    ) -> Optional[s_obj.Object]:
        sobj = self.schema.get(
            name, module_aliases=modaliases, type=type,
            condition=condition, label=label,
            default=default)
        if sobj is not None and sobj is not default:
            self.add_schema_ref(sobj, expr)

            if (
                isinstance(sobj, s_types.Type)
                and sobj.get_expr(self.schema) is not None
            ):
                # If the type is derived from an ALIAS declaration,
                # make sure we record the reference to the Alias object
                # as well for correct delta ordering.
                alias_objs = self.schema.get_referrers(
                    sobj,
                    scls_type=s_aliases.Alias,
                    field_name='type',
                )
                for obj in alias_objs:
                    self.add_schema_ref(obj, expr)

        return sobj

    def get_schema_type_and_track(
        self,
        name: s_name.Name,
        expr: Optional[qlast.Base]=None,
        *,
        modaliases: Optional[Mapping[Optional[str], str]] = None,
        default: None | s_obj.Object | s_obj.NoDefaultT = s_obj.NoDefault,
        label: Optional[str]=None,
        condition: Optional[Callable[[s_obj.Object], bool]]=None,
    ) -> s_types.Type:

        stype = self.get_schema_object_and_track(
            name, expr, modaliases=modaliases, default=default, label=label,
            condition=condition, type=s_types.Type,
        )

        return cast(s_types.Type, stype)


class ContextLevel(compiler.ContextLevel):

    env: Environment
    """Compilation environment common for all context levels."""

    derived_target_module: Optional[str]
    """The name of the module for classes derived by views."""

    anchors: dict[
        str | type[qlast.SpecialAnchor],
        irast.Set,
    ]
    """A mapping of anchor variables (aliases to path expressions passed
    to the compiler programmatically).
    """

    modaliases: dict[Optional[str], str]
    """A combined list of module name aliases declared in the WITH block,
    or passed to the compiler programmatically.
    """

    view_nodes: dict[s_name.Name, s_types.Type]
    """A dictionary of newly derived Node classes representing views."""

    view_sets: dict[s_obj.Object, irast.Set]
    """A dictionary of IR expressions for views declared in the query."""

    suppress_rewrites: frozenset[s_types.Type]
    """Types to suppress using rewrites on"""

    aliased_views: ChainMap[s_name.Name, irast.Set]
    """A dictionary of views aliased in a statement body."""

    class_view_overrides: dict[uuid.UUID, s_types.Type]
    """Object mapping used by implicit view override in SELECT."""

    clause: Optional[str]
    """Statement clause the compiler is currently in."""

    toplevel_stmt: Optional[irast.Stmt]
    """Top-level statement."""

    stmt: Optional[irast.Stmt]
    """Statement node currently being built."""

    qlstmt: Optional[qlast.Statement]
    """Statement source node currently being built."""

    path_id_namespace: frozenset[str]
    """A namespace to use for all path ids."""

    pending_stmt_own_path_id_namespace: frozenset[str]
    """A path id namespace to add to the fence of the next statement."""

    pending_stmt_full_path_id_namespace: frozenset[str]
    """A set of path id namespaces to use in path ids in the next statement."""

    inserting_paths: dict[irast.PathId, Literal['body'] | Literal['else']]
    """A set of path ids that are currently being inserted."""

    view_map: ChainMap[
        irast.PathId,
        tuple[tuple[irast.PathId, irast.Set], ...],
    ]
    """Set translation map.  Used for mapping computable sources..

    When compiling a computable, we need to be able to map references to
    the source back to the correct source set.

    This maps from a namespace-stripped source path_id to the expected
    computable-internal path_id and the actual source set.

    The namespace stripping is necessary to handle the case where
    bindings have added more namespaces to the source set reference.
    (See test_edgeql_scope_computables_13.)
    """

    path_scope: irast.ScopeTreeNode
    """Path scope tree, with per-lexical-scope levels."""

    iterator_ctx: Optional[ContextLevel]
    """The context of the statement where all iterators should be placed."""

    iterator_path_ids: frozenset[irast.PathId]
    """The path ids of all in scope iterator variables"""

    scope_id_ctr: compiler.SimpleCounter
    """Path scope id counter."""

    view_rptr: Optional[ViewRPtr]
    """Pointer information for the top-level view of the substatement."""

    view_scls: Optional[s_types.Type]
    """Schema class for the top-level set of the substatement."""

    toplevel_result_view_name: Optional[s_name.QualName]
    """The name to use for the view that is the result of the top statement."""

    partial_path_prefix: Optional[irast.Set]
    """The set used as a prefix for partial paths."""

    implicit_id_in_shapes: bool
    """Whether to include the id property in object shapes implicitly."""

    implicit_tid_in_shapes: bool
    """Whether to include the type id property in object shapes implicitly."""

    implicit_tname_in_shapes: bool
    """Whether to include the type name property in object shapes
       implicitly."""

    implicit_limit: int
    """Implicit LIMIT clause in SELECT statements."""

    special_computables_in_mutation_shape: frozenset[str]
    """A set of "special" computable pointers allowed in mutation shape."""

    empty_result_type_hint: Optional[s_types.Type]
    """Type to use if the statement result expression is an empty set ctor."""

    defining_view: Optional[s_objtypes.ObjectType]
    """Whether a view is currently being defined (as opposed to be compiled)"""

    current_schema_views: tuple[s_types.Type, ...]
    """Which schema views are currently being compiled"""

    recompiling_schema_alias: bool
    """Whether we are currently recompiling a schema-level expression alias."""

    compiling_update_shape: bool
    """Whether an UPDATE shape is currently being compiled."""

    active_computeds: ordered.OrderedSet[s_pointers.Pointer]
    """A ordered set of currently compiling computeds"""

    allow_endpoint_linkprops: bool
    """Whether to allow references to endpoint linkpoints (@source, @target)."""

    disallow_dml: Optional[str]
    """Whether we are currently in a place where no dml is allowed,
        if not None, then it is of the form `in a FILTER clause`  """

    active_rewrites: frozenset[s_objtypes.ObjectType]
    """For detecting cycles in rewrite rules"""

    active_defaults: frozenset[s_objtypes.ObjectType]
    """For detecting cycles in defaults"""

    collection_cast_info: Optional[CollectionCastInfo]
    """For generating errors messages when casting to collections.

    This will be set by the outermost cast and then shared between all
    sub-casts.

    Some casts (eg. arrays) will generate select statements containing other
    type casts. These will also share the outermost cast info.
    """

    no_factoring: bool

    def __init__(
        self,
        prevlevel: Optional[ContextLevel],
        mode: ContextSwitchMode,
        *,
        env: Optional[Environment] = None,
    ) -> None:

        self.mode = mode

        if prevlevel is None:
            assert env is not None
            self.env = env
            self.derived_target_module = None
            self.aliases = compiler.AliasGenerator()
            self.anchors = {}
            self.modaliases = {}

            self.view_nodes = {}
            self.view_sets = {}
            self.suppress_rewrites = frozenset()
            self.aliased_views = collections.ChainMap()
            self.class_view_overrides = {}

            self.toplevel_stmt = None
            self.stmt = None
            self.qlstmt = None
            self.path_id_namespace = frozenset()
            self.pending_stmt_own_path_id_namespace = frozenset()
            self.pending_stmt_full_path_id_namespace = frozenset()
            self.inserting_paths = {}
            self.view_map = collections.ChainMap()
            self.path_scope = env.path_scope
            self.iterator_path_ids = frozenset()
            self.scope_id_ctr = compiler.SimpleCounter()
            self.view_scls = None
            self.expr_exposed = Exposure.UNEXPOSED

            self.partial_path_prefix = None

            self.view_rptr = None
            self.toplevel_result_view_name = None
            self.implicit_id_in_shapes = False
            self.implicit_tid_in_shapes = False
            self.implicit_tname_in_shapes = False
            self.implicit_limit = 0
            self.special_computables_in_mutation_shape = frozenset()
            self.empty_result_type_hint = None
            self.defining_view = None
            self.current_schema_views = ()
            self.compiling_update_shape = False
            self.active_computeds = ordered.OrderedSet()
            self.recompiling_schema_alias = False
            self.active_rewrites = frozenset()
            self.active_defaults = frozenset()

            self.allow_endpoint_linkprops = False
            self.disallow_dml = None
            self.no_factoring = False

            self.collection_cast_info = None

        else:
            self.env = prevlevel.env
            self.derived_target_module = prevlevel.derived_target_module
            self.aliases = prevlevel.aliases

            self.view_nodes = prevlevel.view_nodes
            self.view_sets = prevlevel.view_sets
            self.suppress_rewrites = prevlevel.suppress_rewrites

            self.iterator_path_ids = prevlevel.iterator_path_ids
            self.path_id_namespace = prevlevel.path_id_namespace
            self.pending_stmt_own_path_id_namespace = \
                prevlevel.pending_stmt_own_path_id_namespace
            self.pending_stmt_full_path_id_namespace = \
                prevlevel.pending_stmt_full_path_id_namespace
            self.inserting_paths = prevlevel.inserting_paths
            self.view_map = prevlevel.view_map
            if prevlevel.path_scope is None:
                prevlevel.path_scope = self.env.path_scope
            self.path_scope = prevlevel.path_scope
            self.scope_id_ctr = prevlevel.scope_id_ctr
            self.view_scls = prevlevel.view_scls
            self.expr_exposed = prevlevel.expr_exposed
            self.partial_path_prefix = prevlevel.partial_path_prefix
            self.toplevel_stmt = prevlevel.toplevel_stmt
            self.implicit_id_in_shapes = prevlevel.implicit_id_in_shapes
            self.implicit_tid_in_shapes = prevlevel.implicit_tid_in_shapes
            self.implicit_tname_in_shapes = prevlevel.implicit_tname_in_shapes
            self.implicit_limit = prevlevel.implicit_limit
            self.special_computables_in_mutation_shape = \
                prevlevel.special_computables_in_mutation_shape
            self.empty_result_type_hint = prevlevel.empty_result_type_hint
            self.defining_view = prevlevel.defining_view
            self.current_schema_views = prevlevel.current_schema_views
            self.compiling_update_shape = prevlevel.compiling_update_shape
            self.active_computeds = prevlevel.active_computeds
            self.recompiling_schema_alias = prevlevel.recompiling_schema_alias
            self.active_rewrites = prevlevel.active_rewrites
            self.active_defaults = prevlevel.active_defaults

            self.allow_endpoint_linkprops = prevlevel.allow_endpoint_linkprops
            self.disallow_dml = prevlevel.disallow_dml
            self.no_factoring = prevlevel.no_factoring

            self.collection_cast_info = prevlevel.collection_cast_info

            if mode == ContextSwitchMode.SUBQUERY:
                self.anchors = prevlevel.anchors.copy()
                self.modaliases = prevlevel.modaliases.copy()
                self.aliased_views = prevlevel.aliased_views.new_child()
                self.class_view_overrides = \
                    prevlevel.class_view_overrides.copy()

                self.pending_stmt_own_path_id_namespace = frozenset()
                self.pending_stmt_full_path_id_namespace = frozenset()
                self.inserting_paths = prevlevel.inserting_paths.copy()

                self.view_rptr = None
                self.view_scls = None
                self.stmt = None
                self.qlstmt = None

                self.view_rptr = None
                self.toplevel_result_view_name = None

            elif mode == ContextSwitchMode.DETACHED:
                self.anchors = prevlevel.anchors.copy()
                self.modaliases = prevlevel.modaliases.copy()
                self.aliased_views = collections.ChainMap()
                self.view_map = collections.ChainMap()
                self.class_view_overrides = {}
                self.expr_exposed = prevlevel.expr_exposed

                self.view_nodes = {}
                self.view_sets = {}
                self.path_id_namespace = frozenset({self.aliases.get('ns')})
                self.pending_stmt_own_path_id_namespace = frozenset()
                self.pending_stmt_full_path_id_namespace = frozenset()
                self.inserting_paths = {}

                self.view_rptr = None
                self.view_scls = None
                self.stmt = prevlevel.stmt
                self.qlstmt = prevlevel.qlstmt

                self.partial_path_prefix = None

                self.view_rptr = None
                self.toplevel_result_view_name = None
            else:
                self.anchors = prevlevel.anchors
                self.modaliases = prevlevel.modaliases
                self.aliased_views = prevlevel.aliased_views
                self.class_view_overrides = prevlevel.class_view_overrides

                self.stmt = prevlevel.stmt
                self.qlstmt = prevlevel.qlstmt

                self.view_rptr = prevlevel.view_rptr
                self.toplevel_result_view_name = \
                    prevlevel.toplevel_result_view_name

            if mode == ContextSwitchMode.NEWFENCE:
                self.path_scope = self.path_scope.attach_fence()

            if mode == ContextSwitchMode.NEWSCOPE:
                self.path_scope = self.path_scope.attach_branch()

    def subquery(self) -> compiler.CompilerContextManager[ContextLevel]:
        return self.new(ContextSwitchMode.SUBQUERY)

    def newscope(
        self,
        *,
        fenced: bool,
    ) -> compiler.CompilerContextManager[ContextLevel]:
        if fenced:
            mode = ContextSwitchMode.NEWFENCE
        else:
            mode = ContextSwitchMode.NEWSCOPE

        return self.new(mode)

    def detached(self) -> compiler.CompilerContextManager[ContextLevel]:
        return self.new(ContextSwitchMode.DETACHED)

    def create_anchor(
        self,
        ir: irast.Set,
        name: str = 'v', *,
        check_dml: bool = False,
        move_scope: bool = False,
    ) -> qlast.Path:
        alias = self.aliases.get(name)
        # TODO: We should probably always check for DML, but I'm
        # concerned about perf, since we don't cache it at all.
        has_dml = check_dml and irutils.contains_dml(ir)
        self.anchors[alias] = ir
        if move_scope:
            assert ir.path_scope_id is not None
        return qlast.Path(
            steps=[qlast.IRAnchor(
                name=alias, has_dml=has_dml, move_scope=move_scope
            )],
        )

    def maybe_create_anchor(
        self,
        ir: irast.Set | qlast.Expr,
        name: str = 'v',
    ) -> qlast.Expr:
        if isinstance(ir, irast.Set):
            return self.create_anchor(ir, name)
        else:
            return ir

    def get_security_context(self) -> SecurityContext:
        '''Compute an additional compilation cache key.

        Return an additional key for any compilation caches that may
        vary based on "security contexts" such as whether we are in an
        access policy.
        '''
        # N.B: Whether we are compiling a trigger is not included here
        # since we clear cached rewrites when compiling them in the
        # *pgsql* compiler.
        return SecurityContext(
            suppress_policies=bool(self.suppress_rewrites),
        )

    def log_warning(self, warning: errors.EdgeDBError) -> None:
        self.env.warnings.append(warning)

    def log_repeatable_read_danger(
        self, d: errors.UnsafeIsolationLevelError
    ) -> None:
        self.env.unsafe_isolation_dangers.append(d)

    def allow_factoring(self) -> None:
        self.no_factoring = False


class CompilerContext(compiler.CompilerContext[ContextLevel]):
    ContextLevelClass = ContextLevel
    default_mode = ContextSwitchMode.NEW


class CollectionCastInfo(NamedTuple):
    """For generating errors messages when casting to collections."""

    from_type: s_types.Type
    to_type: s_types.Type

    path_elements: list[tuple[str, Optional[str]]]
    """Represents a path to the current collection element being cast.

    A path element is a tuple of the collection type and an optional
    element name. eg. ('tuple', 'a') or ('array', None)

    The list is shared between the outermost context and all its sub contexts.
    When casting a collection, each element's path should be pushed before
    entering the "sub-cast" and popped immediately after.

    In the event of a cast error, the list is preserved at the outermost cast.
    """


================================================
FILE: edb/edgeql/compiler/dispatch.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations


import functools

from edb.edgeql import ast as qlast
from edb.ir import ast as irast

from . import context


@functools.singledispatch
def compile(node: qlast.Base, *, ctx: context.ContextLevel) -> irast.Set:
    raise NotImplementedError(
        f'no EdgeQL compiler handler for {node.__class__}')


================================================
FILE: edb/edgeql/compiler/eta_expand.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""η-expansion of tuples and arrays.

Our shape compiler is only able to produce shape outputs for objects
at places in the program that get fairly directly routed into the
output. To compensate for this, when we have an expression like
`[User {name}][0]`, the shape output is actually computed *after* the
array indexing. This works well, but fails when the thing we need to
do late shape injection into is a collection type that cannot have a
shape put on it directly: `[(User {name}, 20)][0]`.
To solve this problem, we use η-expansion.

η-expansion is a technique coming from the study of the lambda
caluclus, where it means to expand an expression `e` into `λx.ex`,
where `x` is a variable that does not appear in `e` (or, in Python-speak
`lambda x: e(x)`). Setting aside questions about what happens if `e` does
not terminate, this new expression `λx.ex` will be equivalent to `e`.

In the traditional untyped lambda calculus, where everything is a
function (from functions to functions), this is the whole story.
But the world of *typed* lambda calculi introduce some interesting
new considerations:
  1. Instead of it being possible to η-expand *any* expression,
     now only expressions actually of function type may be expanded.
     This allows us to define of notion of an expression being "η-long",
     which means that it is "fully η-expanded" (and that any new expansion
     would create a reducible expression, where a lambda appears directly
     on the LHS of an application).
  2. If other types are introduced, we can define notions of η-expansion
     for them as well. The key idea is that expanded expression should
     explicitly construct an object of the desired type.

     For a pair type, for example, we would expand `e` into `(e[0], e[1])`.
     This also can be done to produce an "η-long" form: for example,
     if we have the type `Tuple[int, Tuple[int, int]]`, we would expand
     that into `(e[0], (e[1][0], e[1][1]))`.

This key idea, of expanding a term into one that explicitly constructs
an object of the desired type, is exactly what we need to ensure that
we can inject shapes into the output.

As a set-based query language, we also need to do some extra work to
preserve the ordering of elements.

Our core rules are:
    EXPAND_ORDERED(t, e) =
        WITH enum := enumerate(e)
        SELECT EXPAND(t, enum.1) ORDER BY enum.0

    EXPAND(tuple, p) = (EXPAND(t, p.0), EXPAND(s, p.1))

    EXPAND(array, p) =
        (p, array_agg(EXPAND_ORDERED(t, array_unpack(p)))).1

    EXPAND(non_collection_type, p) = p

They are discussed in more detail at the implementation sites.
"""


from __future__ import annotations


from edb.ir import ast as irast

from edb.schema import name as sn
from edb.schema import types as s_types

from edb.edgeql import ast as qlast

from . import astutils
from . import context
from . import dispatch
from . import setgen


# If true, we disregard the optimizations meant to avoid unnecessary
# expansions. This is useful as a bug-finding tool, since η-expansion
# found lots of bugs, but mostly in test cases that didn't *really*
# need it.
ALWAYS_EXPAND = False


def needs_eta_expansion_expr(
    ir: irast.Expr,
    stype: s_types.Type,
    *,
    ctx: context.ContextLevel,
) -> bool:
    """Determine if an expr is in need of η-expansion

    In general, any expression of an object-containing
    tuple or array type needs expansion unless it is:
        * A tuple literal
        * An empty array literal
        * A one-element array literal
        * A call to array_agg
    in which none of the arguments are sets that need expansion.
    """
    if isinstance(ir, irast.SelectStmt):
        return needs_eta_expansion(
            ir.result, has_clauses=bool(ir.where or ir.orderby), ctx=ctx)

    if isinstance(stype, s_types.Array):
        if isinstance(ir, irast.Array):
            return bool(ir.elements) and (
                len(ir.elements) != 1
                or needs_eta_expansion(ir.elements[0], ctx=ctx)
            )
        elif (
            isinstance(ir, irast.FunctionCall)
            and ir.func_shortname == sn.QualName('std', 'array_agg')
        ):
            return needs_eta_expansion(ir.args[0].expr, ctx=ctx)
        else:
            return True

    elif isinstance(stype, s_types.Tuple):
        if isinstance(ir, irast.Tuple):
            return any(
                needs_eta_expansion(el.val, ctx=ctx) for el in ir.elements
            )
        else:
            return True

    else:
        return False


def needs_eta_expansion(
    ir: irast.Set,
    *,
    has_clauses: bool = False,
    ctx: context.ContextLevel,
) -> bool:
    """Determine if a set is in need of η-expansion"""
    stype = setgen.get_set_type(ir, ctx=ctx)

    if not (
        isinstance(stype, (s_types.Array, s_types.Tuple))
        and stype.contains_object(ctx.env.schema)
    ):
        return False

    if ALWAYS_EXPAND:
        return True

    # Object containing arrays always need to be eta expanded if they
    # might be processed by a clause. This is because the pgsql side
    # will produce *either* a value or serialized for array_agg/array
    # literals.
    if has_clauses and (
        (subarray := stype.find_array(ctx.env.schema))
        and subarray.contains_object(ctx.env.schema)
    ):
        return True

    # If we are directly projecting an element out of a tuple, we can just
    # look through to the relevant tuple element. This is probably not
    # an important optimization to support, but our expansion can generate
    # this idiom, so on principle I wanted to support it.
    if (
        isinstance(ir.expr, irast.TupleIndirectionPointer)
        and isinstance(ir.expr.source.expr, irast.Tuple)
    ):
        name = ir.expr.ptrref.shortname.name
        els = [x for x in ir.expr.source.expr.elements if x.name == name]
        if len(els) == 1:
            return needs_eta_expansion(els[0].val, ctx=ctx)

    if not ir.expr or (
        ir.is_binding and ir.is_binding != irast.BindingKind.Select
    ):
        return True

    return needs_eta_expansion_expr(ir.expr, stype, ctx=ctx)


def _get_alias(
    name: str, *, ctx: context.ContextLevel
) -> tuple[str, qlast.Path]:
    alias = ctx.aliases.get(name)
    return alias, qlast.Path(
        steps=[qlast.ObjectRef(name=alias)],
    )


def eta_expand_ir(
    ir: irast.Set,
    *,
    toplevel: bool=False,
    ctx: context.ContextLevel,
) -> irast.Set:
    """η-expansion of an IR set.

    Our core implementation of η-expansion operates on an AST,
    so this mostly just checks that we really want to expand
    and then sets up an anchor for the AST based implementation
    to run on.
    """
    if (
        ctx.env.options.schema_object_context
        or ctx.env.options.func_params
        or ctx.env.options.schema_view_mode
    ):
        return ir

    if not needs_eta_expansion(ir, ctx=ctx):
        return ir

    with ctx.new() as subctx:
        subctx.allow_factoring()

        subctx.anchors = subctx.anchors.copy()
        source_ref = subctx.create_anchor(ir)

        alias, path = _get_alias('eta', ctx=subctx)
        qry = qlast.SelectQuery(
            result=eta_expand_ordered(
                path, setgen.get_set_type(ir, ctx=subctx), ctx=subctx
            ),
            aliases=[
                qlast.AliasedExpr(alias=alias, expr=source_ref)
            ],
        )
        if toplevel:
            subctx.toplevel_stmt = None
        return dispatch.compile(qry, ctx=subctx)


def eta_expand_ordered(
    expr: qlast.Expr,
    stype: s_types.Type,
    *,
    ctx: context.ContextLevel,
) -> qlast.Expr:
    """Do an order-preserving η-expansion

    Unlike in the lambda calculus, edgeql is a set-based language
    with a notion of ordering, which we need to preserve.
    We do this by using enumerate and ORDER BY on it:
        EXPAND_ORDERED(t, e) =
            WITH enum := enumerate(e)
            SELECT EXPAND(t, enum.1) ORDER BY enum.0
    """
    enumerated = qlast.FunctionCall(
        func=('__std__', 'enumerate'), args=[expr]
    )

    enumerated_alias, enumerated_path = _get_alias('enum', ctx=ctx)

    element_path = astutils.extend_path(enumerated_path, '1')
    result_expr = eta_expand(element_path, stype, ctx=ctx)

    return qlast.SelectQuery(
        result=result_expr,
        orderby=[
            qlast.SortExpr(path=astutils.extend_path(enumerated_path, '0'))
        ],
        aliases=[
            qlast.AliasedExpr(alias=enumerated_alias, expr=enumerated)
        ],
    )


def eta_expand(
    path: qlast.Path,
    stype: s_types.Type,
    *,
    ctx: context.ContextLevel,
) -> qlast.Expr:
    """η-expansion of an AST path"""
    if not ALWAYS_EXPAND and not stype.contains_object(ctx.env.schema):
        # This isn't strictly right from a "fully η expanding" perspective,
        # but for our uses, we only need to make sure that objects are
        # exposed to the output, so we can skip anything not containing one.
        return path

    if isinstance(stype, s_types.Array):
        return eta_expand_array(path, stype, ctx=ctx)

    elif isinstance(stype, s_types.Tuple):
        return eta_expand_tuple(path, stype, ctx=ctx)

    else:
        return path


def eta_expand_tuple(
    path: qlast.Path,
    stype: s_types.Tuple,
    *,
    ctx: context.ContextLevel,
) -> qlast.Expr:
    """η-expansion of tuples

    η-expansion of tuple types is straightforward and traditional:
        EXPAND(tuple, p) = (EXPAND(t, p.0), EXPAND(s, p.1))
    is the case for pairs. n-ary and named cases are generalized in the
    obvious way.
    The one exception is that the expansion of the empty tuple type is
    `p` and not `()`, to ensure that the path appears in the output.
    """
    if not stype.get_subtypes(ctx.env.schema):
        return path

    els = [
        qlast.TupleElement(
            name=qlast.Ptr(name=name),
            val=eta_expand(astutils.extend_path(path, name), subtype, ctx=ctx),
        )
        for name, subtype in stype.iter_subtypes(ctx.env.schema)
    ]

    if stype.is_named(ctx.env.schema):
        return qlast.NamedTuple(elements=els)
    else:
        return qlast.Tuple(elements=[el.val for el in els])


def eta_expand_array(
    path: qlast.Path,
    stype: s_types.Array,
    *,
    ctx: context.ContextLevel,
) -> qlast.Expr:
    """η-expansion of arrays

    η-expansion of array types is is a little peculiar to edgeql and less
    grounded in typed lambda calculi:
        EXPAND(array, p) =
            (p, array_agg(EXPAND_ORDERED(t, array_unpack(p)))).1

    We use a similar approach for compiling casts.

    The tuple projection trick serves to make sure that we iterate over
    `p` *outside* of the array_agg (or else all the arrays would get
    aggregated together) as well as ensuring that `p` appears in the expansion
    in a non-fenced position (or else sorting it from outside wouldn't work).

    (If it wasn't for the latter requirement, we could just use a FOR.
    I find it a little unsatisfying that our η-expansion needs to use this
    trick, and the pgsql compiler needed to be hacked to make it work.)
    """

    unpacked = qlast.FunctionCall(
        func=('__std__', 'array_unpack'), args=[path]
    )

    expanded = eta_expand_ordered(
        unpacked, stype.get_element_type(ctx.env.schema), ctx=ctx)

    agg_expr = qlast.FunctionCall(
        func=('__std__', 'array_agg'), args=[expanded]
    )

    return astutils.extend_path(
        qlast.Tuple(elements=[path, agg_expr]), '1'
    )


================================================
FILE: edb/edgeql/compiler/expr.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL non-statement expression compilation functions."""


from __future__ import annotations

from typing import Callable, Optional, Sequence, cast

from edb import errors

from edb.common import parsing
from edb.common import span as edb_span

from edb.ir import ast as irast
from edb.ir import typeutils as irtyputils
from edb.ir import utils as irutils

from edb.schema import constraints as s_constr
from edb.schema import functions as s_func
from edb.schema import globals as s_globals
from edb.schema import indexes as s_indexes
from edb.schema import name as sn
from edb.schema import objects as so
from edb.schema import objtypes as s_objtypes
from edb.schema import permissions as s_permissions
from edb.schema import pseudo as s_pseudo
from edb.schema import scalars as s_scalars
from edb.schema import schema as s_schema
from edb.schema import types as s_types
from edb.schema import utils as s_utils

from edb.edgeql import ast as qlast

from . import astutils
from . import casts
from . import context
from . import dispatch
from . import pathctx
from . import schemactx
from . import setgen
from . import stmt
from . import tuple_args
from . import typegen

from . import func  # NOQA


@dispatch.compile.register(qlast.OptionalExpr)
def compile__Optional(
    expr: qlast.OptionalExpr, *, ctx: context.ContextLevel
) -> irast.Set:

    result = dispatch.compile(expr.expr, ctx=ctx)

    pathctx.register_set_in_scope(result, optional=True, ctx=ctx)

    return result


@dispatch.compile.register(qlast.Path)
def compile_Path(expr: qlast.Path, *, ctx: context.ContextLevel) -> irast.Set:
    if ctx.no_factoring and not expr.allow_factoring:
        res = dispatch.compile(
            qlast.SelectQuery(
                result=expr.replace(allow_factoring=True),
                implicit=True,
            ),
            ctx=ctx,
        )
        # Mark the nodes as having been protected from factoring. At
        # the end of compilation, we see if we can eliminate the
        # scopes without inducing factoring.
        #
        # Don't do this if the head of the path is an expression
        # (instead of an ObjectRef), though, because that interacts
        # badly with function inlining in some cases??
        # (test_edgeql_functions_inline_object_06).
        # My hope is to just destroy all this machinery instead of tracking
        # that interaction down, though.
        if expr.partial or not isinstance(expr.steps[0], qlast.Expr):
            res.is_factoring_protected = True
        return res

    return stmt.maybe_add_view(setgen.compile_path(expr, ctx=ctx), ctx=ctx)


def _balance(
    elements: Sequence[qlast.Expr],
    ctor: Callable[
        [qlast.Expr, qlast.Expr, Optional[qlast.Span]],
        qlast.Expr
    ],
    span: Optional[qlast.Span],
) -> qlast.Expr:
    mid = len(elements) // 2
    ls, rs = elements[:mid], elements[mid:]
    ls_span = rs_span = None
    if len(ls) > 1 and ls[0].span and ls[-1].span:
        ls_span = edb_span.merge_spans([
            ls[0].span, ls[-1].span
        ])
    if len(rs) > 1 and rs[0].span and rs[-1].span:
        rs_span = edb_span.merge_spans([
            rs[0].span, rs[-1].span])

    return ctor(
        (
            _balance(ls, ctor, ls_span)
            if len(ls) > 1 else ls[0]
        ),
        (
            _balance(rs, ctor, rs_span)
            if len(rs) > 1 else rs[0]
        ),
        span,
    )


REBALANCED_OPS = {'UNION'}
REBALANCE_THRESHOLD = 10


@dispatch.compile.register(qlast.BinOp)
def compile_BinOp(expr: qlast.BinOp, *, ctx: context.ContextLevel) -> irast.Set:
    # Rebalance some associative operations to avoid deeply nested ASTs
    if expr.op in REBALANCED_OPS and not expr.rebalanced:
        elements = collect_binop(expr)
        # Don't bother rebalancing small groups
        if len(elements) >= REBALANCE_THRESHOLD:
            balanced = _balance(
                elements,
                lambda l, r, s: qlast.BinOp(
                    left=l, right=r, op=expr.op, rebalanced=True, span=s
                ),
                expr.span
            )
            return dispatch.compile(balanced, ctx=ctx)

    if expr.op == '??' and astutils.contains_dml(expr.right, ctx=ctx):
        return _compile_dml_coalesce(expr, ctx=ctx)

    return func.compile_operator(
        expr, op_name=expr.op, qlargs=[expr.left, expr.right], ctx=ctx
    )


@dispatch.compile.register(qlast.IsOp)
def compile_IsOp(expr: qlast.IsOp, *, ctx: context.ContextLevel) -> irast.Set:
    op_node = compile_type_check_op(expr, ctx=ctx)
    return setgen.ensure_set(op_node, ctx=ctx)


@dispatch.compile.register
def compile_StrInterp(
    expr: qlast.StrInterp, *, ctx: context.ContextLevel
) -> irast.Set:
    strs: list[qlast.Expr] = []
    strs.append(qlast.Constant.string(expr.prefix))

    str_type = qlast.TypeName(
        maintype=qlast.ObjectRef(module='__std__', name='str')
    )
    for fragment in expr.interpolations:
        strs.append(qlast.TypeCast(
            expr=fragment.expr, type=str_type
        ))
        strs.append(qlast.Constant.string(fragment.suffix))

    call = qlast.FunctionCall(
        func=('__std__', 'array_join'),
        args=[qlast.Array(elements=strs), qlast.Constant.string('')],
    )

    return dispatch.compile(call, ctx=ctx)


@dispatch.compile.register(qlast.DetachedExpr)
def compile_DetachedExpr(
    expr: qlast.DetachedExpr,
    *,
    ctx: context.ContextLevel,
) -> irast.Set:
    with ctx.detached() as subctx:
        if expr.preserve_path_prefix:
            subctx.partial_path_prefix = ctx.partial_path_prefix

        ir = dispatch.compile(expr.expr, ctx=subctx)
    # Wrap the result in another set, so that the inner namespace
    # doesn't leak out into any shapes (since computable computation
    # will pull namespaces from the source path_ids.)
    return setgen.ensure_set(setgen.ensure_stmt(ir, ctx=ctx), ctx=ctx)


@dispatch.compile.register(qlast.Set)
def compile_Set(expr: qlast.Set, *, ctx: context.ContextLevel) -> irast.Set:
    # after flattening the set may still end up with 0 or 1 element,
    # which are treated as a special case
    elements = flatten_set(expr)

    if elements:
        if len(elements) == 1:
            # From the scope perspective, single-element set
            # literals are equivalent to a binary UNION with
            # an empty set, not to the element.
            return dispatch.compile(
                astutils.ensure_ql_query(elements[0]), ctx=ctx
            )
        else:
            # Turn it into a tree of UNIONs so we only blow up the nesting
            # depth logarithmically.
            # TODO: Introduce an N-ary operation that handles the whole thing?
            bigunion = _balance(
                elements,
                lambda l, r, s: qlast.BinOp(
                    left=l, op='UNION', right=r,
                    rebalanced=True, set_constructor=True, span=s
                ),
                expr.span
            )
            res = dispatch.compile(bigunion, ctx=ctx)
            if cres := try_constant_set(res):
                res = setgen.ensure_set(cres, span=res.span, ctx=ctx)
            return res
    else:
        return setgen.new_empty_set(
            alias=ctx.aliases.get('e'),
            ctx=ctx,
            span=expr.span,
        )


@dispatch.compile.register(qlast.Constant)
def compile_Constant(
    expr: qlast.Constant, *, ctx: context.ContextLevel
) -> irast.Set:
    value = expr.value

    node_cls: type[irast.BaseConstant]

    if expr.kind == qlast.ConstantKind.STRING:
        std_type = sn.QualName('std', 'str')
        node_cls = irast.StringConstant
    elif expr.kind == qlast.ConstantKind.INTEGER:
        value = value.replace("_", "")
        std_type = sn.QualName('std', 'int64')
        node_cls = irast.IntegerConstant
    elif expr.kind == qlast.ConstantKind.FLOAT:
        value = value.replace("_", "")
        std_type = sn.QualName('std', 'float64')
        node_cls = irast.FloatConstant
    elif expr.kind == qlast.ConstantKind.DECIMAL:
        assert value[-1] == 'n'
        value = value[:-1].replace("_", "")
        std_type = sn.QualName('std', 'decimal')
        node_cls = irast.DecimalConstant
    elif expr.kind == qlast.ConstantKind.BIGINT:
        assert value[-1] == 'n'
        value = value[:-1].replace("_", "")
        std_type = sn.QualName('std', 'bigint')
        node_cls = irast.BigintConstant
    elif expr.kind == qlast.ConstantKind.BOOLEAN:
        std_type = sn.QualName('std', 'bool')
        node_cls = irast.BooleanConstant
    else:
        raise RuntimeError(f'unexpected constant type: {expr.kind}')

    ct = typegen.type_to_typeref(
        ctx.env.get_schema_type_and_track(std_type),
        env=ctx.env,
    )
    ir_expr = node_cls(value=value, typeref=ct, span=expr.span)
    return setgen.ensure_set(ir_expr, ctx=ctx)


@dispatch.compile.register(qlast.BytesConstant)
def compile_BytesConstant(
    expr: qlast.BytesConstant, *, ctx: context.ContextLevel
) -> irast.Set:
    std_type = sn.QualName('std', 'bytes')

    ct = typegen.type_to_typeref(
        ctx.env.get_schema_type_and_track(std_type),
        env=ctx.env,
    )
    return setgen.ensure_set(
        irast.BytesConstant(value=expr.value, typeref=ct), ctx=ctx
    )


@dispatch.compile.register(qlast.NamedTuple)
def compile_NamedTuple(
    expr: qlast.NamedTuple, *, ctx: context.ContextLevel
) -> irast.Set:

    names = set()
    elements = []
    for el in expr.elements:
        name = el.name.name
        if name in names:
            raise errors.QueryError(
                f"named tuple has duplicate field '{name}'",
                span=el.span)
        names.add(name)

        element = irast.TupleElement(
            name=name,
            val=dispatch.compile(el.val, ctx=ctx),
        )
        elements.append(element)

    return setgen.new_tuple_set(elements, named=True, ctx=ctx)


@dispatch.compile.register(qlast.Tuple)
def compile_Tuple(expr: qlast.Tuple, *, ctx: context.ContextLevel) -> irast.Set:

    elements = []
    for i, el in enumerate(expr.elements):
        element = irast.TupleElement(
            name=str(i),
            val=dispatch.compile(el, ctx=ctx),
        )
        elements.append(element)

    return setgen.new_tuple_set(elements, named=False, ctx=ctx)


@dispatch.compile.register(qlast.Array)
def compile_Array(expr: qlast.Array, *, ctx: context.ContextLevel) -> irast.Set:
    elements = [dispatch.compile(e, ctx=ctx) for e in expr.elements]
    return setgen.new_array_set(elements, ctx=ctx, span=expr.span)


def _compile_dml_coalesce(
    expr: qlast.BinOp, *, ctx: context.ContextLevel
) -> irast.Set:
    """Transform a coalesce that contains DML into FOR loops

    The basic approach is to extract the pieces from the ?? and
    rewrite them into:
        for optional x in (LHS) union (
          {
            x,
            (for _ in (select () filter not exists x) union (RHS)),
          }
        )

    Optional for is needed because the LHS needs to be bound in a for
    in order to get put in a CTE and only executed once, but the RHS
    needs to be dependent on the LHS being empty.
    """
    with ctx.newscope(fenced=False) as subctx:
        # We have to compile it under a factoring fence to prevent
        # correlation with outside things. We can't just rely on the
        # factoring fences inserted when compiling the FORs, since we
        # are going to need to explicitly exempt the iterator
        # expression from that.
        subctx.path_scope.factoring_fence = True
        subctx.path_scope.factoring_allowlist.update(ctx.iterator_path_ids)

        ir = func.compile_operator(
            expr, op_name=expr.op, qlargs=[expr.left, expr.right], ctx=subctx)

        # Extract the IR parts from the ??
        # Note that lhs_ir will be unfenced while rhs_ir
        # will have been compiled under fences.
        match ir.expr:
            case irast.OperatorCall(args={
                0: irast.CallArg(expr=lhs_ir),
                1: irast.CallArg(expr=rhs_ir),
            }):
                pass
            case _:
                raise AssertionError('malformed DML ??')

        subctx.anchors = subctx.anchors.copy()

        alias = ctx.aliases.get('_coalesce_x')
        cond_path = qlast.Path(
            steps=[qlast.ObjectRef(name=alias)],
        )

        rhs_b = qlast.ForQuery(
            iterator_alias=ctx.aliases.get('_coalesce_dummy'),
            iterator=qlast.SelectQuery(
                result=qlast.Tuple(elements=[]),
                where=qlast.UnaryOp(
                    op='NOT',
                    operand=qlast.UnaryOp(op='EXISTS', operand=cond_path),
                ),
            ),
            result=subctx.create_anchor(
                rhs_ir, move_scope=True, check_dml=True
            ),
        )

        full = qlast.ForQuery(
            iterator_alias=alias,
            iterator=subctx.create_anchor(lhs_ir, 'b'),
            result=qlast.Set(elements=[cond_path, rhs_b]),
            optional=True,
            from_desugaring=True,
        )

        subctx.iterator_path_ids |= {lhs_ir.path_id}
        res = dispatch.compile(full, ctx=subctx)
        # Indicate that the original ?? code should determine the
        # cardinality/multiplicity.
        assert isinstance(res.expr, irast.SelectStmt)
        res.expr.card_inference_override = ir

        return res


def _compile_dml_ifelse(
    expr: qlast.IfElse, *, ctx: context.ContextLevel
) -> irast.Set:
    """Transform an IF/ELSE that contains DML into FOR loops

    The basic approach is to extract the pieces from the if/then/else and
    rewrite them into:
        for b in COND union (
          {
            (for _ in (select () filter b) union (IF_BRANCH)),
            (for _ in (select () filter not b) union (ELSE_BRANCH)),
          }
        )
    """

    with ctx.newscope(fenced=False) as subctx:
        # We have to compile it under a factoring fence to prevent
        # correlation with outside things. We can't just rely on the
        # factoring fences inserted when compiling the FORs, since we
        # are going to need to explicitly exempt the iterator
        # expression from that.
        subctx.path_scope.factoring_fence = True
        subctx.path_scope.factoring_allowlist.update(ctx.iterator_path_ids)

        ir = func.compile_operator(
            expr, op_name='std::IF',
            qlargs=[expr.if_expr, expr.condition, expr.else_expr], ctx=subctx)

        # Extract the IR parts from the IF/THEN/ELSE
        # Note that cond_ir will be unfenced while if_ir and else_ir
        # will have been compiled under fences.
        match ir.expr:
            case irast.OperatorCall(args={
                0: irast.CallArg(expr=if_ir),
                1: irast.CallArg(expr=cond_ir),
                2: irast.CallArg(expr=else_ir),
            }):
                pass
            case _:
                raise AssertionError('malformed DML IF/ELSE')

        subctx.anchors = subctx.anchors.copy()

        alias = ctx.aliases.get('_ifelse_b')
        cond_path = qlast.Path(
            steps=[qlast.ObjectRef(name=alias)],
        )

        els: list[qlast.Expr] = []

        if not isinstance(irutils.unwrap_set(if_ir).expr, irast.EmptySet):
            if_b = qlast.ForQuery(
                iterator_alias=ctx.aliases.get('_ifelse_true_dummy'),
                iterator=qlast.SelectQuery(
                    result=qlast.Tuple(elements=[]),
                    where=cond_path,
                ),
                result=subctx.create_anchor(
                    if_ir, move_scope=True, check_dml=True
                ),
            )
            els.append(if_b)

        if not isinstance(irutils.unwrap_set(else_ir).expr, irast.EmptySet):
            else_b = qlast.ForQuery(
                iterator_alias=ctx.aliases.get('_ifelse_false_dummy'),
                iterator=qlast.SelectQuery(
                    result=qlast.Tuple(elements=[]),
                    where=qlast.UnaryOp(op='NOT', operand=cond_path),
                ),
                result=subctx.create_anchor(
                    else_ir, move_scope=True, check_dml=True
                ),
            )
            els.append(else_b)

        full = qlast.ForQuery(
            iterator_alias=alias,
            iterator=subctx.create_anchor(cond_ir, 'b'),
            result=qlast.Set(elements=els) if len(els) != 1 else els[0],
        )

        subctx.iterator_path_ids |= {cond_ir.path_id}
        res = dispatch.compile(full, ctx=subctx)
        # Indicate that the original IF/ELSE code should determine the
        # cardinality/multiplicity.
        assert isinstance(res.expr, irast.SelectStmt)
        res.expr.card_inference_override = ir

        return res


@dispatch.compile.register(qlast.IfElse)
def compile_IfElse(
    expr: qlast.IfElse, *, ctx: context.ContextLevel
) -> irast.Set:

    if (
        astutils.contains_dml(expr.if_expr, ctx=ctx)
        or astutils.contains_dml(expr.else_expr, ctx=ctx)
    ):
        return _compile_dml_ifelse(expr, ctx=ctx)

    res = func.compile_operator(
        expr, op_name='std::IF',
        qlargs=[expr.if_expr, expr.condition, expr.else_expr], ctx=ctx)

    return res


@dispatch.compile.register(qlast.UnaryOp)
def compile_UnaryOp(
    expr: qlast.UnaryOp, *, ctx: context.ContextLevel
) -> irast.Set:

    return func.compile_operator(
        expr, op_name=expr.op, qlargs=[expr.operand], ctx=ctx)


def _cache_as_type_rewrite(
    target: irast.Set,
    stype: s_types.Type,
    populate: Callable[[], irast.Set],
    *,
    ctx: context.ContextLevel,
) -> irast.Set:
    key = (stype, False)
    if not ctx.env.type_rewrites.get(key):
        ctx.env.type_rewrites[key] = populate()
    rewrite_target = ctx.env.type_rewrites[key]

    # We need to have the set with expr=None, so that the rewrite
    # will be applied, but we also need to wrap it with a
    # card_inference_override so that we use the real cardinality
    # instead of assuming it is MANY.
    assert isinstance(rewrite_target, irast.Set)
    typeref = typegen.type_to_typeref(stype, env=ctx.env)
    target = setgen.new_set_from_set(
        target,
        expr=irast.TypeRoot(typeref=typeref, is_cached_global=True),
        stype=stype,
        ctx=ctx,
    )
    wrap = irast.SelectStmt(
        result=target,
        card_inference_override=rewrite_target,
    )
    return setgen.new_set_from_set(target, expr=wrap, ctx=ctx)


@dispatch.compile.register(qlast.GlobalExpr)
def compile_GlobalExpr(
    expr: qlast.GlobalExpr, *, ctx: context.ContextLevel
) -> irast.Set:
    # The expr object can be either a Permission or Global.
    # Get an Object and manually check for correct type and None.
    expr_schema_name = s_utils.ast_ref_to_name(expr.name)
    glob = ctx.env.get_schema_object_and_track(
        expr_schema_name,
        expr.name,
        default=None,
        modaliases=ctx.modaliases,
        type=so.Object,
    )

    # Check for None first.
    if glob is None:
        # If no object is found, we want to raise an error with 'global' as
        # the desired type.
        # If we let `get_schema_object_and_track`, the error will contain
        # 'object' instead.
        s_schema.Schema.raise_bad_reference(
            expr_schema_name,
            module_aliases=ctx.modaliases,
            span=expr.span,
            type=s_globals.Global,
        )

    # Check for incorrect type
    if not isinstance(glob, (s_globals.Global, s_permissions.Permission)):
        s_schema.Schema.raise_wrong_type(
            expr_schema_name,
            glob.__class__,
            s_globals.Global,
            span=expr.span,
        )
        # Raise an error here so mypy knows that expr_obj can only be a global
        # or permission past this point.
        raise AssertionError('should never happen')

    if (
        # Computed global
        isinstance(glob, s_globals.Global)
        and glob.is_computable(ctx.env.schema)
    ):
        obj_ref = s_utils.name_to_ast_ref(
            glob.get_target(ctx.env.schema).get_name(ctx.env.schema))
        # Wrap the reference in a subquery so that it does not get
        # factored out or go directly into the scope tree.
        qry = qlast.SelectQuery(result=qlast.Path(steps=[obj_ref]))
        target = dispatch.compile(qry, ctx=ctx)

        # If the global is single, use type_rewrites to make sure it
        # is computed only once in the SQL query.
        if glob.get_cardinality(ctx.env.schema).is_single():
            def _populate() -> irast.Set:
                with ctx.detached() as dctx:
                    # The official rewrite needs to be in a detached
                    # scope to avoid collisions; this won't really
                    # recompile the whole thing, it will hit a cache
                    # of the view.
                    return dispatch.compile(qry, ctx=dctx)

            target = _cache_as_type_rewrite(
                target,
                setgen.get_set_type(target, ctx=ctx),
                populate=_populate,
                ctx=ctx,
            )

        return target

    default_ql: Optional[qlast.Expr] = None
    if isinstance(glob, s_globals.Global):
        if default_expr := glob.get_default(ctx.env.schema):
            default_ql = default_expr.parse()

    # If we are compiling with globals suppressed but still allowed, always
    # treat it as being empty.
    if ctx.env.options.make_globals_empty:
        if isinstance(glob, s_permissions.Permission):
            return dispatch.compile(qlast.Constant.boolean(False), ctx=ctx)
        elif default_ql:
            return dispatch.compile(default_ql, ctx=ctx)
        else:
            return setgen.new_empty_set(
                stype=glob.get_target(ctx.env.schema), ctx=ctx
            )

    objctx = ctx.env.options.schema_object_context
    if objctx in (s_constr.Constraint, s_indexes.Index):
        typname = objctx.get_schema_class_displayname()
        raise errors.SchemaDefinitionError(
            f'global variables cannot be referenced from {typname}',
            span=expr.span)

    param_set: qlast.Expr | irast.Set
    present_set: qlast.Expr | irast.Set | None
    if (
        ctx.env.options.func_params is None
        and not ctx.env.options.json_parameters
    ):
        param_set, present_set = setgen.get_global_param_sets(
            glob, ctx=ctx,
        )
    else:
        param_set, present_set = setgen.get_func_global_param_sets(
            glob, ctx=ctx
        )

        if isinstance(glob, s_permissions.Permission):
            # Globals are assumed to be optional within functions. However,
            # permissions always have a value. Provide a default value to
            # reassure the cardinality checks.
            default_ql = qlast.Constant.boolean(False)

    if default_ql and not present_set:
        # If we have a default value and the global is required,
        # then we can use the param being {} as a signal to use
        # the default.
        with ctx.new() as subctx:
            subctx.anchors = subctx.anchors.copy()
            main_param = subctx.maybe_create_anchor(param_set, 'glob')
            param_set = func.compile_operator(
                expr,
                op_name='std::??',
                qlargs=[main_param, default_ql],
                ctx=subctx
            )
    elif default_ql and present_set:
        # ... but if {} is a valid value for the global, we need to
        # stick in an extra parameter to indicate whether to use
        # the default.
        with ctx.new() as subctx:
            subctx.anchors = subctx.anchors.copy()
            main_param = subctx.maybe_create_anchor(param_set, 'glob')

            present_param = subctx.maybe_create_anchor(present_set, 'present')

            param_set = func.compile_operator(
                expr,
                op_name='std::IF',
                qlargs=[main_param, present_param, default_ql],
                ctx=subctx
            )
    elif not isinstance(param_set, irast.Set):
        param_set = dispatch.compile(param_set, ctx=ctx)

    # When we are compiling a global as something we are pulling out
    # of JSON, arrange to cache it as a type rewrite. This can have
    # big performance wins.
    if (
        not ctx.env.options.schema_object_context
        and not (
            ctx.env.options.func_params is None
            and not ctx.env.options.json_parameters
        )
        # TODO: support this for permissions too?
        # OR! Don't put the permissions into the globals JSON?
        and isinstance(glob, s_globals.Global)
    ):
        name = glob.get_name(ctx.env.schema)
        if name not in ctx.env.query_globals_types:
            # HACK: We have mechanism for caching based on type... so
            # make up a type.
            # I would like us to be less type-forward though.
            ctx.env.query_globals_types[name] = (
                schemactx.derive_view(glob.get_target(ctx.env.schema), ctx=ctx)
            )
        ty = ctx.env.query_globals_types[name]
        param_set = _cache_as_type_rewrite(
            param_set, ty, lambda: param_set, ctx=ctx
        )

    return param_set


@dispatch.compile.register(qlast.TypeCast)
def compile_TypeCast(
    expr: qlast.TypeCast, *, ctx: context.ContextLevel
) -> irast.Set:
    try:
        target_stype = typegen.ql_typeexpr_to_type(expr.type, ctx=ctx)
    except errors.InvalidReferenceError as e:
        if (
            e.hint is None
            and isinstance(expr.type, qlast.TypeName)
            and isinstance(expr.type.maintype, qlast.ObjectRef)
        ):
            s_utils.enrich_schema_lookup_error(
                e,
                s_utils.ast_ref_to_name(expr.type.maintype),
                modaliases=ctx.modaliases,
                schema=ctx.env.schema,
                suggestion_limit=1,
                item_type=s_func.Function,
                hint_text='did you mean to call'
            )
        raise

    ir_expr: irast.Set | irast.Expr

    if isinstance(expr.expr, (qlast.QueryParameter, qlast.FunctionParameter)):
        if (
            # generic types not explicitly allowed
            not ctx.env.options.allow_generic_type_output and
            # not compiling a function which hadles its own generic types
            ctx.env.options.func_name is None and
            target_stype.is_polymorphic(ctx.env.schema)
        ):
            raise errors.QueryError(
                f'parameter cannot be a generic type '
                f'{target_stype.get_displayname(ctx.env.schema)!r}',
                hint="Please ensure you don't use generic "
                     '"any" types or abstract scalars.',
                span=expr.span)

        pt = typegen.ql_typeexpr_to_type(expr.type, ctx=ctx)

        param_name = expr.expr.name
        if expr.cardinality_mod:
            if expr.cardinality_mod == qlast.CardinalityModifier.Optional:
                required = False
            elif expr.cardinality_mod == qlast.CardinalityModifier.Required:
                required = True
            else:
                raise NotImplementedError(
                    f"cardinality modifier {expr.cardinality_mod}")
        else:
            required = True

        parameter_type = (
            irast.QueryParameter
            if isinstance(expr.expr, qlast.QueryParameter) else
            irast.FunctionParameter
        )

        typeref = typegen.type_to_typeref(pt, env=ctx.env)
        param = setgen.ensure_set(
            parameter_type(
                typeref=typeref,
                name=param_name,
                required=required,
                span=expr.expr.span,
            ),
            ctx=ctx,
        )

        if ex_param := ctx.env.script_params.get(param_name):
            # N.B. Accessing the schema_type from the param is unreliable
            ctx.env.schema, param_first_type = irtyputils.ir_typeref_to_type(
                ctx.env.schema, ex_param.ir_type)
            if param_first_type != pt:
                raise errors.QueryError(
                    f'parameter type '
                    f'{pt.get_displayname(ctx.env.schema)} '
                    f'does not match original type '
                    f'{param_first_type.get_displayname(ctx.env.schema)}',
                    span=expr.expr.span)

        if param_name not in ctx.env.query_parameters:
            sub_params = None
            if ex_param and ex_param.sub_params:
                sub_params = tuple_args.finish_sub_params(
                    ex_param.sub_params, ctx=ctx)

            ctx.env.query_parameters[param_name] = irast.Param(
                name=param_name,
                required=required,
                schema_type=pt,
                ir_type=typeref,
                sub_params=sub_params,
            )

        return param

    with ctx.new() as subctx:
        if target_stype.contains_json(subctx.env.schema):
            # JSON wants type shapes and acts as an output sink.
            subctx.expr_exposed = context.Exposure.EXPOSED
            subctx.implicit_limit = 0
            subctx.implicit_id_in_shapes = False
            subctx.implicit_tid_in_shapes = False
            subctx.implicit_tname_in_shapes = False

        ir_expr = dispatch.compile(expr.expr, ctx=subctx)
        orig_stype = setgen.get_set_type(ir_expr, ctx=ctx)

        use_message_context = False
        if target_stype.is_collection() and subctx.collection_cast_info is None:
            subctx.collection_cast_info = context.CollectionCastInfo(
                from_type=orig_stype,
                to_type=target_stype,
                path_elements=[]
            )

            use_message_context = (
                orig_stype.is_array() and target_stype.is_array()
                or (
                    orig_stype.is_tuple(ctx.env.schema)
                    and target_stype.is_tuple(ctx.env.schema)
                )
            )

        try:
            res = casts.compile_cast(
                ir_expr,
                target_stype,
                cardinality_mod=expr.cardinality_mod,
                ctx=subctx,
                span=expr.span,
            )

        except errors.QueryError as e:
            if (
                (message_context := casts.cast_message_context(subctx))
                and use_message_context
            ):
                e.args = (
                    (message_context + e.args[0],)
                    + e.args[1:]
                )
            raise e

    return stmt.maybe_add_view(res, ctx=ctx)


def _infer_type_introspection(
    typeref: irast.TypeRef,
    env: context.Environment,
    span: Optional[parsing.Span]=None,
) -> s_types.Type:
    if irtyputils.is_scalar(typeref):
        return env.schema.get_by_name(
            'schema::ScalarType', type=s_objtypes.ObjectType
        )
    elif irtyputils.is_object(typeref):
        return env.schema.get_by_name(
            'schema::ObjectType', type=s_objtypes.ObjectType
        )
    elif irtyputils.is_array(typeref):
        return env.schema.get_by_name(
            'schema::Array', type=s_objtypes.ObjectType
        )
    elif irtyputils.is_tuple(typeref):
        return env.schema.get_by_name(
            'schema::Tuple', type=s_objtypes.ObjectType
        )
    elif irtyputils.is_range(typeref):
        return env.schema.get_by_name(
            'schema::Range', type=s_objtypes.ObjectType
        )
    elif irtyputils.is_multirange(typeref):
        return env.schema.get_by_name(
            'schema::MultiRange', type=s_objtypes.ObjectType
        )
    else:
        raise errors.QueryError(
            'unexpected type in INTROSPECT', span=span)


@dispatch.compile.register(qlast.Introspect)
def compile_Introspect(
    expr: qlast.Introspect, *, ctx: context.ContextLevel
) -> irast.Set:

    typeref = typegen.ql_typeexpr_to_ir_typeref(expr.type, ctx=ctx)
    if typeref.material_type and not irtyputils.is_object(typeref):
        typeref = typeref.material_type
    if typeref.is_opaque_union:
        typeref = typegen.type_to_typeref(
            ctx.env.schema.get_by_name(
                'std::BaseObject', type=s_objtypes.ObjectType
            ),
            env=ctx.env,
        )

    if irtyputils.is_view(typeref):
        raise errors.QueryError(
            f'cannot introspect transient type variant',
            span=expr.type.span)
    if irtyputils.is_collection(typeref):
        raise errors.QueryError(
            f'cannot introspect collection types',
            span=expr.type.span)
    if irtyputils.is_generic(typeref):
        raise errors.QueryError(
            f'cannot introspect generic types',
            span=expr.type.span)

    result_typeref = typegen.type_to_typeref(
        _infer_type_introspection(typeref, ctx.env, expr.span), env=ctx.env
    )
    ir = setgen.ensure_set(
        irast.TypeIntrospection(output_typeref=typeref, typeref=result_typeref),
        ctx=ctx,
    )
    return stmt.maybe_add_view(ir, ctx=ctx)


def _infer_index_type(
    expr: irast.Set | irast.Expr,
    index: irast.Set,
    *, ctx: context.ContextLevel,
) -> s_types.Type:
    env = ctx.env
    node_type = setgen.get_expr_type(expr, ctx=ctx)
    index_type = setgen.get_set_type(index, ctx=ctx)

    str_t = env.schema.get_by_name('std::str', type=s_scalars.ScalarType)
    bytes_t = env.schema.get_by_name('std::bytes', type=s_scalars.ScalarType)
    int_t = env.schema.get_by_name('std::int64', type=s_scalars.ScalarType)
    json_t = env.schema.get_by_name('std::json', type=s_scalars.ScalarType)

    result: s_types.Type

    if node_type.issubclass(env.schema, str_t):

        if not index_type.implicitly_castable_to(int_t, env.schema):
            raise errors.QueryError(
                f'cannot index string by '
                f'{index_type.get_displayname(env.schema)}, '
                f'{int_t.get_displayname(env.schema)} was expected',
                span=index.span)

        result = str_t

    elif node_type.issubclass(env.schema, bytes_t):

        if not index_type.implicitly_castable_to(int_t, env.schema):
            raise errors.QueryError(
                f'cannot index bytes by '
                f'{index_type.get_displayname(env.schema)}, '
                f'{int_t.get_displayname(env.schema)} was expected',
                span=index.span)

        result = bytes_t

    elif node_type.issubclass(env.schema, json_t):

        if not (index_type.implicitly_castable_to(int_t, env.schema) or
                index_type.implicitly_castable_to(str_t, env.schema)):

            raise errors.QueryError(
                f'cannot index json by '
                f'{index_type.get_displayname(env.schema)}, '
                f'{int_t.get_displayname(env.schema)} or '
                f'{str_t.get_displayname(env.schema)} was expected',
                span=index.span)

        result = json_t

    elif isinstance(node_type, s_types.Array):

        if not index_type.implicitly_castable_to(int_t, env.schema):
            raise errors.QueryError(
                f'cannot index array by '
                f'{index_type.get_displayname(env.schema)}, '
                f'{int_t.get_displayname(env.schema)} was expected',
                span=index.span)

        result = node_type.get_subtypes(env.schema)[0]

    elif (node_type.is_any(env.schema) or
            (node_type.is_scalar() and
                str(node_type.get_name(env.schema)) == 'std::anyscalar') and
            (index_type.implicitly_castable_to(int_t, env.schema) or
                index_type.implicitly_castable_to(str_t, env.schema))):
        result = s_pseudo.PseudoType.get(env.schema, 'anytype')

    else:
        raise errors.QueryError(
            f'index indirection cannot be applied to '
            f'{node_type.get_verbosename(env.schema)}',
            span=expr.span)

    return result


def _infer_slice_type(
    expr: irast.Set,
    start: Optional[irast.Set],
    stop: Optional[irast.Set],
    *, ctx: context.ContextLevel,
) -> s_types.Type:
    env = ctx.env
    node_type = setgen.get_set_type(expr, ctx=ctx)

    str_t = env.schema.get_by_name('std::str', type=s_scalars.ScalarType)
    int_t = env.schema.get_by_name('std::int64', type=s_scalars.ScalarType)
    json_t = env.schema.get_by_name('std::json', type=s_scalars.ScalarType)
    bytes_t = env.schema.get_by_name('std::bytes', type=s_scalars.ScalarType)

    if node_type.issubclass(env.schema, str_t):
        base_name = 'string'
    elif node_type.issubclass(env.schema, json_t):
        base_name = 'JSON array'
    elif node_type.issubclass(env.schema, bytes_t):
        base_name = 'bytes'
    elif isinstance(node_type, s_types.Array):
        base_name = 'array'
    elif node_type.is_any(env.schema):
        base_name = 'anytype'
    else:
        # the base type is not valid
        raise errors.QueryError(
            f'{node_type.get_verbosename(env.schema)} cannot be sliced',
            span=expr.span)

    for index in [start, stop]:
        if index is not None:
            index_type = setgen.get_set_type(index, ctx=ctx)

            if not index_type.implicitly_castable_to(int_t, env.schema):
                raise errors.QueryError(
                    f'cannot slice {base_name} by '
                    f'{index_type.get_displayname(env.schema)}, '
                    f'{int_t.get_displayname(env.schema)} was expected',
                    span=index.span)

    return node_type


@dispatch.compile.register(qlast.Indirection)
def compile_Indirection(
    expr: qlast.Indirection, *, ctx: context.ContextLevel
) -> irast.Set:
    node: irast.Set | irast.Expr = dispatch.compile(expr.arg, ctx=ctx)
    for indirection_el in expr.indirection:
        if isinstance(indirection_el, qlast.Index):
            idx = dispatch.compile(indirection_el.index, ctx=ctx)
            idx.span = indirection_el.index.span
            typeref = typegen.type_to_typeref(
                _infer_index_type(node, idx, ctx=ctx), env=ctx.env
            )

            node = irast.IndexIndirection(
                expr=node, index=idx, typeref=typeref, span=expr.span
            )
        elif isinstance(indirection_el, qlast.Slice):
            start: Optional[irast.Base]
            stop: Optional[irast.Base]

            if indirection_el.start:
                start = dispatch.compile(indirection_el.start, ctx=ctx)
            else:
                start = None

            if indirection_el.stop:
                stop = dispatch.compile(indirection_el.stop, ctx=ctx)
            else:
                stop = None

            node_set = setgen.ensure_set(node, ctx=ctx)
            typeref = typegen.type_to_typeref(
                _infer_slice_type(node_set, start, stop, ctx=ctx), env=ctx.env
            )
            node = irast.SliceIndirection(
                expr=node_set, start=start, stop=stop, typeref=typeref,
            )
        else:
            raise ValueError(
                'unexpected indirection node: ' '{!r}'.format(indirection_el)
            )

    return setgen.ensure_set(node, ctx=ctx)


def compile_type_check_op(
    expr: qlast.IsOp, *, ctx: context.ContextLevel
) -> irast.TypeCheckOp:
    #  IS 
    left = dispatch.compile(expr.left, ctx=ctx)
    ltype = setgen.get_set_type(left, ctx=ctx)
    typeref = typegen.ql_typeexpr_to_ir_typeref(expr.right, ctx=ctx)

    if ltype.is_object_type() and not ltype.is_free_object_type(ctx.env.schema):
        # Argh, what a mess path factoring and deduplication is!  We
        # need to dereference __type__, and  needs to be visible
        # in the scope when we do it, or else it will get
        # deduplicated.
        pathctx.register_set_in_scope(left, ctx=ctx)

        left = setgen.ptr_step_set(
            left, expr=None, source=ltype, ptr_name='__type__',
            span=expr.span, ctx=ctx
        )
        result = None
    else:
        if (ltype.is_collection()
                and cast(s_types.Collection, ltype).contains_object(
                    ctx.env.schema)):
            raise errors.QueryError(
                f'type checks on non-primitive collections are not supported'
            )

        ctx.env.schema, test_type = (
            irtyputils.ir_typeref_to_type(ctx.env.schema, typeref)
        )
        result = ltype.issubclass(ctx.env.schema, test_type)

    output_typeref = typegen.type_to_typeref(
        ctx.env.schema.get_by_name('std::bool', type=s_types.Type),
        env=ctx.env,
    )

    return irast.TypeCheckOp(
        left=left, right=typeref, op=expr.op, result=result,
        typeref=output_typeref)


def flatten_set(expr: qlast.Set) -> list[qlast.Expr]:
    elements = []
    for el in expr.elements:
        if isinstance(el, qlast.Set):
            elements.extend(flatten_set(el))
        else:
            elements.append(el)

    return elements


def collect_binop(expr: qlast.BinOp) -> list[qlast.Expr]:
    elements = []

    stack = [expr.right, expr.left]
    while stack:
        el = stack.pop()
        if isinstance(el, qlast.BinOp) and el.op == expr.op:
            stack.extend([el.right, el.left])
        else:
            elements.append(el)

    return elements


def try_constant_set(expr: irast.Base) -> Optional[irast.ConstantSet]:
    elements = []

    stack: list[Optional[irast.Base]] = [expr]
    while stack:
        el = stack.pop()
        if isinstance(el, irast.Set):
            stack.append(el.expr)
        elif (
            isinstance(el, irast.OperatorCall)
            and str(el.func_shortname) == 'std::UNION'
        ):
            stack.extend([el.args[1].expr.expr, el.args[0].expr.expr])
        elif el and irutils.is_trivial_select(el):
            stack.append(el.result)
        elif isinstance(el, (irast.BaseConstant, irast.BaseParameter)):
            elements.append(el)
        else:
            return None

    if elements:
        return irast.ConstantSet(
            elements=tuple(elements), typeref=elements[0].typeref
        )
    else:
        return None


class IdentCompletionException(BaseException):
    """An exception that is raised to halt the compilation and return a list of
    suggested idents to be used at the location of qlast.Cursor node.
    """

    def __init__(self, suggestions: list[str]):
        self.suggestions = suggestions


@dispatch.compile.register(qlast.Cursor)
def compile_Cursor(
    expr: qlast.Cursor, *, ctx: context.ContextLevel
) -> irast.Set:
    suggestions = []

    # with bindings
    name: sn.Name
    for name in ctx.aliased_views.keys():
        suggestions.append(name.name)

    # names in current module
    if cur_mod := ctx.modaliases.get(None):
        obj_types = ctx.env.schema.get_objects(
            included_modules=[sn.UnqualName(cur_mod)],
            type=s_objtypes.ObjectType,
        )
        obj_type_names = [
            obj_type.get_name(ctx.env.schema).name
            for obj_type in obj_types
        ]
        obj_type_names.sort()
        suggestions.extend(obj_type_names)

    raise IdentCompletionException(suggestions)


================================================
FILE: edb/edgeql/compiler/func.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL compiler routines for function calls and operators."""


from __future__ import annotations
from typing import (
    Callable,
    Final,
    Optional,
    Protocol,
    Iterable,
    Sequence,
    cast,
    TYPE_CHECKING,
)

from edb import errors
from edb.common import ast
from edb.common import parsing
from edb.common.typeutils import not_none

from edb.ir import ast as irast
from edb.ir import staeval
from edb.ir import utils as irutils
from edb.ir import typeutils as irtyputils

from edb.schema import constraints as s_constr
from edb.schema import delta as sd
from edb.schema import functions as s_func
from edb.schema import globals as s_globals
from edb.schema import modules as s_mod
from edb.schema import name as sn
from edb.schema import objtypes as s_objtypes
from edb.schema import operators as s_oper
from edb.schema import permissions as s_permissions
from edb.schema import scalars as s_scalars
from edb.schema import types as s_types
from edb.schema import indexes as s_indexes
from edb.schema import schema as s_schema
from edb.schema import utils as s_utils

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes as ft
from edb.edgeql import parser as qlparser

from . import casts
from . import context
from . import dispatch
from . import pathctx
from . import polyres
from . import schemactx
from . import setgen
from . import stmt
from . import typegen

if TYPE_CHECKING:
    import uuid


@dispatch.compile.register(qlast.FunctionCall)
def compile_FunctionCall(
    expr: qlast.FunctionCall, *, ctx: context.ContextLevel
) -> irast.Set:

    env = ctx.env

    funcname: sn.Name
    if isinstance(expr.func, str):
        if (
            ctx.env.options.func_params is not None
            and ctx.env.options.func_params.get_by_name(
                env.schema, expr.func
            )
        ):
            raise errors.QueryError(
                f'parameter `{expr.func}` is not callable',
                span=expr.span)

        funcname = sn.UnqualName(expr.func)
    else:
        funcname = sn.QualName(*expr.func)

    try:
        funcs = s_func.lookup_functions(
            funcname,
            module_aliases=ctx.modaliases,
            schema=env.schema,
        )
    except errors.InvalidReferenceError as e:
        s_utils.enrich_schema_lookup_error(
            e,
            funcname,
            modaliases=ctx.modaliases,
            schema=env.schema,
            suggestion_limit=1,
            item_type=s_types.Type,
            span=expr.span,
            hint_text='did you mean to cast to'
        )
        raise

    prefer_subquery_args = any(
        func.get_prefer_subquery_args(env.schema) for func in funcs
    )

    if funcs is None:
        raise errors.QueryError(
            f'could not resolve function name {funcname}',
            span=expr.span)

    in_polymorphic_func = (
        ctx.env.options.func_params is not None and
        ctx.env.options.func_params.has_polymorphic(env.schema)
    )

    in_abstract_constraint = (
        in_polymorphic_func and
        ctx.env.options.schema_object_context is s_constr.Constraint
    )

    typemods = polyres.find_callable_typemods(
        funcs, num_args=len(expr.args), kwargs_names=expr.kwargs.keys(),
        ctx=ctx)
    args, kwargs = compile_func_call_args(
        expr, funcname, typemods, prefer_subquery_args=prefer_subquery_args,
        ctx=ctx)
    with errors.ensure_span(expr.span):
        matched = polyres.find_callable(
            funcs, args=args, kwargs=kwargs, ctx=ctx)
    if not matched:
        alts = [f.get_signature_as_str(env.schema) for f in funcs]
        sig: list[str] = []
        # This is used to generate unique arg names.
        argnum = 0
        for argtype, _ in args:
            # Skip any name colliding with kwargs.
            while f'arg{argnum}' in kwargs:
                argnum += 1
            ty = schemactx.get_material_type(argtype, ctx=ctx)
            sig.append(
                f'arg{argnum}: {ty.get_displayname(env.schema)}'
            )
            argnum += 1
        for kwname, (kwtype, _) in kwargs.items():
            ty = schemactx.get_material_type(kwtype, ctx=ctx)
            sig.append(
                f'NAMED ONLY {kwname}: {kwtype.get_displayname(env.schema)}'
            )

        signature = f'{funcname}({", ".join(sig)})'

        if not funcs:
            hint = None
        elif len(alts) == 1:
            hint = f'Did you want "{alts[0]}"?'
        else:  # Multiple alternatives
            hint = (
                f'Did you want one of the following functions instead:\n' +
                ('\n'.join(alts))
            )

        raise errors.QueryError(
            f'function "{signature}" does not exist',
            hint=hint,
            span=expr.span)
    elif len(matched) > 1:
        if in_abstract_constraint:
            matched_call = matched[0]
        else:
            alts = [m.func.get_signature_as_str(env.schema) for m in matched]
            raise errors.QueryError(
                f'function {funcname} is not unique',
                hint=f'Please disambiguate between the following '
                     f'alternatives:\n' +
                     ('\n'.join(alts)),
                span=expr.span)
    else:
        matched_call = matched[0]

    func = matched_call.func
    assert isinstance(func, s_func.Function)

    if matched_call.server_param_conversions:
        for param_name, conversions in (
            matched_call.server_param_conversions.items()
        ):
            if param_name not in ctx.env.server_param_conversions:
                ctx.env.server_param_conversions[param_name] = {}
            ctx.env.server_param_conversions[param_name].update(
                conversions
            )
        ctx.env.server_param_conversion_calls.append((
            func.get_signature_as_str(env.schema),
            expr.span,
        ))

    inline_func = None
    if (
        func.get_language(ctx.env.schema) == qlast.Language.EdgeQL
        and (
            func.get_volatility(ctx.env.schema) == ft.Volatility.Modifying
            or func.get_is_inlined(ctx.env.schema)
        )
    ):
        inline_func = s_func.compile_function_inline(
            schema=ctx.env.schema,
            context=sd.CommandContext(
                schema=ctx.env.schema,
            ),
            body=not_none(func.get_nativecode(ctx.env.schema)),
            func_name=func.get_name(ctx.env.schema),
            params=func.get_params(ctx.env.schema),
            language=not_none(func.get_language(ctx.env.schema)),
            return_type=func.get_return_type(ctx.env.schema),
            return_typemod=func.get_return_typemod(ctx.env.schema),
            track_schema_ref_exprs=False,
            inlining_context=ctx,
        )

    # Record this node in the list of potential DML expressions.
    if func.get_volatility(env.schema) == ft.Volatility.Modifying:
        ctx.env.dml_exprs.append(expr)

        # This is some kind of mutation, so we need to check if it is
        # allowed.
        if ctx.env.options.in_ddl_context_name is not None:
            raise errors.SchemaDefinitionError(
                f'mutations are invalid in '
                f'{ctx.env.options.in_ddl_context_name}',
                span=expr.span,
            )
        elif (
            (dv := ctx.defining_view) is not None
            and dv.get_expr_type(ctx.env.schema) is s_types.ExprType.Select
            and not irutils.is_trivial_free_object(
                not_none(ctx.partial_path_prefix))
        ):
            # This is some shape in a regular query. Although
            # DML is not allowed in the computable, but it may
            # be possible to refactor it.
            raise errors.QueryError(
                f"mutations are invalid in a shape's computed expression",
                hint=(
                    f'To resolve this try to factor out the mutation '
                    f'expression into the top-level WITH block.'
                ),
                span=expr.span,
            )

    func_name = func.get_shortname(env.schema)

    matched_func_params = func.get_params(env.schema)
    variadic_param = matched_func_params.find_variadic(env.schema)
    variadic_param_type = None
    if variadic_param is not None:
        variadic_param_type = typegen.type_to_typeref(
            variadic_param.get_type(env.schema),
            env=env,
        )

    matched_func_ret_type = func.get_return_type(env.schema)
    is_polymorphic = (
        any(p.get_type(env.schema).is_polymorphic(env.schema)
            for p in matched_func_params.objects(env.schema)) and
        matched_func_ret_type.is_polymorphic(env.schema)
    )

    matched_func_initial_value = func.get_initial_value(env.schema)

    final_args, param_name_to_arg_key = finalize_args(
        matched_call,
        guessed_typemods=typemods,
        is_polymorphic=is_polymorphic,
        ctx=ctx,
    )

    # Forbid DML in non-scalar function args
    if func.get_nativecode(env.schema):
        # We are sure that there is no such functions implemented with SQL

        for arg in final_args.values():
            if arg.expr.typeref.is_scalar:
                continue
            if not irutils.contains_dml(arg.expr):
                continue
            raise errors.UnsupportedFeatureError(
                'newly created or updated objects cannot be passed to '
                'functions',
                span=arg.expr.span
            )

    if not in_abstract_constraint:
        # We cannot add strong references to functions from
        # abstract constraints, since we cannot know which
        # form of the function is actually used.
        env.add_schema_ref(func, expr)

    ctx.env.required_permissions.update(
        func.get_required_permissions(ctx.env.schema).objects(ctx.env.schema)
    )

    func_initial_value: Optional[irast.Set]

    if matched_func_initial_value is not None:
        frag = qlparser.parse_fragment(matched_func_initial_value.text)
        assert isinstance(frag, qlast.Expr)
        iv_ql = qlast.TypeCast(
            expr=frag,
            type=typegen.type_to_ql_typeref(matched_call.return_type, ctx=ctx),
        )
        func_initial_value = dispatch.compile(iv_ql, ctx=ctx)
    else:
        func_initial_value = None

    rtype = matched_call.return_type
    path_id = pathctx.get_expression_path_id(rtype, ctx=ctx)

    if rtype.is_tuple(env.schema):
        rtype = cast(s_types.Tuple, rtype)
        tuple_path_ids = []
        nested_path_ids = []
        for n, st in rtype.iter_subtypes(ctx.env.schema):
            elem_path_id = pathctx.get_tuple_indirection_path_id(
                path_id, n, st, ctx=ctx)

            if isinstance(st, s_types.Tuple):
                nested_path_ids.append([
                    pathctx.get_tuple_indirection_path_id(
                        elem_path_id, nn, sst, ctx=ctx)
                    for nn, sst in st.iter_subtypes(ctx.env.schema)
                ])

            tuple_path_ids.append(elem_path_id)
        for nested in nested_path_ids:
            tuple_path_ids.extend(nested)
    else:
        tuple_path_ids = []

    global_args = None
    if not inline_func:
        global_args = get_globals(
            expr, matched_call, candidates=funcs, ctx=ctx
        )

    volatility = (
        # Incorporate the volatility of any server param conversions
        max([
            func.get_volatility(env.schema),
            *(
                conversion.volatility
                for conversions in (
                    matched_call.server_param_conversions.values()
                )
                for conversion in conversions.values()
            )
        ])
        if matched_call.server_param_conversions else
        func.get_volatility(env.schema)
    )

    fcall = irast.FunctionCall(
        args=final_args,
        func_shortname=func_name,
        backend_name=func.get_backend_name(env.schema),
        func_polymorphic=is_polymorphic,
        func_sql_function=func.get_from_function(env.schema),
        func_sql_expr=func.get_from_expr(env.schema),
        force_return_cast=func.get_force_return_cast(env.schema),
        volatility=volatility,
        sql_func_has_out_params=func.get_sql_func_has_out_params(env.schema),
        error_on_null_result=func.get_error_on_null_result(env.schema),
        preserves_optionality=func.get_preserves_optionality(env.schema),
        preserves_upper_cardinality=func.get_preserves_upper_cardinality(
            env.schema),
        typeref=typegen.type_to_typeref(
            rtype, env=env,
        ),
        typemod=matched_call.func.get_return_typemod(env.schema),
        has_empty_variadic=(matched_call.variadic_arg_count == 0),
        variadic_param_type=variadic_param_type,
        func_initial_value=func_initial_value,
        tuple_path_ids=tuple_path_ids,
        impl_is_strict=(
            func.get_impl_is_strict(env.schema)
            # Inlined functions should always check for null arguments.
            and not inline_func
        ),
        prefer_subquery_args=func.get_prefer_subquery_args(env.schema),
        is_singleton_set_of=func.get_is_singleton_set_of(env.schema),
        global_args=global_args,
        span=expr.span,
        return_polymorphism=matched_call.return_polymorphism,
    )

    # Apply special function handling
    if special_func := _SPECIAL_FUNCTIONS.get(str(func_name)):
        res = special_func(fcall, ctx=ctx)
    elif inline_func:
        res = fcall

        # TODO: Global parameters still use the implicit globals parameter.
        # They should be directly substituted in whenever possible.

        inline_args: dict[str, irast.CallArg | irast.Set] = {}

        # Collect non-default call args to inline
        for param_shortname, arg_key in param_name_to_arg_key.items():
            if (
                isinstance(arg_key, int)
                and matched_call.variadic_arg_id is not None
                and arg_key >= matched_call.variadic_arg_id
            ):
                continue

            arg = final_args[arg_key]
            if arg.is_default:
                continue

            inline_args[param_shortname] = arg

        # Package variadic arguments into an array
        if variadic_param is not None:
            assert variadic_param_type is not None
            assert matched_call.variadic_arg_id is not None
            assert matched_call.variadic_arg_count is not None

            param_shortname = variadic_param.get_parameter_name(env.schema)
            inline_args[param_shortname] = ir_set = setgen.ensure_set(
                irast.Array(
                    elements=[
                        final_args[arg_key].expr
                        for arg_key in range(
                            matched_call.variadic_arg_id,
                            matched_call.variadic_arg_id
                            + matched_call.variadic_arg_count
                        )
                    ],
                    typeref=variadic_param_type,
                ),
                ctx=ctx,
            )

        # Compile default args if necessary
        for param in matched_func_params.objects(env.schema):
            param_shortname = param.get_parameter_name(env.schema)

            if param_shortname in inline_args:
                continue

            else:
                # Missing named only args have their default values already
                # compiled in try_bind_call_args.
                if bound_args := [
                    bound_arg
                    for bound_arg in matched_call.args
                    if (
                        isinstance(bound_arg, polyres.DefaultArg)
                        and bound_arg.name == param_shortname
                    )
                ]:
                    assert len(bound_args) == 1
                    inline_args[param_shortname] = bound_args[0].val
                    continue

                # Check if default is available
                p_default = param.get_default(env.schema)
                if p_default is None:
                    continue

                # Compile default
                assert isinstance(param, s_func.Parameter)
                p_ir_default = dispatch.compile(p_default.parse(), ctx=ctx)
                inline_args[param_shortname] = p_ir_default

        argument_inliner = ArgumentInliner(inline_args, ctx=ctx)
        res.body = argument_inliner.visit(inline_func)

    else:
        res = fcall

    if isinstance(res, irast.FunctionCall) and res.body:
        # If we are generating a special-cased inlined function call,
        # make sure to register all the arguments in the scope tree
        # to ensure that the compiled arguments get picked up when
        # compiling the body.
        for arg in res.args.values():
            pathctx.register_set_in_scope(
                arg.expr,
                optional=(
                    arg.param_typemod == ft.TypeModifier.OptionalType
                ),
                ctx=ctx,
            )

    ir_set = setgen.ensure_set(res, typehint=rtype, path_id=path_id, ctx=ctx)
    return stmt.maybe_add_view(ir_set, ctx=ctx)


class ArgumentInliner(ast.NodeTransformer):

    # Don't look through hidden nodes, they may contain references to nodes
    # which should not be modified. For example, irast.Stmt.parent_stmt.
    skip_hidden = True

    mapped_args: dict[irast.PathId, irast.PathId]
    inlined_arg_keys: list[int | str]

    def __init__(
        self,
        inline_args: dict[str, irast.CallArg | irast.Set],
        ctx: context.ContextLevel,
    ) -> None:
        super().__init__()
        self.inline_args = inline_args
        self.ctx = ctx
        self.mapped_args = {}

    def visit_Set(self, node: irast.Set) -> irast.Base:
        if (
            isinstance(node.expr, irast.FunctionParameter)
            and node.expr.name in self.inline_args
        ):
            arg = self.inline_args[node.expr.name]
            if isinstance(arg, irast.CallArg):
                # Inline param as an expr ref. The pg compiler will find the
                # appropriate rvar.
                self.mapped_args[node.path_id] = arg.expr.path_id
                inlined_param_expr = setgen.ensure_set(
                    irast.InlinedParameterExpr(
                        typeref=arg.expr.typeref,
                        required=node.expr.required,
                        is_global=node.expr.is_global,
                    ),
                    path_id=arg.expr.path_id,
                    ctx=self.ctx,
                )
                inlined_param_expr.shape = node.shape
                return inlined_param_expr
            else:
                # Directly inline the set.
                # Used for default values, which are constants.
                return arg

        elif isinstance(node.expr, irast.Pointer):
            # The set and source path ids must match in order for the pointer
            # to find the correct rvar. If a pointer's source path was modified
            # because of an inlined parameter, modify the pointer's path as
            # well.
            prev_source_path_id = node.expr.source.path_id
            result = cast(irast.Set, self.generic_visit(node))

            if prev_source_path_id in self.mapped_args:
                result = setgen.new_set_from_set(
                    result,
                    path_id=irtyputils.replace_pathid_prefix(
                        result.path_id,
                        prev_source_path_id,
                        self.mapped_args[prev_source_path_id],
                    ),
                    ctx=self.ctx,
                )
                self.mapped_args[node.path_id] = result.path_id

            return result

        return cast(irast.Base, self.generic_visit(node))

    # Don't transform pointer refs.
    # They are updated in other places, such as cardinality inference.
    def visit_PointerRef(
        self, node: irast.PointerRef
    ) -> irast.Base:
        return node

    def visit_TupleIndirectionPointerRef(
        self, node: irast.TupleIndirectionPointerRef
    ) -> irast.Base:
        return node

    def visit_SpecialPointerRef(
        self, node: irast.SpecialPointerRef
    ) -> irast.Base:
        return node

    def visit_TypeIntersectionPointerRef(
        self, node: irast.TypeIntersectionPointerRef
    ) -> irast.Base:
        return node


class _SpecialCaseFunc(Protocol):
    def __call__(
        self, call: irast.FunctionCall, *, ctx: context.ContextLevel
    ) -> irast.Expr:
        pass


_SPECIAL_FUNCTIONS: dict[str, _SpecialCaseFunc] = {}


def _special_case(name: str) -> Callable[[_SpecialCaseFunc], _SpecialCaseFunc]:
    def func(f: _SpecialCaseFunc) -> _SpecialCaseFunc:
        _SPECIAL_FUNCTIONS[name] = f
        return f

    return func


def compile_operator(
    qlexpr: qlast.Expr,
    op_name: str,
    qlargs: list[qlast.Expr],
    *,
    ctx: context.ContextLevel,
) -> irast.Set:

    env = ctx.env
    schema = env.schema
    opers = s_oper.lookup_operators(
        op_name, module_aliases=ctx.modaliases, schema=schema
    )

    if opers is None:
        raise errors.QueryError(
            f'no operator matches the given name and argument types',
            span=qlexpr.span)

    typemods = polyres.find_callable_typemods(
        opers, num_args=len(qlargs), kwargs_names=set(), ctx=ctx)

    prefer_subquery_args = any(
        oper.get_prefer_subquery_args(env.schema) for oper in opers
    )

    args = []

    for ai, qlarg in enumerate(qlargs):
        arg_ir = polyres.compile_arg(
            qlarg,
            typemods[ai],
            prefer_subquery_args=prefer_subquery_args,
            ctx=ctx,
        )

        arg_type = setgen.get_set_type(arg_ir, ctx=ctx)
        if arg_type is None:
            raise errors.QueryError(
                f'could not resolve the type of operand '
                f'#{ai} of {op_name}',
                span=qlarg.span)

        args.append((arg_type, arg_ir))

    # Check if the operator is a derived operator, and if so,
    # find the origins.
    origin_op = opers[0].get_derivative_of(env.schema)
    derivative_op: Optional[s_oper.Operator]
    if origin_op is not None:
        # If this is a derived operator, there should be
        # exactly one form of it.  This is enforced at the DDL
        # level, but check again to be sure.
        if len(opers) > 1:
            raise errors.InternalServerError(
                f'more than one derived operator of the same name: {op_name}',
                span=qlarg.span)

        derivative_op = opers[0]
        opers = s_oper.lookup_operators(origin_op, schema=schema)
        if not opers:
            raise errors.InternalServerError(
                f'cannot find the origin operator for {op_name}',
                span=qlarg.span)
        actual_typemods = [
            param.get_typemod(schema)
            for param in derivative_op.get_params(schema).objects(schema)
        ]
    else:
        derivative_op = None
        actual_typemods = []

    matched = None
    # Some 2-operand operators are special when their operands are
    # arrays or tuples.
    if len(args) == 2:
        coll_opers = None
        # If both of the args are arrays or tuples, potentially
        # compile the operator for them differently than for other
        # combinations.
        if args[0][0].is_tuple(env.schema) and args[1][0].is_tuple(env.schema):
            # Out of the candidate operators, find the ones that
            # correspond to tuples.
            coll_opers = [
                op for op in opers
                if all(
                    param.get_type(schema).is_tuple(schema)
                    for param in op.get_params(schema).objects(schema)
                )
            ]

        elif args[0][0].is_array() and args[1][0].is_array():
            # Out of the candidate operators, find the ones that
            # correspond to arrays.
            coll_opers = [
                op for op in opers
                if all(
                    param.get_type(schema).is_array()
                    for param in op.get_params(schema).objects(schema)
                )
            ]

        # Proceed only if we have a special case of collection operators.
        if coll_opers:
            # Then check if they are recursive (i.e. validation must be
            # done recursively for the subtypes). We rely on the fact that
            # it is forbidden to define an operator that has both
            # recursive and non-recursive versions.
            if not coll_opers[0].get_recursive(schema):
                # The operator is non-recursive, so regular processing
                # is needed.
                matched = polyres.find_callable(
                    coll_opers, args=args, kwargs={}, ctx=ctx)

            else:
                # The recursive operators are usually defined as
                # being polymorphic on all parameters, and so this has
                # a side-effect of forcing both operands to be of
                # the same type (via casting) before the operator is
                # applied.  This might seem suboptmial, since there might
                # be a more specific operator for the types of the
                # elements, but the current version of Postgres
                # actually requires tuples and arrays to be of the
                # same type in comparison, so this behavior is actually
                # what we want.
                matched = polyres.find_callable(
                    coll_opers,
                    args=args,
                    kwargs={},
                    ctx=ctx,
                )

                # Now that we have an operator, we need to validate that it
                # can be applied to the tuple or array elements.
                submatched = validate_recursive_operator(
                    opers, args[0], args[1], ctx=ctx)

                if len(submatched) != 1:
                    # This is an error. We want the error message to
                    # reflect whether no matches were found or too
                    # many, so we preserve the submatches found for
                    # this purpose.
                    matched = submatched

    # No special handling match was necessary, find a normal match.
    if matched is None:
        matched = polyres.find_callable(opers, args=args, kwargs={}, ctx=ctx)

    in_polymorphic_func = (
        ctx.env.options.func_params is not None and
        ctx.env.options.func_params.has_polymorphic(env.schema)
    )

    in_abstract_constraint = (
        in_polymorphic_func and
        ctx.env.options.schema_object_context is s_constr.Constraint
    )

    if not in_polymorphic_func:
        matched = [call for call in matched
                   if not call.func.get_abstract(env.schema)]

    if len(matched) == 1:
        matched_call = matched[0]
    else:
        args_ty = [schemactx.get_material_type(a[0], ctx=ctx) for a in args]
        args_dn = [repr(a.get_displayname(env.schema)) for a in args_ty]

        if len(args_dn) == 2:
            types = f'{args_dn[0]} and {args_dn[1]}'
        else:
            types = ', '.join(a for a in args_dn)

        if not matched:
            hint = ('Consider using an explicit type cast or a conversion '
                    'function.')

            if op_name == 'std::IF':
                hint = (f"The IF and ELSE result clauses must be of "
                        f"compatible types, while the condition clause must "
                        f"be 'std::bool'. {hint}")
            elif op_name == '+':
                str_t = env.schema.get('std::str', type=s_scalars.ScalarType)
                bytes_t = env.schema.get('std::bytes',
                                         type=s_scalars.ScalarType)
                if (
                    all(t.issubclass(env.schema, str_t) for t in args_ty) or
                    all(t.issubclass(env.schema, bytes_t) for t in args_ty) or
                    all(t.is_array() for t in args_ty)
                ):
                    hint = 'Consider using the "++" operator for concatenation'

            if isinstance(qlexpr, qlast.BinOp) and qlexpr.set_constructor:
                msg = (
                    f'set constructor has arguments of incompatible types '
                    f'{types}'
                )
            else:
                msg = (
                    f'operator {str(op_name)!r} cannot be applied to '
                    f'operands of type {types}'
                )
            raise errors.InvalidTypeError(
                msg,
                hint=hint,
                span=qlexpr.span)
        elif len(matched) > 1:
            if in_abstract_constraint:
                matched_call = matched[0]
            else:
                detail = ', '.join(
                    f'`{m.func.get_verbosename(ctx.env.schema)}`'
                    for m in matched
                )
                raise errors.QueryError(
                    f'operator {str(op_name)!r} is ambiguous for '
                    f'operands of type {types}',
                    hint=f'Possible variants: {detail}.',
                    span=qlexpr.span)

    oper = matched_call.func
    assert isinstance(oper, s_oper.Operator)
    env.add_schema_ref(oper, expr=qlexpr)
    oper_name = oper.get_shortname(env.schema)
    str_oper_name = str(oper_name)

    is_singleton_set_of = oper.get_is_singleton_set_of(env.schema)

    matched_params = oper.get_params(env.schema)
    rtype = matched_call.return_type
    matched_rtype = oper.get_return_type(env.schema)

    is_polymorphic = (
        any(p.get_type(env.schema).is_polymorphic(env.schema)
            for p in matched_params.objects(env.schema)) and
        matched_rtype.is_polymorphic(env.schema)
    )

    final_args, _ = finalize_args(
        matched_call,
        actual_typemods=actual_typemods,
        guessed_typemods=typemods,
        is_polymorphic=is_polymorphic,
        ctx=ctx,
    )

    if str_oper_name in {
        'std::UNION', 'std::IF', 'std::??'
    } and rtype.is_object_type():
        # Special case for the UNION, IF and ?? operators: instead of common
        # parent type, we return a union type.
        if str_oper_name == 'std::IF':
            larg, _, rarg = (a.expr for a in final_args.values())
        else:
            larg, rarg = (a.expr for a in final_args.values())

        left_type = setgen.get_set_type(larg, ctx=ctx)
        right_type = setgen.get_set_type(rarg, ctx=ctx)
        rtype = schemactx.get_union_type(
            [left_type, right_type],
            preserve_derived=True,
            ctx=ctx,
            span=qlexpr.span
        )

    from_op = oper.get_from_operator(env.schema)
    sql_operator = None
    if (
        from_op is not None
        and oper.get_code(env.schema) is None
        and oper.get_from_function(env.schema) is None
    ):
        sql_operator = tuple(from_op)

    origin_name: Optional[sn.QualName]
    origin_module_id: Optional[uuid.UUID]
    if derivative_op is not None:
        origin_name = oper_name
        origin_module_id = env.schema.get_global(
            s_mod.Module, origin_name.module).id
        oper_name = derivative_op.get_shortname(env.schema)
        is_singleton_set_of = derivative_op.get_is_singleton_set_of(env.schema)
    else:
        origin_name = None
        origin_module_id = None

    from_func = oper.get_from_function(env.schema)
    if from_func is None:
        sql_func = None
    else:
        sql_func = tuple(from_func)

    node = irast.OperatorCall(
        args=final_args,
        func_shortname=oper_name,
        func_polymorphic=is_polymorphic,
        origin_name=origin_name,
        origin_module_id=origin_module_id,
        sql_function=sql_func,
        func_sql_expr=oper.get_from_expr(env.schema),
        sql_operator=sql_operator,
        force_return_cast=oper.get_force_return_cast(env.schema),
        volatility=oper.get_volatility(env.schema),
        operator_kind=oper.get_operator_kind(env.schema),
        typeref=typegen.type_to_typeref(rtype, env=env),
        typemod=oper.get_return_typemod(env.schema),
        tuple_path_ids=[],
        impl_is_strict=oper.get_impl_is_strict(env.schema),
        prefer_subquery_args=oper.get_prefer_subquery_args(env.schema),
        is_singleton_set_of=is_singleton_set_of,
        span=qlexpr.span,
        return_polymorphism=matched_call.return_polymorphism,
    )

    _check_free_shape_op(node, ctx=ctx)

    return stmt.maybe_add_view(
        setgen.ensure_set(node, typehint=rtype, ctx=ctx),
        ctx=ctx)


# These ops are all footguns when used with free shapes,
# so we ban them
INVALID_FREE_SHAPE_OPS: Final = {
    sn.QualName('std', x) for x in [
        'DISTINCT', '=', '!=', '?=', '?!=', 'IN', 'NOT IN',
        'assert_distinct',
    ]
}


def _check_free_shape_op(ir: irast.Call, *, ctx: context.ContextLevel) -> None:
    if ir.func_shortname not in INVALID_FREE_SHAPE_OPS:
        return

    virt_obj = ctx.env.schema.get(
        'std::FreeObject', type=s_objtypes.ObjectType)
    for arg in ir.args.values():
        typ = setgen.get_set_type(arg.expr, ctx=ctx)
        if typ.issubclass(ctx.env.schema, virt_obj):
            raise errors.QueryError(
                f'cannot use {ir.func_shortname.name} on free shape',
                span=ir.span)


def validate_recursive_operator(
    opers: Iterable[s_func.CallableObject],
    larg: tuple[s_types.Type, irast.Set],
    rarg: tuple[s_types.Type, irast.Set],
    *,
    ctx: context.ContextLevel,
) -> list[polyres.BoundCall]:

    matched: list[polyres.BoundCall] = []

    # if larg and rarg are tuples or arrays, recurse into their subtypes
    if (
        (
            larg[0].is_tuple(ctx.env.schema)
            and rarg[0].is_tuple(ctx.env.schema)
        ) or (
            larg[0].is_array()
            and rarg[0].is_array()
        )
    ):
        assert isinstance(larg[0], s_types.Collection)
        assert isinstance(rarg[0], s_types.Collection)
        for rsub, lsub in zip(larg[0].get_subtypes(ctx.env.schema),
                              rarg[0].get_subtypes(ctx.env.schema)):
            matched = validate_recursive_operator(
                opers, (lsub, larg[1]), (rsub, rarg[1]), ctx=ctx)
            if len(matched) != 1:
                # this is an error already
                break

    else:
        # we just have a pair of non-containers to compare
        matched = polyres.find_callable(
            opers, args=[larg, rarg], kwargs={}, ctx=ctx)

    return matched


def compile_func_call_args(
    expr: qlast.FunctionCall,
    funcname: sn.Name,
    typemods: dict[int | str, ft.TypeModifier],
    *,
    prefer_subquery_args: bool=False,
    ctx: context.ContextLevel
) -> tuple[
    list[tuple[s_types.Type, irast.Set]],
    dict[str, tuple[s_types.Type, irast.Set]],
]:
    args = []
    kwargs = {}

    for ai, arg in enumerate(expr.args):
        arg_ir = polyres.compile_arg(
            arg, typemods[ai], prefer_subquery_args=prefer_subquery_args,
            ctx=ctx)
        arg_type = setgen.get_set_type(arg_ir, ctx=ctx)
        if arg_type is None:
            raise errors.QueryError(
                f'could not resolve the type of positional argument '
                f'#{ai} of function {funcname}',
                span=arg.span)

        args.append((arg_type, arg_ir))

    for aname, arg in expr.kwargs.items():
        arg_ir = polyres.compile_arg(
            arg, typemods[aname], prefer_subquery_args=prefer_subquery_args,
            ctx=ctx)

        arg_type = setgen.get_set_type(arg_ir, ctx=ctx)
        if arg_type is None:
            raise errors.QueryError(
                f'could not resolve the type of named argument '
                f'${aname} of function {funcname}',
                span=arg.span)

        kwargs[aname] = (arg_type, arg_ir)

    return args, kwargs


def get_globals(
    expr: qlast.FunctionCall,
    bound_call: polyres.BoundCall,
    candidates: Sequence[s_func.Function],
    *, ctx: context.ContextLevel,
) -> list[irast.Set]:
    assert isinstance(bound_call.func, s_func.Function)

    func_language = bound_call.func.get_language(ctx.env.schema)
    if func_language is not qlast.Language.EdgeQL:
        return []

    schema = ctx.env.schema

    globs: set[s_globals.Global | s_permissions.Permission] = set()
    if bound_call.func.get_params(schema).has_objects(schema):
        # We look at all the candidates since it might be used in a
        # subtype's overload.
        # TODO: be careful and only do this in the needed cases
        for func in candidates:
            globs.update(set(func.get_used_globals(schema).objects(schema)))
            globs.update(set(func.get_used_permissions(schema).objects(schema)))
    else:
        globs.update(bound_call.func.get_used_globals(schema).objects(schema))
        globs.update(
            bound_call.func.get_used_permissions(schema).objects(schema)
        )

    if (
        (
            ctx.env.options.func_name is None
            or ctx.env.options.func_params is None
        )
        and not ctx.env.options.json_parameters
    ):
        glob_set = setgen.get_globals_as_json(
            tuple(globs), ctx=ctx, span=expr.span)
    else:
        # Make sure that we properly track the globals we use in functions
        for glob in globs:
            setgen.get_global_param(glob, ctx=ctx)

        glob_set = setgen.get_func_global_json_arg(ctx=ctx)

    return [glob_set]


def finalize_args(
    bound_call: polyres.BoundCall,
    *,
    actual_typemods: Sequence[ft.TypeModifier] = (),
    guessed_typemods: dict[int | str, ft.TypeModifier],
    is_polymorphic: bool = False,
    ctx: context.ContextLevel,
) -> tuple[dict[int | str, irast.CallArg], dict[str, int | str]]:

    args: dict[int | str, irast.CallArg] = {}
    param_name_to_arg: dict[str, int | str] = {}
    position_index: int = 0

    for i, barg in enumerate(bound_call.args):
        arg_val = barg.val
        arg_type_path_id: Optional[irast.PathId] = None
        if isinstance(barg, polyres.DefaultBitmask):
            # defaults bitmask
            param_name_to_arg['__defaults_mask__'] = -1
            args[-1] = irast.CallArg(
                expr=arg_val,
                param_typemod=ft.TypeModifier.SingletonType,
            )
            continue
        assert isinstance(barg, polyres.ValueArg)

        if actual_typemods:
            param_mod = actual_typemods[i]
        else:
            param_mod = barg.param_typemod

        if param_mod is not ft.TypeModifier.SetOfType:
            param_shortname = barg.name

            if param_shortname in bound_call.null_args:
                pathctx.register_set_in_scope(arg_val, optional=True, ctx=ctx)

            # If we guessed the argument was optional but it wasn't,
            # try to go back and make it *not* optional.
            elif (
                param_mod is ft.TypeModifier.SingletonType
                and isinstance(barg, polyres.PassedArg)
                and param_mod is not guessed_typemods[barg.arg_id]
            ):
                for child in ctx.path_scope.children:
                    if (
                        child.path_id == arg_val.path_id
                        or (
                            arg_val.path_scope_id is not None
                            and child.unique_id == arg_val.path_scope_id
                        )
                    ):
                        child.optional = False

            # Object type arguments to functions may be overloaded, so
            # we populate a path id field so that we can also pass the
            # type as an argument on the pgsql side. If the param type
            # is "anytype", though, then it can't be overloaded on
            # that argument.
            arg_type = setgen.get_set_type(arg_val, ctx=ctx)
            if (
                isinstance(arg_type, s_objtypes.ObjectType)
                and not barg.orig_param_type.is_any(ctx.env.schema)
            ):
                arg_type_path_id = pathctx.extend_path_id(
                    arg_val.path_id,
                    ptrcls=setgen.resolve_ptr(
                        arg_type, '__type__', track_ref=None, ctx=ctx
                    ),
                    ctx=ctx,
                )
        else:
            is_array_agg = (
                isinstance(bound_call.func, s_func.Function)
                and (
                    bound_call.func.get_shortname(ctx.env.schema)
                    == sn.QualName('std', 'array_agg')
                )
            )

            if (
                # Ideally, we should implicitly slice all array values,
                # but in practice, the vast majority of large arrays
                # will come from array_agg, and so we only care about
                # that.
                is_array_agg
                and ctx.expr_exposed
                and ctx.implicit_limit
                and isinstance(arg_val.expr, irast.SelectStmt)
                and arg_val.expr.limit is None
            ):
                arg_val.expr.limit = dispatch.compile(
                    qlast.Constant.integer(ctx.implicit_limit),
                    ctx=ctx,
                )

        paramtype = barg.param_type
        param_kind = barg.param_kind
        if param_kind is ft.ParameterKind.VariadicParam:
            # For variadic params, paramtype would be array,
            # and we need T to cast the arguments.
            assert isinstance(paramtype, s_types.Array)
            paramtype = list(paramtype.get_subtypes(ctx.env.schema))[0]

        # Check if we need to cast the argument value before passing
        # it to the callable.
        compatible = s_types.is_type_compatible(
            paramtype, barg.valtype, schema=ctx.env.schema,
        )

        if not compatible:
            # The callable form was chosen via an implicit cast,
            # cast the arguments so that the backend has no
            # wiggle room to apply its own (potentially different)
            # casting.
            orig_arg_val = arg_val
            arg_val = casts.compile_cast(
                arg_val, paramtype, span=None, ctx=ctx)
            if ctx.path_scope.is_optional(orig_arg_val.path_id):
                pathctx.register_set_in_scope(arg_val, optional=True, ctx=ctx)

        arg = irast.CallArg(
            expr=arg_val,
            expr_type_path_id=arg_type_path_id,
            is_default=isinstance(barg, polyres.DefaultArg),
            param_typemod=param_mod,
            polymorphism=barg.polymorphism,
        )
        param_shortname = barg.name
        if param_kind is ft.ParameterKind.NamedOnlyParam:
            args[param_shortname] = arg
            param_name_to_arg[param_shortname] = param_shortname
        else:
            args[position_index] = arg
            if (
                # Variadic args will all have the same name, but different
                # indexes. We want to take the first index.
                param_shortname not in param_name_to_arg
            ):
                param_name_to_arg[param_shortname] = position_index
            position_index += 1

    return args, param_name_to_arg


@_special_case('ext::ai::search')
def compile_ext_ai_search(
    call: irast.FunctionCall, *, ctx: context.ContextLevel
) -> irast.Expr:
    indexes = _validate_object_search_call(
        call,
        context="ext::ai::search()",
        object_arg=call.args[0],
        index_name=sn.QualName("ext::ai", "index"),
        ctx=ctx,
    )

    schema = ctx.env.schema

    index_metadata = {}
    for typeref, index in indexes.items():
        dimensions = index.must_get_json_annotation(
            schema,
            sn.QualName("ext::ai", "embedding_dimensions"),
            int,
        )
        kwargs = index.get_concrete_kwargs(schema)
        df_expr = kwargs.get("distance_function")
        if df_expr is not None:
            df = df_expr.ensure_compiled(
                schema,
                as_fragment=True,
                context=None,
            ).as_python_value()
        else:
            df = "Cosine"

        match df:
            case "Cosine":
                distance_fname = "cosine_distance"
            case "InnerProduct":
                distance_fname = "neg_inner_product"
            case "L2":
                distance_fname = "euclidean_distance"
            case _:
                raise RuntimeError(f"unsupported distance_function: {df}")

        distance_func = schema.get_by_shortname(
            s_func.Function, sn.QualName("ext::pgvector", distance_fname)
        )[0]

        index_metadata[typeref] = {
            "id": s_indexes.get_ai_index_id(schema, index),
            "dimensions": dimensions,
            "distance_function": (
                distance_func.get_shortname(schema),
                distance_func.get_backend_name(schema),
            ),
        }
    call.extras = {"index_metadata": index_metadata}

    return call


@_special_case('ext::ai::to_context')
def compile_ext_ai_to_str(
    call: irast.FunctionCall, *, ctx: context.ContextLevel
) -> irast.Expr:
    indexes = _validate_object_search_call(
        call,
        context="ext::ai::to_context()",
        object_arg=call.args[0],
        index_name=sn.QualName("ext::ai", "index"),
        ctx=ctx,
    )

    index = next(iter(indexes.values()))
    index_expr = index.get_expr(ctx.env.schema)
    assert index_expr is not None

    with ctx.detached() as subctx:
        subctx.partial_path_prefix = call.args[0].expr
        subctx.anchors["__subject__"] = call.args[0].expr
        call.body = dispatch.compile(
            qlast.FunctionCall(
                func=('__std__', 'assert_exists'),
                args=[index_expr.parse()],
                kwargs={
                    'message': qlast.Constant.string(
                        'Object context is not set.'
                    ),
                }
            ),
            ctx=subctx,
        )

    return call


@_special_case('std::fts::search')
def compile_fts_search(
    call: irast.FunctionCall, *, ctx: context.ContextLevel
) -> irast.Expr:
    _validate_object_search_call(
        call,
        context="std::fts::search()",
        object_arg=call.args[0],
        index_name=sn.QualName("std::fts", "index"),
        ctx=ctx,
    )

    return call


def _validate_object_search_call(
    call: irast.FunctionCall,
    *,
    context: str,
    object_arg: irast.CallArg,
    index_name: sn.QualName,
    ctx: context.ContextLevel,
) -> dict[irast.TypeRef, s_indexes.Index]:
    # validate that object has std::fts::index index
    object_typeref = object_arg.expr.typeref
    object_typeref = object_typeref.material_type or object_typeref
    stype_id = object_typeref.id

    schema = ctx.env.schema
    span = object_arg.span

    stype = schema.get_by_id(stype_id, type=s_types.Type)
    indexes = {}

    if union_variants := stype.get_union_of(schema):
        for variant in union_variants.objects(schema):
            schema, variant = variant.material_type(schema)
            idx = _validate_has_object_index(
                variant, schema, span, context, index_name)
            indexes[typegen.type_to_typeref(variant, ctx.env)] = idx
    else:
        idx = _validate_has_object_index(
            stype, schema, span, context, index_name)
        indexes[object_typeref] = idx

    return indexes


def _validate_has_object_index(
    stype: s_types.Type,
    schema: s_schema.Schema,
    span: Optional[parsing.Span],
    context: str,
    index_name: sn.QualName,
) -> s_indexes.Index:
    if isinstance(stype, s_indexes.IndexableSubject):
        (obj_index, _) = s_indexes.get_effective_object_index(
            schema, stype, index_name
        )
    else:
        obj_index = None

    if not obj_index:
        raise errors.InvalidReferenceError(
            f"{context} requires an {index_name} index on type "
            f"'{stype.get_displayname(schema)}'",
            span=span,
        )

    return obj_index


@_special_case('std::fts::with_options')
def compile_fts_with_options(
    call: irast.FunctionCall, *, ctx: context.ContextLevel
) -> irast.Expr:
    # language has already been typechecked to be an enum
    lang = call.args['language'].expr
    assert lang.typeref
    lang_ty_id = lang.typeref.id
    lang_ty = ctx.env.schema.get_by_id(lang_ty_id, type=s_scalars.ScalarType)
    assert lang_ty

    lang_domain = set()  # languages that the fts index needs to support
    try:
        lang_const = staeval.evaluate_to_python_val(lang, ctx.env.schema)
    except staeval.UnsupportedExpressionError:
        lang_const = None

    if lang_const is not None:
        # language is constant
        # -> determine its only value at compile time
        lang_domain.add(lang_const.lower())
    else:
        # language is not constant
        # -> use all possible values of the enum
        enum_values = lang_ty.get_enum_values(ctx.env.schema)
        assert enum_values
        for enum_value in enum_values:
            lang_domain.add(enum_value.lower())

    # weight_category
    weight_expr = call.args['weight_category'].expr
    try:
        weight: str = staeval.evaluate_to_python_val(
            weight_expr, ctx.env.schema)
    except staeval.UnsupportedExpressionError:
        raise errors.InvalidValueError(
            f"std::fts::search weight_category must be a constant",
            span=weight_expr.span,
        ) from None

    return irast.FTSDocument(
        text=call.args[0].expr,
        language=lang,
        language_domain=lang_domain,
        weight=weight,
        typeref=typegen.type_to_typeref(
            ctx.env.schema.get('std::fts::document', type=s_scalars.ScalarType),
            env=ctx.env,
        )
    )


@_special_case('std::_warn_on_call')
def compile_warn_on_call(
    call: irast.FunctionCall, *, ctx: context.ContextLevel
) -> irast.Expr:
    ctx.log_warning(
        errors.QueryError('Test warning please ignore', span=call.span)
    )
    return call


================================================
FILE: edb/edgeql/compiler/group.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

from typing import Any, Optional, Sequence

from edb.common import ast as ast_visitor

from edb.edgeql import qltypes
from edb.ir import ast as irast

from . import context
from . import inference
from . import setgen


class FindAggregatingUses(ast_visitor.NodeVisitor):
    """
    Find aggregated uses of a target node that can be hoisted.
    """
    skip_hidden = True
    extra_skips = frozenset(['materialized_sets'])

    def __init__(
        self,
        target: irast.PathId,
        *,
        ctx: context.ContextLevel,
    ) -> None:
        super().__init__()
        self.target = target
        self.aggregate: Optional[irast.Set] = None
        self.sightings: set[Optional[irast.Set]] = set()
        self.ctx = ctx
        # Track pathids that we've seen. pathids that we are interested
        # in but haven't seen get marked as False.
        self.seen: dict[irast.PathId, bool] = {}
        self.skippable: dict[
            Optional[irast.Set], frozenset[irast.PathId]] = {}
        self.scope_tree = ctx.path_scope
        # We don't bother trying to reuse the existing inference
        # context because we make singleton assumptions that it
        # wouldn't and because ignore_computed_cards could invalidate
        # it.
        self.infctx = inference.make_ctx(ctx.env)._replace(
            singletons=frozenset({target}),
            ignore_computed_cards=True,
            # Don't update the IR with the results!
            make_updates=False,
        )

    def visit_Stmt(self, stmt: irast.Stmt) -> Any:
        # Sometimes there is sharing, so we want the official scope
        # for a node to be based on its appearance in the result,
        # not in a subquery.
        # I think it might not actually matter, though.

        old = self.aggregate

        # Can't handle ORDER/LIMIT/OFFSET which operate on the whole set
        # TODO: but often we probably could with arguments to the
        # aggregates, as long as the argument to the aggregate is just
        # a reference
        if isinstance(stmt, irast.SelectStmt) and (
            stmt.orderby or stmt.limit or stmt.offset or stmt.materialized_sets
        ):
            self.aggregate = None

        self.visit(stmt.bindings)
        if stmt.iterator_stmt:
            self.visit(stmt.iterator_stmt)
        if isinstance(stmt, (irast.MutatingStmt, irast.GroupStmt)):
            self.visit(stmt.subject)
        if isinstance(stmt, irast.GroupStmt):
            for v in stmt.using.values():
                self.visit(v)
        self.visit(stmt.result)

        res = self.generic_visit(stmt)

        self.aggregate = old

        return res

    def repeated_node_visit(self, node: irast.Base) -> None:
        if isinstance(node, irast.Set):
            self.seen[node.path_id] = True

    def visit_Set(self, node: irast.Set, skip_rptr: bool = False) -> None:
        self.seen[node.path_id] = True

        if node.path_id == self.target:
            self.sightings.add(self.aggregate)
            return

        old_scope = self.scope_tree
        if node.path_scope_id is not None:
            self.scope_tree = self.ctx.env.scope_tree_nodes[node.path_scope_id]

        # We also can't handle references inside of a semi-join,
        # because the bodies are executed one at a time, and so the
        # semi-join deduplication doesn't work.
        is_semijoin = (
            isinstance(node.expr, irast.Pointer)
            and node.path_id.is_objtype_path()
            and not self.scope_tree.is_visible(node.expr.source.path_id)
        )

        old = self.aggregate
        if is_semijoin:
            self.aggregate = None

        self.visit(node.shape)

        if isinstance(node.expr, irast.Pointer):
            sub_expr = node.expr.expr
            if not sub_expr:
                self.visit(node.expr.source)
            else:
                if node.expr.source.path_id not in self.seen:
                    self.seen[node.expr.source.path_id] = False
        else:
            sub_expr = node.expr

        if isinstance(sub_expr, irast.Call):
            self.process_call(sub_expr, node)
        else:
            self.visit(sub_expr)

        self.aggregate = old
        self.scope_tree = old_scope

    def process_call(self, node: irast.Call, ir_set: irast.Set) -> None:
        # It needs to be backed by an actual SQL function and must
        # not return SET OF
        returns_set = node.typemod == qltypes.TypeModifier.SetOfType
        calls_sql_func = (
            isinstance(node, irast.FunctionCall)
            and node.func_sql_function
        )
        for arg in node.args.values():
            typemod = arg.param_typemod
            old = self.aggregate
            # If this *returns* a set, it is going to mess things up since
            # the operation can't actually run on multiple things...

            old_seen = None

            # TODO: we would like to do better in some cases with
            # DISTINCT and the like where there are built in features
            # to do it in a GROUP
            if returns_set:
                self.aggregate = None
            elif (
                calls_sql_func
                and typemod == qltypes.TypeModifier.SetOfType
                # Don't hoist aggregates whose outputs contain objects
                # (I think this can only be array_agg).
                #
                # We have to eta-expand to put a shape on them anyway,
                # so there's no real point, and we mishandled that
                # case in a few places.  Eventually we'll want to properly
                # be able to serialize in the first place, though.
                and not setgen.get_set_type(
                    ir_set, ctx=self.ctx).contains_object(self.ctx.env.schema)
            ):
                old_seen = self.seen
                self.seen = {}
                self.aggregate = ir_set
            self.visit(arg)
            self.aggregate = old

            force_fail = False
            if old_seen is not None:
                self.skippable[ir_set] = frozenset({
                    k for k, v in self.seen.items() if not v
                    and self.scope_tree.is_visible(k)
                })
                for k, was_seen in self.seen.items():
                    # If we referred to some visible set and also
                    # spotted the target, we can't actually compile
                    # the target separately, so ditch it.
                    if (
                        was_seen
                        and self.scope_tree.is_visible(k)
                        and ir_set in self.sightings
                    ):
                        force_fail = True
                        self.sightings.discard(ir_set)
                        self.sightings.add(None)
                    old_seen[k] = self.seen.get(k, False) | was_seen

                # If, assuming the target is single, the aggregate is
                # still multi, then we can't extract it, since that
                # would lead to actually return multiple elements in a
                # SQL subquery.
                if (
                    ir_set in self.sightings
                    and inference.infer_cardinality(
                        arg.expr, scope_tree=self.scope_tree,
                        ctx=self.infctx).is_multi()
                ):
                    force_fail = True

                self.seen = old_seen

            if force_fail:
                self.sightings.discard(ir_set)
                self.sightings.add(None)


def infer_group_aggregates(
    irs: Sequence[irast.Base],
    *,
    ctx: context.ContextLevel,
) -> None:
    groups = ast_visitor.find_children(irs, irast.GroupStmt)
    for stmt in groups:
        visitor = FindAggregatingUses(
            stmt.group_binding.path_id,
            ctx=ctx,
        )
        visitor.visit(stmt.result)
        stmt.group_aggregate_sets = {
            k: visitor.skippable.get(k, frozenset())
            for k in visitor.sightings
        }


================================================
FILE: edb/edgeql/compiler/inference/__init__.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2015-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

__all__ = (
    'infer_cardinality',
    'infer_volatility',
    'infer_multiplicity',
    'InfCtx',
    'make_ctx',
)

from .cardinality import infer_cardinality  # NOQA
from .context import InfCtx, make_ctx  # NOQA
from .multiplicity import infer_multiplicity  # NOQA
from .volatility import infer_volatility  # NOQA


================================================
FILE: edb/edgeql/compiler/inference/cardinality.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL cardinality inference.

A top-down cardinality inferer that traverses the full AST populating
cardinality fields and performing cardinality checks.
"""


from __future__ import annotations
from typing import (
    Optional,
    Iterable,
    Sequence,
    NamedTuple,
)

import enum
import functools
import uuid

from edb import errors
from edb.common import parsing

from edb.edgeql import qltypes

from edb.schema import name as sn
from edb.schema import types as s_types
from edb.schema import objtypes as s_objtypes
from edb.schema import pointers as s_pointers
from edb.schema import constraints as s_constraints

from edb.ir import ast as irast
from edb.ir import utils as irutils
from edb.ir import typeutils
from edb.edgeql import ast as qlast

from . import context as inference_context
from . import utils as inf_utils
from . import volatility
from . import multiplicity

from .. import context


AT_MOST_ONE = qltypes.Cardinality.AT_MOST_ONE
ONE = qltypes.Cardinality.ONE
MANY = qltypes.Cardinality.MANY
AT_LEAST_ONE = qltypes.Cardinality.AT_LEAST_ONE


class CardinalityBound(int, enum.Enum):
    '''This enum is used to perform some of the cardinality operations.'''
    ZERO = 0
    ONE = 1
    MANY = 2

    def __add__(self, other: int) -> CardinalityBound:
        return CardinalityBound(min(int(self) + other, CB_MANY))

    def __mul__(self, other: int) -> CardinalityBound:
        return CardinalityBound(min(int(self) * other, CB_MANY))

    def as_required(self) -> bool:
        return self >= CB_ONE

    def as_schema_cardinality(self) -> qltypes.SchemaCardinality:
        if self >= CB_MANY:
            return qltypes.SchemaCardinality.Many
        else:
            return qltypes.SchemaCardinality.One

    @classmethod
    def from_required(cls, required: bool) -> CardinalityBound:
        return CB_ONE if required else CB_ZERO

    @classmethod
    def from_schema_value(
        cls, card: qltypes.SchemaCardinality
    ) -> CardinalityBound:
        if card >= qltypes.SchemaCardinality.Many:
            return CB_MANY
        else:
            return CB_ONE


CB_ZERO = CardinalityBound.ZERO
CB_ONE = CardinalityBound.ONE
CB_MANY = CardinalityBound.MANY


class CardinalityBounds(NamedTuple):
    lower: CardinalityBound
    upper: CardinalityBound


def _card_to_bounds(card: qltypes.Cardinality) -> CardinalityBounds:
    lower, upper = card.to_schema_value()
    return CardinalityBounds(
        CardinalityBound.from_required(lower),
        CardinalityBound.from_schema_value(upper),
    )


def _bounds_to_card(
    lower: CardinalityBound,
    upper: CardinalityBound,
) -> qltypes.Cardinality:
    return qltypes.Cardinality.from_schema_value(
        lower.as_required(),
        upper.as_schema_cardinality(),
    )


def _card_unzip(
    args: Iterable[qltypes.Cardinality],
) -> tuple[tuple[CardinalityBound, ...], tuple[CardinalityBound, ...]]:
    card = list(zip(*(_card_to_bounds(a) for a in args)))
    lower, upper = card if card else ((), ())
    return lower, upper


def product(arg: Iterable[CardinalityBound]) -> CardinalityBound:
    res = CB_ONE
    for x in arg:
        res *= x
    return res


def cartesian_cardinality(
    args: Iterable[qltypes.Cardinality],
) -> qltypes.Cardinality:
    '''Cardinality of Cartesian product of multiple args.'''
    lower, upper = _card_unzip(args)
    return _bounds_to_card(product(lower), product(upper))


def max_cardinality(
    args: Iterable[qltypes.Cardinality],
) -> qltypes.Cardinality:
    '''Maximum lower and upper bound of specified cardinalities.'''

    lower, upper = _card_unzip(args)
    assert lower, "cannot take max cardinality of no elements"
    return _bounds_to_card(max(lower), max(upper))


def min_cardinality(
    args: Iterable[qltypes.Cardinality],
) -> qltypes.Cardinality:
    '''Minimum lower and upper bound of specified cardinalities.'''

    lower, upper = _card_unzip(args)
    assert lower, "cannot take min cardinality of no elements"
    return _bounds_to_card(min(lower), min(upper))


def _union_cardinality(
    args: Iterable[qltypes.Cardinality],
) -> qltypes.Cardinality:
    '''Cardinality of UNION of multiple args.'''
    lower, upper = _card_unzip(args)
    return _bounds_to_card(
        sum(lower, start=CB_ZERO), sum(upper, start=CB_ZERO))


VOLATILE = qltypes.Volatility.Volatile
MODIFYING = qltypes.Volatility.Modifying


def _check_op_volatility(
    args: Sequence[irast.Base],
    cards: Sequence[qltypes.Cardinality],
    ctx: inference_context.InfCtx,
) -> None:
    vols = [volatility.infer_volatility(a, env=ctx.env) for a in args]

    # Check the rules on volatility correlation: volatile operations
    # can't be cross producted with any potentially multi set. We
    # check this by assuming that a voltile operation is AT_MOST_ONE
    # and making sure that the resulting cartesian cardinality isn't
    # multi.
    for i, vol in enumerate(vols):
        if vol.is_volatile():
            cards2 = list(cards)
            cards2[i] = AT_MOST_ONE
            if cartesian_cardinality(cards2).is_multi():
                raise errors.QueryError(
                    "can not take cross product of volatile operation",
                    span=args[i].span
                )


def _common_cardinality(
    args: Sequence[irast.Base],
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    cards = [
        infer_cardinality(
            a, scope_tree=scope_tree, ctx=ctx
        ) for a in args
    ]
    _check_op_volatility(args, cards, ctx=ctx)

    return cartesian_cardinality(cards)


@functools.singledispatch
def _infer_cardinality(
    ir: irast.Base,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    raise ValueError(f'infer_cardinality: cannot handle {ir!r}')


@_infer_cardinality.register
def __infer_statement(
    ir: irast.Statement,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return infer_cardinality(
        ir.expr, scope_tree=scope_tree, ctx=ctx)


@_infer_cardinality.register
def __infer_config_insert(
    ir: irast.ConfigInsert,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return infer_cardinality(
        ir.expr, scope_tree=scope_tree, ctx=ctx)


@_infer_cardinality.register
def __infer_config_set(
    ir: irast.ConfigSet,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    card = infer_cardinality(
        ir.expr, scope_tree=scope_tree, ctx=ctx)
    if ir.required and card.can_be_zero():
        raise errors.QueryError(
            f"possibly an empty set returned for "
            f"a global declared as 'required'",
            span=ir.span,
        )
    if ir.cardinality.is_single() and not card.is_single():
        raise errors.QueryError(
            f"possibly more than one element returned for "
            f"a global declared as 'single'",
            span=ir.span,
        )

    return card


@_infer_cardinality.register
def __infer_config_reset(
    ir: irast.ConfigReset,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    if ir.selector:
        return infer_cardinality(
            ir.selector, scope_tree=scope_tree, ctx=ctx)
    else:
        return ONE


@_infer_cardinality.register
def __infer_empty_set(
    ir: irast.EmptySet,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return AT_MOST_ONE


@_infer_cardinality.register
def __infer_typeref(
    ir: irast.TypeRef,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return AT_MOST_ONE


@_infer_cardinality.register
def __infer_type_introspection(
    ir: irast.TypeIntrospection,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return ONE


@_infer_cardinality.register
def __infer_type_root(
    ir: irast.TypeRoot,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    if typeutils.is_exactly_free_object(ir.typeref):
        return ONE
    else:
        return MANY


@_infer_cardinality.register
def __infer_cleared(
    ir: irast.RefExpr,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return MANY


def _infer_pointer_cardinality(
    *,
    ptrcls: s_pointers.Pointer,
    ptrref: Optional[irast.BasePointerRef],
    irexpr: irast.Base,
    specified_required: Optional[bool] = None,
    specified_card: Optional[qltypes.SchemaCardinality] = None,
    is_mut_assignment: bool = False,
    shape_op: qlast.ShapeOp = qlast.ShapeOp.ASSIGN,
    source_ctx: Optional[parsing.Span] = None,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> None:

    env = ctx.env

    if specified_required is None:
        spec_lower_bound = None
    else:
        spec_lower_bound = CardinalityBound.from_required(specified_required)

    if specified_card is None:
        spec_upper_bound = None
    else:
        spec_upper_bound = CardinalityBound.from_schema_value(specified_card)

    expr_card = infer_cardinality(
        irexpr, scope_tree=scope_tree, ctx=ctx)

    ptrcls_schema_card = ptrcls.get_cardinality(env.schema)

    # Infer cardinality and convert it back to schema values of "ONE/MANY".
    if shape_op is qlast.ShapeOp.APPEND:
        # += in shape always means MANY
        inferred_card = qltypes.Cardinality.MANY
    elif shape_op is qlast.ShapeOp.SUBTRACT:
        # -= does not increase cardinality, but it may result in an empty set,
        # hence AT_MOST_ONE.
        inferred_card = qltypes.Cardinality.AT_MOST_ONE
    else:
        # Pull cardinality from the ptrcls, if it exists.
        # (This generally will have been populated by the source_map
        # handling in infer_toplevel_cardinality().)
        if ptrcls_schema_card.is_known():
            inferred_card = qltypes.Cardinality.from_schema_value(
                not expr_card.can_be_zero(), ptrcls_schema_card
            )
        else:
            inferred_card = expr_card

    if spec_upper_bound is None and spec_lower_bound is None:
        # Common case of no explicit specifier and no overloading.
        ptr_card = inferred_card
    else:
        # Verify that the explicitly specified (or inherited) cardinality is
        # within the cardinality bounds inferred from the expression, except
        # for mutations we punt the lower cardinality bound check to the
        # runtime DML constraint as that would produce a more meaningful error.
        inf_lower_bound, inf_upper_bound = _card_to_bounds(inferred_card)

        if spec_upper_bound is None:
            upper_bound = inf_upper_bound
        else:
            if inf_upper_bound > spec_upper_bound:
                desc = ptrcls.get_verbosename(env.schema)
                if not is_mut_assignment:
                    desc = f"computed {desc}"
                raise errors.QueryError(
                    f"possibly more than one element returned by an "
                    f"expression for a {desc} declared as 'single'",
                    span=source_ctx,
                )
            upper_bound = spec_upper_bound

        if spec_lower_bound is None:
            lower_bound = inf_lower_bound
        else:
            if inf_lower_bound < spec_lower_bound:
                if is_mut_assignment:
                    lower_bound = inf_lower_bound
                else:
                    desc = f"computed {ptrcls.get_verbosename(env.schema)}"
                    raise errors.QueryError(
                        f"possibly an empty set returned by an "
                        f"expression for a {desc} declared as 'required'",
                        span=source_ctx,
                    )
            else:
                lower_bound = spec_lower_bound

        ptr_card = _bounds_to_card(lower_bound, upper_bound)

    if (
        not ptrcls_schema_card.is_known()
        or ptrcls in ctx.env.pointer_specified_info
    ):
        if ptrcls_schema_card.is_known():
            # If we are overloading an existing pointer, take the _maximum_
            # of the cardinalities.  In practice this only means that we might
            # raise the lower bound, since the other redefinitions of bounds
            # are prohibited above and in viewgen.
            ptrcls_card = qltypes.Cardinality.from_schema_value(
                ptrcls.get_required(env.schema),
                ptrcls_schema_card,
            )
            if is_mut_assignment:
                ptr_card = cartesian_cardinality((ptrcls_card, ptr_card))
            else:
                ptr_card = max_cardinality((ptrcls_card, ptr_card))
        required, card = ptr_card.to_schema_value()
        env.schema = ptrcls.set_field_value(env.schema, 'cardinality', card)
        env.schema = ptrcls.set_field_value(env.schema, 'required', required)
        if ctx.make_updates:
            _update_cardinality_in_derived(ptrcls, env=ctx.env)

    if ptrref and ctx.make_updates:
        out_card, in_card = typeutils.cardinality_from_ptrcls(
            env.schema, ptrcls)
        assert in_card is not None
        assert out_card is not None
        ptrref.in_cardinality = in_card
        ptrref.out_cardinality = out_card


def _update_cardinality_in_derived(
    ptrcls: s_pointers.Pointer, *, env: context.Environment
) -> None:

    children = env.pointer_derivation_map.get(ptrcls)
    if children:
        ptrcls_cardinality = ptrcls.get_cardinality(env.schema)
        ptrcls_required = ptrcls.get_required(env.schema)
        assert ptrcls_cardinality.is_known()
        for child in children:
            env.schema = child.set_field_value(
                env.schema, 'cardinality', ptrcls_cardinality)
            env.schema = child.set_field_value(
                env.schema, 'required', ptrcls_required)
            _update_cardinality_in_derived(child, env=env)


def _infer_shape(
    ir: irast.Set,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> None:
    # Mark the source of the shape as being a singleton. We can't just
    # rely on the scope tree, where it might appear as optional
    # (giving us AT_MOST_ONE instead of ONE).
    ctx = ctx._replace(singletons=ctx.singletons | {ir.path_id})

    for shape_set, shape_op in ir.shape:
        new_scope = inf_utils.get_set_scope(shape_set, scope_tree, ctx=ctx)
        rptr = shape_set.expr
        if rptr.expr:
            ptrref = rptr.ptrref

            ctx.env.schema, ptrcls = typeutils.ptrcls_from_ptrref(
                ptrref, schema=ctx.env.schema)
            assert isinstance(ptrcls, s_pointers.Pointer)
            specified_card, specified_required, _ = (
                ctx.env.pointer_specified_info.get(ptrcls,
                                                   (None, False, None)))
            assert isinstance(rptr.expr, irast.Stmt)

            _infer_pointer_cardinality(
                ptrcls=ptrcls,
                ptrref=ptrref,
                source_ctx=shape_set.span,
                irexpr=rptr.expr,
                is_mut_assignment=rptr.is_mutation,
                specified_card=specified_card,
                specified_required=specified_required,
                shape_op=shape_op,
                scope_tree=new_scope,
                ctx=ctx,
            )

        _infer_shape(shape_set, scope_tree=scope_tree,
                     ctx=ctx)


def _infer_set(
    ir: irast.Set,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:

    # First compute (or look up) the "intrinsic" cardinality of the set
    if not (result := ctx.inferred_cardinality.get(ir)):
        result = _infer_set_inner(
            ir, scope_tree=scope_tree, ctx=ctx)

        # We need to cache the main result before doing the shape,
        # since sometimes the shape will refer to the enclosing set.
        ctx.inferred_cardinality[ir] = result

        new_scope = inf_utils.get_set_scope(ir, scope_tree, ctx=ctx)
        _infer_shape(
            ir, scope_tree=new_scope, ctx=ctx)

    # With that in hand, compute the cardinality of a *reference* to the
    # set from this location in the tree.
    if ir.path_id in ctx.singletons:
        return ONE
    elif (node := inf_utils.find_visible(ir, scope_tree)) is not None:
        if not node.optional:
            return ONE
        # If the set is visible, but optional, it must have upper bound ONE
        # but we still want to compute the lower bound.
        else:
            return _bounds_to_card(_card_to_bounds(result).lower, CB_ONE)
    else:
        return result


@_infer_cardinality.register
def _infer_pointer(
    ir: irast.Pointer,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    raise AssertionError('TODO: properly infer Pointer-as-Expr ')


def _infer_set_inner(
    ir: irast.Set,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    new_scope = inf_utils.get_set_scope(ir, scope_tree, ctx=ctx)

    # TODO: Migrate to Pointer-as-Expr well, and not half-assedly.
    sub_expr = irutils.sub_expr(ir)
    if sub_expr:
        expr_card = infer_cardinality(sub_expr, scope_tree=new_scope, ctx=ctx)

    if isinstance(ir.expr, irast.Pointer) and not ir.expr.is_phony:
        ptr = ir.expr

        assert ir is not ptr.source, "self-referential pointer"
        # FIXME: The thing blocking extracting Pointer inference from
        # here is that this source inference relies on using the old
        # scope_tree. I think this is probably fixable.
        source_card = infer_cardinality(
            ptr.source, scope_tree=scope_tree, ctx=ctx,
        )

        ctx.env.schema, ptrcls = typeutils.ptrcls_from_ptrref(
            ptr.ptrref, schema=ctx.env.schema)
        if ptr.expr:
            assert isinstance(ptrcls, s_pointers.Pointer)
            _infer_pointer_cardinality(
                ptrcls=ptrcls,
                ptrref=ptr.ptrref,
                irexpr=ptr.expr,
                scope_tree=scope_tree,
                ctx=ctx,
            )

        if ptr.ptrref.union_components:
            # We use cartesian cardinality instead of union cardinality
            # because the union of pointers in this context is disjoint
            # in a sense that for any specific source only a given union
            # component is used.
            rptrref_card = cartesian_cardinality(
                c.dir_cardinality(ptr.direction)
                for c in ptr.ptrref.union_components
            )
        elif ctx.ignore_computed_cards and ptr.expr:
            rptrref_card = expr_card
        elif isinstance(ptr.ptrref, irast.TypeIntersectionPointerRef):
            rptrref_card = AT_MOST_ONE
        else:
            rptrref_card = ptr.ptrref.dir_cardinality(ptr.direction)

        card = cartesian_cardinality((source_card, rptrref_card))

        # "Optional derefs" (.?>) always produce an optional result.
        if ptr.optional_deref:
            card = cartesian_cardinality((AT_MOST_ONE, card))

    elif sub_expr is not None:
        card = expr_card
    else:
        # The only things that should be here without an expression or
        # an rptr are certain visible_binding_refs (typically from
        # GROUP). We report them as MANY, but that might be refined
        # based on the scope tree in the enclosing context.
        assert ir.is_visible_binding_ref
        card = MANY

    # If this node is an optional argument bound at this location,
    # but it can't actually be zero, clear the optionality to avoid
    # subpar codegen.
    if (
        new_scope.parent_fence
        and (node := new_scope.parent_fence.find_child(
            ir.path_id, in_branches=True
        ))
        and node.optional
        and not card.can_be_zero()
    ):
        node.optional = False

    return card


def _typemod_to_card(typemod: qltypes.TypeModifier) -> qltypes.Cardinality:
    return (
        MANY if typemod is qltypes.TypeModifier.SetOfType else
        AT_MOST_ONE if typemod is qltypes.TypeModifier.OptionalType else
        ONE
    )


def _standard_call_cardinality(
    ir: irast.Call,
    cards: Sequence[qltypes.Cardinality],
    *,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    # For regular functions and operators, the general rule of
    # Cartesian cardinality of arguments applies, although we still
    # have to account for the declared return cardinality, as the
    # function might be OPTIONAL or SET OF in its return type.
    #
    # We compute the Cartesian cardinality of the functions's
    # _non-SET OF_ arguments and its return, but with the lower bound
    # of any optional arguments set to CB_ONE.
    non_aggregate_args = []
    non_aggregate_arg_cards = []

    for arg, card in zip(ir.args.values(), cards):
        typemod = arg.param_typemod
        if typemod is qltypes.TypeModifier.SingletonType:
            non_aggregate_args.append(arg.expr)
            non_aggregate_arg_cards.append(card)
        elif typemod is qltypes.TypeModifier.OptionalType:
            non_aggregate_args.append(arg.expr)
            non_aggregate_arg_cards.append(
                _bounds_to_card(CB_ONE, _card_to_bounds(card).upper)
            )

    _check_op_volatility(
        non_aggregate_args, non_aggregate_arg_cards, ctx=ctx)

    return cartesian_cardinality(
        non_aggregate_arg_cards + [_typemod_to_card(ir.typemod)])


@_infer_cardinality.register
def __infer_func_call(
    ir: irast.FunctionCall,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:

    for glob_arg in (ir.global_args or ()):
        infer_cardinality(glob_arg, scope_tree=scope_tree, ctx=ctx)

    cards: list[qltypes.Cardinality] = []
    arg_typemods: list[qltypes.TypeModifier] = []
    for arg in ir.args.values():
        card = infer_cardinality(arg.expr, scope_tree=scope_tree, ctx=ctx)
        cards.append(card)
        arg_typemods.append(arg.param_typemod)
        if ctx.make_updates:
            arg.cardinality = card

    if ir.preserves_optionality or ir.preserves_upper_cardinality:
        ret_lower_bound, ret_upper_bound = _card_to_bounds(
            _typemod_to_card(ir.typemod))

        # This is a generic aggregate function which preserves the
        # optionality and/or upper cardinality of its generic
        # argument.  For simplicity we are deliberately not checking
        # the parameters here as that would have been done at the time
        # of declaration.
        arg_cards = []
        force_multi = False

        for arg, card in zip(ir.args.values(), cards):
            typemod = arg.param_typemod
            if typemod is not qltypes.TypeModifier.OptionalType:
                arg_cards.append(card)
            else:
                force_multi |= card.is_multi()

        arg_card = zip(*(_card_to_bounds(card) for card in arg_cards))
        arg_lower, arg_upper = arg_card
        lower = (
            min(arg_lower) if ir.preserves_optionality else
            CB_ONE if ir.func_shortname == sn.QualName('std', 'assert_exists')
            else ret_lower_bound
        )
        upper = (CB_MANY if force_multi
                 else max(arg_upper) if ir.preserves_upper_cardinality
                 else ret_upper_bound)
        call_card = _bounds_to_card(lower, upper)

    else:
        call_card = _standard_call_cardinality(ir, cards, ctx=ctx)

    if ir.body is not None:
        body_card = infer_cardinality(ir.body, scope_tree=scope_tree, ctx=ctx)
        # Check that inline body cardinality does not disagree with
        # declared function cardinality.
        if body_card.can_be_zero() and not call_card.can_be_zero():
            raise errors.QueryError(
                'inline function body expression returns a possibly empty '
                'result while the function is not declared as returning '
                'OPTIONAL',
                span=ir.span,
            )
        if body_card.is_multi() and not call_card.is_multi():
            raise errors.QueryError(
                'inline function body expression possibly returns more '
                'than one element, while the function is not declared as '
                'returning SET OF',
                span=ir.span,
            )

    if ir.volatility == MODIFYING:
        if any(card.is_multi() for card in cards):
            raise errors.QueryError(
                'possibly more than one element passed into modifying function',
                span=ir.span,
            )

        if any(
            (
                card.can_be_zero()
                and typemod == qltypes.TypeModifier.SingletonType
            )
            for card, typemod in zip(cards, arg_typemods)
        ):
            raise errors.QueryError(
                'possibly an empty set passed as non-optional argument '
                'into modifying function',
                span=ir.span,
            )

    return call_card


@_infer_cardinality.register
def __infer_oper_call(
    ir: irast.OperatorCall,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    cards = []
    for arg in ir.args.values():
        card = infer_cardinality(arg.expr, scope_tree=scope_tree, ctx=ctx)
        cards.append(card)
        if ctx.make_updates:
            arg.cardinality = card

    if str(ir.func_shortname) == 'std::UNION':
        # UNION needs to "add up" cardinalities.
        return _union_cardinality(cards)
    elif str(ir.func_shortname) == 'std::EXCEPT':
        # EXCEPT cardinality cannot be greater than the first argument, but
        # the lower bound can be ZERO.
        _lower, upper = _card_to_bounds(cards[0])
        return _bounds_to_card(CB_ZERO, upper)
    elif str(ir.func_shortname) == 'std::INTERSECT':
        # INTERSECT takes the minimum of cardinalities and makes the lower
        # bound ZERO.
        _lower, upper = _card_to_bounds(min_cardinality(cards))
        return _bounds_to_card(CB_ZERO, upper)
    elif str(ir.func_shortname) == 'std::??':
        # Coalescing takes the maximum of both lower and upper bounds.
        return max_cardinality(cards)
    elif str(ir.func_shortname) in ('std::DISTINCT', 'std::IF'):
        return cartesian_cardinality(cards)
    else:
        return _standard_call_cardinality(ir, cards, ctx=ctx)


@_infer_cardinality.register
def __infer_const(
    ir: irast.BaseConstant,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return ONE


@_infer_cardinality.register
def __infer_param(
    ir: irast.QueryParameter,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return ONE if ir.required else AT_MOST_ONE


@_infer_cardinality.register
def __infer_function_param(
    ir: irast.FunctionParameter,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return ONE if ir.required else AT_MOST_ONE


@_infer_cardinality.register
def __infer_inlined_param(
    ir: irast.InlinedParameterExpr,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return ONE if ir.required else AT_MOST_ONE


@_infer_cardinality.register
def __infer_const_set(
    ir: irast.ConstantSet,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return ONE if len(ir.elements) == 1 else AT_LEAST_ONE


@_infer_cardinality.register
def __infer_typecheckop(
    ir: irast.TypeCheckOp,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return infer_cardinality(
        ir.left, scope_tree=scope_tree, ctx=ctx,
    )


@_infer_cardinality.register
def __infer_typecast(
    ir: irast.TypeCast,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    card = infer_cardinality(
        ir.expr, scope_tree=scope_tree, ctx=ctx,
    )
    # json values can be 'null', which produces an empty set, which we
    # need to reflect in the cardinality.
    if (
        typeutils.is_json(ir.from_type)
        and not ir.cardinality_mod == qlast.CardinalityModifier.Required
    ):
        card = _bounds_to_card(CB_ZERO, _card_to_bounds(card).upper)
    return card


def _is_ptr_or_self_ref(
    set: irast.Base,
    result_expr: irast.Set,
    env: context.Environment,
) -> bool:
    if not isinstance(set, irast.Set):
        return False

    srccls = env.set_types[result_expr]
    if not isinstance(srccls, s_objtypes.ObjectType):
        return False

    if set.path_id == result_expr.path_id:
        return True

    if isinstance(set.expr, irast.Pointer):
        rptr = set.expr
        return (
            isinstance(rptr.ptrref, irast.PointerRef)
            and not rptr.ptrref.is_computable
            and _is_ptr_or_self_ref(rptr.source, result_expr, env)
        )
    elif irutils.is_implicit_wrapper(set.expr):
        return _is_ptr_or_self_ref(set.expr.result, result_expr, env)
    else:
        return False


def extract_filters(
    result_set: irast.Set,
    filter_set: irast.Set,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> Sequence[tuple[Sequence[s_pointers.Pointer], irast.Set]]:

    env = ctx.env
    schema = env.schema
    scope_tree = inf_utils.get_set_scope(filter_set, scope_tree, ctx=ctx)

    expr = filter_set.expr
    if isinstance(expr, irast.OperatorCall):
        if str(expr.func_shortname) == 'std::=':
            left, right = [a.expr for a in expr.args.values()]
            op_card = _common_cardinality(
                [left, right], scope_tree=scope_tree, ctx=ctx
            )
            result_stype = env.set_types[result_set]

            if op_card.is_multi():
                pass

            elif (
                (left_matches := _is_ptr_or_self_ref(left, result_set, env))
                or _is_ptr_or_self_ref(right, result_set, env)
            ):
                # If the match was on the right, flip the args
                if not left_matches:
                    left, right = right, left

                if infer_cardinality(
                    right, scope_tree=scope_tree, ctx=ctx,
                ).is_single():
                    pointers = []
                    left_stype = env.set_types[left]
                    if left_stype == result_stype:
                        assert isinstance(left_stype, s_objtypes.ObjectType)
                        ptr = left_stype.getptr(schema, sn.UnqualName('id'))
                        pointers.append(ptr)
                    else:
                        while left.path_id != result_set.path_id:
                            if irutils.is_implicit_wrapper(left.expr):
                                left = left.expr.result
                                continue

                            assert isinstance(left.expr, irast.Pointer)
                            ptr = env.schema.get(
                                left.expr.ptrref.name,
                                type=s_pointers.Pointer
                            )
                            pointers.append(ptr)
                            left = left.expr.source
                        pointers.reverse()

                    return [(pointers, right)]

        elif str(expr.func_shortname) == 'std::AND':
            left, right = (
                irutils.unwrap_set(a.expr)
                for a in expr.args.values()
            )

            left_filters = extract_filters(
                result_set, left, scope_tree, ctx
            )
            right_filters = extract_filters(
                result_set, right, scope_tree, ctx
            )

            return [*left_filters, *right_filters]

    return []


def _all_have_exclusive(
    ptrs: Sequence[s_pointers.Pointer],
    ctx: inference_context.InfCtx,
) -> bool:
    return all(
        bool(ptr.get_exclusive_constraints(ctx.env.schema))
        for ptr in ptrs
    )


def _track_all_constraint_refs(
    ptrs: Sequence[s_pointers.Pointer],
    ctx: inference_context.InfCtx,
) -> None:
    for ptr in ptrs:
        for constr in ptr.get_exclusive_constraints(ctx.env.schema):
            # We need to track all schema refs, since an expression
            # in the schema needs to depend on any constraint
            # that affects its cardinality.
            ctx.env.add_schema_ref(constr, None)


def extract_exclusive_filters(
    result_set: irast.Set,
    filter_set: irast.Set,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> list[tuple[tuple[s_pointers.Pointer, irast.Set], ...]]:

    filtered_ptrs = extract_filters(result_set, filter_set, scope_tree, ctx)

    results: list[tuple[tuple[s_pointers.Pointer, irast.Set], ...]] = []
    if filtered_ptrs:
        schema = ctx.env.schema
        # Only look at paths where all trailing pointers are exclusive;
        # that is, if we see `.foo.bar`, `bar` must be exclusive.
        # If that's the case, then we can look at whether `.foo` is
        # exclusive or used in an exclusive object constraint.
        filtered_ptrs_map = {
            ptrs[0].get_nearest_non_derived_parent(schema): (ptrs, expr)
            for ptrs, expr in filtered_ptrs
            if _all_have_exclusive(ptrs[1:], ctx)
        }
        ptr_set = set(filtered_ptrs_map)
        # First look at each referenced pointer and see if it has
        # an exclusive constraint.
        for ptr, (ptrs, expr) in filtered_ptrs_map.items():
            if _all_have_exclusive([ptr], ctx):
                # Bingo, got an equality filter on a pointer with a
                # unique constraint
                results.append(((ptr, expr),))
                _track_all_constraint_refs(ptrs, ctx)

        # Then look at all the object exclusive constraints
        result_stype = ctx.env.set_types[result_set]
        obj_exclusives = get_object_exclusive_constraints(
            result_stype, ptr_set, ctx.env)
        for constr, obj_exc_ptrs in obj_exclusives.items():
            results.append(
                tuple((ptr, filtered_ptrs_map[ptr][1]) for ptr in obj_exc_ptrs)
            )
            ctx.env.add_schema_ref(constr, None)
            for ptr in obj_exc_ptrs:
                _track_all_constraint_refs(filtered_ptrs_map[ptr][0], ctx)

    return results


def get_object_exclusive_constraints(
    typ: s_types.Type,
    ptr_set: set[s_pointers.Pointer],
    env: context.Environment,
) -> dict[s_constraints.Constraint, frozenset[s_pointers.Pointer]]:
    """Collect any exclusive object constraints that apply.

    An object constraint applies if all of the pointers referenced
    in it are filtered on in the query.
    """

    if not isinstance(typ, s_objtypes.ObjectType):
        return {}

    schema = env.schema
    exclusive = schema.get('std::exclusive', type=s_constraints.Constraint)

    cnstrs = {}
    typ = typ.get_nearest_non_derived_parent(schema)
    for constr in typ.get_constraints(schema).objects(schema):
        if (
            constr.issubclass(schema, exclusive)
            and (subjectexpr := constr.get_subjectexpr(schema))
            # We ignore constraints with except expressions, because
            # they can't actually ensure cardinality
            and not constr.get_except_expr(schema)
            # And delegated constraints can't either
            and not constr.get_delegated(schema)
        ):
            if subjectexpr.refs is None:
                continue
            pointer_refs = frozenset({
                x for x in subjectexpr.refs.objects(schema)
                if isinstance(x, s_pointers.Pointer)
            })
            # If all of the referenced pointers are filtered on,
            # we match.
            if pointer_refs.issubset(ptr_set):
                cnstrs[constr] = pointer_refs

    return cnstrs


def _analyse_filter_clause(
    result_set: irast.Set,
    result_card: qltypes.Cardinality,
    filter_clause: irast.Set,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    if extract_exclusive_filters(result_set, filter_clause, scope_tree, ctx):
        return AT_MOST_ONE
    else:
        return result_card


def _infer_matset_cardinality(
    materialized_sets: Optional[dict[uuid.UUID, irast.MaterializedSet]],
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> None:
    if not materialized_sets:
        return
    if not ctx.make_updates:
        return

    for mat_set in materialized_sets.values():
        if (len(mat_set.uses) <= 1
                or mat_set.cardinality != qltypes.Cardinality.UNKNOWN):
            continue
        assert mat_set.materialized
        # set it to something to prevent recursion
        mat_set.cardinality = MANY
        new_scope = inf_utils.get_set_scope(
            mat_set.materialized, scope_tree, ctx=ctx)
        mat_set.cardinality = infer_cardinality(
            mat_set.materialized, scope_tree=new_scope, ctx=ctx,
        )


def _infer_dml_check_cardinality(
    ir: irast.MutatingStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> None:
    if not ctx.make_updates:
        return
    pctx = ctx._replace(singletons=ctx.singletons | {ir.result.path_id})
    for read_pol in ir.read_policies.values():
        read_pol.cardinality = infer_cardinality(
            read_pol.expr, scope_tree=scope_tree, ctx=pctx
        )

    for write_pol in ir.write_policies.values():
        for p in write_pol.policies:
            p.cardinality = infer_cardinality(
                p.expr, scope_tree=scope_tree, ctx=pctx
            )

    if ir.conflict_checks:
        for on_conflict in ir.conflict_checks:
            _infer_on_conflict_cardinality(
                on_conflict, type_has_rewrites=False,
                scope_tree=scope_tree, ctx=ctx,
            )

    if ir.rewrites:
        for rewrites in ir.rewrites.by_type.values():
            for rewrite, _ in rewrites.values():
                infer_cardinality(
                    rewrite,
                    scope_tree=scope_tree,
                    ctx=ctx,
                )


def _infer_stmt_cardinality(
    ir: irast.FilteredStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    for part, _ in (ir.bindings or []):
        infer_cardinality(part, scope_tree=scope_tree, ctx=ctx)

    result = ir.subject if isinstance(ir, irast.MutatingStmt) else ir.result
    result_card = infer_cardinality(
        result,
        scope_tree=scope_tree,
        ctx=ctx,
    )
    if ir.where:
        ir.where_card = infer_cardinality(
            ir.where, scope_tree=scope_tree, ctx=ctx,
        )

        if (
            ir.where_card.is_multi()
            # Don't generate warnings against internally generated code
            and ir.where.span
        ):
            ctx.env.warnings.append(errors.QueryError(
                'possibly more than one element returned by an expression '
                'in a FILTER clause',
                hint='If this is intended, try using any()',
                span=ir.where.span,
            ))

        # Cross with AT_MOST_ONE to ensure result can be empty
        result_card = cartesian_cardinality([result_card, AT_MOST_ONE])

    if result_card.is_multi() and ir.where:
        result_mult = multiplicity.infer_multiplicity(
            result, scope_tree=scope_tree, ctx=ctx)

        # We can only apply filter clause restrictions when the result
        # is a unique set, because if the set has duplicates we can
        # also pick out duplicates.
        if result_mult.is_unique():
            result_card = _analyse_filter_clause(
                ir.result, result_card, ir.where, scope_tree, ctx)

    _infer_matset_cardinality(
        ir.materialized_sets, scope_tree=scope_tree, ctx=ctx)

    if isinstance(ir, irast.MutatingStmt):
        _infer_dml_check_cardinality(ir, scope_tree=scope_tree, ctx=ctx)

    return result_card


def _infer_singleton_only(
    part: irast.Set,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    new_scope = inf_utils.get_set_scope(part, scope_tree, ctx=ctx)
    card = infer_cardinality(part, scope_tree=new_scope, ctx=ctx)
    if card.is_multi():
        raise errors.QueryError(
            'possibly more than one element returned by an expression '
            'where only singletons are allowed',
            span=part.span)
    return card


def _infer_on_conflict_cardinality(
    on_conflict: irast.OnConflictClause,
    *,
    type_has_rewrites: bool,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    # Note: If we start supporting ELSE without ON, we'll need to
    # factor the cardinality of this into the else_card below
    infer_cardinality(
        on_conflict.select_ir, scope_tree=scope_tree, ctx=ctx)

    card = AT_MOST_ONE
    if on_conflict.else_ir:
        else_card = infer_cardinality(
            on_conflict.else_ir, scope_tree=scope_tree, ctx=ctx)
        card = max_cardinality((card, else_card))
        if type_has_rewrites:
            card = _bounds_to_card(CB_ZERO, _card_to_bounds(card).upper)

    return card


@_infer_cardinality.register
def __infer_select_stmt(
    ir: irast.SelectStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:

    if ir.iterator_stmt:
        iter_card = infer_cardinality(
            ir.iterator_stmt, scope_tree=scope_tree, ctx=ctx,
        )

    stmt_card = _infer_stmt_cardinality(ir, scope_tree=scope_tree, ctx=ctx)

    for part in [ir.limit, ir.offset] + [
            sort.expr for sort in (ir.orderby or ())]:
        if part:
            _infer_singleton_only(part, scope_tree=scope_tree, ctx=ctx)

    if ir.limit is not None:
        if (
            isinstance(ir.limit.expr, irast.IntegerConstant)
            and ir.limit.expr.value == '1'
        ):
            # Explicit LIMIT 1 clause.
            stmt_card = _bounds_to_card(
                _card_to_bounds(stmt_card).lower, CB_ONE)
        elif (
            not isinstance(ir.limit.expr, irast.IntegerConstant)
            or ir.limit.expr.value == '0'
        ):
            # LIMIT 0 or a non-static LIMIT that could be 0
            stmt_card = _bounds_to_card(
                CB_ZERO, _card_to_bounds(stmt_card).upper)

    if ir.offset is not None:
        stmt_card = _bounds_to_card(
            CB_ZERO, _card_to_bounds(stmt_card).upper)

    if ir.iterator_stmt:
        stmt_card = cartesian_cardinality((stmt_card, iter_card))

    # But actually! Check if it is overridden
    if ir.card_inference_override:
        stmt_card = infer_cardinality(
            ir.card_inference_override, scope_tree=scope_tree, ctx=ctx)

    return stmt_card


@_infer_cardinality.register
def __infer_insert_stmt(
    ir: irast.InsertStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    for part, _ in (ir.bindings or []):
        infer_cardinality(part, scope_tree=scope_tree, ctx=ctx)

    infer_cardinality(
        ir.subject, scope_tree=scope_tree, ctx=ctx
    )
    new_scope = inf_utils.get_set_scope(ir.result, scope_tree, ctx=ctx)
    infer_cardinality(
        ir.result, scope_tree=new_scope, ctx=ctx
    )

    assert not ir.iterator_stmt, "InsertStmt shouldn't ever have an iterator"

    _infer_matset_cardinality(
        ir.materialized_sets, scope_tree=scope_tree, ctx=ctx)

    _infer_dml_check_cardinality(ir, scope_tree=scope_tree, ctx=ctx)

    # INSERT without a FOR is always a singleton.
    if not ir.on_conflict:
        return ONE
    # ... except if UNLESS CONFLICT is used
    else:
        return _infer_on_conflict_cardinality(
            ir.on_conflict,
            type_has_rewrites=bool(ir.write_policies),
            scope_tree=scope_tree,
            ctx=ctx,
        )


@_infer_cardinality.register
def __infer_update_stmt(
    ir: irast.UpdateStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    infer_cardinality(
        ir.result, scope_tree=scope_tree, ctx=ctx,
    )

    return _infer_stmt_cardinality(ir, scope_tree=scope_tree, ctx=ctx)


@_infer_cardinality.register
def __infer_delete_stmt(
    ir: irast.DeleteStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    infer_cardinality(
        ir.result, scope_tree=scope_tree, ctx=ctx,
    )

    return _infer_stmt_cardinality(ir, scope_tree=scope_tree, ctx=ctx)


@_infer_cardinality.register
def __infer_group_stmt(
    ir: irast.GroupStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    infer_cardinality(ir.subject, scope_tree=scope_tree, ctx=ctx)
    for key, (binding, _) in ir.using.items():
        binding_card = _infer_singleton_only(
            binding, scope_tree=scope_tree, ctx=ctx)
        ir.using[key] = binding, binding_card

    infer_cardinality(ir.group_binding, scope_tree=scope_tree, ctx=ctx)

    _infer_stmt_cardinality(ir, scope_tree=scope_tree, ctx=ctx)

    for part in (ir.orderby or ()):
        _infer_singleton_only(part.expr, scope_tree=scope_tree, ctx=ctx)

    return MANY


@_infer_cardinality.register
def __infer_slice(
    ir: irast.SliceIndirection,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    # slice indirection cardinality depends on the cardinality of
    # the base expression and the slice index expressions
    args: list[irast.Base] = [ir.expr]
    if ir.start is not None:
        args.append(ir.start)
    if ir.stop is not None:
        args.append(ir.stop)

    return _common_cardinality(args, scope_tree=scope_tree, ctx=ctx)


@_infer_cardinality.register
def __infer_index(
    ir: irast.IndexIndirection,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    # index indirection cardinality depends on both the cardinality of
    # the base expression and the index expression
    return _common_cardinality(
        [ir.expr, ir.index], scope_tree=scope_tree, ctx=ctx,
    )


@_infer_cardinality.register
def __infer_array(
    ir: irast.Array,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return _common_cardinality(ir.elements, scope_tree=scope_tree, ctx=ctx)


@_infer_cardinality.register
def __infer_tuple(
    ir: irast.Tuple,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return _common_cardinality(
        [el.val for el in ir.elements], scope_tree=scope_tree, ctx=ctx
    )


@_infer_cardinality.register
def __infer_trigger_anchor(
    ir: irast.TriggerAnchor,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return MANY


@_infer_cardinality.register
def __infer_searchable_string(
    ir: irast.FTSDocument,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    return _common_cardinality(
        (ir.text, ir.language), scope_tree=scope_tree, ctx=ctx
    )


def infer_cardinality(
    ir: irast.Base,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inference_context.InfCtx,
) -> qltypes.Cardinality:
    key = (ir, scope_tree, ctx.singletons)
    result = ctx.inferred_cardinality.get(key)
    if result is not None:
        return result

    if isinstance(ir, irast.Set):
        result = _infer_set(
            ir, scope_tree=scope_tree, ctx=ctx,
        )
    else:
        result = _infer_cardinality(ir, scope_tree=scope_tree, ctx=ctx)

    if result not in {AT_MOST_ONE, ONE, MANY, AT_LEAST_ONE}:
        raise errors.QueryError(
            'could not determine the cardinality of '
            'set produced by expression',
            span=ir.span)

    ctx.inferred_cardinality[key] = result

    return result


def is_subset_cardinality(
    card0: qltypes.Cardinality, card1: qltypes.Cardinality
) -> bool:
    '''Determine if card0 is a subset of card1.'''
    l0, u0 = _card_to_bounds(card0)
    l1, u1 = _card_to_bounds(card1)

    return l0 >= l1 and u0 <= u1


================================================
FILE: edb/edgeql/compiler/inference/context.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2020-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations
from typing import Optional, NamedTuple

import dataclasses

from edb.ir import ast as irast
from edb.edgeql import qltypes

from .. import context


@dataclasses.dataclass(frozen=True, eq=False)
class MultiplicityInfo:
    """Extended multiplicity descriptor"""

    #: Actual multiplicity number
    own: qltypes.Multiplicity
    #: Whether this multiplicity descriptor describes
    #: part of a disjoint set.
    disjoint_union: bool = False

    def is_empty(self) -> bool:
        return self.own.is_empty()

    def is_unique(self) -> bool:
        return self.own.is_unique()

    def is_duplicate(self) -> bool:
        return self.own.is_duplicate()


class InfCtx(NamedTuple):
    env: context.Environment
    inferred_cardinality: dict[
        tuple[irast.Base, irast.ScopeTreeNode, frozenset[irast.PathId]]
        | irast.Base,
        qltypes.Cardinality,
    ]
    inferred_multiplicity: dict[
        tuple[irast.Base, irast.ScopeTreeNode, Optional[irast.PathId]],
        MultiplicityInfo,
    ]
    singletons: frozenset[irast.PathId]
    distinct_iterator: Optional[irast.PathId]
    ignore_computed_cards: bool
    # Whether to make updates to the cardinality fields in the IR/schema.
    # This is used in cases where we need to do a "hypothetical"
    # inference, but don't want to affect real state.
    make_updates: bool


def make_ctx(env: context.Environment) -> InfCtx:
    return InfCtx(
        env=env,
        inferred_cardinality={},
        inferred_multiplicity={},
        singletons=frozenset(env.singletons),
        distinct_iterator=None,
        ignore_computed_cards=False,
        make_updates=True,
    )


================================================
FILE: edb/edgeql/compiler/inference/multiplicity.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2020-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL multiplicity inference.

A top-down multiplicity inferer that traverses the full AST populating
multiplicity fields and performing multiplicity checks.
"""


from __future__ import annotations
from typing import Iterable

import dataclasses
import functools
import itertools

from edb.common.typeutils import downcast

from edb import errors

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from edb.schema import objtypes as s_objtypes
from edb.schema import pointers as s_pointers

from edb.ir import ast as irast
from edb.ir import typeutils as irtyputils
from edb.ir import utils as irutils

from . import cardinality
from . import context as inf_ctx
from . import utils as inf_utils


EMPTY = inf_ctx.MultiplicityInfo(own=qltypes.Multiplicity.EMPTY)
UNIQUE = inf_ctx.MultiplicityInfo(own=qltypes.Multiplicity.UNIQUE)
DUPLICATE = inf_ctx.MultiplicityInfo(own=qltypes.Multiplicity.DUPLICATE)
DISTINCT_UNION = inf_ctx.MultiplicityInfo(
    own=qltypes.Multiplicity.UNIQUE,
    disjoint_union=True,
)


@dataclasses.dataclass(frozen=True, eq=False)
class ContainerMultiplicityInfo(inf_ctx.MultiplicityInfo):
    """Multiplicity descriptor for an expression returning a container"""

    #: Individual multiplicity values for container elements.
    elements: tuple[inf_ctx.MultiplicityInfo, ...] = ()


def _max_multiplicity(
    args: Iterable[inf_ctx.MultiplicityInfo],
) -> inf_ctx.MultiplicityInfo:
    arg_list = [a.own for a in args]
    if not arg_list:
        max_mult = qltypes.Multiplicity.UNIQUE
    else:
        max_mult = max(arg_list)

    return inf_ctx.MultiplicityInfo(own=max_mult)


def _min_multiplicity(
    args: Iterable[inf_ctx.MultiplicityInfo],
) -> inf_ctx.MultiplicityInfo:
    arg_list = [a.own for a in args]
    if not arg_list:
        min_mult = qltypes.Multiplicity.UNIQUE
    else:
        min_mult = min(arg_list)

    return inf_ctx.MultiplicityInfo(own=min_mult)


def _common_multiplicity(
    args: Iterable[irast.Base],
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return _max_multiplicity(
        infer_multiplicity(a, scope_tree=scope_tree, ctx=ctx) for a in args)


@functools.singledispatch
def _infer_multiplicity(
    ir: irast.Base,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    # return DUPLICATE
    raise ValueError(f'infer_multiplicity: cannot handle {ir!r}')


@_infer_multiplicity.register
def __infer_config_insert(
    ir: irast.ConfigInsert,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return infer_multiplicity(
        ir.expr, scope_tree=scope_tree, ctx=ctx)


@_infer_multiplicity.register
def __infer_config_set(
    ir: irast.ConfigSet,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return infer_multiplicity(
        ir.expr, scope_tree=scope_tree, ctx=ctx)


@_infer_multiplicity.register
def __infer_config_reset(
    ir: irast.ConfigReset,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    if ir.selector:
        return infer_multiplicity(
            ir.selector, scope_tree=scope_tree, ctx=ctx)
    else:
        return UNIQUE


@_infer_multiplicity.register
def __infer_empty_set(
    ir: irast.EmptySet,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return EMPTY


@_infer_multiplicity.register
def __infer_type_introspection(
    ir: irast.TypeIntrospection,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    # TODO: The result is always UNIQUE, but we still want to actually
    # introspect the expression. Unfortunately, currently the
    # expression is not available at this stage.
    #
    # E.g. consider:
    #   WITH X := Foo {bar := {Bar, Bar}}
    #   SELECT INTROSPECT TYPEOF X.bar;
    return UNIQUE


@_infer_multiplicity.register
def __infer_type_root(
    ir: irast.TypeRoot,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return UNIQUE


@_infer_multiplicity.register
def __infer_cleared(
    ir: irast.RefExpr,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return DUPLICATE


def _infer_shape(
    ir: irast.Set,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> None:
    for shape_set, shape_op in ir.shape:
        new_scope = inf_utils.get_set_scope(shape_set, scope_tree, ctx=ctx)

        rptr = shape_set.expr
        if rptr.expr:
            expr_mult = infer_multiplicity(
                rptr.expr, scope_tree=new_scope, ctx=ctx)

            ptrref = rptr.ptrref
            if (
                expr_mult.is_duplicate()
                and shape_op is not qlast.ShapeOp.APPEND
                and shape_op is not qlast.ShapeOp.SUBTRACT
                and irtyputils.is_object(ptrref.out_target)
            ):
                ctx.env.schema, ptrcls = irtyputils.ptrcls_from_ptrref(
                    ptrref, schema=ctx.env.schema)
                assert isinstance(ptrcls, s_pointers.Pointer)
                desc = ptrcls.get_verbosename(ctx.env.schema)
                if not rptr.is_mutation:
                    desc = f"computed {desc}"
                raise errors.QueryError(
                    f'possibly not a distinct set returned by an '
                    f'expression for a {desc}',
                    hint=(
                        f'You can use assert_distinct() around the expression '
                        f'to turn this into a runtime assertion, or the '
                        f'DISTINCT operator to silently discard duplicate '
                        f'elements.'
                    ),
                    span=shape_set.span
                )

        _infer_shape(
            shape_set, scope_tree=scope_tree, ctx=ctx)


def _infer_set(
    ir: irast.Set,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    result = _infer_set_inner(
        ir, scope_tree=scope_tree, ctx=ctx
    )
    ctx.inferred_multiplicity[ir, scope_tree, ctx.distinct_iterator] = result

    # The shape doesn't affect multiplicity, but requires validation.
    _infer_shape(ir, scope_tree=scope_tree, ctx=ctx)

    return result


def _infer_set_inner(
    ir: irast.Set,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    new_scope = inf_utils.get_set_scope(ir, scope_tree, ctx=ctx)

    # TODO: Migrate to Pointer-as-Expr well, and not half-assedly.
    sub_expr = irutils.sub_expr(ir)
    if sub_expr is None:
        expr_mult = None
    else:
        expr_mult = infer_multiplicity(sub_expr, scope_tree=new_scope, ctx=ctx)

    if isinstance(ir.expr, irast.Pointer):
        ptr = ir.expr
        src_mult = infer_multiplicity(
            ptr.source, scope_tree=new_scope, ctx=ctx
        )

        if isinstance(ptr.ptrref, irast.TupleIndirectionPointerRef):
            if isinstance(src_mult, ContainerMultiplicityInfo):
                idx = irtyputils.get_tuple_element_index(ptr.ptrref)
                path_mult = src_mult.elements[idx]
            else:
                # All bets are off for tuple elements coming from
                # opaque tuples.
                path_mult = DUPLICATE
        elif not irtyputils.is_object(ir.typeref):
            # This is not an expression and is some kind of scalar, so
            # multiplicity cannot be guaranteed to be UNIQUE (most scalar
            # expressions don't have an implicit requirement to be sets)
            # unless we also have an exclusive constraint.
            if (
                expr_mult is not None
                and inf_utils.find_visible(ptr.source, new_scope) is not None
            ):
                path_mult = expr_mult
            else:
                schema = ctx.env.schema
                # We should only have some kind of path terminating in a
                # property here.
                assert isinstance(ptr.ptrref, irast.PointerRef)
                pointer = schema.get_by_id(
                    ptr.ptrref.id, type=s_pointers.Pointer
                )
                if pointer.is_exclusive(schema):
                    # Got an exclusive constraint
                    path_mult = UNIQUE
                else:
                    path_mult = DUPLICATE
        else:
            # This is some kind of a link at the end of a path.
            # Therefore the target is a proper set.
            path_mult = UNIQUE

    elif expr_mult is not None:
        path_mult = expr_mult

    else:
        # Evidently this is not a pointer, expression, or a scalar.
        # This is an object type and therefore a proper set.
        path_mult = UNIQUE

    if (
        not path_mult.is_duplicate()
        and irutils.get_path_root(ir).path_id == ctx.distinct_iterator
    ):
        path_mult = dataclasses.replace(path_mult, disjoint_union=True)

    if irtyputils.is_free_object(ir.typeref):
        path_mult = UNIQUE

    return path_mult


@_infer_multiplicity.register
def __infer_func_call(
    ir: irast.FunctionCall,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    card = cardinality.infer_cardinality(ir, scope_tree=scope_tree, ctx=ctx)
    args_mult = []
    for arg in ir.args.values():
        arg_mult = infer_multiplicity(arg.expr, scope_tree=scope_tree, ctx=ctx)
        args_mult.append(arg_mult)
        arg.multiplicity = arg_mult.own

    if ir.global_args:
        for g_arg in ir.global_args:
            _infer_set(g_arg, scope_tree=scope_tree, ctx=ctx)

    if ir.body:
        infer_multiplicity(ir.body, scope_tree=scope_tree, ctx=ctx)

    if card.is_single():
        return UNIQUE
    elif str(ir.func_shortname) == 'std::assert_distinct':
        return UNIQUE
    elif str(ir.func_shortname) == 'std::assert_exists':
        return args_mult[1]
    elif str(ir.func_shortname) == 'std::enumerate':
        # The output of enumerate is always of multiplicity UNIQUE because
        # it's a set of tuples with first elements being guaranteed to be
        # distinct.
        return ContainerMultiplicityInfo(
            own=qltypes.Multiplicity.UNIQUE,
            elements=(UNIQUE,) + tuple(args_mult),
        )
    else:
        # If the function returns a set (for any reason), all bets are off
        # and the maximum multiplicity cannot be inferred.
        return DUPLICATE


@_infer_multiplicity.register
def __infer_oper_call(
    ir: irast.OperatorCall,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    card = cardinality.infer_cardinality(ir, scope_tree=scope_tree, ctx=ctx)
    mult: list[inf_ctx.MultiplicityInfo] = []
    cards: list[qltypes.Cardinality] = []
    for arg in ir.args.values():
        cards.append(
            cardinality.infer_cardinality(
                arg.expr, scope_tree=scope_tree, ctx=ctx
            )
        )

        m = infer_multiplicity(arg.expr, scope_tree=scope_tree, ctx=ctx)
        arg.multiplicity = m.own
        mult.append(m)

    op_name = str(ir.func_shortname)

    if op_name == 'std::UNION':
        # UNION will produce multiplicity DUPLICATE unless most or all of
        # the elements multiplicity is ZERO (from an empty set), or
        # all of the elements are sets of unrelated object types of
        # multiplicity at most UNIQUE, or if all elements have been
        # proven to be disjoint (e.g. a UNION of INSERTs).
        result = EMPTY

        arg_type = ctx.env.set_types[ir.args[0].expr]
        if isinstance(arg_type, s_objtypes.ObjectType):
            types: list[s_objtypes.ObjectType] = [
                downcast(s_objtypes.ObjectType, ctx.env.set_types[arg.expr])
                for arg in ir.args.values()
            ]

            lineages = [
                (t,) + tuple(t.descendants(ctx.env.schema))
                for t in types
            ]
            flattened = tuple(itertools.chain.from_iterable(lineages))
            types_disjoint = len(flattened) == len(frozenset(flattened))
        else:
            types_disjoint = False

        for m in mult:
            if m.is_unique():
                if (
                    result.is_empty()
                    or types_disjoint
                    or (result.disjoint_union and m.disjoint_union)
                ):
                    result = m
                else:
                    result = DUPLICATE
                    break
            elif m.is_duplicate():
                result = DUPLICATE
                break
            else:
                # ZERO
                pass

        return result

    elif op_name == 'std::EXCEPT':
        # EXCEPT will produce multiplicity no greater than that of its first
        # argument.
        return mult[0]

    elif op_name == 'std::INTERSECT':
        # INTERSECT will produce the minimum multiplicity of its arguments.
        return _min_multiplicity((mult[0], mult[1]))

    elif op_name == 'std::DISTINCT':
        if mult[0] == EMPTY:
            return EMPTY
        else:
            return UNIQUE
    elif op_name == 'std::IF':
        # If the cardinality of the condition is more than ONE, then
        # the multiplicity cannot be inferred.
        if cards[1].is_single():
            # Now it's just a matter of the multiplicity of the
            # possible results.
            return _max_multiplicity((mult[0], mult[2]))
        else:
            return DUPLICATE
    elif op_name == 'std::??':
        return _max_multiplicity((mult[0], mult[1]))
    elif card.is_single():
        return UNIQUE
    elif op_name in ('std::++', 'std::+'):
        # Operators known to be injective.
        # Basically just done to avoid breaking backward compatability
        # more than was necessary, because we used to *always* use this
        # path, which was wrong.
        result = _max_multiplicity(mult)
        if result.is_duplicate():
            return result

        # Even when arguments are of multiplicity UNIQUE, we cannot
        # exclude the possibility of the result being of multiplicity
        # DUPLICATE. We need to check that at most one argument has
        # cardinality more than ONE.

        if len([card for card in cards if card.is_multi()]) > 1:
            return DUPLICATE
        else:
            return result
    else:
        # Everything else.
        return DUPLICATE


@_infer_multiplicity.register
def __infer_const(
    ir: irast.BaseConstant,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return UNIQUE


@_infer_multiplicity.register
def __infer_param(
    ir: irast.QueryParameter,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return UNIQUE


@_infer_multiplicity.register
def __infer_function_param(
    ir: irast.FunctionParameter,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return UNIQUE


@_infer_multiplicity.register
def __infer_inlined_param(
    ir: irast.InlinedParameterExpr,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return UNIQUE


@_infer_multiplicity.register
def __infer_const_set(
    ir: irast.ConstantSet,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    # Is it worth doing this? It won't trigger in the common case of having
    # performed constant extraction.
    els = set()
    for el in ir.elements:
        if isinstance(el, irast.BaseConstant):
            els.add(el.value)
        else:
            return DUPLICATE

    if len(ir.elements) == len(els):
        return UNIQUE
    else:
        return DUPLICATE


@_infer_multiplicity.register
def __infer_typecheckop(
    ir: irast.TypeCheckOp,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    # Unless this is a singleton, multiplicity cannot be assumed to be UNIQUE
    card = cardinality.infer_cardinality(
        ir, scope_tree=scope_tree, ctx=ctx)

    infer_multiplicity(ir.left, scope_tree=scope_tree, ctx=ctx)

    if card is not None and card.is_single():
        return UNIQUE
    else:
        return DUPLICATE


@_infer_multiplicity.register
def __infer_typecast(
    ir: irast.TypeCast,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return infer_multiplicity(
        ir.expr, scope_tree=scope_tree, ctx=ctx,
    )


def _infer_stmt_multiplicity(
    ir: irast.FilteredStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    # WITH block bindings need to be validated; they don't have to
    # have multiplicity UNIQUE, but their sub-expressions must be valid.
    for part, _ in (ir.bindings or []):
        infer_multiplicity(part, scope_tree=scope_tree, ctx=ctx)

    subj = ir.subject if isinstance(ir, irast.MutatingStmt) else ir.result
    result = infer_multiplicity(
        subj,
        scope_tree=scope_tree,
        ctx=ctx,
    )

    if ir.where:
        infer_multiplicity(ir.where, scope_tree=scope_tree, ctx=ctx)
        filtered_ptrs = cardinality.extract_filters(
            subj, ir.where, scope_tree, ctx)
        for _, flt_expr in filtered_ptrs:
            # Check if any of the singleton filter expressions in FILTER
            # reference enclosing iterators with multiplicity UNIQUE, and
            # if so, indicate to the enclosing iterator that this UNION
            # is guaranteed to be disjoint.
            if (
                irutils.get_path_root(flt_expr).path_id
                == ctx.distinct_iterator
                or irutils.get_path_root(irutils.unwrap_set(flt_expr)).path_id
                == ctx.distinct_iterator
            ) and not infer_multiplicity(
                flt_expr, scope_tree=scope_tree, ctx=ctx
            ).is_duplicate():
                return DISTINCT_UNION

    return result


def _infer_for_multiplicity(
    ir: irast.SelectStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    itset = ir.iterator_stmt
    assert itset is not None
    itexpr = itset.expr
    assert itexpr is not None
    itmult = infer_multiplicity(itset, scope_tree=scope_tree, ctx=ctx)

    if itmult != DUPLICATE:
        ctx = ctx._replace(distinct_iterator=itset.path_id)
    result_mult = infer_multiplicity(ir.result, scope_tree=scope_tree, ctx=ctx)

    if isinstance(ir.result.expr, irast.InsertStmt):
        # A union of inserts always has multiplicity UNIQUE
        return UNIQUE
    elif itmult.is_duplicate():
        return DUPLICATE
    else:
        if result_mult.disjoint_union:
            # If we know the union was disjoint wrt this FOR, then our
            # set is unique (or empty maybe), but we have to clear
            # disjoint_union since we it was only with respect to this
            # FOR, so we can't have it leak.
            return dataclasses.replace(result_mult, disjoint_union=False)
        else:
            return DUPLICATE


@_infer_multiplicity.register
def __infer_select_stmt(
    ir: irast.SelectStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:

    if ir.iterator_stmt is not None:
        stmt_mult = _infer_for_multiplicity(ir, scope_tree=scope_tree, ctx=ctx)
    else:
        stmt_mult = _infer_stmt_multiplicity(
            ir, scope_tree=scope_tree, ctx=ctx)

        clauses = (
            [ir.limit, ir.offset]
            + [sort.expr for sort in (ir.orderby or ())]
        )

        for clause in filter(None, clauses):
            new_scope = inf_utils.get_set_scope(clause, scope_tree, ctx=ctx)
            infer_multiplicity(clause, scope_tree=new_scope, ctx=ctx)

    if ir.card_inference_override:
        stmt_mult = infer_multiplicity(
            ir.card_inference_override, scope_tree=scope_tree, ctx=ctx)

    return stmt_mult


@_infer_multiplicity.register
def __infer_insert_stmt(
    ir: irast.InsertStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    # WITH block bindings need to be validated, they don't have to
    # have multiplicity UNIQUE, but their sub-expressions must be valid.
    for part, _ in (ir.bindings or []):
        infer_multiplicity(part, scope_tree=scope_tree, ctx=ctx)

    # INSERT will always return a proper set, but we still want to
    # process the sub-expressions.
    infer_multiplicity(
        ir.subject, scope_tree=scope_tree, ctx=ctx
    )
    new_scope = inf_utils.get_set_scope(ir.result, scope_tree, ctx=ctx)
    infer_multiplicity(
        ir.result, scope_tree=new_scope, ctx=ctx
    )

    if ir.on_conflict:
        _infer_on_conflict_clause(
            ir.on_conflict, scope_tree=scope_tree, ctx=ctx
        )

    _infer_mutating_stmt(ir, scope_tree=scope_tree, ctx=ctx)

    return DISTINCT_UNION


@_infer_multiplicity.register
def __infer_update_stmt(
    ir: irast.UpdateStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    # Presumably UPDATE will always return a proper set, even if it's
    # fed something with higher multiplicity, but we still want to
    # process the expression being updated.
    infer_multiplicity(
        ir.result, scope_tree=scope_tree, ctx=ctx,
    )
    result = _infer_stmt_multiplicity(ir, scope_tree=scope_tree, ctx=ctx)

    _infer_mutating_stmt(ir, scope_tree=scope_tree, ctx=ctx)

    if result is EMPTY:
        return EMPTY
    else:
        return UNIQUE


@_infer_multiplicity.register
def __infer_delete_stmt(
    ir: irast.DeleteStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    # Presumably DELETE will always return a proper set, even if it's
    # fed something with higher multiplicity, but we still want to
    # process the expression being deleted.
    infer_multiplicity(
        ir.result, scope_tree=scope_tree, ctx=ctx,
    )
    result = _infer_stmt_multiplicity(ir, scope_tree=scope_tree, ctx=ctx)

    _infer_mutating_stmt(ir, scope_tree=scope_tree, ctx=ctx)

    if result is EMPTY:
        return EMPTY
    else:
        return UNIQUE


def _infer_mutating_stmt(
    ir: irast.MutatingStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> None:
    if ir.conflict_checks:
        for clause in ir.conflict_checks:
            _infer_on_conflict_clause(clause, scope_tree=scope_tree, ctx=ctx)

    for write_pol in ir.write_policies.values():
        for policy in write_pol.policies:
            infer_multiplicity(policy.expr, scope_tree=scope_tree, ctx=ctx)

    for read_pol in ir.read_policies.values():
        infer_multiplicity(read_pol.expr, scope_tree=scope_tree, ctx=ctx)

    if ir.rewrites:
        for rewrites in ir.rewrites.by_type.values():
            for rewrite, _ in rewrites.values():
                infer_multiplicity(
                    rewrite,
                    scope_tree=scope_tree,
                    ctx=ctx,
                )


def _infer_on_conflict_clause(
    ir: irast.OnConflictClause,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> None:
    for part in [ir.select_ir, ir.else_ir]:
        if part:
            infer_multiplicity(part, scope_tree=scope_tree, ctx=ctx)


@_infer_multiplicity.register
def __infer_group_stmt(
    ir: irast.GroupStmt,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    infer_multiplicity(ir.subject, scope_tree=scope_tree, ctx=ctx)
    for binding, _ in ir.using.values():
        infer_multiplicity(binding, scope_tree=scope_tree, ctx=ctx)
    _infer_stmt_multiplicity(ir, scope_tree=scope_tree, ctx=ctx)

    for clause in (ir.orderby or ()):
        new_scope = inf_utils.get_set_scope(clause.expr, scope_tree, ctx=ctx)
        infer_multiplicity(clause.expr, scope_tree=new_scope, ctx=ctx)

    infer_multiplicity(ir.group_binding, scope_tree=scope_tree, ctx=ctx)
    if ir.grouping_binding:
        infer_multiplicity(ir.grouping_binding, scope_tree=scope_tree, ctx=ctx)

    for set in ir.group_aggregate_sets:
        if set:
            infer_multiplicity(set, scope_tree=scope_tree, ctx=ctx)

    # N.B: The type is usually a free object (except in some
    # internal tests), which are always unique
    if irtyputils.is_free_object(ir.typeref):
        return UNIQUE

    return DUPLICATE


@_infer_multiplicity.register
def __infer_slice(
    ir: irast.SliceIndirection,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    # Slice indirection multiplicity is guaranteed to be UNIQUE as long
    # as the cardinality of this expression is at most one, otherwise
    # the results of index indirection can contain values with
    # multiplicity > 1.

    infer_multiplicity(ir.expr, scope_tree=scope_tree, ctx=ctx)
    if ir.start:
        infer_multiplicity(ir.start, scope_tree=scope_tree, ctx=ctx)
    if ir.stop:
        infer_multiplicity(ir.stop, scope_tree=scope_tree, ctx=ctx)

    card = cardinality.infer_cardinality(
        ir, scope_tree=scope_tree, ctx=ctx)
    if card is not None and card.is_single():
        return UNIQUE
    else:
        return DUPLICATE


@_infer_multiplicity.register
def __infer_index(
    ir: irast.IndexIndirection,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    # Index indirection multiplicity is guaranteed to be UNIQUE as long
    # as the cardinality of this expression is at most one, otherwise
    # the results of index indirection can contain values with
    # multiplicity > 1.
    card = cardinality.infer_cardinality(
        ir, scope_tree=scope_tree, ctx=ctx)

    infer_multiplicity(ir.expr, scope_tree=scope_tree, ctx=ctx)
    infer_multiplicity(ir.index, scope_tree=scope_tree, ctx=ctx)

    if card is not None and card.is_single():
        return UNIQUE
    else:
        return DUPLICATE


@_infer_multiplicity.register
def __infer_array(
    ir: irast.Array,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return _common_multiplicity(ir.elements, scope_tree=scope_tree, ctx=ctx)


@_infer_multiplicity.register
def __infer_tuple(
    ir: irast.Tuple,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    els = tuple(
        infer_multiplicity(el.val, scope_tree=scope_tree, ctx=ctx)
        for el in ir.elements
    )
    cards = [
        cardinality.infer_cardinality(el.val, scope_tree=scope_tree, ctx=ctx)
        for el in ir.elements
    ]

    num_many = sum(card.is_multi() for card in cards)
    new_els = []
    for el, card in zip(els, cards):
        # If more than one tuple element is many, everything has DUPLICATE
        # multiplicity.
        if num_many > 1:
            el = DUPLICATE
        # If exactly one tuple element is many, then *that* element
        # can keep its underlying multiplicity, while everything else
        # becomes DUPLICATE.
        elif num_many == 1 and card.is_single():
            el = DUPLICATE
        new_els.append(el)

    return ContainerMultiplicityInfo(
        own=_max_multiplicity(els).own,
        elements=tuple(new_els),
    )


@_infer_multiplicity.register
def __infer_trigger_anchor(
    ir: irast.TriggerAnchor,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return UNIQUE


@_infer_multiplicity.register
def __infer_searchable_string(
    ir: irast.FTSDocument,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    return _common_multiplicity(
        (ir.text, ir.language), scope_tree=scope_tree, ctx=ctx
    )


def infer_multiplicity(
    ir: irast.Base,
    *,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> inf_ctx.MultiplicityInfo:
    assert ctx.make_updates, (
        "multiplicity inference hasn't implemented make_updates=False yet")

    result = ctx.inferred_multiplicity.get(
        (ir, scope_tree, ctx.distinct_iterator))
    if result is not None:
        return result

    # We can use cardinality as a helper in determining multiplicity,
    # since singletons have multiplicity one.
    card = cardinality.infer_cardinality(ir, scope_tree=scope_tree, ctx=ctx)

    if isinstance(ir, irast.Set):
        result = _infer_set(ir, scope_tree=scope_tree, ctx=ctx)
    else:
        result = _infer_multiplicity(ir, scope_tree=scope_tree, ctx=ctx)

    if card is not None and card.is_single() and result.is_duplicate():
        # We've validated multiplicity, so now we can just override it
        # safely.
        result = UNIQUE

    if not isinstance(result, inf_ctx.MultiplicityInfo):
        raise errors.QueryError(
            'could not determine the multiplicity of '
            'set produced by expression',
            span=ir.span)

    ctx.inferred_multiplicity[ir, scope_tree, ctx.distinct_iterator] = result

    return result


================================================
FILE: edb/edgeql/compiler/inference/utils.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2021-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Common utilities used in inferers."""


from __future__ import annotations
from typing import Optional

from edb import errors
from edb.ir import ast as irast

from . import context as inf_ctx


def get_set_scope(
    ir_set: irast.Set,
    scope_tree: irast.ScopeTreeNode,
    ctx: inf_ctx.InfCtx,
) -> irast.ScopeTreeNode:

    if ir_set.path_scope_id:
        new_scope = ctx.env.scope_tree_nodes.get(ir_set.path_scope_id)
        if new_scope is None:
            raise errors.InternalServerError(
                f'dangling scope pointer to node with uid'
                f':{ir_set.path_scope_id} in {ir_set!r}'
            )
    else:
        new_scope = scope_tree

    return new_scope


def find_visible(
    ir: irast.Set,
    scope_tree: irast.ScopeTreeNode,
) -> Optional[irast.ScopeTreeNode]:
    # We want to look one fence up from whatever our current fence is.
    # (Most of the time, scope_tree will be a fence, so this is equivalent
    # to parent_fence, but sometimes it will be a branch.)
    outer_fence = scope_tree.fence.parent_fence
    if outer_fence is not None:
        if scope_tree.namespaces:
            path_id = ir.path_id.strip_namespace(scope_tree.namespaces)
        else:
            path_id = ir.path_id

        return outer_fence.find_visible(path_id)
    else:
        return None


================================================
FILE: edb/edgeql/compiler/inference/volatility.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations
from typing import Iterable

import functools

from edb import errors

from edb.edgeql import qltypes

from edb.ir import ast as irast
from edb.ir import typeutils as irtyputils

from .. import context


InferredVolatility = context.InferredVolatility


IMMUTABLE = qltypes.Volatility.Immutable
STABLE = qltypes.Volatility.Stable
VOLATILE = qltypes.Volatility.Volatile
MODIFYING = qltypes.Volatility.Modifying


# Volatility inference computes two volatility results:
# A basic one, and one for consumption by materialization.
#
# The one for consumption by materialization differs in that it
# (counterintuitively) does not consider DML to be volatile/modifying
# (since DML has its own "materialization" mechanism).
#
# We represent this output as a pair, but for ergonomics, inference
# functions are allowed to still produce a single volatility value,
# which is normalized when necessary.


def _normalize_volatility(
    vol: InferredVolatility,
) -> tuple[qltypes.Volatility, qltypes.Volatility]:
    if not isinstance(vol, tuple):
        return (vol, vol)
    else:
        return vol


def _max_volatility(args: Iterable[InferredVolatility]) -> InferredVolatility:
    arg_list = list(args)
    if not arg_list:
        return IMMUTABLE
    else:
        nargs = [_normalize_volatility(x) for x in arg_list]
        return (
            max(x[0] for x in nargs),
            max(x[1] for x in nargs),
        )


def _common_volatility(
    args: Iterable[irast.Base],
    env: context.Environment,
) -> InferredVolatility:
    return _max_volatility(
        _infer_volatility(a, env) for a in args)


@functools.singledispatch
def _infer_volatility_inner(
    ir: irast.Base,
    env: context.Environment,
) -> InferredVolatility:
    raise ValueError(f'infer_volatility: cannot handle {ir!r}')


@_infer_volatility_inner.register(type(None))
def __infer_none(
    ir: None,
    env: context.Environment,
) -> InferredVolatility:
    # Here for debugging purposes.
    raise ValueError('invalid infer_volatility(None, schema) call')


@_infer_volatility_inner.register
def __infer_statement(
    ir: irast.Statement,
    env: context.Environment,
) -> InferredVolatility:
    return _infer_volatility(ir.expr, env)


@_infer_volatility_inner.register
def __infer_config_command(
    ir: irast.ConfigCommand,
    env: context.Environment,
) -> InferredVolatility:
    return VOLATILE


@_infer_volatility_inner.register
def __infer_emptyset(
    ir: irast.EmptySet,
    env: context.Environment,
) -> InferredVolatility:
    return IMMUTABLE


@_infer_volatility_inner.register
def __infer_typeref(
    ir: irast.TypeRef,
    env: context.Environment,
) -> InferredVolatility:
    return IMMUTABLE


@_infer_volatility_inner.register
def __infer_type_introspection(
    ir: irast.TypeIntrospection,
    env: context.Environment,
) -> InferredVolatility:
    return IMMUTABLE


@_infer_volatility_inner.register
def __infer_type_root(
    ir: irast.TypeRoot,
    env: context.Environment,
) -> InferredVolatility:
    return STABLE


@_infer_volatility_inner.register
def __infer_cleared_expr(
    ir: irast.RefExpr,
    env: context.Environment,
) -> InferredVolatility:
    return IMMUTABLE


@_infer_volatility_inner.register
def _infer_pointer(
    ir: irast.Pointer,
    env: context.Environment,
) -> InferredVolatility:
    vol = _infer_volatility(ir.source, env)
    # If there's an expression on an rptr, and it comes from
    # the schema, we need to actually infer it, since it won't
    # have been processed at a shape declaration.
    if ir.expr is not None and not ir.ptrref.defined_here:
        vol = _max_volatility((
            vol,
            _infer_volatility(ir.expr, env),
        ))

    # If source is an object, then a pointer reference implies
    # a table scan, and so we can assume STABLE at the minimum.
    #
    # A single dereference of a singleton path can be IMMUTABLE,
    # though, which we need in order to enforce that indexes
    # don't call STABLE functions.
    if (
        irtyputils.is_object(ir.source.typeref)
        and ir.source.path_id not in env.singletons
    ):
        vol = _max_volatility((vol, STABLE))

    return vol


@_infer_volatility_inner.register
def __infer_set(
    ir: irast.Set,
    env: context.Environment,
) -> InferredVolatility:
    vol: InferredVolatility

    if ir.path_id in env.singletons:
        vol = IMMUTABLE
    else:
        vol = _infer_volatility(ir.expr, env)

    # Cache our best-known as to this point volatility, to prevent
    # infinite recursion.
    env.inferred_volatility[ir] = vol

    if ir.shape:
        vol = _max_volatility([
            _common_volatility(
                (el.expr.expr for el, _ in ir.shape if el.expr.expr), env
            ),
            vol,
        ])

    if ir.is_binding and ir.is_binding != irast.BindingKind.Schema:
        vol = IMMUTABLE

    return vol


@_infer_volatility_inner.register
def __infer_func_call(
    ir: irast.FunctionCall,
    env: context.Environment,
) -> InferredVolatility:
    func_volatility = (
        _infer_volatility(ir.body, env) if ir.body else ir.volatility
    )

    if ir.args:
        return _max_volatility([
            _common_volatility((arg.expr for arg in ir.args.values()), env),
            func_volatility
        ])
    else:
        return func_volatility


@_infer_volatility_inner.register
def __infer_oper_call(
    ir: irast.OperatorCall,
    env: context.Environment,
) -> InferredVolatility:
    if ir.args:
        return _max_volatility([
            _common_volatility((arg.expr for arg in ir.args.values()), env),
            ir.volatility
        ])
    else:
        return ir.volatility


@_infer_volatility_inner.register
def __infer_const(
    ir: irast.BaseConstant,
    env: context.Environment,
) -> InferredVolatility:
    return IMMUTABLE


@_infer_volatility_inner.register
def __infer_param(
    ir: irast.QueryParameter,
    env: context.Environment,
) -> InferredVolatility:
    return STABLE if ir.is_global else IMMUTABLE


@_infer_volatility_inner.register
def __infer_function_param(
    ir: irast.FunctionParameter,
    env: context.Environment,
) -> InferredVolatility:
    return STABLE if ir.is_global else IMMUTABLE


@_infer_volatility_inner.register
def __infer_inlined_param(
    ir: irast.InlinedParameterExpr,
    env: context.Environment,
) -> InferredVolatility:
    return STABLE if ir.is_global else IMMUTABLE


@_infer_volatility_inner.register
def __infer_const_set(
    ir: irast.ConstantSet,
    env: context.Environment,
) -> InferredVolatility:
    return IMMUTABLE


@_infer_volatility_inner.register
def __infer_typecheckop(
    ir: irast.TypeCheckOp,
    env: context.Environment,
) -> InferredVolatility:
    return _infer_volatility(ir.left, env)


@_infer_volatility_inner.register
def __infer_typecast(
    ir: irast.TypeCast,
    env: context.Environment,
) -> InferredVolatility:
    return _infer_volatility(ir.expr, env)


@_infer_volatility_inner.register
def __infer_select_stmt(
    ir: irast.SelectStmt,
    env: context.Environment,
) -> InferredVolatility:
    components = []

    if ir.iterator_stmt is not None:
        components.append(ir.iterator_stmt)

    components.append(ir.result)

    if ir.where is not None:
        components.append(ir.where)

    if ir.orderby:
        components.extend(o.expr for o in ir.orderby)

    if ir.offset is not None:
        components.append(ir.offset)

    if ir.limit is not None:
        components.append(ir.limit)

    if ir.bindings is not None:
        components.extend(part for part, _ in ir.bindings)

    return _common_volatility(components, env)


@_infer_volatility_inner.register
def __infer_group_stmt(
    ir: irast.GroupStmt,
    env: context.Environment,
) -> InferredVolatility:
    components = [ir.subject, ir.result] + [v for v, _ in ir.using.values()]
    return _common_volatility(components, env)


@_infer_volatility_inner.register
def __infer_trigger_anchor(
    ir: irast.TriggerAnchor,
    env: context.Environment,
) -> InferredVolatility:
    return STABLE, STABLE


@_infer_volatility_inner.register
def __infer_searchable_string(
    ir: irast.FTSDocument,
    env: context.Environment,
) -> InferredVolatility:
    return _common_volatility([ir.text, ir.language], env)


@_infer_volatility_inner.register
def __infer_dml_stmt(
    ir: irast.MutatingStmt,
    env: context.Environment,
) -> InferredVolatility:
    # For materialization purposes, DML is not volatile.  (Since it
    # has its *own* elaborate mechanism using top-level CTEs).
    return MODIFYING, STABLE


@_infer_volatility_inner.register
def __infer_slice(
    ir: irast.SliceIndirection,
    env: context.Environment,
) -> InferredVolatility:
    # slice indirection volatility depends on the volatility of
    # the base expression and the slice index expressions
    args: list[irast.Base] = [ir.expr]
    if ir.start is not None:
        args.append(ir.start)
    if ir.stop is not None:
        args.append(ir.stop)

    return _common_volatility(args, env)


@_infer_volatility_inner.register
def __infer_index(
    ir: irast.IndexIndirection,
    env: context.Environment,
) -> InferredVolatility:
    # index indirection volatility depends on both the volatility of
    # the base expression and the index expression
    return _common_volatility([ir.expr, ir.index], env)


@_infer_volatility_inner.register
def __infer_array(
    ir: irast.Array,
    env: context.Environment,
) -> InferredVolatility:
    return _common_volatility(ir.elements, env)


@_infer_volatility_inner.register
def __infer_tuple(
    ir: irast.Tuple,
    env: context.Environment,
) -> InferredVolatility:
    return _common_volatility(
        [el.val for el in ir.elements], env)


def _infer_volatility(
    ir: irast.Base,
    env: context.Environment,
) -> InferredVolatility:
    result = env.inferred_volatility.get(ir)
    if result is not None:
        return result

    result = _infer_volatility_inner(ir, env)

    env.inferred_volatility[ir] = result

    return result


def infer_volatility(
    ir: irast.Base,
    env: context.Environment,
    *,
    exclude_dml: bool=False,
) -> qltypes.Volatility:
    result = _normalize_volatility(_infer_volatility(ir, env))[exclude_dml]

    if result not in {VOLATILE, STABLE, IMMUTABLE, MODIFYING}:
        raise errors.QueryError(
            'could not determine the volatility of '
            'set produced by expression',
            span=ir.span)

    return result


================================================
FILE: edb/edgeql/compiler/normalization.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2020-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either nodeess or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL expression normalization functions."""


from __future__ import annotations
from typing import (
    Any,
    Optional,
    AbstractSet,
    Mapping,
    Collection,
    cast,
)

import functools

from edb.common.ast import base

from edb.edgeql import ast as qlast
from edb.edgeql import parser as qlparser

from edb.schema import name as sn
from edb.schema import schema as s_schema
from edb.schema import functions as s_func
from edb.schema import utils as s_utils


@functools.singledispatch
def normalize(
    node: Any,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> None:
    raise AssertionError(f'normalize: cannot handle {node!r}')


def renormalize_compat(
    norm_qltree: qlast.Base_T,
    orig_text: str,
    *,
    schema: s_schema.Schema,
    localnames: AbstractSet[str] = frozenset(),
) -> qlast.Base_T:
    """Renormalize an expression normalized with imprint_expr_context().

    This helper takes the original, unmangled expression, an EdgeQL AST
    tree of the same expression mangled with `imprint_expr_context()`
    (which injects extra WITH MODULE clauses), and produces a normalized
    expression with explicitly qualified identifiers instead.  Old dumps
    are the main user of this facility.
    """
    orig_qltree = qlparser.parse_fragment(orig_text)

    norm_aliases: dict[Optional[str], str] = {}
    assert isinstance(norm_qltree, (
        qlast.Query, qlast.Command, qlast.DDLCommand
    ))
    for alias in (norm_qltree.aliases or ()):
        if isinstance(alias, qlast.ModuleAliasDecl):
            norm_aliases[alias.alias] = alias.module

    if isinstance(orig_qltree, (
        qlast.Query, qlast.Command, qlast.DDLCommand
    )):
        orig_aliases: dict[Optional[str], str] = {}
        for alias in (orig_qltree.aliases or ()):
            if isinstance(alias, qlast.ModuleAliasDecl):
                orig_aliases[alias.alias] = alias.module

        modaliases = {
            k: v
            for k, v in norm_aliases.items()
            if k not in orig_aliases
        }
    else:
        modaliases = norm_aliases

    normalize(
        orig_qltree,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )

    assert isinstance(orig_qltree, type(norm_qltree))
    return cast(qlast.Base_T, orig_qltree)


def _normalize_recursively(
    node: qlast.Base,
    value: Any,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> None:
    # We only want to handle fields that need to be traversed
    # recursively: Base AST and lists. Other fields are essentially
    # expected to be processed by the more specific handlers.
    if isinstance(value, qlast.Base):
        normalize(
            value,
            schema=schema,
            modaliases=modaliases,
            localnames=localnames,
        )
    elif isinstance(value, (tuple, list)):
        if value and isinstance(value[0], qlast.Base):
            for el in value:
                normalize(
                    el,
                    schema=schema,
                    modaliases=modaliases,
                    localnames=localnames,
                )


@normalize.register
def normalize_generic(
    node: qlast.Base,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
    skip: Collection[str] = frozenset(),
) -> None:
    for field, value in base.iter_fields(node):
        if field not in skip:
            _normalize_recursively(
                node,
                value,
                schema=schema,
                modaliases=modaliases,
                localnames=localnames,
            )


# This is the heart of the whole thing.
@normalize.register
def normalize_ObjectRef(
    ref: qlast.ObjectRef,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> None:
    if ref.name not in localnames:
        obj = schema.get(
            s_utils.ast_ref_to_name(ref),
            default=None,
            module_aliases=modaliases,
        )
        if obj is not None:
            name = obj.get_name(schema)
            assert isinstance(name, sn.QualName)
            ref.module = name.module
        elif ref.module in modaliases:
            # Even if the name was not resolved in the
            # schema it may be the name of the object
            # being defined, as such the default module
            # should be used. Names that must be ignored
            # (like aliases and parameters) have already
            # been filtered by the localnames.
            ref.module = modaliases[ref.module]


def _normalize_with_block(
    node: qlast.Query,
    *,
    field: str='aliases',
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> tuple[Mapping[Optional[str], str], AbstractSet[str]]:

    # Update the default aliases, modaliases, and localnames.
    modaliases = dict(modaliases)
    newaliases: list[qlast.AliasedExpr | qlast.ModuleAliasDecl] = []

    aliases: Optional[list[qlast.AliasedExpr]] = getattr(node, field)
    for alias in (aliases or ()):
        if isinstance(alias, qlast.ModuleAliasDecl):
            if alias.alias:
                modaliases[alias.alias] = alias.module
            else:
                modaliases[None] = alias.module
        else:
            assert isinstance(alias, qlast.AliasedExpr)
            normalize(
                alias.expr,
                schema=schema,
                modaliases=modaliases,
                localnames=localnames,
            )
            newaliases.append(alias)
            localnames = {alias.alias} | localnames

    setattr(node, field, newaliases)

    return modaliases, localnames


def _normalize_aliased_field(
    node: qlast.SubjectQuery | qlast.ReturningQuery,
    fname: str,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> AbstractSet[str]:

    # Potentially the result defines an alias that is visible in other
    # clauses
    val = getattr(node, fname)
    normalize(
        val,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )
    alias = getattr(node, f'{fname}_alias', None)
    if alias:
        localnames = {alias} | localnames

    return localnames


@normalize.register
def normalize_SelectQuery(
    node: qlast.SelectQuery,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> None:

    # Process WITH block
    modaliases, localnames = _normalize_with_block(
        node,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )

    # Process the result expression
    localnames = _normalize_aliased_field(
        node,
        'result',
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )

    normalize_generic(
        node,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
        skip=('aliases', 'result'),
    )


@normalize.register(qlast.InsertQuery)
@normalize.register(qlast.UpdateQuery)
@normalize.register(qlast.DeleteQuery)
def normalize_DML(
    node: qlast.InsertQuery | qlast.UpdateQuery | qlast.DeleteQuery,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> None:

    # Process WITH block
    modaliases, localnames = _normalize_with_block(
        node,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )

    normalize_generic(
        node,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
        skip=('aliases',),
    )


@normalize.register
def normalize_ForQuery(
    node: qlast.ForQuery,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> None:

    # Process WITH block
    modaliases, localnames = _normalize_with_block(
        node,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )

    # Process the iterator expression
    localnames = _normalize_aliased_field(
        node,
        'iterator',
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )

    # Process the rest
    normalize_generic(
        node,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
        skip=('aliases', 'iterator'),
    )


@normalize.register
def normalize_GroupQuery(
    node: qlast.GroupQuery,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> None:
    # Process WITH block
    modaliases, localnames = _normalize_with_block(
        node,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )

    # Process the result expression
    localnames = _normalize_aliased_field(
        node,
        'subject',
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )

    modaliases, localnames = _normalize_with_block(
        node,
        field='using',
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
    )

    normalize_generic(
        node,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
        skip=('aliases', 'subject', 'using'),
    )


@normalize.register
def normalize_FunctionCall(
    node: qlast.FunctionCall,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> None:

    if node.func not in localnames:
        name = (
            sn.UnqualName(node.func) if isinstance(node.func, str)
            else sn.QualName(*node.func)
        )
        funcs = s_func.lookup_functions(
            name, default=tuple(), module_aliases=modaliases, schema=schema,
        )
        if funcs:
            # As long as we found some functions, they will be from
            # the same module (the first valid resolved module for the
            # function name will mask "std").
            sname = funcs[0].get_shortname(schema)
            node.func = (sname.module, sname.name)

        elif modaliases and isinstance(name, sn.QualName):
            # Even if no function was found, apply the modaliases.
            # It is possible that a function without the modalias exists but
            # we don't want to find that.
            #
            # Eg.
            # module A {
            #     function foo() -> int64 using (1);
            # }
            # module B {}
            # module default {
            #     alias query := (with A as module B select A::foo() );
            # }
            module = s_schema.apply_module_aliases(name.module, modaliases)
            if module is not None:
                node.func = (module, name.name)

        # It's odd we don't find a function, but this will be picked up
        # by the compiler with a more appropriate error message.

    for arg in node.args:
        normalize(
            arg,
            schema=schema,
            modaliases=modaliases,
            localnames=localnames,
        )

    for val in node.kwargs.values():
        normalize(
            val,
            schema=schema,
            modaliases=modaliases,
            localnames=localnames,
        )


@normalize.register
def compile_TypeName(
    node: qlast.TypeName,
    *,
    schema: s_schema.Schema,
    modaliases: Mapping[Optional[str], str],
    localnames: AbstractSet[str] = frozenset(),
) -> None:

    # Resolve the main type
    if isinstance(node.maintype, qlast.ObjectRef):
        # This is a specific path root, resolve it.
        if (
            # maintype names 'array', 'tuple', 'range', and 'multirange'
            # specifically should also be ignored
            node.maintype.name not in {
                'array', 'tuple', 'range', 'multirange',
            }
        ):
            normalize(
                node.maintype,
                schema=schema,
                modaliases=modaliases,
                localnames=localnames,
            )

    normalize_generic(
        node,
        schema=schema,
        modaliases=modaliases,
        localnames=localnames,
        skip=('maintype',),
    )


================================================
FILE: edb/edgeql/compiler/options.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL compiler options."""


from __future__ import annotations
from typing import Any, Optional, Mapping, Collection, TYPE_CHECKING

from dataclasses import dataclass, field as dc_field

if TYPE_CHECKING:
    from edb.schema import functions as s_func
    from edb.schema import objects as s_obj
    from edb.schema import name as s_name
    from edb.schema import types as s_types
    from edb.schema import pointers as s_pointers
    from edb.ir import pathid
    from edb.edgeql import qltypes

    SourceOrPathId = s_types.Type | s_pointers.Pointer | pathid.PathId


@dataclass(kw_only=True)
class GlobalCompilerOptions:
    """Compiler toggles that affect compilation as a whole."""

    #: Whether to allow the expression to be of a generic type.
    allow_generic_type_output: bool = False

    #: Whether to apply various query rewrites, including access policy.
    apply_query_rewrites: bool = True

    #: Whether to apply user-specified access policies
    apply_user_access_policies: bool = True

    #: Whether to allow specifying 'id' explicitly in INSERT
    allow_user_specified_id: bool = False

    #: Force types of all parameters to std::json
    json_parameters: bool = False

    #: Use material types for pointer targets in schema views.
    schema_view_mode: bool = False

    #: True in compile_bootstrap_script().
    bootstrap_mode: bool = False

    #: Whether to track which subexpressions reference each schema object.
    track_schema_ref_exprs: bool = False

    #: If the expression is being processed in the context of a certain
    #: schema object, i.e. a constraint expression, or a pointer default,
    #: this contains the type of the schema object.
    schema_object_context: Optional[type[s_obj.Object]] = None

    #: When compiling a function body, the function name.
    func_name: Optional[s_name.QualName] = None

    #: When compiling a function body, specifies function parameter
    #: definitions.
    func_params: Optional[s_func.ParameterLikeList] = None

    #: Should the backend compiler expand inheritance CTEs in place.
    #: This is needed by EXPLAIN to maintain alias names in
    #: the query plan.
    is_explain: bool = False

    #: The name that can be used in a "DML is disallowed in ..."
    #: error. When this is not None, any DML should cause an error.
    in_ddl_context_name: Optional[str] = None

    #: Whether to just treat all globals as empty instead of compiling them.
    #: This is used when populating something using `SET default` in DDL.
    make_globals_empty: bool = False

    #: Is the compiler running in testmode
    testmode: bool = False

    # Is the compiler running in the server's schema reflection mode
    schema_reflection_mode: bool = False

    # are we invoking the compiler from inside a CONFIGURE?
    in_server_config_op: bool = False

    # This this restoring a dump?
    dump_restore_mode: bool = False


@dataclass(kw_only=True)
class CompilerOptions(GlobalCompilerOptions):

    #: Module name aliases.
    modaliases: Mapping[Optional[str], str] = dc_field(default_factory=dict)

    #: External symbol table.
    anchors: Mapping[str, Any] = dc_field(default_factory=dict)

    #: The symbol to assume as the prefix for abbreviated paths.
    path_prefix_anchor: Optional[str] = None

    #: Module to put derived schema objects to.
    derived_target_module: Optional[str] = None

    #: The name to use for the top-level type variant.
    result_view_name: Optional[s_name.QualName] = None

    #: If > 0, Inject implicit LIMIT to every SELECT query.
    implicit_limit: int = 0

    #: Include id property in every shape implicitly.
    implicit_id_in_shapes: bool = False

    #: Include __tid__ computable (.__type__.id) in every shape implicitly.
    implicit_tid_in_shapes: bool = False

    #: Include __tname__ computable (.__type__.name) in every shape implicitly.
    implicit_tname_in_shapes: bool = False

    #: A set of schema types and links that should be treated
    #: as singletons in the context of this compilation.
    #: If a tuple is provided, the boolean argument indicates it is optional.
    singletons: Collection[
        SourceOrPathId | tuple[SourceOrPathId, bool]
    ] = frozenset()

    #: Type references that should be remaped to another type.  This
    #: is for dealing with remapping explicit type names in schema
    #: expressions to their subtypes when necessary.
    type_remaps: dict[s_obj.Object, s_obj.Object] = dc_field(
        default_factory=dict
    )

    detached: bool = False

    #: In order to prevent recursive triggers, these fields are used to track
    #: the sources of a given trigger. These will only be present if
    #: schema_object_context is set to Trigger.
    trigger_type: Optional[s_types.Type] = None
    trigger_kinds: Optional[Collection[qltypes.TriggerKind]] = None


================================================
FILE: edb/edgeql/compiler/pathctx.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL compiler path scope helpers."""


from __future__ import annotations

from typing import Literal, Optional, AbstractSet

from edb import errors

from edb.ir import ast as irast
from edb.ir import typeutils as irtyputils

from edb.schema import name as s_name
from edb.schema import pointers as s_pointers
from edb.schema import types as s_types

from . import context


def get_path_id(
    stype: s_types.Type,
    *,
    typename: Optional[s_name.QualName] = None,
    ctx: context.ContextLevel,
) -> irast.PathId:
    return irast.PathId.from_type(
        ctx.env.schema,
        stype,
        typename=typename,
        env=ctx.env,
        namespace=ctx.path_id_namespace)


def get_tuple_indirection_path_id(
    tuple_path_id: irast.PathId,
    element_name: str,
    element_type: s_types.Type,
    *,
    ctx: context.ContextLevel,
) -> irast.PathId:

    ctx.env.schema, src_t = irtyputils.ir_typeref_to_type(
        ctx.env.schema, tuple_path_id.target)
    ptrcls = irast.TupleIndirectionLink(
        src_t,
        element_type,
        element_name=element_name,
    )

    ptrref = irtyputils.ptrref_from_ptrcls(
        schema=ctx.env.schema,
        ptrcls=ptrcls,
        cache=ctx.env.ptr_ref_cache,
        typeref_cache=ctx.env.type_ref_cache,
    )

    return tuple_path_id.extend(ptrref=ptrref)


def get_expression_path_id(
    stype: s_types.Type,
    alias: Optional[str] = None,
    *,
    ctx: context.ContextLevel,
) -> irast.PathId:
    if alias is None:
        alias = ctx.aliases.get('expr')
    typename = s_name.QualName(module='__derived__', name=alias)
    return get_path_id(stype, typename=typename, ctx=ctx)


def register_set_in_scope(
    ir_set: irast.Set,
    *,
    path_scope: Optional[irast.ScopeTreeNode] = None,
    optional: bool = False,
    ctx: context.ContextLevel,
) -> None:
    if path_scope is None:
        path_scope = ctx.path_scope

    path_scope.attach_path(
        ir_set.path_id,
        optional=optional,
        span=ir_set.span,
        ctx=ctx,
    )


def assign_set_scope(
    ir_set: irast.Set,
    scope: Optional[irast.ScopeTreeNode],
    *,
    ctx: context.ContextLevel,
) -> irast.Set:
    if scope is None:
        ir_set.path_scope_id = None
    else:
        if scope.unique_id is None:
            scope.unique_id = ctx.scope_id_ctr.nextval()
            ctx.env.scope_tree_nodes[scope.unique_id] = scope
        ir_set.path_scope_id = scope.unique_id
        if scope.find_child(ir_set.path_id):
            raise RuntimeError('scoped set must not contain itself')

    return ir_set


def get_set_scope(
    ir_set: irast.Set,
    *,
    ctx: context.ContextLevel,
) -> Optional[irast.ScopeTreeNode]:
    if ir_set.path_scope_id is None:
        return None
    else:
        scope = ctx.env.scope_tree_nodes.get(ir_set.path_scope_id)
        if scope is None:
            raise errors.InternalServerError(
                f'dangling scope pointer to node with uid'
                f':{ir_set.path_scope_id} in {ir_set!r}'
            )
        return scope


def extend_path_id(
    path_id: irast.PathId,
    *,
    ptrcls: s_pointers.PointerLike,
    direction: s_pointers.PointerDirection = (
        s_pointers.PointerDirection.Outbound),
    ns: AbstractSet[str] = frozenset(),
    ctx: context.ContextLevel,
) -> irast.PathId:
    """A wrapper over :meth:`ir.pathid.PathId.extend` that also ensures
       the cardinality of *ptrcls* is known at the end of compilation.
    """

    ptrref = irtyputils.ptrref_from_ptrcls(
        schema=ctx.env.schema,
        ptrcls=ptrcls,
        cache=ctx.env.ptr_ref_cache,
        typeref_cache=ctx.env.type_ref_cache,
    )

    return path_id.extend(ptrref=ptrref, direction=direction, ns=ns)


def ban_inserting_path(
    path_id: irast.PathId,
    *,
    location: Literal['body'] | Literal['else'],
    ctx: context.ContextLevel,
) -> None:

    ctx.inserting_paths = ctx.inserting_paths.copy()
    ctx.inserting_paths[path_id] = location


def path_is_inserting(
    path_id: irast.PathId, *, ctx: context.ContextLevel
) -> bool:

    node = ctx.path_scope.find_visible(path_id)
    return bool(
        node
        and node.path_id
        and ctx.inserting_paths.get(node.path_id) == 'body'
    )


================================================
FILE: edb/edgeql/compiler/policies.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL access policy compilation."""


from __future__ import annotations

from typing import Optional

from edb.ir import ast as irast

from edb.schema import name as s_name
from edb.schema import objtypes as s_objtypes
from edb.schema import policies as s_policies
from edb.schema import schema as s_schema
from edb.schema import types as s_types
from edb.schema import expr as s_expr

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from edb.ir import typeutils as irtyputils

from . import astutils
from . import context
from . import dispatch
from . import setgen


def should_ignore_rewrite(
    stype: s_types.Type,
    *,
    ctx: context.ContextLevel,
) -> bool:
    if not ctx.suppress_rewrites:
        return False

    if stype in ctx.suppress_rewrites:
        return True

    # If we are in any access policy at all, suppress all
    # policies except the stdlib ones.
    #
    # (Eventually will might do a generalization of this based on
    # RBAC ownership of schema objects.)
    # XXX: extension modules???
    schema = ctx.env.schema
    if (
        isinstance(stype, s_objtypes.ObjectType)
        and s_name.UnqualName(stype.get_name(schema).module)
            not in s_schema.STD_MODULES
    ):
        return True

    return False


def get_access_policies(
    stype: s_objtypes.ObjectType,
    *,
    ctx: context.ContextLevel,
) -> tuple[s_policies.AccessPolicy, ...]:
    schema = ctx.env.schema
    if not ctx.env.options.apply_query_rewrites:
        return ()

    # The apply_access_policies config flag disables user-specified
    # access polices, but not stdlib ones
    if (
        not ctx.env.options.apply_user_access_policies
        and s_name.UnqualName(stype.get_name(schema).module)
            not in s_schema.STD_MODULES
    ):
        return ()

    return stype.get_access_policies(schema).objects(schema)


def has_own_policies(
    *,
    stype: s_objtypes.ObjectType,
    skip_from: Optional[s_objtypes.ObjectType]=None,
    ctx: context.ContextLevel,
) -> bool:
    # TODO: some kind of caching or precomputation

    schema = ctx.env.schema
    for pol in get_access_policies(stype, ctx=ctx):
        if not any(
            skip_from == base.get_subject(schema)
            for base in pol.get_bases(schema).objects(schema)
        ):
            return True

    return any(
        has_own_policies(stype=child, skip_from=stype, ctx=ctx)
        for child in stype.children(schema)
        if not irtyputils.is_excluded_cfg_view(
            child, ancestor=stype, schema=schema
        )
    )


def compile_pol(
    pol: s_policies.AccessPolicy,
    *,
    ctx: context.ContextLevel,
) -> irast.Set:
    """Compile the condition from an individual policy.

    A policy is evaluated in a context where it is allowed to access
    the *original subject type of the policy* and *all of its
    descendants*.

    Because it is based on the original source of the policy,
    we need to compile each policy separately.
    """
    schema = ctx.env.schema

    expr_field: Optional[s_expr.Expression] = pol.get_expr(schema)
    if expr_field:
        expr = expr_field.parse()
    else:
        expr = qlast.Constant.boolean(True)

    if condition := pol.get_condition(schema):
        assert isinstance(condition, s_expr.Expression)
        expr = qlast.BinOp(op='AND', left=condition.parse(), right=expr)

    # Find all descendants of the original subject of the rule
    subject = pol.get_original_subject(schema)
    descs = {subject} | {
        desc for desc in subject.descendants(schema)
        if desc.is_material_object_type(schema)
    }

    # Compile it with all of the
    with ctx.newscope(fenced=True) as _, _.detached() as dctx:
        dctx.partial_path_prefix = ctx.partial_path_prefix
        dctx.expr_exposed = context.Exposure.UNEXPOSED
        dctx.suppress_rewrites = frozenset(descs)

        return setgen.scoped_set(dispatch.compile(expr, ctx=dctx), ctx=dctx)


def get_extra_function_rewrite_filter(ctx: context.ContextLevel) -> qlast.Expr:
    # Functions need to check whether access policies are disabled,
    # which is signalled through a field in globals json object.
    # It's only populated when policies are disabled.
    #
    # We could also have done this by checking
    # cfg::Config.apply_access_policies, but that's probably slower,
    # and we have this mechanism anyway.
    json_type = qlast.TypeName(maintype=qlast.ObjectRef(
        module='__std__', name='json'))
    glob_set = setgen.get_func_global_json_arg(ctx=ctx)
    func_override = qlast.FunctionCall(
        func=('__std__', 'json_get'),
        args=[
            ctx.create_anchor(glob_set, 'a'),
            qlast.Constant.string(value="__disable_access_policies"),
        ],
        kwargs={
            'default': qlast.TypeCast(
                expr=qlast.Constant.boolean(False),
                type=json_type,
            )
        },
    )
    return qlast.TypeCast(
        expr=func_override,
        type=qlast.TypeName(maintype=qlast.ObjectRef(
            module='__std__', name='bool'))
    )


def get_rewrite_filter(
    stype: s_objtypes.ObjectType,
    *,
    mode: qltypes.AccessKind,
    ctx: context.ContextLevel,
) -> Optional[qlast.Expr]:
    schema = ctx.env.schema
    pols = get_access_policies(stype, ctx=ctx)
    if not pols:
        return None

    ctx.anchors = ctx.anchors.copy()

    allow, deny = [], []
    for pol in pols:
        if mode not in pol.get_access_kinds(schema):
            continue

        ir_set = compile_pol(pol, ctx=ctx)
        expr: qlast.Expr = ctx.create_anchor(ir_set, move_scope=True)

        is_allow = pol.get_action(schema) == qltypes.AccessPolicyAction.Allow
        if is_allow:
            allow.append(expr)
        else:
            deny.append(expr)

    if ctx.env.options.func_params is not None:
        allow.append(get_extra_function_rewrite_filter(ctx))

    if allow:
        filter_expr = astutils.extend_binop(None, *allow, op='OR')
    else:
        filter_expr = qlast.Constant.boolean(False)

    if deny:
        deny_expr = qlast.UnaryOp(
            op='NOT',
            operand=astutils.extend_binop(None, *deny, op='OR')
        )
        filter_expr = astutils.extend_binop(filter_expr, deny_expr)

    # We compile the expression again so we can do an IR based
    # analysis on it below.
    with ctx.newscope(fenced=False) as dctx:
        # HACK: to prevent filter_ir from being warning fenced
        dctx.allow_factoring()
        dctx.expr_exposed = context.Exposure.UNEXPOSED
        filter_ir = dispatch.compile(filter_expr, ctx=dctx)
        filter_expr = setgen.moveable_anchor(filter_ir, ctx=dctx)

    # This is a bad hack, but add an always false condition that
    # postgres does not *know* is always false. This prevents postgres
    # from bogusly optimizing away the entire type CTE if it can prove
    # it empty (which could then result in assert_exists on links to
    # the type not always firing).
    #
    # As an optimization, we try to only do it when the object might
    # not be referenced.
    if (
        mode == qltypes.AccessKind.Select
        and not (
            ctx.partial_path_prefix
            and _always_references_set(filter_ir, ctx.partial_path_prefix)
        )
    ):
        bogus_check = qlast.BinOp(
            op='?=',
            left=qlast.Path(partial=True, steps=[qlast.Ptr(name='id')]),
            right=qlast.TypeCast(
                type=qlast.TypeName(maintype=qlast.ObjectRef(
                    module='__std__', name='uuid')),
                expr=qlast.Set(elements=[]),
            )
        )
        filter_expr = astutils.extend_binop(filter_expr, bogus_check, op='OR')

    return filter_expr


def _always_references_set(
    ir: irast.Set | irast.Expr,
    ref: irast.Set,
    inverted: bool=False,
) -> bool:
    """Return whether *ir* "always references" *ref*

    The idea here is to check whether *ir* references *ref* in a way
    that ensures that postgres will actually look at *ref* when
    executing.

    Fortunately postgres doesn't seem to do anything too crazy here(??),
    so we mostly just have to understand how it works with boolean
    operators, IF, and coalesce.
    But we also need to be able to propagate it through other operations.

    We need *ref* to be referenced in *every* conjunct of an AND.
    We need it referenced by *at least one* disjunct of an OR.
    But because of DeMorgan's law (which postgres understands),
    OR sometimes needs to be treated like AND.

    We track the invertedness and invert the AND and OR behavior when
    underneath a NOT, kind of for fun.
    """
    if isinstance(ir, irast.Set):
        if ir is ref:
            return True
        ir = ir.expr

    match ir:
        case irast.SelectStmt(result=result):
            return _always_references_set(result, ref, inverted)

        case irast.OperatorCall(
            func_shortname=s_name.QualName('std', 'OR'), args=args
        ):
            check = all if inverted else any
            return check(
                _always_references_set(x.expr, ref, inverted)
                for x in args.values()
            )

        case irast.OperatorCall(
            func_shortname=s_name.QualName('std', 'AND'), args=args
        ):
            check = any if inverted else all
            return check(
                _always_references_set(x.expr, ref, inverted)
                for x in args.values()
            )

        case irast.OperatorCall(
            func_shortname=s_name.QualName('std', 'NOT'), args={0: arg}
        ):
            return _always_references_set(arg.expr, ref, not inverted)

        case irast.OperatorCall(
            func_shortname=s_name.QualName('std', '??'), args={0: lhs, 1: _},
        ):
            # LHS always evaluated; RHS might not be
            return _always_references_set(lhs.expr, ref, inverted)

        case irast.OperatorCall(
            func_shortname=s_name.QualName('std', 'IF'),
            args={0: t, 1: c, 2: f},
        ):
            return (
                _always_references_set(c.expr, ref, inverted)
                or (
                    _always_references_set(t.expr, ref, inverted)
                    and _always_references_set(f.expr, ref, inverted)
                )
            )

        # Any other call, we use 'any' semantics.
        case irast.Call(args=args):
            return any(
                _always_references_set(x.expr, ref, inverted)
                for x in args.values()
            )

        case irast.Pointer(expr=expr) as p:
            if expr is not None:
                return _always_references_set(expr, ref, inverted)
            else:
                return _always_references_set(p.source, ref, inverted)

        case irast.TypeCast(expr=expr):
            return _always_references_set(expr, ref, inverted)

        case _:
            return False


def try_type_rewrite(
    stype: s_objtypes.ObjectType,
    *,
    skip_subtypes: bool,
    ctx: context.ContextLevel,
) -> None:
    schema = ctx.env.schema
    rw_key = (stype, skip_subtypes)
    type_rewrites = ctx.env.type_rewrites

    # Make sure the base types in unions and intersections have their
    # rewrites compiled
    if stype.is_compound_type(schema):
        type_rewrites[rw_key] = None
        objs = (
            stype.get_union_of(schema).objects(schema) +
            stype.get_intersection_of(schema).objects(schema)
        )
        for obj in objs:
            srw_key = (obj, skip_subtypes)
            if srw_key not in type_rewrites:
                try_type_rewrite(
                    stype=obj, skip_subtypes=skip_subtypes, ctx=ctx)
                # Mark this as having a real rewrite if any parts do
                if type_rewrites[srw_key]:
                    type_rewrites[rw_key] = True
        return

    # What we *hope* to do, is to just directly select from the view
    # for our type and apply filters to it.
    #
    # Note that this is mostly optimizing the size/complexity of the
    # output *text*, by using views instead of expanding it out
    # manually.
    #
    # If some of our children have their own policies, though, we want
    # to instead union together all of our children.
    #
    # But if that is the case, and some of our children have
    # overlapping descendants, then we can't do that either, so we
    # need to explicitly list out *all* of the descendants.
    children_have_policies = not skip_subtypes and any(
        has_own_policies(stype=child, skip_from=stype, ctx=ctx)
        for child in stype.children(schema)
        if not irtyputils.is_excluded_cfg_view(
            child, ancestor=stype, schema=schema
        )
    )

    pols = get_access_policies(stype, ctx=ctx)
    if not pols and not children_have_policies:
        type_rewrites[rw_key] = None
        return

    # TODO: caching?
    children_overlap = False
    if children_have_policies:
        all_child_descs = [
            x
            for child in stype.children(schema)
            for x in [child, *child.descendants(schema)]
        ]

        # Children overlap
        child_descs = set(all_child_descs)
        if len(child_descs) != len(all_child_descs):
            children_overlap = True

    # Put a placeholder to prevent recursion.
    type_rewrites[rw_key] = None

    sets = []
    # Generate the the filters for the base type we are actually considering.
    # If the type is abstract, though, and there are policies on the children,
    # then we skip it.
    if not (children_have_policies and stype.get_abstract(schema)):
        with ctx.detached() as subctx:
            # We skip looking at subtypes in two cases:
            # 1. When some children have policies of their own, and thus
            #    need to be handled separately
            # 2. When skip_subtypes was set, and so we must
            base_set = setgen.class_set(
                stype=stype,
                skip_subtypes=children_have_policies or skip_subtypes,
                ctx=subctx)

            if children_have_policies:
                # If children have policies, then all of the filtered sets
                # will be generated on skip_subtypes sets, so we don't have
                # any work to do.
                filtered_set = base_set
            else:
                # Otherwise, do the actual work of filtering.
                from . import clauses

                filtered_stmt = irast.SelectStmt(result=base_set)
                subctx.anchors['__subject__'] = base_set
                subctx.partial_path_prefix = base_set
                subctx.path_scope = subctx.env.path_scope.root.attach_fence()

                filtered_stmt.where = clauses.compile_where_clause(
                    get_rewrite_filter(
                        stype, mode=qltypes.AccessKind.Select, ctx=subctx),
                    ctx=subctx)

                filtered_set = setgen.scoped_set(filtered_stmt, ctx=subctx)

            sets.append(filtered_set)

    if children_have_policies and not skip_subtypes:
        # N.B: we don't filter here, we just generate references
        # they will go in their own CTEs
        children = (
            stype.children(schema)
            if not children_overlap
            else stype.descendants(schema)
        )
        # We need to explicitly exclude cfg views to prevent them from
        # from showing up in type rewrites. Normally this happens in
        # inheritance.get_inheritance_view, but needs to happen here
        # when descendants are expanded.
        children = frozenset(
            child
            for child in children
            if not irtyputils.is_excluded_cfg_view(
                child, ancestor=stype, schema=schema
            )
        )
        sets += [
            # We need to wrap it in a type override so that unioning
            # them all together works...
            setgen.expression_set(
                setgen.ensure_stmt(
                    setgen.class_set(
                        stype=child, skip_subtypes=children_overlap, ctx=ctx),
                    ctx=ctx),
                type_override=stype,
                ctx=ctx,
            )
            for child in children
            if child.is_material_object_type(schema)
        ]

    # If we have multiple sets, union them together
    rewritten_set: Optional[irast.Set]
    if len(sets) > 1:
        with ctx.new() as subctx:
            subctx.expr_exposed = context.Exposure.UNEXPOSED
            subctx.anchors = subctx.anchors.copy()
            parts: list[qlast.Expr] = [subctx.create_anchor(x) for x in sets]
            rewritten_set = dispatch.compile(
                qlast.Set(elements=parts), ctx=subctx)
    elif len(sets) > 0:
        rewritten_set = sets[0]
    else:
        rewritten_set = None

    type_rewrites[rw_key] = rewritten_set


def compile_dml_write_policies(
    stype: s_objtypes.ObjectType,
    result: irast.Set,
    mode: qltypes.AccessKind, *,
    ctx: context.ContextLevel,
) -> Optional[irast.WritePolicies]:
    """Compile policy filters and wrap them into irast.WritePolicies"""
    pols = get_access_policies(stype, ctx=ctx)
    if not pols:
        return None

    with ctx.detached() as _, _.newscope(fenced=True) as subctx:
        # TODO: can we make sure to always avoid generating needless
        # select filters
        _prepare_dml_policy_context(stype, result, ctx=subctx)

        schema = subctx.env.schema
        subctx.anchors = subctx.anchors.copy()

        policies = []
        for pol in pols:
            if mode not in pol.get_access_kinds(schema):
                continue

            ir_set = compile_pol(pol, ctx=subctx)

            action = pol.get_action(schema)
            name = str(pol.get_shortname(schema))

            policies.append(
                irast.WritePolicy(
                    expr=ir_set,
                    action=action,
                    name=name,
                    error_msg=pol.get_errmessage(schema),
                )
            )

        return irast.WritePolicies(policies=policies)


def compile_dml_read_policies(
    stype: s_objtypes.ObjectType,
    result: irast.Set,
    mode: qltypes.AccessKind,
    *,
    ctx: context.ContextLevel,
) -> Optional[irast.ReadPolicyExpr]:
    """Compile a policy filter for a DML statement at a particular type"""
    if not get_access_policies(stype, ctx=ctx):
        return None

    with ctx.detached() as _, _.newscope(fenced=True) as subctx:
        # TODO: can we make sure to always avoid generating needless
        # select filters
        _prepare_dml_policy_context(stype, result, ctx=subctx)

        condition = get_rewrite_filter(stype, mode=mode, ctx=subctx)
        if not condition:
            return None

        return irast.ReadPolicyExpr(
            expr=setgen.scoped_set(
                dispatch.compile(condition, ctx=subctx), ctx=subctx
            ),
        )


def _prepare_dml_policy_context(
    stype: s_objtypes.ObjectType,
    result: irast.Set,
    *,
    ctx: context.ContextLevel,
) -> None:
    # It doesn't matter whether we skip subtypes here, so don't skip
    # subtypes if it has already been compiled that way, otherwise do.
    skip_subtypes = (stype, False) not in ctx.env.type_rewrites
    result = setgen.class_set(
        stype, path_id=result.path_id, skip_subtypes=skip_subtypes, ctx=ctx
    )

    ctx.anchors['__subject__'] = result
    ctx.partial_path_prefix = result


================================================
FILE: edb/edgeql/compiler/polyres.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL compiler routines for polymorphic call resolution."""


from __future__ import annotations

from typing import (
    Any,
    Optional,
    AbstractSet,
    Iterable,
    Mapping,
    Sequence,
    cast,
)
import dataclasses

import hashlib
import json

from edb import errors

from edb.ir import ast as irast
from edb.ir import utils as irutils

from edb.schema import functions as s_func
from edb.schema import name as sn
from edb.schema import types as s_types
from edb.schema import pseudo as s_pseudo
from edb.schema import expr as s_expr
from edb.schema import scalars as s_scalars
from edb.schema import schema as s_schema

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes as ft

from . import context
from . import dispatch
from . import pathctx
from . import setgen
from . import tuple_args
from . import typegen


@dataclasses.dataclass(kw_only=True, frozen=True)
class BoundArg:
    """The base type for bound arguments for BoundCall."""

    val: irast.Set


@dataclasses.dataclass(kw_only=True, frozen=True)
class DefaultBitmask(BoundArg):
    """The default bitmask argument, if defaults are present."""
    pass


@dataclasses.dataclass(kw_only=True, frozen=True)
class ValueArg(BoundArg):
    """A bound argument with an actual value."""

    name: str

    orig_param_type: s_types.Type
    param_type: s_types.Type
    param_typemod: ft.TypeModifier
    param_kind: ft.ParameterKind

    val: irast.Set
    valtype: s_types.Type

    polymorphism: ft.Polymorphism = ft.Polymorphism.NotUsed


@dataclasses.dataclass(kw_only=True, frozen=True)
class DefaultArg(ValueArg):
    """A bound argument whose value comes from a default."""
    pass


@dataclasses.dataclass(kw_only=True, frozen=True)
class PassedArg(ValueArg):
    """A bound argument whose value comes from a passed argument."""

    cast_distance: int
    arg_id: int | str


@dataclasses.dataclass(frozen=True)
class MissingArg:

    param: Optional[s_func.ParameterLike]
    param_type: s_types.Type


@dataclasses.dataclass(kw_only=True, frozen=True)
class BoundCall:

    func: s_func.CallableLike
    args: list[BoundArg]
    null_args: set[str]
    return_type: s_types.Type
    variadic_arg_id: Optional[int]
    variadic_arg_count: Optional[int]
    return_polymorphism: ft.Polymorphism = ft.Polymorphism.NotUsed

    server_param_conversions: Optional[dict[
        str,
        dict[str, context.ServerParamConversion],
    ]] = None


_VARIADIC = ft.ParameterKind.VariadicParam
_NAMED_ONLY = ft.ParameterKind.NamedOnlyParam
_POSITIONAL = ft.ParameterKind.PositionalParam

_SET_OF = ft.TypeModifier.SetOfType
_OPTIONAL = ft.TypeModifier.OptionalType
_SINGLETON = ft.TypeModifier.SingletonType


def find_callable_typemods(
    candidates: Sequence[s_func.CallableLike],
    *,
    num_args: int,
    kwargs_names: AbstractSet[str],
    ctx: context.ContextLevel,
) -> dict[int | str, ft.TypeModifier]:
    """Find the type modifiers for a callable.

    We do this early, before we've compiled/checked the arguments,
    so that we can compile the arguments with the proper fences.
    """

    typ: s_types.Type = s_pseudo.PseudoType.get(ctx.env.schema, 'anytype')
    dummy = irast.DUMMY_SET
    args = [(typ, dummy)] * num_args
    kwargs = {k: (typ, dummy) for k in kwargs_names}
    options = find_callable(
        candidates, basic_matching_only=True, args=args, kwargs=kwargs, ctx=ctx
    )

    # No options means an error is going to happen later, but for now,
    # just return some placeholders so that we can make it to the
    # error later.
    if not options:
        return {k: _SINGLETON for k in set(range(num_args)) | kwargs_names}

    fts: dict[int | str, ft.TypeModifier] = {}
    for choice in options:
        for barg in choice.args:
            if not isinstance(barg, PassedArg):
                continue
            ft = barg.param_typemod
            if barg.arg_id in fts and fts[barg.arg_id] != ft:
                if ft == _SET_OF or fts[barg.arg_id] == _SET_OF:
                    raise errors.QueryError(
                        f'argument could be SET OF or not in call to '
                        f'{candidates[0].get_verbosename(ctx.env.schema)}: '
                        f'seems like a stdlib bug!')
                else:
                    # If there is a mix between OPTIONAL and SINGLETON
                    # arguments in possible call sites, we just call it
                    # optional. Generated code quality will be a little
                    # worse but still correct.
                    fts[barg.arg_id] = _OPTIONAL
            else:
                fts[barg.arg_id] = ft

    return fts


def find_callable(
    candidates: Iterable[s_func.CallableLike],
    *,
    args: list[tuple[s_types.Type, irast.Set]],
    kwargs: dict[str, tuple[s_types.Type, irast.Set]],
    basic_matching_only: bool = False,
    ctx: context.ContextLevel,
) -> list[BoundCall]:

    implicit_cast_distance = None
    matched = []

    candidates = list(candidates)
    for candidate in candidates:
        call = None
        if (
            not basic_matching_only
            and (conversion := _check_server_arg_conversion(
                candidate, args, kwargs, ctx=ctx
            ))
        ):
            # If there is a server param conversion, the argument should be
            # treated as if it has already been converted.
            #
            # This means we need to check the other candidates to see if they
            # match the converted args.
            converted_args, converted_kwargs, converted_params = conversion

            for alt_candidate in candidates:
                if alt_candidate is candidate:
                    continue
                if call := try_bind_call_args(
                    converted_args,
                    converted_kwargs,
                    alt_candidate,
                    basic_matching_only,
                    ctx=ctx,
                    server_param_conversions=converted_params,
                ):
                    # A call which matches the conversion exists.
                    # Add the server param conversions to the env.
                    break

        else:
            call = try_bind_call_args(
                args, kwargs, candidate, basic_matching_only, ctx=ctx)

        if call is None:
            continue

        total_cd = sum(
            barg.cast_distance
            for barg in call.args
            if isinstance(barg, PassedArg)
        )

        if implicit_cast_distance is None:
            implicit_cast_distance = total_cd
            matched.append(call)
        elif implicit_cast_distance == total_cd:
            matched.append(call)
        elif implicit_cast_distance > total_cd:
            implicit_cast_distance = total_cd
            matched = [call]

    if len(matched) <= 1:
        # Unambiguios resolution
        return matched

    else:
        # Ambiguous resolution, try to disambiguate by
        # checking for total type distance.
        type_dist = None
        remaining = []

        for call in matched:
            call_type_dist = 0

            for barg in call.args:
                if not isinstance(barg, PassedArg):
                    # Skip injected bitmask argument.
                    continue

                arg_type_dist = barg.valtype.get_common_parent_type_distance(
                    barg.orig_param_type, ctx.env.schema
                )
                call_type_dist += arg_type_dist

            if type_dist is None:
                type_dist = call_type_dist
                remaining.append(call)
            elif type_dist == call_type_dist:
                remaining.append(call)
            elif type_dist > call_type_dist:
                type_dist = call_type_dist
                remaining = [call]

        return remaining


def try_bind_call_args(
    args: Sequence[tuple[s_types.Type, irast.Set]],
    kwargs: Mapping[str, tuple[s_types.Type, irast.Set]],
    func: s_func.CallableLike,
    basic_matching_only: bool,
    *,
    ctx: context.ContextLevel,
    server_param_conversions: Optional[
        dict[str, dict[str, context.ServerParamConversion]]
    ] = None,
) -> Optional[BoundCall]:

    return_type = func.get_return_type(ctx.env.schema)
    is_abstract = func.get_abstract(ctx.env.schema)
    resolved_poly_base_type: Optional[s_types.Type] = None

    def _get_cast_distance(
        arg: irast.Set,
        arg_type: s_types.Type,
        param_type: s_types.Type,
    ) -> int:
        nonlocal resolved_poly_base_type
        if basic_matching_only:
            return 0

        if in_polymorphic_func:
            # Compiling a body of a polymorphic function.

            if arg_type.is_polymorphic(schema):
                if param_type.is_polymorphic(schema):
                    if arg_type.test_polymorphic(schema, param_type):
                        return 0
                    else:
                        return -1
                else:
                    if arg_type.resolve_polymorphic(schema, param_type):
                        return 0
                    else:
                        return -1

        if param_type.is_polymorphic(schema):
            if not arg_type.test_polymorphic(schema, param_type):
                return -1

            resolved = param_type.resolve_polymorphic(schema, arg_type)
            if resolved is None:
                return -1

            if resolved_poly_base_type is None:
                resolved_poly_base_type = resolved

            if resolved_poly_base_type == resolved:
                if is_abstract:
                    return s_types.MAX_TYPE_DISTANCE
                elif arg_type.is_range() and param_type.is_multirange():
                    # Ranges are implicitly cast into multiranges of the same
                    # type, so they are compatible as far as polymorphic
                    # resolution goes, but it's still 1 cast.
                    return 1
                else:
                    return 0

            ctx.env.schema, ct = (
                resolved_poly_base_type.find_common_implicitly_castable_type(
                    resolved,
                    ctx.env.schema,
                )
            )

            if ct is not None:
                # If we found a common implicitly castable type, we
                # refine our resolved_poly_base_type to be that as the
                # more general case.
                resolved_poly_base_type = ct
            else:
                # Try resolving a polymorphic argument type against the
                # resolved base type. This lets us handle cases like
                #  - if b then x else {}
                #  - if b then [1] else []
                # Though it is still unfortunately not smart enough
                # to handle the reverse case.
                if resolved.is_polymorphic(schema):
                    ct = resolved.resolve_polymorphic(
                        schema, resolved_poly_base_type)

            if ct is not None:
                return s_types.MAX_TYPE_DISTANCE if is_abstract else 0
            else:
                return -1

        if arg_type.issubclass(schema, param_type):
            return 0

        return arg_type.get_implicit_cast_distance(param_type, schema)

    schema = ctx.env.schema

    in_polymorphic_func = (
        ctx.env.options.func_params is not None and
        ctx.env.options.func_params.has_polymorphic(schema)
    )

    variadic_arg_id: Optional[int] = None
    variadic_arg_count: Optional[int] = None
    no_args_call = not args and not kwargs
    has_inlined_defaults = (
        func.has_inlined_defaults(schema)
        and not (
            isinstance(func, s_func.Function)
            and (
                func.get_volatility(schema) == ft.Volatility.Modifying
                or func.get_is_inlined(schema)
            )
        )
    )

    func_params = func.get_params(schema)

    if not func_params:
        if no_args_call:
            # Match: `func` is a function without parameters
            # being called with no arguments.
            bargs: list[BoundArg] = []
            if has_inlined_defaults:
                bytes_t = ctx.env.get_schema_type_and_track(
                    sn.QualName('std', 'bytes'))
                typeref = typegen.type_to_typeref(bytes_t, env=ctx.env)
                argval = setgen.ensure_set(
                    irast.BytesConstant(value=b'\x00', typeref=typeref),
                    typehint=bytes_t,
                    ctx=ctx)
                bargs = [
                    DefaultBitmask(
                        val=argval,
                    )
                ]
            return BoundCall(
                func=func,
                args=bargs,
                null_args=set(),
                return_type=return_type,
                variadic_arg_id=None,
                variadic_arg_count=None,
                server_param_conversions=server_param_conversions,
            )
        else:
            # No match: `func` is a function without parameters
            # being called with some arguments.
            return None

    named_only = func_params.find_named_only(schema)

    if no_args_call and func_params.has_required_params(schema):
        # A call without arguments and there is at least
        # one parameter without default.
        return None

    bound_args_prep: list[MissingArg | PassedArg] = []

    params = func_params.get_in_canonical_order(schema)
    nparams = len(params)
    nargs = len(args)
    has_missing_args = False

    ai = 0
    pi = 0
    matched_kwargs = 0

    # Bind NAMED ONLY arguments (they are compiled as first set of arguments).
    while True:
        if pi >= nparams:
            break

        param = params[pi]
        if param.get_kind(schema) is not _NAMED_ONLY:
            break

        pi += 1

        param_shortname = param.get_parameter_name(schema)
        param_type = param.get_type(schema)
        param_typemod = param.get_typemod(schema)
        param_kind = param.get_kind(schema)
        if param_shortname in kwargs:
            matched_kwargs += 1

            arg_type, arg_val = kwargs[param_shortname]
            cd = _get_cast_distance(arg_val, arg_type, param_type)
            if cd < 0:
                return None

            bound_args_prep.append(
                PassedArg(
                    name=param_shortname,
                    orig_param_type=param_type,
                    param_type=param_type,
                    param_typemod=param_typemod,
                    param_kind=param_kind,
                    val=arg_val,
                    valtype=arg_type,
                    cast_distance=cd,
                    arg_id=param_shortname,
                )
            )

        else:
            if param.get_default(schema) is None:
                # required named parameter without default and
                # without a matching argument
                return None

            has_missing_args = True
            bound_args_prep.append(MissingArg(param, param_type))

    if matched_kwargs != len(kwargs):
        # extra kwargs?
        return None

    # Bind POSITIONAL arguments (compiled to go after NAMED ONLY arguments).
    while True:
        if ai < nargs:
            arg_type, arg_val = args[ai]
            ai += 1

            if pi >= nparams:
                # too many positional arguments
                return None
            param = params[pi]
            param_shortname = param.get_parameter_name(schema)
            param_type = param.get_type(schema)
            param_typemod = param.get_typemod(schema)
            param_kind = param.get_kind(schema)
            pi += 1

            if param_kind is _NAMED_ONLY:
                # impossible condition
                raise RuntimeError('unprocessed NAMED ONLY parameter')

            if param_kind is _VARIADIC:
                param_type = cast(s_types.Array, param_type)
                var_type = param_type.get_subtypes(schema)[0]
                cd = _get_cast_distance(arg_val, arg_type, var_type)
                if cd < 0:
                    return None

                bound_args_prep.append(
                    PassedArg(
                        name=param_shortname,
                        orig_param_type=param_type,
                        param_type=param_type,
                        param_typemod=param_typemod,
                        param_kind=param_kind,
                        val=arg_val,
                        valtype=arg_type,
                        cast_distance=cd,
                        arg_id=ai - 1,
                    )
                )

                for di, (arg_type, arg_val) in enumerate(args[ai:]):
                    cd = _get_cast_distance(arg_val, arg_type, var_type)
                    if cd < 0:
                        return None

                    bound_args_prep.append(
                        PassedArg(
                            name=param_shortname,
                            orig_param_type=param_type,
                            param_type=param_type,
                            param_typemod=param_typemod,
                            param_kind=param_kind,
                            val=arg_val,
                            valtype=arg_type,
                            cast_distance=cd,
                            arg_id=ai + di,
                        )
                    )

                variadic_arg_id = ai - 1
                variadic_arg_count = nargs - ai + 1

                break

            cd = _get_cast_distance(arg_val, arg_type, param_type)
            if cd < 0:
                return None

            bound_args_prep.append(
                PassedArg(
                    name=param_shortname,
                    orig_param_type=param_type,
                    param_type=param_type,
                    param_typemod=param_typemod,
                    param_kind=param_kind,
                    val=arg_val,
                    valtype=arg_type,
                    cast_distance=cd,
                    arg_id=ai - 1,
                )
            )

        else:
            break

    # Handle yet unprocessed POSITIONAL & VARIADIC arguments.
    for i in range(pi, nparams):
        param = params[i]
        param_type = param.get_type(schema)
        param_kind = param.get_kind(schema)

        if param_kind is _POSITIONAL:
            if param.get_default(schema) is None:
                # required positional parameter that we don't have a
                # positional argument for.
                return None

            has_missing_args = True
            bound_args_prep.append(MissingArg(param, param_type))

        elif param_kind is _VARIADIC:
            variadic_arg_id = i
            variadic_arg_count = 0

        elif param_kind is _NAMED_ONLY:
            # impossible condition
            raise RuntimeError('unprocessed NAMED ONLY parameter')

    # Populate defaults.
    defaults_mask = 0
    null_args: set[str] = set()
    bound_param_args: list[BoundArg] = []
    if has_missing_args:
        if has_inlined_defaults or named_only:
            for i, prep_barg in enumerate(bound_args_prep):
                if isinstance(prep_barg, PassedArg):
                    bound_param_args.append(prep_barg)
                    continue
                if prep_barg.param is None:
                    # Shouldn't be possible; the code above takes care of this.
                    raise RuntimeError(
                        f'failed to resolve the parameter for the arg #{i}')

                param = prep_barg.param
                param_shortname = param.get_parameter_name(schema)
                param_type = param.get_type(schema)
                param_typemod = param.get_typemod(schema)
                param_kind = param.get_kind(schema)

                null_args.add(param_shortname)

                defaults_mask |= 1 << i

                if not has_inlined_defaults:
                    param_default: Optional[s_expr.Expression] = (
                        param.get_default(schema)
                    )
                    assert param_default is not None
                    default = compile_arg(
                        param_default.parse(), param_typemod, ctx=ctx)

                empty_default = (
                    has_inlined_defaults or
                    irutils.is_empty(default)
                )

                if empty_default and not basic_matching_only:
                    default_type = None

                    if param_type.is_any(schema):
                        if resolved_poly_base_type is None:
                            raise errors.QueryError(
                                f'could not resolve "anytype" type for the '
                                f'${param_shortname} parameter')
                        else:
                            default_type = resolved_poly_base_type
                    else:
                        default_type = param_type

                else:
                    default_type = param_type

                if has_inlined_defaults:
                    default = compile_arg(
                        qlast.TypeCast(
                            expr=qlast.Set(elements=[]),
                            type=typegen.type_to_ql_typeref(
                                default_type,
                                ctx=ctx,
                            ),
                        ),
                        ft.TypeModifier.OptionalType,
                        ctx=ctx,
                    )

                default = setgen.ensure_set(
                    default,
                    typehint=default_type,
                    ctx=ctx,
                )

                bound_param_args.append(
                    DefaultArg(
                        name=param_shortname,
                        orig_param_type=param_type,
                        param_type=param_type,
                        param_typemod=param_typemod,
                        param_kind=param_kind,
                        val=default,
                        valtype=param_type,
                    )
                )

        else:
            bound_param_args = [
                barg for barg in bound_args_prep if isinstance(barg, PassedArg)
            ]
    else:
        bound_param_args = cast(list[BoundArg], bound_args_prep)

    if has_inlined_defaults:
        # If we are compiling an EdgeQL function, inject the defaults
        # bit-mask as a first argument.
        bytes_t = ctx.env.get_schema_type_and_track(
            sn.QualName('std', 'bytes'))
        bm = defaults_mask.to_bytes(nparams // 8 + 1, 'little')
        typeref = typegen.type_to_typeref(bytes_t, env=ctx.env)
        bm_set = setgen.ensure_set(
            irast.BytesConstant(value=bm, typeref=typeref),
            typehint=bytes_t, ctx=ctx)
        bound_param_args.insert(
            0,
            DefaultBitmask(
                val=bm_set,
            ),
        )

    return_polymorphism = ft.Polymorphism.NotUsed
    if return_type.is_polymorphic(schema):
        return_polymorphism = ft.Polymorphism.from_schema_type(return_type)

        if resolved_poly_base_type is not None:
            ctx.env.schema, return_type = return_type.to_nonpolymorphic(
                ctx.env.schema, resolved_poly_base_type)
        elif not in_polymorphic_func and not basic_matching_only:
            return None

    # resolved_poly_base_type may be legitimately None within
    # bodies of polymorphic functions
    if resolved_poly_base_type is not None:
        for i, barg in enumerate(bound_param_args):
            if (
                isinstance(barg, ValueArg)
                and barg.param_type.is_polymorphic(schema)
            ):
                ctx.env.schema, ptype = barg.param_type.to_nonpolymorphic(
                    ctx.env.schema, resolved_poly_base_type)
                polymorphism = ft.Polymorphism.from_schema_type(barg.param_type)
                bound_param_args[i] = dataclasses.replace(
                    barg,
                    param_type=ptype,
                    polymorphism=polymorphism,
                )

    return BoundCall(
        func=func,
        args=bound_param_args,
        null_args=null_args,
        return_type=return_type,
        variadic_arg_id=variadic_arg_id,
        variadic_arg_count=variadic_arg_count,
        return_polymorphism=return_polymorphism,
        server_param_conversions=server_param_conversions,
    )


def _check_server_arg_conversion(
    func: s_func.CallableLike,
    args: list[tuple[s_types.Type, irast.Set]],
    kwargs: dict[str, tuple[s_types.Type, irast.Set]],
    *,
    ctx: context.ContextLevel,
) -> Optional[tuple[
    Sequence[tuple[s_types.Type, irast.Set]],
    Mapping[str, tuple[s_types.Type, irast.Set]],
    dict[str, dict[str, context.ServerParamConversion]],
]]:
    """Check if there is a server param conversion and get the effective args.

    Server param conversion allows the server to replace a function arg with
    another parameter which it computes before executing the query.

    For example when `ext::ai::search(anyobject, str)` is called, the server
    gets an embedding vector for string arg which it then substitutes into a
    call to `ext::ai::search(anyobject, array)`.

    If any conversions are applied, returns (args, kwargs) with new query
    parameters representing the converted parameters.
    """
    schema = ctx.env.schema

    func_params: s_func.FuncParameterList = cast(
        s_func.FuncParameterList,
        func.get_params(schema),
    )

    if arg_conversions_json := (
        isinstance(func, s_func.Function)
        and func.get_server_param_conversions(schema)
    ):
        curr_server_param_conversions: dict[
            str,
            dict[str, context.ServerParamConversion],
        ] = {}

        arg_conversions: dict[str, str | list[str]] = json.loads(
            arg_conversions_json
        )
        for arg_name, conversion_info in arg_conversions.items():
            if isinstance(conversion_info, str):
                conversion_name = conversion_info
            else:
                conversion_name = conversion_info[0]

            # Get the arg being converted
            arg_key: int | str
            arg: tuple[s_types.Type, irast.Set]
            param, arg_key, arg = _get_arg(
                func_params,
                arg_name,
                args,
                kwargs,
                error_msg=f'Server param conversion {conversion_name} error',
                schema=schema,
            )

            if arg[1].expr is None:
                # Dummy set, do nothing
                continue

            original_type: s_types.Type = arg[0].material_type(schema)[1]
            if original_type != param.get_type(schema):
                # Wrong param type, function candidate doesn't apply.
                # TODO: Check "any" params
                return None

            is_param_query_parameter = (
                isinstance(arg[1].expr, irast.QueryParameter)
                and not arg[1].expr.is_global
            )
            is_param_ir_constant = isinstance(arg[1].expr, irast.BaseConstant)
            if (
                not original_type.is_array()
                and not is_param_query_parameter
                and not is_param_ir_constant
            ):
                raise errors.QueryError(
                    f"Argument '{arg_name}' "
                    f"must be a constant or query parameter",
                    span=arg[1].expr.span,
                )
            elif (
                original_type.is_array()
                and not is_param_query_parameter
            ):
                # Array literals are normalized as expressions
                # For now, don't support them as constants
                raise errors.QueryError(
                    f"Argument '{arg_name}' must be a query parameter",
                    span=arg[1].expr.span,
                )

            # Get info about the conversion
            converted_type, additional_info, conversion_volatility = (
                _resolve_server_param_conversion(
                    func_params,
                    args,
                    kwargs,
                    conversion_name,
                    schema=schema,
                    conversion_info=(
                        conversion_info
                        if isinstance(conversion_info, list)
                        else None
                    )
                )
            )

            query_param_name: str
            constant_value: Optional[Any] = None
            if isinstance(arg[1].expr, irast.BaseConstant):
                # Currently only support str constants
                constant_expr = arg[1].expr
                if isinstance(constant_expr, irast.StringConstant):
                    constant_value = constant_expr.value
                elif isinstance(
                    constant_expr, (irast.IntegerConstant, irast.BigintConstant)
                ):
                    constant_value = int(constant_expr.value)
                elif isinstance(
                    constant_expr, (irast.FloatConstant, irast.DecimalConstant)
                ):
                    constant_value = float(constant_expr.value)
                else:
                    raise RuntimeError(
                        f'Unsupported constant argument: {arg_name}'
                    )
                # Use a hash of the text value as the name
                value_hash = (
                    hashlib.sha1(constant_expr.value.encode()).hexdigest()
                )
                query_param_name = f'const_{value_hash}'
            elif isinstance(arg[1].expr, irast.QueryParameter):
                query_param_name = arg[1].expr.name
            else:
                raise RuntimeError('Server param conversion has no parameter')

            # Create a substitute parameter set with the correct type
            existing_converted_path_id = None
            if (
                (curr_conversions := (
                    ctx.env.server_param_conversions.get(query_param_name, None)
                ))
                and (
                    existing_param_conversion := (
                        curr_conversions.get(conversion_name, None)
                    )
                )
            ):
                # If the param was converted in another call, reuse its path id
                existing_converted_path_id = existing_param_conversion.path_id

            converted_param_name = f'{query_param_name}~{conversion_name}'
            converted_required = (
                isinstance(arg[1].expr, irast.QueryParameter)
                and arg[1].expr.required
            )
            converted_typeref = typegen.type_to_typeref(
                converted_type, ctx.env
            )
            conversion_set: irast.Set = setgen.ensure_set(
                irast.QueryParameter(
                    name=converted_param_name,
                    required=converted_required,
                    typeref=converted_typeref,
                    span=arg[1].span,
                ),
                path_id=existing_converted_path_id,
                ctx=ctx,
            )

            if query_param_name not in curr_server_param_conversions:
                curr_server_param_conversions[query_param_name] = {}
            curr_conversions = (
                curr_server_param_conversions[query_param_name]
            )

            if existing_converted_path_id is None:
                # If this is the first time this conversion was applied to this
                # query param, save the conversion to be possibly reused by
                # another call.

                # Create the sub-params in case the resulting converted param
                # is a tuple. Currently, no such conversion exists, but this
                # is here to prepare for that distant future.
                sub_params = tuple_args.create_sub_params(
                    converted_param_name,
                    converted_required,
                    typeref=converted_typeref,
                    pt=converted_type,
                    is_func_param=True,
                    ctx=ctx
                )

                curr_conversions[conversion_name] = (
                    context.ServerParamConversion(
                        path_id=conversion_set.path_id,
                        ir_param=irast.Param(
                            name=converted_param_name,
                            required=converted_required,
                            schema_type=converted_type,
                            ir_type=converted_typeref,
                            sub_params=sub_params,
                        ),
                        additional_info=additional_info,
                        volatility=conversion_volatility,
                        script_param_index=(
                            list(ctx.env.script_params.keys()).index(
                                query_param_name
                            )
                            if query_param_name in ctx.env.script_params else
                            None
                        ),
                        constant_value=constant_value,
                    )
                )

                # Don't include the newly created irast.Param in
                # ctx.env.query_parameters.
                # Such parameters need to have a corresponding entry in
                # compiler.Context.Environment.script_params
                #
                # The parameters will be handled separately in fini_expression
                # and compile_ir_to_sql_tree.

            # Substitute the old arg
            if isinstance(arg_key, int):
                args = args.copy()
                args[arg_key] = (converted_type, conversion_set)
            else:
                kwargs = kwargs.copy()
                kwargs[arg_key] = (converted_type, conversion_set)

        if len(curr_server_param_conversions) != len(arg_conversions):
            # Not all conversions were applied, function candidate doesn't
            # apply.
            return None

        return args, kwargs, curr_server_param_conversions

    else:
        return None


def _resolve_server_param_conversion(
    func_params: s_func.FuncParameterList,
    args: list[tuple[s_types.Type, irast.Set]],
    kwargs: dict[str, tuple[s_types.Type, irast.Set]],
    conversion_name: str,
    *,
    schema: s_schema.Schema,
    conversion_info: Optional[list[str]] = None,
) -> tuple[
    s_types.Type,
    tuple[str, ...],
    ft.Volatility,
]:
    converted_type: s_types.Type
    additional_info: tuple[str, ...] = tuple()
    conversion_volatility: ft.Volatility

    if conversion_name == 'cast_int64_to_str':
        converted_type = schema.get(
            'std::str', type=s_scalars.ScalarType
        )
        conversion_volatility = ft.Volatility.Immutable

    elif conversion_name == 'cast_int64_to_str_volatile':
        converted_type = schema.get(
            'std::str', type=s_scalars.ScalarType
        )
        conversion_volatility = ft.Volatility.Volatile

    elif conversion_name == 'cast_int64_to_float64':
        converted_type = schema.get(
            'std::float64', type=s_scalars.ScalarType
        )
        conversion_volatility = ft.Volatility.Immutable

    elif conversion_name == 'join_str_array':
        assert conversion_info is not None
        separator = conversion_info[1]

        converted_type = schema.get(
            'std::str', type=s_scalars.ScalarType
        )
        additional_info = (separator,)
        conversion_volatility = ft.Volatility.Immutable

    elif conversion_name == 'ai_text_embedding':
        assert isinstance(conversion_info, list)
        object_param_name = conversion_info[1]

        converted_type = schema.get_global(
            s_types.Array,
            s_types.Array.generate_name(
                sn.QualName('std', 'float32')
            )
        )

        _, _, object_arg = _get_arg(
            func_params,
            object_param_name,
            args,
            kwargs,
            error_msg=f'Server param conversion {conversion_name} '
            f'error finding object argument',
            schema=schema,
        )

        object_type = object_arg[0].material_type(schema)[1]
        additional_info = (str(object_type.get_id(schema)),)
        conversion_volatility = ft.Volatility.Volatile

    else:
        raise RuntimeError(
            f'Unknown server param conversion: {conversion_name}'
        )

    return (
        converted_type,
        additional_info,
        conversion_volatility,
    )


def _get_arg(
    func_params: s_func.FuncParameterList,
    param_name: str,
    args: Sequence[tuple[s_types.Type, irast.Set]],
    kwargs: Mapping[str, tuple[s_types.Type, irast.Set]],
    *,
    error_msg: str,
    schema: s_schema.Schema,
) -> tuple[
    s_func.Parameter,
    int | str,
    tuple[s_types.Type, irast.Set],
]:
    param = func_params.get_by_name(name=param_name, schema=schema)
    if param is None:
        raise RuntimeError(
            f'{error_msg}: missing param "{param_name}"'
        )

    param_kind = param.get_kind(schema)
    if param_kind == ft.ParameterKind.PositionalParam:
        param_index: int = param.get_num(schema)
        return param, param_index, args[param_index]
    elif param_kind == ft.ParameterKind.NamedOnlyParam:
        return param, param_name, kwargs[param_name]
    else:
        raise RuntimeError(
            f'{error_msg}: variadic param "{param_name}" not allowed'
        )


def compile_arg(
    arg_ql: qlast.Expr,
    typemod: ft.TypeModifier,
    *,
    prefer_subquery_args: bool=False,
    ctx: context.ContextLevel,
) -> irast.Set:
    fenced = typemod is ft.TypeModifier.SetOfType
    optional = typemod is ft.TypeModifier.OptionalType

    # Create a branch for OPTIONAL ones. The OPTIONAL branch is to
    # have a place to mark as optional in the scope tree.
    # For fenced arguments we instead wrap it in a SELECT below.
    #
    # We also put a branch when we are trying to compile the argument
    # into a subquery, so that things it uses get bound locally.
    branched = optional or (prefer_subquery_args and not fenced)

    new = ctx.newscope(fenced=False) if branched else ctx.new()
    with new as argctx:
        if optional:
            argctx.path_scope.mark_as_optional()

        if fenced:
            arg_ql = qlast.SelectQuery(
                result=arg_ql, span=arg_ql.span,
                implicit=True, rptr_passthrough=True)

        argctx.implicit_limit = 0

        arg_ir = dispatch.compile(arg_ql, ctx=argctx)

        if optional:
            pathctx.register_set_in_scope(arg_ir, optional=True, ctx=ctx)

            if arg_ir.path_scope_id is None:
                pathctx.assign_set_scope(arg_ir, argctx.path_scope, ctx=argctx)

        elif branched:
            arg_ir = setgen.scoped_set(arg_ir, ctx=argctx)

        return arg_ir


================================================
FILE: edb/edgeql/compiler/schemactx.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL compiler schema helpers."""


from __future__ import annotations

from typing import (
    Any,
    Callable,
    Optional,
    Iterable,
    Sequence,
    NamedTuple,
    cast,
)

from edb import errors

from edb.common import parsing
from edb.ir import typeutils

from edb.schema import links as s_links
from edb.schema import name as sn
from edb.schema import objects as s_obj
from edb.schema import objtypes as s_objtypes
from edb.schema import pointers as s_pointers
from edb.schema import pseudo as s_pseudo
from edb.schema import scalars as s_scalars
from edb.schema import sources as s_sources
from edb.schema import types as s_types
from edb.schema import utils as s_utils

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from . import context


def get_schema_object(
    ref: qlast.BaseObjectRef,
    module: Optional[str]=None,
    *,
    item_type: Optional[type[s_obj.Object]]=None,
    condition: Optional[Callable[[s_obj.Object], bool]]=None,
    label: Optional[str]=None,
    ctx: context.ContextLevel,
    span: Optional[parsing.Span] = None,
) -> s_obj.Object:

    if isinstance(ref, qlast.ObjectRef):
        if span is None:
            span = ref.span
        module = ref.module
        lname = ref.name
    elif isinstance(ref, qlast.PseudoObjectRef):
        return s_pseudo.PseudoType.get(ctx.env.schema, ref.name)
    else:
        raise AssertionError(f"Unhandled BaseObjectRef subclass: {ref!r}")

    name: sn.Name
    if module:
        name = sn.QualName(module=module, name=lname)
    else:
        name = sn.UnqualName(name=lname)

    try:
        stype = ctx.env.get_schema_object_and_track(
            name=name,
            expr=ref,
            modaliases=ctx.modaliases,
            type=item_type,
            condition=condition,
            label=label,
        )

    except errors.QueryError as e:
        s_utils.enrich_schema_lookup_error(
            e,
            name,
            modaliases=ctx.modaliases,
            schema=ctx.env.schema,
            item_type=item_type,
            pointer_parent=_get_partial_path_prefix_type(ctx),
            condition=condition,
            span=span,
        )
        raise

    if stype == ctx.defining_view:
        # stype is the view in process of being defined and as such is
        # not yet a valid schema object
        raise errors.SchemaDefinitionError(
            f'illegal self-reference in definition of {str(name)!r}',
            span=span)

    return stype


def _get_partial_path_prefix_type(
    ctx: context.ContextLevel,
) -> Optional[s_types.Type]:
    if ctx is None:
        return None
    ppp = ctx.partial_path_prefix
    if ppp is None or ppp.typeref is None:
        return None

    _, type = typeutils.ir_typeref_to_type(ctx.env.schema, ppp.typeref)
    return type


def get_schema_type(
    name: qlast.BaseObjectRef,
    module: Optional[str] = None,
    *,
    ctx: context.ContextLevel,
    label: Optional[str] = None,
    condition: Optional[Callable[[s_obj.Object], bool]] = None,
    item_type: Optional[type[s_obj.Object]] = None,
    span: Optional[parsing.Span] = None,
) -> s_types.Type:
    if item_type is None:
        item_type = s_types.Type
    obj = get_schema_object(name, module, item_type=item_type,
                            condition=condition, label=label,
                            ctx=ctx, span=span)
    assert isinstance(obj, s_types.Type)
    return obj


def resolve_schema_name(
    name: str, module: str, *, ctx: context.ContextLevel
) -> Optional[sn.QualName]:
    schema_module = ctx.modaliases.get(module)
    if schema_module is None:
        return None
    else:
        return sn.QualName(name=name, module=schema_module)


def preserve_view_shape(
    base: s_types.Type | s_pointers.Pointer,
    derived: s_types.Type | s_pointers.Pointer,
    *,
    derived_name_base: Optional[sn.Name] = None,
    ctx: context.ContextLevel,
) -> None:
    """Copy a view shape to a child type, updating the pointers"""
    new = []
    schema = ctx.env.schema
    for ptr, op in ctx.env.view_shapes[base]:
        target = ptr.get_target(ctx.env.schema)
        assert target
        schema, nptr = ptr.get_derived(
            schema, cast(s_sources.Source, derived), target,
            derived_name_base=derived_name_base)
        new.append((nptr, op))
    ctx.env.view_shapes[derived] = new
    if isinstance(base, s_types.Type) and isinstance(derived, s_types.Type):
        ctx.env.view_shapes_metadata[derived] = (
            ctx.env.view_shapes_metadata[base]).replace()

    # All of the pointers should already exist, so nothing should have
    # been created.
    assert schema is ctx.env.schema


def derive_view(
    stype: s_types.Type,
    *,
    derived_name: Optional[sn.QualName] = None,
    derived_name_quals: Optional[Sequence[str]] = (),
    preserve_shape: bool = False,
    exprtype: s_types.ExprType = s_types.ExprType.Select,
    inheritance_merge: bool = True,
    attrs: Optional[dict[str, Any]] = None,
    ctx: context.ContextLevel,
) -> s_types.Type:

    if derived_name is None:
        if isinstance(stype, s_obj.DerivableObject):
            derived_name = derive_view_name(
                stype=stype, derived_name_quals=derived_name_quals,
                ctx=ctx)
        else:
            derived_name = sn.QualName('__derived__', ctx.aliases.get('v'))

    if attrs is None:
        attrs = {}
    else:
        attrs = dict(attrs)

    attrs['expr_type'] = exprtype

    derived: s_types.Type

    if isinstance(stype, s_types.Collection):
        ctx.env.schema, derived = stype.derive_subtype(
            ctx.env.schema,
            name=derived_name,
            attrs=attrs,
        )

    elif isinstance(stype, (s_objtypes.ObjectType, s_scalars.ScalarType)):
        existing = ctx.env.schema.get(
            derived_name, default=None, type=type(stype))
        if existing is not None:
            if ctx.recompiling_schema_alias:
                # When recompiling schema alias, we, essentially
                # re-derive the already-existing objects exactly.
                derived = existing
            else:
                raise AssertionError(
                    f'{type(stype).get_schema_class_displayname()}'
                    f' {derived_name!r} already exists',
                )
        else:
            ctx.env.schema, derived = stype.derive_subtype(
                ctx.env.schema,
                name=derived_name,
                inheritance_merge=inheritance_merge,
                inheritance_refdicts={'pointers'},
                mark_derived=True,
                transient=True,
                # When compiling aliases, we can't elide
                # @source/@target pointers, which normally we would
                # when creating a view.
                preserve_endpoint_ptrs=ctx.env.options.schema_view_mode,
                attrs=attrs,
                stdmode=ctx.env.options.bootstrap_mode,
            )

        if (
            stype.is_view(ctx.env.schema)
            # XXX: Previously, the main check here was just for
            # (not stype.is_non_concrete(...)). is_non_concrete isn't really the
            # right way to figure out if something is a view, since
            # some aliases will be generic. On changing it to is_view
            # instead, though, two GROUP BY tests that grouped
            # on the result of a group broke
            # (test_edgeql_group_by_group_by_03{a,b}).
            #
            # It's probably a bug that this matters in that case, and
            # it is an accident that group bindings are named in such
            # a way that they count as being generic, but for now
            # preserve that behavior.
            and not (
                stype.is_non_concrete(ctx.env.schema)
                and (view_ir := ctx.view_sets.get(stype))
                and (scope_info := ctx.env.path_scope_map.get(view_ir))
                and scope_info.binding_kind
            )
            and isinstance(derived, s_objtypes.ObjectType)
        ):
            assert isinstance(stype, s_objtypes.ObjectType)
            scls_pointers = stype.get_pointers(ctx.env.schema)
            derived_own_pointers = derived.get_pointers(ctx.env.schema)

            for pn, ptr in derived_own_pointers.items(ctx.env.schema):
                # This is a view of a view.  Make sure query-level
                # computable expressions for pointers are carried over.
                src_ptr = scls_pointers.get(ctx.env.schema, pn)
                computable_data = (
                    ctx.env.source_map.get(src_ptr) if src_ptr else None)
                if computable_data is not None:
                    ctx.env.source_map[ptr] = computable_data

                if src_ptr in ctx.env.pointer_specified_info:
                    ctx.env.pointer_derivation_map[src_ptr].append(ptr)

    else:
        raise TypeError("unsupported type in derive_view")

    ctx.view_nodes[derived.get_name(ctx.env.schema)] = derived

    if preserve_shape and stype in ctx.env.view_shapes:
        preserve_view_shape(stype, derived, ctx=ctx)

    return derived


def derive_ptr(
    ptr: s_pointers.Pointer,
    source: s_sources.Source,
    target: Optional[s_types.Type] = None,
    *qualifiers: str,
    derived_name: Optional[sn.QualName] = None,
    derived_name_quals: Optional[Sequence[str]] = (),
    preserve_shape: bool = False,
    derive_backlink: bool = False,
    inheritance_merge: bool = True,
    attrs: Optional[dict[str, Any]] = None,
    ctx: context.ContextLevel,
) -> s_pointers.Pointer:

    if derived_name is None and ctx.derived_target_module:
        derived_name = derive_view_name(
            stype=ptr, derived_name_quals=derived_name_quals, ctx=ctx)

    if ptr.get_name(ctx.env.schema) == derived_name:
        qualifiers = qualifiers + (ctx.aliases.get('d'),)

    # If we are deriving a backlink, we just register that instead of
    # actually deriving from it.
    if derive_backlink:
        attrs = attrs.copy() if attrs else {}
        attrs['computed_link_alias'] = ptr
        attrs['computed_link_alias_is_backward'] = True
        ptr = ctx.env.schema.get('std::link', type=s_pointers.Pointer)

    ctx.env.schema, derived = ptr.derive_ref(
        ctx.env.schema,
        source,
        *qualifiers,
        target=target,
        name=derived_name,
        inheritance_merge=inheritance_merge,
        inheritance_refdicts={'pointers'},
        mark_derived=True,
        transient=True,
        # When compiling aliases, we can't elide
        # @source/@target pointers, which normally we would
        # when creating a view.
        preserve_endpoint_ptrs=ctx.env.options.schema_view_mode,
        attrs=attrs,
    )

    if not ptr.is_non_concrete(ctx.env.schema):
        if isinstance(derived, s_sources.Source):
            ptr = cast(s_links.Link, ptr)
            scls_pointers = ptr.get_pointers(ctx.env.schema)
            derived_own_pointers = derived.get_pointers(ctx.env.schema)

            for pn, ptr in derived_own_pointers.items(ctx.env.schema):
                # This is a view of a view.  Make sure query-level
                # computable expressions for pointers are carried over.
                src_ptr = scls_pointers.get(ctx.env.schema, pn)
                # mypy somehow loses the type argument in the
                # "pointers" ObjectIndex.
                assert isinstance(src_ptr, s_pointers.Pointer)
                computable_data = ctx.env.source_map.get(src_ptr)
                if computable_data is not None:
                    ctx.env.source_map[ptr] = computable_data

    if preserve_shape and ptr in ctx.env.view_shapes:
        preserve_view_shape(ptr, derived, ctx=ctx)

    return derived


def derive_view_name(
    stype: Optional[s_obj.DerivableObject],
    derived_name_quals: Optional[Sequence[str]] = (),
    derived_name_base: Optional[sn.Name] = None,
    *,
    ctx: context.ContextLevel,
) -> sn.QualName:
    if not derived_name_quals:
        derived_name_quals = (ctx.aliases.get('view'),)

    if ctx.derived_target_module:
        derived_name_module = ctx.derived_target_module
    else:
        derived_name_module = '__derived__'

    return s_obj.derive_name(
        ctx.env.schema,
        *derived_name_quals,
        module=derived_name_module,
        derived_name_base=derived_name_base,
        parent=stype,
    )


def get_union_type[TypeT: s_types.Type](
    types: Sequence[TypeT],
    *,
    opaque: bool = False,
    preserve_derived: bool = False,
    ctx: context.ContextLevel,
    span: Optional[parsing.Span] = None,
) -> TypeT:

    targets: Sequence[s_types.Type]
    if preserve_derived:
        targets = s_utils.simplify_union_types_preserve_derived(
            ctx.env.schema, types
        )
    else:
        targets = s_utils.simplify_union_types(
            ctx.env.schema, types
        )

    try:
        ctx.env.schema, union, _ = s_utils.ensure_union_type(
            ctx.env.schema, targets,
            opaque=opaque, transient=True)
    except errors.SchemaError as e:
        union_name = (
            '(' + ' | '.join(sorted(
            t.get_displayname(ctx.env.schema)
            for t in types
            )) + ')'
        )
        e.args = (
            (f'cannot create union {union_name} {e.args[0]}',)
            + e.args[1:]
        )
        e.set_span(span)
        raise e

    if (
        not isinstance(union, s_obj.QualifiedObject)
        or union.get_name(ctx.env.schema).module != '__derived__'
    ):
        ctx.env.add_schema_ref(union, expr=None)

    return cast(TypeT, union)


def get_intersection_type[TypeT: s_types.Type](
    types: Sequence[TypeT],
    *,
    ctx: context.ContextLevel,
) -> TypeT:

    targets: Sequence[s_types.Type]
    targets = s_utils.simplify_intersection_types(ctx.env.schema, types)
    ctx.env.schema, intersection, _ = s_utils.ensure_intersection_type(
        ctx.env.schema, targets, transient=True
    )

    if (
        not isinstance(intersection, s_obj.QualifiedObject)
        or intersection.get_name(ctx.env.schema).module != '__derived__'
    ):
        ctx.env.add_schema_ref(intersection, expr=None)

    return cast(TypeT, intersection)


def get_material_type[TypeT: s_types.Type](
    t: TypeT,
    *,
    ctx: context.ContextLevel,
) -> TypeT:

    ctx.env.schema, mtype = t.material_type(ctx.env.schema)
    return mtype


def concretify[TypeT: s_types.Type](
    t: TypeT,
    *,
    ctx: context.ContextLevel,
) -> TypeT:
    """Produce a version of t with all views removed.

    This procedes recursively through unions and intersections,
    which can result in major simplifications with intersection types
    in particular.
    """
    t = get_material_type(t, ctx=ctx)
    if els := t.get_union_of(ctx.env.schema):
        ts = [concretify(e, ctx=ctx) for e in els.objects(ctx.env.schema)]
        return get_union_type(ts, ctx=ctx)
    if els := t.get_intersection_of(ctx.env.schema):
        ts = [concretify(e, ctx=ctx) for e in els.objects(ctx.env.schema)]
        return get_intersection_type(ts, ctx=ctx)
    return t


def get_all_concrete(
    stype: s_objtypes.ObjectType, *, ctx: context.ContextLevel
) -> set[s_objtypes.ObjectType]:
    if union := stype.get_union_of(ctx.env.schema):
        return {
            x
            for t in union.objects(ctx.env.schema)
            for x in get_all_concrete(t, ctx=ctx)
        }
    elif intersection := stype.get_intersection_of(ctx.env.schema):
        return set.intersection(*(
            get_all_concrete(t, ctx=ctx)
            for t in intersection.objects(ctx.env.schema)
        ))
    return {stype} | {
        x for x in stype.descendants(ctx.env.schema)
        if x.is_material_object_type(ctx.env.schema)
    }


class TypeIntersectionResult(NamedTuple):

    stype: s_types.Type
    is_empty: bool = False
    is_subtype: bool = False


def apply_intersection(
    left: s_types.Type, right: s_types.Type, *, ctx: context.ContextLevel
) -> TypeIntersectionResult:
    """Compute an intersection of two types: *left* and *right*.

    Returns:
        A :class:`~TypeIntersectionResult` named tuple containing the
        result intersection type, whether the type system considers
        the intersection empty and whether *left* is related to *right*
        (i.e either is a subtype of another).
    """

    if left.issubclass(ctx.env.schema, right):
        # The intersection type is a proper *superclass*
        # of the argument, then this is, effectively, a NOP.
        return TypeIntersectionResult(stype=left)

    if right.issubclass(ctx.env.schema, left):
        # The intersection type is a proper *subclass* and can be directly
        # narrowed.
        return TypeIntersectionResult(
            stype=right,
            is_empty=False,
            is_subtype=True,
        )

    if (
        left.get_is_opaque_union(ctx.env.schema)
        and (left_union := left.get_union_of(ctx.env.schema))
    ):
        # Expose any opaque union types before continuing with the intersection.
        # The schema does not yet fully implement type intersections since there
        # is no `IntersectionTypeShell`. As a result, some intersections
        # produced while compiling the standard library cannot be resolved.
        left = get_union_type(left_union.objects(ctx.env.schema), ctx=ctx)

    int_type: s_types.Type = get_intersection_type([left, right], ctx=ctx)
    is_empty: bool = (
        not s_utils.expand_type_expr_descendants(int_type, ctx.env.schema)
    )
    is_subtype: bool = int_type.issubclass(ctx.env.schema, left)

    return TypeIntersectionResult(
        stype=int_type,
        is_empty=is_empty,
        is_subtype=is_subtype,
    )


def derive_dummy_ptr(
    ptr: s_pointers.Pointer,
    *,
    ctx: context.ContextLevel,
) -> s_pointers.Pointer:
    stdobj = ctx.env.schema.get('std::BaseObject', type=s_objtypes.ObjectType)
    derived_obj_name = stdobj.get_derived_name(
        ctx.env.schema, stdobj, module='__derived__')
    derived_obj = ctx.env.schema.get(
        derived_obj_name, None, type=s_obj.QualifiedObject)
    if derived_obj is None:
        ctx.env.schema, derived_obj = stdobj.derive_subtype(
            ctx.env.schema, name=derived_obj_name)

    derived_name = ptr.get_derived_name(
        ctx.env.schema, derived_obj)

    derived: s_pointers.Pointer
    derived = cast(s_pointers.Pointer, ctx.env.schema.get(derived_name, None))
    if derived is None:
        ctx.env.schema, derived = ptr.derive_ref(
            ctx.env.schema,
            derived_obj,
            target=derived_obj,
            attrs={
                'cardinality': qltypes.SchemaCardinality.One,
            },
            name=derived_name,
            mark_derived=True,
        )

    return derived


def get_union_pointer(
    *,
    ptrname: sn.UnqualName,
    source: s_sources.Source,
    direction: s_pointers.PointerDirection,
    components: Iterable[s_pointers.Pointer],
    opaque: bool = False,
    modname: Optional[str] = None,
    ctx: context.ContextLevel,
) -> s_pointers.Pointer:

    ctx.env.schema, ptr = s_pointers.get_or_create_union_pointer(
        ctx.env.schema,
        ptrname,
        source,
        direction=direction,
        components=components,
        opaque=opaque,
        modname=modname,
        transient=True,
    )
    return ptr


================================================
FILE: edb/edgeql/compiler/setgen.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL set compilation functions."""


from __future__ import annotations

from typing import (
    Any,
    Callable,
    Final,
    Literal,
    Optional,
    AbstractSet,
    ContextManager,
    Iterable,
    Iterator,
    Sequence,
    NoReturn,
    TYPE_CHECKING,
)

import contextlib
import enum

from edb import errors

from edb.common import levenshtein
from edb.common.typeutils import downcast, not_none

from edb.ir import ast as irast
from edb.ir import typeutils as irtyputils
from edb.ir import utils as irutils

from edb.schema import constraints as s_constr
from edb.schema import globals as s_globals
from edb.schema import indexes as s_indexes
from edb.schema import links as s_links
from edb.schema import name as s_name
from edb.schema import objtypes as s_objtypes
from edb.schema import permissions as s_permissions
from edb.schema import pointers as s_pointers
from edb.schema import pseudo as s_pseudo
from edb.schema import scalars as s_scalars
from edb.schema import sources as s_sources
from edb.schema import types as s_types
from edb.schema import utils as s_utils
from edb.schema import expr as s_expr

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes
from edb.edgeql import parser as qlparser

from . import astutils
from . import casts
from . import context
from . import dispatch
from . import inference
from . import pathctx
from . import schemactx
from . import stmtctx
from . import typegen

if TYPE_CHECKING:
    from edb.schema import objects as s_obj


PtrDir = s_pointers.PointerDirection


def new_set(
    *,
    stype: s_types.Type,
    expr: irast.Expr,
    ctx: context.ContextLevel,
    ircls: type[irast.Set] = irast.Set,
    **kwargs: Any,
) -> irast.Set:
    """Create a new ir.Set instance with given attributes.

    Absolutely all ir.Set instances must be created using this
    constructor.
    """

    ignore_rewrites: bool = kwargs.get('ignore_rewrites', False)

    skip_subtypes = False
    if isinstance(expr, irast.TypeRoot):
        skip_subtypes = expr.skip_subtypes

    rw_key = (stype, skip_subtypes)

    if not ignore_rewrites and ctx.suppress_rewrites:
        from . import policies
        ignore_rewrites = kwargs['ignore_rewrites'] = (
            policies.should_ignore_rewrite(stype, ctx=ctx))

    if (
        not ignore_rewrites
        and rw_key not in ctx.env.type_rewrites
        and isinstance(stype, s_objtypes.ObjectType)
        and ctx.env.options.apply_query_rewrites
    ):
        from . import policies
        policies.try_type_rewrite(stype, skip_subtypes=skip_subtypes, ctx=ctx)

    if (
        not ignore_rewrites
        and ctx.env.type_rewrites.get(rw_key)
    ):
        ctx.env.policy_use_count += 1

    typeref = typegen.type_to_typeref(stype, env=ctx.env)
    ir_set = ircls(typeref=typeref, expr=expr, **kwargs)
    ctx.env.set_types[ir_set] = stype
    return ir_set


def new_empty_set(
    *,
    stype: Optional[s_types.Type]=None, alias: str='e',
    ctx: context.ContextLevel,
    span: Optional[qlast.Span]=None
) -> irast.Set:
    if stype is None:
        stype = s_pseudo.PseudoType.get(ctx.env.schema, 'anytype')
        if span is not None:
            ctx.env.type_origins[stype] = span

    typeref = typegen.type_to_typeref(stype, env=ctx.env)
    path_id = pathctx.get_expression_path_id(stype, alias, ctx=ctx)
    ir_set = irast.Set(
        path_id=path_id,
        typeref=typeref,
        expr=irast.EmptySet(typeref=typeref),
    )
    ctx.env.set_types[ir_set] = stype
    return ir_set


def get_set_type(
    ir_set: irast.Set, *, ctx: context.ContextLevel
) -> s_types.Type:
    return ctx.env.set_types[ir_set]


def get_expr_type(
    ir: irast.Set | irast.Expr, *, ctx: context.ContextLevel
) -> s_types.Type:
    return typegen.type_from_typeref(ir.typeref, env=ctx.env)


class KeepCurrentT(enum.Enum):
    KeepCurrent = 0


KeepCurrent: Final = KeepCurrentT.KeepCurrent


def new_set_from_set(
        ir_set: irast.Set, *,
        merge_current_ns: bool=False,
        path_scope_id: Optional[int | KeepCurrentT]=KeepCurrent,
        path_id: Optional[irast.PathId]=None,
        stype: Optional[s_types.Type]=None,
        expr: irast.Expr | KeepCurrentT=KeepCurrent,
        span: Optional[qlast.Span]=None,
        is_binding: Optional[irast.BindingKind]=None,
        is_schema_alias: Optional[bool]=None,
        is_materialized_ref: Optional[bool]=None,
        is_visible_binding_ref: Optional[bool]=None,
        ignore_rewrites: Optional[bool]=None,
        is_factoring_protected: Optional[bool]=None,
        ctx: context.ContextLevel) -> irast.Set:
    """Create a new ir.Set from another ir.Set.

    The new Set inherits source everything from the old set that
    is not overriden.

    If *merge_current_ns* is set, the new Set's path_id will be
    namespaced with the currently active scope namespace.
    """
    if path_id is None:
        path_id = ir_set.path_id
    if merge_current_ns:
        path_id = path_id.merge_namespace(ctx.path_id_namespace)
    if stype is None:
        stype = get_set_type(ir_set, ctx=ctx)
    if path_scope_id == KeepCurrent:
        path_scope_id = ir_set.path_scope_id
    if expr == KeepCurrent:
        expr = ir_set.expr
    if span is None:
        span = ir_set.span
    if is_binding is None:
        is_binding = ir_set.is_binding
    if is_schema_alias is None:
        is_schema_alias = ir_set.is_schema_alias
    if is_materialized_ref is None:
        is_materialized_ref = ir_set.is_materialized_ref
    if is_visible_binding_ref is None:
        is_visible_binding_ref = ir_set.is_visible_binding_ref
    if ignore_rewrites is None:
        ignore_rewrites = ir_set.ignore_rewrites
    if is_factoring_protected is None:
        is_factoring_protected = ir_set.is_factoring_protected
    return new_set(
        path_id=path_id,
        path_scope_id=path_scope_id,
        stype=stype,
        expr=expr,
        span=span,
        is_binding=is_binding,
        is_schema_alias=is_schema_alias,
        is_materialized_ref=is_materialized_ref,
        is_visible_binding_ref=is_visible_binding_ref,
        ignore_rewrites=ignore_rewrites,
        is_factoring_protected=is_factoring_protected,
        ircls=type(ir_set),
        ctx=ctx,
    )


def new_tuple_set(
    elements: list[irast.TupleElement],
    *,
    named: bool,
    ctx: context.ContextLevel,
) -> irast.Set:

    element_types = {el.name: get_set_type(el.val, ctx=ctx) for el in elements}
    ctx.env.schema, stype = s_types.Tuple.create(
        ctx.env.schema, element_types=element_types, named=named)
    result_path_id = pathctx.get_expression_path_id(stype, ctx=ctx)

    final_elems = []
    for elem in elements:
        elem_path_id = pathctx.get_tuple_indirection_path_id(
            result_path_id, elem.name, get_set_type(elem.val, ctx=ctx),
            ctx=ctx)
        final_elems.append(irast.TupleElement(
            name=elem.name,
            val=elem.val,
            path_id=elem_path_id,
        ))

    typeref = typegen.type_to_typeref(stype, env=ctx.env)
    tup = irast.Tuple(elements=final_elems, named=named, typeref=typeref)
    return ensure_set(tup, path_id=result_path_id,
                      type_override=stype, ctx=ctx)


def new_array_set(
    elements: Sequence[irast.Set],
    *,
    stype: Optional[s_types.Type] = None,
    ctx: context.ContextLevel,
    span: Optional[qlast.Span]=None
) -> irast.Set:

    if elements:
        element_type = typegen.infer_common_type(elements, ctx.env)
        if element_type is None:
            raise errors.QueryError('could not determine array type',
                                    span=span)
    elif stype is not None:
        # When constructing an empty array, we should skip explicit cast any
        # time that we would skip it for an empty set because we can infer it
        # from the context.
        assert stype.is_array()
    else:
        element_type = s_pseudo.PseudoType.get(ctx.env.schema, 'anytype')
        if span is not None:
            ctx.env.type_origins[element_type] = span

    if stype is None:
        assert element_type
        ctx.env.schema, stype = s_types.Array.create(
            ctx.env.schema, element_type=element_type, dimensions=[-1]
        )
    typeref = typegen.type_to_typeref(stype, env=ctx.env)
    arr = irast.Array(elements=elements, typeref=typeref)
    return ensure_set(arr, type_override=stype, ctx=ctx)


def raise_self_insert_error(
    stype: s_obj.Object,
    span: Optional[qlast.Span],
    *,
    ctx: context.ContextLevel,
) -> NoReturn:
    dname = stype.get_displayname(ctx.env.schema)
    raise errors.QueryError(
        f'invalid reference to {dname}: '
        f'self-referencing INSERTs are not allowed',
        hint=(
            f'Use DETACHED if you meant to refer to an '
            f'uncorrelated {dname} set'
        ),
        span=span,
    )


def raise_invalid_property_reference(
    source: s_obj.Object,
    span: Optional[qlast.Span],
    *,
    ctx: context.ContextLevel,
) -> NoReturn:
    if isinstance(source, s_types.Type):
        source = schemactx.get_material_type(source, ctx=ctx)
    raise errors.InvalidReferenceError(
        f"invalid property reference on an expression of primitive type "
        f"'{source.get_displayname(ctx.env.schema)}'",
        span=span,
    )


def compile_path(expr: qlast.Path, *, ctx: context.ContextLevel) -> irast.Set:
    """Create an ir.Set representing the given EdgeQL path expression."""
    anchors = ctx.anchors

    if expr.partial:
        if ctx.partial_path_prefix is not None:
            path_tip = ctx.partial_path_prefix
        else:
            hint = None

            # If there are anchors, suggest one
            if anchors:
                anchor_names: list[str] = [
                    key if isinstance(key, str) else key.name
                    for key in anchors
                ]

                import edb.edgeql.codegen
                suggestion = (
                    f'{anchor_names[0]}'
                    f'{edb.edgeql.codegen.generate_source(expr)}'
                )

                if len(anchor_names) == 1:
                    hint = (
                        f'Did you mean {suggestion}?'
                    )
                else:
                    hint = (
                        f'Did you mean to use one of: {anchor_names}? '
                        f'eg. {suggestion}'
                    )

            raise errors.QueryError(
                'could not resolve partial path ',
                span=expr.span,
                hint=hint
            )

    computables: list[irast.Set] = []
    path_sets: list[irast.Set] = []

    for i, step in enumerate(expr.steps):
        is_computable = False
        skip_register_set = False

        if isinstance(step, qlast.SpecialAnchor):
            path_tip = resolve_special_anchor(step, ctx=ctx)

        elif isinstance(step, qlast.IRAnchor):
            # Check if the starting path label is a known anchor
            refnode = anchors.get(step.name)
            if not refnode:
                raise AssertionError(f'anchor {step.name} is missing')
            path_tip = new_set_from_set(refnode, ctx=ctx)

            if step.move_scope:
                assert refnode.path_scope_id is not None
                node = next(iter(
                    x for x in ctx.path_scope.root.descendants
                    if x.unique_id == refnode.path_scope_id
                ))
                node.remove()
                ctx.path_scope.attach_child(node)

                skip_register_set = True

        elif isinstance(step, qlast.ObjectRef):
            if i > 0:  # pragma: no cover
                raise RuntimeError(
                    'unexpected ObjectRef as a non-first path item')

            refnode = None

            if (
                not step.module
                and s_name.UnqualName(step.name) not in ctx.aliased_views
            ):
                # Check if the starting path label is a known anchor
                refnode = anchors.get(step.name)

            if refnode is not None:
                path_tip = new_set_from_set(refnode, ctx=ctx)
            else:
                (view_set, stype) = resolve_name(step, ctx=ctx)

                if (stype.is_enum(ctx.env.schema) and
                        not stype.is_view(ctx.env.schema)):
                    return compile_enum_path(expr, source=stype, ctx=ctx)

                if (stype.get_expr_type(ctx.env.schema) is not None and
                        stype.get_name(ctx.env.schema) not in ctx.view_nodes):
                    if not stype.get_expr(ctx.env.schema):
                        raise errors.InvalidReferenceError(
                            f"cannot refer to alias link helper type "
                            f"'{stype.get_name(ctx.env.schema)}'",
                            span=step.span,
                        )

                    # This is a schema-level view, as opposed to
                    # a WITH-block or inline alias view.
                    stype = stmtctx.declare_view_from_schema(stype, ctx=ctx)

                if not view_set:
                    view_set = ctx.view_sets.get(stype)
                if view_set is not None:
                    view_scope_info = ctx.env.path_scope_map[view_set]
                    path_tip = new_set_from_set(
                        view_set,
                        merge_current_ns=(
                            view_scope_info.pinned_path_id_ns is None
                        ),
                        is_binding=view_scope_info.binding_kind,
                        span=step.span,
                        ctx=ctx,
                    )

                    maybe_materialize(stype, path_tip, ctx=ctx)

                else:
                    path_tip = class_set(stype, ctx=ctx)

                view_scls = ctx.class_view_overrides.get(stype.id)
                if (view_scls is not None
                        and view_scls != get_set_type(path_tip, ctx=ctx)):
                    path_tip = ensure_set(
                        path_tip, type_override=view_scls, ctx=ctx)

        elif isinstance(step, qlast.Ptr):
            # Pointer traversal step
            ptr_expr = step
            if ptr_expr.direction is not None:
                direction = s_pointers.PointerDirection(ptr_expr.direction)
            else:
                direction = s_pointers.PointerDirection.Outbound

            ptr_name = ptr_expr.name

            source: s_obj.Object
            ptr: s_pointers.PointerLike

            if ptr_expr.type == 'property':
                # Link property reference; the source is the
                # link immediately preceding this step in the path.

                if isinstance(path_tip.expr, irast.Pointer):
                    ptrref = path_tip.expr.ptrref
                    fake_tip = path_tip
                elif (
                    path_tip.is_binding == irast.BindingKind.For
                    and (new := irutils.unwrap_set(path_tip))
                    and isinstance(new.expr, irast.Pointer)
                ):
                    # When accessing variables bound with FOR, allow
                    # looking through to the underlying link.  N.B:
                    # This relies on the FOR bindings still having an
                    # expr that lets us look at their
                    # definition. Eventually I'd like to stop doing
                    # that, and then we'll need to store it as part of
                    # the binding/type metadata.
                    ptrref = new.expr.ptrref
                    fake_tip = new

                    ind_prefix, _ = typegen.collapse_type_intersection_rptr(
                        fake_tip,
                        ctx=ctx,
                    )
                    # Don't allow using the iterator to access
                    # linkprops if the source of the link isn't
                    # visible, because then there will be a semi-join
                    # that prevents access to the props.  (This is
                    # pretty similar to how "changes the
                    # interpretation" errors).
                    assert isinstance(ind_prefix.expr, irast.Pointer)
                    if not ctx.path_scope.is_visible(
                        ind_prefix.expr.source.path_id
                    ):
                        # Better message
                        raise errors.QueryError(
                            'improper reference to link property on '
                            'a non-link object',
                            span=step.span,
                        )

                    # Mark the underlying pointer as needing a link table,
                    # so that we access the mapped table to begin with.
                    assert isinstance(fake_tip.expr, irast.Pointer)
                    fake_tip.expr.force_link_table = True

                else:
                    raise errors.EdgeQLSyntaxError(
                        f"unexpected reference to link property {ptr_name!r} "
                        "outside of a path expression",
                        span=ptr_expr.span,
                    )

                # The backend can't really handle @source/@target
                # outside of the singleton mode compiler, and they
                # aren't really particularly useful outside that
                # anyway, so disallow them.
                if (
                    ptr_expr.name in ('source', 'target')
                    and not ctx.allow_endpoint_linkprops
                    and (
                        ctx.env.options.schema_object_context
                        not in (s_constr.Constraint, s_indexes.Index)
                    )
                ):
                    raise errors.QueryError(
                        f'@{ptr_expr.name} may only be used in index and '
                        'constraint definitions',
                        span=step.span)

                if isinstance(
                    ptrref, irast.TypeIntersectionPointerRef
                ):
                    ind_prefix, ptrs = typegen.collapse_type_intersection_rptr(
                        fake_tip,
                        ctx=ctx,
                    )

                    assert isinstance(ind_prefix.expr, irast.Pointer)
                    prefix_type = get_set_type(ind_prefix.expr.source, ctx=ctx)
                    assert isinstance(prefix_type, s_objtypes.ObjectType)

                    if not ptrs:
                        tip_type = get_set_type(path_tip, ctx=ctx)
                        s_vn = prefix_type.get_verbosename(ctx.env.schema)
                        t_vn = tip_type.get_verbosename(ctx.env.schema)
                        pn = ind_prefix.expr.ptrref.shortname.name
                        if direction is s_pointers.PointerDirection.Inbound:
                            s_vn, t_vn = t_vn, s_vn
                        raise errors.InvalidReferenceError(
                            f"property '{ptr_name}' does not exist because"
                            f" there are no '{pn}' links between"
                            f" {s_vn} and {t_vn}",
                            span=ptr_expr.span,
                        )

                    prefix_ptr_name = (
                        next(iter(ptrs)).get_local_name(ctx.env.schema))

                    ptr = schemactx.get_union_pointer(
                        ptrname=prefix_ptr_name,
                        source=prefix_type,
                        direction=ind_prefix.expr.direction,
                        components=ptrs,
                        ctx=ctx,
                    )
                else:
                    ptr = typegen.ptrcls_from_ptrref(
                        ptrref, ctx=ctx)

                if isinstance(ptr, s_links.Link):
                    source = ptr
                else:
                    raise errors.QueryError(
                        'improper reference to link property on '
                        'a non-link object',
                        span=step.span,
                    )
            else:
                source = get_set_type(path_tip, ctx=ctx)

            # If this is followed by type intersections, collect
            # them up, since we need them in ptr_step_set.
            upcoming_intersections = []
            for j in range(i + 1, len(expr.steps)):
                nstep = expr.steps[j]
                if (isinstance(nstep, qlast.TypeIntersection)
                        and isinstance(nstep.type, qlast.TypeName)):
                    upcoming_intersections.append(
                        schemactx.get_schema_type(
                            nstep.type.maintype, ctx=ctx))
                else:
                    break

            if isinstance(source, s_types.Tuple):
                path_tip = tuple_indirection_set(
                    path_tip, source=source, ptr_name=ptr_name,
                    span=step.span, ctx=ctx)

            else:
                path_tip = ptr_step_set(
                    path_tip, expr=step, source=source, ptr_name=ptr_name,
                    direction=direction,
                    upcoming_intersections=upcoming_intersections,
                    ignore_computable=True,
                    optional_deref=step.type == 'optional',
                    span=step.span, ctx=ctx)

                assert isinstance(path_tip.expr, irast.Pointer)
                ptrcls = typegen.ptrcls_from_ptrref(
                    path_tip.expr.ptrref, ctx=ctx)
                if _is_computable_ptr(ptrcls, path_tip.expr, ctx=ctx):
                    is_computable = True

        elif isinstance(step, qlast.TypeIntersection):
            arg_type = get_set_type(path_tip, ctx=ctx)
            if not isinstance(arg_type, s_objtypes.ObjectType):
                raise errors.QueryError(
                    f'cannot apply type intersection operator to '
                    f'{arg_type.get_verbosename(ctx.env.schema)}: '
                    f'it is not an object type',
                    span=step.span)

            typ: s_types.Type = typegen.ql_typeexpr_to_type(step.type, ctx=ctx)

            try:
                path_tip = type_intersection_set(
                    path_tip, typ, optional=False, span=step.span,
                    ctx=ctx)
            except errors.SchemaError as e:
                e.set_span(step.type.span)
                raise

        else:
            # Arbitrary expression
            if i > 0:  # pragma: no cover
                raise RuntimeError(
                    'unexpected expression as a non-first path item')

            # We need to fence this if the head is a mutating
            # statement, to make sure that the factoring allowlist
            # works right.
            is_subquery = isinstance(step, qlast.Query)
            with ctx.newscope(fenced=is_subquery) as subctx:
                subctx.view_rptr = None
                path_tip = dispatch.compile(step, ctx=subctx)

                # If the head of the path is a direct object
                # reference, wrap it in an expression set to give it a
                # new path id. This prevents the object path from being
                # spuriously visible to computable paths defined in a shape
                # at the root of a path. (See test_edgeql_select_tvariant_04
                # for an example).
                if (
                    path_tip.path_id.is_objtype_path()
                    and not path_tip.path_id.is_view_path()
                    and path_tip.path_id.src_path() is None
                ):
                    path_tip = expression_set(
                        ensure_stmt(path_tip, ctx=subctx),
                        ctx=subctx)

                if path_tip.path_id.is_type_intersection_path():
                    assert isinstance(path_tip.expr, irast.Pointer)
                    scope_set = path_tip.expr.source
                else:
                    scope_set = path_tip

                scope_set = scoped_set(scope_set, ctx=subctx)

        # We compile computables under namespaces, but we need to have
        # the source of the computable *not* under that namespace,
        # so we need to do some remapping.
        if mapped := get_view_map_remapping(path_tip.path_id, ctx):
            path_tip = new_set_from_set(
                path_tip, path_id=mapped.path_id, ctx=ctx)
            # If we are remapping a source path, then we know that
            # the path is visible, so we shouldn't recompile it
            # if it is a computable path.
            is_computable = False

        if is_computable:
            computables.append(path_tip)

        if pathctx.path_is_inserting(path_tip.path_id, ctx=ctx):
            stype = ctx.env.schema.get_by_id(
                path_tip.typeref.id, type=s_types.Type
            )
            assert stype
            raise_self_insert_error(stype, step.span, ctx=ctx)

        # Don't track this step of the path if it didn't change the set
        # (probably because of do-nothing intersection)
        if not path_sets or path_sets[-1] != path_tip:
            path_sets.append(path_tip)

    if expr.span:
        path_tip.span = expr.span
    # Register the set in the scope tree. We only skip it when the
    # path was an IRAnchor with move_scope set, and so instead of
    # registering the set we moved its whole scoped set over.
    # (I think it would be *correct* to always register it, but we
    # get better generated code quality in some of those cases, when
    # we want the computation to occur down in a subquery.)
    if not skip_register_set:
        pathctx.register_set_in_scope(path_tip, ctx=ctx)

    for ir_set in computables:
        # Compile the computables in sibling scopes to the subpaths
        # they are computing. Note that the path head will be visible
        # from inside the computable scope. That's fine.

        scope = ctx.path_scope.find_descendant(ir_set.path_id)
        if scope is None:
            scope = ctx.path_scope.find_visible(ir_set.path_id)
        # We skip recompiling if we can't find a scope for it.
        # This whole mechanism seems a little sketchy, unfortunately.
        if scope is None:
            continue

        with ctx.new() as subctx:
            subctx.path_scope = scope
            assert isinstance(ir_set.expr, irast.Pointer)
            comp_ir_set = computable_ptr_set(
                ir_set.expr, ir_set.path_id, span=ir_set.span, ctx=subctx
            )
            i = path_sets.index(ir_set)
            if i != len(path_sets) - 1:
                prptr = path_sets[i + 1].expr
                assert isinstance(prptr, irast.Pointer)
                prptr.source = comp_ir_set
            else:
                path_tip = comp_ir_set
            path_sets[i] = comp_ir_set

    return path_tip


def resolve_name(
    name: qlast.ObjectRef, *, ctx: context.ContextLevel
) -> tuple[Optional[irast.Set], s_types.Type]:

    view_set = None
    stype = None
    if not name.module:
        view_set = ctx.aliased_views.get(s_name.UnqualName(name.name))
        if view_set:
            stype = get_set_type(view_set, ctx=ctx)
            return (view_set, stype)

    stype = schemactx.get_schema_type(
        name,
        condition=lambda o: (
            isinstance(o, s_types.Type)
            and (
                o.is_object_type() or
                o.is_view(ctx.env.schema) or
                o.is_enum(ctx.env.schema)
            )
        ),
        label='object type or alias',
        item_type=s_types.QualifiedType,
        span=name.span,
        ctx=ctx,
    )
    return (None, stype)


def resolve_special_anchor(
    anchor: qlast.SpecialAnchor, *, ctx: context.ContextLevel
) -> irast.Set:

    # '__source__' and '__subject__` can only appear as the
    # starting path label syntactically and must be pre-populated
    # by the compile() caller.

    assert isinstance(anchor, qlast.SpecialAnchor)
    token = anchor.name

    path_tip = ctx.anchors.get(token)

    if not path_tip:
        raise errors.InvalidReferenceError(
            f'{token} cannot be used in this expression',
            span=anchor.span,
        )

    return path_tip


def ptr_step_set(
    path_tip: irast.Set,
    *,
    upcoming_intersections: Sequence[s_types.Type] = (),
    source: s_obj.Object,
    expr: Optional[qlast.Base],
    ptr_name: str,
    direction: PtrDir = PtrDir.Outbound,
    span: Optional[qlast.Span],
    ignore_computable: bool = False,
    optional_deref: bool = False,
    ctx: context.ContextLevel,
) -> irast.Set:
    ptrcls, path_id_ptrcls = resolve_ptr_with_intersections(
        source,
        ptr_name,
        upcoming_intersections=upcoming_intersections,
        track_ref=expr,
        direction=direction,
        span=span,
        ctx=ctx)

    return extend_path(
        path_tip, ptrcls, direction,
        path_id_ptrcls=path_id_ptrcls,
        ignore_computable=ignore_computable,
        optional_deref=optional_deref,
        span=span,
        ctx=ctx)


def _add_target_schema_refs(
    stype: Optional[s_obj.Object],
    ctx: context.ContextLevel,
) -> None:
    """Add the appropriate schema dependencies for a pointer target.

    The only annoying bit is we need to handle unions/intersections also."""
    if not isinstance(stype, s_objtypes.ObjectType):
        return
    ctx.env.add_schema_ref(stype, None)
    schema = ctx.env.schema
    for obj in (
        stype.get_union_of(schema).objects(schema) +
        stype.get_intersection_of(schema).objects(schema)
    ):
        ctx.env.add_schema_ref(obj, None)


def resolve_ptr(
    near_endpoint: s_obj.Object,
    pointer_name: str,
    *,
    direction: s_pointers.PointerDirection = (
        s_pointers.PointerDirection.Outbound
    ),
    span: Optional[qlast.Span] = None,
    track_ref: Optional[qlast.Base | Literal[False]],
    ctx: context.ContextLevel,
) -> s_pointers.Pointer:
    return resolve_ptr_with_intersections(
        near_endpoint, pointer_name,
        direction=direction, span=span,
        track_ref=track_ref, ctx=ctx)[0]


def resolve_ptr_with_intersections(
    near_endpoint: s_obj.Object,
    pointer_name: str,
    *,
    upcoming_intersections: Sequence[s_types.Type] = (),
    far_endpoints: Iterable[s_obj.Object] = (),
    direction: s_pointers.PointerDirection = (
        s_pointers.PointerDirection.Outbound
    ),
    span: Optional[qlast.Span] = None,
    track_ref: Optional[qlast.Base | Literal[False]],
    ctx: context.ContextLevel,
) -> tuple[s_pointers.Pointer, s_pointers.Pointer]:
    """Resolve a pointer, taking into account upcoming intersections.

    The key trickiness here is that *two* pointers are returned:
      * one that (for backlinks) includes just the pointers that actually
        may be used
      * one for use in path ids, that does not do that filtering, so that
        path factoring works properly.
    """

    if not isinstance(near_endpoint, s_sources.Source):
        # Reference to a property on non-object
        raise_invalid_property_reference(near_endpoint, span, ctx=ctx)

    ptr: Optional[s_pointers.Pointer] = None

    if direction is s_pointers.PointerDirection.Outbound:
        path_id_ptr = ptr = near_endpoint.maybe_get_ptr(
            ctx.env.schema,
            s_name.UnqualName(pointer_name),
        )

        # If we couldn't anything, but the source is a computed link
        # that aliases some other link, look for a link property on
        # it. This allows us to access link properties in both
        # directions on links, including when the backlink has been
        # stuck in a computed.
        if (
            ptr is None
            and isinstance(near_endpoint, s_links.Link)
            and (back := near_endpoint.get_computed_link_alias(ctx.env.schema))
            and isinstance(back, s_links.Link)
            and (nptr := back.maybe_get_ptr(
                ctx.env.schema,
                s_name.UnqualName(pointer_name),
            ))
            # We can't handle computeds yet, since we would need to switch
            # around a bunch of stuff inside them.
            and not nptr.is_pure_computable(ctx.env.schema)
        ):
            src_type = downcast(
                s_types.Type, near_endpoint.get_source(ctx.env.schema)
            )
            if not src_type.is_view(ctx.env.schema):
                # HACK: If the source is in the standard library, and
                # not a view, we can't add a derived pointer.  For
                # consistency, just always require it be a view.
                new_source = downcast(
                    s_objtypes.ObjectType,
                    schemactx.derive_view(src_type, ctx=ctx),
                )
                new_endpoint = downcast(s_links.Link, schemactx.derive_ptr(
                    near_endpoint, new_source, ctx=ctx))
            else:
                new_endpoint = near_endpoint

            ptr = schemactx.derive_ptr(nptr, new_endpoint, ctx=ctx)
            path_id_ptr = nptr

        if ptr is not None:
            ref = ptr.get_nearest_non_derived_parent(ctx.env.schema)
            if track_ref is not False:
                ctx.env.add_schema_ref(ref, track_ref)
                _add_target_schema_refs(
                    ref.get_target(ctx.env.schema), ctx=ctx)

    else:
        assert isinstance(near_endpoint, s_types.Type)
        concrete_near_endpoint = schemactx.concretify(near_endpoint, ctx=ctx)
        ptrs = concrete_near_endpoint.getrptrs(
            ctx.env.schema, pointer_name, sources=far_endpoints)
        if ptrs:
            # If this reverse pointer access is followed by
            # intersections, we filter out any pointers that
            # couldn't be picked up by the intersections.
            # If a pointer doesn't get picked up, we look to see
            # if any of its children might.
            #
            # This both allows us to avoid creating spurious
            # dependencies when reverse links are used in schemas
            # and to generate a precise set of possible pointers.
            dep_ptrs = set()
            wl = list(ptrs)
            while wl:
                ptr = wl.pop()
                if (src := ptr.get_source(ctx.env.schema)):
                    if all(
                        src.issubclass(ctx.env.schema, typ)
                        for typ in upcoming_intersections
                    ):
                        dep_ptrs.add(ptr)
                    else:
                        wl.extend(ptr.children(ctx.env.schema))

            if track_ref is not False:
                for p in dep_ptrs:
                    p = p.get_nearest_non_derived_parent(ctx.env.schema)
                    ctx.env.add_schema_ref(p, track_ref)
                    _add_target_schema_refs(
                        p.get_source(ctx.env.schema), ctx=ctx)

            # We can only compute backlinks for non-computed pointers,
            # but we need to make sure that a computed pointer doesn't
            # break properly-filtered backlinks.
            concrete_ptrs = [
                ptr for ptr in ptrs
                if not ptr.is_pure_computable(ctx.env.schema)]

            for ptr in ptrs:
                if (
                    ptr.is_pure_computable(ctx.env.schema)
                    and (ptr in dep_ptrs or not concrete_ptrs)
                ):
                    vname = ptr.get_verbosename(ctx.env.schema,
                                                with_parent=True)
                    raise errors.InvalidReferenceError(
                        f'cannot follow backlink {pointer_name!r} because '
                        f'{vname} is computed',
                        span=span
                    )

            opaque = not far_endpoints
            concrete_ptr = schemactx.get_union_pointer(
                ptrname=s_name.UnqualName(pointer_name),
                source=near_endpoint,
                direction=direction,
                components=concrete_ptrs,
                opaque=opaque,
                modname=ctx.derived_target_module,
                ctx=ctx,
            )
            path_id_ptr = ptr = concrete_ptr
            # If we have an upcoming intersection that has actual
            # pointer targets, we want to put the filtered down
            # version into the AST, so that we can more easily use
            # that information in compilation.  But we still need the
            # *full* union in the path_ids, for factoring.
            if dep_ptrs and upcoming_intersections:
                ptr = schemactx.get_union_pointer(
                    ptrname=s_name.UnqualName(pointer_name),
                    source=near_endpoint,
                    direction=direction,
                    components=dep_ptrs,
                    opaque=opaque,
                    modname=ctx.derived_target_module,
                    ctx=ctx,
                )

    if ptr and path_id_ptr:
        return ptr, path_id_ptr

    if isinstance(near_endpoint, s_links.Link):
        vname = near_endpoint.get_verbosename(ctx.env.schema, with_parent=True)
        msg = f'{vname} has no property {pointer_name!r}'

    elif direction == s_pointers.PointerDirection.Outbound:
        msg = (f'{near_endpoint.get_verbosename(ctx.env.schema)} '
               f'has no link or property {pointer_name!r}')

    else:
        nep_name = near_endpoint.get_displayname(ctx.env.schema)
        path = f'{nep_name}.{direction}{pointer_name}'
        msg = f'{path!r} does not resolve to any known path'

    err = errors.InvalidReferenceError(msg, span=span)

    if (
        direction is s_pointers.PointerDirection.Outbound
        # In some call sites, we call resolve_ptr "experimentally",
        # not tracking references and swallowing failures. Don't do an
        # expensive (30% of compilation time in some benchmarks!)
        # error enrichment for cases that won't really error.
        and track_ref is not False
    ):
        s_utils.enrich_schema_lookup_error(
            err,
            s_name.UnqualName(pointer_name),
            modaliases=ctx.modaliases,
            item_type=s_pointers.Pointer,
            pointer_parent=near_endpoint,
            schema=ctx.env.schema,
        )

    raise err


def _check_secret_ptr(
    ptrcls: s_pointers.Pointer,
    *,
    span: Optional[qlast.Span]=None,
    ctx: context.ContextLevel,
) -> None:
    module = ptrcls.get_name(ctx.env.schema).module

    # HACK: Workaround for #8974. Aliases/globals have expr duplicated
    # in their associated Type, and sometimes recompilation of the
    # Type is triggered.
    # Skip producing secret errors there, since we don't have the
    # result_view_name available.
    #
    # The errors will get produced when actually compiling the
    # Global/Alias itself.
    if (
        ctx.env.options.schema_object_context
        and issubclass(ctx.env.options.schema_object_context, s_types.Type)
    ):
        return

    func_name = ctx.env.options.func_name
    if func_name and func_name.module == module:
        return

    view_name = ctx.env.options.result_view_name  # type: ignore
    if view_name and view_name.module == module:
        return

    if ctx.current_schema_views:
        view_name = ctx.current_schema_views[-1].get_name(ctx.env.schema)
        if view_name.module == module:
            return

    vn = ptrcls.get_verbosename(ctx.env.schema, with_parent=True)
    raise errors.QueryError(
        f"cannot access {vn} because it is secret",
        span=span,
    )


def extend_path(
    source_set: irast.Set,
    ptrcls: s_pointers.Pointer,
    direction: PtrDir = PtrDir.Outbound,
    *,
    path_id_ptrcls: Optional[s_pointers.Pointer] = None,
    ignore_computable: bool = False,
    same_computable_scope: bool = False,
    optional_deref: bool = False,
    span: Optional[qlast.Span]=None,
    ctx: context.ContextLevel,
) -> irast.SetE[irast.Pointer]:
    """Return a Set node representing the new path tip."""

    if ptrcls.is_link_property(ctx.env.schema):
        src_path_id = source_set.path_id.ptr_path()
    else:
        if direction is not s_pointers.PointerDirection.Inbound:
            source = ptrcls.get_near_endpoint(ctx.env.schema, direction)
            assert isinstance(source, s_types.Type)
            stype = get_set_type(source_set, ctx=ctx)
            if not stype.issubclass(ctx.env.schema, source):
                # Polymorphic link reference
                source_set = type_intersection_set(
                    source_set, source, optional=True, span=span,
                    ctx=ctx)

        src_path_id = source_set.path_id

    orig_ptrcls = ptrcls

    # If there is a particular specified ptrcls for the pathid, use
    # it, otherwise use the actual ptrcls. This comes up with
    # intersections on backlinks, where we want to use a precise ptr
    # in the IR for compilation reasons but need a path_id that is
    # independent of intersections.
    path_id_ptrcls = path_id_ptrcls or ptrcls

    # Find the pointer definition site.
    # This makes it so that views don't change path ids unless they are
    # introducing some computation.
    ptrcls = ptrcls.get_nearest_defined(ctx.env.schema)
    path_id_ptrcls = path_id_ptrcls.get_nearest_defined(ctx.env.schema)

    path_id = pathctx.extend_path_id(
        src_path_id,
        ptrcls=path_id_ptrcls,
        direction=direction,
        ns=ctx.path_id_namespace,
        ctx=ctx,
    )

    if ptrcls.get_secret(ctx.env.schema):
        _check_secret_ptr(ptrcls, span=span, ctx=ctx)

    target = orig_ptrcls.get_far_endpoint(ctx.env.schema, direction)
    assert isinstance(target, s_types.Type)
    ptr = irast.Pointer(
        source=source_set,
        direction=direction,
        ptrref=typegen.ptr_to_ptrref(ptrcls, ctx=ctx),
        is_definition=False,
        optional_deref=optional_deref,
    )
    target_set = new_set(
        stype=target, path_id=path_id, span=span, expr=ptr, ctx=ctx)

    is_computable = _is_computable_ptr(ptrcls, ptr, ctx=ctx)
    if not ignore_computable and is_computable:
        target_set = computable_ptr_set(
            ptr,
            path_id,
            same_computable_scope=same_computable_scope,
            span=span,
            ctx=ctx,
        )

    assert irutils.is_set_instance(target_set, irast.Pointer)
    return target_set


def needs_rewrite_existence_assertion(
    ptrcls: s_pointers.PointerLike,
    rptr: irast.Pointer,
    *,
    ctx: context.ContextLevel,
) -> bool:
    """Determines if we need to inject an assert_exists for a pointer

    Required pointers to types with access policies need to have an
    assert_exists added
    """

    return bool(
        not ctx.suppress_rewrites
        # We *don't* need to do the rewrite when using .?>
        and not rptr.optional_deref
        and ptrcls.get_required(ctx.env.schema)
        and rptr.direction == PtrDir.Outbound
        and (target := ptrcls.get_target(ctx.env.schema))
        and ctx.env.type_rewrites.get((target, False))
        and ptrcls.get_shortname(ctx.env.schema).name != '__type__'
    )


def is_injected_computable_ptr(
    ptrcls: s_pointers.PointerLike,
    rptr: irast.Pointer,
    *,
    ctx: context.ContextLevel,
) -> bool:
    return (
        ctx.env.options.apply_query_rewrites
        and ptrcls not in ctx.active_computeds
        and (
            bool(ptrcls.get_schema_reflection_default(ctx.env.schema))
            or needs_rewrite_existence_assertion(ptrcls, rptr, ctx=ctx)
        )
    )


def _is_computable_ptr(
    ptrcls: s_pointers.PointerLike,
    rptr: irast.Pointer,
    *,
    ctx: context.ContextLevel,
) -> bool:
    try:
        qlexpr = ctx.env.source_map[ptrcls].qlexpr
    except KeyError:
        pass
    else:
        return qlexpr is not None

    return (
        bool(ptrcls.get_expr(ctx.env.schema))
        or is_injected_computable_ptr(ptrcls, rptr, ctx=ctx)
    )


def compile_enum_path(
    expr: qlast.Path, *, source: s_types.Type, ctx: context.ContextLevel
) -> irast.Set:

    assert isinstance(source, s_scalars.ScalarType)
    enum_values = source.get_enum_values(ctx.env.schema)
    assert enum_values

    nsteps = len(expr.steps)
    if nsteps == 1:
        raise errors.QueryError(
            f"'{source.get_displayname(ctx.env.schema)}' enum "
            f"path expression lacks an enum member name, as in "
            f"'{source.get_displayname(ctx.env.schema)}.{enum_values[0]}'",
            span=expr.steps[0].span,
        )

    step2 = expr.steps[1]
    if not isinstance(step2, qlast.Ptr):
        raise errors.QueryError(
            f"an enum member name must follow enum type name in the path, "
            f"as in "
            f"'{source.get_displayname(ctx.env.schema)}.{enum_values[0]}'",
            span=step2.span,
        )

    ptr_name = step2.name

    step2_direction = s_pointers.PointerDirection.Outbound
    if step2.direction is not None:
        step2_direction = s_pointers.PointerDirection(step2.direction)
    if step2_direction is not s_pointers.PointerDirection.Outbound:
        raise errors.QueryError(
            f"enum types do not support backlink navigation",
            span=step2.span,
        )
    if step2.type == 'property':
        raise errors.QueryError(
            f"unexpected reference to link property '{ptr_name}' "
            f"outside of a path expression",
            span=step2.span,
        )

    if nsteps > 2:
        raise_invalid_property_reference(
            source, span=expr.steps[2].span, ctx=ctx
        )

    if ptr_name not in enum_values:
        rec_name = sorted(
            enum_values,
            key=lambda name: levenshtein.distance(name, ptr_name)
        )[0]
        src_name = source.get_displayname(ctx.env.schema)
        raise errors.InvalidReferenceError(
            f"'{src_name}' enum has no member called {ptr_name!r}",
            hint=f"did you mean {rec_name!r}?",
            span=step2.span,
        )

    return enum_indirection_set(
        source=source,
        ptr_name=step2.name,
        span=expr.span,
        ctx=ctx,
    )


def enum_indirection_set(
    *,
    source: s_types.Type,
    ptr_name: str,
    span: Optional[qlast.Span],
    ctx: context.ContextLevel,
) -> irast.Set:

    strref = typegen.type_to_typeref(
        ctx.env.get_schema_type_and_track(s_name.QualName('std', 'str')),
        env=ctx.env,
    )

    return casts.compile_cast(
        irast.StringConstant(value=ptr_name, typeref=strref),
        source,
        span=span,
        ctx=ctx,
    )


def tuple_indirection_set(
    path_tip: irast.Set,
    *,
    source: s_types.Type,
    ptr_name: str,
    span: Optional[qlast.Span] = None,
    ctx: context.ContextLevel,
) -> irast.Set:

    assert isinstance(source, s_types.Tuple)

    try:
        el_name = ptr_name
        el_norm_name = source.normalize_index(ctx.env.schema, el_name)
        el_type = source.get_subtype(ctx.env.schema, el_name)
    except errors.InvalidReferenceError as e:
        if span and not e.has_span():
            e.set_span(span)
        raise e

    path_id = pathctx.get_tuple_indirection_path_id(
        path_tip.path_id, el_norm_name, el_type, ctx=ctx)

    ptr = irast.TupleIndirectionPointer(
        source=path_tip,
        ptrref=downcast(irast.TupleIndirectionPointerRef, path_id.rptr()),
        direction=not_none(path_id.rptr_dir()),
    )
    ti_set = new_set(stype=el_type, path_id=path_id, expr=ptr, ctx=ctx)

    return ti_set


def type_intersection_set(
    source_set: irast.Set,
    stype: s_types.Type,
    *,
    optional: bool,
    span: Optional[qlast.Span] = None,
    ctx: context.ContextLevel,
) -> irast.Set:
    """Return an interesection of *source_set* with type *stype*."""

    arg_type = get_set_type(source_set, ctx=ctx)

    result = schemactx.apply_intersection(arg_type, stype, ctx=ctx)
    if result.stype == arg_type:
        return source_set

    rptr_specialization = []

    if (
        isinstance(source_set.expr, irast.Pointer)
        and source_set.expr.ptrref.union_components
    ):
        rptr = source_set.expr

        # This is a type intersection of a union pointer, most likely
        # a reverse link path specification.  If so, test the union
        # components against the type expression and record which
        # components match.  This information will be used later
        # when evaluating the path cardinality, as well as to
        # route link property references accordingly.
        for component in source_set.expr.ptrref.union_components:
            component_endpoint_ref = component.dir_target(rptr.direction)
            ctx.env.schema, component_endpoint = irtyputils.ir_typeref_to_type(
                ctx.env.schema, component_endpoint_ref)
            if component_endpoint.issubclass(ctx.env.schema, stype):
                assert isinstance(component, irast.PointerRef)
                rptr_specialization.append(component)
            elif stype.issubclass(ctx.env.schema, component_endpoint):
                assert isinstance(stype, s_objtypes.ObjectType)
                if rptr.direction is s_pointers.PointerDirection.Inbound:
                    narrow_ptr = stype.getptr(
                        ctx.env.schema,
                        component.shortname.get_local_name(),
                    )
                    rptr_specialization.append(
                        irtyputils.ptrref_from_ptrcls(
                            schema=ctx.env.schema,
                            ptrcls=narrow_ptr,
                            cache=ctx.env.ptr_ref_cache,
                            typeref_cache=ctx.env.type_ref_cache,
                        ),
                    )
                else:
                    assert isinstance(component, irast.PointerRef)
                    rptr_specialization.append(component)

    ptrcls = irast.TypeIntersectionLink(
        arg_type,
        result.stype,
        optional=optional,
        is_empty=result.is_empty,
        is_subtype=result.is_subtype,
        rptr_specialization=rptr_specialization,
        # The type intersection cannot increase the cardinality
        # of the input set, so semantically, the cardinality
        # of the type intersection "link" is, at most, ONE.
        cardinality=qltypes.SchemaCardinality.One,
    )

    ptrref = irtyputils.ptrref_from_ptrcls(
        schema=ctx.env.schema,
        ptrcls=ptrcls,
        cache=ctx.env.ptr_ref_cache,
        typeref_cache=ctx.env.type_ref_cache,
    )

    poly_set = new_set(
        stype=result.stype,
        path_id=source_set.path_id.extend(ptrref=ptrref),
        expr=irast.TypeIntersectionPointer(
            source=source_set,
            ptrref=downcast(irast.TypeIntersectionPointerRef, ptrref),
            direction=s_pointers.PointerDirection.Outbound,
            optional=optional,
        ),
        span=span,
        ctx=ctx,
    )

    return poly_set


def class_set(
    stype: s_types.Type,
    *,
    path_id: Optional[irast.PathId] = None,
    skip_subtypes: bool = False,
    ignore_rewrites: bool = False,
    ctx: context.ContextLevel,
) -> irast.Set:
    """Nominally, create a set representing selecting some type.

    That is, create a set with a TypeRoot expr.

    TODO(ir): In practice, a lot of call sites really want some kind
    of handle to something that will be bound elsewhere, and we should
    clean those up to use a different node.
    """

    if path_id is None:
        path_id = pathctx.get_path_id(stype, ctx=ctx)
    return new_set(
        path_id=path_id,
        stype=stype,
        ignore_rewrites=ignore_rewrites,
        expr=irast.TypeRoot(
            typeref=typegen.type_to_typeref(stype, env=ctx.env),
            skip_subtypes=skip_subtypes,
        ),
        ctx=ctx,
    )


def expression_set(
    expr: irast.Expr,
    path_id: Optional[irast.PathId] = None,
    *,
    type_override: Optional[s_types.Type] = None,
    ctx: context.ContextLevel,
) -> irast.Set:

    if isinstance(expr, irast.Set):  # pragma: no cover
        raise errors.InternalServerError(f'{expr!r} is already a Set')

    if type_override is not None:
        stype = type_override
    else:
        stype = get_expr_type(expr, ctx=ctx)

    if path_id is None:
        path_id = getattr(expr, 'path_id', None)
        if path_id is None:
            path_id = pathctx.get_expression_path_id(stype, ctx=ctx)

    return new_set(
        path_id=path_id,
        stype=stype,
        expr=expr,
        span=expr.span,
        ctx=ctx
    )


def scoped_set(
    expr: irast.Set | irast.Expr,
    *,
    type_override: Optional[s_types.Type] = None,
    typehint: Optional[s_types.Type] = None,
    path_id: Optional[irast.PathId] = None,
    force_reassign: bool = False,
    ctx: context.ContextLevel,
) -> irast.Set:

    if not isinstance(expr, irast.Set):
        ir_set = expression_set(
            expr, type_override=type_override,
            path_id=path_id, ctx=ctx)
        pathctx.assign_set_scope(ir_set, ctx.path_scope, ctx=ctx)
    else:
        if typehint is not None or type_override is not None:
            ir_set = ensure_set(
                expr, typehint=typehint,
                type_override=type_override,
                path_id=path_id, ctx=ctx)
        else:
            ir_set = expr

        if ir_set.path_scope_id is None or force_reassign:
            if ctx.path_scope.find_child(ir_set.path_id) and path_id is None:
                # Protect from scope recursion in the common case by
                # wrapping the set into a subquery.
                ir_set = expression_set(
                    ensure_stmt(ir_set, ctx=ctx),
                    type_override=type_override,
                    ctx=ctx)

            pathctx.assign_set_scope(ir_set, ctx.path_scope, ctx=ctx)

    return ir_set


def moveable_anchor(
    expr: irast.Set,
    name: str = 'v',
    *,
    check_dml: bool = False,
    ctx: context.ContextLevel,
) -> qlast.Path:
    return ctx.create_anchor(
        scoped_set(expr, ctx=ctx),
        name=name,
        check_dml=check_dml,
        move_scope=True,
    )


def ensure_set(
    expr: irast.Set | irast.Expr,
    *,
    type_override: Optional[s_types.Type] = None,
    typehint: Optional[s_types.Type] = None,
    path_id: Optional[irast.PathId] = None,
    span: Optional[qlast.Span] = None,
    ctx: context.ContextLevel,
) -> irast.Set:

    if not isinstance(expr, irast.Set):
        ir_set = expression_set(
            expr, type_override=type_override,
            path_id=path_id, ctx=ctx)
    else:
        ir_set = expr

    stype = get_set_type(ir_set, ctx=ctx)

    if type_override is not None and stype != type_override:
        ir_set = new_set_from_set(ir_set, stype=type_override, ctx=ctx)

        stype = type_override

    if span is not None:
        ir_set = new_set_from_set(ir_set, span=span, ctx=ctx)

    if (irutils.is_set_instance(ir_set, irast.EmptySet)
            and (stype is None or stype.is_any(ctx.env.schema))
            and typehint is not None):
        typegen.amend_empty_set_type(ir_set, typehint, env=ctx.env)
        stype = get_set_type(ir_set, ctx=ctx)

    if (
        typehint is not None
        and stype != typehint
        and not stype.implicitly_castable_to(typehint, ctx.env.schema)
    ):
        raise errors.QueryError(
            f'expecting expression of type '
            f'{typehint.get_displayname(ctx.env.schema)}, '
            f'got {stype.get_displayname(ctx.env.schema)}',
            span=expr.span
        )

    return ir_set


def ensure_stmt(
    expr: irast.Set | irast.Expr, *, ctx: context.ContextLevel
) -> irast.Stmt:
    if not isinstance(expr, irast.Stmt):
        expr = irast.SelectStmt(
            result=ensure_set(expr, ctx=ctx),
            implicit_wrapper=True,
        )
    return expr


def fixup_computable_source_set(
    source_set: irast.Set,
    *,
    ctx: context.ContextLevel,
) -> irast.Set:
    source_scls = get_set_type(source_set, ctx=ctx)
    # process_view() may generate computable pointer expressions
    # in the form "self.linkname".  To prevent infinite recursion,
    # self must resolve to the parent type of the view NOT the view
    # type itself.  Similarly, when resolving computable link properties
    # make sure that we use the parent of derived ptrcls.
    if source_scls.is_view(ctx.env.schema):
        source_set_stype = source_scls.peel_view(ctx.env.schema)
        source_set = new_set_from_set(
            source_set, stype=source_set_stype, ctx=ctx)
        source_set.shape = ()

        if isinstance(source_set.expr, irast.Pointer):
            source_rptrref = source_set.expr.ptrref
            if source_rptrref.base_ptr is not None:
                source_rptrref = source_rptrref.base_ptr
            source_set.expr = source_set.expr.replace(
                ptrref=source_rptrref,
                is_definition=True,
            )
    return source_set


def computable_ptr_set(
    rptr: irast.Pointer,
    path_id: irast.PathId,
    *,
    same_computable_scope: bool=False,
    span: Optional[qlast.Span]=None,
    ctx: context.ContextLevel,
) -> irast.Set:
    """Return ir.Set for a pointer defined as a computable."""
    ptrcls = typegen.ptrcls_from_ptrref(rptr.ptrref, ctx=ctx)
    source_scls = get_set_type(rptr.source, ctx=ctx)
    source_set = fixup_computable_source_set(rptr.source, ctx=ctx)
    ptrcls_to_shadow = None

    qlctx: Optional[context.ContextLevel]

    try:
        comp_info = ctx.env.source_map[ptrcls]
        qlexpr = comp_info.qlexpr
        assert isinstance(comp_info.context, context.ContextLevel)
        qlctx = comp_info.context
        inner_source_path_id = comp_info.path_id
        path_id_ns = comp_info.path_id_ns
    except KeyError:
        comp_expr: Optional[s_expr.Expression] = ptrcls.get_expr(ctx.env.schema)
        schema_qlexpr: Optional[qlast.Expr] = None
        if comp_expr is None and ctx.env.options.apply_query_rewrites:
            assert isinstance(ptrcls, s_pointers.Pointer)
            ptrcls_n = ptrcls.get_shortname(ctx.env.schema).name
            path = qlast.Path(
                steps=[
                    qlast.SpecialAnchor(name='__source__'),
                    qlast.Ptr(
                        name=ptrcls_n,
                        direction=s_pointers.PointerDirection.Outbound,
                        type=(
                            'property'
                            if ptrcls.is_link_property(ctx.env.schema)
                            else None
                        )
                    )
                ],
            )

            schema_deflt = ptrcls.get_schema_reflection_default(ctx.env.schema)
            if schema_deflt is not None:
                schema_qlexpr = qlast.BinOp(
                    left=path,
                    right=qlparser.parse_fragment(schema_deflt),
                    op='??',
                )

            if needs_rewrite_existence_assertion(ptrcls, rptr, ctx=ctx):
                # Wrap it in a dummy select so that we can't optimize away
                # the assert_exists.
                # TODO: do something less bad
                arg = qlast.SelectQuery(
                    result=path, where=qlast.Constant.boolean(True))
                vname = ptrcls.get_verbosename(
                    ctx.env.schema, with_parent=True)
                msg = f'required {vname} is hidden by access policy'
                if ctx.active_computeds:
                    cur = next(reversed(ctx.active_computeds))
                    vname = cur.get_verbosename(
                        ctx.env.schema, with_parent=True)
                    msg += f' (while evaluating computed {vname})'

                schema_qlexpr = qlast.FunctionCall(
                    func=('__std__', 'assert_exists'),
                    args=[arg],
                    kwargs={'message': qlast.Constant.string(value=msg)},
                )

            # Is this is a view, we want to shadow the underlying
            # ptrcls, since otherwise we will generate this default
            # code *twice*.
            if rptr.ptrref.base_ptr:
                ptrcls_to_shadow = typegen.ptrcls_from_ptrref(
                    rptr.ptrref.base_ptr, ctx=ctx)

        if schema_qlexpr is None:
            if comp_expr is None:
                ptrcls_sn = ptrcls.get_shortname(ctx.env.schema)
                raise errors.InternalServerError(
                    f'{ptrcls_sn!r} is not a computed pointer')

            comp_qlexpr = comp_expr.parse()
            assert isinstance(comp_qlexpr, qlast.Expr), 'expected qlast.Expr'
            schema_qlexpr = comp_qlexpr

        # NOTE: Validation of the expression type is not the concern
        # of this function. For any non-object pointer target type,
        # the default expression must be assignment-cast into that
        # type.
        target_scls = ptrcls.get_target(ctx.env.schema)
        assert target_scls is not None
        if not target_scls.is_object_type():
            schema_qlexpr = qlast.TypeCast(
                type=typegen.type_to_ql_typeref(
                    target_scls, ctx=ctx),
                expr=schema_qlexpr,
            )
        qlexpr = astutils.ensure_ql_query(schema_qlexpr)
        qlctx = None
        path_id_ns = None

    newctx: Callable[[], ContextManager[context.ContextLevel]]

    if qlctx is None:
        # Schema-level computed link or property, the context should
        # still have a source.
        newctx = _get_schema_computed_ctx(
            rptr=rptr,
            source=source_set,
            ctx=ctx)

    else:
        newctx = _get_computable_ctx(
            rptr=rptr,
            source=source_set,
            source_scls=source_scls,
            inner_source_path_id=inner_source_path_id,
            path_id_ns=path_id_ns,
            same_scope=same_computable_scope,
            qlctx=qlctx,
            ctx=ctx)

    result_stype = ptrcls.get_target(ctx.env.schema)
    base_object = ctx.env.schema.get('std::BaseObject', type=s_types.Type)
    with newctx() as subctx:
        assert isinstance(source_scls, s_sources.Source)
        assert isinstance(ptrcls, s_pointers.Pointer)

        subctx.active_computeds = subctx.active_computeds.copy()
        if ptrcls_to_shadow:
            assert isinstance(ptrcls_to_shadow, s_pointers.Pointer)
            subctx.active_computeds.add(ptrcls_to_shadow)
        subctx.active_computeds.add(ptrcls)
        if result_stype != base_object:
            subctx.view_scls = result_stype
        subctx.view_rptr = context.ViewRPtr(
            source=source_scls, ptrcls=ptrcls)
        subctx.anchors['__source__'] = source_set
        if qlctx and '__default__' in qlctx.anchors:
            subctx.anchors['__default__'] = qlctx.anchors['__default__']
        subctx.empty_result_type_hint = ptrcls.get_target(ctx.env.schema)
        subctx.partial_path_prefix = source_set
        # On a mutation, make the expr_exposed. This corresponds with
        # a similar check on is_mutation in _normalize_view_ptr_expr.
        if (source_scls.get_expr_type(ctx.env.schema)
                != s_types.ExprType.Select):
            subctx.expr_exposed = context.Exposure.EXPOSED

        comp_ir_set = dispatch.compile(qlexpr, ctx=subctx)

    # XXX: or should we update rptr in place??
    rptr = rptr.replace(expr=comp_ir_set.expr)
    comp_ir_set = new_set_from_set(
        comp_ir_set, path_id=path_id, expr=rptr, span=span,
        merge_current_ns=True,
        ctx=ctx)

    maybe_materialize(ptrcls, comp_ir_set, ctx=ctx)

    return comp_ir_set


def _get_schema_computed_ctx(
    *, rptr: irast.Pointer, source: irast.Set, ctx: context.ContextLevel
) -> Callable[[], ContextManager[context.ContextLevel]]:

    @contextlib.contextmanager
    def newctx() -> Iterator[context.ContextLevel]:
        with ctx.detached() as subctx:
            source_scope = pathctx.get_set_scope(rptr.source, ctx=ctx)
            if source_scope and source_scope.namespaces:
                subctx.path_id_namespace |= source_scope.namespaces

            # Get the type of the actual location where the computed pointer
            # was defined in the schema, since that is the type that must
            # be used in the view map, since that is the type that might
            # be *referenced in the definition*.
            ptr = typegen.ptrcls_from_ptrref(rptr.ptrref, ctx=ctx)
            assert isinstance(ptr, s_pointers.Pointer)
            ptr = ptr.maybe_get_topmost_concrete_base(ctx.env.schema) or ptr
            src = ptr.get_source(ctx.env.schema)

            # If the source is an abstract pointer, then we don't have
            # a full path to bind in the computed. Otherwise use a
            # path derived from the pointer source.
            if not (
                isinstance(src, s_pointers.Pointer)
                and src.is_non_concrete(ctx.env.schema)
            ):
                inner_path_id = not_none(irast.PathId.from_pointer(
                    ctx.env.schema, ptr, namespace=subctx.path_id_namespace,
                    env=ctx.env,
                ).src_path())

                # XXX: THIS IS DODGY - wait, this is a no-op
                remapped_source = new_set_from_set(
                    rptr.source, expr=rptr.source.expr, ctx=ctx
                )
                update_view_map(inner_path_id, remapped_source, ctx=subctx)

            yield subctx

    return newctx


def update_view_map(
    path_id: irast.PathId,
    remapped_source: irast.Set,
    *,
    ctx: context.ContextLevel
) -> None:
    ctx.view_map = ctx.view_map.new_child()
    key = path_id.strip_namespace(path_id.namespace)
    old = ctx.view_map.get(key, ())
    ctx.view_map[key] = ((path_id, remapped_source),) + old


def get_view_map_remapping(
    path_id: irast.PathId, ctx: context.ContextLevel
) -> Optional[irast.Set]:
    """Perform path_id remapping based on outer views

    This is a little fiddly, since we may have
    picked up *additional* namespaces.
    """
    key = path_id.strip_namespace(path_id.namespace)
    entries = ctx.view_map.get(key, ())
    fixed_path_id = path_id.merge_namespace(ctx.path_id_namespace, deep=True)
    for inner_path_id, mapped in entries:
        fixed_inner = inner_path_id.merge_namespace(
            ctx.path_id_namespace, deep=True)

        if fixed_inner == fixed_path_id:
            return mapped
    return None


def remap_path_id(
    path_id: irast.PathId, ctx: context.ContextLevel
) -> irast.PathId:
    """Remap a path_id based on the view_map, one step at a time.

    This is intended to mirror what happens to paths in compile_path.
    """
    new_id = None
    hit = False
    for prefix in path_id.iter_prefixes():
        if not new_id:
            new_id = prefix
        else:
            nrptr, dir = prefix.rptr(), prefix.rptr_dir()
            assert nrptr and dir
            new_id = new_id.extend(
                ptrref=nrptr, direction=dir, ns=prefix.namespace)

        if mapped := get_view_map_remapping(new_id, ctx):
            hit = True
            new_id = mapped.path_id

    assert new_id and (new_id == path_id or hit)
    return new_id


def _get_computable_ctx(
    *,
    rptr: irast.Pointer,
    source: irast.Set,
    source_scls: s_types.Type,
    inner_source_path_id: irast.PathId,
    path_id_ns: Optional[irast.Namespace],
    same_scope: bool,
    qlctx: context.ContextLevel,
    ctx: context.ContextLevel
) -> Callable[[], ContextManager[context.ContextLevel]]:

    @contextlib.contextmanager
    def newctx() -> Iterator[context.ContextLevel]:
        with ctx.new() as subctx:
            subctx.class_view_overrides = {}
            subctx.partial_path_prefix = None

            subctx.modaliases = qlctx.modaliases.copy()
            subctx.aliased_views = qlctx.aliased_views.new_child()

            subctx.view_nodes = qlctx.view_nodes.copy()
            subctx.view_map = ctx.view_map.new_child()
            subctx.view_sets = ctx.view_sets.copy()

            source_scope = pathctx.get_set_scope(rptr.source, ctx=ctx)
            if source_scope and source_scope.namespaces:
                subctx.path_id_namespace |= source_scope.namespaces

            if path_id_ns is not None:
                subctx.path_id_namespace |= {path_id_ns}

            pending_pid_ns = {ctx.aliases.get('ns')}

            if path_id_ns is not None and same_scope:
                pending_pid_ns.add(path_id_ns)

            subctx.pending_stmt_own_path_id_namespace = (
                frozenset(pending_pid_ns))

            subns = set(pending_pid_ns)
            subns.add(ctx.aliases.get('ns'))

            # Include the namespace from the source in the namespace
            # we compile under. This helps make sure the remapping
            # lines up.
            subns |= qlctx.path_id_namespace

            subctx.pending_stmt_full_path_id_namespace = frozenset(subns)

            # If one of the sources present at the definition site is still
            # visible, make sure to hang on to the remapping.
            for entry in qlctx.view_map.values():
                for map_path_id, remapped in entry:
                    if subctx.path_scope.is_visible(map_path_id):
                        update_view_map(map_path_id, remapped, ctx=subctx)

            inner_path_id = inner_source_path_id.merge_namespace(subns)
            with subctx.new() as remapctx:
                remapctx.path_id_namespace |= subns
                # We need to run the inner_path_id through the same
                # remapping process that happens in compile_path, or
                # else the path id won't match, since the prefix will
                # get remapped first.
                inner_path_id = remap_path_id(inner_path_id, remapctx)

            # XXX: THIS IS DODGY - wait, this is a no-op
            remapped_source = new_set_from_set(
                rptr.source, expr=rptr.source.expr, ctx=ctx)
            update_view_map(inner_path_id, remapped_source, ctx=subctx)

            yield subctx

    return newctx


def maybe_materialize(
    stype: s_types.Type | s_pointers.PointerLike,
    ir: irast.Set,
    *,
    ctx: context.ContextLevel,
) -> None:
    if isinstance(stype, s_pointers.PseudoPointer):
        return

    # Search for a materialized_sets entry
    while True:
        if mat_entry := ctx.env.materialized_sets.get(stype):
            break
        # Search up for parent pointers, if applicable
        if not isinstance(stype, s_pointers.Pointer):
            return
        bases = stype.get_bases(ctx.env.schema).objects(ctx.env.schema)
        if not bases:
            return
        stype = bases[0]

    # We've found an entry, populate it.
    mat_qlstmt, reason = mat_entry
    materialize_in_stmt = ctx.env.compiled_stmts[mat_qlstmt]
    if materialize_in_stmt.materialized_sets is None:
        materialize_in_stmt.materialized_sets = {}

    assert not isinstance(stype, s_pointers.PseudoPointer)
    if stype.id not in materialize_in_stmt.materialized_sets:
        materialize_in_stmt.materialized_sets[stype.id] = (
            irast.MaterializedSet(
                materialized=ir, reason=reason, use_sets=[]))

    mat_set = materialize_in_stmt.materialized_sets[stype.id]
    mat_set.use_sets.append(ir)


def should_materialize(
    ir: irast.Base,
    *,
    ptrcls: Optional[s_pointers.Pointer] = None,
    materialize_visible: bool = False,
    skipped_bindings: AbstractSet[irast.PathId] = frozenset(),
    ctx: context.ContextLevel,
) -> Sequence[irast.MaterializeReason]:
    volatility = inference.infer_volatility(ir, ctx.env, exclude_dml=True)
    reasons: list[irast.MaterializeReason] = []

    if volatility.is_volatile():
        reasons.append(irast.MaterializeVolatile())

    if not isinstance(ir, irast.Set):
        return reasons

    typ = get_set_type(ir, ctx=ctx)

    assert ir.path_scope_id is not None

    # For shape elements, we need to materialize when they reference
    # bindings that are visible from that point. This means that doing
    # WITH/FOR bindings internally is fine, but referring to
    # externally bound things needs materialization. We can't actually
    # do this visibility analysis until we are done, though, so
    # instead we just store the bindings.
    if (
        materialize_visible
        and (vis := irutils.find_potentially_visible(
            ir,
            ctx.env.scope_tree_nodes[ir.path_scope_id],
            ctx.env.scope_tree_nodes, skipped_bindings))
    ):
        reasons.append(irast.MaterializeVisible(
            sets=vis, path_scope_id=ir.path_scope_id))

    if ptrcls and ptrcls in ctx.env.source_map:
        reasons += ctx.env.source_map[ptrcls].should_materialize

    for r in should_materialize_type(typ, ctx=ctx):
        # Rewrite visibility reasons from the typ to reflect this,
        # the real bind point.
        if isinstance(r, irast.MaterializeVolatile):
            reasons.append(r)
        else:
            reasons.append(
                irast.MaterializeVisible(
                    sets=r.sets, path_scope_id=ir.path_scope_id))

    return reasons


def should_materialize_type(
    typ: s_types.Type, *, ctx: context.ContextLevel
) -> list[irast.MaterializeReason]:
    schema = ctx.env.schema
    reasons: list[irast.MaterializeReason] = []
    if isinstance(
            typ, (s_objtypes.ObjectType, s_pointers.Pointer)):
        for pointer in typ.get_pointers(schema).objects(schema):
            if pointer in ctx.env.source_map:
                reasons += ctx.env.source_map[pointer].should_materialize
    elif isinstance(typ, s_types.Collection):
        for sub in typ.get_subtypes(schema):
            reasons += should_materialize_type(sub, ctx=ctx)

    return reasons


def get_global_param(
    glob: s_globals.Global | s_permissions.Permission,
    *,
    ctx: context.ContextLevel
) -> irast.Global:
    name = glob.get_name(ctx.env.schema)

    if name not in ctx.env.query_globals:
        param_name = f'__edb_global_{len(ctx.env.query_globals)}__'

        if isinstance(glob, s_globals.Global):
            # Globals
            target = glob.get_target(ctx.env.schema)
            target_typeref = typegen.type_to_typeref(target, env=ctx.env)

            ctx.env.query_globals[name] = irast.Global(
                name=param_name,
                required=False,
                schema_type=target,
                ir_type=target_typeref,
                global_name=name,
                has_present_arg=glob.needs_present_arg(ctx.env.schema),
                is_permission=False,
            )

        else:
            # Permissions
            target = ctx.env.schema.get('std::bool', type=s_types.Type)
            target_typeref = typegen.type_to_typeref(target, env=ctx.env)

            ctx.env.query_globals[name] = irast.Global(
                name=param_name,
                required=True,
                schema_type=target,
                ir_type=target_typeref,
                global_name=name,
                has_present_arg=False,
                is_permission=True,
            )

    return ctx.env.query_globals[name]


def get_global_param_sets(
    glob: s_globals.Global | s_permissions.Permission,
    *,
    ctx: context.ContextLevel,
    is_implicit_global: bool = False,
) -> tuple[irast.Set, Optional[irast.Set]]:
    param = get_global_param(glob, ctx=ctx)
    default = (
        glob.get_default(ctx.env.schema)
        if isinstance(glob, s_globals.Global) else
        None
    )

    # This function is called to compile either a global expr or the global
    # params for a function call. Both are compiled as QueryParameter.
    assert ctx.env.options.func_params is None

    param_set = ensure_set(
        irast.QueryParameter(
            name=param.name,
            required=param.required and not bool(default),
            typeref=param.ir_type,
            is_implicit_global=is_implicit_global,
        ),
        ctx=ctx,
    )

    if (
        isinstance(glob, s_globals.Global)
        and glob.needs_present_arg(ctx.env.schema)
    ):
        present_set = ensure_set(
            irast.QueryParameter(
                name=param.name + "present__",
                required=True,
                typeref=typegen.type_to_typeref(
                    ctx.env.schema.get('std::bool', type=s_types.Type),
                    env=ctx.env,
                ),
                is_implicit_global=is_implicit_global,
            ),
            ctx=ctx,
        )
    else:
        present_set = None

    return param_set, present_set


def get_func_global_json_arg(*, ctx: context.ContextLevel) -> irast.Set:
    json_type = ctx.env.schema.get('std::json', type=s_types.Type)
    json_typeref = typegen.type_to_typeref(json_type, env=ctx.env)
    name = '__edb_json_globals__'

    is_func_param = ctx.env.options.func_params is not None
    parameter_type = (
        irast.FunctionParameter if is_func_param else irast.QueryParameter
    )

    # If this is because we have json params, not because we're in a
    # function, we need to register it.
    if ctx.env.options.json_parameters:
        qname = s_name.QualName('__', name)
        ctx.env.query_globals[qname] = irast.Global(
            name=name,
            required=False,
            schema_type=json_type,
            ir_type=json_typeref,
            global_name=qname,
            has_present_arg=False,
            is_permission=False,
        )

    return ensure_set(
        parameter_type(
            name=name,
            required=True,
            typeref=json_typeref,
        ),
        ctx=ctx,
    )


def get_func_global_param_sets(
    glob: s_globals.Global | s_permissions.Permission,
    *,
    ctx: context.ContextLevel,
) -> tuple[qlast.Expr, Optional[qlast.Expr]]:
    # NB: updates ctx anchors

    # Make sure that we properly track the globals we use in functions
    get_global_param(glob, ctx=ctx)

    with ctx.new() as subctx:
        name = str(glob.get_name(ctx.env.schema))

        glob_set = get_func_global_json_arg(ctx=ctx)
        glob_anchor = qlast.FunctionCall(
            func=('__std__', 'json_get'),
            args=[
                subctx.create_anchor(glob_set, 'a'),
                qlast.Constant.string(value=str(name)),
            ],
        )

        if isinstance(glob, s_globals.Global):
            target = glob.get_target(ctx.env.schema)

        else:
            # Permissions
            target = ctx.env.schema.get('std::bool', type=s_types.Type)

        type = typegen.type_to_ql_typeref(target, ctx=ctx)
        main_set = qlast.TypeCast(expr=glob_anchor, type=type)

        if (
            isinstance(glob, s_globals.Global)
            and glob.needs_present_arg(ctx.env.schema)
        ):
            present_set = qlast.UnaryOp(
                op='EXISTS',
                operand=glob_anchor,
            )
        else:
            present_set = None

    return main_set, present_set


def get_globals_as_json(
    globs: Sequence[s_globals.Global | s_permissions.Permission],
    *,
    ctx: context.ContextLevel,
    span: Optional[qlast.Span],
) -> irast.Set:
    """Build a json object that contains the values of `globs`

    The format of the object is simply
       {"": glob_val_1, ...},
    where values that are unset or set to {} are represented as null,
    with one catch:
       for globals that need "present" arguments (that is, optional globals
       with default values), we need to distinguish between the global
       being unset and being set to {}. In that case, we represent being
       set to {} with null and being unset by omitting it from the object.
    """
    # TODO: arrange to compute this once per query, in a CTE or some such?

    # If globals are empty, arrange to still pass in the argument but
    # don't put anything in it.
    if ctx.env.options.make_globals_empty:
        globs = ()

    objctx = ctx.env.options.schema_object_context
    is_constraint_like = objctx in (s_constr.Constraint, s_indexes.Index)
    if globs and is_constraint_like:
        assert objctx
        typname = objctx.get_schema_class_displayname()
        # XXX: or should we pass in empty globals, in this situation?
        raise errors.SchemaDefinitionError(
            f'functions that reference global variables cannot be called '
            f'from {typname}',
            span=span)

    null_expr = qlast.FunctionCall(
        func=('__std__', 'to_json'),
        args=[qlast.Constant.string(value="null")],
    )

    with ctx.new() as subctx:
        subctx.anchors = subctx.anchors.copy()
        normal_els = []
        full_objs: list[qlast.Expr] = []

        json_type = qlast.TypeName(maintype=qlast.ObjectRef(
            module='__std__', name='json'))

        for glob in globs:
            param, present = get_global_param_sets(
                glob, is_implicit_global=True, ctx=ctx)
            # The name of the global isn't syntactically a valid identifier
            # for a namedtuple element but nobody can stop us!
            name = str(glob.get_name(ctx.env.schema))

            main_param = subctx.create_anchor(param, 'a')
            tuple_el = qlast.TupleElement(
                name=qlast.Ptr(name=name),
                val=qlast.BinOp(
                    op='??',
                    left=qlast.TypeCast(expr=main_param, type=json_type),
                    right=null_expr,
                )
            )

            if not present:
                # For normal globals, just stick the element in the tuple.
                normal_els.append(tuple_el)
            else:
                # For globals with a present arg, we conditionally
                # construct a one-element object if it is present
                # and an empty object if it is not. These are
                # be combined using ++.
                present_param = subctx.create_anchor(present, 'a')
                tup = qlast.TypeCast(
                    expr=qlast.NamedTuple(elements=[tuple_el]),
                    type=json_type,
                )

                full_objs.append(qlast.IfElse(
                    condition=present_param,
                    if_expr=tup,
                    else_expr=qlast.FunctionCall(
                        func=('__std__', 'to_json'),
                        args=[qlast.Constant.string(value="{}")],
                    )
                ))

        # If access policies are disabled, stick a value in the blob
        # to indicate that.  We do this using a full object so it
        # works in constraints and the like, where the tuple->json cast
        # isn't supported yet.
        if (
            not ctx.env.options.apply_user_access_policies
            or not ctx.env.options.apply_query_rewrites
        ) and not is_constraint_like:
            full_objs.append(qlast.FunctionCall(
                func=('__std__', 'to_json'),
                args=[qlast.Constant.string(
                    value='{"__disable_access_policies": true}'
                )],
            ))

        full_expr: qlast.Expr
        if not normal_els and not full_objs:
            full_expr = null_expr
        else:
            simple_obj = None
            if normal_els or not full_objs:
                simple_obj = qlast.TypeCast(
                    expr=qlast.NamedTuple(elements=normal_els),
                    type=json_type,
                )

            full_expr = astutils.extend_binop(simple_obj, *full_objs, op='++')

        return dispatch.compile(full_expr, ctx=subctx)


================================================
FILE: edb/edgeql/compiler/stmt.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL statement compilation routines."""


from __future__ import annotations
from typing import (
    Any,
    Optional,
    Sequence,
    cast
)

from collections import defaultdict
import textwrap
import itertools

from edb import errors
from edb.common import ast
from edb.common import span as edb_span
from edb.common.typeutils import not_none

from edb.ir import ast as irast
from edb.ir import typeutils
from edb.ir import utils as irutils

from edb.schema import ddl as s_ddl
from edb.schema import functions as s_func
from edb.schema import links as s_links
from edb.schema import properties as s_props
from edb.schema import modules as s_mod
from edb.schema import name as s_name
from edb.schema import objects as s_obj
from edb.schema import objtypes as s_objtypes
from edb.schema import pointers as s_pointers
from edb.schema import pseudo as s_pseudo
from edb.schema import schema as s_schema
from edb.schema import types as s_types
from edb.schema import utils as s_utils

from edb.edgeql import ast as qlast
from edb.edgeql import utils as qlutils
from edb.edgeql import qltypes
from edb.edgeql import desugar_group

from . import astutils
from . import clauses
from . import context
from . import config_desc
from . import dispatch
from . import inference
from . import pathctx
from . import policies
from . import setgen
from . import viewgen
from . import schemactx
from . import stmtctx
from . import typegen
from . import conflicts


def try_desugar(
    expr: qlast.Query, *, ctx: context.ContextLevel
) -> Optional[irast.Set]:
    new_syntax = desugar_group.try_group_rewrite(expr, aliases=ctx.aliases)
    if new_syntax:
        return dispatch.compile(new_syntax, ctx=ctx)
    return None


def _protect_expr(
    expr: Optional[qlast.Expr], *, ctx: context.ContextLevel
) -> None:
    if ctx.no_factoring:
        while isinstance(expr, qlast.Shape):
            expr.allow_factoring = True
            expr = expr.expr
        if isinstance(expr, qlast.Path):
            expr.allow_factoring = True


@dispatch.compile.register(qlast.SelectQuery)
def compile_SelectQuery(
    expr: qlast.SelectQuery, *, ctx: context.ContextLevel
) -> irast.Set:
    if rewritten := try_desugar(expr, ctx=ctx):
        return rewritten

    _protect_expr(expr.result, ctx=ctx)

    with ctx.subquery() as sctx:
        stmt = irast.SelectStmt()
        init_stmt(stmt, expr, ctx=sctx, parent_ctx=ctx)
        if expr.implicit:
            # Make sure path prefix does not get blown away by
            # implicit subqueries.
            sctx.partial_path_prefix = ctx.partial_path_prefix
            stmt.implicit_wrapper = True

        # If there is an offset or a limit, this query was a wrapper
        # around something else, and we need to forward_rptr

        forward_rptr = (
            bool(expr.offset)
            or bool(expr.limit)
            or expr.rptr_passthrough
            # We need to preserve view_rptr if this SELECT is just
            # an implicit wrapping of a single DISTINCT, because otherwise
            # using a DISTINCT to satisfy link multiplicity requirement
            # will kill the link properties.
            #
            # This includes problems with initializing the schema itself.
            or (
                isinstance(expr.result, qlast.UnaryOp)
                and expr.result.op == 'DISTINCT'
            )
            or (
                isinstance(expr.result, qlast.FunctionCall)
                and expr.result.func in (
                    'assert_distinct', 'assert_single', 'assert_exists')
            )
        )

        stmt.result = compile_result_clause(
            expr.result,
            view_scls=ctx.view_scls,
            view_rptr=ctx.view_rptr,
            result_alias=expr.result_alias,
            view_name=ctx.toplevel_result_view_name,
            forward_rptr=forward_rptr,
            ctx=sctx)

        stmt.where = clauses.compile_where_clause(expr.where, ctx=sctx)

        stmt.orderby = clauses.compile_orderby_clause(expr.orderby, ctx=sctx)

        stmt.offset = clauses.compile_limit_offset_clause(
            expr.offset, ctx=sctx)

        stmt.limit = clauses.compile_limit_offset_clause(
            expr.limit, ctx=sctx)

        result = fini_stmt(stmt, ctx=sctx, parent_ctx=ctx)

    return result


@dispatch.compile.register(qlast.ForQuery)
def compile_ForQuery(
    qlstmt: qlast.ForQuery, *, ctx: context.ContextLevel
) -> irast.Set:
    if rewritten := try_desugar(qlstmt, ctx=ctx):
        return rewritten

    with ctx.subquery() as sctx:
        stmt = irast.SelectStmt(span=qlstmt.span)
        init_stmt(stmt, qlstmt, ctx=sctx, parent_ctx=ctx)

        # As an optimization, if the iterator is a singleton set, use
        # the element directly.
        iterator = qlstmt.iterator
        if isinstance(iterator, qlast.Set) and len(iterator.elements) == 1:
            iterator = iterator.elements[0]

        contains_dml = astutils.contains_dml(qlstmt.result, ctx=ctx)

        with sctx.new() as ectx:
            if ectx.expr_exposed:
                ectx.expr_exposed = context.Exposure.BINDING
            iterator_view = stmtctx.declare_view(
                astutils.ensure_ql_select(iterator),
                s_name.UnqualName(qlstmt.iterator_alias),
                factoring_fence=contains_dml,
                path_id_namespace=sctx.path_id_namespace,
                binding_kind=irast.BindingKind.For,
                ctx=ectx,
            )

        iterator_stmt = setgen.new_set_from_set(iterator_view, ctx=sctx)
        iterator_view.is_visible_binding_ref = True
        stmt.iterator_stmt = iterator_stmt

        iterator_type = setgen.get_set_type(iterator_stmt, ctx=ctx)
        if iterator_type.is_any(ctx.env.schema):
            raise errors.QueryError(
                'FOR statement has iterator of indeterminate type',
                span=ctx.env.type_origins.get(iterator_type),
            )

        view_scope_info = sctx.env.path_scope_map[iterator_view]

        pathctx.register_set_in_scope(
            iterator_stmt,
            path_scope=sctx.path_scope,
            optional=qlstmt.optional,
            ctx=sctx,
        )

        sctx.iterator_path_ids |= {stmt.iterator_stmt.path_id}
        node = sctx.path_scope.find_descendant(iterator_stmt.path_id)
        if node is not None:
            # If the body contains DML, then we need to prohibit
            # correlation between the iterator and the enclosing
            # query, since the correlation imposes compilation issues
            # we aren't willing to tackle.
            #
            # Do this by sticking the iterator subtree onto a branch
            # with a factoring fence.
            if contains_dml:
                node = node.attach_branch()
                node.factoring_fence = True
                node.factoring_allowlist.update(ctx.iterator_path_ids)
                node = node.attach_branch()

            node.attach_subtree(
                view_scope_info.path_scope,
                span=iterator.span,
                ctx=ctx,
            )

        # Compile the body
        with sctx.newscope(fenced=True) as bctx:
            stmt.result = setgen.scoped_set(
                compile_result_clause(
                    # Make sure it is a stmt, so that shapes inside the body
                    # get resolved there.
                    astutils.ensure_ql_query(qlstmt.result),
                    view_scls=ctx.view_scls,
                    view_rptr=ctx.view_rptr,
                    view_name=ctx.toplevel_result_view_name,
                    forward_rptr=True,
                    ctx=bctx,
                ),
                ctx=bctx,
            )

        # Inject an implicit limit if appropriate
        if ((ctx.expr_exposed or sctx.stmt is ctx.toplevel_stmt)
                and ctx.implicit_limit):
            stmt.limit = dispatch.compile(
                qlast.Constant.integer(ctx.implicit_limit),
                ctx=sctx,
            )

        result = fini_stmt(stmt, ctx=sctx, parent_ctx=ctx)

    return result


def _make_group_binding(
    stype: s_types.Type,
    alias: str,
    *,
    ctx: context.ContextLevel,
) -> irast.Set:
    """Make a binding for one of the "dummy" bindings used in group"""
    binding_type = schemactx.derive_view(
        stype,
        derived_name=s_name.QualName('__derived__', alias),
        preserve_shape=True, ctx=ctx)

    binding_set = setgen.class_set(binding_type, ctx=ctx)
    binding_set.is_visible_binding_ref = True

    name = s_name.UnqualName(alias)
    ctx.aliased_views[name] = binding_set
    ctx.view_sets[binding_type] = binding_set
    ctx.env.path_scope_map[binding_set] = context.ScopeInfo(
        path_scope=ctx.path_scope,
        binding_kind=irast.BindingKind.For,
        pinned_path_id_ns=ctx.path_id_namespace,
    )

    return binding_set


@dispatch.compile.register(qlast.InternalGroupQuery)
def compile_InternalGroupQuery(
    expr: qlast.InternalGroupQuery, *, ctx: context.ContextLevel
) -> irast.Set:
    # We disallow use of FOR GROUP except for when running in test mode.
    if not expr.from_desugaring and not ctx.env.options.testmode:
        raise errors.UnsupportedFeatureError(
            "'FOR GROUP' is an internal testing feature",
            span=expr.span,
        )

    _protect_expr(expr.subject, ctx=ctx)
    _protect_expr(expr.result, ctx=ctx)

    with ctx.subquery() as sctx:
        stmt = irast.GroupStmt(by=expr.by)
        init_stmt(stmt, expr, ctx=sctx, parent_ctx=ctx)

        with sctx.newscope(fenced=True) as topctx:
            # N.B: Subject is exposed because we want any shape on the
            # subject to be exposed on bare references to the group
            # alias.  This is frankly pretty dodgy behavior for
            # FOR GROUP to have but the real GROUP needs to
            # maintain shapes, and this is the easiest way to handle
            # that.
            stmt.subject = compile_result_clause(
                expr.subject,
                result_alias=expr.subject_alias,
                exprtype=s_types.ExprType.Group,
                ctx=topctx)

            if topctx.partial_path_prefix:
                pathctx.register_set_in_scope(
                    topctx.partial_path_prefix, ctx=topctx)

            # compile the USING
            assert expr.using is not None

            for using_entry in expr.using:
                # Fail on keys named 'id', since we can't put them
                # in the output free object.
                if using_entry.alias == 'id':
                    raise errors.UnsupportedFeatureError(
                        "may not name a grouping alias 'id'",
                        span=using_entry.span,
                    )
                elif desugar_group.key_name(using_entry.alias) == 'id':
                    raise errors.UnsupportedFeatureError(
                        "may not group by a field named id",
                        span=using_entry.expr.span,
                        hint="try 'using id_ := .id'",
                    )

                with topctx.newscope(fenced=True) as scopectx:
                    if scopectx.expr_exposed:
                        scopectx.expr_exposed = context.Exposure.BINDING
                    binding = stmtctx.declare_view(
                        using_entry.expr,
                        s_name.UnqualName(using_entry.alias),
                        binding_kind=irast.BindingKind.With,
                        path_id_namespace=scopectx.path_id_namespace,
                        ctx=scopectx,
                    )
                    binding.span = using_entry.expr.span
                    stmt.using[using_entry.alias] = (
                        setgen.new_set_from_set(binding, ctx=sctx),
                        qltypes.Cardinality.UNKNOWN)
                    binding.is_visible_binding_ref = True

            subject_stype = setgen.get_set_type(stmt.subject, ctx=topctx)
            stmt.group_binding = _make_group_binding(
                subject_stype, expr.group_alias, ctx=topctx)

            # # Compile the shape on the group binding, in case we need it
            # viewgen.late_compile_view_shapes(stmt.group_binding, ctx=topctx)

            if expr.grouping_alias:
                ctx.env.schema, grouping_stype = s_types.Array.create(
                    ctx.env.schema,
                    element_type=(
                        ctx.env.schema.get('std::str', type=s_types.Type)
                    )
                )
                stmt.grouping_binding = _make_group_binding(
                    grouping_stype, expr.grouping_alias, ctx=topctx)

        # Check that the by clause is legit
        by_refs = ast.find_children(stmt.by, qlast.ObjectRef)
        for by_ref in by_refs:
            if by_ref.name not in stmt.using:
                raise errors.InvalidReferenceError(
                    f"variable '{by_ref.name}' referenced in BY but not "
                    f"declared in USING",
                    span=by_ref.span,
                )

        # compile the output
        # newscope because we don't want the result to get assigned the
        # same statement scope as the subject and elements, which we
        # need to stick in the real GROUP BY
        with sctx.newscope(fenced=True) as bctx:
            pathctx.register_set_in_scope(
                stmt.group_binding, path_scope=bctx.path_scope, ctx=bctx
            )

            # Compile the shape on the group binding, in case we need it
            viewgen.late_compile_view_shapes(stmt.group_binding, ctx=bctx)

            node = bctx.path_scope.find_descendant(stmt.group_binding.path_id)
            not_none(node).is_group = True
            for using_value, _ in stmt.using.values():
                pathctx.register_set_in_scope(
                    using_value, path_scope=bctx.path_scope, ctx=bctx
                )

            if stmt.grouping_binding:
                pathctx.register_set_in_scope(
                    stmt.grouping_binding, path_scope=bctx.path_scope, ctx=bctx
                )

            stmt.result = compile_result_clause(
                astutils.ensure_ql_query(expr.result),
                result_alias=expr.result_alias,
                ctx=bctx)

            stmt.where = clauses.compile_where_clause(expr.where, ctx=bctx)

            stmt.orderby = clauses.compile_orderby_clause(
                expr.orderby, ctx=bctx)

        result = fini_stmt(stmt, ctx=sctx, parent_ctx=ctx)

    return result


@dispatch.compile.register(qlast.GroupQuery)
def compile_GroupQuery(
    expr: qlast.GroupQuery, *, ctx: context.ContextLevel
) -> irast.Set:
    return dispatch.compile(
        desugar_group.desugar_group(expr, ctx.aliases),
        ctx=ctx,
    )


@dispatch.compile.register(qlast.InsertQuery)
def compile_InsertQuery(
    expr: qlast.InsertQuery, *, ctx: context.ContextLevel
) -> irast.Set:

    if ctx.disallow_dml:
        raise errors.QueryError(
            f'INSERT statements cannot be used {ctx.disallow_dml}',
            hint=(
                f'To resolve this try to factor out the mutation '
                f'expression into the top-level WITH block.'
            ),
            span=expr.span,
        )

    # Record this node in the list of potential DML expressions.
    ctx.env.dml_exprs.append(expr)

    with ctx.subquery() as ictx:
        stmt = irast.InsertStmt(span=expr.span)
        init_stmt(stmt, expr, ctx=ictx, parent_ctx=ctx)

        with ictx.new() as ectx:
            ectx.expr_exposed = context.Exposure.UNEXPOSED
            subject = dispatch.compile(
                qlast.Path(steps=[expr.subject], allow_factoring=True), ctx=ectx
            )
        assert isinstance(subject, irast.Set)

        subject_stype = setgen.get_set_type(subject, ctx=ictx)

        # If we are INSERTing a type that we are in the ELSE block of,
        # we need to error out.
        if ictx.inserting_paths.get(subject.path_id) == 'else':
            setgen.raise_self_insert_error(
                subject_stype, expr.subject.span, ctx=ctx)

        if subject_stype.get_abstract(ctx.env.schema):
            raise errors.QueryError(
                f'cannot insert into abstract '
                f'{subject_stype.get_verbosename(ctx.env.schema)}',
                span=expr.subject.span)

        if subject_stype.is_free_object_type(ctx.env.schema):
            raise errors.QueryError(
                f'free objects cannot be inserted',
                span=expr.subject.span)

        if subject_stype.is_view(ctx.env.schema):
            raise errors.QueryError(
                f'cannot insert into expression alias '
                f'{str(subject_stype.get_shortname(ctx.env.schema))!r}',
                span=expr.subject.span)

        if _is_forbidden_stdlib_type_for_mod(subject_stype, ctx):
            raise errors.QueryError(
                f'cannot insert standard library type '
                f'{subject_stype.get_displayname(ctx.env.schema)}',
                span=expr.subject.span)

        with ictx.new() as bodyctx:
            # Self-references in INSERT are prohibited.
            pathctx.ban_inserting_path(
                subject.path_id, location='body', ctx=bodyctx)

            bodyctx.class_view_overrides = ictx.class_view_overrides.copy()
            bodyctx.implicit_id_in_shapes = False
            bodyctx.implicit_tid_in_shapes = False
            bodyctx.implicit_tname_in_shapes = False
            bodyctx.implicit_limit = 0

            stmt.subject = compile_query_subject(
                subject,
                shape=expr.shape,
                view_rptr=ctx.view_rptr,
                compile_views=True,
                exprtype=s_types.ExprType.Insert,
                ctx=bodyctx,
                span=expr.span,
            )

        stmt_subject_stype = setgen.get_set_type(subject, ctx=ictx)
        assert isinstance(stmt_subject_stype, s_objtypes.ObjectType)

        stmt.conflict_checks = conflicts.compile_inheritance_conflict_checks(
            stmt, stmt_subject_stype, ctx=ictx)

        if expr.unless_conflict is not None:
            constraint_spec, else_branch = expr.unless_conflict

            if constraint_spec:
                stmt.on_conflict = conflicts.compile_insert_unless_conflict_on(
                    stmt, stmt_subject_stype, constraint_spec, else_branch,
                    ctx=ictx)
            else:
                stmt.on_conflict = conflicts.compile_insert_unless_conflict(
                    stmt, stmt_subject_stype, ctx=ictx)

        conflicts.check_for_isolation_conflicts(
            stmt, stmt_subject_stype, ctx=ictx)

        mat_stype = schemactx.get_material_type(stmt_subject_stype, ctx=ctx)
        result = setgen.class_set(
            mat_stype, path_id=stmt.subject.path_id, ctx=ctx
        )

        with ictx.new() as resultctx:
            stmt.result = compile_query_subject(
                result,
                view_scls=ctx.view_scls,
                view_name=ctx.toplevel_result_view_name,
                compile_views=ictx.stmt is ictx.toplevel_stmt,
                ctx=resultctx,
                span=expr.span,
            )

        if pol_condition := policies.compile_dml_write_policies(
            mat_stype, result, mode=qltypes.AccessKind.Insert, ctx=ictx
        ):
            stmt.write_policies[mat_stype.id] = pol_condition

        # Compute the unioned output type if needed
        if stmt.on_conflict and stmt.on_conflict.else_ir:
            final_typ = typegen.infer_common_type(
                [stmt.result, stmt.on_conflict.else_ir], ctx.env)
            if final_typ is None:
                raise errors.QueryError('could not determine INSERT type',
                                        span=stmt.span)
            stmt.final_typeref = typegen.type_to_typeref(final_typ, env=ctx.env)

        # Wrap the statement.
        result = fini_stmt(stmt, ctx=ictx, parent_ctx=ctx)

        # If we have an ELSE clause, and this is a toplevel statement,
        # we need to compile_query_subject *again* on the outer query,
        # in order to produce a view for the joined output, which we
        # need to have to generate the proper type descriptor.  This
        # feels like somewhat of a hack; I think it might be possible
        # to do something more general elsewhere.
        if (
            expr.unless_conflict
            and expr.unless_conflict[1]
            and ictx.stmt is ctx.toplevel_stmt
        ):
            with ictx.new() as resultctx:
                resultctx.expr_exposed = context.Exposure.EXPOSED
                result = compile_query_subject(
                    result,
                    view_name=ctx.toplevel_result_view_name,
                    compile_views=ictx.stmt is ctx.toplevel_stmt,
                    ctx=resultctx,
                    span=result.span,
                )

    return result


@dispatch.compile.register(qlast.UpdateQuery)
def compile_UpdateQuery(
    expr: qlast.UpdateQuery, *, ctx: context.ContextLevel
) -> irast.Set:

    if ctx.disallow_dml:
        raise errors.QueryError(
            f'UPDATE statements cannot be used {ctx.disallow_dml}',
            hint=(
                f'To resolve this try to factor out the mutation '
                f'expression into the top-level WITH block.'
            ),
            span=expr.span,
        )

    _protect_expr(expr.subject, ctx=ctx)

    # Record this node in the list of DML statements.
    ctx.env.dml_exprs.append(expr)

    with ctx.subquery() as ictx:
        stmt = irast.UpdateStmt(
            span=expr.span,
        )
        init_stmt(stmt, expr, ctx=ictx, parent_ctx=ctx)

        with ictx.new() as ectx:
            ectx.expr_exposed = context.Exposure.UNEXPOSED
            subject = dispatch.compile(expr.subject, ctx=ectx)
        assert isinstance(subject, irast.Set)

        subj_type = setgen.get_set_type(subject, ctx=ictx)
        if not isinstance(subj_type, s_objtypes.ObjectType):
            raise errors.QueryError(
                f'cannot update non-ObjectType objects',
                span=expr.subject.span
            )

        if subj_type.is_free_object_type(ctx.env.schema):
            raise errors.QueryError(
                f'free objects cannot be updated',
                span=expr.subject.span)

        mat_stype = schemactx.concretify(subj_type, ctx=ctx)

        if _is_forbidden_stdlib_type_for_mod(mat_stype, ctx):
            raise errors.QueryError(
                f'cannot update standard library type '
                f'{subj_type.get_displayname(ctx.env.schema)}',
                span=expr.subject.span)

        stmt._material_type = typeutils.type_to_typeref(
            ctx.env.schema,
            mat_stype,
            include_children=True,
            include_ancestors=True,
            cache=ctx.env.type_ref_cache,
        )

        ictx.partial_path_prefix = subject

        stmt.where = clauses.compile_where_clause(expr.where, ctx=ictx)

        with ictx.new() as bodyctx:
            bodyctx.class_view_overrides = ictx.class_view_overrides.copy()
            bodyctx.implicit_id_in_shapes = False
            bodyctx.implicit_tid_in_shapes = False
            bodyctx.implicit_tname_in_shapes = False
            bodyctx.implicit_limit = 0

            stmt.subject = compile_query_subject(
                subject,
                shape=expr.shape,
                view_rptr=ctx.view_rptr,
                compile_views=True,
                exprtype=s_types.ExprType.Update,
                ctx=bodyctx,
                span=expr.span,
            )

        result = setgen.class_set(
            mat_stype, path_id=stmt.subject.path_id, ctx=ctx,
        )

        with ictx.new() as resultctx:
            stmt.result = compile_query_subject(
                result,
                view_scls=ctx.view_scls,
                view_name=ctx.toplevel_result_view_name,
                compile_views=ictx.stmt is ictx.toplevel_stmt,
                ctx=resultctx,
                span=expr.span,
            )

        for dtype in schemactx.get_all_concrete(mat_stype, ctx=ctx):
            if read_pol := policies.compile_dml_read_policies(
                dtype, result, mode=qltypes.AccessKind.UpdateRead, ctx=ictx
            ):
                stmt.read_policies[dtype.id] = read_pol
            if write_pol := policies.compile_dml_write_policies(
                dtype, result, mode=qltypes.AccessKind.UpdateWrite, ctx=ictx
            ):
                stmt.write_policies[dtype.id] = write_pol

            conflicts.check_for_isolation_conflicts(
                stmt, dtype, mat_stype, ctx=ictx)

        stmt.conflict_checks = conflicts.compile_inheritance_conflict_checks(
            stmt, mat_stype, ctx=ictx)

        result = fini_stmt(stmt, ctx=ictx, parent_ctx=ctx)

    return result


@dispatch.compile.register(qlast.DeleteQuery)
def compile_DeleteQuery(
    expr: qlast.DeleteQuery, *, ctx: context.ContextLevel
) -> irast.Set:

    if ctx.disallow_dml:
        raise errors.QueryError(
            f'DELETE statements cannot be used {ctx.disallow_dml}',
            hint=(
                f'To resolve this try to factor out the mutation '
                f'expression into the top-level WITH block.'
            ),
            span=expr.span,
        )

    _protect_expr(expr.subject, ctx=ctx)

    # Record this node in the list of potential DML expressions.
    ctx.env.dml_exprs.append(expr)

    with ctx.subquery() as ictx:
        stmt = irast.DeleteStmt(span=expr.span)
        # Expand the DELETE from sugar into full DELETE (SELECT ...)
        # form, if there's any additional clauses.
        if any([expr.where, expr.orderby, expr.offset, expr.limit]):
            if expr.offset or expr.limit:
                subjql = qlast.SelectQuery(
                    result=qlast.SelectQuery(
                        result=expr.subject,
                        where=expr.where,
                        orderby=expr.orderby,
                        span=expr.span,
                        implicit=True,
                    ),
                    limit=expr.limit,
                    offset=expr.offset,
                    span=expr.span,
                )
            else:
                subjql = qlast.SelectQuery(
                    result=expr.subject,
                    where=expr.where,
                    orderby=expr.orderby,
                    offset=expr.offset,
                    limit=expr.limit,
                    span=expr.span,
                )

            expr = qlast.DeleteQuery(
                aliases=expr.aliases,
                span=expr.span,
                subject=subjql,
            )

        init_stmt(stmt, expr, ctx=ictx, parent_ctx=ctx)

        # DELETE Expr is a delete(SET OF X), so we need a scope fence.
        with ictx.newscope(fenced=True) as scopectx:
            scopectx.implicit_limit = 0
            scopectx.expr_exposed = context.Exposure.UNEXPOSED
            subject = setgen.scoped_set(
                dispatch.compile(expr.subject, ctx=scopectx), ctx=scopectx)

        subj_type = setgen.get_set_type(subject, ctx=ictx)
        if not isinstance(subj_type, s_objtypes.ObjectType):
            raise errors.QueryError(
                f'cannot delete non-ObjectType objects',
                span=expr.subject.span
            )

        if subj_type.is_free_object_type(ctx.env.schema):
            raise errors.QueryError(
                f'free objects cannot be deleted',
                span=expr.subject.span)

        mat_stype = schemactx.concretify(subj_type, ctx=ctx)

        if _is_forbidden_stdlib_type_for_mod(mat_stype, ctx):
            raise errors.QueryError(
                f'cannot delete standard library type '
                f'{subj_type.get_displayname(ctx.env.schema)}',
                span=expr.subject.span)

        stmt._material_type = typeutils.type_to_typeref(
            ctx.env.schema,
            mat_stype,
            include_children=True,
            include_ancestors=True,
            cache=ctx.env.type_ref_cache,
        )

        with ictx.new() as bodyctx:
            bodyctx.implicit_id_in_shapes = False
            bodyctx.implicit_tid_in_shapes = False
            bodyctx.implicit_tname_in_shapes = False

            stmt.subject = compile_query_subject(
                subject,
                shape=None,
                exprtype=s_types.ExprType.Delete,
                ctx=bodyctx,
                span=expr.span,
            )

        result = setgen.class_set(
            mat_stype, path_id=stmt.subject.path_id, ctx=ctx
        )

        with ictx.new() as resultctx:
            stmt.result = compile_query_subject(
                result,
                view_scls=ctx.view_scls,
                view_name=ctx.toplevel_result_view_name,
                compile_views=ictx.stmt is ictx.toplevel_stmt,
                ctx=resultctx,
                span=expr.span,
            )

        for dtype in schemactx.get_all_concrete(mat_stype, ctx=ctx):
            # Compile policies for every concrete type
            if pol_cond := policies.compile_dml_read_policies(
                dtype, result, mode=qltypes.AccessKind.Delete, ctx=ictx
            ):
                stmt.read_policies[dtype.id] = pol_cond

            schema = ctx.env.schema
            # And find any pointers to delete
            ptrs = []
            for ptr in dtype.get_pointers(schema).objects(schema):
                # If there is a pointer that has a real table and doesn't
                # have a special ON SOURCE DELETE policy, arrange to
                # delete it in the query itself.
                if not ptr.is_pure_computable(schema) and (
                    not ptr.singular(schema)
                    or ptr.has_user_defined_properties(schema)
                ) and (
                    not isinstance(ptr, s_links.Link)
                    or ptr.get_on_source_delete(schema) ==
                    s_links.LinkSourceDeleteAction.Allow
                ):
                    ptrs.append(typegen.ptr_to_ptrref(ptr, ctx=ctx))

            stmt.links_to_delete[dtype.id] = tuple(ptrs)

        result = fini_stmt(stmt, ctx=ictx, parent_ctx=ctx)

    return result


@dispatch.compile.register
def compile_DescribeStmt(
    ql: qlast.DescribeStmt, *, ctx: context.ContextLevel
) -> irast.Set:
    with ctx.subquery() as ictx:
        stmt = irast.SelectStmt()
        init_stmt(stmt, ql, ctx=ictx, parent_ctx=ctx)

        if ql.object is qlast.DescribeGlobal.Schema:
            if ql.language is qltypes.DescribeLanguage.DDL:
                # DESCRIBE SCHEMA AS DDL
                text = s_ddl.ddl_text_from_schema(
                    ctx.env.schema,
                )
            elif ql.language is qltypes.DescribeLanguage.SDL:
                # DESCRIBE SCHEMA AS SDL
                text = s_ddl.sdl_text_from_schema(
                    ctx.env.schema,
                )
            else:
                raise errors.QueryError(
                    f'cannot describe full schema as {ql.language}')

            ct = typegen.type_to_typeref(
                ctx.env.get_schema_type_and_track(
                    s_name.QualName('std', 'str')),
                env=ctx.env,
            )

            stmt.result = setgen.ensure_set(
                irast.StringConstant(value=text, typeref=ct),
                ctx=ictx,
            )

        elif ql.object is qlast.DescribeGlobal.DatabaseConfig:
            if ql.language is qltypes.DescribeLanguage.DDL:
                stmt.result = config_desc.compile_describe_config(
                    qltypes.ConfigScope.DATABASE, ctx=ictx)
            else:
                raise errors.QueryError(
                    f'cannot describe config as {ql.language}')

        elif ql.object is qlast.DescribeGlobal.InstanceConfig:
            if ql.language is qltypes.DescribeLanguage.DDL:
                stmt.result = config_desc.compile_describe_config(
                    qltypes.ConfigScope.INSTANCE, ctx=ictx)
            else:
                raise errors.QueryError(
                    f'cannot describe config as {ql.language}')

        elif ql.object is qlast.DescribeGlobal.Roles:
            if ql.language is qltypes.DescribeLanguage.DDL:
                function_call = dispatch.compile(
                    qlast.FunctionCall(
                        func=('sys', '_describe_roles_as_ddl'),
                    ),
                    ctx=ictx)
                stmt.result = function_call
            else:
                raise errors.QueryError(
                    f'cannot describe roles as {ql.language}')

        else:
            assert isinstance(ql.object, qlast.ObjectRef), ql.object
            modules = []
            items: defaultdict[str, list[s_name.Name]] = defaultdict(list)
            referenced_classes: list[s_obj.ObjectMeta] = []

            objref = ql.object
            itemclass = objref.itemclass

            if itemclass is qltypes.SchemaObjectClass.MODULE:
                mod = s_name.UnqualName(str(s_utils.ast_ref_to_name(objref)))
                if not ctx.env.schema.get_global(
                        s_mod.Module, mod, None):
                    raise errors.InvalidReferenceError(
                        f"module '{mod}' does not exist",
                        span=objref.span,
                    )

                modules.append(mod)
            else:
                itemtype: Optional[type[s_obj.Object]] = None

                name = s_utils.ast_ref_to_name(objref)
                if itemclass is not None:
                    if itemclass is qltypes.SchemaObjectClass.ALIAS:
                        # Look for underlying derived type.
                        itemtype = s_types.Type
                    else:
                        itemtype = (
                            s_obj.ObjectMeta.get_schema_metaclass_for_ql_class(
                                itemclass)
                        )

                last_exc = None
                # Search in the current namespace AND in std. We do
                # this to avoid masking a `std` object/function by one
                # in a default module.
                search_ns = [ictx.modaliases]
                # Only check 'std' separately if the current
                # modaliases don't already include it.
                if ictx.modaliases.get(None, 'std') != 'std':
                    search_ns.append({None: 'std'})

                # Search in the current namespace AND in std.
                for aliases in search_ns:
                    # Use the specific modaliases instead of the
                    # context ones.
                    with ictx.subquery() as newctx:
                        newctx.modaliases = aliases
                        # Get the default module name
                        modname = aliases[None]
                        # Is the current item a function
                        is_function = (itemclass is
                                       qltypes.SchemaObjectClass.FUNCTION)

                        # We need to check functions if we're looking for them
                        # specifically or if this is a broad search. They are
                        # handled separately because they allow multiple
                        # matches for the same name.
                        if (itemclass is None or is_function):
                            funcs = s_func.lookup_functions(
                                name,
                                tuple(),
                                module_aliases=aliases,
                                schema=newctx.env.schema,
                            )
                            for func in funcs:
                                items[f'function_{modname}'].append(
                                    func.get_name(newctx.env.schema)
                                )

                        # Also find an object matching the name as long as
                        # it's not a function we're looking for specifically.
                        if not is_function:
                            try:
                                if itemclass is not \
                                        qltypes.SchemaObjectClass.ALIAS:
                                    condition = None
                                    label = None
                                else:
                                    condition = (
                                        lambda obj:
                                        obj.get_alias_is_persistent(
                                            ctx.env.schema
                                        )
                                    )
                                    label = 'alias'
                                obj = schemactx.get_schema_object(
                                    objref,
                                    item_type=itemtype,
                                    condition=condition,
                                    label=label,
                                    ctx=newctx,
                                )
                                items[f'other_{modname}'].append(
                                    obj.get_name(newctx.env.schema))
                            except errors.InvalidReferenceError as exc:
                                # Record the exception to be possibly
                                # raised if no matches are found
                                last_exc = exc

                # If we already have some results, suppress the exception,
                # otherwise raise the recorded exception.
                if not items and last_exc:
                    raise last_exc

                if not items:
                    raise errors.InvalidReferenceError(
                        f"{str(itemclass).lower()} '{objref.name}' "
                        f"does not exist",
                        span=objref.span,
                    )

            verbose = ql.options.get_flag('VERBOSE')

            method: Any
            if ql.language is qltypes.DescribeLanguage.DDL:
                method = s_ddl.ddl_text_from_schema
            elif ql.language is qltypes.DescribeLanguage.SDL:
                method = s_ddl.sdl_text_from_schema
            elif ql.language is qltypes.DescribeLanguage.TEXT:
                method = s_ddl.descriptive_text_from_schema
                if not verbose.val:
                    referenced_classes = [s_links.Link, s_props.Property]
            else:
                raise errors.InternalServerError(
                    f'cannot handle describe language {ql.language}'
                )

            # Based on the items found generate main text and a
            # potential comment about masked items.
            defmod = ictx.modaliases.get(None, 'std')
            default_items = []
            masked_items = set()
            for objtype in ['function', 'other']:
                defkey = f'{objtype}_{defmod}'
                mskkey = f'{objtype}_std'

                default_items += items.get(defkey, [])
                if defkey in items and mskkey in items:
                    # We have a match in default module and some masked.
                    masked_items.update(items.get(mskkey, []))
                else:
                    default_items += items.get(mskkey, [])

            # Throw out anything in the masked set that's already in
            # the default.
            masked_items.difference_update(default_items)

            text = method(
                ctx.env.schema,
                included_modules=modules,
                included_items=default_items,
                included_ref_classes=referenced_classes,
                include_module_ddl=False,
                include_std_ddl=True,
            )
            if masked_items:
                text += ('\n\n'
                         '# The following builtins are masked by the above:'
                         '\n\n')
                masked = method(
                    ctx.env.schema,
                    included_modules=modules,
                    included_items=masked_items,
                    included_ref_classes=referenced_classes,
                    include_module_ddl=False,
                    include_std_ddl=True,
                )
                masked = textwrap.indent(masked, '# ')
                text += masked

            ct = typegen.type_to_typeref(
                ctx.env.get_schema_type_and_track(
                    s_name.QualName('std', 'str')),
                env=ctx.env,
            )

            stmt.result = setgen.ensure_set(
                irast.StringConstant(value=text, typeref=ct),
                ctx=ictx,
            )

        result = fini_stmt(stmt, ctx=ictx, parent_ctx=ctx)

    return result


@dispatch.compile.register(qlast.Shape)
def compile_Shape(
    shape: qlast.Shape, *, ctx: context.ContextLevel
) -> irast.Set:

    if ctx.no_factoring and not shape.allow_factoring:
        return dispatch.compile(
            qlast.SelectQuery(result=shape, implicit=True),
            ctx=ctx,
        )

    shape_expr = shape.expr or qlutils.FREE_SHAPE_EXPR
    with ctx.new() as subctx:
        subctx.qlstmt = astutils.ensure_ql_query(shape)
        subctx.stmt = stmt = irast.SelectStmt()
        ctx.env.compiled_stmts[subctx.qlstmt] = stmt
        subctx.class_view_overrides = subctx.class_view_overrides.copy()

        with subctx.new() as exposed_ctx:
            exposed_ctx.expr_exposed = context.Exposure.UNEXPOSED
            expr = dispatch.compile(shape_expr, ctx=exposed_ctx)

        expr_stype = setgen.get_set_type(expr, ctx=ctx)
        if not isinstance(expr_stype, s_objtypes.ObjectType):
            raise errors.QueryError(
                f'shapes cannot be applied to '
                f'{expr_stype.get_verbosename(ctx.env.schema)}',
                span=shape.span,
            )

        stmt.result = compile_query_subject(
            expr,
            shape=shape.elements,
            compile_views=False,
            ctx=subctx,
            span=expr.span)

        ir_result = setgen.ensure_set(stmt, ctx=subctx)

    return ir_result


def init_stmt(
    irstmt: irast.Stmt,
    qlstmt: qlast.Statement,
    *,
    ctx: context.ContextLevel,
    parent_ctx: context.ContextLevel,
) -> None:

    ctx.env.compiled_stmts[qlstmt] = irstmt

    irstmt.span = qlstmt.span

    if isinstance(irstmt, irast.MutatingStmt):
        # This is some kind of mutation, so we need to check if it is
        # allowed.
        if ctx.env.options.in_ddl_context_name is not None:
            raise errors.SchemaDefinitionError(
                f'mutations are invalid in '
                f'{ctx.env.options.in_ddl_context_name}',
                span=qlstmt.span,
            )
        elif (
            (dv := ctx.defining_view) is not None
            and dv.get_expr_type(ctx.env.schema) is s_types.ExprType.Select
            and not (
                # We allow DML in trivial *top-level* free objects
                ctx.partial_path_prefix
                and irutils.is_trivial_free_object(
                    irutils.unwrap_set(ctx.partial_path_prefix))
                # Find the enclosing context at the point the free object
                # was defined.
                and (outer_ctx := next((
                    x for x in reversed(ctx._stack.stack)
                    if isinstance(x, context.ContextLevel)
                    and x.partial_path_prefix != ctx.partial_path_prefix
                ), None))
                and outer_ctx.expr_exposed
            )
        ):
            # This is some shape in a regular query. Although
            # DML is not allowed in the computable, but it may
            # be possible to refactor it.
            raise errors.QueryError(
                f"mutations are invalid in a shape's computed expression",
                hint=(
                    f'To resolve this try to factor out the mutation '
                    f'expression into the top-level WITH block.'
                ),
                span=qlstmt.span,
            )

    ctx.stmt = irstmt
    ctx.qlstmt = qlstmt
    if ctx.toplevel_stmt is None:
        parent_ctx.toplevel_stmt = ctx.toplevel_stmt = irstmt

    ctx.path_scope = parent_ctx.path_scope.attach_fence()

    pending_own_ns = parent_ctx.pending_stmt_own_path_id_namespace
    if pending_own_ns:
        ctx.path_scope.add_namespaces(pending_own_ns)

    pending_full_ns = parent_ctx.pending_stmt_full_path_id_namespace
    if pending_full_ns:
        ctx.path_id_namespace |= pending_full_ns

    irstmt.parent_stmt = parent_ctx.stmt

    irstmt.bindings = process_with_block(
        qlstmt, ctx=ctx, parent_ctx=parent_ctx)

    if isinstance(irstmt, irast.MutatingStmt):
        ctx.path_scope.factoring_fence = True
        ctx.path_scope.factoring_allowlist.update(ctx.iterator_path_ids)


def fini_stmt(
    irstmt: irast.Stmt | irast.Set,
    *,
    ctx: context.ContextLevel,
    parent_ctx: context.ContextLevel,
) -> irast.Set:

    view_name = parent_ctx.toplevel_result_view_name
    t = setgen.get_expr_type(irstmt, ctx=ctx)

    view: Optional[s_types.Type]
    path_id: Optional[irast.PathId]

    if isinstance(irstmt, irast.MutatingStmt):
        ctx.env.dml_stmts.append(irstmt)
        irstmt.rewrites = ctx.env.dml_rewrites.pop(irstmt.subject, None)

    if (isinstance(t, s_pseudo.PseudoType)
            and t.is_any(ctx.env.schema)):
        # Need to produce something valid. Should get caught as an
        # error later.
        view = None
        path_id = None

    elif t.get_name(ctx.env.schema) == view_name:
        # The view statement did contain a view declaration and
        # generated a view class with the requested name.
        view = t
        path_id = pathctx.get_path_id(view, ctx=parent_ctx)
    elif view_name is not None:
        # The view statement did _not_ contain a view declaration,
        # but we still want the correct path_id.
        view_obj = ctx.env.schema.get(view_name, None)
        if view_obj is not None:
            assert isinstance(view_obj, s_types.Type)
            view = view_obj
        else:
            view = schemactx.derive_view(
                t,
                derived_name=view_name,
                preserve_shape=True,
                attrs={'span': irstmt.span},
                ctx=parent_ctx
            )
        path_id = pathctx.get_path_id(view, ctx=parent_ctx)
    else:
        view = None
        path_id = None

    type_override = view if view is not None else None
    result = setgen.scoped_set(
        irstmt, type_override=type_override, path_id=path_id, ctx=ctx)
    if irstmt.span and not result.span:
        result = setgen.new_set_from_set(
            result, span=irstmt.span, ctx=ctx)

    if view is not None:
        parent_ctx.view_sets[view] = result

    return result


def process_with_block(
    edgeql_tree: qlast.Statement,
    *,
    ctx: context.ContextLevel,
    parent_ctx: context.ContextLevel,
) -> list[tuple[irast.Set, qltypes.Volatility]]:
    if edgeql_tree.aliases is None:
        return []

    had_materialized = False
    results = []
    for with_entry in edgeql_tree.aliases:
        if isinstance(with_entry, qlast.ModuleAliasDecl):
            ctx.modaliases[with_entry.alias] = with_entry.module

        elif isinstance(with_entry, qlast.AliasedExpr):
            with ctx.new() as scopectx:
                if scopectx.expr_exposed:
                    scopectx.expr_exposed = context.Exposure.BINDING
                binding = stmtctx.declare_view(
                    with_entry.expr,
                    s_name.UnqualName(with_entry.alias),
                    binding_kind=irast.BindingKind.With,
                    ctx=scopectx,
                )
                volatility = inference.infer_volatility(
                    binding, ctx.env, exclude_dml=True
                )
                results.append((binding, volatility))

                if reason := setgen.should_materialize(binding, ctx=ctx):
                    had_materialized = True
                    typ = setgen.get_set_type(binding, ctx=ctx)
                    ctx.env.materialized_sets[typ] = edgeql_tree, reason
                    setgen.maybe_materialize(typ, binding, ctx=ctx)

        else:
            raise RuntimeError(
                f'unexpected expression in WITH block: {with_entry}')

    if had_materialized:
        # If we had to materialize, put the body of the statement into
        # its own fence, to avoid potential spurious factoring when we
        # compile view sets for materialized sets.
        # (We could just *always* do this, but don't, to avoid cluttering
        # up the scope tree more.)
        ctx.path_scope = ctx.path_scope.attach_fence()

    return results


def compile_result_clause(
    result: qlast.Expr,
    *,
    view_scls: Optional[s_types.Type] = None,
    view_rptr: Optional[context.ViewRPtr] = None,
    view_name: Optional[s_name.QualName] = None,
    exprtype: s_types.ExprType = s_types.ExprType.Select,
    result_alias: Optional[str] = None,
    forward_rptr: bool = False,
    ctx: context.ContextLevel,
) -> irast.Set:
    with ctx.new() as sctx:
        if forward_rptr:
            sctx.view_rptr = view_rptr
            # sctx.view_scls = view_scls

        if result_alias:
            # `SELECT foo := expr` is equivalent to
            # `WITH foo := expr SELECT foo`
            rexpr = astutils.ensure_ql_select(result)

            stmtctx.declare_view(
                rexpr,
                alias=s_name.UnqualName(result_alias),
                binding_kind=irast.BindingKind.Select,
                ctx=sctx,
            )

            result = qlast.Path(
                steps=[qlast.ObjectRef(name=result_alias)],
                allow_factoring=True,
            )

        result_expr: qlast.Expr
        shape: Optional[Sequence[qlast.ShapeElement]]

        if isinstance(result, qlast.Shape):
            result_expr = result.expr or qlutils.FREE_SHAPE_EXPR
            shape = result.elements
        else:
            result_expr = result
            shape = None

        if astutils.is_ql_empty_set(result_expr):
            expr = setgen.new_empty_set(
                stype=sctx.empty_result_type_hint,
                alias=ctx.aliases.get('e'),
                ctx=sctx,
                span=result_expr.span,
            )
        elif astutils.is_ql_empty_array(result_expr):
            type_hint: Optional[s_types.Type] = None
            if (
                sctx.empty_result_type_hint is not None
                and sctx.empty_result_type_hint.is_array()
            ):
                type_hint = sctx.empty_result_type_hint

            expr = setgen.new_array_set(
                [],
                stype=type_hint,
                ctx=sctx,
                span=result_expr.span,
            )
        else:
            with sctx.new() as ectx:
                if shape is not None:
                    ectx.expr_exposed = context.Exposure.UNEXPOSED
                expr = dispatch.compile(result_expr, ctx=ectx)

        ctx.partial_path_prefix = expr

        ir_result = compile_query_subject(
            expr, shape=shape, view_rptr=view_rptr, view_name=view_name,
            forward_rptr=forward_rptr,
            result_alias=result_alias,
            view_scls=view_scls,
            allow_select_shape_inject=False,
            exprtype=exprtype,
            compile_views=ctx.stmt is ctx.toplevel_stmt,
            ctx=sctx,
            span=result.span
        )

        ctx.partial_path_prefix = ir_result

    return ir_result


def compile_query_subject(
        set: irast.Set,
        *,
        shape: Optional[list[qlast.ShapeElement]]=None,
        view_rptr: Optional[context.ViewRPtr]=None,
        view_name: Optional[s_name.QualName]=None,
        result_alias: Optional[str]=None,
        view_scls: Optional[s_types.Type]=None,
        compile_views: bool=True,
        exprtype: s_types.ExprType = s_types.ExprType.Select,
        allow_select_shape_inject: bool=True,
        forward_rptr: bool=False,
        span: Optional[qlast.Span],
        ctx: context.ContextLevel) -> irast.Set:

    set_stype = setgen.get_set_type(set, ctx=ctx)

    set_expr = set.expr
    while isinstance(set_expr, irast.TypeIntersectionPointer):
        set_expr = set_expr.source.expr

    is_ptr_alias = (
        view_rptr is not None
        and view_rptr.ptrcls is None
        and view_rptr.ptrcls_name is not None
        and isinstance(set_expr, irast.Pointer)
        and not isinstance(set_expr.source.expr, irast.Pointer)
        and (
            view_rptr.source.get_bases(ctx.env.schema).first(ctx.env.schema).id
            == set_expr.source.typeref.id
        )
        and (
            view_rptr.ptrcls_is_linkprop
            == (set_expr.ptrref.source_ptr is not None)
        )
    )

    if is_ptr_alias:
        assert view_rptr is not None
        set_rptr = cast(irast.Pointer, set_expr)
        # We are inside an expression that defines a link alias in
        # the parent shape, ie. Spam { alias := Spam.bar }, so
        # `Spam.alias` should be a subclass of `Spam.bar` inheriting
        # its properties.
        #
        # We also try to detect reverse aliases like `.= context.Exposure.BINDING
                and allow_select_shape_inject

                and not forward_rptr
                and viewgen.has_implicit_type_computables(
                    set_stype,
                    is_mutation=exprtype.is_mutation(),
                    ctx=ctx,
                )
                and not set_stype.is_view(ctx.env.schema)
            )
            or exprtype.is_mutation()
            or (
                exprtype == s_types.ExprType.Group
                and not set_stype.is_view(ctx.env.schema)
            )
        )
        and set_stype.is_object_type()
        and shape is None
    ):
        # Force the subject to be compiled as a view in these cases:
        # a) a __tid__ insertion is anticipated (the actual
        #    decision about this is taken by the
        #    compile_view_shapes() flow);
        #    we also skip doing this when forward_rptr is true, because
        #    generating an extra type in those cases can cause issues,
        #    and we can just do the insertion on whatever the inner thing is
        #
        #    Note that we do this when exposed or when potentially exposed
        #    because we are in a binding. This is because types that
        #    appear in bindings might get put into the output
        #    and need a __tid__ injection without having a chance to have
        #    a shape put on them.
        # b) this is a mutation without an explicit shape,
        #    such as a DELETE, because mutation subjects are
        #    always expected to be derived types.
        shape = []

    if shape is not None and view_scls is None:
        if (view_name is None and
                isinstance(result_alias, s_name.QualName)):
            view_name = result_alias

        if not isinstance(set_stype, s_objtypes.ObjectType):
            raise errors.QueryError(
                f'shapes cannot be applied to '
                f'{set_stype.get_verbosename(ctx.env.schema)}',
                span=span,
            )

        view_scls, set = viewgen.process_view(
            set,
            stype=set_stype,
            elements=shape,
            view_rptr=view_rptr,
            view_name=view_name,
            exprtype=exprtype,
            ctx=ctx,
            span=span,
        )

    if view_scls is not None:
        set = setgen.ensure_set(set, type_override=view_scls, ctx=ctx)
        set_stype = view_scls

    if compile_views:
        viewgen.late_compile_view_shapes(set, ctx=ctx)

    if (shape is not None or view_scls is not None) and len(set.path_id) == 1:
        ctx.class_view_overrides[set.path_id.target.id] = set_stype

    if shape:
        # make sure that an applied shape expands the span of the set
        set.span = edb_span.merge_spans(
            itertools.chain(
                (s.span for s in [set] if s.span),
                (el.span for el in shape if el.span)
            )
        )

    return set


def maybe_add_view(ir: irast.Set, *, ctx: context.ContextLevel) -> irast.Set:
    """Possibly wrap ir in a new view, if needed for tid/tname injection

    This should be called by every ast leaf compilation that can originate
    an object type.
    """

    # We call compile_query_subject in order to create a new view for
    # injecting properties if needed. This will only happen if
    # expr_exposed, so stmt code paths that don't want a new view
    # created (because there is a shape already specified or because
    # it wants to create its own new view in its compile_query_subject call)
    # should make sure expr_exposed is false.
    #
    # The checks here are microoptimizations.
    if (
        ctx.expr_exposed >= context.Exposure.BINDING
        and ir.path_id.is_objtype_path()
    ):
        return compile_query_subject(
            ir, allow_select_shape_inject=True, compile_views=False, ctx=ctx,
            span=ir.span)
    else:
        return ir


def _is_forbidden_stdlib_type_for_mod(
    t: s_types.Type, ctx: context.ContextLevel
) -> bool:
    o = ctx.env.options
    if o.bootstrap_mode or o.schema_reflection_mode:
        return False

    schema = ctx.env.schema

    assert isinstance(t, s_objtypes.ObjectType)
    assert not t.is_view(schema)

    if intersection := t.get_intersection_of(schema):
        return all((_is_forbidden_stdlib_type_for_mod(it, ctx)
                    for it in intersection.objects(schema)))
    elif union := t.get_union_of(schema):
        return any((_is_forbidden_stdlib_type_for_mod(ut, ctx)
                    for ut in union.objects(schema)))

    name = t.get_name(schema)
    mod_name = name.get_module_name()

    if (
        mod_name == s_name.UnqualName('cfg')
        and o.in_server_config_op
    ):
        # Config ops include various internally generated statements for cfg::
        return False
    if name == s_name.QualName('std', 'Object'):
        # Allow people to mess with the baseclass of user-defined objects to
        # their hearts' content
        return False
    if mod_name == s_name.UnqualName('std::net::http'):
        # Allow users to insert net module types
        return False
    return mod_name in s_schema.STD_MODULES


================================================
FILE: edb/edgeql/compiler/stmtctx.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL compiler statement-level context management."""


from __future__ import annotations

from typing import (
    Any,
    Optional,
    Mapping,
    Sequence,
)

import copy
import uuid

from edb import errors

from edb.ir import ast as irast
from edb.ir import utils as irutils
from edb.ir import typeutils as irtyputils

from edb.schema import constraints as s_constr
from edb.schema import modules as s_mod
from edb.schema import name as s_name
from edb.schema import objects as s_obj
from edb.schema import objtypes as s_objtypes
from edb.schema import pointers as s_pointers
from edb.schema import rewrites as s_rewrites
from edb.schema import schema as s_schema
from edb.schema import sources as s_sources
from edb.schema import types as s_types
from edb.schema import expr as s_expr

from edb.edgeql import ast as qlast

from edb.common.ast import visitor as ast_visitor
from edb.common import ordered
from edb.common.typeutils import not_none

from . import astutils
from . import context
from . import dispatch
from . import eta_expand
from . import group
from . import inference
from . import options as coptions
from . import pathctx
from . import setgen
from . import viewgen
from . import schemactx
from . import triggers
from . import tuple_args
from . import typegen


def init_context(
    *,
    schema: s_schema.Schema,
    options: coptions.CompilerOptions,
    inlining_context: Optional[context.ContextLevel] = None,
) -> context.ContextLevel:

    if not schema.get_global(s_mod.Module, '__derived__', None):
        schema, _ = s_mod.Module.create_in_schema(
            schema,
            name=s_name.UnqualName('__derived__'),
        )

    if inlining_context:
        env = copy.copy(inlining_context.env)
        env.options = options
        env.path_scope = inlining_context.path_scope
        env.alias_result_view_name = options.result_view_name
        env.query_parameters = {}
        env.server_param_conversions = {}
        env.server_param_conversion_calls = []
        env.script_params = {}

        ctx = context.ContextLevel(
            inlining_context, mode=context.ContextSwitchMode.DETACHED
        )
        ctx.env = env

    else:
        env = context.Environment(
            schema=schema,
            options=options,
            alias_result_view_name=options.result_view_name,
        )
        ctx = context.ContextLevel(None, context.ContextSwitchMode.NEW, env=env)
    _ = context.CompilerContext(initial=ctx)

    if options.singletons:
        # The caller wants us to treat these type and pointer
        # references as singletons for the purposes of the overall
        # expression cardinality inference, so we set up the scope
        # tree in the necessary fashion.
        had_optional = False
        for singleton_ent in options.singletons:
            singleton, optional = (
                singleton_ent if isinstance(singleton_ent, tuple)
                else (singleton_ent, False)
            )
            had_optional |= optional
            path_id = compile_anchor('__', singleton, ctx=ctx).path_id
            ctx.env.path_scope.attach_path(
                path_id, optional=optional, span=None, ctx=ctx
            )
            if not optional:
                ctx.env.singletons.append(path_id)
            ctx.iterator_path_ids |= {path_id}

        # If we installed any optional singletons, run the rest of the
        # compilation under a fence to protect them.
        if had_optional:
            ctx.path_scope = ctx.path_scope.attach_fence()

    for orig, remapped in options.type_remaps.items():
        rset = compile_anchor('__', remapped, ctx=ctx)
        ctx.view_sets[orig] = rset
        ctx.env.path_scope_map[rset] = context.ScopeInfo(
            path_scope=ctx.path_scope, binding_kind=None
        )

    ctx.modaliases.update(options.modaliases)

    if options.anchors:
        with ctx.newscope(fenced=True) as subctx:
            populate_anchors(options.anchors, ctx=subctx)

    if options.path_prefix_anchor is not None:
        path_prefix = options.anchors[options.path_prefix_anchor]
        ctx.partial_path_prefix = compile_anchor(
            options.path_prefix_anchor, path_prefix, ctx=ctx)
        ctx.partial_path_prefix.anchor = options.path_prefix_anchor
        ctx.partial_path_prefix.show_as_anchor = options.path_prefix_anchor

    if options.detached:
        ctx.path_id_namespace = frozenset({ctx.aliases.get('ns')})

    if options.schema_object_context is s_rewrites.Rewrite:
        assert ctx.partial_path_prefix
        typ = setgen.get_set_type(ctx.partial_path_prefix, ctx=ctx)
        assert isinstance(typ, s_objtypes.ObjectType)
        ctx.active_rewrites |= {typ, *typ.descendants(ctx.env.schema)}

    ctx.derived_target_module = options.derived_target_module
    ctx.toplevel_result_view_name = options.result_view_name
    ctx.implicit_id_in_shapes = options.implicit_id_in_shapes
    ctx.implicit_tid_in_shapes = options.implicit_tid_in_shapes
    ctx.implicit_tname_in_shapes = options.implicit_tname_in_shapes
    ctx.implicit_limit = options.implicit_limit
    ctx.expr_exposed = context.Exposure.EXPOSED

    ctx.no_factoring = True

    return ctx


def fini_expression(
    ir: irast.Set, *, ctx: context.ContextLevel
) -> irast.Statement | irast.ConfigCommand:

    ctx.path_scope = ctx.env.path_scope

    ir = eta_expand.eta_expand_ir(ir, toplevel=True, ctx=ctx)

    if (
        isinstance(ir, irast.Set)
        and pathctx.get_set_scope(ir, ctx=ctx) is None
    ):
        ir = setgen.scoped_set(ir, ctx=ctx)

    # Compile any triggers that were triggered by the query
    ir_triggers = triggers.compile_triggers(ctx=ctx)

    # Collect all of the expressions stored in various side sets
    # that can make it into the output, so that we can make sure
    # to catch them all in our fixups and analyses.
    # IMPORTANT: Any new expressions that are sent to the backend
    # but don't appear in `ir` must be added here.
    extra_exprs: list[irast.Set] = []
    extra_exprs += [
        rw for rw in ctx.env.type_rewrites.values()
        if isinstance(rw, irast.Set)
    ]
    extra_exprs += [
        p.sub_params.decoder_ir for p in ctx.env.query_parameters.values()
        if p.sub_params and p.sub_params.decoder_ir
    ]
    extra_exprs += [
        conversion.ir_param.sub_params.decoder_ir
        for conversions in ctx.env.server_param_conversions.values()
        for conversion in conversions.values()
        if (
            conversion.ir_param.sub_params
            and conversion.ir_param.sub_params.decoder_ir
        )
    ]
    extra_exprs += [trigger.expr for stage in ir_triggers for trigger in stage]

    all_exprs = [ir] + extra_exprs

    # exprs_to_clear collects sets where we should never need to use
    # their expr in pgsql compilation, so we strip it out to make this
    # more evident in debug output. We have to do the clearing at the
    # end, because multiplicity/cardinality inference needs to be able
    # to look through those pointers.
    exprs_to_clear = _fixup_materialized_sets(all_exprs, ctx=ctx)
    for expr in all_exprs:
        exprs_to_clear.extend(_find_visible_binding_refs(expr, ctx=ctx))

    # The inference context object will be shared between
    # cardinality and multiplicity inferrers.
    inf_ctx = inference.make_ctx(env=ctx.env)
    cardinality = inference.infer_cardinality(
        ir, scope_tree=ctx.path_scope, ctx=inf_ctx
    )
    multiplicity = inference.infer_multiplicity(
        ir, scope_tree=ctx.path_scope, ctx=inf_ctx
    )

    for extra in extra_exprs:
        inference.infer_cardinality(
            extra, scope_tree=ctx.path_scope, ctx=inf_ctx)
        inference.infer_multiplicity(
            extra, scope_tree=ctx.path_scope, ctx=inf_ctx)

    # Fix up weak namespaces
    _rewrite_weak_namespaces(all_exprs, ctx)

    _collapse_factoring_protected(all_exprs, ctx)

    ctx.path_scope.validate_unique_ids()

    # Collect query parameters
    params = collect_params(ctx)
    server_param_conversions, server_param_conversion_params = (
        collect_server_param_conversions(ctx)
    )

    # ConfigSet and ConfigReset don't like being part of a Set, so bail early
    if isinstance(ir.expr, (irast.ConfigSet, irast.ConfigReset)):
        ir.expr.scope_tree = ctx.path_scope
        ir.expr.globals = list(ctx.env.query_globals.values())
        ir.expr.params = params
        ir.expr.schema = ctx.env.schema
        ir.expr.type_rewrites = _get_type_rewrites(ctx)

        if ctx.env.server_param_conversion_calls:
            func_name, func_span = ctx.env.server_param_conversion_calls[0]
            raise errors.QueryError(
                f"Function '{func_name}' is not allowed in a config statement.",
                span=func_span,
            )

        return ir.expr

    volatility = inference.infer_volatility(ir, env=ctx.env)
    expr_type = setgen.get_set_type(ir, ctx=ctx)

    in_polymorphic_func = (
        ctx.env.options.func_params is not None and
        ctx.env.options.func_params.has_polymorphic(ctx.env.schema)
    )
    if (
        not in_polymorphic_func
        and not ctx.env.options.allow_generic_type_output
    ):
        anytype = expr_type.find_generic(ctx.env.schema)
        if anytype is not None:
            raise errors.QueryError(
                'expression returns value of indeterminate type',
                hint='Consider using an explicit type cast.',
                span=ctx.env.type_origins.get(anytype))

    # Clear out exprs that we decided to omit from the IR
    for ir_set in exprs_to_clear:
        new = (
            irast.MaterializedExpr(typeref=ir_set.typeref)
            if ir_set.is_materialized_ref
            else irast.VisibleBindingExpr(typeref=ir_set.typeref)
        )
        if isinstance(ir_set.expr, irast.Pointer):
            ir_set.expr.expr = new
        else:
            ir_set.expr = new

    # Analyze GROUP statements to find aggregates that can be optimized
    group.infer_group_aggregates(all_exprs, ctx=ctx)

    # If we are producing a schema view, clean up the result types
    if ctx.env.options.schema_view_mode:
        _fixup_schema_view(ctx=ctx)

    result = irast.Statement(
        expr=ir,
        params=params,
        globals=list(ctx.env.query_globals.values()),
        required_permissions=set(ctx.env.required_permissions),
        server_param_conversions=server_param_conversions,
        server_param_conversion_params=server_param_conversion_params,
        views=ctx.view_nodes,
        scope_tree=ctx.env.path_scope,
        cardinality=cardinality,
        volatility=volatility,
        multiplicity=multiplicity.own,
        stype=expr_type,
        view_shapes={
            src: [ptr for ptr, op in ptrs if op != qlast.ShapeOp.MATERIALIZE]
            for src, ptrs in ctx.env.view_shapes.items()
            if isinstance(src, s_obj.Object)
        },
        view_shapes_metadata=ctx.env.view_shapes_metadata,
        schema=ctx.env.schema,
        schema_refs=frozenset(
            {
                r
                for r in ctx.env.schema_refs
                # filter out newly derived objects
                if ctx.env.orig_schema.has_object(r.id)
            }
        ),
        schema_ref_exprs=ctx.env.schema_ref_exprs,
        type_rewrites=_get_type_rewrites(ctx),
        dml_exprs=ctx.env.dml_exprs,
        singletons=ctx.env.singletons,
        triggers=ir_triggers,
        warnings=tuple(ctx.env.warnings),
        unsafe_isolation_dangers=tuple(ctx.env.unsafe_isolation_dangers),
    )
    return result


def _get_type_rewrites(ctx: context.ContextLevel) -> dict[
    tuple[uuid.UUID, bool], irast.Set
]:
    return {
        (typ.id, not skip_subtypes): s
        for (typ, skip_subtypes), s in ctx.env.type_rewrites.items()
        if isinstance(s, irast.Set)
    }


def collect_params(ctx: context.ContextLevel) -> list[irast.Param]:
    lparams = [
        p for p in ctx.env.query_parameters.values() if not p.is_sub_param
    ]
    if ctx.env.script_params:
        script_ordering = {k: i for i, k in enumerate(ctx.env.script_params)}
        lparams.sort(key=lambda x: script_ordering[x.name])

    params = []
    # Now flatten it out, including all sub_params, making sure subparams
    # appear in the right order.
    for p in lparams:
        params.append(p)
        if p.sub_params:
            params.extend(p.sub_params.params)
    return params


def collect_server_param_conversions(
    ctx: context.ContextLevel
) -> tuple[
    list[irast.ServerParamConversion],
    list[irast.Param],
]:
    """Gather converted parameters for use in the ir Statement.

    Returns ServerParamConversion which will eventually be sent to the server
    as well as the irast.Params which should be used to generate the pgast.
    """
    lparams = [
        (
            param_name,
            conversion_name,
            conversion,
        )
        for param_name, conversions in ctx.env.server_param_conversions.items()
        for conversion_name, conversion in conversions.items()
        if not conversion.ir_param.is_sub_param
    ]
    script_ordering = {k: i for i, k in enumerate(ctx.env.script_params)}

    # Add ordering for param conversions which don't match query params.
    # This can happen for constants.
    extra_ordering: dict[str, int] = {}
    for param_name in sorted(ctx.env.server_param_conversions.keys()):
        if param_name not in script_ordering:
            extra_ordering[param_name] = (
                len(script_ordering) + len(extra_ordering)
            )
    script_ordering.update(extra_ordering)

    lparams.sort(key=lambda x: (script_ordering[x[0]], x[1]))

    conversions = []
    params = []
    # Now flatten it out, including all sub_params, making sure subparams
    # appear in the right order.
    for param_name, conversion_name, conversion in lparams:
        conversions.append(irast.ServerParamConversion(
            param_name=param_name,
            conversion_name=conversion_name,
            additional_info=conversion.additional_info,
            script_param_index=conversion.script_param_index,
            constant_value=conversion.constant_value,
        ))
        params.append(conversion.ir_param)
        if conversion.ir_param.sub_params:
            params.extend(conversion.ir_param.sub_params.params)
    return conversions, params


def _fixup_materialized_sets(
    irs: Sequence[irast.Base], *, ctx: context.ContextLevel
) -> list[irast.Set]:
    # Make sure that all materialized sets have their views compiled
    skips = {'materialized_sets'}
    children = []
    for ir in irs:
        children += ast_visitor.find_children(
            ir, irast.Stmt, extra_skips=skips)

    to_clear = []
    for stmt in ordered.OrderedSet(children):
        if not stmt.materialized_sets:
            continue
        for key in list(stmt.materialized_sets):
            mat_set = stmt.materialized_sets[key]
            assert not mat_set.finalized

            if len(mat_set.uses) <= 1:
                del stmt.materialized_sets[key]
                continue

            ir_set = mat_set.materialized
            assert ir_set.path_scope_id is not None
            new_scope = ctx.env.scope_tree_nodes[ir_set.path_scope_id]
            parent = not_none(new_scope.parent)

            good_reason = False
            for x in mat_set.reason:
                if isinstance(x, irast.MaterializeVolatile):
                    good_reason = True
                elif isinstance(x, irast.MaterializeVisible):
                    reason_scope = ctx.env.scope_tree_nodes[x.path_scope_id]
                    reason_parent = not_none(reason_scope.parent)

                    # If any of the bindings that the set uses are
                    # *visible* at the definition point and *not
                    # visible* from at least one use point, we need to
                    # materialize, to make sure that the use site sees
                    # the same value for the binding as the definition
                    # point. If it's not visible, then it's just being
                    # used internally and we don't need any special
                    # work.
                    use_scopes = [
                        ctx.env.scope_tree_nodes.get(x.path_scope_id)
                        if x.path_scope_id is not None
                        else None
                        for x in mat_set.use_sets
                    ]
                    for b, _ in x.sets:
                        if (
                            reason_parent.is_visible(b, allow_group=True)
                        ) and not all(
                            use_scope and use_scope.parent
                            and use_scope.parent.is_visible(
                                b, allow_group=True)
                            for use_scope in use_scopes
                        ):
                            good_reason = True
                            break

            if not good_reason:
                del stmt.materialized_sets[key]
                continue

            # Compile the view shapes in the set
            with ctx.new() as subctx:
                subctx.implicit_tid_in_shapes = False
                subctx.implicit_tname_in_shapes = False
                subctx.path_scope = new_scope
                subctx.path_scope = parent.attach_fence()
                viewgen.late_compile_view_shapes(ir_set, ctx=subctx)

            for use_set in mat_set.use_sets:
                if use_set != mat_set.materialized:
                    use_set.is_materialized_ref = True
                    # XXX: Deleting it on linkprops breaks a bunch of
                    # linkprop related DML...
                    if not use_set.path_id.is_linkprop_path():
                        to_clear.append(use_set)

            assert (
                not any(use.src_path() for use in mat_set.uses)
                or isinstance(mat_set.materialized.expr, irast.Pointer)
            ), f"materialized ptr {mat_set.uses} missing pointer"
            mat_set.finalized = True

    return to_clear


def _find_visible_binding_refs(
    ir: irast.Base, *, ctx: context.ContextLevel
) -> list[irast.Set]:
    children = ast_visitor.find_children(
        ir, irast.Set, lambda n: n.is_visible_binding_ref)
    return children


def _try_namespace_fix(
    scope: irast.ScopeTreeNode,
    path_id: irast.PathId,
) -> irast.PathId:
    for prefix in path_id.iter_prefixes():
        replacement = scope.find_visible(prefix, allow_group=True)
        if (
            replacement and replacement.path_id
            and prefix != replacement.path_id
        ):
            new = irtyputils.replace_pathid_prefix(
                path_id, prefix, replacement.path_id)

            return new

    return path_id


def _rewrite_weak_namespaces(
    irs: Sequence[irast.Base], ctx: context.ContextLevel
) -> None:
    """Rewrite weak namespaces in path ids to be usable by the backend.

    Weak namespaces in path ids in the frontend are "relative", and
    their interpretation depends on the current scope tree node and
    the namespaces on the parent nodes. The IR->pgsql compiler does
    not do this sort of interpretation, and needs path IDs that are
    "absolute".

    To accomplish this, we go through all the path ids and rewrite
    them: using the scope tree, we try to find the binding of the path
    ID (using a prefix if necessary) and drop all namespace parts that
    don't appear in the binding.
    """

    tree = ctx.path_scope

    for node in tree.strict_descendants:
        if node.path_id:
            node.path_id = _try_namespace_fix(node, node.path_id)

    scopes = irutils.find_path_scopes(irs)

    for ir_set in ctx.env.set_types:
        path_scope_id: Optional[int] = scopes.get(ir_set)
        if path_scope_id is not None:
            # Some entries in set_types are from compiling views
            # in temporary scopes, so we need to just skip those.
            if scope := ctx.env.scope_tree_nodes.get(path_scope_id):
                ir_set.path_id = _try_namespace_fix(scope, ir_set.path_id)


def _get_all_pathids(irs: Sequence[irast.Base]) -> set[
    tuple[irast.PathId, irast.Set | None]
]:
    all_ids: set[tuple[irast.PathId, irast.Set | None]] = set()
    for ir in irs:
        for ir_set in ast_visitor.find_children(ir, irast.Set):
            all_ids.add((ir_set.path_id, ir_set))
        for arg in ast_visitor.find_children(ir, irast.CallArg):
            if arg.expr_type_path_id:
                all_ids.add((arg.expr_type_path_id, None))

    return all_ids


def _collapse_factoring_protected(
    irs: Sequence[irast.Base], ctx: context.ContextLevel
) -> None:
    """Try to remove the Selects inserted for simple_scoping.

    In simple_scoping mode, we protect certain paths by wrapping them
    in selects so that they don't participate in path factoring.

    This generates more verbose SQL in all cases and inhibits
    important optimizations in others -- in particular, our efforts to
    make ORDER BY clauses simple enough for postgres to optimize.

    To remedy this, we try to collapse away those selects and their
    fences in the scope tree by checking if removing them would lead
    to any path factoring. If not, we can drop it.

    Note that *some* new-school factoring may still have happened.
    If we have `select User filter User.name = 'Elvis'`, the outer `User`
    will be unprotected and the inner `User` will be factored out to it,
    leaving just `User.name` in a protected inner scope.
    That's fine, and we will see User.name doesn't have anything
    to factor with and remove the select that was injected.
    """
    children = []
    for ir in irs:
        children += ast_visitor.find_children(
            ir, irast.Set, lambda x: x.is_factoring_protected
        )
    all_ids = _get_all_pathids(irs)

    for ir_set in ordered.OrderedSet(children):
        if (
            ir_set.path_scope_id is None
            or not irutils.is_implicit_wrapper(ir_set.expr)
        ):
            continue

        node = ctx.env.scope_tree_nodes[ir_set.path_scope_id]
        if not (parent := node.parent):
            continue

        # If the underlying thing has already been factored fully, we
        # skip it, because it might be no-factor fenced?
        if parent.find_visible(ir_set.expr.result.path_id):
            continue

        # If collapsing this node would lead to any factoring, we
        # obviously can't do it.
        # We check by seeing if there are some factorable nodes
        # *other* than the ones we are starting from.
        if any(
            parent.find_factorable_nodes(path_id, child_to_skip=node)
            for path_id in node.get_all_paths()
        ):
            continue

        # If the path is referenced at all, we can't do it.
        # PERF: Should we build a dict with all prefixes as keys, instead
        # of this O(n*m) loop?
        if any(
            x is not ir_set and path_id.startswith(ir_set.path_id)
            for path_id, x in all_ids
        ):
            continue

        del ctx.env.scope_tree_nodes[ir_set.path_scope_id]

        # Merge the node up into its parent
        node.optional |= parent.optional
        parent.fuse_subtree(node, self_fenced=True, ctx=ctx)

        # Mark the new path as optional if the old path was optional.
        orig = None
        gparent = parent.parent
        if (
            gparent
            and (orig := gparent.find_child(ir_set.path_id))
            and parent.optional
            and orig.optional
        ):
            pathctx.register_set_in_scope(
                ir_set.expr.result, optional=True, path_scope=gparent, ctx=ctx
            )
        if orig:
            orig.remove()
        # Yeeeee haw. Replace the old set with the inner one.
        ir_set.__dict__ = ir_set.expr.result.__dict__


def _fixup_schema_view(*, ctx: context.ContextLevel) -> None:
    """Finalize schema view types for inclusion in the real schema.

    This includes setting from_alias flags and collapsing opaque
    unions to BaseObject.
    """
    for view in ctx.view_nodes.values():
        if view.is_collection():
            continue

        assert isinstance(view, s_types.InheritingType)
        _elide_derived_ancestors(view, ctx=ctx)

        if not isinstance(view, s_sources.Source):
            continue

        view_own_pointers = view.get_pointers(ctx.env.schema)
        for vptr in view_own_pointers.objects(ctx.env.schema):
            _elide_derived_ancestors(vptr, ctx=ctx)
            ctx.env.schema = vptr.set_field_value(
                ctx.env.schema,
                'from_alias',
                True,
            )

            tgt = vptr.get_target(ctx.env.schema)
            assert tgt is not None

            if (tgt.is_union_type(ctx.env.schema)
                    and tgt.get_is_opaque_union(ctx.env.schema)):
                # Opaque unions should manifest as std::BaseObject
                # in schema views.
                ctx.env.schema = vptr.set_target(
                    ctx.env.schema,
                    ctx.env.schema.get(
                        'std::BaseObject', type=s_types.Type),
                )

            if not isinstance(vptr, s_sources.Source):
                continue

            vptr_own_pointers = vptr.get_pointers(ctx.env.schema)
            for vlprop in vptr_own_pointers.objects(ctx.env.schema):
                _elide_derived_ancestors(vlprop, ctx=ctx)
                ctx.env.schema = vlprop.set_field_value(
                    ctx.env.schema,
                    'from_alias',
                    True,
                )


def _get_nearest_non_source_derived_parent(
    obj: s_obj.DerivableInheritingObjectT, ctx: context.ContextLevel
) -> s_obj.DerivableInheritingObjectT:
    """Find the nearest ancestor of obj whose "root source" is not derived"""
    schema = ctx.env.schema
    while (
        (src := s_pointers.get_root_source(obj, schema))
        and isinstance(src, s_obj.DerivableInheritingObject)
        and src.get_is_derived(schema)
    ):
        obj = obj.get_bases(schema).first(schema)
    return obj


def _elide_derived_ancestors(
    obj: s_types.InheritingType | s_pointers.Pointer,
    *,
    ctx: context.ContextLevel,
) -> None:
    """Collapse references to derived objects in bases.

    When compiling a schema view expression, make sure we don't
    expose any ephemeral derived objects, as these wouldn't be
    present in the schema outside of the compilation context.
    """

    pbase = obj.get_bases(ctx.env.schema).first(ctx.env.schema)
    new_pbase = _get_nearest_non_source_derived_parent(pbase, ctx)
    if pbase != new_pbase:
        ctx.env.schema = obj.set_field_value(
            ctx.env.schema,
            'bases',
            s_obj.ObjectList.create(ctx.env.schema, [new_pbase]),
        )

        ctx.env.schema = obj.set_field_value(
            ctx.env.schema,
            'ancestors',
            s_obj.compute_ancestors(ctx.env.schema, obj)
        )


def compile_anchor(
    name: str,
    anchor: qlast.Expr | irast.Base | s_obj.Object | irast.PathId,
    *,
    ctx: context.ContextLevel,
) -> irast.Set:

    show_as_anchor = True

    if isinstance(anchor, s_types.Type):
        # Anchors should not receive type rewrites; we are already
        # evaluating in their context.
        ctx.env.type_rewrites[anchor, False] = None
        step = setgen.class_set(anchor, ctx=ctx)

    elif (isinstance(anchor, s_pointers.Pointer) and
            not anchor.is_link_property(ctx.env.schema)):
        src = anchor.get_source(ctx.env.schema)
        if src is not None:
            assert isinstance(src, s_objtypes.ObjectType)
            ctx.env.type_rewrites[src, False] = None
            path = setgen.extend_path(
                setgen.class_set(src, ctx=ctx),
                anchor,
                s_pointers.PointerDirection.Outbound,
                ctx=ctx,
            )
        else:
            ptrcls = schemactx.derive_dummy_ptr(anchor, ctx=ctx)
            src = ptrcls.get_source(ctx.env.schema)
            assert isinstance(src, s_types.Type)
            ctx.env.type_rewrites[src, False] = None
            path = setgen.extend_path(
                setgen.class_set(src, ctx=ctx),
                ptrcls,
                s_pointers.PointerDirection.Outbound,
                ctx=ctx)

        step = path

    elif (isinstance(anchor, s_pointers.Pointer) and
            anchor.is_link_property(ctx.env.schema)):

        anchor_source = anchor.get_source(ctx.env.schema)
        assert isinstance(anchor_source, s_pointers.Pointer)
        anchor_source_source = anchor_source.get_source(ctx.env.schema)

        if anchor_source_source:
            assert isinstance(anchor_source_source, s_objtypes.ObjectType)
            path = setgen.extend_path(
                setgen.class_set(anchor_source_source, ctx=ctx),
                anchor_source,
                s_pointers.PointerDirection.Outbound,
                ctx=ctx,
            )
        else:
            ptrcls = schemactx.derive_dummy_ptr(anchor_source, ctx=ctx)
            src = ptrcls.get_source(ctx.env.schema)
            assert isinstance(src, s_types.Type)
            path = setgen.extend_path(
                setgen.class_set(src, ctx=ctx),
                ptrcls,
                s_pointers.PointerDirection.Outbound,
                ctx=ctx)

        step = setgen.extend_path(
            path,
            anchor,
            s_pointers.PointerDirection.Outbound,
            ctx=ctx)

    elif isinstance(anchor, qlast.Base):
        step = dispatch.compile(anchor, ctx=ctx)

    elif isinstance(anchor, (irast.QueryParameter, irast.FunctionParameter)):
        step = setgen.ensure_set(anchor, ctx=ctx)

    elif isinstance(anchor, irast.PathId):
        stype = typegen.type_from_typeref(anchor.target, env=ctx.env)
        step = setgen.class_set(
            stype, path_id=anchor, ignore_rewrites=True, ctx=ctx)

    else:
        raise RuntimeError(f'unexpected anchor value: {anchor!r}')

    if show_as_anchor:
        step.anchor = name
        step.show_as_anchor = name

    return step


def populate_anchors(
    anchors: Mapping[str, Any],
    *,
    ctx: context.ContextLevel,
) -> None:

    for name, val in anchors.items():
        ctx.anchors[name] = compile_anchor(name, val, ctx=ctx)


def declare_view(
    expr: qlast.Expr,
    alias: s_name.Name,
    *,
    factoring_fence: bool=False,
    fully_detached: bool=False,
    binding_kind: irast.BindingKind,
    path_id_namespace: Optional[frozenset[str]]=None,
    ctx: context.ContextLevel,
) -> irast.Set:

    pinned_pid_ns = path_id_namespace

    with ctx.newscope(fenced=True) as subctx:
        if factoring_fence:
            subctx.path_scope.factoring_fence = True
            subctx.path_scope.factoring_allowlist.update(ctx.iterator_path_ids)

        if path_id_namespace is not None:
            subctx.path_id_namespace = path_id_namespace

        if not fully_detached:
            cached_view_set = ctx.env.expr_view_cache.get((expr, alias))
            # Detach the view namespace and record the prefix
            # in the parent statement's fence node.
            view_path_id_ns = {ctx.aliases.get('ns')}
            # if view_path_id_ns == {'ns~3'}:
            #     view_path_id_ns = set()
            subctx.path_id_namespace |= view_path_id_ns
            ctx.path_scope.add_namespaces(view_path_id_ns)
        else:
            cached_view_set = None

        if ctx.stmt is not None:
            subctx.stmt = ctx.stmt.parent_stmt

        if cached_view_set is not None:
            subctx.view_scls = setgen.get_set_type(cached_view_set, ctx=ctx)
            view_name = subctx.view_scls.get_name(ctx.env.schema)
            assert isinstance(view_name, s_name.QualName)
        else:
            if (
                isinstance(alias, s_name.QualName)
                and subctx.env.options.schema_view_mode
            ):
                view_name = alias
                subctx.recompiling_schema_alias = True
            else:
                view_name = s_name.QualName(
                    module=ctx.derived_target_module or '__derived__',
                    name=s_name.get_specialized_name(
                        alias,
                        ctx.aliases.get('w')
                    )
                )

        subctx.toplevel_result_view_name = view_name

        view_set = dispatch.compile(astutils.ensure_ql_query(expr), ctx=subctx)
        assert isinstance(view_set, irast.Set)

        ctx.env.path_scope_map[view_set] = context.ScopeInfo(
            path_scope=subctx.path_scope,
            pinned_path_id_ns=pinned_pid_ns,
            binding_kind=binding_kind,
        )

        if not fully_detached:
            # The view path id _itself_ should not be in the nested namespace.
            # The fully_detached case should be handled by the caller.
            if path_id_namespace is None:
                path_id_namespace = ctx.path_id_namespace
            view_set.path_id = view_set.path_id.replace_namespace(
                path_id_namespace)

        ctx.aliased_views[alias] = view_set
        ctx.env.expr_view_cache[expr, alias] = view_set

    return view_set


def _declare_view_from_schema(
    viewcls: s_types.Type, *, ctx: context.ContextLevel
) -> tuple[s_types.Type, irast.Set]:
    # We need to include "security context" things (currently just
    # access policy state) in the cache key, here.
    #
    # See below for an optimization in the case where polices are not
    # used.
    security_context = ctx.get_security_context()
    key = viewcls, security_context
    e = ctx.env.schema_view_cache.get(key)
    if e is not None:
        return e

    orig_policy_count = ctx.env.policy_use_count

    # N.B: This takes a context, which we need to use to create a
    # subcontext to compile in, but it should avoid depending on the
    # context, because of the cache.
    with ctx.detached() as subctx:
        subctx.current_schema_views += (viewcls,)
        subctx.expr_exposed = context.Exposure.UNEXPOSED
        view_expr: s_expr.Expression | None = viewcls.get_expr(ctx.env.schema)
        assert view_expr is not None
        view_ql = view_expr.parse()
        viewcls_name = viewcls.get_name(ctx.env.schema)
        assert isinstance(view_ql, qlast.Expr), 'expected qlast.Expr'
        view_set = declare_view(
            view_ql,
            alias=viewcls_name,
            binding_kind=irast.BindingKind.Schema,
            fully_detached=True,
            ctx=subctx,
        )
        # The view path id _itself_ should not be in the nested namespace.
        view_set.path_id = view_set.path_id.replace_namespace(frozenset())
        view_set.is_schema_alias = True

        vs = subctx.aliased_views[viewcls_name]
        assert vs is not None
        vc = setgen.get_set_type(vs, ctx=ctx)

        # If policies weren't actually used, see if we already
        # compiled this global/alias with policy suppression in the
        # other state, to avoid generating two CTEs for a cached
        # global pointlessly.
        if orig_policy_count == ctx.env.policy_use_count:
            key2 = viewcls, security_context.toggle_policies()
            if key2 in ctx.env.schema_view_cache:
                vc, view_set = ctx.env.schema_view_cache[key2]

        ctx.env.schema_view_cache[key] = vc, view_set

    return vc, view_set


def declare_view_from_schema(
    viewcls: s_types.Type, *, ctx: context.ContextLevel
) -> s_types.Type:
    vc, view_set = _declare_view_from_schema(viewcls, ctx=ctx)

    viewcls_name = viewcls.get_name(ctx.env.schema)

    ctx.aliased_views[viewcls_name] = view_set
    ctx.view_nodes[vc.get_name(ctx.env.schema)] = vc
    ctx.view_sets[vc] = view_set

    return vc


def check_params(params: dict[str, irast.Param]) -> None:
    first_argname = next(iter(params))
    for param in params.values():
        # FIXME: context?
        if param.name.isdecimal() != first_argname.isdecimal():
            raise errors.QueryError(
                f'cannot combine positional and named parameters '
                f'in the same query')

    if first_argname.isdecimal():
        args_decnames = {int(arg) for arg in params}
        args_tpl = set(range(len(params)))
        if args_decnames != args_tpl:
            missing_args = args_tpl - args_decnames
            missing_args_repr = ', '.join(f'${a}' for a in missing_args)
            raise errors.QueryError(
                f'missing {missing_args_repr} positional argument'
                f'{"s" if len(missing_args) > 1 else ""}')


def throw_on_shaped_param(
    param: qlast.QueryParameter, shape: qlast.Shape, ctx: context.ContextLevel
) -> None:
    raise errors.QueryError(
        f'cannot apply a shape to the parameter',
        hint='Consider adding parentheses around the parameter and type cast',
        span=shape.span
    )


def throw_on_loose_param(
    param: qlast.QueryParameter, ctx: context.ContextLevel
) -> None:
    if ctx.env.options.func_params is not None:
        if ctx.env.options.schema_object_context is s_constr.Constraint:
            raise errors.InvalidConstraintDefinitionError(
                f'dollar-prefixed "$parameters" cannot be used here',
                span=param.span)
        else:
            raise errors.InvalidFunctionDefinitionError(
                f'dollar-prefixed "$parameters" cannot be used here',
                span=param.span)
    raise errors.QueryError(
        f'missing a type cast before the parameter',
        span=param.span)


def preprocess_script(
    stmts: Sequence[qlast.Base], *, ctx: context.ContextLevel
) -> irast.ScriptInfo:
    """Extract parameters from all statements in a script.

    Doing this in advance makes it easy to check that they have
    consistent types.
    """
    params_lists = [
        astutils.find_parameters(stmt, ctx.modaliases)
        for stmt in stmts
    ]

    if loose_params := [
        loose for params in params_lists
        for loose in params.loose_params
    ]:
        throw_on_loose_param(loose_params[0], ctx)

    if shaped_params := [
        shaped for params in params_lists
        for shaped in params.shaped_params
    ]:
        throw_on_shaped_param(shaped_params[0][0], shaped_params[0][1], ctx)

    casts = [
        cast for params in params_lists for cast in params.cast_params
    ]
    params = {}
    for cast, modaliases in casts:
        assert isinstance(cast.expr, qlast.QueryParameter)
        name = cast.expr.name
        if name in params:
            continue
        with ctx.new() as mctx:
            mctx.modaliases = modaliases
            target_stype = typegen.ql_typeexpr_to_type(cast.type, ctx=mctx)

        if ctx.env.options.json_parameters:
            # Rule check on JSON-input parameters.
            # The actual casting of the the parameter happens in
            if name.isdecimal():
                raise errors.QueryError(
                    'queries compiled to accept JSON parameters do not '
                    'accept positional parameters',
                    span=cast.expr.span)

        # for ObjectType parameters, we inject intermediate cast to uuid,
        # so parameter is uuid and then cast to ObjectType
        if target_stype.is_object_type():
            uuid_cast = qlast.TypeCast(
                type=qlast.TypeName(maintype=qlast.ObjectRef(name='uuid')),
                expr=cast.expr,
                cardinality_mod=cast.cardinality_mod,
            )
            cast.expr = uuid_cast
            cast = cast.expr

            with ctx.new() as mctx:
                mctx.modaliases = modaliases
                target_stype = typegen.ql_typeexpr_to_type(cast.type, ctx=mctx)

        target_typeref = typegen.type_to_typeref(target_stype, env=ctx.env)
        required = cast.cardinality_mod != qlast.CardinalityModifier.Optional

        # This handles processing of tuple arguments, nested arrays, and
        # all json-mode parameters.
        sub_params = tuple_args.create_sub_params(
            name,
            required,
            typeref=target_typeref,
            pt=target_stype,
            is_func_param=False,
            ctx=ctx,
        )
        params[name] = irast.Param(
            name=name,
            required=required,
            schema_type=target_stype,
            ir_type=target_typeref,
            sub_params=sub_params,
        )

    if params:
        check_params(params)

        def _arg_key(k: tuple[str, object]) -> int:
            name = k[0]
            arg_prefix = '__edb_arg_'
            # Positional arguments should just be sorted numerically,
            # while for named arguments, injected args should be sorted and
            # need to come after normal ones. Normal named arguments can have
            # any order.
            if name.isdecimal():
                return int(name)
            elif name.startswith(arg_prefix):
                return int(k[0][len(arg_prefix):])
            else:
                return -1

        params = dict(sorted(params.items(), key=_arg_key))

    return irast.ScriptInfo(params=params, schema=ctx.env.schema)


================================================
FILE: edb/edgeql/compiler/triggers.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL trigger compilation."""


from __future__ import annotations

from typing import Optional, Collection

from edb import errors

from edb.ir import ast as irast

from edb.schema import name as sn
from edb.schema import objtypes as s_objtypes
from edb.schema import triggers as s_triggers
from edb.schema import types as s_types
from edb.schema import expr as s_expr

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from . import context
from . import dispatch
from . import options
from . import schemactx
from . import setgen
from . import typegen


TRIGGER_KINDS = {
    irast.UpdateStmt: qltypes.TriggerKind.Update,
    irast.DeleteStmt: qltypes.TriggerKind.Delete,
    irast.InsertStmt: qltypes.TriggerKind.Insert,
}


def compile_trigger(
    trigger: s_triggers.Trigger,
    affected: set[tuple[s_objtypes.ObjectType, irast.MutatingStmt]],
    all_typs: set[s_objtypes.ObjectType],
    *,
    ctx: context.ContextLevel,
) -> irast.Trigger:
    schema = ctx.env.schema

    scope = trigger.get_scope(schema)
    kinds = set(trigger.get_kinds(schema))
    source = trigger.get_subject(schema)

    with ctx.detached() as tc, tc.newscope(fenced=True) as sctx:
        sctx.anchors = sctx.anchors.copy()

        anchors = {}
        new_path = irast.PathId.from_type(
            schema,
            source,
            typename=sn.QualName(
                module='__derived__', name=ctx.aliases.get('__new__')
            ),
            env=ctx.env,
        )
        new_set = setgen.class_set(
            source, path_id=new_path, ignore_rewrites=True, ctx=sctx)
        new_set.expr = irast.TriggerAnchor(typeref=new_set.typeref)

        old_set = None
        if qltypes.TriggerKind.Insert not in kinds:
            old_path = irast.PathId.from_type(
                schema,
                source,
                typename=sn.QualName(
                    module='__derived__', name=ctx.aliases.get('__old__')
                ),
                env=ctx.env,
            )
            old_set = setgen.class_set(
                source, path_id=old_path, ignore_rewrites=True, ctx=sctx)
            old_set.expr = irast.TriggerAnchor(typeref=old_set.typeref)
            anchors['__old__'] = old_set
        if qltypes.TriggerKind.Delete not in kinds:
            anchors['__new__'] = new_set

        for name, ir in anchors.items():
            if scope == qltypes.TriggerScope.Each:
                sctx.path_scope.attach_path(ir.path_id, span=None, ctx=sctx)
                sctx.iterator_path_ids |= {ir.path_id}
            sctx.anchors[name] = ir

        trigger_expr: Optional[s_expr.Expression] = trigger.get_expr(schema)
        assert trigger_expr
        trigger_ast = trigger_expr.parse()

        # A conditional trigger desugars to a FOR query that puts the
        # condition in the FILTER of a trivial SELECT.
        condition: Optional[s_expr.Expression] = trigger.get_condition(schema)
        if condition:
            trigger_ast = qlast.ForQuery(
                iterator_alias='__',
                iterator=qlast.SelectQuery(
                    result=qlast.Tuple(elements=[]),
                    where=condition.parse(),
                ),
                result=trigger_ast,
            )

        trigger_set = dispatch.compile(trigger_ast, ctx=sctx)

    typeref = typegen.type_to_typeref(source, env=ctx.env)
    taffected = {
        (typegen.type_to_typeref(t, env=ctx.env), ir) for t, ir in affected
    }
    tall = {
        typegen.type_to_typeref(t, env=ctx.env) for t in all_typs
    }

    return irast.Trigger(
        expr=trigger_set,
        kinds=kinds,
        scope=scope,
        source_type=typeref,
        affected=taffected,
        all_affected_types=tall,
        new_set=new_set,
        old_set=old_set,
    )


def compile_triggers_phase(
    dml_stmts: Collection[irast.MutatingStmt],
    defining_trigger_on: Optional[s_types.Type],
    defining_trigger_kinds: Optional[Collection[qltypes.TriggerKind]],
    *,
    ctx: context.ContextLevel,
) -> tuple[irast.Trigger, ...]:
    schema = ctx.env.schema

    trigger_map: dict[
        s_triggers.Trigger,
        tuple[
            set[tuple[s_objtypes.ObjectType, irast.MutatingStmt]],
            set[s_objtypes.ObjectType],
        ],
    ] = {}
    for stmt in dml_stmts:
        kind = TRIGGER_KINDS[type(stmt)]

        stype = schemactx.concretify(
            setgen.get_set_type(stmt.result, ctx=ctx), ctx=ctx)
        assert isinstance(stype, s_objtypes.ObjectType)
        # For updates and deletes, we need to look to see if any
        # descendant types have triggers.
        if isinstance(stmt, irast.InsertStmt):
            stypes = {stype}
        else:
            stypes = schemactx.get_all_concrete(stype, ctx=ctx)

        # Process all the types, starting with the base type
        for subtype in sorted(stypes, key=lambda t: t != stype):
            if (defining_trigger_on and defining_trigger_kinds
                and kind in defining_trigger_kinds
                and subtype.issubclass(ctx.env.schema, defining_trigger_on)
            ):
                name = str(defining_trigger_on.get_name(ctx.env.schema))
                raise errors.SchemaDefinitionError(
                    f"trigger on {name} after {kind.lower()} is recursive"
                )

            for trigger in subtype.get_relevant_triggers(kind, schema):
                mro = (trigger, *trigger.get_ancestors(schema).objects(schema))
                base = mro[-1]
                tmap, all_typs = trigger_map.setdefault(base, (set(), set()))
                # N.B: If the *base type* of the DML appears, that
                # suffices, because it covers everything, and we don't
                # need to duplicate.  This is a specific interaction
                # with how dml.compile_trigger is implemented, where
                # processing the base type of a DML naturally covers
                # all subtypes, but processing a child does not cover
                # a grandchild.
                if (stype, stmt) not in tmap:
                    tmap.add((subtype, stmt))
                all_typs.add(subtype)

    # sort these by name just to avoid weird nondeterminism
    return tuple(
        compile_trigger(trigger, affected, all_typs, ctx=ctx)
        for trigger, (affected, all_typs)
        in sorted(trigger_map.items(), key=lambda t: t[0].get_name(schema))
    )


def compile_triggers(
    *,
    ctx: context.ContextLevel,
) -> tuple[tuple[irast.Trigger, ...], ...]:
    defining_trigger = (
        ctx.env.options.schema_object_context == s_triggers.Trigger)
    defining_trigger_on = None
    defining_trigger_kinds = None
    if (
        defining_trigger and
        isinstance(ctx.env.options, options.CompilerOptions)
    ):
        defining_trigger_on = ctx.env.options.trigger_type
        defining_trigger_kinds = ctx.env.options.trigger_kinds

    ir_triggers: list[tuple[irast.Trigger, ...]] = []
    start = 0
    all_trigger_causes: set[tuple[irast.TypeRef, qltypes.TriggerKind]] = set()
    while start < len(ctx.env.dml_stmts):
        end = len(ctx.env.dml_stmts)
        compiled_triggers = compile_triggers_phase(
            ctx.env.dml_stmts[start:],
            defining_trigger_on,
            defining_trigger_kinds,
            ctx=ctx
        )
        new_causes: set[tuple[irast.TypeRef, qltypes.TriggerKind]] = {
            (affected_type, kind)
            for compiled_trigger in compiled_triggers
            for affected_type in compiled_trigger.all_affected_types
            for kind in compiled_trigger.kinds
        }

        # Any given type is allowed allowed to have its triggers fire
        # in *one* phase of trigger execution, since the semantics get
        # a little unclear otherwise. We might relax this later.
        overlap = new_causes & all_trigger_causes
        if overlap:
            names: Collection[str] = sorted(
                f"{str(cause[0].name_hint)} after {cause[1].lower()}"
                for cause in overlap
            )
            raise errors.QueryError(
                f"trigger would need to be executed in multiple stages on "
                f"{', '.join(names)}"
            )
        all_trigger_causes |= new_causes
        ir_triggers.append(compiled_triggers)
        start = end

    return tuple(ir_triggers)


================================================
FILE: edb/edgeql/compiler/tuple_args.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""Implementation of tuple argument decoding compiler.

Postgres does not support passing records (tuples in edgeql) as query
parameters, and so we need to go to some length to work around this.

All of the trickyness here comes from the interaction with arrays;
without arrays, we could just split a tuple into multiple parameters.
The singly-nested case is also still fairly simple: turn an array of
tuples into multiple parallel arrays, such that `array>`
becomes `array` and `array`.

The doubly-nested case, in which the tuple itself contains an array
(for example, `array>>`), is trickier:
Postgres does not allow nested arrays (except if there is an intervening
record type).

The key insight to resolve this dilemma is that a nested array type
`array>` can be transformed into two non-nested arrays with
types `array` and `array`, where the `array` contains all
of the elements of the nested arrays flattened out and the
`array` contains the indexes into the flattened array indicating where
each of the nested arrays begins (followed by the length of the flattened
array, so that pairs of adjacent elements form the slice indexes into
the flattened array).

As an example, consider a parameter of type `array>>`,
with the value:
[
    ('foo', [100]),
    ('bar', [101, 102]),
    ('baz', [103, 104, 105]),
]
We will encode this into three arguments, with types `array`,
`array`, and `array`, with values:
  ['foo', 'bar', 'baz']
  [0, 1, 3, 6]
  [100, 101, 102, 103, 104, 105]


The encoding algorithm is straightforward: we traverse the type and
the input data in tandem, appending data into the appropriate argument
arrays and tracking array lengths. This is implemented in our cython
protocol server, in edb.server.protocol.args_ser, and operates directly
on the wire encodings.

The decoding needs to be done as part of the SQL query we execute, so
we generate an EdgeQL query that decodes to the proper type. The generated
code operates in a top-down manner, looping over the arrays and constructing
the value in a single pass.

The code we generate for our running example could look something like:
  with v0 := >$0, v1 := >$1, v2 := >$2,
  select array_agg((for i in range_unpack(range(0, len(v0))) union (
    (
      v0[i],
      array_agg((for j in range_unpack(v1[i], v1[i + 1]) union (v2[i]))),
    )
  )))
In this case, since the nested array is simply an array of a scalar, we can
do an optimization and use slicing instead of an array_agg+for:
  with v0 := >$0, v1 := >$1, v2 := >$2,
  select array_agg((for i in range_unpack(range(0, len(v0))) union (
    (
      v0[i],
      v2[v1[i] : v1[i + 1]],
    )
  )))

The decoder queries will get placed in a CTE in the generated SQL.
"""

from __future__ import annotations

import dataclasses

from typing import Optional, Sequence, TYPE_CHECKING

from edb import errors
from edb.common.typeutils import not_none

from edb.ir import ast as irast
from edb.ir import typeutils as irtypeutils

from edb.schema import name as sn
from edb.schema import types as s_types
from edb.schema import utils as s_utils

from edb.edgeql import ast as qlast

from . import context
from . import dispatch
from . import typegen

if TYPE_CHECKING:
    from edb.schema import schema as s_schema

# Since we process tuple types recusively in our cython server, insert
# a recursion depth check here, to be confident that this won't blow
# our C stack. (Though in practice I would expect anything that might
# blow it to blow the python stack while compiling the translation.)
MAX_NESTING = 20


def _lmost_is_array(typ: irast.ParamTransType) -> bool:
    while isinstance(typ, irast.ParamTuple):
        _, typ = typ.typs[0]
    return isinstance(typ, irast.ParamArray)


def translate_type(
    typeref: irast.TypeRef,
    *,
    schema: s_schema.Schema,
) -> tuple[irast.ParamTransType, tuple[irast.TypeRef, ...]]:
    """Translate the type of a tuple-containing param to multiple params.

    This computes a list of parameter types, as well as a
    ParamTransType that clones the type information but augments each
    node in the type with indexes that correspond to which parameter
    data is drawn from. This is used to drive the encoder and the
    decoder generator.
    """

    typs: list[irast.TypeRef] = []

    def trans(
        typ: irast.TypeRef, in_array: bool, depth: int
    ) -> irast.ParamTransType:
        if depth > MAX_NESTING:
            raise errors.QueryError(
                f'type of parameter is too deeply nested')

        start = len(typs)

        if irtypeutils.is_array(typ):
            # If our array is appearing already inside another array,
            # we need to add an extra parameter
            if in_array:
                int_typeref = schema.get(
                    sn.QualName('std', 'int32'), type=s_types.Type)
                nschema, array_styp = s_types.Array.from_subtypes(
                    schema, [int_typeref])
                typs.append(irtypeutils.type_to_typeref(
                    nschema, array_styp, cache=None))

            if irtypeutils.is_array(typ.subtypes[0]):
                # Treat nested arrays as if they are arrays of tuples of arrays
                nschema, inner_array_styp = irtypeutils.ir_typeref_to_type(
                    schema, typ.subtypes[0]
                )
                nschema, wrapper_tuple_styp = s_types.Tuple.from_subtypes(
                    schema, {'f1': inner_array_styp}
                )
                wrapper_tuple_typ = irtypeutils.type_to_typeref(
                    nschema, wrapper_tuple_styp, cache=None
                )
                return irast.ParamArray(
                    typeref=typ,
                    idx=start,
                    typ=trans(
                        wrapper_tuple_typ, in_array=True, depth=depth + 1
                    ),
                )
            else:
                return irast.ParamArray(
                    typeref=typ,
                    idx=start,
                    typ=trans(typ.subtypes[0], in_array=True, depth=depth + 1),
                )

        elif irtypeutils.is_tuple(typ):
            return irast.ParamTuple(
                typeref=typ,
                idx=start,
                typs=tuple(
                    (
                        t.element_name,
                        trans(t, in_array=in_array, depth=depth + 1),
                    )
                    for t in typ.subtypes
                ),
            )

        else:
            nt = typ
            # If this appears in an array, the param needs to be an array
            if in_array:
                nschema, styp = irtypeutils.ir_typeref_to_type(schema, typ)
                nschema, styp = s_types.Array.from_subtypes(nschema, [styp])
                nt = irtypeutils.type_to_typeref(nschema, styp, cache=None)
            typs.append(nt)
            return irast.ParamScalar(typeref=typ, idx=start)

    t = trans(typeref, in_array=False, depth=0)
    return t, tuple(typs)


def _ref_to_ast(
    typeref: irast.TypeRef, *, ctx: context.ContextLevel
) -> qlast.TypeExpr:
    ctx.env.schema, styp = irtypeutils.ir_typeref_to_type(
        ctx.env.schema, typeref)
    return s_utils.typeref_to_ast(ctx.env.schema, styp)


def _get_alias(
    name: str, *, ctx: context.ContextLevel
) -> tuple[str, qlast.Path]:
    alias = ctx.aliases.get(name)
    return alias, qlast.Path(
        steps=[qlast.ObjectRef(name=alias)],
    )


def _plus_const(expr: qlast.Expr, val: int) -> qlast.Expr:
    if val == 0:
        return expr
    return qlast.BinOp(
        left=expr,
        op='+',
        right=qlast.Constant.integer(val),
    )


def _index(expr: qlast.Expr, idx: qlast.Expr) -> qlast.Indirection:
    return qlast.Indirection(arg=expr, indirection=[qlast.Index(index=idx)])


def _make_tuple(
    fields: Sequence[tuple[Optional[str], qlast.Expr]]
) -> qlast.NamedTuple | qlast.Tuple:
    is_named = fields and fields[0][0]
    if is_named:
        return qlast.NamedTuple(elements=[
            qlast.TupleElement(name=qlast.Ptr(name=not_none(f)), val=e)
            for f, e in fields
        ])
    else:
        return qlast.Tuple(
            elements=[e for _, e in fields]
        )


def make_decoder(
    ptyp: irast.ParamTransType,
    qparams: tuple[irast.Param, ...],
    *,
    ctx: context.ContextLevel,
) -> qlast.Expr:
    """Generate a decoder for tuple parameters.

    More details in the module docstring.
    """
    params: list[qlast.Expr] = [
        qlast.TypeCast(
            expr=qlast.QueryParameter(name=param.name),
            type=_ref_to_ast(param.ir_type, ctx=ctx),
            cardinality_mod=(
                qlast.CardinalityModifier.Optional if not param.required
                else None
            ),
        )
        for param in qparams
    ]

    def mk(typ: irast.ParamTransType, idx: Optional[qlast.Expr]) -> qlast.Expr:
        if isinstance(typ, irast.ParamScalar):
            expr = params[typ.idx]
            if idx is not None:
                expr = _index(expr, idx)
            if typ.cast_to:
                expr = qlast.TypeCast(
                    expr=expr,
                    type=_ref_to_ast(typ.cast_to, ctx=ctx),
                )
            return expr

        elif isinstance(typ, irast.ParamTuple):
            return _make_tuple([(f, mk(t, idx=idx)) for f, t in typ.typs])

        elif isinstance(typ, irast.ParamArray):
            inner_idx_alias, inner_idx = _get_alias('idx', ctx=ctx)

            lo: qlast.Expr
            hi: qlast.Expr
            if idx is None:
                lo = qlast.Constant.integer(0)
                hi = qlast.FunctionCall(
                    func=('__std__', 'len'), args=[params[typ.idx]])
                # If the leftmost element inside a toplevel array is
                # itself an array, subtract 1 from the length (since
                # array params have an extra element). We also need to
                # call `max` to prevent generating an invalid range.
                if _lmost_is_array(typ.typ):
                    hi = qlast.FunctionCall(
                        func=('__std__', 'max'), args=[
                            qlast.Set(elements=[lo, _plus_const(hi, -1)])])
            else:
                lo = _index(params[typ.idx], idx)
                hi = _index(params[typ.idx], _plus_const(idx, 1))

            # If the contents is just a scalar, then we can take
            # values directly from the scalar array parameter, without
            # needing to iterate over the array directly.
            # This is an optimization, and not necessary for correctness.
            if isinstance(typ.typ, irast.ParamScalar) and not typ.typ.cast_to:
                sub = params[typ.typ.idx]
                # If we are in an array, do a slice
                if idx is not None:
                    sub = qlast.Indirection(
                        arg=sub,
                        indirection=[qlast.Slice(start=lo, stop=hi)],
                    )
                return sub

            sub_expr = mk(typ.typ, idx=inner_idx)

            loop = qlast.ForQuery(
                iterator_alias=inner_idx_alias,
                iterator=qlast.FunctionCall(
                    func=('__std__', '__pg_generate_series'),
                    args=[lo, _plus_const(hi, -1)],
                ),
                result=sub_expr,
            )
            res: qlast.Expr = qlast.FunctionCall(
                func=('__std__', 'array_agg'), args=[loop],
            )

            # If the param is optional, and we are still at the
            # top-level, insert a filter so that our aggregate doesn't
            # create something from nothing.
            if not qparams[typ.idx].required and idx is None:
                res = qlast.SelectQuery(
                    result=res,
                    where=qlast.UnaryOp(op='EXISTS', operand=params[typ.idx]),
                )

            return res

        else:
            raise AssertionError(f'bogus type {typ}')

    decoder = mk(ptyp, idx=None)

    return decoder


def create_sub_params(
    name: str,
    required: bool,
    typeref: irast.TypeRef,
    pt: s_types.Type,
    is_func_param: bool=False,
    *,
    ctx: context.ContextLevel,
) -> Optional[irast.SubParams]:
    """Create sub parameters for a new param, if needed.

    We need to do this if there is a tuple in the type.

    We do this for nested arrays as well since array is handled
    as array>.
    """
    json_cast = ctx.env.options.json_parameters and not is_func_param
    if not (
        (
            pt.is_tuple(ctx.env.schema)
            or pt.is_anytuple(ctx.env.schema)
            or pt.contains_array_of_array(ctx.env.schema)
            or pt.contains_array_of_tuples(ctx.env.schema)
        )
        and not ctx.env.options.func_params
    ) and not json_cast:
        return None

    pdt: irast.ParamTransType
    arg_typs: tuple[irast.TypeRef, ...]
    if json_cast:
        json = typegen.type_to_typeref(
            ctx.env.get_schema_type_and_track(sn.QualName('std', 'json')),
            env=ctx.env,
        )
        pdt = irast.ParamScalar(typeref=json, cast_to=typeref, idx=0)
        arg_typs = (json,)
    else:
        pdt, arg_typs = translate_type(typeref, schema=ctx.env.schema)

    params = tuple([
        irast.Param(
            name=f'__edb_decoded_{name}_{i}__',
            required=required,
            ir_type=arg_typeref,
            schema_type=typegen.type_from_typeref(arg_typeref, env=ctx.env),
        )
        for i, arg_typeref in enumerate(arg_typs)
    ])

    decode_ql = make_decoder(pdt, params, ctx=ctx)

    return irast.SubParams(
        trans_type=pdt, decoder_edgeql=decode_ql, params=params)


def finish_sub_params(
    subps: irast.SubParams,
    *,
    ctx: context.ContextLevel,
) -> Optional[irast.SubParams]:
    """Finalize the subparams by compiling the IR in the proper context.

    We can't just compile it when doing create_sub_params, since that is
    called from preprocessing and so is shared between queries.
    """
    with ctx.newscope(fenced=True) as subctx:
        decode_ir = dispatch.compile(subps.decoder_edgeql, ctx=subctx)

    return dataclasses.replace(subps, decoder_ir=decode_ir)


================================================
FILE: edb/edgeql/compiler/typegen.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL compiler type-related helpers."""


from __future__ import annotations

from typing import Optional, Sequence, cast, overload

from edb import errors

from edb.ir import ast as irast
from edb.ir import typeutils as irtyputils
from edb.ir import utils as irutils

from edb.schema import name as s_name
from edb.schema import objtypes as s_objtypes
from edb.schema import pointers as s_pointers
from edb.schema import scalars as s_scalars
from edb.schema import types as s_types
from edb.schema import utils as s_utils

from edb.edgeql import ast as qlast

from . import context
from . import dispatch
from . import schemactx
from . import setgen


def amend_empty_set_type(
    es: irast.SetE[irast.EmptySet],
    t: s_types.Type,
    env: context.Environment
) -> None:
    env.set_types[es] = t
    alias = es.path_id.target_name_hint.name
    typename = s_name.QualName(module='__derived__', name=alias)
    es.path_id = irast.PathId.from_type(
        env.schema, t, env=env, typename=typename,
        namespace=es.path_id.namespace,
    )


def infer_common_type(
    irs: Sequence[irast.Set], env: context.Environment
) -> Optional[s_types.Type]:
    if not irs:
        raise errors.QueryError(
            'cannot determine common type of an empty set',
            span=irs[0].span)

    types = []
    empties = []

    seen_object = False
    seen_scalar = False
    seen_coll = False

    for i, arg in enumerate(irs):
        if (
            isinstance(arg.expr, irast.EmptySet)
            and env.set_types[arg] is None
        ):
            empties.append(i)
            continue

        t = env.set_types[arg]
        if isinstance(t, s_types.Collection):
            seen_coll = True
        elif isinstance(t, s_scalars.ScalarType):
            seen_scalar = True
        else:
            seen_object = True
        types.append(t)

    if seen_coll + seen_scalar + seen_object > 1:
        raise errors.QueryError(
            'cannot determine common type',
            span=irs[0].span)

    if not types:
        raise errors.QueryError(
            'cannot determine common type of an empty set',
            span=irs[0].span)

    common_type = None
    if seen_scalar or seen_coll:
        it = iter(types)
        common_type = next(it)
        while True:
            next_type = next(it, None)
            if next_type is None:
                break
            env.schema, common_type = (
                common_type.find_common_implicitly_castable_type(
                    next_type,
                    env.schema,
                )
            )
            if common_type is None:
                break
    else:
        common_types = s_utils.get_class_nearest_common_ancestors(
            env.schema,
            cast(Sequence[s_types.InheritingType], types),
        )
        # We arbitrarily select the first nearest common ancestor
        common_type = common_types[0] if common_types else None

    if common_type is None:
        return None

    for i in empties:
        amend_empty_set_type(
            cast(irast.SetE[irast.EmptySet], irs[i]), common_type, env)

    return common_type


def type_to_ql_typeref(
    t: s_types.Type,
    *,
    _name: Optional[str] = None,
    ctx: context.ContextLevel,
) -> qlast.TypeExpr:
    return s_utils.typeref_to_ast(
        ctx.env.schema,
        t,
        disambiguate_std='std' in ctx.modaliases,
    )


def ql_typeexpr_to_ir_typeref(
    ql_t: qlast.TypeExpr, *, ctx: context.ContextLevel
) -> irast.TypeRef:

    stype = ql_typeexpr_to_type(ql_t, ctx=ctx)
    return irtyputils.type_to_typeref(
        ctx.env.schema, stype, cache=ctx.env.type_ref_cache
    )


def ql_typeexpr_to_type(
    ql_t: qlast.TypeExpr, *, ctx: context.ContextLevel
) -> s_types.Type:

    (op, _, types) = (
        _ql_typeexpr_get_types(ql_t, ctx=ctx)
    )
    return _ql_typeexpr_combine_types(op, types, ctx=ctx)


def _ql_typeexpr_combine_types(
        op: Optional[str], types: list[s_types.Type], *,
        ctx: context.ContextLevel
) -> s_types.Type:
    if len(types) == 1:
        return types[0]
    elif op == '|':
        return schemactx.get_union_type(types, ctx=ctx)
    elif op == '&':
        return schemactx.get_intersection_type(types, ctx=ctx)
    else:
        raise errors.InternalServerError('This should never happen')


def _ql_typeexpr_get_types(
    ql_t: qlast.TypeExpr, *, ctx: context.ContextLevel
) -> tuple[Optional[str], bool, list[s_types.Type]]:

    if isinstance(ql_t, qlast.TypeOf):
        with ctx.new() as subctx:
            # Use an empty scope tree, to avoid polluting things pointlessly
            subctx.path_scope = irast.ScopeTreeNode()
            subctx.expr_exposed = context.Exposure.UNEXPOSED
            orig_rewrites = ctx.env.type_rewrites.copy()
            ir_set = dispatch.compile(ql_t.expr, ctx=subctx)
            stype = setgen.get_set_type(ir_set, ctx=subctx)
            ctx.env.type_rewrites = orig_rewrites

        return (None, True, [stype])

    elif isinstance(ql_t, qlast.TypeOp):
        if ql_t.op in [qlast.TypeOpName.OR, qlast.TypeOpName.AND]:
            (left_op, left_leaf, left_types) = (
                _ql_typeexpr_get_types(ql_t.left, ctx=ctx)
            )
            (right_op, right_leaf, right_types) = (
                _ql_typeexpr_get_types(ql_t.right, ctx=ctx)
            )

            # We need to validate that type ops are applied only to
            # object types. So we check the base case here, when the
            # left or right operand is a single type, because if it's
            # a longer list, then we know that it was already composed
            # of "|" or "&", or it is the result of inference by
            # "typeof" and is a list of object types anyway.
            if left_leaf and not left_types[0].is_object_type():
                raise errors.UnsupportedFeatureError(
                    f"cannot use type operator '{ql_t.op}' with non-object "
                    f"type {left_types[0].get_displayname(ctx.env.schema)}",
                    span=ql_t.left.span)
            if right_leaf and not right_types[0].is_object_type():
                raise errors.UnsupportedFeatureError(
                    f"cannot use type operator '{ql_t.op}' with non-object "
                    f"type {right_types[0].get_displayname(ctx.env.schema)}",
                    span=ql_t.right.span)

            # if an operand is either a single type or uses the same operator,
            # flatten it into the result types list.
            # if an operand has a different operator is used, its types should
            # be combined into a new type before appending to the result types.
            types: list[s_types.Type] = []
            types += (
                left_types
                if left_op is None or left_op == ql_t.op else
                [_ql_typeexpr_combine_types(left_op, left_types, ctx=ctx)]
            )
            types += (
                right_types
                if right_op is None or right_op == ql_t.op else
                [_ql_typeexpr_combine_types(right_op, right_types, ctx=ctx)]
            )

            return (ql_t.op, False, types)

        raise errors.UnsupportedFeatureError(
            f'type operator {ql_t.op!r} is not implemented',
            span=ql_t.span)

    elif isinstance(ql_t, qlast.TypeName):
        return (None, True, [_ql_typename_to_type(ql_t, ctx=ctx)])

    else:
        raise errors.EdgeQLSyntaxError("Unexpected type expression",
                                       span=ql_t.span)


def _ql_typename_to_type(
    ql_t: qlast.TypeName, *, ctx: context.ContextLevel
) -> s_types.Type:
    if ql_t.subtypes:
        assert isinstance(ql_t.maintype, qlast.ObjectRef)
        coll = s_types.Collection.get_class(ql_t.maintype.name)
        ct: s_types.Type

        if issubclass(coll, s_types.Tuple):
            t_subtypes = {}
            named = False
            for si, st in enumerate(ql_t.subtypes):
                if st.name:
                    named = True
                    type_name = st.name
                else:
                    type_name = str(si)

                t_subtypes[type_name] = ql_typeexpr_to_type(st, ctx=ctx)

            ctx.env.schema, ct = coll.from_subtypes(
                ctx.env.schema, t_subtypes, {'named': named})
            return ct
        else:
            a_subtypes = []
            for st in ql_t.subtypes:
                a_subtypes.append(ql_typeexpr_to_type(st, ctx=ctx))

            ctx.env.schema, ct = coll.from_subtypes(ctx.env.schema, a_subtypes)
            return ct
    else:
        return schemactx.get_schema_type(ql_t.maintype, ctx=ctx)


@overload
def ptrcls_from_ptrref(
    ptrref: irast.PointerRef,
    *,
    ctx: context.ContextLevel,
) -> s_pointers.Pointer:
    ...


@overload
def ptrcls_from_ptrref(
    ptrref: irast.TupleIndirectionPointerRef,
    *,
    ctx: context.ContextLevel,
) -> irast.TupleIndirectionLink:
    ...


@overload
def ptrcls_from_ptrref(
    ptrref: irast.TypeIntersectionPointerRef,
    *,
    ctx: context.ContextLevel,
) -> irast.TypeIntersectionLink:
    ...


@overload
def ptrcls_from_ptrref(
    ptrref: irast.BasePointerRef,
    *,
    ctx: context.ContextLevel,
) -> s_pointers.PointerLike:
    ...


def ptrcls_from_ptrref(
    ptrref: irast.BasePointerRef,
    *,
    ctx: context.ContextLevel,
) -> s_pointers.PointerLike:

    cached = ctx.env.ptr_ref_cache.get_ptrcls_for_ref(ptrref)
    if cached is not None:
        return cached

    ctx.env.schema, ptr = irtyputils.ptrcls_from_ptrref(
        ptrref, schema=ctx.env.schema)

    return ptr


def ptr_to_ptrref(
    ptrcls: s_pointers.Pointer,
    *,
    ctx: context.ContextLevel,
) -> irast.PointerRef:
    return irtyputils.ptrref_from_ptrcls(
        schema=ctx.env.schema,
        ptrcls=ptrcls,
        cache=ctx.env.ptr_ref_cache,
        typeref_cache=ctx.env.type_ref_cache,
    )


def collapse_type_intersection_rptr(
    ir_set: irast.Set,
    *,
    ctx: context.ContextLevel,
) -> tuple[irast.Set, list[s_pointers.Pointer]]:

    ind_prefix, ind_ptrs = irutils.collapse_type_intersection(ir_set)
    if not ind_ptrs:
        return ir_set, []

    rptr_specialization: set[irast.PointerRef] = set()
    for ind_ptr in ind_ptrs:
        for ind_ptr in ind_ptrs:
            if ind_ptr.ptrref.rptr_specialization:
                rptr_specialization.update(
                    ind_ptr.ptrref.rptr_specialization)
            elif (
                not ind_ptr.ptrref.is_empty
                and isinstance(ind_ptr.source.expr, irast.Pointer)
            ):
                assert isinstance(ind_ptr.source.expr.ptrref, irast.PointerRef)
                rptr_specialization.add(ind_ptr.source.expr.ptrref)

    ptrs = [ptrcls_from_ptrref(ptrref, ctx=ctx)
            for ptrref in rptr_specialization]

    return ind_prefix, ptrs


def type_from_typeref(
    t: irast.TypeRef,
    env: context.Environment,
) -> s_types.Type:
    env.schema, styp = irtyputils.ir_typeref_to_type(env.schema, t)
    return styp


def type_to_typeref(
    t: s_types.Type,
    env: context.Environment,
) -> irast.TypeRef:
    schema = env.schema
    cache = env.type_ref_cache
    expr_type = t.get_expr_type(env.schema)
    include_children = (
        expr_type is s_types.ExprType.Update
        or expr_type is s_types.ExprType.Delete
        or isinstance(t, s_objtypes.ObjectType)
    )
    include_ancestors = (
        expr_type is s_types.ExprType.Insert
        or expr_type is s_types.ExprType.Update
        or expr_type is s_types.ExprType.Delete
    )
    return irtyputils.type_to_typeref(
        schema,
        t,
        include_children=include_children,
        include_ancestors=include_ancestors,
        cache=cache,
    )


================================================
FILE: edb/edgeql/compiler/viewgen.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""EdgeQL shape compilation functions."""


from __future__ import annotations

import collections
import dataclasses
import functools
from typing import (
    Callable,
    Optional,
    AbstractSet,
    Mapping,
    Sequence,
    NamedTuple,
    cast,
    TYPE_CHECKING,
)

from edb import errors
from edb.common import ast
from edb.common import parsing
from edb.common import topological
from edb.common.typeutils import downcast, not_none

from edb.ir import ast as irast
from edb.ir import typeutils
from edb.ir import utils as irutils
import edb.ir.typeutils as irtypeutils

from edb.schema import futures as s_futures
from edb.schema import links as s_links
from edb.schema import name as sn
from edb.schema import objtypes as s_objtypes
from edb.schema import objects as s_objects
from edb.schema import pointers as s_pointers
from edb.schema import properties as s_props
from edb.schema import types as s_types
from edb.schema import expr as s_expr

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from . import astutils
from . import context
from . import dispatch
from . import eta_expand
from . import pathctx
from . import schemactx
from . import setgen
from . import typegen

if TYPE_CHECKING:
    from edb.schema import sources as s_sources


class ShapeElementDesc(NamedTuple):
    """Annotated QL shape element for processing convenience"""

    #: Shape element AST
    ql: qlast.ShapeElement
    #: Canonical Path AST for the shape element
    path_ql: qlast.Path
    #: The underlying pointer AST
    ptr_ql: qlast.Ptr
    #: The name of the pointer
    ptr_name: str
    #: Pointer source object
    source: s_sources.Source
    #: Target type intersection (if any)
    target_typexpr: Optional[qlast.TypeExpr]
    #: Whether the source is a type intersection
    is_polymorphic: bool
    #: Whether the pointer is a link property
    is_linkprop: bool


class EarlyShapePtr(NamedTuple):
    """Stage 1 shape processing result element"""
    ptrcls: s_pointers.Pointer
    target_set: Optional[irast.Set]
    shape_origin: qlast.ShapeOrigin
    span: Optional[parsing.Span]


class ShapePtr(NamedTuple):
    """Stage 2 shape processing result element"""
    source_set: irast.Set
    ptrcls: s_pointers.Pointer
    shape_op: qlast.ShapeOp
    target_set: Optional[irast.Set]
    span: Optional[parsing.Span]


@dataclasses.dataclass(kw_only=True, frozen=True)
class ShapeContext:
    # a helper object for passing shape compile parameters

    path_id_namespace: Optional[irast.Namespace] = None

    view_rptr: Optional[context.ViewRPtr] = None

    view_name: Optional[sn.QualName] = None

    exprtype: s_types.ExprType = s_types.ExprType.Select


def process_view(
    ir_set: irast.Set,
    *,
    stype: s_objtypes.ObjectType,
    elements: Sequence[qlast.ShapeElement],
    view_rptr: Optional[context.ViewRPtr] = None,
    view_name: Optional[sn.QualName] = None,
    exprtype: s_types.ExprType = s_types.ExprType.Select,
    ctx: context.ContextLevel,
    span: Optional[parsing.Span],
) -> tuple[s_objtypes.ObjectType, irast.Set]:

    cache_key = (stype, exprtype, tuple(elements))
    view_scls = ctx.env.shape_type_cache.get(cache_key)
    if view_scls is not None:
        return view_scls, ir_set

    # XXX: This is an unfortunate hack to ensure that "cannot
    # reference correlated set" errors get produced correctly,
    # since there needs to be an intervening branch for a
    # factoring fence to be respected.
    hackscope = ctx.path_scope.attach_branch()
    pathctx.register_set_in_scope(ir_set, path_scope=hackscope, ctx=ctx)
    hackscope.remove()
    ctx.path_scope.attach_subtree(hackscope, ctx=ctx)

    # Make a snapshot of aliased_views that can't be mutated
    # in any parent scopes.
    ctx.aliased_views = collections.ChainMap(dict(ctx.aliased_views))

    s_ctx = ShapeContext(
        path_id_namespace=None,
        view_rptr=view_rptr,
        view_name=view_name,
        exprtype=exprtype,
    )

    view_scls, ir = _process_view(
        ir_set,
        stype=stype,
        elements=elements,
        ctx=ctx,
        s_ctx=s_ctx,
        span=span,
    )

    ctx.env.shape_type_cache[cache_key] = view_scls

    return view_scls, ir


def _process_view(
    ir_set: irast.Set,
    *,
    stype: s_objtypes.ObjectType,
    elements: Optional[Sequence[qlast.ShapeElement]],
    s_ctx: ShapeContext,
    ctx: context.ContextLevel,
    span: Optional[parsing.Span],
) -> tuple[s_objtypes.ObjectType, irast.Set]:
    path_id = ir_set.path_id
    view_rptr = s_ctx.view_rptr

    view_name = s_ctx.view_name
    needs_real_name = view_name is None and ctx.env.options.schema_view_mode
    generated_name = None
    if needs_real_name and view_rptr is not None:
        # Make sure persistent schema expression aliases have properly formed
        # names as opposed to the usual mangled form of the ephemeral
        # aliases.  This is needed for introspection readability, as well
        # as helps in maintaining proper type names for schema
        # representations that require alphanumeric names, such as
        # GraphQL.
        #
        # We use the name of the source together with the name
        # of the inbound link to form the name, so in e.g.
        #    CREATE ALIAS V := (SELECT Foo { bar: { baz: { ... } })
        # The name of the innermost alias would be "__V__bar__baz".
        source_name = view_rptr.source.get_name(ctx.env.schema).name
        if not source_name.startswith('__'):
            source_name = f'__{source_name}'
        if view_rptr.ptrcls_name is not None:
            ptr_name = view_rptr.ptrcls_name.name
        elif view_rptr.ptrcls is not None:
            ptr_name = view_rptr.ptrcls.get_shortname(ctx.env.schema).name
        else:
            raise errors.InternalServerError(
                '_process_view in schema mode received view_rptr with '
                'neither ptrcls_name, not ptrcls'
            )

        generated_name = f'{source_name}__{ptr_name}'
    elif needs_real_name and ctx.env.alias_result_view_name:
        # If this is a persistent schema expression but we aren't just
        # obviously sitting on an rptr (e.g CREATE ALIAS V := (Foo { x }, 10)),
        # we create a name like __V__Foo__2.
        source_name = ctx.env.alias_result_view_name.name
        type_name = stype.get_name(ctx.env.schema).name
        generated_name = f'__{source_name}__{type_name}'

    if generated_name:
        # If there are multiple, we want to stick a number on, but we'd
        # like to skip the number if there aren't.
        name = ctx.aliases.get(
            generated_name).replace('~1', '').replace('~', '__')
        view_name = sn.QualName(
            module=ctx.derived_target_module or '__derived__',
            name=name,
        )

    view_scls = schemactx.derive_view(
        stype,
        exprtype=s_ctx.exprtype,
        derived_name=view_name,
        ctx=ctx,
    )
    assert isinstance(view_scls, s_objtypes.ObjectType), view_scls
    is_mutation = s_ctx.exprtype.is_insert() or s_ctx.exprtype.is_update()
    is_defining_shape = ctx.expr_exposed or is_mutation

    ir_set = setgen.ensure_set(ir_set, type_override=view_scls, ctx=ctx)
    # Maybe rematerialize the set. The old ir_set might have already
    # been materialized, but the new version would be missing from the
    # use_sets.
    if isinstance(ir_set.expr, irast.Pointer):
        ctx.env.schema, remat_ptrcls = typeutils.ptrcls_from_ptrref(
            ir_set.expr.ptrref, schema=ctx.env.schema
        )
        setgen.maybe_materialize(remat_ptrcls, ir_set, ctx=ctx)

    if view_rptr is not None and view_rptr.ptrcls is None:
        target_scls = stype if is_mutation else view_scls
        derive_ptrcls(view_rptr, target_scls=target_scls, ctx=ctx)

    pointers: dict[s_pointers.Pointer, EarlyShapePtr] = {}

    if elements is None:
        elements = []

    shape_desc: list[ShapeElementDesc] = []
    # First, find all explicit pointers (i.e. non-splat elements)
    for shape_el in elements:
        if isinstance(shape_el.expr.steps[0], qlast.Splat):
            continue

        shape_desc.append(
            _shape_el_ql_to_shape_el_desc(
                shape_el, source=view_scls, s_ctx=s_ctx, ctx=ctx
            )
        )

    explicit_ptr_names = {
        desc.ptr_name for desc in shape_desc if not desc.is_linkprop
    }

    explicit_lprop_names = {
        desc.ptr_name for desc in shape_desc if desc.is_linkprop
    }

    # Now look for any splats and expand them.
    # Track descriptions by name and whether they are link properties.
    splat_descs: dict[tuple[str, bool], ShapeElementDesc] = {}
    for shape_el in elements:
        if not isinstance(shape_el.expr.steps[0], qlast.Splat):
            continue

        if s_ctx.exprtype is not s_types.ExprType.Select:
            raise errors.QueryError(
                "unexpected splat operator in non-SELECT shape",
                span=shape_el.expr.span,
            )

        if ctx.env.options.func_params is not None:
            raise errors.UnsupportedFeatureError(
                "splat operators in function bodies are not supported",
                span=shape_el.expr.span,
            )

        splat = shape_el.expr.steps[0]
        if splat.type is not None:
            splat_type = typegen.ql_typeexpr_to_type(splat.type, ctx=ctx)
            if not isinstance(splat_type, s_objtypes.ObjectType):
                vn = splat_type.get_verbosename(schema=ctx.env.schema)
                raise errors.QueryError(
                    f"splat operator expects an object type, got {vn}",
                    span=splat.type.span,
                )

            if not stype.issubclass(ctx.env.schema, splat_type):
                vn = stype.get_verbosename(ctx.env.schema)
                vn2 = splat_type.get_verbosename(schema=ctx.env.schema)
                raise errors.QueryError(
                    f"splat type must be {vn} or its parent type, "
                    f"got {vn2}",
                    span=splat.type.span,
                )

            if splat.intersection is not None:
                intersector_type = typegen.ql_typeexpr_to_type(
                    splat.intersection.type, ctx=ctx)
                splat_type = schemactx.apply_intersection(
                    splat_type,
                    intersector_type,
                    ctx=ctx,
                ).stype
                assert isinstance(splat_type, s_objtypes.ObjectType)

        elif splat.intersection is not None:
            splat_type = typegen.ql_typeexpr_to_type(
                splat.intersection.type, ctx=ctx)
            if not isinstance(splat_type, s_objtypes.ObjectType):
                vn = splat_type.get_verbosename(schema=ctx.env.schema)
                raise errors.QueryError(
                    f"splat operator expects an object type, got {vn}",
                    span=splat.intersection.type.span,
                )
        else:
            splat_type = stype

        if (
            view_rptr is not None
            and isinstance(view_rptr.ptrcls, s_links.Link)
        ):
            splat_rlink = view_rptr.ptrcls
        else:
            splat_rlink = None

        expanded_splat = _expand_splat(
            splat_type,
            depth=splat.depth,
            intersection=splat.intersection,
            rlink=splat_rlink,
            skip_ptrs=explicit_ptr_names,
            skip_lprops=explicit_lprop_names,
            ctx=ctx,
        )

        for splat_el in expanded_splat:
            desc = _shape_el_ql_to_shape_el_desc(
                splat_el, source=view_scls, s_ctx=s_ctx, ctx=ctx
            )
            desc_key: tuple[str, bool] = (desc.ptr_name, desc.is_linkprop)
            if old_desc := splat_descs.get(desc_key):
                # If pointers appear in multiple splats, we take the
                # one from the ancestor class. If neither class is an
                # ancestor, we reject it.
                # TODO: Accept it instead, if the types are the same.
                new_source: object = desc.source
                old_source: object = old_desc.source
                if isinstance(new_source, s_links.Link):
                    new_source = new_source.get_source(ctx.env.schema)
                assert isinstance(new_source, s_objtypes.ObjectType)
                if isinstance(old_source, s_links.Link):
                    old_source = old_source.get_source(ctx.env.schema)
                assert isinstance(old_source, s_objtypes.ObjectType)
                new_source = schemactx.concretify(new_source, ctx=ctx)
                old_source = schemactx.concretify(old_source, ctx=ctx)

                if new_source.issubclass(ctx.env.schema, old_source):
                    # Do nothing.
                    pass
                elif old_source.issubclass(ctx.env.schema, new_source):
                    # Take the new one
                    splat_descs[desc_key] = desc
                else:
                    vn1 = old_source.get_verbosename(schema=ctx.env.schema)
                    vn2 = new_source.get_verbosename(schema=ctx.env.schema)
                    raise errors.QueryError(
                        f"link or property '{desc.ptr_name}' appears in splats "
                        f"for unrelated types: {vn1} and {vn2}",
                        span=splat.span,
                    )

            else:
                splat_descs[desc_key] = desc

    shape_desc.extend(splat_descs.values())

    for shape_el_desc in shape_desc:
        with ctx.new() as scopectx:
            # when doing insert or update with a compexpr, generate the
            # the anchor for __default__
            if (
                (
                    # mutating statement, ptrcls guaranteed to exist
                    (s_ctx.exprtype.is_insert() or s_ctx.exprtype.is_update())
                    # linkprops are used in non-mutating statements as part of
                    # mutating statemnts, ptrcls not guaranteed to exist
                    or (
                        shape_el_desc.is_linkprop
                        # check that the linkprop actually exists in the source
                        and (
                            sn.UnqualName(shape_el_desc.ptr_name) in (
                                source_ptr.get_local_name(scopectx.env.schema)
                                for source_ptr in (
                                    shape_el_desc.source
                                    .get_pointers(scopectx.env.schema)
                                    .objects(scopectx.env.schema)
                                )
                            )
                        )
                    )
                )
                and shape_el_desc.ql.compexpr is not None
                and shape_el_desc.ptr_name not in (
                    ctx.special_computables_in_mutation_shape
                )
            ):
                ptrcls = setgen.resolve_ptr(
                    shape_el_desc.source,
                    shape_el_desc.ptr_name,
                    track_ref=shape_el_desc.ptr_ql,
                    ctx=scopectx
                )

                compexpr_uses_default = False
                compexpr_default_span: Optional[parsing.Span] = None
                for path_node in (
                    ast.find_children(
                        shape_el_desc.ql.compexpr,
                        qlast.Path,
                        extra_skip_types=(qlast.Query, qlast.Shape),
                    )
                    if not isinstance(
                        shape_el_desc.ql.compexpr,
                        (qlast.Query, qlast.Shape)
                    ) else
                    ()
                ):
                    for step in path_node.steps:
                        if not isinstance(step, qlast.SpecialAnchor):
                            continue
                        if step.name != '__default__':
                            continue

                        compexpr_uses_default = True
                        compexpr_default_span = step.span
                        break

                    if compexpr_uses_default:
                        break

                if compexpr_uses_default:
                    def make_error(
                        span: Optional[parsing.Span], hint: str
                    ) -> errors.InvalidReferenceError:
                        return errors.InvalidReferenceError(
                            f'__default__ cannot be used in this expression',
                            span=span,
                            hint=hint,
                        )

                    default_expr: Optional[s_expr.Expression] = (
                        ptrcls.get_default(scopectx.env.schema)
                    )
                    if default_expr is None:
                        raise make_error(
                            compexpr_default_span,
                            'No default expression exists',
                        )

                    default_ast_expr = default_expr.parse()

                    if any(
                        any(
                            (
                                isinstance(step, qlast.SpecialAnchor)
                                and step.name == '__source__'
                            )
                            for step in path_node.steps
                        )
                        for path_node in ast.find_children(
                            default_ast_expr, qlast.Path
                        )
                    ):
                        raise make_error(
                            compexpr_default_span,
                            'Default expression uses __source__',
                        )

                    if astutils.contains_dml(default_ast_expr, ctx=ctx):
                        raise make_error(
                            compexpr_default_span,
                            'Default expression uses DML',
                        )

                    default_set = dispatch.compile(
                        default_ast_expr, ctx=scopectx
                    )

                    scopectx.anchors['__default__'] = default_set

            pointer, ptr_set = _normalize_view_ptr_expr(
                ir_set,
                shape_el_desc,
                view_scls,
                path_id=path_id,
                pending_pointers=pointers,
                s_ctx=s_ctx,
                ctx=scopectx,
            )

            pointers[pointer] = EarlyShapePtr(
                pointer, ptr_set, shape_el_desc.ql.origin, shape_el_desc.ql.span
            )

    # If we are not defining a shape (so we might care about
    # materialization), look through our parent view (if one exists)
    # for materialized properties that are not present in this shape.
    # If any are found, inject them.
    # (See test_edgeql_volatility_rebind_flat_01 for an example.)
    schema = ctx.env.schema
    base = view_scls.get_bases(schema).objects(schema)[0]
    base_ptrs = (view_scls.get_pointers(schema).objects(schema)
                 if not is_defining_shape else ())
    for ptrcls in base_ptrs:
        if ptrcls in pointers or base not in ctx.env.view_shapes:
            continue
        pptr = ptrcls.get_bases(schema).objects(schema)[0]
        if (pptr, qlast.ShapeOp.MATERIALIZE) not in ctx.env.view_shapes[base]:
            continue

        # Make up a dummy shape element
        name = ptrcls.get_shortname(schema).name
        dummy_el = qlast.ShapeElement(expr=qlast.Path(
            steps=[qlast.Ptr(name=name)]))
        dummy_el_desc = _shape_el_ql_to_shape_el_desc(
            dummy_el, source=view_scls, s_ctx=s_ctx, ctx=ctx
        )

        with ctx.new() as scopectx:
            pointer, ptr_set = _normalize_view_ptr_expr(
                ir_set,
                dummy_el_desc,
                view_scls,
                path_id=path_id,
                s_ctx=s_ctx,
                ctx=scopectx,
            )

        pointers[pointer] = EarlyShapePtr(
            pointer, ptr_set, qlast.ShapeOrigin.MATERIALIZATION, None
        )

    specified_ptrs = {
        ptrcls.get_local_name(ctx.env.schema) for ptrcls in pointers
    }

    # defaults
    if s_ctx.exprtype.is_insert():
        defaults_ptrs = _gen_pointers_from_defaults(
            specified_ptrs, view_scls, ir_set, stype, s_ctx, ctx
        )
        pointers.update(defaults_ptrs)

    # rewrites
    rewrite_kind = (
        qltypes.RewriteKind.Insert
        if s_ctx.exprtype.is_insert()
        else qltypes.RewriteKind.Update
        if s_ctx.exprtype.is_update()
        else None
    )

    if rewrite_kind:
        rewrites = _compile_rewrites(
            specified_ptrs, rewrite_kind, view_scls, ir_set, stype, s_ctx, ctx
        )
        if rewrites:
            ctx.env.dml_rewrites[ir_set] = rewrites
    else:
        rewrites = None

    if s_ctx.exprtype.is_insert():
        _raise_on_missing(pointers, stype, rewrites, ctx, span=span)

    set_shape = []
    shape_ptrs: list[ShapePtr] = []

    for ptrcls, ptr_set, _, span in pointers.values():
        source: s_types.Type | s_pointers.PointerLike

        if ptrcls.is_link_property(ctx.env.schema):
            assert view_rptr is not None and view_rptr.ptrcls is not None
            source = view_rptr.ptrcls
        else:
            source = view_scls

        if is_defining_shape:
            cinfo = ctx.env.source_map.get(ptrcls)
            if cinfo is not None:
                shape_op = cinfo.shape_op
            else:
                shape_op = qlast.ShapeOp.ASSIGN
        elif ptrcls.get_computable(ctx.env.schema):
            shape_op = qlast.ShapeOp.MATERIALIZE
        else:
            continue

        ctx.env.view_shapes[source].append((ptrcls, shape_op))
        shape_ptrs.append(ShapePtr(ir_set, ptrcls, shape_op, ptr_set, span))

    rptrcls = view_rptr.ptrcls if view_rptr else None
    shape_ptrs = _get_early_shape_configuration(
        ir_set, shape_ptrs, rptrcls=rptrcls, ctx=ctx)

    # Produce the shape. The main thing here is that we need to fixup
    # all of the rptrs to properly point back at ir_set.
    for _, ptrcls, shape_op, ptr_set, ptr_span in shape_ptrs:
        if ptrcls in ctx.env.pointer_specified_info:
            _, _, ptr_span = ctx.env.pointer_specified_info[ptrcls]

        if ptr_set:
            src_path_id = path_id
            is_linkprop = ptrcls.is_link_property(ctx.env.schema)
            if is_linkprop:
                src_path_id = src_path_id.ptr_path()

            ptr_set.path_id = pathctx.extend_path_id(
                src_path_id,
                ptrcls=ptrcls,
                ns=ctx.path_id_namespace,
                ctx=ctx,
            )
            assert not isinstance(ptr_set.expr, irast.Pointer)
            ptr_set.expr = irast.Pointer(
                source=ir_set,
                expr=ptr_set.expr,
                direction=s_pointers.PointerDirection.Outbound,
                ptrref=not_none(ptr_set.path_id.rptr()),
                is_definition=True,

                is_mutation=(
                    is_mutation
                    or (
                        is_linkprop
                        and s_ctx.view_rptr is not None
                        and s_ctx.view_rptr.exprtype.is_mutation()
                    )
                ),
            )
            # XXX: We would maybe like to *not* do this when it
            # already has a context, since for explain output that
            # seems nicer, but this is what we want for producing
            # actual error messages.
            ptr_set.span = ptr_span

        else:
            # The set must be something pretty trivial, so just do it
            ptr_set = setgen.extend_path(
                ir_set,
                ptrcls,
                same_computable_scope=True,
                span=ptr_span,
                ctx=ctx,
            )

        assert irutils.is_set_instance(ptr_set, irast.Pointer)
        set_shape.append((ptr_set, shape_op))

    ir_set.shape = tuple(set_shape)

    if (view_rptr is not None and view_rptr.ptrcls is not None and
            view_scls != stype):
        ctx.env.schema = view_scls.set_field_value(
            ctx.env.schema, 'rptr', view_rptr.ptrcls)

    return view_scls, ir_set


def _shape_el_ql_to_shape_el_desc(
    shape_el: qlast.ShapeElement,
    *,
    source: s_sources.Source,
    s_ctx: ShapeContext,
    ctx: context.ContextLevel,
) -> ShapeElementDesc:
    """Look at ShapeElement AST and annotate it for more convenient handing."""

    steps = shape_el.expr.steps
    is_linkprop = False
    is_polymorphic = False
    plen = len(steps)
    target_typexpr = None
    source_intersection = []

    if plen >= 2 and isinstance(steps[-1], qlast.TypeIntersection):
        # Target type intersection: foo: Type
        target_typexpr = steps[-1].type
        plen -= 1
        steps = steps[:-1]

    if plen == 1:
        # regular shape
        lexpr = steps[0]
        assert isinstance(lexpr, qlast.Ptr)
        is_linkprop = lexpr.type == 'property'
        if is_linkprop:
            view_rptr = s_ctx.view_rptr
            if view_rptr is None or view_rptr.ptrcls is None:
                raise errors.QueryError(
                    'invalid reference to link property '
                    'in top level shape', span=lexpr.span)
            assert isinstance(view_rptr.ptrcls, s_links.Link)
            source = view_rptr.ptrcls
    elif plen == 2 and isinstance(steps[0], qlast.TypeIntersection):
        # Source type intersection: [IS Type].foo
        source_intersection = [steps[0]]
        lexpr = steps[1]
        ptype = steps[0].type
        source_spec = typegen.ql_typeexpr_to_type(ptype, ctx=ctx)
        if not isinstance(source_spec, s_objtypes.ObjectType):
            raise errors.QueryError(
                f"expected object type, got "
                f"{source_spec.get_verbosename(ctx.env.schema)}",
                span=ptype.span,
            )
        source = source_spec
        is_polymorphic = True
    else:  # pragma: no cover
        raise RuntimeError(
            f'unexpected path length in view shape: {len(steps)}')

    assert isinstance(lexpr, qlast.Ptr)
    ptrname = lexpr.name

    if target_typexpr is None:
        path_ql = qlast.Path(
            steps=[
                *source_intersection,
                lexpr,
            ],
            partial=True,
            span=shape_el.span,
        )
    else:
        path_ql = qlast.Path(
            steps=[
                *source_intersection,
                lexpr,
                qlast.TypeIntersection(type=target_typexpr),
            ],
            partial=True,
            span=shape_el.span,
        )

    return ShapeElementDesc(
        ql=shape_el,
        path_ql=path_ql,
        ptr_ql=lexpr,
        ptr_name=ptrname,
        source=source,
        target_typexpr=target_typexpr,
        is_polymorphic=is_polymorphic,
        is_linkprop=is_linkprop,
    )


def _expand_splat(
    stype: s_objtypes.ObjectType,
    *,
    depth: int,
    skip_ptrs: AbstractSet[str] = frozenset(),
    skip_lprops: AbstractSet[str] = frozenset(),
    rlink: Optional[s_links.Link] = None,
    intersection: Optional[qlast.TypeIntersection] = None,
    ctx: context.ContextLevel,
) -> list[qlast.ShapeElement]:
    """Expand a splat (possibly recursively) into a list of ShapeElements"""
    elements = []
    pointers = stype.get_pointers(ctx.env.schema)
    path: list[qlast.PathElement] = []
    if intersection is not None:
        path.append(intersection)
    for ptr in pointers.objects(ctx.env.schema):
        if ptr.get_secret(ctx.env.schema):
            continue
        splat_strat = ptr.get_splat_strategy(ctx.env.schema)
        if (
            splat_strat == qltypes.SplatStrategy.Explicit
            or (
                splat_strat == qltypes.SplatStrategy.Default
                and (
                    ptr.get_linkful(ctx.env.schema)
                    if s_futures.future_enabled(
                        ctx.env.schema, 'no_linkful_computed_splats'
                    )
                    else isinstance(ptr, s_links.Link)
                )
            )
        ):
            continue
        sname = ptr.get_shortname(ctx.env.schema)
        # Skip any dunder properties; these are injected properties like
        # __tid__ and __tname__, and we want to manage injecting them
        # ourselves, in the correct positions.
        if (
            (sname.name.startswith('__') and sname.name.endswith('__'))
            or sname.name in skip_ptrs
        ):
            continue
        step = qlast.Ptr(name=sname.name)
        # Make sure not to overwrite the id property.
        if not ptr.is_id_pointer(ctx.env.schema):
            steps = path + [step]
        else:
            steps = [step]
        elements.append(qlast.ShapeElement(
            expr=qlast.Path(steps=steps),
            origin=qlast.ShapeOrigin.SPLAT_EXPANSION,
        ))

    if rlink is not None:
        for prop in rlink.get_pointers(ctx.env.schema).objects(ctx.env.schema):
            if prop.is_endpoint_pointer(ctx.env.schema):
                continue
            assert isinstance(prop, s_props.Property), \
                "non-property pointer on link?"
            sname = prop.get_shortname(ctx.env.schema)
            if sname.name in skip_lprops:
                continue
            elements.append(
                qlast.ShapeElement(
                    expr=qlast.Path(
                        steps=[qlast.Ptr(
                            name=sname.name,
                            type='property',
                        )]
                    ),
                    origin=qlast.ShapeOrigin.SPLAT_EXPANSION,
                )
            )

    if depth > 1:
        for ptr in pointers.objects(ctx.env.schema):
            if not isinstance(ptr, s_links.Link):
                continue
            if ptr.get_secret(ctx.env.schema):
                continue
            splat_strat = ptr.get_splat_strategy(ctx.env.schema)
            if splat_strat == qltypes.SplatStrategy.Explicit:
                continue
            pn = ptr.get_shortname(ctx.env.schema)
            if (
                (pn.name.startswith('__') and pn.name.endswith('__'))
                or pn.name in skip_ptrs
            ):
                continue
            elements.append(
                qlast.ShapeElement(
                    expr=qlast.Path(steps=path + [qlast.Ptr(name=pn.name)]),
                    elements=_expand_splat(
                        ptr.get_target(ctx.env.schema),
                        rlink=ptr,
                        depth=depth - 1,
                        ctx=ctx,
                    ),
                    origin=qlast.ShapeOrigin.SPLAT_EXPANSION,
                )
            )

    return elements


def _gen_pointers_from_defaults(
    specified_ptrs: set[sn.UnqualName],
    view_scls: s_objtypes.ObjectType,
    ir_set: irast.Set,
    stype: s_objtypes.ObjectType,
    s_ctx: ShapeContext,
    ctx: context.ContextLevel,
) -> dict[s_pointers.Pointer, EarlyShapePtr]:
    path_id = ir_set.path_id
    result: list[EarlyShapePtr] = []

    if stype in ctx.active_defaults:
        vn = stype.get_verbosename(ctx.env.schema)
        raise errors.QueryError(
            f"default on property of {vn} is part of a default cycle",
        )

    scls_pointers = stype.get_pointers(ctx.env.schema)
    for pn, ptrcls in scls_pointers.items(ctx.env.schema):
        if (
            (pn in specified_ptrs or ptrcls.is_pure_computable(ctx.env.schema))
            and not ptrcls.get_protected(ctx.env.schema)
        ):
            continue

        default_expr: Optional[s_expr.Expression] = (
            ptrcls.get_default(ctx.env.schema)
        )
        if not default_expr:
            continue

        ptrcls_sn = ptrcls.get_shortname(ctx.env.schema)
        default_ql = qlast.ShapeElement(
            expr=qlast.Path(
                steps=[qlast.Ptr(name=ptrcls_sn.name)],
            ),
            compexpr=qlast.DetachedExpr(
                expr=default_expr.parse(),
                preserve_path_prefix=True,
            ),
            origin=qlast.ShapeOrigin.DEFAULT,
        )
        default_ql_desc = _shape_el_ql_to_shape_el_desc(
            default_ql, source=view_scls, s_ctx=s_ctx, ctx=ctx
        )

        with ctx.new() as scopectx:
            scopectx.active_defaults |= {stype}

            # add __source__ to anchors
            source_set = ir_set
            scopectx.path_scope.attach_path(
                source_set.path_id, span=None,
                optional=False,
                ctx=ctx,
            )
            scopectx.iterator_path_ids |= {source_set.path_id}
            scopectx.anchors['__source__'] = source_set

            pointer, ptr_set = _normalize_view_ptr_expr(
                ir_set,
                default_ql_desc,
                view_scls,
                path_id=path_id,
                from_default=True,
                s_ctx=s_ctx,
                ctx=scopectx,
            )

            result.append(EarlyShapePtr(
                pointer, ptr_set, qlast.ShapeOrigin.DEFAULT, None
            ))

    schema = ctx.env.schema

    # Toposort defaults
    # This is required because defaults may reference each other
    # (and even contain cyclical dependencies).
    # We cannot check or preprocess this at migration time, because some
    # defaults may not be used for some inserts.
    pointer_indexes = {}
    for (index, (pointer, _, _, _)) in enumerate(result):
        p = pointer.get_nearest_non_derived_parent(schema)
        pointer_indexes[p.get_name(schema).name] = index
    graph = {}
    for (index, (_, irset, _, _)) in enumerate(result):
        assert irset
        dep_pointers = ast.find_children(irset, irast.Pointer)
        dep_rptrs = (
            # pointer.target_path_id.rptr() for pointer in dep_pointers
            pointer.ptrref for pointer in dep_pointers
            if pointer.source.typeref.id == stype.id
        )
        deps = {
            pointer_indexes[rpts.name.name] for rpts in dep_rptrs
            if rpts and rpts.name.name in pointer_indexes
        }
        graph[index] = topological.DepGraphEntry(
            item=index, deps=deps, extra=False,
        )

    ordered = [
        result[i] for i in topological.sort(graph, allow_unresolved=True)
    ]

    return {v.ptrcls: v for v in ordered}


def _raise_on_missing(
    pointers: dict[s_pointers.Pointer, EarlyShapePtr],
    stype: s_objtypes.ObjectType,
    rewrites: Optional[irast.Rewrites],
    ctx: context.ContextLevel,
    span: Optional[parsing.Span],
) -> None:
    pointer_names = {
        ptr.get_local_name(ctx.env.schema) for ptr in pointers
    }

    scls_pointers = stype.get_pointers(ctx.env.schema)
    for pn, ptrcls in scls_pointers.items(ctx.env.schema):
        if pn == sn.UnqualName("__type__"):
            continue

        if pn in pointer_names or ptrcls.is_pure_computable(ctx.env.schema):
            continue

        if not ptrcls.get_required(ctx.env.schema):
            continue

        # is it rewritten?
        if rewrites:
            # (inserts must produce rewrites only for stype)
            assert len(rewrites.by_type) == 1
            if pn.name in next(iter(rewrites.by_type.values())):
                continue

        if ptrcls.is_property():
            # If the target is a sequence, there's no need
            # for an explicit value.
            ptrcls_target = ptrcls.get_target(ctx.env.schema)
            assert ptrcls_target is not None
            if ptrcls_target.issubclass(
                ctx.env.schema,
                ctx.env.schema.get(
                    "std::sequence", type=s_objects.SubclassableObject
                ),
            ):
                continue

        vn = ptrcls.get_verbosename(ctx.env.schema, with_parent=True)
        msg = f"missing value for required {vn}"
        # If this is happening in the context of DDL, report a
        # QueryError because it is weird to report an ExecutionError
        # (MissingRequiredError) when nothing is really executing.
        if ctx.env.options.schema_object_context:
            raise errors.SchemaDefinitionError(msg, span=span)
        else:
            raise errors.MissingRequiredError(msg, span=span)


@dataclasses.dataclass(kw_only=True, repr=False, eq=False)
class RewriteContext:
    specified_ptrs: set[sn.UnqualName]
    kind: qltypes.RewriteKind

    base_type: s_objtypes.ObjectType
    shape_type: s_objtypes.ObjectType


def _compile_rewrites(
    specified_ptrs: set[sn.UnqualName],
    kind: qltypes.RewriteKind,
    view_scls: s_objtypes.ObjectType,
    ir_set: irast.Set,
    stype: s_objtypes.ObjectType,
    s_ctx: ShapeContext,
    ctx: context.ContextLevel,
) -> Optional[irast.Rewrites]:
    # init
    r_ctx = RewriteContext(
        specified_ptrs=specified_ptrs,
        kind=kind,
        base_type=stype,
        shape_type=view_scls,
    )

    # Computing anchors isn't cheap, so we want to only do it once,
    # and only do it when it is necessary.
    anchors: dict[s_objtypes.ObjectType, RewriteAnchors] = {}

    def get_anchors(stype: s_objtypes.ObjectType) -> RewriteAnchors:
        if stype not in anchors:
            anchors[stype] = prepare_rewrite_anchors(
                stype, ir_set.path_id, r_ctx, s_ctx, ctx)
        return anchors[stype]

    rewrites = _compile_rewrites_for_stype(
        stype, kind, ir_set, get_anchors, s_ctx, ctx=ctx
    )

    if kind == qltypes.RewriteKind.Insert:
        type_ref = typegen.type_to_typeref(stype, ctx.env)
        rewrites_by_type = {type_ref: rewrites}

    elif kind == qltypes.RewriteKind.Update:
        # Update may also change objects that are children of stype
        # Here we build a dict of rewrites for each descendent type for each
        # of its pointers.

        # This dict is stored in the context and pulled into the update
        # statement later.

        rewrites_by_type = _compile_rewrites_of_children(
            stype, rewrites, kind, ir_set, get_anchors, s_ctx, ctx
        )

    else:
        raise NotImplementedError()

    schema = ctx.env.schema
    by_type: dict[irast.TypeRef, irast.RewritesOfType] = {}
    for ty, rewrites_of_type in rewrites_by_type.items():
        ty = ty.real_material_type

        by_type[ty] = {}
        for element in rewrites_of_type.values():
            target = element.target_set
            assert target

            ptrref = typegen.ptr_to_ptrref(element.ptrcls, ctx=ctx)
            actual_ptrref = irtypeutils.find_actual_ptrref(ty, ptrref)
            pn = actual_ptrref.shortname.name
            path_id = irast.PathId.from_pointer(
                schema, element.ptrcls, env=ctx.env
            )

            # construct a new set with correct path_id
            ptr_set = setgen.new_set_from_set(
                target,
                path_id=path_id,
                ctx=ctx,
            )

            # construct a new set with correct path_id
            ptr_set.expr = irast.Pointer(
                source=ir_set,
                expr=ptr_set.expr,
                direction=s_pointers.PointerDirection.Outbound,
                ptrref=actual_ptrref,
                is_definition=True,
            )
            assert irutils.is_set_instance(ptr_set, irast.Pointer)

            by_type[ty][pn] = (ptr_set, ptrref.real_material_ptr)

    anc = next(iter(anchors.values()), None)
    if not anc:
        return None

    return irast.Rewrites(
        old_path_id=anc.old_set.path_id if anc.old_set else None,
        by_type=by_type,
    )


def _compile_rewrites_of_children(
    stype: s_objtypes.ObjectType,
    parent_rewrites: dict[sn.UnqualName, EarlyShapePtr],
    kind: qltypes.RewriteKind,
    ir_set: irast.Set,
    get_anchors: Callable[[s_objtypes.ObjectType], RewriteAnchors],
    s_ctx: ShapeContext,
    ctx: context.ContextLevel,
) -> dict[irast.TypeRef, dict[sn.UnqualName, EarlyShapePtr]]:
    rewrites_for_type: dict[
        irast.TypeRef, dict[sn.UnqualName, EarlyShapePtr]
    ] = {}

    # save parent to result
    type_ref = typegen.type_to_typeref(stype, ctx.env)
    rewrites_for_type[type_ref] = parent_rewrites.copy()

    for child in stype.children(ctx.env.schema):
        if child.get_is_derived(ctx.env.schema):
            continue

        # base on parent rewrites
        child_rewrites = parent_rewrites.copy()
        # override with rewrites defined here
        rewrites_defined_here = _compile_rewrites_for_stype(
            child, kind, ir_set, get_anchors, s_ctx,
            already_defined_rewrites=child_rewrites,
            ctx=ctx
        )
        child_rewrites.update(rewrites_defined_here)

        # recurse for children
        rewrites_for_type.update(
            _compile_rewrites_of_children(
                child,
                child_rewrites,
                kind,
                ir_set,
                get_anchors,
                s_ctx,
                ctx=ctx,
            )
        )

    return rewrites_for_type


def _compile_rewrites_for_stype(
    stype: s_objtypes.ObjectType,
    kind: qltypes.RewriteKind,
    ir_set: irast.Set,
    get_anchors: Callable[[s_objtypes.ObjectType], RewriteAnchors],
    s_ctx: ShapeContext,
    *,
    already_defined_rewrites: Optional[
        Mapping[sn.UnqualName, EarlyShapePtr]] = None,
    ctx: context.ContextLevel,
) -> dict[sn.UnqualName, EarlyShapePtr]:
    schema = ctx.env.schema

    path_id = ir_set.path_id

    res = {}

    if stype in ctx.active_rewrites:
        vn = stype.get_verbosename(ctx.env.schema)
        raise errors.QueryError(
            f"rewrite rule on {vn} is part of a rewrite rule cycle",
        )

    scls_pointers = stype.get_pointers(schema)
    for pn, ptrcls in scls_pointers.items(schema):
        if ptrcls.is_pure_computable(schema):
            continue

        rewrite = ptrcls.get_rewrite(schema, kind)
        if not rewrite:
            continue
        rewrite_pointer = downcast(
            s_pointers.Pointer, rewrite.get_subject(schema))

        # Because rewrites are not duplicated on inherited properties, the
        # subject this pointer will not be on stype, but on one of its
        # ancestors. Mitigation is to pick the correct pointer from the stype.
        rewrite_pointer = downcast(
            s_pointers.Pointer, stype.get_pointers(schema).get(schema, pn)
        )

        # get_rewrite searches in ancestors for rewrites, but if the rewrite
        # for that ancestor has already been compiled, skip it to avoid
        # duplicating work
        if (
            already_defined_rewrites
            and (existing := already_defined_rewrites.get(pn))
            and (existing[0].get_nearest_non_derived_parent(schema)
                 == rewrite_pointer)
        ):
            continue

        anchors = get_anchors(stype)

        rewrite_expr: Optional[s_expr.Expression] = (
            rewrite.get_expr(ctx.env.schema)
        )
        assert rewrite_expr

        with ctx.newscope(fenced=True) as scopectx:
            scopectx.active_rewrites |= {stype}

            # prepare context
            scopectx.partial_path_prefix = anchors.subject_set
            nanchors = {}
            nanchors["__specified__"] = anchors.specified_set
            nanchors["__subject__"] = anchors.subject_set
            if anchors.old_set:
                nanchors["__old__"] = anchors.old_set

            for key, anchor in nanchors.items():
                scopectx.path_scope.attach_path(
                    anchor.path_id,
                    optional=(anchor is anchors.subject_set),
                    span=None,
                    ctx=ctx,
                )
                scopectx.iterator_path_ids |= {anchor.path_id}
                scopectx.anchors[key] = anchor

            ctx.path_scope.factoring_allowlist.add(anchors.subject_set.path_id)

            # prepare expression
            ptrcls_sn = ptrcls.get_shortname(ctx.env.schema)
            shape_ql = qlast.ShapeElement(
                expr=qlast.Path(
                    steps=[qlast.Ptr(name=ptrcls_sn.name)],
                ),
                compexpr=qlast.DetachedExpr(
                    expr=rewrite_expr.parse(),
                    preserve_path_prefix=True,
                ),
            )
            shape_ql_desc = _shape_el_ql_to_shape_el_desc(
                shape_ql,
                source=anchors.rewrite_type,
                s_ctx=s_ctx,
                ctx=scopectx,
            )

            # compile as normal shape element
            pointer, ptr_set = _normalize_view_ptr_expr(
                anchors.subject_set,
                shape_ql_desc,
                anchors.rewrite_type,
                path_id=path_id,
                from_default=True,
                s_ctx=s_ctx,
                ctx=scopectx,
            )
            res[pn] = EarlyShapePtr(
                pointer, ptr_set, qlast.ShapeOrigin.DEFAULT, None
            )
    return res


@dataclasses.dataclass(kw_only=True, repr=False, eq=False)
class RewriteAnchors:
    subject_set: irast.Set
    specified_set: irast.Set
    old_set: Optional[irast.Set]

    rewrite_type: s_objtypes.ObjectType


def prepare_rewrite_anchors(
    stype: s_objtypes.ObjectType,
    subject_path_id: irast.PathId,
    r_ctx: RewriteContext,
    s_ctx: ShapeContext,
    ctx: context.ContextLevel,
) -> RewriteAnchors:
    schema = ctx.env.schema

    # init set for __subject__
    subject_set = setgen.class_set(
        stype, path_id=subject_path_id, ctx=ctx
    )

    # init reference to std::bool
    bool_type = schema.get("std::bool", type=s_types.Type)
    bool_path = irast.PathId.from_type(
        schema,
        bool_type,
        typename=sn.QualName(module="std", name="bool"),
        env=ctx.env,
    )

    # init set for __specified__
    specified_pointers: list[irast.TupleElement] = []
    for pn, _ in stype.get_pointers(schema).items(schema):
        pointer_path_id = irast.PathId.from_type(
            schema,
            bool_type,
            typename=sn.QualName(
                module="__derived__", name=ctx.aliases.get(pn.name)
            ),
            namespace=ctx.path_id_namespace,
            env=ctx.env,
        )

        specified_pointers.append(
            irast.TupleElement(
                name=pn.name,
                val=setgen.ensure_set(
                    irast.BooleanConstant(
                        value=str(pn in r_ctx.specified_ptrs),
                        typeref=bool_path.target,
                    ),
                    ctx=ctx
                ),
                path_id=pointer_path_id
            )
        )
    specified_set = setgen.new_tuple_set(
        specified_pointers, named=True, ctx=ctx
    )

    # init set for __old__
    if r_ctx.kind == qltypes.RewriteKind.Update:
        old_name = sn.QualName("__derived__", ctx.aliases.get("__old__"))
        old_path_id = irast.PathId.from_type(
            schema, stype, typename=old_name,
            namespace=ctx.path_id_namespace, env=ctx.env,
        )
        old_set = setgen.new_set(
            stype=stype, path_id=old_path_id, ctx=ctx,
            expr=irast.TriggerAnchor(
                typeref=typegen.type_to_typeref(stype, env=ctx.env)),
        )
    else:
        old_set = None

    rewrite_type = r_ctx.shape_type
    if stype != r_ctx.shape_type.get_nearest_non_derived_parent(schema):
        rewrite_type = downcast(
            s_objtypes.ObjectType,
            schemactx.derive_view(
                stype,
                exprtype=s_ctx.exprtype,
                ctx=ctx,
            )
        )
        subject_set = setgen.class_set(
            rewrite_type, path_id=subject_set.path_id, ctx=ctx)
        if old_set:
            old_set = setgen.class_set(
                rewrite_type, path_id=old_set.path_id, ctx=ctx)

    return RewriteAnchors(
        subject_set=subject_set,
        specified_set=specified_set,
        old_set=old_set,
        rewrite_type=rewrite_type,
    )


def _compile_qlexpr(
    ir_source: irast.Set,
    qlexpr: qlast.Base,
    view_scls: s_objtypes.ObjectType,
    *,
    ptrcls: Optional[s_pointers.Pointer],
    ptrsource: s_sources.Source,
    ptr_name: sn.QualName,
    is_linkprop: bool,
    should_set_partial_prefix: bool,
    s_ctx: ShapeContext,
    ctx: context.ContextLevel,
) -> tuple[irast.Set, context.ViewRPtr]:

    with ctx.newscope(fenced=True) as shape_expr_ctx:
        # Put current pointer class in context, so
        # that references to link properties in sub-SELECT
        # can be resolved.  This is necessary for proper
        # evaluation of link properties on computable links,
        # most importantly, in INSERT/UPDATE context.
        shape_expr_ctx.view_rptr = context.ViewRPtr(
            source=ptrsource if is_linkprop else view_scls,
            ptrcls=ptrcls,
            ptrcls_name=ptr_name,
            ptrcls_is_linkprop=is_linkprop,
            exprtype=s_ctx.exprtype,
        )

        shape_expr_ctx.defining_view = view_scls
        shape_expr_ctx.path_scope.unnest_fence = True
        source_set = setgen.fixup_computable_source_set(
            ir_source, ctx=shape_expr_ctx
        )

        if should_set_partial_prefix:
            shape_expr_ctx.partial_path_prefix = source_set

        if ptrcls is not None:
            if s_ctx.exprtype.is_mutation():
                shape_expr_ctx.expr_exposed = context.Exposure.EXPOSED

            shape_expr_ctx.empty_result_type_hint = \
                ptrcls.get_target(ctx.env.schema)

        shape_expr_ctx.view_map = ctx.view_map.new_child()
        setgen.update_view_map(
            source_set.path_id, source_set, ctx=shape_expr_ctx)

        irexpr = dispatch.compile(qlexpr, ctx=shape_expr_ctx)

    if ctx.expr_exposed:
        irexpr = eta_expand.eta_expand_ir(irexpr, ctx=ctx)

    return irexpr, shape_expr_ctx.view_rptr


def _normalize_view_ptr_expr(
    ir_source: irast.Set,
    shape_el_desc: ShapeElementDesc,
    view_scls: s_objtypes.ObjectType,
    *,
    path_id: irast.PathId,
    from_default: bool = False,
    pending_pointers: Mapping[s_pointers.Pointer, EarlyShapePtr] | None = None,
    s_ctx: ShapeContext,
    ctx: context.ContextLevel,
) -> tuple[s_pointers.Pointer, Optional[irast.Set]]:
    is_mutation = s_ctx.exprtype.is_insert() or s_ctx.exprtype.is_update()

    materialized = None
    qlexpr: Optional[qlast.Expr] = None
    base_ptrcls_is_alias = False
    irexpr = None

    shape_el = shape_el_desc.ql
    ptrsource = shape_el_desc.source
    ptrname = shape_el_desc.ptr_name
    is_linkprop = shape_el_desc.is_linkprop
    is_polymorphic = shape_el_desc.is_polymorphic
    target_typexpr = shape_el_desc.target_typexpr

    is_independent_polymorphic = False

    compexpr: Optional[qlast.Expr] = shape_el.compexpr
    if compexpr is None and is_mutation:
        raise errors.QueryError(
            "mutation queries must specify values with ':='",
            span=shape_el.expr.steps[-1].span,
        )

    ptrcls: Optional[s_pointers.Pointer]

    if compexpr is None:
        ptrcls = setgen.resolve_ptr(
            ptrsource,
            ptrname,
            track_ref=shape_el_desc.ptr_ql,
            ctx=ctx,
            span=shape_el.span,
        )
        real_ptrcls = None
        if is_polymorphic:
            # For polymorphic pointers, we need to see if the *real*
            # base class has the pointer, because if so we need to use
            # that when doing cardinality inference (since it may need
            # to raise an error, if it is required). If it isn't
            # present on the real type, take note of that so that we
            # suppress the inherited cardinality.
            try:
                real_ptrcls = setgen.resolve_ptr(
                    view_scls,
                    ptrname,
                    track_ref=shape_el_desc.ptr_ql,
                    ctx=ctx,
                    span=shape_el.span,
                )
            except errors.InvalidReferenceError:
                is_independent_polymorphic = True
            ptrcls = schemactx.derive_ptr(ptrcls, view_scls, ctx=ctx)
        real_ptrcls = real_ptrcls or ptrcls

        base_ptrcls = real_ptrcls.get_bases(
            ctx.env.schema).first(ctx.env.schema)
        base_ptr_is_computable = base_ptrcls in ctx.env.source_map
        ptr_name = sn.QualName(
            module='__',
            name=ptrcls.get_shortname(ctx.env.schema).name,
        )

        # Schema computables that point to opaque unions will just have
        # BaseObject as their target, but in order to properly compile
        # it, we need to know the actual type here, so we recompute it.
        # XXX: This is a hack, though, and hopefully we can fix it once
        # the computable/alias rework lands.
        is_opaque_schema_computable = (
            ptrcls.is_pure_computable(ctx.env.schema)
            and (t := ptrcls.get_target(ctx.env.schema))
            and t.get_name(ctx.env.schema) == sn.QualName('std', 'BaseObject')
        )

        base_required = base_ptrcls.get_required(ctx.env.schema)
        base_cardinality = _get_base_ptr_cardinality(base_ptrcls, ctx=ctx)
        base_is_singleton = False
        if base_cardinality is not None and base_cardinality.is_known():
            base_is_singleton = base_cardinality.is_single()

        is_nontrivial = astutils.is_nontrivial_shape_element(shape_el)
        is_obj = not_none(ptrcls.get_target(ctx.env.schema)).is_object_type()

        if (
            is_obj
            or is_nontrivial
            or shape_el.elements

            or base_ptr_is_computable
            or is_polymorphic
            or target_typexpr is not None
            or (ctx.implicit_limit and not base_is_singleton)
            or is_opaque_schema_computable
        ):
            qlexpr = shape_el_desc.path_ql
            if shape_el.elements:
                qlexpr = qlast.Shape(expr=qlexpr, elements=shape_el.elements)

            qlexpr = astutils.ensure_ql_query(qlexpr)
            assert isinstance(qlexpr, qlast.SelectQuery)
            qlexpr.where = shape_el.where
            qlexpr.orderby = shape_el.orderby

            if shape_el.offset or shape_el.limit:
                qlexpr = qlast.SelectQuery(result=qlexpr, implicit=True)
                qlexpr.offset = shape_el.offset
                qlexpr.limit = shape_el.limit

            if (
                ctx.expr_exposed
                and ctx.implicit_limit
                and not base_is_singleton
            ):
                qlexpr = qlast.SelectQuery(result=qlexpr, implicit=True)
                qlexpr.limit = qlast.Constant.integer(ctx.implicit_limit)

        if target_typexpr is not None:
            assert isinstance(target_typexpr, qlast.TypeName)
            intersector_type = schemactx.get_schema_type(
                target_typexpr.maintype, ctx=ctx)

            int_result = schemactx.apply_intersection(
                ptrcls.get_target(ctx.env.schema),  # type: ignore
                intersector_type,
                ctx=ctx,
            )

            ptr_target = int_result.stype
        else:
            _ptr_target = ptrcls.get_target(ctx.env.schema)
            assert _ptr_target
            ptr_target = _ptr_target

        ptr_required = base_required
        ptr_cardinality = base_cardinality
        if shape_el.where or is_polymorphic:
            # If the shape has a filter on it, we need to force a reinference
            # of the cardinality, to produce an error if needed.
            ptr_cardinality = None
        if ptr_cardinality is None or not ptr_cardinality.is_known():
            # We do not know the parent's pointer cardinality yet.
            ctx.env.pointer_derivation_map[base_ptrcls].append(ptrcls)
            ctx.env.pointer_specified_info[ptrcls] = (
                shape_el.cardinality, shape_el.required, shape_el.span)

        # If we generated qlexpr for the element, we process the
        # subview by just compiling the qlexpr. This is so that we can
        # figure out if it needs materialization and also so that
        # `qlexpr is not None` always implies that we did the
        # compilation.
        if qlexpr:
            irexpr, _ = _compile_qlexpr(
                ir_source,
                qlexpr,
                view_scls,
                ptrcls=ptrcls,
                ptrsource=ptrsource,
                ptr_name=ptr_name,
                is_linkprop=is_linkprop,
                should_set_partial_prefix=True,
                s_ctx=s_ctx,
                ctx=ctx,
            )
            materialized = setgen.should_materialize(
                irexpr, ptrcls=ptrcls,
                materialize_visible=True, skipped_bindings={path_id},
                ctx=ctx)
            ptr_target = setgen.get_set_type(irexpr, ctx=ctx)

    # compexpr is not None
    else:
        base_ptrcls = ptrcls = None

        if (is_mutation
                and ptrname not in ctx.special_computables_in_mutation_shape):
            # If this is a mutation, the pointer must exist.
            ptrcls = setgen.resolve_ptr(
                ptrsource, ptrname, track_ref=shape_el_desc.ptr_ql, ctx=ctx)
            if ptrcls.is_pure_computable(ctx.env.schema) and not from_default:
                ptr_vn = ptrcls.get_verbosename(ctx.env.schema,
                                                with_parent=True)
                raise errors.QueryError(
                    f'modification of computed {ptr_vn} is prohibited',
                    span=shape_el.span)

            base_ptrcls = ptrcls.get_bases(
                ctx.env.schema).first(ctx.env.schema)

            ptr_name = sn.QualName(
                module='__',
                name=ptrcls.get_shortname(ctx.env.schema).name,
            )

        else:
            ptr_name = sn.QualName(
                module='__',
                name=ptrname,
            )

            try:
                is_linkprop_mutation = (
                    is_linkprop
                    and s_ctx.view_rptr is not None
                    and s_ctx.view_rptr.exprtype.is_mutation()
                )

                ptrcls = setgen.resolve_ptr(
                    ptrsource,
                    ptrname,
                    track_ref=(
                        False if not is_linkprop_mutation
                        else shape_el_desc.ptr_ql
                    ),
                    ctx=ctx,
                )

                base_ptrcls = ptrcls.get_bases(
                    ctx.env.schema).first(ctx.env.schema)
            except errors.InvalidReferenceError:
                # Check if we aren't inside of modifying statement
                # for link property, otherwise this is a NEW
                # computable pointer, it's fine.
                if is_linkprop_mutation:
                    raise

        qlexpr = astutils.ensure_ql_query(compexpr)
        # HACK: For scope tree related reasons, DML inside of free objects
        # needs to be wrapped in a SELECT. This is probably fixable.
        if irutils.is_trivial_free_object(ir_source):
            qlexpr = astutils.ensure_ql_select(qlexpr)

        if (
            ctx.expr_exposed
            and ctx.implicit_limit
        ):
            qlexpr = qlast.SelectQuery(result=qlexpr, implicit=True)
            qlexpr.limit = qlast.Constant.integer(ctx.implicit_limit)

        irexpr, sub_view_rptr = _compile_qlexpr(
            ir_source,
            qlexpr,
            view_scls,
            ptrcls=ptrcls,
            ptrsource=ptrsource,
            ptr_name=ptr_name,
            is_linkprop=is_linkprop,
            # do not set partial path prefix if in the insert
            # shape but not in defaults
            should_set_partial_prefix=(
                not s_ctx.exprtype.is_insert() or from_default),
            s_ctx=s_ctx,
            ctx=ctx,
        )
        materialized = setgen.should_materialize(
            irexpr, ptrcls=ptrcls,
            materialize_visible=True, skipped_bindings={path_id},
            ctx=ctx)
        ptr_target = setgen.get_set_type(irexpr, ctx=ctx)

        if (
            shape_el.operation.op is qlast.ShapeOp.APPEND
            or shape_el.operation.op is qlast.ShapeOp.SUBTRACT
        ):
            if not s_ctx.exprtype.is_update():
                op = (
                    '+=' if shape_el.operation.op is qlast.ShapeOp.APPEND
                    else '-='
                )
                raise errors.EdgeQLSyntaxError(
                    f"unexpected '{op}'",
                    span=shape_el.operation.span,
                )

        irexpr.span = compexpr.span

        is_inbound_alias = False
        if base_ptrcls is None:
            base_ptrcls = sub_view_rptr.base_ptrcls
            base_ptrcls_is_alias = sub_view_rptr.ptrcls_is_alias
            is_inbound_alias = (
                sub_view_rptr.rptr_dir is s_pointers.PointerDirection.Inbound)

        if ptrcls is not None:
            ctx.env.schema = ptrcls.set_field_value(
                ctx.env.schema, 'owned', True)

        ptr_cardinality = None
        ptr_required = False

        _record_created_collection_types(ptr_target, ctx)

        generic_type = ptr_target.find_generic(ctx.env.schema)
        if generic_type is not None:
            raise errors.QueryError(
                'expression returns value of indeterminate type',
                span=ctx.env.type_origins.get(generic_type),
            )

        # Validate that the insert/update expression is
        # of the correct class.
        if is_mutation and ptrcls is not None:
            base_target = ptrcls.get_target(ctx.env.schema)
            assert base_target is not None
            if ptr_target.assignment_castable_to(
                    base_target,
                    schema=ctx.env.schema):
                # Force assignment casts if the target type is not a
                # subclass of the base type and the cast is not to an
                # object type.
                if not (
                    base_target.is_object_type()
                    or s_types.is_type_compatible(
                        base_target, ptr_target, schema=ctx.env.schema
                    )
                ):
                    qlexpr = astutils.ensure_ql_query(
                        qlast.TypeCast(
                            type=typegen.type_to_ql_typeref(
                                base_target, ctx=ctx
                            ),
                            expr=compexpr,
                        )
                    )
                    ptr_target = base_target
                    # We also need to compile the cast to IR.
                    with ctx.new() as subctx:
                        subctx.anchors = subctx.anchors.copy()
                        source_path = subctx.create_anchor(irexpr, 'a')
                        cast_qlexpr = astutils.ensure_ql_query(
                            qlast.TypeCast(
                                type=typegen.type_to_ql_typeref(
                                    base_target, ctx=ctx
                                ),
                                expr=source_path,
                            )
                        )

                        # HACK: This is mad dodgy. Hide the Pointer
                        # when compiling.
                        old_expr = irexpr.expr
                        if isinstance(old_expr, irast.Pointer):
                            assert old_expr.expr
                            irexpr.expr = old_expr.expr
                        irexpr = dispatch.compile(cast_qlexpr, ctx=subctx)
                        if isinstance(old_expr, irast.Pointer):
                            old_expr.expr = irexpr.expr
                            irexpr.expr = old_expr

            else:
                expected = [
                    repr(str(base_target.get_displayname(ctx.env.schema)))
                ]

                ercls: type[errors.EdgeDBError]
                if ptrcls.is_property():
                    ercls = errors.InvalidPropertyTargetError
                else:
                    ercls = errors.InvalidLinkTargetError

                ptr_vn = ptrcls.get_verbosename(ctx.env.schema,
                                                with_parent=True)

                raise ercls(
                    f'invalid target for {ptr_vn}: '
                    f'{str(ptr_target.get_displayname(ctx.env.schema))!r} '
                    f'(expecting {" or ".join(expected)})'
                )

    # Prohibit update of readonly
    if (
        s_ctx.exprtype.is_update()
        and ptrcls
        and ptrcls.get_readonly(ctx.env.schema)
    ):
        raise errors.QueryError(
            f'cannot update {ptrcls.get_verbosename(ctx.env.schema)}: '
            f'it is declared as read-only',
            span=compexpr.span if compexpr else None,
        )

    if (
        s_ctx.exprtype.is_mutation()
        and ptrcls
        and ptrcls.get_protected(ctx.env.schema)
        and not from_default
    ):
        # 4.0 shipped with a bug where dumps included protected fields
        # in config values, so we need to suppress the error in that
        # case.  Default value injection is set up to *always* inject
        # on protected pointers.
        if ctx.env.options.dump_restore_mode:
            return ptrcls, None
        raise errors.QueryError(
            f'cannot assign to {ptrcls.get_verbosename(ctx.env.schema)}: '
            f'it is protected',
            span=compexpr.span if compexpr else None,
        )

    # Prohibit invalid operations on id
    id_access = (
        ptrcls
        and ptrcls.is_id_pointer(ctx.env.schema)
        and (
            not ctx.env.options.allow_user_specified_id
            or not s_ctx.exprtype.is_mutation()
        )
    )
    if (
        (compexpr is not None or is_polymorphic)
        and id_access and not from_default and ptrcls
    ):
        vn = ptrcls.get_verbosename(ctx.env.schema)
        if is_polymorphic:
            msg = (f'cannot access {vn} on a polymorphic '
                   f'shape element')
        else:
            msg = f'cannot assign to {vn}'
        if (
            not ctx.env.options.allow_user_specified_id
            and s_ctx.exprtype.is_mutation()
        ):
            hint = (
                'consider enabling the "allow_user_specified_id" '
                'configuration parameter to allow setting custom object ids'
            )
        else:
            hint = None

        raise errors.QueryError(msg, span=shape_el.span, hint=hint)

    # Common code for computed/not computed

    if (
        pending_pointers is not None and ptrcls is not None
        and (prev := pending_pointers.get(ptrcls)) is not None
        and prev.shape_origin is not qlast.ShapeOrigin.SPLAT_EXPANSION
    ):
        vnp = ptrcls.get_verbosename(ctx.env.schema, with_parent=True)
        raise errors.QueryError(
            f'duplicate definition of {vnp}',
            span=shape_el.span)

    if qlexpr is not None or ptrcls is None:
        src_scls: s_sources.Source

        if is_linkprop:
            # Proper checking was done when is_linkprop is defined.
            assert s_ctx.view_rptr is not None
            assert isinstance(s_ctx.view_rptr.ptrcls, s_links.Link)
            src_scls = s_ctx.view_rptr.ptrcls
        else:
            src_scls = view_scls

        if ptr_target.is_object_type():
            base = ctx.env.get_schema_object_and_track(
                sn.QualName('std', 'link'), expr=None)
        else:
            base = ctx.env.get_schema_object_and_track(
                sn.QualName('std', 'property'), expr=None)

        if base_ptrcls is not None:
            derive_from = base_ptrcls
        else:
            derive_from = base

        derived_name = schemactx.derive_view_name(
            base_ptrcls,
            derived_name_base=ptr_name,
            derived_name_quals=[str(src_scls.get_name(ctx.env.schema))],
            ctx=ctx,
        )

        existing = ctx.env.schema.get(
            derived_name, default=None, type=s_pointers.Pointer)
        if existing is not None:
            existing_target = existing.get_target(ctx.env.schema)
            assert existing_target is not None
            if ctx.recompiling_schema_alias:
                ptr_cardinality = existing.get_cardinality(ctx.env.schema)
                ptr_required = existing.get_required(ctx.env.schema)
            if ptr_target == existing_target:
                ptrcls = existing
            elif ptr_target.implicitly_castable_to(
                    existing_target, ctx.env.schema):

                ctx.env.ptr_ref_cache.pop(existing, None)
                ctx.env.schema = existing.set_target(
                    ctx.env.schema, ptr_target)
                ptrcls = existing
            else:
                vnp = existing.get_verbosename(
                    ctx.env.schema, with_parent=True)

                t1_vn = existing_target.get_verbosename(ctx.env.schema)
                t2_vn = ptr_target.get_verbosename(ctx.env.schema)

                if compexpr is not None:
                    span = compexpr.span
                else:
                    span = shape_el.expr.steps[-1].span
                raise errors.SchemaError(
                    f'cannot redefine {vnp} as {t2_vn}',
                    details=f'{vnp} is defined as {t1_vn}',
                    span=span,
                )
        else:
            ptrcls = schemactx.derive_ptr(
                derive_from, src_scls, ptr_target,
                derive_backlink=is_inbound_alias,
                derived_name=derived_name,
                ctx=ctx)

    elif ptrcls.get_target(ctx.env.schema) != ptr_target:
        ctx.env.ptr_ref_cache.pop(ptrcls, None)
        ctx.env.schema = ptrcls.set_target(ctx.env.schema, ptr_target)

    assert ptrcls is not None

    if materialized and not is_mutation and ctx.qlstmt:
        assert ptrcls not in ctx.env.materialized_sets
        ctx.env.materialized_sets[ptrcls] = ctx.qlstmt, materialized

        if irexpr:
            setgen.maybe_materialize(ptrcls, irexpr, ctx=ctx)

    if qlexpr is not None:
        ctx.env.schema = ptrcls.set_field_value(
            ctx.env.schema, 'defined_here', True
        )

    if qlexpr is not None:
        ctx.env.source_map[ptrcls] = irast.ComputableInfo(
            qlexpr=qlexpr,
            irexpr=irexpr,
            context=ctx,
            path_id=path_id,
            path_id_ns=s_ctx.path_id_namespace,
            shape_op=shape_el.operation.op,
            should_materialize=materialized or [],
        )

    if compexpr is not None or is_polymorphic or materialized:
        if (old_ptrref := ctx.env.ptr_ref_cache.get(ptrcls)):
            old_ptrref.is_computable = True

        ctx.env.schema = ptrcls.set_field_value(
            ctx.env.schema,
            'computable',
            True,
        )

        ctx.env.schema = ptrcls.set_field_value(
            ctx.env.schema,
            'owned',
            True,
        )

    if ptr_cardinality is not None:
        ctx.env.schema = ptrcls.set_field_value(
            ctx.env.schema, 'cardinality', ptr_cardinality)
        ctx.env.schema = ptrcls.set_field_value(
            ctx.env.schema, 'required', ptr_required)
    else:
        if qlexpr is None and ptrcls is not base_ptrcls:
            ctx.env.pointer_derivation_map[base_ptrcls].append(ptrcls)

        base_cardinality = None
        base_required = None
        if (
            base_ptrcls is not None
            and not base_ptrcls_is_alias
            and not is_independent_polymorphic
        ):
            base_cardinality = _get_base_ptr_cardinality(base_ptrcls, ctx=ctx)
            base_required = base_ptrcls.get_required(ctx.env.schema)

        if base_cardinality is None or not base_cardinality.is_known():
            # If the base cardinality is not known the we can't make
            # any checks here and will rely on validation in the
            # cardinality inferer.
            specified_cardinality = shape_el.cardinality
            specified_required = shape_el.required
        else:
            specified_cardinality = base_cardinality

            # Inferred optionality overrides that of the base pointer
            # if base pointer is not `required`, hence the is True check.
            if shape_el.required is not None:
                specified_required = shape_el.required
            elif base_required is True:
                specified_required = base_required
            else:
                specified_required = None

            if (
                shape_el.cardinality is not None
                and base_ptrcls is not None
                and shape_el.cardinality != base_cardinality
            ):
                base_src = base_ptrcls.get_source(ctx.env.schema)
                assert base_src is not None
                base_src_name = base_src.get_verbosename(ctx.env.schema)
                raise errors.SchemaError(
                    f'cannot redefine the cardinality of '
                    f'{ptrcls.get_verbosename(ctx.env.schema)}: '
                    f'it is defined as {base_cardinality.as_ptr_qual()!r} '
                    f'in the base {base_src_name}',
                    span=compexpr.span if compexpr else None,
                )

            if (
                shape_el.required is False
                and base_ptrcls is not None
                and base_required
            ):
                base_src = base_ptrcls.get_source(ctx.env.schema)
                assert base_src is not None
                base_src_name = base_src.get_verbosename(ctx.env.schema)
                raise errors.SchemaError(
                    f'cannot redefine '
                    f'{ptrcls.get_verbosename(ctx.env.schema)} '
                    f'as optional: it is defined as required '
                    f'in the base {base_src_name}',
                    span=compexpr.span if compexpr else None,
                )

        ctx.env.pointer_specified_info[ptrcls] = (
            specified_cardinality, specified_required, shape_el.span)

        ctx.env.schema = ptrcls.set_field_value(
            ctx.env.schema, 'cardinality', qltypes.SchemaCardinality.Unknown)

    if irexpr and not irexpr.span:
        irexpr.span = shape_el.span

    return ptrcls, irexpr


def derive_ptrcls(
    view_rptr: context.ViewRPtr,
    *,
    target_scls: s_types.Type,
    ctx: context.ContextLevel
) -> s_pointers.Pointer:

    if view_rptr.ptrcls is None:
        if view_rptr.base_ptrcls is None:
            if target_scls.is_object_type():
                base = ctx.env.get_schema_object_and_track(
                    sn.QualName('std', 'link'), expr=None)
                view_rptr.base_ptrcls = cast(s_links.Link, base)
            else:
                base = ctx.env.get_schema_object_and_track(
                    sn.QualName('std', 'property'), expr=None)
                view_rptr.base_ptrcls = cast(s_props.Property, base)

        derived_name = schemactx.derive_view_name(
            view_rptr.base_ptrcls,
            derived_name_base=view_rptr.ptrcls_name,
            derived_name_quals=(
                str(view_rptr.source.get_name(ctx.env.schema)),
            ),
            ctx=ctx)

        is_inbound_alias = (
            view_rptr.rptr_dir is s_pointers.PointerDirection.Inbound)
        view_rptr.ptrcls = schemactx.derive_ptr(
            view_rptr.base_ptrcls, view_rptr.source, target_scls,
            derived_name=derived_name,
            derive_backlink=is_inbound_alias,
            ctx=ctx
        )

    else:
        view_rptr.ptrcls = schemactx.derive_ptr(
            view_rptr.ptrcls, view_rptr.source, target_scls,
            derived_name_quals=(
                str(view_rptr.source.get_name(ctx.env.schema)),
            ),
            ctx=ctx
        )

    return view_rptr.ptrcls


def _link_has_shape(
    ptrcls: s_pointers.PointerLike, *, ctx: context.ContextLevel
) -> bool:
    if not isinstance(ptrcls, s_links.Link):
        return False

    ptr_shape = {p for p, _ in ctx.env.view_shapes[ptrcls]}
    for p in ptrcls.get_pointers(ctx.env.schema).objects(ctx.env.schema):
        if p.is_special_pointer(ctx.env.schema) or p not in ptr_shape:
            continue
        else:
            return True

    return False


def _get_base_ptr_cardinality(
    ptrcls: s_pointers.Pointer,
    *,
    ctx: context.ContextLevel,
) -> Optional[qltypes.SchemaCardinality]:
    ptr_name = ptrcls.get_name(ctx.env.schema)
    if ptr_name in {
        sn.QualName('std', 'link'),
        sn.QualName('std', 'property')
    }:
        return None
    else:
        return ptrcls.get_cardinality(ctx.env.schema)


def has_implicit_tid(
    stype: s_types.Type, *, is_mutation: bool, ctx: context.ContextLevel
) -> bool:

    return (
        stype.is_object_type()
        and not stype.is_free_object_type(ctx.env.schema)
        and not is_mutation
        and ctx.implicit_tid_in_shapes
    )


def has_implicit_tname(
    stype: s_types.Type, *, is_mutation: bool, ctx: context.ContextLevel
) -> bool:

    return (
        stype.is_object_type()
        and not stype.is_free_object_type(ctx.env.schema)
        and not is_mutation
        and ctx.implicit_tname_in_shapes
    )


def has_implicit_type_computables(
    stype: s_types.Type, *, is_mutation: bool, ctx: context.ContextLevel
) -> bool:

    return (
        has_implicit_tid(stype, is_mutation=is_mutation, ctx=ctx)
        or has_implicit_tname(stype, is_mutation=is_mutation, ctx=ctx)
    )


def _inline_type_computable(
    ir_set: irast.Set,
    stype: s_objtypes.ObjectType,
    compname: str,
    propname: str,
    *,
    shape_ptrs: list[ShapePtr],
    ctx: context.ContextLevel,
) -> None:
    assert isinstance(stype, s_objtypes.ObjectType)
    # Injecting into non-view objects /almost/ works, but it fails if the
    # object is in the std library, and is dodgy always.
    # Prevent it in general to find bugs faster.
    assert stype.is_view(ctx.env.schema)

    ptr: Optional[s_pointers.Pointer]
    try:
        ptr = setgen.resolve_ptr(stype, compname, track_ref=False, ctx=ctx)
        # The pointer might exist on the base type. That doesn't count,
        # and we need to re-inject it.
        if ptr not in ctx.env.source_map:
            ptr = None
    except errors.InvalidReferenceError:
        ptr = None

    ptr_set = None
    if ptr is None:
        ql = qlast.ShapeElement(
            required=True,
            expr=qlast.Path(
                steps=[qlast.Ptr(
                    name=compname,
                    direction=s_pointers.PointerDirection.Outbound,
                )],
            ),
            compexpr=qlast.Path(
                steps=[
                    qlast.SpecialAnchor(name='__source__'),
                    qlast.Ptr(
                        name='__type__',
                        direction=s_pointers.PointerDirection.Outbound,
                    ),
                    qlast.Ptr(
                        name=propname,
                        direction=s_pointers.PointerDirection.Outbound,
                    )
                ]
            )
        )
        ql_desc = _shape_el_ql_to_shape_el_desc(
            ql, source=stype, s_ctx=ShapeContext(), ctx=ctx
        )

        with ctx.new() as scopectx:
            scopectx.anchors = scopectx.anchors.copy()
            # Use the actual base type as the root of the injection, so that
            # if a user has overridden `__type__` in a computable,
            # we see through that.
            base_stype = stype.get_nearest_non_derived_parent(ctx.env.schema)
            base_ir_set = setgen.ensure_set(
                ir_set, type_override=base_stype, ctx=scopectx)

            scopectx.anchors['__source__'] = base_ir_set
            ptr, ptr_set = _normalize_view_ptr_expr(
                base_ir_set,
                ql_desc,
                stype,
                path_id=ir_set.path_id,
                s_ctx=ShapeContext(),
                ctx=scopectx
            )

    # even if the pointer was not created here, or was already present in
    # the shape, we set defined_here, so it is not inlined in `extend_path`.
    ctx.env.schema = ptr.set_field_value(
        ctx.env.schema, 'defined_here', True
    )

    view_shape = ctx.env.view_shapes[stype]
    view_shape_ptrs = {p for p, _ in view_shape}
    if ptr not in view_shape_ptrs:
        if ptr not in ctx.env.pointer_specified_info:
            ctx.env.pointer_specified_info[ptr] = (None, None, None)
        view_shape.insert(0, (ptr, qlast.ShapeOp.ASSIGN))
        shape_ptrs.insert(
            0, ShapePtr(ir_set, ptr, qlast.ShapeOp.ASSIGN, ptr_set, None)
        )


def _get_shape_configuration_inner(
    ir_set: irast.Set,
    shape_ptrs: list[ShapePtr],
    stype: s_types.Type,
    *,
    parent_view_type: Optional[s_types.ExprType]=None,
    ctx: context.ContextLevel
) -> None:
    is_objtype = ir_set.path_id.is_objtype_path()
    all_materialize = all(
        op == qlast.ShapeOp.MATERIALIZE for _, _, op, _, _ in shape_ptrs
    )

    if is_objtype:
        assert isinstance(stype, s_objtypes.ObjectType)

        view_type = stype.get_expr_type(ctx.env.schema)
        is_mutation = view_type in (s_types.ExprType.Insert,
                                    s_types.ExprType.Update)
        is_parent_update = parent_view_type is s_types.ExprType.Update

        implicit_id = (
            # shape is not specified at all
            not shape_ptrs
            # implicit ids are always wanted
            or (ctx.implicit_id_in_shapes and not is_mutation)
            # we are inside an UPDATE shape and this is
            # an explicit expression (link target update)
            or (is_parent_update and irutils.sub_expr(ir_set) is not None)
            or all_materialize
        )
        # We actually *always* inject an implicit id, but it's just
        # there in case materialization needs it, in many cases.
        implicit_op = qlast.ShapeOp.ASSIGN
        if not implicit_id:
            implicit_op = qlast.ShapeOp.MATERIALIZE

        # We want the id in this shape and it's not already there,
        # so insert it in the first position.
        pointers = stype.get_pointers(ctx.env.schema).objects(
            ctx.env.schema)
        view_shape = ctx.env.view_shapes[stype]
        view_shape_ptrs = {p for p, _ in view_shape}
        for ptr in pointers:
            if ptr.is_id_pointer(ctx.env.schema):
                if ptr not in view_shape_ptrs:
                    shape_metadata = ctx.env.view_shapes_metadata[stype]
                    view_shape.insert(0, (ptr, implicit_op))
                    shape_metadata.has_implicit_id = True
                    shape_ptrs.insert(
                        0, ShapePtr(ir_set, ptr, implicit_op, None, None)
                    )
                break

    is_mutation = parent_view_type in {
        s_types.ExprType.Insert,
        s_types.ExprType.Update
    }

    if (
        stype is not None
        and has_implicit_tid(stype, is_mutation=is_mutation, ctx=ctx)
    ):
        # HACK: Make sure set is here first, to avoid potential
        # warn_old_scoping warnings.
        pathctx.register_set_in_scope(ir_set, ctx=ctx)
        assert isinstance(stype, s_objtypes.ObjectType)
        _inline_type_computable(
            ir_set, stype, '__tid__', 'id', ctx=ctx, shape_ptrs=shape_ptrs)

    if (
        stype is not None
        and has_implicit_tname(stype, is_mutation=is_mutation, ctx=ctx)
    ):
        # HACK: Make sure set is here first, to avoid potential
        # warn_old_scoping warnings.
        pathctx.register_set_in_scope(ir_set, ctx=ctx)
        assert isinstance(stype, s_objtypes.ObjectType)
        _inline_type_computable(
            ir_set, stype, '__tname__', 'name', ctx=ctx, shape_ptrs=shape_ptrs)


def _get_early_shape_configuration(
    ir_set: irast.Set,
    in_shape_ptrs: list[ShapePtr],
    *,
    rptrcls: Optional[s_pointers.Pointer],
    parent_view_type: Optional[s_types.ExprType]=None,
    ctx: context.ContextLevel
) -> list[ShapePtr]:
    """Return a list of (source_set, ptrcls) pairs as a shape for a given set.
    """

    stype = setgen.get_set_type(ir_set, ctx=ctx)

    # HACK: For some reason, all the link properties need to go last or
    # things choke in native output mode?
    shape_ptrs = sorted(
        in_shape_ptrs,
        key=lambda arg: arg.ptrcls.is_link_property(ctx.env.schema),
    )

    _get_shape_configuration_inner(
        ir_set, shape_ptrs, stype, parent_view_type=parent_view_type, ctx=ctx)

    return shape_ptrs


def _get_late_shape_configuration(
    ir_set: irast.Set,
    *,
    rptr: Optional[irast.Pointer]=None,
    parent_view_type: Optional[s_types.ExprType]=None,
    ctx: context.ContextLevel
) -> list[ShapePtr]:

    """Return a list of (source_set, ptrcls) pairs as a shape for a given set.
    """

    stype = setgen.get_set_type(ir_set, ctx=ctx)

    sources: list[s_types.Type | s_pointers.PointerLike] = []
    link_view = False
    is_objtype = ir_set.path_id.is_objtype_path()

    if rptr is None:
        if isinstance(ir_set.expr, irast.Pointer):
            rptr = ir_set.expr
    elif ir_set.expr and not isinstance(ir_set.expr, irast.Pointer):
        # If we have a specified rptr but set is not a pointer itself,
        # construct a version of the set that is pointer so it can be used
        # as the path tip for applying pointers. This ensures that
        # we can find link properties on late shapes.
        ir_set = setgen.new_set_from_set(
            ir_set, expr=rptr.replace(expr=ir_set.expr, is_phony=True), ctx=ctx
        )

    rptrcls: Optional[s_pointers.PointerLike]
    if rptr is not None:
        rptrcls = typegen.ptrcls_from_ptrref(rptr.ptrref, ctx=ctx)
    else:
        rptrcls = None

    link_view = (
        rptrcls is not None and
        not rptrcls.is_link_property(ctx.env.schema) and
        _link_has_shape(rptrcls, ctx=ctx)
    )

    if is_objtype or not link_view:
        sources.append(stype)

    if link_view:
        assert rptrcls is not None
        sources.append(rptrcls)

    shape_ptrs: list[ShapePtr] = []

    for source in sources:
        for ptr, shape_op in ctx.env.view_shapes[source]:
            shape_ptrs.append(ShapePtr(ir_set, ptr, shape_op, None, None)
        )

    _get_shape_configuration_inner(
        ir_set, shape_ptrs, stype, parent_view_type=parent_view_type, ctx=ctx)

    return shape_ptrs


@functools.singledispatch
def late_compile_view_shapes(
    expr: irast.Base,
    *,
    rptr: Optional[irast.Pointer] = None,
    parent_view_type: Optional[s_types.ExprType] = None,
    ctx: context.ContextLevel,
) -> None:
    """Do a late insertion of any unprocessed shapes.

    We mainly compile shapes in process_view, but late_compile_view_shapes
    is responsible for compiling implicit exposed shapes (containing
    only id) and in cases like accessing a semi-joined shape.

    """
    pass


@late_compile_view_shapes.register(irast.Set)
def _late_compile_view_shapes_in_set(
        ir_set: irast.Set, *,
        rptr: Optional[irast.Pointer] = None,
        parent_view_type: Optional[s_types.ExprType] = None,
        ctx: context.ContextLevel) -> None:

    shape_ptrs = _get_late_shape_configuration(
        ir_set, rptr=rptr, parent_view_type=parent_view_type, ctx=ctx)

    # We want to push down the shape to better correspond with where it
    # appears in the query (rather than lifting it up to the first
    # place the view_type appears---this is a little hacky, because
    # letting it be lifted up is the natural thing with our view type-driven
    # shape compilation).
    #
    # This is to avoid losing subquery distinctions (in cases
    # like test_edgeql_scope_tuple_15), and generally seems more natural.
    is_definition_or_not_pointer = (
        not isinstance(ir_set.expr, irast.Pointer) or ir_set.expr.is_definition
    )
    expr = irutils.sub_expr(ir_set)
    if (
        isinstance(expr, (irast.SelectStmt, irast.GroupStmt))
        and is_definition_or_not_pointer
        and (setgen.get_set_type(ir_set, ctx=ctx) ==
             setgen.get_set_type(expr.result, ctx=ctx))
    ):
        child = expr.result
        set_scope = pathctx.get_set_scope(ir_set, ctx=ctx)

        if shape_ptrs:
            pathctx.register_set_in_scope(ir_set, ctx=ctx)
        with ctx.new() as scopectx:
            if set_scope is not None:
                scopectx.path_scope = set_scope

            if not rptr and isinstance(ir_set.expr, irast.Pointer):
                rptr = ir_set.expr
            late_compile_view_shapes(
                child,
                rptr=rptr,
                parent_view_type=parent_view_type,
                ctx=scopectx)

        ir_set.shape_source = child if child.shape else child.shape_source
        return

    if shape_ptrs:
        pathctx.register_set_in_scope(ir_set, ctx=ctx)
        stype = setgen.get_set_type(ir_set, ctx=ctx)

        # If the shape has already been populated (because the set is
        # referenced multiple times), then we've got nothing to do.
        if ir_set.shape:
            # We want to make sure anything inside of the shape gets
            # processed, though, so we do need to look through the
            # internals.
            for element, _ in ir_set.shape:
                element_scope = pathctx.get_set_scope(element, ctx=ctx)
                with ctx.new() as scopectx:
                    if element_scope:
                        scopectx.path_scope = element_scope
                    late_compile_view_shapes(
                        element,
                        parent_view_type=stype.get_expr_type(ctx.env.schema),
                        ctx=scopectx)

            return

        shape = []
        for path_tip, ptr, shape_op, _, ptr_span in shape_ptrs:
            ptr_span = None
            if ptr in ctx.env.pointer_specified_info:
                _, _, ptr_span = ctx.env.pointer_specified_info[ptr]

            element = setgen.extend_path(
                path_tip,
                ptr,
                same_computable_scope=True,
                span=ptr_span,
                ctx=ctx,
            )

            element_scope = pathctx.get_set_scope(element, ctx=ctx)

            if element_scope is None:
                element_scope = ctx.path_scope.attach_fence()
                pathctx.assign_set_scope(element, element_scope, ctx=ctx)

            if element_scope.namespaces:
                element.path_id = element.path_id.merge_namespace(
                    element_scope.namespaces)

            with ctx.new() as scopectx:
                scopectx.path_scope = element_scope
                late_compile_view_shapes(
                    element,
                    parent_view_type=stype.get_expr_type(ctx.env.schema),
                    ctx=scopectx)

            shape.append((element, shape_op))

        ir_set.shape = tuple(shape)

    elif expr is not None:
        set_scope = pathctx.get_set_scope(ir_set, ctx=ctx)
        if set_scope is not None:
            with ctx.new() as scopectx:
                scopectx.path_scope = set_scope
                late_compile_view_shapes(expr, ctx=scopectx)
        else:
            late_compile_view_shapes(expr, ctx=ctx)

    elif isinstance(ir_set.expr, irast.TupleIndirectionPointer):
        late_compile_view_shapes(ir_set.expr.source, ctx=ctx)


@late_compile_view_shapes.register(irast.SelectStmt)
def _late_compile_view_shapes_in_select(
    stmt: irast.SelectStmt,
    *,
    rptr: Optional[irast.Pointer] = None,
    parent_view_type: Optional[s_types.ExprType] = None,
    ctx: context.ContextLevel,
) -> None:
    late_compile_view_shapes(
        stmt.result, rptr=rptr, parent_view_type=parent_view_type, ctx=ctx)


@late_compile_view_shapes.register(irast.Call)
def _late_compile_view_shapes_in_call(
    expr: irast.Call,
    *,
    rptr: Optional[irast.Pointer] = None,
    parent_view_type: Optional[s_types.ExprType] = None,
    ctx: context.ContextLevel,
) -> None:

    if expr.func_polymorphic:
        for call_arg in expr.args.values():
            arg = call_arg.expr
            arg_scope = pathctx.get_set_scope(arg, ctx=ctx)
            if arg_scope is not None:
                with ctx.new() as scopectx:
                    scopectx.path_scope = arg_scope
                    late_compile_view_shapes(arg, ctx=scopectx)
            else:
                late_compile_view_shapes(arg, ctx=ctx)


@late_compile_view_shapes.register(irast.Tuple)
def _late_compile_view_shapes_in_tuple(
    expr: irast.Tuple,
    *,
    rptr: Optional[irast.Pointer] = None,
    parent_view_type: Optional[s_types.ExprType] = None,
    ctx: context.ContextLevel,
) -> None:
    for element in expr.elements:
        late_compile_view_shapes(element.val, ctx=ctx)


@late_compile_view_shapes.register(irast.Array)
def _late_compile_view_shapes_in_array(
    expr: irast.Array,
    *,
    rptr: Optional[irast.Pointer] = None,
    parent_view_type: Optional[s_types.ExprType] = None,
    ctx: context.ContextLevel,
) -> None:
    for element in expr.elements:
        late_compile_view_shapes(element, ctx=ctx)


def _record_created_collection_types(
    type: s_types.Type, ctx: context.ContextLevel
) -> None:
    """
    Record references to implicitly defined collection types,
    so that the alias delta machinery can pick them up.
    """

    if isinstance(
        type, s_types.Collection
    ) and not ctx.env.orig_schema.get_by_id(type.id, default=None):
        for sub_type in type.get_subtypes(ctx.env.schema):
            _record_created_collection_types(sub_type, ctx)


================================================
FILE: edb/edgeql/declarative.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""SDL loader.

The purpose of this module is to take a set of SDL documents and
transform them into schema modules.  The crux of the task is to
break the SDL declarations into a correct sequence of DDL commands,
considering all possible cyclic references.  The dependency tracking
is complicated by the presence of expressions in schema definitions.
In those cases we make a best-effort tracing using a rudimentary
EdgeQL AST visitor.
"""

from __future__ import annotations
from typing import (
    Optional,
    AbstractSet,
    Iterable,
    Mapping,
    MutableSet,
    TypedDict,
    cast,
)

import copy
import functools
from collections import defaultdict

from edb import errors

from edb.common import parsing
from edb.common import topological
from edb.common import english
from edb.common.ordered import OrderedSet

from edb.edgeql import ast as qlast
from edb.edgeql import codegen as qlcodegen
from edb.edgeql import parser as qlparser
from edb.edgeql import tracer as qltracer
from edb.edgeql import utils as qlutils

from edb.schema import annos as s_anno
from edb.schema import constraints as s_constr
from edb.schema import indexes as s_indexes
from edb.schema import links as s_links
from edb.schema import name as s_name
from edb.schema import objects as s_obj
from edb.schema import objtypes as s_objtypes
from edb.schema import properties as s_props
from edb.schema import pseudo as s_pseudo
from edb.schema import scalars as s_scalars
from edb.schema import schema as s_schema
from edb.schema import sources as s_sources
from edb.schema import types as s_types
from edb.schema import utils as s_utils


class TraceContextBase:

    schema: s_schema.Schema
    module: str
    depstack: list[tuple[qlast.DDLOperation, s_name.QualName]]
    modaliases: dict[Optional[str], str]
    objects: dict[s_name.QualName, Optional[qltracer.ObjectLike]]
    pointers: dict[s_name.UnqualName, set[s_name.QualName]]
    parents: dict[s_name.QualName, set[s_name.QualName]]
    ancestors: dict[s_name.QualName, set[s_name.QualName]]
    defdeps: dict[s_name.QualName, set[s_name.QualName]]
    constraints: dict[s_name.QualName, set[s_name.QualName]]
    local_modules: AbstractSet[str]

    def __init__(
        self,
        schema: s_schema.Schema,
        local_modules: AbstractSet[str],
    ) -> None:
        self.schema = schema
        self.module = '__not_set__'
        self.depstack = []
        self.modaliases = {}
        self.objects = {}
        self.pointers = {}
        self.parents = {}
        self.ancestors = {}
        self.defdeps = defaultdict(set)
        self.constraints = defaultdict(set)
        self.local_modules = local_modules

    def set_module(self, module: str) -> None:
        self.module = module
        self.modaliases = {None: module}

    def get_local_name(
        self,
        ref: qlast.ObjectRef,
        declaration: bool=False,
    ) -> s_name.QualName:
        return qltracer.resolve_name(
            ref,
            current_module=self.module,
            schema=self.schema,
            objects=self.objects,
            modaliases=None,
            local_modules=self.local_modules,
            declaration=declaration,
        )

    def get_ref_name(self, ref: qlast.BaseObjectRef) -> s_name.QualName:
        if isinstance(ref, qlast.ObjectRef):
            return self.get_local_name(ref)
        elif isinstance(ref, qlast.PseudoObjectRef):
            # We pretend `anytype` has a fully-qualified name here, because
            # the tracing machinery really wants to work with fully-qualified
            # names and wants to distinguish between objects from the standard
            # library and the user-defines ones.
            # Ditto for `anytuple` and `anyobject`.
            return s_name.QualName('std', ref.name)
        else:
            raise TypeError(
                "ObjectRef expected "
                "(got type {!r})".format(type(ref).__name__)
            )

    def get_fq_name(
        self,
        decl: qlast.DDLOperation,
        declaration: bool=False,
    ) -> tuple[str, s_name.QualName]:
        # Get the basic name form.
        if isinstance(decl, qlast.CreateConcretePointer):
            name = decl.name.name
            parent_expected = True
        elif isinstance(decl, qlast.SetField):
            name = decl.name
            parent_expected = True
        elif isinstance(decl, qlast.ObjectDDL):
            fq_name = self.get_local_name(decl.name, declaration=declaration)
            name = str(fq_name)
            parent_expected = False
        else:
            raise AssertionError(f'unexpected DDL node: {decl!r}')

        if self.depstack:
            parent_name = self.depstack[-1][1]
            fq_name = s_name.QualName(
                module=parent_name.module,
                name=f'{parent_name.name}@{name}'
            )
        elif parent_expected:
            raise AssertionError(
                f'missing expected parent context for {decl!r}')

        # Additionally, functions and concrete constraints may need an
        # extra name piece.
        extra_name = None
        if isinstance(decl, qlast.CreateFunction):
            # Functions are defined by their name + call signature, so we
            # need to add that to the "extra_name".
            extra_name = f'({qlcodegen.generate_source(decl.params)})'

        elif isinstance(decl, qlast.CreateConcreteConstraint):
            # Concrete constraints are defined by their expr, so we need
            # to add that to the "extra_name".
            exprs = list(decl.args)
            if decl.subjectexpr:
                exprs.append(decl.subjectexpr)
            if decl.except_expr:
                # Add an extra dummy argument to distinguish between
                # ON and EXCEPT, when only one is present
                exprs.append(qlast.Set(elements=[]))
                exprs.append(decl.except_expr)

            for cmd in decl.commands:
                if isinstance(cmd, qlast.SetField) and cmd.name == "expr":
                    assert cmd.value, "sdl SetField should always have value"
                    assert isinstance(cmd.value, qlast.Expr)
                    exprs.append(cmd.value)

            extra_name = '|'.join(qlcodegen.generate_source(e) for e in exprs)

        elif isinstance(decl, qlast.CreateConcreteIndex):
            # Indexes are defined by what they are an index over, so we need
            # to add that to the "extra_name".
            extra_name = f'({qlcodegen.generate_source(decl.expr)})'
            if decl.except_expr:
                except_bit = f'({qlcodegen.generate_source(decl.except_expr)})'
                extra_name = f'{extra_name}/{except_bit}'

        if extra_name:
            fq_name = s_name.QualName(
                module=fq_name.module,
                name=f'{fq_name.name}@@{extra_name}',
            )

        return name, fq_name


def get_verbosename_from_fqname(
    fq_name: s_name.QualName,
    ctx: DepTraceContext | LayoutTraceContext,
) -> str:
    traceobj = ctx.objects[fq_name]
    assert traceobj is not None

    name = str(fq_name)
    clsname = traceobj.get_schema_class_displayname()
    ofobj = ''

    if isinstance(traceobj, qltracer.Alias):
        clsname = 'alias'
    elif isinstance(traceobj, qltracer.ObjectType):
        clsname = 'object'
    elif isinstance(traceobj, qltracer.ScalarType):
        clsname = 'scalar'
    elif isinstance(traceobj, qltracer.Function):
        name = str(fq_name).split('@@', 1)[0]
        if isinstance(ctx, DepTraceContext):
            node = ctx.ddlgraph[fq_name].item
            assert isinstance(node, qlast.FunctionCommand)
            params = ','.join(
                qlcodegen.generate_source(param, sdlmode=True)
                for param in node.params
            )
            name = f"{name}({params})"
    elif isinstance(traceobj, qltracer.Pointer):
        ofobj, name = str(fq_name).split('@', 1)
        ofobj = f" of object type '{ofobj}'"
    elif isinstance(traceobj, qltracer.AccessPolicy):
        clsname = 'access policy'
        ofobj, name = str(fq_name).split('@', 1)
        _, name = name.split('::')
        ofobj = f" of object type '{ofobj}'"
    elif isinstance(traceobj, qltracer.Trigger):
        clsname = 'trigger'
        ofobj, name = str(fq_name).split('@', 1)
        _, name = name.split('::')
        ofobj = f" of object type '{ofobj}'"
    elif isinstance(traceobj, qltracer.ConcreteIndex):
        clsname = 'index'
        ofobj, name = str(fq_name).split('@', 1)
        name, _ = name.split('@@', 1)
        if name == str(s_indexes.DEFAULT_INDEX):
            name = ''
        ofobj = f" of object type '{ofobj}'"
    elif isinstance(traceobj, qltracer.Field):
        clsname = 'field'
        obj, name = fq_name.name.rsplit('@', 1)
        ofobj = ' of ' + get_verbosename_from_fqname(
            s_name.QualName(fq_name.module, obj), ctx)

    if name:
        return f"{clsname} '{name}'{ofobj}"
    else:
        return f"{clsname}{ofobj}"


class InheritanceGraphEntry(TypedDict):

    item: qltracer.NamedObject
    deps: AbstractSet[s_name.Name]
    merge: AbstractSet[s_name.Name]


class LayoutTraceContext(TraceContextBase):

    inh_graph: dict[
        s_name.QualName,
        topological.DepGraphEntry[
            s_name.QualName,
            qltracer.NamedObject,
            bool,
        ],
    ]

    def __init__(
        self,
        schema: s_schema.Schema,
        local_modules: AbstractSet[str],
    ) -> None:
        super().__init__(schema, local_modules)
        self.inh_graph = {}


DDLGraph = dict[
    s_name.QualName,
    topological.DepGraphEntry[s_name.QualName, qlast.DDLCommand, bool],
]


class DepTraceContext(TraceContextBase):

    def __init__(
        self,
        schema: s_schema.Schema,
        ddlgraph: DDLGraph,
        objects: dict[s_name.QualName, Optional[qltracer.ObjectLike]],
        pointers: dict[s_name.UnqualName, set[s_name.QualName]],
        parents: dict[s_name.QualName, set[s_name.QualName]],
        ancestors: dict[s_name.QualName, set[s_name.QualName]],
        defdeps: dict[s_name.QualName, set[s_name.QualName]],
        constraints: dict[s_name.QualName, set[s_name.QualName]],
        local_modules: AbstractSet[str],
    ) -> None:
        super().__init__(schema, local_modules)
        self.ddlgraph = ddlgraph
        self.objects = objects
        self.pointers = pointers
        self.parents = parents
        self.ancestors = ancestors
        self.defdeps = defdeps
        self.constraints = constraints


class Dependency:
    pass


class TypeDependency(Dependency):

    texpr: qlast.TypeExpr

    def __init__(self, texpr: qlast.TypeExpr) -> None:
        self.texpr = texpr


class ExprDependency(Dependency):

    expr: qlast.Expr

    def __init__(self, expr: qlast.Expr) -> None:
        self.expr = expr


class FunctionDependency(ExprDependency):

    params: Mapping[str, qlast.TypeExpr]

    def __init__(
        self,
        expr: qlast.Expr,
        params: Mapping[str, qlast.TypeExpr],
    ) -> None:
        super().__init__(expr=expr)
        self.params = params


def sdl_to_ddl(
    schema: s_schema.Schema,
    documents: Mapping[str, list[qlast.DDLCommand]],
) -> tuple[qlast.DDLCommand, ...]:

    ddlgraph: DDLGraph = {}
    mods: list[qlast.DDLCommand] = []

    ctx = LayoutTraceContext(schema, frozenset(mod for mod in documents))

    ctx.objects[s_name.QualName('std', 'anytype')] = (
        schema.get_global(s_pseudo.PseudoType, 'anytype'))
    ctx.objects[s_name.QualName('std', 'anytuple')] = (
        schema.get_global(s_pseudo.PseudoType, 'anytuple'))
    ctx.objects[s_name.QualName('std', 'anyobject')] = (
        schema.get_global(s_pseudo.PseudoType, 'anyobject'))

    for module_name, declarations in documents.items():
        ctx.set_module(module_name)
        for decl_ast in declarations:
            if isinstance(decl_ast, qlast.CreateObject):
                _, fq_name = ctx.get_fq_name(decl_ast, declaration=True)

                if isinstance(decl_ast, qlast.CreateObjectType):
                    ctx.objects[fq_name] = qltracer.ObjectType(fq_name)
                elif isinstance(decl_ast, qlast.CreateAlias):
                    ctx.objects[fq_name] = qltracer.Alias(fq_name)
                elif isinstance(decl_ast, qlast.CreateScalarType):
                    ctx.objects[fq_name] = qltracer.ScalarType(fq_name)
                elif isinstance(decl_ast, qlast.CreateLink):
                    ctx.objects[fq_name] = qltracer.Link(
                        fq_name, source=None, target=None)
                elif isinstance(decl_ast, qlast.CreateProperty):
                    ctx.objects[fq_name] = qltracer.Property(
                        fq_name, source=None, target=None)
                elif isinstance(decl_ast, qlast.CreateFunction):
                    ctx.objects[fq_name] = qltracer.Function(fq_name)
                elif isinstance(decl_ast, qlast.CreateConstraint):
                    ctx.objects[fq_name] = qltracer.Constraint(fq_name)
                elif isinstance(decl_ast, qlast.CreateAnnotation):
                    ctx.objects[fq_name] = qltracer.Annotation(fq_name)
                elif isinstance(decl_ast, qlast.CreateGlobal):
                    ctx.objects[fq_name] = qltracer.Global(fq_name)
                elif isinstance(decl_ast, qlast.CreatePermission):
                    ctx.objects[fq_name] = qltracer.Permission(fq_name)
                elif isinstance(decl_ast, qlast.CreateIndex):
                    ctx.objects[fq_name] = qltracer.Index(fq_name)
                else:
                    raise AssertionError(
                        f'unexpected SDL declaration: {decl_ast}')

    for module_name, declarations in documents.items():
        ctx.set_module(module_name)
        for decl_ast in declarations:
            trace_layout(decl_ast, ctx=ctx)

    # compute the ancestors graph
    for obj_name in ctx.parents.keys():
        ctx.ancestors[obj_name] = get_ancestors(
            obj_name, ctx.ancestors, ctx.parents)

    topological.normalize(
        ctx.inh_graph,
        merger=_graph_merge_cb,  # type: ignore
        schema=schema,
    )

    tracectx = DepTraceContext(
        schema, ddlgraph, ctx.objects, ctx.pointers, ctx.parents, ctx.ancestors,
        ctx.defdeps, ctx.constraints, ctx.local_modules,
    )

    created_modules = set()
    for module_name, declarations in documents.items():
        tracectx.set_module(module_name)
        # module (and any enclosing modules) needs to be created
        # regardless of whether its contents are empty or not
        parts = module_name.split('::')
        for i in range(len(parts)):
            n = '::'.join(parts[:i + 1])
            if n not in created_modules:
                created_modules.add(n)
                mods.append(qlast.CreateModule(name=qlast.ObjectRef(name=n)))
        for decl_ast in declarations:
            trace_dependencies(decl_ast, ctx=tracectx)

    for ddlentry in ddlgraph.values():
        # Filter out deps that are in the schema but not in ctx.objects.
        # Deps that are in neither get left in, so that we catch the bug.
        deps = {
            x for x in ddlentry.deps
            if x in ctx.objects or not schema.get(x, default=None)
        }
        weak_deps = {
            x for x in ddlentry.weak_deps
            if x in ctx.objects or not schema.get(x, default=None)
        }

        # Before sorting normalize all ordering, to make sure that errors
        # are consistent.
        ddlentry.deps = OrderedSet(sorted(deps))
        ddlentry.weak_deps = OrderedSet(sorted(weak_deps))

    try:
        ordered = topological.sort(ddlgraph, allow_unresolved=False)
    except topological.CycleError as e:
        assert isinstance(e.item, s_name.QualName)
        node = tracectx.ddlgraph[e.item].item
        item_vn = get_verbosename_from_fqname(e.item, tracectx)

        if e.path is not None and len(e.path):
            # Recursion involving more than one schema object.
            rec_vn = get_verbosename_from_fqname(e.path[-1], tracectx)
            msg = (
                f'definition dependency cycle between {rec_vn} '
                f'and {item_vn}'
            )
        else:
            # A single schema object with a recursive definition.
            msg = f'{item_vn} is defined recursively'

        raise errors.InvalidDefinitionError(msg, span=node.span) from e

    return tuple(mods) + tuple(ordered)


def _graph_merge_cb(
    item: qltracer.NamedObject,
    parent: qltracer.NamedObject,
    *,
    schema: s_schema.Schema,
) -> qltracer.NamedObject:
    if (
        isinstance(item, (qltracer.Source, s_sources.Source))
        and isinstance(parent, (qltracer.Source, s_sources.Source))
    ):
        return _merge_items(item, parent, schema=schema)
    else:
        return item


def _merge_items(
    item: qltracer.Source_T,
    parent: qltracer.SourceLike_T,
    *,
    schema: s_schema.Schema,
) -> qltracer.Source_T:

    item_ptrs = dict(item.get_pointers(schema).items(schema))

    for pn, ptr in parent.get_pointers(schema).items(schema):
        if not isinstance(ptr, (qltracer.Pointer, s_sources.Source)):
            continue

        if pn not in item_ptrs:
            PointerType = (qltracer.Property if ptr.is_property(schema)
                           else qltracer.Link)
            ptr_copy = PointerType(
                s_name.QualName('__', pn.name),
                source=ptr.get_source(schema),
                target=ptr.get_target(schema),
            )
            ptr_copy.pointers = dict(
                ptr.get_pointers(schema).items(schema))
            item.pointers[pn] = ptr_copy
        else:
            item_ptr = item.getptr(schema, pn)
            assert isinstance(item_ptr, (qltracer.Pointer, s_sources.Source))
            PointerType = (qltracer.Property if item_ptr.is_property(schema)
                           else qltracer.Link)
            ptr_copy = PointerType(
                s_name.QualName('__', pn.name),
                source=item,
                target=item_ptr.get_target(schema),
            )
            ptr_copy.pointers = dict(
                item_ptr.get_pointers(schema).items(schema))
            item.pointers[pn] = _merge_items(ptr_copy, ptr, schema=schema)

    return item


@functools.singledispatch
def trace_layout(
    node: qlast.Base,
    *,
    ctx: LayoutTraceContext,
) -> None:
    pass


@trace_layout.register
def trace_layout_Schema(
    node: qlast.Schema,
    *,
    ctx: LayoutTraceContext,
) -> None:
    for decl in node.declarations:
        trace_layout(decl, ctx=ctx)


@trace_layout.register
def trace_layout_CreateScalarType(
    node: qlast.CreateScalarType,
    *,
    ctx: LayoutTraceContext,
) -> None:
    _trace_item_layout(node, ctx=ctx)


@trace_layout.register
def trace_layout_CreateObjectType(
    node: qlast.CreateObjectType,
    *,
    ctx: LayoutTraceContext,
) -> None:
    _trace_item_layout(node, ctx=ctx)


@trace_layout.register
def trace_layout_CreateLink(
    node: qlast.CreateLink,
    *,
    ctx: LayoutTraceContext,
) -> None:
    _trace_item_layout(node, ctx=ctx)


@trace_layout.register
def trace_layout_CreateProperty(
    node: qlast.CreateProperty,
    *,
    ctx: LayoutTraceContext,
) -> None:
    _trace_item_layout(node, ctx=ctx)


@trace_layout.register
def trace_layout_CreateConstraint(
    node: qlast.CreateConstraint,
    *,
    ctx: LayoutTraceContext,
) -> None:
    _trace_item_layout(node, ctx=ctx)


def _trace_item_layout(
    node: qlast.CreateObject,
    *,
    obj: Optional[qltracer.NamedObject] = None,
    fq_name: Optional[s_name.QualName] = None,
    ctx: LayoutTraceContext,
) -> None:
    if obj is None:
        fq_name = ctx.get_local_name(node.name)
        local_obj = ctx.objects[fq_name]
        assert isinstance(local_obj, qltracer.NamedObject)
        obj = local_obj

    assert fq_name is not None
    PointerType: type[qltracer.Pointer]

    if isinstance(node, qlast.BasedOn):
        bases = []
        # construct the parents set, used later in ancestors graph
        parents = set()

        for ref in _get_bases(node, ctx=ctx):
            bases.append(ref)

            # ignore std modules dependencies
            if ref.get_module_name() not in s_schema.STD_MODULES:
                parents.add(ref)

            if (
                ref.module not in ctx.local_modules
                and ref not in ctx.inh_graph
            ):
                base_obj = type(obj)(name=ref)
                ctx.inh_graph[ref] = topological.DepGraphEntry(item=base_obj)

                base = ctx.schema.get(ref)
                if isinstance(base, s_sources.Source):
                    assert isinstance(base_obj, qltracer.Source)
                    base_pointers = base.get_pointers(ctx.schema)
                    for pn, p in base_pointers.items(ctx.schema):
                        PointerType = (
                            qltracer.Property
                            if p.is_property() else
                            qltracer.Link
                        )
                        base_obj.pointers[pn] = PointerType(
                            s_name.QualName('__', pn.name),
                            source=base,
                            target=p.get_target(ctx.schema),
                        )

        ctx.parents[fq_name] = parents
        ctx.inh_graph[fq_name] = topological.DepGraphEntry(
            item=obj,
            deps=set(bases),
            merge=set(bases),
        )

    for decl in node.commands:
        if isinstance(decl, qlast.CreateConcretePointer):
            assert isinstance(obj, qltracer.Source)

            target: Optional[qltracer.TypeLike]
            target_expr: Optional[qlast.Expr]
            if isinstance(decl.target, qlast.TypeExpr):
                target = _resolve_type_expr(decl.target, ctx=ctx)
                target_expr = _get_expr_field(decl)
            else:
                target = None
                target_expr = decl.target

            pn = s_utils.ast_ref_to_unqualname(decl.name)

            PointerType = (
                qltracer.Property
                if isinstance(decl, qlast.CreateConcreteProperty) else
                qltracer.Link
                if isinstance(decl, qlast.CreateConcreteProperty) else
                qltracer.UnknownPointer
            )
            ptr = PointerType(
                s_name.QualName('__', pn.name),
                source=obj,
                target=target,
                target_expr=target_expr,
            )
            obj.pointers[pn] = ptr
            ptr_name = s_name.QualName(
                module=fq_name.module,
                name=f'{fq_name.name}@{decl.name.name}',
            )
            ctx.objects[ptr_name] = ptr
            ctx.defdeps[fq_name].add(ptr_name)
            ctx.pointers.setdefault(pn, set()).add(ptr_name)

            _trace_item_layout(
                decl, obj=ptr, fq_name=ptr_name, ctx=ctx)

        elif isinstance(decl, qlast.CreateConcreteConstraint):
            # Validate that the constraint exists at all.
            _validate_schema_ref(decl, ctx=ctx)
            _, con_fq_name = ctx.get_fq_name(decl)

            con_name = s_name.QualName(
                module=fq_name.module,
                name=f'{fq_name.name}@{con_fq_name}',
            )
            ctx.objects[con_name] = qltracer.ConcreteConstraint(con_name)
            ctx.constraints[fq_name].add(con_name)

        elif isinstance(decl, qlast.CreateAnnotationValue):
            # Validate that the annotation exists at all.
            _validate_schema_ref(decl, ctx=ctx)
            _, anno_fq_name = ctx.get_fq_name(decl)

            anno_name = s_name.QualName(
                module=fq_name.module,
                name=f'{fq_name.name}@{anno_fq_name}',
            )
            ctx.objects[anno_name] = qltracer.AnnotationValue(anno_name)

        elif isinstance(decl, qlast.CreateAccessPolicy):
            _, pol_fq_name = ctx.get_fq_name(decl)

            pol_name = s_name.QualName(
                module=fq_name.module,
                name=f'{fq_name.name}@{pol_fq_name}',
            )
            assert isinstance(obj, qltracer.Source)
            ctx.objects[pol_name] = qltracer.AccessPolicy(pol_name, source=obj)

        # XXX: name conflict with triggers, other things??
        elif isinstance(decl, qlast.CreateTrigger):
            _, trigger_fq_name = ctx.get_fq_name(decl)

            trigger_name = s_name.QualName(
                module=fq_name.module,
                name=f'{fq_name.name}@{trigger_fq_name}',
            )
            assert isinstance(obj, qltracer.Source)
            ctx.objects[trigger_name] = qltracer.Trigger(
                trigger_name, source=obj)

        elif isinstance(decl, qlast.CreateConcreteIndex):
            # Validate that the index exists at all.
            _validate_schema_ref(decl, ctx=ctx)
            _, idx_fq_name = ctx.get_fq_name(decl)

            idx_name = s_name.QualName(
                module=fq_name.module,
                name=f'{fq_name.name}@{idx_fq_name}',
            )
            ctx.objects[idx_name] = qltracer.ConcreteIndex(idx_name)

        elif isinstance(decl, qlast.SetField):
            field_name = s_name.QualName(
                module=fq_name.module,
                name=f'{fq_name.name}@{decl.name}',
            )

            # Trivial fields don't get added to the ddlgraph, which is
            # where duplication checks are normally done, so do the
            # check here instead.
            if field_name in ctx.objects:
                vn = get_verbosename_from_fqname(field_name, ctx)
                msg = f'{vn} was already declared'
                raise errors.InvalidDefinitionError(msg, span=decl.span)

            ctx.objects[field_name] = qltracer.Field(field_name)


RECURSION_GUARD: set[s_name.QualName] = set()


def get_ancestors(
    fq_name: s_name.QualName,
    ancestors: dict[s_name.QualName, set[s_name.QualName]],
    parents: Mapping[s_name.QualName, AbstractSet[s_name.QualName]],
) -> set[s_name.QualName]:
    """Recursively compute ancestors (in place) from the parents graph."""

    # value already computed
    result = ancestors.get(fq_name, set())
    if result is RECURSION_GUARD:
        raise errors.InvalidDefinitionError(
            f'{str(fq_name)!r} is defined recursively')
    elif result:
        return result

    ancestors[fq_name] = RECURSION_GUARD

    parent_set = parents.get(fq_name, set())
    # base case: include the parents
    result = set(parent_set)
    for fq_parent in parent_set:
        # recursive step: include parents' ancestors
        result |= get_ancestors(fq_parent, ancestors, parents)

    ancestors[fq_name] = result

    return result


@functools.singledispatch
def trace_dependencies(
    node: qlast.Base,
    *,
    ctx: DepTraceContext,
) -> None:
    raise NotImplementedError(
        f"no SDL dep tracer handler for {node.__class__}")


@trace_dependencies.register
def trace_SetField(
    node: qlast.SetField,
    *,
    ctx: DepTraceContext,
) -> None:
    deps = set()
    exprs = []

    assert node.value, "sdl SetField should always have value"
    if node.name == 'default':
        assert isinstance(node.value, qlast.Expr)
        exprs.append(ExprDependency(expr=node.value))
    else:
        for dep in qltracer.trace_refs(
            node.value,
            schema=ctx.schema,
            module=ctx.module,
            objects=ctx.objects,
            pointers=ctx.pointers,
            local_modules=ctx.local_modules,
            params={},
        )[0]:
            # ignore std module dependencies
            if dep.get_module_name() not in s_schema.STD_MODULES:
                deps.add(dep)

    _register_item(node, deps=deps, hard_dep_exprs=exprs, ctx=ctx)


@trace_dependencies.register
def trace_ConcreteConstraint(
    node: qlast.CreateConcreteConstraint,
    *,
    ctx: DepTraceContext,
) -> None:
    deps = set()

    base_name = ctx.get_ref_name(node.name)
    if base_name.get_module_name() not in s_schema.STD_MODULES:
        deps.add(base_name)

    exprs = [ExprDependency(expr=arg) for arg in node.args]
    if node.subjectexpr:
        exprs.append(ExprDependency(expr=node.subjectexpr))
    if node.except_expr:
        exprs.append(ExprDependency(expr=node.except_expr))

    if (expr := _get_expr_field(node)):
        exprs.append(ExprDependency(expr=expr))

    loop_control: Optional[s_name.QualName]
    if isinstance(ctx.depstack[-1][0], qlast.AlterScalarType):
        # Scalars are tightly bound to their constraints, so
        # we must prohibit any possible reference to this scalar
        # type from within the constraint.
        loop_control = ctx.depstack[-1][1]
    else:
        loop_control = None

    _register_item(
        node,
        deps=deps,
        hard_dep_exprs=exprs,
        loop_control=loop_control,
        source=ctx.depstack[-1][1],
        subject=ctx.depstack[-1][1],
        ctx=ctx,
    )


@trace_dependencies.register
def trace_AccessPolicy(
    node: qlast.CreateAccessPolicy,
    *,
    ctx: DepTraceContext,
) -> None:
    exprs = []
    if node.expr:
        exprs.append(ExprDependency(expr=node.expr))
    if node.condition:
        exprs.append(ExprDependency(expr=node.condition))

    _register_item(
        node,
        deps=set(),
        hard_dep_exprs=exprs,
        source=ctx.depstack[-1][1],
        subject=ctx.depstack[-1][1],
        ctx=ctx,
    )


@trace_dependencies.register
def trace_Trigger(
    node: qlast.CreateTrigger,
    *,
    ctx: DepTraceContext,
) -> None:
    exprs = [ExprDependency(expr=node.expr)]
    if node.condition:
        exprs.append(ExprDependency(expr=node.condition))

    obj = ctx.depstack[-1][1]
    _register_item(
        node,
        deps=set(),
        hard_dep_exprs=exprs,
        source=obj,
        subject=obj,
        anchors={'__new__': obj, '__old__': obj},
        ctx=ctx,
    )


@trace_dependencies.register
def trace_Rewrite(
    node: qlast.CreateRewrite,
    *,
    ctx: DepTraceContext,
) -> None:
    exprs = [ExprDependency(expr=node.expr)]

    obj = ctx.depstack[-2][1]
    _register_item(
        node,
        deps=set(),
        hard_dep_exprs=exprs,
        source=obj,
        subject=obj,
        anchors={'__old__': obj},
        ctx=ctx,
    )


@trace_dependencies.register
def trace_Index(
    node: qlast.CreateConcreteIndex,
    *,
    ctx: DepTraceContext,
) -> None:
    exprs = [ExprDependency(expr=node.expr)]
    if node.except_expr:
        exprs.append(ExprDependency(expr=node.except_expr))
    deps = set()
    if node.kwargs:
        for kwarg in node.kwargs:
            # HACK: Search all objects and depend on any ext::ai annotations.
            # FIXME: Can we make this more general and less slow?
            if kwarg == "embedding_model":
                for n, v in ctx.objects.items():
                    if (
                        "@ext::ai::" in n.name
                        and isinstance(v, qltracer.AnnotationValue)
                    ):
                        deps.add(n)
    _register_item(
        node,
        deps=deps,
        hard_dep_exprs=exprs,
        source=ctx.depstack[-1][1],
        subject=ctx.depstack[-1][1],
        ctx=ctx,
    )


@trace_dependencies.register
def trace_ConcretePointer(
    node: qlast.CreateConcretePointer,
    *,
    ctx: DepTraceContext,
) -> None:
    deps: list[Dependency] = []
    if isinstance(node.target, qlast.TypeExpr):
        deps.append(TypeDependency(texpr=node.target))
    elif isinstance(node.target, qlast.Expr):
        deps.append(ExprDependency(expr=node.target))
    elif node.target is None:
        pass
    else:
        raise AssertionError(
            f'unexpected CreateConcretePointer.target: {node.target!r}')

    if (target_expr := _get_expr_field(node)):
        deps.append(ExprDependency(expr=target_expr))

    _register_item(
        node,
        hard_dep_exprs=deps,
        source=ctx.depstack[-1][1],
        ctx=ctx,
    )


@trace_dependencies.register
def trace_Alias(
    node: qlast.CreateAlias,
    *,
    ctx: DepTraceContext,
) -> None:
    hard_dep_exprs = []

    if (expr := _get_expr_field(node)):
        hard_dep_exprs.append(ExprDependency(expr=expr))

    _register_item(node, hard_dep_exprs=hard_dep_exprs, ctx=ctx)


@trace_dependencies.register
def trace_Global(
    node: qlast.CreateGlobal,
    *,
    ctx: DepTraceContext,
) -> None:
    deps: list[Dependency] = []

    if isinstance(node.target, qlast.TypeExpr):
        deps.append(TypeDependency(texpr=node.target))
    elif isinstance(node.target, qlast.Expr):
        deps.append(ExprDependency(expr=node.target))

    _register_item(node, hard_dep_exprs=deps, ctx=ctx)


@trace_dependencies.register
def trace_Permission(
    node: qlast.CreatePermission,
    *,
    ctx: DepTraceContext,
) -> None:
    deps: list[Dependency] = [
        TypeDependency(texpr=qlast.TypeName(
            maintype=qlast.ObjectRef(module='__std__', name='bool')
        ))
    ]

    _register_item(node, hard_dep_exprs=deps, ctx=ctx)


@trace_dependencies.register
def trace_Function(
    node: qlast.CreateFunction,
    *,
    ctx: DepTraceContext,
) -> None:
    # We also need to add all the signature types as dependencies
    # to make sure that DDL linearization of SDL will define the types
    # before the function.
    deps: list[Dependency] = []

    # We don't actually care to resolve these, but we do need to check for
    # tracing errors.
    for param in node.params:
        if (
            isinstance(param.type, qlast.TypeName)
            and isinstance(param.type.maintype, qlast.PseudoObjectRef)
        ):
            # generic types are handled elsewhere
            continue
        _resolve_type_expr(param.type, ctx=ctx)
    _resolve_type_expr(node.returning, ctx=ctx)

    deps.extend(TypeDependency(texpr=param.type) for param in node.params)
    deps.append(TypeDependency(texpr=node.returning))

    params: dict[str, qlast.TypeExpr] = {}
    for param in node.params:
        params[param.name] = param.type

    if node.nativecode is not None:
        deps.append(FunctionDependency(expr=node.nativecode, params=params))
    elif (
        node.code is not None
        and node.code.language is qlast.Language.EdgeQL
        and node.code.code
    ):
        # Need to parse the actual code string and use that as the dependency.
        fcode = qlparser.parse_query(node.code.code)
        assert isinstance(fcode, qlast.Expr)
        deps.append(FunctionDependency(expr=fcode, params=params))

    # XXX: hard_dep_expr is used because it ultimately calls the
    # _get_hard_deps helper that extracts the proper dependency list
    # from types.
    _register_item(node, ctx=ctx, hard_dep_exprs=deps)


@trace_dependencies.register
def trace_default(
    node: qlast.CreateObject,
    *,
    ctx: DepTraceContext,
) -> None:
    # Generic DDL catchall
    _register_item(node, ctx=ctx)


def _clear_nonessential_subcommands(node: qlast.DDLOperation) -> None:
    node.commands = [
        cmd for cmd in node.commands
        if isinstance(cmd, qlast.SetField) and cmd.name.startswith('orig_')
    ]


def _register_item(
    decl: qlast.DDLOperation,
    *,
    deps: Optional[AbstractSet[s_name.QualName]] = None,
    hard_dep_exprs: Optional[Iterable[Dependency]] = None,
    loop_control: Optional[s_name.QualName] = None,
    anchors: Optional[Mapping[str, s_name.QualName]] = None,
    source: Optional[s_name.QualName] = None,
    subject: Optional[s_name.QualName] = None,
    ctx: DepTraceContext,
) -> None:

    name, fq_name = ctx.get_fq_name(decl)

    if fq_name in ctx.ddlgraph:
        vn = get_verbosename_from_fqname(fq_name, ctx)
        msg = f'{vn} was already declared'
        raise errors.InvalidDefinitionError(msg, span=decl.span)

    if deps:
        deps = set(deps)
    else:
        deps = set()

    weak_deps: set[s_name.QualName] = set()

    op = orig_op = copy.copy(decl)

    if ctx.depstack:
        if isinstance(op, qlast.CreateObject):
            op.sdl_alter_if_exists = True
        top_parent = parent = copy.copy(ctx.depstack[0][0])
        _clear_nonessential_subcommands(parent)
        for entry, _ in ctx.depstack[1:]:
            entry_op = copy.copy(entry)
            parent.commands.append(entry_op)
            parent = entry_op
            _clear_nonessential_subcommands(parent)

        parent.commands.append(op)
        op = top_parent
    else:
        assert isinstance(op, (qlast.Query, qlast.Command, qlast.DDLCommand))
        op.aliases = [qlast.ModuleAliasDecl(alias=None, module=ctx.module)]

    assert isinstance(op, qlast.DDLCommand)
    node = topological.DepGraphEntry(
        item=op,
        deps={n for _, n in ctx.depstack if n != loop_control},
        extra=False,
    )
    ctx.ddlgraph[fq_name] = node

    if hasattr(decl, "bases"):
        # add parents to dependencies
        parents = ctx.parents.get(fq_name)
        if parents is not None:
            deps.update(parents)

    if ctx.depstack:
        # all ancestors should be seen as dependencies
        ancestor_bases = ctx.ancestors.get(ctx.depstack[-1][1])
        if ancestor_bases:
            for ancestor_base in ancestor_bases:
                base_item = qltracer.qualify_name(ancestor_base, name)
                if base_item in ctx.objects:
                    deps.add(base_item)

    ast_subcommands = getattr(decl, 'commands', [])
    commands = []
    if ast_subcommands:
        subcmds: list[qlast.DDLOperation] = []
        for cmd in ast_subcommands:
            # include dependency on constraints or annotations if present
            if isinstance(cmd, qlast.CreateConcreteConstraint):
                cmd_name = ctx.get_local_name(cmd.name)
                if cmd_name.get_module_name() not in s_schema.STD_MODULES:
                    deps.add(cmd_name)
            elif isinstance(cmd, qlast.CreateAnnotationValue):
                cmd_name = ctx.get_local_name(cmd.name)
                if cmd_name.get_module_name() not in s_schema.STD_MODULES:
                    deps.add(cmd_name)

            if (isinstance(cmd, qlast.ObjectDDL)
                    # HACK: functions don't have alters at the moment
                    and not isinstance(decl, qlast.CreateFunction)):
                subcmds.append(cmd)
            elif (isinstance(cmd, qlast.SetField)
                  and not cmd.special_syntax
                  and not isinstance(cmd.value, qlast.BaseConstant)
                  and not isinstance(
                      op, (qlast.CreateAlias, qlast.CreateGlobal))):
                subcmds.append(cmd)
            else:
                commands.append(cmd)

        if subcmds:
            assert isinstance(decl, qlast.ObjectDDL)
            alter_name = f"Alter{decl.__class__.__name__[len('Create'):]}"
            alter_cls = getattr(qlast, alter_name)
            alter_cmd: qlast.ObjectDDL = alter_cls(name=decl.name)

            # indexes need to preserve their "on" expression
            if isinstance(decl, qlast.CreateConcreteIndex):
                assert isinstance(alter_cmd, qlast.ConcreteIndexCommand)
                alter_cmd.expr = decl.expr
                alter_cmd.kwargs = decl.kwargs

            # constraints need to preserve their "on" expression
            if isinstance(decl, qlast.CreateConcreteConstraint):
                assert isinstance(alter_cmd, qlast.ConcreteConstraintOp)
                alter_cmd.subjectexpr = decl.subjectexpr
                alter_cmd.args = decl.args

            # functions need to preserve arguments
            if isinstance(decl, qlast.CreateFunction):
                assert isinstance(alter_cmd, qlast.FunctionCommand)
                alter_cmd.params = decl.params

            if not ctx.depstack:
                alter_cmd.aliases = [
                    qlast.ModuleAliasDecl(alias=None, module=ctx.module)
                ]

            ctx.depstack.append((alter_cmd, fq_name))

            for cmd in subcmds:
                trace_dependencies(cmd, ctx=ctx)

            ctx.depstack.pop()

    if hard_dep_exprs:
        anchors = dict(anchors or {})
        if source:
            anchors['__source__'] = source
        if subject or (
            fq_name
            and not (
                isinstance(decl, qlast.SetField) and decl.name == 'default'
            )
        ):
            anchors['__subject__'] = subject or fq_name

        for expr in hard_dep_exprs:
            if isinstance(expr, TypeDependency):
                deps |= _get_hard_deps(expr.texpr, ctx=ctx)
            elif isinstance(expr, ExprDependency):
                qlexpr = expr.expr
                params: Mapping[str, qlast.TypeExpr]
                if isinstance(expr, FunctionDependency):
                    params = expr.params
                else:
                    params = {}

                strong_tdeps, weak_tdeps = qltracer.trace_refs(
                    qlexpr,
                    schema=ctx.schema,
                    module=ctx.module,
                    path_prefix=source,
                    anchors=anchors,
                    objects=ctx.objects,
                    pointers=ctx.pointers,
                    local_modules=ctx.local_modules,
                    params=params,
                )

                for tdeps, strong in (
                    (strong_tdeps, True), (weak_tdeps, False)
                ):
                    pdeps: MutableSet[s_name.QualName] = set()
                    for dep in tdeps:
                        # ignore std module dependencies
                        if dep.get_module_name() not in s_schema.STD_MODULES:
                            # First check if the dep is a pointer that's
                            # defined explicitly. If it's not explicitly
                            # defined, check for ancestors and use them
                            # instead.
                            #
                            # FIXME: Ideally we should use the closest
                            # ancestor, instead of all of them, but
                            # including all is still correct.
                            if '@' in dep.name:
                                pdeps |= _get_pointer_deps(dep, ctx=ctx)
                            else:
                                pdeps.add(dep)

                    # Handle the pre-processed deps now.
                    cdeps = deps if strong else weak_deps
                    for dep in pdeps:
                        cdeps.add(dep)

                        if isinstance(
                                decl, (qlast.CreateAlias, qlast.CreateGlobal)):
                            # If the declaration is a view, we need to be
                            # dependent on all the types and their props
                            # used in the view.
                            vdeps = {dep} | ctx.ancestors.get(dep, set())
                            for vdep in vdeps:
                                cdeps |= ctx.defdeps.get(vdep, set())

                        if (
                            isinstance(decl, (
                                qlast.CreateConcretePointer,
                                qlast.CreateGlobal))
                            and isinstance(decl.target, qlast.Expr)
                        ) or isinstance(
                            decl, (
                                qlast.CreateAccessPolicy, qlast.CreateTrigger)
                        ):
                            # If the declaration is a computable pointer/global
                            # or access policy (XXX: trigger?),
                            # we need to include the
                            # possible constraints for every dependency
                            # that it lists. This is so that any other
                            # links/props that this computable uses has
                            # all of their constraints defined before the
                            # computable and the cardinality can be
                            # inferred correctly.
                            con_deps = {dep} | ctx.ancestors.get(dep, set())
                            for con_dep in con_deps:
                                cdeps |= ctx.constraints.get(con_dep, set())
            else:
                raise AssertionError(f'unexpected dependency type: {expr!r}')

    orig_op.commands = commands

    if loop_control:
        parent_node = ctx.ddlgraph[loop_control]
        parent_node.loop_control.add(fq_name)

    node.deps |= deps
    node.weak_deps |= weak_deps - {fq_name}


def _get_pointer_deps(
    pointer: s_name.QualName,
    *,
    ctx: DepTraceContext,
) -> MutableSet[s_name.QualName]:
    result: MutableSet[s_name.QualName] = set()
    owner_name, ptr_name = pointer.name.split('@', 1)
    # For every ancestor of the type, where
    # the pointer is defined, see if there are
    # ancestors of the pointer itself defined.
    for tansc in ctx.ancestors.get(
            s_name.QualName(
                module=pointer.module, name=owner_name
            ), set()):
        ptr_ansc = s_name.QualName(
            module=tansc.module,
            name=f'{tansc.name}@{ptr_name}',
        )

        # Only add the pointer's ancestor if
        # it is explicitly defined.
        if ptr_ansc in ctx.objects:
            result.add(ptr_ansc)

    # Only add the pointer if it is explicitly defined.
    if pointer in ctx.objects:
        result.add(pointer)

    # HACK: Add all pointers that have this pointer (link, actually)
    # as their prefix. As a rule, the assumption is that depending on
    # a link typically comes as a package of depending on the link's
    # property.
    # This will *also* grab any constraints on the pointer, which
    # is is important for properly doing cardinality inference
    # on expressions involving it.
    # PERF: We should avoid actually searching all the objects.
    for propname, prop in ctx.objects.items():
        if (
            str(propname).startswith(str(pointer) + '@')
            and not isinstance(prop, qltracer.Field)
        ):
            result.add(propname)

    return result


def _get_hard_deps(
    expr: qlast.TypeExpr, *, ctx: DepTraceContext
) -> MutableSet[s_name.QualName]:
    deps: MutableSet[s_name.QualName] = set()

    if isinstance(expr, qlast.TypeName):

        # Special case for `enum`
        # Don't trace at all, neither `enum` or `VariantA` are resolvable names.
        # This case will fail later, saying that you need to declare a new type.
        if qlutils.is_enum(expr):
            return deps

        # We care about subtypes dependencies, because
        # they can either be custom scalars or illegal
        # ObjectTypes (then error message will depend on
        # dependency tracing)
        if expr.subtypes:
            for subtype in expr.subtypes:
                deps |= _get_hard_deps(subtype, ctx=ctx)

        else:
            # Base case.
            name = ctx.get_ref_name(expr.maintype)
            if name.get_module_name() not in s_schema.STD_MODULES:
                deps.add(name)

    elif isinstance(expr, qlast.TypeExprLiteral):
        pass

    elif isinstance(expr, qlast.TypeOf):
        # TODO: maybe we should also recurse into the inner expr?
        pass

    elif isinstance(expr, qlast.TypeOp):
        deps |= _get_hard_deps(expr.left, ctx=ctx)
        deps |= _get_hard_deps(expr.right, ctx=ctx)

    return deps


def _get_bases(
    decl: qlast.CreateObject, *, ctx: LayoutTraceContext
) -> list[s_name.QualName]:
    """Resolve object bases from the "extends" declaration."""
    if not isinstance(decl, qlast.BasedOn):
        return []

    bases = []

    if decl.bases:
        # Explicit inheritance
        has_enums = any(qlutils.is_enum(br) for br in decl.bases)

        if has_enums:
            if len(decl.bases) > 1:
                raise errors.SchemaError(
                    f"invalid scalar type definition, enumeration must "
                    f"be the only supertype specified",
                    span=decl.bases[0].span,
                )

            bases = [s_name.QualName("std", "anyenum")]

        else:
            for base_ref in decl.bases:
                # Validate that the base actually exists.
                tracer_type = _get_tracer_type(decl)
                assert tracer_type is not None
                obj = _resolve_type_name(
                    base_ref.maintype,
                    tracer_type=tracer_type,
                    ctx=ctx
                )
                name = obj.get_name(ctx.schema)
                if not isinstance(name, s_name.QualName):
                    qname = s_name.QualName.from_string(name.name)
                else:
                    qname = name
                bases.append(qname)

    return bases


def _resolve_type_expr(
    texpr: qlast.TypeExpr,
    *,
    ctx: LayoutTraceContext | DepTraceContext,
) -> qltracer.TypeLike:

    if isinstance(texpr, qlast.TypeName):
        if texpr.subtypes:
            return qltracer.Type(
                name=s_name.QualName(module='__coll__', name=texpr.name or ''),
            )
        else:
            return cast(
                qltracer.TypeLike,
                _resolve_type_name(
                    texpr.maintype,
                    tracer_type=qltracer.Type,
                    ctx=ctx,
                )
            )

    elif isinstance(texpr, qlast.TypeOp):

        if texpr.op == qlast.TypeOpName.OR:
            return qltracer.UnionType([
                _resolve_type_expr(texpr.left, ctx=ctx),
                _resolve_type_expr(texpr.right, ctx=ctx),
            ])

        if texpr.op == qlast.TypeOpName.AND:
            return qltracer.IntersectionType([
                _resolve_type_expr(texpr.left, ctx=ctx),
                _resolve_type_expr(texpr.right, ctx=ctx),
            ])

        else:
            raise NotImplementedError(
                f'unsupported type operation: {texpr.op}')

    else:
        raise NotImplementedError(
            f'unsupported type expression: {texpr!r}'
        )


TRACER_TO_REAL_TYPE_MAP = {
    qltracer.Type: s_types.Type,
    qltracer.ObjectType: s_objtypes.ObjectType,
    qltracer.ScalarType: s_scalars.ScalarType,
    qltracer.Constraint: s_constr.Constraint,
    qltracer.Annotation: s_anno.Annotation,
    qltracer.Property: s_props.Property,
    qltracer.Link: s_links.Link,
    qltracer.Index: s_indexes.Index,
}


def _get_local_obj(
    refname: s_name.QualName,
    tracer_type: type[qltracer.NamedObject],
    span: Optional[parsing.Span],
    *,
    ctx: LayoutTraceContext | DepTraceContext,
) -> Optional[qltracer.NamedObject]:

    obj = ctx.objects.get(refname)

    if isinstance(obj, s_pseudo.PseudoType):
        raise errors.SchemaError(
            f'invalid type: {obj.get_verbosename(ctx.schema)} is a generic '
            f'type and they are not supported in user-defined schema',
            span=span,
        )

    elif obj is not None and not isinstance(obj, tracer_type):
        obj_type = TRACER_TO_REAL_TYPE_MAP[type(obj)]
        real_type = TRACER_TO_REAL_TYPE_MAP[tracer_type]
        raise errors.InvalidReferenceError(
            f'{str(refname)!r} exists, but is '
            f'{english.add_a(obj_type.get_schema_class_displayname())}, '
            f'not {english.add_a(real_type.get_schema_class_displayname())}',
            span=span,
        )

    return obj


def _resolve_type_name(
    ref: qlast.BaseObjectRef,
    *,
    tracer_type: type[qltracer.NamedObject],
    ctx: LayoutTraceContext | DepTraceContext,
) -> qltracer.ObjectLike:

    refname = ctx.get_ref_name(ref)
    local_obj = _get_local_obj(refname, tracer_type, ref.span, ctx=ctx)
    obj: qltracer.ObjectLike
    if local_obj is not None:
        obj = local_obj
    else:
        obj = _resolve_schema_ref(
            refname,
            type=tracer_type,
            span=ref.span,
            ctx=ctx,
        )

    return obj


def _get_tracer_type(
    decl: qlast.CreateObject,
) -> Optional[type[qltracer.NamedObject]]:

    tracer_type: Optional[type[qltracer.NamedObject]] = None

    if isinstance(decl, qlast.CreateObjectType):
        tracer_type = qltracer.ObjectType
    elif isinstance(decl, qlast.CreateScalarType):
        tracer_type = qltracer.ScalarType
    elif isinstance(decl, (qlast.CreateConstraint,
                           qlast.CreateConcreteConstraint)):
        tracer_type = qltracer.Constraint
    elif isinstance(decl, (qlast.CreateAnnotation,
                           qlast.CreateAnnotationValue)):
        tracer_type = qltracer.Annotation
    elif isinstance(decl, qlast.CreateConcreteUnknownPointer):
        tracer_type = qltracer.Pointer
    elif isinstance(decl, (qlast.CreateProperty,
                           qlast.CreateConcreteProperty)):
        tracer_type = qltracer.Property
    elif isinstance(decl, (qlast.CreateLink,
                           qlast.CreateConcreteLink)):
        tracer_type = qltracer.Link
    elif isinstance(decl, qlast.CreatePermission):
        tracer_type = qltracer.Permission
    elif isinstance(decl, (qlast.CreateIndex,
                           qlast.CreateConcreteIndex)):
        tracer_type = qltracer.Index

    return tracer_type


def _validate_schema_ref(
    decl: qlast.CreateObject,
    *,
    ctx: LayoutTraceContext,
) -> None:
    refname = ctx.get_ref_name(decl.name)
    tracer_type = _get_tracer_type(decl)
    if tracer_type is None:
        # Bail out and rely on some other validation mechanism
        return

    local_obj = _get_local_obj(refname, tracer_type, decl.span, ctx=ctx)

    if local_obj is None:
        if (tracer_type is qltracer.Index and
                refname == s_indexes.DEFAULT_INDEX):
            return

        _resolve_schema_ref(
            refname,
            type=tracer_type,
            span=decl.span,
            ctx=ctx,
        )


def _resolve_schema_ref(
    name: s_name.Name,
    type: type[qltracer.NamedObject],
    span: Optional[parsing.Span],
    *,
    ctx: LayoutTraceContext | DepTraceContext,
) -> s_obj.SubclassableObject:
    real_type = TRACER_TO_REAL_TYPE_MAP[type]
    try:
        return ctx.schema.get(name, type=real_type, span=span)
    except errors.InvalidReferenceError as e:
        s_utils.enrich_schema_lookup_error(
            e,
            name,
            schema=ctx.schema,
            modaliases=ctx.modaliases,
            item_type=real_type,
            span=span,
        )
        raise


def _get_expr_field(decl: qlast.DDLOperation) -> Optional[qlast.Expr]:
    for cmd in decl.commands:
        if isinstance(cmd, qlast.SetField) and cmd.name == "expr":
            assert cmd.value, "sdl SetField should always have value"
            assert isinstance(cmd.value, qlast.Expr)
            return cmd.value
    return None


================================================
FILE: edb/edgeql/desugar_group.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Desugar GROUP queries into internal FOR GROUP queries.

This code is called by both the model and the real implementation,
though if that starts becoming a problem it should just be abandoned.
"""

from __future__ import annotations


from typing import Optional, AbstractSet

from edb import errors

from edb.common import ast
from edb.common import ordered
from edb.common.compiler import AliasGenerator

from edb.edgeql import ast as qlast
from edb.edgeql.compiler import astutils


def key_name(s: str) -> str:
    return s.split('~')[0]


def name_path(name: str) -> qlast.Path:
    return qlast.Path(steps=[qlast.ObjectRef(name=name)])


def make_free_object(els: dict[str, qlast.Expr]) -> qlast.Shape:
    return qlast.Shape(
        expr=None,
        elements=[
            qlast.ShapeElement(
                expr=qlast.Path(steps=[qlast.Ptr(name=name)]),
                compexpr=expr
            )
            for name, expr in els.items()
        ],
    )


def collect_grouping_atoms(
    els: list[qlast.GroupingElement],
) -> AbstractSet[str]:
    atoms: ordered.OrderedSet[str] = ordered.OrderedSet()

    def _collect_atom(el: qlast.GroupingAtom) -> None:
        if isinstance(el, qlast.GroupingIdentList):
            for at in el.elements:
                _collect_atom(at)

        else:
            assert isinstance(el, qlast.ObjectRef)
            atoms.add(el.name)

    def _collect_el(el: qlast.GroupingElement) -> None:
        if isinstance(el, qlast.GroupingSets):
            for sub in el.sets:
                _collect_el(sub)
        elif isinstance(el, qlast.GroupingOperation):
            for at in el.elements:
                _collect_atom(at)
        elif isinstance(el, qlast.GroupingSimple):
            _collect_atom(el.element)
        else:
            raise AssertionError('Unknown GroupingElement')

    for el in els:
        _collect_el(el)

    return atoms


def desugar_group(
    node: qlast.GroupQuery,
    aliases: AliasGenerator,
) -> qlast.InternalGroupQuery:
    assert not isinstance(node, qlast.InternalGroupQuery)
    by_alias_map: dict[str, tuple[str, qlast.Path]] = {}

    def rewrite_atom(el: qlast.GroupingAtom) -> qlast.GroupingAtom:
        if isinstance(el, qlast.ObjectRef):
            return el
        elif isinstance(el, qlast.Path):
            assert isinstance(el.steps[0], qlast.Ptr)
            ptrname = el.steps[0].name
            ptrtype = el.steps[0].type
            if ptrname not in by_alias_map:
                alias = aliases.get(ptrname)
                by_alias_map[ptrname] = (alias, el)
            else:
                alias = by_alias_map[ptrname][0]
                aliased_el = by_alias_map[ptrname][1]
                assert isinstance(aliased_el.steps[0], qlast.Ptr)
                aliased_el_ptrtype = aliased_el.steps[0].type
                if ptrtype != aliased_el_ptrtype:
                    raise errors.QueryError(
                        f"BY clause cannot refer to link property and object "
                        f"property with the same name",
                        span=el.span,
                    )
            return qlast.ObjectRef(name=alias)
        else:
            assert isinstance(el, qlast.GroupingIdentList)
            return qlast.GroupingIdentList(
                span=el.span,
                elements=tuple(rewrite_atom(at) for at in el.elements),
            )

    def rewrite(el: qlast.GroupingElement) -> qlast.GroupingElement:
        if isinstance(el, qlast.GroupingSimple):
            return qlast.GroupingSimple(
                span=el.span, element=rewrite_atom(el.element))
        elif isinstance(el, qlast.GroupingSets):
            return qlast.GroupingSets(
                span=el.span, sets=[rewrite(s) for s in el.sets])
        elif isinstance(el, qlast.GroupingOperation):
            return qlast.GroupingOperation(
                span=el.span,
                oper=el.oper,
                elements=[rewrite_atom(a) for a in el.elements])
        raise AssertionError

    # The rewrite calls on the grouping elements populate alias_map
    # with any bindings for pointers the by clause refers to directly.
    by = [rewrite(by_el) for by_el in node.by]

    alias_map: dict[str, tuple[str, qlast.Expr]] = {
        k: v for k, v in by_alias_map.items()
    }

    for using_clause in (node.using or ()):
        if using_clause.alias in alias_map:
            # TODO: This would be a great place to allow multiple spans!
            raise errors.QueryError(
                f"USING clause binds a variable '{using_clause.alias}' "
                f"but a property with that name is used directly in the BY "
                f"clause",
                span=alias_map[using_clause.alias][1].span,
            )
        alias_map[using_clause.alias] = (using_clause.alias, using_clause.expr)

    using = []
    for alias, path in alias_map.values():
        using.append(qlast.AliasedExpr(alias=alias, expr=path))

    actual_keys = collect_grouping_atoms(by)

    g_alias = aliases.get('g')
    grouping_alias = aliases.get('grouping')
    output_dict = {
        'key': make_free_object({
            name: name_path(alias)
            for name, (alias, _) in alias_map.items()
            if alias in actual_keys
        }),
        'grouping': qlast.FunctionCall(
            func='array_unpack',
            args=[name_path(grouping_alias)],
        ),
        'elements': name_path(g_alias),
    }
    output_shape = make_free_object(output_dict)

    return qlast.InternalGroupQuery(
        span=node.span,
        aliases=node.aliases,
        subject_alias=node.subject_alias,
        subject=node.subject,
        # rewritten parts!
        using=using,
        by=by,
        group_alias=g_alias,
        grouping_alias=grouping_alias,
        result=output_shape,
        from_desugaring=True,
    )


def _count_alias_uses(
    node: qlast.Expr,
    alias: str,
) -> int:
    uses = 0
    for child in ast.find_children(node, qlast.Path):
        match child:
            case astutils.alias_view((alias2, _)) if alias == alias2:
                uses += 1
    return uses


def try_group_rewrite(
    node: qlast.Query,
    aliases: AliasGenerator,
) -> Optional[qlast.Query]:
    """
    Try to apply some syntactic rewrites of GROUP expressions so we
    can generate better code.

    The two key desugarings are:

    * Sink a shape into the internal group result

        SELECT (GROUP ...) 
        [filter-clause] [order-clause] [other clauses]
        =>
        SELECT (
          FOR GROUP ...
          UNION  
          [filter-clause]
          [order-clause]
        ) [other clauses]

    * Convert a FOR over a group into just an internal group (and
      a trivial FOR)

        FOR g in (GROUP ...) UNION 
        =>
        FOR GROUP ...
        UNION (
            FOR g IN ()
            UNION 
        )
    """

    # Inline trivial uses of aliases bound to a group and then
    # immediately used, so that we can apply the other optimizations.
    match node:
        case qlast.SelectQuery(
            aliases=[
                *_,
                qlast.AliasedExpr(alias=alias, expr=qlast.GroupQuery() as grp)
            ] as qaliases,
            result=qlast.Shape(
                expr=astutils.alias_view((alias2, [])),
                elements=elements,
            ) as result,
        ) if alias == alias2 and _count_alias_uses(result, alias) == 1:
            node = node.replace(
                aliases=qaliases[:-1],
                result=qlast.Shape(expr=grp, elements=elements),
            )

        case qlast.ForQuery(
            aliases=[
                *_,
                qlast.AliasedExpr(alias=alias, expr=qlast.GroupQuery() as grp)
            ] as qaliases,
            iterator=astutils.alias_view((alias2, [])),
            result=result,
        ) if alias == alias2 and _count_alias_uses(result, alias) == 0:
            node = node.replace(
                aliases=qaliases[:-1],
                iterator=grp,
            )

    # Sink shapes into the GROUP
    if (
        isinstance(node, qlast.SelectQuery)
        and isinstance(node.result, qlast.Shape)
        and isinstance(node.result.expr, qlast.GroupQuery)
    ):
        igroup = desugar_group(node.result.expr, aliases)
        igroup = igroup.replace(result=qlast.Shape(
            expr=igroup.result, elements=node.result.elements))

        # FILTER gets sunk into the body of the FOR GROUP
        if node.where or node.orderby:
            igroup = igroup.replace(
                # We need to move the result_alias in case
                # the FILTER depends on it.
                result_alias=node.result_alias,
                where=node.where,
                orderby=node.orderby,
            )

        return node.replace(
            result=igroup, result_alias=None, where=None, orderby=None)

    # Eliminate FORs over GROUPs
    if (
        isinstance(node, qlast.ForQuery)
        and isinstance(node.iterator, qlast.GroupQuery)
    ):
        igroup = desugar_group(node.iterator, aliases)
        new_result = qlast.ForQuery(
            iterator_alias=node.iterator_alias,
            iterator=igroup.result,
            result=node.result,
        )
        return igroup.replace(result=new_result, aliases=node.aliases)

    return None


================================================
FILE: edb/edgeql/parser/__init__.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations
from typing import Any, Callable, Optional, Mapping
import pathlib

from edb import errors
from edb.common import parsing

import edb._edgeql_parser as rust_parser

from .grammar import tokens

from .. import ast as qlast
from .. import tokenizer as qltokenizer


SPEC_LOADED = False


def append_module_aliases(
    command: qlast.Command, aliases: Mapping[Optional[str], str]
):
    modaliases: list[qlast.Alias] = []
    for alias, module in aliases.items():
        decl = qlast.ModuleAliasDecl(module=module, alias=alias)
        modaliases.append(decl)

    if not command.aliases:
        command.aliases = modaliases
    else:
        command.aliases = modaliases + command.aliases


def parse_fragment(
    source: qltokenizer.Source | str,
    filename: Optional[str] = None,
) -> qlast.Expr:
    res = parse(tokens.T_STARTFRAGMENT, source, filename=filename)
    assert isinstance(res, qlast.Expr)
    return res


def parse_query(
    source: qltokenizer.Source | str,
    module_aliases: Optional[Mapping[Optional[str], str]] = None,
) -> qlast.Query:
    """Parse some EdgeQL potentially adding some module aliases.

    This will parse EdgeQL queries and expressions. If the source is an
    expression, the result will be wrapped into a SelectQuery.
    """

    tree = parse_fragment(source)
    if not isinstance(tree, qlast.Query):
        tree = qlast.SelectQuery(result=tree)

    if module_aliases:
        append_module_aliases(tree, module_aliases)

    return tree


def parse_block(
    source: qltokenizer.Source | str,
    module_aliases: Optional[Mapping[Optional[str], str]] = None,
) -> list[qlast.Command]:
    node = parse(tokens.T_STARTBLOCK, source)
    assert isinstance(node, qlast.Commands), node
    if module_aliases:
        for command in node.commands:
            append_module_aliases(command, module_aliases)
    return node.commands


def parse_migration_body_block(
    source: str,
) -> tuple[qlast.NestedQLBlock, list[qlast.SetField]]:
    # For parser-internal technical reasons, we don't have a
    # production that means "just the *inside* of a migration block
    # (without braces)", so we just hack around this by adding braces.
    # This is only really workable because we only use this in a place
    # where the source contexts don't matter anyway.
    return parse(tokens.T_STARTMIGRATION, f"{{{source}}}")


def parse_extension_package_body_block(
    source: str,
) -> tuple[qlast.NestedQLBlock, list[qlast.SetField]]:
    # For parser-internal technical reasons, we don't have a
    # production that means "just the *inside* of a migration block
    # (without braces)", so we just hack around this by adding braces.
    # This is only really workable because we only use this in a place
    # where the source contexts don't matter anyway.
    return parse(tokens.T_STARTEXTENSION, f"{{{source}}}")


def parse_sdl(expr: str):
    return parse(tokens.T_STARTSDLDOCUMENT, expr)


def parse(
    start_token: type[tokens.Token],
    source: str | qltokenizer.Source,
    filename: Optional[str] = None,
):
    if not SPEC_LOADED:
        preload_spec()

    if isinstance(source, str):
        source = qltokenizer.Source.from_string(source)

    start_name = start_token.__name__[2:]
    result, productions = rust_parser.parse(start_name, source.tokens())

    if len(result.errors) > 0:
        # TODO: emit multiple errors

        # Heuristic to pick the error:
        # - the only Unexpected, if it is a keyword
        # - first encountered,
        # - Unexpected before Missing,
        # - original order.
        errs = result.errors
        unexpected = [e for e in errs if e[0].startswith('Unexpected')]
        if len(unexpected) == 1 and unexpected[0][0].startswith(
            'Unexpected keyword'
        ):
            error = unexpected[0]
        else:
            errs.sort(key=lambda e: (e[1][0], -ord(e[0][1])))
            error = errs[0]

        message, span, hint, details = error
        position = qltokenizer.inflate_position(source.text(), span)

        parsing_span = parsing.Span(
            'query',
            source.text(),
            start=position[2],
            end=position[3] or position[2],
            context_lines=10,
        )
        raise errors.EdgeQLSyntaxError(
            message,
            position=position,
            hint=hint,
            details=details,
            span=parsing_span,
        )

    assert isinstance(result.out, rust_parser.CSTNode)
    return _cst_to_ast(
        result.out,
        productions,
        source,
        filename,
    ).val


def _cst_to_ast(
    cst: rust_parser.CSTNode,
    productions: list[tuple[type, Callable]],
    source: qltokenizer.Source,
    filename: Optional[str],
) -> Any:
    # Converts CST into AST by calling methods from the grammar classes.
    #
    # This function was originally written as a simple recursion.
    # Then I had to unfold it, because it was hitting recursion limit.
    # Stack here contains all remaining things to do:
    # - CST node means the node has to be processed and pushed onto the
    #   result stack,
    # - production means that all args of production have been processed
    #   are are ready to be passed to the production method. The result is
    #   obviously pushed onto the result stack

    stack: list[rust_parser.CSTNode | rust_parser.Production] = [cst]
    result: list[Any] = []

    while len(stack) > 0:
        node = stack.pop()

        if isinstance(node, rust_parser.CSTNode):
            # this would be the body of the original recursion function

            if terminal := node.terminal:
                # Terminal is simple: just convert to parsing.Token
                span = parsing.Span(
                    filename=filename,
                    buffer=source.text(),
                    start=terminal.start,
                    end=terminal.end,
                )
                result.append(
                    parsing.Token(terminal.text, terminal.value, span)
                )

            elif production := node.production:
                # Production needs to first process all args, then
                # call the appropriate method.
                # (this is all in reverse, because stacks)
                stack.append(production)
                args = list(production.args)
                args.reverse()
                stack.extend(args)
            else:
                raise NotImplementedError(node)

        elif isinstance(node, rust_parser.Production):
            # production args are done, get them out of result stack
            len_args = len(node.args)
            split_at = len(result) - len_args
            args = result[split_at:]
            result = result[0:split_at]

            # find correct method to call
            production_id = node.id
            non_term_type, method = productions[production_id]
            sym = non_term_type()

            # init the span onto the Nonterm object, so it can be accessed by
            # production methods to construct nodes
            if node.start is not None and node.end is not None:
                sym.span = parsing.Span(
                    filename=filename,
                    buffer=source.text(),
                    start=node.start,
                    end=node.end,
                )
            else:
                sym.span = None

            method(sym, *args)

            # a helper to set the span of each constructed node, so we don't
            # have to manually set the span things assigned to nonterm.val
            if sym.span and isinstance(sym.val, qlast.Base):
                sym.val.span = sym.span

            # push into result stack
            result.append(sym)

    return result.pop()


def preload_spec() -> None:
    global SPEC_LOADED
    path = get_spec_filepath()
    rust_parser.preload_spec(path)
    SPEC_LOADED = True


def get_spec_filepath():
    "Returns an absolute path to the serialized grammar spec file"

    edgeql_dir = pathlib.Path(__file__).parent.parent
    return str(edgeql_dir / 'grammar.bc')


================================================
FILE: edb/edgeql/parser/grammar/.gitignore
================================================
*.log
*.pickle
*.dot


================================================
FILE: edb/edgeql/parser/grammar/__init__.py
================================================
##
# Copyright (c) 2015-present MagicStack Inc.
# All rights reserved.
#
# See LICENSE for details.
##

from __future__ import annotations

from . import start as start  # noqa


================================================
FILE: edb/edgeql/parser/grammar/commondl.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

import sys
import types
import typing


from edb.errors import EdgeQLSyntaxError

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from edb.common import parsing
from edb.common import verutils

from . import expressions
from . import tokens

from .precedence import *  # NOQA
from .tokens import *  # NOQA
from .expressions import *  # NOQA


Nonterm = expressions.Nonterm  # type: ignore[misc]


def _parse_language(node):
    lang = node.val.upper()
    if lang == 'EDGEQL':
        return qlast.Language.EdgeQL
    if lang == 'SQL':
        return qlast.Language.SQL
    raise EdgeQLSyntaxError(
        f'{node.val} is not a valid language',
        span=node.span) from None


def _validate_declarations(
    declarations: typing.Sequence[
        qlast.ModuleDeclaration | qlast.ObjectDDL]
) -> None:
    # Check that top-level declarations either use fully-qualified
    # names or are module blocks.
    for decl in declarations:
        if (
            not isinstance(
                decl,
                (qlast.ModuleDeclaration, qlast.ExtensionCommand,
                 qlast.FutureCommand)
            ) and decl.name.module is None
        ):
            raise EdgeQLSyntaxError(
                "only fully-qualified name is allowed in "
                "top-level declaration",
                span=decl.name.span)


def extract_bases(bases, commands):
    vbases = bases
    vcommands = []
    for command in commands:
        if isinstance(command, qlast.AlterAddInherit):
            if vbases:
                raise EdgeQLSyntaxError(
                    "specifying EXTENDING twice is not allowed",
                    span=command.span)
            vbases = command.bases
        else:
            vcommands.append(command)
    return vbases, vcommands


class NewNontermHelper:
    def __init__(self, modname):
        self.name = modname

    def _new_nonterm(
        self, clsname, clsdict=None, clskwds=None, clsbases=(Nonterm,)
    ):
        if clsdict is None:
            clsdict = {}
        if clskwds is None:
            clskwds = {}
        mod = sys.modules[self.name]

        def clsexec(ns):
            ns['__module__'] = self.name
            for k, v in clsdict.items():
                ns[k] = v
            return ns

        cls = types.new_class(clsname, clsbases, clskwds, clsexec)
        setattr(mod, clsname, cls)
        return cls


class Semicolons(Nonterm):
    # one or more semicolons
    @parsing.inline(0)
    def reduce_SEMICOLON(self, tok):
        pass

    @parsing.inline(0)
    def reduce_Semicolons_SEMICOLON(self, semicolons, semicolon):
        pass


class OptSemicolons(Nonterm):
    @parsing.inline(0)
    def reduce_Semicolons(self, semicolons):
        pass

    def reduce_empty(self):
        self.val = None


class ExtendingSimple(Nonterm):
    @parsing.inline(1)
    def reduce_EXTENDING_SimpleTypeNameList(self, _, list):
        pass


class OptExtendingSimple(Nonterm):
    @parsing.inline(0)
    def reduce_ExtendingSimple(self, extending):
        pass

    def reduce_empty(self):
        self.val = []


class Extending(Nonterm):
    @parsing.inline(1)
    def reduce_EXTENDING_TypeNameList(self, _, list):
        pass


class OptExtending(Nonterm):
    @parsing.inline(0)
    def reduce_Extending(self, extending):
        pass

    def reduce_empty(self):
        self.val = []


class CreateSimpleExtending(Nonterm):
    def reduce_EXTENDING_SimpleTypeNameList(self, *kids):
        self.val = qlast.AlterAddInherit(bases=kids[1].val)


class OnExpr(Nonterm):
    # NOTE: the reason why we need parentheses around the expression
    # is to disambiguate whether the '{' following the expression is
    # meant to be a shape or a nested DDL/SDL block.
    @parsing.inline(1)
    def reduce_ON_ParenExpr(self, _, expr):
        pass


class OptOnExpr(Nonterm):
    def reduce_empty(self):
        self.val = None

    @parsing.inline(0)
    def reduce_OnExpr(self, expr):
        pass


class OptDeferred(Nonterm):
    def reduce_empty(self):
        self.val = None

    def reduce_DEFERRED(self, _):
        self.val = True


class OptExceptExpr(Nonterm):
    def reduce_empty(self):
        self.val = None

    @parsing.inline(1)
    def reduce_EXCEPT_ParenExpr(self, _, expr):
        pass


class OptConcreteConstraintArgList(Nonterm):
    @parsing.inline(1)
    def reduce_LPAREN_OptPosCallArgList_RPAREN(self, _lparen, list, _rparen):
        pass

    def reduce_empty(self):
        self.val = []


class OptDefault(Nonterm):
    def reduce_empty(self):
        self.val = None

    @parsing.inline(1)
    def reduce_EQUALS_Expr(self, _, expr):
        pass


class ParameterKind(Nonterm):
    def reduce_VARIADIC(self, *kids):
        self.val = qltypes.ParameterKind.VariadicParam

    def reduce_NAMEDONLY(self, _):
        self.val = qltypes.ParameterKind.NamedOnlyParam


class OptParameterKind(Nonterm):
    def reduce_empty(self):
        self.val = qltypes.ParameterKind.PositionalParam

    @parsing.inline(0)
    def reduce_ParameterKind(self, *kids):
        pass


class FuncDeclArgName(Nonterm):
    def reduce_Identifier(self, dp):
        self.val = dp.val
        self.span = dp.span

    def reduce_PARAMETER(self, dp):
        if dp.val[1].isdigit():
            raise EdgeQLSyntaxError(
                f'numeric parameters are not supported',
                span=dp.span)
        else:
            raise EdgeQLSyntaxError(
                f"function parameters do not need a $ prefix, "
                f"rewrite as '{dp.val[1:]}'",
                span=dp.span)


class FuncDeclArg(Nonterm):
    def reduce_kwarg(self, kind, name, _, typemod, type, default):
        r"""%reduce OptParameterKind FuncDeclArgName COLON \
                OptTypeQualifier FullTypeExpr OptDefault \
        """
        self.val = qlast.FuncParamDecl(
            kind=kind.val,
            name=name.val,
            typemod=typemod.val,
            type=type.val,
            default=default.val
        )

    def reduce_OptParameterKind_FuncDeclArgName_OptDefault(
        self, kind, name, default
    ):
        raise EdgeQLSyntaxError(
            f'missing type declaration for the `{name.val}` parameter',
            span=name.span
        )


class FuncDeclArgList(parsing.ListNonterm, element=FuncDeclArg,
                      separator=tokens.T_COMMA, allow_trailing_separator=True):
    pass


class FuncDeclArgs(Nonterm):
    @parsing.inline(0)
    def reduce_FuncDeclArgList(self, list):
        pass


class ProcessFunctionParamsMixin:
    def _validate_params(self, params):
        last_pos_default_arg = None
        last_named_arg = None
        variadic_arg = None
        names = set()

        for arg in params:
            if isinstance(arg, tuple):
                # A tuple here means that it's part of the "param := val"
                raise EdgeQLSyntaxError(
                    f"Unexpected ':='",
                    span=arg[1])

            if arg.name in names:
                raise EdgeQLSyntaxError(
                    f'duplicate parameter name `{arg.name}`',
                    span=arg.span)
            names.add(arg.name)

            if arg.kind is qltypes.ParameterKind.VariadicParam:
                if variadic_arg is not None:
                    raise EdgeQLSyntaxError(
                        'more than one variadic argument',
                        span=arg.span)
                elif last_named_arg is not None:
                    raise EdgeQLSyntaxError(
                        f'NAMED ONLY argument `{last_named_arg.name}` '
                        f'before VARIADIC argument `{arg.name}`',
                        span=last_named_arg.span)
                else:
                    variadic_arg = arg

                if arg.default is not None:
                    raise EdgeQLSyntaxError(
                        f'VARIADIC argument `{arg.name}` '
                        f'cannot have a default value',
                        span=arg.span)

            elif arg.kind is qltypes.ParameterKind.NamedOnlyParam:
                last_named_arg = arg

            else:
                if last_named_arg is not None:
                    raise EdgeQLSyntaxError(
                        f'positional argument `{arg.name}` '
                        f'follows NAMED ONLY argument `{last_named_arg.name}`',
                        span=arg.span)

                if variadic_arg is not None:
                    raise EdgeQLSyntaxError(
                        f'positional argument `{arg.name}` '
                        f'follows VARIADIC argument `{variadic_arg.name}`',
                        span=arg.span)

            if arg.kind is qltypes.ParameterKind.PositionalParam:
                if arg.default is None:
                    if last_pos_default_arg is not None:
                        raise EdgeQLSyntaxError(
                            f'positional argument `{arg.name}` without '
                            f'default follows positional argument '
                            f'`{last_pos_default_arg.name}` with default',
                            span=arg.span)
                else:
                    last_pos_default_arg = arg


class CreateFunctionArgs(Nonterm, ProcessFunctionParamsMixin):
    def reduce_LPAREN_RPAREN(self, _lparen, _rparen):
        self.val = []

    def reduce_LPAREN_FuncDeclArgs_RPAREN(self, _lparen, args, _rparen):
        args = args.val
        self._validate_params(args)
        self.val = args


class OptTypeQualifier(Nonterm):
    def reduce_SET_OF(self, _s, _o):
        self.val = qltypes.TypeModifier.SetOfType

    def reduce_OPTIONAL(self, _):
        self.val = qltypes.TypeModifier.OptionalType

    def reduce_empty(self):
        self.val = qltypes.TypeModifier.SingletonType


class FunctionType(Nonterm):
    @parsing.inline(0)
    def reduce_FullTypeExpr(self, expr):
        pass


class FromFunction(Nonterm):
    def reduce_USING_ParenExpr(self, _, expr):
        lang = qlast.Language.EdgeQL
        self.val = qlast.FunctionCode(
            language=lang,
            nativecode=expr.val)

    def reduce_USING_Identifier_BaseStringConstant(self, _, ident, const):
        lang = _parse_language(ident)
        code = const.val.value
        self.val = qlast.FunctionCode(language=lang, code=code)

    def reduce_USING_Identifier_FUNCTION_BaseStringConstant(
        self, _using, ident, _function, const
    ):
        lang = _parse_language(ident)
        if lang != qlast.Language.SQL:
            raise EdgeQLSyntaxError(
                f'{lang} language is not supported in USING FUNCTION clause',
                span=ident.span) from None

        self.val = qlast.FunctionCode(
            language=lang,
            from_function=const.val.value
        )

    def reduce_USING_Identifier_EXPRESSION(self, _using, ident, _expression):
        lang = _parse_language(ident)
        if lang != qlast.Language.SQL:
            raise EdgeQLSyntaxError(
                f'{lang} language is not supported in USING clause',
                span=ident.span) from None

        self.val = qlast.FunctionCode(language=lang)


class ProcessFunctionBlockMixin:
    span: parsing.Span

    def _process_function_body(self, block, *, optional_using: bool=False):
        props: dict[str, typing.Any] = {}

        commands = []
        code = None
        nativecode = None
        language = qlast.Language.EdgeQL
        from_expr = False
        from_function = None

        for node in block.val:
            if isinstance(node, qlast.FunctionCode):
                if node.from_function:
                    if from_function is not None:
                        raise EdgeQLSyntaxError(
                            'more than one USING FUNCTION clause',
                            span=node.span)
                    from_function = node.from_function
                    language = qlast.Language.SQL

                elif node.nativecode:
                    if code is not None or nativecode is not None:
                        raise EdgeQLSyntaxError(
                            'more than one USING  clause',
                            span=node.span)
                    nativecode = node.nativecode
                    language = node.language

                elif node.code:
                    if code is not None or nativecode is not None:
                        raise EdgeQLSyntaxError(
                            'more than one USING  clause',
                            span=node.span)
                    code = node.code
                    language = node.language

                else:
                    # USING SQL EXPRESSION
                    from_expr = True
                    language = qlast.Language.SQL
            else:
                commands.append(node)

        if (
            nativecode is None and
            code is None and
            from_function is None and
            not from_expr and
            not optional_using
        ):
            raise EdgeQLSyntaxError(
                'missing a USING clause',
                span=block.span)

        else:
            if from_expr and (from_function or code):
                raise EdgeQLSyntaxError(
                    'USING SQL EXPRESSION is mutually exclusive with other '
                    'USING variants',
                    span=block.span)

            props['code'] = qlast.FunctionCode(
                language=language,
                from_function=from_function,
                from_expr=from_expr,
                code=code,
                span=self.span,
            )

            props['nativecode'] = nativecode

        if commands:
            props['commands'] = commands

        return props


#
# CREATE TYPE ... { CREATE LINK ... { ON TARGET DELETE ...
#
class OnTargetDeleteStmt(Nonterm):
    def reduce_ON_TARGET_DELETE_RESTRICT(self, *_):
        self.val = qlast.OnTargetDelete(
            cascade=qltypes.LinkTargetDeleteAction.Restrict)

    def reduce_ON_TARGET_DELETE_DELETE_SOURCE(self, *_):
        self.val = qlast.OnTargetDelete(
            cascade=qltypes.LinkTargetDeleteAction.DeleteSource)

    def reduce_ON_TARGET_DELETE_ALLOW(self, *_):
        self.val = qlast.OnTargetDelete(
            cascade=qltypes.LinkTargetDeleteAction.Allow)

    def reduce_ON_TARGET_DELETE_DEFERRED_RESTRICT(self, *_):
        self.val = qlast.OnTargetDelete(
            cascade=qltypes.LinkTargetDeleteAction.DeferredRestrict)


class OnSourceDeleteStmt(Nonterm):
    def reduce_ON_SOURCE_DELETE_DELETE_TARGET(self, *_):
        self.val = qlast.OnSourceDelete(
            cascade=qltypes.LinkSourceDeleteAction.DeleteTarget)

    def reduce_ON_SOURCE_DELETE_ALLOW(self, *_):
        self.val = qlast.OnSourceDelete(
            cascade=qltypes.LinkSourceDeleteAction.Allow)

    def reduce_ON_SOURCE_DELETE_DELETE_TARGET_IF_ORPHAN(self, *_):
        self.val = qlast.OnSourceDelete(
            cascade=qltypes.LinkSourceDeleteAction.DeleteTargetIfOrphan)


class OptWhenBlock(Nonterm):
    @parsing.inline(1)
    def reduce_WHEN_ParenExpr(self, _, expr):
        pass

    def reduce_empty(self):
        self.val = None


class OptUsingBlock(Nonterm):
    @parsing.inline(1)
    def reduce_USING_ParenExpr(self, _, expr):
        pass

    def reduce_empty(self):
        self.val = None


class AccessKind(Nonterm):
    val: list[qltypes.AccessKind]

    def reduce_ALL(self, _):
        self.val = list(qltypes.AccessKind)

    def reduce_SELECT(self, _):
        self.val = [qltypes.AccessKind.Select]

    def reduce_UPDATE(self, _):
        self.val = [
            qltypes.AccessKind.UpdateRead, qltypes.AccessKind.UpdateWrite]

    def reduce_UPDATE_READ(self, _u, _r):
        self.val = [qltypes.AccessKind.UpdateRead]

    def reduce_UPDATE_WRITE(self, _u, _w):
        self.val = [qltypes.AccessKind.UpdateWrite]

    def reduce_INSERT(self, _):
        self.val = [qltypes.AccessKind.Insert]

    def reduce_DELETE(self, _):
        self.val = [qltypes.AccessKind.Delete]


class AccessKindList(parsing.ListNonterm, element=AccessKind,
                     separator=tokens.T_COMMA):
    val: list[list[qltypes.AccessKind]]


class AccessPolicyAction(Nonterm):

    def reduce_ALLOW(self, _):
        self.val = qltypes.AccessPolicyAction.Allow

    def reduce_DENY(self, _):
        self.val = qltypes.AccessPolicyAction.Deny


class TriggerTiming(Nonterm):
    def reduce_AFTER(self, *kids):
        self.val = qltypes.TriggerTiming.After

    def reduce_AFTER_COMMIT_OF(self, *kids):
        self.val = qltypes.TriggerTiming.AfterCommitOf


class TriggerKind(Nonterm):
    def reduce_UPDATE(self, *kids):
        self.val = qltypes.TriggerKind.Update

    def reduce_INSERT(self, *kids):
        self.val = qltypes.TriggerKind.Insert

    def reduce_DELETE(self, *kids):
        self.val = qltypes.TriggerKind.Delete


class TriggerKindList(parsing.ListNonterm, element=TriggerKind,
                      separator=tokens.T_COMMA):
    pass


class TriggerScope(Nonterm):
    def reduce_EACH(self, *kids):
        self.val = qltypes.TriggerScope.Each

    def reduce_ALL(self, *kids):
        self.val = qltypes.TriggerScope.All


class RewriteKind(Nonterm):
    def reduce_UPDATE(self, *kids):
        self.val = qltypes.RewriteKind.Update

    def reduce_INSERT(self, *kids):
        self.val = qltypes.RewriteKind.Insert


class RewriteKindList(parsing.ListNonterm, element=RewriteKind,
                      separator=tokens.T_COMMA):
    pass


class ExtensionVersion(Nonterm):

    def reduce_VERSION_BaseStringConstant(self, _, const):
        version = const.val

        try:
            verutils.parse_version(version.value)
        except ValueError:
            raise EdgeQLSyntaxError(
                'invalid extension version format',
                details='Expected a SemVer-compatible format.',
                span=version.span,
            ) from None

        self.val = version


class OptExtensionVersion(Nonterm):

    @parsing.inline(0)
    def reduce_ExtensionVersion(self, version):
        pass

    def reduce_empty(self):
        self.val = None


class IndexArg(Nonterm):
    def reduce_kwarg_bad_definition(self, *kids):
        r"""%reduce FuncDeclArgName COLON \
                OptTypeQualifier FullTypeExpr OptDefault \
        """
        raise EdgeQLSyntaxError(
            f'index parameters have to be NAMED ONLY',
            span=kids[0].span)

    def reduce_kwarg_definition(self, kind, name, _, typemod, type, default):
        r"""%reduce ParameterKind FuncDeclArgName COLON \
                OptTypeQualifier FullTypeExpr OptDefault \
        """
        if kind.val is not qltypes.ParameterKind.NamedOnlyParam:
            raise EdgeQLSyntaxError(
                f'index parameters have to be NAMED ONLY',
                span=kind.span)

        self.val = qlast.FuncParamDecl(
            kind=kind.val,
            name=name.val,
            typemod=typemod.val,
            type=type.val,
            default=default.val
        )

    def reduce_AnyIdentifier_ASSIGN_Expr(self, ident, _, expr):
        self.val = (
            ident.val,
            ident.span,
            expr.val,
        )

    def reduce_FuncDeclArgName_OptDefault(self, name, default):
        raise EdgeQLSyntaxError(
            f'missing type declaration for the `{name.val}` parameter',
            span=name.span)


class IndexArgList(parsing.ListNonterm, element=IndexArg,
                   separator=tokens.T_COMMA, allow_trailing_separator=True):
    pass


class OptIndexArgList(Nonterm):
    @parsing.inline(0)
    def reduce_IndexArgList(self, list):
        pass

    def reduce_empty(self):
        self.val = []


class IndexExtArgList(Nonterm):

    @parsing.inline(1)
    def reduce_LPAREN_OptIndexArgList_RPAREN(self, *_):
        pass


class OptIndexExtArgList(Nonterm):

    @parsing.inline(0)
    def reduce_IndexExtArgList(self, list):
        pass

    def reduce_empty(self):
        self.val = []


class ProcessIndexMixin(ProcessFunctionParamsMixin):
    def _process_arguments(self, arguments):
        kwargs = {}
        for argval in arguments:
            if isinstance(argval, qlast.FuncParamDecl):
                raise EdgeQLSyntaxError(
                    f"unexpected new parameter definition `{argval.name}`",
                    span=argval.span)

            argname, argname_ctx, arg = argval
            if argname in kwargs:
                raise EdgeQLSyntaxError(
                    f"duplicate named argument `{argname}`",
                    span=argname_ctx)

            kwargs[argname] = arg

        return kwargs

    def _process_params_or_kwargs(self, bases, arguments):
        params = []
        kwargs = dict()

        # If the definition is extending another abstract index, then we
        # cannot define new parameters, but can only supply some arguments.
        if bases:
            kwargs = self._process_arguments(arguments)
        else:
            params = arguments
            self._validate_params(params)

        return params, kwargs

    def _process_sql_body(self, block, *, optional_using: bool=False):
        props: dict[str, typing.Any] = {}

        commands = []
        code = None

        for node in block.val:
            if isinstance(node, qlast.IndexCode):
                if code is not None:
                    raise EdgeQLSyntaxError(
                        'more than one USING  clause',
                        span=node.span)
                props['code'] = node
            else:
                commands.append(node)

        if commands:
            props['commands'] = commands

        return props


================================================
FILE: edb/edgeql/parser/grammar/config.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from edb import errors

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from .expressions import Nonterm
from .tokens import *  # NOQA
from .expressions import *  # NOQA


class ConfigScope(Nonterm):

    def reduce_SESSION(self, _):
        self.val = qltypes.ConfigScope.SESSION

    def reduce_CURRENT_DATABASE(self, _c, _d):
        self.val = qltypes.ConfigScope.DATABASE

    def reduce_CURRENT_BRANCH(self, _c, _d):
        self.val = qltypes.ConfigScope.DATABASE

    def reduce_SYSTEM(self, _):
        self.val = qltypes.ConfigScope.INSTANCE

    def reduce_INSTANCE(self, _):
        self.val = qltypes.ConfigScope.INSTANCE


class ConfigOp(Nonterm):
    val: qlast.ConfigOp

    def reduce_SET_NodeName_ASSIGN_Expr(self, _s, name, _a, expr):
        self.val = qlast.ConfigSet(
            name=name.val,
            expr=expr.val,
        )

    def reduce_INSERT_NodeName_Shape(self, _, name, shape):
        self.val = qlast.ConfigInsert(
            name=name.val,
            shape=shape.val,
        )

    def reduce_RESET_NodeName_OptFilterClause(self, _, name, where):
        self.val = qlast.ConfigReset(
            name=name.val,
            where=where.val,
        )


class ConfigStmt(Nonterm):

    def reduce_CONFIGURE_DATABASE_ConfigOp(self, configure, database, _config):
        raise errors.EdgeQLSyntaxError(
            f"'{configure.val} {database.val}' is invalid syntax. "
            f"Did you mean '{configure.val} "
            f"{'current' if database.val[0] == 'd' else 'CURRENT'} "
            f"{database.val}'?",
            span=database.span)

    def reduce_CONFIGURE_BRANCH_ConfigOp(self, configure, database, _config):
        raise errors.EdgeQLSyntaxError(
            f"'{configure.val} {database.val}' is invalid syntax. "
            f"Did you mean '{configure.val} "
            f"{'current' if database.val[0] == 'd' else 'CURRENT'} "
            f"{database.val}'?",
            span=database.span)

    def reduce_CONFIGURE_ConfigScope_ConfigOp(self, _, scope, op):
        self.val = op.val
        self.val.scope = scope.val

    def reduce_SET_GLOBAL_NodeName_ASSIGN_Expr(self, _s, _g, name, _a, expr):
        self.val = qlast.ConfigSet(
            name=name.val,
            expr=expr.val,
            scope=qltypes.ConfigScope.GLOBAL,
        )

    def reduce_RESET_GLOBAL_NodeName(self, _r, _g, name):
        self.val = qlast.ConfigReset(
            name=name.val,
            scope=qltypes.ConfigScope.GLOBAL,
        )


================================================
FILE: edb/edgeql/parser/grammar/ddl.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2015-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

import collections
import re
import textwrap
import typing

from edb import errors
from edb.errors import EdgeQLSyntaxError

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from edb.common import span as edb_span
from edb.common import parsing

from . import expressions
from . import commondl
from . import tokens

from .precedence import *  # NOQA
from .tokens import *  # NOQA
from .commondl import *  # NOQA

from .sdl import *  # NOQA


Nonterm = expressions.Nonterm  # type: ignore[misc]
Semicolons = commondl.Semicolons  # type: ignore[misc]


sdl_nontem_helper = commondl.NewNontermHelper(__name__)
_new_nonterm = sdl_nontem_helper._new_nonterm


class DDLStmt(Nonterm):
    val: qlast.DDLCommand

    @parsing.inline(0)
    def reduce_DatabaseStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_BranchStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_RoleStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_ExtensionPackageStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_OptWithDDLStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_MigrationStmt(self, *_):
        pass


class DDLWithBlock(Nonterm):
    @parsing.inline(0)
    def reduce_WithBlock(self, *_):
        pass


class OptWithDDLStmt(Nonterm):
    def reduce_DDLWithBlock_WithDDLStmt(self, *kids):
        self.val = kids[1].val
        self.val.aliases = kids[0].val.aliases

    @parsing.inline(0)
    def reduce_WithDDLStmt(self, *_):
        pass


class WithDDLStmt(Nonterm):
    @parsing.inline(0)
    def reduce_InnerDDLStmt(self, *_):
        pass


class InnerDDLStmt(Nonterm):

    @parsing.inline(0)
    def reduce_CreatePseudoTypeStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateScalarTypeStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterScalarTypeStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropScalarTypeStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateAnnotationStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterAnnotationStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropAnnotationStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateObjectTypeStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterObjectTypeStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropObjectTypeStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateAliasStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterAliasStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropAliasStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateConstraintStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterConstraintStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropConstraintStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateLinkStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterLinkStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropLinkStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreatePropertyStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterPropertyStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropPropertyStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateModuleStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterModuleStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropModuleStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateFunctionStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterFunctionStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropFunctionStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateOperatorStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterOperatorStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropOperatorStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateCastStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterCastStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateGlobalStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterGlobalStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropGlobalStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreatePermissionStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterPermissionStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropPermissionStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropCastStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_ExtensionStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_FutureStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateIndexStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_AlterIndexStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropIndexStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_CreateIndexMatchStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_DropIndexMatchStmt(self, *_):
        pass


class PointerName(Nonterm):
    @parsing.inline(0)
    def reduce_PtrNodeName(self, *kids):
        pass

    def reduce_DUNDERTYPE(self, *kids):
        self.val = qlast.ObjectRef(name=kids[0].val)


class UnqualifiedPointerName(Nonterm):
    def reduce_PointerName(self, *kids):
        if kids[0].val.module:
            raise EdgeQLSyntaxError(
                'unexpected fully-qualified name',
                span=kids[0].val.span)
        self.val = kids[0].val


class OptIfNotExists(Nonterm):
    def reduce_IF_NOT_EXISTS(self, *kids):
        self.val = True

    def reduce_empty(self, *kids):
        self.val = False


class ProductionTpl:
    def _passthrough(self, cmd):
        self.val = cmd.val

    def _singleton_list(self, cmd):
        self.val = [cmd.val]

    def _empty(self, *kids):
        self.val = []

    def _block(self, lbrace, cmdlist, sc2, rbrace):
        self.val = cmdlist.val

    def _block2(self, lbrace, sc1, cmdlist, sc2, rbrace):
        self.val = cmdlist.val


def commands_block(parent, *commands, opt=True, production_tpl=ProductionTpl):
    if parent is None:
        parent = ''

    clsdict = collections.OrderedDict()

    # Command := Command1 | Command2 ...
    #
    for command in commands:
        clsdict['reduce_{}'.format(command.__name__)] = \
            production_tpl._passthrough

    cmd = _new_nonterm(parent + 'Command', clsdict=clsdict)

    # CommandsList := Command [; Command ...]
    cmdlist = _new_nonterm(parent + 'CommandsList',
                           clsbases=(parsing.ListNonterm,),
                           clskwds=dict(element=cmd, separator=Semicolons))

    # CommandsBlock :=
    #
    #   { [ ; ] CommandsList ; }
    clsdict = collections.OrderedDict()
    clsdict['reduce_LBRACE_' + cmdlist.__name__ + '_OptSemicolons_RBRACE'] = \
        production_tpl._block
    clsdict['reduce_LBRACE_Semicolons_' + cmdlist.__name__ +
            '_OptSemicolons_RBRACE'] = \
        production_tpl._block2
    clsdict['reduce_LBRACE_OptSemicolons_RBRACE'] = \
        production_tpl._empty
    if not opt:
        #
        #   | Command
        clsdict['reduce_{}'.format(cmd.__name__)] = \
            production_tpl._singleton_list
    cmdblock = _new_nonterm(
        parent + 'CommandsBlock',
        clsdict=clsdict,
        clsbases=(Nonterm, production_tpl),
    )

    # OptCommandsBlock := CommandsBlock | 
    clsdict = collections.OrderedDict()
    clsdict['reduce_{}'.format(cmdblock.__name__)] = \
        production_tpl._passthrough
    clsdict['reduce_empty'] = production_tpl._empty

    if opt:
        _new_nonterm(
            'Opt' + parent + 'CommandsBlock',
            clsdict=clsdict,
            clsbases=(Nonterm, production_tpl),
        )


class NestedQLBlockStmt(Nonterm):
    val: qlast.DDLOperation

    def reduce_Stmt(self, stmt):
        if isinstance(stmt.val, qlast.Query):
            self.val = qlast.DDLQuery(query=stmt.val)
        else:
            self.val = stmt.val

    @parsing.inline(0)
    def reduce_OptWithDDLStmt(self, *_):
        pass

    @parsing.inline(0)
    def reduce_SetFieldStmt(self, *kids):
        pass


class NestedQLBlock(ProductionTpl):

    @property
    def allowed_fields(self) -> frozenset[str]:
        raise NotImplementedError

    @property
    def result(self) -> typing.Any:
        raise NotImplementedError

    def _process_body(self, body):
        fields = []
        stmts = []
        uniq_check = set()
        for stmt in body:
            if isinstance(stmt, qlast.SetField):
                if stmt.name not in self.allowed_fields:
                    raise errors.InvalidSyntaxError(
                        f'unexpected field: {stmt.name!r}',
                        span=stmt.span,
                    )
                if stmt.name in uniq_check:
                    raise errors.InvalidSyntaxError(
                        f'duplicate `SET {stmt.name} := ...`',
                        span=stmt.span,
                    )
                uniq_check.add(stmt.name)
                fields.append(stmt)
            else:
                stmts.append(stmt)

        return fields, stmts

    def _get_text(self, body):
        # XXX: Workaround the rust lexer issue of returning
        # byte token offsets instead of character offsets.
        src_start = body.span.start
        src_end = body.span.end
        buffer = body.span.buffer.encode('utf-8')
        text = buffer[src_start:src_end].decode('utf-8').strip().strip('}{\n')
        return textwrap.dedent(text).strip('\n')

    def _block(self, lbrace, cmdlist, sc2, rbrace):
        # LBRACE NestedQLBlock OptSemicolons RBRACE
        fields, stmts = self._process_body(cmdlist.val)
        body = qlast.NestedQLBlock(commands=stmts)

        kids = [lbrace, cmdlist, sc2, rbrace]
        body.span = (
            edb_span.merge_spans(k.span for k in kids if k.span)
            or edb_span.Span.empty()
        )

        body.text = self._get_text(body)
        self.val = self.result(body=body, fields=fields)

    def _block2(self, lbrace, sc1, cmdlist, sc2, rbrace):
        # LBRACE Semicolons NestedQLBlock OptSemicolons RBRACE
        fields, stmts = self._process_body(cmdlist.val)
        body = qlast.NestedQLBlock(commands=stmts)

        kids = [lbrace, sc1, cmdlist, sc2, rbrace]
        body.span = (
            edb_span.merge_spans(k.span for k in kids if k.span)
            or edb_span.Span.empty()
        )

        body.text = self._get_text(body)
        self.val = self.result(body=body, fields=fields)

    def _empty(self, *kids):
        # LBRACE OptSemicolons RBRACE | 
        self.val = []
        body = qlast.NestedQLBlock(commands=[])
        body.span = (
            edb_span.merge_spans(k.span for k in kids if k.span)
            or edb_span.Span.empty()
        )
        body.text = self._get_text(body)
        self.val = self.result(body=body, fields=[])


class UsingStmt(Nonterm):

    def reduce_USING_ParenExpr(self, *kids):
        self.val = qlast.SetField(
            name='expr',
            value=kids[1].val,
            special_syntax=True,
        )

    def reduce_RESET_EXPRESSION(self, *kids):
        self.val = qlast.SetField(
            name='expr',
            value=None,
            special_syntax=True,
        )


class SetFieldStmt(Nonterm):
    # field := 
    def reduce_SET_Identifier_ASSIGN_GenExpr(self, *kids):
        self.val = qlast.SetField(
            name=kids[1].val.lower(),
            value=kids[3].val,
        )


class ResetFieldStmt(Nonterm):
    # RESET field
    def reduce_RESET_IDENT(self, *kids):
        self.val = qlast.SetField(
            name=kids[1].val.lower(),
            value=None,
        )

    def reduce_RESET_DEFAULT(self, *kids):
        self.val = qlast.SetField(
            name='default',
            value=None,
        )


class CreateAnnotationValueStmt(Nonterm):
    def reduce_CREATE_ANNOTATION_NodeName_ASSIGN_GenExpr(self, *kids):
        self.val = qlast.CreateAnnotationValue(
            name=kids[2].val,
            value=kids[4].val,
        )


class AlterAnnotationValueStmt(Nonterm):
    def reduce_ALTER_ANNOTATION_NodeName_ASSIGN_GenExpr(self, *kids):
        self.val = qlast.AlterAnnotationValue(
            name=kids[2].val,
            value=kids[4].val,
        )

    def reduce_ALTER_ANNOTATION_NodeName_DROP_OWNED(self, *kids):
        self.val = qlast.AlterAnnotationValue(
            name=kids[2].val,
        )
        self.val.commands = [qlast.SetField(
            name='owned',
            value=qlast.Constant.boolean(False, span=self.span),
            special_syntax=True,
        )]


class DropAnnotationValueStmt(Nonterm):
    def reduce_DROP_ANNOTATION_NodeName(self, *kids):
        self.val = qlast.DropAnnotationValue(
            name=kids[2].val,
        )


class RenameStmt(Nonterm):
    def reduce_RENAME_TO_NodeName(self, *kids):
        self.val = qlast.Rename(new_name=kids[2].val)


commands_block(
    'Create',
    UsingStmt,
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
)


commands_block(
    'Alter',
    UsingStmt,
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    opt=False)


class AlterAbstract(Nonterm):

    def reduce_DROP_ABSTRACT(self, *kids):
        # TODO: Raise a DeprecationWarning once we have facility for that.
        self.val = qlast.SetField(
            name='abstract',
            value=qlast.Constant.boolean(False, span=self.span),
            special_syntax=True,
        )

    def reduce_SET_NOT_ABSTRACT(self, *kids):
        self.val = qlast.SetField(
            name='abstract',
            value=qlast.Constant.boolean(False, span=self.span),
            special_syntax=True,
        )

    def reduce_SET_ABSTRACT(self, *kids):
        self.val = qlast.SetField(
            name='abstract',
            value=qlast.Constant.boolean(True, span=self.span),
            special_syntax=True,
        )

    def reduce_RESET_ABSTRACT(self, *kids):
        self.val = qlast.SetField(
            name='abstract',
            value=None,
            special_syntax=True,
        )


class OptPosition(Nonterm):
    def reduce_BEFORE_NodeName(self, *kids):
        self.val = qlast.Position(ref=kids[1].val, position='BEFORE')

    def reduce_AFTER_NodeName(self, *kids):
        self.val = qlast.Position(ref=kids[1].val, position='AFTER')

    def reduce_FIRST(self, *kids):
        self.val = qlast.Position(position='FIRST')

    def reduce_LAST(self, *kids):
        self.val = qlast.Position(position='LAST')

    def reduce_empty(self, *kids):
        self.val = None


class AlterSimpleExtending(Nonterm):
    def reduce_EXTENDING_SimpleTypeNameList_OptPosition(self, *kids):
        self.val = qlast.AlterAddInherit(
            bases=kids[1].val, position=kids[2].val
        )

    def reduce_DROP_EXTENDING_SimpleTypeNameList(self, *kids):
        self.val = qlast.AlterDropInherit(bases=kids[2].val)

    @parsing.inline(0)
    def reduce_AlterAbstract(self, *kids):
        pass


class AlterExtending(Nonterm):
    def reduce_EXTENDING_TypeNameList_OptPosition(self, *kids):
        self.val = qlast.AlterAddInherit(
            bases=kids[1].val, position=kids[2].val
        )

    def reduce_DROP_EXTENDING_TypeNameList(self, *kids):
        self.val = qlast.AlterDropInherit(bases=kids[2].val)

    @parsing.inline(0)
    def reduce_AlterAbstract(self, *kids):
        pass


class AlterOwnedStmt(Nonterm):

    def reduce_DROP_OWNED(self, *kids):
        self.val = qlast.SetField(
            name='owned',
            value=qlast.Constant.boolean(False, span=self.span),
            special_syntax=True,
        )

    def reduce_SET_OWNED(self, *kids):
        self.val = qlast.SetField(
            name='owned',
            value=qlast.Constant.boolean(True, span=self.span),
            special_syntax=True,
        )


#
# DATABASE
#


class DatabaseName(Nonterm):

    def reduce_Identifier(self, kid):
        self.val = qlast.ObjectRef(module=None, name=kid.val)

    def reduce_ReservedKeyword(self, *kids):
        name = kids[0].val
        if (
            name[:2] == '__' and name[-2:] == '__' and
            name not in {'__edgedbsys__', '__edgedbtpl__'}
        ):
            # There are a few reserved keywords like __std__ and __subject__
            # that can be used in paths but are prohibited to be used
            # anywhere else. So just as the tokenizer prohibits using
            # __names__ in general, we enforce the rule here for the
            # few remaining reserved __keywords__.
            raise EdgeQLSyntaxError(
                "identifiers surrounded by double underscores are forbidden",
                span=kids[0].span)

        self.val = qlast.ObjectRef(
            module=None,
            name=name
        )


class DatabaseStmt(Nonterm):

    @parsing.inline(0)
    def reduce_CreateDatabaseStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_DropDatabaseStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AlterDatabaseStmt(self, *kids):
        pass


#
# CREATE DATABASE
#


commands_block(
    'CreateDatabase',
    SetFieldStmt,
)


class CreateDatabaseStmt(Nonterm):
    def reduce_CREATE_DATABASE_regular(self, *kids):
        """%reduce CREATE DATABASE DatabaseName OptCreateDatabaseCommandsBlock
        """
        self.val = qlast.CreateDatabase(
            name=kids[2].val,
            commands=kids[3].val,
            branch_type=qlast.BranchType.EMPTY,
            flavor='DATABASE',
        )

    # TODO: This one should probably not exist, and we'll get rid of
    # it once we merge Victor's new testing.
    def reduce_CREATE_DATABASE_from_template(self, *kids):
        """%reduce
            CREATE DATABASE DatabaseName FROM AnyNodeName
            OptCreateDatabaseCommandsBlock
        """
        _, _, _name, _, _template, _commands = kids
        self.val = qlast.CreateDatabase(
            name=kids[2].val,
            commands=kids[5].val,
            branch_type=qlast.BranchType.DATA,
            template=kids[4].val,
            flavor='DATABASE',
        )


#
# DROP DATABASE
#
class DropDatabaseStmt(Nonterm):
    def reduce_DROP_DATABASE_DatabaseName(self, *kids):
        self.val = qlast.DropDatabase(
            name=kids[2].val,
            flavor='DATABASE',
        )


#
# ALTER DATABASE
#


commands_block(
    'AlterDatabase',
    RenameStmt,
    opt=False
)


class AlterDatabaseStmt(Nonterm):
    def reduce_ALTER_DATABASE_DatabaseName_AlterDatabaseCommandsBlock(
        self, *kids
    ):
        _, _, name, commands = kids
        self.val = qlast.AlterDatabase(
            name=name.val,
            commands=commands.val,
        )


#
# BRANCH
#


class BranchStmt(Nonterm):

    @parsing.inline(0)
    def reduce_CreateBranchStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_DropBranchStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AlterBranchStmt(self, *kids):
        pass

#
# CREATE BRANCH
#


class CreateBranchStmt(Nonterm):
    def reduce_CREATE_EMPTY_BRANCH_DatabaseName(self, *kids):
        self.val = qlast.CreateDatabase(
            name=kids[3].val,
            branch_type=qlast.BranchType.EMPTY,
        )

    def reduce_create_schema_branch(self, *kids):
        """%reduce
            CREATE SCHEMA BRANCH DatabaseName FROM DatabaseName
        """
        self.val = qlast.CreateDatabase(
            name=kids[3].val,
            template=kids[5].val,
            branch_type=qlast.BranchType.SCHEMA,
        )

    def reduce_create_data_branch(self, *kids):
        """%reduce
            CREATE DATA BRANCH DatabaseName FROM DatabaseName
        """
        self.val = qlast.CreateDatabase(
            name=kids[3].val,
            template=kids[5].val,
            branch_type=qlast.BranchType.DATA,
        )

    def reduce_create_template_branch(self, *kids):
        """%reduce
            CREATE TEMPLATE BRANCH DatabaseName FROM DatabaseName
        """
        self.val = qlast.CreateDatabase(
            name=kids[3].val,
            template=kids[5].val,
            branch_type=qlast.BranchType.TEMPLATE,
        )


#
# DROP BRANCH
#

BranchOptionsSpec = collections.namedtuple(
    'BranchOptionsSpec', ['force'])


class BranchOptions(Nonterm):
    # This is generalizable, but we don't bother generalizing it yet.
    def reduce_empty(self, *kids):
        self.val = BranchOptionsSpec(force=False)

    def reduce_FORCE(self, *kids):
        self.val = BranchOptionsSpec(force=True)


class DropBranchStmt(Nonterm):
    def reduce_DROP_BRANCH_DatabaseName_BranchOptions(self, *kids):
        _, _, name, options = kids
        self.val = qlast.DropDatabase(
            name=name.val,
            force=options.val.force,
        )


#
# ALTER BRANCH
#


commands_block(
    'AlterBranch',
    RenameStmt,
    opt=False
)


class AlterBranchStmt(Nonterm):
    def reduce_alter_branch(self, *kids):
        """%reduce
            ALTER BRANCH DatabaseName BranchOptions AlterBranchCommandsBlock
        """
        _, _, name, options, commands = kids
        self.val = qlast.AlterDatabase(
            name=name.val,
            commands=commands.val,
            force=options.val.force,
        )


#
# EXTENSION PACKAGE
#

class ExtensionPackageStmt(Nonterm):

    @parsing.inline(0)
    def reduce_CreateExtensionPackageStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_DropExtensionPackageStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_CreateExtensionPackageMigrationStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_DropExtensionPackageMigrationStmt(self, *kids):
        pass


#
# CREATE EXTENSION PACKAGE
#
class ExtensionPackageBody(typing.NamedTuple):

    body: qlast.NestedQLBlock
    fields: list[qlast.SetField]


class CreateExtensionPackageBodyBlock(NestedQLBlock):

    @property
    def allowed_fields(self) -> frozenset[str]:
        return frozenset(
            {'internal', 'ext_module', 'sql_extensions', 'dependencies',
             'sql_setup_script', 'sql_teardown_script'}
        )

    @property
    def result(self) -> typing.Any:
        return ExtensionPackageBody


commands_block(
    'CreateExtensionPackage',
    NestedQLBlockStmt,
    opt=True,
    production_tpl=CreateExtensionPackageBodyBlock,
)


class CreateExtensionPackageStmt(Nonterm):

    def reduce_CreateExtensionPackageStmt(self, *kids):
        r"""%reduce CREATE EXTENSIONPACKAGE ShortNodeName
                    ExtensionVersion
                    OptCreateExtensionPackageCommandsBlock
        """
        self.val = qlast.CreateExtensionPackage(
            name=kids[2].val,
            version=kids[3].val,
            body=kids[4].val.body,
            commands=kids[4].val.fields,
        )


#
# DROP EXTENSION PACKAGE
#
class DropExtensionPackageStmt(Nonterm):

    def reduce_DropExtensionPackageStmt(self, *kids):
        r"""%reduce DROP EXTENSIONPACKAGE ShortNodeName ExtensionVersion"""
        self.val = qlast.DropExtensionPackage(
            name=kids[2].val,
            version=kids[3].val,
        )


#
# CREATE EXTENSION PACKAGE MIGRATION
#

class CreateExtensionPackageMigrationBodyBlock(NestedQLBlock):

    @property
    def allowed_fields(self) -> frozenset[str]:
        return frozenset(
            {'early_sql_script', 'late_sql_script'}
        )

    @property
    def result(self) -> typing.Any:
        return ExtensionPackageBody


commands_block(
    'CreateExtensionPackage',
    NestedQLBlockStmt,
    opt=True,
    production_tpl=CreateExtensionPackageBodyBlock,
)


class CreateExtensionPackageMigrationStmt(Nonterm):

    def reduce_CreateExtensionPackageMigrationStmt(self, *kids):
        r"""%reduce CREATE EXTENSIONPACKAGE ShortNodeName
                    MIGRATION FROM
                    ExtensionVersion TO
                    ExtensionVersion
                    OptCreateExtensionPackageCommandsBlock
        """
        _, _, name, _, _, from_version, _, to_version, block = kids
        self.val = qlast.CreateExtensionPackageMigration(
            name=name.val,
            from_version=from_version.val,
            to_version=to_version.val,
            body=block.val.body,
            commands=block.val.fields,
        )


#
# DROP EXTENSION PACKAGE MIGRATION
#
class DropExtensionPackageMigrationStmt(Nonterm):

    def reduce_DropExtensionPackageMigrationStmt(self, *kids):
        r"""%reduce DROP EXTENSIONPACKAGE ShortNodeName
                    MIGRATION FROM
                    ExtensionVersion TO
                    ExtensionVersion
        """
        _, _, name, _, _, from_version, _, to_version = kids

        self.val = qlast.DropExtensionPackageMigration(
            name=name.val,
            from_version=from_version.val,
            to_version=to_version.val,
        )


#
# EXTENSIONS
#


class ExtensionStmt(Nonterm):

    @parsing.inline(0)
    def reduce_CreateExtensionStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AlterExtensionStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_DropExtensionStmt(self, *kids):
        pass


#
# CREATE EXTENSION
#


commands_block(
    'CreateExtension',
    SetFieldStmt,
)


class CreateExtensionStmt(Nonterm):

    def reduce_CreateExtensionStmt(self, *kids):
        r"""%reduce CREATE EXTENSION ShortNodeName OptExtensionVersion
                    OptCreateExtensionCommandsBlock
        """
        self.val = qlast.CreateExtension(
            name=kids[2].val,
            version=kids[3].val,
            commands=kids[4].val,
        )

#
# ALTER EXTENSION
#


class AlterExtensionStmt(Nonterm):

    def reduce_AlterExtensionStmt(self, *kids):
        r"""%reduce ALTER EXTENSION ShortNodeName
                    TO ExtensionVersion
        """
        _, _, name, _, ver = kids
        self.val = qlast.AlterExtension(
            name=name.val,
            to_version=ver.val,
        )


#
# DROP EXTENSION
#
class DropExtensionStmt(Nonterm):

    def reduce_DropExtensionPackageStmt(self, *kids):
        r"""%reduce DROP EXTENSION ShortNodeName OptExtensionVersion"""
        self.val = qlast.DropExtension(
            name=kids[2].val,
            version=kids[3].val,
        )


#
# FUTURE
#


class FutureStmt(Nonterm):

    @parsing.inline(0)
    def reduce_CreateFutureStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_DropFutureStmt(self, *kids):
        pass


#
# CREATE FUTURE
#


class CreateFutureStmt(Nonterm):

    def reduce_CreateFutureStmt(self, *kids):
        r"""%reduce CREATE FUTURE ShortNodeName"""
        self.val = qlast.CreateFuture(
            name=kids[2].val,
        )


#
# DROP FUTURE
#
class DropFutureStmt(Nonterm):

    def reduce_DropFutureStmt(self, *kids):
        r"""%reduce DROP FUTURE ShortNodeName"""
        self.val = qlast.DropFuture(
            name=kids[2].val,
        )


#
# ROLE
#

class RoleStmt(Nonterm):

    @parsing.inline(0)
    def reduce_CreateRoleStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AlterRoleStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_DropRoleStmt(self, *kids):
        pass


class ShortTypeName(Nonterm):
    def reduce_ShortNodeName(self, name):
        self.val = qlast.TypeName(maintype=name.val)


class ShortTypeNameList(
    parsing.ListNonterm, element=ShortTypeName, separator=tokens.T_COMMA
):
    pass


#
# CREATE ROLE
#
class ShortExtending(Nonterm):
    @parsing.inline(1)
    def reduce_EXTENDING_ShortTypeNameList(self, *kids):
        pass


class OptShortExtending(Nonterm):
    @parsing.inline(0)
    def reduce_ShortExtending(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = []


commands_block(
    'CreateRole',
    SetFieldStmt,
)


class OptSuperuser(Nonterm):

    def reduce_SUPERUSER(self, *kids):
        self.val = True

    def reduce_empty(self, *kids):
        self.val = False


class CreateRoleStmt(Nonterm):
    def reduce_CreateRoleStmt(self, *kids):
        r"""%reduce CREATE OptSuperuser ROLE ShortNodeName
                    OptShortExtending OptIfNotExists OptCreateRoleCommandsBlock
        """
        self.val = qlast.CreateRole(
            name=kids[3].val,
            bases=kids[4].val,
            create_if_not_exists=kids[5].val,
            commands=kids[6].val,
            superuser=kids[1].val,
        )


#
# ALTER ROLE
#
class AlterRoleExtending(Nonterm):
    def reduce_EXTENDING_ShortTypeNameList_OptPosition(self, *kids):
        self.val = qlast.AlterAddInherit(
            bases=kids[1].val,
            position=kids[2].val
        )

    def reduce_DROP_EXTENDING_ShortTypeNameList(self, *kids):
        self.val = qlast.AlterDropInherit(
            bases=kids[2].val
        )


commands_block(
    'AlterRole',
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    AlterRoleExtending,
    opt=False
)


class AlterRoleStmt(Nonterm):
    def reduce_ALTER_ROLE_ShortNodeName_AlterRoleCommandsBlock(self, *kids):
        self.val = qlast.AlterRole(
            name=kids[2].val,
            commands=kids[3].val,
        )


#
# DROP ROLE
#
class DropRoleStmt(Nonterm):
    def reduce_DROP_ROLE_ShortNodeName(self, *kids):
        self.val = qlast.DropRole(
            name=kids[2].val,
        )


#
# CREATE CONSTRAINT
#
class CreateConstraintStmt(Nonterm):
    def reduce_CreateConstraint(self, *kids):
        r"""%reduce CREATE ABSTRACT CONSTRAINT NodeName OptOnExpr \
                    OptExtendingSimple OptCreateCommandsBlock"""
        self.val = qlast.CreateConstraint(
            name=kids[3].val,
            subjectexpr=kids[4].val,
            bases=kids[5].val,
            commands=kids[6].val,
        )

    def reduce_CreateConstraint_CreateFunctionArgs(self, *kids):
        r"""%reduce CREATE ABSTRACT CONSTRAINT NodeName CreateFunctionArgs \
                    OptOnExpr OptExtendingSimple OptCreateCommandsBlock"""
        self.val = qlast.CreateConstraint(
            name=kids[3].val,
            params=kids[4].val,
            subjectexpr=kids[5].val,
            bases=kids[6].val,
            commands=kids[7].val,
        )


class AlterConstraintStmt(Nonterm):
    def reduce_CreateConstraint(self, *kids):
        r"""%reduce ALTER ABSTRACT CONSTRAINT NodeName \
                    AlterCommandsBlock"""
        self.val = qlast.AlterConstraint(
            name=kids[3].val,
            commands=kids[4].val,
        )


class DropConstraintStmt(Nonterm):
    def reduce_CreateConstraint(self, *kids):
        r"""%reduce DROP ABSTRACT CONSTRAINT NodeName"""
        self.val = qlast.DropConstraint(
            name=kids[3].val
        )


class OptDelegated(Nonterm):
    def reduce_DELEGATED(self, *kids):
        self.val = True

    def reduce_empty(self):
        self.val = False


class CreateConcreteConstraintStmt(Nonterm):
    def reduce_CreateConstraint(self, *kids):
        r"""%reduce CREATE OptDelegated CONSTRAINT \
                    NodeName OptConcreteConstraintArgList OptOnExpr \
                    OptExceptExpr \
                    OptCreateCommandsBlock"""
        self.val = qlast.CreateConcreteConstraint(
            delegated=kids[1].val,
            name=kids[3].val,
            args=kids[4].val,
            subjectexpr=kids[5].val,
            except_expr=kids[6].val,
            commands=kids[7].val,
        )


class SetDelegatedStmt(Nonterm):

    def reduce_SET_DELEGATED(self, *kids):
        self.val = qlast.SetField(
            name='delegated',
            value=qlast.Constant.boolean(True, span=self.span),
            special_syntax=True,
        )

    def reduce_SET_NOT_DELEGATED(self, *kids):
        self.val = qlast.SetField(
            name='delegated',
            value=qlast.Constant.boolean(False, span=self.span),
            special_syntax=True,
        )

    def reduce_RESET_DELEGATED(self, *kids):
        self.val = qlast.SetField(
            name='delegated',
            value=None,
            special_syntax=True,
        )


commands_block(
    'AlterConcreteConstraint',
    SetFieldStmt,
    ResetFieldStmt,
    SetDelegatedStmt,
    AlterOwnedStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    AlterAbstract,
    opt=False
)


class AlterConcreteConstraintStmt(Nonterm):
    def reduce_CreateConstraint(self, *kids):
        r"""%reduce ALTER CONSTRAINT NodeName
                    OptConcreteConstraintArgList OptOnExpr OptExceptExpr
                    AlterConcreteConstraintCommandsBlock"""
        self.val = qlast.AlterConcreteConstraint(
            name=kids[2].val,
            args=kids[3].val,
            subjectexpr=kids[4].val,
            except_expr=kids[5].val,
            commands=kids[6].val,
        )


class DropConcreteConstraintStmt(Nonterm):
    def reduce_DropConstraint(self, *kids):
        r"""%reduce DROP CONSTRAINT NodeName
                    OptConcreteConstraintArgList OptOnExpr OptExceptExpr"""
        self.val = qlast.DropConcreteConstraint(
            name=kids[2].val,
            args=kids[3].val,
            subjectexpr=kids[4].val,
            except_expr=kids[5].val,
        )


#
# CREATE PSEUDO TYPE
#

commands_block(
    'CreatePseudoType',
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
)


class CreatePseudoTypeStmt(Nonterm):

    def reduce_CreatePseudoTypeStmt(self, *kids):
        r"""%reduce
            CREATE PSEUDO TYPE NodeName OptCreatePseudoTypeCommandsBlock
        """
        self.val = qlast.CreatePseudoType(
            name=kids[3].val,
            commands=kids[4].val,
        )


#
# CREATE SCALAR TYPE
#

commands_block(
    'CreateScalarType',
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    CreateConcreteConstraintStmt)


class CreateScalarTypeStmt(Nonterm):
    def reduce_CreateAbstractScalarTypeStmt(self, *kids):
        r"""%reduce \
            CREATE ABSTRACT SCALAR TYPE NodeName \
            OptExtending OptCreateScalarTypeCommandsBlock \
        """
        self.val = qlast.CreateScalarType(
            name=kids[4].val,
            abstract=True,
            bases=kids[5].val,
            commands=kids[6].val
        )

    def reduce_CreateFinalScalarTypeStmt(self, *kids):
        r"""%reduce \
            CREATE FINAL SCALAR TYPE NodeName \
            OptExtending OptCreateScalarTypeCommandsBlock \
        """
        # Old dumps (1.0-beta.3 and earlier) specify FINAL for all
        # scalar types, despite it not doing anything and being
        # undocumented. So we need to support it in the syntax, and we
        # reject later it when not reading an old dump.
        self.val = qlast.CreateScalarType(
            name=kids[4].val,
            final=True,
            bases=kids[5].val,
            commands=kids[6].val
        )

    def reduce_CreateScalarTypeStmt(self, *kids):
        r"""%reduce \
            CREATE SCALAR TYPE NodeName \
            OptExtending OptCreateScalarTypeCommandsBlock \
        """
        self.val = qlast.CreateScalarType(
            name=kids[3].val,
            bases=kids[4].val,
            commands=kids[5].val
        )


#
# ALTER SCALAR TYPE
#

commands_block(
    'AlterScalarType',
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    AlterExtending,
    CreateConcreteConstraintStmt,
    AlterConcreteConstraintStmt,
    DropConcreteConstraintStmt,
    opt=False
)


class AlterScalarTypeStmt(Nonterm):
    def reduce_AlterScalarTypeStmt(self, *kids):
        r"""%reduce \
            ALTER SCALAR TYPE NodeName \
            AlterScalarTypeCommandsBlock \
        """
        self.val = qlast.AlterScalarType(
            name=kids[3].val,
            commands=kids[4].val
        )


class DropScalarTypeStmt(Nonterm):
    def reduce_DROP_SCALAR_TYPE_NodeName(self, *kids):
        self.val = qlast.DropScalarType(name=kids[3].val)


#
# CREATE ANNOTATION
#
commands_block(
    'CreateAnnotation',
    CreateAnnotationValueStmt,
)


class CreateAnnotationStmt(Nonterm):
    def reduce_CreateAnnotation(self, *kids):
        r"""%reduce CREATE ABSTRACT ANNOTATION NodeName \
                    OptCreateAnnotationCommandsBlock"""
        self.val = qlast.CreateAnnotation(
            name=kids[3].val,
            commands=kids[4].val,
            inheritable=False,
        )

    def reduce_CreateInheritableAnnotation(self, *kids):
        r"""%reduce CREATE ABSTRACT INHERITABLE ANNOTATION
                    NodeName OptCreateCommandsBlock"""
        self.val = qlast.CreateAnnotation(
            name=kids[4].val,
            commands=kids[5].val,
            inheritable=True,
        )


#
# ALTER ANNOTATION
#
commands_block(
    'AlterAnnotation',
    RenameStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    opt=False,
)


class AlterAnnotationStmt(Nonterm):
    def reduce_AlterAnnotation(self, *kids):
        r"""%reduce ALTER ABSTRACT ANNOTATION NodeName \
                    AlterAnnotationCommandsBlock"""
        self.val = qlast.AlterAnnotation(
            name=kids[3].val,
            commands=kids[4].val
        )


#
# DROP ANNOTATION
#
class DropAnnotationStmt(Nonterm):
    def reduce_DropAnnotation(self, *kids):
        r"""%reduce DROP ABSTRACT ANNOTATION NodeName"""
        self.val = qlast.DropAnnotation(
            name=kids[3].val,
        )


#
# CREATE INDEX
#
commands_block(
    'CreateIndex',
    UsingStmt,
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
)


commands_block(
    'AlterIndex',
    UsingStmt,
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    opt=False)


class CreateIndexStmt(
    Nonterm,
    commondl.ProcessIndexMixin,
):
    def reduce_CreateIndex(self, *kids):
        r"""%reduce CREATE ABSTRACT INDEX NodeName \
                    OptExtendingSimple OptCreateIndexCommandsBlock"""
        self.val = qlast.CreateIndex(
            name=kids[3].val,
            bases=kids[4].val,
            **self._process_sql_body(kids[5])
        )

    def reduce_CreateIndex_CreateFunctionArgs(self, *kids):
        r"""%reduce CREATE ABSTRACT INDEX NodeName IndexExtArgList \
                    OptExtendingSimple OptCreateIndexCommandsBlock"""
        bases = kids[5].val
        params, kwargs = self._process_params_or_kwargs(bases, kids[4].val)

        self.val = qlast.CreateIndex(
            name=kids[3].val,
            params=params,
            kwargs=kwargs,
            bases=bases,
            **self._process_sql_body(kids[6])
        )


#
# ALTER INDEX
#
class AlterIndexStmt(Nonterm, commondl.ProcessIndexMixin):
    def reduce_AlterIndex(self, *kids):
        r"""%reduce ALTER ABSTRACT INDEX NodeName \
                    AlterIndexCommandsBlock"""
        self.val = qlast.AlterIndex(
            name=kids[3].val,
            **self._process_sql_body(kids[4])
        )


#
# DROP INDEX
#
class DropIndexStmt(Nonterm):
    def reduce_DropIndex(self, *kids):
        r"""%reduce DROP ABSTRACT INDEX NodeName"""
        self.val = qlast.DropIndex(
            name=kids[3].val
        )


#
# CREATE CONCRETE INDEX
#
class CreateConcreteIndexStmt(Nonterm, commondl.ProcessIndexMixin):
    def reduce_CreateConcreteDefaultIndex(self, *kids):
        r"""%reduce CREATE OptDeferred INDEX OnExpr OptExceptExpr
                    OptCreateCommandsBlock
        """
        self.val = qlast.CreateConcreteIndex(
            name=qlast.ObjectRef(module='__', name='idx', span=kids[2].span),
            expr=kids[3].val,
            except_expr=kids[4].val,
            deferred=kids[1].val,
            commands=kids[5].val,
        )

    def reduce_CreateConcreteIndex(self, *kids):
        r"""%reduce CREATE OptDeferred INDEX NodeName
                    OptIndexExtArgList OnExpr OptExceptExpr
                    OptCreateCommandsBlock
        """
        kwargs = self._process_arguments(kids[4].val)
        self.val = qlast.CreateConcreteIndex(
            name=kids[3].val,
            kwargs=kwargs,
            expr=kids[5].val,
            except_expr=kids[6].val,
            deferred=kids[1].val,
            commands=kids[7].val,
        )


#
# ALTER CONCRETE INDEX
#

class AlterDeferredStmt(Nonterm):
    def reduce_DROP_DEFERRED(self, *kids):
        self.val = qlast.SetField(
            name='deferred',
            value=qlast.Constant.boolean(False, span=self.span),
            special_syntax=True,
        )

    def reduce_SET_DEFERRED(self, *kids):
        self.val = qlast.SetField(
            name='deferred',
            value=qlast.Constant.boolean(True, span=self.span),
            special_syntax=True,
        )


commands_block(
    'AlterConcreteIndex',
    SetFieldStmt,
    ResetFieldStmt,
    AlterOwnedStmt,
    AlterDeferredStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    opt=False)


class AlterConcreteIndexStmt(Nonterm, commondl.ProcessIndexMixin):
    def reduce_AlterConcreteIndex(self, *kids):
        r"""%reduce ALTER INDEX OnExpr OptExceptExpr \
                    AlterConcreteIndexCommandsBlock \
        """
        self.val = qlast.AlterConcreteIndex(
            name=qlast.ObjectRef(module='__', name='idx', span=kids[1].span),
            expr=kids[2].val,
            except_expr=kids[3].val,
            commands=kids[4].val,
        )

    def reduce_AlterConcreteNamedIndex(self, *kids):
        r"""%reduce ALTER INDEX NodeName OptIndexExtArgList OnExpr \
                    OptExceptExpr \
                    AlterConcreteIndexCommandsBlock \
        """
        kwargs = self._process_arguments(kids[3].val)
        self.val = qlast.AlterConcreteIndex(
            name=kids[2].val,
            kwargs=kwargs,
            expr=kids[4].val,
            except_expr=kids[5].val,
            commands=kids[6].val,
        )


commands_block(
    'DropConcreteIndex',
    SetFieldStmt,
    opt=True,
)


#
# DROP CONCRETE INDEX
#
class DropConcreteIndexStmt(Nonterm, commondl.ProcessIndexMixin):
    def reduce_DropConcreteIndex(self, *kids):
        r"""%reduce DROP INDEX OnExpr OptExceptExpr \
                    OptDropConcreteIndexCommandsBlock \
        """
        self.val = qlast.DropConcreteIndex(
            name=qlast.ObjectRef(module='__', name='idx', span=kids[1].span),
            expr=kids[2].val,
            except_expr=kids[3].val,
            commands=kids[4].val,
        )

    def reduce_DropConcreteNamedIndex(self, *kids):
        r"""%reduce DROP INDEX NodeName OptIndexExtArgList OnExpr \
                    OptExceptExpr \
                    OptDropConcreteIndexCommandsBlock \
        """
        kwargs = self._process_arguments(kids[3].val)
        self.val = qlast.DropConcreteIndex(
            name=kids[2].val,
            kwargs=kwargs,
            expr=kids[4].val,
            except_expr=kids[5].val,
            commands=kids[6].val,
        )


#
# CREATE INDEX MATCH
#
commands_block(
    'CreateIndexMatch',
    CreateAnnotationValueStmt,
)


class CreateIndexMatchStmt(Nonterm):
    def reduce_CreateIndexMatch(self, *kids):
        r"""%reduce CREATE INDEX MATCH FOR TypeName USING NodeName \
                    OptCreateIndexMatchCommandsBlock"""
        self.val = qlast.CreateIndexMatch(
            valid_type=kids[4].val,
            name=kids[6].val,
            commands=kids[7].val,
        )


#
# DROP INDEX MATCH
#
class DropIndexMatchStmt(Nonterm):
    def reduce_DropIndexMatch(self, *kids):
        r"""%reduce DROP INDEX MATCH FOR TypeName USING NodeName"""
        self.val = qlast.DropIndexMatch(
            valid_type=kids[4].val,
            name=kids[6].val,
        )


#
# CREATE REWRITE
#

commands_block(
    'CreateRewrite',
    CreateAnnotationValueStmt,
    SetFieldStmt,
)


class CreateRewriteStmt(Nonterm):
    def reduce_CreateRewrite(self, *kids):
        """%reduce
            CREATE REWRITE RewriteKindList
            USING ParenExpr
            OptCreateRewriteCommandsBlock
        """
        _, _, kinds, _, expr, commands = kids
        self.val = qlast.CreateRewrite(
            kinds=kinds.val,
            expr=expr.val,
            commands=commands.val,
        )


commands_block(
    'AlterRewrite',
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    SetFieldStmt,
    ResetFieldStmt,
    UsingStmt,
    opt=False
)


class AlterRewriteStmt(Nonterm):
    def reduce_AlterRewrite(self, _a, _r, kinds, commands):
        r"""%reduce \
            ALTER REWRITE RewriteKindList \
            AlterRewriteCommandsBlock \
        """
        self.val = qlast.AlterRewrite(
            kinds=kinds.val,
            commands=commands.val,
        )


class DropRewriteStmt(Nonterm):
    def reduce_DropRewrite(self, _d, _r, kinds):
        r"""%reduce DROP REWRITE RewriteKindList"""
        self.val = qlast.DropRewrite(
            kinds=kinds.val
        )


#
# CREATE PROPERTY
#

commands_block(
    'CreateProperty',
    UsingStmt,
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    commondl.CreateSimpleExtending,
)


class CreatePropertyStmt(Nonterm):
    def reduce_CreateProperty(self, *kids):
        r"""%reduce CREATE ABSTRACT PROPERTY PtrNodeName OptExtendingSimple \
                    OptCreatePropertyCommandsBlock \
        """
        vbases, vcommands = commondl.extract_bases(kids[4].val, kids[5].val)
        self.val = qlast.CreateProperty(
            name=kids[3].val,
            bases=vbases,
            commands=vcommands,
            abstract=True,
        )


#
# ALTER PROPERTY
#

commands_block(
    'AlterProperty',
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    CreateRewriteStmt,
    AlterRewriteStmt,
    DropRewriteStmt,
    opt=False
)


class AlterPropertyStmt(Nonterm):
    def reduce_AlterProperty(self, *kids):
        r"""%reduce \
            ALTER ABSTRACT PROPERTY PtrNodeName \
            AlterPropertyCommandsBlock \
        """
        self.val = qlast.AlterProperty(
            name=kids[3].val,
            commands=kids[4].val
        )


#
# DROP PROPERTY
#
class DropPropertyStmt(Nonterm):
    def reduce_DropProperty(self, *kids):
        r"""%reduce DROP ABSTRACT PROPERTY PtrNodeName"""
        self.val = qlast.DropProperty(
            name=kids[3].val
        )


#
# CREATE LINK ... { CREATE PROPERTY
#

class SetRequiredInCreateStmt(Nonterm):

    def reduce_SET_REQUIRED_OptAlterUsingClause(self, *kids):
        self.val = qlast.SetPointerOptionality(
            name='required',
            value=qlast.Constant.boolean(True, span=self.span),
            special_syntax=True,
            fill_expr=kids[2].val,
        )


commands_block(
    'CreateConcreteProperty',
    UsingStmt,
    SetFieldStmt,
    SetRequiredInCreateStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    CreateConcreteConstraintStmt,
    CreateRewriteStmt,
    commondl.CreateSimpleExtending,
)


class CreateConcretePropertyStmt(Nonterm):
    def reduce_CreateRegularProperty(self, *kids):
        """%reduce
            CREATE OptPtrQuals PROPERTY UnqualifiedPointerName
            OptExtendingSimple ARROW FullTypeExpr
            OptCreateConcretePropertyCommandsBlock
        """
        vbases, vcommands = commondl.extract_bases(kids[4].val, kids[7].val)
        self.val = qlast.CreateConcreteProperty(
            name=kids[3].val,
            bases=vbases,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=kids[6].val,
            commands=vcommands,
        )

    def reduce_CreateRegularPropertyNew(self, *kids):
        """%reduce
            CREATE OptPtrQuals PROPERTY UnqualifiedPointerName
            OptExtendingSimple COLON FullTypeExpr
            OptCreateConcretePropertyCommandsBlock
        """
        vbases, vcommands = commondl.extract_bases(kids[4].val, kids[7].val)
        self.val = qlast.CreateConcreteProperty(
            name=kids[3].val,
            bases=vbases,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=kids[6].val,
            commands=vcommands,
        )

    def reduce_CreateComputableProperty(self, *kids):
        """%reduce
            CREATE OptPtrQuals PROPERTY UnqualifiedPointerName ASSIGN GenExpr
        """
        self.val = qlast.CreateConcreteProperty(
            name=kids[3].val,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=kids[5].val,
        )

    def reduce_CreateComputablePropertyWithUsing(self, *kids):
        """%reduce
            CREATE OptPtrQuals PROPERTY UnqualifiedPointerName
            OptCreateConcretePropertyCommandsBlock
        """
        cmds = kids[4].val
        target = None

        for cmd in cmds:
            if isinstance(cmd, qlast.SetField) and cmd.name == 'expr':
                if target is not None:
                    raise EdgeQLSyntaxError(
                        f'computed property with more than one expression',
                        span=kids[3].span)
                target = cmd.value
            elif isinstance(cmd, qlast.AlterAddInherit):
                raise EdgeQLSyntaxError(
                    f'computed property cannot specify EXTENDING',
                    span=kids[3].span)

        if target is None:
            raise EdgeQLSyntaxError(
                f'computed property without expression',
                span=kids[3].span)

        self.val = qlast.CreateConcreteProperty(
            name=kids[3].val,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=target,
            commands=cmds,
        )


#
# ALTER LINK/PROPERTY
#


class OptAlterUsingClause(Nonterm):
    @parsing.inline(1)
    def reduce_USING_ParenExpr(self, *kids):
        pass

    def reduce_empty(self):
        self.val = None


class SetCardinalityStmt(Nonterm):

    def reduce_SET_SINGLE_OptAlterUsingClause(self, *kids):
        self.val = qlast.SetPointerCardinality(
            name='cardinality',
            value=qlast.Constant.string(
                qltypes.SchemaCardinality.One,
                span=kids[1].span,
            ),
            special_syntax=True,
            conv_expr=kids[2].val,
        )

    def reduce_SET_MULTI(self, *kids):
        self.val = qlast.SetPointerCardinality(
            name='cardinality',
            value=qlast.Constant.string(
                qltypes.SchemaCardinality.Many,
                span=kids[1].span,
            ),
            special_syntax=True,
        )

    def reduce_RESET_CARDINALITY_OptAlterUsingClause(self, *kids):
        self.val = qlast.SetPointerCardinality(
            name='cardinality',
            value=None,
            special_syntax=True,
            conv_expr=kids[2].val,
        )


class SetRequiredStmt(Nonterm):

    def reduce_SET_REQUIRED_OptAlterUsingClause(self, *kids):
        self.val = qlast.SetPointerOptionality(
            name='required',
            value=qlast.Constant.boolean(True, span=self.span),
            special_syntax=True,
            fill_expr=kids[2].val,
        )

    def reduce_SET_OPTIONAL(self, *kids):
        self.val = qlast.SetPointerOptionality(
            name='required',
            value=qlast.Constant.boolean(False, span=self.span),
            special_syntax=True,
        )

    def reduce_DROP_REQUIRED(self, *kids):
        # TODO: Raise a DeprecationWarning once we have facility for that.
        self.val = qlast.SetPointerOptionality(
            name='required',
            value=qlast.Constant.boolean(False, span=self.span),
            special_syntax=True,
        )

    def reduce_RESET_OPTIONALITY(self, *kids):
        self.val = qlast.SetPointerOptionality(
            name='required',
            value=None,
            special_syntax=True,
        )


class SetPointerTypeStmt(Nonterm):

    def reduce_SETTYPE_FullTypeExpr_OptAlterUsingClause(self, *kids):
        self.val = qlast.SetPointerType(
            value=kids[1].val,
            cast_expr=kids[2].val,
        )

    def reduce_RESET_TYPE(self, *kids):
        self.val = qlast.SetPointerType(
            value=None,
        )


commands_block(
    'AlterConcreteProperty',
    UsingStmt,
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    AlterOwnedStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    SetPointerTypeStmt,
    SetCardinalityStmt,
    SetRequiredStmt,
    AlterSimpleExtending,
    CreateConcreteConstraintStmt,
    AlterConcreteConstraintStmt,
    DropConcreteConstraintStmt,
    CreateRewriteStmt,
    AlterRewriteStmt,
    DropRewriteStmt,
    opt=False
)


class AlterConcretePropertyStmt(Nonterm):
    def reduce_AlterProperty(self, *kids):
        r"""%reduce \
            ALTER PROPERTY UnqualifiedPointerName \
            AlterConcretePropertyCommandsBlock \
        """
        self.val = qlast.AlterConcreteProperty(
            name=kids[2].val,
            commands=kids[3].val
        )


#
# ALTER LINK ... { DROP PROPERTY
#

class DropConcretePropertyStmt(Nonterm):
    def reduce_DropProperty(self, *kids):
        r"""%reduce \
            DROP PROPERTY UnqualifiedPointerName \
        """
        self.val = qlast.DropConcreteProperty(
            name=kids[2].val
        )


#
# CREATE LINK
#

commands_block(
    'CreateLink',
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    CreateConcreteConstraintStmt,
    CreateConcretePropertyStmt,
    CreateConcreteIndexStmt,
    CreateRewriteStmt,
    commondl.CreateSimpleExtending,
)


class CreateLinkStmt(Nonterm):
    def reduce_CreateLink(self, *kids):
        r"""%reduce \
            CREATE ABSTRACT LINK PtrNodeName OptExtendingSimple \
            OptCreateLinkCommandsBlock \
        """
        vbases, vcommands = commondl.extract_bases(
            kids[4].val,
            kids[5].val,
        )
        self.val = qlast.CreateLink(
            name=kids[3].val,
            bases=vbases,
            commands=vcommands,
            abstract=True,
        )


#
# ALTER LINK
#

commands_block(
    'AlterLink',
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    AlterSimpleExtending,
    CreateConcreteConstraintStmt,
    AlterConcreteConstraintStmt,
    DropConcreteConstraintStmt,
    CreateConcretePropertyStmt,
    AlterConcretePropertyStmt,
    DropConcretePropertyStmt,
    CreateConcreteIndexStmt,
    AlterConcreteIndexStmt,
    DropConcreteIndexStmt,
    CreateRewriteStmt,
    AlterRewriteStmt,
    DropRewriteStmt,
    opt=False
)


class AlterLinkStmt(Nonterm):
    def reduce_AlterLink(self, *kids):
        r"""%reduce \
            ALTER ABSTRACT LINK PtrNodeName \
            AlterLinkCommandsBlock \
        """
        self.val = qlast.AlterLink(
            name=kids[3].val,
            commands=kids[4].val
        )


#
# DROP LINK
#

commands_block(
    'DropLink',
    DropConcreteConstraintStmt,
    DropConcreteConstraintStmt,
    DropConcretePropertyStmt,
    DropConcreteIndexStmt,
)


class DropLinkStmt(Nonterm):
    def reduce_DropLink(self, *kids):
        r"""%reduce \
            DROP ABSTRACT LINK PtrNodeName \
            OptDropLinkCommandsBlock \
        """
        self.val = qlast.DropLink(
            name=kids[3].val,
            commands=kids[4].val
        )


#
# CREATE TYPE ... { CREATE LINK
#

commands_block(
    'CreateConcreteLink',
    UsingStmt,
    SetFieldStmt,
    SetRequiredInCreateStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    CreateConcreteConstraintStmt,
    CreateConcretePropertyStmt,
    CreateConcreteIndexStmt,
    commondl.OnTargetDeleteStmt,
    commondl.OnSourceDeleteStmt,
    CreateRewriteStmt,
    commondl.CreateSimpleExtending,
)


class CreateConcreteLinkStmt(Nonterm):
    def reduce_CreateRegularLink(self, *kids):
        """%reduce
            CREATE OptPtrQuals LINK UnqualifiedPointerName OptExtendingSimple
            ARROW FullTypeExpr OptCreateConcreteLinkCommandsBlock
        """
        vbases, vcommands = commondl.extract_bases(kids[4].val, kids[7].val)
        self.val = qlast.CreateConcreteLink(
            name=kids[3].val,
            bases=vbases,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=kids[6].val,
            commands=vcommands,
        )

    def reduce_CreateRegularLinkNew(self, *kids):
        """%reduce
            CREATE OptPtrQuals LINK UnqualifiedPointerName OptExtendingSimple
            COLON FullTypeExpr OptCreateConcreteLinkCommandsBlock
        """
        vbases, vcommands = commondl.extract_bases(kids[4].val, kids[7].val)
        self.val = qlast.CreateConcreteLink(
            name=kids[3].val,
            bases=vbases,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=kids[6].val,
            commands=vcommands
        )

    def reduce_CreateComputableLink(self, *kids):
        """%reduce
            CREATE OptPtrQuals LINK UnqualifiedPointerName ASSIGN GenExpr
        """
        self.val = qlast.CreateConcreteLink(
            name=kids[3].val,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=kids[5].val,
        )

    def reduce_CreateComputableLinkWithUsing(self, *kids):
        """%reduce
            CREATE OptPtrQuals LINK UnqualifiedPointerName
            OptCreateConcreteLinkCommandsBlock
        """
        cmds = kids[4].val
        target = None

        for cmd in cmds:
            if isinstance(cmd, qlast.SetField) and cmd.name == 'expr':
                if target is not None:
                    raise EdgeQLSyntaxError(
                        f'computed link with more than one expression',
                        span=kids[3].span)
                target = cmd.value
            elif isinstance(cmd, qlast.AlterAddInherit):
                raise EdgeQLSyntaxError(
                    f'computed link cannot specify EXTENDING',
                    span=kids[3].span)

        if target is None:
            raise EdgeQLSyntaxError(
                f'computed link without expression',
                span=kids[3].span)

        self.val = qlast.CreateConcreteLink(
            name=kids[3].val,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=target,
            commands=cmds,
        )


class OnTargetDeleteResetStmt(Nonterm):
    def reduce_RESET_ON_TARGET_DELETE(self, *kids):
        self.val = qlast.OnTargetDelete(cascade=None)


class OnSourceDeleteResetStmt(Nonterm):
    def reduce_RESET_ON_SOURCE_DELETE(self, *kids):
        self.val = qlast.OnSourceDelete(cascade=None)


commands_block(
    'AlterConcreteLink',
    UsingStmt,
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    AlterOwnedStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    SetCardinalityStmt,
    SetRequiredStmt,
    SetPointerTypeStmt,
    AlterSimpleExtending,
    CreateConcreteConstraintStmt,
    AlterConcreteConstraintStmt,
    DropConcreteConstraintStmt,
    CreateConcretePropertyStmt,
    AlterConcretePropertyStmt,
    DropConcretePropertyStmt,
    CreateConcreteIndexStmt,
    AlterConcreteIndexStmt,
    DropConcreteIndexStmt,
    commondl.OnTargetDeleteStmt,
    commondl.OnSourceDeleteStmt,
    OnTargetDeleteResetStmt,
    OnSourceDeleteResetStmt,
    CreateRewriteStmt,
    AlterRewriteStmt,
    DropRewriteStmt,
    opt=False
)


class AlterConcreteLinkStmt(Nonterm):
    def reduce_AlterLink(self, *kids):
        r"""%reduce \
            ALTER LINK UnqualifiedPointerName AlterConcreteLinkCommandsBlock \
        """
        self.val = qlast.AlterConcreteLink(
            name=kids[2].val,
            commands=kids[3].val
        )


commands_block(
    'DropConcreteLink',
    DropConcreteConstraintStmt,
    DropConcretePropertyStmt,
    DropConcreteIndexStmt,
)


class DropConcreteLinkStmt(Nonterm):
    def reduce_DropLink(self, *kids):
        r"""%reduce \
            DROP LINK UnqualifiedPointerName \
            OptDropConcreteLinkCommandsBlock \
        """
        self.val = qlast.DropConcreteLink(
            name=kids[2].val,
            commands=kids[3].val
        )


#
# CREATE ACCESS POLICY
#

commands_block(
    'CreateAccessPolicy',
    CreateAnnotationValueStmt,
    SetFieldStmt,
)


class CreateAccessPolicyStmt(Nonterm):
    def reduce_CreateAccessPolicy(self, *kids):
        """%reduce
            CREATE ACCESS POLICY UnqualifiedPointerName
            OptWhenBlock AccessPolicyAction AccessKindList
            OptUsingBlock
            OptCreateAccessPolicyCommandsBlock
        """
        self.val = qlast.CreateAccessPolicy(
            name=kids[3].val,
            condition=kids[4].val,
            action=kids[5].val,
            access_kinds=[y for x in kids[6].val for y in x],
            expr=kids[7].val,
            commands=kids[8].val,
        )


class AccessPermStmt(Nonterm):
    def reduce_AccessPolicyAction_AccessKindList(self, *kids):
        self.val = qlast.SetAccessPerms(
            action=kids[0].val,
            access_kinds=[y for x in kids[1].val for y in x],
        )


class AccessUsingStmt(Nonterm):
    def reduce_USING_ParenExpr(self, *kids):
        self.val = qlast.SetField(
            name='expr',
            value=kids[1].val,
            special_syntax=True,
        )

    def reduce_RESET_EXPRESSION(self, *kids):
        self.val = qlast.SetField(
            name='expr',
            value=None,
            special_syntax=True,
        )


class AccessWhenStmt(Nonterm):

    def reduce_WHEN_ParenExpr(self, *kids):
        self.val = qlast.SetField(
            name='condition',
            value=kids[1].val,
            special_syntax=True,
        )

    def reduce_RESET_WHEN(self, *kids):
        self.val = qlast.SetField(
            name='condition',
            value=None,
            special_syntax=True,
        )


commands_block(
    'AlterAccessPolicy',
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    RenameStmt,
    AccessPermStmt,
    AccessUsingStmt,
    AccessWhenStmt,
    SetFieldStmt,
    ResetFieldStmt,
    opt=False
)


class AlterAccessPolicyStmt(Nonterm):
    def reduce_AlterAccessPolicy(self, *kids):
        r"""%reduce \
            ALTER ACCESS POLICY UnqualifiedPointerName \
            AlterAccessPolicyCommandsBlock \
        """
        self.val = qlast.AlterAccessPolicy(
            name=kids[3].val,
            commands=kids[4].val,
        )


class DropAccessPolicyStmt(Nonterm):
    def reduce_DropAccessPolicy(self, *kids):
        r"""%reduce DROP ACCESS POLICY UnqualifiedPointerName"""
        self.val = qlast.DropAccessPolicy(
            name=kids[3].val
        )


#
# CREATE TRIGGER
#

commands_block(
    'CreateTrigger',
    CreateAnnotationValueStmt,
    SetFieldStmt,
)


class CreateTriggerStmt(Nonterm):
    def reduce_CreateTrigger(self, *kids):
        """%reduce
            CREATE TRIGGER UnqualifiedPointerName
            TriggerTiming TriggerKindList
            FOR TriggerScope
            OptWhenBlock
            DO ParenExpr
            OptCreateTriggerCommandsBlock
        """
        _, _, name, timing, kinds, _, scope, when, _, expr, commands = kids
        self.val = qlast.CreateTrigger(
            name=name.val,
            timing=timing.val,
            kinds=kinds.val,
            scope=scope.val,
            expr=expr.val,
            condition=when.val,
            commands=commands.val,
        )


# TODO: commands to change timing/kind/scope?
commands_block(
    'AlterTrigger',
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    RenameStmt,
    UsingStmt,
    AccessWhenStmt,
    SetFieldStmt,
    ResetFieldStmt,
    opt=False
)


class AlterTriggerStmt(Nonterm):
    def reduce_AlterTrigger(self, *kids):
        r"""%reduce \
            ALTER TRIGGER UnqualifiedPointerName \
            AlterTriggerCommandsBlock \
        """
        _, _, name, commands = kids
        self.val = qlast.AlterTrigger(
            name=name.val,
            commands=commands.val,
        )


class DropTriggerStmt(Nonterm):
    def reduce_DropTrigger(self, *kids):
        r"""%reduce DROP TRIGGER UnqualifiedPointerName"""
        _, _, name = kids
        self.val = qlast.DropTrigger(
            name=name.val
        )


#
# CREATE TYPE
#

commands_block(
    'CreateObjectType',
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    CreateConcretePropertyStmt,
    AlterConcretePropertyStmt,
    CreateConcreteLinkStmt,
    AlterConcreteLinkStmt,
    CreateConcreteConstraintStmt,
    AlterConcreteConstraintStmt,
    CreateConcreteIndexStmt,
    AlterConcreteIndexStmt,
    CreateAccessPolicyStmt,
    AlterAccessPolicyStmt,
    CreateTriggerStmt,
    AlterTriggerStmt,
)


class CreateObjectTypeStmt(Nonterm):
    def reduce_CreateAbstractObjectTypeStmt(self, *kids):
        r"""%reduce \
            CREATE ABSTRACT TYPE NodeName \
            OptExtendingSimple OptCreateObjectTypeCommandsBlock \
        """
        _, _, _, name, bases, commands = kids
        self.val = qlast.CreateObjectType(
            name=name.val,
            bases=bases.val,
            abstract=True,
            commands=commands.val,
        )

    def reduce_CreateRegularObjectTypeStmt(self, *kids):
        r"""%reduce \
            CREATE TYPE NodeName \
            OptExtendingSimple OptCreateObjectTypeCommandsBlock \
        """
        _, _, name, bases, commands = kids
        self.val = qlast.CreateObjectType(
            name=name.val,
            bases=bases.val,
            abstract=False,
            commands=commands.val,
        )


#
# ALTER TYPE
#

commands_block(
    'AlterObjectType',
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    AlterSimpleExtending,
    CreateConcretePropertyStmt,
    AlterConcretePropertyStmt,
    DropConcretePropertyStmt,
    CreateConcreteLinkStmt,
    AlterConcreteLinkStmt,
    DropConcreteLinkStmt,
    CreateConcreteConstraintStmt,
    AlterConcreteConstraintStmt,
    DropConcreteConstraintStmt,
    CreateConcreteIndexStmt,
    AlterConcreteIndexStmt,
    DropConcreteIndexStmt,
    CreateAccessPolicyStmt,
    AlterAccessPolicyStmt,
    DropAccessPolicyStmt,
    CreateTriggerStmt,
    AlterTriggerStmt,
    DropTriggerStmt,
    opt=False
)


class AlterObjectTypeStmt(Nonterm):
    def reduce_AlterObjectTypeStmt(self, *kids):
        r"""%reduce \
            ALTER TYPE NodeName \
            AlterObjectTypeCommandsBlock \
        """
        self.val = qlast.AlterObjectType(
            name=kids[2].val,
            commands=kids[3].val
        )


#
# DROP TYPE
#

commands_block(
    'DropObjectType',
    DropConcretePropertyStmt,
    DropConcreteLinkStmt,
    DropConcreteConstraintStmt,
    DropConcreteIndexStmt
)


class DropObjectTypeStmt(Nonterm):
    def reduce_DropObjectType(self, *kids):
        r"""%reduce \
            DROP TYPE \
            NodeName OptDropObjectTypeCommandsBlock \
        """
        self.val = qlast.DropObjectType(
            name=kids[2].val,
            commands=kids[3].val
        )


#
# CREATE ALIAS
#

commands_block(
    'CreateAlias',
    UsingStmt,
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    opt=False
)


class CreateAliasStmt(Nonterm):
    def reduce_CreateAliasShortStmt(self, *kids):
        r"""%reduce
            CREATE ALIAS NodeName ASSIGN GenExpr
        """
        self.val = qlast.CreateAlias(
            name=kids[2].val,
            commands=[
                qlast.SetField(
                    name='expr',
                    value=kids[4].val,
                    special_syntax=True,
                    span=self.span,
                )
            ]
        )

    def reduce_CreateAliasRegularStmt(self, *kids):
        r"""%reduce
            CREATE ALIAS NodeName
            CreateAliasCommandsBlock
        """
        self.val = qlast.CreateAlias(
            name=kids[2].val,
            commands=kids[3].val,
        )


#
# ALTER ALIAS
#

commands_block(
    'AlterAlias',
    UsingStmt,
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    opt=False
)


class AlterAliasStmt(Nonterm):
    def reduce_AlterAliasStmt(self, *kids):
        r"""%reduce
            ALTER ALIAS NodeName
            AlterAliasCommandsBlock
        """
        self.val = qlast.AlterAlias(
            name=kids[2].val,
            commands=kids[3].val
        )


#
# DROP ALIAS
#

class DropAliasStmt(Nonterm):
    def reduce_DropAlias(self, *kids):
        r"""%reduce
            DROP ALIAS NodeName
        """
        self.val = qlast.DropAlias(
            name=kids[2].val,
        )


#
# CREATE MODULE
#
class CreateModuleStmt(Nonterm):
    def reduce_CREATE_MODULE_ModuleName_OptIfNotExists_OptCreateCommandsBlock(
        self, *kids
    ):
        self.val = qlast.CreateModule(
            name=qlast.ObjectRef(
                module=None, name='::'.join(kids[2].val), span=kids[2].span
            ),
            create_if_not_exists=kids[3].val,
            commands=kids[4].val
        )


#
# ALTER MODULE
#
class AlterModuleStmt(Nonterm):
    def reduce_ALTER_MODULE_ModuleName_AlterCommandsBlock(self, *kids):
        self.val = qlast.AlterModule(
            name=qlast.ObjectRef(
                module=None, name='::'.join(kids[2].val), span=kids[2].span
            ),
            commands=kids[3].val
        )


#
# DROP MODULE
#
class DropModuleStmt(Nonterm):
    def reduce_DROP_MODULE_ModuleName(self, *kids):
        self.val = qlast.DropModule(
            name=qlast.ObjectRef(
                module=None, name='::'.join(kids[2].val), span=kids[2].span
            )
        )


#
# CREATE FUNCTION
#


commands_block(
    'CreateFunction',
    commondl.FromFunction,
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    opt=False
)


class CreateFunctionStmt(Nonterm, commondl.ProcessFunctionBlockMixin):
    def reduce_CreateFunction(self, *kids):
        r"""%reduce CREATE FUNCTION NodeName CreateFunctionArgs \
                FunctionResult CreateFunctionCommandsBlock
        """
        self.val = qlast.CreateFunction(
            name=kids[2].val,
            params=kids[3].val,
            returning=kids[4].val.result_type,
            returning_typemod=kids[4].val.type_qualifier,
            **self._process_function_body(kids[5])
        )


class DropFunctionStmt(Nonterm):
    def reduce_DropFunction(self, *kids):
        r"""%reduce DROP FUNCTION NodeName CreateFunctionArgs"""
        self.val = qlast.DropFunction(
            name=kids[2].val,
            params=kids[3].val)


#
# ALTER FUNCTION
#

commands_block(
    'AlterFunction',
    commondl.FromFunction,
    SetFieldStmt,
    ResetFieldStmt,
    RenameStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    opt=False
)


class AlterFunctionStmt(Nonterm, commondl.ProcessFunctionBlockMixin):
    def reduce_AlterFunctionStmt(self, *kids):
        """%reduce
           ALTER FUNCTION NodeName CreateFunctionArgs
           AlterFunctionCommandsBlock
        """
        self.val = qlast.AlterFunction(
            name=kids[2].val,
            params=kids[3].val,
            **self._process_function_body(kids[4], optional_using=True)
        )


#
# CREATE OPERATOR
#

class OperatorKind(Nonterm):

    def reduce_INFIX(self, *kids):
        self.val = qltypes.OperatorKind.Infix

    def reduce_POSTFIX(self, *kids):
        self.val = qltypes.OperatorKind.Postfix

    def reduce_PREFIX(self, *kids):
        self.val = qltypes.OperatorKind.Prefix

    def reduce_TERNARY(self, *kids):
        self.val = qltypes.OperatorKind.Ternary


SQL_OP_RE = r"([^(]+)(?:\(([\w\.]*(?:,\s*[\w\.]*)*)\))?"


class OperatorCode(Nonterm):

    def reduce_USING_Identifier_OPERATOR_BaseStringConstant(self, *kids):
        lang = commondl._parse_language(kids[1])
        if lang != qlast.Language.SQL:
            raise EdgeQLSyntaxError(
                f'{lang} language is not supported in USING OPERATOR clause',
                span=kids[1].span) from None

        m = re.match(SQL_OP_RE, kids[3].val.value)
        if not m:
            raise EdgeQLSyntaxError(
                f'invalid syntax for USING OPERATOR clause',
                span=kids[3].span) from None

        sql_operator = (m.group(1),)
        if m.group(2):
            sql_operator += tuple(op.strip() for op in m.group(2).split(","))

        self.val = qlast.OperatorCode(
            language=lang, from_operator=sql_operator)

    def reduce_USING_Identifier_FUNCTION_BaseStringConstant(self, *kids):
        lang = commondl._parse_language(kids[1])
        if lang != qlast.Language.SQL:
            raise EdgeQLSyntaxError(
                f'{lang} language is not supported in USING FUNCTION clause',
                span=kids[1].span) from None

        m = re.match(SQL_OP_RE, kids[3].val.value)
        if not m:
            raise EdgeQLSyntaxError(
                f'invalid syntax for USING FUNCTION clause',
                span=kids[3].span) from None

        sql_function = (m.group(1),)
        if m.group(2):
            sql_function += tuple(op.strip() for op in m.group(2).split(','))

        self.val = qlast.OperatorCode(
            language=lang, from_function=sql_function)

    def reduce_USING_Identifier_BaseStringConstant(self, *kids):
        lang = commondl._parse_language(kids[1])
        if lang != qlast.Language.SQL:
            raise EdgeQLSyntaxError(
                f'{lang} language is not supported in USING clause',
                span=kids[1].span) from None

        self.val = qlast.OperatorCode(language=lang,
                                      code=kids[2].val.value)

    def reduce_USING_Identifier_EXPRESSION(self, *kids):
        lang = commondl._parse_language(kids[1])
        if lang != qlast.Language.SQL:
            raise EdgeQLSyntaxError(
                f'{lang} language is not supported in USING clause',
                span=kids[1].span) from None

        self.val = qlast.OperatorCode(language=lang)


commands_block(
    'CreateOperator',
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    OperatorCode,
    opt=False
)


class OptCreateOperatorCommandsBlock(Nonterm):

    @parsing.inline(0)
    def reduce_CreateOperatorCommandsBlock(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = []


class CreateOperatorStmt(Nonterm):

    def reduce_CreateOperatorStmt(self, *kids):
        r"""%reduce
            CREATE OperatorKind OPERATOR NodeName CreateFunctionArgs
            FunctionResult CreateOperatorCommandsBlock
        """
        self.val = qlast.CreateOperator(
            kind=kids[1].val,
            name=kids[3].val,
            params=kids[4].val,
            returning_typemod=kids[5].val.type_qualifier,
            returning=kids[5].val.result_type,
            **self._process_operator_body(kids[6])
        )

    def reduce_CreateAbstractOperatorStmt(self, *kids):
        r"""%reduce
            CREATE ABSTRACT OperatorKind OPERATOR NodeName CreateFunctionArgs
            FunctionResult OptCreateOperatorCommandsBlock
        """
        self.val = qlast.CreateOperator(
            kind=kids[2].val,
            name=kids[4].val,
            params=kids[5].val,
            returning_typemod=kids[6].val.type_qualifier,
            returning=kids[6].val.result_type,
            abstract=True,
            **self._process_operator_body(kids[7], abstract=True)
        )

    def _process_operator_body(self, block, abstract: bool = False):
        props: dict[str, typing.Any] = {}

        commands = []
        from_operator = None
        from_function = None
        from_expr = False
        code = None

        for node in block.val:
            if isinstance(node, qlast.OperatorCode):
                if abstract:
                    raise errors.InvalidOperatorDefinitionError(
                        'unexpected USING clause in abstract '
                        'operator definition',
                        span=node.span,
                    )

                if node.from_function:
                    if from_function is not None:
                        raise errors.InvalidOperatorDefinitionError(
                            'more than one USING FUNCTION clause',
                            span=node.span)
                    from_function = node.from_function

                elif node.from_operator:
                    if from_operator is not None:
                        raise errors.InvalidOperatorDefinitionError(
                            'more than one USING OPERATOR clause',
                            span=node.span)
                    from_operator = node.from_operator

                elif node.code:
                    if code is not None:
                        raise errors.InvalidOperatorDefinitionError(
                            'more than one USING  clause',
                            span=node.span)
                    code = node.code

                else:
                    # USING SQL EXPRESSION
                    from_expr = True
            else:
                commands.append(node)

        if not abstract:
            if (code is None and from_operator is None
                    and from_function is None
                    and not from_expr):
                raise errors.InvalidOperatorDefinitionError(
                    'CREATE OPERATOR requires at least one USING clause',
                    span=block.span)

            else:
                if from_expr and (from_operator or from_function or code):
                    raise errors.InvalidOperatorDefinitionError(
                        'USING SQL EXPRESSION is mutually exclusive with '
                        'other USING variants',
                        span=block.span)

                props['code'] = qlast.OperatorCode(
                    language=qlast.Language.SQL,
                    from_function=from_function,
                    from_operator=from_operator,
                    from_expr=from_expr,
                    code=code,
                    span=self.span,
                )

        if commands:
            props['commands'] = commands

        return props


#
# ALTER OPERATOR
#

commands_block(
    'AlterOperator',
    SetFieldStmt,
    ResetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    opt=False
)


class AlterOperatorStmt(Nonterm):
    def reduce_AlterOperatorStmt(self, *kids):
        """%reduce
           ALTER OperatorKind OPERATOR NodeName CreateFunctionArgs
           AlterOperatorCommandsBlock
        """
        self.val = qlast.AlterOperator(
            kind=kids[1].val,
            name=kids[3].val,
            params=kids[4].val,
            commands=kids[5].val
        )


#
# DROP OPERATOR
#

class DropOperatorStmt(Nonterm):
    def reduce_DropOperator(self, *kids):
        """%reduce
           DROP OperatorKind OPERATOR NodeName CreateFunctionArgs
        """
        self.val = qlast.DropOperator(
            kind=kids[1].val,
            name=kids[3].val,
            params=kids[4].val,
        )


#
# CREATE CAST
#


class CastUseValue(typing.NamedTuple):

    use: str


class CastAllowedUse(Nonterm):

    def reduce_ALLOW_IMPLICIT(self, *kids):
        self.val = CastUseValue(use=kids[1].val.upper())

    def reduce_ALLOW_ASSIGNMENT(self, *kids):
        self.val = CastUseValue(use=kids[1].val.upper())


class CastCode(Nonterm):

    def reduce_USING_Identifier_FUNCTION_BaseStringConstant(self, *kids):
        lang = commondl._parse_language(kids[1])
        if lang not in {qlast.Language.SQL, qlast.Language.EdgeQL}:
            raise EdgeQLSyntaxError(
                f'{lang} language is not supported in USING FUNCTION clause',
                span=kids[1].span) from None

        self.val = qlast.CastCode(language=lang,
                                  from_function=kids[3].val.value)

    def reduce_USING_Identifier_BaseStringConstant(self, *kids):
        lang = commondl._parse_language(kids[1])
        if lang not in {qlast.Language.SQL, qlast.Language.EdgeQL}:
            raise EdgeQLSyntaxError(
                f'{lang} language is not supported in USING clause',
                span=kids[1].span) from None

        self.val = qlast.CastCode(language=lang,
                                  code=kids[2].val.value)

    def reduce_USING_Identifier_CAST(self, *kids):
        lang = commondl._parse_language(kids[1])
        if lang != qlast.Language.SQL:
            raise EdgeQLSyntaxError(
                f'{lang} language is not supported in USING CAST clause',
                span=kids[1].span) from None

        self.val = qlast.CastCode(language=lang, from_cast=True)

    def reduce_USING_Identifier_EXPRESSION(self, *kids):
        lang = commondl._parse_language(kids[1])
        if lang != qlast.Language.SQL:
            raise EdgeQLSyntaxError(
                f'{lang} language is not supported in USING EXPRESSION clause',
                span=kids[1].span) from None

        self.val = qlast.CastCode(language=lang)


commands_block(
    'CreateCast',
    SetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    CastCode,
    CastAllowedUse,
    opt=False
)


class CreateCastStmt(Nonterm):

    def reduce_CreateCastStmt(self, *kids):
        r"""%reduce
            CREATE CAST FROM TypeName TO TypeName
            CreateCastCommandsBlock
        """
        self.val = qlast.CreateCast(
            from_type=kids[3].val,
            to_type=kids[5].val,
            **self._process_cast_body(kids[6])
        )

    def _process_cast_body(self, block):
        props = {}

        commands = []
        from_function = None
        from_expr = False
        from_cast = False
        allow_implicit = False
        allow_assignment = False
        code = None

        for node in block.val:
            if isinstance(node, qlast.CastCode):
                if node.from_function:
                    if from_function is not None:
                        raise EdgeQLSyntaxError(
                            'more than one USING FUNCTION clause',
                            span=node.span)
                    from_function = node.from_function

                elif node.code:
                    if code is not None:
                        raise EdgeQLSyntaxError(
                            'more than one USING  clause',
                            span=node.span)
                    code = node.code

                elif node.from_cast:
                    # USING SQL CAST

                    if from_cast:
                        raise EdgeQLSyntaxError(
                            'more than one USING CAST clause',
                            span=node.span)

                    from_cast = True

                else:
                    # USING SQL EXPRESSION

                    if from_expr:
                        raise EdgeQLSyntaxError(
                            'more than one USING EXPRESSION clause',
                            span=node.span)

                    from_expr = True

            elif isinstance(node, CastUseValue):

                if node.use == 'IMPLICIT':
                    allow_implicit = True
                elif node.use == 'ASSIGNMENT':
                    allow_assignment = True
                else:
                    raise EdgeQLSyntaxError(
                        'unexpected ALLOW clause',
                        span=node.span)

            else:
                commands.append(node)

        if (code is None and from_function is None
                and not from_expr and not from_cast):
            raise EdgeQLSyntaxError(
                'CREATE CAST requires at least one USING clause',
                span=block.span)

        else:
            if from_expr and (from_function or code or from_cast):
                raise EdgeQLSyntaxError(
                    'USING SQL EXPRESSION is mutually exclusive with other '
                    'USING variants',
                    span=block.span)

            if from_cast and (from_function or code or from_expr):
                raise EdgeQLSyntaxError(
                    'USING SQL CAST is mutually exclusive with other '
                    'USING variants',
                    span=block.span)

            props['code'] = qlast.CastCode(
                language=qlast.Language.SQL,
                from_function=from_function,
                from_expr=from_expr,
                from_cast=from_cast,
                code=code,
                span=self.span,
            )

            props['allow_implicit'] = allow_implicit
            props['allow_assignment'] = allow_assignment

        if commands:
            props['commands'] = commands

        return props


#
# ALTER CAST
#

commands_block(
    'AlterCast',
    SetFieldStmt,
    ResetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    opt=False
)


class AlterCastStmt(Nonterm):
    def reduce_AlterCastStmt(self, *kids):
        """%reduce
           ALTER CAST FROM TypeName TO TypeName
           AlterCastCommandsBlock
        """
        self.val = qlast.AlterCast(
            from_type=kids[3].val,
            to_type=kids[5].val,
            commands=kids[6].val,
        )


#
# DROP CAST
#

class DropCastStmt(Nonterm):
    def reduce_DropCastStmt(self, *kids):
        """%reduce
           DROP CAST FROM TypeName TO TypeName
        """
        self.val = qlast.DropCast(
            from_type=kids[3].val,
            to_type=kids[5].val,
        )

#
# CREATE GLOBAL
#


commands_block(
    'CreateGlobal',
    UsingStmt,
    SetFieldStmt,
    CreateAnnotationValueStmt,
)


class CreateGlobalStmt(Nonterm):
    def reduce_CreateRegularGlobal(self, *kids):
        """%reduce
            CREATE OptPtrQuals GLOBAL NodeName
            ARROW FullTypeExpr
            OptCreateGlobalCommandsBlock
        """
        self.val = qlast.CreateGlobal(
            name=kids[3].val,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=kids[5].val,
            commands=kids[6].val,
        )

    def reduce_CreateRegularGlobalNew(self, *kids):
        """%reduce
            CREATE OptPtrQuals GLOBAL NodeName
            COLON FullTypeExpr
            OptCreateGlobalCommandsBlock
        """
        self.val = qlast.CreateGlobal(
            name=kids[3].val,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=kids[5].val,
            commands=kids[6].val,
        )

    def reduce_CreateComputableGlobal(self, *kids):
        """%reduce
            CREATE OptPtrQuals GLOBAL NodeName ASSIGN GenExpr
        """
        self.val = qlast.CreateGlobal(
            name=kids[3].val,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=kids[5].val,
        )

    def reduce_CreateComputableGlobalWithUsing(self, *kids):
        """%reduce
            CREATE OptPtrQuals GLOBAL NodeName
            OptCreateConcretePropertyCommandsBlock
        """
        cmds = kids[4].val
        target = None

        for cmd in cmds:
            if isinstance(cmd, qlast.SetField) and cmd.name == 'expr':
                if target is not None:
                    raise EdgeQLSyntaxError(
                        f'computed global with more than one expression',
                        span=kids[3].span)
                target = cmd.value

        if target is None:
            raise EdgeQLSyntaxError(
                f'computed global without expression',
                span=kids[3].span)

        self.val = qlast.CreateGlobal(
            name=kids[3].val,
            is_required=kids[1].val.required,
            cardinality=kids[1].val.cardinality,
            target=target,
            commands=cmds,
        )


class SetGlobalTypeStmt(Nonterm):

    def reduce_SETTYPE_FullTypeExpr_OptAlterUsingClause(self, *kids):
        self.val = qlast.SetGlobalType(
            value=kids[1].val,
            cast_expr=kids[2].val,
        )

    def reduce_SETTYPE_FullTypeExpr_RESET_TO_DEFAULT(self, *kids):
        self.val = qlast.SetGlobalType(
            value=kids[1].val,
            reset_value=True,
        )

    def reduce_RESET_TYPE(self, *kids):
        self.val = qlast.SetGlobalType(
            value=None,
        )


commands_block(
    'AlterGlobal',
    UsingStmt,
    RenameStmt,
    SetFieldStmt,
    ResetFieldStmt,
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    SetGlobalTypeStmt,
    SetCardinalityStmt,
    SetRequiredStmt,
    opt=False
)


class AlterGlobalStmt(Nonterm):
    def reduce_AlterGlobal(self, *kids):
        r"""%reduce \
            ALTER GLOBAL NodeName \
            AlterGlobalCommandsBlock \
        """
        self.val = qlast.AlterGlobal(
            name=kids[2].val,
            commands=kids[3].val
        )


class DropGlobalStmt(Nonterm):
    def reduce_DropGlobal(self, *kids):
        r"""%reduce DROP GLOBAL NodeName"""
        self.val = qlast.DropGlobal(
            name=kids[2].val
        )


#
# CREATE PERMISSION
#
commands_block(
    'CreatePermission',
    CreateAnnotationValueStmt,
)


class CreatePermissionStmt(Nonterm):
    def reduce_CreatePermission(self, *kids):
        """%reduce
            CREATE PERMISSION NodeName
            OptCreatePermissionCommandsBlock
        """
        _, _, name, commands = kids
        self.val = qlast.CreatePermission(
            name=name.val,
            commands=commands.val,
        )


#
# ALTER PERMISSION
#
commands_block(
    'AlterPermission',
    CreateAnnotationValueStmt,
    AlterAnnotationValueStmt,
    DropAnnotationValueStmt,
    RenameStmt,
    opt=False
)


class AlterPermissionStmt(Nonterm):
    def reduce_AlterPermission(self, *kids):
        r"""%reduce \
            ALTER PERMISSION NodeName \
            AlterPermissionCommandsBlock \
        """
        _, _, name, commands = kids
        self.val = qlast.AlterPermission(
            name=name.val,
            commands=commands.val,
        )


#
# DROP PERMISSION
#
class DropPermissionStmt(Nonterm):
    def reduce_DropPermission(self, *kids):
        r"""%reduce DROP PERMISSION NodeName"""
        _, _, name = kids
        self.val = qlast.DropPermission(
            name=name.val
        )


#
# MIGRATIONS
#


class MigrationStmt(Nonterm):

    @parsing.inline(0)
    def reduce_CreateMigrationStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AlterMigrationStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AlterCurrentMigrationStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_StartMigrationStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AbortMigrationStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_PopulateMigrationStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_CommitMigrationStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_DropMigrationStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ResetSchemaStmt(self, *kids):
        pass


class MigrationBody(typing.NamedTuple):

    body: qlast.NestedQLBlock
    fields: list[qlast.SetField]


class CreateMigrationBodyBlock(NestedQLBlock):

    @property
    def allowed_fields(self) -> frozenset[str]:
        return frozenset({'message', 'generated_by'})

    @property
    def result(self) -> typing.Any:
        return MigrationBody


commands_block(
    'CreateMigration',
    NestedQLBlockStmt,
    opt=True,
    production_tpl=CreateMigrationBodyBlock,
)


class MigrationNameAndParent(typing.NamedTuple):

    name: typing.Optional[qlast.ObjectRef]
    parent: typing.Optional[qlast.ObjectRef]


class OptMigrationNameParentName(Nonterm):

    def reduce_ShortNodeName_ONTO_ShortNodeName(self, *kids):
        self.val = MigrationNameAndParent(
            name=kids[0].val,
            parent=kids[2].val,
        )

    def reduce_ShortNodeName(self, *kids):
        self.val = MigrationNameAndParent(
            name=kids[0].val,
            parent=None,
        )

    def reduce_empty(self):
        self.val = MigrationNameAndParent(
            name=None,
            parent=None,
        )


class CreateMigrationStmt(Nonterm):

    def reduce_CreateMigration(self, *kids):
        r"""%reduce
            CREATE MIGRATION OptMigrationNameParentName
            OptCreateMigrationCommandsBlock
        """
        self.val = qlast.CreateMigration(
            name=kids[2].val.name,
            parent=kids[2].val.parent,
            body=kids[3].val.body,
            commands=kids[3].val.fields,
        )

    def reduce_CreateAppliedMigration(self, *kids):
        r"""%reduce
            CREATE APPLIED MIGRATION OptMigrationNameParentName
            OptCreateMigrationCommandsBlock
        """
        self.val = qlast.CreateMigration(
            name=kids[3].val.name,
            parent=kids[3].val.parent,
            body=kids[4].val.body,
            metadata_only=True,
            commands=kids[4].val.fields,
        )


class StartMigrationStmt(Nonterm):

    def reduce_StartMigration(self, *kids):
        r"""%reduce START MIGRATION TO SDLCommandBlock"""

        declarations = kids[3].val
        commondl._validate_declarations(declarations)
        self.val = qlast.StartMigration(
            target=qlast.Schema(
                declarations=declarations,
                span=kids[3].span,
            ),
        )

    def reduce_StartMigrationToCommitted(self, *kids):
        r"""%reduce START MIGRATION TO COMMITTED SCHEMA"""
        self.val = qlast.StartMigration(
            target=qlast.CommittedSchema(span=self.span)
        )

    def reduce_StartMigrationRewrite(self, *kids):
        r"""%reduce START MIGRATION REWRITE"""
        self.val = qlast.StartMigrationRewrite()


class PopulateMigrationStmt(Nonterm):

    def reduce_POPULATE_MIGRATION(self, *kids):
        self.val = qlast.PopulateMigration()


class AlterCurrentMigrationStmt(Nonterm):

    def reduce_ALTER_CURRENT_MIGRATION_REJECT_PROPOSED(self, *kids):
        self.val = qlast.AlterCurrentMigrationRejectProposed()


class AbortMigrationStmt(Nonterm):

    def reduce_ABORT_MIGRATION(self, *kids):
        self.val = qlast.AbortMigration()

    def reduce_ABORT_MIGRATION_REWRITE(self, *kids):
        self.val = qlast.AbortMigrationRewrite()


class CommitMigrationStmt(Nonterm):

    def reduce_COMMIT_MIGRATION(self, *kids):
        self.val = qlast.CommitMigration()

    def reduce_COMMIT_MIGRATION_REWRITE(self, *kids):
        self.val = qlast.CommitMigrationRewrite()


commands_block(
    'AlterMigration',
    SetFieldStmt,
    ResetFieldStmt,
    opt=False,
)


class AlterMigrationStmt(Nonterm):
    def reduce_AlterMigration(self, *kids):
        r"""%reduce ALTER MIGRATION NodeName \
                    AlterMigrationCommandsBlock \
        """
        self.val = qlast.AlterMigration(
            name=kids[2].val,
            commands=kids[3].val
        )


class DropMigrationStmt(Nonterm):
    def reduce_DROP_MIGRATION_NodeName(self, *kids):
        self.val = qlast.DropMigration(
            name=kids[2].val,
        )


class ResetSchemaStmt(Nonterm):
    def reduce_ResetSchemaTo(self, *kids):
        r"""%reduce RESET SCHEMA TO NodeName"""
        self.val = qlast.ResetSchema(
            target=kids[3].val,
        )


================================================
FILE: edb/edgeql/parser/grammar/expressions.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

import collections
import typing

from edb.common import parsing, span

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from edb import errors

from . import keywords
from . import precedence
from . import tokens

from .precedence import *  # NOQA
from .tokens import *  # NOQA


class Nonterm(parsing.Nonterm, is_internal=True):
    pass


def merge_spans(nodes: typing.Iterable[Nonterm]) -> span.Span:
    return assert_non_null(span.merge_spans(n.span for n in nodes if n.span))


def assert_non_null(span):
    assert span
    return span


class ListNonterm(parsing.ListNonterm, element=None, is_internal=True):
    pass


# We have an annoying split between "simple" ExprStmt and "annoying"
# ExprStmt. The heart of the issue is we want to allow unparenthesized
# statements in places like function arguments, but the trailing
# parenthesis allowed in the BY clause of GROUP conflicts with the
# commas there.
#
# So instead we allow unparenthesized expressions as long as they
# aren't GROUP (or a FOR  IN  GROUP ...).
class ExprStmt(Nonterm):
    val: qlast.Query

    @parsing.inline(0)
    def reduce_ExprStmtSimple(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ExprStmtAnnoying(self, *kids):
        pass


class ExprStmtSimple(Nonterm):
    val: qlast.Query

    def reduce_WithBlock_ExprStmtSimpleCore(self, *kids):
        self.val = kids[1].val
        self.val.aliases = kids[0].val.aliases

    @parsing.inline(0)
    def reduce_ExprStmtSimpleCore(self, *kids):
        pass


class ExprStmtAnnoying(Nonterm):
    val: qlast.Query

    def reduce_WithBlock_ExprStmtAnnoyingCore(self, *kids):
        self.val = kids[1].val
        self.val.aliases = kids[0].val.aliases

    @parsing.inline(0)
    def reduce_ExprStmtAnnoyingCore(self, *kids):
        pass


class ExprStmtSimpleCore(Nonterm):
    val: qlast.Query

    def reduce_Select(self, *kids):
        r"%reduce SELECT OptionallyAliasedExpr \
                  OptFilterClause OptSortClause OptSelectLimit"

        offset, limit = kids[4].val

        if offset is not None or limit is not None:
            subj = qlast.SelectQuery(
                result=kids[1].val.expr,
                result_alias=kids[1].val.alias,
                where=kids[2].val,
                orderby=kids[3].val,
                implicit=True,
                span=merge_spans((kids[0], kids[3]))
            )

            self.val = qlast.SelectQuery(
                result=subj,
                offset=offset,
                limit=limit,
            )
        else:
            self.val = qlast.SelectQuery(
                result=kids[1].val.expr,
                result_alias=kids[1].val.alias,
                where=kids[2].val,
                orderby=kids[3].val,
            )

    def reduce_Insert(self, *kids):
        r'%reduce INSERT Expr OptUnlessConflictClause'

        subj = kids[1].val
        unless_conflict = kids[2].val

        if isinstance(subj, qlast.Shape):
            if not subj.expr:
                raise errors.EdgeQLSyntaxError(
                    "insert shape expressions must have a type name",
                    span=subj.span
                )
            subj_path = subj.expr
            shape = subj.elements
        else:
            subj_path = subj
            shape = []

        if isinstance(subj_path, qlast.Path) and \
                len(subj_path.steps) == 1 and \
                isinstance(subj_path.steps[0], qlast.ObjectRef):
            objtype = subj_path.steps[0]
        elif isinstance(subj_path, qlast.IfElse):
            # Insert attempted on something that looks like a conditional
            # expression. Aside from it being an error, it also seems that
            # the intent was to insert something conditionally.
            raise errors.EdgeQLSyntaxError(
                f"INSERT only works with object types, not conditional "
                f"expressions",
                hint=(
                    f"To resolve this try surrounding the INSERT branch of "
                    f"the conditional expression with parentheses. This way "
                    f"the INSERT will be triggered conditionally in one of "
                    f"the branches."
                ),
                span=subj_path.span)
        else:
            raise errors.EdgeQLSyntaxError(
                f"INSERT only works with object types, not arbitrary "
                f"expressions",
                hint=(
                    f"To resolve this try to surround the entire INSERT "
                    f"statement with parentheses in order to separate it "
                    f"from the rest of the expression."
                ),
                span=subj_path.span)

        self.val = qlast.InsertQuery(
            subject=objtype,
            shape=shape,
            unless_conflict=unless_conflict,
        )

    def reduce_Update(self, *kids):
        "%reduce UPDATE Expr OptFilterClause SET Shape"
        self.val = qlast.UpdateQuery(
            subject=kids[1].val,
            where=kids[2].val,
            shape=kids[4].val,
        )

    def reduce_Delete(self, *kids):
        r"%reduce DELETE Expr \
                  OptFilterClause OptSortClause OptSelectLimit"
        self.val = qlast.DeleteQuery(
            subject=kids[1].val,
            where=kids[2].val,
            orderby=kids[3].val,
            offset=kids[4].val[0],
            limit=kids[4].val[1],
        )

    def reduce_ForIn(self, *kids):
        r"%reduce FOR OptionalOptional Identifier IN AtomicExpr UNION Expr"
        _, optional, iterator_alias, _, iterator, _, body = kids
        self.val = qlast.ForQuery(
            optional=optional.val,
            iterator_alias=iterator_alias.val,
            iterator=iterator.val,
            result=body.val,
        )

    def reduce_ForInStmt(self, *kids):
        r"%reduce FOR OptionalOptional Identifier IN AtomicExpr ExprStmtSimple"
        _, optional, iterator_alias, _, iterator, body = kids
        self.val = qlast.ForQuery(
            has_union=False,
            optional=optional.val,
            iterator_alias=iterator_alias.val,
            iterator=iterator.val,
            result=body.val,
        )

    def reduce_InternalGroup(self, *kids):
        r"%reduce FOR GROUP OptionallyAliasedExpr \
                  UsingClause \
                  ByClause \
                  IN Identifier OptGroupingAlias \
                  UNION OptionallyAliasedExpr \
                  OptFilterClause OptSortClause \
        "
        self.val = qlast.InternalGroupQuery(
            subject=kids[2].val.expr,
            subject_alias=kids[2].val.alias,
            using=kids[3].val,
            by=kids[4].val,
            group_alias=kids[6].val,
            grouping_alias=kids[7].val,
            result_alias=kids[9].val.alias,
            result=kids[9].val.expr,
            where=kids[10].val,
            orderby=kids[11].val,
        )


class ExprStmtAnnoyingCore(Nonterm):
    @parsing.inline(0)
    def reduce_AnnoyingFor(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_SimpleGroup(self, *kids):
        pass


# A "generalized expression" that can be either an expression or
# *most* unparenthesized statements.
#
# (Note that a number of places that are *approximately* using this
# instead need to spell it out more explicitly because it doesn't
# exactly fit.)
class GenExpr(Nonterm):
    val: qlast.Expr

    @parsing.inline(0)
    def reduce_Expr(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ExprStmtSimpleCore(self, *kids):
        pass


class AliasedExpr(Nonterm):
    val: qlast.AliasedExpr

    def reduce_Identifier_ASSIGN_Expr(self, *kids):
        self.val = qlast.AliasedExpr(alias=kids[0].val, expr=kids[2].val)


# NOTE: This is intentionally not an AST node, since this structure never
# makes it to the actual AST and exists solely for parser convenience.
AliasedExprSpec = collections.namedtuple(
    'AliasedExprSpec', ['alias', 'expr'], module=__name__)


class OptionallyAliasedExpr(Nonterm):
    val: AliasedExprSpec

    def reduce_AliasedExpr(self, *kids):
        val = kids[0].val
        self.val = AliasedExprSpec(alias=val.alias, expr=val.expr)

    def reduce_Expr(self, *kids):
        self.val = AliasedExprSpec(alias=None, expr=kids[0].val)


class AliasedExprList(ListNonterm, element=AliasedExpr,
                      separator=tokens.T_COMMA, allow_trailing_separator=True):
    val: list[qlast.AliasedExpr]


class GroupingIdent(Nonterm):
    val: qlast.GroupingAtom

    def reduce_Identifier(self, *kids):
        self.val = qlast.ObjectRef(name=kids[0].val)

    def reduce_DOT_Identifier(self, *kids):
        self.val = qlast.Path(
            partial=True,
            steps=[
                qlast.Ptr(
                    name=kids[1].val,
                    span=kids[1].span,
                )
            ],
        )

    def reduce_AT_Identifier(self, *kids):
        self.val = qlast.Path(
            partial=True,
            steps=[
                qlast.Ptr(
                    name=kids[1].val,
                    type='property',
                    span=kids[1].span,
                )
            ]
        )


class GroupingIdentList(ListNonterm, element=GroupingIdent,
                        separator=tokens.T_COMMA):
    val: list[qlast.GroupingAtom]


class GroupingAtom(Nonterm):
    val: qlast.GroupingAtom

    @parsing.inline(0)
    def reduce_GroupingIdent(self, *kids):
        pass

    def reduce_LPAREN_GroupingIdentList_RPAREN(self, *kids):
        self.val = qlast.GroupingIdentList(elements=kids[1].val)


class GroupingAtomList(
        ListNonterm, element=GroupingAtom, separator=tokens.T_COMMA,
        allow_trailing_separator=True):
    val: list[qlast.GroupingAtom]


class GroupingElement(Nonterm):
    val: qlast.GroupingElement

    def reduce_GroupingAtom(self, *kids):
        self.val = qlast.GroupingSimple(element=kids[0].val)

    def reduce_LBRACE_GroupingElementList_RBRACE(self, *kids):
        self.val = qlast.GroupingSets(sets=kids[1].val)

    def reduce_ROLLUP_LPAREN_GroupingAtomList_RPAREN(self, *kids):
        self.val = qlast.GroupingOperation(oper='rollup', elements=kids[2].val)

    def reduce_CUBE_LPAREN_GroupingAtomList_RPAREN(self, *kids):
        self.val = qlast.GroupingOperation(oper='cube', elements=kids[2].val)


class GroupingElementList(
        ListNonterm, element=GroupingElement, separator=tokens.T_COMMA,
        allow_trailing_separator=True):
    val: list[qlast.GroupingElement]


class OptionalOptional(Nonterm):
    val: bool

    def reduce_OPTIONAL(self, *kids):
        self.val = True

    def reduce_empty(self, *kids):
        self.val = False


class AnnoyingFor(Nonterm):
    val: qlast.ForQuery

    def reduce_ForInStmt(self, *kids):
        r"%reduce FOR OptionalOptional Identifier IN AtomicExpr \
                  ExprStmtAnnoying"
        _, optional, iterator_alias, _, iterator, body = kids
        self.val = qlast.ForQuery(
            has_union=False,
            optional=optional.val,
            iterator_alias=iterator_alias.val,
            iterator=iterator.val,
            result=body.val,
        )


class ByClause(Nonterm):
    val: list[qlast.GroupingElement]

    @parsing.inline(1)
    def reduce_BY_GroupingElementList(self, *kids):
        pass


class UsingClause(Nonterm):
    val: list[qlast.AliasedExpr]

    @parsing.inline(1)
    def reduce_USING_AliasedExprList(self, *kids):
        pass


class OptUsingClause(Nonterm):
    val: list[qlast.AliasedExpr]

    @parsing.inline(0)
    def reduce_UsingClause(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = None


class SimpleGroup(Nonterm):
    val: qlast.GroupQuery

    def reduce_Group(self, *kids):
        r"%reduce GROUP OptionallyAliasedExpr \
                  OptUsingClause \
                  ByClause"
        self.val = qlast.GroupQuery(
            subject=kids[1].val.expr,
            subject_alias=kids[1].val.alias,
            using=kids[2].val,
            by=kids[3].val,
        )


class OptGroupingAlias(Nonterm):
    val: typing.Optional[qlast.GroupQuery]

    @parsing.inline(1)
    def reduce_COMMA_Identifier(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = None


FunctionResultData = collections.namedtuple(
    'FunctionResultData',
    ['type_qualifier', 'result_type'],
    module=__name__
)


class FunctionResult(Nonterm):
    def reduce_ARROW_OptTypeQualifier_FunctionType(
        self, _, type_qualifier, result_type
    ):
        self.val = FunctionResultData(
            type_qualifier=type_qualifier.val,
            result_type=result_type.val,
        )


WithBlockData = collections.namedtuple(
    'WithBlockData', ['aliases'], module=__name__)


class WithBlock(Nonterm):
    def reduce_WITH_WithDeclList(self, *kids):
        aliases = []
        for w in kids[1].val:
            aliases.append(w)
        self.val = WithBlockData(aliases=aliases)


class AliasDecl(Nonterm):
    def reduce_MODULE_ModuleName(self, *kids):
        self.val = qlast.ModuleAliasDecl(
            module='::'.join(kids[1].val)
        )

    def reduce_Identifier_AS_MODULE_ModuleName(self, *kids):
        self.val = qlast.ModuleAliasDecl(
            alias=kids[0].val,
            module='::'.join(kids[3].val)
        )

    @parsing.inline(0)
    def reduce_AliasedExpr(self, *kids):
        pass

    def reduce_Identifier_ASSIGN_ExprStmtSimple(self, *kids):
        self.val = qlast.AliasedExpr(alias=kids[0].val, expr=kids[2].val)


class WithDecl(Nonterm):
    @parsing.inline(0)
    def reduce_AliasDecl(self, *kids):
        pass


class WithDeclList(ListNonterm, element=WithDecl,
                   separator=tokens.T_COMMA, allow_trailing_separator=True):
    pass


class Shape(Nonterm):
    def reduce_LBRACE_RBRACE(self, *kids):
        self.val = []

    @parsing.inline(1)
    def reduce_LBRACE_ShapeElementList_RBRACE(self, *kids):
        pass


class FreeShape(Nonterm):
    def reduce_LBRACE_FreeComputableShapePointerList_RBRACE(self, *kids):
        self.val = qlast.Shape(elements=kids[1].val)


class OptAnySubShape(Nonterm):
    @parsing.inline(1)
    def reduce_COLON_Shape(self, *_):
        pass

    def reduce_empty(self, *kids):
        self.val = []


class ShapeElement(Nonterm):
    def reduce_ShapeElementWithSubShape(self, *kids):
        r"""%reduce ShapePointer \
             OptAnySubShape OptFilterClause OptSortClause OptSelectLimit \
        """
        self.val = kids[0].val
        self.val.elements = kids[1].val
        self.val.where = kids[2].val
        self.val.orderby = kids[3].val
        self.val.offset = kids[4].val[0]
        self.val.limit = kids[4].val[1]

    @parsing.inline(0)
    def reduce_ComputableShapePointer(self, *kids):
        pass


class ShapeElementList(ListNonterm, element=ShapeElement,
                       separator=tokens.T_COMMA, allow_trailing_separator=True):
    pass


class SimpleShapePath(Nonterm):

    def reduce_PathStepName(self, *kids):
        from edb.schema import pointers as s_pointers

        steps = [
            qlast.Ptr(
                name=kids[0].val.name,
                direction=s_pointers.PointerDirection.Outbound,
                span=kids[0].span,
            ),
        ]

        self.val = qlast.Path(steps=steps)

    def reduce_AT_PathNodeName(self, *kids):
        self.val = qlast.Path(
            steps=[
                qlast.Ptr(
                    name=kids[1].val.name,
                    type='property',
                    span=kids[1].span,
                )
            ]
        )


class SimpleShapePointer(Nonterm):

    def reduce_SimpleShapePath(self, *kids):
        self.val = qlast.ShapeElement(
            expr=kids[0].val
        )


# Shape pointers in free shapes are not allowed to be link
# properties. This is because we need to be able to distinguish
# free shapes from set literals with only one token of lookahead
# (since this is an LL(1) parser) and seeing the := after @ident would
# require two tokens of lookahead.
class FreeSimpleShapePointer(Nonterm):

    def reduce_FreeStepName(self, *kids):
        from edb.schema import pointers as s_pointers

        steps = [
            qlast.Ptr(
                name=kids[0].val.name,
                direction=s_pointers.PointerDirection.Outbound,
                span=kids[0].span,
            ),
        ]

        self.val = qlast.ShapeElement(
            expr=qlast.Path(steps=steps, span=self.span)
        )


class ShapePath(Nonterm):
    # A form of Path appearing as an element in shapes.
    #
    # one-of:
    #   __type__
    #   link
    #   @prop
    #   [IS ObjectType].link
    #   [IS Link]@prop - currently not supported
    #    (see Splat production for possible syntaxes)

    def reduce_PathStepName_OptTypeIntersection(self, *kids):
        from edb.schema import pointers as s_pointers

        steps = [
            qlast.Ptr(
                name=kids[0].val.name,
                direction=s_pointers.PointerDirection.Outbound,
                span=kids[0].span,
            ),
        ]

        if kids[1].val is not None:
            steps.append(kids[1].val)

        self.val = qlast.Path(steps=steps)

    @parsing.inline(0)
    def reduce_Splat(self, *kids):
        pass

    def reduce_AT_PathNodeName(self, *kids):
        self.val = qlast.Path(
            steps=[
                qlast.Ptr(
                    name=kids[1].val.name,
                    type='property',
                    span=kids[1].span,
                )
            ]
        )

    def reduce_TypeIntersection_DOT_PathStepName_OptTypeIntersection(
            self, *kids):
        from edb.schema import pointers as s_pointers

        steps = [
            kids[0].val,
            qlast.Ptr(
                name=kids[2].val.name,
                direction=s_pointers.PointerDirection.Outbound,
                span=kids[2].span,
            ),
        ]

        if kids[3].val is not None:
            steps.append(kids[3].val)

        self.val = qlast.Path(steps=steps)


# N.B. the production verbosity below is necessary due to conflicts,
#      as is the use of PathStepName in place of SimpleTypeName.
class Splat(Nonterm):
    def reduce_STAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(depth=1, span=kids[0].span),
        ])

    def reduce_DOUBLESTAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(depth=2, span=kids[0].span),
        ])

    # Type.*
    def reduce_PathStepName_DOT_STAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=1,
                type=qlast.TypeName(
                    maintype=kids[0].val, span=kids[0].span
                ),
                span=merge_spans(kids),
            ),
        ])

    # Type.**
    def reduce_PathStepName_DOT_DOUBLESTAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=2,
                type=qlast.TypeName(
                    maintype=kids[0].val, span=kids[0].span
                ),
                span=merge_spans(kids),
            ),
        ])

    # [is Foo].*
    def reduce_TypeIntersection_DOT_STAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=1,
                intersection=kids[0].val,
                span=merge_spans(kids),
            ),
        ])

    # [is Foo].**
    def reduce_TypeIntersection_DOT_DOUBLESTAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=2,
                intersection=kids[0].val,
                span=merge_spans(kids),
            ),
        ])

    # Type[is Foo].*
    def reduce_PathStepName_TypeIntersection_DOT_STAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=1,
                type=qlast.TypeName(
                    maintype=kids[0].val, span=kids[0].span
                ),
                intersection=kids[1].val,
                span=merge_spans(kids),
            ),
        ])

    # Type[is Foo].**
    def reduce_PathStepName_TypeIntersection_DOT_DOUBLESTAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=2,
                type=qlast.TypeName(
                    maintype=kids[0].val, span=kids[0].span
                ),
                intersection=kids[1].val,
                span=merge_spans(kids),
            ),
        ])

    # module::Type.*
    def reduce_PtrQualifiedNodeName_DOT_STAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                type=qlast.TypeName(
                    maintype=kids[0].val, span=kids[0].span
                ),
                depth=1,
                span=merge_spans(kids),
            ),
        ])

    # module::Type.**
    def reduce_PtrQualifiedNodeName_DOT_DOUBLESTAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                type=qlast.TypeName(
                    maintype=kids[0].val, span=kids[0].span
                ),
                depth=2,
                span=merge_spans(kids),
            ),
        ])

    # module::Type[is ].*
    def reduce_PtrQualifiedNodeName_TypeIntersection_DOT_STAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=1,
                type=qlast.TypeName(
                    maintype=kids[0].val, span=kids[0].span
                ),
                intersection=kids[1].val,
                span=merge_spans(kids),
            ),
        ])

    # module::Type[is ].**
    def reduce_PtrQualifiedNodeName_TypeIntersection_DOT_DOUBLESTAR(
        self,
        *kids,
    ):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=2,
                type=qlast.TypeName(
                    maintype=kids[0].val, span=kids[0].span
                ),
                intersection=kids[1].val,
                span=merge_spans(kids),
            ),
        ])

    # ().*
    def reduce_ParenTypeExpr_DOT_STAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=1,
                type=kids[0].val,
                span=merge_spans(kids),
            ),
        ])

    # ().**
    def reduce_ParenTypeExpr_TypeIntersection_DOT_STAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=1,
                type=kids[0].val,
                intersection=kids[1].val,
                span=merge_spans(kids),
            ),
        ])

    # ()[is ].*
    def reduce_ParenTypeExpr_DOT_DOUBLESTAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=2,
                type=kids[0].val,
                span=merge_spans(kids),
            ),
        ])

    # ()[is ].**
    def reduce_ParenTypeExpr_TypeIntersection_DOT_DOUBLESTAR(self, *kids):
        self.val = qlast.Path(steps=[
            qlast.Splat(
                depth=2,
                type=kids[0].val,
                intersection=kids[1].val,
                span=merge_spans(kids),
            ),
        ])


class ShapePointer(Nonterm):
    def reduce_ShapePath(self, *kids):
        self.val = qlast.ShapeElement(
            expr=kids[0].val
        )


class PtrQualsSpec(typing.NamedTuple):
    required: typing.Optional[bool] = None
    cardinality: typing.Optional[qltypes.SchemaCardinality] = None


class PtrQuals(Nonterm):
    def reduce_OPTIONAL(self, *kids):
        self.val = PtrQualsSpec(required=False)

    def reduce_REQUIRED(self, *kids):
        self.val = PtrQualsSpec(required=True)

    def reduce_SINGLE(self, *kids):
        self.val = PtrQualsSpec(cardinality=qltypes.SchemaCardinality.One)

    def reduce_MULTI(self, *kids):
        self.val = PtrQualsSpec(cardinality=qltypes.SchemaCardinality.Many)

    def reduce_OPTIONAL_SINGLE(self, *kids):
        self.val = PtrQualsSpec(
            required=False, cardinality=qltypes.SchemaCardinality.One)

    def reduce_OPTIONAL_MULTI(self, *kids):
        self.val = PtrQualsSpec(
            required=False, cardinality=qltypes.SchemaCardinality.Many)

    def reduce_REQUIRED_SINGLE(self, *kids):
        self.val = PtrQualsSpec(
            required=True, cardinality=qltypes.SchemaCardinality.One)

    def reduce_REQUIRED_MULTI(self, *kids):
        self.val = PtrQualsSpec(
            required=True, cardinality=qltypes.SchemaCardinality.Many)


class OptPtrQuals(Nonterm):

    def reduce_empty(self, *kids):
        self.val = PtrQualsSpec()

    @parsing.inline(0)
    def reduce_PtrQuals(self, *kids):
        pass


# We have to inline the OptPtrQuals here because the parser generator
# fails to cope with a shift/reduce on a REQUIRED token, since PtrQuals
# are followed by an ident in this case (unlike in DDL, where it is followed
# by a keyword).
class ComputableShapePointer(Nonterm):

    def reduce_OPTIONAL_SimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[1].val
        self.val.compexpr = kids[3].val
        self.val.required = False
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[2].span),
        )

    def reduce_REQUIRED_SimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[1].val
        self.val.compexpr = kids[3].val
        self.val.required = True
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[2].span),
        )

    def reduce_MULTI_SimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[1].val
        self.val.compexpr = kids[3].val
        self.val.cardinality = qltypes.SchemaCardinality.Many
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[2].span),
        )

    def reduce_SINGLE_SimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[1].val
        self.val.compexpr = kids[3].val
        self.val.cardinality = qltypes.SchemaCardinality.One
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[2].span),
        )

    def reduce_OPTIONAL_MULTI_SimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[2].val
        self.val.compexpr = kids[4].val
        self.val.required = False
        self.val.cardinality = qltypes.SchemaCardinality.Many
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[3].span),
        )

    def reduce_OPTIONAL_SINGLE_SimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[2].val
        self.val.compexpr = kids[4].val
        self.val.required = False
        self.val.cardinality = qltypes.SchemaCardinality.One
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[3].span),
        )

    def reduce_REQUIRED_MULTI_SimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[2].val
        self.val.compexpr = kids[4].val
        self.val.required = True
        self.val.cardinality = qltypes.SchemaCardinality.Many
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[3].span),
        )

    def reduce_REQUIRED_SINGLE_SimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[2].val
        self.val.compexpr = kids[4].val
        self.val.required = True
        self.val.cardinality = qltypes.SchemaCardinality.One
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[3].span),
        )

    def reduce_SimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[0].val
        self.val.compexpr = kids[2].val
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[1].span),
        )

    def reduce_SimpleShapePointer_ADDASSIGN_GenExpr(self, *kids):
        self.val = kids[0].val
        self.val.compexpr = kids[2].val
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.APPEND,
            span=assert_non_null(kids[1].span),
        )

    def reduce_SimpleShapePointer_REMASSIGN_GenExpr(self, *kids):
        self.val = kids[0].val
        self.val.compexpr = kids[2].val
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.SUBTRACT,
            span=assert_non_null(kids[1].span),
        )


# This is the same as the above ComputableShapePointer, except using
# FreeSimpleShapePointer and not allowing +=/-=.
class FreeComputableShapePointer(Nonterm):
    def reduce_OPTIONAL_FreeSimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[1].val
        self.val.compexpr = kids[3].val
        self.val.required = False
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[2].span),
        )

    def reduce_REQUIRED_FreeSimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[1].val
        self.val.compexpr = kids[3].val
        self.val.required = True
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[2].span),
        )

    def reduce_MULTI_FreeSimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[1].val
        self.val.compexpr = kids[3].val
        self.val.cardinality = qltypes.SchemaCardinality.Many
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[2].span),
        )

    def reduce_SINGLE_FreeSimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[1].val
        self.val.compexpr = kids[3].val
        self.val.cardinality = qltypes.SchemaCardinality.One
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[2].span),
        )

    def reduce_OPTIONAL_MULTI_FreeSimpleShapePointer_ASSIGN_GenExpr(
        self, *kids
    ):
        self.val = kids[2].val
        self.val.compexpr = kids[4].val
        self.val.required = False
        self.val.cardinality = qltypes.SchemaCardinality.Many
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[3].span),
        )

    def reduce_OPTIONAL_SINGLE_FreeSimpleShapePointer_ASSIGN_GenExpr(
        self, *kids
    ):
        self.val = kids[2].val
        self.val.compexpr = kids[4].val
        self.val.required = False
        self.val.cardinality = qltypes.SchemaCardinality.One
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[3].span),
        )

    def reduce_REQUIRED_MULTI_FreeSimpleShapePointer_ASSIGN_GenExpr(
        self, *kids
    ):
        self.val = kids[2].val
        self.val.compexpr = kids[4].val
        self.val.required = True
        self.val.cardinality = qltypes.SchemaCardinality.Many
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[3].span),
        )

    def reduce_REQUIRED_SINGLE_FreeSimpleShapePointer_ASSIGN_GenExpr(
        self, *kids
    ):
        self.val = kids[2].val
        self.val.compexpr = kids[4].val
        self.val.required = True
        self.val.cardinality = qltypes.SchemaCardinality.One
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[3].span),
        )

    def reduce_FreeSimpleShapePointer_ASSIGN_GenExpr(self, *kids):
        self.val = kids[0].val
        self.val.compexpr = kids[2].val
        self.val.operation = qlast.ShapeOperation(
            op=qlast.ShapeOp.ASSIGN,
            span=assert_non_null(kids[1].span),
        )


class FreeComputableShapePointerList(ListNonterm,
                                     element=FreeComputableShapePointer,
                                     separator=tokens.T_COMMA,
                                     allow_trailing_separator=True):
    pass


class UnlessConflictSpecifier(Nonterm):
    def reduce_ON_Expr_ELSE_Expr(self, *kids):
        self.val = (kids[1].val, kids[3].val)

    def reduce_ON_Expr(self, *kids):
        self.val = (kids[1].val, None)

    def reduce_empty(self, *kids):
        self.val = (None, None)


class UnlessConflictCause(Nonterm):
    @parsing.inline(2)
    def reduce_UNLESS_CONFLICT_UnlessConflictSpecifier(self, *kids):
        pass


class OptUnlessConflictClause(Nonterm):
    @parsing.inline(0)
    def reduce_UnlessConflictCause(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = None


class FilterClause(Nonterm):
    val: qlast.Expr

    @parsing.inline(1)
    def reduce_FILTER_Expr(self, *kids):
        pass


class OptFilterClause(Nonterm):
    val: typing.Optional[qlast.Expr]

    @parsing.inline(0)
    def reduce_FilterClause(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = None


class SortClause(Nonterm):
    val: list[qlast.SortExpr]

    @parsing.inline(1)
    def reduce_ORDERBY_OrderbyList(self, *kids):
        pass


class OptSortClause(Nonterm):
    val: list[qlast.SortExpr]

    @parsing.inline(0)
    def reduce_SortClause(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = []


class OrderbyExpr(Nonterm):
    val: qlast.SortExpr

    def reduce_Expr_OptDirection_OptNonesOrder(self, *kids):
        self.val = qlast.SortExpr(path=kids[0].val,
                                  direction=kids[1].val,
                                  nones_order=kids[2].val)


class OrderbyList(ListNonterm, element=OrderbyExpr,
                  separator=tokens.T_THEN):
    val: list[qlast.SortExpr]


class OptSelectLimit(Nonterm):
    val: tuple[typing.Optional[qlast.Expr], typing.Optional[qlast.Expr]]

    @parsing.inline(0)
    def reduce_SelectLimit(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = (None, None)


class SelectLimit(Nonterm):
    val: tuple[typing.Optional[qlast.Expr], typing.Optional[qlast.Expr]]

    def reduce_OffsetClause_LimitClause(self, *kids):
        self.val = (kids[0].val, kids[1].val)

    def reduce_OffsetClause(self, *kids):
        self.val = (kids[0].val, None)

    def reduce_LimitClause(self, *kids):
        self.val = (None, kids[0].val)


class OffsetClause(Nonterm):
    val: qlast.Expr

    @parsing.inline(1)
    def reduce_OFFSET_Expr(self, *kids):
        pass


class LimitClause(Nonterm):
    val: qlast.Expr

    @parsing.inline(1)
    def reduce_LIMIT_Expr(self, *kids):
        pass


class OptDirection(Nonterm):
    def reduce_ASC(self, *kids):
        self.val = qlast.SortAsc

    def reduce_DESC(self, *kids):
        self.val = qlast.SortDesc

    def reduce_empty(self, *kids):
        self.val = qlast.SortDefault


class OptNonesOrder(Nonterm):
    def reduce_EMPTY_FIRST(self, *kids):
        self.val = qlast.NonesFirst

    def reduce_EMPTY_LAST(self, *kids):
        self.val = qlast.NonesLast

    def reduce_empty(self, *kids):
        self.val = None


class IndirectionEl(Nonterm):
    def reduce_LBRACKET_Expr_RBRACKET(self, *kids):
        self.val = qlast.Index(index=kids[1].val)

    def reduce_LBRACKET_Expr_COLON_Expr_RBRACKET(self, *kids):
        self.val = qlast.Slice(start=kids[1].val, stop=kids[3].val)

    def reduce_LBRACKET_Expr_COLON_RBRACKET(self, *kids):
        self.val = qlast.Slice(start=kids[1].val, stop=None)

    def reduce_LBRACKET_COLON_Expr_RBRACKET(self, *kids):
        self.val = qlast.Slice(start=None, stop=kids[2].val)


class ParenExpr(Nonterm):
    @parsing.inline(1)
    def reduce_LPAREN_Expr_RPAREN(self, *kids):
        pass

    @parsing.inline(1)
    def reduce_LPAREN_ExprStmt_RPAREN(self, *kids):
        pass


class BaseAtomicExpr(Nonterm):
    val: qlast.Expr
    # { ... } | Constant | '(' Expr ')' | FuncExpr
    # | Tuple | NamedTuple | Collection | Set
    # | '__source__' | '__subject__'
    # | '__new__' | '__old__' | '__specified__' | '__default__'
    # | NodeName | PathStep

    @parsing.inline(0)
    def reduce_FreeShape(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_Constant(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_StringInterpolation(self, *kids):
        pass

    def reduce_DUNDERSOURCE(self, kw):
        self.val = qlast.Path(
            steps=[
                qlast.SpecialAnchor(name='__source__', span=kw.span)
            ]
        )

    def reduce_DUNDERSUBJECT(self, kw):
        self.val = qlast.Path(
            steps=[
                qlast.SpecialAnchor(name='__subject__', span=kw.span)
            ]
        )

    def reduce_DUNDERNEW(self, kw):
        self.val = qlast.Path(
            steps=[
                qlast.SpecialAnchor(name='__new__', span=kw.span)
            ]
        )

    def reduce_DUNDEROLD(self, kw):
        self.val = qlast.Path(
            steps=[
                qlast.SpecialAnchor(name='__old__', span=kw.span)
            ]
        )

    def reduce_DUNDERSPECIFIED(self, kw):
        self.val = qlast.Path(
            steps=[
                qlast.SpecialAnchor(name='__specified__', span=kw.span)
            ]
        )

    def reduce_DUNDERDEFAULT(self, kw):
        self.val = qlast.Path(
            steps=[
                qlast.SpecialAnchor(name='__default__', span=kw.span)
            ]
        )

    @parsing.precedence(precedence.P_UMINUS)
    @parsing.inline(0)
    def reduce_ParenExpr(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_FuncExpr(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_Tuple(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_Collection(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_Set(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_NamedTuple(self, *kids):
        pass

    @parsing.precedence(precedence.P_DOT)
    def reduce_NodeName(self, *kids):
        self.val = qlast.Path(
            steps=[
                qlast.ObjectRef(
                    name=kids[0].val.name,
                    module=kids[0].val.module,
                    span=kids[0].span,
                )
            ]
        )

    @parsing.precedence(precedence.P_DOT)
    def reduce_PathStep(self, *kids):
        self.val = qlast.Path(steps=[kids[0].val], partial=True)


class Expr(Nonterm):
    val: qlast.Expr
    # BaseAtomicExpr
    # Path | Expr { ... }

    # | Expr '[' Expr ']'
    # | Expr '[' Expr ':' Expr ']'
    # | Expr '[' ':' Expr ']'
    # | Expr '[' Expr ':' ']'
    # | Expr '[' IS NodeName ']'

    # | '+' Expr | '-' Expr | Expr '+' Expr | Expr '-' Expr
    # | Expr '*' Expr | Expr '/' Expr | Expr '%' Expr
    # | Expr '**' Expr | Expr '<' Expr | Expr '>' Expr
    # | Expr '=' Expr
    # | Expr AND Expr | Expr OR Expr | NOT Expr
    # | Expr LIKE Expr | Expr NOT LIKE Expr
    # | Expr ILIKE Expr | Expr NOT ILIKE Expr
    # | Expr IS TypeExpr | Expr IS NOT TypeExpr
    # | INTROSPECT TypeExpr
    # | Expr IN Expr | Expr NOT IN Expr
    # | '<' TypeName '>' Expr
    # | Expr IF Expr ELSE Expr
    # | Expr ?? Expr
    # | Expr UNION Expr | Expr UNION Expr
    # | DISTINCT Expr
    # | DETACHED Expr
    # | GLOBAL Name
    # | EXISTS Expr

    @parsing.inline(0)
    def reduce_BaseAtomicExpr(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_Path(self, *kids):
        pass

    def reduce_Expr_Shape(self, *kids):
        self.val = qlast.Shape(expr=kids[0].val, elements=kids[1].val)

    def reduce_EXISTS_Expr(self, *kids):
        self.val = qlast.UnaryOp(op='EXISTS', operand=kids[1].val)

    def reduce_DISTINCT_Expr(self, *kids):
        self.val = qlast.UnaryOp(op='DISTINCT', operand=kids[1].val)

    def reduce_DETACHED_Expr(self, *kids):
        self.val = qlast.DetachedExpr(expr=kids[1].val)

    def reduce_GLOBAL_NodeName(self, *kids):
        self.val = qlast.GlobalExpr(name=kids[1].val)

    def reduce_Expr_IndirectionEl(self, *kids):
        expr = kids[0].val
        if isinstance(expr, qlast.Indirection):
            self.val = expr
            expr.indirection.append(kids[1].val)
        else:
            self.val = qlast.Indirection(arg=expr,
                                         indirection=[kids[1].val])

    @parsing.precedence(precedence.P_UMINUS)
    def reduce_PLUS_Expr(self, *kids):
        self.val = qlast.UnaryOp(op=kids[0].val, operand=kids[1].val)

    @parsing.precedence(precedence.P_UMINUS)
    def reduce_MINUS_Expr(self, *kids):
        arg = kids[1].val
        if isinstance(arg, qlast.Constant) and arg.kind in {
            qlast.ConstantKind.INTEGER,
            qlast.ConstantKind.FLOAT,
            qlast.ConstantKind.BIGINT,
            qlast.ConstantKind.DECIMAL,
        }:
            self.val = type(arg)(value=f'-{arg.value}', kind=arg.kind)
        else:
            self.val = qlast.UnaryOp(op=kids[0].val, operand=arg)

    def reduce_Expr_PLUS_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val,
                               right=kids[2].val)

    def reduce_Expr_DOUBLEPLUS_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val,
                               right=kids[2].val)

    def reduce_Expr_MINUS_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val,
                               right=kids[2].val)

    def reduce_Expr_STAR_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val,
                               right=kids[2].val)

    def reduce_Expr_SLASH_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val,
                               right=kids[2].val)

    def reduce_Expr_DOUBLESLASH_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val,
                               right=kids[2].val)

    def reduce_Expr_PERCENT_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val,
                               right=kids[2].val)

    def reduce_Expr_CIRCUMFLEX_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val,
                               right=kids[2].val)

    @parsing.precedence(precedence.P_DOUBLEQMARK_OP)
    def reduce_Expr_DOUBLEQMARK_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val,
                               right=kids[2].val)

    @parsing.precedence(precedence.P_COMPARE_OP)
    def reduce_Expr_CompareOp_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val,
                               right=kids[2].val)

    def reduce_Expr_AND_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val.upper(),
                               right=kids[2].val)

    def reduce_Expr_OR_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op=kids[1].val.upper(),
                               right=kids[2].val)

    def reduce_NOT_Expr(self, *kids):
        self.val = qlast.UnaryOp(op=kids[0].val.upper(), operand=kids[1].val)

    def reduce_Expr_LIKE_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op='LIKE',
                               right=kids[2].val)

    def reduce_Expr_NOT_LIKE_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op='NOT LIKE',
                               right=kids[3].val)

    def reduce_Expr_ILIKE_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op='ILIKE',
                               right=kids[2].val)

    def reduce_Expr_NOT_ILIKE_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op='NOT ILIKE',
                               right=kids[3].val)

    def reduce_Expr_IS_TypeExpr(self, *kids):
        self.val = qlast.IsOp(left=kids[0].val, op='IS',
                              right=kids[2].val)

    @parsing.precedence(precedence.P_IS)
    def reduce_Expr_IS_NOT_TypeExpr(self, *kids):
        self.val = qlast.IsOp(left=kids[0].val, op='IS NOT',
                              right=kids[3].val)

    def reduce_INTROSPECT_TypeExpr(self, *kids):
        self.val = qlast.Introspect(type=kids[1].val)

    def reduce_Expr_IN_Expr(self, *kids):
        inexpr = kids[2].val
        self.val = qlast.BinOp(left=kids[0].val, op='IN',
                               right=inexpr)

    @parsing.precedence(precedence.P_IN)
    def reduce_Expr_NOT_IN_Expr(self, *kids):
        inexpr = kids[3].val
        self.val = qlast.BinOp(left=kids[0].val, op='NOT IN',
                               right=inexpr)

    @parsing.precedence(precedence.P_TYPECAST)
    def reduce_LANGBRACKET_FullTypeExpr_RANGBRACKET_Expr(
            self, *kids):
        self.val = qlast.TypeCast(
            expr=kids[3].val,
            type=kids[1].val,
            cardinality_mod=None,
        )

    @parsing.precedence(precedence.P_TYPECAST)
    def reduce_LANGBRACKET_OPTIONAL_FullTypeExpr_RANGBRACKET_Expr(
            self, *kids):
        self.val = qlast.TypeCast(
            expr=kids[4].val,
            type=kids[2].val,
            cardinality_mod=qlast.CardinalityModifier.Optional,
        )

    @parsing.precedence(precedence.P_TYPECAST)
    def reduce_LANGBRACKET_REQUIRED_FullTypeExpr_RANGBRACKET_Expr(
            self, *kids):
        self.val = qlast.TypeCast(
            expr=kids[4].val,
            type=kids[2].val,
            cardinality_mod=qlast.CardinalityModifier.Required,
        )

    def reduce_Expr_IF_Expr_ELSE_Expr(self, *kids):
        if_expr, _, condition, _, else_expr = kids
        self.val = qlast.IfElse(
            if_expr=if_expr.val,
            condition=condition.val,
            else_expr=else_expr.val,
            python_style=True,
        )

    @parsing.inline(0)
    def reduce_IfThenElseExpr(self, _):
        pass

    def reduce_Expr_UNION_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op='UNION',
                               right=kids[2].val)

    def reduce_Expr_EXCEPT_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op='EXCEPT',
                               right=kids[2].val)

    def reduce_Expr_INTERSECT_Expr(self, *kids):
        self.val = qlast.BinOp(left=kids[0].val, op='INTERSECT',
                               right=kids[2].val)


class IfThenElseExpr(Nonterm):
    def reduce_IF_Expr_THEN_Expr_ELSE_Expr(self, *kids):
        _, condition, _, if_expr, _, else_expr = kids
        self.val = qlast.IfElse(
            condition=condition.val,
            if_expr=if_expr.val,
            else_expr=else_expr.val,
        )


class CompareOp(Nonterm):
    @parsing.inline(0)
    @parsing.precedence(precedence.P_COMPARE_OP)
    def reduce_DISTINCTFROM(self, *_):
        pass

    @parsing.inline(0)
    @parsing.precedence(precedence.P_COMPARE_OP)
    def reduce_GREATEREQ(self, *_):
        pass

    @parsing.inline(0)
    @parsing.precedence(precedence.P_COMPARE_OP)
    def reduce_LESSEQ(self, *_):
        pass

    @parsing.inline(0)
    @parsing.precedence(precedence.P_COMPARE_OP)
    def reduce_NOTDISTINCTFROM(self, *_):
        pass

    @parsing.inline(0)
    @parsing.precedence(precedence.P_COMPARE_OP)
    def reduce_NOTEQ(self, *_):
        pass

    @parsing.inline(0)
    @parsing.precedence(precedence.P_COMPARE_OP)
    def reduce_LANGBRACKET(self, *_):
        pass

    @parsing.inline(0)
    @parsing.precedence(precedence.P_COMPARE_OP)
    def reduce_RANGBRACKET(self, *_):
        pass

    @parsing.inline(0)
    @parsing.precedence(precedence.P_COMPARE_OP)
    def reduce_EQUALS(self, *_):
        pass


class Tuple(Nonterm):
    def reduce_LPAREN_GenExpr_COMMA_OptExprList_RPAREN(self, *kids):
        self.val = qlast.Tuple(elements=[kids[1].val] + kids[3].val)

    def reduce_LPAREN_RPAREN(self, *kids):
        self.val = qlast.Tuple(elements=[])


class NamedTuple(Nonterm):
    def reduce_LPAREN_NamedTupleElementList_RPAREN(self, *kids):
        self.val = qlast.NamedTuple(elements=kids[1].val)


class NamedTupleElement(Nonterm):
    def reduce_ShortNodeName_ASSIGN_GenExpr(self, *kids):
        self.val = qlast.TupleElement(
            name=qlast.Ptr(name=kids[0].val.name, span=kids[0].span),
            val=kids[2].val
        )


class NamedTupleElementList(ListNonterm, element=NamedTupleElement,
                            separator=tokens.T_COMMA,
                            allow_trailing_separator=True):
    pass


class Set(Nonterm):
    def reduce_LBRACE_OptExprList_RBRACE(self, *kids):
        self.val = qlast.Set(elements=kids[1].val)


class Collection(Nonterm):
    def reduce_LBRACKET_OptExprList_RBRACKET(self, *kids):
        elements = kids[1].val
        self.val = qlast.Array(elements=elements)


class OptExprList(Nonterm):
    @parsing.inline(0)
    def reduce_ExprList(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = []


class ExprList(ListNonterm, element=GenExpr, separator=tokens.T_COMMA,
               allow_trailing_separator=True):
    val: list[qlast.Expr]


class Constant(Nonterm):
    val: qlast.Expr

    # PARAMETER
    # | BaseNumberConstant
    # | BaseStringConstant
    # | BaseBooleanConstant
    # | BaseBytesConstant

    def reduce_PARAMETER(self, param):
        self.val = qlast.QueryParameter(name=param.val[1:])

    def reduce_PARAMETERANDTYPE(self, param):
        assert param.val.startswith('$')
        self.val = qlast.TypeCast(
            type=qlast.TypeName(
                maintype=qlast.ObjectRef(
                    name=type_name,
                    module='__std__'
                ),
                span=param.span,
            ),
            expr=qlast.QueryParameter(
                name=param_name,
                span=param.span,
            ),
        )

    @parsing.inline(0)
    def reduce_BaseNumberConstant(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_BaseStringConstant(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_BaseBooleanConstant(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_BaseBytesConstant(self, *kids):
        pass


class StringInterpolationTail(Nonterm):
    def reduce_Expr_STRINTERPEND(self, *kids):
        expr, lit = kids
        self.val = qlast.StrInterp(
            prefix='',
            interpolations=[
                qlast.StrInterpFragment(
                    expr=expr.val, suffix=lit.clean_value, span=self.span
                ),
            ]
        )

    def reduce_Expr_STRINTERPCONT_StringInterpolationTail(self, *kids):
        expr, lit, tail = kids
        self.val = tail.val
        self.val.interpolations.append(
            qlast.StrInterpFragment(
                expr=expr.val, suffix=lit.clean_value, span=self.span
            )
        )


class StringInterpolation(Nonterm):
    def reduce_STRINTERPSTART_StringInterpolationTail(self, *kids):
        # We produce somewhat malformed StrInterp values out of
        # StringInterpolationTail, for convenience and efficiency, and
        # fix them up here.
        # (In particular, we put the interpolations in backward.)
        lit, tail = kids
        self.val = tail.val
        self.val.prefix = lit.clean_value
        self.val.interpolations.reverse()


class BaseNumberConstant(Nonterm):
    val: qlast.Constant

    def reduce_ICONST(self, *kids):
        self.val = qlast.Constant(
            value=kids[0].val, kind=qlast.ConstantKind.INTEGER
        )

    def reduce_FCONST(self, *kids):
        self.val = qlast.Constant(
            value=kids[0].val, kind=qlast.ConstantKind.FLOAT
        )

    def reduce_NICONST(self, *kids):
        self.val = qlast.Constant(
            value=kids[0].val, kind=qlast.ConstantKind.BIGINT
        )

    def reduce_NFCONST(self, *kids):
        self.val = qlast.Constant(
            value=kids[0].val, kind=qlast.ConstantKind.DECIMAL
        )


class BaseStringConstant(Nonterm):
    val: qlast.Constant

    def reduce_SCONST(self, token):
        self.val = qlast.Constant.string(value=token.clean_value)


class BaseBytesConstant(Nonterm):
    val: qlast.BaseConstant

    def reduce_BCONST(self, bytes_tok):
        self.val = qlast.BytesConstant(value=bytes_tok.clean_value)


class BaseBooleanConstant(Nonterm):
    val: qlast.Constant

    def reduce_TRUE(self, *kids):
        self.val = qlast.Constant.boolean(True)

    def reduce_FALSE(self, *kids):
        self.val = qlast.Constant.boolean(False)


def ensure_path(expr):
    if not isinstance(expr, qlast.Path):
        expr = qlast.Path(steps=[expr])
    return expr


class Path(Nonterm):
    @parsing.precedence(precedence.P_DOT)
    def reduce_Expr_PathStep(self, *kids):
        path = ensure_path(kids[0].val)
        path.steps.append(kids[1].val)
        self.val = path


class AtomicExpr(Nonterm):
    @parsing.inline(0)
    def reduce_BaseAtomicExpr(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AtomicPath(self, *kids):
        pass

    @parsing.precedence(precedence.P_TYPECAST)
    def reduce_LANGBRACKET_FullTypeExpr_RANGBRACKET_AtomicExpr(
            self, *kids):
        self.val = qlast.TypeCast(
            expr=kids[3].val,
            type=kids[1].val,
            cardinality_mod=None,
        )


# Duplication of Path above, but with BasicExpr at the root
class AtomicPath(Nonterm):
    @parsing.precedence(precedence.P_DOT)
    def reduce_AtomicExpr_PathStep(self, *kids):
        path = ensure_path(kids[0].val)
        path.steps.append(kids[1].val)
        self.val = path


class PathStep(Nonterm):
    def reduce_DOT_PathStepName(self, *kids):
        from edb.schema import pointers as s_pointers

        self.val = qlast.Ptr(
            name=kids[1].val.name,
            direction=s_pointers.PointerDirection.Outbound
        )

    def reduce_DOT_ICONST(self, *kids):
        # this is a valid link-like syntax for accessing unnamed tuples
        from edb.schema import pointers as s_pointers

        self.val = qlast.Ptr(
            name=kids[1].val,
            direction=s_pointers.PointerDirection.Outbound
        )

    def reduce_DOTBW_PathStepName(self, *kids):
        from edb.schema import pointers as s_pointers

        self.val = qlast.Ptr(
            name=kids[1].val.name,
            direction=s_pointers.PointerDirection.Inbound
        )

    def reduce_DOTQ_PathStepName(self, *kids):
        from edb.schema import pointers as s_pointers

        self.val = qlast.Ptr(
            name=kids[1].val.name,
            direction=s_pointers.PointerDirection.Outbound,
            type='optional',
        )

    def reduce_AT_PathNodeName(self, *kids):
        from edb.schema import pointers as s_pointers

        self.val = qlast.Ptr(
            name=kids[1].val.name,
            direction=s_pointers.PointerDirection.Outbound,
            type='property'
        )

    @parsing.inline(0)
    def reduce_TypeIntersection(self, *kids):
        pass


class TypeIntersection(Nonterm):
    def reduce_LBRACKET_IS_FullTypeExpr_RBRACKET(self, *kids):
        self.val = qlast.TypeIntersection(
            type=kids[2].val,
        )


class OptTypeIntersection(Nonterm):
    @parsing.inline(0)
    def reduce_TypeIntersection(self, *kids):
        pass

    def reduce_empty(self):
        self.val = None


# Used in free shapes
class FreeStepName(Nonterm):
    @parsing.inline(0)
    def reduce_ShortNodeName(self, *kids):
        pass

    def reduce_DUNDERTYPE(self, *kids):
        self.val = qlast.ObjectRef(name=kids[0].val)


# Used in shapes, paths and in PROPERTY/LINK definitions.
class PathStepName(Nonterm):
    @parsing.inline(0)
    def reduce_PathNodeName(self, *kids):
        pass

    def reduce_DUNDERTYPE(self, *kids):
        self.val = qlast.ObjectRef(name=kids[0].val)


class FuncApplication(Nonterm):
    def reduce_NodeName_LPAREN_OptFuncArgList_RPAREN(self, *kids):
        module = kids[0].val.module
        func_name = kids[0].val.name
        name = func_name if not module else (module, func_name)

        last_named_seen = None
        args = []
        kwargs = {}
        for argname, argname_ctx, arg in kids[2].val:
            if argname is not None:
                if argname in kwargs:
                    raise errors.EdgeQLSyntaxError(
                        f"duplicate named argument `{argname}`",
                        span=argname_ctx)

                last_named_seen = argname
                kwargs[argname] = arg

            else:
                if last_named_seen is not None:
                    raise errors.EdgeQLSyntaxError(
                        f"positional argument after named "
                        f"argument `{last_named_seen}`",
                        span=arg.span)
                args.append(arg)

        self.val = qlast.FunctionCall(func=name, args=args, kwargs=kwargs)


class FuncExpr(Nonterm):
    @parsing.inline(0)
    def reduce_FuncApplication(self, *kids):
        pass


class FuncCallArgExpr(Nonterm):
    def reduce_Expr(self, *kids):
        self.val = (
            None,
            None,
            kids[0].val,
        )

    def reduce_AnyIdentifier_ASSIGN_Expr(self, *kids):
        self.val = (
            kids[0].val,
            kids[0].span,
            kids[2].val,
        )

    def reduce_PARAMETER_ASSIGN_Expr(self, *kids):
        if kids[0].val[1].isdigit():
            raise errors.EdgeQLSyntaxError(
                f"numeric named parameters are not supported",
                span=kids[0].span)
        else:
            raise errors.EdgeQLSyntaxError(
                f"named parameters do not need a '$' prefix, "
                f"rewrite as '{kids[0].val[1:]} := ...'",
                span=kids[0].span)


class FuncCallArg(Nonterm):
    def reduce_FuncCallArgExpr_OptFilterClause_OptSortClause(self, *kids):
        self.val = kids[0].val

        if kids[1].val or kids[2].val:
            qry = qlast.SelectQuery(
                result=self.val[2],
                where=kids[1].val,
                orderby=kids[2].val,
                implicit=True,
                span=merge_spans(kids),
            )
            self.val = (self.val[0], self.val[1], qry)

    def reduce_ExprStmtSimple(self, *kids):
        self.val = (
            None,
            None,
            kids[0].val,
        )

    def reduce_AnyIdentifier_ASSIGN_ExprStmtSimple(self, *kids):
        self.val = (
            kids[0].val,
            kids[0].span,
            kids[2].val,
        )


class FuncArgList(ListNonterm, element=FuncCallArg, separator=tokens.T_COMMA,
                  allow_trailing_separator=True):
    pass


class OptFuncArgList(Nonterm):
    @parsing.inline(0)
    def reduce_FuncArgList(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = []


class PosCallArg(Nonterm):
    def reduce_Expr_OptFilterClause_OptSortClause(self, *kids):
        self.val = kids[0].val
        if kids[1].val or kids[2].val:
            self.val = qlast.SelectQuery(
                result=self.val,
                where=kids[1].val,
                orderby=kids[2].val,
                implicit=True,
            )


class PosCallArgList(ListNonterm, element=PosCallArg,
                     separator=tokens.T_COMMA):
    pass


class OptPosCallArgList(Nonterm):
    @parsing.inline(0)
    def reduce_PosCallArgList(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = []


class Identifier(Nonterm):
    val: str  # == Token.value

    def reduce_IDENT(self, ident):
        self.val = ident.clean_value

    @parsing.inline(0)
    def reduce_UnreservedKeyword(self, *_):
        pass


class PtrIdentifier(Nonterm):
    @parsing.inline(0)
    def reduce_Identifier(self, *_):
        pass

    @parsing.inline(0)
    def reduce_PartialReservedKeyword(self, *_):
        pass


class AnyIdentifier(Nonterm):
    @parsing.inline(0)
    def reduce_PtrIdentifier(self, *kids):
        pass

    def reduce_ReservedKeyword(self, *kids):
        name = kids[0].val
        if name[:2] == '__' and name[-2:] == '__':
            # There are a few reserved keywords like __std__ and __subject__
            # that can be used in paths but are prohibited to be used
            # anywhere else. So just as the tokenizer prohibits using
            # __names__ in general, we enforce the rule here for the
            # few remaining reserved __keywords__.
            raise errors.EdgeQLSyntaxError(
                "identifiers surrounded by double underscores are forbidden",
                span=kids[0].span)

        self.val = name


class DottedIdents(
        ListNonterm, element=AnyIdentifier, separator=tokens.T_DOT):
    pass


class DotName(Nonterm):
    val: str

    def reduce_DottedIdents(self, *kids):
        self.val = '.'.join(part for part in kids[0].val)


class ModuleName(ListNonterm, element=DotName, separator=tokens.T_DOUBLECOLON):
    val: list[str]


class ColonedIdents(
        ListNonterm, element=AnyIdentifier, separator=tokens.T_DOUBLECOLON):
    pass


class QualifiedName(Nonterm):
    def reduce_Identifier_DOUBLECOLON_ColonedIdents(self, ident, _, idents):
        assert ident.val
        assert idents.val
        self.val = [ident.val, *idents.val]

    def reduce_DUNDERSTD_DOUBLECOLON_ColonedIdents(self, _s, _c, idents):
        assert idents.val
        self.val = ['__std__', *idents.val]


# this can appear anywhere
class BaseName(Nonterm):
    def reduce_Identifier(self, *kids):
        self.val = [kids[0].val]

    @parsing.inline(0)
    def reduce_QualifiedName(self, *kids):
        pass


# this can appear in link/property definitions
class PtrName(Nonterm):
    def reduce_PtrIdentifier(self, ptr_identifier):
        assert ptr_identifier.val
        self.val = [ptr_identifier.val]

    @parsing.inline(0)
    def reduce_QualifiedName(self, *_):
        pass


# Non-collection type.
class SimpleTypeName(Nonterm):
    def reduce_PtrNodeName(self, *kids):
        self.val = qlast.TypeName(maintype=kids[0].val)

    def reduce_ANYTYPE(self, *kids):
        self.val = qlast.TypeName(
            maintype=qlast.PseudoObjectRef(name='anytype', span=self.span)
        )

    def reduce_ANYTUPLE(self, *kids):
        self.val = qlast.TypeName(
            maintype=qlast.PseudoObjectRef(name='anytuple', span=self.span)
        )

    def reduce_ANYOBJECT(self, *kids):
        self.val = qlast.TypeName(
            maintype=qlast.PseudoObjectRef(name='anyobject', span=self.span)
        )


class SimpleTypeNameList(ListNonterm, element=SimpleTypeName,
                         separator=tokens.T_COMMA):
    pass


class CollectionTypeName(Nonterm):

    def validate_subtype_list(self, lst):
        has_nonstrval = has_strval = has_items = False
        for el in lst.val:
            if isinstance(el, qlast.TypeExprLiteral):
                has_strval = True
            elif isinstance(el, qlast.TypeName):
                if el.name:
                    has_items = True
                else:
                    has_nonstrval = True

        if (has_nonstrval or has_items) and has_strval:
            # Prohibit cases like `tuple` and
            # `enum`
            raise errors.EdgeQLSyntaxError(
                "mixing string type literals and type names is not supported",
                span=lst.span)

        if has_items and has_nonstrval:
            # Prohibit cases like `tuple`
            raise errors.EdgeQLSyntaxError(
                "mixing named and unnamed subtype declarations "
                "is not supported",
                span=lst.span)

    def reduce_NodeName_LANGBRACKET_RANGBRACKET(self, *kids):
        # Constructs like `enum<>` or `array<>` aren't legal.
        raise errors.EdgeQLSyntaxError(
            'parametrized type must have at least one argument',
            span=kids[1].span,
        )

    def reduce_NodeName_LANGBRACKET_SubtypeList_RANGBRACKET(self, *kids):
        self.validate_subtype_list(kids[2])
        self.val = qlast.TypeName(
            maintype=kids[0].val,
            subtypes=kids[2].val,
        )


class TypeName(Nonterm):
    @parsing.inline(0)
    def reduce_SimpleTypeName(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_CollectionTypeName(self, *kids):
        pass


class TypeNameList(ListNonterm, element=TypeName,
                   separator=tokens.T_COMMA):
    pass


# A type expression that is not a simple type.
class NontrivialTypeExpr(Nonterm):
    def reduce_TYPEOF_Expr(self, *kids):
        self.val = qlast.TypeOf(expr=kids[1].val)

    @parsing.inline(1)
    def reduce_LPAREN_FullTypeExpr_RPAREN(self, *kids):
        pass

    def reduce_TypeExpr_PIPE_TypeExpr(self, *kids):
        self.val = qlast.TypeOp(
            left=kids[0].val,
            op=qlast.TypeOpName.OR,
            right=kids[2].val,
        )

    def reduce_TypeExpr_AMPER_TypeExpr(self, *kids):
        self.val = qlast.TypeOp(
            left=kids[0].val,
            op=qlast.TypeOpName.AND,
            right=kids[2].val,
        )


# This is a type expression without angle brackets, so it
# can be used without parentheses in a context where the
# angle bracket has a different meaning.
class TypeExpr(Nonterm):
    @parsing.inline(0)
    def reduce_SimpleTypeName(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_NontrivialTypeExpr(self, *kids):
        pass


# A type expression enclosed in parentheses
class ParenTypeExpr(Nonterm):
    @parsing.inline(1)
    def reduce_LPAREN_FullTypeExpr_RPAREN(self, *kids):
        pass


# This is a type expression which includes collection types,
# so it can only be directly used in a context where the
# angle bracket is unambiguous.
class FullTypeExpr(Nonterm):
    @parsing.inline(0)
    def reduce_TypeName(self, *kids):
        pass

    def reduce_TYPEOF_Expr(self, *kids):
        self.val = qlast.TypeOf(expr=kids[1].val)

    @parsing.inline(1)
    def reduce_LPAREN_FullTypeExpr_RPAREN(self, *kids):
        pass

    def reduce_FullTypeExpr_PIPE_FullTypeExpr(self, *kids):
        self.val = qlast.TypeOp(
            left=kids[0].val,
            op=qlast.TypeOpName.OR,
            right=kids[2].val,
        )

    def reduce_FullTypeExpr_AMPER_FullTypeExpr(self, *kids):
        self.val = qlast.TypeOp(
            left=kids[0].val,
            op=qlast.TypeOpName.AND,
            right=kids[2].val,
        )


class Subtype(Nonterm):
    @parsing.inline(0)
    def reduce_FullTypeExpr(self, *kids):
        pass

    def reduce_Identifier_COLON_FullTypeExpr(self, *kids):
        self.val = kids[2].val
        self.val.name = kids[0].val

    def reduce_BaseStringConstant(self, *kids):
        # TODO: Raise a DeprecationWarning once we have facility for that.
        self.val = qlast.TypeExprLiteral(
            val=kids[0].val,
        )

    def reduce_BaseNumberConstant(self, *kids):
        self.val = qlast.TypeExprLiteral(
            val=kids[0].val,
        )


class SubtypeList(ListNonterm, element=Subtype, separator=tokens.T_COMMA,
                  allow_trailing_separator=True):
    pass


class NodeName(Nonterm):
    # NOTE: Generic short of fully-qualified name.
    #
    # This name is safe to be used anywhere as it starts with IDENT only.

    def reduce_BaseName(self, base_name):
        self.val = qlast.ObjectRef(
            module='::'.join(base_name.val[:-1]) or None,
            name=base_name.val[-1])


class PtrNodeName(Nonterm):
    # NOTE: Generic short of fully-qualified name.
    #
    # This name is safe to be used in most DDL and SDL definitions.

    def reduce_PtrName(self, ptr_name):
        self.val = qlast.ObjectRef(
            module='::'.join(ptr_name.val[:-1]) or None,
            name=ptr_name.val[-1])


class PtrQualifiedNodeName(Nonterm):
    def reduce_QualifiedName(self, *kids):
        self.val = qlast.ObjectRef(
            module='::'.join(kids[0].val[:-1]),
            name=kids[0].val[-1])


class ShortNodeName(Nonterm):
    # NOTE: A non-qualified name that can be an identifier or
    # UNRESERVED_KEYWORD.
    #
    # This name is used as part of paths after the DOT. It can be an
    # identifier including UNRESERVED_KEYWORD and does not need to be
    # quoted or parenthesized.

    def reduce_Identifier(self, *kids):
        self.val = qlast.ObjectRef(
            module=None,
            name=kids[0].val)


class PathNodeName(Nonterm):
    # NOTE: A non-qualified name that can be an identifier or
    # PARTIAL_RESERVED_KEYWORD.
    #
    # This name is used as part of paths after the DOT as well as in
    # definitions after LINK/POINTER. It can be an identifier including
    # PARTIAL_RESERVED_KEYWORD and does not need to be quoted or
    # parenthesized.

    def reduce_PtrIdentifier(self, *kids):
        self.val = qlast.ObjectRef(
            module=None,
            name=kids[0].val)


class AnyNodeName(Nonterm):
    # NOTE: A non-qualified name that can be ANY identifier.
    #
    # This name is used as part of paths after the DOT. It can be any
    # identifier including RESERVED_KEYWORD and UNRESERVED_KEYWORD and
    # does not need to be quoted or parenthesized.
    #
    # This is mainly used in DDL statements that have another keyword
    # completely disambiguating that what comes next is a name. It
    # CANNOT be used in Expr productions because it will cause
    # ambiguity with NodeName, etc.

    def reduce_AnyIdentifier(self, *kids):
        self.val = qlast.ObjectRef(
            module=None,
            name=kids[0].val)


class Keyword(parsing.Nonterm):
    """Base class for the different classes of keywords.

    Not a real nonterm on its own.
    """
    def __init_subclass__(
            cls, *, type, is_internal=False, **kwargs):
        super().__init_subclass__(is_internal=is_internal, **kwargs)

        if is_internal:
            return

        assert type in keywords.keyword_types

        for token in keywords.by_type[type].values():
            def method(inst, *kids):
                inst.val = kids[0].val
            method.__doc__ = "%%reduce %s" % token
            method.__name__ = 'reduce_%s' % token
            setattr(cls, method.__name__, method)


class UnreservedKeyword(Keyword,
                        type=keywords.UNRESERVED_KEYWORD):
    pass


class PartialReservedKeyword(Keyword,
                             type=keywords.PARTIAL_RESERVED_KEYWORD):
    pass


class ReservedKeyword(Keyword,
                      type=keywords.RESERVED_KEYWORD):
    pass


class SchemaObjectClassValue(typing.NamedTuple):

    itemclass: qltypes.SchemaObjectClass


class SchemaObjectClass(Nonterm):

    def reduce_ALIAS(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.ALIAS)

    def reduce_ANNOTATION(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.ANNOTATION)

    def reduce_CAST(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.CAST)

    def reduce_CONSTRAINT(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.CONSTRAINT)

    def reduce_FUNCTION(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.FUNCTION)

    def reduce_LINK(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.LINK)

    def reduce_MODULE(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.MODULE)

    def reduce_OPERATOR(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.OPERATOR)

    def reduce_PROPERTY(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.PROPERTY)

    def reduce_SCALAR_TYPE(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.SCALAR_TYPE)

    def reduce_TYPE(self, *kids):
        self.val = SchemaObjectClassValue(
            itemclass=qltypes.SchemaObjectClass.TYPE)


class SchemaItem(Nonterm):

    def reduce_SchemaObjectClass_NodeName(self, *kids):
        ref = kids[1].val
        ref.itemclass = kids[0].val.itemclass
        self.val = ref


================================================
FILE: edb/edgeql/parser/grammar/keywords.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2010-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

import re

import edb._edgeql_parser as ql_parser


keyword_types = range(1, 5)
(UNRESERVED_KEYWORD, RESERVED_KEYWORD, TYPE_FUNC_NAME_KEYWORD,
 PARTIAL_RESERVED_KEYWORD) = keyword_types

unreserved_keywords = ql_parser.unreserved_keywords
future_reserved_keywords = ql_parser.future_reserved_keywords
reserved_keywords = (
    future_reserved_keywords | ql_parser.current_reserved_keywords
)
# These keywords can be used in pretty much all the places where they are
# preceeded by a reserved keyword or some other disambiguating token like `.`,
# `.<`, or `@`.
#
# In practice we mainly relax their usage as link/property names.
partial_reserved_keywords = ql_parser.partial_reserved_keywords


def _check_keywords():
    duplicate_keywords = reserved_keywords & unreserved_keywords
    if duplicate_keywords:
        raise ValueError(
            f'The following EdgeQL keywords are defined as *both* '
            f'reserved and unreserved: {duplicate_keywords!r}')


_check_keywords()


_dunder_re = re.compile(r'(?i)^__[a-z]+__$')


def tok_name(keyword):
    '''Convert a literal keyword into a token name.'''
    if _dunder_re.match(keyword):
        return f'DUNDER{keyword[2:-2].upper()}'
    else:
        return keyword.upper()


edgeql_keywords = {k: (tok_name(k), UNRESERVED_KEYWORD)
                   for k in unreserved_keywords}
edgeql_keywords.update({k: (tok_name(k), RESERVED_KEYWORD)
                        for k in reserved_keywords})
edgeql_keywords.update({k: (tok_name(k), PARTIAL_RESERVED_KEYWORD)
                        for k in partial_reserved_keywords})


by_type: dict[int, dict] = {typ: {} for typ in keyword_types}

for val, spec in edgeql_keywords.items():
    by_type[spec[1]][val] = spec[0]


================================================
FILE: edb/edgeql/parser/grammar/precedence.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

from edb.common import parsing


class Precedence(parsing.Precedence, assoc='fail', is_internal=True):
    pass


class P_UNION(Precedence, assoc='left', tokens=('UNION', 'EXCEPT',)):
    pass


class P_INTERSECT(Precedence, assoc='left', tokens=('INTERSECT',)):
    pass


class P_IFELSE(Precedence, assoc='right', tokens=('IF', 'ELSE')):
    pass


class P_OR(Precedence, assoc='left', tokens=('OR',)):
    pass


class P_AND(Precedence, assoc='left', tokens=('AND',)):
    pass


class P_NOT(Precedence, assoc='right', tokens=('NOT',)):
    pass


class P_LIKE_ILIKE(Precedence, assoc='nonassoc', tokens=('LIKE', 'ILIKE')):
    pass


class P_IN(Precedence, assoc='nonassoc', tokens=('IN',)):
    pass


class P_IDENT(Precedence, assoc='nonassoc', tokens=('IDENT', 'PARTITION')):
    pass


class P_COMPARE_OP(
    Precedence,
    assoc='nonassoc',
    tokens=(
        'DISTINCTFROM',
        'GREATEREQ',
        'LESSEQ',
        'NOTDISTINCTFROM',
        'NOTEQ',
        'LANGBRACKET',
        'RANGBRACKET',
        'EQUALS',
    )
):
    pass


class P_IS(Precedence, assoc='nonassoc', tokens=('IS',)):
    pass


class P_ADD_OP(Precedence, assoc='left',
               tokens=('PLUS', 'MINUS', 'DOUBLEPLUS')):
    pass


class P_MUL_OP(Precedence, assoc='left',
               tokens=('STAR', 'SLASH', 'DOUBLESLASH', 'PERCENT')):
    pass


class P_DOUBLEQMARK_OP(Precedence, assoc='right', tokens=('DOUBLEQMARK',)):
    pass


class P_TYPEOF(Precedence, assoc='nonassoc', tokens=('TYPEOF',)):
    pass


class P_INTROSPECT(Precedence, assoc='nonassoc', tokens=('INTROSPECT',)):
    pass


class P_TYPEOR(Precedence, assoc='left', tokens=('PIPE',)):
    pass


class P_TYPEAND(Precedence, assoc='left', tokens=('AMPER',)):
    pass


class P_UMINUS(Precedence, assoc='right'):
    pass


class P_EXISTS(Precedence, assoc='right', tokens=('EXISTS',),
               rel_to_last='='):
    pass


class P_DISTINCT(Precedence, assoc='right', tokens=('DISTINCT',),
                 rel_to_last='='):
    pass


class P_POW_OP(Precedence, assoc='right', tokens=('CIRCUMFLEX',)):
    pass


class P_TYPECAST(Precedence, assoc='right'):
    pass


class P_BRACE(Precedence, assoc='left', tokens=('LBRACE', 'RBRACE')):
    pass


class P_BRACKET(Precedence, assoc='left', tokens=('LBRACKET', 'RBRACKET')):
    pass


class P_PAREN(Precedence, assoc='left', tokens=('LPAREN', 'RPAREN')):
    pass


class P_DOT(Precedence, assoc='left', tokens=('DOT', 'DOTBW', 'DOTQ')):
    pass


class P_DETACHED(Precedence, assoc='right', tokens=('DETACHED',)):
    pass


class P_GLOBAL(Precedence, assoc='right', tokens=('GLOBAL',)):
    pass


class P_DOUBLECOLON(Precedence, assoc='left', tokens=('DOUBLECOLON',)):
    pass


class P_AT(Precedence, assoc='left', tokens=('AT',)):
    pass


# XXX: I don't remember why this helps.

class P_REQUIRED(Precedence, assoc='right', tokens=('REQUIRED',)):
    pass


class P_MULTI(Precedence, assoc='right', tokens=('MULTI',),
              rel_to_last='='):
    pass


class P_OPTIONAL(Precedence, assoc='right', tokens=('OPTIONAL',),
                 rel_to_last='='):
    pass


class P_SINGLE(Precedence, assoc='right', tokens=('SINGLE',),
               rel_to_last='='):
    pass


================================================
FILE: edb/edgeql/parser/grammar/sdl.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

from edb.edgeql import ast as qlast

from edb.common import parsing
from edb import errors

from . import expressions
from . import commondl

from .precedence import *  # NOQA
from .tokens import *  # NOQA
from .commondl import *  # NOQA


Nonterm = expressions.Nonterm  # type: ignore[misc]
OptSemicolons = commondl.OptSemicolons  # type: ignore[misc]


sdl_nontem_helper = commondl.NewNontermHelper(__name__)
_new_nonterm = sdl_nontem_helper._new_nonterm


# top-level SDL statements
class SDLStatement(Nonterm):
    @parsing.inline(0)
    def reduce_SDLBlockStatement(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_SDLShortStatement_SEMICOLON(self, *kids):
        pass


# a list of SDL statements with optional semicolon separators
class SDLStatements(parsing.ListNonterm, element=SDLStatement,
                    separator=OptSemicolons):
    pass


# These statements have a block
class SDLBlockStatement(Nonterm):
    @parsing.inline(0)
    def reduce_ModuleDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ScalarTypeDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AnnotationDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ObjectTypeDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AliasDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ConstraintDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_LinkDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_PropertyDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_FunctionDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_GlobalDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_IndexDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_PermissionDeclaration(self, *kids):
        pass


# these statements have no {} block
class SDLShortStatement(Nonterm):

    @parsing.inline(0)
    def reduce_ExtensionRequirementDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_FutureRequirementDeclaration(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ScalarTypeDeclarationShort(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AnnotationDeclarationShort(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ObjectTypeDeclarationShort(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_AliasDeclarationShort(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ConstraintDeclarationShort(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_LinkDeclarationShort(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_PropertyDeclarationShort(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_FunctionDeclarationShort(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_GlobalDeclarationShort(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_IndexDeclarationShort(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_PermissionDeclarationShort(self, *kids):
        pass


# A rule for an SDL block, either as part of `module` declaration or
# as top-level schema used in MIGRATION DDL.
class SDLCommandBlock(Nonterm):
    # this command block can be empty
    def reduce_LBRACE_OptSemicolons_RBRACE(self, *kids):
        self.val = []

    def reduce_statement_without_semicolons(self, _0, _1, stmt, _2):
        r"""%reduce LBRACE \
                OptSemicolons SDLShortStatement \
            RBRACE
        """
        self.val = [stmt.val]

    def reduce_statements_without_optional_trailing_semicolons(self, *kids):
        r"""%reduce LBRACE \
                OptSemicolons SDLStatements \
                OptSemicolons SDLShortStatement \
            RBRACE
        """
        _, _, stmts, _, stmt, _ = kids
        self.val = stmts.val + [stmt.val]

    @parsing.inline(2)
    def reduce_LBRACE_OptSemicolons_SDLStatements_RBRACE(self, *kids):
        pass

    @parsing.inline(2)
    def reduce_statements_without_optional_trailing_semicolons2(self, *kids):
        r"""%reduce LBRACE \
                OptSemicolons SDLStatements \
                Semicolons \
            RBRACE
        """


class SDLProductionHelper:
    def _passthrough(self, *cmds):
        self.val = cmds[0].val

    def _singleton_list(self, cmd):
        self.val = [cmd.val]

    def _empty(self, *kids):
        self.val = []

    def _block(self, lbrace, sc1, cmdl, rbrace):
        self.val = [cmdl.val]

    def _block2(self, lbrace, sc1, cmdlist, sc2, rbrace):
        self.val = cmdlist.val

    def _block3(self, lbrace, sc1, cmdlist, sc2, cmd, rbrace):
        self.val = cmdlist.val + [cmd.val]


def sdl_commands_block(parent, *commands, opt=True):
    if parent is None:
        parent = ''

    # SDLCommand := SDLCommand1 | SDLCommand2 ...
    #
    # All the "short" commands, ones that need a ";" are gathered as
    # SDLCommandShort.
    #
    # All the "block" commands, ones that have a "{...}" and don't
    # need a ";" are gathered as SDLCommandBlock.
    clsdict_b = {}
    clsdict_s = {}

    for command in commands:
        if command.__name__.endswith('Block'):
            clsdict_b[f'reduce_{command.__name__}'] = \
                SDLProductionHelper._passthrough
        else:
            clsdict_s[f'reduce_{command.__name__}'] = \
                SDLProductionHelper._passthrough

    cmd_s = _new_nonterm(f'{parent}SDLCommandShort', clsdict=clsdict_s)
    cmd_b = _new_nonterm(f'{parent}SDLCommandBlock', clsdict=clsdict_b)

    # Merged command which has minimal ";"
    #
    # SDLCommandFull := SDLCommandShort ; | SDLCommandBlock
    clsdict = {}
    clsdict[f'reduce_{cmd_s.__name__}_SEMICOLON'] = \
        SDLProductionHelper._passthrough
    clsdict[f'reduce_{cmd_b.__name__}'] = \
        SDLProductionHelper._passthrough
    cmd = _new_nonterm(f'{parent}SDLCommandFull', clsdict=clsdict)

    # SDLCommandsList := SDLCommandFull [; SDLCommandFull ...]
    cmdlist = _new_nonterm(f'{parent}SDLCommandsList',
                           clsbases=(parsing.ListNonterm,),
                           clskwds=dict(element=cmd, separator=OptSemicolons))

    # Command block is tricky, but the inner commands must terminate
    # without a ";", is possible.
    #
    # SDLCommandsBlock :=
    #
    #   { [ ; ] SDLCommandFull }
    #   { [ ; ] SDLCommandsList [ ; ]} |
    #   { [ ; ] SDLCommandsList [ ; ] SDLCommandFull }
    clsdict = {}
    clsdict[f'reduce_LBRACE_OptSemicolons_{cmd_s.__name__}_RBRACE'] = \
        SDLProductionHelper._block
    clsdict[f'reduce_LBRACE_OptSemicolons_{cmdlist.__name__}_' +
            f'OptSemicolons_RBRACE'] = \
        SDLProductionHelper._block2
    clsdict[f'reduce_LBRACE_OptSemicolons_{cmdlist.__name__}_OptSemicolons_' +
            f'{cmd_s.__name__}_RBRACE'] = \
        SDLProductionHelper._block3
    clsdict[f'reduce_LBRACE_OptSemicolons_RBRACE'] = \
        SDLProductionHelper._empty
    _new_nonterm(f'{parent}SDLCommandsBlock', clsdict=clsdict)

    if opt is False:
        #   | Command
        clsdict = {}
        clsdict[f'reduce_{cmd_s.__name__}'] = \
            SDLProductionHelper._singleton_list
        clsdict[f'reduce_{cmd_b.__name__}'] = \
            SDLProductionHelper._singleton_list
        _new_nonterm(parent + 'SingleSDLCommandBlock', clsdict=clsdict)


class Using(Nonterm):
    def reduce_USING_ParenExpr(self, *kids):
        _, paren_expr = kids
        self.val = qlast.SetField(
            name='expr',
            value=paren_expr.val,
            special_syntax=True,
        )


class SetField(Nonterm):
    # field := 
    def reduce_Identifier_ASSIGN_GenExpr(self, *kids):
        identifier, _, expr = kids
        self.val = qlast.SetField(name=identifier.val, value=expr.val)


class SetAnnotation(Nonterm):
    def reduce_ANNOTATION_NodeName_ASSIGN_GenExpr(self, *kids):
        _, name, _, expr = kids
        self.val = qlast.CreateAnnotationValue(name=name.val, value=expr.val)


sdl_commands_block(
    'Create',
    Using,
    SetField,
    SetAnnotation)


class ExtensionRequirementDeclaration(Nonterm):

    def reduce_USING_EXTENSION_ShortNodeName_OptExtensionVersion(self, *kids):
        _, _, name, version = kids
        self.val = qlast.CreateExtension(
            name=name.val,
            version=version.val,
        )


class FutureRequirementDeclaration(Nonterm):

    def reduce_USING_FUTURE_ShortNodeName(self, *kids):
        _, _, name = kids
        self.val = qlast.CreateFuture(
            name=name.val,
        )


class ModuleDeclaration(Nonterm):
    def reduce_MODULE_ModuleName_SDLCommandBlock(self, _, name, block):

        # Check that top-level declarations DO NOT use fully-qualified
        # names and aren't nested module blocks.
        declarations = block.val
        for decl in declarations:
            if isinstance(decl, qlast.ExtensionCommand):
                raise errors.EdgeQLSyntaxError(
                    "'using extension' cannot be used inside a module block",
                    span=decl.span)
            elif isinstance(decl, qlast.FutureCommand):
                raise errors.EdgeQLSyntaxError(
                    "'using future' cannot be used inside a module block",
                    span=decl.span)
            elif decl.name.module is not None:
                raise errors.EdgeQLSyntaxError(
                    "fully-qualified name is not allowed in "
                    "a module declaration",
                    span=decl.name.span)

        self.val = qlast.ModuleDeclaration(
            # mirror what we do in CREATE MODULE
            name=qlast.ObjectRef(
                module=None, name='::'.join(name.val), span=name.span
            ),
            declarations=declarations,
        )


#
# Constraints
#
class ConstraintDeclaration(Nonterm):
    def reduce_CreateConstraint(self, *kids):
        r"""%reduce ABSTRACT CONSTRAINT NodeName OptOnExpr \
                    OptExtendingSimple CreateSDLCommandsBlock"""
        _, _, name, on_expr, extending, commands = kids
        self.val = qlast.CreateConstraint(
            name=name.val,
            subjectexpr=on_expr.val,
            bases=extending.val,
            commands=commands.val,
        )

    def reduce_CreateConstraint_CreateFunctionArgs(self, *kids):
        r"""%reduce ABSTRACT CONSTRAINT NodeName CreateFunctionArgs \
                    OptOnExpr OptExtendingSimple CreateSDLCommandsBlock"""
        _, _, name, args, on_expr, extending, commands = kids
        self.val = qlast.CreateConstraint(
            name=name.val,
            params=args.val,
            subjectexpr=on_expr.val,
            bases=extending.val,
            commands=commands.val,
        )


class ConstraintDeclarationShort(Nonterm):
    def reduce_CreateConstraint(self, *kids):
        r"""%reduce ABSTRACT CONSTRAINT NodeName OptOnExpr \
                    OptExtendingSimple"""
        _, _, name, on_expr, extending = kids
        self.val = qlast.CreateConstraint(
            name=name.val,
            subject=on_expr.val,
            bases=extending.val,
        )

    def reduce_CreateConstraint_CreateFunctionArgs(self, *kids):
        r"""%reduce ABSTRACT CONSTRAINT NodeName CreateFunctionArgs \
                    OptOnExpr OptExtendingSimple"""
        _, _, name, args, on_expr, extending = kids
        self.val = qlast.CreateConstraint(
            name=name.val,
            params=args.val,
            subject=on_expr.val,
            bases=extending.val,
        )


class ConcreteConstraintBlock(Nonterm):
    def reduce_CreateConstraint(self, *kids):
        r"""%reduce CONSTRAINT \
                    NodeName OptConcreteConstraintArgList OptOnExpr \
                    OptExceptExpr \
                    CreateSDLCommandsBlock"""
        _, name, arg_list, on_expr, except_expr, commands = kids
        self.val = qlast.CreateConcreteConstraint(
            name=name.val,
            args=arg_list.val,
            subjectexpr=on_expr.val,
            except_expr=except_expr.val,
            commands=commands.val,
        )

    def reduce_CreateDelegatedConstraint(self, *kids):
        r"""%reduce DELEGATED CONSTRAINT \
                    NodeName OptConcreteConstraintArgList OptOnExpr \
                    OptExceptExpr \
                    CreateSDLCommandsBlock"""
        _, _, name, arg_list, on_expr, except_expr, commands = kids
        self.val = qlast.CreateConcreteConstraint(
            delegated=True,
            name=name.val,
            args=arg_list.val,
            subjectexpr=on_expr.val,
            except_expr=except_expr.val,
            commands=commands.val,
        )


class ConcreteConstraintShort(Nonterm):
    def reduce_CreateConstraint(self, *kids):
        r"""%reduce CONSTRAINT \
                    NodeName OptConcreteConstraintArgList OptOnExpr \
                    OptExceptExpr"""
        _, name, arg_list, on_expr, except_expr = kids
        self.val = qlast.CreateConcreteConstraint(
            name=name.val,
            args=arg_list.val,
            subjectexpr=on_expr.val,
            except_expr=except_expr.val,
        )

    def reduce_CreateDelegatedConstraint(self, *kids):
        r"""%reduce DELEGATED CONSTRAINT \
                    NodeName OptConcreteConstraintArgList OptOnExpr \
                    OptExceptExpr"""
        _, _, name, arg_list, on_expr, except_expr = kids
        self.val = qlast.CreateConcreteConstraint(
            delegated=True,
            name=name.val,
            args=arg_list.val,
            subjectexpr=on_expr.val,
            except_expr=except_expr.val,
        )


#
# Scalar Types
#

sdl_commands_block(
    'CreateScalarType',
    SetField,
    SetAnnotation,
    ConcreteConstraintBlock,
    ConcreteConstraintShort,
)


class ScalarTypeDeclaration(Nonterm):
    def reduce_CreateAbstractScalarTypeStmt(self, *kids):
        r"""%reduce \
            ABSTRACT SCALAR TYPE NodeName \
            OptExtending CreateScalarTypeSDLCommandsBlock \
        """
        _, _, _, name, extending, commands = kids
        self.val = qlast.CreateScalarType(
            abstract=True,
            name=name.val,
            bases=extending.val,
            commands=commands.val,
        )

    def reduce_ScalarTypeDeclaration(self, *kids):
        r"""%reduce \
            SCALAR TYPE NodeName \
            OptExtending CreateScalarTypeSDLCommandsBlock \
        """
        _, _, name, extending, commands = kids
        self.val = qlast.CreateScalarType(
            name=name.val,
            bases=extending.val,
            commands=commands.val,
        )


class ScalarTypeDeclarationShort(Nonterm):
    def reduce_CreateAbstractScalarTypeStmt(self, *kids):
        r"""%reduce \
            ABSTRACT SCALAR TYPE NodeName \
            OptExtending \
        """
        _, _, _, name, extending = kids
        self.val = qlast.CreateScalarType(
            abstract=True,
            name=name.val,
            bases=extending.val,
        )

    def reduce_ScalarTypeDeclaration(self, *kids):
        r"""%reduce \
            SCALAR TYPE NodeName \
            OptExtending \
        """
        _, _, name, extending = kids
        self.val = qlast.CreateScalarType(
            name=name.val,
            bases=extending.val,
        )


#
# Annotations
#
class AnnotationDeclaration(Nonterm):
    def reduce_CreateAnnotation(self, *kids):
        r"""%reduce ABSTRACT ANNOTATION NodeName OptExtendingSimple \
                    CreateSDLCommandsBlock"""
        _, _, name, extending, commands = kids
        self.val = qlast.CreateAnnotation(
            abstract=True,
            name=name.val,
            bases=extending.val,
            inheritable=False,
            commands=commands.val,
        )

    def reduce_CreateInheritableAnnotation(self, *kids):
        r"""%reduce ABSTRACT INHERITABLE ANNOTATION
                    NodeName OptExtendingSimple CreateSDLCommandsBlock"""
        _, _, _, name, extending, commands = kids
        self.val = qlast.CreateAnnotation(
            abstract=True,
            name=name.val,
            bases=extending.val,
            inheritable=True,
            commands=commands.val,
        )


class AnnotationDeclarationShort(Nonterm):
    def reduce_CreateAnnotation(self, *kids):
        r"""%reduce ABSTRACT ANNOTATION NodeName OptExtendingSimple"""
        _, _, name, extending = kids
        self.val = qlast.CreateAnnotation(
            abstract=True,
            name=name.val,
            bases=extending.val,
            inheritable=False,
        )

    def reduce_CreateInheritableAnnotation(self, *kids):
        r"""%reduce ABSTRACT INHERITABLE ANNOTATION
                    NodeName OptExtendingSimple"""
        _, _, _, name, extending = kids
        self.val = qlast.CreateAnnotation(
            abstract=True,
            name=name.val,
            bases=extending.val,
            inheritable=True,
        )


#
# Indexes
#
sdl_commands_block(
    'CreateIndex',
    Using,
    SetField,
    SetAnnotation,
)


class IndexDeclaration(
    Nonterm,
    commondl.ProcessIndexMixin,
):
    def reduce_CreateIndex(self, *kids):
        r"""%reduce ABSTRACT INDEX NodeName \
                    OptExtendingSimple CreateIndexSDLCommandsBlock"""
        _, _, name, bases, commands = kids
        self.val = qlast.CreateIndex(
            name=name.val,
            bases=bases.val,
            commands=commands.val,
        )

    def reduce_CreateIndex_CreateFunctionArgs(self, *kids):
        r"""%reduce ABSTRACT INDEX NodeName IndexExtArgList \
                    OptExtendingSimple CreateIndexSDLCommandsBlock"""
        _, _, name, arg_list, bases, commands = kids
        params, kwargs = self._process_params_or_kwargs(
            bases.val, arg_list.val)
        self.val = qlast.CreateIndex(
            name=name.val,
            params=params,
            kwargs=kwargs,
            bases=bases.val,
            commands=commands.val,
        )


class IndexDeclarationShort(
    Nonterm,
    commondl.ProcessIndexMixin,
):
    def reduce_CreateIndex(self, *kids):
        r"""%reduce ABSTRACT INDEX NodeName OptExtendingSimple"""
        _, _, name, bases = kids
        self.val = qlast.CreateIndex(
            name=name.val,
            bases=bases.val,
        )

    def reduce_CreateIndex_CreateFunctionArgs(self, *kids):
        r"""%reduce ABSTRACT INDEX NodeName IndexExtArgList \
                    OptExtendingSimple"""
        _, _, name, arg_list, bases = kids
        params, kwargs = self._process_params_or_kwargs(
            bases.val, arg_list.val)
        self.val = qlast.CreateIndex(
            name=name.val,
            params=params,
            kwargs=kwargs,
            bases=bases.val,
        )


sdl_commands_block(
    'CreateConcreteIndex',
    SetField,
    SetAnnotation)


class ConcreteIndexDeclarationBlock(Nonterm, commondl.ProcessIndexMixin):
    def reduce_CreateConcreteAnonymousIndex(self, *kids):
        r"""%reduce INDEX OnExpr OptExceptExpr
                    CreateConcreteIndexSDLCommandsBlock
        """
        _, on_expr, except_expr, commands = kids
        self.val = qlast.CreateConcreteIndex(
            name=qlast.ObjectRef(module='__', name='idx', span=kids[0].span),
            expr=on_expr.val,
            except_expr=except_expr.val,
            commands=commands.val,
        )

    def reduce_CreateConcreteAnonymousDeferredIndex(self, *kids):
        r"""%reduce DEFERRED INDEX OnExpr OptExceptExpr
                    CreateConcreteIndexSDLCommandsBlock
        """
        _, _, on_expr, except_expr, commands = kids
        self.val = qlast.CreateConcreteIndex(
            name=qlast.ObjectRef(module='__', name='idx', span=kids[0].span),
            expr=on_expr.val,
            except_expr=except_expr.val,
            deferred=True,
            commands=commands.val,
        )

    def reduce_CreateConcreteIndex(self, *kids):
        r"""%reduce INDEX NodeName \
                    OnExpr OptExceptExpr \
                    CreateConcreteIndexSDLCommandsBlock \
        """
        _, name, on_expr, except_expr, commands = kids
        self.val = qlast.CreateConcreteIndex(
            name=name.val,
            expr=on_expr.val,
            except_expr=except_expr.val,
            commands=commands.val,
        )

    def reduce_CreateConcreteDeferredIndex(self, *kids):
        r"""%reduce DEFERRED INDEX NodeName \
                    OnExpr OptExceptExpr \
                    CreateConcreteIndexSDLCommandsBlock \
        """
        _, _, name, on_expr, except_expr, commands = kids
        self.val = qlast.CreateConcreteIndex(
            name=name.val,
            expr=on_expr.val,
            except_expr=except_expr.val,
            deferred=True,
            commands=commands.val,
        )

    def reduce_CreateConcreteIndexWithArgs(self, *kids):
        r"""%reduce INDEX NodeName IndexExtArgList \
                    OnExpr OptExceptExpr \
                    CreateConcreteIndexSDLCommandsBlock \
        """
        _, name, arg_list, on_expr, except_expr, commands = kids
        kwargs = self._process_arguments(arg_list.val)
        self.val = qlast.CreateConcreteIndex(
            name=name.val,
            kwargs=kwargs,
            expr=on_expr.val,
            except_expr=except_expr.val,
            commands=commands.val,
        )

    def reduce_CreateConcreteDeferredIndexWithArgs(self, *kids):
        r"""%reduce DEFERRED INDEX NodeName IndexExtArgList \
                    OnExpr OptExceptExpr \
                    CreateConcreteIndexSDLCommandsBlock \
        """
        _, _, name, arg_list, on_expr, except_expr, commands = kids
        kwargs = self._process_arguments(arg_list.val)
        self.val = qlast.CreateConcreteIndex(
            name=name.val,
            kwargs=kwargs,
            expr=on_expr.val,
            except_expr=except_expr.val,
            deferred=True,
            commands=commands.val,
        )


class ConcreteIndexDeclarationShort(Nonterm, commondl.ProcessIndexMixin):
    def reduce_INDEX_OnExpr_OptExceptExpr(self, *kids):
        _, on_expr, except_expr = kids
        self.val = qlast.CreateConcreteIndex(
            name=qlast.ObjectRef(module='__', name='idx', span=kids[0].span),
            expr=on_expr.val,
            except_expr=except_expr.val,
        )

    def reduce_DEFERRED_INDEX_OnExpr_OptExceptExpr(self, *kids):
        _, _, on_expr, except_expr = kids
        self.val = qlast.CreateConcreteIndex(
            name=qlast.ObjectRef(module='__', name='idx', span=kids[0].span),
            expr=on_expr.val,
            except_expr=except_expr.val,
            deferred=True,
        )

    def reduce_CreateConcreteIndex(self, *kids):
        r"""%reduce INDEX NodeName OnExpr OptExceptExpr
        """
        _, name, on_expr, except_expr = kids
        self.val = qlast.CreateConcreteIndex(
            name=name.val,
            expr=on_expr.val,
            except_expr=except_expr.val,
        )

    def reduce_CreateConcreteDeferredIndex(self, *kids):
        r"""%reduce DEFERRED INDEX NodeName OnExpr OptExceptExpr
        """
        _, _, name, on_expr, except_expr = kids
        self.val = qlast.CreateConcreteIndex(
            name=name.val,
            expr=on_expr.val,
            except_expr=except_expr.val,
            deferred=True,
        )

    def reduce_CreateConcreteIndexWithArgs(self, *kids):
        r"""%reduce INDEX NodeName IndexExtArgList \
                    OnExpr OptExceptExpr \
        """
        _, name, arg_list, on_expr, except_expr = kids
        kwargs = self._process_arguments(arg_list.val)
        self.val = qlast.CreateConcreteIndex(
            name=name.val,
            kwargs=kwargs,
            expr=on_expr.val,
            except_expr=except_expr.val,
        )

    def reduce_CreateConcreteDeferredIndexWithArgs(self, *kids):
        r"""%reduce DEFERRED INDEX NodeName IndexExtArgList
                    OnExpr OptExceptExpr
        """
        _, _, name, arg_list, on_expr, except_expr = kids
        kwargs = self._process_arguments(arg_list.val)
        self.val = qlast.CreateConcreteIndex(
            name=name.val,
            kwargs=kwargs,
            expr=on_expr.val,
            except_expr=except_expr.val,
            deferred=True,
        )


#
# Mutation rewrites
#
sdl_commands_block(
    'CreateRewrite',
    SetField,
    SetAnnotation
)


class RewriteDeclarationBlock(Nonterm):
    def reduce_CreateRewrite(self, _r, kinds, _u, expr, commands):
        """%reduce
            REWRITE RewriteKindList
            USING ParenExpr
            CreateRewriteSDLCommandsBlock
        """
        # The name isn't important (it gets replaced) but we need to
        # have one.
        name = '/'.join(str(kind) for kind in kinds.val)
        self.val = qlast.CreateRewrite(
            name=qlast.ObjectRef(name=name, span=kinds.span),
            kinds=kinds.val,
            expr=expr.val,
            commands=commands.val,
        )


class RewriteDeclarationShort(Nonterm):
    def reduce_CreateRewrite(self, _r, kinds, _u, expr):
        """%reduce
            REWRITE RewriteKindList
            USING ParenExpr
        """
        # The name isn't important (it gets replaced) but we need to
        # have one.
        name = '/'.join(str(kind) for kind in kinds.val)
        self.val = qlast.CreateRewrite(
            name=qlast.ObjectRef(name=name, span=kinds.span),
            kinds=kinds.val,
            expr=expr.val,
        )


#
# Unknown kind pointers (could be link or property)
#

class PtrTarget(Nonterm):

    def reduce_ARROW_FullTypeExpr(self, *kids):
        _arrow, type_expr = kids

        self.val = type_expr.val
        self.span = type_expr.val.span

    def reduce_COLON_FullTypeExpr(self, *kids):
        _, type_expr = kids
        self.val = type_expr.val
        self.span = type_expr.val.span


class OptPtrTarget(Nonterm):

    def reduce_empty(self, *kids):
        self.val = None

    @parsing.inline(0)
    def reduce_PtrTarget(self, *kids):
        pass


class ConcreteUnknownPointerBlock(Nonterm):
    def _validate(self):
        on_target_delete = None
        for cmd in self.val.commands:
            if isinstance(cmd, qlast.OnTargetDelete):
                if on_target_delete:
                    raise errors.EdgeQLSyntaxError(
                        f"more than one 'on target delete' specification",
                        span=cmd.span)
                else:
                    on_target_delete = cmd

    def _extract_target(self, target, cmds, span, *, overloaded=False):
        if target:
            return target, cmds

        for cmd in cmds:
            if isinstance(cmd, qlast.SetField) and cmd.name == 'expr':
                if target is not None:
                    raise errors.EdgeQLSyntaxError(
                        f'computed link with more than one expression',
                        span=span)
                target = cmd.value

        if not overloaded and target is None:
            raise errors.EdgeQLSyntaxError(
                f'computed link without expression',
                span=span)

        return target, cmds

    def reduce_CreateRegularPointer(self, *kids):
        """%reduce
            PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcreteLinkSDLCommandsBlock
        """
        name, opt_bases, opt_target, block = kids
        target, cmds = self._extract_target(
            opt_target.val, block.val, name.span)
        vbases, vcmds = commondl.extract_bases(opt_bases.val, cmds)
        self.val = qlast.CreateConcreteUnknownPointer(
            name=name.val,
            bases=vbases,
            target=target,
            commands=vcmds,
        )
        self._validate()

    def reduce_CreateRegularQualifiedPointer(self, *kids):
        """%reduce
            PtrQuals PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcreteLinkSDLCommandsBlock
        """
        quals, name, opt_bases, opt_target, block = kids
        target, cmds = self._extract_target(
            opt_target.val, block.val, name.span)
        vbases, vcmds = commondl.extract_bases(opt_bases.val, cmds)
        self.val = qlast.CreateConcreteUnknownPointer(
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            name=name.val,
            bases=vbases,
            target=target,
            commands=vcmds,
        )
        self._validate()

    def reduce_CreateOverloadedPointer(self, *kids):
        """%reduce
            OVERLOADED PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcreteLinkSDLCommandsBlock
        """
        _, name, opt_bases, opt_target, block = kids
        target, cmds = self._extract_target(
            opt_target.val, block.val, name.span, overloaded=True)
        vbases, vcmds = commondl.extract_bases(opt_bases.val, cmds)
        self.val = qlast.CreateConcreteUnknownPointer(
            name=name.val,
            bases=vbases,
            declared_overloaded=True,
            is_required=None,
            cardinality=None,
            target=target,
            commands=vcmds,
        )
        self._validate()

    def reduce_CreateOverloadedQualifiedPointer(self, *kids):
        """%reduce
            OVERLOADED PtrQuals PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcreteLinkSDLCommandsBlock
        """
        _, quals, name, opt_bases, opt_target, block = kids
        target, cmds = self._extract_target(
            opt_target.val, block.val, name.span, overloaded=True)
        vbases, vcmds = commondl.extract_bases(opt_bases.val, cmds)
        self.val = qlast.CreateConcreteUnknownPointer(
            name=name.val,
            bases=vbases,
            declared_overloaded=True,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=target,
            commands=vcmds,
        )
        self._validate()


class ConcreteUnknownPointerShort(Nonterm):

    def reduce_CreateRegularPointer(self, *kids):
        """%reduce
            PathNodeName OptExtendingSimple
            PtrTarget
        """
        name, opt_bases, target = kids
        self.val = qlast.CreateConcreteUnknownPointer(
            name=name.val,
            bases=opt_bases.val,
            target=target.val,
        )

    def reduce_CreateRegularQualifiedPointer(self, *kids):
        """%reduce
            PtrQuals PathNodeName OptExtendingSimple
            PtrTarget
        """
        quals, name, opt_bases, target = kids
        self.val = qlast.CreateConcreteUnknownPointer(
            name=name.val,
            bases=opt_bases.val,
            target=target.val,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
        )

    def reduce_CreateOverloadedPointer(self, *kids):
        """%reduce
            OVERLOADED PathNodeName OptExtendingSimple
            OptPtrTarget
        """
        _, name, opt_bases, opt_target = kids
        self.val = qlast.CreateConcreteUnknownPointer(
            name=name.val,
            bases=opt_bases.val,
            declared_overloaded=True,
            is_required=None,
            cardinality=None,
            target=opt_target.val,
        )

    def reduce_CreateOverloadedQualifiedPointer(self, *kids):
        """%reduce
            OVERLOADED PtrQuals PathNodeName OptExtendingSimple
            OptPtrTarget
        """
        _, quals, name, opt_bases, opt_target = kids
        self.val = qlast.CreateConcreteUnknownPointer(
            name=name.val,
            bases=opt_bases.val,
            declared_overloaded=True,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=opt_target.val,
        )


# Unknown simple computed pointers can only go on objects, since they
# conflict with SetField on links.
class ConcreteUnknownPointerObjectShort(Nonterm):
    def reduce_CreateComputableUnknownPointer(self, *kids):
        """%reduce
            PathNodeName ASSIGN GenExpr
        """
        name, _, expr = kids
        self.val = qlast.CreateConcreteUnknownPointer(
            name=name.val,
            target=expr.val,
        )

    def reduce_CreateQualifiedComputableUnknownPointer(self, *kids):
        """%reduce
            PtrQuals PathNodeName ASSIGN GenExpr
        """
        quals, name, _, expr = kids
        self.val = qlast.CreateConcreteUnknownPointer(
            name=name.val,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=expr.val,
        )


#
# Properties
#
sdl_commands_block(
    'CreateProperty',
    Using,
    SetField,
    SetAnnotation,
    commondl.CreateSimpleExtending,
)


class PropertyDeclaration(Nonterm):
    def reduce_CreateProperty(self, *kids):
        r"""%reduce ABSTRACT PROPERTY PtrNodeName OptExtendingSimple \
                    CreatePropertySDLCommandsBlock \
        """
        _, _, name, extending, commands_block = kids

        vbases, vcommands = commondl.extract_bases(
            extending.val,
            commands_block.val
        )
        self.val = qlast.CreateProperty(
            name=name.val,
            bases=vbases,
            commands=vcommands,
            abstract=True,
        )


class PropertyDeclarationShort(Nonterm):
    def reduce_CreateProperty(self, *kids):
        r"""%reduce ABSTRACT PROPERTY PtrNodeName OptExtendingSimple"""
        _, _, name, extending = kids
        self.val = qlast.CreateProperty(
            name=name.val,
            bases=extending.val,
            abstract=True,
        )


sdl_commands_block(
    'CreateConcreteProperty',
    Using,
    SetField,
    SetAnnotation,
    ConcreteConstraintBlock,
    ConcreteConstraintShort,
    RewriteDeclarationBlock,
    RewriteDeclarationShort,
    commondl.CreateSimpleExtending,
)


class ConcretePropertyBlock(Nonterm):
    def _extract_target(self, target, cmds, span, *, overloaded=False):
        if target:
            return target, cmds

        for cmd in cmds:
            if isinstance(cmd, qlast.SetField) and cmd.name == 'expr':
                if target is not None:
                    raise errors.EdgeQLSyntaxError(
                        f'computed property with more than one expression',
                        span=span)
                target = cmd.value

        if not overloaded and target is None:
            raise errors.EdgeQLSyntaxError(
                f'computed property without expression',
                span=span)

        return target, cmds

    def reduce_CreateRegularProperty(self, *kids):
        """%reduce
            PROPERTY PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcretePropertySDLCommandsBlock
        """
        _, name, extending, target, commands_block = kids

        target, cmds = self._extract_target(
            target.val, commands_block.val, name.span
        )
        vbases, vcmds = commondl.extract_bases(extending.val, cmds)
        self.val = qlast.CreateConcreteProperty(
            name=name.val,
            bases=vbases,
            target=target,
            commands=vcmds,
        )

    def reduce_CreateRegularQualifiedProperty(self, *kids):
        """%reduce
            PtrQuals PROPERTY PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcretePropertySDLCommandsBlock
        """
        (quals, property, name, extending, target, commands) = kids

        target, cmds = self._extract_target(
            target.val, commands.val, property.span
        )
        vbases, vcmds = commondl.extract_bases(extending.val, cmds)
        self.val = qlast.CreateConcreteProperty(
            name=name.val,
            bases=vbases,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=target,
            commands=vcmds,
        )

    def reduce_CreateOverloadedProperty(self, *kids):
        """%reduce
            OVERLOADED PROPERTY PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcretePropertySDLCommandsBlock
        """
        _, _, name, opt_bases, opt_target, block = kids
        target, cmds = self._extract_target(
            opt_target.val, block.val, name.span, overloaded=True)
        vbases, vcmds = commondl.extract_bases(opt_bases.val, cmds)
        self.val = qlast.CreateConcreteProperty(
            name=name.val,
            bases=vbases,
            declared_overloaded=True,
            is_required=None,
            cardinality=None,
            target=target,
            commands=vcmds,
        )

    def reduce_CreateOverloadedQualifiedProperty(self, *kids):
        """%reduce
            OVERLOADED PtrQuals PROPERTY PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcretePropertySDLCommandsBlock
        """
        _, quals, _, name, opt_bases, opt_target, block = kids
        target, cmds = self._extract_target(
            opt_target.val, block.val, name.span, overloaded=True)
        vbases, vcmds = commondl.extract_bases(opt_bases.val, cmds)
        self.val = qlast.CreateConcreteProperty(
            name=name.val,
            bases=vbases,
            declared_overloaded=True,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=target,
            commands=vcmds,
        )


class ConcretePropertyShort(Nonterm):
    def reduce_CreateRegularProperty(self, *kids):
        """%reduce
            PROPERTY PathNodeName OptExtendingSimple PtrTarget
        """
        _, name, extending, target = kids
        self.val = qlast.CreateConcreteProperty(
            name=name.val,
            bases=extending.val,
            target=target.val,
        )

    def reduce_CreateRegularQualifiedProperty(self, *kids):
        """%reduce
            PtrQuals PROPERTY PathNodeName OptExtendingSimple PtrTarget
        """
        quals, _, name, extending, target = kids
        self.val = qlast.CreateConcreteProperty(
            name=name.val,
            bases=extending.val,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=target.val,
        )

    def reduce_CreateOverloadedProperty(self, *kids):
        """%reduce
            OVERLOADED PROPERTY PathNodeName OptExtendingSimple
            OptPtrTarget
        """
        _, _, name, opt_bases, opt_target = kids
        self.val = qlast.CreateConcreteProperty(
            name=name.val,
            bases=opt_bases.val,
            declared_overloaded=True,
            is_required=None,
            cardinality=None,
            target=opt_target.val,
        )

    def reduce_CreateOverloadedQualifiedProperty(self, *kids):
        """%reduce
            OVERLOADED PtrQuals PROPERTY PathNodeName OptExtendingSimple
            OptPtrTarget
        """
        _, quals, _, name, opt_bases, opt_target = kids
        self.val = qlast.CreateConcreteProperty(
            name=name.val,
            bases=opt_bases.val,
            declared_overloaded=True,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=opt_target.val,
        )

    def reduce_CreateComputableProperty(self, *kids):
        """%reduce
            PROPERTY PathNodeName ASSIGN GenExpr
        """
        _, name, _, expr = kids
        self.val = qlast.CreateConcreteProperty(
            name=name.val,
            target=expr.val,
        )

    def reduce_CreateQualifiedComputableProperty(self, *kids):
        """%reduce
            PtrQuals PROPERTY PathNodeName ASSIGN GenExpr
        """
        quals, _, name, _, expr = kids
        self.val = qlast.CreateConcreteProperty(
            name=name.val,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=expr.val,
        )


#
# Links
#

sdl_commands_block(
    'CreateLink',
    SetField,
    SetAnnotation,
    ConcreteConstraintBlock,
    ConcreteConstraintShort,
    ConcretePropertyBlock,
    ConcretePropertyShort,
    ConcreteUnknownPointerBlock,
    ConcreteUnknownPointerShort,
    ConcreteIndexDeclarationBlock,
    ConcreteIndexDeclarationShort,
    RewriteDeclarationShort,
    RewriteDeclarationBlock,
    commondl.CreateSimpleExtending,
)


class LinkDeclaration(Nonterm):
    def reduce_CreateLink(self, *kids):
        r"""%reduce \
            ABSTRACT LINK PtrNodeName OptExtendingSimple \
            CreateLinkSDLCommandsBlock \
        """
        _, _, name, extending, commands = kids
        vbases, vcommands = commondl.extract_bases(extending.val, commands.val)
        self.val = qlast.CreateLink(
            name=name.val,
            bases=vbases,
            commands=vcommands,
            abstract=True,
        )


class LinkDeclarationShort(Nonterm):
    def reduce_CreateLink(self, *kids):
        r"""%reduce \
            ABSTRACT LINK PtrNodeName OptExtendingSimple"""
        _, _, name, extending = kids
        self.val = qlast.CreateLink(
            name=name.val,
            bases=extending.val,
            abstract=True,
        )


sdl_commands_block(
    'CreateConcreteLink',
    Using,
    SetField,
    SetAnnotation,
    ConcreteConstraintBlock,
    ConcreteConstraintShort,
    ConcretePropertyBlock,
    ConcretePropertyShort,
    ConcreteUnknownPointerBlock,
    ConcreteUnknownPointerShort,
    ConcreteIndexDeclarationBlock,
    ConcreteIndexDeclarationShort,
    commondl.OnTargetDeleteStmt,
    commondl.OnSourceDeleteStmt,
    RewriteDeclarationShort,
    RewriteDeclarationBlock,
    commondl.CreateSimpleExtending,
)


class ConcreteLinkBlock(Nonterm):
    def _validate(self):
        on_target_delete = None
        for cmd in self.val.commands:
            if isinstance(cmd, qlast.OnTargetDelete):
                if on_target_delete:
                    raise errors.EdgeQLSyntaxError(
                        f"more than one 'on target delete' specification",
                        span=cmd.span)
                else:
                    on_target_delete = cmd

    def _extract_target(self, target, cmds, span, *, overloaded=False):
        if target:
            return target, cmds

        for cmd in cmds:
            if isinstance(cmd, qlast.SetField) and cmd.name == 'expr':
                if target is not None:
                    raise errors.EdgeQLSyntaxError(
                        f'computed link with more than one expression',
                        span=span)
                target = cmd.value

        if not overloaded and target is None:
            raise errors.EdgeQLSyntaxError(
                f'computed link without expression',
                span=span)

        return target, cmds

    def reduce_CreateRegularLink(self, *kids):
        """%reduce
            LINK PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcreteLinkSDLCommandsBlock
        """
        _, name, extending, target, commands = kids
        target, cmds = self._extract_target(
            target.val, commands.val, name.span
        )
        vbases, vcmds = commondl.extract_bases(extending.val, cmds)
        self.val = qlast.CreateConcreteLink(
            name=name.val,
            bases=vbases,
            target=target,
            commands=vcmds,
        )
        self._validate()

    def reduce_CreateRegularQualifiedLink(self, *kids):
        """%reduce
            PtrQuals LINK PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcreteLinkSDLCommandsBlock
        """
        quals, _, name, extending, target, commands = kids
        target, cmds = self._extract_target(
            target.val, commands.val, name.span
        )
        vbases, vcmds = commondl.extract_bases(extending.val, cmds)
        self.val = qlast.CreateConcreteLink(
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            name=name.val,
            bases=vbases,
            target=target,
            commands=vcmds,
        )
        self._validate()

    def reduce_CreateOverloadedLink(self, *kids):
        """%reduce
            OVERLOADED LINK PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcreteLinkSDLCommandsBlock
        """
        _, _, name, opt_bases, opt_target, block = kids
        target, cmds = self._extract_target(
            opt_target.val, block.val, name.span, overloaded=True)
        vbases, vcmds = commondl.extract_bases(opt_bases.val, cmds)
        self.val = qlast.CreateConcreteLink(
            name=name.val,
            bases=vbases,
            declared_overloaded=True,
            is_required=None,
            cardinality=None,
            target=target,
            commands=vcmds,
        )
        self._validate()

    def reduce_CreateOverloadedQualifiedLink(self, *kids):
        """%reduce
            OVERLOADED PtrQuals LINK PathNodeName OptExtendingSimple
            OptPtrTarget CreateConcreteLinkSDLCommandsBlock
        """
        _, quals, _, name, opt_bases, opt_target, block = kids
        target, cmds = self._extract_target(
            opt_target.val, block.val, name.span, overloaded=True)
        vbases, vcmds = commondl.extract_bases(opt_bases.val, cmds)
        self.val = qlast.CreateConcreteLink(
            name=name.val,
            bases=vbases,
            declared_overloaded=True,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=target,
            commands=vcmds,
        )
        self._validate()


class ConcreteLinkShort(Nonterm):

    def reduce_CreateRegularLink(self, *kids):
        """%reduce
            LINK PathNodeName OptExtendingSimple
            PtrTarget
        """
        _, name, opt_bases, target = kids
        self.val = qlast.CreateConcreteLink(
            name=name.val,
            bases=opt_bases.val,
            target=target.val,
        )

    def reduce_CreateRegularQualifiedLink(self, *kids):
        """%reduce
            PtrQuals LINK PathNodeName OptExtendingSimple
            PtrTarget
        """
        quals, _, name, opt_bases, target = kids
        self.val = qlast.CreateConcreteLink(
            name=name.val,
            bases=opt_bases.val,
            target=target.val,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
        )

    def reduce_CreateOverloadedLink(self, *kids):
        """%reduce
            OVERLOADED LINK PathNodeName OptExtendingSimple
            OptPtrTarget
        """
        _, _, name, opt_bases, opt_target = kids
        self.val = qlast.CreateConcreteLink(
            name=name.val,
            bases=opt_bases.val,
            declared_overloaded=True,
            is_required=None,
            cardinality=None,
            target=opt_target.val,
        )

    def reduce_CreateOverloadedQualifiedLink(self, *kids):
        """%reduce
            OVERLOADED PtrQuals LINK PathNodeName OptExtendingSimple
            OptPtrTarget
        """
        _, quals, _, name, opt_bases, opt_target = kids
        self.val = qlast.CreateConcreteLink(
            name=name.val,
            bases=opt_bases.val,
            declared_overloaded=True,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=opt_target.val,
        )

    def reduce_CreateComputableLink(self, *kids):
        """%reduce
            LINK PathNodeName ASSIGN GenExpr
        """
        _, name, _, expr = kids
        self.val = qlast.CreateConcreteLink(
            name=name.val,
            target=expr.val,
        )

    def reduce_CreateQualifiedComputableLink(self, *kids):
        """%reduce
            PtrQuals LINK PathNodeName ASSIGN GenExpr
        """
        quals, _, name, _, expr = kids
        self.val = qlast.CreateConcreteLink(
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            name=name.val,
            target=expr.val,
        )


#
# Access Policies
#
sdl_commands_block(
    'CreateAccessPolicy',
    SetField,
    SetAnnotation
)


class AccessPolicyDeclarationBlock(Nonterm):
    def reduce_CreateAccessPolicy(self, *kids):
        """%reduce
            ACCESS POLICY ShortNodeName
            OptWhenBlock AccessPolicyAction AccessKindList
            OptUsingBlock
            CreateAccessPolicySDLCommandsBlock
        """
        _, _, name, when, action, access_kinds, using, commands = kids
        self.val = qlast.CreateAccessPolicy(
            name=name.val,
            condition=when.val,
            action=action.val,
            access_kinds=[y for x in access_kinds.val for y in x],
            expr=using.val,
            commands=commands.val,
        )


class AccessPolicyDeclarationShort(Nonterm):
    def reduce_CreateAccessPolicy(self, *kids):
        """%reduce
            ACCESS POLICY ShortNodeName
            OptWhenBlock AccessPolicyAction AccessKindList
            OptUsingBlock
        """
        _, _, name, when, action, access_kinds, using = kids
        self.val = qlast.CreateAccessPolicy(
            name=name.val,
            condition=when.val,
            action=action.val,
            access_kinds=[y for x in access_kinds.val for y in x],
            expr=using.val,
        )


#
# Triggers
#
sdl_commands_block(
    'CreateTrigger',
    SetField,
    SetAnnotation
)


class TriggerDeclarationBlock(Nonterm):
    def reduce_CreateTrigger(self, *kids):
        """%reduce
            TRIGGER NodeName
            TriggerTiming TriggerKindList
            FOR TriggerScope
            OptWhenBlock
            DO ParenExpr
            CreateTriggerSDLCommandsBlock
        """
        _, name, timing, kinds, _, scope, when, _, expr, commands = kids
        self.val = qlast.CreateTrigger(
            name=name.val,
            timing=timing.val,
            kinds=kinds.val,
            scope=scope.val,
            expr=expr.val,
            condition=when.val,
            commands=commands.val,
        )


class TriggerDeclarationShort(Nonterm):
    def reduce_CreateTrigger(self, *kids):
        """%reduce
            TRIGGER NodeName
            TriggerTiming TriggerKindList
            FOR TriggerScope
            OptWhenBlock
            DO ParenExpr
        """
        _, name, timing, kinds, _, scope, when, _, expr = kids
        self.val = qlast.CreateTrigger(
            name=name.val,
            timing=timing.val,
            kinds=kinds.val,
            scope=scope.val,
            expr=expr.val,
            condition=when.val,
        )


#
# Object Types
#

sdl_commands_block(
    'CreateObjectType',
    SetAnnotation,
    ConcretePropertyBlock,
    ConcretePropertyShort,
    ConcreteLinkBlock,
    ConcreteLinkShort,
    ConcreteUnknownPointerBlock,
    ConcreteUnknownPointerShort,
    ConcreteUnknownPointerObjectShort,
    ConcreteConstraintBlock,
    ConcreteConstraintShort,
    ConcreteIndexDeclarationBlock,
    ConcreteIndexDeclarationShort,
    AccessPolicyDeclarationBlock,
    AccessPolicyDeclarationShort,
    TriggerDeclarationBlock,
    TriggerDeclarationShort,
)


class ObjectTypeDeclaration(Nonterm):
    def reduce_CreateAbstractObjectTypeStmt(self, *kids):
        r"""%reduce \
            ABSTRACT TYPE NodeName OptExtendingSimple \
            CreateObjectTypeSDLCommandsBlock \
        """
        _, _, name, extending, commands = kids
        self.val = qlast.CreateObjectType(
            abstract=True,
            name=name.val,
            bases=extending.val,
            commands=commands.val,
        )

    def reduce_CreateRegularObjectTypeStmt(self, *kids):
        r"""%reduce \
            TYPE NodeName OptExtendingSimple \
            CreateObjectTypeSDLCommandsBlock \
        """
        _, name, extending, commands = kids
        self.val = qlast.CreateObjectType(
            name=name.val,
            bases=extending.val,
            commands=commands.val,
        )


class ObjectTypeDeclarationShort(Nonterm):
    def reduce_CreateAbstractObjectTypeStmt(self, *kids):
        r"""%reduce \
            ABSTRACT TYPE NodeName OptExtendingSimple"""
        _, _, name, extending = kids
        self.val = qlast.CreateObjectType(
            abstract=True,
            name=name.val,
            bases=extending.val,
        )

    def reduce_CreateRegularObjectTypeStmt(self, *kids):
        r"""%reduce \
            TYPE NodeName OptExtendingSimple"""
        _, name, extending = kids
        self.val = qlast.CreateObjectType(
            name=name.val,
            bases=extending.val,
        )


#
# Aliases
#

sdl_commands_block(
    'CreateAlias',
    Using,
    SetField,
    SetAnnotation,
    opt=False
)


class AliasDeclaration(Nonterm):
    def reduce_CreateAliasRegularStmt(self, *kids):
        r"""%reduce
            ALIAS NodeName CreateAliasSDLCommandsBlock
        """
        _, name, commands = kids
        self.val = qlast.CreateAlias(
            name=name.val,
            commands=commands.val,
        )


class AliasDeclarationShort(Nonterm):
    def reduce_CreateAliasShortStmt(self, *kids):
        r"""%reduce
            ALIAS NodeName ASSIGN GenExpr
        """
        _, name, _, expr = kids
        self.val = qlast.CreateAlias(
            name=name.val,
            commands=[
                qlast.SetField(
                    name='expr',
                    value=expr.val,
                    special_syntax=True,
                    span=self.span,
                )
            ]
        )

    def reduce_CreateAliasRegularStmt(self, *kids):
        r"""%reduce
            ALIAS NodeName CreateAliasSingleSDLCommandBlock
        """
        _, name, commands = kids
        self.val = qlast.CreateAlias(
            name=name.val,
            commands=commands.val,
        )


#
# Functions
#


sdl_commands_block(
    'CreateFunction',
    commondl.FromFunction,
    SetField,
    SetAnnotation,
    opt=False
)


class FunctionDeclaration(Nonterm, commondl.ProcessFunctionBlockMixin):
    def reduce_CreateFunction(self, *kids):
        r"""%reduce FUNCTION NodeName CreateFunctionArgs \
                FunctionResult CreateFunctionSDLCommandsBlock
        """
        _, name, args, result, body = kids
        self.val = qlast.CreateFunction(
            name=name.val,
            params=args.val,
            returning=result.val.result_type,
            returning_typemod=result.val.type_qualifier,
            **self._process_function_body(body),
        )


class FunctionDeclarationShort(Nonterm, commondl.ProcessFunctionBlockMixin):
    def reduce_CreateFunction(self, *kids):
        r"""%reduce FUNCTION NodeName CreateFunctionArgs \
                FunctionResult CreateFunctionSingleSDLCommandBlock
        """
        _, name, args, result, body = kids
        self.val = qlast.CreateFunction(
            name=name.val,
            params=args.val,
            returning=result.val.result_type,
            returning_typemod=result.val.type_qualifier,
            **self._process_function_body(body),
        )


#
# Globals
#

sdl_commands_block(
    'CreateGlobal',
    Using,
    SetField,
    SetAnnotation,
)


class GlobalDeclaration(Nonterm):
    def _extract_target(self, target, cmds, span, *, overloaded=False):
        if target:
            return target, cmds

        for cmd in cmds:
            if isinstance(cmd, qlast.SetField) and cmd.name == 'expr':
                if target is not None:
                    raise errors.EdgeQLSyntaxError(
                        f'computed global with more than one expression',
                        span=span)
                target = cmd.value

        if not overloaded and target is None:
            raise errors.EdgeQLSyntaxError(
                f'computed property without expression',
                span=span)

        return target, cmds

    def reduce_CreateGlobalQuals(self, *kids):
        """%reduce
            PtrQuals GLOBAL NodeName
            OptPtrTarget CreateGlobalSDLCommandsBlock
        """
        quals, glob, name, target, commands = kids
        target, cmds = self._extract_target(
            target.val, commands.val, glob.span
        )
        self.val = qlast.CreateGlobal(
            name=name.val,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=target,
            commands=cmds,
        )

    def reduce_CreateGlobal(self, *kids):
        """%reduce
            GLOBAL NodeName
            OptPtrTarget CreateGlobalSDLCommandsBlock
        """
        glob, name, target, commands = kids
        target, cmds = self._extract_target(
            target.val, commands.val, glob.span
        )
        self.val = qlast.CreateGlobal(
            name=name.val,
            target=target,
            commands=cmds,
        )


class GlobalDeclarationShort(Nonterm):
    def reduce_CreateRegularGlobalShortQuals(self, *kids):
        """%reduce
            PtrQuals GLOBAL NodeName PtrTarget
        """
        quals, _, name, target = kids
        self.val = qlast.CreateGlobal(
            name=name.val,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=target.val,
        )

    def reduce_CreateRegularGlobalShort(self, *kids):
        """%reduce
            GLOBAL NodeName PtrTarget
        """
        _, name, target = kids
        self.val = qlast.CreateGlobal(
            name=name.val,
            target=target.val,
        )

    def reduce_CreateComputedGlobalShortQuals(self, *kids):
        """%reduce
            PtrQuals GLOBAL NodeName ASSIGN GenExpr
        """
        quals, _, name, _, expr = kids
        self.val = qlast.CreateGlobal(
            name=name.val,
            is_required=quals.val.required,
            cardinality=quals.val.cardinality,
            target=expr.val,
        )

    def reduce_CreateComputedGlobalShort(self, *kids):
        """%reduce
            GLOBAL NodeName ASSIGN GenExpr
        """
        _, name, _, expr = kids
        self.val = qlast.CreateGlobal(
            name=name.val,
            target=expr.val,
        )


#
# Permissions
#


sdl_commands_block(
    'CreatePermission',
    SetAnnotation,
)


class PermissionDeclaration(Nonterm):
    def reduce_CreatePermission(self, *kids):
        """%reduce
            PERMISSION NodeName
            CreatePermissionSDLCommandsBlock
        """
        _, name, commands = kids
        self.val = qlast.CreatePermission(
            name=name.val,
            commands=commands.val,
        )


class PermissionDeclarationShort(Nonterm):
    def reduce_CreatePermission(self, *kids):
        """%reduce
            PERMISSION NodeName
        """
        _, name = kids
        self.val = qlast.CreatePermission(
            name=name.val,
        )


================================================
FILE: edb/edgeql/parser/grammar/session.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from edb.edgeql import ast as qlast

from .expressions import Nonterm
from .tokens import *  # NOQA
from .expressions import *  # NOQA


class SetStmt(Nonterm):
    val: qlast.SessionSetAliasDecl

    def reduce_SET_ALIAS_Identifier_AS_MODULE_ModuleName(self, *kids):
        _, _, alias, _, _, module = kids
        self.val = qlast.SessionSetAliasDecl(
            decl=qlast.ModuleAliasDecl(
                module='::'.join(module.val), alias=alias.val, span=self.span
            )
        )

    def reduce_SET_MODULE_ModuleName(self, *kids):
        _, _, module = kids
        self.val = qlast.SessionSetAliasDecl(
            decl=qlast.ModuleAliasDecl(
                module='::'.join(module.val), span=self.span
            )
        )


class ResetStmt(Nonterm):
    val: qlast.SessionResetAliasDecl

    def reduce_RESET_ALIAS_Identifier(self, *kids):
        self.val = qlast.SessionResetAliasDecl(
            alias=kids[2].val)

    def reduce_RESET_MODULE(self, *kids):
        self.val = qlast.SessionResetModule()

    def reduce_RESET_ALIAS_STAR(self, *kids):
        self.val = qlast.SessionResetAllAliases()


================================================
FILE: edb/edgeql/parser/grammar/start.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

from edb.common import parsing
from edb.edgeql import ast as qlast

from . import commondl
from .expressions import Nonterm
from .precedence import *  # NOQA
from .tokens import *  # NOQA
from .statements import *  # NOQA
from .ddl import *  # NOQA
from .session import *  # NOQA
from .config import *  # NOQA


# The main EdgeQL grammar, all of whose productions should start with a
# GrammarToken, that determines the "subgrammar" to use.
#
# To add a new "subgrammar":
# - add a new GrammarToken in tokens.py,
# - add a new production here,
# - add a new token kind in tokenizer.rs,
# - add a mapping from the Python token name into the Rust token kind
#   in parser.rs `fn get_token_kind`
class EdgeQLGrammar(Nonterm):
    "%start"

    val: qlast.GrammarEntryPoint

    def reduce_STARTBLOCK_EdgeQLBlock_EOI(self, *kids):
        self.val = kids[1].val

    def reduce_STARTEXTENSION_CreateExtensionPackageCommandsBlock_EOI(self, *k):
        self.val = k[1].val

    def reduce_STARTMIGRATION_CreateMigrationCommandsBlock_EOI(self, *kids):
        self.val = kids[1].val

    def reduce_STARTFRAGMENT_ExprStmt_EOI(self, *kids):
        self.val = kids[1].val

    def reduce_STARTFRAGMENT_Expr_EOI(self, *kids):
        self.val = kids[1].val

    def reduce_STARTSDLDOCUMENT_SDLDocument_EOI(self, *kids):
        self.val = kids[1].val


class EdgeQLBlock(Nonterm):
    val: qlast.Commands

    def reduce_StmtList_OptSemicolons(self, s, _semicolon):
        self.val = qlast.Commands(commands=s.val)

    def reduce_OptSemicolons(self, _semicolon):
        self.val = qlast.Commands(commands=[])


class SingleStmt(Nonterm):
    val: qlast.Command

    @parsing.inline(0)
    def reduce_Stmt(self, stmt):
        pass

    def reduce_IfThenElseExpr(self, *kids):
        # TODO: this should not be here, but in ExprStmtSimpleCore instead
        self.val = qlast.SelectQuery(result=kids[0].val, implicit=True)

    @parsing.inline(0)
    def reduce_DDLStmt(self, _):
        # Data definition commands
        pass

    @parsing.inline(0)
    def reduce_SetStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ResetStmt(self, *kids):
        pass

    @parsing.inline(0)
    def reduce_ConfigStmt(self, _):
        # Configuration commands
        pass


class StmtList(
    parsing.ListNonterm, element=SingleStmt, separator=commondl.Semicolons
):
    val: list[qlast.Command]


class SDLDocument(Nonterm):
    def reduce_OptSemicolons(self, *kids):
        self.val = qlast.Schema(declarations=[])

    def reduce_statement_without_semicolons(self, *kids):
        r"""%reduce \
            OptSemicolons SDLShortStatement
        """
        declarations = [kids[1].val]
        commondl._validate_declarations(declarations)
        self.val = qlast.Schema(declarations=declarations)

    def reduce_statements_without_optional_trailing_semicolons(self, *kids):
        r"""%reduce \
            OptSemicolons SDLStatements \
            OptSemicolons SDLShortStatement
        """
        declarations = kids[1].val + [kids[3].val]
        commondl._validate_declarations(declarations)
        self.val = qlast.Schema(declarations=declarations)

    def reduce_OptSemicolons_SDLStatements(self, *kids):
        declarations = kids[1].val
        commondl._validate_declarations(declarations)
        self.val = qlast.Schema(declarations=declarations)

    def reduce_OptSemicolons_SDLStatements_Semicolons(self, *kids):
        declarations = kids[1].val
        commondl._validate_declarations(declarations)
        self.val = qlast.Schema(declarations=declarations)


================================================
FILE: edb/edgeql/parser/grammar/statements.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

import typing

from edb import errors
from edb.common import parsing

from edb.edgeql import ast as qlast
from edb.edgeql import qltypes

from .expressions import Nonterm, ListNonterm
from .precedence import *  # NOQA
from .tokens import *  # NOQA
from .expressions import *  # NOQA

from . import tokens


class Stmt(Nonterm):
    val: qlast.Command

    @parsing.inline(0)
    def reduce_TransactionStmt(self, stmt):
        pass

    @parsing.inline(0)
    def reduce_DescribeStmt(self, stmt):
        # DESCRIBE
        pass

    @parsing.inline(0)
    def reduce_AnalyzeStmt(self, stmt):
        # ANALYZE
        pass

    @parsing.inline(0)
    def reduce_AdministerStmt(self, stmt):
        pass

    @parsing.inline(0)
    def reduce_ExprStmt(self, stmt):
        pass


class TransactionMode(Nonterm):

    def reduce_ISOLATION_SERIALIZABLE(self, *kids):
        self.val = (qltypes.TransactionIsolationLevel.SERIALIZABLE,
                    kids[0].span)

    def reduce_ISOLATION_REPEATABLE_READ(self, *kids):
        self.val = (qltypes.TransactionIsolationLevel.REPEATABLE_READ,
                    kids[0].span)

    def reduce_READ_WRITE(self, *kids):
        self.val = (qltypes.TransactionAccessMode.READ_WRITE,
                    kids[0].span)

    def reduce_READ_ONLY(self, *kids):
        self.val = (qltypes.TransactionAccessMode.READ_ONLY,
                    kids[0].span)

    def reduce_DEFERRABLE(self, *kids):
        self.val = (qltypes.TransactionDeferMode.DEFERRABLE,
                    kids[0].span)

    def reduce_NOT_DEFERRABLE(self, *kids):
        self.val = (qltypes.TransactionDeferMode.NOT_DEFERRABLE,
                    kids[0].span)


class TransactionModeList(ListNonterm, element=TransactionMode,
                          separator=tokens.T_COMMA):
    pass


class OptTransactionModeList(Nonterm):

    @parsing.inline(0)
    def reduce_TransactionModeList(self, *kids):
        pass

    def reduce_empty(self, *kids):
        self.val = []


class TransactionStmt(Nonterm):

    def reduce_START_TRANSACTION_OptTransactionModeList(self, *kids):
        modes = kids[2].val

        isolation = None
        access = None
        deferrable = None

        for mode, mode_ctx in modes:
            if isinstance(mode, qltypes.TransactionIsolationLevel):
                if isolation is not None:
                    raise errors.EdgeQLSyntaxError(
                        f"only one isolation level can be specified",
                        span=mode_ctx)
                isolation = mode

            elif isinstance(mode, qltypes.TransactionAccessMode):
                if access is not None:
                    raise errors.EdgeQLSyntaxError(
                        f"only one access mode can be specified",
                        span=mode_ctx)
                access = mode

            else:
                assert isinstance(mode, qltypes.TransactionDeferMode)
                if deferrable is not None:
                    raise errors.EdgeQLSyntaxError(
                        f"deferrable mode can only be specified once",
                        span=mode_ctx)
                deferrable = mode

        self.val = qlast.StartTransaction(
            isolation=isolation, access=access, deferrable=deferrable)

    def reduce_COMMIT(self, *kids):
        self.val = qlast.CommitTransaction()

    def reduce_ROLLBACK(self, *kids):
        self.val = qlast.RollbackTransaction()

    def reduce_DECLARE_SAVEPOINT_Identifier(self, *kids):
        self.val = qlast.DeclareSavepoint(name=kids[2].val)

    def reduce_ROLLBACK_TO_SAVEPOINT_Identifier(self, *kids):
        self.val = qlast.RollbackToSavepoint(name=kids[3].val)

    def reduce_RELEASE_SAVEPOINT_Identifier(self, *kids):
        self.val = qlast.ReleaseSavepoint(name=kids[2].val)


class DescribeFmt(typing.NamedTuple):

    language: typing.Optional[qltypes.DescribeLanguage] = None
    options: typing.Optional[qlast.Options] = None


class DescribeFormat(Nonterm):
    val: DescribeFmt

    def reduce_empty(self, *kids):
        self.val = DescribeFmt(
            language=qltypes.DescribeLanguage.DDL,
            options=qlast.Options(),
        )

    def reduce_AS_DDL(self, *kids):
        self.val = DescribeFmt(
            language=qltypes.DescribeLanguage.DDL,
            options=qlast.Options(),
        )

    def reduce_AS_SDL(self, *kids):
        self.val = DescribeFmt(
            language=qltypes.DescribeLanguage.SDL,
            options=qlast.Options(),
        )

    def reduce_AS_JSON(self, *kids):
        self.val = DescribeFmt(
            language=qltypes.DescribeLanguage.JSON,
            options=qlast.Options(),
        )

    def reduce_AS_TEXT(self, *kids):
        self.val = DescribeFmt(
            language=qltypes.DescribeLanguage.TEXT,
            options=qlast.Options(),
        )

    def reduce_AS_TEXT_VERBOSE(self, *kids):
        self.val = DescribeFmt(
            language=qltypes.DescribeLanguage.TEXT,
            options=qlast.Options(
                options={'VERBOSE': qlast.OptionFlag(
                    name='VERBOSE', val=True, span=kids[2].span)}
            ),
        )


class DescribeStmt(Nonterm):
    val: qlast.DescribeStmt

    def reduce_DESCRIBE_SCHEMA(self, *kids):
        """%reduce DESCRIBE SCHEMA DescribeFormat"""
        self.val = qlast.DescribeStmt(
            object=qlast.DescribeGlobal.Schema,
            language=kids[2].val.language,
            options=kids[2].val.options,
        )

    def reduce_DESCRIBE_CURRENT_DATABASE_CONFIG(self, *kids):
        """%reduce DESCRIBE CURRENT DATABASE CONFIG DescribeFormat"""
        self.val = qlast.DescribeStmt(
            object=qlast.DescribeGlobal.DatabaseConfig,
            language=kids[4].val.language,
            options=kids[4].val.options,
        )

    def reduce_DESCRIBE_CURRENT_BRANCH_CONFIG(self, *kids):
        """%reduce DESCRIBE CURRENT BRANCH CONFIG DescribeFormat"""
        self.val = qlast.DescribeStmt(
            object=qlast.DescribeGlobal.DatabaseConfig,
            language=kids[4].val.language,
            options=kids[4].val.options,
        )

    def reduce_DESCRIBE_INSTANCE_CONFIG(self, *kids):
        """%reduce DESCRIBE INSTANCE CONFIG DescribeFormat"""
        self.val = qlast.DescribeStmt(
            object=qlast.DescribeGlobal.InstanceConfig,
            language=kids[3].val.language,
            options=kids[3].val.options,
        )

    def reduce_DESCRIBE_SYSTEM_CONFIG(self, *kids):
        """%reduce DESCRIBE SYSTEM CONFIG DescribeFormat"""
        return self.reduce_DESCRIBE_INSTANCE_CONFIG(*kids)

    def reduce_DESCRIBE_ROLES(self, *kids):
        """%reduce DESCRIBE ROLES DescribeFormat"""
        self.val = qlast.DescribeStmt(
            object=qlast.DescribeGlobal.Roles,
            language=kids[2].val.language,
            options=kids[2].val.options,
        )

    def reduce_DESCRIBE_SchemaItem(self, *kids):
        """%reduce DESCRIBE SchemaItem DescribeFormat"""
        self.val = qlast.DescribeStmt(
            object=kids[1].val,
            language=kids[2].val.language,
            options=kids[2].val.options,
        )

    def reduce_DESCRIBE_OBJECT(self, *kids):
        """%reduce DESCRIBE OBJECT NodeName DescribeFormat"""
        self.val = qlast.DescribeStmt(
            object=kids[2].val,
            language=kids[3].val.language,
            options=kids[3].val.options,
        )

    def reduce_DESCRIBE_CURRENT_MIGRATION(self, *kids):
        """%reduce DESCRIBE CURRENT MIGRATION DescribeFormat"""
        lang = kids[3].val.language
        if (
            lang is not qltypes.DescribeLanguage.DDL
            and lang is not qltypes.DescribeLanguage.JSON
        ):
            raise errors.InvalidSyntaxError(
                f'unexpected DESCRIBE format: {lang!r}',
                span=kids[3].span,
            )
        if kids[3].val.options:
            raise errors.InvalidSyntaxError(
                f'DESCRIBE CURRENT MIGRATION does not support options',
                span=kids[3].span,
            )

        self.val = qlast.DescribeCurrentMigration(
            language=lang,
        )


class AnalyzeStmt(Nonterm):
    val: qlast.ExplainStmt

    def reduce_ANALYZE_NamedTuple_ExprStmt(self, *kids):
        _, args, stmt = kids
        self.val = qlast.ExplainStmt(
            args=args.val,
            query=stmt.val,
        )

    def reduce_ANALYZE_ExprStmt(self, *kids):
        _, stmt = kids
        self.val = qlast.ExplainStmt(
            query=stmt.val,
        )


class AdministerStmt(Nonterm):
    val: qlast.AdministerStmt

    def reduce_ADMINISTER_FuncExpr(self, *kids):
        _, expr = kids
        self.val = qlast.AdministerStmt(expr=expr.val)


================================================
FILE: edb/edgeql/parser/grammar/tokens.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2008-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

import re
import sys
import types
import typing

from edb.common import parsing

from . import keywords


clean_string = re.compile(r"'(?:\s|\n)+'")
string_quote = re.compile(r'\$(?:[A-Za-z_][A-Za-z_0-9]*)?\$')


class Token(parsing.Token, is_internal=True):
    pass


class GrammarToken(Token, is_internal=True):
    """
    Instead of having different grammars, we prefix each query with a special
    grammar token which directs the parser to appropriate grammar.

    This greatly reduces the combined size of grammar specifications, since the
    overlap between grammars is substantial.
    """


class T_STARTBLOCK(GrammarToken):
    pass


class T_STARTEXTENSION(GrammarToken):
    pass


class T_STARTFRAGMENT(GrammarToken):
    pass


class T_STARTMIGRATION(GrammarToken):
    pass


class T_STARTSDLDOCUMENT(GrammarToken):
    pass


class T_STRINTERPSTART(GrammarToken):
    pass


class T_STRINTERPCONT(GrammarToken):
    pass


class T_STRINTERPEND(GrammarToken):
    pass


class T_DOT(Token, lextoken='.'):
    pass


class T_DOTBW(Token, lextoken='.<'):
    pass


class T_DOTQ(Token, lextoken='.?>'):
    pass


class T_LBRACKET(Token, lextoken='['):
    pass


class T_RBRACKET(Token, lextoken=']'):
    pass


class T_LPAREN(Token, lextoken='('):
    pass


class T_RPAREN(Token, lextoken=')'):
    pass


class T_LBRACE(Token, lextoken='{'):
    pass


class T_RBRACE(Token, lextoken='}'):
    pass


class T_DOUBLECOLON(Token, lextoken='::'):
    pass


class T_DOUBLESTAR(Token, lextoken='**'):
    pass


class T_DOUBLEQMARK(Token, lextoken='??'):
    pass


class T_COLON(Token, lextoken=':'):
    pass


class T_SEMICOLON(Token, lextoken=';'):
    pass


class T_COMMA(Token, lextoken=','):
    pass


class T_PLUS(Token, lextoken='+'):
    pass


class T_DOUBLEPLUS(Token, lextoken='++'):
    pass


class T_MINUS(Token, lextoken='-'):
    pass


class T_STAR(Token, lextoken='*'):
    pass


class T_SLASH(Token, lextoken='/'):
    pass


class T_DOUBLESLASH(Token, lextoken='//'):
    pass


class T_PERCENT(Token, lextoken='%'):
    pass


class T_CIRCUMFLEX(Token, lextoken='^'):
    pass


class T_AT(Token, lextoken='@'):
    pass


class T_PARAMETER(Token):
    pass


class T_PARAMETERANDTYPE(Token):
    # A special token produced by normalization
    pass


class T_ASSIGN(Token, lextoken=':='):
    pass


class T_ADDASSIGN(Token, lextoken='+='):
    pass


class T_REMASSIGN(Token, lextoken='-='):
    pass


class T_ARROW(Token, lextoken='->'):
    pass


class T_LANGBRACKET(Token, lextoken='<'):
    pass


class T_RANGBRACKET(Token, lextoken='>'):
    pass


class T_EQUALS(Token, lextoken='='):
    pass


class T_AMPER(Token, lextoken='&'):
    pass


class T_PIPE(Token, lextoken='|'):
    pass


class T_NAMEDONLY(Token, lextoken='named only'):
    pass


class T_SETTYPE(Token, lextoken='set type'):
    pass


class T_EXTENSIONPACKAGE(Token, lextoken='extension package'):
    pass


class T_ORDERBY(Token, lextoken='order by'):
    pass


class T_ICONST(Token):
    pass


class T_NICONST(Token):
    pass


class T_FCONST(Token):
    pass


class T_NFCONST(Token):
    pass


class T_BCONST(Token):
    pass


class T_SCONST(Token):
    pass


class T_DISTINCTFROM(Token, lextoken="?!="):
    pass


class T_GREATEREQ(Token, lextoken=">="):
    pass


class T_LESSEQ(Token, lextoken="<="):
    pass


class T_NOTDISTINCTFROM(Token, lextoken="?="):
    pass


class T_NOTEQ(Token, lextoken="!="):
    pass


class T_IDENT(Token):
    pass


class T_EOI(Token):
    pass


# explicitly define tokens which are referenced elsewhere
T_THEN: typing.Optional[Token] = None


def _gen_keyword_tokens():
    # Define keyword tokens

    mod = sys.modules[__name__]

    def clsexec(ns):
        ns['__module__'] = __name__
        return ns

    for token, _ in keywords.edgeql_keywords.values():
        clsname = 'T_{}'.format(token)
        clskwds = dict(token=token)
        cls = types.new_class(clsname, (Token,), clskwds, clsexec)
        setattr(mod, clsname, cls)


_gen_keyword_tokens()


================================================
FILE: edb/edgeql/qltypes.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations
from typing import TYPE_CHECKING

import enum

from edb.common import enum as s_enum


if TYPE_CHECKING:
    from edb.schema import types as s_types


class ParameterKind(s_enum.StrEnum):
    VariadicParam = 'VariadicParam'
    NamedOnlyParam = 'NamedOnlyParam'
    PositionalParam = 'PositionalParam'

    def to_edgeql(self) -> str:
        if self is ParameterKind.VariadicParam:
            return 'VARIADIC'
        elif self is ParameterKind.NamedOnlyParam:
            return 'NAMED ONLY'
        else:
            return ''


class TypeModifier(s_enum.StrEnum):
    SetOfType = 'SetOfType'
    OptionalType = 'OptionalType'
    SingletonType = 'SingletonType'

    def to_edgeql(self) -> str:
        if self is TypeModifier.SetOfType:
            return 'SET OF'
        elif self is TypeModifier.OptionalType:
            return 'OPTIONAL'
        else:
            return ''


class Polymorphism(s_enum.StrEnum):
    NotUsed = 'NotUsed'
    Simple = 'Simple'
    Array = 'Array'
    Collection = 'Collection'

    @staticmethod
    def from_schema_type(type: s_types.Type) -> Polymorphism:
        return (
            Polymorphism.Simple
            if not type.is_collection() else
            Polymorphism.Array
            if type.is_array() else
            Polymorphism.Collection
        )


class OperatorKind(s_enum.StrEnum):
    Infix = 'Infix'
    Postfix = 'Postfix'
    Prefix = 'Prefix'
    Ternary = 'Ternary'


class TransactionIsolationLevel(s_enum.StrEnum):
    REPEATABLE_READ = 'REPEATABLE READ'
    SERIALIZABLE = 'SERIALIZABLE'


class TransactionAccessMode(s_enum.StrEnum):
    READ_WRITE = 'READ WRITE'
    READ_ONLY = 'READ ONLY'


class TransactionDeferMode(s_enum.StrEnum):
    DEFERRABLE = 'DEFERRABLE'
    NOT_DEFERRABLE = 'NOT DEFERRABLE'


class SchemaCardinality(s_enum.OrderedEnumMixin, s_enum.StrEnum):
    '''This enum is used to store cardinality in the schema.'''
    One = 'One'
    Many = 'Many'
    Unknown = 'Unknown'

    def is_multi(self) -> bool:
        if self is SchemaCardinality.One:
            return False
        elif self is SchemaCardinality.Many:
            return True
        else:
            raise ValueError('cardinality is unknown')

    def is_single(self) -> bool:
        return not self.is_multi()

    def is_known(self) -> bool:
        return self is not SchemaCardinality.Unknown

    def as_ptr_qual(self) -> str:
        if self is SchemaCardinality.One:
            return 'single'
        elif self is SchemaCardinality.Many:
            return 'multi'
        else:
            raise ValueError('cardinality is unknown')

    def to_edgeql(self) -> str:
        return self.as_ptr_qual().upper()


class Cardinality(s_enum.StrEnum):
    '''This enum is used in cardinality inference internally.'''
    # [0, 1]
    AT_MOST_ONE = 'AT_MOST_ONE'
    # [1, 1]
    ONE = 'ONE'
    # [0, inf)
    MANY = 'MANY'
    # [1, inf)
    AT_LEAST_ONE = 'AT_LEAST_ONE'
    # Sentinel
    UNKNOWN = 'UNKNOWN'

    def is_single(self) -> bool:
        return self in {Cardinality.AT_MOST_ONE, Cardinality.ONE}

    def is_multi(self) -> bool:
        return not self.is_single()

    def can_be_zero(self) -> bool:
        return self not in {Cardinality.ONE, Cardinality.AT_LEAST_ONE}

    def to_schema_value(self) -> tuple[bool, SchemaCardinality]:
        return _CARD_TO_TUPLE[self]

    @classmethod
    def from_schema_value(
        cls, required: bool, card: SchemaCardinality
    ) -> Cardinality:
        return _TUPLE_TO_CARD[(required, card)]


_CARD_TO_TUPLE = {
    Cardinality.AT_MOST_ONE: (False, SchemaCardinality.One),
    Cardinality.ONE: (True, SchemaCardinality.One),
    Cardinality.MANY: (False, SchemaCardinality.Many),
    Cardinality.AT_LEAST_ONE: (True, SchemaCardinality.Many),
}
_TUPLE_TO_CARD = {
    (False, SchemaCardinality.One): Cardinality.AT_MOST_ONE,
    (True, SchemaCardinality.One): Cardinality.ONE,
    (False, SchemaCardinality.Many): Cardinality.MANY,
    (True, SchemaCardinality.Many): Cardinality.AT_LEAST_ONE,
}


class Volatility(s_enum.OrderedEnumMixin, s_enum.StrEnum):
    # Make sure that the values appear from least volatile to most volatile.
    Immutable = 'Immutable'
    Stable = 'Stable'
    Volatile = 'Volatile'
    Modifying = 'Modifying'

    def is_volatile(self) -> bool:
        return self in (Volatility.Volatile, Volatility.Modifying)

    @classmethod
    def _missing_(cls, name):
        # We want both `volatility := 'immutable'` in SDL and
        # `SET volatility := 'IMMUTABLE`` in DDL to work.
        return cls(name.title())


class Multiplicity(s_enum.OrderedEnumMixin, s_enum.StrEnum):
    # Make sure that the values appear in ascending order.
    EMPTY = 'EMPTY'
    UNIQUE = 'UNIQUE'
    DUPLICATE = 'DUPLICATE'
    UNKNOWN = 'UNKNOWN'

    def is_empty(self) -> bool:
        return self is Multiplicity.EMPTY

    def is_unique(self) -> bool:
        return self is Multiplicity.UNIQUE

    def is_duplicate(self) -> bool:
        return self is Multiplicity.DUPLICATE


class IndexDeferrability(s_enum.OrderedEnumMixin, s_enum.StrEnum):
    Prohibited = 'Prohibited'
    Permitted = 'Permitted'
    Required = 'Required'

    def is_deferrable(self) -> bool:
        return (
            self is IndexDeferrability.Required
            or self is IndexDeferrability.Permitted
        )


class AccessPolicyAction(s_enum.StrEnum):
    Allow = 'Allow'
    Deny = 'Deny'


class AccessKind(s_enum.StrEnum):
    Select = 'Select'
    UpdateRead = 'UpdateRead'
    UpdateWrite = 'UpdateWrite'
    Delete = 'Delete'
    Insert = 'Insert'

    def is_data_check(self) -> bool:
        return self is AccessKind.UpdateWrite or self is AccessKind.Insert


class TriggerTiming(s_enum.StrEnum):
    After = 'After'
    AfterCommitOf = 'After Commit Of'


class TriggerKind(s_enum.StrEnum):
    Update = 'Update'
    Delete = 'Delete'
    Insert = 'Insert'


class TriggerScope(s_enum.StrEnum):
    Each = 'Each'
    All = 'All'


class RewriteKind(s_enum.StrEnum):
    Update = 'Update'
    Insert = 'Insert'


class SplatStrategy(s_enum.StrEnum):
    Default = 'Default'
    Explicit = 'Explicit'
    Implicit = 'Implicit'


class DescribeLanguage(s_enum.StrEnum):
    DDL = 'DDL'
    SDL = 'SDL'
    TEXT = 'TEXT'
    JSON = 'JSON'


class SchemaObjectClass(s_enum.StrEnum):

    ACCESS_POLICY = 'ACCESS_POLICY'
    ALIAS = 'ALIAS'
    ANNOTATION = 'ANNOTATION'
    ARRAY_TYPE = 'ARRAY TYPE'
    BRANCH = 'BRANCH'
    CAST = 'CAST'
    CONSTRAINT = 'CONSTRAINT'
    DATABASE = 'DATABASE'
    EXTENSION = 'EXTENSION'
    EXTENSION_PACKAGE = 'EXTENSION PACKAGE'
    EXTENSION_PACKAGE_MIGRATION = 'EXTENSION PACKAGE MIGRATION'
    FUTURE = 'FUTURE'
    FUNCTION = 'FUNCTION'
    GLOBAL = 'GLOBAL'
    INDEX = 'INDEX'
    INDEX_MATCH = 'INDEX MATCH'
    LINK = 'LINK'
    MIGRATION = 'MIGRATION'
    MODULE = 'MODULE'
    MULTIRANGE_TYPE = 'MULTIRANGE_TYPE'
    OPERATOR = 'OPERATOR'
    PARAMETER = 'PARAMETER'
    PERMISSION = 'PERMISSION'
    PROPERTY = 'PROPERTY'
    PSEUDO_TYPE = 'PSEUDO TYPE'
    RANGE_TYPE = 'RANGE TYPE'
    REWRITE = 'REWRITE'
    ROLE = 'ROLE'
    SCALAR_TYPE = 'SCALAR TYPE'
    TRIGGER = 'TRIGGER'
    TUPLE_TYPE = 'TUPLE TYPE'
    TYPE = 'TYPE'


class LinkTargetDeleteAction(s_enum.StrEnum):
    Restrict = 'Restrict'
    DeleteSource = 'DeleteSource'
    Allow = 'Allow'
    DeferredRestrict = 'DeferredRestrict'

    def to_edgeql(self) -> str:
        if self is LinkTargetDeleteAction.DeleteSource:
            return 'DELETE SOURCE'
        elif self is LinkTargetDeleteAction.DeferredRestrict:
            return 'DEFERRED RESTRICT'
        elif self is LinkTargetDeleteAction.Restrict:
            return 'RESTRICT'
        elif self is LinkTargetDeleteAction.Allow:
            return 'ALLOW'
        else:
            raise ValueError(f'unsupported enum value {self!r}')


class LinkSourceDeleteAction(s_enum.StrEnum):
    DeleteTarget = 'DeleteTarget'
    Allow = 'Allow'
    DeleteTargetIfOrphan = 'DeleteTargetIfOrphan'

    def to_edgeql(self) -> str:
        if self is LinkSourceDeleteAction.DeleteTarget:
            return 'DELETE TARGET'
        elif self is LinkSourceDeleteAction.Allow:
            return 'ALLOW'
        elif self is LinkSourceDeleteAction.DeleteTargetIfOrphan:
            return 'DELETE TARGET IF ORPHAN'
        else:
            raise ValueError(f'unsupported enum value {self!r}')


class ConfigScope(s_enum.StrEnum):

    INSTANCE = 'INSTANCE'
    DATABASE = 'DATABASE'
    SESSION = 'SESSION'
    GLOBAL = 'GLOBAL'

    def to_edgeql(self) -> str:
        if self is ConfigScope.DATABASE:
            return 'CURRENT BRANCH'
        else:
            return str(self)


class TypeTag(enum.IntEnum):
    SCALAR = 0
    TUPLE = 1
    ARRAY = 2


================================================
FILE: edb/edgeql/quote.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2013-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

import re

from .parser.grammar import keywords


_re_ident = re.compile(r'''(?x)
    [^\W\d]\w*  # alphanumeric identifier
''')

_re_ident_or_num = re.compile(r'''(?x)
    [^\W\d]\w*  # alphanumeric identifier
    |
    ([1-9]\d* | 0)  # purely integer identifier
''')


def escape_string(s: str) -> str:
    # characters escaped according to
    # https://www.edgedb.com/docs/reference/edgeql/lexical#strings
    result = s

    # escape backslash first
    result = result.replace('\\', '\\\\')

    result = result.replace('\'', '\\\'')
    result = result.replace('\b', '\\b')
    result = result.replace('\f', '\\f')
    result = result.replace('\n', '\\n')
    result = result.replace('\r', '\\r')
    result = result.replace('\t', '\\t')

    return result


def quote_literal(string: str) -> str:
    return "'" + escape_string(string) + "'"


def dollar_quote_literal(text: str) -> str:
    quote = '$$'
    qq = 0

    while quote in text:
        if qq % 16 < 10:
            qq += 10 - qq % 16

        quote = '${:x}$'.format(qq)[::-1]
        qq += 1

    return quote + text + quote


def needs_quoting(string: str, allow_reserved: bool, allow_num: bool) -> bool:
    if not string or string.startswith('@') or '::' in string:
        # some strings are illegal as identifiers and as such don't
        # require quoting
        return False

    r = _re_ident_or_num if allow_num else _re_ident
    isalnum = r.fullmatch(string)

    string = string.lower()

    is_reserved = (
        string not in {'__type__', '__std__'}
        and string in keywords.by_type[keywords.RESERVED_KEYWORD]
    )

    return (
        not isalnum
        or (not allow_reserved and is_reserved)
    )


def _quote_ident(string: str) -> str:
    return '`' + string.replace('`', '``') + '`'


def quote_ident(
    string: str,
    *,
    force: bool = False,
    allow_reserved: bool = False,
    allow_num: bool = False,
) -> str:
    if force or needs_quoting(string, allow_reserved, allow_num):
        return _quote_ident(string)
    else:
        return string


================================================
FILE: edb/edgeql/tokenizer.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations
from typing import Any, Optional, Sequence

import re
import hashlib

import edb._edgeql_parser as ql_parser

from edb import errors


TRAILING_WS_IN_CONTINUATION = re.compile(r'\\ \s+\n')


def deserialize(serialized: bytes, text: str) -> Source:
    match serialized[0]:
        case 0:
            tokens = ql_parser.unpack(serialized)
            assert isinstance(tokens, list)
            return Source(text, tokens, serialized)
        case 1:
            entry = ql_parser.unpack(serialized)
            assert isinstance(entry, ql_parser.Entry)
            return NormalizedSource(entry, text, serialized)

    raise ValueError(f"Invalid type/version byte: {serialized[0]}")


class Source:
    def __init__(
        self,
        text: str,
        tokens: list[ql_parser.OpaqueToken],
        serialized: bytes,
    ) -> None:
        self._cache_key = hashlib.blake2b(serialized).digest()
        self._text = text
        self._tokens = tokens
        self._serialized = serialized

    def text(self) -> str:
        return self._text

    def cache_key(self) -> bytes:
        return self._cache_key

    def variables(self) -> dict[str, Any]:
        return {}

    def tokens(self) -> list[ql_parser.OpaqueToken]:
        return self._tokens

    def first_extra(self) -> Optional[int]:
        return None

    def extra_counts(self) -> Sequence[int]:
        return ()

    def extra_blobs(self) -> list[bytes]:
        return []

    def extra_formatted_as_text(self) -> bool:
        return False

    def extra_type_oids(self) -> Sequence[int]:
        return ()

    def serialize(self) -> bytes:
        return self._serialized

    @staticmethod
    def from_string(text: str) -> Source:
        result = _tokenize(text)
        assert isinstance(result.out, list)
        return Source(text=text, tokens=result.out, serialized=result.pack())

    def __repr__(self):
        return f''

    def denormalized(self) -> Source:
        return self


class NormalizedSource(Source):
    def __init__(
        self,
        normalized: ql_parser.Entry,
        text: str,
        serialized: bytes,
    ) -> None:
        self._text = text
        self._cache_key = normalized.key
        self._tokens = normalized.tokens
        self._variables = normalized.get_variables()
        self._first_extra = normalized.first_extra
        self._extra_counts = normalized.extra_counts
        self._extra_blobs = normalized.extra_blobs
        self._serialized = serialized

    def text(self) -> str:
        return self._text

    def cache_key(self) -> bytes:
        return self._cache_key

    def variables(self) -> dict[str, Any]:
        return self._variables

    def tokens(self) -> list[ql_parser.OpaqueToken]:
        return self._tokens

    def first_extra(self) -> Optional[int]:
        return self._first_extra

    def extra_counts(self) -> Sequence[int]:
        return self._extra_counts

    def extra_blobs(self) -> list[bytes]:
        return self._extra_blobs

    @staticmethod
    def from_string(text: str) -> NormalizedSource:
        normalized = _normalize(text)
        return NormalizedSource(normalized, text, normalized.pack())

    def denormalized(self) -> Source:
        return Source.from_string(self._text)


def inflate_span(
    source: str, span: tuple[int, Optional[int]]
) -> tuple[ql_parser.SourcePoint, Optional[ql_parser.SourcePoint]]:
    (start, end) = span
    source_bytes = source.encode('utf-8')

    points = [start]
    if end is not None:
        points.append(end)

    points_sp = ql_parser.SourcePoint.from_offsets(source_bytes, points)

    start_sp = points_sp[0]
    if end is not None:
        end_sp = points_sp[1]
    else:
        end_sp = None
    return (start_sp, end_sp)


def inflate_position(
    source: str, span: tuple[int, Optional[int]]
) -> tuple[int, int, int, Optional[int]]:
    (start, end) = inflate_span(source, span)
    return (
        start.column,
        start.line,
        start.offset,
        end.offset if end else None,
    )


def line_col_to_source_point(
    source: str,
    line: int,  # zero-based
    col: int,  # zero-based, in utf16 code points
) -> ql_parser.SourcePoint:
    points = ql_parser.SourcePoint.from_lines_cols(
        source.encode('utf-8'), [(line, col)]
    )
    return points[0]


def _tokenize(eql: str) -> ql_parser.ParserResult:
    result = ql_parser.tokenize(eql)

    if len(result.errors) > 0:
        # TODO: emit multiple errors
        error = result.errors[0]

        message, span, hint, details = error
        position = inflate_position(eql, span)

        hint = _derive_hint(eql, message, position) or hint
        raise errors.EdgeQLSyntaxError(
            message, position=position, hint=hint, details=details
        )

    return result


def _normalize(eql: str) -> ql_parser.Entry:
    try:
        return ql_parser.normalize(eql)
    except ql_parser.SyntaxError as e:
        message, span, hint, details = e.args
        position = inflate_position(eql, span)

        hint = _derive_hint(eql, message, position) or hint
        raise errors.EdgeQLSyntaxError(
            message, position=position, hint=hint, details=details
        ) from e


def _derive_hint(
    input: str,
    message: str,
    position: tuple[int, int, int, Optional[int]],
) -> Optional[str]:
    _, _, off, _ = position

    if message.endswith(
        r"invalid string literal: invalid escape sequence '\ '"
    ):
        if TRAILING_WS_IN_CONTINUATION.search(input[off:]):
            return "consider removing trailing whitespace"
    return None


================================================
FILE: edb/edgeql/tracer.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2015-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

# Import specific things to avoid name clashes
from typing import (Generator, Mapping, Optional,
                    Iterable, TypeVar, Sequence,
                    AbstractSet)

import functools

from contextlib import contextmanager
from edb import errors
from edb.common import parsing
from edb.schema import links as s_links
from edb.schema import name as sn
from edb.schema import objects as so
from edb.schema import pointers as s_pointers
from edb.schema import schema as s_schema
from edb.schema import sources as s_sources
from edb.schema import types as s_types

from edb.edgeql import ast as qlast


NamedObject_T = TypeVar("NamedObject_T", bound="NamedObject")


class NamedObject:
    '''Generic tracing object with an explicit name.'''

    def __init__(self, name: sn.QualName) -> None:
        self.name = name

    def get_name(self, schema: s_schema.Schema) -> sn.QualName:
        return self.name

    @classmethod
    def get_schema_class_displayname(cls) -> str:
        return cls.__name__.lower()


SentinelObject = NamedObject(
    name=sn.QualName(module='__unknown__', name='__unknown__'),
)


ObjectLike = NamedObject | so.Object


class Function(NamedObject):
    pass


class Constraint(NamedObject):
    pass


class ConcreteConstraint(NamedObject):
    pass


class Annotation(NamedObject):
    pass


class AnnotationValue(NamedObject):
    pass


class Global(NamedObject):
    pass


class Permission(NamedObject):
    pass


class Index(NamedObject):
    pass


class ConcreteIndex(NamedObject):
    pass


class Field(NamedObject):
    pass


class Type(NamedObject):
    def is_scalar(self) -> bool:
        return False

    def is_object_type(self) -> bool:
        return False


class ScalarType(Type):
    def is_scalar(self) -> bool:
        return True


TypeLike = Type | s_types.Type


T = TypeVar('T')


class UnqualObjectIndex[T]:

    def __init__(self, items: Mapping[sn.UnqualName, T]) -> None:
        self._items = items

    def items(
        self,
        schema: s_schema.Schema,
    ) -> Iterable[tuple[sn.UnqualName, T]]:
        return self._items.items()


class Source(NamedObject):

    pointers: dict[sn.UnqualName, s_pointers.Pointer | Pointer]

    '''Abstract type that mocks the s_sources.Source for tracing purposes.'''

    def __init__(self, name: sn.QualName) -> None:
        super().__init__(name)
        self.pointers = {}

    def maybe_get_ptr(
        self,
        schema: s_schema.Schema,
        name: sn.UnqualName,
    ) -> Optional[s_pointers.Pointer | Pointer]:
        return self.pointers.get(name)

    def getptr(
        self,
        schema: s_schema.Schema,
        name: sn.UnqualName,
    ) -> s_pointers.Pointer | Pointer:
        ptr = self.maybe_get_ptr(schema, name)
        if ptr is None:
            raise AssertionError(f'{self.name} has no link or property {name}')
        return ptr

    def get_pointers(
        self,
        schema: s_schema.Schema,
    ) -> UnqualObjectIndex[s_pointers.Pointer | Pointer]:
        return UnqualObjectIndex(self.pointers)


Source_T = TypeVar("Source_T", bound="Source")
SourceLike = Source | s_sources.Source
SourceLike_T = TypeVar("SourceLike_T", bound="SourceLike")


class ObjectType(Type, Source):

    def is_pointer(self) -> bool:
        return False

    def is_scalar(self) -> bool:
        return False

    def is_object_type(self) -> bool:
        return True


class Alias(ObjectType):
    pass


class CompositeType(Type):

    types: list[Type | CompositeType | s_types.Type]

    def __init__(
        self,
        types: list[Type | CompositeType | s_types.Type],
    ) -> None:
        self.types = types


class UnionType(CompositeType):

    def __init__(
        self,
        types: list[Type | CompositeType | s_types.Type],
    ) -> None:
        super().__init__(types)

    def get_name(self, schema: s_schema.Schema) -> sn.QualName:
        component_ids = sorted(str(t.get_name(schema)) for t in self.types)
        nqname = f"({' | '.join(component_ids)})"
        return sn.QualName(name=nqname, module='__derived__')

    def is_object_type(self) -> bool:
        return True


class IntersectionType(CompositeType):

    def __init__(
        self,
        types: list[Type | CompositeType | s_types.Type],
    ) -> None:
        super().__init__(types)

    def get_name(self, schema: s_schema.Schema) -> sn.QualName:
        component_ids = sorted(str(t.get_name(schema)) for t in self.types)
        nqname = f"({' & '.join(component_ids)})"
        return sn.QualName(name=nqname, module='__derived__')

    def is_object_type(self) -> bool:
        return True


class Pointer(Source):

    def __init__(
        self,
        name: sn.QualName,
        *,
        source: Optional[SourceLike] = None,
        target: Optional[TypeLike] = None,
        target_expr: Optional[qlast.Expr] = None,
    ) -> None:
        super().__init__(name)
        self.source = source
        self.target = target
        self.target_expr = target_expr

    def is_pointer(self) -> bool:
        return True

    def is_property(
        self,
        schema: s_schema.Schema,
    ) -> bool:
        raise NotImplementedError

    def maybe_get_ptr(
        self,
        schema: s_schema.Schema,
        name: sn.UnqualName,
    ) -> Optional[s_pointers.Pointer | Pointer]:
        if (not (res := super().maybe_get_ptr(schema, name))
                and isinstance(self.target, (Source, s_sources.Source))):
            res = self.target.maybe_get_ptr(schema, name)
        return res

    def get_target(
        self,
        schema: s_schema.Schema,
    ) -> Optional[TypeLike]:
        return self.target

    def get_source(
        self,
        schema: s_schema.Schema,
    ) -> Optional[SourceLike]:
        return self.source


class Property(Pointer):
    def is_property(
        self,
        schema: s_schema.Schema,
    ) -> bool:
        return True


class Link(Pointer):
    def is_property(
        self,
        schema: s_schema.Schema,
    ) -> bool:
        return False


class UnknownPointer(Pointer):
    def is_property(
        self,
        schema: s_schema.Schema,
    ) -> bool:
        return False

    @classmethod
    def get_schema_class_displayname(cls) -> str:
        return 'link or property'


class AccessPolicy(NamedObject):

    def __init__(
        self,
        name: sn.QualName,
        *,
        source: Optional[SourceLike] = None,
    ) -> None:
        super().__init__(name)
        self.source = source

    def get_source(
        self,
        schema: s_schema.Schema,
    ) -> Optional[SourceLike]:
        return self.source


class Trigger(NamedObject):

    def __init__(
        self,
        name: sn.QualName,
        *,
        source: Optional[SourceLike] = None,
    ) -> None:
        super().__init__(name)
        self.source = source

    def get_source(
        self,
        schema: s_schema.Schema,
    ) -> Optional[SourceLike]:
        return self.source


class Rewrite(NamedObject):

    def __init__(
        self,
        name: sn.QualName,
        *,
        source: Optional[SourceLike] = None,
    ) -> None:
        super().__init__(name)
        self.source = source

    def get_source(
        self,
        schema: s_schema.Schema,
    ) -> Optional[SourceLike]:
        return self.source


def qualify_name(name: sn.QualName, qual: str) -> sn.QualName:
    return sn.QualName(name.module, f'{name.name}@{qual}')


def trace_refs(
    qltree: qlast.Base,
    *,
    schema: s_schema.Schema,
    anchors: Optional[Mapping[str, sn.QualName]] = None,
    path_prefix: Optional[sn.QualName] = None,
    module: str,
    objects: dict[sn.QualName, Optional[ObjectLike]],
    pointers: Mapping[sn.UnqualName, set[sn.QualName]],
    params: Mapping[str, qlast.TypeExpr],
    local_modules: AbstractSet[str]
) -> tuple[frozenset[sn.QualName], frozenset[sn.QualName]]:

    """Return a list of schema item names used in an expression.

    First set is strong deps, second is weak.
    """

    ctx = TracerContext(
        schema=schema,
        module=module,
        objects=objects,
        pointers=pointers,
        anchors=anchors or {},
        path_prefix=path_prefix,
        modaliases={},
        params=params,
        visited=set(),
        local_modules=local_modules,
    )
    trace(qltree, ctx=ctx)
    return frozenset(ctx.refs), frozenset(ctx.weak_refs)


def resolve_name(
    ref: qlast.ObjectRef,
    *,
    current_module: str,
    schema: s_schema.Schema,
    objects: dict[sn.QualName, Optional[ObjectLike]],
    modaliases: Optional[dict[Optional[str], str]],
    local_modules: AbstractSet[str],
    declaration: bool=False,
) -> sn.QualName:
    """Resolve a name into a fully-qualified one.

    This takes into account the current module and modaliases.

    This function mostly mirrors schema.lookup
    except:
    - If no module and no default module was set, try the current module
    - When searching in std, ensure module is not a local module
    - If no result found, return a name with the best modname available
    """

    def exists(name: sn.QualName) -> bool:
        return (
            objects.get(name) is not None
            or schema.get(name, default=None, type=so.Object) is not None
        )

    module = ref.module
    orig_module = module

    # Apply module aliases
    module = s_schema.apply_module_aliases(module, modaliases)
    no_std = declaration

    # Check if something matches the name
    if module is not None:
        fqname = sn.QualName(module=module, name=ref.name)
        if exists(fqname):
            return fqname

    elif orig_module is None:
        # Look for name in current module
        fqname = sn.QualName(module=current_module, name=ref.name)
        if exists(fqname):
            return fqname

    # Try something in std
    if not no_std:
        # If module == None, look in std
        if orig_module is None:
            mod_name = 'std'
            fqname = sn.QualName(mod_name, ref.name)
            if exists(fqname):
                return fqname

        # Ensure module is not a local module.
        # Then try the module as part of std.
        if module and module not in local_modules:
            mod_name = f'std::{module}'
            fqname = sn.QualName(mod_name, ref.name)
            if exists(fqname):
                return fqname

    # Just pick the best module name available
    return sn.QualName(
        module=module or orig_module or current_module,
        name=ref.name,
    )


class TracerContext:
    def __init__(
        self,
        *,
        schema: s_schema.Schema,
        module: str,
        objects: dict[sn.QualName, Optional[ObjectLike]],
        pointers: Mapping[sn.UnqualName, set[sn.QualName]],
        anchors: Mapping[str, sn.QualName],
        path_prefix: Optional[sn.QualName],
        modaliases: dict[Optional[str], str],
        params: Mapping[str, qlast.TypeExpr],
        visited: set[s_pointers.Pointer | Pointer],
        local_modules: AbstractSet[str],
    ) -> None:
        self.schema = schema
        self.refs: set[sn.QualName] = set()
        self.weak_refs: set[sn.QualName] = set()
        self.module = module
        self.objects = objects
        self.pointers = pointers
        self.anchors = anchors
        self.path_prefix = path_prefix
        self.modaliases = modaliases
        self.params = params
        self.local_modules = local_modules
        self.visited = visited

    def get_ref_name(self, ref: qlast.BaseObjectRef) -> sn.QualName:
        # We don't actually expect to handle anything other than
        # ObjectRef here.
        assert isinstance(ref, qlast.ObjectRef)

        return resolve_name(
            ref,
            current_module=self.module,
            schema=self.schema,
            objects=self.objects,
            modaliases=self.modaliases,
            local_modules=self.local_modules,
        )

    def get_ref_name_startswith(self, ref: qlast.ObjectRef) -> set[sn.QualName]:
        refs = set()
        prefixes = set()

        if ref.module:
            # replace the module alias with the real name
            module = self.modaliases.get(ref.module, ref.module)
            prefixes.add(f'{module}::{ref.name}')
            prefixes.add(f'std::{module}::{ref.name}')
        else:
            prefixes.add(f'{self.module}::{ref.name}')
            prefixes.add(f'std::{ref.name}')

        for objname in self.objects.keys():
            short_name = str(objname).split('@@', 1)[0]
            if short_name in prefixes:
                refs.add(objname)

        return refs


def _fork_context(ctx: TracerContext) -> TracerContext:
    nctx = TracerContext(
        schema=ctx.schema,
        module=ctx.module,
        objects=dict(ctx.objects),
        pointers=ctx.pointers,
        anchors=ctx.anchors,
        path_prefix=ctx.path_prefix,
        modaliases=dict(ctx.modaliases),
        params=ctx.params,
        visited=ctx.visited,
        local_modules=ctx.local_modules,
    )
    nctx.refs = ctx.refs
    nctx.weak_refs = ctx.weak_refs

    return nctx


@contextmanager
def alias_context(
    ctx: TracerContext,
    aliases: Optional[
        Sequence[qlast.Alias]],
) -> Generator[TracerContext, None, None]:
    ctx = _fork_context(ctx)

    for alias in (aliases or ()):
        # module and modalias in ctx needs to be amended
        if isinstance(alias, qlast.ModuleAliasDecl):
            if alias.alias:
                ctx.modaliases[alias.alias] = alias.module
            else:
                # default module
                ctx.module = alias.module

        elif isinstance(alias, qlast.AliasedExpr):
            obj = trace(alias.expr, ctx=ctx)
            # Regardless of whether tracing the expression produces an
            # object, record the alias.
            ctx.objects[sn.QualName('__alias__', alias.alias)] = obj

    try:
        yield ctx
    finally:
        # refs are already updated
        pass


@contextmanager
def result_alias_context(
    ctx: TracerContext,
    node: qlast.ReturningQuery | qlast.SubjectQuery,
    obj: Optional[ObjectLike],
) -> Generator[TracerContext, None, None]:

    alias: Optional[str] = None
    if isinstance(node, qlast.SelectQuery):
        alias = node.result_alias
    elif isinstance(node, qlast.GroupQuery):
        alias = node.subject_alias

    # potentially SELECT uses an alias for the main result
    if obj is not None and alias:
        nctx = TracerContext(
            schema=ctx.schema,
            module=ctx.module,
            objects=dict(ctx.objects),
            pointers=ctx.pointers,
            anchors=ctx.anchors,
            path_prefix=ctx.path_prefix,
            modaliases=ctx.modaliases,
            params=ctx.params,
            visited=ctx.visited,
            local_modules=ctx.local_modules,
        )
        # use the same refs set
        nctx.refs = ctx.refs
        nctx.objects[sn.QualName('__alias__', alias)] = obj
    else:
        nctx = ctx

    try:
        yield nctx
    finally:
        # refs are already updated
        pass


@functools.singledispatch
def trace(
    node: Optional[qlast.Base],
    *,
    ctx: TracerContext,
) -> Optional[ObjectLike]:
    raise NotImplementedError(f"do not know how to trace {node!r}")


@trace.register
def trace_none(node: None, *, ctx: TracerContext) -> None:
    pass


@trace.register
def trace_Constant(node: qlast.BaseConstant, *, ctx: TracerContext) -> None:
    pass


@trace.register
def trace_QueryParameter(
    node: qlast.QueryParameter, *, ctx: TracerContext
) -> None:
    raise errors.SchemaError(
        'query parameters are not allowed in schemas',
        span=node.span,
    )


@trace.register
def trace_FunctionParameter(
    node: qlast.FunctionParameter, *, ctx: TracerContext
) -> None:
    raise AssertionError(
        'function parameters are expected to be substituted for paths '
        'in schemas',
    )


@trace.register
def trace_Array(node: qlast.Array, *, ctx: TracerContext) -> None:
    for el in node.elements:
        trace(el, ctx=ctx)


@trace.register
def trace_StrInterpFragment(
    node: qlast.StrInterpFragment, *, ctx: TracerContext
) -> None:
    trace(node.expr, ctx=ctx)


@trace.register
def trace_StrInterp(node: qlast.StrInterp, *, ctx: TracerContext) -> None:
    for el in node.interpolations:
        trace(el, ctx=ctx)


@trace.register
def trace_Set(node: qlast.Set, *, ctx: TracerContext) -> None:
    for el in node.elements:
        trace(el, ctx=ctx)


@trace.register
def trace_Tuple(node: qlast.Tuple, *, ctx: TracerContext) -> None:
    for el in node.elements:
        trace(el, ctx=ctx)


@trace.register
def trace_NamedTuple(node: qlast.NamedTuple, *, ctx: TracerContext) -> None:
    for el in node.elements:
        trace(el.val, ctx=ctx)


@trace.register
def trace_BinOp(node: qlast.BinOp, *, ctx: TracerContext) -> None:
    trace(node.left, ctx=ctx)
    trace(node.right, ctx=ctx)


@trace.register
def trace_UnaryOp(node: qlast.UnaryOp, *, ctx: TracerContext) -> None:
    trace(node.operand, ctx=ctx)


@trace.register
def trace_Detached(
    node: qlast.DetachedExpr, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    # DETACHED works with partial paths same as its inner expression.
    return trace(node.expr, ctx=ctx)


@trace.register
def trace_Global(
    node: qlast.GlobalExpr, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    refname = ctx.get_ref_name(node.name)
    if refname in ctx.objects:
        ctx.refs.add(refname)
        tip = ctx.objects[refname]
    else:
        tip = ctx.schema.get(refname, span=node.span)
    return tip


def check_type_exists(
    typename: sn.QualName,
    ctx: TracerContext,
    span: Optional[parsing.Span],
    *,
    hint: Optional[str] = None,
) -> None:
    if typename in ctx.objects:
        return

    try:
        # Check if the typename is already in the schema
        ctx.schema.get(typename, type=s_types.Type, span=span)
    except errors.InvalidReferenceError as e:
        if hint and not e.hint:
            e.set_hint_and_details(hint, e.details)
        raise e


@trace.register
def trace_TypeCast(node: qlast.TypeCast, *, ctx: TracerContext) -> None:
    trace(node.expr, ctx=ctx)
    if isinstance(node.type, qlast.TypeName):
        if not node.type.subtypes:
            typename: sn.QualName = ctx.get_ref_name(node.type.maintype)
            check_type_exists(typename, ctx, node.type.span)
            ctx.refs.add(typename)


@trace.register
def trace_IsOp(node: qlast.IsOp, *, ctx: TracerContext) -> None:
    trace(node.left, ctx=ctx)
    if isinstance(node.right, qlast.TypeName):
        if not node.right.subtypes:
            typename: sn.QualName = ctx.get_ref_name(node.right.maintype)

            hint: Optional[str] = None
            if typename.name.lower() in ['null', 'none']:
                hint = (
                    'Did you mean to use `exists` to check if a set is empty?'
                )
            check_type_exists(typename, ctx, node.right.span, hint=hint)

            ctx.refs.add(typename)


@trace.register
def trace_Introspect(node: qlast.Introspect, *, ctx: TracerContext) -> None:
    if isinstance(node.type, qlast.TypeName):
        if not node.type.subtypes:
            typename: sn.QualName = ctx.get_ref_name(node.type.maintype)
            check_type_exists(typename, ctx, node.type.span)
            ctx.refs.add(typename)


@trace.register
def trace_FunctionCall(node: qlast.FunctionCall, *, ctx: TracerContext) -> None:

    if isinstance(node.func, tuple):
        fname = qlast.ObjectRef(module=node.func[0], name=node.func[1])
    else:
        fname = qlast.ObjectRef(name=node.func)
    # The function call is dependent on the function actually being
    # present, so we add all variations of that function name to the
    # dependency list.

    names = ctx.get_ref_name_startswith(fname)
    ctx.refs.update(names)

    for arg in node.args:
        trace(arg, ctx=ctx)
    for arg in node.kwargs.values():
        trace(arg, ctx=ctx)


@trace.register
def trace_Indirection(node: qlast.Indirection, *, ctx: TracerContext) -> None:
    for indirection in node.indirection:
        trace(indirection, ctx=ctx)
    trace(node.arg, ctx=ctx)


@trace.register
def trace_Index(node: qlast.Index, *, ctx: TracerContext) -> None:
    trace(node.index, ctx=ctx)


@trace.register
def trace_Slice(node: qlast.Slice, *, ctx: TracerContext) -> None:
    trace(node.start, ctx=ctx)
    trace(node.stop, ctx=ctx)


@trace.register
def trace_Path(
    node: qlast.Path,
    *,
    ctx: TracerContext,
) -> Optional[ObjectLike]:
    tip: Optional[ObjectLike] = None
    ptr: Optional[Pointer | s_pointers.Pointer] = None
    plen = len(node.steps)

    # HACK: This isn't very smart, and can't properly track types
    # through arbitrary expressions. To try to mitigate the damage
    # from this, when we have a pointer step but don't know the type,
    # we track *weak* references to all pointers with that name.
    # This won't always work (if there is a tangle of cyclic weak deps),
    # but it works pretty well.

    for i, step in enumerate(node.steps):
        if isinstance(step, qlast.ObjectRef):
            # the ObjectRef without a module may be referring to an
            # aliased expression
            aname = sn.QualName('__alias__', step.name)
            if not step.module and aname in ctx.objects:
                tip = ctx.objects[aname]

            elif not step.module and step.name in ctx.params:
                param_type = ctx.params[step.name]
                if (
                    isinstance(param_type, qlast.TypeName)
                    and isinstance(param_type.maintype, qlast.PseudoObjectRef)
                ):
                    # Pretend pseudotypes (eg. `anytype`) have a fully
                    # qualified name.
                    refname = sn.QualName('std', param_type.maintype.name)
                    ctx.refs.add(refname)

                    tip = ctx.objects[refname]

                else:
                    tip = _resolve_type_expr(param_type, ctx=ctx)

            else:
                refname = ctx.get_ref_name(step)
                if refname in ctx.objects:
                    ctx.refs.add(refname)
                    tip = ctx.objects[refname]
                else:
                    tip = ctx.schema.get(refname, span=step.span)

        elif isinstance(step, qlast.Ptr):
            pname = sn.UnqualName(step.name)

            if i == 0:
                # Abbreviated path.
                if ctx.path_prefix in ctx.objects:
                    tip = ctx.objects[ctx.path_prefix]
                    if isinstance(tip, Pointer):
                        ptr = tip
                else:
                    # We can't reason about this path.
                    # Do a weak dependency on anything with the same name.
                    ctx.weak_refs.update(ctx.pointers.get(pname, ()))

            if step.type == 'property':
                if ptr is None:
                    # This is either a computable def  or unknown link, bail.
                    # Do a weak dependency on anything with the same name.
                    ctx.weak_refs.update(ctx.pointers.get(pname, ()))
                    tip = None

                elif isinstance(ptr, (s_links.Link, Pointer)):
                    lprop = ptr.maybe_get_ptr(
                        ctx.schema,
                        pname,
                    )
                    if lprop is None:
                        # Invalid link property reference, bail.
                        return None

                    if (isinstance(lprop, Pointer) and
                            lprop.source is not None):
                        src = lprop.source
                        src_name = src.get_name(ctx.schema)
                        if (isinstance(src, Pointer) and
                                src.source is not None):
                            src_src_name = src.source.get_name(ctx.schema)
                            source_name = qualify_name(
                                src_src_name, src_name.name)
                        else:
                            source_name = src_name
                        ctx.refs.add(qualify_name(source_name, step.name))
            else:
                if step.direction == '<':
                    if plen > i + 1 and isinstance(node.steps[i + 1],
                                                   qlast.TypeIntersection):
                        # A reverse link traversal with a type intersection,
                        # process it on the next step.
                        pass
                    else:
                        # No type intersection, so the only type that
                        # it can be is "Object", which is trivial.
                        # However, we need to make it dependent on
                        # every link of the same name now.
                        for fqname in ctx.pointers.get(pname, ()):
                            obj = ctx.objects.get(fqname)

                            # Ignore what appears to not be a link
                            # with the right name.
                            if (isinstance(obj, (s_pointers.Pointer,
                                                 Pointer)) and
                                fqname.name.split('@', 1)[1] ==
                                    step.name):

                                target = obj.get_target(ctx.schema)
                                # Ignore scalars, but include other
                                # computables to produce better error
                                # messages.
                                if (target is None or
                                        not target.is_scalar()):
                                    # Record link with matching short
                                    # name.
                                    ctx.refs.add(fqname)

                        tip = ptr = None
                else:
                    if isinstance(tip, (Source, s_sources.Source)):
                        ptr = tip.maybe_get_ptr(
                            ctx.schema, sn.UnqualName(step.name)
                        )
                        if ptr is None:
                            # Invalid pointer reference, bail.
                            return None
                        else:
                            ptr_source = ptr.get_source(ctx.schema)

                        if ptr_source is not None:
                            sname = ptr_source.get_name(ctx.schema)
                            assert isinstance(sname, sn.QualName)
                            ctx.refs.add(qualify_name(sname, step.name))
                            tip = ptr.get_target(ctx.schema)

                            if tip is None:
                                if ptr in ctx.visited:
                                    # Possibly recursive definition, bail out.
                                    return None

                                # This can only be Pointer that didn't
                                # infer the target type yet.
                                assert isinstance(ptr, Pointer)
                                # We haven't computed the target yet,
                                # so try computing it now.
                                ctx.visited.add(ptr)

                                target_ctx = _fork_context(ctx)
                                target_ctx.path_prefix = sname
                                ptr_target = trace(
                                    ptr.target_expr, ctx=target_ctx
                                )

                                if isinstance(ptr_target, (Type,
                                                           s_types.Type)):
                                    tip = ptr.target = ptr_target

                        else:
                            # Can't figure out the new tip, so we bail.
                            return None

                    else:
                        # We can't reason about this path.
                        # Do a weak dependency on anything with the same name.
                        ctx.weak_refs.update(ctx.pointers.get(pname, ()))
                        tip = ptr = None

        elif isinstance(step, qlast.TypeIntersection):
            # This tip is determined from the type in the type
            # intersection, which is valid in the general case, but
            # there's a special case that needs to be potentially
            # handled for backward links.
            tip = _resolve_type_expr(step.type, ctx=ctx)
            prev_step = node.steps[i - 1]
            if isinstance(prev_step, qlast.Ptr):
                if prev_step.direction == '<':
                    if isinstance(tip, (s_sources.Source, ObjectType)):
                        ptr = tip.maybe_get_ptr(
                            ctx.schema, sn.UnqualName(prev_step.name)
                        )
                        if ptr is None:
                            # Invalid pointer reference, bail.
                            return None

                        if isinstance(tip, Type):
                            tip_name = tip.get_name(ctx.schema)
                            ctx.refs.add(qualify_name(tip_name, prev_step.name))

        elif isinstance(step, qlast.Splat):
            if step.type is not None:
                _resolve_type_expr(step.type, ctx=ctx)
            if step.intersection is not None:
                _resolve_type_expr(step.intersection.type, ctx=ctx)

        else:
            tr = trace(step, ctx=ctx)
            tip = ptr = None
            if tr is not None:
                tip = tr
                if isinstance(tip, Pointer):
                    ptr = tip

    return tip


@trace.register
def trace_Anchor(
    node: qlast.Anchor, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    if name := ctx.anchors.get(node.name):
        return ctx.objects[name]
    return None


def _resolve_type_expr(
    texpr: qlast.TypeExpr,
    *,
    ctx: TracerContext,
) -> TypeLike:

    if isinstance(texpr, qlast.TypeName):
        if texpr.subtypes and isinstance(texpr.maintype, qlast.ObjectRef):
            return Type(
                name=sn.QualName(
                    module='__coll__',
                    name=texpr.maintype.name,
                ),
            )
        else:
            refname = ctx.get_ref_name(texpr.maintype)
            local_obj = ctx.objects.get(refname)
            obj: TypeLike
            if local_obj is None:
                obj = ctx.schema.get(
                    refname, type=s_types.Type, span=texpr.span)
            else:
                assert isinstance(local_obj, Type)
                obj = local_obj
                ctx.refs.add(refname)

            return obj

    elif isinstance(texpr, qlast.TypeOp):

        left = _resolve_type_expr(texpr.left, ctx=ctx)
        right = _resolve_type_expr(texpr.right, ctx=ctx)

        ThisCompositeType: type[CompositeType] = (
            UnionType
            if texpr.op == qlast.TypeOpName.OR else
            IntersectionType
        )

        if isinstance(left, ThisCompositeType):
            if isinstance(right, ThisCompositeType):
                return ThisCompositeType(left.types + right.types)
            else:
                return ThisCompositeType(left.types + [right])
        else:
            if isinstance(right, ThisCompositeType):
                return ThisCompositeType([left] + right.types)
            else:
                return ThisCompositeType([left, right])

    else:
        raise NotImplementedError(
            f'unsupported type expression: {texpr!r}'
        )


@trace.register
def trace_TypeIntersection(
    node: qlast.TypeIntersection, *, ctx: TracerContext
) -> None:
    trace(node.type, ctx=ctx)


@trace.register
def trace_TypeOf(node: qlast.TypeOf, *, ctx: TracerContext) -> None:
    trace(node.expr, ctx=ctx)


@trace.register
def trace_TypeName(node: qlast.TypeName, *, ctx: TracerContext) -> None:
    if node.subtypes:
        for st in node.subtypes:
            trace(st, ctx=ctx)
    elif isinstance(node.maintype, qlast.ObjectRef):
        tref = node.maintype
        if tref.module:
            fq_name = sn.QualName(module=tref.module, name=tref.name)
        else:
            fq_name = sn.QualName(module=ctx.module, name=tref.name)
            if fq_name not in ctx.objects:
                std_name = sn.QualName(module="std", name=tref.name)
                if ctx.schema.get(std_name, default=None) is not None:
                    fq_name = std_name
        ctx.refs.add(fq_name)


@trace.register
def trace_TypeOp(node: qlast.TypeOp, *, ctx: TracerContext) -> None:
    trace(node.left, ctx=ctx)
    trace(node.right, ctx=ctx)


@trace.register
def trace_IfElse(node: qlast.IfElse, *, ctx: TracerContext) -> None:
    trace(node.if_expr, ctx=ctx)
    trace(node.else_expr, ctx=ctx)
    trace(node.condition, ctx=ctx)


@trace.register
def trace_Shape(
    node: qlast.Shape, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    tip = trace(node.expr, ctx=ctx)
    if isinstance(node.expr, qlast.Path):
        orig_prefix = ctx.path_prefix
        if tip is not None:
            tip_name = tip.get_name(ctx.schema)
            assert isinstance(tip_name, sn.QualName)
            ctx.path_prefix = tip_name
        else:
            ctx.path_prefix = None

    for element in node.elements:
        trace(element, ctx=ctx)

    if isinstance(node.expr, qlast.Path):
        ctx.path_prefix = orig_prefix

    return tip


@trace.register
def trace_ShapeElement(node: qlast.ShapeElement, *, ctx: TracerContext) -> None:
    trace(node.expr, ctx=ctx)
    if node.elements:
        for element in node.elements:
            trace(element, ctx=ctx)
    trace(node.where, ctx=ctx)
    if node.orderby:
        for sortexpr in node.orderby:
            trace(sortexpr, ctx=ctx)
    trace(node.offset, ctx=ctx)
    trace(node.limit, ctx=ctx)
    trace(node.compexpr, ctx=ctx)


def _update_path_prefix(tip: Optional[ObjectLike], ctx: TracerContext) -> None:
    if tip is not None:
        tip_name = tip.get_name(ctx.schema)
        assert isinstance(tip_name, sn.QualName)
        ctx.path_prefix = tip_name
    else:
        ctx.path_prefix = None


@trace.register
def trace_Select(
    node: qlast.SelectQuery, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    with alias_context(ctx, node.aliases) as ctx:
        tip = trace(node.result, ctx=ctx)
        _update_path_prefix(tip, ctx=ctx)

        # potentially SELECT uses an alias for the main result
        with result_alias_context(ctx, node, tip) as nctx:
            if node.where is not None:
                trace(node.where, ctx=nctx)
            if node.orderby:
                for expr in node.orderby:
                    trace(expr, ctx=nctx)
            if node.offset is not None:
                trace(node.offset, ctx=nctx)
            if node.limit is not None:
                trace(node.limit, ctx=nctx)

        return tip


def trace_GroupingAtom(node: qlast.GroupingAtom, *, ctx: TracerContext) -> None:
    if isinstance(node, qlast.ObjectRef):
        trace(qlast.Path(steps=[node]), ctx=ctx)
    elif isinstance(node, qlast.Path):
        trace(node, ctx=ctx)
    else:
        assert isinstance(node, qlast.GroupingIdentList)
        for el in node.elements:
            trace_GroupingAtom(el, ctx=ctx)


@trace.register
def trace_GroupingSimple(
    node: qlast.GroupingSimple, *, ctx: TracerContext
) -> None:
    trace_GroupingAtom(node.element, ctx=ctx)


@trace.register
def trace_GroupingSets(node: qlast.GroupingSets, *, ctx: TracerContext) -> None:
    for s in node.sets:
        trace(s, ctx=ctx)


@trace.register
def trace_GroupingOperation(
    node: qlast.GroupingOperation, *, ctx: TracerContext
) -> None:
    for s in node.elements:
        trace(s, ctx=ctx)


@trace.register
def trace_Group(
    node: qlast.GroupQuery, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    return _trace_GroupQuery(node, ctx=ctx)


@trace.register
def trace_InternalGroupQuery(
    node: qlast.InternalGroupQuery, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    return _trace_GroupQuery(node, ctx=ctx)


def _trace_GroupQuery(
    node: qlast.GroupQuery | qlast.InternalGroupQuery, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    with alias_context(ctx, node.aliases) as ctx:
        tip = trace(node.subject, ctx=ctx)
        if tip is not None:
            tip_name = tip.get_name(ctx.schema)
            assert isinstance(tip_name, sn.QualName)
            ctx.path_prefix = tip_name

        # potentially GROUP uses an alias for the main result
        with result_alias_context(ctx, node, tip) as nctx:
            with alias_context(nctx, node.using) as byctx:
                for by_el in node.by:
                    trace(by_el, ctx=byctx)

        if isinstance(node, qlast.InternalGroupQuery):
            with alias_context(nctx, node.using) as byctx:
                ctx.objects[sn.QualName('__alias__', node.group_alias)] = (
                    SentinelObject)
                if node.grouping_alias:
                    ctx.objects[
                        sn.QualName('__alias__', node.grouping_alias)] = (
                            SentinelObject)
                trace(node.result, ctx=byctx)

        return tip


@trace.register
def trace_SortExpr(node: qlast.SortExpr, *, ctx: TracerContext) -> None:
    trace(node.path, ctx=ctx)


@trace.register
def trace_InsertQuery(node: qlast.InsertQuery, *, ctx: TracerContext) -> None:
    with alias_context(ctx, node.aliases) as ctx:
        if node.unless_conflict:
            trace(node.unless_conflict[0], ctx=ctx)
            trace(node.unless_conflict[1], ctx=ctx)

        tip = trace(qlast.Path(steps=[node.subject]), ctx=ctx)
        _update_path_prefix(tip, ctx=ctx)

        for element in node.shape:
            trace(element, ctx=ctx)


@trace.register
def trace_UpdateQuery(
    node: qlast.UpdateQuery, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    with alias_context(ctx, node.aliases) as ctx:
        tip = trace(node.subject, ctx=ctx)
        _update_path_prefix(tip, ctx=ctx)

        # potentially UPDATE uses an alias for the main result
        with result_alias_context(ctx, node, tip) as nctx:
            for element in node.shape:
                trace(element, ctx=nctx)

            trace(node.where, ctx=nctx)

        return tip


@trace.register
def trace_DeleteQuery(
    node: qlast.DeleteQuery, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    with alias_context(ctx, node.aliases) as ctx:
        tip = trace(node.subject, ctx=ctx)
        _update_path_prefix(tip, ctx=ctx)

        # potentially DELETE uses an alias for the main result
        with result_alias_context(ctx, node, tip) as nctx:
            if node.where is not None:
                trace(node.where, ctx=nctx)
            if node.orderby:
                for expr in node.orderby:
                    trace(expr, ctx=nctx)
            if node.offset is not None:
                trace(node.offset, ctx=nctx)
            if node.limit is not None:
                trace(node.limit, ctx=nctx)

        return tip


@trace.register
def trace_For(
    node: qlast.ForQuery, *, ctx: TracerContext
) -> Optional[ObjectLike]:
    with alias_context(ctx, node.aliases) as ctx:
        obj = trace(node.iterator, ctx=ctx)
        if obj is None:
            obj = SentinelObject
        ctx.objects[sn.QualName('__alias__', node.iterator_alias)] = obj
        tip = trace(node.result, ctx=ctx)

        return tip


@trace.register
def trace_DescribeStmt(
    node: qlast.DescribeStmt,
    *,
    ctx: TracerContext,
) -> None:

    if isinstance(node.object, qlast.ObjectRef):
        fq_name = ctx.get_ref_name(node.object)
        ctx.refs.add(fq_name)


@trace.register
def trace_ExplainStmt(
    node: qlast.ExplainStmt,
    *,
    ctx: TracerContext,
) -> None:
    pass


@trace.register
def trace_AdministerStmt(
    node: qlast.AdministerStmt,
    *,
    ctx: TracerContext,
) -> None:
    pass


@trace.register
def trace_Placeholder(
    node: qlast.Placeholder,
    *,
    ctx: TracerContext,
) -> None:
    pass


================================================
FILE: edb/edgeql/utils.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2015-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

import copy
import itertools
from typing import Any, Optional, Mapping

from edb import errors
from edb.common import ast
from edb.schema import schema as s_schema
from edb.schema import functions as s_func

from . import ast as qlast


FREE_SHAPE_EXPR = qlast.DetachedExpr(
    expr=qlast.Path(
        steps=[qlast.ObjectRef(module='std', name='FreeObject')],
        allow_factoring=True,
    ),
)


class ParameterInliner(ast.NodeTransformer):

    def __init__(self, args_map: Mapping[str, qlast.Base]) -> None:
        super().__init__()
        self.args_map = args_map

    def visit_Path(self, node: qlast.Path) -> qlast.Base:
        if len(node.steps) != 1 or not isinstance(
            node.steps[0], qlast.ObjectRef
        ):
            self.visit(node.steps[0])
            return node

        ref: qlast.ObjectRef = node.steps[0]
        try:
            arg = self.args_map[ref.name]
        except KeyError:
            return node

        arg = copy.deepcopy(arg)
        return arg


def inline_parameters(
    ql_expr: qlast.Base, args: Mapping[str, qlast.Base]
) -> None:

    inliner = ParameterInliner(args)
    inliner.visit(ql_expr)


def index_parameters(
    ql_args: list[qlast.Base],
    *,
    parameters: s_func.ParameterLikeList,
    schema: s_schema.Schema
) -> dict[str, qlast.Base]:

    result: dict[str, qlast.Base] = {}
    varargs: Optional[list[qlast.Expr]] = None
    variadic = parameters.find_variadic(schema)
    variadic_num = variadic.get_num(schema) if variadic else -1  # type: ignore

    params = parameters.objects(schema)

    if not variadic and len(ql_args) > len(params):
        # In error message we discount the implicit __subject__ param.
        raise errors.SchemaDefinitionError(
            f'Expected {len(params) - 1} arguments, but found '
            f'{len(ql_args) - 1}',
            span=ql_args[-1].span,
            details='Did you mean to use ON (...) for specifying the subject?',
        )

    e: qlast.Expr
    p: s_func.ParameterLike
    for iter in itertools.zip_longest(
        enumerate(ql_args), params, fillvalue=None
    ):
        (i, e), p = iter  # type: ignore
        if isinstance(e, qlast.SelectQuery):
            e = e.result

        if variadic and variadic_num == i:
            assert varargs is None
            varargs = []
            result[p.get_parameter_name(schema)] = qlast.Array(
                elements=varargs
            )

        if varargs is not None:
            varargs.append(e)
        else:
            result[p.get_parameter_name(schema)] = e

    return result


class AnchorInliner(ast.NodeTransformer):

    def __init__(self, anchors: Mapping[str, qlast.Base]) -> None:
        super().__init__()
        self.anchors = anchors

    def visit_Path(self, node: qlast.Path) -> qlast.Path:
        if not node.steps:
            return node

        step0 = node.steps[0]

        if isinstance(step0, qlast.Anchor):
            node.steps[0] = self.anchors[step0.name]  # type: ignore
        elif isinstance(step0, qlast.ObjectRef) and step0.name in self.anchors:
            node.steps[0] = self.anchors[step0.name]  # type: ignore

        return node


def inline_anchors(
    ql_expr: qlast.Base, anchors: Mapping[Any, qlast.Base]
) -> None:

    inliner = AnchorInliner(anchors)
    inliner.visit(ql_expr)


def find_paths(ql: qlast.Base) -> list[qlast.Path]:
    return ast.find_children(ql, qlast.Path)


def find_subject_ptrs(ast: qlast.Base) -> set[str]:
    ptrs = set()
    for path in find_paths(ast):
        if path.partial:
            p = path.steps[0]
        elif is_anchor(path.steps[0], '__subject__') and len(path.steps) > 1:
            p = path.steps[1]
        else:
            continue

        if isinstance(p, qlast.Ptr):
            ptrs.add(p.name)
    return ptrs


def is_anchor(expr: qlast.PathElement, name: str) -> bool:
    return isinstance(expr, qlast.Anchor) and expr.name == name


def subject_paths_substitute(
    ast: qlast.Base_T,
    subject_ptrs: dict[str, qlast.Expr],
) -> qlast.Base_T:
    ast = copy.deepcopy(ast)
    for path in find_paths(ast):
        if path.partial and isinstance(path.steps[0], qlast.Ptr):
            path.steps[0] = subject_paths_substitute(
                subject_ptrs[path.steps[0].name],
                subject_ptrs,
            )
        elif (
            is_anchor(path.steps[0], '__subject__')
            and len(path.steps)
            and isinstance(path.steps[1], qlast.Ptr)
        ):
            path.steps[0:2] = [subject_paths_substitute(
                subject_ptrs[path.steps[1].name],
                subject_ptrs,
            )]
    return ast


def subject_substitute(
    ast: qlast.Base_T, new_subject: qlast.Expr
) -> qlast.Base_T:
    ast = copy.deepcopy(ast)
    # If the subject is a path (usually will be), graft the path
    # elements directly to avoid an extra SelectStmt/Set in the IR,
    # which can result in worse codegen (unnecessary semijoins, for
    # example).
    # TODO: Unify other substitution functions.
    if isinstance(new_subject, qlast.Path):
        new_partial = new_subject.partial
        new_head = new_subject.steps
    else:
        new_partial = False
        new_head = [new_subject]

    for path in find_paths(ast):
        if is_anchor(path.steps[0], '__subject__'):
            path.steps[0:1] = new_head
            path.partial = new_partial
        elif path.partial:
            path.steps[0:0] = new_head
            path.partial = new_partial
    return ast


def is_enum(type_name: qlast.TypeName):
    return (
        isinstance(type_name.maintype, (qlast.TypeName, qlast.ObjectRef))
        and type_name.maintype.name == "enum"
        and type_name.subtypes
    )


================================================
FILE: edb/edgeql-parser/Cargo.toml
================================================
[package]
name = "edgeql-parser"
version = "0.1.0"
license = "MIT/Apache-2.0"
authors = ["MagicStack Inc. "]
edition = "2021"

[lints]
workspace = true

[dependencies]
pyo3 = { workspace = true, optional = true }

base32 = "0.5.1"
bigdecimal = { version = "0.4.5", features = ["serde"] }
num-bigint = { version = "0.4.6", features = ["serde"] }
sha2 = "0.10.2"
snafu = "0.8.1"
memchr = "2.5.0"
serde = { version = "1.0.106", features = ["derive"], optional = true }
thiserror = "2"
unicode-width = "0.1.8"
edgeql-parser-derive = { path = "edgeql-parser-derive", optional = true }
indexmap = "2.4.0"
serde_json = { version = "1.0", features = ["preserve_order"] }
bumpalo = { version = "3.13.0", features = ["collections"] }
phf = { version = "0.11.1", features = ["macros"] }
append-only-vec = "0.1.2"

[features]
default = []
python = ["pyo3", "serde", "edgeql-parser-derive"]

[lib]


================================================
FILE: edb/edgeql-parser/edgeql-parser-derive/Cargo.toml
================================================
[package]
name = "edgeql-parser-derive"
description = "Derive macros for IntoPython trait for AST"
version = "0.1.0"
edition = "2021"

[lints]
workspace = true

[lib]
proc-macro = true

[dependencies]
syn = { version = "2.0.76" }
quote = "1.0.37"
proc-macro2 = "1.0"


================================================
FILE: edb/edgeql-parser/edgeql-parser-derive/src/lib.rs
================================================
use proc_macro::TokenStream;

use syn::{parse_macro_input, Attribute, Type, TypePath};

use quote::quote;
use syn::{self, Fields, Ident};

#[proc_macro_derive(IntoPython, attributes(py_child, py_enum, py_union))]
pub fn into_python(input: TokenStream) -> TokenStream {
    use syn::Item;
    let mut item = parse_macro_input!(input as Item);
    match &mut item {
        Item::Enum(enum_) => impl_enum_into_python(enum_),
        Item::Struct(struct_) => impl_struct_into_python(struct_),
        unsupported => {
            syn::Error::new_spanned(unsupported, "IntoPython only supports structs and enums")
                .into_compile_error()
                .into()
        }
    }
}

fn impl_enum_into_python(enum_: &mut syn::ItemEnum) -> TokenStream {
    let variants = infer_variants(enum_);

    let name = &enum_.ident;
    let mut cases = Vec::new();

    if let Some(py_enum) = find_attr(&enum_.attrs, "py_enum") {
        let class_path = py_enum.meta.path();

        for Variant { name } in variants {
            cases.push(quote! {
                Self::#name => py.eval(#class_path.#name, None, None),
            });
        }
    } else if find_attr(&enum_.attrs, "py_child").is_some() {
        for Variant { name } in variants {
            cases.push(quote! {
                Self::#name(value) => value.into_python(py, parent),
            });
        }
    } else if find_attr(&enum_.attrs, "py_union").is_some() {
        for Variant { name } in variants {
            cases.push(quote! {
                Self::#name(value) => value.into_python(py, None),
            });
        }
    } else {
        panic!("enum is missing one of #[py_enum], #[py_child] or #[py_union]")
    }

    quote! {
        impl crate::into_python::IntoPython for #name {
            fn into_python(
                self,
                py: cpython::Python,
                parent: Option,
            ) -> cpython::PyResult {
                use crate::into_python::IntoPython;

                match self { #(#cases)* }
            }
        }
    }
    .into()
}

fn infer_variants(enum_: &syn::ItemEnum) -> Vec {
    let mut variants = Vec::new();

    for variant in &enum_.variants {
        let name = variant.ident.clone();

        match &variant.fields {
            Fields::Named(_) => panic!("IntoPython does not support named enum variant fields"),
            Fields::Unnamed(fields) => {
                if fields.unnamed.len() != 1 {
                    panic!("IntoPython supports only enum variant fields with zero or one fields")
                }

                variants.push(Variant { name });
            }
            Fields::Unit => {
                variants.push(Variant { name });
            }
        }
    }

    variants
}

/// Information about the struct annotated with IntoPython
struct Variant {
    name: Ident,
}

fn impl_struct_into_python(struct_: &mut syn::ItemStruct) -> TokenStream {
    let (properties, py_child_field) = infer_fields(struct_);

    let name = &struct_.ident;

    let mut property_assigns = Vec::new();
    for property in properties {
        property_assigns.push(quote! {
            kw_args.set_item(
                py,
                stringify!(#property),
                self.#property.into_python(py, None)?
            )?;
        });
    }

    let init = if let Some(py_child_field) = py_child_field {
        let field = py_child_field.ident;
        if py_child_field.is_option {
            quote! {
                match self.#field {
                    Some(kind) => kind.into_python(py, Some(kw_args)),
                    None => crate::into_python::init_ast_class(py, stringify!(#name), kw_args)
                }
            }
        } else {
            quote! {
                self.#field.into_python(py, Some(kw_args))
            }
        }
    } else {
        quote! {
            crate::into_python::init_ast_class(py, stringify!(#name), kw_args)
        }
    };

    quote! {
        impl crate::into_python::IntoPython for #name {
            fn into_python(
                self,
                py: cpython::Python,
                parent_kw_args: Option,
            ) -> cpython::PyResult {
                use crate::into_python::IntoPython;

                let kw_args = parent_kw_args.unwrap_or_else(|| cPython::PyDict::new_bound(py));
                #(#property_assigns)*

                #init
            }
        }
    }
    .into()
}

struct PyChildField {
    ident: Ident,
    is_option: bool,
}

fn infer_fields(r#struct: &mut syn::ItemStruct) -> (Vec, Option) {
    let mut properties = Vec::new();
    let mut py_child = None;

    for field in &mut r#struct.fields {
        let ident = field
            .ident
            .clone()
            .expect("py_inherit supports only named fields");

        if find_attr(&field.attrs, "py_child").is_some() {
            let is_option = is_option(&field.ty);

            py_child = Some(PyChildField { ident, is_option });
            continue;
        }

        properties.push(ident);
    }

    (properties, py_child)
}

fn find_attr<'a>(attrs: &'a [Attribute], name: &'static str) -> Option<&'a Attribute> {
    attrs.iter().find(|a| {
        let Some(ident) = a.path().get_ident() else {
            return false;
        };
        *ident == name
    })
}

fn is_option(ty: &Type) -> bool {
    let Type::Path(TypePath { path, .. }) = ty else {
        return false;
    };
    let Some(segment) = path.segments.first() else {
        return false;
    };
    segment.ident == "Option"
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/Cargo.toml
================================================
[package]
name = "edgeql-parser-python"
license = "MIT/Apache-2.0"
version = "0.1.0"
authors = ["MagicStack Inc. "]
edition = "2021"

[lints]
workspace = true

[features]
python_extension = ["pyo3/extension-module"]
default = ["python_extension"]

[dependencies]
pyo3 = { workspace = true, optional = true }

edgeql-parser = { path = "..", features = ["serde"] }
bytes = "1.0.1"
num-bigint = "0.4.3"
bigdecimal = { version = "0.4.5", features = ["string-only"] }
blake2 = "0.10.4"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
indexmap = "2.4.0"
once_cell = "1.18.0"
bincode = { version = "1.3.3" }
gel-protocol = { workspace = true, features = ["with-num-bigint", "with-bigdecimal"] }

[lib]
crate-type = ["lib", "cdylib"]
name = "edgeql_rust"
path = "src/lib.rs"


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/src/errors.rs
================================================
use edgeql_parser::tokenizer::Error;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyList};
use pyo3::{create_exception, exceptions};

use crate::tokenizer::OpaqueToken;

create_exception!(_edgeql_parser, SyntaxError, exceptions::PyException);

#[pyclass]
pub struct ParserResult {
    #[pyo3(get)]
    pub out: Py,

    #[pyo3(get)]
    pub errors: Py,
}

#[pymethods]
impl ParserResult {
    fn pack(&self, py: Python) -> PyResult> {
        let tokens = self.out.downcast_bound::(py)?;
        let mut rv = Vec::with_capacity(tokens.len());
        for token in tokens {
            let token: &Bound = token.downcast()?;
            rv.push(token.borrow().inner.clone());
        }
        let mut buf = vec![0u8]; // type and version
        bincode::serialize_into(&mut buf, &rv)
            .map_err(|e| PyValueError::new_err(format!("Failed to pack: {e}")))?;
        Ok(PyBytes::new(py, buf.as_slice()).into())
    }
}

pub fn parser_error_into_tuple(
    error: &Error,
) -> (&str, (u64, u64), Option<&String>, Option<&String>) {
    (
        &error.message,
        (error.span.start, error.span.end),
        error.hint.as_ref(),
        error.details.as_ref(),
    )
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/src/hash.rs
================================================
use std::sync::RwLock;

use edgeql_parser::hash;
use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyString};

use crate::errors::SyntaxError;

#[pyclass]
pub struct Hasher {
    _hasher: RwLock>,
}

#[pymethods]
impl Hasher {
    #[staticmethod]
    fn start_migration(parent_id: &Bound) -> PyResult {
        let hasher = hash::Hasher::start_migration(parent_id.to_str()?);
        Ok(Hasher {
            _hasher: RwLock::new(Some(hasher)),
        })
    }

    fn add_source(&self, py: Python, data: &Bound) -> PyResult> {
        let text = data.to_str()?;
        let mut cell = self._hasher.write().unwrap();
        let hasher = cell
            .as_mut()
            .ok_or_else(|| PyRuntimeError::new_err(("cannot add source after finish",)))?;

        hasher.add_source(text).map_err(|e| match e {
            hash::Error::Tokenizer(msg, pos) => {
                SyntaxError::new_err((msg, (pos.offset, py.None()), py.None(), py.None()))
            }
        })?;
        Ok(py.None())
    }

    fn make_migration_id(&self) -> PyResult {
        let mut cell = self._hasher.write().unwrap();
        let hasher = cell
            .take()
            .ok_or_else(|| PyRuntimeError::new_err(("cannot do migration id twice",)))?;
        Ok(hasher.make_migration_id())
    }
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/src/keywords.rs
================================================
use pyo3::{prelude::*, types::PyFrozenSet};

use edgeql_parser::keywords;

pub struct AllKeywords {
    pub current: Py,
    pub future: Py,
    pub unreserved: Py,
    pub partial: Py,
}

pub fn get_keywords(py: Python) -> PyResult {
    let intern = py.import("sys")?.getattr("intern")?;

    Ok(AllKeywords {
        current: prepare_keywords(py, &keywords::CURRENT_RESERVED_KEYWORDS, &intern)?,
        unreserved: prepare_keywords(py, &keywords::UNRESERVED_KEYWORDS, &intern)?,
        future: prepare_keywords(py, &keywords::FUTURE_RESERVED_KEYWORDS, &intern)?,
        partial: prepare_keywords(py, &keywords::PARTIAL_RESERVED_KEYWORDS, &intern)?,
    })
}

fn prepare_keywords<'a, 'py, I: IntoIterator>(
    py: Python<'py>,
    keyword_set: I,
    intern: &Bound<'py, PyAny>,
) -> PyResult> {
    PyFrozenSet::new(
        py,
        keyword_set
            .into_iter()
            .map(|s| intern.call((&s,), None).unwrap()),
    )
    .map(|o| o.unbind())
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/src/lib.rs
================================================
#![cfg(feature = "python_extension")]
mod errors;
mod hash;
mod keywords;
pub mod normalize;
mod parser;
mod position;
mod pynormalize;
mod tokenizer;
mod unpack;

use pyo3::prelude::*;

/// Rust bindings to the edgeql-parser crate
#[pymodule]
fn _edgeql_parser(py: Python, m: &Bound) -> PyResult<()> {
    m.add("SyntaxError", py.get_type::())?;
    m.add("ParserResult", py.get_type::())?;

    m.add_class::()?;

    let keywords = keywords::get_keywords(py)?;
    m.add("unreserved_keywords", keywords.unreserved)?;
    m.add("partial_reserved_keywords", keywords.partial)?;
    m.add("future_reserved_keywords", keywords.future)?;
    m.add("current_reserved_keywords", keywords.current)?;

    m.add_class::()?;
    m.add_function(wrap_pyfunction!(pynormalize::normalize, m)?)?;

    m.add_function(wrap_pyfunction!(parser::parse, m)?)?;
    m.add_function(wrap_pyfunction!(parser::suggest_next_keywords, m)?)?;
    m.add_function(wrap_pyfunction!(parser::preload_spec, m)?)?;
    m.add_function(wrap_pyfunction!(parser::save_spec, m)?)?;
    m.add_class::()?;
    m.add_class::()?;
    m.add_class::()?;

    m.add_function(wrap_pyfunction!(position::offset_of_line, m)?)?;
    m.add("SourcePoint", py.get_type::())?;

    m.add_class::()?;
    m.add_function(wrap_pyfunction!(tokenizer::tokenize, m)?)?;
    m.add_function(wrap_pyfunction!(tokenizer::unpickle_token, m)?)?;

    m.add_function(wrap_pyfunction!(unpack::unpack, m)?)?;

    tokenizer::fini_module(m);

    Ok(())
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/src/normalize.rs
================================================
use std::collections::BTreeSet;

use edgeql_parser::keywords::Keyword;
use edgeql_parser::position::{Pos, Span};
use edgeql_parser::tokenizer::{Kind, Token, Tokenizer, Value};

use blake2::{Blake2b512, Digest};

#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Variable {
    pub value: Value,
}

pub struct Entry {
    pub processed_source: String,
    pub hash: [u8; 64],
    pub tokens: Vec>,
    pub variables: Vec>,
    pub named_args: bool,
    pub first_arg: Option,
}

/// PackedEntry is a compact Entry for serialization purposes
#[derive(serde::Serialize, serde::Deserialize)]
pub struct PackedEntry {
    pub tokens: Vec>,
    pub variables: Vec>,
    pub named_args: bool,
    pub first_arg: Option,
}

impl From for PackedEntry {
    fn from(val: Entry) -> Self {
        PackedEntry {
            tokens: val.tokens,
            variables: val.variables,
            named_args: val.named_args,
            first_arg: val.first_arg,
        }
    }
}

impl From for Entry {
    fn from(val: PackedEntry) -> Self {
        let processed_source = serialize_tokens(&val.tokens[..]);
        Entry {
            hash: hash(&processed_source),
            processed_source,
            tokens: val.tokens,
            variables: val.variables,
            named_args: val.named_args,
            first_arg: val.first_arg,
        }
    }
}

#[derive(Debug)]
pub enum Error {
    Tokenizer(String, u64),
    Assertion(String, Pos),
}

pub fn normalize(text: &str) -> Result {
    let tokens = Tokenizer::new(text)
        .validated_values()
        .with_eof()
        .map(|x| x.map(|t| t.cloned()))
        .collect::, _>>()
        .map_err(|e| Error::Tokenizer(e.message, e.span.start))?;

    let (named_args, var_idx) = match scan_vars(&tokens) {
        Some(pair) => pair,
        None => {
            // don't extract from invalid query, let python code do its work
            let processed_source = serialize_tokens(&tokens);
            return Ok(Entry {
                hash: hash(&processed_source),
                processed_source,
                tokens,
                variables: Vec::new(),
                named_args: false,
                first_arg: None,
            });
        }
    };
    let mut rewritten_tokens = Vec::with_capacity(tokens.len());
    let mut all_variables = Vec::new();
    let mut variables = Vec::new();
    let mut counter = var_idx;
    let mut next_var = || {
        let n = counter;
        counter += 1;
        if named_args {
            format!("$__edb_arg_{n}")
        } else {
            format!("${n}")
        }
    };
    let mut last_was_set = false;
    for tok in &tokens {
        let mut is_set = false;
        match tok.kind {
            Kind::IntConst
            // Don't replace `.12` because this is a tuple access
            if !matches!(rewritten_tokens.last(),
                Some(Token { kind: Kind::Dot, .. }))
            // Don't replace 'LIMIT 1' as a special case
            && (tok.text != "1"
                || !matches!(rewritten_tokens.last(),
                    Some(Token { kind: Kind::Keyword(Keyword("limit")), .. })))
            && tok.text != "9223372036854775808"
            => {
                rewritten_tokens.push(arg_type_cast(
                    "int64", next_var(), tok.span
                ));
                variables.push(Variable {
                    value: tok.value.clone().unwrap(),
                });
                continue;
            }
            Kind::FloatConst => {
                rewritten_tokens.push(arg_type_cast(
                    "float64", next_var(), tok.span
                ));
                variables.push(Variable {
                    value: tok.value.clone().unwrap(),
                });
                continue;
            }
            Kind::BigIntConst => {
                rewritten_tokens.push(arg_type_cast(
                    "bigint", next_var(), tok.span
                ));
                variables.push(Variable {
                    value: tok.value.clone().unwrap(),
                });
                continue;
            }
            Kind::DecimalConst => {
                rewritten_tokens.push(arg_type_cast(
                    "decimal", next_var(), tok.span
                ));
                variables.push(Variable {
                    value: tok.value.clone().unwrap(),
                });
                continue;
            }
            Kind::Str => {
                rewritten_tokens.push(arg_type_cast(
                    "str", next_var(), tok.span
                ));
                variables.push(Variable {
                    value: tok.value.clone().unwrap(),
                });
                continue;
            }
            Kind::Keyword(Keyword(kw))
            if (
                matches!(kw, "administer"|"configure"|"create"|"alter"|"drop"|"start"|"analyze")
                || (last_was_set && kw == "global")
            ) => {
                let processed_source = serialize_tokens(&tokens);
                return Ok(Entry {
                    hash: hash(&processed_source),
                    processed_source,
                    tokens,
                    variables: Vec::new(),
                    named_args: false,
                    first_arg: None,
                });
            }
            // Split on semicolons.
            // N.B: This naive statement splitting on semicolons works
            // because the only statements with internal semis are DDL
            // statements, which we don't support anyway.
            Kind::Semicolon => {
                all_variables.push(variables);
                variables = Vec::new();
                rewritten_tokens.push(tok.clone());
            }
            Kind::Keyword(Keyword("set")) => {
                is_set = true;
                rewritten_tokens.push(tok.clone());
            }
            _ => rewritten_tokens.push(tok.clone()),
        }
        last_was_set = is_set;
    }

    all_variables.push(variables);
    // N.B: We always serialize the tokens to produce
    // processed_source, even when no changes have been made. This is
    // because when Source gets serialized, it always uses a
    // PackedEntry, which will result in it being normalized *there*,
    // and so if we don't do it *here*, then we won't be able to hit
    // the persistent cache in cases where we didn't reserialize the
    // tokens.
    // TODO: Rework the caching to avoid needing to do this.
    let processed_source = serialize_tokens(&rewritten_tokens[..]);
    Ok(Entry {
        hash: hash(&processed_source),
        processed_source,
        named_args,
        first_arg: if counter <= var_idx {
            None
        } else {
            Some(var_idx)
        },
        tokens: rewritten_tokens,
        variables: all_variables,
    })
}

fn is_operator(token: &Token) -> bool {
    use edgeql_parser::tokenizer::Kind::*;
    match token.kind {
        Assign | SubAssign | AddAssign | Arrow | Coalesce | Namespace | DoubleSplat
        | BackwardLink | OptionalLink | FloorDiv | Concat | GreaterEq | LessEq | NotEq
        | NotDistinctFrom | DistinctFrom | Comma | OpenParen | CloseParen | OpenBracket
        | CloseBracket | OpenBrace | CloseBrace | Dot | Semicolon | Colon | Add | Sub | Mul
        | Div | Modulo | Pow | Less | Greater | Eq | Ampersand | Pipe | At => true,
        DecimalConst | FloatConst | IntConst | BigIntConst | BinStr | Parameter
        | ParameterAndType | Str | BacktickName | Keyword(_) | Ident | Substitution | EOI
        | Epsilon | StartBlock | StartExtension | StartFragment | StartMigration
        | StartSDLDocument | StrInterpStart | StrInterpCont | StrInterpEnd => false,
    }
}

fn serialize_tokens(tokens: &[Token]) -> String {
    use edgeql_parser::tokenizer::Kind::Parameter;

    let mut buf = String::new();
    let mut needs_space = false;
    for token in tokens {
        if matches!(token.kind, Kind::EOI) {
            break;
        }

        if needs_space && !is_operator(token) && token.kind != Parameter {
            buf.push(' ');
        }
        buf.push_str(&token.text);
        needs_space = !is_operator(token);
    }
    buf
}

fn scan_vars<'x, 'y: 'x, I>(tokens: I) -> Option<(bool, usize)>
where
    I: IntoIterator>,
{
    let mut max_visited = None::;
    let mut names = BTreeSet::new();
    for t in tokens {
        if t.kind == Kind::Parameter {
            if let Ok(v) = t.text[1..].parse() {
                if max_visited.map(|old| v > old).unwrap_or(true) {
                    max_visited = Some(v);
                }
            } else {
                names.insert(&t.text[..]);
            }
        }
    }
    if names.is_empty() {
        let next = max_visited.map(|x| x.checked_add(1)).unwrap_or(Some(0))?;
        Some((false, next))
    } else if max_visited.is_some() {
        return None; // mixed arguments
    } else {
        Some((true, names.len()))
    }
}

fn hash(text: &str) -> [u8; 64] {
    let mut result = [0u8; 64];
    result.copy_from_slice(&Blake2b512::new_with_prefix(text.as_bytes()).finalize());
    result
}

/// Produces tokens corresponding to ($var)
fn arg_type_cast(typ: &'static str, var: String, span: Span) -> Token<'static> {
    // the `lit` is required so these tokens have different text than an actual
    // type cast and parameter, so their hashes don't clash.
    Token {
        kind: Kind::ParameterAndType,
        text: format!("{var}").into(),
        value: None,
        span,
    }
}

#[cfg(test)]
mod test {
    use super::scan_vars;
    use edgeql_parser::tokenizer::{Token, Tokenizer};

    fn tokenize(s: &str) -> Vec {
        let mut r = Vec::new();
        let mut s = Tokenizer::new(s);
        loop {
            match s.next() {
                Some(Ok(x)) => r.push(x),
                None => break,
                Some(Err(e)) => panic!("Parse error at {}: {}", s.current_pos(), e.message),
            }
        }
        r
    }

    #[test]
    fn none() {
        assert_eq!(scan_vars(&tokenize("SELECT 1+1")).unwrap(), (false, 0));
    }

    #[test]
    fn numeric() {
        assert_eq!(scan_vars(&tokenize("$0 $1 $2")).unwrap(), (false, 3));
        assert_eq!(scan_vars(&tokenize("$2 $3 $2")).unwrap(), (false, 4));
        assert_eq!(scan_vars(&tokenize("$0 $0 $0")).unwrap(), (false, 1));
        assert_eq!(scan_vars(&tokenize("$10 $100")).unwrap(), (false, 101));
    }

    #[test]
    fn named() {
        assert_eq!(scan_vars(&tokenize("$a")).unwrap(), (true, 1));
        assert_eq!(scan_vars(&tokenize("$b $c $d")).unwrap(), (true, 3));
        assert_eq!(scan_vars(&tokenize("$b $c $b")).unwrap(), (true, 2));
        assert_eq!(
            scan_vars(&tokenize("$a $b $b $a $c $xx")).unwrap(),
            (true, 4)
        );
    }

    #[test]
    fn mixed() {
        assert_eq!(scan_vars(&tokenize("$a $0")), None);
        assert_eq!(scan_vars(&tokenize("$0 $a")), None);
        assert_eq!(scan_vars(&tokenize("$b $c $100")), None);
        assert_eq!(scan_vars(&tokenize("$10 $xx $yy")), None);
    }
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/src/parser.rs
================================================
use once_cell::sync::OnceCell;

use edgeql_parser::parser;
use pyo3::exceptions::{PyAssertionError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyList, PyString};

use crate::errors::{parser_error_into_tuple, ParserResult};
use crate::pynormalize::TokenizerValue;
use crate::tokenizer::OpaqueToken;

#[pyfunction]
pub fn parse(
    py: Python,
    start_token_name: &Bound,
    tokens: Py,
) -> PyResult<(ParserResult, &'static Py)> {
    let start_token_name = start_token_name.to_string();

    let (spec, productions) = get_spec()?;

    let tokens = downcast_tokens(py, &start_token_name, tokens)?;

    let context = parser::Context::new(spec);
    let (cst, errors) = parser::parse(&tokens, &context);

    let errors = PyList::new(py, errors.iter().map(|e| parser_error_into_tuple(e)))?;

    let res = ParserResult {
        out: cst.as_ref().map(ParserCSTNode).into_pyobject(py)?.unbind(),
        errors: errors.into(),
    };

    Ok((res, productions))
}

#[pyfunction]
pub fn suggest_next_keywords(
    py: Python,
    start_token_name: &Bound,
    tokens: Py,
) -> PyResult<(Py, bool)> {
    let start_token_name = start_token_name.to_string();

    let (spec, _) = get_spec()?;

    let tokens = downcast_tokens(py, &start_token_name, tokens)?;

    let context = parser::Context::new(spec);
    let (suggestions, can_be_ident) = parser::suggest_next_keyword(&tokens, &context);

    let suggestions_py = suggestions.iter().map(|k| PyString::new(py, k.0));
    let suggestions_py = PyList::new(py, suggestions_py)?.into();

    Ok((suggestions_py, can_be_ident))
}

#[pyclass]
pub struct CSTNode {
    #[pyo3(get)]
    production: Option>,
    #[pyo3(get)]
    terminal: Option>,
}

#[pyclass]
pub struct Production {
    #[pyo3(get)]
    id: usize,
    #[pyo3(get)]
    args: Py,
    #[pyo3(get)]
    start: Option,
    #[pyo3(get)]
    end: Option,
}

#[pyclass]
pub struct Terminal {
    #[pyo3(get)]
    text: String,
    #[pyo3(get)]
    value: Py,
    #[pyo3(get)]
    start: u64,
    #[pyo3(get)]
    end: u64,
}

static PARSER_SPECS: OnceCell<(parser::Spec, Py)> = OnceCell::new();

fn downcast_tokens(
    py: Python,
    start_token_name: &str,
    token_list: Py,
) -> PyResult> {
    let tokens = token_list.downcast_bound::(py)?;

    let mut buf = Vec::with_capacity(tokens.len() + 1);
    buf.push(parser::Terminal::from_start_name(start_token_name));
    for token in tokens.iter() {
        let token: &Bound = token.downcast()?;
        let token = token.borrow().inner.clone();

        buf.push(parser::Terminal::from_token(token));
    }

    Ok(buf)
}

fn get_spec() -> PyResult<&'static (parser::Spec, Py)> {
    if let Some(x) = PARSER_SPECS.get() {
        Ok(x)
    } else {
        Err(PyAssertionError::new_err(("grammar spec not loaded",)))
    }
}

/// Loads the grammar specification from file and caches it in memory.
#[pyfunction]
pub fn preload_spec(py: Python, spec_filepath: &Bound) -> PyResult<()> {
    if PARSER_SPECS.get().is_some() {
        return Ok(());
    }

    let spec_filepath = spec_filepath.to_string();
    let bytes = std::fs::read(&spec_filepath)
        .unwrap_or_else(|e| panic!("Cannot read grammar spec from {spec_filepath} ({e})"));

    let spec: parser::Spec = bincode::deserialize::(&bytes)
        .map_err(|e| PyValueError::new_err(format!("Bad spec: {e}")))?
        .into();
    let productions = load_productions(py, &spec)?;

    let _ = PARSER_SPECS.set((spec, productions));
    Ok(())
}

/// Serialize the grammar specification and write it to a file.
///
/// Called from setup.py.
#[pyfunction]
pub fn save_spec(spec_json: &Bound, dst: &Bound) -> PyResult<()> {
    let spec_json = spec_json.to_string();
    let spec: parser::SpecSerializable = serde_json::from_str(&spec_json)
        .map_err(|e| PyValueError::new_err(format!("Invalid JSON: {e}")))?;
    let spec_bitcode = bincode::serialize(&spec)
        .map_err(|e| PyValueError::new_err(format!("Failed to pack spec: {e}")))?;

    let dst = dst.to_string();

    std::fs::write(dst, spec_bitcode).ok().unwrap();
    Ok(())
}

fn load_productions(py: Python<'_>, spec: &parser::Spec) -> PyResult> {
    let grammar_name = "edb.edgeql.parser.grammar.start";
    let grammar_mod = py.import(grammar_name)?;
    let load_productions = py
        .import("edb.common.parsing")?
        .getattr("load_spec_productions")?;

    let productions = load_productions.call((&spec.production_names, grammar_mod), None)?;
    Ok(productions.into())
}

/// Newtype required to define a trait for a foreign type.
struct ParserCSTNode<'a>(&'a parser::CSTNode<'a>);

impl<'py> IntoPyObject<'py> for ParserCSTNode<'_> {
    type Target = CSTNode;
    type Output = Bound<'py, Self::Target>;
    type Error = PyErr;

    fn into_pyobject(self, py: Python<'py>) -> PyResult {
        let res = match self.0 {
            parser::CSTNode::Empty => CSTNode {
                production: None,
                terminal: None,
            },
            parser::CSTNode::Terminal(token) => CSTNode {
                production: None,
                terminal: Some(Py::new(
                    py,
                    Terminal {
                        text: token.text.clone(),
                        value: (token.value.as_ref())
                            .map(TokenizerValue)
                            .into_pyobject(py)?
                            .unbind(),
                        start: token.span.start,
                        end: token.span.end,
                    },
                )?),
            },
            parser::CSTNode::Production(prod) => CSTNode {
                production: Some(Py::new(
                    py,
                    Production {
                        id: prod.id,
                        args: PyList::new(py, prod.args.iter().map(ParserCSTNode))?.into(),
                        start: prod.span.map(|s| s.start),
                        end: prod.span.map(|s| s.end),
                    },
                )?),
                terminal: None,
            },
        };
        Ok(Py::new(py, res)?.bind(py).clone())
    }
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/src/position.rs
================================================
use pyo3::{
    exceptions::{PyIndexError, PyRuntimeError},
    prelude::*,
    types::{PyBytes, PyList},
};

use edgeql_parser::position::InflatedPos;

#[pyclass]
pub struct SourcePoint {
    _position: InflatedPos,
}

#[pymethods]
impl SourcePoint {
    #[staticmethod]
    fn from_offsets(py: Python, data: &Bound, offsets: Py) -> PyResult> {
        let mut list: Vec = offsets.extract(py)?;
        let data: &[u8] = data.as_bytes();
        list.sort();
        let result = InflatedPos::from_offsets(data, &list)
            .map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

        PyList::new(
            py,
            result
                .into_iter()
                .map(|_position| SourcePoint { _position }),
        )
        .map(|v| v.into())
    }

    #[staticmethod]
    fn from_lines_cols(
        py: Python,
        data: &Bound,
        lines_cols: Py,
    ) -> PyResult> {
        let mut list: Vec<(u64, u64)> = lines_cols.extract(py)?;
        let data: &[u8] = data.as_bytes();
        list.sort();
        let result = InflatedPos::from_lines_cols(data, &list)
            .map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

        PyList::new(
            py,
            result
                .into_iter()
                .map(|_position| SourcePoint { _position }),
        )
        .map(|v| v.into())
    }

    #[getter]
    fn line(&self) -> u64 {
        self._position.line + 1
    }
    #[getter]
    fn zero_based_line(&self) -> u64 {
        self._position.line
    }
    #[getter]
    fn column(&self) -> u64 {
        self._position.column + 1
    }
    #[getter]
    fn utf16column(&self) -> u64 {
        self._position.utf16column
    }
    #[getter]
    fn offset(&self) -> u64 {
        self._position.offset
    }
    #[getter]
    fn char_offset(&self) -> u64 {
        self._position.char_offset
    }
}

fn _offset_of_line(text: &str, target: usize) -> Option {
    let mut was_lf = false;
    let mut line = 0; // this assumes line found by rfind
    for (idx, &byte) in text.as_bytes().iter().enumerate() {
        if line >= target {
            return Some(idx);
        }
        match byte {
            b'\n' => {
                line += 1;
                was_lf = false;
            }
            _ if was_lf => {
                line += 1;
                if line >= target {
                    return Some(idx);
                }
                was_lf = byte == b'\r';
            }
            b'\r' => {
                was_lf = true;
            }
            _ => {}
        }
    }
    if was_lf {
        line += 1;
    }
    if target > line {
        return None;
    }
    Some(text.len())
}

#[pyfunction]
pub fn offset_of_line(text: &str, target: usize) -> PyResult {
    match _offset_of_line(text, target) {
        Some(offset) => Ok(offset),
        None => Err(PyIndexError::new_err("line number is too large")),
    }
}

#[test]
fn line_offsets() {
    assert_eq!(_offset_of_line("line1\nline2\nline3", 0), Some(0));
    assert_eq!(_offset_of_line("line1\nline2\nline3", 1), Some(6));
    assert_eq!(_offset_of_line("line1\nline2\nline3", 2), Some(12));
    assert_eq!(_offset_of_line("line1\nline2\nline3", 3), None);
    assert_eq!(_offset_of_line("line1\rline2\rline3", 0), Some(0));
    assert_eq!(_offset_of_line("line1\rline2\rline3", 1), Some(6));
    assert_eq!(_offset_of_line("line1\rline2\rline3", 2), Some(12));
    assert_eq!(_offset_of_line("line1\rline2\rline3", 3), None);
    assert_eq!(_offset_of_line("line1\r\nline2\r\nline3", 0), Some(0));
    assert_eq!(_offset_of_line("line1\r\nline2\r\nline3", 1), Some(7));
    assert_eq!(_offset_of_line("line1\r\nline2\r\nline3", 2), Some(14));
    assert_eq!(_offset_of_line("line1\r\nline2\r\nline3", 3), None);
    assert_eq!(_offset_of_line("line1\rline2\r\nline3\n", 0), Some(0));
    assert_eq!(_offset_of_line("line1\rline2\r\nline3\n", 1), Some(6));
    assert_eq!(_offset_of_line("line1\rline2\r\nline3\n", 2), Some(13));
    assert_eq!(_offset_of_line("line1\rline2\r\nline3\n", 3), Some(19));
    assert_eq!(_offset_of_line("line1\rline2\r\nline3\n", 4), None);
    assert_eq!(_offset_of_line("line1\nline2\rline3\r\n", 0), Some(0));
    assert_eq!(_offset_of_line("line1\nline2\rline3\r\n", 1), Some(6));
    assert_eq!(_offset_of_line("line1\nline2\rline3\r\n", 2), Some(12));
    assert_eq!(_offset_of_line("line1\nline2\rline3\r\n", 3), Some(19));
    assert_eq!(_offset_of_line("line1\nline2\rline3\r\n", 4), None);
    assert_eq!(_offset_of_line("line1\n\rline2\r\rline3\r", 0), Some(0));
    assert_eq!(_offset_of_line("line1\n\rline2\r\rline3\r", 1), Some(6));
    assert_eq!(_offset_of_line("line1\n\rline2\r\rline3\r", 2), Some(7));
    assert_eq!(_offset_of_line("line1\n\rline2\r\rline3\r", 3), Some(13));
    assert_eq!(_offset_of_line("line1\n\rline2\r\rline3\r", 4), Some(14));
    assert_eq!(_offset_of_line("line1\n\rline2\r\rline3\r", 5), Some(20));
    assert_eq!(_offset_of_line("line1\n\rline2\r\rline3\r", 6), None);
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/src/pynormalize.rs
================================================
use std::convert::TryFrom;

use bigdecimal::Num;

use bytes::{BufMut, Bytes, BytesMut};
use edgeql_parser::tokenizer::Value;
use gel_protocol::codec;
use gel_protocol::model::{BigInt, Decimal};
use pyo3::exceptions::{PyAssertionError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyFloat, PyInt, PyList, PyString};

use crate::errors::SyntaxError;
use crate::normalize::{normalize as _normalize, Error, PackedEntry, Variable};
use crate::tokenizer::tokens_to_py;

#[pyfunction]
pub fn normalize(py: Python<'_>, text: &Bound) -> PyResult {
    let text = text.to_string();
    match _normalize(&text) {
        Ok(entry) => Entry::new(py, entry),
        Err(Error::Tokenizer(msg, pos)) => Err(SyntaxError::new_err((
            msg,
            (pos, py.None()),
            py.None(),
            py.None(),
        ))),
        Err(Error::Assertion(msg, pos)) => Err(PyAssertionError::new_err(format!("{pos}: {msg}"))),
    }
}

#[pyclass]
pub struct Entry {
    #[pyo3(get)]
    key: Py,

    #[pyo3(get)]
    tokens: Py,

    #[pyo3(get)]
    extra_blobs: Py,

    extra_named: bool,

    #[pyo3(get)]
    first_extra: Option,

    #[pyo3(get)]
    extra_counts: Py,

    entry_pack: PackedEntry,
}

impl Entry {
    pub fn new(py: Python, entry: crate::normalize::Entry) -> PyResult {
        let blobs = serialize_all(py, &entry.variables)?;
        let counts = entry.variables.iter().map(|x| x.len());

        Ok(Entry {
            key: PyBytes::new(py, &entry.hash[..]).into(),
            tokens: tokens_to_py(py, entry.tokens.clone())?.into_any(),
            extra_blobs: blobs.into(),
            extra_named: entry.named_args,
            first_extra: entry.first_arg,
            extra_counts: PyList::new(py, counts)?.into(),
            entry_pack: entry.into(),
        })
    }
}

#[pymethods]
impl Entry {
    fn get_variables(&self, py: Python) -> PyResult> {
        let vars = PyDict::new(py);
        let first = match self.first_extra {
            Some(first) => first,
            None => return Ok(vars.into()),
        };
        for (idx, var) in self.entry_pack.variables.iter().flatten().enumerate() {
            let s = if self.extra_named {
                format!("__edb_arg_{}", first + idx)
            } else {
                (first + idx).to_string()
            };
            vars.set_item(s, TokenizerValue(&var.value))?;
        }

        Ok(vars.into())
    }

    fn pack(&self, py: Python) -> PyResult> {
        let mut buf = vec![1u8]; // type and version
        bincode::serialize_into(&mut buf, &self.entry_pack)
            .map_err(|e| PyValueError::new_err(format!("Failed to pack: {e}")))?;
        Ok(PyBytes::new(py, buf.as_slice()).into())
    }
}

pub fn serialize_extra(variables: &[Variable]) -> Result {
    use gel_protocol::codec::Codec;
    use gel_protocol::value::Value as P;

    let mut buf = BytesMut::new();
    buf.reserve(4 * variables.len());
    for var in variables {
        buf.reserve(4);
        let pos = buf.len();
        buf.put_u32(0); // replaced after serializing a value
        match var.value {
            Value::Int(v) => {
                codec::Int64
                    .encode(&mut buf, &P::Int64(v))
                    .map_err(|e| format!("int cannot be encoded: {e}"))?;
            }
            Value::String(ref v) => {
                codec::Str
                    .encode(&mut buf, &P::Str(v.clone()))
                    .map_err(|e| format!("str cannot be encoded: {e}"))?;
            }
            Value::Float(ref v) => {
                codec::Float64
                    .encode(&mut buf, &P::Float64(*v))
                    .map_err(|e| format!("float cannot be encoded: {e}"))?;
            }
            Value::BigInt(ref v) => {
                // We have two different versions of BigInt implementations here.
                // We have to use bigdecimal::num_bigint::BigInt because it can parse with radix 16.

                let val = bigdecimal::num_bigint::BigInt::from_str_radix(v, 16)
                    .map_err(|e| format!("bigint cannot be encoded: {e}"))
                    .and_then(|x| {
                        BigInt::try_from(x).map_err(|e| format!("bigint cannot be encoded: {e}"))
                    })?;

                codec::BigInt
                    .encode(&mut buf, &P::BigInt(val))
                    .map_err(|e| format!("bigint cannot be encoded: {e}"))?;
            }
            Value::Decimal(ref v) => {
                let val = Decimal::try_from(v.clone())
                    .map_err(|e| format!("decimal cannot be encoded: {e}"))?;
                codec::Decimal
                    .encode(&mut buf, &P::Decimal(val))
                    .map_err(|e| format!("decimal cannot be encoded: {e}"))?;
            }
            Value::Bytes(_) => {
                // bytes literals should not be extracted during normalization
                unreachable!()
            }
        }
        let len = buf.len() - pos - 4;
        buf[pos..pos + 4].copy_from_slice(
            &u32::try_from(len)
                .map_err(|_| "element isn't too long".to_owned())?
                .to_be_bytes(),
        );
    }
    Ok(buf.freeze())
}

pub fn serialize_all<'a>(
    py: Python<'a>,
    variables: &[Vec],
) -> PyResult> {
    let mut buf = Vec::with_capacity(variables.len());
    for vars in variables {
        let bytes = serialize_extra(vars).map_err(PyAssertionError::new_err)?;
        buf.push(PyBytes::new(py, &bytes));
    }
    PyList::new(py, &buf)
}

/// Newtype required to define a trait for a foreign type.
pub struct TokenizerValue<'a>(pub &'a Value);

impl<'py> IntoPyObject<'py> for TokenizerValue<'py> {
    type Target = PyAny;
    type Output = Bound<'py, Self::Target>;
    type Error = PyErr;

    fn into_pyobject(self, py: Python<'py>) -> PyResult {
        let res = match self.0 {
            Value::Int(v) => v.into_pyobject(py)?.into_any(),
            Value::String(v) => v.into_pyobject(py)?.into_any(),
            Value::Float(v) => v.into_pyobject(py)?.into_any(),
            Value::BigInt(v) => py.get_type::().call((v, 16), None)?,
            Value::Decimal(v) => py
                .get_type::()
                .call((v.to_string(),), None)?
                .into_any(),
            Value::Bytes(v) => PyBytes::new(py, v).into_any(),
        };
        Ok(res)
    }
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/src/tokenizer.rs
================================================
use edgeql_parser::tokenizer::{Kind, Token, Tokenizer};
use once_cell::sync::OnceCell;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyList, PyString};

use crate::errors::{parser_error_into_tuple, ParserResult};

#[pyfunction]
pub fn tokenize(py: Python, s: &Bound) -> PyResult {
    let data = s.to_string();

    let token_stream = Tokenizer::new(&data[..]).validated_values().with_eof();

    let mut tokens = vec![];
    let mut errors = vec![];

    for res in token_stream.into_iter() {
        match res {
            Ok(token) => tokens.push(token),
            Err(e) => {
                errors.push(parser_error_into_tuple(&e).into_pyobject(py)?);

                // TODO: fix tokenizer to skip bad tokens and continue
                break;
            }
        }
    }

    let out = tokens_to_py(py, tokens)?.into_pyobject(py)?.into();
    let errors = PyList::new(py, errors)?.into();

    Ok(ParserResult { out, errors })
}

// An opaque wrapper around [edgeql_parser::tokenizer::Token].
// Supports Python pickle serialization.
#[pyclass]
pub struct OpaqueToken {
    pub inner: Token<'static>,
}

#[pymethods]
impl OpaqueToken {
    fn __repr__(&self) -> PyResult {
        Ok(self.inner.to_string())
    }
    fn __reduce__(&self, py: Python) -> PyResult<(Py, (Py,))> {
        let data = bincode::serialize(&self.inner)
            .map_err(|e| PyValueError::new_err(format!("Failed to reduce: {e}")))?;

        let tok = get_unpickle_token_fn(py);
        Ok((tok, (PyBytes::new(py, &data).into(),)))
    }

    fn span_start(&self) -> u64 {
        self.inner.span.start
    }

    fn span_end(&self) -> u64 {
        self.inner.span.end
    }

    fn is_ident(&self) -> bool {
        matches!(self.inner.kind, Kind::Ident)
    }
}

pub fn tokens_to_py(py: Python<'_>, rust_tokens: Vec) -> PyResult> {
    Ok(PyList::new(
        py,
        rust_tokens.into_iter().map(|tok| OpaqueToken {
            inner: tok.cloned(),
        }),
    )?
    .unbind())
}

/// To support pickle serialization of OpaqueTokens, we need to provide a
/// deserialization function in __reduce__ methods.
/// This function must not be inlined and must be globally accessible.
/// To achieve this, we expose it a part of the module definition
/// (`unpickle_token`) and save reference to is in the `FN_UNPICKLE_TOKEN`.
///
/// A bit hackly, but it works.
static FN_UNPICKLE_TOKEN: OnceCell> = OnceCell::new();

pub fn fini_module(m: &Bound) {
    let _unpickle_token = m.getattr("unpickle_token").unwrap();
    FN_UNPICKLE_TOKEN
        .set(_unpickle_token.unbind())
        .expect("module is already initialized");
}

#[pyfunction]
pub fn unpickle_token(bytes: &Bound) -> PyResult {
    let token = bincode::deserialize(bytes.as_bytes())
        .map_err(|e| PyValueError::new_err(format!("Failed to read token: {e}")))?;
    Ok(OpaqueToken { inner: token })
}

fn get_unpickle_token_fn(py: Python) -> Py {
    let py_function = FN_UNPICKLE_TOKEN.get().expect("module uninitialized");
    py_function.clone_ref(py)
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/src/unpack.rs
================================================
use edgeql_parser::tokenizer::Token;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyBytes;

use crate::normalize::PackedEntry;
use crate::pynormalize::Entry;
use crate::tokenizer::tokens_to_py;

#[pyfunction]
pub fn unpack(py: Python<'_>, serialized: &Bound) -> PyResult> {
    let buf = serialized.as_bytes();
    match buf[0] {
        0u8 => {
            let tokens: Vec = bincode::deserialize(&buf[1..])
                .map_err(|e| PyValueError::new_err(format!("{e}")))?;
            Ok(tokens_to_py(py, tokens)?.into_any())
        }
        1u8 => {
            let pack: PackedEntry = bincode::deserialize(&buf[1..])
                .map_err(|e| PyValueError::new_err(format!("Failed to unpack: {e}")))?;
            let entry = Entry::new(py, pack.into())?;
            entry.into_pyobject(py).map(|e| e.unbind().into_any())
        }
        _ => Err(PyValueError::new_err(format!(
            "Invalid type/version byte: {}",
            buf[0]
        ))),
    }
}


================================================
FILE: edb/edgeql-parser/edgeql-parser-python/tests/normalize.rs
================================================
#![cfg(feature = "python_extension")]
use edgeql_parser::tokenizer::Value;
use edgeql_rust::normalize::{normalize, Variable};
use num_bigint::BigInt;

#[test]
fn test_verbatim() {
    let entry = normalize(
        r###"
        SELECT $1 + $2
    "###,
    )
    .unwrap();
    assert_eq!(entry.processed_source, "SELECT$1+$2");
    assert_eq!(entry.variables, vec![vec![]]);
}

#[test]
fn test_configure() {
    let entry = normalize(
        r###"
        CONFIGURE INSTANCE SET some_setting := 7
    "###,
    )
    .unwrap();
    assert_eq!(
        entry.processed_source,
        "CONFIGURE INSTANCE SET some_setting:=7"
    );
    assert_eq!(entry.variables, vec![] as Vec>);
}

#[test]
fn test_int() {
    let entry = normalize(
        r###"
        SELECT 1 + 2
    "###,
    )
    .unwrap();
    assert_eq!(entry.processed_source, "SELECT $0+$1");
    assert_eq!(
        entry.variables,
        vec![vec![
            Variable {
                value: Value::Int(1),
            },
            Variable {
                value: Value::Int(2),
            }
        ]]
    );
}

#[test]
fn test_str() {
    let entry = normalize(
        r#"
        SELECT "x" + "yy"
    "#,
    )
    .unwrap();
    assert_eq!(entry.processed_source, "SELECT $0+$1");
    assert_eq!(
        entry.variables,
        vec![vec![
            Variable {
                value: Value::String("x".into()),
            },
            Variable {
                value: Value::String("yy".into()),
            }
        ]]
    );
}

#[test]
fn test_float() {
    let entry = normalize(
        r###"
        SELECT 1.5 + 23.25
    "###,
    )
    .unwrap();
    assert_eq!(
        entry.processed_source,
        "SELECT $0+$1"
    );
    assert_eq!(
        entry.variables,
        vec![vec![
            Variable {
                value: Value::Float(1.5),
            },
            Variable {
                value: Value::Float(23.25),
            }
        ]]
    );
}

#[test]
fn test_bigint() {
    let entry = normalize(
        r###"
        SELECT 1n + 23n
    "###,
    )
    .unwrap();
    assert_eq!(
        entry.processed_source,
        "SELECT $0+$1"
    );
    assert_eq!(
        entry.variables,
        vec![vec![
            Variable {
                value: Value::BigInt("1".into()),
            },
            Variable {
                value: Value::BigInt(BigInt::from(23).to_str_radix(16)),
            }
        ]]
    );
}

#[test]
fn test_bigint_exponent() {
    let entry = normalize(
        r###"
        SELECT 1e10n + 23e13n
    "###,
    )
    .unwrap();
    assert_eq!(
        entry.processed_source,
        "SELECT $0+$1"
    );
    assert_eq!(
        entry.variables,
        vec![vec![
            Variable {
                value: Value::BigInt(BigInt::from(10000000000u64).to_str_radix(16)),
            },
            Variable {
                value: Value::BigInt(BigInt::from(230000000000000u64).to_str_radix(16)),
            }
        ]]
    );
}

#[test]
fn test_decimal() {
    let entry = normalize(
        r###"
        SELECT 1.33n + 23.77n
    "###,
    )
    .unwrap();
    assert_eq!(
        entry.processed_source,
        "SELECT $0+$1"
    );
    assert_eq!(
        entry.variables,
        vec![vec![
            Variable {
                value: Value::Decimal("1.33".parse().unwrap()),
            },
            Variable {
                value: Value::Decimal("23.77".parse().unwrap()),
            }
        ]]
    );
}

#[test]
fn test_positional() {
    let entry = normalize(
        r###"
        SELECT $0 + 2
    "###,
    )
    .unwrap();
    assert_eq!(entry.processed_source, "SELECT$0+$1");
    assert_eq!(
        entry.variables,
        vec![vec![Variable {
            value: Value::Int(2),
        }]]
    );
}

#[test]
fn test_named() {
    let entry = normalize(
        r###"
        SELECT $test_var + 2
    "###,
    )
    .unwrap();
    assert_eq!(
        entry.processed_source,
        "SELECT$test_var+$__edb_arg_1"
    );
    assert_eq!(
        entry.variables,
        vec![vec![Variable {
            value: Value::Int(2),
        }]]
    );
}

#[test]
fn test_limit_1() {
    let entry = normalize(
        r###"
        SELECT User { one := 1 } LIMIT 1
    "###,
    )
    .unwrap();
    assert_eq!(
        entry.processed_source,
        "SELECT User{one:=$0}LIMIT 1"
    );
    assert_eq!(
        entry.variables,
        vec![vec![Variable {
            value: Value::Int(1),
        },]]
    );
}

#[test]
fn test_tuple_access() {
    let entry = normalize(
        r###"
        SELECT User { one := 2, two := .field.2, three := .field  . 3 }
    "###,
    )
    .unwrap();
    assert_eq!(
        entry.processed_source,
        "SELECT User{one:=$0,\
                     two:=.field.2,three:=.field.3}"
    );
    assert_eq!(
        entry.variables,
        vec![vec![Variable {
            value: Value::Int(2),
        },]]
    );
}

#[test]
fn test_script() {
    let entry = normalize(
        r###"
        SELECT 1 + 2;
        SELECT 2;
    "###,
    )
    .unwrap();
    assert_eq!(
        entry.processed_source,
        "SELECT $0+$1;\
        SELECT $2;",
    );
    assert_eq!(
        entry.variables,
        vec![
            vec![
                Variable {
                    value: Value::Int(1),
                },
                Variable {
                    value: Value::Int(2),
                }
            ],
            vec![Variable {
                value: Value::Int(2),
            }],
            vec![]
        ]
    );
}

#[test]
fn test_script_with_args() {
    let entry = normalize(
        r###"
        SELECT 2 + $1;
        SELECT $1 + 2;
    "###,
    )
    .unwrap();
    assert_eq!(
        entry.processed_source,
        "SELECT $2+$1;SELECT$1+$3;",
    );
    assert_eq!(
        entry.variables,
        vec![
            vec![Variable {
                value: Value::Int(2),
            }],
            vec![Variable {
                value: Value::Int(2),
            }],
            vec![]
        ]
    );
}


================================================
FILE: edb/edgeql-parser/src/ast.rs
================================================
// DO NOT EDIT. This file was generated with:
//
// $ edb gen-rust-ast

//! Abstract Syntax Tree for EdgeQL
#![allow(non_camel_case_types)]
#![cfg(never)] // TODO: migrate cpython-rust to pyo3

use indexmap::IndexMap;

#[cfg(feature = "python")]
use edgeql_parser_derive::IntoPython;

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct OptionValue {
    pub name: String,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum OptionValueKind {
    OptionFlag(OptionFlag),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct OptionFlag {
    pub val: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Options {
    pub options: IndexMap,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Expr {
    #[cfg_attr(feature = "python", py_child)]
    pub kind: ExprKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum ExprKind {
    Placeholder(Placeholder),
    Anchor(Anchor),
    DetachedExpr(DetachedExpr),
    GlobalExpr(GlobalExpr),
    Indirection(Indirection),
    BinOp(BinOp),
    FunctionCall(FunctionCall),
    BaseConstant(BaseConstant),
    Parameter(Parameter),
    UnaryOp(UnaryOp),
    IsOp(IsOp),
    Path(Path),
    TypeCast(TypeCast),
    Introspect(Introspect),
    IfElse(IfElse),
    NamedTuple(NamedTuple),
    Tuple(Tuple),
    Array(Array),
    Set(Set),
    ShapeElement(ShapeElement),
    Shape(Shape),
    Query(Query),
    ConfigOp(ConfigOp),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Placeholder {
    pub name: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SortExpr {
    pub path: Box,
    pub direction: Option,
    pub nones_order: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AliasedExpr {
    pub alias: String,
    pub expr: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ModuleAliasDecl {
    pub module: String,
    pub alias: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct BaseObjectRef {
    #[cfg_attr(feature = "python", py_child)]
    pub kind: BaseObjectRefKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum BaseObjectRefKind {
    ObjectRef(ObjectRef),
    PseudoObjectRef(PseudoObjectRef),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ObjectRef {
    pub name: String,
    pub module: Option,
    pub itemclass: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct PseudoObjectRef {
    #[cfg_attr(feature = "python", py_child)]
    pub kind: PseudoObjectRefKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum PseudoObjectRefKind {
    AnyType(AnyType),
    AnyTuple(AnyTuple),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AnyType {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AnyTuple {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Anchor {
    pub name: String,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: AnchorKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum AnchorKind {
    SpecialAnchor(SpecialAnchor),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SpecialAnchor {
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum SpecialAnchorKind {
    Source(Source),
    Subject(Subject),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Source {
    pub name: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Subject {
    pub name: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DetachedExpr {
    pub expr: Box,
    pub preserve_path_prefix: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct GlobalExpr {
    pub name: ObjectRef,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Index {
    pub index: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Slice {
    pub start: Option>,
    pub stop: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Indirection {
    pub arg: Box,
    pub indirection: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum IndirectionIndirection {
    Index(Index),
    Slice(Slice),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct BinOp {
    pub left: Box,
    pub op: String,
    pub right: Box,
    pub rebalanced: bool,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum BinOpKind {
    SetConstructorOp(SetConstructorOp),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SetConstructorOp {
    pub op: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct WindowSpec {
    pub orderby: Vec,
    pub partition: Vec>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct FunctionCall {
    pub func: FunctionCallFunc,
    pub args: Vec>,
    pub kwargs: IndexMap>,
    pub window: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum FunctionCallFunc {
    Tuple((String, String)),
    str(String),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct BaseConstant {
    pub value: String,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: BaseConstantKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum BaseConstantKind {
    StringConstant(StringConstant),
    BaseRealConstant(BaseRealConstant),
    BooleanConstant(BooleanConstant),
    BytesConstant(BytesConstant),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct StringConstant {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct BaseRealConstant {
    pub is_negative: bool,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: BaseRealConstantKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum BaseRealConstantKind {
    IntegerConstant(IntegerConstant),
    FloatConstant(FloatConstant),
    BigintConstant(BigintConstant),
    DecimalConstant(DecimalConstant),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct IntegerConstant {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct FloatConstant {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct BigintConstant {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DecimalConstant {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct BooleanConstant {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct BytesConstant {
    pub value: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Parameter {
    pub name: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct UnaryOp {
    pub op: String,
    pub operand: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct TypeExpr {
    pub name: Option,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum TypeExprKind {
    TypeOf(TypeOf),
    TypeExprLiteral(TypeExprLiteral),
    TypeName(TypeName),
    TypeOp(TypeOp),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct TypeOf {
    pub expr: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct TypeExprLiteral {
    pub val: BaseConstant,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct TypeName {
    pub maintype: BaseObjectRef,
    pub subtypes: Option>,
    pub dimensions: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct TypeOp {
    pub left: Box,
    pub op: String,
    pub right: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct FuncParam {
    pub name: String,
    pub r#type: TypeExpr,
    pub typemod: TypeModifier,
    pub kind: ParameterKind,
    pub default: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct IsOp {
    pub left: Box,
    pub op: String,
    pub right: TypeExpr,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct TypeIntersection {
    pub r#type: TypeExpr,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Ptr {
    pub ptr: ObjectRef,
    pub direction: Option,
    pub r#type: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Splat {
    pub depth: i64,
    pub r#type: Option,
    pub intersection: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Path {
    pub steps: Vec,
    pub partial: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum PathSteps {
    Expr(Box),
    Ptr(Ptr),
    TypeIntersection(TypeIntersection),
    ObjectRef(ObjectRef),
    Splat(Splat),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct TypeCast {
    pub expr: Box,
    pub r#type: TypeExpr,
    pub cardinality_mod: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Introspect {
    pub r#type: TypeExpr,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct IfElse {
    pub condition: Box,
    pub if_expr: Box,
    pub else_expr: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct TupleElement {
    pub name: ObjectRef,
    pub val: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct NamedTuple {
    pub elements: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Tuple {
    pub elements: Vec>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Array {
    pub elements: Vec>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Set {
    pub elements: Vec>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Command {
    pub aliases: Option>,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: CommandKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum CommandAliases {
    AliasedExpr(AliasedExpr),
    ModuleAliasDecl(ModuleAliasDecl),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum CommandKind {
    SessionSetAliasDecl(SessionSetAliasDecl),
    SessionResetAliasDecl(SessionResetAliasDecl),
    SessionResetModule(SessionResetModule),
    SessionResetAllAliases(SessionResetAllAliases),
    DDLCommand(DDLCommand),
    DescribeStmt(DescribeStmt),
    ExplainStmt(ExplainStmt),
    AdministerStmt(AdministerStmt),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SessionSetAliasDecl {
    pub decl: ModuleAliasDecl,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SessionResetAliasDecl {
    pub alias: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SessionResetModule {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SessionResetAllAliases {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ShapeOperation {
    pub op: ShapeOp,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ShapeElement {
    pub expr: Path,
    pub elements: Option>,
    pub compexpr: Option>,
    pub cardinality: Option,
    pub required: Option,
    pub operation: ShapeOperation,
    pub origin: ShapeOrigin,
    pub r#where: Option>,
    pub orderby: Option>,
    pub offset: Option>,
    pub limit: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Shape {
    pub expr: Option>,
    pub elements: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Query {
    pub aliases: Option>,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: QueryKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum QueryAliases {
    AliasedExpr(AliasedExpr),
    ModuleAliasDecl(ModuleAliasDecl),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum QueryKind {
    PipelinedQuery(PipelinedQuery),
    GroupQuery(GroupQuery),
    InsertQuery(InsertQuery),
    UpdateQuery(UpdateQuery),
    ForQuery(ForQuery),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct PipelinedQuery {
    pub implicit: bool,
    pub r#where: Option>,
    pub orderby: Option>,
    pub offset: Option>,
    pub limit: Option>,
    pub rptr_passthrough: bool,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: PipelinedQueryKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum PipelinedQueryKind {
    SelectQuery(SelectQuery),
    DeleteQuery(DeleteQuery),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SelectQuery {
    pub result_alias: Option,
    pub result: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct GroupingIdentList {
    pub elements: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum GroupingIdentListElements {
    ObjectRef(ObjectRef),
    Path(Path),
    GroupingIdentList(GroupingIdentList),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct GroupingElement {
    #[cfg_attr(feature = "python", py_child)]
    pub kind: GroupingElementKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum GroupingElementKind {
    GroupingSimple(GroupingSimple),
    GroupingSets(GroupingSets),
    GroupingOperation(GroupingOperation),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct GroupingSimple {
    pub element: GroupingSimpleElement,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum GroupingSimpleElement {
    ObjectRef(ObjectRef),
    Path(Path),
    GroupingIdentList(GroupingIdentList),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct GroupingSets {
    pub sets: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct GroupingOperation {
    pub oper: String,
    pub elements: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum GroupingOperationElements {
    ObjectRef(ObjectRef),
    Path(Path),
    GroupingIdentList(GroupingIdentList),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct GroupQuery {
    pub subject_alias: Option,
    pub using: Option>,
    pub by: Vec,
    pub subject: Box,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum GroupQueryKind {
    InternalGroupQuery(InternalGroupQuery),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct InternalGroupQuery {
    pub group_alias: String,
    pub grouping_alias: Option,
    pub from_desugaring: bool,
    pub result_alias: Option,
    pub result: Box,
    pub r#where: Option>,
    pub orderby: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct InsertQuery {
    pub subject: ObjectRef,
    pub shape: Vec,
    pub unless_conflict: Option<(Option>, Option>)>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct UpdateQuery {
    pub shape: Vec,
    pub subject: Box,
    pub r#where: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DeleteQuery {
    pub subject: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ForQuery {
    pub iterator: Box,
    pub iterator_alias: String,
    pub result_alias: Option,
    pub result: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Transaction {
    #[cfg_attr(feature = "python", py_child)]
    pub kind: TransactionKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum TransactionKind {
    StartTransaction(StartTransaction),
    CommitTransaction(CommitTransaction),
    RollbackTransaction(RollbackTransaction),
    DeclareSavepoint(DeclareSavepoint),
    RollbackToSavepoint(RollbackToSavepoint),
    ReleaseSavepoint(ReleaseSavepoint),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct StartTransaction {
    pub isolation: Option,
    pub access: Option,
    pub deferrable: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CommitTransaction {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct RollbackTransaction {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DeclareSavepoint {
    pub name: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct RollbackToSavepoint {
    pub name: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ReleaseSavepoint {
    pub name: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Position {
    pub r#ref: Option,
    pub position: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DDLOperation {
    pub commands: Vec,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: DDLOperationKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum DDLOperationKind {
    DDLCommand(DDLCommand),
    AlterAddInherit(AlterAddInherit),
    AlterDropInherit(AlterDropInherit),
    OnTargetDelete(OnTargetDelete),
    OnSourceDelete(OnSourceDelete),
    SetField(SetField),
    SetAccessPerms(SetAccessPerms),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DDLCommand {
    #[cfg_attr(feature = "python", py_child)]
    pub kind: DDLCommandKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum DDLCommandKind {
    NamedDDL(NamedDDL),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterAddInherit {
    pub position: Option,
    pub bases: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterDropInherit {
    pub bases: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct OnTargetDelete {
    pub cascade: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct OnSourceDelete {
    pub cascade: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SetField {
    pub name: String,
    pub value: SetFieldValue,
    pub special_syntax: bool,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum SetFieldValue {
    Expr(Box),
    TypeExpr(TypeExpr),
    NoneType(()),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum SetFieldKind {
    SetPointerType(SetPointerType),
    SetPointerCardinality(SetPointerCardinality),
    SetPointerOptionality(SetPointerOptionality),
    SetGlobalType(SetGlobalType),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SetPointerType {
    pub name: String,
    pub value: Option,
    pub special_syntax: bool,
    pub cast_expr: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SetPointerCardinality {
    pub name: String,
    pub special_syntax: bool,
    pub conv_expr: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SetPointerOptionality {
    pub name: String,
    pub special_syntax: bool,
    pub fill_expr: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct NamedDDL {
    pub name: ObjectRef,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: NamedDDLKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum NamedDDLKind {
    ObjectDDL(ObjectDDL),
    Rename(Rename),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ObjectDDL {
    #[cfg_attr(feature = "python", py_child)]
    pub kind: ObjectDDLKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum ObjectDDLKind {
    CreateObject(CreateObject),
    AlterObject(AlterObject),
    DropObject(DropObject),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateObject {
    pub r#abstract: bool,
    pub sdl_alter_if_exists: bool,
    pub create_if_not_exists: bool,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum CreateObjectKind {
    CreateExtendingObject(CreateExtendingObject),
    CreateMigration(CreateMigration),
    CreateDatabase(CreateDatabase),
    CreateExtensionPackage(CreateExtensionPackage),
    CreateExtension(CreateExtension),
    CreateFuture(CreateFuture),
    CreateModule(CreateModule),
    CreateRole(CreateRole),
    CreatePseudoType(CreatePseudoType),
    CreateConcretePointer(CreateConcretePointer),
    CreateAlias(CreateAlias),
    CreateGlobal(CreateGlobal),
    CreatePermission(CreatePermission),
    CreateConcreteConstraint(CreateConcreteConstraint),
    CreateConcreteIndex(CreateConcreteIndex),
    CreateAnnotationValue(CreateAnnotationValue),
    CreateAccessPolicy(CreateAccessPolicy),
    CreateTrigger(CreateTrigger),
    CreateRewrite(CreateRewrite),
    CreateFunction(CreateFunction),
    CreateOperator(CreateOperator),
    CreateCast(CreateCast),
    CreateIndexMatch(CreateIndexMatch),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterObject {
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum AlterObjectKind {
    AlterMigration(AlterMigration),
    AlterDatabase(AlterDatabase),
    AlterModule(AlterModule),
    AlterRole(AlterRole),
    AlterAnnotation(AlterAnnotation),
    AlterScalarType(AlterScalarType),
    AlterProperty(AlterProperty),
    AlterConcreteProperty(AlterConcreteProperty),
    AlterObjectType(AlterObjectType),
    AlterAlias(AlterAlias),
    AlterGlobal(AlterGlobal),
    AlterPermission(AlterPermission),
    AlterLink(AlterLink),
    AlterConcreteLink(AlterConcreteLink),
    AlterConstraint(AlterConstraint),
    AlterConcreteConstraint(AlterConcreteConstraint),
    AlterIndex(AlterIndex),
    AlterConcreteIndex(AlterConcreteIndex),
    AlterAnnotationValue(AlterAnnotationValue),
    AlterAccessPolicy(AlterAccessPolicy),
    AlterTrigger(AlterTrigger),
    AlterRewrite(AlterRewrite),
    AlterFunction(AlterFunction),
    AlterOperator(AlterOperator),
    AlterCast(AlterCast),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropObject {
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum DropObjectKind {
    DropMigration(DropMigration),
    DropDatabase(DropDatabase),
    DropExtensionPackage(DropExtensionPackage),
    DropExtension(DropExtension),
    DropFuture(DropFuture),
    DropModule(DropModule),
    DropRole(DropRole),
    DropAnnotation(DropAnnotation),
    DropScalarType(DropScalarType),
    DropProperty(DropProperty),
    DropConcreteProperty(DropConcreteProperty),
    DropObjectType(DropObjectType),
    DropAlias(DropAlias),
    DropGlobal(DropGlobal),
    DropPermission(DropPermission),
    DropLink(DropLink),
    DropConcreteLink(DropConcreteLink),
    DropConstraint(DropConstraint),
    DropConcreteConstraint(DropConcreteConstraint),
    DropIndex(DropIndex),
    DropConcreteIndex(DropConcreteIndex),
    DropAnnotationValue(DropAnnotationValue),
    DropAccessPolicy(DropAccessPolicy),
    DropTrigger(DropTrigger),
    DropRewrite(DropRewrite),
    DropFunction(DropFunction),
    DropOperator(DropOperator),
    DropCast(DropCast),
    DropIndexMatch(DropIndexMatch),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateExtendingObject {
    pub r#final: bool,
    pub bases: Vec,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum CreateExtendingObjectKind {
    CreateAnnotation(CreateAnnotation),
    CreateScalarType(CreateScalarType),
    CreateProperty(CreateProperty),
    CreateObjectType(CreateObjectType),
    CreateLink(CreateLink),
    CreateConcreteLink(CreateConcreteLink),
    CreateConstraint(CreateConstraint),
    CreateIndex(CreateIndex),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Rename {
    pub new_name: ObjectRef,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct NestedQLBlock {
    pub commands: Vec,
    pub text: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateMigration {
    pub body: NestedQLBlock,
    pub parent: Option,
    pub metadata_only: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CommittedSchema {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct StartMigration {
    pub target: StartMigrationTarget,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum StartMigrationTarget {
    Schema(Schema),
    CommittedSchema(CommittedSchema),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AbortMigration {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct PopulateMigration {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterCurrentMigrationRejectProposed {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DescribeCurrentMigration {
    pub language: DescribeLanguage,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CommitMigration {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterMigration {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropMigration {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ResetSchema {
    pub target: ObjectRef,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct StartMigrationRewrite {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AbortMigrationRewrite {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CommitMigrationRewrite {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateDatabase {
    pub template: Option,
    pub branch_type: BranchType,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterDatabase {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropDatabase {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateExtensionPackage {
    pub body: NestedQLBlock,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropExtensionPackage {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateExtension {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropExtension {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateFuture {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropFuture {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateModule {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterModule {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropModule {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateRole {
    pub superuser: bool,
    pub bases: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterRole {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropRole {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateAnnotation {
    pub r#type: Option,
    pub inheritable: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterAnnotation {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropAnnotation {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreatePseudoType {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateScalarType {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterScalarType {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropScalarType {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateProperty {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterProperty {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropProperty {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateConcretePointer {
    pub is_required: Option,
    pub declared_overloaded: bool,
    pub target: CreateConcretePointerTarget,
    pub cardinality: SchemaCardinality,
    pub bases: Vec,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum CreateConcretePointerTarget {
    Expr(Box),
    TypeExpr(TypeExpr),
    NoneType(()),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum CreateConcretePointerKind {
    CreateConcreteUnknownPointer(CreateConcreteUnknownPointer),
    CreateConcreteProperty(CreateConcreteProperty),
    CreateConcreteLink(CreateConcreteLink),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateConcreteUnknownPointer {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateConcreteProperty {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterConcreteProperty {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropConcreteProperty {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateObjectType {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterObjectType {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropObjectType {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateAlias {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterAlias {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropAlias {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateGlobal {
    pub is_required: Option,
    pub target: CreateGlobalTarget,
    pub cardinality: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum CreateGlobalTarget {
    Expr(Box),
    TypeExpr(TypeExpr),
    NoneType(()),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterGlobal {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropGlobal {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SetGlobalType {
    pub name: String,
    pub value: Option,
    pub special_syntax: bool,
    pub cast_expr: Option>,
    pub reset_value: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreatePermission {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterPermission {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropPermission {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateLink {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterLink {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropLink {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateConcreteLink {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterConcreteLink {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropConcreteLink {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateConstraint {
    pub r#abstract: bool,
    pub subjectexpr: Option>,
    pub params: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterConstraint {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropConstraint {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateConcreteConstraint {
    pub delegated: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterConcreteConstraint {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropConcreteConstraint {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct IndexType {
    pub name: ObjectRef,
    pub args: Vec>,
    pub kwargs: IndexMap>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct IndexCode {
    pub language: Language,
    pub code: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateIndex {
    pub kwargs: IndexMap>,
    pub index_types: Vec,
    pub code: Option,
    pub params: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterIndex {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropIndex {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateConcreteIndex {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterConcreteIndex {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropConcreteIndex {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateIndexMatch {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropIndexMatch {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateAnnotationValue {
    pub value: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterAnnotationValue {
    pub value: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropAnnotationValue {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateAccessPolicy {
    pub condition: Option>,
    pub action: AccessPolicyAction,
    pub access_kinds: Vec,
    pub expr: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct SetAccessPerms {
    pub access_kinds: Vec,
    pub action: AccessPolicyAction,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterAccessPolicy {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropAccessPolicy {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateTrigger {
    pub timing: TriggerTiming,
    pub kinds: Vec,
    pub scope: TriggerScope,
    pub expr: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterTrigger {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropTrigger {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateRewrite {
    pub expr: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterRewrite {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropRewrite {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct FunctionCode {
    pub language: Language,
    pub code: Option,
    pub nativecode: Option>,
    pub from_function: Option,
    pub from_expr: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateFunction {
    pub returning: TypeExpr,
    pub code: FunctionCode,
    pub nativecode: Option>,
    pub returning_typemod: TypeModifier,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterFunction {
    pub code: FunctionCode,
    pub nativecode: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropFunction {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct OperatorCode {
    pub language: Language,
    pub from_operator: Option>,
    pub from_function: Option>,
    pub from_expr: bool,
    pub code: Option,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateOperator {
    pub returning: TypeExpr,
    pub returning_typemod: TypeModifier,
    pub code: OperatorCode,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterOperator {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropOperator {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CastCode {
    pub language: Language,
    pub from_function: String,
    pub from_expr: bool,
    pub from_cast: bool,
    pub code: String,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct CreateCast {
    pub code: CastCode,
    pub allow_implicit: bool,
    pub allow_assignment: bool,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AlterCast {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DropCast {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ConfigOp {
    pub name: ObjectRef,
    pub scope: ConfigScope,
    #[cfg_attr(feature = "python", py_child)]
    pub kind: ConfigOpKind,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_child)]
pub enum ConfigOpKind {
    ConfigSet(ConfigSet),
    ConfigInsert(ConfigInsert),
    ConfigReset(ConfigReset),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ConfigSet {
    pub expr: Box,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ConfigInsert {
    pub shape: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ConfigReset {
    pub r#where: Option>,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct DescribeStmt {
    pub language: DescribeLanguage,
    pub object: DescribeStmtObject,
    pub options: Options,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum DescribeStmtObject {
    ObjectRef(ObjectRef),
    DescribeGlobal(DescribeGlobal),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ExplainStmt {
    pub args: Option,
    pub query: Query,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct AdministerStmt {
    pub expr: FunctionCall,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct ModuleDeclaration {
    pub name: ObjectRef,
    pub declarations: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum ModuleDeclarationDeclarations {
    NamedDDL(DDLOperation),
    ModuleDeclaration(ModuleDeclaration),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
pub struct Schema {
    pub declarations: Vec,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_union)]
pub enum SchemaDeclarations {
    NamedDDL(DDLOperation),
    ModuleDeclaration(ModuleDeclaration),
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qlast.SortOrder))]
pub enum SortOrder {
    Asc,
    Desc,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qlast.NonesOrder))]
pub enum NonesOrder {
    First,
    Last,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qlast.CardinalityModifier))]
pub enum CardinalityModifier {
    Optional,
    Required,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qlast.DescribeGlobal))]
pub enum DescribeGlobal {
    Schema,
    DatabaseConfig,
    InstanceConfig,
    Roles,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qlast.ShapeOp))]
pub enum ShapeOp {
    APPEND,
    SUBTRACT,
    ASSIGN,
    MATERIALIZE,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qlast.ShapeOrigin))]
pub enum ShapeOrigin {
    EXPLICIT,
    DEFAULT,
    SPLAT_EXPANSION,
    MATERIALIZATION,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qlast.Language))]
pub enum Language {
    SQL,
    EdgeQL,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.ParameterKind))]
pub enum ParameterKind {
    VariadicParam,
    NamedOnlyParam,
    PositionalParam,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.TypeModifier))]
pub enum TypeModifier {
    SetOfType,
    OptionalType,
    SingletonType,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.OperatorKind))]
pub enum OperatorKind {
    Infix,
    Postfix,
    Prefix,
    Ternary,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.TransactionIsolationLevel))]
pub enum TransactionIsolationLevel {
    REPEATABLE_READ,
    SERIALIZABLE,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.TransactionAccessMode))]
pub enum TransactionAccessMode {
    READ_WRITE,
    READ_ONLY,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.TransactionDeferMode))]
pub enum TransactionDeferMode {
    DEFERRABLE,
    NOT_DEFERRABLE,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.SchemaCardinality))]
pub enum SchemaCardinality {
    One,
    Many,
    Unknown,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.Cardinality))]
pub enum Cardinality {
    AT_MOST_ONE,
    ONE,
    MANY,
    AT_LEAST_ONE,
    UNKNOWN,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.Volatility))]
pub enum Volatility {
    Immutable,
    Stable,
    Volatile,
    Modifying,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.Multiplicity))]
pub enum Multiplicity {
    EMPTY,
    UNIQUE,
    DUPLICATE,
    UNKNOWN,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.AccessPolicyAction))]
pub enum AccessPolicyAction {
    Allow,
    Deny,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.AccessKind))]
pub enum AccessKind {
    Select,
    UpdateRead,
    UpdateWrite,
    Delete,
    Insert,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.TriggerTiming))]
pub enum TriggerTiming {
    After,
    AfterCommitOf,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.TriggerKind))]
pub enum TriggerKind {
    Update,
    Delete,
    Insert,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.TriggerScope))]
pub enum TriggerScope {
    Each,
    All,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.RewriteKind))]
pub enum RewriteKind {
    Update,
    Insert,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.DescribeLanguage))]
pub enum DescribeLanguage {
    DDL,
    SDL,
    TEXT,
    JSON,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.SchemaObjectClass))]
pub enum SchemaObjectClass {
    ACCESS_POLICY,
    ALIAS,
    ANNOTATION,
    ARRAY_TYPE,
    CAST,
    CONSTRAINT,
    DATABASE,
    EXTENSION,
    EXTENSION_PACKAGE,
    FUTURE,
    FUNCTION,
    GLOBAL,
    INDEX,
    LINK,
    MIGRATION,
    MODULE,
    OPERATOR,
    PARAMETER,
    PERMISSION,
    PROPERTY,
    PSEUDO_TYPE,
    RANGE_TYPE,
    REWRITE,
    ROLE,
    SCALAR_TYPE,
    TRIGGER,
    TUPLE_TYPE,
    TYPE,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.LinkTargetDeleteAction))]
pub enum LinkTargetDeleteAction {
    Restrict,
    DeleteSource,
    Allow,
    DeferredRestrict,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.LinkSourceDeleteAction))]
pub enum LinkSourceDeleteAction {
    DeleteTarget,
    Allow,
    DeleteTargetIfOrphan,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qltypes.ConfigScope))]
pub enum ConfigScope {
    INSTANCE,
    DATABASE,
    SESSION,
    GLOBAL,
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "python", derive(IntoPython))]
#[cfg_attr(feature = "python", py_enum(qlast.BranchType))]
pub enum BranchType {
    EMPTY,
    SCHEMA,
    DATA,
}


================================================
FILE: edb/edgeql-parser/src/expr.rs
================================================
use crate::position::{InflatedPos, Pos};
use crate::tokenizer::{self, Kind};

/// Error of expression checking
///
/// See [check][].
#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("{}: tokenizer error: {}", _1, _0)]
    Tokenizer(String, Pos),
    #[error(
        "{}: closing bracket mismatch, opened {:?} at {}, encountered {:?}",
        closing_pos,
        opened,
        opened_pos,
        encountered
    )]
    BracketMismatch {
        opened: &'static str,
        encountered: &'static str,
        opened_pos: Pos,
        closing_pos: Pos,
    },
    #[error("{}: extra closing bracket {:?}", _1, _0)]
    ExtraBracket(&'static str, Pos),
    #[error("{}: bracket {:?} has never been closed", _1, _0)]
    MissingBracket(&'static str, Pos),
    #[error(
        "{}: token {:?} is not allowed in expression \
             (try parenthesize the expression)",
        _1,
        _0
    )]
    UnexpectedToken(String, Pos),
    #[error("expression is empty")]
    Empty,
}

fn bracket_str(tok: Kind) -> &'static str {
    use crate::tokenizer::Kind::*;

    match tok {
        OpenBracket => "[",
        CloseBracket => "]",
        OpenBrace => "{",
        CloseBrace => "}",
        OpenParen => "(",
        CloseParen => ")",
        _ => unreachable!("token is not a bracket"),
    }
}

fn matching_bracket(tok: Kind) -> Kind {
    use crate::tokenizer::Kind::*;

    match tok {
        OpenBracket => CloseBracket,
        OpenBrace => CloseBrace,
        OpenParen => CloseParen,
        _ => unreachable!("token is not a bracket"),
    }
}

/// Minimal validation of expression
///
/// This is used for substitutions in migrations. This check merely ensures
/// that overall structure of the statement is not ruined. Mostly checks for
/// matching brackets and quotes closed.
///
/// More specificaly current implementation checks that expression is not
/// empty, checks for valid tokens, matching braces and disallows comma `,`and
/// semicolon `;` outside of brackets.
///
/// This is NOT a security measure.
pub fn check(text: &str) -> Result<(), Error> {
    use crate::tokenizer::Kind::*;
    use Error::*;

    let mut brackets = Vec::new();
    let mut parser = &mut tokenizer::Tokenizer::new(text);
    let mut empty = true;
    for token in &mut parser {
        let token = match token {
            Ok(t) => t,
            Err(crate::tokenizer::Error { message, .. }) => {
                return Err(Tokenizer(message, parser.current_pos()));
            }
        };
        let pos = token.span.start;
        let pos = InflatedPos::from_offset(text.as_bytes(), pos)
            .unwrap()
            .deflate();

        empty = false;
        match token.kind {
            Comma | Semicolon if brackets.is_empty() => {
                return Err(UnexpectedToken(token.text.into(), pos));
            }
            OpenParen | OpenBracket | OpenBrace => {
                brackets.push((token.kind, pos));
            }
            CloseParen | CloseBracket | CloseBrace => match brackets.pop() {
                Some((opened, opened_pos)) => {
                    if matching_bracket(opened) != token.kind {
                        return Err(BracketMismatch {
                            opened: bracket_str(opened),
                            opened_pos,
                            encountered: bracket_str(token.kind),
                            closing_pos: pos,
                        });
                    }
                }
                None => {
                    return Err(ExtraBracket(bracket_str(token.kind), pos));
                }
            },
            _ => {}
        }
    }
    if let Some((bracket, pos)) = brackets.pop() {
        return Err(MissingBracket(bracket_str(bracket), pos));
    }
    if empty {
        return Err(Empty);
    }
    Ok(())
}


================================================
FILE: edb/edgeql-parser/src/hash.rs
================================================
use sha2::digest::Digest;

use crate::position::Pos;
use crate::tokenizer::Tokenizer;

#[derive(Debug, Clone)]
pub struct Hasher {
    hasher: sha2::Sha256,
}

#[derive(Debug)]
pub enum Error {
    // TODO: use [crate::Error] instead
    Tokenizer(String, Pos),
}

impl Hasher {
    pub fn start_migration(parent_id: &str) -> Hasher {
        let mut me = Hasher {
            hasher: sha2::Sha256::new(),
        };
        me.hasher.update(b"CREATE\0MIGRATION\0ONTO\0");
        me.hasher.update(parent_id.as_bytes());
        me.hasher.update(b"\0{\0");
        me
    }
    pub fn add_source(&mut self, data: &str) -> Result<&mut Self, Error> {
        let mut parser = &mut Tokenizer::new(data);
        for token in &mut parser {
            let token = match token {
                Ok(t) => t,
                Err(crate::tokenizer::Error { message, .. }) => {
                    return Err(Error::Tokenizer(message, parser.current_pos()));
                }
            };
            self.hasher.update(token.text.as_bytes());
            self.hasher.update(b"\0");
        }
        Ok(self)
    }
    pub fn make_migration_id(mut self) -> String {
        self.hasher.update(b"}\0");
        let hash = base32::encode(
            base32::Alphabet::Rfc4648 { padding: false },
            &self.hasher.finalize(),
        );
        format!("m1{}", hash.to_ascii_lowercase())
    }
}

#[cfg(test)]
mod test {
    use super::Hasher;

    fn hash(initial: &str, text: &str) -> String {
        let mut hasher = Hasher::start_migration(initial);
        hasher.add_source(text).unwrap();
        hasher.make_migration_id()
    }

    #[test]
    fn empty() {
        assert_eq!(
            hash("initial", "    \n   "),
            "m1tjyzfl33vvzwjd5izo5nyp4zdsekyvxpdm7zhtt5ufmqjzczopdq"
        );
    }

    #[test]
    fn hash_1() {
        assert_eq!(
            hash(
                "m1g3qzqdr57pp3w2mdwdkq4g7dq4oefawqdavzgeiov7fiwntpb3lq",
                r###"
                CREATE TYPE Type1;
            "###
            ),
            "m1fvpcra5cxntkss3k2to2yfu7pit3t3owesvdw2nysqvvpihdiszq"
        );
    }

    #[test]
    fn tokens_arent_normalized() {
        assert_eq!(
            hash(
                "m1g3qzqdr57pp3w2mdwdkq4g7dq4oefawqdavzgeiov7fiwntpb3lq",
                r###"
                CREATE type Type1;
            "###
            ),
            "m1ddghtidugdk3mazwfzpfblqzuoqvsxpivgy2fbq4vywykab7z5rq"
        );

        assert_eq!(
            hash(
                "m1g3qzqdr57pp3w2mdwdkq4g7dq4oefawqdavzgeiov7fiwntpb3lq",
                r###"
                creATE TyPe Type1;
            "###
            ),
            "m1oc32ytxeqlvxeyps3ozqiqazy2duuz5bcqog7nkhubmkbsjgf4vq"
        );
    }

    #[test]
    fn hash_parent() {
        assert_eq!(
            hash(
                "initial",
                r###"
                CREATE TYPE Type1;
            "###
            ),
            "m1q3jjfe7zjl74v3n2vxjwzneousdas6vvd4qwrfd6j6xmhmktyada"
        );
    }
}


================================================
FILE: edb/edgeql-parser/src/helpers/bytes.rs
================================================
pub fn unquote_bytes(value: &str) -> Result, String> {
    let idx = value
        .find(['\'', '"'])
        .ok_or_else(|| "invalid bytes literal: missing quotes".to_string())?;
    let prefix = &value[..idx];
    match prefix {
        "br" | "rb" => Ok(value.as_bytes()[3..value.len() - 1].to_vec()),
        "b" => Ok(unquote_bytes_inner(&value[2..value.len() - 1])?),
        _ => Err(
            format_args!("prefix {prefix:?} is not allowed for bytes, allowed: `b`, `rb`",)
                .to_string(),
        ),
    }
}

fn unquote_bytes_inner(s: &str) -> Result, String> {
    let mut res = Vec::with_capacity(s.len());
    let mut bytes = s.as_bytes().iter();
    while let Some(&c) = bytes.next() {
        match c {
            b'\\' => {
                match *bytes.next().expect("slash cant be at the end") {
                    c @ b'"' | c @ b'\\' | c @ b'/' | c @ b'\'' => res.push(c),
                    b'b' => res.push(b'\x08'),
                    b'f' => res.push(b'\x0C'),
                    b'n' => res.push(b'\n'),
                    b'r' => res.push(b'\r'),
                    b't' => res.push(b'\t'),
                    b'x' => {
                        let tail = &s[s.len() - bytes.as_slice().len()..];
                        let hex = tail.get(0..2);
                        let code = hex
                            .and_then(|s| u8::from_str_radix(s, 16).ok())
                            .ok_or_else(|| {
                                format!(
                                    "invalid bytes literal: \
                                invalid escape sequence '\\x{}'",
                                    hex.unwrap_or(tail).escape_debug()
                                )
                            })?;
                        res.push(code);
                        bytes.nth(1);
                    }
                    b'\r' | b'\n' => {
                        let nskip = bytes
                            .as_slice()
                            .iter()
                            .take_while(|&&x| x.is_ascii_whitespace())
                            .count();
                        if nskip > 0 {
                            bytes.nth(nskip - 1);
                        }
                    }
                    c => {
                        let ch = if c < 0x7f {
                            c as char
                        } else {
                            // recover the unicode byte
                            s[s.len() - bytes.as_slice().len() - 1..]
                                .chars()
                                .next()
                                .unwrap()
                        };
                        return Err(format!(
                            "invalid bytes literal: \
                            invalid escape sequence '\\{}'",
                            ch.escape_debug()
                        ));
                    }
                }
            }
            c => res.push(c),
        }
    }

    Ok(res)
}

#[test]
fn simple_bytes() {
    assert_eq!(unquote_bytes_inner(r"\x09").unwrap(), b"\x09");
    assert_eq!(unquote_bytes_inner(r"\x0A").unwrap(), b"\x0A");
    assert_eq!(unquote_bytes_inner(r"\x0D").unwrap(), b"\x0D");
    assert_eq!(unquote_bytes_inner(r"\x20").unwrap(), b"\x20");
    assert_eq!(unquote_bytes(r"b'\x09'").unwrap(), b"\x09");
    assert_eq!(unquote_bytes(r"b'\x0A'").unwrap(), b"\x0A");
    assert_eq!(unquote_bytes(r"b'\x0D'").unwrap(), b"\x0D");
    assert_eq!(unquote_bytes(r"b'\x20'").unwrap(), b"\x20");
    assert_eq!(unquote_bytes(r"br'\x09'").unwrap(), b"\\x09");
    assert_eq!(unquote_bytes(r"br'\x0A'").unwrap(), b"\\x0A");
    assert_eq!(unquote_bytes(r"br'\x0D'").unwrap(), b"\\x0D");
    assert_eq!(unquote_bytes(r"br'\x20'").unwrap(), b"\\x20");
}

#[test]
fn newline_escaping_bytes() {
    assert_eq!(
        unquote_bytes_inner(
            r"hello \
                                world"
        )
        .unwrap(),
        b"hello world"
    );
    assert_eq!(
        unquote_bytes(
            r"br'hello \
                                world'"
        )
        .unwrap(),
        b"hello \\\n                                world"
    );

    assert_eq!(
        unquote_bytes_inner(
            r"bb\
aa \
            bb"
        )
        .unwrap(),
        b"bbaa bb"
    );
    assert_eq!(
        unquote_bytes(
            r"rb'bb\
aa \
            bb'"
        )
        .unwrap(),
        b"bb\\\naa \\\n            bb"
    );
    assert_eq!(
        unquote_bytes_inner(
            r"bb\

        aa"
        )
        .unwrap(),
        b"bbaa"
    );
    assert_eq!(
        unquote_bytes(
            r"br'bb\

        aa'"
        )
        .unwrap(),
        b"bb\\\n\n        aa"
    );
    assert_eq!(
        unquote_bytes_inner(
            r"bb\
        \
        aa"
        )
        .unwrap(),
        b"bbaa"
    );
    assert_eq!(
        unquote_bytes(
            r"rb'bb\
        \
        aa'"
        )
        .unwrap(),
        b"bb\\\n        \\\n        aa"
    );
    assert_eq!(unquote_bytes_inner("bb\\\r   aa").unwrap(), b"bbaa");
    assert_eq!(unquote_bytes("br'bb\\\r   aa'").unwrap(), b"bb\\\r   aa");
    assert_eq!(unquote_bytes_inner("bb\\\r\n   aa").unwrap(), b"bbaa");
    assert_eq!(
        unquote_bytes("rb'bb\\\r\n   aa'").unwrap(),
        b"bb\\\r\n   aa"
    );
}

#[test]
fn complex_bytes() {
    assert_eq!(
        unquote_bytes_inner(r"\x09 hello \x0A there").unwrap(),
        b"\x09 hello \x0A there"
    );
    assert_eq!(
        unquote_bytes(r"br'\x09 hello \x0A there'").unwrap(),
        b"\\x09 hello \\x0A there"
    );
}


================================================
FILE: edb/edgeql-parser/src/helpers/mod.rs
================================================
mod bytes;
mod strings;

pub use bytes::*;
pub use strings::*;


================================================
FILE: edb/edgeql-parser/src/helpers/strings.rs
================================================
use std::borrow::Cow;
use std::char;
use std::error::Error;
use std::fmt::{self, Write};

use crate::keywords;

/// Error returned from `unquote_string` function
///
/// Opaque for now
#[derive(Debug)]
pub struct UnquoteError(String);

/// Converts the string into edgeql-compatible name (of a column or a property)
///
/// # Examples
/// ```
/// use edgeql_parser::helpers::quote_name;
/// assert_eq!(quote_name("col1"), "col1");
/// assert_eq!(quote_name("another name"), "`another name`");
/// assert_eq!(quote_name("with `quotes`"), "`with ``quotes```");
/// ```
pub fn quote_name(s: &str) -> Cow {
    if s.chars().all(|c| c.is_alphanumeric() || c == '_') {
        let lower = s.to_ascii_lowercase();
        if keywords::lookup(&lower).is_none() {
            return s.into();
        }
    }
    let escaped = s.replace('`', "``");
    let mut s = String::with_capacity(escaped.len() + 2);
    s.push('`');
    s.push_str(&escaped);
    s.push('`');
    s.into()
}

pub fn quote_string(s: &str) -> String {
    let mut buf = String::with_capacity(s.len() + 2);
    buf.push('"');
    for c in s.chars() {
        match c {
            '"' => {
                buf.push('\\');
                buf.push('"');
            }
            '\\' => {
                buf.push('\\');
                buf.push('\\');
            }
            '\x00'..='\x08'
            | '\x0B'
            | '\x0C'
            | '\x0E'..='\x1F'
            | '\u{007F}'
            | '\u{0080}'..='\u{009F}' => {
                write!(buf, "\\x{:02x}", c as u32).unwrap();
            }
            c => buf.push(c),
        }
    }
    buf.push('"');
    buf
}

pub fn unquote_string(value: &str) -> Result, UnquoteError> {
    if value.starts_with('r') {
        Ok(value[2..value.len() - 1].into())
    } else if let Some(stripped) = value.strip_prefix('$') {
        let msize = 2 + stripped
            .find('$')
            .ok_or_else(|| "invalid dollar-quoted string".to_string())
            .map_err(UnquoteError)?;
        Ok(value[msize..value.len() - msize].into())
    } else {
        let end_trim = if value.ends_with("\\(") { 2 } else { 1 };

        Ok(_unquote_string(&value[1..value.len() - end_trim])
            .map_err(UnquoteError)?
            .into())
    }
}

fn _unquote_string(s: &str) -> Result {
    let mut res = String::with_capacity(s.len());
    let mut chars = s.chars();
    while let Some(c) = chars.next() {
        match c {
            '\\' => {
                let c = chars
                    .next()
                    .ok_or_else(|| "quoted string cannot end in slash".to_string())?;
                match c {
                    c @ '"' | c @ '\\' | c @ '/' | c @ '\'' => res.push(c),
                    'b' => res.push('\u{0008}'),
                    'f' => res.push('\u{000C}'),
                    'n' => res.push('\n'),
                    'r' => res.push('\r'),
                    't' => res.push('\t'),
                    'x' => {
                        let hex = chars.as_str().get(0..2);
                        let code = hex
                            .and_then(|s| u8::from_str_radix(s, 16).ok())
                            .ok_or_else(|| {
                                format!(
                                    "invalid string literal: \
                                invalid escape sequence '\\x{}'",
                                    hex.unwrap_or(chars.as_str()).escape_debug()
                                )
                            })?;
                        if code > 0x7f || code == 0 {
                            return Err(format!(
                                "invalid string literal: \
                                 invalid escape sequence '\\x{code:x}' \
                                 (only non-null ascii allowed)"
                            ));
                        }
                        res.push(code as char);
                        chars.nth(1);
                    }
                    'u' => {
                        let hex = chars.as_str().get(0..4);
                        let ch = hex
                            .and_then(|s| u32::from_str_radix(s, 16).ok())
                            .and_then(char::from_u32)
                            .and_then(|c| if c == '\0' { None } else { Some(c) })
                            .ok_or_else(|| {
                                format!(
                                    "invalid string literal: \
                                    invalid escape sequence '\\u{}'",
                                    hex.unwrap_or(chars.as_str()).escape_debug()
                                )
                            })?;
                        res.push(ch);
                        chars.nth(3);
                    }
                    'U' => {
                        let hex = chars.as_str().get(0..8);
                        let ch = hex
                            .and_then(|s| u32::from_str_radix(s, 16).ok())
                            .and_then(char::from_u32)
                            .and_then(|c| if c == '\0' { None } else { Some(c) })
                            .ok_or_else(|| {
                                format!(
                                    "invalid string literal: \
                                    invalid escape sequence '\\U{}'",
                                    hex.unwrap_or(chars.as_str()).escape_debug()
                                )
                            })?;
                        res.push(ch);
                        chars.nth(7);
                    }
                    '\r' | '\n' => {
                        let nleft = chars.as_str().trim_start().len();
                        let nskip = chars.as_str().len() - nleft;
                        if nskip > 0 {
                            chars.nth(nskip - 1);
                        }
                    }
                    c => {
                        return Err(format!(
                            "invalid string literal: \
                             invalid escape sequence '\\{}'",
                            c.escape_debug()
                        ));
                    }
                }
            }
            c => res.push(c),
        }
    }

    Ok(res)
}

#[test]
fn unquote_unicode_string() {
    assert_eq!(_unquote_string(r"\x09").unwrap(), "\u{09}");
    assert_eq!(_unquote_string(r"\u000A").unwrap(), "\u{000A}");
    assert_eq!(_unquote_string(r"\u000D").unwrap(), "\u{000D}");
    assert_eq!(_unquote_string(r"\u0020").unwrap(), "\u{0020}");
    assert_eq!(_unquote_string(r"\uFFFF").unwrap(), "\u{FFFF}");
}

#[test]
fn unquote_string_error() {
    assert_eq!(
        _unquote_string(r"\x00").unwrap_err(),
        "invalid string literal: \
             invalid escape sequence '\\x0' (only non-null ascii allowed)"
    );
    assert_eq!(
        _unquote_string(r"\u0000").unwrap_err(),
        "invalid string literal: invalid escape sequence '\\u0000'"
    );
    assert_eq!(
        _unquote_string(r"\U00000000").unwrap_err(),
        "invalid string literal: invalid escape sequence '\\U00000000'"
    );
}

#[test]
fn newline_escaping_str() {
    assert_eq!(
        _unquote_string(
            r"hello \
                                world"
        )
        .unwrap(),
        "hello world"
    );

    assert_eq!(
        _unquote_string(
            r"bb\
aa \
            bb"
        )
        .unwrap(),
        "bbaa bb"
    );
    assert_eq!(
        _unquote_string(
            r"bb\

        aa"
        )
        .unwrap(),
        "bbaa"
    );
    assert_eq!(
        _unquote_string(
            r"bb\
        \
        aa"
        )
        .unwrap(),
        "bbaa"
    );
    assert_eq!(_unquote_string("bb\\\r   aa").unwrap(), "bbaa");
    assert_eq!(_unquote_string("bb\\\r\n   aa").unwrap(), "bbaa");
}

#[test]
fn test_quote_string() {
    assert_eq!(quote_string(r"\n"), r#""\\n""#);
    assert_eq!(unquote_string("e_string(r"\n")).unwrap(), r"\n");
}

#[test]
fn complex_strings() {
    assert_eq!(
        _unquote_string(r"\u0009 hello \u000A there").unwrap(),
        "\u{0009} hello \u{000A} there"
    );

    assert_eq!(
        _unquote_string(r"\x62:\u2665:\U000025C6").unwrap(),
        "\u{62}:\u{2665}:\u{25C6}"
    );
}

impl fmt::Display for UnquoteError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        self.0.fmt(f)
    }
}
impl Error for UnquoteError {}


================================================
FILE: edb/edgeql-parser/src/keywords.rs
================================================
use phf::phf_set;

pub const UNRESERVED_KEYWORDS: phf::Set<&str> = phf_set!(
    "abort",
    "abstract",
    "access",
    "after",
    "alias",
    "allow",
    "all",
    "annotation",
    "applied",
    "as",
    "asc",
    "assignment",
    "before",
    "blobal",
    "branch",
    "cardinality",
    "cast",
    "committed",
    "config",
    "conflict",
    "constraint",
    "cube",
    "current",
    "data",
    "database",
    "ddl",
    "declare",
    "default",
    "deferrable",
    "deferred",
    "delegated",
    "desc",
    "deny",
    "each",
    "empty",
    "expression",
    "extension",
    "final",
    "first",
    "force",
    "from",
    "function",
    "future",
    "implicit",
    "index",
    "infix",
    "inheritable",
    "instance",
    "into",
    "isolation",
    "json",
    "last",
    "link",
    "migration",
    "multi",
    "named",
    "object",
    "of",
    "only",
    "onto",
    "operator",
    "optionality",
    "order",
    "orphan",
    "overloaded",
    "owned",
    "package",
    "permission",
    "policy",
    "populate",
    "postfix",
    "prefix",
    "property",
    "proposed",
    "pseudo",
    "read",
    "reject",
    "release",
    "rename",
    "repeatable",
    "required",
    "reset",
    "restrict",
    "rewrite",
    "role",
    "roles",
    "rollup",
    "savepoint",
    "scalar",
    "schema",
    "sdl",
    "serializable",
    "session",
    "source",
    "superuser",
    "system",
    "target",
    "template",
    "ternary",
    "text",
    "then",
    "to",
    "transaction",
    "trigger",
    "type",
    "unless",
    "using",
    "verbose",
    "version",
    "view",
    "write",
);

pub const PARTIAL_RESERVED_KEYWORDS: phf::Set<&str> = phf_set!("except", "intersect", "union",);

pub const FUTURE_RESERVED_KEYWORDS: phf::Set<&str> = phf_set!(
    "anyarray",
    "begin",
    "case",
    "check",
    "deallocate",
    "discard",
    "end",
    "explain",
    "fetch",
    "get",
    "global",
    "grant",
    "import",
    "listen",
    "load",
    "lock",
    "match",
    "move",
    "notify",
    "on",
    "over",
    "prepare",
    "partition",
    "raise",
    "refresh",
    "revoke",
    "single",
    "when",
    "window",
    "never",
);

pub const CURRENT_RESERVED_KEYWORDS: phf::Set<&str> = phf_set!(
    "__source__",
    "__subject__",
    "__type__",
    "__std__",
    "__edgedbsys__",
    "__edgedbtpl__",
    "__new__",
    "__old__",
    "__specified__",
    "__default__",
    "administer",
    "alter",
    "analyze",
    "and",
    "anytuple",
    "anytype",
    "anyobject",
    "by",
    "commit",
    "configure",
    "create",
    "delete",
    "describe",
    "detached",
    "distinct",
    "do",
    "drop",
    "else",
    "exists",
    "extending",
    "false",
    "filter",
    "for",
    "group",
    "if",
    "ilike",
    "in",
    "insert",
    "introspect",
    "is",
    "like",
    "limit",
    "module",
    "not",
    "offset",
    "optional",
    "or",
    "rollback",
    "select",
    "set",
    "start",
    "true",
    "typeof",
    "update",
    "variadic",
    "with",
);

pub const COMBINED_KEYWORDS: phf::Set<&str> = phf_set!(
    "named only",
    "set annotation",
    "set type",
    "extension package",
    "order by",
);

pub fn lookup(s: &str) -> Option {
    None.or_else(|| PARTIAL_RESERVED_KEYWORDS.get_key(s))
        .or_else(|| FUTURE_RESERVED_KEYWORDS.get_key(s))
        .or_else(|| CURRENT_RESERVED_KEYWORDS.get_key(s))
        .map(|x| Keyword(x))
}

pub fn lookup_all(s: &str) -> Option {
    lookup(s).or_else(|| {
        None.or_else(|| COMBINED_KEYWORDS.get_key(s))
            .or_else(|| UNRESERVED_KEYWORDS.get_key(s))
            .map(|x| Keyword(x))
    })
}

/// This is required for serde deserializer for Token to work correctly.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Keyword(pub &'static str);

impl Keyword {
    pub fn is_reserved(&self) -> bool {
        FUTURE_RESERVED_KEYWORDS.contains(self.0) || CURRENT_RESERVED_KEYWORDS.contains(self.0)
    }
    pub fn is_unreserved(&self) -> bool {
        UNRESERVED_KEYWORDS.contains(self.0) || PARTIAL_RESERVED_KEYWORDS.contains(self.0)
    }
    pub fn is_dunder(&self) -> bool {
        self.0.starts_with("__") && self.0.ends_with("__")
    }
    pub fn is_bool(&self) -> bool {
        self.0 == "true" || self.0 == "false"
    }
}

impl From for &'static str {
    fn from(value: Keyword) -> Self {
        value.0
    }
}


================================================
FILE: edb/edgeql-parser/src/lib.rs
================================================
pub mod ast;
pub mod expr;
pub mod hash;
pub mod helpers;
pub mod keywords;
pub mod parser;
pub mod position;
pub mod preparser;
pub mod schema_file;
pub mod tokenizer;
pub mod validation;


================================================
FILE: edb/edgeql-parser/src/parser/cst.rs
================================================
use crate::helpers::quote_name;
use crate::keywords::Keyword;
use crate::position::Span;
use crate::tokenizer::{Kind, Token, Value};

/// A node of the CST tree.
///
/// Warning: allocated in the bumpalo arena, which does not Drop.
/// Any types that do allocation with global allocator (such as String or Vec),
/// must manually drop. This is why Terminal has a special vec arena that does
/// Drop.
#[derive(Debug, Clone, Copy, Default)]
pub enum CSTNode<'a> {
    #[default]
    Empty,
    Terminal(&'a Terminal),
    Production(Production<'a>),
}
#[derive(Clone, Debug)]
pub struct Terminal {
    pub kind: Kind,
    pub text: String,
    pub value: Option,
    pub span: Span,
    pub(super) is_placeholder: bool,
}

#[derive(Debug, Clone, Copy)]
pub struct Production<'a> {
    pub id: usize,
    pub args: &'a [CSTNode<'a>],
    pub span: Option,

    /// When a production is inlined, its id is saved into the new production
    /// This is needed when matching CST nodes by production id.
    pub inlined_ids: Option<&'a [usize]>,
}

impl std::fmt::Display for Terminal {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if (self.is_placeholder && self.kind == Kind::Ident) || self.text.is_empty() {
            if let Some(user_friendly) = self.kind.user_friendly_text() {
                return write!(f, "{user_friendly}");
            }
        }

        match self.kind {
            Kind::Ident => write!(f, "'{}'", "e_name(&self.text)),
            Kind::Keyword(Keyword(kw)) => write!(f, "keyword '{}'", kw.to_ascii_uppercase()),
            _ => write!(f, "'{}'", self.text),
        }
    }
}

impl Terminal {
    pub fn from_token(token: Token) -> Self {
        Terminal {
            kind: token.kind,
            text: token.text.into(),
            value: token.value,
            span: token.span,
            is_placeholder: false,
        }
    }

    #[cfg(feature = "serde")]
    pub fn from_start_name(start_name: &str) -> Self {
        use super::spec;

        Terminal {
            kind: spec::get_token_kind(start_name),
            text: "".to_string(),
            value: None,
            span: Default::default(),
            is_placeholder: false,
        }
    }
}


================================================
FILE: edb/edgeql-parser/src/parser/custom_errors.rs
================================================
use crate::tokenizer::Kind;
use crate::{keywords::Keyword, position::Span};

use super::{CSTNode, Context, Error, Parser, StackNode, Terminal};

impl Parser<'_> {
    pub(super) fn custom_error(&self, ctx: &Context, token: &Terminal) -> Option {
        let ltok = self.get_from_top(0).unwrap();

        if let Some(value) = self.custom_error_from_rule(token, ctx) {
            return Some(value);
        }

        if matches!(token.kind, Kind::Keyword(Keyword("explain"))) {
            return Some({
                Error {
                    message: format!("Unexpected keyword '{}'", token.text.to_uppercase()),
                    span: Span::default(),
                    hint: Some("Use `analyze` to show query performance details".to_string()),
                    details: None,
                }
            });
        }

        if let Kind::Keyword(kw) = token.kind {
            if kw.is_reserved() && !Cond::Production("Expr").check(ltok, ctx) {
                // Another token followed by a reserved keyword:
                // likely an attempt to use keyword as identifier
                return Some(unexpected_reserved_keyword(&token.text, token.span));
            }
        };

        None
    }

    fn custom_error_from_rule(&self, token: &Terminal, ctx: &Context) -> Option {
        let last = self.get_from_top(0).unwrap();

        let (i, rule) = self.get_rule(ctx)?;
        // Look at the parsing stack and use tokens and
        // non-terminals to infer the parser rule when the
        // error occurred.

        match rule {
            ParserRule::ListOfArguments
                // The stack is like  LPAREN 
                if i == 1
                    && Cond::AnyOf(vec![
                        Cond::Production("AnyIdentifier"),
                        Cond::keyword("with"),
                        Cond::keyword("select"),
                        Cond::keyword("for"),
                        Cond::keyword("insert"),
                        Cond::keyword("update"),
                        Cond::keyword("delete"),
                    ])
                    .check(last, ctx)
            => {
                return Some(Error {
                    message: "Missing parentheses around statement used as an expression"
                        .to_string(),
                    span: super::get_span_of_nodes(&[last.value]).unwrap_or_default(),
                    hint: None,
                    details: None,
                });
            }

            ParserRule::ArraySlice
                if matches!(token.kind, Kind::Ident | Kind::IntConst)
                && !Cond::Terminal(Kind::Colon).check(last, ctx)
            => {
                // The offending token was something that could
                // make an expression
                return Some(Error::new(format!(
                    "It appears that a ':' is missing in {rule} before {}",
                    token.text
                )));
            },

            ParserRule::Definition if token.kind == Kind::Ident => {
                // Something went wrong in a definition, so check
                // if the last successful token is a keyword.
                if Cond::Production("Identifier").check(last, ctx)
                // TODO: && ltok.value.upper() == "INDEX"
                {
                    return Some(Error::new(format!(
                        "Expected 'ON', but got '{}' instead",
                        token.text
                    )));
                }
            },

            ParserRule::ForIterator => {
                let span = if i >= 4 {
                    let span_start = self.get_from_top(i - 4).unwrap();
                    let span = super::get_span_of_nodes(&[span_start.value]).unwrap_or_default();
                    span.combine(token.span)
                } else {
                    token.span
                };
                return Some(Error {
                    message: "Missing parentheses around complex expression in \
                              a FOR iterator clause".to_string(),
                    span,
                    hint: None,
                    details: None,
                });
            },

            ParserRule::Create => {
                if matches!(token.kind, Kind::Keyword(Keyword("branch"))) {
                    return Some(Error {
                        message: "Missing one of keywords 'EMPTY', 'SCHEMA' or 'DATA'".to_string(),
                        span: Span { start: token.span.start - 1, end: token.span.start },
                        hint: None,
                        details: None,
                    })
                }
            }

            _ => {}
        }
        None
    }

    /// Look at the parsing stack and use tokens and non-terminals
    /// to infer the parser rule when the error occurred.
    fn get_rule(&self, ctx: &Context) -> Option<(usize, ParserRule)> {
        // If the last valid token was a closing brace/parent/bracket,
        // so we need to find a match for it before deciding what rule
        // context we're in.
        let mut need_match = self.compare_stack(
            &[Cond::AnyOf(vec![
                Cond::Terminal(Kind::CloseBrace),
                Cond::Terminal(Kind::CloseParen),
                Cond::Terminal(Kind::CloseBracket),
            ])],
            0,
            ctx,
        );
        let mut found_union = false;

        let ltok = self.get_from_top(0).unwrap();

        let mut nextel = None;
        let mut curr_el = Some(self.stack_top);
        let mut i = 0;
        while let Some(el) = curr_el {
            // We'll need the element right before "{", "[", or "(".
            let prevel = el.parent;

            match el.value {
                CSTNode::Terminal(Terminal {
                    kind: Kind::OpenBrace,
                    ..
                }) => {
                    if need_match && Cond::Terminal(Kind::CloseBrace).check(ltok, ctx) {
                        // This is matched, while we're looking
                        // for unmatched braces.
                        need_match = false;
                    } else if Cond::Production("OptExtending").check_opt(prevel, ctx) {
                        // This is some SDL/DDL
                        return Some((i, ParserRule::Definition));
                    } else if prevel.is_some_and(|prevel| {
                        Cond::Production("Expr").check(prevel, ctx)
                            || (Cond::Terminal(Kind::Colon).check(prevel, ctx)
                                && Cond::Production("ShapePointer").check_opt(prevel.parent, ctx))
                    }) {
                        // This is some kind of shape.
                        return Some((i, ParserRule::Shape));
                    } else {
                        return None;
                    }
                }

                CSTNode::Terminal(Terminal {
                    kind: Kind::OpenParen,
                    ..
                }) => {
                    if need_match && Cond::Terminal(Kind::CloseParen).check(ltok, ctx) {
                        // This is matched, while we're looking
                        // for unmatched parentheses.
                        need_match = false
                    } else if Cond::Production("NodeName").check_opt(prevel, ctx) {
                        return Some((i, ParserRule::ListOfArguments));
                    } else if Cond::AnyOf(vec![
                        Cond::keyword("for"),
                        Cond::keyword("select"),
                        Cond::keyword("update"),
                        Cond::keyword("delete"),
                        Cond::keyword("insert"),
                        Cond::keyword("for"),
                    ])
                    .check_opt(nextel, ctx)
                    {
                        // A parenthesized subquery expression,
                        // we should leave the error as is.
                        return None;
                    } else {
                        return Some((i, ParserRule::Tuple));
                    }
                }

                CSTNode::Terminal(Terminal {
                    kind: Kind::OpenBracket,
                    ..
                }) => {
                    // This is either an array literal or
                    // array index.

                    if need_match && Cond::Terminal(Kind::CloseBracket).check(ltok, ctx) {
                        // This is matched, while we're looking
                        // for unmatched brackets.
                        need_match = false
                    } else if Cond::Production("Expr").check_opt(prevel, ctx) {
                        return Some((i, ParserRule::ArraySlice));
                    } else {
                        return Some((i, ParserRule::Array));
                    }
                }

                CSTNode::Terminal(Terminal {
                    kind: Kind::Keyword(Keyword("create")),
                    ..
                }) => return Some((i, ParserRule::Create)),

                _ => {}
            }

            // Check if we're in the `FOR x IN bad_tokens` situation
            if self.compare_stack(&[Cond::keyword("union")], i, ctx) {
                found_union = true;
            }
            if !found_union
                && self.compare_stack(
                    &[
                        Cond::keyword("for"),
                        Cond::Production("OptionalOptional"),
                        Cond::Production("Identifier"),
                        Cond::keyword("in"),
                    ],
                    i,
                    ctx,
                )
            {
                return Some((i + 3, ParserRule::ForIterator));
            }

            // Also keep track of the element right after current.
            nextel = Some(el);
            curr_el = el.parent;
            i += 1;
        }

        None
    }

    /// Looks at the stack and compares it with the expected nodes.
    /// Does not compare [top_offset] number of nodes from the top of the start.
    ///
    /// Example of matching with top_offset=1, expected=[X, Y, Z]
    /// ```plain
    /// stack top -> A     (offset 1)
    ///              B - Z
    ///              C - Y
    ///              D - X
    ///              E
    /// ```
    fn compare_stack(&self, expected: &[Cond], top_offset: usize, ctx: &Context) -> bool {
        let mut current = self.get_from_top(top_offset);

        for validator in expected.iter().rev() {
            let Some(cur) = current else {
                return false;
            };
            if !validator.check(cur, ctx) {
                return false;
            }

            current = cur.parent;
        }
        true
    }
}

fn unexpected_reserved_keyword(text: &str, span: Span) -> Error {
    let text_upper = text.to_uppercase();
    Error {
        message: format!("Unexpected keyword '{text_upper}'"),
        span,
        details: Some(
            "This name is a reserved keyword and cannot be \
            used as an identifier"
                .to_string(),
        ),
        hint: Some(format!(
            "Use a different identifier or quote the name \
            with backticks: `{text}`"
        )),
    }
}

/// Condition for a stack node. An easier way to match stack node kinds.
enum Cond {
    Terminal(Kind),
    Production(&'static str),
    AnyOf(Vec),
}

impl Cond {
    fn keyword(kw: &'static str) -> Self {
        Cond::Terminal(Kind::Keyword(Keyword(kw)))
    }

    fn check(&self, node: &StackNode, ctx: &Context) -> bool {
        match self {
            Cond::Terminal(kind) => matches!(
                node.value,
                CSTNode::Terminal(Terminal { kind: k, .. }) if k == kind
            ),
            Cond::Production(non_term) => match node.value {
                CSTNode::Production(prod) => {
                    let (pn, _) = &ctx.spec.production_names[prod.id];
                    if non_term == pn {
                        return true;
                    }

                    // When looking for a production, it might have happened
                    // that it was inlined and superseded by one of its
                    // arguments. That's why we save the id of the parent into
                    // child's `inlined_ids` and check all of them here.
                    if let Some(inlined_ids) = prod.inlined_ids {
                        for prod_id in inlined_ids {
                            let (pn, _) = &ctx.spec.production_names[*prod_id];
                            if non_term == pn {
                                return true;
                            }
                        }
                    }
                    false
                }
                _ => false,
            },
            Cond::AnyOf(options) => options.iter().any(|v| v.check(node, ctx)),
        }
    }

    fn check_opt(&self, node: Option<&StackNode>, ctx: &Context) -> bool {
        node.is_some_and(|x| self.check(x, ctx))
    }
}

#[derive(Debug)]
enum ParserRule {
    ForIterator,
    Definition,
    Shape,
    ArraySlice,
    Array,
    Tuple,
    ListOfArguments,
    Create,
}

impl std::fmt::Display for ParserRule {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ParserRule::ForIterator => f.write_str("for iterator"),
            ParserRule::Definition => f.write_str("definition"),
            ParserRule::Shape => f.write_str("shape"),
            ParserRule::ArraySlice => f.write_str("array slice"),
            ParserRule::Array => f.write_str("array"),
            ParserRule::Tuple => f.write_str("tuple"),
            ParserRule::ListOfArguments => f.write_str("list of arguments"),
            ParserRule::Create => f.write_str("create"),
        }
    }
}

pub fn post_process(errors: Vec) -> Vec {
    let mut new_errors: Vec = Vec::with_capacity(errors.len());
    for error in errors {
        // Enrich combination of 'Unexpected keyword' + 'Missing identifier'
        if error.message == "Missing identifier" {
            if let Some(last) = new_errors.last() {
                if last.message.starts_with("Unexpected keyword '")
                    && last.span.end == error.span.start
                {
                    let last = new_errors.pop().unwrap();
                    let text = last.message.strip_prefix("Unexpected keyword '").unwrap();
                    let text = text.strip_suffix('\'').unwrap();

                    new_errors.push(unexpected_reserved_keyword(text, last.span));
                    continue;
                }
            }
        }

        new_errors.push(error);
    }

    new_errors
}


================================================
FILE: edb/edgeql-parser/src/parser/mod.rs
================================================
mod cst;
mod custom_errors;
mod spec;

pub use cst::{CSTNode, Production, Terminal};
pub use spec::{Action, Reduce, Spec, SpecSerializable};

use append_only_vec::AppendOnlyVec;

use crate::keywords::{self, Keyword};
use crate::position::Span;
use crate::tokenizer::{Error, Kind, Value};

pub struct Context<'s> {
    spec: &'s Spec,
    arena: bumpalo::Bump,
    terminal_arena: AppendOnlyVec,
}

impl<'s> Context<'s> {
    pub fn new(spec: &'s Spec) -> Self {
        Context {
            spec,
            arena: bumpalo::Bump::new(),
            terminal_arena: AppendOnlyVec::new(),
        }
    }
}

/// This is a const just so we remember to update it everywhere
/// when changing.
const UNEXPECTED: &str = "Unexpected";

pub fn parse<'a>(input: &'a [Terminal], ctx: &'a Context) -> (Option>, Vec) {
    let stack_top = ctx.arena.alloc(StackNode {
        parent: None,
        state: 0,
        value: CSTNode::Empty,
    });
    let initial_track = Parser {
        stack_top,
        error_cost: 0,
        node_count: 0,
        can_recover: true,
        errors: Vec::new(),
        has_custom_error: false,
    };

    // Append EIO token.
    // We have a weird setup that requires two EOI tokens:
    // - one is consumed by the grammar generator and does not contribute to
    //   span of the nodes.
    // - second is consumed by explicit EOF tokens in EdgeQLGrammar NonTerm.
    //   Since these are children of productions, they do contribute to the
    //   spans of the top-level nodes.
    // First EOI is produced by tokenizer (with correct offset) and second one
    // is injected here.
    let end = input.last().map(|t| t.span.end).unwrap_or_default();
    let eoi = ctx.alloc_terminal(Terminal {
        kind: Kind::EOI,
        span: Span { start: end, end },
        text: "".to_string(),
        value: None,
        is_placeholder: false,
    });
    let input = input.iter().chain([eoi]);

    let mut parsers = vec![initial_track];
    let mut prev_span: Option = None;
    let mut new_parsers = Vec::with_capacity(parsers.len() + 5);

    for token in input {
        // println!("token {:?}", token);

        while let Some(mut parser) = parsers.pop() {
            let res = parser.act(ctx, token);

            if res.is_ok() {
                // base case: ok
                parser.node_successful();
                new_parsers.push(parser);
            } else {
                // error: try to recover

                let gap_span = {
                    let prev_end = prev_span.map(|p| p.end).unwrap_or(token.span.start);

                    Span {
                        start: prev_end,
                        end: token.span.start,
                    }
                };

                // option 1: inject a token
                if parser.error_cost <= ERROR_COST_INJECT_MAX && !parser.has_custom_error {
                    let possible_actions = &ctx.spec.actions[parser.stack_top.state];
                    for token_kind in possible_actions.keys() {
                        if parser.can_act(ctx, token_kind).is_none() {
                            continue;
                        }

                        let mut inject = parser.clone();

                        let injection =
                            new_token_for_injection(*token_kind, &prev_span, token.span, ctx);

                        let cost = injection_cost(token_kind);
                        let error = Error::new(format!("Missing {injection}")).with_span(gap_span);
                        inject.push_error(error, cost);

                        if inject.error_cost <= ERROR_COST_INJECT_MAX
                            && inject.act(ctx, injection).is_ok()
                        {
                            // println!("   --> [inject {injection}]");

                            // insert into parsers, to retry the original token
                            parsers.push(inject);
                        }
                    }
                }

                // option 2: check for a custom error and skip token
                //   Due to performance reasons, this is done only on first
                //   error, not during all the steps of recovery.
                if parser.error_cost == 0 {
                    if let Some(error) = parser.custom_error(ctx, token) {
                        parser
                            .push_error(error.default_span_to(token.span), ERROR_COST_CUSTOM_ERROR);
                        parser.has_custom_error = true;

                        // println!("   --> [custom error]");
                        new_parsers.push(parser);
                        continue;
                    }
                } else if parser.has_custom_error {
                    // when there is a custom error, just skip the tokens until
                    // the parser recovers
                    // println!("   --> [skip because of custom error]");
                    new_parsers.push(parser);
                    continue;
                }

                // option 3: skip the token
                let mut skip = parser;
                let error = Error::new(format!("{UNEXPECTED} {token}")).with_span(token.span);
                skip.push_error(error, ERROR_COST_SKIP);
                if token.kind == Kind::EOI || token.kind == Kind::Semicolon {
                    // extra penalty
                    skip.error_cost += ERROR_COST_INJECT_MAX;
                    skip.can_recover = false;
                }

                // insert into new_parsers, so the token is skipped
                // println!("   --> [skip] {}", skip.error_cost);
                new_parsers.push(skip);
            }
        }

        // has any parser recovered?
        if new_parsers.len() > 1 {
            new_parsers.sort_by_key(Parser::adjusted_cost);

            if new_parsers[0].has_custom_error {
                // if we have a custom error, just keep that

                new_parsers.drain(1..);
            } else if new_parsers[0].has_recovered() {
                // recover parsers whose "adjusted error cost" reached 0 and discard the rest

                new_parsers.retain(|p| p.has_recovered());
                for p in &mut new_parsers {
                    p.error_cost = 0;
                }
            } else if new_parsers[0].error_cost > ERROR_COST_INJECT_MAX {
                // prune: pick only 1 best parsers that has cost > ERROR_COST_INJECT_MAX

                new_parsers.drain(1..);
            } else if new_parsers.len() > PARSER_COUNT_MAX {
                // prune: pick only X best parsers

                new_parsers.drain(PARSER_COUNT_MAX..);
            }
        }

        assert!(parsers.is_empty());
        std::mem::swap(&mut parsers, &mut new_parsers);
        prev_span = Some(token.span);

        // for (index, parser) in parsers.iter().enumerate() {
        //     print!(
        //         "p{index} {:06} {:5}:",
        //         parser.error_cost, parser.can_recover
        //     );
        //     for e in &parser.errors {
        //         print!(" {}", e.message);
        //     }
        //     println!("");
        // }
        // println!("");
    }

    // there will always be a parser left,
    // since we always allow a token to be skipped
    let parser = parsers
        .into_iter()
        .min_by(|a, b| {
            Ord::cmp(&a.error_cost, &b.error_cost).then_with(|| {
                Ord::cmp(
                    &starts_with_unexpected_error(a),
                    &starts_with_unexpected_error(b),
                )
                .reverse()
            })
        })
        .unwrap();

    let node = parser.finish(ctx);
    let errors = custom_errors::post_process(parser.errors);
    (node, errors)
}

/// Parses tokens and then inspects the state of the parser to suggest possible
/// next keywords and a boolean indicating if next token can be an identifier.
/// This is done by looking at available actions in current state.
/// An important detail is that not all of these actions are valid.
/// They might trigger a chain of reductions that ends in a state that
/// does not accept the suggested token.
pub fn suggest_next_keyword<'a>(input: &'a [Terminal], ctx: &'a Context) -> (Vec, bool) {
    // init
    let stack_top = ctx.arena.alloc(StackNode {
        parent: None,
        state: 0,
        value: CSTNode::Empty,
    });
    let mut parser = Parser {
        stack_top,
        error_cost: 0,
        node_count: 0,
        can_recover: true,
        errors: Vec::new(),
        has_custom_error: false,
    };

    // parse tokens
    for token in input.iter() {
        if matches!(token.kind, Kind::EOI) {
            break;
        }

        let res = parser.act(ctx, token);

        if res.is_err() {
            return (vec![], false);
        }
    }

    // extract possible next actions
    let actions = &ctx.spec.actions[parser.stack_top.state];

    let can_be_ident =
        actions.contains_key(&Kind::Ident) && parser.can_act(ctx, &Kind::Ident).is_some();

    let keywords = actions
        .keys()
        // suggest only keywords
        .filter_map(|kind| {
            if let Kind::Keyword(keyword) = kind {
                Some(*keyword)
            } else {
                None
            }
        })
        // never suggest dunder or bools, they should be suggested semantically
        .filter(|k| !k.is_dunder() && !k.is_bool())
        // if next token can be ident, hide all unreserved keywords
        .filter(|k| !(can_be_ident && k.is_unreserved()))
        // filter only valid actions
        .filter(|k| parser.can_act(ctx, &Kind::Keyword(*k)).is_some())
        .collect();

    (keywords, can_be_ident)
}

fn starts_with_unexpected_error(a: &Parser) -> bool {
    a.errors
        .first()
        .is_none_or(|x| x.message.starts_with(UNEXPECTED))
}

impl Context<'_> {
    fn alloc_terminal(&self, t: Terminal) -> &'_ Terminal {
        let idx = self.terminal_arena.push(t);
        &self.terminal_arena[idx]
    }

    fn alloc_slice_and_push(&self, slice: &Option<&[usize]>, element: usize) -> &[usize] {
        let curr_len = slice.map_or(0, |x| x.len());
        let mut new = Vec::with_capacity(curr_len + 1);
        if let Some(inlined_ids) = slice {
            new.extend(*inlined_ids);
        }
        new.push(element);
        self.arena.alloc_slice_clone(new.as_slice())
    }
}

fn new_token_for_injection<'a>(
    kind: Kind,
    prev_span: &Option,
    next_span: Span,
    ctx: &'a Context,
) -> &'a Terminal {
    let (text, value) = match kind {
        Kind::Keyword(Keyword(kw)) => (kind.text(), Some(Value::String(kw.to_string()))),
        Kind::Ident => {
            let ident = "ident_placeholder";
            (Some(ident), Some(Value::String(ident.into())))
        }
        _ => (kind.text(), None),
    };

    ctx.alloc_terminal(Terminal {
        kind,
        text: text.unwrap_or_default().to_string(),
        value,
        span: Span {
            start: prev_span.map_or(0, |x| x.end),
            end: next_span.start,
        },
        is_placeholder: true,
    })
}

struct StackNode<'p> {
    parent: Option<&'p StackNode<'p>>,

    state: usize,
    value: CSTNode<'p>,
}

#[derive(Clone)]
struct Parser<'s> {
    stack_top: &'s StackNode<'s>,

    /// sum of cost of every error recovery action
    error_cost: u16,

    /// number of nodes pushed to stack since last error
    node_count: u16,

    /// prevent parser from recovering, for cases when EOF was skipped
    can_recover: bool,

    errors: Vec,

    /// A flag that is used to make the parser prefer custom errors over other
    /// recovery paths
    has_custom_error: bool,
}

impl<'s> Parser<'s> {
    fn act(&mut self, ctx: &'s Context, token: &'s Terminal) -> Result<(), ()> {
        // self.print_stack();
        // println!("INPUT: {}", token.text);

        loop {
            // find next action
            let Some(action) = ctx.spec.actions[self.stack_top.state].get(&token.kind) else {
                return Err(());
            };

            match action {
                Action::Shift(next) => {
                    // println!("   --> [shift {next}]");

                    // push on stack
                    self.push_on_stack(ctx, *next, CSTNode::Terminal(token));
                    return Ok(());
                }
                Action::Reduce(reduce) => {
                    self.reduce(ctx, reduce);
                }
            }
        }
    }

    fn reduce(&mut self, ctx: &'s Context, reduce: &'s Reduce) {
        let args = ctx.arena.alloc_slice_fill_with(reduce.cnt, |_| {
            let v = self.stack_top.value;
            self.stack_top = self.stack_top.parent.unwrap();
            v
        });
        args.reverse();

        let value = CSTNode::Production(Production {
            id: reduce.production_id,
            span: get_span_of_nodes(args),
            args,
            inlined_ids: None,
        });

        let nstate = self.stack_top.state;

        let next = *ctx.spec.goto[nstate].get(&reduce.non_term).unwrap();

        // inline (if there is an inlining rule)
        let mut value = value;
        if let CSTNode::Production(production) = value {
            if let Some(inline_position) = ctx.spec.inlines.get(&production.id) {
                let inlined_id = production.id;
                // inline rule found
                let args = production.args;

                value = args[*inline_position as usize];

                // save inlined id
                if let CSTNode::Production(new_prod) = &mut value {
                    new_prod.inlined_ids =
                        Some(ctx.alloc_slice_and_push(&new_prod.inlined_ids, inlined_id));
                }
            } else {
                // place back
                value = CSTNode::Production(production);
            }
        }

        self.push_on_stack(ctx, next, value);

        // println!(
        //     "   --> [reduce {} ::= ({} popped) at {}/{}]",
        //     production, cnt, state, nstate
        // );
        // self.print_stack();
    }

    pub fn push_on_stack(&mut self, ctx: &'s Context, state: usize, value: CSTNode<'s>) {
        let node = StackNode {
            parent: Some(self.stack_top),
            state,
            value,
        };
        self.stack_top = ctx.arena.alloc(node);
    }

    pub fn finish(&self, _ctx: &'s Context) -> Option> {
        if !self.can_recover || self.has_custom_error {
            return None;
        }

        // pop the EOI from the top of the stack
        assert!(
            matches!(
                &self.stack_top.value,
                CSTNode::Terminal(Terminal {
                    kind: Kind::EOI,
                    ..
                })
            ),
            "expected EOI CST node, got {:?}",
            self.stack_top.value
        );

        let final_node = self.stack_top.parent.unwrap();

        // self.print_stack(_ctx);
        // println!("   --> accept");

        let first = final_node.parent.unwrap();
        assert!(
            matches!(&first.value, CSTNode::Empty),
            "expected empty CST node, found {:?}",
            first.value
        );

        Some(final_node.value)
    }

    /// Lightweight version of act that checks if a token *could* be applied.
    /// Returns next state.
    fn can_act(&self, ctx: &'s Context, token: &Kind) -> Option {
        let mut state = self.stack_top.state;

        let mut node = &self.stack_top;

        // count of "ghost" stack nodes, which should have been pushed to the stack,
        // but haven't because we don't actually need them there, only need to know
        // how many of them there are
        let mut ghosts = 0;

        loop {
            // find next action
            let action = ctx.spec.actions[state].get(token)?;

            match action {
                Action::Shift(next) => {
                    return Some(*next);
                }
                Action::Reduce(reduce) => {
                    // simulate reduce stack pops
                    // (cancel out any ghost nodes if there is any)
                    let cancel_out = usize::min(ghosts, reduce.cnt);
                    ghosts -= cancel_out;
                    for _ in 0..(reduce.cnt - cancel_out) {
                        node = node.parent.as_ref().unwrap();
                    }

                    // get state of current stack top
                    // Stack top is node.state, unless we have ghosts. In that case, the
                    // state of node we would have pushed is stored in `state`.
                    let stack_state = if ghosts > 0 { state } else { node.state };

                    state = *ctx.spec.goto[stack_state].get(&reduce.non_term)?;

                    ghosts += 1;
                }
            }
        }
    }

    #[cfg(never)]
    fn print_stack(&self, ctx: &'s Context) {
        let prefix = "STACK: ";

        let mut stack = Vec::new();
        let mut node = Some(self.stack_top);
        while let Some(n) = node {
            stack.push(n);
            node = n.parent.clone();
        }
        stack.reverse();

        let names = stack
            .iter()
            .map(|s| match s.value {
                CSTNode::Empty => format!("Empty"),
                CSTNode::Terminal(term) => format!("{term}"),
                CSTNode::Production(prod) => {
                    let prod_name = &ctx.spec.production_names[prod.id];
                    format!("{}.{}", prod_name.0, prod_name.1)
                }
            })
            .collect::>();

        let mut states = format!("{:5}", ' ');
        for (index, node) in stack.iter().enumerate() {
            let name_width = names[index].chars().count();
            states += &format!("  {: u16 {
        let x = self.node_count.saturating_sub(3);
        self.error_cost.saturating_sub(x.saturating_mul(x))
    }

    fn has_recovered(&self) -> bool {
        self.can_recover && self.adjusted_cost() == 0
    }

    fn get_from_top(&self, steps: usize) -> Option<&StackNode<'s>> {
        self.stack_top.step_up(steps)
    }
}

impl<'a> StackNode<'a> {
    fn step_up(&self, steps: usize) -> Option<&StackNode<'a>> {
        let mut node = Some(self);
        for _ in 0..steps {
            match node {
                None => return None,
                Some(n) => node = n.parent,
            }
        }
        node
    }
}

/// Returns the span of syntactically ordered nodes. Panic on empty nodes.
fn get_span_of_nodes(nodes: &[CSTNode]) -> Option {
    let start = nodes.iter().find_map(|x| match x {
        CSTNode::Terminal(t) => Some(t.span.start),
        CSTNode::Production(p) => Some(p.span?.start),
        CSTNode::Empty => panic!(),
    })?;
    let end = nodes.iter().rev().find_map(|x| match x {
        CSTNode::Terminal(t) => Some(t.span.end),
        CSTNode::Production(p) => Some(p.span?.end),
        CSTNode::Empty => panic!(),
    })?;
    Some(Span { start, end })
}

const PARSER_COUNT_MAX: usize = 10;

const ERROR_COST_INJECT_MAX: u16 = 15;
const ERROR_COST_SKIP: u16 = 3;
const ERROR_COST_CUSTOM_ERROR: u16 = 3;

fn injection_cost(kind: &Kind) -> u16 {
    use Kind::*;

    match kind {
        Ident => 10,
        Substitution => 8,

        // Manual keyword tweaks to encourage some error messages and discourage others.
        Keyword(keywords::Keyword(
            "delete" | "update" | "migration" | "role" | "global" | "administer" | "future"
            | "database" | "serializable" | "REPEATABLE" | "NOT", //  | "if" | "group",
        )) => 100,
        Keyword(keywords::Keyword("insert" | "module" | "extension" | "branch")) => 20,
        Keyword(keywords::Keyword("select" | "property" | "type")) => 10,
        Keyword(_) => 15,

        Dot => 5,
        OpenBrace | OpenBracket => 5,
        OpenParen => 4,

        CloseBrace | CloseBracket | CloseParen => 1,

        Namespace => 10,
        Comma | Colon | Semicolon => 2,
        Eq => 5,

        At => 6,
        IntConst => 8,

        Assign | Arrow => 5,

        _ => 100, // forbidden
    }
}


================================================
FILE: edb/edgeql-parser/src/parser/spec.rs
================================================
use indexmap::IndexMap;

use crate::tokenizer::Kind;

pub struct Spec {
    pub actions: Vec>,
    pub goto: Vec>,
    pub inlines: IndexMap,
    pub production_names: Vec<(String, String)>,
}

#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum Action {
    Shift(usize),
    Reduce(Reduce),
}

#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Reduce {
    /// Index of the production in the associated production array
    pub production_id: usize,

    pub non_term: String,

    /// Number of arguments
    pub cnt: usize,
}

#[cfg(feature = "serde")]
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct SpecSerializable {
    pub actions: Vec>,
    pub goto: Vec>,
    pub inlines: Vec<(usize, u8)>,
    pub production_names: Vec<(String, String)>,
}

#[cfg(feature = "serde")]
impl From for Spec {
    fn from(v: SpecSerializable) -> Spec {
        let actions = v
            .actions
            .into_iter()
            .map(|x| x.into_iter().map(|(k, a)| (get_token_kind(&k), a)))
            .map(IndexMap::from_iter)
            .collect();
        let goto = v.goto.into_iter().map(IndexMap::from_iter).collect();
        let inlines = IndexMap::from_iter(v.inlines);

        Spec {
            actions,
            goto,
            inlines,
            production_names: v.production_names,
        }
    }
}

#[cfg(feature = "serde")]
pub(super) fn get_token_kind(token_name: &str) -> Kind {
    use Kind::*;

    match token_name {
        "+" => Add,
        "&" => Ampersand,
        "@" => At,
        ".<" => BackwardLink,
        ".?>" => OptionalLink,
        "}" => CloseBrace,
        "]" => CloseBracket,
        ")" => CloseParen,
        "??" => Coalesce,
        ":" => Colon,
        "," => Comma,
        "++" => Concat,
        "/" => Div,
        "." => Dot,
        "**" => DoubleSplat,
        "=" => Eq,
        "//" => FloorDiv,
        "%" => Modulo,
        "*" => Mul,
        "::" => Namespace,
        "{" => OpenBrace,
        "[" => OpenBracket,
        "(" => OpenParen,
        "|" => Pipe,
        "^" => Pow,
        ";" => Semicolon,
        "-" => Sub,

        "?!=" => DistinctFrom,
        ">=" => GreaterEq,
        "<=" => LessEq,
        "?=" => NotDistinctFrom,
        "!=" => NotEq,
        "<" => Less,
        ">" => Greater,

        "IDENT" => Ident,
        "EOI" | "<$>" => EOI,
        "" => Epsilon,

        "BCONST" => BinStr,
        "FCONST" => FloatConst,
        "ICONST" => IntConst,
        "NFCONST" => DecimalConst,
        "NICONST" => BigIntConst,
        "SCONST" => Str,

        "STARTBLOCK" => StartBlock,
        "STARTEXTENSION" => StartExtension,
        "STARTFRAGMENT" => StartFragment,
        "STARTMIGRATION" => StartMigration,
        "STARTSDLDOCUMENT" => StartSDLDocument,

        "+=" => AddAssign,
        "->" => Arrow,
        ":=" => Assign,
        "-=" => SubAssign,

        "PARAMETER" => Parameter,
        "PARAMETERANDTYPE" => ParameterAndType,
        "SUBSTITUTION" => Substitution,

        "STRINTERPSTART" => StrInterpStart,
        "STRINTERPCONT" => StrInterpCont,
        "STRINTERPEND" => StrInterpEnd,

        _ => {
            let mut token_name = token_name.to_lowercase();

            if let Some(rem) = token_name.strip_prefix("dunder") {
                token_name = format!("__{rem}__");
            }

            let kw = crate::keywords::lookup_all(&token_name)
                .unwrap_or_else(|| panic!("unknown keyword {token_name}"));
            Keyword(kw)
        }
    }
}


================================================
FILE: edb/edgeql-parser/src/position.rs
================================================
use std::fmt;
use std::str::{from_utf8, Utf8Error};

use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};

/// Span of an element in source code
#[derive(Debug, Clone, Copy, Default, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Span {
    /// Byte offset in the original file
    ///
    /// Technically you can read > 4Gb file on 32bit machine so it may
    /// not fit in usize
    pub start: u64,

    /// Byte offset in the original file
    ///
    /// Technically you can read > 4Gb file on 32bit machine so it may
    /// not fit in usize
    pub end: u64,
}
/// Original position of element in source code
#[derive(PartialOrd, Ord, PartialEq, Eq, Clone, Copy, Default, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Pos {
    /// One-based line number
    pub line: usize,
    /// One-based column number
    pub column: usize,
    /// Byte offset in the original file
    ///
    /// Technically you can read > 4Gb file on 32bit machine so it may
    /// not fit in usize
    pub offset: u64,
}

/// This contains position in all forms that EdgeDB needs
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct InflatedPos {
    /// Zero-based line number
    pub line: u64,
    /// Zero-based column number
    pub column: u64,
    /// Zero-based Utf16 column offset
    ///
    /// (this is required by language server protocol, LSP)
    pub utf16column: u64,
    /// Bytes offset in the orignal (utf-8 encoded) byte buffer
    pub offset: u64,
    /// Character offset in the whole string
    pub char_offset: u64,
}

/// Error calculating InflatedPos
#[derive(Debug, thiserror::Error)]
pub enum InflatingError {
    #[error(transparent)]
    Utf8(Utf8Error),
    #[error("offset out of range")]
    OutOfRange,
}

impl fmt::Debug for Pos {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "Pos({}:{})", self.line, self.column)
    }
}

impl fmt::Display for Pos {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}:{}", self.line, self.column)
    }
}

impl Span {
    pub fn combine(self, right: Span) -> Span {
        Span {
            start: self.start,
            end: right.end,
        }
    }

    pub fn extend(self, other: &Span) -> Span {
        Span {
            start: u64::min(self.start, other.start),
            end: u64::max(self.end, other.end),
        }
    }
}

fn new_lines_in_fragment(data: &[u8]) -> u64 {
    let mut was_lf = false;
    let mut lines = 0;
    for byte in data {
        match byte {
            b'\n' if was_lf => {
                was_lf = false;
            }
            b'\n' => {
                lines += 1;
            }
            b'\r' => {
                lines += 1;
                was_lf = true;
            }
            _ => {
                was_lf = false;
            }
        }
    }
    lines
}

impl InflatedPos {
    pub fn from_offset(data: &[u8], offset: u64) -> Result {
        let res = Self::from_offsets(data, &[offset as usize])?;
        Ok(res.into_iter().next().unwrap())
    }

    pub fn from_offsets(
        data: &[u8],
        offsets: &[usize],
    ) -> Result, InflatingError> {
        let mut result = Vec::with_capacity(offsets.len());
        // TODO(tailhook) optimize calculation if offsets are growing
        for &offset in offsets {
            if offset > data.len() {
                return Err(InflatingError::OutOfRange);
            }
            let prefix = &data[..offset];
            let prefix_s = from_utf8(prefix).map_err(InflatingError::Utf8)?;
            let line_offset;
            let line;
            if let Some(loff) = prefix_s.rfind(['\r', '\n']) {
                line_offset = loff + 1;
                let mut lines = &prefix[..loff];
                if data[loff] == b'\n' && loff > 0 && data[loff - 1] == b'\r' {
                    lines = &lines[..lines.len() - 1];
                }
                line = new_lines_in_fragment(lines) + 1;
            } else {
                line = 0;
                line_offset = 0;
            };
            let col_s = &prefix_s[line_offset..offset];
            result.push(InflatedPos {
                line,
                column: UnicodeWidthStr::width(col_s) as u64,
                utf16column: col_s.chars().map(|c| c.len_utf16() as u64).sum(),
                offset: offset as u64,
                char_offset: prefix_s.chars().count() as u64,
            });
        }
        Ok(result)
    }

    pub fn from_lines_cols(
        data: &[u8],
        lines_cols: &[(u64, u64)],
    ) -> Result, InflatingError> {
        let mut result = Vec::with_capacity(lines_cols.len());

        let text = from_utf8(data).map_err(InflatingError::Utf8)?;
        let mut text_iter = text.chars().peekable();

        let mut lines_cols = lines_cols.iter().peekable();

        let mut offset: u64 = 0;
        let mut char_offset: u64 = 0;
        let mut lines = 0..;

        for line in &mut lines {
            let mut utf16column = 0;
            let mut column = 0;
            'line: loop {
                // emit all matching points (there will typically be only one)
                loop {
                    if let Some((l, c)) = lines_cols.peek() {
                        if line == *l && *c <= utf16column {
                            result.push(InflatedPos {
                                line,
                                column,
                                utf16column,
                                offset,
                                char_offset,
                            });
                            lines_cols.next();
                        } else {
                            break;
                        }
                    } else {
                        break 'line;
                    }
                }

                // stop if end of line
                let eol = text_iter.peek().is_none_or(|c| *c == '\n' || *c == '\r');
                if eol {
                    break;
                }

                // advance a char
                let char = text_iter.next().unwrap();

                offset += char.len_utf8() as u64;
                utf16column += char.len_utf16() as u64;
                char_offset += 1;
                column += UnicodeWidthChar::width(char).unwrap_or(0) as u64;
            }

            // emit all point that had column out of line
            while let Some((l, _)) = lines_cols.peek() {
                if line == *l {
                    result.push(InflatedPos {
                        line,
                        column,
                        utf16column,
                        offset,
                        char_offset,
                    });
                    lines_cols.next();
                } else {
                    break;
                }
            }

            if text_iter.peek().is_none() || lines_cols.peek().is_none() {
                break;
            }

            // consume \n or \r\n
            if text_iter.peek().is_some_and(|c| *c == '\r') {
                text_iter.next();
                offset += 1;
                char_offset += 1;
            }
            if text_iter.peek().is_some_and(|c| *c == '\n') {
                text_iter.next();
                offset += 1;
                char_offset += 1;
            }
        }

        // emit all lines out of buffer
        let last_line = lines.next().unwrap().saturating_sub(1);
        for _ in lines_cols {
            result.push(InflatedPos {
                line: last_line,
                column: 0,
                utf16column: 0,
                offset,
                char_offset,
            });
        }

        Ok(result)
    }

    pub fn deflate(self) -> Pos {
        Pos {
            line: self.line as usize + 1,
            column: self.column as usize + 1,
            offset: self.offset,
        }
    }
}

#[cfg(test)]
mod test {
    use super::{new_lines_in_fragment, InflatedPos};

    fn mkpos(s: &str, off: usize) -> InflatedPos {
        InflatedPos::from_offsets(s.as_bytes(), &[off]).unwrap()[0]
    }

    fn mkpos2(s: &str, line: u64, col: u64) -> InflatedPos {
        InflatedPos::from_lines_cols(s.as_bytes(), &[(line, col)]).unwrap()[0]
    }

    #[track_caller]
    fn mkpos_both(s: &str, off: usize) -> InflatedPos {
        let pos = mkpos(s, off);
        let pos2 = mkpos2(s, pos.line, pos.utf16column);
        assert_eq!(pos, pos2);
        pos
    }

    #[test]
    fn ascii_line() {
        let text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit,";
        for off in 0..text.len() {
            let pos = mkpos(text, off);
            let off = off as u64;
            assert_eq!(pos.line, 0);
            assert_eq!(pos.column, off);
            assert_eq!(pos.utf16column, off);
            assert_eq!(pos.offset, off);
            assert_eq!(pos.char_offset, off);

            let pos2 = mkpos2(text, pos.line, pos.utf16column);
            assert_eq!(pos.line, pos2.line);
            assert_eq!(pos.column, pos2.column);
            assert_eq!(pos.utf16column, pos2.utf16column);
            assert_eq!(pos.offset, pos2.offset);
            assert_eq!(pos.char_offset, pos2.char_offset);
        }
    }

    #[test]
    fn ascii_multi_line() {
        let text = "line1\nline2";
        for off in 6..text.len() {
            let pos = mkpos(text, off);
            let off = off as u64;
            assert_eq!(pos.line, 1);
            assert_eq!(pos.column, off - 6);
            assert_eq!(pos.utf16column, off - 6);
            assert_eq!(pos.offset, off);
            assert_eq!(pos.char_offset, off);

            let pos2 = mkpos2(text, pos.line, pos.utf16column);
            assert_eq!(pos.line, pos2.line);
            assert_eq!(pos.column, pos2.column);
            assert_eq!(pos.utf16column, pos2.utf16column);
            assert_eq!(pos.offset, pos2.offset);
            assert_eq!(pos.char_offset, pos2.char_offset);
        }
    }

    #[test]
    fn line_endings() {
        fn count(s: &str) -> u64 {
            new_lines_in_fragment(s.as_bytes())
        }
        assert_eq!(count("line1\nline2\nline3"), 2);
        assert_eq!(count("line1\rline2\rline3"), 2);
        assert_eq!(count("line1\r\nline2\r\nline3"), 2);
        assert_eq!(count("line1\rline2\r\nline3\n"), 3);
        assert_eq!(count("line1\nline2\rline3\r\n"), 3);
        assert_eq!(count("line1\n\rline2\r\rline3\r"), 5);
    }

    #[test]
    fn char_offsets_00() {
        let pos = mkpos_both("bomb = 'b'", 9);
        assert_eq!(pos.line, 0);
        assert_eq!(pos.column, 9);
        assert_eq!(pos.utf16column, 9);
        assert_eq!(pos.offset, 9);
        assert_eq!(pos.char_offset, 9);
    }

    #[test]
    fn char_offsets_01() {
        assert!('💣'.len_utf16() == 2);

        // bomb takes 4 bytes when encoded as utf8
        let pos = mkpos("bomb = '💣'", 12);
        assert_eq!(pos.line, 0);
        assert_eq!(pos.column, 10); // bomb takes two columns
        assert_eq!(pos.utf16column, 10); // and also two 2 utf16 code points
        assert_eq!(pos.offset, 12);
        assert_eq!(pos.char_offset, 9);
    }

    #[test]
    fn char_offsets_02() {
        let pos = mkpos_both("line1\nbomb = '💣'", 18);
        assert_eq!(pos.line, 1);
        assert_eq!(pos.column, 10);
        assert_eq!(pos.utf16column, 10);
        assert_eq!(pos.offset, 18);
        assert_eq!(pos.char_offset, 15);
    }
    #[test]
    fn char_offsets_03() {
        let pos = mkpos_both("bomb = '💣'\nline1", 18);
        assert_eq!(pos.line, 1);
        assert_eq!(pos.column, 4);
        assert_eq!(pos.utf16column, 4);
        assert_eq!(pos.offset, 18);
        assert_eq!(pos.char_offset, 15);
    }
    #[test]
    fn char_offsets_04() {
        let pos = mkpos_both("letter = 'Ф'", 12);
        assert_eq!(pos.line, 0);
        assert_eq!(pos.column, 11);
        assert_eq!(pos.utf16column, 11);
        assert_eq!(pos.offset, 12);
        assert_eq!(pos.char_offset, 11);
    }
    #[test]
    fn char_offsets_05() {
        let pos = mkpos_both("line1\nletter = 'Ф'", 18);
        assert_eq!(pos.line, 1);
        assert_eq!(pos.column, 11);
        assert_eq!(pos.utf16column, 11);
        assert_eq!(pos.offset, 18);
        assert_eq!(pos.char_offset, 17);
    }
    #[test]
    fn char_offsets_06() {
        let pos = mkpos_both("letter = 'Ф'\nline1", 18);
        assert_eq!(pos.line, 1);
        assert_eq!(pos.column, 4);
        assert_eq!(pos.utf16column, 4);
        assert_eq!(pos.offset, 18);
        assert_eq!(pos.char_offset, 17);
    }
    #[test]
    fn char_offsets_07() {
        let pos = mkpos_both("letter = 'H'", 13);
        assert_eq!(pos.line, 0);
        assert_eq!(pos.column, 12);
        assert_eq!(pos.utf16column, 11);
        assert_eq!(pos.offset, 13);
        assert_eq!(pos.char_offset, 11);
    }
    #[test]
    fn char_offsets_08() {
        let pos = mkpos_both("line1\nletter = 'H'", 19);
        assert_eq!(pos.line, 1);
        assert_eq!(pos.column, 12);
        assert_eq!(pos.utf16column, 11);
        assert_eq!(pos.offset, 19);
        assert_eq!(pos.char_offset, 17);
    }
    #[test]
    fn char_offsets_09() {
        let pos = mkpos_both("letter = 'H'\nline1", 19);
        assert_eq!(pos.line, 1);
        assert_eq!(pos.column, 4);
        assert_eq!(pos.utf16column, 4);
        assert_eq!(pos.offset, 19);
        assert_eq!(pos.char_offset, 17);
    }
    #[test]
    fn char_offsets_10() {
        let pos = mkpos_both("hello\r\nworld", 9);
        assert_eq!(pos.line, 1);
        assert_eq!(pos.column, 2);
        assert_eq!(pos.utf16column, 2);
        assert_eq!(pos.offset, 9);
        assert_eq!(pos.char_offset, 9);
    }
    #[test]
    fn char_offsets_11() {
        let pos = mkpos2("hello\r\nworld", 0, 10);
        assert_eq!(pos.line, 0);
        assert_eq!(pos.column, 5);
        assert_eq!(pos.utf16column, 5);
        assert_eq!(pos.offset, 5);
        assert_eq!(pos.char_offset, 5);
    }
}


================================================
FILE: edb/edgeql-parser/src/preparser.rs
================================================
use memchr::memmem::find;

#[derive(Debug, PartialEq)]
pub struct Continuation {
    position: usize,
    braces: Vec,
}

/// Returns index of semicolon, or position where to continue search on new
/// data
pub fn full_statement(
    data: &[u8],
    continuation: Option,
) -> Result {
    let mut iter = data.iter().enumerate().peekable();
    if let Some(cont) = continuation.as_ref() {
        if cont.position > 0 {
            iter.nth(cont.position - 1);
        }
    }
    let mut braces_buf = continuation
        .map(|cont| cont.braces)
        .unwrap_or_else(|| Vec::with_capacity(8));
    'outer: while let Some((idx, b)) = iter.next() {
        match b {
            b'"' => {
                while let Some((_, b)) = iter.next() {
                    match b {
                        b'\\' => {
                            // skip any next char, even quote
                            iter.next();
                        }
                        b'"' => continue 'outer,
                        _ => continue,
                    }
                }
                return Err(Continuation {
                    position: idx,
                    braces: braces_buf,
                });
            }
            b'\'' => {
                while let Some((_, b)) = iter.next() {
                    match b {
                        b'\\' => {
                            // skip any next char, even quote
                            iter.next();
                        }
                        b'\'' => continue 'outer,
                        _ => continue,
                    }
                }
                return Err(Continuation {
                    position: idx,
                    braces: braces_buf,
                });
            }
            b'r' => {
                if matches!(iter.peek(), Some((_, b'b'))) {
                    // rb'something' -- skip `b` but match on quote
                    iter.next();
                };
                match iter.peek() {
                    None => {
                        return Err(Continuation {
                            position: idx,
                            braces: braces_buf,
                        });
                    }
                    Some((_, start @ (b'\'' | b'"'))) => {
                        let end = *start;
                        iter.next();
                        for (_, b) in iter.by_ref() {
                            if b == end {
                                continue 'outer;
                            }
                        }
                        return Err(Continuation {
                            position: idx,
                            braces: braces_buf,
                        });
                    }
                    Some((_, _)) => continue,
                }
            }
            b'`' => {
                for (_, b) in iter.by_ref() {
                    match b {
                        b'`' => continue 'outer,
                        _ => continue,
                    }
                }
                return Err(Continuation {
                    position: idx,
                    braces: braces_buf,
                });
            }
            b'#' => {
                for (_, &b) in iter.by_ref() {
                    if b == b'\n' {
                        continue 'outer;
                    }
                }
                return Err(Continuation {
                    position: idx,
                    braces: braces_buf,
                });
            }
            b'$' => {
                match iter.next() {
                    Some((end_idx, b'$')) => {
                        let end = find(&data[end_idx + 1..], b"$$");
                        if let Some(end) = end {
                            iter.nth(end + end_idx - idx);
                            continue 'outer;
                        }
                        return Err(Continuation {
                            position: idx,
                            braces: braces_buf,
                        });
                    }
                    Some((_, b'A'..=b'Z')) | Some((_, b'a'..=b'z')) | Some((_, b'_')) => {}
                    // Not a dollar-quote
                    Some((_, _)) => continue 'outer,
                    None => {
                        return Err(Continuation {
                            position: idx,
                            braces: braces_buf,
                        })
                    }
                }
                loop {
                    let (c_idx, c) = if let Some(pair) = iter.peek() {
                        *pair
                    } else {
                        return Err(Continuation {
                            position: idx,
                            braces: braces_buf,
                        });
                    };
                    match c {
                        b'$' => {
                            let end_idx = c_idx + 1;
                            let marker_size = end_idx - idx;
                            if let Some(end) = find(&data[end_idx..], &data[idx..end_idx]) {
                                iter.nth(1 + end + marker_size - 1);
                                continue 'outer;
                            }
                            return Err(Continuation {
                                position: idx,
                                braces: braces_buf,
                            });
                        }
                        b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'_' => {}
                        // Not a dollar-quote
                        _ => continue 'outer,
                    }
                    iter.next();
                }
            }
            b'{' => braces_buf.push(b'}'),
            b'(' => braces_buf.push(b')'),
            b'[' => braces_buf.push(b']'),
            b'}' | b')' | b']' if braces_buf.last() == Some(b) => {
                braces_buf.pop();
            }
            b';' if braces_buf.is_empty() => return Ok(idx + 1),
            _ => continue,
        }
    }
    Err(Continuation {
        position: data.len(),
        braces: braces_buf,
    })
}

/// Returns true if the text has no partial statements
///
/// This equivalent to `text.trim().is_empty()` except it also ignores
/// EdgeQL comments.
///
/// This is useful to find out whether last part of text split by
/// `full_statement` contains anything relevant. Before this function we
/// couldn't add a comment at the end of EdgeQL file.
pub fn is_empty(text: &str) -> bool {
    let mut iter = text.chars();
    loop {
        let cur_char = match iter.next() {
            Some(c) => c,
            None => return true,
        };
        match cur_char {
            '\u{feff}' | '\r' | '\t' | '\n' | ' ' | ';' => continue,
            // Comment
            '#' => {
                for c in iter.by_ref() {
                    if c == '\r' || c == '\n' {
                        break;
                    }
                }
                continue;
            }
            _ => return false,
        }
    }
}


================================================
FILE: edb/edgeql-parser/src/schema_file.rs
================================================
use crate::position::Pos;
use crate::tokenizer;
use crate::tokenizer::Tokenizer;

#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum SchemaFileError {
    #[error("{}: bracket `{}` has never been closed", pos, kind)]
    MissingBracket { pos: Pos, kind: char },
    #[error(
        "{}: closing bracket mismatch, opened `{}` at {}, encountered `{}`",
        closing_pos,
        opened,
        opened_pos,
        encountered
    )]
    BracketMismatch {
        opened: char,
        opened_pos: Pos,
        closing_pos: Pos,
        encountered: char,
    },
    #[error("{}: extra closing bracket `{}`", pos, kind)]
    ExtraBracket { pos: Pos, kind: char },
    #[error("{}: tokenizer error: {}", pos, error)]
    TokenizerError { pos: Pos, error: String },
}

fn match_bracket(
    open: char,
    encountered: char,
    pos: Pos,
    brackets: &mut Vec<(char, char, Pos)>,
) -> Result<(), SchemaFileError> {
    use SchemaFileError::*;

    match brackets.pop() {
        Some((_, exp, _)) if exp == encountered => Ok(()),
        Some((opened, _, opened_pos)) => Err(BracketMismatch {
            opened,
            opened_pos,
            closing_pos: pos,
            encountered,
        }),
        None => Err(ExtraBracket { pos, kind: open }),
    }
}

pub fn validate(text: &str) -> Result<(), SchemaFileError> {
    use tokenizer::Kind::*;
    use SchemaFileError::*;

    let mut token_stream = Tokenizer::new(text);
    let mut brackets = Vec::new();
    loop {
        let pos = token_stream.current_pos();
        match token_stream.next() {
            Some(Ok(tok)) => match tok.kind {
                OpenParen => brackets.push(('(', ')', pos)),
                OpenBrace => brackets.push(('{', '}', pos)),
                OpenBracket => brackets.push(('[', ']', pos)),
                CloseParen => match_bracket('(', ')', pos, &mut brackets)?,
                CloseBrace => match_bracket('{', '}', pos, &mut brackets)?,
                CloseBracket => match_bracket('[', ']', pos, &mut brackets)?,
                _ => {}
            },
            None => break,
            Some(Err(e)) => {
                return Err(TokenizerError {
                    pos: token_stream.current_pos(),
                    error: e.message,
                });
            }
        }
    }
    if let Some((kind, _, pos)) = brackets.pop() {
        return Err(MissingBracket { kind, pos });
    }
    Ok(())
}

#[cfg(test)]
mod test {
    use super::validate;

    fn check(s: &str) -> String {
        validate(s)
            .map(|_| String::new())
            .map_err(|e| {
                let s = e.to_string();
                assert!(!s.is_empty());
                s
            })
            .unwrap_or_else(|e| e)
    }

    #[test]
    fn test_normal() {
        assert_eq!(check("alias X := (SELECT 1)"), "");
    }

    #[test]
    fn test_braces() {
        assert_eq!(
            check("type X { property y := '}';"),
            "1:8: bracket `{` has never been closed"
        );

        assert_eq!(
            check("type X { property y -> z; )"),
            "1:27: closing bracket mismatch, \
            opened `{` at 1:8, encountered `)`"
        );

        assert_eq!(
            check("type X\nproperty y; }"),
            "2:13: extra closing bracket `{`"
        );

        assert_eq!(check("type X { property y := (select 1)}"), "");

        assert_eq!(
            check("type X { property y := (select 1})"),
            "1:33: closing bracket mismatch, \
            opened `(` at 1:24, encountered `}`"
        );

        assert_eq!(
            check("type X { property y := (select 1"),
            "1:24: bracket `(` has never been closed"
        );

        assert_eq!(
            check("type X { property y := (select 1)}}"),
            "1:35: extra closing bracket `{`"
        );

        assert_eq!(check("type X { property y := .z[1]}"), "");
    }

    #[test]
    fn test_str() {
        assert_eq!(
            check("create type X { \"} "),
            "1:17: tokenizer error: \
                unterminated string, quoted by `\"`"
        );
    }
}


================================================
FILE: edb/edgeql-parser/src/tokenizer.rs
================================================
use std::borrow::Cow;
use std::fmt;
use std::str::CharIndices;

use bigdecimal::BigDecimal;
use memchr::memmem::find;

use crate::keywords::{self, Keyword};
use crate::position::{Pos, Span};
use crate::validation::Validator;

// Current max keyword length is 10, but we're reserving some space
pub const MAX_KEYWORD_LENGTH: usize = 16;

#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Token<'a> {
    pub kind: Kind,
    pub text: Cow<'a, str>,

    /// Parsed during validation.
    pub value: Option,

    pub span: Span,
}

#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum Value {
    String(String),
    Int(i64),
    Float(f64),
    Bytes(Vec),

    /// Radix 16
    BigInt(String),
    Decimal(BigDecimal),
}

#[derive(Debug, Clone)]
pub struct Error {
    pub message: String,
    pub span: Span,
    pub hint: Option,
    pub details: Option,
}

impl Error {
    pub fn new(message: S) -> Self {
        Error {
            message: message.to_string(),
            span: Span::default(),
            hint: None,
            details: None,
        }
    }

    pub fn with_span(mut self, span: Span) -> Self {
        self.span = span;
        self
    }

    pub fn default_span_to(mut self, span: Span) -> Self {
        if self.span == Span::default() {
            self.span = span;
        }
        self
    }
}

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum Kind {
    Assign,           // :=
    SubAssign,        // -=
    AddAssign,        // +=
    Arrow,            // ->
    Coalesce,         // ??
    Namespace,        // ::
    BackwardLink,     // .<
    OptionalLink,     // .?>
    FloorDiv,         // //
    Concat,           // ++
    GreaterEq,        // >=
    LessEq,           // <=
    NotEq,            // !=
    NotDistinctFrom,  // ?=
    DistinctFrom,     // ?!=
    Comma,            // ,
    OpenParen,        // (
    CloseParen,       // )
    OpenBracket,      // [
    CloseBracket,     // ]
    OpenBrace,        // {
    CloseBrace,       // }
    Dot,              // .
    Semicolon,        // ;
    Colon,            // :
    Add,              // +
    Sub,              // -
    DoubleSplat,      // **
    Mul,              // *
    Div,              // /
    Modulo,           // %
    Pow,              // ^
    Less,             // <
    Greater,          // >
    Eq,               // =
    Ampersand,        // &
    Pipe,             // |
    At,               // @
    Parameter,        // $something, $`something`
    ParameterAndType, // $something
    DecimalConst,
    FloatConst,
    IntConst,
    BigIntConst,
    BinStr, // b"xx", b'xx'
    Str,    // "xx", 'xx', r"xx", r'xx', $$xx$$

    StrInterpStart, // "xx\(, 'xx\(
    StrInterpCont,  // )xx\(
    StrInterpEnd,   // )xx", )xx'

    BacktickName, // `xx`
    Substitution, // \(name)

    #[cfg_attr(feature = "serde", serde(deserialize_with = "deserialize_keyword"))]
    Keyword(Keyword),

    Ident,
    EOI,     // end of input
    Epsilon, //  (needed for LR parser)

    StartBlock,
    StartExtension,
    StartFragment,
    StartMigration,
    StartSDLDocument,
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
struct TokenStub<'a> {
    pub kind: Kind,
    pub text: &'a str,
}

#[derive(Debug, PartialEq)]
pub struct Tokenizer<'a> {
    buf: &'a str,
    position: Pos,
    off: usize,
    dot: bool,
    next_state: Option<(usize, TokenStub<'a>, usize, Pos, Pos)>,
    keyword_buf: String,
    // We maintain a stack of the starting string characters and
    // parentheses nesting level for all our open string
    // interpolations, since we need to match the correct one when
    // closing them.
    str_interp_stack: Vec<(String, usize)>,
    // The number of currently open parentheses. If we see a close
    // paren when there are no open parens *and* we are inside a
    // string inerpolation, we close it.
    open_parens: usize,
}

#[derive(Clone, Debug, PartialEq)]
pub struct Checkpoint {
    position: Pos,
    off: usize,
    dot: bool,
}

impl<'a> Iterator for Tokenizer<'a> {
    type Item = Result, Error>;

    fn next(&mut self) -> Option {
        let start = self.current_pos().offset;

        Some(
            self.read_token()?
                .map(|(token, end)| {
                    let end = end.offset;
                    Token {
                        kind: token.kind,
                        text: token.text.into(),
                        value: None,
                        span: Span { start, end },
                    }
                })
                .map_err(|e| {
                    let end = self.position.offset;
                    e.with_span(Span { start, end })
                }),
        )
    }
}

impl<'a> Tokenizer<'a> {
    pub fn new(s: &str) -> Tokenizer {
        let mut me = Tokenizer {
            buf: s,
            position: Pos {
                line: 1,
                column: 1,
                offset: 0,
            },
            off: 0,
            dot: false,
            next_state: None,
            // Current max keyword length is 10, but we're reserving some
            // space
            keyword_buf: String::with_capacity(MAX_KEYWORD_LENGTH),
            str_interp_stack: Vec::new(),
            open_parens: 0,
        };
        me.skip_whitespace();
        me
    }

    /// Start stream a with a modified position
    ///
    /// Note: we assume that the current position is at the start of slice `s`
    pub fn new_at(s: &str, position: Pos) -> Tokenizer {
        let mut me = Tokenizer {
            buf: s,
            position,
            off: 0,
            dot: false,
            next_state: None,
            keyword_buf: String::with_capacity(MAX_KEYWORD_LENGTH),
            // XXX: If we are in the middle of an interpolated string we will have trouble
            str_interp_stack: Vec::new(),
            open_parens: 0,
        };
        me.skip_whitespace();
        me
    }

    pub fn validated_values(self) -> Validator<'a> {
        Validator::new(self)
    }

    pub fn checkpoint(&self) -> Checkpoint {
        Checkpoint {
            position: self.position,
            off: self.off,
            dot: self.dot,
        }
    }

    pub fn reset(&mut self, checkpoint: Checkpoint) {
        self.position = checkpoint.position;
        self.off = checkpoint.off;
        self.dot = checkpoint.dot;
    }

    pub fn current_pos(&self) -> Pos {
        self.position
    }

    fn read_token(&mut self) -> Option, Pos), Error>> {
        use self::Kind::*;

        // This quickly resets the stream one token back
        // (the most common reset that used quite often)
        if let Some((at, tok, off, end, next)) = self.next_state {
            if at == self.off {
                self.off = off;
                self.position = next;
                return Some(Ok((tok, end)));
            }
        }

        let old_pos = self.off;
        let (kind, len) = match self.peek_token()? {
            Ok(x) => x,
            Err(e) => return Some(Err(e)),
        };

        match kind {
            StrInterpStart => {
                let start = self.buf[self.off..].chars().next()?;
                self.str_interp_stack.push((start.into(), self.open_parens));
            }
            StrInterpEnd => {
                self.str_interp_stack.pop();
            }
            OpenParen => {
                self.open_parens += 1;
            }
            CloseParen => {
                if self.open_parens > 0 {
                    self.open_parens -= 1;
                }
            }
            _ => {}
        }

        // note we may want to get rid of "update_position" here as it's
        // faster to update 'as you go', but this is easier to get right first
        self.update_position(len);
        self.dot = matches!(kind, Kind::Dot);
        let value = &self.buf[self.off - len..self.off];
        let end = self.position;

        self.skip_whitespace();
        let token = TokenStub { kind, text: value };
        // This is for quick reset on token back
        self.next_state = Some((old_pos, token, self.off, end, self.position));
        Some(Ok((token, end)))
    }

    fn peek_token(&mut self) -> Option> {
        let tail = &self.buf[self.off..];
        let mut iter = tail.char_indices();

        let (_, cur_char) = iter.next()?;
        Some(self.peek_token_inner(cur_char, tail, &mut iter))
    }

    fn peek_token_inner(
        &mut self,
        cur_char: char,
        tail: &str,
        iter: &mut CharIndices<'_>,
    ) -> Result<(Kind, usize), Error> {
        use self::Kind::*;

        match cur_char {
            ':' => match iter.next() {
                Some((_, '=')) => Ok((Assign, 2)),
                Some((_, ':')) => Ok((Namespace, 2)),
                _ => Ok((Colon, 1)),
            },
            '-' => match iter.next() {
                Some((_, '>')) => Ok((Arrow, 2)),
                Some((_, '=')) => Ok((SubAssign, 2)),
                _ => Ok((Sub, 1)),
            },
            '>' => match iter.next() {
                Some((_, '=')) => Ok((GreaterEq, 2)),
                _ => Ok((Greater, 1)),
            },
            '<' => match iter.next() {
                Some((_, '=')) => Ok((LessEq, 2)),
                _ => Ok((Less, 1)),
            },
            '+' => match iter.next() {
                Some((_, '=')) => Ok((AddAssign, 2)),
                Some((_, '+')) => Ok((Concat, 2)),
                _ => Ok((Add, 1)),
            },
            '/' => match iter.next() {
                Some((_, '/')) => Ok((FloorDiv, 2)),
                _ => Ok((Div, 1)),
            },
            '.' => match iter.next() {
                Some((_, '<')) => Ok((BackwardLink, 2)),
                Some((_, '?')) => {
                    if let Some((_, '>')) = iter.next() {
                        Ok((OptionalLink, 3))
                    } else {
                        Err(Error::new(
                            "`.?` is not an operator, \
                                did you mean `.?>` ?",
                        ))
                    }
                }
                _ => Ok((Dot, 1)),
            },
            '?' => match iter.next() {
                Some((_, '?')) => Ok((Coalesce, 2)),
                Some((_, '=')) => Ok((NotDistinctFrom, 2)),
                Some((_, '!')) => {
                    if let Some((_, '=')) = iter.next() {
                        Ok((DistinctFrom, 3))
                    } else {
                        Err(Error::new(
                            "`?!` is not an operator, \
                                did you mean `?!=` ?",
                        ))
                    }
                }
                _ => Err(Error::new(
                    "Bare `?` is not an operator, \
                            did you mean `?=` or `??` ?",
                )),
            },
            '!' => match iter.next() {
                Some((_, '=')) => Ok((NotEq, 2)),
                _ => Err(Error::new(
                    "Bare `!` is not an operator, \
                            did you mean `!=`?",
                )),
            },
            '"' | '\'' => self.parse_string(0, false, false),
            '`' => {
                while let Some((idx, c)) = iter.next() {
                    if c == '`' {
                        if let Some((_, '`')) = iter.next() {
                            continue;
                        }
                        let val = &tail[..idx + 1];
                        if val.starts_with("`@") {
                            return Err(Error::new(
                                "backtick-quoted name cannot \
                                    start with char `@`",
                            ));
                        }
                        if val.starts_with("`$") {
                            return Err(Error::new(
                                "backtick-quoted name cannot \
                                    start with char `$`",
                            ));
                        }
                        if val.contains("::") {
                            return Err(Error::new(
                                "backtick-quoted name cannot \
                                    contain `::`",
                            ));
                        }
                        if val.starts_with("`__") && val.ends_with("__`") {
                            return Err(Error::new(
                                "backtick-quoted names surrounded by double \
                                    underscores are forbidden",
                            ));
                        }
                        if idx == 1 {
                            return Err(Error::new("backtick quotes cannot be empty"));
                        }
                        return Ok((BacktickName, idx + 1));
                    }
                    check_prohibited(c, false)?;
                }
                Err(Error::new("unterminated backtick name"))
            }
            '=' => Ok((Eq, 1)),
            ',' => Ok((Comma, 1)),
            '(' => Ok((OpenParen, 1)),
            ')' => match self.str_interp_stack.last() {
                Some((delim, paren_count)) if *paren_count == self.open_parens => {
                    self.parse_string_interp_cont(delim)
                }
                _ => Ok((CloseParen, 1)),
            },
            '[' => Ok((OpenBracket, 1)),
            ']' => Ok((CloseBracket, 1)),
            '{' => Ok((OpenBrace, 1)),
            '}' => Ok((CloseBrace, 1)),
            ';' => Ok((Semicolon, 1)),
            '*' => match iter.next() {
                Some((_, '*')) => Ok((DoubleSplat, 2)),
                _ => Ok((Mul, 1)),
            },
            '%' => Ok((Modulo, 1)),
            '^' => Ok((Pow, 1)),
            '&' => Ok((Ampersand, 1)),
            '|' => Ok((Pipe, 1)),
            '@' => Ok((At, 1)),
            c if c == '_' || c.is_alphabetic() => {
                let end_idx = loop {
                    match iter.next() {
                        Some((idx, '"')) | Some((idx, '\'')) => {
                            let prefix = &tail[..idx];
                            let (raw, binary) = match prefix {
                                "r" => (true, false),
                                "b" => (false, true),
                                "rb" => (true, true),
                                "br" => (true, true),
                                _ => {
                                    return Err(Error::new(format_args!(
                                        "prefix {prefix:?} \
                                    is not allowed for strings, \
                                    allowed: `b`, `r`"
                                    )))
                                }
                            };
                            return self.parse_string(idx, raw, binary);
                        }
                        Some((idx, '`')) => {
                            let prefix = &tail[..idx];
                            return Err(Error::new(format_args!(
                                "prefix {prefix:?} is not \
                                allowed for field names, perhaps missing \
                                comma or dot?"
                            )));
                        }
                        Some((_, c)) if c == '_' || c.is_alphanumeric() => continue,
                        Some((idx, _)) => break idx,
                        None => break self.buf.len() - self.off,
                    }
                };
                let val = &tail[..end_idx];
                if let Some(keyword) = self.as_keyword(val) {
                    Ok((Keyword(keyword), end_idx))
                } else if val.starts_with("__") && val.ends_with("__") {
                    return Err(Error::new(
                        "identifiers surrounded by double \
                            underscores are forbidden",
                    ));
                } else {
                    return Ok((Ident, end_idx));
                }
            }
            '0'..='9' => {
                if self.dot {
                    let len = loop {
                        match iter.next() {
                            Some((_, '0'..='9')) => continue,
                            Some((_, c)) if c.is_alphabetic() => {
                                return Err(Error::new(format_args!(
                                    "unexpected char {c:?}, \
                                        only integers are allowed after dot \
                                        (for tuple access)"
                                )));
                            }
                            Some((idx, _)) => break idx,
                            None => break self.buf.len() - self.off,
                        }
                    };
                    if cur_char == '0' && len > 1 {
                        return Err(Error::new("leading zeros are not allowed in numbers"));
                    }
                    Ok((IntConst, len))
                } else {
                    self.parse_number()
                }
            }
            '$' => {
                let mut has_letter = false;
                if let Some((_, c)) = iter.next() {
                    match c {
                        '$' => {
                            let suffix = &self.buf[self.off + 2..];
                            let end = find(suffix.as_bytes(), b"$$");
                            if let Some(end) = end {
                                for c in self.buf[self.off + 2..][..end].chars() {
                                    check_prohibited(c, false)?;
                                }
                                return Ok((Str, 2 + end + 2));
                            } else {
                                return Err(Error::new("unterminated string started with $$"));
                            }
                        }
                        '`' => {
                            while let Some((idx, c)) = iter.next() {
                                if c == '`' {
                                    if let Some((_, '`')) = iter.next() {
                                        continue;
                                    }
                                    let var = &tail[..idx + 1];
                                    if var.starts_with("$`@") {
                                        return Err(Error::new(
                                            "backtick-quoted argument \
                                                cannot start with char `@`",
                                        ));
                                    }
                                    if var.contains("::") {
                                        return Err(Error::new(
                                            "backtick-quoted argument \
                                                cannot contain `::`",
                                        ));
                                    }
                                    if var.starts_with("$`__") && var.ends_with("__`") {
                                        return Err(Error::new(
                                            "backtick-quoted arguments \
                                                surrounded by double \
                                                underscores are forbidden",
                                        ));
                                    }
                                    if idx == 2 {
                                        return Err(Error::new(
                                            "backtick-quoted argument cannot be empty",
                                        ));
                                    }
                                    return Ok((Parameter, idx + 1));
                                }
                                check_prohibited(c, false)?;
                            }
                            return Err(Error::new("unterminated backtick argument"));
                        }
                        '0'..='9' => {}
                        c if c.is_alphabetic() || c == '_' => {
                            has_letter = true;
                        }
                        _ => return Err(Error::new("bare $ is not allowed")),
                    }
                } else {
                    return Err(Error::new("bare $ is not allowed"));
                }
                let end_idx = loop {
                    match iter.next() {
                        Some((end_idx, '$')) => {
                            let msize = end_idx + 1;
                            let marker = &self.buf[self.off..][..msize];
                            if let Some('0'..='9') = marker[1..].chars().next() {
                                return Err(Error::new("dollar quote must not start with a digit"));
                            }
                            if !marker.is_ascii() {
                                return Err(Error::new("dollar quote supports only ascii chars"));
                            }
                            if let Some(end) =
                                find(&self.buf.as_bytes()[self.off + msize..], marker.as_bytes())
                            {
                                let data = &self.buf[self.off + msize..][..end];
                                for c in data.chars() {
                                    check_prohibited(c, false)?;
                                }
                                return Ok((Str, msize + end + msize));
                            } else {
                                return Err(Error::new(format_args!(
                                    "unterminated string started with {marker:?}"
                                )));
                            }
                        }
                        Some((_, '0'..='9')) => continue,
                        Some((_, c)) if c.is_alphabetic() || c == '_' => {
                            has_letter = true;
                            continue;
                        }
                        Some((end_idx, _)) => break end_idx,
                        None => break self.buf.len() - self.off,
                    }
                };
                if has_letter {
                    let name = &tail[1..];
                    if let Some('0'..='9') = name.chars().next() {
                        return Err(Error::new(format_args!(
                            "the {:?} is not a valid \
                            argument, either name starting with letter \
                            or only digits are expected",
                            &tail[..end_idx]
                        )));
                    }
                }
                Ok((Parameter, end_idx))
            }
            '\\' => match iter.next() {
                Some((_, '(')) => {
                    let len = loop {
                        match iter.next() {
                            Some((_, '_')) => continue,
                            Some((_, c)) if c.is_alphanumeric() => continue,
                            Some((idx, ')')) => break idx,
                            Some((_, _)) => {
                                return Err(Error::new(
                                    "only alphanumerics are allowed in \
                                     \\(name) token",
                                ));
                            }
                            None => {
                                return Err(Error::new("unclosed \\(name) token"));
                            }
                        }
                    };
                    Ok((Substitution, len + 1))
                }
                _ => Err(Error::new(format_args!(
                    "unexpected character {cur_char:?}",
                ))),
            },
            _ => Err(Error::new(format_args!(
                "unexpected character {cur_char:?}",
            ))),
        }
    }

    fn parse_string(
        &self,
        quote_off: usize,
        raw: bool,
        binary: bool,
    ) -> Result<(Kind, usize), Error> {
        let mut iter = self.buf[self.off + quote_off..].char_indices();
        let open_quote = iter.next().unwrap().1;
        if binary {
            while let Some((idx, c)) = iter.next() {
                match c {
                    '\\' if !raw => match iter.next() {
                        // skip any next char, even quote
                        Some((_, _)) => continue,
                        None => break,
                    },
                    c if c as u32 > 0x7f => {
                        return Err(Error::new(format_args!(
                            "invalid bytes literal: character \
                                {c:?} is unexpected, only ascii chars are \
                                allowed in bytes literals"
                        )));
                    }
                    c if c == open_quote => return Ok((Kind::BinStr, quote_off + idx + 1)),
                    _ => {}
                }
            }
        } else {
            while let Some((idx, c)) = iter.next() {
                match c {
                    '\\' if !raw => match iter.next() {
                        Some((idx, '(')) => return Ok((Kind::StrInterpStart, quote_off + idx + 1)),
                        // skip any next char, even quote
                        Some((_, _)) => continue,
                        None => break,
                    },
                    c if c == open_quote => return Ok((Kind::Str, quote_off + idx + 1)),
                    _ => check_prohibited(c, true)?,
                }
            }
        }
        Err(Error::new(format_args!(
            "unterminated string, quoted by `{open_quote}`"
        )))
    }

    fn parse_string_interp_cont(&self, end: &str) -> Result<(Kind, usize), Error> {
        let quote_off = 1;
        let mut iter = self.buf[self.off + quote_off..].char_indices();

        while let Some((idx, c)) = iter.next() {
            match c {
                '\\' => match iter.next() {
                    Some((idx, '(')) => return Ok((Kind::StrInterpCont, quote_off + idx + 1)),
                    // skip any next char, even quote
                    Some((_, _)) => continue,
                    None => break,
                },
                _ if self.buf[self.off + quote_off + idx..].starts_with(end) => {
                    return Ok((Kind::StrInterpEnd, quote_off + idx + end.len()))
                }
                _ => check_prohibited(c, true)?,
            }
        }
        Err(Error::new(format_args!(
            "unterminated string with interpolations, quoted by `{end}`",
        )))
    }

    fn parse_number(&mut self) -> Result<(Kind, usize), Error> {
        #[derive(PartialEq, PartialOrd)]
        enum Break {
            Dot,
            Exponent,
            Letter,
            End,
        }
        use self::Kind::*;
        let mut iter = self.buf[self.off + 1..].char_indices();
        let mut suffix = None;
        let mut decimal = false;
        // decimal part
        let (mut bstate, dec_len) = loop {
            match iter.next() {
                Some((_, '0'..='9')) => continue,
                Some((_, '_')) => continue,
                Some((idx, 'e')) => break (Break::Exponent, idx + 1),
                Some((idx, '.')) => break (Break::Dot, idx + 1),
                Some((idx, c)) if c.is_alphabetic() => {
                    suffix = Some(idx + 1);
                    break (Break::Letter, idx + 1);
                }
                Some((idx, _)) => break (Break::End, idx + 1),
                None => break (Break::End, self.buf.len() - self.off),
            }
        };
        if self.buf.as_bytes()[self.off] == b'0' && dec_len > 1 {
            return Err(Error::new(
                "unexpected leading zeros are not allowed in numbers",
            ));
        }
        if bstate == Break::End {
            return Ok((IntConst, dec_len));
        }
        if bstate == Break::Dot {
            decimal = true;
            bstate = loop {
                if let Some((idx, c)) = iter.next() {
                    match c {
                        '0'..='9' => continue,
                        '_' => {
                            if idx + 1 == dec_len + 1 {
                                return Err(Error::new(
                                    "expected digit after dot, \
                                    found underscore",
                                ));
                            }
                            continue;
                        }
                        'e' => {
                            if idx + 1 == dec_len + 1 {
                                return Err(Error::new(
                                    "expected digit after dot, \
                                    found exponent",
                                ));
                            }
                            break Break::Exponent;
                        }
                        '.' => return Err(Error::new("unexpected extra decimal dot in number")),
                        c if c.is_alphabetic() => {
                            if idx == dec_len {
                                return Err(Error::new("expected digit after dot, found suffix"));
                            }
                            suffix = Some(idx + 1);
                            break Break::Letter;
                        }
                        _ => {
                            if idx + 1 == dec_len + 1 {
                                return Err(Error::new(
                                    "expected digit after dot, \
                                    found end of decimal",
                                ));
                            }
                            return Ok((FloatConst, idx + 1));
                        }
                    }
                } else {
                    if self.buf.len() - self.off == dec_len + 1 {
                        return Err(Error::new("expected digit after dot, found end of decimal"));
                    }
                    return Ok((FloatConst, self.buf.len() - self.off));
                }
            }
        }
        if bstate == Break::Exponent {
            match iter.next() {
                Some((_, '0'..='9')) => {}
                Some((_, c @ '+')) | Some((_, c @ '-')) => {
                    if c == '-' {
                        decimal = true;
                    }
                    match iter.next() {
                        Some((_, '0'..='9')) => {}
                        Some((_, '.')) => {
                            return Err(Error::new("unexpected extra decimal dot in number"))
                        }
                        _ => {
                            return Err(Error::new(
                                "unexpected optional `+` or `-` followed by digits must \
                                follow `e` in float const",
                            ))
                        }
                    }
                }
                _ => {
                    return Err(Error::new(
                        "unexpected optional `+` or `-` followed by digits must \
                        follow `e` in float const",
                    ))
                }
            }
            loop {
                match iter.next() {
                    Some((_, '0'..='9')) => continue,
                    Some((_, '_')) => continue,
                    Some((_, '.')) => {
                        return Err(Error::new("unexpected extra decimal dot in number"))
                    }
                    Some((idx, c)) if c.is_alphabetic() => {
                        suffix = Some(idx + 1);
                        break;
                    }
                    Some((idx, _)) => return Ok((FloatConst, idx + 1)),
                    None => return Ok((FloatConst, self.buf.len() - self.off)),
                }
            }
        }
        let soff = suffix.expect("tokenizer integrity error");
        let end = loop {
            if let Some((idx, c)) = iter.next() {
                if c != '_' && !c.is_alphanumeric() {
                    break idx + 1;
                }
            } else {
                break self.buf.len() - self.off;
            }
        };
        let suffix = &self.buf[self.off + soff..self.off + end];
        if suffix == "n" {
            if decimal {
                Ok((DecimalConst, end))
            } else {
                Ok((BigIntConst, end))
            }
        } else {
            let suffix = if suffix.len() > 8 {
                Cow::Owned(format!("{}...", &suffix[..8]))
            } else {
                Cow::Borrowed(suffix)
            };
            let val = if soff < 20 {
                &self.buf[self.off..][..soff]
            } else {
                "123"
            };
            if suffix.starts_with('O') {
                Err(Error::new(format_args!(
                    "suffix {suffix:?} is invalid for \
                        numbers, perhaps mixed up letter `O` \
                        with zero `0`?"
                )))
            } else if decimal {
                return Err(Error::new(format_args!(
                    "suffix {suffix:?} is invalid for \
                        numbers, perhaps you wanted `{val}n` (decimal)?"
                )));
            } else {
                return Err(Error::new(format_args!(
                    "suffix {suffix:?} is invalid for \
                        numbers, perhaps you wanted `{val}n` (bigint)?"
                )));
            }
        }
    }

    fn skip_whitespace(&mut self) {
        let mut iter = self.buf[self.off..].char_indices();
        let idx = 'outer: loop {
            let (idx, cur_char) = match iter.next() {
                Some(pair) => pair,
                None => break self.buf.len() - self.off,
            };
            match cur_char {
                '\u{feff}' | '\r' => continue,
                '\t' => self.position.column += 8,
                '\n' => {
                    self.position.column = 1;
                    self.position.line += 1;
                }
                // comma is also entirely ignored in spec
                ' ' => {
                    self.position.column += 1;
                    continue;
                }
                //comment
                '#' => {
                    for (idx, cur_char) in iter.by_ref() {
                        if check_prohibited(cur_char, false).is_err() {
                            // can't return error from skip_whitespace
                            // but we return up to this char, so the tokenizer
                            // chokes on it next time is invoked
                            break 'outer idx;
                        }
                        if cur_char == '\r' || cur_char == '\n' {
                            self.position.column = 1;
                            self.position.line += 1;
                            break;
                        }
                    }
                    continue;
                }
                _ => break idx,
            }
        };
        self.off += idx;
        self.position.offset += idx as u64;
    }

    fn update_position(&mut self, len: usize) {
        let val = &self.buf[self.off..][..len];
        self.off += len;
        let lines = val.as_bytes().iter().filter(|&&x| x == b'\n').count();
        self.position.line += lines;
        if lines > 0 {
            let line_offset = val.rfind('\n').unwrap() + 1;
            let num = val[line_offset..].chars().count();
            self.position.column = num + 1;
        } else {
            let num = val.chars().count();
            self.position.column += num;
        }
        self.position.offset += len as u64;
    }

    fn as_keyword(&mut self, s: &str) -> Option {
        if s.len() > MAX_KEYWORD_LENGTH {
            return None;
        }
        self.keyword_buf.clear();
        self.keyword_buf.push_str(s);
        self.keyword_buf.make_ascii_lowercase();
        keywords::lookup_all(&self.keyword_buf)
    }
}

impl fmt::Display for TokenStub<'_> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}[{:?}]", self.text, self.kind)
    }
}

impl fmt::Display for Token<'_> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}[{:?}]", self.text, self.kind)
    }
}

impl Token<'_> {
    pub fn cloned(self) -> Token<'static> {
        Token {
            kind: self.kind,
            text: Cow::<'static, str>::Owned(self.text.to_string()),
            value: self.value,
            span: self.span,
        }
    }
}

fn check_prohibited(c: char, escape: bool) -> Result<(), Error> {
    match c {
        '\0' if escape => Err(Error::new("character U+0000 is not allowed")),
        '\0' | '\u{202A}' | '\u{202B}' | '\u{202C}' | '\u{202D}' | '\u{202E}' | '\u{2066}'
        | '\u{2067}' | '\u{2068}' | '\u{2069}' => {
            if escape {
                Err(Error::new(format!(
                    "character U+{0:04X} is not allowed, \
                     use escaped form \\u{0:04x}",
                    c as u32
                )))
            } else {
                Err(Error::new(format!(
                    "character U+{:04X} is not allowed",
                    c as u32
                )))
            }
        }
        _ => Ok(()),
    }
}

impl std::cmp::PartialEq for Token<'_> {
    fn eq(&self, other: &Self) -> bool {
        self.kind == other.kind && self.text == other.text && self.value == other.value
    }
}

impl std::fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str(&self.message)
    }
}

#[cfg(feature = "serde")]
fn deserialize_keyword<'de, D>(deserializer: D) -> Result
where
    D: serde::Deserializer<'de>,
{
    struct Visitor;
    use serde::de;

    impl de::Visitor<'_> for Visitor {
        type Value = Keyword;

        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
            formatter.write_str("EdgeQL keyword")
        }

        fn visit_str(self, v: &str) -> Result
        where
            E: de::Error,
        {
            keywords::lookup_all(v)
                .ok_or_else(|| de::Error::invalid_value(de::Unexpected::Str(v), &"keyword"))
        }
    }

    deserializer.deserialize_str(Visitor)
}

impl Kind {
    pub fn text(&self) -> Option<&'static str> {
        use Kind::*;

        Some(match self {
            Add => "+",
            Ampersand => "&",
            At => "@",
            BackwardLink => ".<",
            OptionalLink => ".?>",
            CloseBrace => "}",
            CloseBracket => "]",
            CloseParen => ")",
            Coalesce => "??",
            Colon => ":",
            Comma => ",",
            Concat => "++",
            Div => "/",
            Dot => ".",
            DoubleSplat => "**",
            Eq => "=",
            FloorDiv => "//",
            Modulo => "%",
            Mul => "*",
            Namespace => "::",
            OpenBrace => "{",
            OpenBracket => "[",
            OpenParen => "(",
            Pipe => "|",
            Pow => "^",
            Semicolon => ";",
            Sub => "-",

            DistinctFrom => "?!=",
            GreaterEq => ">=",
            LessEq => "<=",
            NotDistinctFrom => "?=",
            NotEq => "!=",
            Less => "<",
            Greater => ">",

            AddAssign => "+=",
            Arrow => "->",
            Assign => ":=",
            SubAssign => "-=",

            Keyword(keywords::Keyword(kw)) => kw,

            _ => return None,
        })
    }

    pub fn user_friendly_text(&self) -> Option<&'static str> {
        use Kind::*;
        Some(match self {
            Ident => "identifier",
            EOI => "end of input",

            BinStr => "binary constant",
            FloatConst => "float constant",
            IntConst => "int constant",
            DecimalConst => "decimal constant",
            BigIntConst => "big int constant",
            Str => "string constant",

            _ => return None,
        })
    }
}


================================================
FILE: edb/edgeql-parser/src/validation.rs
================================================
use std::str::FromStr;

use bigdecimal::num_bigint::ToBigInt;
use bigdecimal::BigDecimal;

use crate::helpers::{unquote_bytes, unquote_string};
use crate::keywords::Keyword;
use crate::position::{Pos, Span};
use crate::tokenizer::{Error, Kind, Token, Tokenizer, Value, MAX_KEYWORD_LENGTH};

/// Applies additional validation to the tokens.
/// Combines multi-word keywords into single tokens.
/// Remaps a few token kinds.
pub struct Validator<'a> {
    pub inner: Tokenizer<'a>,

    pub(super) peeked: Option, Error>>>,
    pub(super) keyword_buf: String,
}

impl<'a> Iterator for Validator<'a> {
    type Item = Result, Error>;

    fn next(&mut self) -> Option {
        let mut token = match self.next_inner()? {
            Ok(t) => t,
            Err(e) => return Some(Err(e)),
        };

        token.value = match parse_value(&token) {
            Ok(x) => x,
            Err(e) => return Some(Err(Error::new(e).with_span(token.span))),
        };

        if let Some(keyword) = self.combine_multi_word_keywords(&token) {
            token.text = keyword.into();
            token.kind = Kind::Keyword(Keyword(keyword));
            self.peeked = None;
        }

        token.kind = remap_kind(token.kind);

        Some(Ok(token))
    }
}

impl<'a> Validator<'a> {
    pub(super) fn new(inner: Tokenizer<'a>) -> Self {
        Validator {
            inner,
            peeked: None,
            keyword_buf: String::with_capacity(MAX_KEYWORD_LENGTH),
        }
    }

    pub fn with_eof(self) -> WithEof<'a> {
        WithEof {
            inner: self,
            emitted: false,
        }
    }

    /// Mimics behavior of [std::iter::Peekable]. We could use that, but it
    /// hides access to underlying iterator.
    fn next_inner(&mut self) -> Option, Error>> {
        if let Some(peeked) = self.peeked.take() {
            peeked
        } else {
            self.inner.next()
        }
    }

    /// Mimics behavior of [std::iter::Peekable]. We could use that, but it
    /// hides access to underlying iterator.
    fn peek(&mut self) -> &Option> {
        if self.peeked.is_none() {
            self.peeked = Some(self.inner.next());
        }

        self.peeked.as_ref().unwrap()
    }

    pub fn current_pos(&self) -> Pos {
        self.inner.current_pos()
    }

    fn combine_multi_word_keywords(&mut self, token: &Token<'a>) -> Option<&'static str> {
        if !matches!(token.kind, Kind::Ident | Kind::Keyword(_)) {
            return None;
        }
        let text = &token.text;

        if text.len() > MAX_KEYWORD_LENGTH {
            return None;
        }

        self.keyword_buf.clear();
        self.keyword_buf.push_str(text);
        self.keyword_buf.make_ascii_lowercase();
        match &self.keyword_buf[..] {
            "named" => {
                if self.peek_keyword("only") {
                    return Some("named only");
                }
            }
            "set" => {
                if self.peek_keyword("annotation") {
                    return Some("set annotation");
                }
                if self.peek_keyword("type") {
                    return Some("set type");
                }
            }
            "extension" => {
                if self.peek_keyword("package") {
                    return Some("extension package");
                }
            }
            "order" => {
                if self.peek_keyword("by") {
                    return Some("order by");
                }
            }
            _ => {}
        }
        None
    }

    fn peek_keyword(&mut self, kw: &'static str) -> bool {
        self.peek()
            .as_ref()
            .and_then(|res| res.as_ref().ok())
            .map(|t| {
                t.kind == Kind::Keyword(Keyword(kw))
                    || (t.kind == Kind::Ident && t.text.eq_ignore_ascii_case(kw))
            })
            .unwrap_or(false)
    }
}

pub fn parse_value(token: &Token) -> Result, String> {
    use Kind::*;
    let text = &token.text;
    let string_value = match token.kind {
        Parameter => {
            if text[1..].starts_with('`') {
                text[2..text.len() - 1].replace("``", "`")
            } else {
                text[1..].to_string()
            }
        }
        DecimalConst => {
            return text[..text.len() - 1]
                .replace('_', "")
                .parse()
                .map(Value::Decimal)
                .map(Some)
                .map_err(|e| format!("can't parse decimal: {e}"))
        }
        FloatConst => {
            return text
                .replace('_', "")
                .parse::()
                .map_err(|e| format!("can't parse std::float64: {e}"))
                .and_then(|num| {
                    if num.is_infinite() {
                        return Err("number is out of range for std::float64".to_string());
                    }
                    if num == 0.0 {
                        let mend = text.find(['e', 'E']).unwrap_or(text.len());
                        let mantissa = &text[..mend];
                        if mantissa.chars().any(|c| c != '0' && c != '.') {
                            return Err("number is out of range for std::float64".to_string());
                        }
                    }
                    Ok(num)
                })
                .map(Value::Float)
                .map(Some);
        }
        IntConst => {
            // We read unsigned here, because unary minus will only
            // be identified on the parser stage. And there is a number
            // -9223372036854775808 which can't be represented in
            // i64 as absolute (positive) value.
            // Python has no problem of representing such a positive
            // value, though.
            return u64::from_str(&text.replace('_', ""))
                .map(|x| Some(Value::Int(x as i64)))
                .map_err(|e| format!("error reading int: {e}"));
        }
        BigIntConst => {
            return text[..text.len() - 1]
                .replace('_', "")
                .parse::()
                .map_err(|e| format!("error reading bigint: {e}"))
                // this conversion to decimal and back to string
                // fixes thing like `1e2n` which we support for bigints
                .and_then(|x| {
                    x.to_bigint()
                        .ok_or_else(|| "number is not integer".to_string())
                })
                .map(|x| Some(Value::BigInt(x.to_str_radix(16))));
        }
        BinStr => {
            return unquote_bytes(text).map(Value::Bytes).map(Some);
        }

        Str | StrInterpStart | StrInterpEnd | StrInterpCont => {
            unquote_string(text).map_err(|s| s.to_string())?.to_string()
        }
        BacktickName => text[1..text.len() - 1].replace("``", "`"),
        Ident | Keyword(_) => text.to_string(),
        Substitution => text[2..text.len() - 1].to_string(),
        _ => return Ok(None),
    };
    Ok(Some(Value::String(string_value)))
}

fn remap_kind(kind: Kind) -> Kind {
    match kind {
        Kind::BacktickName => Kind::Ident,
        kind => kind,
    }
}

pub struct WithEof<'a> {
    inner: Validator<'a>,

    emitted: bool,
}

impl<'a> Iterator for WithEof<'a> {
    type Item = Result, Error>;

    fn next(&mut self) -> Option {
        if let Some(next) = self.inner.next() {
            Some(next)
        } else if !self.emitted {
            self.emitted = true;
            let pos = self.inner.current_pos().offset;

            Some(Ok(Token {
                kind: Kind::EOI,
                text: "".into(),
                value: None,
                span: Span {
                    start: pos,
                    end: pos,
                },
            }))
        } else {
            None
        }
    }
}


================================================
FILE: edb/edgeql-parser/tests/expr.rs
================================================
use edgeql_parser::expr::check;

#[test]
fn test_valid() {
    check("1").unwrap();
    check(" 42    ").unwrap();
    check("42 # )").unwrap();
    check("33 ++ 44").unwrap();
    check("33 ++ '44'").unwrap();
    check("(1, 2) # tuple").unwrap();
    check("# next line\n 2+2").unwrap();
    check("{}").unwrap();
    check("()").unwrap();
    check(".user.name").unwrap();
    check("call(me.maybe)").unwrap();
    check("bad +/- grammar **** but --- allowed").unwrap();
}

fn check_err(s: &str) -> String {
    check(s).unwrap_err().to_string()
}

#[test]
fn test_empty() {
    assert_eq!(check_err(""), "expression is empty");
    assert_eq!(check_err("   "), "expression is empty");
    assert_eq!(check_err("# xxx + yyy"), "expression is empty");
}

#[test]
fn bad_token() {
    assert_eq!(
        check_err("'quote"),
        "1:1: tokenizer error: unterminated string, quoted by `'`"
    );
    assert_eq!(
        check_err("\\(quote"),
        "1:1: tokenizer error: unclosed \\(name) token"
    );
}

#[test]
fn bracket_mismatch() {
    assert_eq!(
        check_err("(a[12)]"),
        "1:6: closing bracket mismatch, \
            opened \"[\" at 1:3, encountered \")\""
    );
    assert_eq!(
        check_err("(a12]"),
        "1:5: closing bracket mismatch, \
            opened \"(\" at 1:1, encountered \"]\""
    );
    assert_eq!(
        check_err("{'}']"),
        "1:5: closing bracket mismatch, \
            opened \"{\" at 1:1, encountered \"]\""
    );
}

#[test]
fn extra_brackets() {
    assert_eq!(check_err("func())"), "1:7: extra closing bracket \")\"");
    assert_eq!(check_err("{} + x]"), "1:7: extra closing bracket \"]\"");
    assert_eq!(
        check_err("{'xxx(yyy'})"),
        "1:12: extra closing bracket \")\""
    );
}

#[test]
fn missing_brackets() {
    assert_eq!(
        check_err("func((1, 2)"),
        "1:5: bracket \"(\" has never been closed"
    );
    assert_eq!(
        check_err("{(1, 2), (3, '}')"),
        "1:1: bracket \"{\" has never been closed"
    );
    assert_eq!(
        check_err("{((())[[()"),
        "1:8: bracket \"[\" has never been closed"
    );
}

#[test]
fn delimiter() {
    assert_eq!(
        check_err("1, 2"),
        "1:2: token \",\" is not allowed in expression \
         (try parenthesize the expression)"
    );
    check("(1, 2)").unwrap();

    assert_eq!(
        check_err("create type Type1;"),
        "1:18: token \";\" is not allowed in expression \
         (try parenthesize the expression)"
    );
    // this doesn't work, but is fun to see
    check("{create if not exists type Type1; SELECT Type1}").unwrap();
}


================================================
FILE: edb/edgeql-parser/tests/preparser.rs
================================================
use edgeql_parser::preparser::{full_statement, is_empty};

fn test_statement(data: &[u8], len: usize) {
    for i in 0..len - 1 {
        let c = full_statement(&data[..i], None).unwrap_err();
        let parsed_len = full_statement(data, Some(c)).unwrap();
        assert_eq!(len, parsed_len, "at {i}");
    }
    for i in len..data.len() {
        let parsed_len = full_statement(&data[..i], None).unwrap();
        assert_eq!(len, parsed_len);
    }
}

#[test]
fn test_simple() {
    test_statement(b"select 1+1; some trailer", 11);
}

#[test]
fn test_quotes() {
    test_statement(b"select \"x\"; some trailer", 11);
}

#[test]
fn test_quoted_semicolon() {
    test_statement(b"select \"a;\"; some trailer", 12);
}

#[test]
fn test_raw_string() {
    test_statement(br#"select r"\"; some trailer"#, 12);
}

#[test]
fn test_raw_byte_string() {
    test_statement(br#"select rb"\"; some trailer"#, 13);
    test_statement(br"select br'hello\'; some trailer", 18);
}

#[test]
fn test_single_quoted_semicolon() {
    test_statement(b"select 'a;'; some trailer", 12);
}

#[test]
fn test_backtick_quoted_semicolon() {
    test_statement(b"select `a;`; some trailer", 12);
}

#[test]
fn test_commented_semicolon() {
    test_statement(b"select # test;\n1+1;", 19);
}

#[test]
fn test_continuation() {
    test_statement(b"select 'a;'; '", 12);
}

#[test]
fn test_quoted_continuation() {
    test_statement(b"select \"a; \";", 13);
}

#[test]
fn test_single_quoted_continuation() {
    test_statement(b"select 'a; ' ;", 14);
}

#[test]
fn test_backtick_quoted_continuation() {
    test_statement(b"select `a;test`+1;", 18);
}

#[test]
fn test_dollar_semicolon() {
    test_statement(b"select $$ ; $$ test;", 20);
    test_statement(b"select $$$$;", 12);
    test_statement(b"select $$$ ; $$;", 16);
    test_statement(b"select $some_L0ng_name$ ; $some_L0ng_name$;", 43);
}

#[test]
fn test_nested_dollar() {
    test_statement(b"select $a$ ; $b$ ; $b$ ; $a$; x", 29);
    test_statement(b"select $a$ ; $b$ ; $a$; x", 23);
}

#[test]
fn test_dollar_continuation() {
    test_statement(b"select $$ ; $ab$ test; $$ ;", 27);
    test_statement(b"select $a$ ; $$ test; $a$ ;", 27);
    test_statement(b"select $a$ ; test; $a$ ;", 24);
    test_statement(b"select $a$a$ ; $$ test; $a$;", 28);
    test_statement(b"select $a$ ; $b$ ; $c$ ; $b$ test; $a$;", 39);
}

#[test]
fn test_dollar_var() {
    test_statement(b"select $a+b; $ test; $a+b; $ ;", 12);
    test_statement(b"select $a b; $ test; $a b; $ ;", 12);
}

#[test]
fn test_after_variable() {
    test_statement(b"select $$ $$; extra;", 13);
    test_statement(b"select $a$ $a$; extra;", 15);
    test_statement(b"select $a;", 10);
    test_statement(b"select $a{ x; };", 16);
}

#[test]
fn test_schema() {
    test_statement(
        br###"
        START MIGRATION TO {
            module default {
                type Movie {
                    required property title -> str;
                    # the year of release
                    property year -> int64;
                    required link director -> Person;
                    multi link actors -> Person;
                }
                type Person {
                    required property first_name -> str;
                    required property last_name -> str;
                }
            }
        };
        "###,
        532,
    );
}

#[test]
fn test_function() {
    test_statement(b"drop function foo(s: str); ", 26);
}

#[test]
fn empty() {
    assert!(is_empty(""));
    assert!(is_empty(" "));
    assert!(is_empty("\n"));
    assert!(is_empty("#xx"));
    assert!(is_empty("#xx\n"));
    assert!(is_empty("# xx\n# yy"));
    assert!(is_empty(" #xx\n  #yy"));
    assert!(is_empty(";"));
    assert!(is_empty(";;"));
    assert!(is_empty("    ;\n#cd"));
    assert!(!is_empty("a"));
    assert!(!is_empty("ab cd"));
    assert!(!is_empty(","));
    assert!(!is_empty(";ab;"));
    assert!(!is_empty("ab;;de"));
    assert!(!is_empty("    xy"));
    assert!(!is_empty("    xy #c"));
    assert!(!is_empty("    '#c"));
    assert!(!is_empty("ab\n#cd"));
}


================================================
FILE: edb/edgeql-parser/tests/tokenizer.rs
================================================
use edgeql_parser::tokenizer::Kind::*;
use edgeql_parser::tokenizer::{Kind, Tokenizer};

fn tok_str(s: &str) -> Vec {
    let mut r = Vec::new();
    let mut s = Tokenizer::new(s).validated_values();
    loop {
        match s.next() {
            Some(Ok(x)) => r.push(x.text.to_string()),
            None => break,
            Some(Err(e)) => panic!("Parse error at {}: {}", e.span.start, e.message),
        }
    }
    r
}

fn tok_typ(s: &str) -> Vec {
    let mut r = Vec::new();
    let mut s = Tokenizer::new(s).validated_values();
    loop {
        match s.next() {
            Some(Ok(x)) => r.push(x.kind),
            None => break,
            Some(Err(e)) => panic!("Parse error at {}: {}", e.span.start, e.message),
        }
    }
    r
}

fn tok_err(s: &str) -> String {
    let mut s = Tokenizer::new(s).validated_values();
    loop {
        match s.next() {
            Some(Ok(_)) => {}
            None => break,
            Some(Err(e)) => return e.message.to_string(),
        }
    }
    panic!("No error, where error expected");
}

fn keyword(kw: &'static str) -> Kind {
    Keyword(edgeql_parser::keywords::Keyword(kw))
}

#[test]
fn whitespace_and_comments() {
    assert_eq!(tok_str("# hello { world }"), &[] as &[&str]);
    assert_eq!(tok_str("# x\n  "), &[] as &[&str]);
    assert_eq!(tok_str("  # x"), &[] as &[&str]);
    assert_eq!(
        tok_err("  # xxx \u{202A} yyy"),
        "unexpected character '\\u{202a}'"
    );
}

#[test]
fn idents() {
    assert_eq!(tok_str("a bc d127"), ["a", "bc", "d127"]);
    assert_eq!(tok_typ("a bc d127"), [Ident, Ident, Ident]);
    assert_eq!(
        tok_str("тест тест_abc abc_тест"),
        ["тест", "тест_abc", "abc_тест"]
    );
    assert_eq!(tok_typ("тест тест_abc abc_тест"), [Ident, Ident, Ident]);
    assert_eq!(
        tok_err(" + __test__"),
        "identifiers surrounded by double underscores are forbidden"
    );
    assert_eq!(tok_str("_1024"), ["_1024"]);
    assert_eq!(tok_typ("_1024"), [Ident]);
}

#[test]
fn keywords() {
    assert_eq!(tok_str("SELECT a"), ["SELECT", "a"]);
    assert_eq!(tok_typ("SELECT a"), [keyword("select"), Ident]);
    assert_eq!(tok_str("with Select"), ["with", "Select"]);
    assert_eq!(tok_typ("with Select"), [keyword("with"), keyword("select")]);
}

#[test]
fn colon_tokens() {
    assert_eq!(tok_str("a :=b"), ["a", ":=", "b"]);
    assert_eq!(tok_typ("a :=b"), [Ident, Assign, Ident]);
    assert_eq!(tok_str("a : = b"), ["a", ":", "=", "b"]);
    assert_eq!(tok_typ("a : = b"), [Ident, Colon, Eq, Ident]);
    assert_eq!(tok_str("a ::= b"), ["a", "::", "=", "b"]);
    assert_eq!(tok_typ("a ::= b"), [Ident, Namespace, Eq, Ident]);
}

#[test]
fn dash_tokens() {
    assert_eq!(tok_str("a-b -> c"), ["a", "-", "b", "->", "c"]);
    assert_eq!(tok_typ("a-b -> c"), [Ident, Sub, Ident, Arrow, Ident]);
    assert_eq!(tok_str("a - > b"), ["a", "-", ">", "b"]);
    assert_eq!(tok_typ("a - > b"), [Ident, Sub, Greater, Ident]);
    assert_eq!(tok_str("a --> b"), ["a", "-", "->", "b"]);
    assert_eq!(tok_typ("a --> b"), [Ident, Sub, Arrow, Ident]);
}

#[test]
fn greater_tokens() {
    assert_eq!(tok_str("a >= c"), ["a", ">=", "c"]);
    assert_eq!(tok_typ("a >= c"), [Ident, GreaterEq, Ident]);
    assert_eq!(tok_str("a > = b"), ["a", ">", "=", "b"]);
    assert_eq!(tok_typ("a > = b"), [Ident, Greater, Eq, Ident]);
    assert_eq!(tok_str("a>b"), ["a", ">", "b"]);
    assert_eq!(tok_typ("a>b"), [Ident, Greater, Ident]);
}

#[test]
fn less_tokens() {
    assert_eq!(tok_str("a <= c"), ["a", "<=", "c"]);
    assert_eq!(tok_typ("a <= c"), [Ident, LessEq, Ident]);
    assert_eq!(tok_str("a < = b"), ["a", "<", "=", "b"]);
    assert_eq!(tok_typ("a < = b"), [Ident, Less, Eq, Ident]);
    assert_eq!(tok_str("a c"), ["a", ".", "b", ".", ">", "c"]);
    assert_eq!(
        tok_typ("a.b .> c"),
        [Ident, Dot, Ident, Dot, Greater, Ident]
    );
    assert_eq!(tok_str("a . > b"), ["a", ".", ">", "b"]);
    assert_eq!(tok_typ("a . > b"), [Ident, Dot, Greater, Ident]);
    assert_eq!(tok_str("a .>> b"), ["a", ".", ">", ">", "b"]);
    assert_eq!(tok_typ("a .>> b"), [Ident, Dot, Greater, Greater, Ident]);
    assert_eq!(tok_str("a ..> b"), ["a", ".", ".", ">", "b"]);
    assert_eq!(tok_typ("a ..> b"), [Ident, Dot, Dot, Greater, Ident]);

    assert_eq!(tok_str("a.b .< c"), ["a", ".", "b", ".<", "c"]);
    assert_eq!(
        tok_typ("a.b .< c"),
        [Ident, Dot, Ident, BackwardLink, Ident]
    );
    assert_eq!(tok_str("a . < b"), ["a", ".", "<", "b"]);
    assert_eq!(tok_typ("a . < b"), [Ident, Dot, Less, Ident]);
    assert_eq!(tok_str("a .<< b"), ["a", ".<", "<", "b"]);
    assert_eq!(tok_typ("a .<< b"), [Ident, BackwardLink, Less, Ident]);
    assert_eq!(tok_str("a ..< b"), ["a", ".", ".<", "b"]);
    assert_eq!(tok_typ("a ..< b"), [Ident, Dot, BackwardLink, Ident]);
}

#[test]
fn tuple_dot_vs_float() {
    assert_eq!(tok_str("tuple.1.<"), ["tuple", ".", "1", ".<"]);
    assert_eq!(tok_typ("tuple.1.<"), [Ident, Dot, IntConst, BackwardLink]);
    assert_eq!(tok_str("tuple.1.e123"), ["tuple", ".", "1", ".", "e123"]);
    assert_eq!(tok_typ("tuple.1.e123"), [Ident, Dot, IntConst, Dot, Ident]);
}

#[test]
fn div_tokens() {
    assert_eq!(tok_str("a // c"), ["a", "//", "c"]);
    assert_eq!(tok_typ("a // c"), [Ident, FloorDiv, Ident]);
    assert_eq!(tok_str("a / / b"), ["a", "/", "/", "b"]);
    assert_eq!(tok_typ("a / / b"), [Ident, Div, Div, Ident]);
    assert_eq!(tok_str("a/b"), ["a", "/", "b"]);
    assert_eq!(tok_typ("a/b"), [Ident, Div, Ident]);
}

#[test]
fn single_char_tokens() {
    assert_eq!(tok_str(".;:+-*"), [".", ";", ":", "+", "-", "*"]);
    assert_eq!(tok_typ(".;:+-*"), [Dot, Semicolon, Colon, Add, Sub, Mul]);
    assert_eq!(tok_str("/%^<>"), ["/", "%", "^", "<", ">"]);
    assert_eq!(tok_typ("/%^<>"), [Div, Modulo, Pow, Less, Greater]);
    assert_eq!(tok_str("=&|@"), ["=", "&", "|", "@"]);
    assert_eq!(tok_typ("=&|@"), [Eq, Ampersand, Pipe, At]);

    assert_eq!(tok_str(". ; : + - *"), [".", ";", ":", "+", "-", "*"]);
    assert_eq!(
        tok_typ(". ; : + - *"),
        [Dot, Semicolon, Colon, Add, Sub, Mul]
    );
    assert_eq!(tok_str("/ % ^ < >"), ["/", "%", "^", "<", ">"]);
    assert_eq!(tok_typ("/ % ^ < >"), [Div, Modulo, Pow, Less, Greater]);
    assert_eq!(tok_str("= & | @"), ["=", "&", "|", "@"]);
    assert_eq!(tok_typ("= & | @"), [Eq, Ampersand, Pipe, At]);
}

#[test]
fn splats() {
    assert_eq!(tok_str("*"), ["*"]);
    assert_eq!(tok_typ("*"), [Mul]);
    assert_eq!(tok_str("**"), ["**"]);
    assert_eq!(tok_typ("**"), [DoubleSplat]);
    assert_eq!(tok_str("* *"), ["*", "*"]);
    assert_eq!(tok_typ("* *"), [Mul, Mul]);
    assert_eq!(tok_str("User.*,"), ["User", ".", "*", ","]);
    assert_eq!(tok_typ("User.*,"), [Ident, Dot, Mul, Comma]);
    assert_eq!(tok_str("User.**,"), ["User", ".", "**", ","]);
    assert_eq!(tok_typ("User.**,"), [Ident, Dot, DoubleSplat, Comma]);
    assert_eq!(tok_str("User {*}"), ["User", "{", "*", "}"]);
    assert_eq!(tok_typ("User {*}"), [Ident, OpenBrace, Mul, CloseBrace]);
    assert_eq!(tok_str("User {**}"), ["User", "{", "**", "}"]);
    assert_eq!(
        tok_typ("User {**}"),
        [Ident, OpenBrace, DoubleSplat, CloseBrace]
    );
}

#[test]
fn integer() {
    assert_eq!(tok_str("0"), ["0"]);
    assert_eq!(tok_typ("0"), [IntConst]);
    assert_eq!(tok_str("*0"), ["*", "0"]);
    assert_eq!(tok_typ("*0"), [Mul, IntConst]);
    assert_eq!(tok_str("123"), ["123"]);
    assert_eq!(tok_typ("123"), [IntConst]);
    assert_eq!(tok_str("123_"), ["123_"]);
    assert_eq!(tok_typ("123_"), [IntConst]);
    assert_eq!(tok_str("123_456"), ["123_456"]);
    assert_eq!(tok_typ("123_456"), [IntConst]);

    assert_eq!(tok_str("0 "), ["0"]);
    assert_eq!(tok_typ("0 "), [IntConst]);
    assert_eq!(tok_str("123 "), ["123"]);
    assert_eq!(tok_typ("123 "), [IntConst]);
    assert_eq!(tok_str("123_ "), ["123_"]);
    assert_eq!(tok_typ("123_ "), [IntConst]);
    assert_eq!(tok_str("123_456 "), ["123_456"]);
    assert_eq!(tok_typ("123_456 "), [IntConst]);
}

#[test]
fn bigint() {
    assert_eq!(tok_str("0n"), ["0n"]);
    assert_eq!(tok_typ("0n"), [BigIntConst]);
    assert_eq!(tok_str("*0n"), ["*", "0n"]);
    assert_eq!(tok_typ("*0n"), [Mul, BigIntConst]);
    assert_eq!(tok_str("123n"), ["123n"]);
    assert_eq!(tok_typ("123n"), [BigIntConst]);
    assert_eq!(tok_str("123e3n"), ["123e3n"]);
    assert_eq!(tok_typ("123e3n"), [BigIntConst]);
    assert_eq!(tok_str("123e+99n"), ["123e+99n"]);
    assert_eq!(tok_typ("123e+99n"), [BigIntConst]);
    assert_eq!(tok_str("123_n"), ["123_n"]);
    assert_eq!(tok_typ("123_n"), [BigIntConst]);
    assert_eq!(tok_str("123_456n"), ["123_456n"]);
    assert_eq!(tok_typ("123_456n"), [BigIntConst]);

    assert_eq!(tok_str("0n "), ["0n"]);
    assert_eq!(tok_typ("0n "), [BigIntConst]);
    assert_eq!(tok_str("123n "), ["123n"]);
    assert_eq!(tok_typ("123n "), [BigIntConst]);
    assert_eq!(tok_str("123e3n "), ["123e3n"]);
    assert_eq!(tok_typ("123e3n "), [BigIntConst]);
    assert_eq!(tok_str("123e+99n "), ["123e+99n"]);
    assert_eq!(tok_typ("123e+99n "), [BigIntConst]);
    assert_eq!(tok_str("123_n "), ["123_n"]);
    assert_eq!(tok_typ("123_n "), [BigIntConst]);
    assert_eq!(tok_str("123_456n "), ["123_456n"]);
    assert_eq!(tok_typ("123_456n "), [BigIntConst]);
}

#[test]
fn float() {
    assert_eq!(tok_str("     0.0"), ["0.0"]);
    assert_eq!(tok_typ("     0.0"), [FloatConst]);
    assert_eq!(tok_str("123.999"), ["123.999"]);
    assert_eq!(tok_typ("123.999"), [FloatConst]);
    assert_eq!(tok_str("123.999e3"), ["123.999e3"]);
    assert_eq!(tok_typ("123.999e3"), [FloatConst]);
    assert_eq!(tok_str("123.999e+99"), ["123.999e+99"]);
    assert_eq!(tok_typ("123.999e+99"), [FloatConst]);
    assert_eq!(tok_str("2345.567e-7"), ["2345.567e-7"]);
    assert_eq!(tok_typ("2345.567e-7"), [FloatConst]);
    assert_eq!(tok_str("123e3"), ["123e3"]);
    assert_eq!(tok_typ("123e3"), [FloatConst]);
    assert_eq!(tok_str("123e+99"), ["123e+99"]);
    assert_eq!(tok_typ("123e+99"), [FloatConst]);
    assert_eq!(tok_str("123e+99_"), ["123e+99_"]);
    assert_eq!(tok_typ("123e+99_"), [FloatConst]);
    assert_eq!(tok_str("123e+9_9"), ["123e+9_9"]);
    assert_eq!(tok_typ("123e+9_9"), [FloatConst]);
    assert_eq!(tok_str("2345e-7"), ["2345e-7"]);
    assert_eq!(tok_typ("2345e-7"), [FloatConst]);
    assert_eq!(tok_str("2_345e-7"), ["2_345e-7"]);
    assert_eq!(tok_typ("2_345e-7"), [FloatConst]);
    assert_eq!(tok_str("1_023.9_099"), ["1_023.9_099"]);
    assert_eq!(tok_typ("1_023.9_099"), [FloatConst]);
    assert_eq!(tok_str("1_023_.9_099_"), ["1_023_.9_099_"]);
    assert_eq!(tok_typ("1_023_.9_099_"), [FloatConst]);

    assert_eq!(tok_str("     0.0 "), ["0.0"]);
    assert_eq!(tok_typ("     0.0 "), [FloatConst]);
    assert_eq!(tok_str("123.999 "), ["123.999"]);
    assert_eq!(tok_typ("123.999 "), [FloatConst]);
    assert_eq!(tok_str("123.999e3 "), ["123.999e3"]);
    assert_eq!(tok_typ("123.999e3 "), [FloatConst]);
    assert_eq!(tok_str("123.999e+99 "), ["123.999e+99"]);
    assert_eq!(tok_typ("123.999e+99 "), [FloatConst]);
    assert_eq!(tok_str("2345.567e-7 "), ["2345.567e-7"]);
    assert_eq!(tok_typ("2345.567e-7 "), [FloatConst]);
    assert_eq!(tok_str("123e3 "), ["123e3"]);
    assert_eq!(tok_typ("123e3 "), [FloatConst]);
    assert_eq!(tok_str("123e+99 "), ["123e+99"]);
    assert_eq!(tok_typ("123e+99 "), [FloatConst]);
    assert_eq!(tok_str("123e+99_ "), ["123e+99_"]);
    assert_eq!(tok_typ("123e+99_ "), [FloatConst]);
    assert_eq!(tok_str("2345e-7 "), ["2345e-7"]);
    assert_eq!(tok_typ("2345e-7 "), [FloatConst]);
    assert_eq!(tok_str("1_023_.9_099_ "), ["1_023_.9_099_"]);
    assert_eq!(tok_typ("1_023_.9_099_ "), [FloatConst]);

    assert_eq!(
        tok_err("01.2"),
        "unexpected leading zeros are not allowed in numbers"
    );
}

#[test]
fn decimal() {
    assert_eq!(tok_str("     0.0n"), ["0.0n"]);
    assert_eq!(tok_typ("     0.0n"), [DecimalConst]);
    assert_eq!(tok_str("123.999n"), ["123.999n"]);
    assert_eq!(tok_typ("123.999n"), [DecimalConst]);
    assert_eq!(tok_str("123.999e3n"), ["123.999e3n"]);
    assert_eq!(tok_typ("123.999e3n"), [DecimalConst]);
    assert_eq!(tok_str("123.999e+99n"), ["123.999e+99n"]);
    assert_eq!(tok_typ("123.999e+99n"), [DecimalConst]);
    assert_eq!(tok_str("2345.567e-7n"), ["2345.567e-7n"]);
    assert_eq!(tok_typ("2345.567e-7n"), [DecimalConst]);
    assert_eq!(tok_str("2345e-7n"), ["2345e-7n"]);
    assert_eq!(tok_typ("2345e-7n"), [DecimalConst]);
    assert_eq!(tok_str("2_345e-7n"), ["2_345e-7n"]);
    assert_eq!(tok_typ("2_345e-7n"), [DecimalConst]);
    assert_eq!(tok_str("1_023.9_099n"), ["1_023.9_099n"]);
    assert_eq!(tok_typ("1_023.9_099n"), [DecimalConst]);
    assert_eq!(tok_str("1_023_.9_099_n"), ["1_023_.9_099_n"]);
    assert_eq!(tok_typ("1_023_.9_099_n"), [DecimalConst]);
    assert_eq!(tok_str("2_345e-7n"), ["2_345e-7n"]);
    assert_eq!(tok_typ("2_345e-7n"), [DecimalConst]);
    assert_eq!(tok_str("2_345e-7_7n"), ["2_345e-7_7n"]);
    assert_eq!(tok_typ("2_345e-7_7n"), [DecimalConst]);

    assert_eq!(tok_str("     0.0n "), ["0.0n"]);
    assert_eq!(tok_typ("     0.0n "), [DecimalConst]);
    assert_eq!(tok_str("123.999n "), ["123.999n"]);
    assert_eq!(tok_typ("123.999n "), [DecimalConst]);
    assert_eq!(tok_str("123.999e3n "), ["123.999e3n"]);
    assert_eq!(tok_typ("123.999e3n "), [DecimalConst]);
    assert_eq!(tok_str("123.999e+99n "), ["123.999e+99n"]);
    assert_eq!(tok_typ("123.999e+99n "), [DecimalConst]);
    assert_eq!(tok_str("2345.567e-7n "), ["2345.567e-7n"]);
    assert_eq!(tok_typ("2345.567e-7n "), [DecimalConst]);
    assert_eq!(tok_str("2345e-7n "), ["2345e-7n"]);
    assert_eq!(tok_typ("2345e-7n "), [DecimalConst]);

    assert_eq!(
        tok_err("01.0n"),
        "unexpected leading zeros are not allowed in numbers"
    );
}

#[test]
fn numbers_from_py() {
    assert_eq!(tok_str("SELECT 3.5432;"), ["SELECT", "3.5432", ";"]);
    assert_eq!(
        tok_typ("SELECT 3.5432;"),
        [keyword("select"), FloatConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT +3.5432;"), ["SELECT", "+", "3.5432", ";"]);
    assert_eq!(
        tok_typ("SELECT +3.5432;"),
        [keyword("select"), Add, FloatConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT -3.5432;"), ["SELECT", "-", "3.5432", ";"]);
    assert_eq!(
        tok_typ("SELECT -3.5432;"),
        [keyword("select"), Sub, FloatConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT 354.32;"), ["SELECT", "354.32", ";"]);
    assert_eq!(
        tok_typ("SELECT 354.32;"),
        [keyword("select"), FloatConst, Semicolon]
    );
    assert_eq!(
        tok_str("SELECT 35400000000000.32;"),
        ["SELECT", "35400000000000.32", ";"]
    );
    assert_eq!(
        tok_typ("SELECT 35400000000000.32;"),
        [keyword("select"), FloatConst, Semicolon]
    );
    assert_eq!(
        tok_str("SELECT 35400000000000000000.32;"),
        ["SELECT", "35400000000000000000.32", ";"]
    );
    assert_eq!(
        tok_typ("SELECT 35400000000000000000.32;"),
        [keyword("select"), FloatConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT 3.5432e20;"), ["SELECT", "3.5432e20", ";"]);
    assert_eq!(
        tok_typ("SELECT 3.5432e20;"),
        [keyword("select"), FloatConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT 3.5432e+20;"), ["SELECT", "3.5432e+20", ";"]);
    assert_eq!(
        tok_typ("SELECT 3.5432e+20;"),
        [keyword("select"), FloatConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT 3.5432e-20;"), ["SELECT", "3.5432e-20", ";"]);
    assert_eq!(
        tok_typ("SELECT 3.5432e-20;"),
        [keyword("select"), FloatConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT 354.32e-20;"), ["SELECT", "354.32e-20", ";"]);
    assert_eq!(
        tok_typ("SELECT 354.32e-20;"),
        [keyword("select"), FloatConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT -0n;"), ["SELECT", "-", "0n", ";"]);
    assert_eq!(
        tok_typ("SELECT -0n;"),
        [keyword("select"), Sub, BigIntConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT 0n;"), ["SELECT", "0n", ";"]);
    assert_eq!(
        tok_typ("SELECT 0n;"),
        [keyword("select"), BigIntConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT 1n;"), ["SELECT", "1n", ";"]);
    assert_eq!(
        tok_typ("SELECT 1n;"),
        [keyword("select"), BigIntConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT -1n;"), ["SELECT", "-", "1n", ";"]);
    assert_eq!(
        tok_typ("SELECT -1n;"),
        [keyword("select"), Sub, BigIntConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT 100000n;"), ["SELECT", "100000n", ";"]);
    assert_eq!(
        tok_typ("SELECT 100000n;"),
        [keyword("select"), BigIntConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT -100000n;"), ["SELECT", "-", "100000n", ";"]);
    assert_eq!(
        tok_typ("SELECT -100000n;"),
        [keyword("select"), Sub, BigIntConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT -354.32n;"), ["SELECT", "-", "354.32n", ";"]);
    assert_eq!(
        tok_typ("SELECT -354.32n;"),
        [keyword("select"), Sub, DecimalConst, Semicolon]
    );
    assert_eq!(
        tok_str("SELECT 35400000000000.32n;"),
        ["SELECT", "35400000000000.32n", ";"]
    );
    assert_eq!(
        tok_typ("SELECT 35400000000000.32n;"),
        [keyword("select"), DecimalConst, Semicolon]
    );
    assert_eq!(
        tok_str("SELECT -35400000000000000000.32n;"),
        ["SELECT", "-", "35400000000000000000.32n", ";"]
    );
    assert_eq!(
        tok_typ("SELECT -35400000000000000000.32n;"),
        [keyword("select"), Sub, DecimalConst, Semicolon]
    );
    assert_eq!(tok_str("SELECT 3.5432e20n;"), ["SELECT", "3.5432e20n", ";"]);
    assert_eq!(
        tok_typ("SELECT 3.5432e20n;"),
        [keyword("select"), DecimalConst, Semicolon]
    );
    assert_eq!(
        tok_str("SELECT -3.5432e+20n;"),
        ["SELECT", "-", "3.5432e+20n", ";"]
    );
    assert_eq!(
        tok_typ("SELECT -3.5432e+20n;"),
        [keyword("select"), Sub, DecimalConst, Semicolon]
    );
    assert_eq!(
        tok_str("SELECT 3.5432e-20n;"),
        ["SELECT", "3.5432e-20n", ";"]
    );
    assert_eq!(
        tok_typ("SELECT 3.5432e-20n;"),
        [keyword("select"), DecimalConst, Semicolon]
    );
    assert_eq!(
        tok_str("SELECT 354.32e-20n;"),
        ["SELECT", "354.32e-20n", ";"]
    );
    assert_eq!(
        tok_typ("SELECT 354.32e-20n;"),
        [keyword("select"), DecimalConst, Semicolon]
    );
}

#[test]
fn num_errors() {
    assert_eq!(
        tok_err("0. "),
        "expected digit after dot, found end of decimal"
    );
    assert_eq!(
        tok_err("1.<"),
        "expected digit after dot, found end of decimal"
    );
    assert_eq!(tok_err("0.n"), "expected digit after dot, found suffix");
    assert_eq!(tok_err("0.e1"), "expected digit after dot, found exponent");
    assert_eq!(tok_err("0.e1n"), "expected digit after dot, found exponent");
    assert_eq!(
        tok_err("0."),
        "expected digit after dot, found end of decimal"
    );
    assert_eq!(tok_err("1.0.x"), "unexpected extra decimal dot in number");
    assert_eq!(tok_err("1.0e1."), "unexpected extra decimal dot in number");
    assert_eq!(
        tok_err("1.0e."),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0e"),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0ex"),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0en"),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0e "),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0e_"),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0e_ "),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0e_1"),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0e+"),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0e+ "),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0e+x"),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1.0e+n"),
        "unexpected optional `+` or `-` \
        followed by digits must follow `e` in float const"
    );
    assert_eq!(
        tok_err("1234numeric"),
        "suffix \"numeric\" \
        is invalid for numbers, perhaps you wanted `1234n` (bigint)?"
    );
    assert_eq!(
        tok_err("1234some_l0ng_trash"),
        "suffix \"some_l0n...\" \
        is invalid for numbers, perhaps you wanted `1234n` (bigint)?"
    );
    assert_eq!(
        tok_err("100O00"),
        "suffix \"O00\" is invalid for numbers, \
        perhaps mixed up letter `O` with zero `0`?"
    );
    assert_eq!(
        tok_err("01"),
        "unexpected leading zeros are not allowed in numbers"
    );
    assert_eq!(
        tok_err("01n"),
        "unexpected leading zeros are not allowed in numbers"
    );
    assert_eq!(
        tok_err("01_n"),
        "unexpected leading zeros are not allowed in numbers"
    );
    assert_eq!(
        tok_err("0_1_n"),
        "unexpected leading zeros are not allowed in numbers"
    );
    assert_eq!(
        tok_err("0_1n"),
        "unexpected leading zeros are not allowed in numbers"
    );
}

#[test]
fn tuple_paths() {
    assert_eq!(
        tok_str("tup.1.2.3.4.5"),
        ["tup", ".", "1", ".", "2", ".", "3", ".", "4", ".", "5"]
    );
    assert_eq!(
        tok_typ("tup.1.2.3.4.5"),
        [Ident, Dot, IntConst, Dot, IntConst, Dot, IntConst, Dot, IntConst, Dot, IntConst]
    );
    assert_eq!(
        tok_err("tup.1.2.>3.4.>5"),
        "unexpected extra decimal dot in number"
    );
    assert_eq!(
        tok_str("$0.1.2.3.4.5"),
        ["$0", ".", "1", ".", "2", ".", "3", ".", "4", ".", "5"]
    );
    assert_eq!(
        tok_typ("$0.1.2.3.4.5"),
        [Parameter, Dot, IntConst, Dot, IntConst, Dot, IntConst, Dot, IntConst, Dot, IntConst]
    );
    assert_eq!(
        tok_err("tup.1n"),
        "unexpected char \'n\', only integers \
        are allowed after dot (for tuple access)"
    );

    assert_eq!(
        tok_err("tup.01"),
        "leading zeros are not allowed in numbers"
    );
}

#[test]
fn strings() {
    assert_eq!(tok_str(r#" ""  "#), [r#""""#]);
    assert_eq!(tok_typ(r#" ""  "#), [Str]);
    assert_eq!(tok_str(r#" ''  "#), [r#"''"#]);
    assert_eq!(tok_typ(r#" ''  "#), [Str]);
    assert_eq!(tok_str(r#" r""  "#), [r#"r"""#]);
    assert_eq!(tok_typ(r#" r""  "#), [Str]);
    assert_eq!(tok_str(r#" r''  "#), [r#"r''"#]);
    assert_eq!(tok_typ(r#" r''  "#), [Str]);
    assert_eq!(tok_str(r#" b""  "#), [r#"b"""#]);
    assert_eq!(tok_typ(r#" b""  "#), [BinStr]);
    assert_eq!(tok_str(r#" b''  "#), [r#"b''"#]);
    assert_eq!(tok_typ(r#" b''  "#), [BinStr]);
    assert_eq!(tok_str(r#" br""  "#), [r#"br"""#]);
    assert_eq!(tok_typ(r#" br""  "#), [BinStr]);
    assert_eq!(tok_str(r#" br''  "#), [r#"br''"#]);
    assert_eq!(tok_typ(r#" br''  "#), [BinStr]);
    assert_eq!(tok_err(r#" ``  "#), "backtick quotes cannot be empty");

    assert_eq!(tok_str(r#" "hello"  "#), [r#""hello""#]);
    assert_eq!(tok_typ(r#" "hello"  "#), [Str]);
    assert_eq!(tok_str(r#" 'hello'  "#), [r#"'hello'"#]);
    assert_eq!(tok_typ(r#" 'hello'  "#), [Str]);
    assert_eq!(tok_str(r#" r"hello"  "#), [r#"r"hello""#]);
    assert_eq!(tok_typ(r#" r"hello"  "#), [Str]);
    assert_eq!(tok_str(r#" r'hello'  "#), [r#"r'hello'"#]);
    assert_eq!(tok_typ(r#" r'hello'  "#), [Str]);
    assert_eq!(tok_str(r#" b"hello"  "#), [r#"b"hello""#]);
    assert_eq!(tok_typ(r#" b"hello"  "#), [BinStr]);
    assert_eq!(tok_str(r#" b'hello'  "#), [r#"b'hello'"#]);
    assert_eq!(tok_typ(r#" b'hello'  "#), [BinStr]);
    assert_eq!(tok_str(r#" rb"hello"  "#), [r#"rb"hello""#]);
    assert_eq!(tok_typ(r#" rb"hello"  "#), [BinStr]);
    assert_eq!(tok_str(r#" rb'hello'  "#), [r#"rb'hello'"#]);
    assert_eq!(tok_typ(r#" rb'hello'  "#), [BinStr]);
    assert_eq!(tok_str(r#" `hello`  "#), [r#"`hello`"#]);
    assert_eq!(tok_typ(r#" `hello`  "#), [Ident]);

    assert_eq!(tok_str(r#" "hello""#), [r#""hello""#]);
    assert_eq!(tok_typ(r#" "hello""#), [Str]);
    assert_eq!(tok_str(r#" 'hello'"#), [r#"'hello'"#]);
    assert_eq!(tok_typ(r#" 'hello'"#), [Str]);
    assert_eq!(tok_str(r#" r"hello""#), [r#"r"hello""#]);
    assert_eq!(tok_typ(r#" r"hello""#), [Str]);
    assert_eq!(tok_str(r#" r'hello'"#), [r#"r'hello'"#]);
    assert_eq!(tok_typ(r#" r'hello'"#), [Str]);
    assert_eq!(tok_str(r#" b"hello""#), [r#"b"hello""#]);
    assert_eq!(tok_typ(r#" b"hello""#), [BinStr]);
    assert_eq!(tok_str(r#" b'hello'"#), [r#"b'hello'"#]);
    assert_eq!(tok_typ(r#" b'hello'"#), [BinStr]);
    assert_eq!(tok_str(r#" rb"hello""#), [r#"rb"hello""#]);
    assert_eq!(tok_typ(r#" rb"hello""#), [BinStr]);
    assert_eq!(tok_str(r#" rb'hello'"#), [r#"rb'hello'"#]);
    assert_eq!(tok_typ(r#" rb'hello'"#), [BinStr]);
    assert_eq!(tok_str(r#" `hello`"#), [r#"`hello`"#]);
    assert_eq!(tok_typ(r#" `hello`"#), [Ident]);

    assert_eq!(tok_str(r#" "h\"ello" "#), [r#""h\"ello""#]);
    assert_eq!(tok_typ(r#" "h\"ello" "#), [Str]);
    assert_eq!(tok_str(r" 'h\'ello' "), [r"'h\'ello'"]);
    assert_eq!(tok_typ(r" 'h\'ello' "), [Str]);
    assert_eq!(tok_str(r#" r"hello\" "#), [r#"r"hello\""#]);
    assert_eq!(tok_typ(r#" r"hello\" "#), [Str]);
    assert_eq!(tok_str(r" r'hello\' "), [r"r'hello\'"]);
    assert_eq!(tok_typ(r" r'hello\' "), [Str]);
    assert_eq!(tok_str(r#" b"h\"ello" "#), [r#"b"h\"ello""#]);
    assert_eq!(tok_typ(r#" b"h\"ello" "#), [BinStr]);
    assert_eq!(tok_str(r" b'h\'ello' "), [r"b'h\'ello'"]);
    assert_eq!(tok_typ(r" b'h\'ello' "), [BinStr]);
    assert_eq!(tok_str(r#" rb"hello\" "#), [r#"rb"hello\""#]);
    assert_eq!(tok_typ(r#" rb"hello\" "#), [BinStr]);
    assert_eq!(tok_str(r" rb'hello\' "), [r"rb'hello\'"]);
    assert_eq!(tok_typ(r" rb'hello\' "), [BinStr]);
    assert_eq!(tok_str(r" `hello\` "), [r"`hello\`"]);
    assert_eq!(tok_typ(r" `hello\` "), [Ident]);
    assert_eq!(tok_str(r#" `hel``lo` "#), [r#"`hel``lo`"#]);
    assert_eq!(tok_typ(r#" `hel``lo` "#), [Ident]);

    assert_eq!(tok_str(r#" "h'el`lo" "#), [r#""h'el`lo""#]);
    assert_eq!(tok_typ(r#" "h'el`lo" "#), [Str]);
    assert_eq!(tok_str(r#" 'h"el`lo' "#), [r#"'h"el`lo'"#]);
    assert_eq!(tok_typ(r#" 'h"el`lo' "#), [Str]);
    assert_eq!(tok_str(r#" r"h'el`lo" "#), [r#"r"h'el`lo""#]);
    assert_eq!(tok_typ(r#" r"h'el`lo" "#), [Str]);
    assert_eq!(tok_str(r#" r'h"el`lo' "#), [r#"r'h"el`lo'"#]);
    assert_eq!(tok_typ(r#" r'h"el`lo' "#), [Str]);
    assert_eq!(tok_str(r#" b"h'el`lo" "#), [r#"b"h'el`lo""#]);
    assert_eq!(tok_typ(r#" b"h'el`lo" "#), [BinStr]);
    assert_eq!(tok_str(r#" b'h"el`lo' "#), [r#"b'h"el`lo'"#]);
    assert_eq!(tok_typ(r#" b'h"el`lo' "#), [BinStr]);
    assert_eq!(tok_str(r#" rb"h'el`lo" "#), [r#"rb"h'el`lo""#]);
    assert_eq!(tok_typ(r#" rb"h'el`lo" "#), [BinStr]);
    assert_eq!(tok_str(r#" rb'h"el`lo' "#), [r#"rb'h"el`lo'"#]);
    assert_eq!(tok_typ(r#" rb'h"el`lo' "#), [BinStr]);
    assert_eq!(tok_str(r#" `h'el"lo` "#), [r#"`h'el"lo`"#]);
    assert_eq!(tok_typ(r#" `h'el"lo\` "#), [Ident]);

    assert_eq!(tok_str(" \"hel\nlo\" "), ["\"hel\nlo\""]);
    assert_eq!(tok_typ(" \"hel\nlo\" "), [Str]);
    assert_eq!(tok_str(" 'hel\nlo' "), ["'hel\nlo'"]);
    assert_eq!(tok_typ(" 'hel\nlo' "), [Str]);
    assert_eq!(tok_str(" r\"hel\nlo\" "), ["r\"hel\nlo\""]);
    assert_eq!(tok_typ(" r\"hel\nlo\" "), [Str]);
    assert_eq!(tok_str(" r'hel\nlo' "), ["r'hel\nlo'"]);
    assert_eq!(tok_typ(" r'hel\nlo' "), [Str]);
    assert_eq!(tok_str(" b\"hel\nlo\" "), ["b\"hel\nlo\""]);
    assert_eq!(tok_typ(" b\"hel\nlo\" "), [BinStr]);
    assert_eq!(tok_str(" b'hel\nlo' "), ["b'hel\nlo'"]);
    assert_eq!(tok_typ(" b'hel\nlo' "), [BinStr]);
    assert_eq!(tok_typ(" rb'hel\nlo' "), [BinStr]);
    assert_eq!(tok_typ(" br'hel\nlo' "), [BinStr]);
    assert_eq!(tok_str(" rb'hel\nlo' "), ["rb'hel\nlo'"]);
    assert_eq!(tok_str(" br'hel\nlo' "), ["br'hel\nlo'"]);
    assert_eq!(tok_str(" `hel\nlo` "), ["`hel\nlo`"]);
    assert_eq!(tok_typ(" `hel\nlo` "), [Ident]);

    assert_eq!(tok_err(r#""hello"#), "unterminated string, quoted by `\"`");
    assert_eq!(tok_err(r#"'hello"#), "unterminated string, quoted by `'`");
    assert_eq!(tok_err(r#"r"hello"#), "unterminated string, quoted by `\"`");
    assert_eq!(tok_err(r#"r'hello"#), "unterminated string, quoted by `'`");
    assert_eq!(tok_err(r#"b"hello"#), "unterminated string, quoted by `\"`");
    assert_eq!(tok_err(r#"b'hello"#), "unterminated string, quoted by `'`");
    assert_eq!(tok_err(r#"`hello"#), "unterminated backtick name");

    assert_eq!(
        tok_err(r#"name`type`"#),
        "prefix \"name\" is not allowed for field names, \
        perhaps missing comma or dot?"
    );
    assert_eq!(
        tok_err(r#"User`type`"#),
        "prefix \"User\" is not allowed for field names, \
        perhaps missing comma or dot?"
    );
    assert_eq!(
        tok_err(r#"r`hello"#),
        "prefix \"r\" is not allowed for field names, \
        perhaps missing comma or dot?"
    );
    assert_eq!(
        tok_err(r#"b`hello"#),
        "prefix \"b\" is not allowed for field names, \
        perhaps missing comma or dot?"
    );
    assert_eq!(
        tok_err(r#"test"hello""#),
        "prefix \"test\" is not allowed for strings, \
        allowed: `b`, `r`"
    );
    assert_eq!(
        tok_err(r#"test'hello'"#),
        "prefix \"test\" is not allowed for strings, \
        allowed: `b`, `r`"
    );
    assert_eq!(
        tok_err(r#"`@x`"#),
        "backtick-quoted name cannot start with char `@`"
    );
    assert_eq!(
        tok_err(r#"`$x`"#),
        "backtick-quoted name cannot start with char `$`"
    );
    assert_eq!(
        tok_err(r#"`a::b`"#),
        "backtick-quoted name cannot contain `::`"
    );
    assert_eq!(
        tok_err(r#"`__x__`"#),
        "backtick-quoted names surrounded by double \
                    underscores are forbidden"
    );
}

#[test]
fn string_prohibited_chars() {
    assert_eq!(
        tok_err("'xxx \u{202A}'"),
        "character U+202A is not allowed, use escaped form \\u202a"
    );
    assert_eq!(
        tok_err("\"\u{202A} yyy\""),
        "character U+202A is not allowed, use escaped form \\u202a"
    );
    assert_eq!(
        tok_err("r\"\u{202A}\""),
        "character U+202A is not allowed, use escaped form \\u202a"
    );
    assert_eq!(
        tok_err("r'\u{202A}'"),
        "character U+202A is not allowed, use escaped form \\u202a"
    );
    assert_eq!(
        tok_err("b'\u{202A}'"),
        "invalid bytes literal: character '\\u{202a}' \
         is unexpected, only ascii chars are allowed in bytes literals"
    );
    assert_eq!(
        tok_err("b\"\u{202A}\""),
        "invalid bytes literal: character '\\u{202a}' \
         is unexpected, only ascii chars are allowed in bytes literals"
    );
    assert_eq!(tok_err("`\u{202A}`"), "character U+202A is not allowed");
    assert_eq!(tok_err("$`\u{202A}`"), "character U+202A is not allowed");
    assert_eq!(
        tok_err("$x\u{202A}$ inner $x\u{202A}$"),
        "unexpected character '\\u{202a}'"
    );
    assert_eq!(tok_err("$$ \u{202A} $$"), "character U+202A is not allowed");
    assert_eq!(
        tok_err("$hello$ \u{202A} $hello$"),
        "character U+202A is not allowed"
    );
    assert_eq!(tok_err("'xxx \0'"), "character U+0000 is not allowed");
    assert_eq!(tok_err("xxx \0"), "unexpected character '\\0'");
    assert_eq!(tok_err("xxx $x$\0$x$"), "character U+0000 is not allowed");
}

#[test]
fn test_dollar() {
    assert_eq!(
        tok_str("select $$ something $$; x"),
        ["select", "$$ something $$", ";", "x"]
    );
    assert_eq!(
        tok_typ("select $$ something $$; x"),
        [keyword("select"), Str, Semicolon, Ident]
    );
    assert_eq!(
        tok_str("select $a$ ; $b$ ; $b$ ; $a$; x"),
        ["select", "$a$ ; $b$ ; $b$ ; $a$", ";", "x"]
    );
    assert_eq!(
        tok_typ("select $a$ ; $b$ ; $b$ ; $a$; x"),
        [keyword("select"), Str, Semicolon, Ident]
    );
    assert_eq!(
        tok_str("select $a$ ; $b$ ; $a$; x"),
        ["select", "$a$ ; $b$ ; $a$", ";", "x"]
    );
    assert_eq!(
        tok_typ("select $a$ ; $b$ ; $a$; x"),
        [keyword("select"), Str, Semicolon, Ident]
    );
    assert_eq!(
        tok_err("select $$ ; $ab$ test;"),
        "unterminated string started with $$"
    );
    assert_eq!(
        tok_err("select $a$ ; $$ test;"),
        "unterminated string started with \"$a$\""
    );
    assert_eq!(
        tok_err("select $0$"),
        "dollar quote must not start with a digit"
    );
    assert_eq!(
        tok_err("select $фыва$"),
        "dollar quote supports only ascii chars"
    );
    assert_eq!(
        tok_str("select $a$a$ ; $a$ test;"),
        ["select", "$a$a$ ; $a$", "test", ";"]
    );
    assert_eq!(
        tok_typ("select $a$a$ ; $a$ test;"),
        [keyword("select"), Str, Ident, Semicolon]
    );
    assert_eq!(
        tok_str("select $a+b; $b test; $a+b; $b ;"),
        ["select", "$a", "+", "b", ";", "$b", "test", ";", "$a", "+", "b", ";", "$b", ";"]
    );
    assert_eq!(
        tok_typ("select $a+b; $b test; $a+b; $b ;"),
        [
            keyword("select"),
            Parameter,
            Add,
            Ident,
            Semicolon,
            Parameter,
            Ident,
            Semicolon,
            Parameter,
            Add,
            Ident,
            Semicolon,
            Parameter,
            Semicolon
        ]
    );
    assert_eq!(
        tok_str("select $def x$y test; $def x$y"),
        ["select", "$def", "x", "$y", "test", ";", "$def", "x", "$y"]
    );
    assert_eq!(
        tok_typ("select $def x$y test; $def x$y"),
        [
            keyword("select"),
            Parameter,
            Ident,
            Parameter,
            Ident,
            Semicolon,
            Parameter,
            Ident,
            Parameter
        ]
    );
    assert_eq!(
        tok_str("select $`x``y` + $0 + $`zz` + $1.2 + $фыва"),
        [
            "select",
            "$`x``y`",
            "+",
            "$0",
            "+",
            "$`zz`",
            "+",
            "$1",
            ".",
            "2",
            "+",
            "$фыва"
        ]
    );
    assert_eq!(
        tok_typ("select $`x``y` + $0 + $`zz` + $1.2 + $фыва"),
        [
            keyword("select"),
            Parameter,
            Add,
            Parameter,
            Add,
            Parameter,
            Add,
            Parameter,
            Dot,
            IntConst,
            Add,
            Parameter
        ]
    );
    assert_eq!(tok_err(r#"$-"#), "bare $ is not allowed");
    assert_eq!(
        tok_err(r#"$0abc"#),
        "the \"$0abc\" is not a valid argument, \
         either name starting with letter or only digits are expected"
    );
    assert_eq!(tok_err(r#"-$"#), "bare $ is not allowed");
    assert_eq!(
        tok_err(r#" $``  "#),
        "backtick-quoted argument cannot be empty"
    );
    assert_eq!(
        tok_err(r#"$`@x`"#),
        "backtick-quoted argument cannot \
        start with char `@`"
    );
    assert_eq!(
        tok_err(r#"$`a::b`"#),
        "backtick-quoted argument cannot contain `::`"
    );
    assert_eq!(
        tok_err(r#"$`__x__`"#),
        "backtick-quoted arguments surrounded by double \
                    underscores are forbidden"
    );
}

#[test]
fn invalid_suffix() {
    assert_eq!(
        tok_err("SELECT 1d;"),
        "suffix \"d\" \
        is invalid for numbers, perhaps you wanted `1n` (bigint)?"
    );
}

#[test]
fn test_substitution() {
    assert_eq!(tok_str("SELECT \\(expr);"), ["SELECT", "\\(expr)", ";"]);
    assert_eq!(
        tok_typ("SELECT \\(expr);"),
        [keyword("select"), Substitution, Semicolon]
    );
    assert_eq!(
        tok_str("SELECT \\(other_Name1);"),
        ["SELECT", "\\(other_Name1)", ";"]
    );
    assert_eq!(
        tok_typ("SELECT \\(other_Name1);"),
        [keyword("select"), Substitution, Semicolon]
    );
    assert_eq!(
        tok_err("SELECT \\(some-name);"),
        "only alphanumerics are allowed in \\(name) token"
    );
    assert_eq!(tok_err("SELECT \\(some_name"), "unclosed \\(name) token");
}


================================================
FILE: edb/errors/__init__.py
================================================
# AUTOGENERATED FROM "edb/api/errors.txt" WITH
#    $ edb gen-errors


# flake8: noqa


from edb.errors.base import *


__all__ = base.__all__ + (  # type: ignore
    'InternalServerError',
    'UnsupportedFeatureError',
    'ProtocolError',
    'BinaryProtocolError',
    'UnsupportedProtocolVersionError',
    'TypeSpecNotFoundError',
    'UnexpectedMessageError',
    'InputDataError',
    'ParameterTypeMismatchError',
    'StateMismatchError',
    'ResultCardinalityMismatchError',
    'CapabilityError',
    'UnsupportedCapabilityError',
    'DisabledCapabilityError',
    'UnsafeIsolationLevelError',
    'QueryError',
    'InvalidSyntaxError',
    'EdgeQLSyntaxError',
    'SchemaSyntaxError',
    'GraphQLSyntaxError',
    'InvalidTypeError',
    'InvalidTargetError',
    'InvalidLinkTargetError',
    'InvalidPropertyTargetError',
    'InvalidReferenceError',
    'UnknownModuleError',
    'UnknownLinkError',
    'UnknownPropertyError',
    'UnknownUserError',
    'UnknownDatabaseError',
    'UnknownParameterError',
    'DeprecatedScopingError',
    'SchemaError',
    'SchemaDefinitionError',
    'InvalidDefinitionError',
    'InvalidModuleDefinitionError',
    'InvalidLinkDefinitionError',
    'InvalidPropertyDefinitionError',
    'InvalidUserDefinitionError',
    'InvalidDatabaseDefinitionError',
    'InvalidOperatorDefinitionError',
    'InvalidAliasDefinitionError',
    'InvalidFunctionDefinitionError',
    'InvalidConstraintDefinitionError',
    'InvalidCastDefinitionError',
    'DuplicateDefinitionError',
    'DuplicateModuleDefinitionError',
    'DuplicateLinkDefinitionError',
    'DuplicatePropertyDefinitionError',
    'DuplicateUserDefinitionError',
    'DuplicateDatabaseDefinitionError',
    'DuplicateOperatorDefinitionError',
    'DuplicateViewDefinitionError',
    'DuplicateFunctionDefinitionError',
    'DuplicateConstraintDefinitionError',
    'DuplicateCastDefinitionError',
    'DuplicateMigrationError',
    'SessionTimeoutError',
    'IdleSessionTimeoutError',
    'QueryTimeoutError',
    'TransactionTimeoutError',
    'IdleTransactionTimeoutError',
    'ExecutionError',
    'InvalidValueError',
    'DivisionByZeroError',
    'NumericOutOfRangeError',
    'AccessPolicyError',
    'QueryAssertionError',
    'IntegrityError',
    'ConstraintViolationError',
    'CardinalityViolationError',
    'MissingRequiredError',
    'TransactionError',
    'TransactionConflictError',
    'TransactionSerializationError',
    'TransactionDeadlockError',
    'QueryCacheInvalidationError',
    'WatchError',
    'ConfigurationError',
    'AccessError',
    'AuthenticationError',
    'AvailabilityError',
    'BackendUnavailableError',
    'ServerOfflineError',
    'UnknownTenantError',
    'ServerBlockedError',
    'BackendError',
    'UnsupportedBackendFeatureError',
    'LogMessage',
    'WarningMessage',
    'StatusMessage',
    'MigrationStatusMessage',
)


class InternalServerError(EdgeDBError):
    _code = 0x_01_00_00_00


class UnsupportedFeatureError(EdgeDBError):
    _code = 0x_02_00_00_00


class ProtocolError(EdgeDBError):
    _code = 0x_03_00_00_00


class BinaryProtocolError(ProtocolError):
    _code = 0x_03_01_00_00


class UnsupportedProtocolVersionError(BinaryProtocolError):
    _code = 0x_03_01_00_01


class TypeSpecNotFoundError(BinaryProtocolError):
    _code = 0x_03_01_00_02


class UnexpectedMessageError(BinaryProtocolError):
    _code = 0x_03_01_00_03


class InputDataError(ProtocolError):
    _code = 0x_03_02_00_00


class ParameterTypeMismatchError(InputDataError):
    _code = 0x_03_02_01_00


class StateMismatchError(InputDataError):
    _code = 0x_03_02_02_00


class ResultCardinalityMismatchError(ProtocolError):
    _code = 0x_03_03_00_00


class CapabilityError(ProtocolError):
    _code = 0x_03_04_00_00


class UnsupportedCapabilityError(CapabilityError):
    _code = 0x_03_04_01_00


class DisabledCapabilityError(CapabilityError):
    _code = 0x_03_04_02_00


class UnsafeIsolationLevelError(CapabilityError):
    _code = 0x_03_04_03_00


class QueryError(EdgeDBError):
    _code = 0x_04_00_00_00


class InvalidSyntaxError(QueryError):
    _code = 0x_04_01_00_00


class EdgeQLSyntaxError(InvalidSyntaxError):
    _code = 0x_04_01_01_00


class SchemaSyntaxError(InvalidSyntaxError):
    _code = 0x_04_01_02_00


class GraphQLSyntaxError(InvalidSyntaxError):
    _code = 0x_04_01_03_00


class InvalidTypeError(QueryError):
    _code = 0x_04_02_00_00


class InvalidTargetError(InvalidTypeError):
    _code = 0x_04_02_01_00


class InvalidLinkTargetError(InvalidTargetError):
    _code = 0x_04_02_01_01


class InvalidPropertyTargetError(InvalidTargetError):
    _code = 0x_04_02_01_02


class InvalidReferenceError(QueryError):
    _code = 0x_04_03_00_00


class UnknownModuleError(InvalidReferenceError):
    _code = 0x_04_03_00_01


class UnknownLinkError(InvalidReferenceError):
    _code = 0x_04_03_00_02


class UnknownPropertyError(InvalidReferenceError):
    _code = 0x_04_03_00_03


class UnknownUserError(InvalidReferenceError):
    _code = 0x_04_03_00_04


class UnknownDatabaseError(InvalidReferenceError):
    _code = 0x_04_03_00_05


class UnknownParameterError(InvalidReferenceError):
    _code = 0x_04_03_00_06


class DeprecatedScopingError(InvalidReferenceError):
    _code = 0x_04_03_00_07


class SchemaError(QueryError):
    _code = 0x_04_04_00_00


class SchemaDefinitionError(QueryError):
    _code = 0x_04_05_00_00


class InvalidDefinitionError(SchemaDefinitionError):
    _code = 0x_04_05_01_00


class InvalidModuleDefinitionError(InvalidDefinitionError):
    _code = 0x_04_05_01_01


class InvalidLinkDefinitionError(InvalidDefinitionError):
    _code = 0x_04_05_01_02


class InvalidPropertyDefinitionError(InvalidDefinitionError):
    _code = 0x_04_05_01_03


class InvalidUserDefinitionError(InvalidDefinitionError):
    _code = 0x_04_05_01_04


class InvalidDatabaseDefinitionError(InvalidDefinitionError):
    _code = 0x_04_05_01_05


class InvalidOperatorDefinitionError(InvalidDefinitionError):
    _code = 0x_04_05_01_06


class InvalidAliasDefinitionError(InvalidDefinitionError):
    _code = 0x_04_05_01_07


class InvalidFunctionDefinitionError(InvalidDefinitionError):
    _code = 0x_04_05_01_08


class InvalidConstraintDefinitionError(InvalidDefinitionError):
    _code = 0x_04_05_01_09


class InvalidCastDefinitionError(InvalidDefinitionError):
    _code = 0x_04_05_01_0A


class DuplicateDefinitionError(SchemaDefinitionError):
    _code = 0x_04_05_02_00


class DuplicateModuleDefinitionError(DuplicateDefinitionError):
    _code = 0x_04_05_02_01


class DuplicateLinkDefinitionError(DuplicateDefinitionError):
    _code = 0x_04_05_02_02


class DuplicatePropertyDefinitionError(DuplicateDefinitionError):
    _code = 0x_04_05_02_03


class DuplicateUserDefinitionError(DuplicateDefinitionError):
    _code = 0x_04_05_02_04


class DuplicateDatabaseDefinitionError(DuplicateDefinitionError):
    _code = 0x_04_05_02_05


class DuplicateOperatorDefinitionError(DuplicateDefinitionError):
    _code = 0x_04_05_02_06


class DuplicateViewDefinitionError(DuplicateDefinitionError):
    _code = 0x_04_05_02_07


class DuplicateFunctionDefinitionError(DuplicateDefinitionError):
    _code = 0x_04_05_02_08


class DuplicateConstraintDefinitionError(DuplicateDefinitionError):
    _code = 0x_04_05_02_09


class DuplicateCastDefinitionError(DuplicateDefinitionError):
    _code = 0x_04_05_02_0A


class DuplicateMigrationError(DuplicateDefinitionError):
    _code = 0x_04_05_02_0B


class SessionTimeoutError(QueryError):
    _code = 0x_04_06_00_00


class IdleSessionTimeoutError(SessionTimeoutError):
    _code = 0x_04_06_01_00


class QueryTimeoutError(SessionTimeoutError):
    _code = 0x_04_06_02_00


class TransactionTimeoutError(SessionTimeoutError):
    _code = 0x_04_06_0A_00


class IdleTransactionTimeoutError(TransactionTimeoutError):
    _code = 0x_04_06_0A_01


class ExecutionError(EdgeDBError):
    _code = 0x_05_00_00_00


class InvalidValueError(ExecutionError):
    _code = 0x_05_01_00_00


class DivisionByZeroError(InvalidValueError):
    _code = 0x_05_01_00_01


class NumericOutOfRangeError(InvalidValueError):
    _code = 0x_05_01_00_02


class AccessPolicyError(InvalidValueError):
    _code = 0x_05_01_00_03


class QueryAssertionError(InvalidValueError):
    _code = 0x_05_01_00_04


class IntegrityError(ExecutionError):
    _code = 0x_05_02_00_00


class ConstraintViolationError(IntegrityError):
    _code = 0x_05_02_00_01


class CardinalityViolationError(IntegrityError):
    _code = 0x_05_02_00_02


class MissingRequiredError(IntegrityError):
    _code = 0x_05_02_00_03


class TransactionError(ExecutionError):
    _code = 0x_05_03_00_00


class TransactionConflictError(TransactionError):
    _code = 0x_05_03_01_00


class TransactionSerializationError(TransactionConflictError):
    _code = 0x_05_03_01_01


class TransactionDeadlockError(TransactionConflictError):
    _code = 0x_05_03_01_02


class QueryCacheInvalidationError(TransactionConflictError):
    _code = 0x_05_03_01_03


class WatchError(ExecutionError):
    _code = 0x_05_04_00_00


class ConfigurationError(EdgeDBError):
    _code = 0x_06_00_00_00


class AccessError(EdgeDBError):
    _code = 0x_07_00_00_00


class AuthenticationError(AccessError):
    _code = 0x_07_01_00_00


class AvailabilityError(EdgeDBError):
    _code = 0x_08_00_00_00


class BackendUnavailableError(AvailabilityError):
    _code = 0x_08_00_00_01


class ServerOfflineError(AvailabilityError):
    _code = 0x_08_00_00_02


class UnknownTenantError(AvailabilityError):
    _code = 0x_08_00_00_03


class ServerBlockedError(AvailabilityError):
    _code = 0x_08_00_00_04


class BackendError(EdgeDBError):
    _code = 0x_09_00_00_00


class UnsupportedBackendFeatureError(BackendError):
    _code = 0x_09_00_01_00


class LogMessage(EdgeDBMessage):
    _code = 0x_F0_00_00_00


class WarningMessage(LogMessage):
    _code = 0x_F0_01_00_00


class StatusMessage(LogMessage):
    _code = 0x_F0_02_00_00


class MigrationStatusMessage(StatusMessage):
    _code = 0x_F0_02_00_01


================================================
FILE: edb/errors/base.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

from typing import Optional, Iterator

from edb.common import span as edb_span
from edb.common import exceptions as ex

import contextlib


__all__ = (
    'EdgeDBError', 'EdgeDBMessage', 'ensure_span',
)


class EdgeDBErrorMeta(type):
    _error_map: dict[int, type[EdgeDBError]] = {}
    _name_map: dict[str, type[EdgeDBError]] = {}

    def __new__(mcls, name, bases, dct):
        cls = super().__new__(mcls, name, bases, dct)

        assert name not in mcls._name_map
        mcls._name_map[name] = cls

        code = dct.get('_code')
        if code is not None:
            mcls._error_map[code] = cls

        return cls

    def __init__(cls, name, bases, dct):
        if cls._code is None and cls.__module__ != __name__:
            # We don't want any EdgeDBError subclasses to not
            # have a code.
            raise RuntimeError(
                'direct subclassing of EdgeDBError is prohibited; '
                'subclass one of its subclasses in edb.errors')

    @classmethod
    def get_error_class_from_code(mcls, code: int) -> type[EdgeDBError]:
        return mcls._error_map[code]

    @classmethod
    def get_error_class_from_name(mcls, name: str) -> type[EdgeDBError]:
        return mcls._name_map[name]


class EdgeDBMessage(Warning):

    _code: Optional[int] = None

    @classmethod
    def get_code(cls):
        if cls._code is None:
            raise RuntimeError(
                f'EdgeDB message code is not set (type: {cls.__name__})')
        return cls._code


class EdgeDBError(Exception, metaclass=EdgeDBErrorMeta):

    _code: Optional[int] = None
    _attrs: dict[int, str]
    _pgext_code: Optional[str] = None

    def __init__(
        self,
        msg: Optional[str] = None,
        *,
        hint: Optional[str] = None,
        details: Optional[str] = None,
        span: Optional[edb_span.Span] = None,
        position: Optional[tuple[int, int, int, int | None]] = None,
        filename: Optional[str] = None,
        pgext_code: Optional[str] = None,
    ):
        if type(self) is EdgeDBError:
            raise RuntimeError(
                'EdgeDBError is not supposed to be instantiated directly')

        self._attrs = {}
        self._pgext_code = pgext_code

        if span:
            self.set_span(span)
        elif position:
            self.set_linecol(position[1], position[0])
            self.set_position(position[2], position[3])

        if filename is not None:
            self.set_filename(filename)

        self.set_hint_and_details(hint, details)

        super().__init__(msg)

    @classmethod
    def get_code(cls):
        if cls._code is None:
            raise RuntimeError(
                f'Gel message code is not set (type: {cls.__name__})')
        return cls._code

    def to_json(self):
        err_dct = {
            'message': str(self),
            'type': str(type(self).__name__),
            'code': self.get_code(),
        }
        for name, field in _JSON_FIELDS.items():
            if field in self._attrs:
                val = self._attrs[field]
                if field in _INT_FIELDS:
                    val = int(val)
                err_dct[name] = val

        return err_dct

    def set_filename(self, filename):
        self._attrs[FIELD_FILENAME] = filename

    def set_linecol(
        self,
        line: Optional[int],  # one-based
        col: Optional[int],  # one-based
    ):
        if line is not None:
            self._attrs[FIELD_LINE_START] = str(line)
        if col is not None:
            self._attrs[FIELD_COLUMN_START] = str(col)

    def compute_line_col(self, source: str):
        from edb.edgeql import tokenizer

        start: int = self.position
        end: int | None = self.position_end
        if end and end < 0:
            end = None

        start_s, end_s = tokenizer.inflate_span(source, (start, end))

        self._attrs[FIELD_LINE_START] = str(start_s.line)
        self._attrs[FIELD_COLUMN_START] = str(start_s.column)
        if end_s is not None:
            self._attrs[FIELD_LINE_END] = str(end_s.line)
            self._attrs[FIELD_COLUMN_END] = str(end_s.column)

    def set_hint_and_details(self, hint, details=None):
        ex.replace_context(
            self, ex.DefaultExceptionContext(hint=hint, details=details))

        if hint is not None:
            self._attrs[FIELD_HINT] = hint
        if details is not None:
            self._attrs[FIELD_DETAILS] = details

    def has_span(self):
        return FIELD_POSITION_START in self._attrs

    def get_span(self) -> tuple[int, int | None] | None:
        if FIELD_POSITION_START not in self._attrs:
            return None
        return (
            int(self._attrs[FIELD_POSITION_START]),
            (
                int(self._attrs[FIELD_POSITION_END])
                if FIELD_POSITION_END in self._attrs
                else None
            ),
        )

    def set_span(self, span: Optional[edb_span.Span]):
        if not span:
            return

        start = span.start_point
        end = span.end_point
        ex.replace_context(self, span)

        self._attrs[FIELD_POSITION_START] = str(start.offset)
        self._attrs[FIELD_POSITION_END] = str(end.offset)
        self._attrs[FIELD_CHARACTER_START] = str(start.char_offset)
        self._attrs[FIELD_CHARACTER_END] = str(end.char_offset)
        self._attrs[FIELD_LINE_START] = str(start.line)
        self._attrs[FIELD_COLUMN_START] = str(start.column)
        self._attrs[FIELD_UTF16_COLUMN_START] = str(start.utf16column)
        self._attrs[FIELD_LINE_END] = str(end.line)
        self._attrs[FIELD_COLUMN_END] = str(end.column)
        self._attrs[FIELD_UTF16_COLUMN_END] = str(end.utf16column)
        if span.filename and span.filename != '':
            self._attrs[FIELD_FILENAME] = span.filename

    def get_position(self) -> tuple[int, int, int, int | None] | None:
        if FIELD_COLUMN_START not in self._attrs:
            return None
        return (
            int(self._attrs[FIELD_COLUMN_START]),
            int(self._attrs[FIELD_LINE_START]),
            int(self._attrs[FIELD_POSITION_START]),
            int(self._attrs[FIELD_POSITION_END])
            if FIELD_POSITION_END in self._attrs
            else None,
        )

    def set_position(
        self,
        start: int,  # zero-based
        end: Optional[int],  # zero-based
    ):
        self._attrs[FIELD_POSITION_START] = str(start)
        self._attrs[FIELD_POSITION_END] = str(end or start)

    @property
    def line(self):
        return int(self._attrs.get(FIELD_LINE_START, -1))

    @property
    def col(self):
        return int(self._attrs.get(FIELD_COLUMN_START, -1))

    @property
    def line_end(self):
        return int(self._attrs.get(FIELD_LINE_END, -1))

    @property
    def col_end(self):
        return int(self._attrs.get(FIELD_COLUMN_END, -1))

    @property
    def position(self):
        return int(self._attrs.get(FIELD_POSITION_START, -1))

    @property
    def position_end(self):
        return int(self._attrs.get(FIELD_POSITION_END, -1))

    @property
    def hint(self):
        return self._attrs.get(FIELD_HINT)

    @property
    def details(self):
        return self._attrs.get(FIELD_DETAILS)

    @property
    def pgext_code(self):
        return self._pgext_code

    @property
    def filename(self):
        return self._attrs.get(FIELD_FILENAME, None)


@contextlib.contextmanager
def ensure_span(span: Optional[edb_span.Span]) -> Iterator[None]:
    try:
        yield
    except EdgeDBError as e:
        if span and not e.has_span():
            e.set_span(span)
        raise


FIELD_HINT = 0x_00_01
FIELD_DETAILS = 0x_00_02
FIELD_SERVER_TRACEBACK = 0x_01_01

# XXX: Subject to be changed/deprecated.
FIELD_POSITION_START = 0x_FF_F1
FIELD_POSITION_END = 0x_FF_F2
FIELD_LINE_START = 0x_FF_F3
FIELD_COLUMN_START = 0x_FF_F4
FIELD_UTF16_COLUMN_START = 0x_FF_F5
FIELD_LINE_END = 0x_FF_F6
FIELD_COLUMN_END = 0x_FF_F7
FIELD_UTF16_COLUMN_END = 0x_FF_F8
FIELD_CHARACTER_START = 0x_FF_F9
FIELD_CHARACTER_END = 0x_FF_FA
FIELD_FILENAME = 0x_FF_FB

_INT_FIELDS = {
    FIELD_POSITION_START,
    FIELD_POSITION_END,
    FIELD_LINE_START,
    FIELD_COLUMN_START,
    FIELD_UTF16_COLUMN_START,
    FIELD_LINE_END,
    FIELD_COLUMN_END,
    FIELD_UTF16_COLUMN_END,
    FIELD_CHARACTER_START,
    FIELD_CHARACTER_END,
}

# Fields to include in the json dump of the type
_JSON_FIELDS = {
    'filename': FIELD_FILENAME,
    'hint': FIELD_HINT,
    'details': FIELD_DETAILS,
    'start': FIELD_CHARACTER_START,
    'end': FIELD_CHARACTER_END,
    'line': FIELD_LINE_START,
    'col': FIELD_COLUMN_START,
}


================================================
FILE: edb/graphql/.gitignore
================================================
extension.c


================================================
FILE: edb/graphql/__init__.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

from .compiler import compile_graphql
from .translator import translate_ast, parse_text, parse_tokens
from .translator import TranspiledOperation
from .tokenizer import Source, NormalizedSource
from .types import GQLCoreSchema

from . import _patch_core
_patch_core.patch_graphql_core()


__all__ = (
    'translate_ast', 'parse_text', 'parse_tokens', 'GQLCoreSchema',
    'compile_graphql', 'TranspiledOperation', 'Source', 'NormalizedSource'
)


================================================
FILE: edb/graphql/_patch_core.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations


def patch_graphql_core():
    import graphql
    import graphql.utilities.type_comparators as type_comparators

    old_is_type_sub_type_of = type_comparators.is_type_sub_type_of

    def is_type_sub_type_of(schema, maybe_subtype, super_type):
        # allow coercing ints to floats
        if super_type is graphql.GraphQLFloat:
            if maybe_subtype is graphql.GraphQLInt:
                return True
        return old_is_type_sub_type_of(schema, maybe_subtype, super_type)

    type_comparators.is_type_sub_type_of = is_type_sub_type_of


================================================
FILE: edb/graphql/codegen.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

import json
from edb.common.ast import codegen


class GraphQLSourceGenerator(codegen.SourceGenerator):
    def generic_visit(self, node):
        raise RuntimeError(
            'No method to generate code for %s' % node.__class__.__name__)

    def _visit_list(self, items, separator=None):
        for item in items:
            self.visit(item)
            if item is not items[-1] and separator:
                self.write(separator)

    def _visit_arguments(self, node):
        if node.arguments:
            self.write('(')
            self._visit_list(node.arguments, separator=', ')
            self.write(')')

    def _visit_directives(self, node):
        if node.directives:
            self.write(' ')
            self._visit_list(node.directives, separator=', ')

    def _visit_type_condition(self, node):
        if node.type_condition:
            self.write(' on ')
            self.visit(node.type_condition)

    def visit_NameNode(self, node):
        self.write(node.value)

    def visit_DocumentNode(self, node):
        self._visit_list(node.definitions)

    def visit_OperationDefinitionNode(self, node):
        if node.operation:
            self.write(node.operation)
            if node.name:
                self.write(' ')
                self.visit(node.name)
            if node.variable_definitions:
                self.write('(')
                self._visit_list(node.variable_definitions, separator=', ')
                self.write(')')
            self._visit_directives(node)

        self.visit(node.selection_set)

    def visit_FragmentDefinitionNode(self, node):
        self.write('fragment ')
        self.visit(node.name)
        self._visit_type_condition(node)
        self._visit_directives(node)
        self.visit(node.selection_set)

    def visit_SelectionSetNode(self, node):
        self.write('{')
        self.new_lines = 1
        self.indentation += 1
        self._visit_list(node.selections)
        self.indentation -= 1
        self.write('}')
        self.new_lines = 2

    def visit_FieldNode(self, node):
        if node.alias:
            self.visit(node.alias)
            self.write(': ')
        self.visit(node.name)
        self._visit_arguments(node)
        self._visit_directives(node)
        if node.selection_set:
            self.visit(node.selection_set)
        else:
            self.new_lines = 1

    def visit_FragmentSpreadNode(self, node):
        self.write('...')
        self.visit(node.name)
        self._visit_directives(node)
        self.new_lines = 1

    def visit_InlineFragmentNode(self, node):
        self.write('...')
        self._visit_type_condition(node)
        self._visit_directives(node)
        self.visit(node.selection_set)

    def visit_ArgumentNode(self, node):
        self.visit(node.name)
        self.write(': ')
        self.visit(node.value)

    def visit_ObjectFieldNode(self, node):
        self.visit_Argument(node)
        self.new_lines = 1

    def visit_VariableDefinitionNode(self, node):
        self.visit(node.variable)
        self.write(': ')
        self.visit(node.type)
        if node.default_value:
            self.write(' = ')
            self.visit(node.default_value)

    def visit_DirectiveNode(self, node):
        self.write('@')
        self.visit(node.name)
        self._visit_arguments(node)

    def visit_StringValueNode(self, node):
        # the GQL string works same as JSON string
        self.write(json.dumps(node.value))

    def visit_IntValueNode(self, node):
        self.write(node.value)

    def visit_FloatValueNode(self, node):
        self.write(node.value)

    def visit_BooleanValueNode(self, node):
        if node.value:
            self.write('true')
        else:
            self.write('false')

    def visit_ListValueNode(self, node):
        self.write('[')
        self._visit_list(node.values, separator=', ')
        self.write(']')

    def visit_ObjectValueNode(self, node):
        if node.fields:
            self.write('{')
            self.new_lines = 1
            self.indentation += 1
            self._visit_list(node.fields)
            self.indentation -= 1
            self.write('}')
        else:
            self.write('{}')

    def visit_EnumValueNode(self, node):
        self.write(node.value)

    def visit_NullValueNode(self, node):
        self.write('null')

    def visit_VariableNode(self, node):
        self.write('$')
        self.visit(node.name)

    def visit_NamedTypeNode(self, node):
        self.visit(node.name)

    def visit_ListTypeNode(self, node):
        self.write('[')
        self.visit(node.type)
        self.write(']')

    def visit_NonNullTypeNode(self, node):
        self.visit(node.type)
        self.write('!')


generate_source = GraphQLSourceGenerator.to_source


================================================
FILE: edb/graphql/compiler.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations
from typing import Any, Optional, Mapping


from edb import graphql

from edb.schema import schema as s_schema

from graphql.language import lexer as gql_lexer


def _get_gqlcore(
    std_schema: s_schema.Schema,
    user_schema: s_schema.Schema,
    global_schema: s_schema.Schema,
) -> graphql.GQLCoreSchema:
    return graphql.GQLCoreSchema(
        s_schema.ChainedSchema(
            std_schema,
            user_schema,
            global_schema
        )
    )


def compile_graphql(
    std_schema: s_schema.Schema,
    user_schema: s_schema.Schema,
    global_schema: s_schema.Schema,
    database_config: Mapping[str, Any],
    system_config: Mapping[str, Any],
    gql: str,
    tokens: Optional[
        list[tuple[gql_lexer.TokenKind, int, int, int, int, str]]],
    substitutions: Optional[dict[str, tuple[str, int, int]]],
    operation_name: Optional[str] = None,
    variables: Optional[Mapping[str, object]] = None,
    native_input: bool = False,
    extracted_variables: Optional[Mapping[str, object]] = None,
) -> graphql.TranspiledOperation:
    if tokens is None:
        ast = graphql.parse_text(gql)
    else:
        ast = graphql.parse_tokens(gql, tokens)

    gqlcore = _get_gqlcore(std_schema, user_schema, global_schema)

    return graphql.translate_ast(
        gqlcore,
        ast,
        variables=variables,
        extracted_variables=extracted_variables,
        substitutions=substitutions,
        operation_name=operation_name,
        native_input=native_input,
    )


================================================
FILE: edb/graphql/errors.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from typing import Optional

from edb import errors


class GraphQLError(errors.QueryError):

    def __init__(self, msg, *, loc: Optional[tuple[int, int]] = None):

        super().__init__(msg)

        if loc:
            # XXX Will be fixes when we have proper LSP SourceLocation
            # abstraction.
            self.set_linecol(loc[0], loc[1])


class GraphQLTranslationError(GraphQLError):
    pass


class GraphQLValidationError(GraphQLTranslationError):
    pass


class GraphQLCoreError(GraphQLError):
    pass


================================================
FILE: edb/graphql/explore.py
================================================
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from __future__ import annotations

import base64


_react_ver = '16.8.3'
_graphiql_ver = '0.12.0'


_edgedb_logo = base64.b64encode(br'''

  
''').decode()  # NoQA


EXPLORE_HTML = (r'''


  
    
    

    

    
    
    
    
  
  
    
Loading...
''').encode() ================================================ FILE: edb/graphql/extension.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import ( Any, Dict, Tuple, List, Optional, Union, ) import cython import http import json import logging import time import urllib.parse from graphql.language import lexer as gql_lexer from edb import _graphql_rewrite from edb import errors from edb.graphql import errors as gql_errors from edb.server.dbview cimport dbview from edb.server import compiler, metrics from edb.server import defines as edbdef from edb.server.pgcon import errors as pgerrors from edb.server.protocol import execute from edb.server.compiler import errormech from edb.schema import schema as s_schema from edb.common import debug from edb.common import markup from . import explore from . import translator logger = logging.getLogger(__name__) _USER_ERRORS = ( _graphql_rewrite.LexingError, _graphql_rewrite.SyntaxError, _graphql_rewrite.NotFoundError, ) # key_vars tracks which variables are actually needed for evaluation, # since the compiler depends on some # redirect accounts for the fact that we actually don't always know # which the relevant variables are until after compiling @cython.final cdef class CacheRedirect: cdef public list key_vars # List[str], must be sorted def __init__(self, key_vars: List[str]): self.key_vars = key_vars CacheEntry = Union[ CacheRedirect, Tuple[compiler.QueryUnitGroup, translator.TranspiledOperation], ] async def handle_request( object request, object response, object db, str role_name, list args, object tenant, ): if args == ['explore'] and request.method == b'GET': response.body = explore.EXPLORE_HTML response.content_type = b'text/html' return if args != []: response.body = b'Unknown path' response.status = http.HTTPStatus.NOT_FOUND response.close_connection = True return operation_name = None variables = None globals = None config = None deprecated_globals = None query = None query_bytes_len = 0 try: if request.method == b'POST': if request.content_type and b'json' in request.content_type: body = json.loads(request.body) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') query = body.get('query') query_bytes_len = len(query.encode('utf-8')) operation_name = body.get('operationName') variables = body.get('variables') deprecated_globals = body.get('globals') elif request.content_type == 'application/graphql': query_bytes_len = len(request.body) query = request.body.decode('utf-8') else: raise TypeError( 'unable to interpret GraphQL POST request') elif request.method == b'GET': if request.url.query: url_query = request.url.query.decode('ascii') qs = urllib.parse.parse_qs(url_query) query = qs.get('query') if query is not None: query = query[0] query_bytes_len = len(query.encode('utf-8')) operation_name = qs.get('operationName') if operation_name is not None: operation_name = operation_name[0] variables = qs.get('variables') if variables is not None: try: variables = json.loads(variables[0]) except Exception: raise TypeError( '"variables" must be a JSON object') deprecated_globals = qs.get('globals') if deprecated_globals is not None: try: deprecated_globals = json.loads(deprecated_globals[0]) except Exception: raise TypeError( '"globals" must be a JSON object') else: raise TypeError('expected a GET or a POST request') if not query: raise TypeError('invalid GraphQL request: query is missing') metrics.query_size.observe( query_bytes_len, tenant.get_instance_name(), 'graphql' ) if (operation_name is not None and not isinstance(operation_name, str)): raise TypeError('operationName must be a string') if variables is not None and not isinstance(variables, dict): raise TypeError('"variables" must be a JSON object') # There are 2 ways of sending globals: # 1) as 'globals' field (deprecated) # 2) as part of 'variables' in the '__globals__' element # # If both ways are present they must match. if variables is not None: globals = variables.get('__globals__') if variables is not None: config = variables.get('__config__') if config is not None and not isinstance(config, dict): raise TypeError('"__config__" must be a JSON object') if globals is not None and not isinstance(globals, dict): raise TypeError('"__globals__" must be a JSON object') if ( deprecated_globals is not None and not isinstance(deprecated_globals, dict) ): raise TypeError('"globals" must be a JSON object') # Globals are dicts if they are present, make sure they are the same. if ( globals is not None and deprecated_globals is not None and globals != deprecated_globals ): raise ValueError('invalid "__globals__" and "globals": ' 'values must match when both are present') globals = globals or deprecated_globals except Exception as ex: if debug.flags.server: markup.dump(ex) response.body = str(ex).encode() response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True return response.status = http.HTTPStatus.OK response.content_type = b'application/json' try: result = await _execute( db, role_name, tenant, query, operation_name, variables, globals, config, ) except Exception as ex: if debug.flags.server: markup.dump(ex) if isinstance(ex, gql_errors.GraphQLError): # XXX Fix this when LSP "location" objects are implemented ex_type = errors.QueryError else: ex = await execute.interpret_error( ex, db, from_graphql=True ) ex_type = type(ex) err_dct = { 'message': f'{ex_type.__name__}: {ex}', } if (isinstance(ex, errors.EdgeDBError) and hasattr(ex, 'line') and hasattr(ex, 'col')): err_dct['locations'] = [{'line': ex.line, 'column': ex.col}] response.body = json.dumps({'errors': [err_dct]}).encode() else: response.body = b'{"data":' + result + b'}' async def compile( dbview.DatabaseConnectionView dbv, tenant, query: str, tokens: Optional[List[Tuple[int, int, int, str]]], substitutions: Optional[Dict[str, Tuple[str, int, int]]], operation_name: Optional[str], variables: Dict[str, Any], ): db = dbv._db server = tenant.server compiler_pool = server.get_compiler_pool() started_at = time.monotonic() try: return await compiler_pool.compile_graphql( db.name, db.user_schema_pickle, tenant.get_global_schema_pickle(), db.reflection_cache, dbv.get_database_config(), dbv.get_compilation_system_config(), dbv.get_session_config(), query, tokens, substitutions, operation_name, variables, client_id=tenant.client_id, client_name=tenant.get_instance_name(), ) finally: metrics.query_compilation_duration.observe( time.monotonic() - started_at, tenant.get_instance_name(), "graphql", ) async def _execute( db, role_name, tenant, query, operation_name, variables, globals, config ): dbver = db.dbver query_cache = tenant.server._http_query_cache if variables: for var_name in variables: if var_name.startswith('__edb_arg_'): raise errors.QueryError( f"Variables starting with '__edb_arg_' are prohibited") query_cache_enabled = not debug.flags.disable_qcache if debug.flags.graphql_compile: debug.header('Input graphql') print(query) print(f'variables: {variables}') try: rewritten = _graphql_rewrite.rewrite(operation_name, query) except _graphql_rewrite.QueryError as e: raise errors.QueryError(e.args[0]) except Exception as e: if isinstance(e, _USER_ERRORS): logger.info("Error rewriting graphql query: %r", e) else: logger.warning("Error rewriting graphql query: %r", e) rewritten = None rewrite_error = e prepared_query = query vars = variables.copy() if variables else {} else: prepared_query = rewritten.key vars = rewritten.variables.copy() if variables: vars.update(variables) if debug.flags.graphql_compile: debug.header('GraphQL optimized query') print(rewritten) print(f'variables: {vars}') await db.introspection() dbv: dbview.DatabaseConnectionView = await tenant.new_dbview( dbname=db.name, query_cache=False, protocol_version=edbdef.CURRENT_PROTOCOL, role_name=role_name, ) dbv.is_transient = True dbv.decode_json_session_config(config) # Put the compilation-affecting session config into the cache key. # N.B: We skip putting system/database config in here, since dbver # gets bumped whenever those change. config_key = db.server.compilation_config_serializer.encode_configs( dbv.get_session_config() ) cache_key = ( 'graphql', prepared_query, (), operation_name, dbver, config_key ) use_prep_stmt = False entry: CacheEntry = None if query_cache_enabled: entry = query_cache.get(cache_key, None) if isinstance(entry, CacheRedirect): if debug.flags.graphql_compile: print("REDIRECT", entry.key_vars) key_vars2 = tuple(vars[k] for k in entry.key_vars) cache_key2 = ( prepared_query, key_vars2, operation_name, dbver, config_key ) entry = query_cache.get(cache_key2, None) if entry is None: if rewritten is not None: qug, gql_op = await compile( dbv, tenant, query, rewritten.tokens(gql_lexer.TokenKind), rewritten.substitutions, operation_name, vars, ) else: qug, gql_op = await compile( dbv, tenant, query, None, None, operation_name, vars, ) if gql_op.cache_deps_vars and gql_op.cache_deps_vars: key_var_set = set(gql_op.cache_deps_vars) key_var_names = sorted(key_var_set) redir = CacheRedirect(key_vars=key_var_names) query_cache[cache_key] = redir key_vars2 = tuple(vars[k] for k in key_var_names) cache_key2 = ( 'graphql', prepared_query, key_vars2, operation_name, dbver ) query_cache[cache_key2] = qug, gql_op else: query_cache[cache_key] = qug, gql_op metrics.graphql_query_compilations.inc( 1.0, tenant.get_instance_name(), 'compiler' ) else: qug, gql_op = entry # This is at least the second time this query is used # and it's safe to cache. use_prep_stmt = True metrics.graphql_query_compilations.inc( 1.0, tenant.get_instance_name(), 'cache' ) compiled = dbview.CompiledQuery(query_unit_group=qug) async with tenant.with_pgcon(db.name) as pgcon: try: return await execute.execute_json( pgcon, dbv, compiled, variables={**gql_op.variables_desc, **vars}, globals_=globals or {}, fe_conn=None, use_prep_stmt=use_prep_stmt, ) finally: tenant.remove_dbview(dbv) ================================================ FILE: edb/graphql/tokenizer.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Sequence import hashlib import struct from graphql.language import lexer as gql_lexer from edb import _graphql_rewrite as graphql_rewrite # type: ignore def deserialize( serialized: bytes, text: str, ) -> Source: match serialized[0]: case 0: text = serialized[1:].decode('utf-8') return Source(text, serialized) case 1: entry = graphql_rewrite.unpack(serialized) assert isinstance(entry, graphql_rewrite.Entry) return NormalizedSource(entry, text, serialized) raise ValueError(f"Invalid type/version byte: {serialized[0]}") class Source: def __init__( self, text: str, serialized: bytes, ) -> None: self._cache_key = hashlib.blake2b(serialized).digest() self._text = text self._serialized = serialized def text(self) -> str: return self._text def cache_key(self) -> bytes: return self._cache_key def variables(self) -> dict[str, Any]: return {} def substitutions(self) -> dict[str, Any]: return {} def tokens(self) -> Optional[list[Any]]: return None def first_extra(self) -> Optional[int]: return None def extra_counts(self) -> Sequence[int]: return () def extra_blobs(self) -> list[bytes]: return [] def extra_formatted_as_text(self) -> bool: return False def extra_type_oids(self) -> Sequence[int]: return () def serialize(self) -> bytes: return self._serialized @staticmethod def from_string(text: str, operation_name: Optional[str]=None) -> Source: return Source(text=text, serialized=b'\x00' + text.encode('utf-8')) def __repr__(self): return f'' class NormalizedSource(Source): def __init__( self, # TODO: type it? normalized: Any, text: str, serialized: bytes, ) -> None: self._text = text self._cache_key = normalized.key.encode('utf-8') # or hash? self._tokens = normalized.tokens(gql_lexer.TokenKind) self._variables = normalized.variables self._substitutions = normalized.substitutions self._first_extra = ( normalized.num_variables if normalized.substitutions else None ) self._extra_counts = (len(normalized.substitutions),) self._serialized = serialized def text(self) -> str: return self._text def cache_key(self) -> bytes: return self._cache_key def variables(self) -> dict[str, Any]: return self._variables def substitutions(self) -> dict[str, Any]: return self._substitutions def tokens(self) -> Optional[list[Any]]: return self._tokens def first_extra(self) -> Optional[int]: return self._first_extra def extra_counts(self) -> Sequence[int]: return self._extra_counts def extra_blobs(self) -> list[bytes]: out = b'' # Q: Or should we use `variables` instead and reencode it? # (We'd need to use a DecimalEncoder.) # I think the token encodings in substitutions are legit json? # N.B: This relies on the substitutions being in the right order. for v, _, _ in self._substitutions.values(): ev = b'\x01' + v.encode('utf-8') out += struct.pack('!I', len(ev)) out += ev return [out] @staticmethod def from_string(text: str, operation_name: Optional[str]=None) -> Source: rewritten = graphql_rewrite.rewrite(operation_name, text) return NormalizedSource(rewritten, text, rewritten.pack()) ================================================ FILE: edb/graphql/translator.py ================================================ # mypy: ignore-errors # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import contextlib import decimal import json import re from typing import ( Any, Optional, Mapping, NamedTuple, ) import graphql from graphql.language import ast as gql_ast from graphql.language import lexer as gql_lexer from graphql import error as gql_error from graphql import language as gql_lang from edb import errors from edb.common import debug from edb.common import typeutils from edb.common.ast import visitor from edb.edgeql import ast as qlast from edb.edgeql import codegen as ql_codegen from edb.edgeql import qltypes from edb.edgeql import quote as eql_quote from edb.schema import utils as s_utils from . import types as gt from . import errors as g_errors ARG_TYPES = { 'Int': gql_ast.IntValueNode, 'String': gql_ast.StringValueNode, } REWRITE_TYPE_ERROR = re.compile( r"Variable '\$(?P__edb_arg_\d+)' of type" r" '(?P\w+)!'" r" used in position expecting type '(?P[^']+)'" ) _STR_TYPES = frozenset(("ID", "ID!")) _INT_TYPES = frozenset(("Int64", "Int64!", "Bigint", "Decimal")) _INT64_TYPES = frozenset(("Bigint", "Decimal")) _IMPLICIT_CONVERSIONS = { # Used, Expected ("String", "ID"), ("String", "ID!"), ("Int", "Int64"), ("Int", "Int64!"), ("Int", "Bigint"), ("Int", "Bigint!"), ("Int", "Decimal"), ("Int", "Decimal!"), ("Int64", "Bigint"), ("Int64", "Bigint!"), ("Int64", "Decimal"), ("Int64", "Decimal!"), ("Decimal", "Float"), ("Decimal", "Float!"), } INT_FLOAT_ERROR = re.compile( r"Variable '\$[^']+' of type 'Int!?'" r" used in position expecting type 'Float!?'" ) class GraphQLTranslatorContext: def __init__( self, *, gqlcore: gt.GQLCoreSchema, variables, query, document_ast, operation_name, native_input, parse_only_mode, ): self.variables = variables self.fragments = {} self.validated_fragments = {} self.vars = {} self.fields = [] self.path = [] self.filter = None self.include_base = [] self.gqlcore = gqlcore self.query = query self.document_ast = document_ast self.operation_name = operation_name self.native_input = native_input self.parse_only_mode = parse_only_mode # only used inside ObjectFieldNode self.base_expr = None self.right_cast = None # auto-incrementing counter self._counter = 0 @property def counter(self): val = self._counter self._counter += 1 return val class Step(NamedTuple): name: Any type: Any eql_alias: str class Field(NamedTuple): name: Any value: Any class Var(NamedTuple): val: Any defn: gql_ast.VariableDefinitionNode critical: bool class Operation(NamedTuple): name: Any stmt: Any critvars: dict[str, Any] vars: dict[str, Any] class TranspiledOperation(NamedTuple): edgeql_ast: qlast.Base cache_deps_vars: Optional[frozenset[str]] variables_desc: dict class Ordering(NamedTuple): names: list[str] direction: qlast.SortOrder nulls: qlast.NonesOrder class BookkeepDict(dict): def __init__(self, values): self.update(values) self.touched = set() def __bool__(self): # HACK! And kind of wrong! # But this keeps this from getting replaced by code in graphql-core # that does "raw_variable_values or {}"... return True def __getitem__(self, key): self.touched.add(key) return super().__getitem__(key) def __contains__(self, key): self.touched.add(key) return super().__contains__(key) def keys(self): raise NotImplementedError() def values(self): raise NotImplementedError() def items(self): raise NotImplementedError() class GraphQLTranslator: def __init__(self, *, context=None): self._context = context def node_visit(self, node): for cls in node.__class__.__mro__: method = 'visit_' + cls.__name__ visitor = getattr(self, method, None) if visitor is not None: break if visitor is None: raise AssertionError(f"Unexpected node {node.__class__}") result = visitor(node) return result def visit(self, node): if typeutils.is_container(node): return [self.node_visit(n) for n in node] else: return self.node_visit(node) def get_loc(self, node): if node.loc: token = node.loc.start_token return token.line, token.column else: return None def get_type(self, name): # the type may be from the EdgeDB schema or some special # GraphQL type/adapter assert isinstance(name, str) return self._context.gqlcore.get(name) def is_list_type(self, node): return isinstance(node, gql_ast.ListTypeNode) or ( isinstance(node, gql_ast.NonNullTypeNode) and self.is_list_type(node.type) ) def get_field_type(self, base, name, *, args=None): return base.get_field_type(name) def get_optname(self, node): if node.name: return node.name.value else: return None def visit_DocumentNode(self, node): # we need to index all of the fragments before we process operations if node.definitions: self._context.fragments = { f.name.value: f for f in node.definitions if isinstance(f, gql_ast.FragmentDefinitionNode) } else: self._context.fragments = {} operation_name = self._context.operation_name if operation_name is None: opnames = [] for opnode in node.definitions: if not isinstance(opnode, gql_ast.OperationDefinitionNode): continue opname = None if opnode.name: opname = opnode.name.value opnames.append(opname) if len(opnames) > 1: raise errors.QueryError( 'must provide operation name if query contains ' 'multiple operations') operation_name = self._context.operation_name = opnames[0] if node.definitions: translated = {d.name: d for d in self.visit(node.definitions) if d is not None} else: translated = {} if operation_name not in translated: if operation_name: raise errors.QueryError( f'unknown operation named "{operation_name}"') operation = translated[operation_name] for el in operation.stmt.result.elements: # swap in the json bits if (isinstance(el.compexpr, qlast.FunctionCall) and el.compexpr.func == 'to_json'): # An introspection query; let graphql evaluate it for us. vars = BookkeepDict(self._context.variables) result = graphql.execute( self._context.gqlcore.graphql_schema, self._context.document_ast, operation_name=operation_name, variable_values=vars) for var_name in vars.touched: var = self._context.vars.get(var_name) # TODO: Why do we track this twice? self._context.vars[var_name] = var._replace(critical=True) operation.critvars[var_name] = ( self._context.vars[var_name].val ) if result.errors: err = result.errors[0] if ( self._context.parse_only_mode and all( 'was not provided' in str(e) for e in result.errors ) ): # Don't worry about it. # And explain why! assert vars.touched elif isinstance(err, graphql.GraphQLError): err_loc = (err.locations[0].line, err.locations[0].column) raise g_errors.GraphQLCoreError( err.message, loc=err_loc) else: raise err expr = qlast.FunctionCall( func='assert_exists', args=[ qlast.TypeCast( type=qlast.TypeName( maintype=qlast.ObjectRef(name='str') ), expr=qlast.Set(elements=[]), ), ], kwargs=dict(message=qlast.Constant.string( "SERVER BUG: error with graphql introspection" )), ) else: name = el.expr.steps[0].name expr = qlast.Constant.string(json.dumps(result.data[name])) el.compexpr.args[0] = expr return translated def visit_FragmentDefinitionNode(self, node): # fragments are already processed, no need to do anything here return None def visit_OperationDefinitionNode(self, node): # create a dict of variables that will be marked as # critical or not self._context.vars = { name: Var(val=val, defn=None, critical=False) for name, val in self._context.variables.items()} self._context.include_base.append(False) opname = None if node.name: opname = node.name.value if opname != self._context.operation_name: self._context.include_base.pop() return None if (node.operation is None or node.operation == graphql.OperationType.QUERY): stmt = self._visit_query(node) elif (node.operation is None or node.operation == graphql.OperationType.MUTATION): stmt = self._visit_mutation(node) else: raise ValueError(f'unsupported operation: {node.operation!r}') # produce the list of variables critical to the shape # of the query critvars = {name: var.val for name, var in self._context.vars.items() if var.critical} # variables that were defined in this operation defvars = {name: var.val for name, var in self._context.vars.items() if var.defn is not None} self._context.include_base.pop() return Operation( name=opname, stmt=stmt, critvars=critvars, vars=defvars, ) def _visit_query(self, node): # populate input variables with defaults, where applicable if node.variable_definitions: self.visit(node.variable_definitions) # base Query needs to be configured specially base = self._context.gqlcore.get('__graphql__::Query') # special treatment of the selection_set, different from inner # recursion query = qlast.SelectQuery( result=qlast.Shape( elements=[] ) ) self._context.fields.append({}) self._context.path.append([Step(None, base, None)]) query.result.elements = self.visit(node.selection_set) self._context.fields.pop() self._context.path.pop() return query def _visit_mutation(self, node): # populate input variables with defaults, where applicable if node.variable_definitions: self.visit(node.variable_definitions) # base Mutation needs to be configured specially base = self._context.gqlcore.get('__graphql__::Mutation') # special treatment of the selection_set, different from inner # recursion query = qlast.SelectQuery( result=qlast.Shape( elements=[] ) ) self._context.fields.append({}) self._context.path.append([Step(None, base, None)]) query.result.elements = self.visit(node.selection_set) self._context.fields.pop() self._context.path.pop() return query def _should_include(self, directives): # First mark *everything* as critical for directive in directives: if directive.name.value in ('include', 'skip'): cond = [a.value for a in directive.arguments if a.name.value == 'if'][0] if isinstance(cond, gql_ast.VariableNode): varname = cond.name.value var = self._context.vars[varname] self._context.vars[varname] = var._replace(critical=True) # In parse_only_mode we are done if self._context.parse_only_mode: return True # Otherwise actually evaluate it for directive in directives: if directive.name.value in ('include', 'skip'): cond = [a.value for a in directive.arguments if a.name.value == 'if'][0] if isinstance(cond, gql_ast.VariableNode): varname = cond.name.value var = self._context.vars[varname] value = var.val if value is None: raise g_errors.GraphQLValidationError( f"no value for the {varname!r} variable", loc=self.get_loc(directive.name)) elif isinstance(cond, gql_ast.BooleanValueNode): value = cond.value if not isinstance(value, bool): raise g_errors.GraphQLValidationError( f"'if' argument of {directive.name.value} " + "directive must be a Boolean", loc=self.get_loc(directive.name)) if directive.name.value == 'include' and not value: return False elif directive.name.value == 'skip' and value: return False return True def visit_VariableDefinitionNode(self, node): varname = node.variable.name.value variables = self._context.vars var = variables.get(varname) if not var: if node.default_value is None: variables[varname] = Var( val=None, defn=node, critical=False) else: val = convert_default(node.default_value, varname) variables[varname] = Var(val=val, defn=node, critical=False) # In HTTP mode, we rely on merging the dict of defaults with # the dict of actual arguments. # # FIXME: This is actually a little dodgy, since the gel proto # will require passing an explicit None, but in graphql that # should result in null being passed and the default not # used... But in our implementation, that fails, because # the rewriter (I think?) is marking it required? # # Alright, actually for now we are rejecting it. if self._context.native_input: raise errors.UnsupportedFeatureError( 'Default variables are not supported on the ' 'gel protocol' ) else: # we have the variable, but we still need to update the defn field variables[varname] = Var( val=var.val, defn=node, critical=var.critical) def visit_SelectionSetNode(self, node): elements = [] for sel in node.selections: spec = self.visit(sel) if not self._should_include(sel.directives): continue if spec is not None: elements.append(spec) elements = self.combine_field_results(elements) return elements def _is_duplicate_field(self, node): # if this field is a duplicate, that is not identical to the # original, throw an exception name = (node.alias or node.name).value dup = self._context.fields[-1].get(name) if dup: return True else: self._context.fields[-1][name] = node return False # XXX: this might need to be trimmed def _is_top_level_field(self, node, fail=None): top = False path = self._context.path[-1] # there is different handling of top-level, built-in and inner # fields top = (len(self._context.path) == 1 and len(path) == 1 and path[0].name is None) prevt = path[-1].type target = self.get_field_type( prevt, node.name.value) path.append(Step(name=node.name.value, type=target, eql_alias=None)) if not top and fail: raise g_errors.GraphQLValidationError( f"field {node.name.value!r} can only appear " f"at the top-level Query", loc=self.get_loc(node)) return top def _maybe_get_current_type(self): if self._context.path: path = self._context.path[-1] return path[-1].type else: return None def _get_parent_and_current_type(self): path = self._context.path[-1] cur = path[-1].type if len(path) > 1: par = path[-2].type else: par = self._context.path[-2][-1].type return par, cur def _prepare_field(self, node): path = self._context.path[-1] include_base = self._context.include_base[-1] is_top = self._is_top_level_field(node) spath = self._context.path[-1] prevt, target = self._get_parent_and_current_type() # insert normal or specialized link steps = [] if include_base: base = spath[0].type steps.append(qlast.TypeIntersection( type=qlast.TypeName( maintype=base.edb_base_name_ast ) )) steps.append(qlast.Ptr(name=node.name.value)) return is_top, path, prevt, target, steps def visit_FieldNode(self, node): if self._is_duplicate_field(node): return _is_top, _path, prevt, target, steps = \ self._prepare_field(node) json_mode = False is_shadowed = prevt.is_field_shadowed(node.name.value) # determine if there needs to be extra subqueries if not prevt.dummy and target.dummy: json_mode = True # this is a special introspection type eql, shape, filterable = target.get_template() spec = qlast.ShapeElement( expr=qlast.Path( steps=[qlast.Ptr( name=(node.alias or node.name).value, )] ), compexpr=eql, ) elif is_shadowed and not node.alias: # shadowed field that doesn't need an alias spec = filterable = shape = qlast.ShapeElement( expr=qlast.Path(steps=steps), ) elif not node.selection_set or is_shadowed and node.alias: # this is either an unshadowed terminal field or an aliased # shadowed field prefix = qlast.Path(steps=self.get_path_prefix(-1)) eql, shape, filterable = prevt.get_field_template( node.name.value, parent=prefix, has_shape=bool(node.selection_set) ) spec = qlast.ShapeElement( expr=qlast.Path( steps=[qlast.Ptr( # this is already a sub-query name=(node.alias or node.name).value )] ), compexpr=eql, # preserve the original cardinality of the computable # aliased fields cardinality=prevt.get_field_cardinality(node.name.value), ) else: # if the parent is NOT a shadowed type, we need an explicit SELECT eql, shape, filterable = target.get_template() spec = qlast.ShapeElement( expr=qlast.Path( steps=[qlast.Ptr( # this is already a sub-query name=(node.alias or node.name).value )] ), compexpr=eql, # preserve the original cardinality of the computable, # which is basically one of the top-level query # fields, all of which are returning lists cardinality=qltypes.SchemaCardinality.Many, ) self._context.include_base.append(False) # INSERT mutations have different arguments from queries if not is_shadowed and node.name.value.startswith('insert_'): # a single recursion target, so we can process # selection set now with self._update_path_for_eql_alias(): alias = self._context.path[-1][-1].eql_alias self._context.fields.append({}) shape.elements = self.visit(node.selection_set) insert_shapes = self._visit_insert_arguments(node.arguments) if not insert_shapes: # No insert arguments, nmeaning that a single object must # be inserted without any shape. insert_shapes = [None] self._context.fields.pop() filterable.aliases = [ qlast.AliasedExpr( alias=alias, expr=qlast.Set(elements=[ qlast.InsertQuery( subject=shape.expr, shape=sh, ) for sh in insert_shapes ]) ) ] filterable.result.expr = qlast.Path( steps=[qlast.ObjectRef(name=alias)]) elif node.selection_set is not None: delete_mode = (not is_shadowed and node.name.value.startswith('delete_')) update_mode = (not is_shadowed and node.name.value.startswith('update_')) if not json_mode: # a single recursion target, so we can process # selection set now with self._update_path_for_eql_alias( delete_mode or update_mode): # set up a unique alias for the deleted object alias = self._context.path[-1][-1].eql_alias self._context.fields.append({}) vals = self.visit(node.selection_set) self._context.fields.pop() if shape: shape.elements = vals if filterable: where, orderby, offset, limit = \ self._visit_query_arguments(node.arguments) filterable.where = where filterable.orderby = orderby filterable.offset = offset filterable.limit = limit if delete_mode: # this should be a DELETE operation, so we'll rearrange the # components of the SelectQuery filterable.aliases = [ qlast.AliasedExpr( alias=alias, expr=qlast.DeleteQuery( subject=filterable.result.expr, where=filterable.where, ) ) ] filterable.where = None filterable.result.expr = qlast.Path( steps=[qlast.ObjectRef(name=alias)]) elif update_mode: update_shape = self._visit_update_arguments(node.arguments) # this should be an UPDATE operation, so we'll rearrange the # components of the SelectQuery and add data operations filterable.aliases = [ qlast.AliasedExpr( alias=alias, expr=qlast.UpdateQuery( subject=filterable.result.expr, where=filterable.where, shape=update_shape, ) ) ] filterable.where = None filterable.result.expr = qlast.Path( steps=[qlast.ObjectRef(name=alias)]) # Remove the processed path. self._context.path[-1].pop() if len(self._context.path[-1]) == 0: # If this was the last shape field, remove the now empty # shell for the shape paths. self._context.path.pop() self._context.include_base.pop() return spec def visit_InlineFragmentNode(self, node): self._validate_fragment_type(node, node) result = self.visit(node.selection_set) if node.type_condition is not None: self._context.path.pop() self._context.include_base.pop() return result def visit_FragmentSpreadNode(self, node): frag = self._context.fragments[node.name.value] self._validate_fragment_type(frag, node) # in case of secondary type, recurse into a copy to avoid # memoized results selection_set = frag.selection_set result = self.visit(selection_set) self._context.path.pop() if frag.type_condition is not None: self._context.include_base.pop() return result def _validate_fragment_type(self, frag, spread): is_specialized = False base_type = None # validate the fragment type w.r.t. the base if frag.type_condition is None: return # validate the base if it's nested if len(self._context.path) > 0: path = self._context.path[-1] base_type = path[-1].type frag_type = self.get_type(frag.type_condition.name.value) if base_type.issubclass(frag_type): # legal hierarchy, no change pass elif frag_type.issubclass(base_type): # specialized link, but still legal is_specialized = True else: raise g_errors.GraphQLValidationError( f"{base_type.short_name} and {frag_type.short_name} " + "are not related", loc=self.get_loc(frag)) self._context.path.append([ Step(name=frag.type_condition, type=frag_type, eql_alias=None)]) self._context.include_base.append(is_specialized) def _visit_query_arguments(self, arguments): where = None orderby = [] first = last = before = after = None for arg in arguments: if arg.name.value == 'filter': where = self.visit(arg.value) elif arg.name.value == 'order': orderby = self.visit_order(arg.value) elif arg.name.value == 'first': first = self._visit_pagination_arg( arg, 'Int', expected='an int') elif arg.name.value == 'last': last = self._visit_pagination_arg( arg, 'Int', expected='an int') elif arg.name.value == 'before': before = self._visit_pagination_arg( arg, 'String', expected='a string castable to an int') elif arg.name.value == 'after': after = self._visit_pagination_arg( arg, 'String', expected='a string castable to an int') # convert before, after, first and last into offset and limit offset, limit = self.get_offset_limit(after, before, first, last) # FIXME: it may be a good idea to create special scalar # (positive integer) so that the values used for offset and # limit can be cast into it and appropriate errors will be # produced. return where, orderby, offset, limit def _visit_pagination_arg(self, node, argtype, expected): if isinstance(node.value, gql_ast.VariableNode): # variables will be type-checked by this point, so assume # the type is valid return self.visit(node.value) elif not isinstance(node.value, ARG_TYPES[argtype]): raise g_errors.GraphQLValidationError( f"invalid value for {node.name.value!r}: " f"expected {expected}", loc=self.get_loc(node.value)) from None try: return int(node.value.value) except (TypeError, ValueError): raise g_errors.GraphQLValidationError( f"invalid value for {node.name.value!r}: " f"expected {expected}, " f"got {node.value.value!r}", loc=self.get_loc(node.value)) from None def get_offset_limit(self, after, before, first, last): # if all the parameters here are constants we can compute and # compile shorter and simpler OFFSET/LIMIT values if any(isinstance(x, qlast.Base) for x in [after, before, first, last] if x is not None): return self._get_general_offset_limit(after, before, first, last) else: return self._get_static_offset_limit(after, before, first, last) def _get_static_offset_limit(self, after, before, first, last): if after is not None: # The +1 is to make 'after' into an appropriate index. # # 0--a--1--b--2--c--3-- ... we call element at # index 0 (or "element 0" for short), the element # immediately after the mark 0. So after "element # 0" really means after "index 1". after += 1 offset = limit = None # convert before, after, first and last into offset and limit if after is not None: offset = after if before is not None: limit = before - (after or 0) if first is not None: if limit is None: limit = first else: limit = min(first, limit) if last is not None: if limit is not None: if last < limit: offset = (offset or 0) + limit - last limit = last else: # FIXME: there wasn't any limit, so we can define last # in terms of offset alone without negative OFFSET # implementation raise g_errors.GraphQLTranslationError( f'last={last} translates to a negative OFFSET in ' f'EdgeQL which is currently unsupported') # convert integers into qlast literals if offset is not None and not isinstance(offset, qlast.Base): offset = qlast.Constant.integer(max(0, offset)) if limit is not None: limit = qlast.Constant.integer(max(0, limit)) return offset, limit def _get_int64_slice_value(self, value): if value is None: return None if isinstance(value, qlast.Base): return qlast.TypeCast( type=qlast.TypeName( maintype=qlast.ObjectRef(name='int64')), expr=value ) else: return qlast.Constant.integer(value) def _get_general_offset_limit(self, after, before, first, last): # Convert any static values to corresponding qlast and # normalize them as int64. after = self._get_int64_slice_value(after) before = self._get_int64_slice_value(before) first = self._get_int64_slice_value(first) last = self._get_int64_slice_value(last) offset = limit = None # convert before, after, first and last into offset and limit if after is not None: # The +1 is to make 'after' into an appropriate index. # # 0--a--1--b--2--c--3-- ... we call element at # index 0 (or "element 0" for short), the element # immediately after the mark 0. So after "element # 0" really means after "index 1". offset = qlast.BinOp( left=after, op='+', right=qlast.Constant.integer('1') ) if before is not None: # limit = before - (after or 0) if after: limit = qlast.BinOp( left=before, op='-', right=offset, ) else: limit = before if first is not None: if limit is None: limit = first else: limit = qlast.IfElse( if_expr=first, condition=qlast.BinOp( left=first, op='<', right=limit ), else_expr=limit ) if last is not None: if limit is not None: if offset: offset = qlast.BinOp( left=offset, op='+', right=qlast.BinOp( left=limit, op='-', right=last ) ) else: offset = qlast.BinOp( left=limit, op='-', right=last ) limit = qlast.IfElse( if_expr=last, condition=qlast.BinOp( left=last, op='<', right=limit ), else_expr=limit ) else: # FIXME: there wasn't any limit, so we can define last # in terms of offset alone without negative OFFSET # implementation raise g_errors.GraphQLTranslationError( f'last translates to a negative OFFSET in ' f'EdgeQL which is currently unsupported') return offset, limit @contextlib.contextmanager def _update_path_for_eql_alias(self, alias_needed=True): if alias_needed: # we need to update the path of the delete field to keep track # of the delete alias alias = f'x{self._context.counter}' # just replace the last path element with the same # element, but aliased step = self._context.path[-1].pop() self._context.path.append([ Step(name=step.name, type=step.type, eql_alias=alias)]) yield # replace it back if alias_needed: self._context.path[-1].pop() self._context.path[-1].append(step) def _visit_update_arguments(self, arguments): result = [] for arg in arguments: if arg.name.value == 'data': # the node is an ObjectNode with the update spec for field in arg.value.fields: fname = field.name.value # capture the full path to the field being updated eqlpath = self.get_path_prefix() eqlpath.append(qlast.Ptr(name=fname)) eqlpath = qlast.Path(steps=eqlpath) # set-up the current path to point to the thing # being updated (so that SELECT can be applied if needed) with self._update_path_for_insert_field(field): _, target = self._get_parent_and_current_type() res = self._visit_update_op( field.value, eqlpath, target) if res is None: continue shapeop, value = res result.append( qlast.ShapeElement( expr=qlast.Path( steps=[qlast.Ptr(name=field.name.value)] ), operation=qlast.ShapeOperation(op=shapeop), compexpr=value, ) ) return result def _visit_update_op(self, node, eqlpath, ftype): # The node is an ObjectNode with the update spec. The fields represent # different oprations that can be performend. Although the spec lists # multiple options exactly one of the options should be present. if not node.fields: raise g_errors.GraphQLValidationError( "No update operation was specified.", loc=self.get_loc(node)) if len(node.fields) > 1: raise g_errors.GraphQLValidationError( "Too many update operations were specified.", loc=self.get_loc(node)) field = node.fields[0] fname = field.name.value # by default we expect an assign shapeop = qlast.ShapeOp.ASSIGN ptrname = eqlpath.steps[-1].name # NOTE: there will be more operations in the future if fname == 'set': value = self._get_input_expr_for_pointer_mutation(field, ptrname) return shapeop, value elif fname == 'clear': cond = field.value if isinstance(cond, gql_ast.VariableNode): var_name = cond.name.value var = self._context.vars[var_name] if not var.critical: self._context.vars[var_name] = \ var._replace(critical=True) value = var.val elif isinstance(cond, gql_ast.BooleanValueNode): value = cond.value elif isinstance(cond, gql_ast.NullValueNode): value = None else: # We assume that schema was validated, # so variable is of correct type raise AssertionError(f"Unexpected node {cond!r}") if value: # empty set to clear the value return shapeop, qlast.Set(elements=[]) elif fname == 'increment': value = qlast.BinOp( left=eqlpath, op='+', right=self._visit_insert_value(field.value) ) return shapeop, value elif fname == 'decrement': value = qlast.BinOp( left=eqlpath, op='-', right=self._visit_insert_value(field.value) ) return shapeop, value elif fname == 'prepend': value = qlast.BinOp( left=self._visit_insert_value(field.value), op='++', right=eqlpath ) return shapeop, value elif fname == 'append': value = qlast.BinOp( left=eqlpath, op='++', right=self._visit_insert_value(field.value) ) return shapeop, value elif fname == 'slice': args = field.value.values num_args = len(args) if num_args == 1: start = self.visit(args[0]) stop = None elif num_args == 2: start = self.visit(args[0]) stop = self.visit(args[1]) else: raise g_errors.GraphQLTranslationError( f'"slice" must be a list of 1 or 2 integers') value = qlast.Indirection( arg=eqlpath, indirection=[qlast.Slice( start=start, stop=stop )] ) return shapeop, value elif fname == 'add': # This is a set, so no reason to validate cardinality. value = self._get_input_expr_for_pointer_mutation( field, ptrname, validate_cardinality=False) shapeop = qlast.ShapeOp.APPEND return shapeop, value elif fname == 'remove': # This is a set, so no reason to validate cardinality. value = self._get_input_expr_for_pointer_mutation( field, ptrname, validate_cardinality=False) shapeop = qlast.ShapeOp.SUBTRACT return shapeop, value def _visit_insert_arguments(self, arguments): input_data = [] for arg in arguments: if arg.name.value == 'data': # normalize the value to a list if isinstance(arg.value, gql_ast.ListValueNode): input_data = arg.value.values else: input_data = [arg.value] return [self._get_shape_from_input_data(node) for node in input_data] def _get_shape_from_input_data(self, node): # the node is an ObjectNode with the input spec result = [] for field in node.fields: # set-up the current path to point to the thing being inserted with self._update_path_for_insert_field(field): compexpr = self._get_input_expr_for_pointer_mutation( field, field.name.value) result.append( qlast.ShapeElement( expr=qlast.Path( steps=[qlast.Ptr(name=field.name.value)] ), compexpr=compexpr, ) ) return result def _get_input_expr_for_pointer_mutation( self, field, fname, validate_cardinality=True, ): compexpr = self._visit_insert_value(field.value) # get the type of the value being inserted ptype, target = self._get_parent_and_current_type() # Object types in mutations potentially need some extra assertions # to validate them. if target.is_object_type: if validate_cardinality: card = ptype.get_field_cardinality(fname) if card is qltypes.SchemaCardinality.Many: # Need to wrap the set into an "assert_distinct()". msg = f'objects provided for {fname!r} are not distinct' compexpr = qlast.FunctionCall( func='assert_distinct', args=[compexpr], kwargs={ 'message': qlast.Constant.string(msg) } ) else: # Singleton object values need to be verified. msg = f'more than one object provided for {fname!r}' compexpr = qlast.FunctionCall( func='assert_single', args=[compexpr], kwargs={ 'message': qlast.Constant.string(msg) } ) # Object types need to be wrapped in a DETACHED in # mutations to avoid referencing the root object. compexpr = qlast.DetachedExpr(expr=compexpr) return compexpr @contextlib.contextmanager def _update_path_for_insert_field(self, node): # we need to update the path of the insert field to keep track # of the insert types path = self._context.path[-1] prevt = path[-1].type target = self.get_field_type( prevt, node.name.value) self._context.path.append([ Step(name=None, type=target, eql_alias=None)]) yield self._context.path.pop() def _visit_range_spec(self, node, target): assert isinstance(node, gql_ast.ObjectValueNode) assert target.is_range or target.is_multirange # This is a range spec subtype = target.edb_base.get_subtypes(target.edb_schema)[0] st_name = subtype.get_name(target.edb_schema) kwargs = { rf.name.value: self.visit(rf.value) for rf in node.fields if not isinstance(rf.value, gql_ast.NullValueNode) } # move some kwargs into args args = [ qlast.TypeCast( expr=kwargs.pop('lower', qlast.Set(elements=[])), type=qlast.TypeName( maintype=qlast.ObjectRef(name=str(st_name)), ), ), qlast.TypeCast( expr=kwargs.pop('upper', qlast.Set(elements=[])), type=qlast.TypeName( maintype=qlast.ObjectRef(name=str(st_name)), ), ), ] return qlast.FunctionCall( func='range', args=args, kwargs=kwargs, ) def _visit_insert_value(self, node): # get the type of the value being inserted _, target = self._get_parent_and_current_type() if isinstance(node, gql_ast.ObjectValueNode): if target.is_range or target.is_multirange: # This is a range spec return self._visit_range_spec(node, target) # get a template AST eql, shape, filterable = target.get_template() if node.fields[0].name.value == 'data': # this may be a new object spec data_node = node.fields[0].value return qlast.InsertQuery( subject=shape.expr, shape=self._get_shape_from_input_data(data_node), ) else: eql.result = shape.expr # this is a filter spec where, orderby, offset, limit = \ self._visit_query_arguments(node.fields) filterable.where = where filterable.orderby = orderby filterable.offset = offset filterable.limit = limit return eql elif isinstance(node, gql_ast.ListValueNode) and target.is_multirange: # Multiranges are composed of a list of ranges. So we just need to # wrap the literal array into a range function call. return qlast.FunctionCall( func='multirange', args=[ qlast.Array( elements=[ self._visit_insert_value(el) for el in node.values ] ), ], ) elif isinstance(node, gql_ast.ListValueNode) and not target.is_array: # not an actual array or multirange, but a set represented as a # list return qlast.Set(elements=[ self._visit_insert_value(el) for el in node.values]) else: # some scalar value val = self.visit(node) if target.is_json: # JSON can only come as a variable and will already be # converted appropriately. return val elif target.edb_base_name != 'std::str': # bigint data would require a bigint input, so # check if the expression is using a parameter if (target.edb_base_name == 'std::bigint' and isinstance(node, gql_ast.VariableNode) and val.type.maintype.name == 'int64'): res = val res.type.maintype.name = target.edb_base_name else: res = qlast.TypeCast( expr=val, type=qlast.TypeName( maintype=target.edb_base_name_ast ) ) if target.is_array: res = qlast.TypeCast( expr=val, type=qlast.TypeName( maintype=qlast.ObjectRef(name='array'), subtypes=[res.type], ) ) elif target.is_range: # Range inputs come in two varieties: as a variable or as # a literal. Variables are already in JSON format and only # need to be cast into the appropriate range. Literals are # processed earlier as ObjectValueNode. res = qlast.TypeCast( expr=val, type=qlast.TypeName( maintype=qlast.ObjectRef(name='range'), subtypes=[res.type], ) ) elif target.is_multirange: # Multiranges are composed of a list of ranges. List # literal is processed earlier, so we just need to cast # JSON into an array of ranges if it came from a # varaible. rtype = qlast.TypeName( maintype=qlast.ObjectRef(name='range'), subtypes=[res.type], ) res = qlast.FunctionCall( func='multirange', args=[ qlast.TypeCast( expr=val, type=qlast.TypeName( maintype=qlast.ObjectRef(name='array'), subtypes=[rtype], ) ) ], ) return res else: return val def get_path_prefix(self, end_trim=None): # flatten the path path = [step for psteps in self._context.path for step in psteps] # find the first shadowed root prev_step = None base_step = None partial = False base_i = 0 for i, step in enumerate(path): cur = step.type # if the field is specifically shadowed, then this is # appropriate shadow base if base_step is None and not partial: if (prev_step is not None and prev_step.type.is_field_shadowed(step.name)): base_step = prev_step base_i = i break # otherwise the base must be shadowing an entire type elif isinstance(cur, gt.GQLShadowType): base_step = step base_i = i # we have a base, but we might find out that we need to # override it with a partial path elif step.name is None and isinstance(cur, gt.GQLShadowType): partial = True base_step = None base_i = i # this is where the actual partial path steps start elif partial and step.name is not None: break prev_step = step else: # we got to the end of the list without hitting other # conditions, so that's the base if base_step is None: base_step = step base_i = i # trim the rest of the path path = path[base_i + 1:end_trim] if base_step is None: # if the base_step is of the form (None, GQLShadowType), then # we don't want any prefix, because we'll use partial paths prefix = [] elif base_step.eql_alias: # the root may be aliased prefix = [qlast.ObjectRef(name=base_step.eql_alias)] else: prefix = [base_step.type.edb_base_name_ast] for step in path: if isinstance(step.name, gql_ast.NamedTypeNode): # This is coming from a fragment, so we need to add a # type intersection. base = step.type prefix.append( qlast.TypeIntersection( type=qlast.TypeName( maintype=base.edb_base_name_ast ) ) ) else: prefix.append(qlast.Ptr(name=step.name)) return prefix def visit_ListValueNode(self, node): return qlast.Array(elements=self.visit(node.values)) def visit_ObjectValueNode(self, node): # This represents some expression to be used in filter. In # case of multiple expressions they are implicitly combined # using AND. return self._visit_list_generalized_bool_op(node.fields, 'AND') def visit_ObjectFieldNode(self, node): fname = node.name.value # handle boolean ops if fname == 'and': # Conform to Postgres AND, which treats False AND NULL = False. return self._visit_list_of_inputs(node.value, 'AND') elif fname == 'or': # Conform to Postgres OR, which treats True OR NULL = True return self._visit_list_of_inputs(node.value, 'OR') elif fname == 'not': return qlast.UnaryOp(op='NOT', operand=self.visit(node.value)) # handle various scalar ops op = gt.GQL_TO_OPS_MAP.get(fname) if op: value = self.visit(node.value) left = self._context.base_expr # 'exists' filter gets converted to: # EXISTS () = # where the is either true or false. This is so # that there's a one-to-one correspondence between the # potential input variables and the EdgeQL variables. # # If different EdgeQL code were generated instead, then # the assumption that it's safe to re-run the same EdgeQL # query with different input variables would not hold. if op == 'EXISTS': left = qlast.UnaryOp(op='EXISTS', operand=left) # The binary operator that we need here is "=" op = '=' elif op == 'IN': # Instead of wrapping the values in an array, wrap # them in a set value = qlast.FunctionCall( func='array_unpack', args=[value], ) elif self._context.right_cast is not None: # We don't need to cast the RHS for the EXISTS, only # for other operations. value = qlast.TypeCast( expr=value, type=self._context.right_cast, ) return qlast.BinOp( left=left, op=op, right=value) # we're at the beginning of a scalar op _, target = self._get_parent_and_current_type() name = self.get_path_prefix() name.append(qlast.Ptr(name=fname)) name = qlast.Path( steps=name, # paths that start with a Ptr are partial partial=isinstance(name[0], qlast.Ptr), ) ftype = target.get_field_type(fname) typename = ftype.edb_base_name if typename not in {'std::str', 'std::uuid'}: gql_type = gt.EDB_TO_GQL_SCALARS_MAP.get(typename) if gql_type == graphql.GraphQLString: # potentially need to cast the 'name' side into a # , so as to be compatible with the 'value' name = qlast.TypeCast( expr=name, type=qlast.TypeName(maintype=qlast.ObjectRef(name='str')), ) # ### Set up context for the nested visitor ### self._context.base_expr = name # potentially the right-hand-side needs to be cast into a float if ftype.is_float: self._context.right_cast = qlast.TypeName( maintype=ftype.edb_base_name_ast) elif typename == 'std::uuid': self._context.right_cast = qlast.TypeName( maintype=qlast.ObjectRef(name='uuid')) path = self._context.path[-1] path.append(Step(name=fname, type=ftype, eql_alias=None)) try: value = self.visit(node.value) finally: path.pop() self._context.right_cast = None self._context.base_expr = None # we need to cast a target string into or enum if (typename == 'std::uuid' and not ( # EXISTS side does not need a cast isinstance(value.left, qlast.UnaryOp) and value.left.op == 'EXISTS' ) and not isinstance(value.right, qlast.TypeCast)): value.right = qlast.TypeCast( expr=value.right, type=qlast.TypeName(maintype=ftype.edb_base_name_ast), ) elif ftype.is_enum: value.right = qlast.TypeCast( expr=value.right, type=qlast.TypeName(maintype=ftype.edb_base_name_ast), ) return value def visit_order(self, node): if not isinstance(node, gql_ast.ObjectValueNode): raise g_errors.GraphQLTranslationError( f'an object is expected for "order"') # if there is no specific ordering, then order by id if not node.fields: return [qlast.SortExpr( path=qlast.Path( steps=[qlast.Ptr(name='id')], partial=True, ), direction=qlast.SortAsc, )] # Ordering is handled by specifying a list of special Ordering objects. # Validation is already handled by this point. orderby = [] for ordering in self._visit_order_item(node): orderby.append(qlast.SortExpr( path=qlast.Path( steps=[ qlast.Ptr(name=name) for name in ordering.names ], partial=True, ), direction=ordering.direction, nones_order=ordering.nulls, )) return orderby def _visit_order_item(self, node): if not isinstance(node, gql_ast.ObjectValueNode): raise g_errors.GraphQLTranslationError( f'an object is expected for "order"') orderings = [] direction = nulls = None for part in node.fields: # Check if there's a longer nested path here. If there is, # validate that there's only one option chosen at this # level. if isinstance(part.value, gql_ast.ObjectValueNode): for subordering in self._visit_order_item(part.value): orderings.append( Ordering( names=[part.name.value] + subordering.names, direction=subordering.direction, nulls=subordering.nulls ) ) elif part.name.value == 'dir': direction = part.value.value elif part.name.value == 'nulls': nulls = part.value.value if orderings: # We have compiled some ordering paths, so we don't have # any direction or nulls on this level. return orderings # direction is a required field, so we can rely on it having # one of two values if direction == 'ASC': direction = qlast.SortAsc # nulls are optional, but are 'SMALLEST' by default if nulls == 'BIGGEST': nulls = qlast.NonesLast else: nulls = qlast.NonesFirst else: # DESC direction = qlast.SortDesc # nulls are optional, but are 'SMALLEST' by default if nulls == 'BIGGEST': nulls = qlast.NonesFirst else: nulls = qlast.NonesLast return [Ordering(names=[], direction=direction, nulls=nulls)] def visit_VariableNode(self, node): return self._get_variable(node.name.value) def _get_variable(self, varname): var = self._context.vars[varname] err_msg = (f"Only scalar input variables are allowed. " f"Variable {varname!r} has non-scalar value.") vartype = var.defn.type optional = True # get the type of the value being inserted target = self._maybe_get_current_type() if isinstance(vartype, gql_ast.NonNullTypeNode): vartype = vartype.type optional = False if self.is_list_type(vartype): if target and target.is_multirange: castname = qlast.ObjectRef(name='json') else: # So far the only list allowed is a multirange # representation. raise errors.QueryError(err_msg) elif vartype.name.value in gt.GQL_TO_EDB_SCALARS_MAP: castname = qlast.ObjectRef( name=gt.GQL_TO_EDB_SCALARS_MAP[vartype.name.value]) elif ( name := gt.GQL_TO_EDB_RANGES_MAP.get(vartype.name.value) ) is not None: castname = qlast.ObjectRef(name=name) else: try: vtype = self.get_type( self._context.gqlcore.gql_to_edb_name(vartype.name.value)) except AssertionError: raise errors.QueryError(err_msg) if vtype.is_enum: castname = vtype.edb_base_name_ast else: raise errors.QueryError(err_msg) casttype = qlast.TypeName(maintype=castname) casts = [casttype] # Currently, whe using the native protocol we pass in # extracted arguments as JSON instead of native encodings. # We probably should be able to do better, since we do this right # on the edgeql extraction side, but I didn't want to bother # with integrating the extractors to share the code. if self._context.native_input and varname.startswith('__edb_arg_'): casts.append( qlast.TypeName(maintype=qlast.ObjectRef(name='json')) ) val = qlast.QueryParameter(name=varname) for ct in reversed(casts): val = qlast.TypeCast( type=ct, expr=val, cardinality_mod=( qlast.CardinalityModifier.Optional if optional else None ), ) return val def visit_StringValueNode(self, node): return qlast.Constant.string(node.value) def visit_IntValueNode(self, node): # produces an int64 or bigint val = int(node.value) if s_utils.MIN_INT64 <= val <= s_utils.MAX_INT64: return qlast.Constant.integer(val) else: return qlast.Constant( value=f'{val}n', kind=qlast.ConstantKind.BIGINT ) def visit_FloatValueNode(self, node): # Treat all Float as Decimal by default and downcast as necessary return qlast.Constant( value=f'{node.value}n', kind=qlast.ConstantKind.DECIMAL ) def visit_BooleanValueNode(self, node): value = 'true' if node.value else 'false' return qlast.Constant.boolean(value) def visit_EnumValueNode(self, node): return qlast.Constant.string(node.value) def _visit_list_of_inputs(self, inputlist, op): if not isinstance(inputlist, gql_ast.ListValueNode): raise g_errors.GraphQLTranslationError( f'a list was expected') return self._visit_list_generalized_bool_op( [node for node in inputlist.values], op) def _visit_list_generalized_bool_op(self, nodes, op): # Generalization of a boolean operation AND or OR as it is # applied to a list of expressions. This comes up in filters # either explicitly by using 'and' or 'or' or by supplying a # list of expressions where 'and' is implied. # # In this limited context it is appropriate to use Postres' # truth table for AND and OR, short-circuiting "False AND # anything" or "True OR anything" respectively to "False" and # "True" instead of the stricter EdgeQL rules that would # produce empty sets if any of the inputs are empty. if not nodes: return None elif len(nodes) == 1: return self.visit(nodes[0]) # The short-circuiting value is True for OR and False for AND. opname = ('sys', f'__pg_{op.lower()}') exprs = [self.visit(node) for node in nodes] result = qlast.FunctionCall( func=opname, args=exprs[0:2], ) for expr in exprs[2:]: result = qlast.FunctionCall( func=opname, args=[result, expr], ) return result def combine_field_results(self, results, *, flatten=True): if flatten: flattened = [] for res in results: if isinstance(res, Field): flattened.append(res) elif isinstance(res, dict): flattened.extend(res.values()) elif typeutils.is_container(res): flattened.extend(res) else: flattened.append(res) return flattened else: return results def value_node_from_pyvalue(val: Any): if val is None: return None elif isinstance(val, str): val = val.replace('\\', '\\\\') value = eql_quote.quote_literal(val) return gql_ast.StringValueNode(value=value[1:-1]) elif isinstance(val, bool): return gql_ast.BooleanValueNode(value=bool(val)) elif isinstance(val, int): return gql_ast.IntValueNode(value=str(val)) elif isinstance(val, (float, decimal.Decimal)): return gql_ast.FloatValueNode(value=str(val)) elif isinstance(val, list): return gql_ast.ListValueNode( values=[value_node_from_pyvalue(v) for v in val]) elif isinstance(val, dict): return gql_ast.ObjectValueNode( fields=[ gql_ast.ObjectFieldNode( name=n, value=value_node_from_pyvalue(v) ) for n, v in val.items() ]) else: raise ValueError(f'unexpected constant type: {type(val)!r}') def parse_text(query: str) -> graphql.Document: try: return graphql.parse(query) except graphql.GraphQLError as err: err_loc = (err.locations[0].line, err.locations[0].column) raise g_errors.GraphQLCoreError(err.message, loc=err_loc) from None class TokenLexer(graphql.language.lexer.Lexer): def __init__(self, source, tokens, eof_pos): self.__tokens = tokens self.__index = 0 self.__eof_pos = eof_pos self.source = source kind, start, end, line, col, body = self.__tokens[0] self.token = gql_lexer.Token(kind, start, end, line, col, None, body) def advance(self) -> gql_lexer.Token: self.last_token = self.token token = self.token = self.lookahead() self.__index += 1 return token def lookahead(self) -> gql_lexer.Token: token = self.token if token.kind != gql_lexer.TokenKind.EOF: if token.next: return self.token.next kind, start, end, line, col, body = self.__tokens[self.__index + 1] token.next = gql_lexer.Token( kind, start, end, line, col, token, body) return token.next else: return token def parse_tokens( text: str, tokens: list[tuple[gql_lexer.TokenKind, int, int, int, int, str]] ) -> graphql.Document: try: src = graphql.Source(text) parser = graphql.language.parser.Parser(src) parser._lexer = TokenLexer(src, tokens, len(text)) return parser.parse_document() except graphql.GraphQLError as err: err_loc = (err.locations[0].line, err.locations[0].column) raise g_errors.GraphQLCoreError(err.message, loc=err_loc) from None def convert_errors( errs: list[gql_error.GraphQLError], *, substitutions: Optional[dict[str, tuple[str, int, int]]], ) -> list[gql_error.GraphQLErrors]: result = [] for err in errs: m = REWRITE_TYPE_ERROR.match(err.message) if not m: # we allow conversion from Int to Float, and that is allowed by # graphql spec. It's unclear why graphql-core chokes on this if INT_FLOAT_ERROR.match(err.message): continue result.append(err) continue elif (m.group("used"), m.group("expected")) in _IMPLICIT_CONVERSIONS: # skip the error, we avoid it in the execution code continue value, line, col = substitutions[m.group("var_name")] err = gql_error.GraphQLError( f"Expected type {m.group('expected')}, found {value}.") err.locations = [gql_lang.SourceLocation(line, col)] result.append(err) return result def translate_ast( gqlcore: gt.GQLCoreSchema, document_ast: graphql.Document, *, operation_name: Optional[str]=None, variables: Optional[Mapping[str, Any]]=None, substitutions: Optional[dict[str, tuple[str, int, int]]], extracted_variables: Optional[Mapping[str, Any]], native_input: bool = False, ) -> TranspiledOperation: # If no variables have been provided, and we are in native # protocol mode, that means we need to be tolerant of critical # variables not having values. We'll report out which ones # existed, and then recompile with the variables present. parse_only_mode = native_input and variables is None if variables is None: variables = {} # The normalizer tries to handle default variables on its own, # which we still don't support in native mode. # Detect what it does and reject it. if ( native_input and len(extracted_variables or ()) > len(substitutions or ()) ): raise errors.UnsupportedFeatureError( 'Default variables are not supported on the gel protocol' ) validation_errors = convert_errors( graphql.validate(gqlcore.graphql_schema, document_ast), substitutions=substitutions) if validation_errors: err = validation_errors[0] if isinstance(err, graphql.GraphQLError): # possibly add additional information and/or hints to the # error message msg = augment_error_message(gqlcore, err.message) err_loc = (err.locations[0].line, err.locations[0].column) raise g_errors.GraphQLCoreError(msg, loc=err_loc) else: raise err context = GraphQLTranslatorContext( gqlcore=gqlcore, query=None, variables=variables, document_ast=document_ast, operation_name=operation_name, native_input=native_input, parse_only_mode=parse_only_mode, ) translator = GraphQLTranslator(context=context) edge_forest_map = translator.visit(document_ast) if debug.flags.graphql_compile: for opname, op in sorted(edge_forest_map.items()): print(f'== operationName: {opname!r} =============') print(ql_codegen.generate_source(op.stmt)) op = next(iter(edge_forest_map.values())) if native_input: used_vars = { p.name for p in visitor.find_children(op.stmt, qlast.QueryParameter) } unused_vars = op.vars.keys() - used_vars if unused_vars: op.stmt.orderby = [ qlast.Tuple( elements=[ translator._get_variable(vn) for vn in sorted(unused_vars) ] ) ] # generate the specific result return TranspiledOperation( edgeql_ast=op.stmt, cache_deps_vars=frozenset(op.critvars) if op.critvars else None, variables_desc=op.vars, ) def augment_error_message(gqlcore: gt.GQLCoreSchema, message: str): # If the error is about wrong Query field, we can add more details # about what seems to have gone wrong. The type is missing, # possibly because this connection is to the wrong DB. However, # this is only relevant if the message doesn't contain a hint already. if (re.match(r"^Cannot query field '(.+?)' on type 'Query'\.$", message)): field = message.split("'", 2)[1] name = gqlcore.gql_to_edb_name(field) message += ( f' There\'s no corresponding type or alias "{name}" exposed in ' 'Gel. Please check the configuration settings for this port ' 'to make sure that you\'re connecting to the right database.' ) return message def convert_default( node: gql_ast.ValueNode, varname: str ) -> str | float | int | bool: if isinstance(node, (gql_ast.StringValueNode, gql_ast.BooleanValueNode, gql_ast.EnumValueNode)): return node.value elif isinstance(node, gql_ast.IntValueNode): return int(node.value) elif isinstance(node, gql_ast.FloatValueNode): return float(node.value) else: raise errors.QueryError( f"Only scalar defaults are allowed. " f"Variable {varname!r} has non-scalar default value.") ================================================ FILE: edb/graphql/types.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, ClassVar, Optional, cast, ) from functools import partial from graphql import ( GraphQLAbstractType, GraphQLSchema, GraphQLInputType, GraphQLNamedType, GraphQLOutputType, GraphQLObjectType, GraphQLWrappingType, GraphQLInterfaceType, GraphQLInputObjectType, GraphQLResolveInfo, GraphQLField, GraphQLInputField, GraphQLArgument, GraphQLList, GraphQLNonNull, GraphQLString, GraphQLInt, GraphQLFloat, GraphQLBoolean, GraphQLID, GraphQLEnumType, ) from graphql.type import GraphQLEnumValue, GraphQLScalarType from graphql.language import ast as gql_ast import itertools from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.edgeql import codegen from edb.edgeql.parser import parse_fragment from edb.schema import modules as s_mod from edb.schema import name as s_name from edb.schema import pointers as s_pointers from edb.schema import objtypes as s_objtypes from edb.schema import scalars as s_scalars from edb.schema import schema as s_schema from edb.schema import types as s_types from edb.schema import utils as s_utils from . import errors as g_errors ''' This module is responsible for mapping Gel types onto the GraphQL types. However, this is an imperfect mapping because not all the types or relationships between them can be expressed. # Aliased Types Aliased types present a particular problem. Basically, they break inheritance and GraphQL fragments become useless for them. Consider a link friends in this alias: ``` type User {... multi link friends -> User} type SpecialUser extending User {property special-> str} alias UserAlias := User {friends: {some_new_prop := 'foo'}} ``` In GraphQL we will have type UserType implementing interfaces User, Object and SpecialUserType implementing SpecialUser, User, Object. The trouble starts with our implicit aliased type for the aliased friends link that targets __UserAlias__friends. This type gets reflected into a GraphQL _edb__UserAlias__friends that implements... what? We have 2 options: 1) Implement interfaces mirroring the Gel types: User, Object. 2) Implement it's own interface (or just omit interfaces here, since if the interface is unique it's not adding anything). Case 2) preserves all the fields defined in the alias to be accessible and filterable, etc., but loses inheritance information. Case 1) leads to the following additional choices: a) The friends target is still interface User, then the field some_new_prop will not appear in the nesting, but will require a specialized fragment: query { UserAlias { friends { ... on _edb__UserAlias_friends { some_new_prop } } } } b) We can make the target of friends of UserAlias to be the actual type _edb__UserAlias__friends (similar to case 2), but then we cannot use the typed fragment `... on SpecialUser` construct inside it because the SpecialUser is a sibling of our aliased type and will cause a GraphQL validation error. c) The field target can be a union type, but it will still only have fields that are common to all union members and will require awkward inlined typed fragments to work with just like the first bullet point. In the end I think that rather than preserving the inheritance and then essentially forcing the use of `... on _edb__UserAlias_friends` just to access the very field for which the alias was created in the first place it's better to bite the bullet accept that in GraphQL aliased types reflection removes all inheritance info, but at least provide all the fields as per normal. The reasoning being that the fields and data are probably much more important for practical purposes than inheritance purity. To allow the SpecialUser polymorphism I'd rather suggest to the user to bake it into the alias like so: ``` alias UserAlias := User { friends: { some_new_prop := 'foo', [IS SpecialUser].special, } } ``` Also, note that for the same reasons as outlined above that make aliased types into a sibling branch in GraphQL, we can't give an accurate __typename for them, like we can in EdgeQL (__type__.name) because the "correct" types would violate the declared GraphQL type hierarchy. So aliased types are necessarily opaque in GraphQL in all these ways. Unlike in EdgeQL. ''' def coerce_int64(value: Any) -> int: if isinstance(value, int): num = value else: num = int(value) if s_utils.MIN_INT64 <= num <= s_utils.MAX_INT64: return num raise Exception( f"Int64 cannot represent non 64-bit signed integer value: {value}") def coerce_bigint(value: Any) -> int: if isinstance(value, int): num = value else: num = int(value) return num def parse_int_literal( ast: gql_ast.Node, _variables: Optional[dict[str, Any]] = None, ) -> Optional[int]: if isinstance(ast, gql_ast.IntValueNode): return int(ast.value) else: return None GraphQLInt64 = GraphQLScalarType( name="Int64", description="The `Int64` scalar type represents non-fractional signed " "whole numeric values. Int can represent values between " "-2^63 and 2^63 - 1.", serialize=coerce_int64, parse_value=coerce_int64, parse_literal=parse_int_literal, ) GraphQLBigint = GraphQLScalarType( name="Bigint", description="The `Bigint` scalar type represents non-fractional signed " "whole numeric values.", serialize=coerce_bigint, parse_value=coerce_bigint, parse_literal=parse_int_literal, ) GraphQLJSON = GraphQLScalarType( name="JSON", description="The `JSON` scalar type represents arbitrary JSON values.", ) def parse_decimal_literal( ast: gql_ast.Node, _variables: Optional[dict[str, Any]] = None, ) -> Optional[float]: if isinstance(ast, (gql_ast.FloatValueNode, gql_ast.IntValueNode)): return float(ast.value) else: return None GraphQLDecimal = GraphQLScalarType( name="Decimal", description="The `Decimal` scalar type represents signed " "unlimited-precision fractional values.", serialize=GraphQLFloat.serialize, parse_value=GraphQLFloat.parse_value, parse_literal=parse_decimal_literal, ) EDB_TO_GQL_SCALARS_MAP = { # For compatibility with GraphQL we cast json into a String, since # GraphQL doesn't have an equivalent type with arbitrary fields. 'std::json': GraphQLJSON, 'std::str': GraphQLString, 'std::anyint': GraphQLInt, 'std::int16': GraphQLInt, 'std::int32': GraphQLInt, 'std::int64': GraphQLInt64, 'std::bigint': GraphQLBigint, 'std::anyfloat': GraphQLFloat, 'std::float32': GraphQLFloat, 'std::float64': GraphQLFloat, 'std::anyreal': GraphQLFloat, 'std::decimal': GraphQLDecimal, 'std::bool': GraphQLBoolean, 'std::uuid': GraphQLID, 'std::datetime': GraphQLString, 'std::duration': GraphQLString, 'std::bytes': None, 'std::cal::local_datetime': GraphQLString, 'std::cal::local_date': GraphQLString, 'std::cal::local_time': GraphQLString, 'std::cal::relative_duration': GraphQLString, 'std::cal::date_duration': GraphQLString, } # used for casting input values from GraphQL to EdgeQL GQL_TO_EDB_SCALARS_MAP = { 'String': 'str', 'Int': 'int32', 'Int64': 'int64', 'Bigint': 'bigint', 'Float': 'float64', 'Decimal': 'decimal', 'Boolean': 'bool', 'ID': 'uuid', 'JSON': 'json', } GQL_TO_EDB_RANGES_MAP = { 'RangeOfString': 'json', 'RangeOfInt': 'json', 'RangeOfInt64': 'json', 'RangeOfFloat': 'json', 'RangeOfDecimal': 'json', } GQL_TO_OPS_MAP = { 'exists': 'EXISTS', 'in': 'IN', 'eq': '=', 'neq': '!=', 'gt': '>', 'gte': '>=', 'lt': '<', 'lte': '<=', 'like': 'LIKE', 'ilike': 'ILIKE', } HIDDEN_MODULES = set(s_schema.STD_MODULES) - {s_name.UnqualName('std')} # The following are placeholders. TOP_LEVEL_TYPES = { s_name.QualName(module='__graphql__', name='Query'), s_name.QualName(module='__graphql__', name='Mutation'), } # The following types should not be exposed as all. HIDDEN_TYPES = { s_name.QualName(module='std', name='FreeObject'), } class GQLCoreSchema: _gql_interfaces: dict[ s_name.QualName, GraphQLInterfaceType, ] _gql_objtypes_from_alias: dict[ s_name.QualName, GraphQLObjectType, ] _gql_objtypes: dict[ s_name.QualName, GraphQLObjectType, ] _gql_inobjtypes: dict[ str, GraphQLInputObjectType | GraphQLEnumType | GraphQLScalarType ] _gql_ordertypes: dict[str, GraphQLInputType] _gql_enums: dict[str, GraphQLEnumType] _type_map: dict[tuple[str, bool], GQLBaseType] def __init__(self, edb_schema: s_schema.Schema) -> None: '''Create a graphql schema based on edgedb schema.''' self.edb_schema = edb_schema # extract and sort modules to have a consistent type ordering self.modules = list(sorted({ m.get_name(self.edb_schema) for m in self.edb_schema.get_objects(type=s_mod.Module) } - HIDDEN_MODULES)) self._gql_interfaces = {} self._gql_uniontypes: set[s_name.QualName] = set() self._gql_objtypes_from_alias = {} self._gql_objtypes = {} self._gql_inobjtypes = {} self._gql_ordertypes = {} self._gql_enums = {} self._define_types() # Use a fake name as a placeholder. Query = s_name.QualName(module='__graphql__', name='Query') query = self._gql_objtypes[Query] = GraphQLObjectType( name='Query', fields=self.get_fields(Query), ) # If a database only has abstract types and scalars, no # mutations will be possible (such as in a blank database), # but we would still want the reflection to work without # error, even if all that can be discovered through GraphQL # then is the schema. Mutation = s_name.QualName(module='__graphql__', name='Mutation') fields = self.get_fields(Mutation) if not fields: mutation = None else: mutation = self._gql_objtypes[Mutation] = GraphQLObjectType( name='Mutation', fields=fields, ) # get a sorted list of types relevant for the Schema types = [ objt for name, objt in itertools.chain(self._gql_objtypes.items(), self._gql_inobjtypes.items()) # the Query is included separately if name not in TOP_LEVEL_TYPES ] types = sorted(types, key=lambda x: x.name) self._gql_schema = GraphQLSchema( query=query, mutation=mutation, types=types) # this map is used for GQL -> EQL translator needs self._type_map = {} @property def edgedb_schema(self) -> s_schema.Schema: return self.edb_schema @property def graphql_schema(self) -> GraphQLSchema: return self._gql_schema def _get_type_gql_name(self, type: s_types.Type) -> str: if type.get_from_global(self.edb_schema): # The names of global types are mangled so use shortname instead typename = type.get_shortname(self.edb_schema) else: typename = type.get_name(self.edb_schema) assert isinstance(typename, s_name.QualName) return self.get_gql_name(typename) @classmethod def get_gql_name(cls, name: s_name.QualName) -> str: module, shortname = name.module, name.name # Adjust the shortname. if shortname.startswith('__'): # Use '_edb' prefix to mark derived and otherwise # internal types. We opt out of '__edb' because we # still rely on the first occurrence of '__' in # GraphQL names to separate the module from the rest # of the name in some code. shortname = f'_edb{shortname}' elif shortname.startswith('('): # Looks like a union type, so we'll need to process individual # parts of the name. names = [] for part in shortname[1:-1].split(' | '): names.append( cls.get_gql_name(s_name.QualName(*part.split(':', 1)))) shortname = '_OR_'.join(names) if module in {'default', 'std'}: return shortname else: assert module != '', f'get_gl_name {name=}' return str(name).replace("::", "__") def get_input_name(self, inputtype: str, name: str) -> str: if '__' in name: module, shortname = name.rsplit('__', 1) assert module != '', f'get_input_name {name=}' return f'{module}__{inputtype}{shortname}' else: return f'{inputtype}{name}' def gql_to_edb_name(self, name: str) -> str: '''Convert the GraphQL field name into a Gel type/view name.''' if '__' in name: return name.replace('__', '::') else: return name def _get_description(self, edb_type: s_types.Type) -> Optional[str]: description_anno = edb_type.get_annotations(self.edb_schema).get( self.edb_schema, s_name.QualName('std', 'description'), None) if description_anno is not None: return description_anno.get_value(self.edb_schema) return None def _convert_edb_type( self, edb_target: s_types.Type, ) -> Optional[GraphQLOutputType]: target: Optional[GraphQLOutputType] = None if isinstance(edb_target, s_types.Array): subtype = edb_target.get_subtypes(self.edb_schema)[0] el_type = self._convert_edb_type(subtype) if el_type is None: # we can't expose an array of unexposable type return el_type else: target = GraphQLList(GraphQLNonNull(el_type)) elif isinstance(edb_target, (s_types.Range, s_types.MultiRange)): # Represent ranges and multiranges as JSON. Same as reason as for # tuples: the values are atomic and cannot be fragmented via # GraphQL specification, so we cannot use objects with fields to # represent them. target = EDB_TO_GQL_SCALARS_MAP['std::json'] elif edb_target.is_view(self.edb_schema): tname = edb_target.get_name(self.edb_schema) assert isinstance(tname, s_name.QualName) target = self._gql_objtypes.get(tname) elif isinstance(edb_target, s_objtypes.ObjectType): target = self._gql_interfaces.get( edb_target.get_name(self.edb_schema), self._gql_objtypes.get(edb_target.get_name(self.edb_schema)) ) elif ( isinstance(edb_target, s_scalars.ScalarType) and edb_target.is_enum(self.edb_schema) ): name = self._get_type_gql_name(edb_target) if name in self._gql_enums: target = self._gql_enums.get(name) elif edb_target.is_tuple(self.edb_schema): # Represent tuples as JSON. target = EDB_TO_GQL_SCALARS_MAP['std::json'] elif isinstance(edb_target, s_types.InheritingType): base_target = edb_target.get_topmost_concrete_base(self.edb_schema) bt_name = base_target.get_name(self.edb_schema) try: target = EDB_TO_GQL_SCALARS_MAP[str(bt_name)] except KeyError: # this is the scalar base case, where all potentially # unrecognized scalars should end up edb_typename = edb_target.get_verbosename(self.edb_schema) raise g_errors.GraphQLCoreError( f"could not convert {edb_typename!r} type to" f" a GraphQL type") else: raise AssertionError(f'unexpected schema object: {edb_target!r}') return target def _get_target( self, ptr: s_pointers.Pointer, ) -> Optional[GraphQLOutputType]: edb_target = ptr.get_target(self.edb_schema) if edb_target is None: raise AssertionError(f'unexpected abstract pointer: {ptr!r}') target = self._convert_edb_type(edb_target) if target is not None: # figure out any additional wrappers due to cardinality # and required flags target = self._wrap_output_type(ptr, target) return target def _wrap_output_type( self, ptr: s_pointers.Pointer, target: GraphQLOutputType, *, ignore_required: bool = False, ) -> GraphQLOutputType: # figure out any additional wrappers due to cardinality # and required flags if not ptr.singular(self.edb_schema): target = GraphQLList(GraphQLNonNull(target)) if not ignore_required: # for input values having a default cancels out being required if ptr.get_required(self.edb_schema): target = GraphQLNonNull(target) return target def _wrap_input_type( self, ptr: s_pointers.Pointer, target: GraphQLInputType, *, ignore_required: bool = False, ) -> GraphQLInputType: # figure out any additional wrappers due to cardinality # and required flags if not ptr.singular(self.edb_schema): target = GraphQLList(GraphQLNonNull(target)) if not ignore_required: if ( ptr.get_required(self.edb_schema) and ptr.get_default(self.edb_schema) is None ): target = GraphQLNonNull(target) return target def _get_query_args( self, typename: s_name.QualName, ) -> dict[str, GraphQLArgument]: return { 'filter': GraphQLArgument(self._gql_inobjtypes[str(typename)]), 'order': GraphQLArgument(self._gql_ordertypes[str(typename)]), 'first': GraphQLArgument(GraphQLInt), 'last': GraphQLArgument(GraphQLInt), # before and after are supposed to be opaque values # serialized to string 'before': GraphQLArgument(GraphQLString), 'after': GraphQLArgument(GraphQLString), } def _get_insert_args( self, typename: s_name.QualName, ) -> dict[str, GraphQLArgument]: # The data can only be a specific non-interface type, if no # such type exists, skip it as we cannot accept unambiguous # data input. It's still possible to just select some existing # data. intype = self._gql_inobjtypes.get(f'Insert{typename}') if intype is None: return {} return { 'data': GraphQLArgument( GraphQLNonNull(GraphQLList(GraphQLNonNull(intype)))), } def _get_update_args( self, typename: s_name.QualName, ) -> dict[str, GraphQLArgument]: # some types have no updates uptype = self._gql_inobjtypes.get(f'Update{typename}') if uptype is None: return {} # the update args are same as for query + data args = self._get_query_args(typename) args['data'] = GraphQLArgument(GraphQLNonNull(uptype)) return args def get_fields( self, typename: s_name.QualName, ) -> dict[str, GraphQLField]: fields = {} if str(typename) == '__graphql__::Query': # The fields here will come from abstract types and aliases. queryable: list[tuple[s_name.QualName, GraphQLNamedType]] = [] queryable.extend(self._gql_interfaces.items()) queryable.extend(self._gql_objtypes_from_alias.items()) queryable.sort(key=lambda x: x[1].name) for name, gqliface in queryable: # '_edb' prefix indicates an internally generated type # (e.g. nested aliased type), which should not be # exposed as a top-level query option. if name in TOP_LEVEL_TYPES or gqliface.name.startswith('_edb'): continue # Check that the underlying type is not a union type. if name in self._gql_uniontypes: continue fields[gqliface.name] = GraphQLField( GraphQLList(GraphQLNonNull(gqliface)), args=self._get_query_args(name), ) elif str(typename) == '__graphql__::Mutation': # Get a list of alias names, so that we don't generate inserts for # them. aliases = {t.name for t in self._gql_objtypes_from_alias.values()} for name, gqltype in sorted(self._gql_objtypes.items(), key=lambda x: x[1].name): # '_edb' prefix indicates an internally generated type # (e.g. nested aliased type), which should not be # exposed as a top-level mutation option. if name in TOP_LEVEL_TYPES or gqltype.name.startswith('_edb'): continue edb_type = self.edb_schema.get(name, type=s_types.Type) gname = self._get_type_gql_name(edb_type) fields[f'delete_{gname}'] = GraphQLField( GraphQLList(GraphQLNonNull(gqltype)), args=self._get_query_args(name), ) if gname in aliases: # Aliases can only have delete mutations continue args = self._get_insert_args(name) fields[f'insert_{gname}'] = GraphQLField( GraphQLList(GraphQLNonNull(gqltype)), args=args, ) for name, gqliface in sorted(self._gql_interfaces.items(), key=lambda x: x[1].name): if (name in TOP_LEVEL_TYPES or gqliface.name.startswith('_edb') or f'Update{name}' not in self._gql_inobjtypes): continue edb_type = self.edb_schema.get(name, type=s_types.Type) gname = self._get_type_gql_name(edb_type) args = self._get_update_args(name) if args: # If there are no args, there's nothing to update. fields[f'update_{gname}'] = GraphQLField( GraphQLList(GraphQLNonNull(gqliface)), args=args, ) else: edb_type = self.edb_schema.get( typename, type=s_objtypes.ObjectType, ) pointers = edb_type.get_pointers(self.edb_schema) for unqual_pn, ptr in sorted(pointers.items(self.edb_schema)): pn = str(unqual_pn) if pn == '__type__': continue assert isinstance(ptr, s_pointers.Pointer) tgt = ptr.get_target(self.edb_schema) assert tgt is not None # Aliased types ignore their ancestors in order to # allow all their fields appear properly in the # filters. # # If the target is not a view, but this is computed, # so we cannot later override it, thus we can use the # type as is. if ( not tgt.is_view(self.edb_schema) and not ptr.is_pure_computable(self.edb_schema) ): # We want to look at the pointer lineage because that # will be reflected into GraphQL interface that is # being extended and the type cannot be changed. ancestors: tuple[s_pointers.Pointer, ...] ancestors = ptr.get_ancestors( self.edb_schema).objects(self.edb_schema) # We want the first non-generic ancestor of this # pointer as its target type will dictate the target # types of all its derived pointers. # # NOTE: We're guaranteed to have a non-generic one # since we're inspecting the lineage of a pointer # belonging to an actual type. for ancestor in reversed((ptr,) + ancestors): if not ancestor.is_non_concrete(self.edb_schema): ptr = ancestor break target = self._get_target(ptr) if target is not None: ptgt = ptr.get_target(self.edb_schema) if not isinstance(ptgt, s_objtypes.ObjectType): objargs = None else: objargs = self._get_query_args( ptgt.get_name(self.edb_schema)) fields[pn] = GraphQLField(target, args=objargs) return fields def get_filter_fields( self, typename: s_name.QualName, nested: bool = False, ) -> dict[str, GraphQLInputField]: selftype = self._gql_inobjtypes[str(typename)] fields = {} if not nested: fields['and'] = GraphQLInputField( GraphQLList(GraphQLNonNull(selftype))) fields['or'] = GraphQLInputField( GraphQLList(GraphQLNonNull(selftype))) fields['not'] = GraphQLInputField(selftype) else: # Always include the 'exists' operation fields['exists'] = GraphQLInputField(GraphQLBoolean) edb_type = self.edb_schema.get(typename, type=s_objtypes.ObjectType) pointers = edb_type.get_pointers(self.edb_schema) names = sorted(pointers.keys(self.edb_schema)) for unqual_name in names: name = str(unqual_name) if name == '__type__': continue if name in fields: raise g_errors.GraphQLCoreError( f"{name!r} of {typename} clashes with special " "reserved fields required for GraphQL conversion" ) ptr = edb_type.getptr(self.edb_schema, unqual_name) edb_target = ptr.get_target(self.edb_schema) assert edb_target is not None if isinstance(edb_target, s_objtypes.ObjectType): t_name = edb_target.get_name(self.edb_schema) gql_name = self.get_input_name( 'NestedFilter', self._get_type_gql_name(edb_target)) intype = self._gql_inobjtypes.get(gql_name) if intype is None: # construct a nested insert type intype = GraphQLInputObjectType( name=gql_name, fields=partial(self.get_filter_fields, t_name, True), ) self._gql_inobjtypes[gql_name] = intype elif not edb_target.is_scalar(): continue else: target = self._convert_edb_type(edb_target) if target is None: # don't expose this continue if isinstance(target, GraphQLNamedType): intype = self._gql_inobjtypes.get(f'Filter{target.name}') else: raise AssertionError( f'unexpected GraphQL type: {target!r}' ) if intype: fields[name] = GraphQLInputField(intype) return fields def get_insert_fields( self, typename: s_name.QualName, ) -> dict[str, GraphQLInputField]: fields = {} edb_type = self.edb_schema.get(typename, type=s_objtypes.ObjectType) pointers = edb_type.get_pointers(self.edb_schema) names = sorted(pointers.keys(self.edb_schema)) for unqual_name in names: name = str(unqual_name) if name in {'__type__', 'id'}: continue ptr = edb_type.getptr(self.edb_schema, unqual_name) edb_target = ptr.get_target(self.edb_schema) intype: GraphQLInputType if ptr.is_pure_computable(self.edb_schema): # skip computed pointer continue elif isinstance(edb_target, s_objtypes.ObjectType): typename = edb_target.get_name(self.edb_schema) inobjtype = self._gql_inobjtypes.get(f'NestedInsert{typename}') if inobjtype is not None: intype = inobjtype else: # construct a nested insert type intype = self._make_generic_nested_insert_type(edb_target) intype = self._wrap_input_type(ptr, intype) fields[name] = GraphQLInputField(intype) elif ( edb_target and edb_target.contains_array_of_tuples(self.edb_schema) ): # Can't insert array> continue elif ( isinstance(edb_target, s_scalars.ScalarType) or isinstance(edb_target, s_types.Array) ): target = self._convert_edb_type(edb_target) if target is None: # don't expose this continue if isinstance(target, GraphQLList): # Check whether the edb_target is an array of enums, # because enums need slightly different handling. assert isinstance(edb_target, s_types.Array) el = edb_target.get_element_type(self.edb_schema) if el.is_enum(self.edb_schema): tname = el.get_name(self.edb_schema) assert isinstance(tname, s_name.QualName) else: tname = target.of_type.of_type.name inobjtype = self._gql_inobjtypes.get(f'Insert{tname}') assert inobjtype is not None intype = GraphQLList(GraphQLNonNull(inobjtype)) elif edb_target.is_enum(self.edb_schema): enum_name = edb_target.get_name(self.edb_schema) assert isinstance(enum_name, s_name.QualName) inobjtype = self._gql_inobjtypes.get(f'Insert{enum_name}') assert inobjtype is not None intype = inobjtype elif isinstance(target, GraphQLNamedType): inobjtype = self._gql_inobjtypes.get( f'Insert{target.name}') assert inobjtype is not None intype = inobjtype else: raise AssertionError( f'unexpected GraphQL type" {target!r}' ) intype = self._wrap_input_type(ptr, intype) if intype: fields[name] = GraphQLInputField(intype) elif isinstance(edb_target, s_types.Range): subtype = edb_target.get_subtypes(self.edb_schema)[0] intype = self.get_input_range_type(subtype) intype = self._wrap_input_type(ptr, intype) fields[name] = GraphQLInputField(intype) elif isinstance(edb_target, s_types.MultiRange): subtype = edb_target.get_subtypes(self.edb_schema)[0] intype = GraphQLList(GraphQLNonNull( self.get_input_range_type(subtype))) intype = self._wrap_input_type(ptr, intype) fields[name] = GraphQLInputField(intype) else: continue return fields def get_update_fields( self, typename: s_name.QualName, ) -> dict[str, GraphQLInputField]: fields = {} edb_type = self.edb_schema.get(typename, type=s_objtypes.ObjectType) pointers = edb_type.get_pointers(self.edb_schema) names = sorted(pointers.keys(self.edb_schema)) # This is just a heavily re-used type variable target: GraphQLInputType | Optional[GraphQLOutputType] for unqual_name in names: name = str(unqual_name) if name == '__type__': continue ptr = edb_type.getptr(self.edb_schema, unqual_name) edb_target = ptr.get_target(self.edb_schema) if ptr.is_pure_computable(self.edb_schema): # skip computed pointer continue elif isinstance(edb_target, s_objtypes.ObjectType): intype = self._gql_inobjtypes.get( f'UpdateOp{typename}__{name}') if intype is None: # the links can only be updated by selecting some # objects, meaning that the basis is the same as for # query of whatever is the link type intype = self._gql_inobjtypes.get( f'NestedUpdate{edb_target.get_name(self.edb_schema)}') if intype is None: # construct a nested insert type intype = self._make_generic_nested_update_type( edb_target) # depending on whether this is a multilink or not wrap # it in a List intype = cast( GraphQLInputObjectType, self._wrap_input_type( ptr, intype, ignore_required=True), ) # wrap into additional layer representing update ops intype = self._make_generic_update_op_type( ptr, name, edb_type, intype) fields[name] = GraphQLInputField(intype) elif ( edb_target and edb_target.contains_array_of_tuples(self.edb_schema) ): # Can't update array> continue elif isinstance( edb_target, ( s_scalars.ScalarType, s_types.Array, ) ): target = self._convert_edb_type(edb_target) if target is None or ptr.get_readonly(self.edb_schema): # don't expose this continue intype = self._gql_inobjtypes.get( f'UpdateOp{typename}__{name}') if intype is None: # construct a nested insert type assert isinstance( target, ( GraphQLScalarType, GraphQLEnumType, GraphQLInputObjectType, GraphQLWrappingType, ), ), f'got {target!r}, expected GraphQLInputType' intype = self._make_generic_update_op_type( ptr, fname=name, edb_base=edb_type, target=self._wrap_input_type( ptr, target, ignore_required=True, ), ) if intype: fields[name] = GraphQLInputField(intype) elif isinstance( edb_target, ( s_types.Range, s_types.MultiRange, ) ): subtype = edb_target.get_subtypes(self.edb_schema)[0] target = self.get_input_range_type(subtype) if isinstance(edb_target, s_types.MultiRange): target = GraphQLList(GraphQLNonNull(target)) intype = self._gql_inobjtypes.get( f'UpdateOp{typename}__{name}') if intype is None: # construct a nested insert type intype = self._make_generic_update_op_type( ptr, fname=name, edb_base=edb_type, target=self._wrap_input_type( ptr, target, ignore_required=True, ), ) if intype: fields[name] = GraphQLInputField(intype) else: continue return fields def _make_generic_update_op_type( self, ptr: s_pointers.Pointer, fname: str, edb_base: s_types.Type, target: GraphQLInputType, ) -> GraphQLInputObjectType: typename = edb_base.get_name(self.edb_schema) assert isinstance(typename, s_name.QualName) name = f'UpdateOp{typename}__{fname}' edb_target = ptr.get_target(self.edb_schema) fields = { 'set': GraphQLInputField(target) } # get additional commands based on the pointer type if not ptr.get_required(self.edb_schema): fields['clear'] = GraphQLInputField(GraphQLBoolean) bt_name: Optional[s_name.QualName] if isinstance(edb_target, s_scalars.ScalarType): base_target = edb_target.get_topmost_concrete_base(self.edb_schema) bt_name = base_target.get_name(self.edb_schema) else: bt_name = None # first check for this being a multi-link if not ptr.singular(self.edb_schema): fields['add'] = GraphQLInputField(target) fields['remove'] = GraphQLInputField(target) elif target in {GraphQLInt, GraphQLInt64, GraphQLBigint, GraphQLFloat, GraphQLDecimal}: # anything that maps onto the numeric types is a fair game fields['increment'] = GraphQLInputField(target) fields['decrement'] = GraphQLInputField(target) elif ( bt_name == s_name.QualName(module='std', name='str') or isinstance(edb_target, s_types.Array) ): # only actual strings and arrays have append, prepend and # slice ops fields['prepend'] = GraphQLInputField(target) fields['append'] = GraphQLInputField(target) # slice [from, to] fields['slice'] = GraphQLInputField( GraphQLList(GraphQLNonNull(GraphQLInt)) ) nitype = GraphQLInputObjectType( name=self.get_input_name( f'UpdateOp_{fname}_', self._get_type_gql_name(edb_base), ), fields=fields, ) self._gql_inobjtypes[name] = nitype return nitype def _make_generic_nested_update_type( self, edb_base: s_objtypes.ObjectType, ) -> GraphQLInputObjectType: typename = edb_base.get_name(self.edb_schema) name = f'NestedUpdate{typename}' nitype = GraphQLInputObjectType( name=self.get_input_name( 'NestedUpdate', self._get_type_gql_name(edb_base)), fields={ 'filter': GraphQLInputField( self._gql_inobjtypes[str(typename)]), 'order': GraphQLInputField( self._gql_ordertypes[str(typename)]), 'first': GraphQLInputField(GraphQLInt), 'last': GraphQLInputField(GraphQLInt), # before and after are supposed to be opaque values # serialized to string 'before': GraphQLInputField(GraphQLString), 'after': GraphQLInputField(GraphQLString), }, ) self._gql_inobjtypes[name] = nitype return nitype def _make_generic_nested_insert_type( self, edb_base: s_objtypes.ObjectType, ) -> GraphQLInputObjectType: typename = edb_base.get_name(self.edb_schema) name = f'NestedInsert{typename}' fields = { 'filter': GraphQLInputField( self._gql_inobjtypes[str(typename)]), 'order': GraphQLInputField( self._gql_ordertypes[str(typename)]), 'first': GraphQLInputField(GraphQLInt), 'last': GraphQLInputField(GraphQLInt), # before and after are supposed to be opaque values # serialized to string 'before': GraphQLInputField(GraphQLString), 'after': GraphQLInputField(GraphQLString), } # The data can only be a specific non-interface type, if no # such type exists, skip it as we cannot accept unambiguous # data input. It's still possible to just select some existing # data. data_t = self._gql_inobjtypes.get(f'Insert{typename}') if data_t: fields['data'] = GraphQLInputField(data_t) nitype = GraphQLInputObjectType( name=self.get_input_name( 'NestedInsert', self._get_type_gql_name(edb_base)), fields=fields, ) self._gql_inobjtypes[name] = nitype return nitype def define_enums(self) -> None: self._gql_enums['directionEnum'] = GraphQLEnumType( 'directionEnum', values=dict( ASC=GraphQLEnumValue(), DESC=GraphQLEnumValue() ), description='Enum value used to specify ordering direction.', ) self._gql_enums['nullsOrderingEnum'] = GraphQLEnumType( 'nullsOrderingEnum', values=dict( SMALLEST=GraphQLEnumValue(), BIGGEST=GraphQLEnumValue(), ), description='Enum value used to specify how nulls are ordered.', ) scalar_types = list( self.edb_schema.get_objects( included_modules=self.modules, type=s_scalars.ScalarType ), ) for st in scalar_types: enum_values = st.get_enum_values(self.edb_schema) if enum_values is not None: t_name = st.get_name(self.edb_schema) gql_name = self._get_type_gql_name(st) enum_type = GraphQLEnumType( gql_name, values={key: GraphQLEnumValue() for key in enum_values}, description=self._get_description(st), ) self._gql_enums[gql_name] = enum_type self._gql_inobjtypes[f'Insert{t_name}'] = enum_type def define_generic_filter_types(self) -> None: eq = ['eq', 'neq'] comp = eq + ['gte', 'gt', 'lte', 'lt'] string = comp + ['like', 'ilike'] self._make_generic_filter_type(GraphQLBoolean, eq) self._make_generic_filter_type(GraphQLID, eq) self._make_generic_filter_type(GraphQLInt, comp) self._make_generic_filter_type(GraphQLInt64, comp) self._make_generic_filter_type(GraphQLBigint, comp) self._make_generic_filter_type(GraphQLFloat, comp) self._make_generic_filter_type(GraphQLDecimal, comp) self._make_generic_filter_type(GraphQLString, string) self._make_generic_filter_type(GraphQLJSON, comp) for name, etype in self._gql_enums.items(): if name not in {'directionEnum', 'nullsOrderingEnum'}: self._make_generic_filter_type(etype, comp) def _make_generic_filter_type( self, base: GraphQLScalarType | GraphQLEnumType, ops: list[str], ) -> None: name = f'Filter{base.name}' fields = {} # Always include the 'exists' operation fields['exists'] = GraphQLInputField(GraphQLBoolean) # Always include the 'in' operation fields['in'] = GraphQLInputField(GraphQLList(GraphQLNonNull(base))) for op in ops: fields[op] = GraphQLInputField(base) self._gql_inobjtypes[name] = GraphQLInputObjectType( name=name, fields=fields, ) def define_generic_insert_types(self) -> None: for itype in [ GraphQLBoolean, GraphQLID, GraphQLInt, GraphQLInt64, GraphQLBigint, GraphQLFloat, GraphQLDecimal, GraphQLString, GraphQLJSON, ]: self._gql_inobjtypes[f'Insert{itype.name}'] = itype def define_generic_order_types(self) -> None: self._gql_ordertypes['directionEnum'] = self._gql_enums['directionEnum'] self._gql_ordertypes['nullsOrderingEnum'] = self._gql_enums[ 'nullsOrderingEnum' ] self._gql_ordertypes['Ordering'] = GraphQLInputObjectType( 'Ordering', fields=dict( dir=GraphQLInputField( GraphQLNonNull(self._gql_enums['directionEnum']), ), nulls=GraphQLInputField( self._gql_enums['nullsOrderingEnum'], default_value='SMALLEST', ), ) ) def get_order_fields( self, typename: s_name.QualName, ) -> dict[str, GraphQLInputField]: fields: dict[str, GraphQLInputField] = {} edb_type = self.edb_schema.get(typename, type=s_objtypes.ObjectType) pointers = edb_type.get_pointers(self.edb_schema) names = sorted(pointers.keys(self.edb_schema)) for unqual_name in names: name = str(unqual_name) if name == '__type__': continue ptr = edb_type.getptr(self.edb_schema, unqual_name) if not ptr.singular(self.edb_schema): continue t = ptr.get_target(self.edb_schema) assert t is not None target = self._convert_edb_type(t) if target is None: # Don't expose this continue if isinstance(t, s_scalars.ScalarType): assert isinstance(target, GraphQLNamedType) # This makes sure that we can only order by properties # that can be reflected into GraphQL intype = self._gql_inobjtypes.get(f'Filter{target.name}') if intype: fields[name] = GraphQLInputField( self._gql_ordertypes['Ordering'] ) elif isinstance(t, s_objtypes.ObjectType): # It's a link so we need the link's type order input t_name = t.get_name(self.edb_schema) fields[name] = GraphQLInputField( self._gql_ordertypes[str(t_name)] ) else: # We ignore pointers that aren't scalars or objects. pass return fields def get_input_range_type( self, subtype: s_types.Type ) -> GraphQLInputObjectType: sub_gqltype = self._convert_edb_type(subtype) assert isinstance(sub_gqltype, GraphQLScalarType) r_name = f'RangeOf{sub_gqltype.name}' # Check the type cache... if (res := self._gql_inobjtypes.get(r_name)) is not None: assert isinstance(res, GraphQLInputObjectType) return res gqltype = GraphQLInputObjectType( name=r_name, fields=dict( lower=GraphQLInputField(sub_gqltype), inc_lower=GraphQLInputField(GraphQLBoolean), upper=GraphQLInputField(sub_gqltype), inc_upper=GraphQLInputField(GraphQLBoolean), empty=GraphQLInputField(GraphQLBoolean), ), description=f'Range of {sub_gqltype.name} values', ) self._gql_inobjtypes[r_name] = gqltype return gqltype def _define_types(self) -> None: interface_types = [] obj_types = [] from_union = {} self.define_enums() self.define_generic_filter_types() self.define_generic_order_types() self.define_generic_insert_types() # Every ObjectType is reflected as an interface. interface_types = list( self.edb_schema.get_objects(included_modules=self.modules, type=s_objtypes.ObjectType)) # concrete types are also reflected as Type (with a '_Type' postfix) obj_types += [t for t in interface_types if not t.get_abstract(self.edb_schema)] # interfaces for t in interface_types: t_name = t.get_name(self.edb_schema) gql_name = self._get_type_gql_name(t) if t_name in HIDDEN_TYPES: continue if t.is_view(self.edb_schema): # The aliased types actually only reflect as an object # type, but the rest of the processing is identical to # interfaces. self._gql_objtypes_from_alias[t_name] = GraphQLObjectType( name=gql_name, fields=partial(self.get_fields, t_name), description=self._get_description(t), ) else: def _type_resolver( obj: GraphQLObjectType, info: GraphQLResolveInfo, _t: GraphQLAbstractType, ) -> GraphQLObjectType: return obj self._gql_interfaces[t_name] = GraphQLInterfaceType( name=gql_name, fields=partial(self.get_fields, t_name), resolve_type=_type_resolver, description=self._get_description(t), ) if t.is_union_type(self.edb_schema): # NOTE: EdgeDB union types and GraphQL union types are # different in some important ways. In EdgeDB a union object # type will have all the common links and properties that are # shared among the members of the union. In GraphQL a union # type has *no fields* at all and must be accessed via typed # fragments. Effectively, EdgeDB union types behave exactly # like GraphQL interfaces, though, which is why they will be # reflected more naturally as interfaces. # # We still need to internally keep track of which interfaces # are actually union types so that we don't create any # top-level Query or Mutation entires for union types, but # stick to only use them in the nested structures they # actually appear in. self._gql_uniontypes.add(t_name) for member in t.get_union_of(self.edb_schema) \ .names(self.edb_schema): # Union types must be interfaces for each of # the individual components so we need to record that. from_union[member] = t_name # input object types corresponding to this interface gqlfiltertype = GraphQLInputObjectType( name=self.get_input_name('Filter', gql_name), fields=partial(self.get_filter_fields, t_name), ) self._gql_inobjtypes[str(t_name)] = gqlfiltertype # ordering input type gqlordertype = GraphQLInputObjectType( name=self.get_input_name('Order', gql_name), fields=partial(self.get_order_fields, t_name), ) self._gql_ordertypes[str(t_name)] = gqlordertype # update object types corresponding to this object (all types # except views and union types can appear as update types) if not (t.is_view(self.edb_schema) or t.is_union_type(self.edb_schema)): # only objects that have at least one non-readonly # link/property are eligible pointers = t.get_pointers(self.edb_schema) if any(not p.get_readonly(self.edb_schema) and not p.is_pure_computable(self.edb_schema) for _, p in pointers.items(self.edb_schema)): gqlupdatetype = GraphQLInputObjectType( name=self.get_input_name('Update', gql_name), fields=partial(self.get_update_fields, t_name), ) self._gql_inobjtypes[f'Update{t_name}'] = gqlupdatetype # object types for t in obj_types: interfaces = [] t_name = t.get_name(self.edb_schema) gql_name = self._get_type_gql_name(t) if t_name in HIDDEN_TYPES: continue if t.is_view(self.edb_schema): # Just copy previously computed type. self._gql_objtypes[t_name] = \ self._gql_objtypes_from_alias[t_name] continue if t.is_union_type(self.edb_schema): continue if t_name in self._gql_interfaces: interfaces.append(self._gql_interfaces[t_name]) if t_name in from_union: interfaces.append(self._gql_interfaces[from_union[t_name]]) ancestors = t.get_ancestors(self.edb_schema) for st in ancestors.objects(self.edb_schema): if (st.is_object_type() and st.get_name(self.edb_schema) in self._gql_interfaces): interfaces.append( self._gql_interfaces[st.get_name(self.edb_schema)]) gqltype = GraphQLObjectType( name=f'{gql_name}_Type', fields=partial(self.get_fields, t_name), interfaces=interfaces, description=self._get_description(t), ) self._gql_objtypes[t_name] = gqltype # only objects that have at least one non-computed # link/property are eligible to be input objects pointers = t.get_pointers(self.edb_schema) if any(not p.is_pure_computable(self.edb_schema) for pname, p in pointers.items(self.edb_schema) if str(pname) not in {'__type__', 'id'}): # input object types corresponding to this object (only # real objects can appear as input objects) gqlinserttype = GraphQLInputObjectType( name=self.get_input_name('Insert', gql_name), fields=partial(self.get_insert_fields, t_name), ) self._gql_inobjtypes[f'Insert{t_name}'] = gqlinserttype def get(self, name: str, *, dummy: bool = False) -> GQLBaseType: '''Get a special GQL type either by name or based on Gel type.''' # normalize name and possibly add 'edb_base' to kwargs edb_base = None kwargs: dict[str, Any] = {'dummy': dummy} if not name.startswith('__graphql__::'): # The name may potentially contain the suffix "_Type", # which in 99% cases indicates that it's a GraphQL # internal type generated from the EdgeDB base type, but # we technically need to check both. if name.endswith('_Type'): names = [name[:-len('_Type')], name] else: names = [name] for tname in names: if edb_base is None: module: s_name.Name | str if '::' in tname: edb_base = self.edb_schema.get( tname, type=s_types.Type, ) elif '__' in tname: # Looks like it's coming from a specific module edb_base = self.edb_schema.get( f"{tname.replace('__', '::')}", type=s_types.Type, ) else: for module in self.modules: edb_base = self.edb_schema.get( f'{module}::{tname}', type=s_types.Type, default=None, ) if edb_base: break # XXX: find a better way to do this for stype in [s_types.Array, s_types.Tuple, s_types.Range, s_types.MultiRange]: if edb_base is None: edb_base = self.edb_schema.get_global( stype, tname, default=None ) else: break if edb_base is None: raise AssertionError( f'unresolved type: {name}') kwargs['edb_base'] = edb_base # check if the type already exists fkey = (name, dummy) gqltype = self._type_map.get(fkey) if not gqltype: _type = GQLTypeMeta.edb_map.get(name, GQLShadowType) gqltype = _type(schema=self, **kwargs) self._type_map[fkey] = gqltype return gqltype class GQLTypeMeta(type): edb_map: dict[str, type[GQLBaseType]] = {} def __new__( mcls, name: str, bases: tuple[type, ...], dct: dict[str, Any], ) -> GQLTypeMeta: cls = super().__new__(mcls, name, bases, dct) edb_type = dct.get('edb_type') if edb_type: mcls.edb_map[str(edb_type)] = cls # type: ignore return cls class GQLBaseType(metaclass=GQLTypeMeta): edb_type: ClassVar[Optional[s_name.QualName]] = None _edb_base: Optional[s_types.Type] _module: Optional[str] _fields: dict[tuple[str, bool], GQLBaseType] _shadow_fields: tuple[str, ...] def __init__( self, schema: GQLCoreSchema, *, name: Optional[str] = None, edb_base: Optional[s_types.Type] = None, dummy: bool = False, ) -> None: self._shadow_fields = () if edb_base is None: if self.edb_type: if self.edb_type.module == '__graphql__': edb_base_name = str(self.edb_type) else: edb_base = schema.edb_schema.get( self.edb_type, type=s_objtypes.ObjectType, ) edb_base_name = str(edb_base.get_name(schema.edb_schema)) else: raise AssertionError( f'neither the constructor, nor the class attribute ' f'define a required edb_base for {type(self)!r}', ) else: edb_base_name = str(edb_base.get_name(schema.edb_schema)) # __typename if name is None: self._name = edb_base_name else: self._name = name # determine module from name if not already specified if '::' in self._name: self._module = self._name.rsplit('::', 1)[0] else: self._module = None # what EdgeDB entity will be the root for queries, if any self._edb_base = edb_base self._schema = schema self._fields = {} # XXX clean up needed, but otherwise it means that the type is # used to validate the fields/types/args/etc., but is not # expected to generate non-empty results, so messy EQL is not # needed. self.dummy = dummy # JSON and bool need some special treatment so we want to know if # we're dealing with it if isinstance(edb_base, s_scalars.ScalarType): bt = edb_base.get_topmost_concrete_base(self.edb_schema) bt_name = str(bt.get_name(self.edb_schema)) self._is_json = bt_name == 'std::json' self._is_bool = bt_name == 'std::bool' self._is_float = edb_base.issubclass( self.edb_schema, self.edb_schema.get( 'std::anyfloat', type=s_scalars.ScalarType, ), ) else: self._is_json = self._is_bool = self._is_float = False @property def is_json(self) -> bool: return self._is_json @property def is_enum(self) -> bool: return False @property def is_bool(self) -> bool: return self._is_bool @property def is_float(self) -> bool: return self._is_float @property def is_array(self) -> bool: if self.edb_base is None: return False else: return self.edb_base.is_array() @property def is_range(self) -> bool: if self.edb_base is None: return False else: return self.edb_base.is_range() @property def is_multirange(self) -> bool: if self.edb_base is None: return False else: return self.edb_base.is_multirange() @property def is_object_type(self) -> bool: if self.edb_base is None: return False else: return self.edb_base.is_object_type() @property def name(self) -> str: return self._name @property def short_name(self) -> str: return self._name.split('::')[-1] @property def module(self) -> Optional[str]: return self._module @property def edb_base(self) -> Optional[s_types.Type]: return self._edb_base @property def edb_base_name_ast(self) -> Optional[qlast.ObjectRef]: if self.edb_base is None: return None if isinstance(self.edb_base, (s_types.Array, s_types.Range, s_types.MultiRange)): el = self.edb_base.get_element_type(self.edb_schema) base_name = el.get_name(self.edb_schema) assert isinstance(base_name, s_name.QualName) return qlast.ObjectRef( module=base_name.module, name=base_name.name, ) else: base_name = self.edb_base.get_name(self.edb_schema) assert isinstance(base_name, s_name.QualName) return qlast.ObjectRef( module=base_name.module, name=base_name.name, ) @property def edb_base_name(self) -> str: ast = self.edb_base_name_ast if ast is None: return '' else: return codegen.generate_source(ast) @property def gql_typename(self) -> str: name = self.name module, shortname = name.rsplit('::', 1) if self.edb_base is None: # We expect that this is one of the fake objects, that # only have an edb_type. assert self.edb_type is not None return self.edb_type.name elif self.edb_base.is_view(self.edb_schema): suffix = '' else: suffix = '_Type' if module in {'default', 'std'}: return f'{shortname}{suffix}' else: assert module != '', 'gql_typename ' + module return f'{name.replace("::", "__")}{suffix}' @property def schema(self) -> GQLCoreSchema: return self._schema @property def edb_schema(self) -> s_schema.Schema: return self._schema.edb_schema @edb_schema.setter def edb_schema(self, schema: s_schema.Schema) -> None: self._schema.edb_schema = schema def convert_edb_to_gql_type( self, base: s_types.Type | s_pointers.Pointer, **kwargs: Any, ) -> GQLBaseType: if isinstance(base, s_pointers.Pointer): tgt = base.get_target(self.edb_schema) assert tgt is not None base = tgt if self.dummy: kwargs['dummy'] = True return self.schema.get(str(base.get_name(self.edb_schema)), **kwargs) def is_field_shadowed(self, name: str) -> bool: return name in self._shadow_fields def get_field_type(self, name: str) -> Optional[GQLBaseType]: if self.dummy: return None # this is just shadowing a real EdgeDB type fkey = (name, self.dummy) target = self._fields.get(fkey) if target is None: # special handling of '__typename' if name == '__typename': target = self.convert_edb_to_gql_type( self.edb_schema.get( s_name.QualName( module='std', name='str', ), type=s_scalars.ScalarType, ), ) elif isinstance(self.edb_base, s_objtypes.ObjectType): ptr = self.edb_base.maybe_get_ptr( self.edb_schema, s_name.UnqualName(name), ) if ptr is not None: target = self.convert_edb_to_gql_type(ptr) if target is not None: self._fields[fkey] = target return target def has_native_field(self, name: str) -> bool: if isinstance(self.edb_base, s_objtypes.ObjectType): ptr = self.edb_base.maybe_get_ptr( self.edb_schema, s_name.UnqualName(name)) return ptr is not None else: return False def issubclass(self, other: Any) -> bool: if ( self.edb_base is not None and other.edb_base is not None and isinstance(other, GQLShadowType) ): return self.edb_base.issubclass( self._schema.edb_schema, other.edb_base ) else: return False def get_template( self, ) -> tuple[qlast.Base, Optional[qlast.Expr], Optional[qlast.SelectQuery]]: '''Provide an EQL AST template to be filled. Return the overall ast, a reference to where the shape element with placeholder is, and a reference to the element which may be filtered. ''' if self.dummy: return parse_fragment(f'''to_json("xxx")'''), None, None eql = parse_fragment(f''' SELECT {self.edb_base_name} {{ xxx }} ''') filterable = eql assert isinstance(filterable, qlast.SelectQuery) shape = filterable.result return eql, shape, filterable def get_field_template( self, name: str, *, parent: qlast.Base, has_shape: bool = False, ) -> tuple[ Optional[qlast.Base], Optional[qlast.Expr], Optional[qlast.SelectQuery], ]: eql = shape = filterable = None if self.dummy: return eql, shape, filterable if name == '__typename' and not self.is_field_shadowed(name): if self.edb_base is None: # We expect that this is one of the fake objects, that # only have an edb_type. assert self.edb_type is not None eql = parse_fragment(f'{self.edb_type.name!r}') elif self.edb_base.is_view(self.edb_schema): eql = parse_fragment(f'{self.gql_typename!r}') else: # Construct the GraphQL type name from the actual type name. eql = parse_fragment(fr''' WITH name := {codegen.generate_source(parent)} .__type__.name SELECT ( name[5:] IF name LIKE 'std::%' ELSE name[9:] IF name LIKE 'default::%' ELSE str_replace(name, '::', '__') ) ++ '_Type' ''') elif has_shape: eql = parse_fragment( f'''SELECT {codegen.generate_source(parent)}. {codegen.generate_source(qlast.ObjectRef(name=name))} {{ xxx }} ''') assert isinstance(eql, qlast.SelectQuery) filterable = eql shape = filterable.result else: eql = parse_fragment( f'''SELECT {codegen.generate_source(parent)}. {codegen.generate_source(qlast.ObjectRef(name=name))} ''') assert isinstance(eql, qlast.SelectQuery) filterable = eql return eql, shape, filterable def get_field_cardinality( self, name: str, ) -> Optional[qltypes.SchemaCardinality]: if not self.is_field_shadowed(name): return None elif isinstance(self.edb_base, s_objtypes.ObjectType): ptr = self.edb_base.getptr( self.edb_schema, s_name.UnqualName(name), ) if not ptr.singular(self.edb_schema): return qltypes.SchemaCardinality.Many return None class GQLShadowType(GQLBaseType): def is_field_shadowed(self, name: str) -> bool: if name == '__typename': return False ftype = self.get_field_type(name) # JSON fields are not shadowed if ftype is None: return False return True @property def is_enum(self) -> bool: if self.edb_base is None: return False else: return self.edb_base.is_enum(self.edb_schema) class GQLBaseQuery(GQLBaseType): def __init__( self, schema: GQLCoreSchema, *, name: Optional[str] = None, edb_base: Optional[s_types.Type] = None, dummy: bool = False, ) -> None: self.modules = schema.modules super().__init__(schema, name=name, edb_base=edb_base, dummy=dummy) # Record names of std built-in object types self._std_obj_names = [ t.get_name(self.edb_schema).name for t in self.edb_schema.get_objects( included_modules=[s_name.UnqualName('std')], type=s_objtypes.ObjectType, ) ] def get_module_and_name(self, name: str) -> tuple[str, ...]: if name in self._std_obj_names: return ('std', name) elif '__' in name: module, name = name.rsplit('__', 1) return (module.replace('__', '::'), name) else: return ('default', name) class GQLQuery(GQLBaseQuery): edb_type = s_name.QualName(module='__graphql__', name='Query') def get_field_type(self, name: str) -> Optional[GQLBaseType]: fkey = (name, self.dummy) target = None if name in {'__type', '__schema'}: if fkey in self._fields: return self._fields[fkey] target = self.schema.get(str(self.edb_type), dummy=True) else: target = super().get_field_type(name) if target is None: module, edb_name = self.get_module_and_name(name) edb_qname = s_name.QualName(module=module, name=edb_name) edb_type = self.edb_schema.get( edb_qname, default=None, type=s_types.Type, ) if edb_type is not None: target = self.convert_edb_to_gql_type(edb_type) if target is not None: self._fields[fkey] = target return target class GQLMutation(GQLBaseQuery): edb_type = s_name.QualName(module='__graphql__', name='Mutation') def get_field_type(self, name: str) -> Optional[GQLBaseType]: fkey = (name, self.dummy) target = None if name == '__typename': # It's a valid field that doesn't start with a command target = super().get_field_type(name) else: op, name = name.split('_', 1) if op in {'delete', 'insert', 'update'}: target = super().get_field_type(name) if target is None: module, edb_name = self.get_module_and_name(name) edb_qname = s_name.QualName(module=module, name=edb_name) edb_type = self.edb_schema.get( edb_qname, default=None, type=s_types.Type, ) if edb_type is not None: target = self.convert_edb_to_gql_type(edb_type) if target is not None: self._fields[fkey] = target return target ================================================ FILE: edb/graphql-rewrite/Cargo.toml ================================================ [package] name = "graphql-rewrite" version = "0.1.0" license = "MIT/Apache-2.0" authors = ["MagicStack Inc. "] edition = "2021" [lints] workspace = true [features] python_extension = ["pyo3/extension-module"] default = ["python_extension"] [dependencies] pyo3 = { workspace = true, optional = true } combine = "3.8" thiserror = "2" num-bigint = "0.4.3" num-traits = "0.2.11" edb-graphql-parser = { git="https://github.com/edgedb/graphql-parser", features = ["serde"] } serde = { version = "1.0.106", features = ["derive"] } bincode = { version = "1.3.3" } [dev-dependencies] pretty_assertions = "1.2.0" [lib] crate-type = ["lib", "cdylib"] name = "graphql_rewrite" path = "src/lib.rs" ================================================ FILE: edb/graphql-rewrite/_graphql_rewrite.pyi ================================================ from typing import Any, Optional class Entry: key: str key_vars: list[str] variables: dict[str, Any] substitutions: dict[str, tuple[str, int, int]] def tokens(self) -> list[tuple[Any, int, int, int, int, Any]]: ... def rewrite(operation: Optional[str], text: str) -> Entry: ... ================================================ FILE: edb/graphql-rewrite/src/lib.rs ================================================ #![cfg(feature = "python_extension")] mod py_entry; mod py_exception; mod py_token; mod rewrite; mod token_vec; pub use py_token::{PyToken, PyTokenKind}; pub use rewrite::{rewrite, Value, Variable}; use py_exception::{AssertionError, LexingError, NotFoundError, QueryError, SyntaxError}; use pyo3::{prelude::*, types::PyString}; /// Rust optimizer for graphql queries #[pymodule] fn _graphql_rewrite(py: Python, m: &Bound) -> PyResult<()> { m.add_function(wrap_pyfunction!(py_rewrite, m)?)?; m.add_function(wrap_pyfunction!(py_entry::unpack, m)?)?; m.add_class::()?; m.add("LexingError", py.get_type::())?; m.add("SyntaxError", py.get_type::())?; m.add("NotFoundError", py.get_type::())?; m.add("AssertionError", py.get_type::())?; m.add("QueryError", py.get_type::())?; Ok(()) } #[pyo3::pyfunction(name = "rewrite")] #[pyo3(signature = (operation, text))] fn py_rewrite( py: Python<'_>, operation: Option<&Bound>, text: &Bound, ) -> PyResult { // convert args let operation = operation.map(|x| x.to_string()); let text = text.to_string(); match rewrite::rewrite(operation.as_ref().map(|x| &x[..]), &text) { Ok(entry) => py_entry::convert_entry(py, entry), Err(e) => Err(py_exception::convert_error(e)), } } ================================================ FILE: edb/graphql-rewrite/src/py_entry.rs ================================================ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyInt, PyString, PyType}; use edb_graphql_parser::position::Pos; use crate::py_token::{self, PyToken}; use crate::rewrite::{self, Value}; #[pyclass] pub struct Entry { #[pyo3(get)] key: Py, #[pyo3(get)] variables: Py, #[pyo3(get)] substitutions: Py, _tokens: Vec, _end_pos: Pos, #[pyo3(get)] num_variables: usize, orig_entry: rewrite::Entry, } #[pymethods] impl Entry { fn tokens<'py>(&self, py: Python<'py>, kinds: Py) -> PyResult> { py_token::convert_tokens(py, &self._tokens, &self._end_pos, kinds) } fn pack(&self, py: Python) -> PyResult> { let mut buf = vec![1u8]; // type and version bincode::serialize_into(&mut buf, &self.orig_entry) .map_err(|e| PyValueError::new_err(format!("Failed to pack: {e}")))?; Ok(PyBytes::new(py, buf.as_slice()).into()) } } #[pyfunction] pub fn unpack(py: Python<'_>, serialized: &Bound) -> PyResult> { let buf = serialized.as_bytes(); match buf[0] { 1u8 => { let pack: rewrite::Entry = bincode::deserialize(&buf[1..]) .map_err(|e| PyValueError::new_err(format!("Failed to unpack: {e}")))?; let entry = convert_entry(py, pack)?; entry.into_pyobject(py).map(|e| e.unbind().into_any()) } _ => Err(PyValueError::new_err(format!( "Invalid type/version byte: {}", buf[0] ))), } } pub fn convert_entry(py: Python<'_>, entry: rewrite::Entry) -> PyResult { // import decimal let decimal_cls = PyModule::import(py, "decimal")?.getattr("Decimal")?; let vars = PyDict::new(py); let substitutions = PyDict::new(py); for (idx, var) in entry.variables.iter().enumerate() { let s = format!("__edb_arg_{idx}").into_pyobject(py)?; vars.set_item(&s, value_to_py(py, &var.value, &decimal_cls)?)?; substitutions.set_item( s, ( &var.token.value, var.token.position.map(|x| x.line), var.token.position.map(|x| x.column), ), )?; } for (name, var) in &entry.defaults { vars.set_item(name, value_to_py(py, &var.value, &decimal_cls)?)? } let orig_entry = entry.clone(); Ok(Entry { key: PyString::new(py, &entry.key).into(), variables: vars.into_pyobject(py)?.into(), substitutions: substitutions.into(), _tokens: entry.tokens, _end_pos: entry.end_pos, num_variables: entry.num_variables, orig_entry, }) } fn value_to_py(py: Python, value: &Value, decimal_cls: &Bound) -> PyResult> { let v = match value { Value::Str(ref v) => PyString::new(py, v).into_any(), Value::Int32(v) => v.into_pyobject(py)?.into_any(), Value::Int64(v) => v.into_pyobject(py)?.into_any(), Value::Decimal(v) => decimal_cls.call((v.as_str(),), None)?.into_any(), Value::BigInt(ref v) => PyType::new::(py) .call((v.as_str(),), None)? .into_any(), Value::Boolean(b) => b.into_pyobject(py)?.to_owned().into_any(), }; Ok(v.into()) } ================================================ FILE: edb/graphql-rewrite/src/py_exception.rs ================================================ use pyo3::{create_exception, exceptions::PyException, PyErr}; use crate::rewrite::Error; create_exception!(_graphql_rewrite, LexingError, PyException); create_exception!(_graphql_rewrite, SyntaxError, PyException); create_exception!(_graphql_rewrite, NotFoundError, PyException); create_exception!(_graphql_rewrite, AssertionError, PyException); create_exception!(_graphql_rewrite, QueryError, PyException); pub fn convert_error(error: Error) -> PyErr { match error { Error::Lexing(e) => LexingError::new_err(e), Error::Syntax(e) => SyntaxError::new_err(e.to_string()), Error::NotFound(e) => NotFoundError::new_err(e), Error::Query(e) => QueryError::new_err(e), Error::Assertion(e) => AssertionError::new_err(e), } } ================================================ FILE: edb/graphql-rewrite/src/py_token.rs ================================================ use edb_graphql_parser::common::{unquote_block_string, unquote_string}; use edb_graphql_parser::position::Pos; use edb_graphql_parser::tokenizer::Token; use pyo3::prelude::*; use pyo3::types::{PyList, PyString, PyTuple}; use std::borrow::Cow; use crate::py_exception::LexingError; use crate::rewrite::Error; #[derive(Debug, PartialEq, Copy, Clone, serde::Serialize, serde::Deserialize)] pub enum PyTokenKind { Sof, Eof, Bang, Dollar, ParenL, ParenR, Spread, Colon, Equals, At, BracketL, BracketR, BraceL, Pipe, BraceR, Name, Int, Float, String, BlockString, } #[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)] pub struct PyToken { pub kind: PyTokenKind, pub value: Cow<'static, str>, pub position: Option, } impl PyToken { pub fn new((token, position): &(Token<'_>, Pos)) -> Result { use edb_graphql_parser::tokenizer::Kind::*; use PyTokenKind as T; let (kind, value) = match (token.kind, token.value) { (IntValue, val) => (T::Int, Cow::Owned(val.into())), (FloatValue, val) => (T::Float, Cow::Owned(val.into())), (StringValue, val) => (T::String, Cow::Owned(val.into())), (BlockString, val) => (T::BlockString, Cow::Owned(val.into())), (Name, val) => (T::Name, Cow::Owned(val.into())), (Punctuator, "!") => (T::Bang, "!".into()), (Punctuator, "$") => (T::Dollar, "$".into()), (Punctuator, "(") => (T::ParenL, "(".into()), (Punctuator, ")") => (T::ParenR, ")".into()), (Punctuator, "...") => (T::Spread, "...".into()), (Punctuator, ":") => (T::Colon, ":".into()), (Punctuator, "=") => (T::Equals, "=".into()), (Punctuator, "@") => (T::At, "@".into()), (Punctuator, "[") => (T::BracketL, "[".into()), (Punctuator, "]") => (T::BracketR, "]".into()), (Punctuator, "{") => (T::BraceL, "{".into()), (Punctuator, "}") => (T::BraceR, "}".into()), (Punctuator, "|") => (T::Pipe, "|".into()), (Punctuator, _) => Err(Error::Assertion("unsupported punctuator".into()))?, }; Ok(PyToken { kind, value, position: Some(*position), }) } } pub fn convert_tokens<'py>( py: Python<'py>, tokens: &[PyToken], end_pos: &Pos, kinds: Py, ) -> PyResult> { use PyTokenKind as K; let sof = kinds.getattr(py, "SOF")?; let eof = kinds.getattr(py, "EOF")?; let bang = kinds.getattr(py, "BANG")?; let bang_v = "!".into_pyobject(py)?; let dollar = kinds.getattr(py, "DOLLAR")?; let dollar_v = "$".into_pyobject(py)?; let paren_l = kinds.getattr(py, "PAREN_L")?; let paren_l_v = "(".into_pyobject(py)?; let paren_r = kinds.getattr(py, "PAREN_R")?; let paren_r_v = ")".into_pyobject(py)?; let spread = kinds.getattr(py, "SPREAD")?; let spread_v = "...".into_pyobject(py)?; let colon = kinds.getattr(py, "COLON")?; let colon_v = ":".into_pyobject(py)?; let equals = kinds.getattr(py, "EQUALS")?; let equals_v = "=".into_pyobject(py)?; let at = kinds.getattr(py, "AT")?; let at_v = "@".into_pyobject(py)?; let bracket_l = kinds.getattr(py, "BRACKET_L")?; let bracket_l_v = "[".into_pyobject(py)?; let bracket_r = kinds.getattr(py, "BRACKET_R")?; let bracket_r_v = "]".into_pyobject(py)?; let brace_l = kinds.getattr(py, "BRACE_L")?; let brace_l_v = "{".into_pyobject(py)?; let pipe = kinds.getattr(py, "PIPE")?; let pipe_v = "|".into_pyobject(py)?; let brace_r = kinds.getattr(py, "BRACE_R")?; let brace_r_v = "}".into_pyobject(py)?; let name = kinds.getattr(py, "NAME")?; let int = kinds.getattr(py, "INT")?; let float = kinds.getattr(py, "FLOAT")?; let string = kinds.getattr(py, "STRING")?; let block_string = kinds.getattr(py, "BLOCK_STRING")?; let mut elems: Vec> = Vec::with_capacity(tokens.len()); let zero = 0u32.into_pyobject(py).unwrap(); let start_of_file = [ sof.clone_ref(py), zero.clone().into(), zero.clone().into(), zero.clone().into(), zero.clone().into(), py.None(), ]; elems.push(PyTuple::new(py, &start_of_file)?.into()); for token in tokens { let (kind, value) = match token.kind { K::Sof => (sof.clone_ref(py), py.None()), K::Eof => (eof.clone_ref(py), py.None()), K::Bang => (bang.clone_ref(py), bang_v.to_owned().into()), K::Dollar => (dollar.clone_ref(py), dollar_v.to_owned().into()), K::ParenL => (paren_l.clone_ref(py), paren_l_v.to_owned().into()), K::ParenR => (paren_r.clone_ref(py), paren_r_v.to_owned().into()), K::Spread => (spread.clone_ref(py), spread_v.to_owned().into()), K::Colon => (colon.clone_ref(py), colon_v.to_owned().into()), K::Equals => (equals.clone_ref(py), equals_v.to_owned().into()), K::At => (at.clone_ref(py), at_v.to_owned().into()), K::BracketL => (bracket_l.clone_ref(py), bracket_l_v.to_owned().into()), K::BracketR => (bracket_r.clone_ref(py), bracket_r_v.to_owned().into()), K::BraceL => (brace_l.clone_ref(py), brace_l_v.to_owned().into()), K::Pipe => (pipe.clone_ref(py), pipe_v.to_owned().into()), K::BraceR => (brace_r.clone_ref(py), brace_r_v.to_owned().into()), K::Name => (name.clone_ref(py), PyString::new(py, &token.value).into()), K::Int => (int.clone_ref(py), PyString::new(py, &token.value).into()), K::Float => (float.clone_ref(py), PyString::new(py, &token.value).into()), K::String => { // graphql-core 3 receives unescaped strings from the lexer let v = unquote_string(&token.value) .map_err(|e| LexingError::new_err(e.to_string()))? .into_pyobject(py)?; (string.clone_ref(py), v.to_owned().into()) } K::BlockString => { // graphql-core 3 receives unescaped strings from the lexer let v = unquote_block_string(&token.value) .map_err(|e| LexingError::new_err(e.to_string()))? .into_pyobject(py)?; (block_string.clone_ref(py), v.to_owned().into()) } }; let token_tuple = ( kind, token.position.map(|x| x.character), token .position .map(|x| x.character + token.value.chars().count()), token.position.map(|x| x.line), token.position.map(|x| x.column), value, ) .into_pyobject(py)?; elems.push(token_tuple.into()); } elems.push( ( eof, end_pos.character, end_pos.line, end_pos.column, end_pos.character, py.None(), ) .into_pyobject(py)? .into(), ); PyList::new(py, elems) } ================================================ FILE: edb/graphql-rewrite/src/rewrite.rs ================================================ use std::collections::{BTreeMap, HashSet}; use combine::stream::{Positioned, StreamOnce}; use edb_graphql_parser::common::{unquote_string, Type, Value as GqlValue}; use edb_graphql_parser::position::Pos; use edb_graphql_parser::query::{parse_query, Document, ParseError}; use edb_graphql_parser::query::{Definition, Directive}; use edb_graphql_parser::query::{InsertVars, InsertVarsKind, Operation}; use edb_graphql_parser::tokenizer::Kind::{BlockString, StringValue}; use edb_graphql_parser::tokenizer::Kind::{FloatValue, IntValue}; use edb_graphql_parser::tokenizer::Kind::{Name, Punctuator}; use edb_graphql_parser::tokenizer::{Token, TokenStream}; use edb_graphql_parser::visitor::Visit; use crate::py_token::{PyToken, PyTokenKind}; use crate::token_vec::TokenVec; #[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)] pub enum Value { Str(String), Int32(i32), Int64(i64), BigInt(String), Decimal(String), Boolean(bool), } #[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)] pub struct Variable { pub value: Value, pub token: PyToken, } #[derive(Debug)] pub enum Error { Lexing(String), Syntax(ParseError), NotFound(String), Assertion(String), Query(String), } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Entry { pub key: String, pub variables: Vec, pub defaults: BTreeMap, pub tokens: Vec, pub end_pos: Pos, pub num_variables: usize, } pub fn rewrite(operation: Option<&str>, s: &str) -> Result { use crate::py_token::PyTokenKind as P; use edb_graphql_parser::query::Value as G; use Value::*; let document: Document<'_, &str> = parse_query(s).map_err(Error::Syntax)?; let oper = if let Some(oper_name) = operation { find_operation(&document, oper_name) .ok_or_else(|| Error::NotFound(format!("no operation {operation:?} found")))? } else { let mut oper = None; for def in &document.definitions { match def { Definition::Operation(ref op) => { if oper.is_some() { Err(Error::NotFound( "Multiple operations \ found. Please specify operation name" .into(), ))?; } else { oper = Some(op); } } _ => continue, }; } oper.ok_or_else(|| Error::NotFound("no operation found".into()))? }; let (all_src_tokens, end_pos) = token_array(s)?; let mut src_tokens = TokenVec::new(&all_src_tokens); let mut tokens = Vec::with_capacity(src_tokens.len()); let mut variables = Vec::new(); let mut defaults = BTreeMap::new(); let mut value_positions = HashSet::new(); visit_directives(&mut value_positions, oper); for var in &oper.variable_definitions { if var.name.starts_with("__edb_arg_") { return Err(Error::Query( "Variables starting with '__edb_arg_' are prohibited".into(), )); } if let Some(ref dvalue) = var.default_value { let value = match (&dvalue.value, type_name(&var.var_type)) { (G::String(ref s), Some("String")) => Str(s.clone()), (G::Int(ref s), Some("Int")) | (G::Int(ref s), Some("Int32")) => { let value = match s.as_i64() { Some(v) if v <= i32::MAX as i64 && v >= i32::MIN as i64 => v, // Ignore bad values. Let graphql solver handle that _ => continue, }; Int32(value as i32) } (G::Int(ref s), Some("Int64")) => { let value = match s.as_i64() { Some(v) => v, // Ignore bad values. Let graphql solver handle that _ => continue, }; Int64(value) } (G::Int(ref s), Some("Bigint")) => BigInt(s.as_bigint().to_string()), (G::Float(s), Some("Float")) => Decimal(s.clone()), (G::Float(s), Some("Decimal")) => Decimal(s.clone()), (G::Boolean(s), Some("Boolean")) => Boolean(*s), // other types are unsupported _ => continue, }; for tok in src_tokens.drain_to(dvalue.span.0.token) { tokens.push(PyToken::new(tok)?); } if !matches!(var.var_type, Type::NonNullType(..)) { tokens.push(PyToken { kind: P::Bang, value: "!".into(), position: None, }); } // first token is needed for errors, others are discarded let pair = src_tokens .drain_to(dvalue.span.1.token) .next() .expect("at least one token of default value"); defaults.insert( var.name.to_owned(), Variable { value, token: PyToken::new(pair)?, }, ); } } for tok in src_tokens.drain_to(oper.insert_variables.position.token) { tokens.push(PyToken::new(tok)?); } let mut args = Vec::new(); let mut tmp = Vec::with_capacity(oper.selection_set.span.1.token - tokens.len()); for tok in src_tokens.drain_to(oper.selection_set.span.0.token) { tmp.push(PyToken::new(tok)?); } for (token, pos) in src_tokens.drain_to(oper.selection_set.span.1.token) { match token.kind { StringValue | BlockString => { let var_name = format!("__edb_arg_{}", variables.len()); tmp.push(PyToken { kind: P::Dollar, value: "$".into(), position: None, }); tmp.push(PyToken { kind: P::Name, value: var_name.clone().into(), position: None, }); variables.push(Variable { token: PyToken::new(&(*token, *pos))?, value: Str(unquote_string(token.value)?), }); push_var_definition(&mut args, &var_name, "String"); continue; } IntValue => { if token.value == "1" && pos.token > 2 && all_src_tokens[pos.token - 1].0.kind == Punctuator && all_src_tokens[pos.token - 1].0.value == ":" && all_src_tokens[pos.token - 2].0.kind == Name && all_src_tokens[pos.token - 2].0.value == "first" { // skip `first: 1` as this is used to fetch singleton // properties from queries where literal `LIMIT 1` // should be present tmp.push(PyToken::new(&(*token, *pos))?); continue; } let var_name = format!("__edb_arg_{}", variables.len()); tmp.push(PyToken { kind: P::Dollar, value: "$".into(), position: None, }); tmp.push(PyToken { kind: P::Name, value: var_name.clone().into(), position: None, }); let (value, typ) = if let Ok(val) = token.value.parse::() { if val <= i32::MAX as i64 && val >= i32::MIN as i64 { (Value::Int32(val as i32), "Int") } else { (Value::Int64(val), "Int64") } } else { (Value::BigInt(token.value.into()), "Bigint") }; variables.push(Variable { token: PyToken::new(&(*token, *pos))?, value, }); push_var_definition(&mut args, &var_name, typ); continue; } FloatValue => { let var_name = format!("__edb_arg_{}", variables.len()); tmp.push(PyToken { kind: P::Dollar, value: "$".into(), position: None, }); tmp.push(PyToken { kind: P::Name, value: var_name.clone().into(), position: None, }); variables.push(Variable { token: PyToken::new(&(*token, *pos))?, value: Value::Decimal(token.value.to_string()), }); push_var_definition(&mut args, &var_name, "Decimal"); continue; } Name if token.value == "true" || token.value == "false" => { let var_name = format!("__edb_arg_{}", variables.len()); tmp.push(PyToken { kind: P::Dollar, value: "$".into(), position: None, }); tmp.push(PyToken { kind: P::Name, value: var_name.clone().into(), position: None, }); variables.push(Variable { token: PyToken::new(&(*token, *pos))?, value: Value::Boolean(token.value == "true"), }); push_var_definition(&mut args, &var_name, "Boolean"); continue; } _ => {} } tmp.push(PyToken::new(&(*token, *pos))?); } insert_args(&mut tokens, &oper.insert_variables, args); tokens.extend(tmp); for tok in src_tokens.drain(src_tokens.len()) { tokens.push(PyToken::new(tok)?); } Ok(Entry { key: join_tokens(&tokens), variables, defaults, tokens, end_pos, num_variables: oper.variable_definitions.len(), }) } impl From for Error { fn from(v: ParseError) -> Error { Error::Syntax(v) } } impl<'a> From, Token<'a>>> for Error { fn from(v: combine::easy::Error, Token<'a>>) -> Error { Error::Lexing(v.to_string()) } } fn token_array(s: &str) -> Result<(Vec<(Token, Pos)>, Pos), Error> { let mut lexer = TokenStream::new(s); let mut tokens = Vec::new(); let mut pos = lexer.position(); loop { match lexer.uncons() { Ok(token) => { tokens.push((token, pos)); pos = lexer.position(); } Err(ref e) if e == &combine::easy::Error::end_of_input() => break, Err(e) => panic!("Parse error at {}: {}", lexer.position(), e), } } Ok((tokens, lexer.position())) } fn find_operation<'a>( document: &'a Document<'a, &'a str>, operation: &str, ) -> Option<&'a Operation<'a, &'a str>> { for def in &document.definitions { let res = match def { Definition::Operation(ref op) if op.name == Some(operation) => op, _ => continue, }; return Some(res); } None } fn insert_args(dest: &mut Vec, ins: &InsertVars, args: Vec) { use crate::py_token::PyTokenKind as P; if args.is_empty() { return; } if ins.kind == InsertVarsKind::Query { dest.push(PyToken { kind: P::Name, value: "query".into(), position: None, }); } if ins.kind != InsertVarsKind::Normal { dest.push(PyToken { kind: P::ParenL, value: "(".into(), position: None, }); } dest.extend(args); if ins.kind != InsertVarsKind::Normal { dest.push(PyToken { kind: P::ParenR, value: ")".into(), position: None, }); } } fn type_name<'x>(var_type: &'x Type<'x, &'x str>) -> Option<&'x str> { match var_type { Type::NamedType(t) => Some(t), Type::NonNullType(b) => type_name(b), _ => None, } } fn push_var_definition(args: &mut Vec, var_name: &str, var_type: &'static str) { use crate::py_token::PyTokenKind as P; args.push(PyToken { kind: P::Dollar, value: "$".into(), position: None, }); args.push(PyToken { kind: P::Name, value: var_name.to_owned().into(), position: None, }); args.push(PyToken { kind: P::Colon, value: ":".into(), position: None, }); args.push(PyToken { kind: P::Name, value: var_type.into(), position: None, }); args.push(PyToken { kind: P::Bang, value: "!".into(), position: None, }); } fn visit_directives<'x>(value_positions: &mut HashSet, oper: &'x Operation<'x, &'x str>) { for dir in oper.selection_set.visit::>() { if dir.name == "include" || dir.name == "skip" { for arg in &dir.arguments { if let GqlValue::Boolean(_) = arg.value { value_positions.insert(arg.value_position.token); } } } } } fn join_tokens<'a, I: IntoIterator>(tokens: I) -> String { let mut buf = String::new(); let mut needs_whitespace = false; for token in tokens { match (token.kind, needs_whitespace) { // space before puncutators is optional (PyTokenKind::ParenL, true) => {} (PyTokenKind::ParenR, true) => {} (PyTokenKind::Spread, true) => {} (PyTokenKind::Colon, true) => {} (PyTokenKind::Equals, true) => {} (PyTokenKind::At, true) => {} (PyTokenKind::BracketL, true) => {} (PyTokenKind::BracketR, true) => {} (PyTokenKind::BraceL, true) => {} (PyTokenKind::BraceR, true) => {} (PyTokenKind::Pipe, true) => {} (PyTokenKind::Bang, true) => {} (_, true) => buf.push(' '), (_, false) => {} } buf.push_str(&token.value); needs_whitespace = match token.kind { PyTokenKind::Dollar => false, PyTokenKind::Bang => false, PyTokenKind::ParenL => false, PyTokenKind::ParenR => false, PyTokenKind::Spread => false, PyTokenKind::Colon => false, PyTokenKind::Equals => false, PyTokenKind::At => false, PyTokenKind::BracketL => false, PyTokenKind::BracketR => false, PyTokenKind::BraceL => false, PyTokenKind::BraceR => false, PyTokenKind::Pipe => false, PyTokenKind::Int => true, PyTokenKind::Float => true, PyTokenKind::String => true, PyTokenKind::BlockString => true, PyTokenKind::Name => true, PyTokenKind::Eof => unreachable!(), PyTokenKind::Sof => unreachable!(), }; } buf } ================================================ FILE: edb/graphql-rewrite/src/token_vec.rs ================================================ use edb_graphql_parser::position::Pos; use edb_graphql_parser::tokenizer::Token; pub struct TokenVec<'a> { tokens: &'a Vec<(Token<'a>, Pos)>, consumed: usize, } impl<'a> TokenVec<'a> { pub fn new(tokens: &'a Vec<(Token<'a>, Pos)>) -> TokenVec<'a> { TokenVec { tokens, consumed: 0, } } pub fn drain(&mut self, n: usize) -> impl Iterator { let pos = self.consumed; self.consumed += n; assert!(n <= self.tokens.len(), "attempt to more tokens than exist"); self.tokens[pos..][..n].iter() } pub fn drain_to(&mut self, end: usize) -> impl Iterator { let n = end .checked_sub(self.consumed) .expect("drain_to with index smaller than current"); self.drain(n) } pub fn len(&self) -> usize { self.tokens .len() .checked_sub(self.consumed) .expect("consumed more tokens than exists") } } ================================================ FILE: edb/graphql-rewrite/tests/rewrite.rs ================================================ #![cfg(feature = "python_extension")] use std::collections::BTreeMap; use edb_graphql_parser::Pos; use graphql_rewrite::{rewrite, Value, Variable}; use graphql_rewrite::{PyToken, PyTokenKind}; #[test] fn test_no_args() { let entry = rewrite( None, r#" query { object(filter: {field: {eq: "test"}}) { field } } "#, ) .unwrap(); assert_eq!( entry.key, "\ query($__edb_arg_0:String!){\ object(filter:{field:{eq:$__edb_arg_0}}){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::String, value: r#""test""#.into(), position: Some(Pos { line: 3, column: 41, character: 57, token: 12 }), }, value: Value::Str("test".into()), }] ); } #[test] fn test_no_query() { let entry = rewrite( None, r#" { object(filter: {field: {eq: "test"}}) { field } } "#, ) .unwrap(); assert_eq!( entry.key, "\ query($__edb_arg_0:String!){\ object(filter:{field:{eq:$__edb_arg_0}}){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::String, value: r#""test""#.into(), position: Some(Pos { line: 3, column: 41, character: 51, token: 11 }), }, value: Value::Str("test".into()), }] ); } #[test] fn test_no_name() { let entry = rewrite( None, r#" query($x: String) { object(filter: {field: {eq: "test"}}, y: $x) { field } } "#, ) .unwrap(); assert_eq!( entry.key, "\ query($x:String $__edb_arg_0:String!){\ object(filter:{field:{eq:$__edb_arg_0}}y:$x){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::String, value: r#""test""#.into(), position: Some(Pos { line: 3, column: 41, character: 69, token: 18 }), }, value: Value::Str("test".into()), }] ); } #[test] fn test_name_args() { let entry = rewrite( Some("Hello"), r#" query Hello($x: String, $y: String!) { object(filter: {field: {eq: "test"}}, x: $x, y: $y) { field } } "#, ) .unwrap(); assert_eq!( entry.key, "\ query Hello($x:String $y:String!$__edb_arg_0:String!){\ object(filter:{field:{eq:$__edb_arg_0}}x:$x y:$y){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::String, value: r#""test""#.into(), position: Some(Pos { line: 3, column: 41, character: 88, token: 24 }), }, value: Value::Str("test".into()), }] ); } #[test] fn test_name() { let entry = rewrite( Some("Hello"), r#" query Hello { object(filter: {field: {eq: "test"}}) { field } } "#, ) .unwrap(); assert_eq!( entry.key, "\ query Hello($__edb_arg_0:String!){\ object(filter:{field:{eq:$__edb_arg_0}}){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::String, value: r#""test""#.into(), position: Some(Pos { line: 3, column: 41, character: 63, token: 13 }), }, value: Value::Str("test".into()), }] ); } #[test] fn test_default_name() { let entry = rewrite( None, r#" query Hello { object(filter: {field: {eq: "test"}}) { field } } "#, ) .unwrap(); assert_eq!( entry.key, "\ query Hello($__edb_arg_0:String!){\ object(filter:{field:{eq:$__edb_arg_0}}){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::String, value: r#""test""#.into(), position: Some(Pos { line: 3, column: 41, character: 63, token: 13 }), }, value: Value::Str("test".into()), }] ); } #[test] fn test_other() { let entry = rewrite( Some("Hello"), r#" query Other { object(filter: {field: {eq: "test1"}}) { field } } query Hello { object(filter: {field: {eq: "test2"}}) { field } } "#, ) .unwrap(); assert_eq!( entry.key, "\ query Other{\ object(filter:{field:{eq:\"test1\"}}){\ field\ }\ }\ query Hello($__edb_arg_0:String!){\ object(filter:{field:{eq:$__edb_arg_0}}){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::String, value: r#""test2""#.into(), position: Some(Pos { line: 8, column: 41, character: 184, token: 34 }), }, value: Value::Str("test2".into()), }] ); } #[test] fn test_defaults() { let entry = rewrite( Some("Hello"), r#" query Hello($x: String = "xxx", $y: String! = "yyy") { object(filter: {field: {eq: "test"}}, x: $x, y: $y) { field } } "#, ) .unwrap(); assert_eq!( entry.key, "\ query Hello($x:String!$y:String!$__edb_arg_0:String!){\ object(filter:{field:{eq:$__edb_arg_0}}x:$x y:$y){\ field\ }\ }\ " ); let mut defaults = BTreeMap::new(); defaults.insert( "x".to_owned(), Variable { value: Value::Str("xxx".into()), token: PyToken { kind: PyTokenKind::Equals, value: "=".into(), position: Some(Pos { line: 2, column: 32, character: 32, token: 7, }), }, }, ); defaults.insert( "y".to_owned(), Variable { value: Value::Str("yyy".into()), token: PyToken { kind: PyTokenKind::Equals, value: "=".into(), position: Some(Pos { line: 2, column: 53, character: 53, token: 14, }), }, }, ); assert_eq!(entry.defaults, defaults); } #[test] fn test_int32() { let entry = rewrite( None, r###" query { object(filter: {field: {eq: 17}}) { field } } "###, ) .unwrap(); assert_eq!( entry.key, "\ query($__edb_arg_0:Int!){\ object(filter:{field:{eq:$__edb_arg_0}}){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::Int, value: r#"17"#.into(), position: Some(Pos { line: 3, column: 41, character: 57, token: 12 }), }, value: Value::Int32(17), }] ); } #[test] fn test_int64() { let entry = rewrite( None, r###" query { object(filter: {field: {eq: 17123456790}}) { field } } "###, ) .unwrap(); assert_eq!( entry.key, "\ query($__edb_arg_0:Int64!){\ object(filter:{field:{eq:$__edb_arg_0}}){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::Int, value: r#"17123456790"#.into(), position: Some(Pos { line: 3, column: 41, character: 57, token: 12 }), }, value: Value::Int64(17123456790), }] ); } #[test] fn test_bigint() { let entry = rewrite( None, r###" query { object(filter: {field: {eq: 171234567901234567890}}) { field } } "###, ) .unwrap(); assert_eq!( entry.key, "\ query($__edb_arg_0:Bigint!){\ object(filter:{field:{eq:$__edb_arg_0}}){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::Int, value: r#"171234567901234567890"#.into(), position: Some(Pos { line: 3, column: 41, character: 57, token: 12 }), }, value: Value::BigInt("171234567901234567890".into()), }] ); } #[test] fn test_first_1() { let entry = rewrite( None, r###" query { object(filter: {field: {eq: 1}}, first: 1) { field } } "###, ) .unwrap(); assert_eq!( entry.key, "\ query($__edb_arg_0:Int!){\ object(filter:{field:{eq:$__edb_arg_0}}first:1){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::Int, value: r#"1"#.into(), position: Some(Pos { line: 3, column: 41, character: 57, token: 12 }), }, value: Value::Int32(1), }] ); } #[test] fn test_first_2() { let entry = rewrite( None, r###" query { object(filter: {field: {eq: 1}}, first: 2) { field } } "###, ) .unwrap(); assert_eq!( entry.key, "\ query($__edb_arg_0:Int!$__edb_arg_1:Int!){\ object(filter:{field:{eq:$__edb_arg_0}}first:$__edb_arg_1){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![ Variable { token: PyToken { kind: PyTokenKind::Int, value: r#"1"#.into(), position: Some(Pos { line: 3, column: 41, character: 57, token: 12 }), }, value: Value::Int32(1), }, Variable { token: PyToken { kind: PyTokenKind::Int, value: r#"2"#.into(), position: Some(Pos { line: 3, column: 53, character: 69, token: 17 }), }, value: Value::Int32(2), }, ] ); } #[test] fn test_defaults_int() { let entry = rewrite( Some("Hello"), r###" query Hello($x: Int = 123, $y: Int! = 1234) { object(x: $x, y: $y) { field } } "###, ) .unwrap(); assert_eq!( entry.key, "\ query Hello($x:Int!$y:Int!){\ object(x:$x y:$y){\ field\ }\ }\ " ); let mut defaults = BTreeMap::new(); defaults.insert( "x".to_owned(), Variable { value: Value::Int32(123), token: PyToken { kind: PyTokenKind::Equals, value: "=".into(), position: Some(Pos { line: 2, column: 29, character: 29, token: 7, }), }, }, ); defaults.insert( "y".to_owned(), Variable { value: Value::Int32(1234), token: PyToken { kind: PyTokenKind::Equals, value: "=".into(), position: Some(Pos { line: 2, column: 45, character: 45, token: 14, }), }, }, ); assert_eq!(entry.defaults, defaults); } #[test] fn test_float() { let entry = rewrite( None, r###" query { object(filter: {field: {eq: 17.25}}) { field } } "###, ) .unwrap(); assert_eq!( entry.key, "\ query($__edb_arg_0:Decimal!){\ object(filter:{field:{eq:$__edb_arg_0}}){\ field\ }\ }\ " ); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::Float, value: r#"17.25"#.into(), position: Some(Pos { line: 3, column: 41, character: 57, token: 12 }), }, value: Value::Decimal("17.25".into()), }] ); } #[test] fn test_defaults_float() { let entry = rewrite( Some("Hello"), r###" query Hello($x: Float = 123.25, $y: Float! = 1234.75) { object(x: $x, y: $y) { field } } "###, ) .unwrap(); assert_eq!( entry.key, "\ query Hello($x:Float!$y:Float!){\ object(x:$x y:$y){\ field\ }\ }\ " ); let mut defaults = BTreeMap::new(); defaults.insert( "x".to_owned(), Variable { value: Value::Decimal("123.25".into()), token: PyToken { kind: PyTokenKind::Equals, value: "=".into(), position: Some(Pos { line: 2, column: 31, character: 31, token: 7, }), }, }, ); defaults.insert( "y".to_owned(), Variable { value: Value::Decimal("1234.75".into()), token: PyToken { kind: PyTokenKind::Equals, value: "=".into(), position: Some(Pos { line: 2, column: 52, character: 52, token: 14, }), }, }, ); assert_eq!(entry.defaults, defaults); } #[test] fn test_defaults_bool() { let entry = rewrite( Some("Hello"), r###" query Hello($x: Boolean = true, $y: Boolean! = false) { object(x: $x, y: $y) { field } } "###, ) .unwrap(); assert_eq!( entry.key, "\ query Hello($x:Boolean!$y:Boolean!){\ object(x:$x y:$y){\ field\ }\ }\ " ); let mut defaults = BTreeMap::new(); defaults.insert( "x".to_owned(), Variable { value: Value::Boolean(true), token: PyToken { kind: PyTokenKind::Equals, value: "=".into(), position: Some(Pos { line: 2, column: 33, character: 33, token: 7, }), }, }, ); defaults.insert( "y".to_owned(), Variable { value: Value::Boolean(false), token: PyToken { kind: PyTokenKind::Equals, value: "=".into(), position: Some(Pos { line: 2, column: 54, character: 54, token: 14, }), }, }, ); assert_eq!(entry.defaults, defaults); } #[test] fn test_include_skip() { let entry = rewrite( Some("Hello"), r###" query Hello($x: Boolean = true) { object { hello @include(if: $x) world @skip(if: true) } } "###, ) .unwrap(); assert_eq!( entry.key, "\ query Hello($x:Boolean!$__edb_arg_0:Boolean!){\ object{\ hello@include(if:$x)\ world@skip(if:$__edb_arg_0)\ }\ }\ " ); let mut defaults = BTreeMap::new(); defaults.insert( "x".to_owned(), Variable { value: Value::Boolean(true), token: PyToken { kind: PyTokenKind::Equals, value: "=".into(), position: Some(Pos { line: 2, column: 33, character: 33, token: 7, }), }, }, ); assert_eq!(entry.defaults, defaults); assert_eq!( entry.variables, vec![Variable { token: PyToken { kind: PyTokenKind::Name, value: r#"true"#.into(), position: Some(Pos { line: 5, column: 33, character: 135, token: 28 }), }, value: Value::Boolean(true), }] ); } ================================================ FILE: edb/ir/__init__.py ================================================ ## # Copyright (c) 2008-present MagicStack Inc. # All rights reserved. # # See LICENSE for details. ## from __future__ import annotations ================================================ FILE: edb/ir/ast.py ================================================ # mypy: implicit-reexport # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """IR expression tree node definitions. The IR expression tree is produced by the EdgeQL compiler (see :mod:`edgeql.compiler`). It is a self-contained representation of an EdgeQL expression, which, together with the accompanying scope tree (:mod:`ir.scopetree`) is sufficient to produce a backend query (e.g. SQL) without any other input or context. The most common part of the IR expression tree is the :class:`~Set` class. Every expression is encoded as a ``Set`` instance that contains all common metadata, such as the expression type, its symbolic identity (PathId) and other useful bits. The ``Set.expr`` field contains the specific node for the expression. The expression nodes usually refer to ``Set`` nodes rather than other nodes directly. For example, the EdgeQL expression ``SELECT str_lower('ABC') ++ 'd'`` yields the following IR (roughly): Set ( expr = SelectStmt ( result = Set ( expr = OperatorCall ( args = [ CallArg ( expr = Set ( expr = FunctionCall ( args = [ CallArg ( expr = Set ( expr = StringConstant ( value = 'ABC' ) ), ), CallArg ( expr = Set ( expr = StringConstant ( value = 'd' ) ), ) ] ) ) ) ] ) ) ) ) """ from __future__ import annotations import abc import dataclasses import typing import uuid from edb.common import ast, compiler, span, markup, enum as s_enum from edb import errors from edb.schema import modules as s_mod from edb.schema import name as sn from edb.schema import objects as so from edb.schema import objtypes as s_objtypes from edb.schema import permissions as s_permissions from edb.schema import pointers as s_pointers from edb.schema import schema as s_schema from edb.schema import types as s_types from edb.edgeql import ast as qlast from edb.edgeql import qltypes from .pathid import PathId, Namespace # noqa from .scopetree import ScopeTreeNode # noqa Span = span.Span def new_scope_tree() -> ScopeTreeNode: return ScopeTreeNode(fenced=True) class Base(ast.AST): __abstract_node__ = True __ast_hidden__ = {'span'} span: typing.Optional[Span] = None def __repr__(self) -> str: return ( f'' ) class ImmutableBase(ast.ImmutableASTMixin, Base): __abstract_node__ = True class ViewShapeMetadata(Base): has_implicit_id: bool = False class TypeRef(ImmutableBase): # Hide ancestors and children from debug spew because they are # incredibly noisy. __ast_hidden__ = {'ancestors', 'children'} # The id of the referenced type id: uuid.UUID # Full name of the type, not necessarily schema-addressable, # used for annotations only. name_hint: sn.Name # Name hint of the real underlying type, if the type ref was created # with an explicitly specified typename. orig_name_hint: typing.Optional[sn.Name] = None # The ref of the underlying material type, if this is a view type, # else None. material_type: typing.Optional[TypeRef] = None # If this is a scalar type, base_type would be the highest # non-abstract base type. base_type: typing.Optional[TypeRef] = None # A set of type children descriptors, if necessary for # this type description. children: typing.Optional[frozenset[TypeRef]] = None # A set of type ancestor descriptors, if necessary for # this type description. ancestors: typing.Optional[frozenset[TypeRef]] = None # If this is a compound type, this is a non-overlapping set of # constituent types. union: typing.Optional[frozenset[TypeRef]] = None # Whether the union is specified by an exhaustive list of # types, and type inheritance should not be considered. union_is_exhaustive: bool = False # If this is a complex type, record the expression used to generate the # type. This is used later to get the correct rvar in `get_path_var`. expr_intersection: typing.Optional[frozenset[TypeRef]] = None expr_union: typing.Optional[frozenset[TypeRef]] = None # If this node is an element of a collection, and the # collection elements are named, this would be then # name of the element. element_name: typing.Optional[str] = None # The kind of the collection type if this is a collection collection: typing.Optional[str] = None # Collection subtypes if this is a collection subtypes: tuple[TypeRef, ...] = () # True, if this describes a scalar type is_scalar: bool = False # True, if this describes a view is_view: bool = False # True, if this describes a cfg view is_cfg_view: bool = False # True, if this describes an abstract type is_abstract: bool = False # True, if the collection type is persisted in the schema in_schema: bool = False # True, if this describes an opaque union type is_opaque_union: bool = False # Does this need to call a custom json cast function needs_custom_json_cast: bool = False # If this has a schema-configured backend type, what is it sql_type: typing.Optional[str] = None # If this has a schema-configured custom sql serialization, what is it custom_sql_serialization: typing.Optional[str] = None def __repr__(self) -> str: return f'' @property def real_material_type(self) -> TypeRef: return self.material_type or self @property def real_base_type(self) -> TypeRef: return self.base_type or self def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False return self.id == other.id def __hash__(self) -> int: return hash(self.id) class AnyTypeRef(TypeRef): pass class AnyTupleRef(TypeRef): pass class AnyObjectRef(TypeRef): pass class BasePointerRef(ImmutableBase): __abstract_node__ = True # Hide children to reduce noise __ast_hidden__ = {'children'} # cardinality fields need to be mutable for lazy cardinality inference. # and children because we update pointers with newly derived children __ast_mutable_fields__ = frozenset( ('in_cardinality', 'out_cardinality', 'children', 'is_computable') ) # The defaults set here are mostly to try to reduce debug spew output. name: sn.QualName shortname: sn.QualName std_parent_name: typing.Optional[sn.QualName] = None out_source: TypeRef out_target: TypeRef source_ptr: typing.Optional[PointerRef] = None base_ptr: typing.Optional[BasePointerRef] = None material_ptr: typing.Optional[BasePointerRef] = None children: frozenset[BasePointerRef] = frozenset() union_components: typing.Optional[set[BasePointerRef]] = None intersection_components: typing.Optional[set[BasePointerRef]] = None union_is_exhaustive: bool = False has_properties: bool = False is_derived: bool = False is_computable: bool = False # Outbound cardinality of the pointer. out_cardinality: qltypes.Cardinality # Inbound cardinality of the pointer. in_cardinality: qltypes.Cardinality = qltypes.Cardinality.MANY defined_here: bool = False computed_link_alias: typing.Optional[BasePointerRef] = None computed_link_alias_is_backward: typing.Optional[bool] = None def dir_target(self, direction: s_pointers.PointerDirection) -> TypeRef: if direction is s_pointers.PointerDirection.Outbound: return self.out_target else: return self.out_source def dir_source(self, direction: s_pointers.PointerDirection) -> TypeRef: if direction is s_pointers.PointerDirection.Outbound: return self.out_source else: return self.out_target def dir_cardinality( self, direction: s_pointers.PointerDirection ) -> qltypes.Cardinality: if direction is s_pointers.PointerDirection.Outbound: return self.out_cardinality else: return self.in_cardinality @property def required(self) -> bool: return self.out_cardinality.to_schema_value()[0] def descendants(self) -> set[BasePointerRef]: res = set(self.children) for child in self.children: res.update(child.descendants()) return res @property def real_material_ptr(self) -> BasePointerRef: return self.material_ptr or self @property def real_base_ptr(self) -> BasePointerRef: return self.base_ptr or self def __repr__(self) -> str: return f'' class PointerRef(BasePointerRef): id: uuid.UUID class ConstraintRef(ImmutableBase): # The id of the constraint id: uuid.UUID class TupleIndirectionLink(s_pointers.PseudoPointer): """A Link-alike that can be used in tuple indirection path ids.""" def __init__( self, source: so.Object, target: s_types.Type, *, element_name: str, ) -> None: self._source = source self._target = target self._name = sn.QualName( module='__tuple__', name=str(element_name)) def __hash__(self) -> int: return hash((self.__class__, self._source, self._name)) def __eq__(self, other: typing.Any) -> bool: if not isinstance(other, self.__class__): return False return self._source == other._source and self._name == other._name def get_name(self, schema: s_schema.Schema) -> sn.QualName: return self._name def get_cardinality( self, schema: s_schema.Schema ) -> qltypes.SchemaCardinality: return qltypes.SchemaCardinality.One def singular( self, schema: s_schema.Schema, direction: s_pointers.PointerDirection = s_pointers.PointerDirection.Outbound ) -> bool: return True def scalar(self) -> bool: return self._target.is_scalar() def get_source(self, schema: s_schema.Schema) -> so.Object: return self._source def get_target(self, schema: s_schema.Schema) -> s_types.Type: return self._target def is_tuple_indirection(self) -> bool: return True def get_computable(self, schema: s_schema.Schema) -> bool: return False class TupleIndirectionPointerRef(BasePointerRef): pass class SpecialPointerRef(BasePointerRef): """Pointer ref used for internal columns, such as __fts_document__""" pass class TypeIntersectionLink(s_pointers.PseudoPointer): """A Link-alike that can be used in type intersection path ids.""" def __init__( self, source: so.Object, target: s_types.Type, *, optional: bool, is_empty: bool, is_subtype: bool, rptr_specialization: typing.Iterable[PointerRef] = (), cardinality: qltypes.SchemaCardinality, ) -> None: name = 'optindirection' if optional else 'indirection' self._name = sn.QualName(module='__type__', name=name) self._source = source self._target = target self._cardinality = cardinality self._optional = optional self._is_empty = is_empty self._is_subtype = is_subtype self._rptr_specialization = frozenset(rptr_specialization) def get_name(self, schema: s_schema.Schema) -> sn.QualName: return self._name def get_cardinality( self, schema: s_schema.Schema ) -> qltypes.SchemaCardinality: return self._cardinality def get_computable(self, schema: s_schema.Schema) -> bool: return False def is_type_intersection(self) -> bool: return True def is_optional(self) -> bool: return self._optional def is_empty(self) -> bool: return self._is_empty def is_subtype(self) -> bool: return self._is_subtype def get_rptr_specialization(self) -> frozenset[PointerRef]: return self._rptr_specialization def get_source(self, schema: s_schema.Schema) -> so.Object: return self._source def get_target(self, schema: s_schema.Schema) -> s_types.Type: return self._target def singular( self, schema: s_schema.Schema, direction: s_pointers.PointerDirection = s_pointers.PointerDirection.Outbound ) -> bool: if direction is s_pointers.PointerDirection.Outbound: return (self.get_cardinality(schema) is qltypes.SchemaCardinality.One) else: return True def scalar(self) -> bool: return self._target.is_scalar() class TypeIntersectionPointerRef(BasePointerRef): optional: bool is_empty: bool is_subtype: bool rptr_specialization: frozenset[PointerRef] class Expr(Base): __abstract_node__ = True if typing.TYPE_CHECKING: @property @abc.abstractmethod def typeref(self) -> TypeRef: raise NotImplementedError # Sets to materialize at this point, keyed by the type/ptr id. materialized_sets: typing.Optional[ dict[uuid.UUID, MaterializedSet]] = None class Pointer(Expr): source: Set ptrref: BasePointerRef direction: s_pointers.PointerDirection # Whether to make this an optional deref (written '.?>') that # suppresses any error due to looking at a required link hidden by # a policy . optional_deref: bool = False # Whether to *always* use a link table when this pointer is # accessed. This is needed (for example) when a (possibly single) # link property is being referenced in a FOR iterator, and we # aren't going to have access to the Pointer when we access # the iterator variable. force_link_table: bool = False # If the pointer is a computed pointer (or a computed pointer # definition), the expression. expr: typing.Optional[Expr] = None is_definition: bool # Set when we have placed an rptr to help route link properties # but it is not a genuine pointer use. is_phony: bool = False anchor: typing.Optional[str] = None show_as_anchor: typing.Optional[str] = None is_mutation: bool = False @property def is_inbound(self) -> bool: return self.direction == s_pointers.PointerDirection.Inbound @property def dir_cardinality(self) -> qltypes.Cardinality: return self.ptrref.dir_cardinality(self.direction) @property def typeref(self) -> TypeRef: return self.ptrref.dir_target(self.direction) class TypeIntersectionPointer(Pointer): optional: bool ptrref: TypeIntersectionPointerRef is_definition: bool = False class TupleIndirectionPointer(Pointer): ptrref: TupleIndirectionPointerRef is_definition: bool = False class ImmutableExpr(Expr, ImmutableBase): __abstract_node__ = True class BindingKind(s_enum.StrEnum): With = 'With' For = 'For' Select = 'Select' Schema = 'Schema' class TypeRoot(Expr): # This will be replicated in the enclosing set. typeref: TypeRef # Whether this is a reference to a global that is cached in a # materialized CTE in the query. is_cached_global: bool = False # Whether to force this to not select subtypes skip_subtypes: bool = False class RefExpr(Expr): '''Different expressions sorts that refer to some kind of binding.''' __abstract_node__ = True typeref: TypeRef class MaterializedExpr(RefExpr): pass class VisibleBindingExpr(RefExpr): pass class InlinedParameterExpr(RefExpr): required: bool is_global: bool T_expr_co = typing.TypeVar('T_expr_co', covariant=True, bound=Expr) # SetE is the base 'Set' type, and it is parameterized over what kind # of expression it holds. Most code uses the Set alias below, which # instantiates it with Expr. # irutils.is_set_instance can be used to refine the type. class SetE(Base, typing.Generic[T_expr_co]): # noqa: UP046 '''A somewhat overloaded metadata container for expressions. Its primary purpose is to be the holder for expression metadata such as path_id. It *also* contains shape applications. ''' __ast_frozen_fields__ = frozenset({'typeref'}) # N.B: Make sure to add new fields to setgen.new_set_from_set! path_id: PathId path_scope_id: typing.Optional[int] = None typeref: TypeRef expr: T_expr_co shape: tuple[tuple[SetE[Pointer], qlast.ShapeOp], ...] = () anchor: typing.Optional[str] = None show_as_anchor: typing.Optional[str] = None # A pointer to a set nested within this one has a shape and the same # typeref, if such a set exists. shape_source: typing.Optional[Set] = None is_binding: typing.Optional[BindingKind] = None is_schema_alias: bool = False is_materialized_ref: bool = False # A ref to a visible binding (like a for iterator variable) should # never need to be compiled--it should always be found. We set a # flag instead of clearing expr because clearing expr can mess up # card/multi inference. is_visible_binding_ref: bool = False # Whether to force this to ignore rewrites. Very dangerous! # Currently for preventing duplicate explicit .id # insertions to BaseObject and for ignoring other access policies # inside access policy expressions. # # N.B: This is defined on Set and not on TypeRoot because we use the Set # to join against target types on links, and to ensure rvars. ignore_rewrites: bool = False # Is this Set a dummy introduced by simple_scoping to protect a # path from factoring? We track this because we try to collapse # these extra scopes away when they are not needed, at the end of # compilation. is_factoring_protected: bool = False def __repr__(self) -> str: return f'' # We set its name to Set because that's what we want visitors to use. SetE.__name__ = 'Set' if typing.TYPE_CHECKING: Set = SetE[Expr] else: Set = SetE DUMMY_SET = Set() # type: ignore[call-arg] class Command(Base): __abstract_node__ = True @dataclasses.dataclass(frozen=True, kw_only=True) class Param: """Query parameter with its schema type and IR type""" name: str """Parameter name""" required: bool """Whether parameter is OPTIONAL or REQUIRED""" schema_type: s_types.Type """Schema type""" ir_type: TypeRef """IR type reference""" sub_params: SubParams | None = None """Sub-parameters containing tuple components. If the param needs to be split into multiple real postgres params in order to implement tuples, this collects those parameters and the decoder expression. """ @property def is_sub_param(self) -> bool: return ( self.name.startswith('__edb_decoded_') and self.name.endswith('__') ) @dataclasses.dataclass(frozen=True, kw_only=True) class SubParams: """Information about sub-parameters needed for tuple components. If the param needs to be split into multiple real postgres params in order to implement tuples, this collects those parameters and the decoder expression. """ trans_type: ParamTransType decoder_edgeql: qlast.Expr params: tuple[Param, ...] decoder_ir: Set | None = None @dataclasses.dataclass(eq=False) class ParamTransType: """Representation of how a tuple-containing parameter type is broken down. The key thing here is that each node contains the index corresponding to which sub-parameter that node in the argument type corresponds with. See edgeql.compiler.tuple_args for details. The reason we track this in a separate data structure (instead of just having an dict from TypeRefs to indexes, say) is that TypeRefs will often be shared among identical types, but we need to track different indexes for different components of a type. (For example, if we have an param type `tuple`, this gets decomposed into two `str` params, with indexes 0 and 1. """ typeref: TypeRef idx: int def flatten(self) -> tuple[typing.Any, ...]: """Flatten out the trans type into a tuple representation. The idea here is to produce something that our inner loop in cython can consume efficiently. """ raise NotImplementedError @dataclasses.dataclass(eq=False) class ParamScalar(ParamTransType): cast_to: typing.Optional[TypeRef] = None def flatten(self) -> tuple[typing.Any, ...]: return (int(qltypes.TypeTag.SCALAR), self.idx) @dataclasses.dataclass(eq=False) class ParamTuple(ParamTransType): typs: tuple[tuple[typing.Optional[str], ParamTransType], ...] def flatten(self) -> tuple[typing.Any, ...]: return ( (int(qltypes.TypeTag.TUPLE), self.idx) + tuple(x.flatten() for _, x in self.typs) ) @dataclasses.dataclass(eq=False) class ParamArray(ParamTransType): typ: ParamTransType def flatten(self) -> tuple[typing.Any, ...]: return (int(qltypes.TypeTag.ARRAY), self.idx, self.typ.flatten()) @dataclasses.dataclass(frozen=True) class Global(Param): global_name: sn.QualName """The name of the global""" has_present_arg: bool """Whether this global needs a companion parameter indicating whether the global is present. This is needed when a global has a default but also is optional, and so we need to distinguish "unset" and "set to {}". """ is_permission: bool """Whether this global comes from a Permission. Permissions are injected directly by the server based on the connection role. """ @dataclasses.dataclass(frozen=True) class ScriptInfo: """Result of preprocessing a script of multiple statements""" params: dict[str, Param] """All parameters in all statements in the script""" schema: s_schema.Schema """The schema after preprocessing. (Collections may have been created.)""" class MaterializeVolatile(Base): pass class MaterializeVisible(Base): __ast_hidden__ = {'sets'} sets: set[tuple[PathId, Set]] path_scope_id: int @markup.serializer.serializer.register(MaterializeVisible) def _serialize_to_markup_mat_vis( ir: MaterializeVisible, *, ctx: typing.Any ) -> typing.Any: # We want to show the path_ids but *not* to show the full sets node = ast.serialize_to_markup(ir, ctx=ctx) fixed = {(x, y.path_id) for x, y in ir.sets} node.add_child(label='uses', node=markup.serialize(fixed, ctx=ctx)) return node MaterializeReason = MaterializeVolatile | MaterializeVisible class ComputableInfo(typing.NamedTuple): qlexpr: qlast.Expr irexpr: typing.Optional[Set | Expr] context: compiler.ContextLevel path_id: PathId path_id_ns: typing.Optional[Namespace] shape_op: qlast.ShapeOp should_materialize: typing.Sequence[MaterializeReason] @dataclasses.dataclass(frozen=True, kw_only=True) class ServerParamConversion: param_name: str conversion_name: str additional_info: tuple[str, ...] # If the parameter is a query parameter, track its script params index. script_param_index: typing.Optional[int] = None # If the parameter is a constant value, pass to directly to the server. constant_value: typing.Optional[typing.Any] = None class Statement(Command): expr: Set views: dict[sn.Name, s_types.Type] params: list[Param] globals: list[Global] required_permissions: set[s_permissions.Permission] server_param_conversions: list[ServerParamConversion] server_param_conversion_params: list[Param] cardinality: qltypes.Cardinality volatility: qltypes.Volatility multiplicity: qltypes.Multiplicity stype: s_types.Type view_shapes: dict[so.Object, list[s_pointers.Pointer]] view_shapes_metadata: dict[s_types.Type, ViewShapeMetadata] schema: s_schema.Schema schema_refs: frozenset[so.Object] schema_ref_exprs: typing.Optional[ dict[so.Object, set[qlast.Base]]] scope_tree: ScopeTreeNode dml_exprs: list[qlast.Base] type_rewrites: dict[tuple[uuid.UUID, bool], Set] singletons: list[PathId] triggers: tuple[tuple[Trigger, ...], ...] warnings: tuple[errors.EdgeDBError, ...] unsafe_isolation_dangers: tuple[errors.UnsafeIsolationLevelError, ...] class TypeIntrospection(ImmutableExpr): # The type value to return output_typeref: TypeRef # The type value *of the output* typeref: TypeRef class ConstExpr(Expr): __abstract_node__ = True typeref: TypeRef class EmptySet(ConstExpr): pass class BaseConstant(ConstExpr, ImmutableExpr): __abstract_node__ = True value: typing.Any def __init__( self, *args: typing.Any, typeref: TypeRef, **kwargs: typing.Any, ) -> None: super().__init__(*args, typeref=typeref, **kwargs) if self.typeref is None: raise ValueError('cannot create irast.Constant without a type') if self.value is None: raise ValueError('cannot create irast.Constant without a value') def _init_copy(self) -> BaseConstant: return self.__class__(typeref=self.typeref, value=self.value) class BaseStrConstant(BaseConstant): __abstract_node__ = True value: str class StringConstant(BaseStrConstant): pass class IntegerConstant(BaseStrConstant): pass class FloatConstant(BaseStrConstant): pass class DecimalConstant(BaseStrConstant): pass class BigintConstant(BaseStrConstant): pass class BooleanConstant(BaseStrConstant): pass class BytesConstant(BaseConstant): value: bytes class ConstantSet(ConstExpr, ImmutableExpr): elements: tuple[BaseConstant | BaseParameter, ...] class BaseParameter(ImmutableExpr): __abstract_node__ = True name: str required: bool typeref: TypeRef # None means not a global. Otherwise, whether this is an implicitly # created global for a function call. is_implicit_global: typing.Optional[bool] = None @property def is_global(self) -> bool: return self.is_implicit_global is not None class QueryParameter(BaseParameter): pass class FunctionParameter(BaseParameter): pass class TupleElement(ImmutableBase): name: str val: Set path_id: typing.Optional[PathId] = None class Tuple(ImmutableExpr): named: bool = False elements: list[TupleElement] typeref: TypeRef class Array(ImmutableExpr): elements: typing.Sequence[Set] typeref: TypeRef class TypeCheckOp(ImmutableExpr): left: Set right: TypeRef op: str result: typing.Optional[bool] = None typeref: TypeRef class SortExpr(Base): expr: Set direction: typing.Optional[qlast.SortOrder] nones_order: typing.Optional[qlast.NonesOrder] class CallArg(ImmutableBase): """Call argument.""" # cardinality fields need to be mutable for lazy cardinality inference. __ast_mutable_fields__ = frozenset(('cardinality', 'multiplicity')) expr: Set """PathId for the __type__ link of object type arguments.""" expr_type_path_id: typing.Optional[PathId] = None cardinality: qltypes.Cardinality = qltypes.Cardinality.UNKNOWN multiplicity: qltypes.Multiplicity = qltypes.Multiplicity.UNKNOWN is_default: bool = False param_typemod: qltypes.TypeModifier polymorphism: qltypes.Polymorphism = qltypes.Polymorphism.NotUsed class Call(ImmutableExpr): """Operator or a function call.""" __abstract_node__ = True # Bound callable has polymorphic parameters and # a polymorphic return type. func_polymorphic: bool # Bound callable's name. func_shortname: sn.QualName # Whether the bound callable is a "USING SQL EXPRESSION" callable. func_sql_expr: bool = False # Whether the return value of the function should be # explicitly cast into the declared function return type. force_return_cast: bool # Bound arguments. # Named arguments are indexed by argument name. # Positional arguments are indexed by argument position. args: dict[int | str, CallArg] # Return type and typemod. In bodies of polymorphic functions # the return type can be polymorphic; in queries the return # type will be a concrete schema type. typeref: TypeRef typemod: qltypes.TypeModifier # If the return type is a tuple, this will contain a list # of tuple element path ids relative to the call set. tuple_path_ids: list[PathId] # Volatility of the function or operator. volatility: qltypes.Volatility # Whether the underlying implementation is strict in all its required # arguments (NULL inputs lead to NULL results). If not, we need to # filter at the call site. impl_is_strict: bool = False # Kind of a hack: indicates that when possible we should pass arguments # to this function as a subquery-as-an-expression. # See comment in schema/functions.py for more discussion. prefer_subquery_args: bool = False # If this is a set of call but is allowed in singleton expressions. is_singleton_set_of: typing.Optional[bool] = None # The polymorphism of the return type # This is used to identify cases where polymorphism needs to be handled in # a specialized way (eg. arrays of arrays). return_polymorphism: qltypes.Polymorphism = qltypes.Polymorphism.NotUsed class FunctionCall(Call): __ast_mutable_fields__ = frozenset(( 'extras', 'body' )) # If the bound callable is a "USING SQL" callable, this # attribute will be set to the name of the SQL function. func_sql_function: typing.Optional[str] # initial value needed for aggregate function calls to correctly # handle empty set func_initial_value: typing.Optional[Set] = None # True if the bound function has a variadic parameter and # there are no arguments that are bound to it. has_empty_variadic: bool = False # The underlying SQL function has OUT parameters. sql_func_has_out_params: bool = False # backend_name for the underlying function backend_name: typing.Optional[uuid.UUID] = None # Error to raise if the underlying SQL function returns NULL. error_on_null_result: typing.Optional[str] = None # Whether the generic function preserves optionality of the generic # argument(s). preserves_optionality: bool = False # Whether the generic function preserves upper cardinality of the generic # argument(s). preserves_upper_cardinality: bool = False # Set to the type of the variadic parameter of the bound function # (or None, if the function has no variadic parameters.) variadic_param_type: typing.Optional[TypeRef] = None # Additional arguments representing global variables global_args: typing.Optional[list[Set]] = None # Any extra information useful for compilation of special-case callables. extras: typing.Optional[dict[str, typing.Any]] = None # Inline body of the callable. body: typing.Optional[Set] = None class OperatorCall(Call): # The kind of the bound operator (INFIX, PREFIX, etc.). operator_kind: qltypes.OperatorKind # If the bound callable is a "USING SQL FUNCTION" callable, this # attribute will be set to the name of the SQL function. sql_function: typing.Optional[tuple[str, ...]] = None # If this operator maps directly onto an SQL operator, this # will contain the operator name, and, optionally, backend # operand types. sql_operator: typing.Optional[tuple[str, ...]] = None # The name of the origin operator if this is a derivative operator. origin_name: typing.Optional[sn.QualName] = None # The module id of the origin operator if this is a derivative operator. origin_module_id: typing.Optional[uuid.UUID] = None class IndexIndirection(ImmutableExpr): expr: Base index: Base typeref: TypeRef class SliceIndirection(ImmutableExpr): expr: Set start: typing.Optional[Base] stop: typing.Optional[Base] typeref: TypeRef class TypeCast(ImmutableExpr): """ImmutableExpr""" expr: Set cast_name: typing.Optional[sn.QualName] = None from_type: TypeRef to_type: TypeRef cardinality_mod: typing.Optional[qlast.CardinalityModifier] = None sql_function: typing.Optional[str] = None sql_cast: bool sql_expr: bool error_message_context: typing.Optional[str] = None @property def typeref(self) -> TypeRef: return self.to_type class MaterializedSet(Base): # Hide uses to reduce spew; we produce our own simpler uses __ast_hidden__ = {'use_sets'} materialized: Set reason: typing.Sequence[MaterializeReason] # We really only want the *paths* of all the places it is used, # but we need to store the sets to take advantage of weak # namespace rewriting. use_sets: list[Set] cardinality: qltypes.Cardinality = qltypes.Cardinality.UNKNOWN # Whether this has been "finalized" by stmtctx; just for supporting some # assertions finalized: bool = False @property def uses(self) -> list[PathId]: return [x.path_id for x in self.use_sets] @markup.serializer.serializer.register(MaterializedSet) def _serialize_to_markup_mat_set( ir: MaterializedSet, *, ctx: typing.Any ) -> typing.Any: # We want to show the path_ids but *not* to show the full uses node = ast.serialize_to_markup(ir, ctx=ctx) node.add_child(label='uses', node=markup.serialize(ir.uses, ctx=ctx)) return node class Stmt(Expr): __abstract_node__ = True # Hide parent_stmt to reduce debug spew and to hide it from find_children __ast_hidden__ = {'parent_stmt'} name: typing.Optional[str] = None # Parts of the edgeql->IR compiler need to create statements and fill in # the result later, but making it Optional would cause lots of errors, # so we stick a dummy set set in. result: Set = DUMMY_SET parent_stmt: typing.Optional[Stmt] = None iterator_stmt: typing.Optional[Set] = None bindings: typing.Optional[list[tuple[Set, qltypes.Volatility]]] = None @property def typeref(self) -> TypeRef: return self.result.typeref class FilteredStmt(Stmt): __abstract_node__ = True where: typing.Optional[Set] = None where_card: qltypes.Cardinality = qltypes.Cardinality.UNKNOWN class SelectStmt(FilteredStmt): orderby: typing.Optional[list[SortExpr]] = None offset: typing.Optional[Set] = None limit: typing.Optional[Set] = None implicit_wrapper: bool = False # An expression to use instead of this one for the purpose of # cardinality/multiplicity inference. This is used for when something # is desugared in a way that doesn't preserve cardinality, but we # need to anyway. card_inference_override: typing.Optional[Set] = None class GroupStmt(FilteredStmt): subject: Set = DUMMY_SET using: dict[str, tuple[Set, qltypes.Cardinality]] = ( ast.field(factory=dict)) by: list[qlast.GroupingElement] result: Set = DUMMY_SET group_binding: Set = DUMMY_SET grouping_binding: typing.Optional[Set] = None orderby: typing.Optional[list[SortExpr]] = None # Optimization information group_aggregate_sets: dict[ typing.Optional[Set], frozenset[PathId] ] = ast.field(factory=dict) class MutatingLikeStmt(Expr): """Represents statements that are "like" mutations for certain purposes. In particular, it includes both MutatingStmt, representing actual mutations, and TriggerAnchor, which is a way to signal that something should (or should not) see certain mutation overlays in the backend without being an actual mutation. """ __abstract_node__ = True class TriggerAnchor(MutatingLikeStmt): """A placeholder to be put in trigger __old__ nodes. The idea here is that in the backend, it will be treated as if it was a MutatingStmt for the purposes of determining whether to use overlays. """ typeref: TypeRef class MutatingStmt(Stmt, MutatingLikeStmt): __abstract_node__ = True # Parts of the edgeql->IR compiler need to create statements and fill in # the subject later, but making it Optional would cause lots of errors, # so we stick a dummy set in. subject: Set = DUMMY_SET # Conflict checks that we should manually raise constraint violations # for. conflict_checks: typing.Optional[list[OnConflictClause]] = None # Access policy checks that we should raise errors on write_policies: dict[uuid.UUID, WritePolicies] = ast.field( factory=dict ) # Access policy checks that we should filter on read_policies: dict[uuid.UUID, ReadPolicyExpr] = ast.field( factory=dict ) # Rewrites of the subject shape rewrites: typing.Optional[Rewrites] = None @property def material_type(self) -> TypeRef: """The proper material type being operated on. This should have all views stripped out. """ raise NotImplementedError class ReadPolicyExpr(Base): expr: Set cardinality: qltypes.Cardinality = qltypes.Cardinality.UNKNOWN class WritePolicies(Base): policies: list[WritePolicy] class WritePolicy(Base): expr: Set action: qltypes.AccessPolicyAction name: str error_msg: typing.Optional[str] cardinality: qltypes.Cardinality = qltypes.Cardinality.UNKNOWN class Trigger(Base): expr: Set # All the relevant dml affected: set[tuple[TypeRef, MutatingStmt]] all_affected_types: set[TypeRef] source_type: TypeRef kinds: set[qltypes.TriggerKind] scope: qltypes.TriggerScope # N.B: Semantically and in the external language, delete triggers # don't have a __new__ set, but we give it one in the # implementation (identical to the old set), to help make the # implementation more uniform. new_set: Set old_set: typing.Optional[Set] class OnConflictClause(Base): constraint: typing.Optional[ConstraintRef] select_ir: Set always_check: bool else_ir: typing.Optional[Set] check_anchor: typing.Optional[PathId] = None else_fail: typing.Optional[MutatingStmt] = None class InsertStmt(MutatingStmt): on_conflict: typing.Optional[OnConflictClause] = None final_typeref: typing.Optional[TypeRef] = None @property def material_type(self) -> TypeRef: return self.subject.typeref.real_material_type @property def typeref(self) -> TypeRef: return self.final_typeref or self.result.typeref # N.B: The PointerRef corresponds to the *definition* point of the rewrite. RewritesOfType = dict[str, tuple[SetE[Pointer], BasePointerRef]] @dataclasses.dataclass(kw_only=True, frozen=True, slots=True) class Rewrites: old_path_id: typing.Optional[PathId] by_type: dict[TypeRef, RewritesOfType] class UpdateStmt(MutatingStmt, FilteredStmt): _material_type: TypeRef | None = None @property def material_type(self) -> TypeRef: assert self._material_type return self._material_type class DeleteStmt(MutatingStmt, FilteredStmt): _material_type: TypeRef | None = None links_to_delete: dict[ uuid.UUID, tuple[PointerRef, ...] ] = ast.field(factory=dict) @property def material_type(self) -> TypeRef: assert self._material_type return self._material_type class SessionStateCmd(Command): modaliases: dict[typing.Optional[str], s_mod.Module] testmode: bool class ConfigCommand(Command, Expr): __abstract_node__ = True name: str scope: qltypes.ConfigScope cardinality: qltypes.SchemaCardinality requires_restart: bool backend_setting: typing.Optional[str] is_system_config: bool type_rewrites: typing.Optional[dict[tuple[uuid.UUID, bool], Set]] = None globals: typing.Optional[list[Global]] = None scope_tree: typing.Optional[ScopeTreeNode] = None params: list[Param] = ast.field(factory=list) schema: typing.Optional[s_schema.Schema] = None class ConfigSet(ConfigCommand): expr: Set required: bool backend_expr: typing.Optional[Set] = None @property def typeref(self) -> TypeRef: return self.expr.typeref class ConfigReset(ConfigCommand): selector: typing.Optional[Set] = None @property def typeref(self) -> TypeRef: return TypeRef( id=so.get_known_type_id('anytype'), name_hint=sn.UnqualName('anytype'), ) class ConfigInsert(ConfigCommand): expr: Set @property def typeref(self) -> TypeRef: return self.expr.typeref class FTSDocument(ImmutableExpr): """ Text and information on how to search through it. Constructed with `std::fts::with_options`. """ text: Set language: Set language_domain: set[str] weight: typing.Optional[str] typeref: TypeRef # StaticIntrospection is only used in static evaluation (staeval.py), # but unfortunately the IR AST node can only be defined here. class StaticIntrospection(Tuple): ir: TypeIntrospection schema: s_schema.Schema @property def meta_type(self) -> s_objtypes.ObjectType: return self.schema.get_by_id( self.ir.typeref.id, type=s_objtypes.ObjectType ) @property def output_type(self) -> s_types.Type: return self.schema.get_by_id( self.ir.output_typeref.id, type=s_types.Type ) @property def elements(self) -> list[TupleElement]: from . import staeval rv = [] schema = self.schema output_type = self.output_type for ptr in self.meta_type.get_pointers(schema).objects(schema): field_sn = ptr.get_shortname(schema) field_name = field_sn.name field_type = ptr.get_target(schema) assert field_type is not None try: field_value = output_type.get_field_value(schema, field_name) except LookupError: continue try: val = staeval.coerce_py_const(field_type.id, field_value) except staeval.UnsupportedExpressionError: continue ref = TypeRef(id=field_type.id, name_hint=field_sn) vset = Set(expr=val, typeref=ref, path_id=PathId.from_typeref(ref)) rv.append(TupleElement(name=field_name, val=vset)) return rv @elements.setter def elements(self, elements: list[TupleElement]) -> None: pass def get_field_value(self, name: sn.QualName) -> ConstExpr | TypeCast: from . import staeval ptr = self.meta_type.getptr(self.schema, name.get_local_name()) rv_type = ptr.get_target(self.schema) assert rv_type is not None rv_value = self.output_type.get_field_value(self.schema, name.name) return staeval.coerce_py_const(rv_type.id, rv_value) ================================================ FILE: edb/ir/astexpr.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2013-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Optional from edb.schema import name as sn from . import ast as irast def get_constraint_references(tree: irast.Base) -> Optional[list[irast.Base]]: return is_constraint_expr(tree) def is_constraint_expr(tree: irast.Base) -> Optional[list[irast.Base]]: return ( is_distinct_expr(tree) or is_set_expr(tree) or is_binop(tree) ) def is_distinct_expr(tree: irast.Base) -> Optional[list[irast.Base]]: return ( is_pure_distinct_expr(tree) or is_possibly_wrapped_distinct_expr(tree) ) def is_pure_distinct_expr(tree: irast.Base) -> Optional[list[irast.Base]]: if not isinstance(tree, irast.FunctionCall): return None if tree.func_shortname != sn.QualName('std', '_is_exclusive'): return None if len(tree.args) != 1: return None if 0 not in tree.args: return None if not isinstance(tree.args[0], irast.CallArg): return None return [tree.args[0].expr] def is_possibly_wrapped_distinct_expr( tree: irast.Base ) -> Optional[list[irast.Base]]: if not isinstance(tree, irast.SelectStmt): return None return is_set_expr(tree.result) def is_set_expr(tree: irast.Base) -> Optional[list[irast.Base]]: if not isinstance(tree, irast.Set): return None return ( is_distinct_expr(tree.expr) or is_binop(tree.expr) ) def is_binop(tree: irast.Base) -> Optional[list[irast.Base]]: if not isinstance(tree, irast.OperatorCall): return None if not tree.func_shortname != sn.QualName('std', 'AND'): return None if len(tree.args) != 2: return None refs = [] for arg in tree.args: if not isinstance(arg, irast.CallArg): return None ref = is_constraint_expr(arg.expr) if not ref: return None refs.extend(ref) return refs ================================================ FILE: edb/ir/pathid.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, AbstractSet, Iterator, cast, TYPE_CHECKING, ) from . import typeutils from edb.common import uuidgen from edb.schema import name as s_name from edb.schema import pointers as s_pointers from edb.schema import types as s_types from edb.ir import ast as irast if TYPE_CHECKING: import uuid from edb.schema import schema as s_schema from edb.edgeql.compiler import context as qlcompiler_ctx Namespace = str class PathId: """A descriptor of a *variable* in an expression. ``PathId`` instances are used to identify and describe expressions in EdgeQL. They are immutable, hashable and comparable. Instances of ``PathId`` describing the same expression variable are equal. Another important aspect (and the reason for the class name) is that ``PathId`` instances describe *paths* in a structured way that allows walking the path to its root. ``PathId`` instances are normally directly created for a path root, and then PathIds representing the steps of a path are derived by calling ``extend()`` on the previous step. For example, for the expression ``Movie.reviews.author`` the following would return a corresponding ``PathId`` (in pseudo-code): path_id = PathId.from_type(Movie).extend('reviews').extend('author') """ __slots__ = ('_path', '_norm_path', '_namespace', '_prefix', '_is_ptr', '_is_linkprop', '_hash') #: Actual path information. _path: tuple[ irast.TypeRef | tuple[irast.BasePointerRef, s_pointers.PointerDirection], ... ] #: Normalized path data, used for PathId hashing and comparisons. _norm_path: tuple[ uuid.UUID | s_name.Name | tuple[s_name.QualName, s_pointers.PointerDirection, bool], ... ] #: A set of namespace identifiers which this PathId belongs to. _namespace: frozenset[str] #: If this PathId has a prefix from another namespace, this will #: contain said prefix. _prefix: Optional[PathId] #: True if this PathId represents the link portion of a link property path. _is_ptr: bool #: True if this PathId represents a link property path. _is_linkprop: bool def __init__( self, initializer: Optional[PathId] = None, *, namespace: AbstractSet[str] = frozenset(), typename: Optional[str] = None, ) -> None: if isinstance(initializer, PathId): self._path = initializer._path self._norm_path = initializer._norm_path if namespace: self._namespace = frozenset(namespace) else: self._namespace = initializer._namespace self._is_ptr = initializer._is_ptr self._is_linkprop = initializer._is_linkprop self._prefix = initializer._prefix elif initializer is not None: raise TypeError('use PathId.from_type') else: self._path = () self._norm_path = () self._namespace = frozenset(namespace) self._prefix = None self._is_ptr = False self._is_linkprop = False self._hash = -1 def __getstate__(self) -> Any: # We need to omit the cached _hash when we pickle because it won't # be correct in a different process. return tuple([ getattr(self, k) if k != '_hash' else -1 for k in PathId.__slots__ ]) def __setstate__(self, state: Any) -> None: for k, v in zip(PathId.__slots__, state): setattr(self, k, v) @classmethod def from_type( cls, schema: s_schema.Schema, t: s_types.Type, *, env: Optional[qlcompiler_ctx.Environment], namespace: AbstractSet[Namespace] = frozenset(), typename: Optional[s_name.QualName] = None, ) -> PathId: """Return a ``PathId`` instance for a given :class:`schema.types.Type` The returned ``PathId`` instance describes a set variable of type *t*. The name of the passed type is used as the name for the variable, unless *typename* is specified, in which case it is used instead. Args: schema: A schema instance where the type *t* is defined. t: The type of the variable being defined. env: Optional EdgeQL compiler environment, used for caching. namespace: Optional namespace in which the variable is defined. typename: If specified, used as the name for the variable instead of the name of the type *t*. Returns: A ``PathId`` instance of type *t*. """ if not isinstance(t, s_types.Type): raise ValueError( f'invalid PathId: bad source: {t!r}') cache = env.type_ref_cache if env is not None else None typeref = typeutils.type_to_typeref( schema, t, cache=cache, typename=typename ) return cls.from_typeref(typeref, namespace=namespace, typename=typename) @classmethod def from_pointer( cls, schema: s_schema.Schema, pointer: s_pointers.Pointer, *, namespace: AbstractSet[Namespace] = frozenset(), env: Optional[qlcompiler_ctx.Environment], ) -> PathId: """Return a ``PathId`` instance for a given link or property. The specified *pointer* argument must be a concrete link or property. The returned ``PathId`` instance describes a set variable of all objects represented by the pointer (i.e, for a link, a set of all link targets). Args: schema: A schema instance where the type *t* is defined. pointer: An instance of a concrete link or property. namespace: Optional namespace in which the variable is defined. Returns: A ``PathId`` instance. """ if pointer.is_non_concrete(schema): raise ValueError(f'invalid PathId: {pointer} is not concrete') source = pointer.get_source(schema) if isinstance(source, s_pointers.Pointer): prefix = cls.from_pointer( schema, source, namespace=namespace, env=env ) prefix = prefix.ptr_path() elif isinstance(source, s_types.Type): prefix = cls.from_type(schema, source, namespace=namespace, env=env) else: raise AssertionError(f'unexpected pointer source: {source!r}') typeref_cache = env.type_ref_cache if env is not None else None ptrref_cache = env.ptr_ref_cache if env is not None else None ptrref = typeutils.ptrref_from_ptrcls( schema=schema, ptrcls=pointer, cache=ptrref_cache, typeref_cache=typeref_cache, ) return prefix.extend(ptrref=ptrref) @classmethod def from_typeref( cls, typeref: irast.TypeRef, *, namespace: AbstractSet[Namespace] = frozenset(), typename: Optional[s_name.Name | uuid.UUID] = None, ) -> PathId: """Return a ``PathId`` instance for a given :class:`ir.ast.TypeRef` The returned ``PathId`` instance describes a set variable of type described by *typeref*. The name of the passed type is used as the name for the variable, unless *typename* is specified, in which case it is used instead. Args: typeref: The descriptor of a type of the variable being defined. namespace: Optional namespace in which the variable is defined. typename: If specified, used as the name for the variable instead of the name of the type *t*. Returns: A ``PathId`` instance of type described by *typeref*. """ pid = cls() pid._path = (typeref,) if typename is None: typename = typeref.id pid._norm_path = (typename,) pid._namespace = frozenset(namespace) return pid @classmethod def from_ptrref( cls, ptrref: irast.PointerRef, *, namespace: AbstractSet[Namespace] = frozenset(), ) -> PathId: """Return a ``PathId`` instance for a given :class:`ir.ast.PointerRef` Args: ptrref: The descriptor of a ptr of the variable being defined. namespace: Optional namespace in which the variable is defined. Returns: A ``PathId`` instance of type described by *ptrref*. """ pid = cls.from_typeref(ptrref.out_source, namespace=namespace) pid = pid.extend(ptrref=ptrref) return pid @classmethod def new_dummy(cls, name: str) -> PathId: name_hint = s_name.QualName(module='__derived__', name=name) typeref = irast.TypeRef(id=uuidgen.uuid1mc(), name_hint=name_hint) return irast.PathId.from_typeref(typeref=typeref) def __hash__(self) -> int: if self._hash == -1: self._hash = hash(( self.__class__, self._norm_path, self._namespace, self._prefix, self._is_ptr, )) return self._hash def __eq__(self, other: Any) -> bool: if not isinstance(other, PathId): return NotImplemented return ( self._norm_path == other._norm_path and self._namespace == other._namespace and self._prefix == other._prefix and self._is_ptr == other._is_ptr ) def __len__(self) -> int: return len(self._path) def __str__(self) -> str: return self.pformat_internal(debug=False) __repr__ = __str__ def extend( self, *, ptrref: irast.BasePointerRef, direction: s_pointers.PointerDirection = ( s_pointers.PointerDirection.Outbound), ns: AbstractSet[Namespace] = frozenset(), ) -> PathId: """Return a new ``PathId`` that is a *path step* from this ``PathId``. For example, if you have a ``PathId`` that describes a variable ``A``, and you want to obtain a ``PathId`` for ``A.b``, you should call ``path_id_for_A.extend(ptrcls=pointer_object_b, schema=schema)``. Args: ptrref: A ``ir.ast.BasePointerRef`` instance that corresponds to the path step. This may be a regular link or property object, or a pseudo-pointer, like a tuple or type intersection step. direction: The direction of the *ptrcls* pointer. This makes sense only for reverse link traversal, all other path steps are always forward. namespace: Optional namespace in which the path extension is defined. If not specified, the namespace of the current PathId is used. schema: A schema instance. Returns: A new ``PathId`` instance representing a step extension of this ``PathId``. """ if not self: raise ValueError('cannot extend empty PathId') if direction is s_pointers.PointerDirection.Outbound: target_ref = ptrref.out_target else: target_ref = ptrref.out_source is_linkprop = ptrref.source_ptr is not None if is_linkprop and not self._is_ptr: raise ValueError( 'link property path extension on a non-link path') result = self.__class__() result._path = self._path + ((ptrref, direction), target_ref) link_name = ptrref.name lnk = (link_name, direction, is_linkprop) result._is_linkprop = is_linkprop if target_ref.material_type is not None: material_type = target_ref.material_type else: material_type = target_ref result._norm_path = (self._norm_path + (lnk, material_type.id)) if ns: if self._namespace: result._namespace = self._namespace | frozenset(ns) else: result._namespace = frozenset(ns) else: result._namespace = self._namespace if self._namespace != result._namespace: result._prefix = self else: result._prefix = self._prefix return result def replace_namespace( self, namespace: AbstractSet[Namespace], ) -> PathId: """Return a copy of this ``PathId`` with namespace set to *namespace*. """ result = self.__class__(self) result._namespace = frozenset(namespace) if result._prefix is not None: result._prefix = result._get_minimal_prefix( result._prefix.replace_namespace(namespace)) return result def merge_namespace( self, namespace: AbstractSet[Namespace], *, deep: bool=False, ) -> PathId: """Return a copy of this ``PathId`` that has *namespace* added to its namespace. """ new_namespace = self._namespace | frozenset(namespace) if new_namespace != self._namespace or deep: result = self.__class__(self) result._namespace = new_namespace if deep and result._prefix is not None: result._prefix = result._prefix.merge_namespace(new_namespace) if result._prefix is not None: result._prefix = result._get_minimal_prefix(result._prefix) return result else: return self def strip_namespace(self, namespace: AbstractSet[Namespace]) -> PathId: """Return a copy of this ``PathId`` with a given portion of the namespace id removed.""" if self._namespace and namespace: stripped_ns = self._namespace - set(namespace) result = self.replace_namespace(stripped_ns) if result._prefix is not None: result._prefix = result._get_minimal_prefix( result._prefix.strip_namespace(namespace)) return result else: return self def pformat_internal(self, debug: bool = False) -> str: """Verbose format for debugging purposes.""" result = '' if not self._path: return '' if self._namespace: result += f'{"@".join(sorted(self._namespace))}@@' path = self._path result += f'({path[0].name_hint})' # type: ignore for i in range(1, len(path) - 1, 2): ptrspec = cast( tuple[irast.BasePointerRef, s_pointers.PointerDirection], path[i], ) tgtspec = cast( irast.TypeRef, path[i + 1], ) if debug: link_name = str(ptrspec[0].name) ptr = f'({link_name})' else: ptr = ptrspec[0].shortname.name ptrdir = ptrspec[1] is_lprop = ptrspec[0].source_ptr is not None if tgtspec.material_type is not None: mat_tgt = tgtspec.material_type else: mat_tgt = tgtspec tgt = mat_tgt.name_hint if tgt: lexpr = f'{ptr}[IS {tgt}]' else: lexpr = f'{ptr}' if is_lprop: step = '@' else: step = f'.{ptrdir}' result += f'{step}{lexpr}' if self._is_ptr: result += '@' return result def pformat(self) -> str: """Pretty PathId format for user-visible messages.""" result = '' if not self._path: return '' path = self._path start_name = s_name.shortname_from_fullname( path[0].name_hint) # type: ignore result += f'{start_name.name}' for i in range(1, len(path) - 1, 2): ptrspec = cast( tuple[irast.BasePointerRef, s_pointers.PointerDirection], path[i], ) ptr_name = ptrspec[0].shortname ptrdir = ptrspec[1] is_lprop = ptrspec[0].source_ptr is not None if is_lprop: step = '@' else: step = '.' if ptrdir == s_pointers.PointerDirection.Inbound: step += ptrdir result += f'{step}{ptr_name.name}' if self._is_ptr: result += '@' return result def rptr(self) -> Optional[irast.BasePointerRef]: """Return the descriptor of a pointer for the last path step, if any. If this PathId represents a non-path expression, ``rptr()`` will return ``None``. """ if len(self._path) > 1: return self._path[-2][0] # type: ignore else: return None def rptr_dir(self) -> Optional[s_pointers.PointerDirection]: """Return the direction of a pointer for the last path step, if any. If this PathId represents a non-path expression, ``rptr_dir()`` will return ``None``. """ if len(self._path) > 1: return self._path[-2][1] # type: ignore else: return None def rptr_name(self) -> Optional[s_name.QualName]: """Return the name of a pointer for the last path step, if any. If this PathId represents a non-path expression, ``rptr_name()`` will return ``None``. """ rptr = self.rptr() if rptr is not None: return rptr.shortname else: return None def src_path(self) -> Optional[PathId]: """Return a ``PathId`` instance representing an immediate path prefix of this ``PathId``, i.e ``PathId('Foo.bar.baz').src_path() == PathId('Foo.bar')``. If this PathId represents a non-path expression, ``src_path()`` will return ``None``. """ if len(self._path) > 1: return self._get_prefix(-2) else: return None def ptr_path(self) -> PathId: """Return a new ``PathId`` instance that is a "pointer prefix" of this ``PathId``. A pointer prefix is the common path prefix shared by paths to link properties of the same link, i.e common_path_id(Foo.bar@prop1, Foo.bar@prop2) == PathId(Foo.bar).ptr_path() """ if self._is_ptr: return self else: result = self.__class__(self) result._is_ptr = True return result def tgt_path(self) -> PathId: """If this is a pointer prefix, return the ``PathId`` representing the path to the target of the pointer. This is the inverse of :meth:`~PathId.ptr_path`. """ if not self._is_ptr: return self else: result = self.__class__(self) result._is_ptr = False return result def iter_prefixes(self, include_ptr: bool = False) -> Iterator[PathId]: """Return an iterator over all prefixes of this ``PathId``. The order of prefixes is from longest to shortest, i.e ``PathId(A.b.c.d).iter_prefixes()`` will yield [PathId(A.b.c.d), PathId(A.b.c), PathId(A.b), PathId(A)]. If *include_ptr* is ``True``, then pointer prefixes for each step are also included. """ if self._prefix is not None: yield from self._prefix.iter_prefixes(include_ptr=include_ptr) start = len(self._prefix) else: yield self._get_prefix(1) start = 1 for i in range(start, len(self._path) - 1, 2): path_id = self._get_prefix(i + 2) if path_id.is_ptr_path(): yield path_id.tgt_path() if include_ptr: yield path_id else: yield path_id def startswith( self, path_id: PathId, permissive_ptr_path: bool = False ) -> bool: """Return true if this ``PathId`` has *path_id* as a prefix.""" base = self._get_prefix(len(path_id)) return base == path_id or ( permissive_ptr_path and base.tgt_path() == path_id) @property def target(self) -> irast.TypeRef: """Return the type descriptor for this PathId.""" return self._path[-1] # type: ignore @property def target_name_hint(self) -> s_name.Name: """Return the name of the type for this PathId.""" if self.target.material_type is not None: material_type = self.target.material_type else: material_type = self.target return material_type.name_hint def is_objtype_path(self) -> bool: """Return True if this PathId represents an expression of object type. """ return not self.is_ptr_path() and typeutils.is_object(self.target) def is_scalar_path(self) -> bool: """Return True if this PathId represents an expression of scalar type. """ return not self.is_ptr_path() and typeutils.is_scalar(self.target) def is_view_path(self) -> bool: """Return True if this PathId represents an expression that is a view. """ return not self.is_ptr_path() and typeutils.is_view(self.target) def is_tuple_path(self) -> bool: """Return True if this PathId represents an expression of an tuple type. """ return not self.is_ptr_path() and typeutils.is_tuple(self.target) def is_tuple_indirection_path(self) -> bool: """Return True if this PathId represents a tuple element indirection expression. """ src_path = self.src_path() return src_path is not None and src_path.is_tuple_path() def is_array_path(self) -> bool: """Return True if this PathId represents an expression of an array type. """ return not self.is_ptr_path() and typeutils.is_array(self.target) def is_range_path(self) -> bool: """Return True if this PathId represents an expression of a range type. """ return not self.is_ptr_path() and typeutils.is_range(self.target) def is_collection_path(self) -> bool: """Return True if this PathId represents an expression of a collection type. """ return not self.is_ptr_path() and typeutils.is_collection(self.target) def is_ptr_path(self) -> bool: """Return True if this PathId represents a link prefix of the path. Immediate prefix of a link property ``PathId`` will return True here. """ return self._is_ptr def is_linkprop_path(self) -> bool: """Return True if this PathId represents a link property path expression, i.e ``Foo.bar@prop``.""" return self._is_linkprop def is_type_intersection_path(self) -> bool: """Return True if this PathId represents a type intersection expression, i.e ``Foo[IS Bar]``.""" rptr_name = self.rptr_name() if rptr_name is None: return False else: return str(rptr_name) in ( '__type__::indirection', '__type__::optindirection', ) @property def namespace(self) -> frozenset[str]: """The namespace of this ``PathId``""" return self._namespace def _get_prefix(self, size: int) -> PathId: if size < 0: size = len(self._path) + size if size == len(self._path): return self if self._prefix is not None: prefix_len = len(self._prefix) if prefix_len == size: return self._prefix elif prefix_len > size: return self._prefix._get_prefix(size) result = self.__class__() result._path = self._path[0:size] result._norm_path = self._norm_path[0:size] result._prefix = self._prefix result._namespace = self._namespace if rptr := result.rptr(): result._is_linkprop = rptr.source_ptr is not None if size < len(self._path) and self._norm_path[size][2]: # type: ignore # A link property ref has been chopped off. result._is_ptr = True return result def _get_minimal_prefix( self, prefix: Optional[PathId], ) -> Optional[PathId]: while prefix is not None: if prefix._namespace == self._namespace: prefix = prefix._prefix else: break return prefix ================================================ FILE: edb/ir/scopetree.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Query scope tree implementation.""" from __future__ import annotations from typing import ( Any, Optional, AbstractSet, Iterator, Mapping, Collection, NamedTuple, Protocol, cast, TYPE_CHECKING, ) if TYPE_CHECKING: from typing_extensions import TypeGuard import sys import textwrap import weakref from edb import errors from edb.common import ordered from edb.common import span from edb.common import term from . import pathid from . import ast as irast class WarningContext(Protocol): def log_warning(self, warning: errors.EdgeDBError) -> None: ... class FenceInfo(NamedTuple): unnest_fence: bool factoring_fence: bool def __or__(self, other: FenceInfo) -> FenceInfo: return FenceInfo( unnest_fence=self.unnest_fence or other.unnest_fence, factoring_fence=self.factoring_fence or other.factoring_fence, ) def has_path_id(nobe: ScopeTreeNode) -> TypeGuard[ScopeTreeNodeWithPathId]: return nobe.path_id is not None class ScopeTreeNode: unique_id: Optional[int] """A unique identifier used to map scopes on sets.""" path_id: Optional[pathid.PathId] """Node path id, or None for branch nodes.""" fenced: bool """Whether the subtree represents a SET OF argument.""" is_group: bool """Whether the node reprents a GROUP binding (and so *is* multi...).""" unnest_fence: bool """Prevent unnesting in parents.""" factoring_fence: bool """Prevent prefix factoring across this node.""" factoring_allowlist: set[pathid.PathId] """A list of prefixes that are always allowed to be factored.""" optional: bool """Whether this node represents an optional path.""" children: list[ScopeTreeNode] """A set of child nodes.""" namespaces: set[pathid.Namespace] """A set of namespaces used by paths in this branch. When a path node is pulled up from this branch, and its namespace matches anything in `namespaces`, the namespace will be stripped. This is used to implement "semi-detached" semantics used by aliases declared in a WITH block.""" def __init__( self, *, path_id: Optional[pathid.PathId]=None, fenced: bool=False, unique_id: Optional[int]=None, optional: bool=False, ) -> None: self.unique_id = unique_id self.path_id = path_id self.fenced = fenced self.unnest_fence = False self.factoring_fence = False self.factoring_allowlist = set() self.optional = optional self.children = [] self.namespaces = set() self.is_group = False self._parent: Optional[weakref.ReferenceType[ScopeTreeNode]] = None FIELDS = ( 'unique_id', 'path_id', 'fenced', 'unnest_fence', 'factoring_fence', 'factoring_allowlist', 'optional', 'children', 'namespaces', 'is_group', ) def __getstate__(self) -> Any: res = self.__dict__.copy() del res['_parent'] return res def __setstate__(self, state: Any) -> None: for f, val in state.items(): setattr(self, f, val) self._parent = None for child in self.children: child._parent = weakref.ref(self) def __repr__(self) -> str: name = 'ScopeFenceNode' if self.fenced else 'ScopeTreeNode' return (f'<{name} {self.path_id!r} at {id(self):0x}>') def find_dupe_unique_ids(self) -> set[int]: seen = set() dupes = set() for node in self.root.descendants: if node.unique_id is not None: if node.unique_id in seen: dupes.add(node.unique_id) seen.add(node.unique_id) return dupes def validate_unique_ids(self) -> None: dupes = self.find_dupe_unique_ids() assert not dupes, f'Duplicate "unique" ids seen {dupes}' @property def name(self) -> str: return self._name(debug=False) def _name(self, debug: bool) -> str: if self.path_id is None: name = ( ('FENCE' if self.fenced else 'BRANCH') ) else: name = self.path_id.pformat_internal(debug=debug) return f'{name}{" [OPT]" if self.optional else ""}' def debugname(self, fuller: bool = False) -> str: parts = [f'{self._name(debug=fuller)}'] if self.unique_id: parts.append(f'uid:{self.unique_id}') if self.namespaces: parts.append(','.join(self.namespaces)) if self.unnest_fence: parts.append('no-unnest') if self.factoring_fence: parts.append('no-factor') if self.is_group: parts.append('group') return ' '.join(parts) @property def fence_info(self) -> FenceInfo: return FenceInfo( unnest_fence=self.unnest_fence, factoring_fence=self.factoring_fence, ) def fence_info_ex( self, path_id: pathid.PathId, namespaces: AbstractSet[str] ) -> FenceInfo: finfo = self.fence_info if any( _paths_equal(path_id, wl, namespaces) for wl in self.factoring_allowlist ): finfo = finfo._replace(factoring_fence=False) return finfo @property def ancestors(self) -> Iterator[ScopeTreeNode]: """An iterator of node's ancestors, including self.""" node: Optional[ScopeTreeNode] = self while node is not None: yield node node = node.parent @property def strict_ancestors(self) -> Iterator[ScopeTreeNode]: """An iterator of node's ancestors, not including self.""" node: Optional[ScopeTreeNode] = self.parent while node is not None: yield node node = node.parent @property def ancestors_and_namespaces( self, ) -> Iterator[tuple[ScopeTreeNode, frozenset[pathid.Namespace]]]: """An iterator of node's ancestors and namespaces, including self.""" namespaces: frozenset[str] = frozenset() node: Optional[ScopeTreeNode] = self while node is not None: namespaces |= node.namespaces yield node, namespaces node = node.parent @property def path_children(self) -> Iterator[ScopeTreeNodeWithPathId]: """An iterator of node's children that have path ids.""" return ( p for p in self.children if has_path_id(p) ) @property def path_descendants(self) -> Iterator[ScopeTreeNodeWithPathId]: """An iterator of node's descendants that have path ids.""" return ( p for p in self.descendants if has_path_id(p) ) def get_all_paths(self) -> AbstractSet[pathid.PathId]: return ordered.OrderedSet(pd.path_id for pd in self.path_descendants) @property def descendants(self) -> Iterator[ScopeTreeNode]: """An iterator of node's descendants including self top-first.""" yield self yield from self.strict_descendants @property def strict_descendants(self) -> Iterator[ScopeTreeNode]: """An iterator of node's descendants not including self top-first.""" for child in tuple(self.children): yield child if child.parent is self: yield from child.strict_descendants def descendants_and_namespaces_ex( self, *, unfenced_only: bool=False, strict: bool=False, skip: Optional[ScopeTreeNode]=None, ) -> Iterator[ tuple[ ScopeTreeNode, AbstractSet[pathid.Namespace], FenceInfo ] ]: """An iterator of node's descendants and namespaces. Args: unfenced_only: Whether to skip traversing through fenced nodes strict: Whether to skip the node itself skip: An optional child to skip during the traversal. This is useful for avoiding performance pathologies when repeatedly searching descendants while climbing the tree (see find_factorable_nodes). Top-first. """ if not strict: yield self, frozenset(), FenceInfo( unnest_fence=False, factoring_fence=False) for child in tuple(self.children): if unfenced_only and child.fenced: continue if child is skip: continue finfo = child.fence_info yield child, child.namespaces, finfo if child.parent is not self: continue desc_ns = child.descendants_and_namespaces_ex( unfenced_only=unfenced_only, strict=True) for desc, desc_namespaces, desc_finfo in desc_ns: yield ( desc, child.namespaces | desc_namespaces, finfo | desc_finfo, ) @property def strict_descendants_and_namespaces( self, ) -> Iterator[ tuple[ ScopeTreeNode, AbstractSet[pathid.Namespace], FenceInfo ] ]: """An iterator of node's descendants and namespaces. Does not include self. Top-first. """ return self.descendants_and_namespaces_ex(strict=True) @property def descendant_namespaces(self) -> set[pathid.Namespace]: """An set of namespaces declared by descendants.""" namespaces = set() for child in self.descendants: namespaces.update(child.namespaces) return namespaces @property def fence(self) -> ScopeTreeNode: """The nearest ancestor fence (or self, if fence).""" if self.fenced: return self else: return cast(ScopeTreeNode, self.parent_fence) @property def parent(self) -> Optional[ScopeTreeNode]: """The parent node.""" if self._parent is None: return None else: return self._parent() @property def path_ancestor(self) -> Optional[ScopeTreeNodeWithPathId]: for ancestor in self.strict_ancestors: if has_path_id(ancestor): return ancestor return None @property def parent_fence(self) -> Optional[ScopeTreeNode]: """The nearest strict ancestor fence.""" for ancestor in self.strict_ancestors: if ancestor.fenced: return ancestor return None @property def parent_branch(self) -> Optional[ScopeTreeNode]: """The nearest strict ancestor branch or fence.""" for ancestor in self.strict_ancestors: if ancestor.path_id is None: return ancestor return None @property def root(self) -> ScopeTreeNode: """The root of this tree.""" node = self while node.parent is not None: node = node.parent return node def strip_path_namespace(self, ns: AbstractSet[str]) -> None: if not ns: return for pd in self.path_descendants: pd.path_id = pd.path_id.strip_namespace(ns) def attach_child( self, node: ScopeTreeNode, span: Optional[span.Span] = None ) -> None: """Attach a child node to this node. This is a low-level operation, no tree validation is performed. For safe tree modification, use attach_subtree()"" """ if node.path_id is not None: for child in self.children: if child.path_id == node.path_id: raise errors.InvalidReferenceError( f'{node.path_id} is already present in {self!r}', span=span, ) if node.unique_id is not None: for child in self.children: if child.unique_id == node.unique_id: return node._set_parent(self) def attach_fence(self) -> ScopeTreeNode: """Create and attach an empty fenced node.""" fence = ScopeTreeNode(fenced=True) self.attach_child(fence) return fence def attach_branch(self) -> ScopeTreeNode: """Create and attach an empty branch node.""" fence = ScopeTreeNode() self.attach_child(fence) return fence def attach_path( self, path_id: pathid.PathId, *, optional: bool=False, span: Optional[span.Span], ctx: WarningContext, ) -> None: """Attach a scope subtree representing *path_id*.""" subtree = parent = ScopeTreeNode(fenced=True) is_lprop = False lprop_base = None for prefix in reversed(list(path_id.iter_prefixes())): new_child = ScopeTreeNode(path_id=prefix, optional=optional and parent is subtree) # Normally the prefix is nested, except that tuple # indirection prefixes and the *object* prefixes of link # properties are are at the same level. # # For example, Foo.bar.baz, where Foo is an object type, # forms this scope shape: # Foo.bar.baz # |-Foo.bar # |-Foo # # Whereas, .bar.baz results in this: # # .bar # .bar.baz # # And Foo.bar[is Typ]@baz results in: # Foo.bar[is Typ]@baz # |-Foo.bar[is Typ] # |-Foo.bar # Foo # # For tuples, this is permissable because their fields are always # singletons. # FIXME: I think that it should not be *necessary* for tuples, # but test_edgeql_volatility_select_tuples_* fail if it is changed, # I think for incidental reasons. # # For link properties, this is necessary because referring # to a link property at the end of a path suppresses # deduplication of the link, which is realized by forcing # the link source to be visible. We avoid making the rest of # the path visible, to preserve prefix visibility information # for certain optimizations. (Foo.bar[is Typ] can be compiled # such that it joins directly on Typ (instead of on Bar first), # but *only* if Foo.bar isn't visible without the type intersection. if prefix.is_linkprop_path(): assert lprop_base is None # If we just saw a linkprop, track where, since we'll # need to come back to this level in the tree once we # reach the "object prefix" of it. lprop_base = parent is_lprop = True elif is_lprop: # Skip through type intersections (i.e [IS Foo]) until # we actually get to the link. if not prefix.is_type_intersection_path(): is_lprop = False else: # If we've reached the "object prefix" of a path # referencing a linkprop, pop back up to the level the # linkprop was attached to. if lprop_base is not None: parent = lprop_base lprop_base = None parent.attach_child(new_child) if not prefix.is_tuple_indirection_path(): parent = new_child self.attach_subtree(subtree, span=span, ctx=ctx) def attach_subtree( self, node: ScopeTreeNode, was_fenced: bool = False, span: Optional[span.Span] = None, fusing: bool = False, *, ctx: WarningContext, ) -> None: """Attach a subtree to this node. *node* is expected to be a balanced scope tree and may be modified by this function. If *node* is not a path node (path_id is None), it is discarded, and it's descendants are attached directly. The tree balance is maintained. """ if node.path_id is not None: # Wrap path node wrapper_node = ScopeTreeNode(fenced=True) wrapper_node.attach_child(node) node = wrapper_node for descendant, dns, _ in node.descendants_and_namespaces_ex(): if not has_path_id(descendant): continue path_id = descendant.path_id.strip_namespace(dns) if descendant.parent_fence is node: # Unfenced path. # Search for occurences elsewhere in the tree that # can be factored with this one. # If found, attach that node directly to the factoring point # and fuse our node onto it. # If there are multiple factorable occurences, we do # this iteratively, from closest to furthest away. factorable_nodes = self.find_factorable_nodes(path_id) current = descendant if factorable_nodes: descendant.strip_path_namespace(dns) desc_optional = ( descendant.is_optional_upto(node.parent) # Check if there is an optional branch between here # and the *highest* factoring point. or self.is_optional_upto(factorable_nodes[-1][1]) ) if desc_optional: descendant.mark_as_optional() for factorable in factorable_nodes: ( existing, factor_point, current_ns, existing_ns, existing_finfo, unnest_fence, node_fenced, ) = factorable self._check_factoring_errors( path_id, descendant, factor_point, existing, unnest_fence, existing_finfo, span, ) existing_fenced = existing.parent_fence is not None and ( factor_point in existing.parent_fence.strict_ancestors ) if existing.is_optional_upto(factor_point): existing.mark_as_optional() # Strip the namespaces of everything in the lifted nodes # based on what they have been lifted through. existing.strip_path_namespace(existing_ns) current.strip_path_namespace(current_ns) current.remove() if ( factor_point is not existing.parent and factor_point is not existing ): existing.remove() factor_point.attach_child(existing) # Discard the node from the subtree being attached. existing.fuse_subtree( current, self_fenced=existing_fenced, node_fenced=node_fenced, span=span, ctx=ctx, ) current = existing # HACK: If we are being called from fuse_subtree, # skip all but the first. This is because we don't # want to merge any children before the parent # fully finishes all of its factoring. if fusing: break for child in tuple(node.children): # Attach whatever is remaining in the subtree. for pd in child.path_descendants: if pd.path_id.namespace: to_strip = set(pd.path_id.namespace) & node.namespaces pd.path_id = pd.path_id.strip_namespace(to_strip) self.attach_child(child) def _check_factoring_errors( self, path_id: pathid.PathId, descendant: ScopeTreeNodeWithPathId, factor_point: ScopeTreeNode, existing: ScopeTreeNodeWithPathId, unnest_fence: bool, existing_finfo: FenceInfo, span: Optional[span.Span], ) -> None: if existing_finfo.factoring_fence: # This node is already present in the surrounding # scope and cannot be factored out, such as # a reference to a correlated set inside a DML # statement. raise errors.InvalidReferenceError( f'cannot reference correlated set ' f'{path_id.pformat()!r} here', span=span, ) if ( unnest_fence and ( factor_point.find_child( path_id, in_branches=True, pfx_with_invariant_card=True, ) is None ) and ( not (src_path := path_id.src_path()) or not self.is_visible(src_path) ) and not existing._node_paths_are_not_links() ): path_ancestor = descendant.path_ancestor if path_ancestor is not None: offending_node = path_ancestor else: offending_node = descendant assert offending_node.path_id is not None imp = '' offending_id = f'{offending_node.path_id.pformat()!r}' existing_id = f'{existing.path_id.pformat()!r}' # If the id is generated, don't leak meaningless info # and try to explain that the reference is implicit. if '~' in offending_id: imp = 'implicit ' offending_id = 'an object' existing_id = 'it' raise errors.InvalidReferenceError( f'{imp}reference to {offending_id} ' f'changes the interpretation of {existing_id} ' f'elsewhere in the query', span=span, ) def _node_paths_are_not_links(self) -> bool: """ Check if all the pointers a path might be hoisted past are not links If the node is a path_id node, return true if the rptrs on all of the chain of parent nodes with path_ids are not links. This is in support of allowing queries like select Card.element filter Card.name = 'Imp' No real change in interpretation happens here, since element is a property and so doesn't get deduplicated. """ node: ScopeTreeNode | None = self while node and node.path_id: if ( isinstance(node.path_id.rptr(), irast.PointerRef) and node.path_id.is_objtype_path() ): return False node = node.parent return True def fuse_subtree( self, node: ScopeTreeNode, self_fenced: bool=False, node_fenced: bool=False, span: Optional[span.Span]=None, *, ctx: WarningContext, ) -> None: node.remove() if not node.optional and not node_fenced: self.optional = False if node.optional and self_fenced: self.optional = True if node.path_id is not None: subtree = ScopeTreeNode(fenced=True) subtree.optional = node.optional for child in tuple(node.children): subtree.attach_child(child) else: subtree = node self.attach_subtree( subtree, was_fenced=self_fenced, span=span, fusing=True, ctx=ctx ) def remove_subtree(self, node: ScopeTreeNode) -> None: """Remove the given subtree from this node.""" if node not in self.children: raise KeyError(f'{node} is not a child of {self}') node._set_parent(None) def remove_descendants( self, path_id: pathid.PathId, new: ScopeTreeNode ) -> None: """Remove all descendant nodes matching *path_id*.""" matching = set() for node in self.descendants: if (node.path_id is not None and _paths_equal(node.path_id, path_id, set())): matching.add(node) for node in matching: node.remove() def mark_as_optional(self) -> None: """Indicate that this scope is used as an OPTIONAL argument.""" self.optional = True def is_optional(self, path_id: pathid.PathId) -> bool: node = self.find_visible(path_id) if node is not None: return node.optional else: return False def add_namespaces( self, namespaces: AbstractSet[pathid.Namespace], ) -> None: # Make sure we don't add namespaces that already appear # in on of the ancestors. namespaces = frozenset(namespaces) - self.get_effective_namespaces() self.namespaces.update(namespaces) def get_effective_namespaces(self) -> AbstractSet[pathid.Namespace]: namespaces: set[pathid.Namespace] = set() for _node, ans in self.ancestors_and_namespaces: namespaces |= ans return namespaces def remove(self) -> None: """Remove this node from the tree (subtree becomes independent).""" parent = self.parent if parent is not None: parent.remove_subtree(self) def is_empty(self) -> bool: if self.path_id is not None: return False else: return ( not self.children or all(c.is_empty() for c in self.children) ) def get_all_visible(self) -> set[pathid.PathId]: paths = set() for node in self.ancestors: if node.path_id: paths.add(node.path_id) else: for c in node.children: if c.path_id: paths.add(c.path_id) return paths def find_visible_ex( self, path_id: pathid.PathId, *, allow_group: bool=False, ) -> tuple[ Optional[ScopeTreeNode], FenceInfo, AbstractSet[pathid.Namespace], ]: """Find the visible node with the given *path_id*.""" namespaces: set[pathid.Namespace] = set() found = None nodes: list[ScopeTreeNode] = [] for node, ans in self.ancestors_and_namespaces: if (node.path_id is not None and _paths_equal(node.path_id, path_id, namespaces)): found = node break for child in node.children: if (child.path_id is not None and _paths_equal(child.path_id, path_id, namespaces)): found = child break if found is not None: break namespaces |= ans if node is not self: nodes.append(node) finfo = FenceInfo(False, False) for node in nodes: finfo |= node.fence_info_ex(path_id, namespaces) if found and found.is_group and not allow_group: found = None return found, finfo, namespaces def find_visible( self, path_id: pathid.PathId, *, allow_group: bool = False ) -> Optional[ScopeTreeNode]: node, _, _ = self.find_visible_ex(path_id, allow_group=allow_group) return node def is_visible( self, path_id: pathid.PathId, *, allow_group: bool = False ) -> bool: return self.find_visible(path_id, allow_group=allow_group) is not None def is_any_prefix_visible(self, path_id: pathid.PathId) -> bool: for prefix in reversed(list(path_id.iter_prefixes())): if self.find_visible(prefix) is not None: return True return False def find_child( self, path_id: pathid.PathId, *, in_branches: bool = False, pfx_with_invariant_card: bool = False, ) -> Optional[ScopeTreeNode]: for child in self.children: if child.path_id == path_id: return child if ( ( in_branches and child.path_id is None and not child.fenced ) or ( pfx_with_invariant_card and child.path_id is not None # Type intersections have invariant cardinality # regardless of prefix visiblity. and child.path_id.is_type_intersection_path() ) ): desc = child.find_child( path_id, in_branches=True, pfx_with_invariant_card=pfx_with_invariant_card, ) if desc is not None: return desc return None def find_descendant( self, path_id: pathid.PathId, ) -> Optional[ScopeTreeNode]: for descendant, dns, _ in self.strict_descendants_and_namespaces: if (descendant.path_id is not None and _paths_equal(descendant.path_id, path_id, dns)): return descendant return None def find_descendants( self, path_id: pathid.PathId, ) -> list[ScopeTreeNodeWithPathId]: matched = [] for descendant, dns, _ in self.strict_descendants_and_namespaces: if (has_path_id(descendant) and _paths_equal(descendant.path_id, path_id, dns)): matched.append(descendant) return matched def find_descendant_and_ns(self, path_id: pathid.PathId) -> tuple[ Optional[ScopeTreeNode], AbstractSet[pathid.Namespace], Optional[FenceInfo], ]: for descendant, dns, finfo in self.strict_descendants_and_namespaces: if (descendant.path_id is not None and _paths_equal(descendant.path_id, path_id, dns)): return descendant, dns, finfo return None, frozenset(), None def is_optional_upto(self, ancestor: Optional[ScopeTreeNode]) -> bool: node: Optional[ScopeTreeNode] = self while node and node is not ancestor: if node.optional: return True node = node.parent return False def find_factorable_nodes( self, path_id: pathid.PathId, *, child_to_skip: Optional[ScopeTreeNode] = None, ) -> list[ tuple[ ScopeTreeNodeWithPathId, ScopeTreeNode, AbstractSet[pathid.Namespace], AbstractSet[pathid.Namespace], FenceInfo, bool, bool, ] ]: """Find nodes factorable with path_id (if attaching path_id to self) This is done by searching up the tree looking for an ancestor node that has path_id as a descendant such that *at most one* of self and the path_id descendant are fenced. That descendant, then, is a factorable node, and the ancestor is its factoring point. We do this by tracking whether we have passed a fence on our way up the tree, and only looking for unfenced descendants if so. We find all such factorable nodes and return them sorted by factoring point, from closest to furthest up. """ namespaces: AbstractSet[str] = frozenset() unnest_fence_seen = False fence_seen = False points = [] up_finfo = FenceInfo(False, False) # Track the last seen node so that we can skip it while looking # for descendants, to avoid performance pathologies, but also # to avoid rediscovering the same nodes when searching higher # in the tree. last = child_to_skip # Search up the tree for node, ans in self.ancestors_and_namespaces: # For each ancestor, search its descendants for path_id. # If we have passed a fence on the way up, only look for # unfenced descendants. for descendant, dns, finfo in ( node.descendants_and_namespaces_ex( unfenced_only=fence_seen, skip=last) ): cns = namespaces | dns if (has_path_id(descendant) and not descendant.is_group and _paths_equal(descendant.path_id, path_id, cns)): points.append(( descendant, node, namespaces, dns, finfo | up_finfo, unnest_fence_seen, fence_seen, )) namespaces |= ans unnest_fence_seen |= node.unnest_fence fence_seen |= node.fenced if node is not self: up_finfo |= node.fence_info_ex(path_id, namespaces) last = node return points def pformat(self) -> str: if self.children: child_formats = [] for c in self.children: cf = c.pformat() if cf: child_formats.append(cf) if child_formats: children = textwrap.indent(',\n'.join(child_formats), ' ') return f'"{self.name}": {{\n{children}\n}}' if self.path_id is not None: return f'"{self.name}"' else: return '' def pdebugformat( self, fuller: bool=False, styles: Optional[Mapping[ScopeTreeNode, term.AbstractStyle]]=None, ) -> str: name = f'"{self.debugname(fuller=fuller)}"' if styles and self in styles: name = styles[self].apply(name) if self.children: child_formats = [] for c in self.children: cf = c.pdebugformat(fuller=fuller, styles=styles) if cf: child_formats.append(cf) children = textwrap.indent(',\n'.join(child_formats), ' ') return f'{name}: {{\n{children}\n}}' else: return name def dump(self) -> None: print(self.pdebugformat()) def dump_full(self, others: Collection[ScopeTreeNode] = ()) -> None: """Do a debug dump of the root but hilight the current node.""" styles = {} if term.supports_colors(sys.stdout.fileno()): styles[self] = term.Style16(color='magenta', bold=True) for other in others: styles[other] = term.Style16(color='blue', bold=True) print(self.root.pdebugformat(styles=styles)) def _set_parent(self, parent: Optional[ScopeTreeNode]) -> None: assert self is not parent current_parent = self.parent if parent is current_parent: return if current_parent is not None: # Make sure no other node refers to us. current_parent.children.remove(self) if parent is not None: self._parent = weakref.ref(parent) parent.children.append(self) else: self._parent = None class ScopeTreeNodeWithPathId(ScopeTreeNode): path_id: pathid.PathId def _paths_equal( path_id_1: pathid.PathId, path_id_2: pathid.PathId, namespaces: AbstractSet[str], ) -> bool: if namespaces: path_id_1 = path_id_1.strip_namespace(namespaces) path_id_2 = path_id_2.strip_namespace(namespaces) return path_id_1 == path_id_2 ================================================ FILE: edb/ir/staeval.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Static evaluation of EdgeQL IR.""" from __future__ import annotations from typing import ( Any, Optional, TypeVar, ) import decimal import functools import uuid import immutables from edb import errors from edb.common import typeutils from edb.common import parsing from edb.common import value_dispatch from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.ir import statypes as statypes from edb.ir import utils as irutils from edb.schema import name as sn from edb.schema import objects as s_obj from edb.schema import objtypes as s_objtypes from edb.schema import types as s_types from edb.schema import scalars as s_scalars from edb.schema import schema as s_schema from edb.schema import constraints as s_constr from edb.schema import pointers as s_pointers from edb.server import config class StaticEvaluationError(errors.QueryError): pass class UnsupportedExpressionError(errors.QueryError): pass EvaluationResult = irast.TypeCast | irast.ConstExpr | irast.Array | irast.Tuple def evaluate_to_python_val( ir: irast.Base, schema: s_schema.Schema, ) -> Any: const = evaluate(ir, schema=schema) return const_to_python(const, schema=schema) @functools.singledispatch def evaluate(ir: irast.Base, schema: s_schema.Schema) -> EvaluationResult: raise UnsupportedExpressionError( f'no static IR evaluation handler for {ir.__class__}') @evaluate.register(irast.SelectStmt) def evaluate_SelectStmt( ir_stmt: irast.SelectStmt, schema: s_schema.Schema ) -> EvaluationResult: if irutils.is_trivial_select(ir_stmt) and not ir_stmt.result.is_binding: return evaluate(ir_stmt.result, schema) else: raise UnsupportedExpressionError( 'expression is not constant', span=ir_stmt.span) @evaluate.register(irast.InsertStmt) def evaluate_InsertStmt( ir: irast.InsertStmt, schema: s_schema.Schema ) -> EvaluationResult: # InsertStmt should NOT be statically evaluated in general; # This is a special case for inserting nested cfg::ConfigObject # when it's evaluated into a named tuple and then squashed into # a Python dict to be used in compile_structured_config(). tmp_schema, subject_type = irtyputils.ir_typeref_to_type( schema, ir.subject.expr.typeref ) config_obj = schema.get("cfg::ConfigObject") assert isinstance(config_obj, s_obj.SubclassableObject) if subject_type.issubclass(tmp_schema, config_obj): return irast.Tuple( named=True, typeref=ir.subject.typeref, elements=[ irast.TupleElement( name=ptr_set.expr.ptrref.shortname.name, val=irast.Set( expr=evaluate(ptr_set.expr.expr, schema), typeref=ptr_set.typeref, path_id=ptr_set.path_id, ), ) for ptr_set, _ in ir.subject.shape if ptr_set.expr.ptrref.shortname.name != "id" and ptr_set.expr.expr is not None ], ) raise UnsupportedExpressionError( f'no static IR evaluation handler for general {ir.__class__}' ) @evaluate.register(irast.TypeIntrospection) def evaluate_TypeIntrospection( ir: irast.TypeIntrospection, schema: s_schema.Schema ) -> EvaluationResult: return irast.StaticIntrospection( named=True, ir=ir, schema=schema, elements=[], typeref=ir.typeref ) @evaluate.register(irast.TypeCast) def evaluate_TypeCast( ir_cast: irast.TypeCast, schema: s_schema.Schema ) -> EvaluationResult: schema, from_type = irtyputils.ir_typeref_to_type( schema, ir_cast.from_type) schema, to_type = irtyputils.ir_typeref_to_type( schema, ir_cast.to_type) if ( not isinstance(from_type, s_scalars.ScalarType) or not isinstance(to_type, s_scalars.ScalarType) ): raise UnsupportedExpressionError('object cast not supported') scalar_type_to_python_type(from_type, schema) scalar_type_to_python_type(to_type, schema) evaluate(ir_cast.expr, schema) return ir_cast @evaluate.register(irast.EmptySet) def evaluate_EmptySet( ir_set: irast.EmptySet, schema: s_schema.Schema ) -> EvaluationResult: return ir_set @evaluate.register(irast.Set) def evaluate_Set( ir_set: irast.Set, schema: s_schema.Schema) -> EvaluationResult: return evaluate(ir_set.expr, schema=schema) @evaluate.register def evaluate_Pointer( ptr: irast.Pointer, schema: s_schema.Schema ) -> EvaluationResult: if ptr.expr is not None: return evaluate(ptr.expr, schema=schema) elif ( ptr.direction == s_pointers.PointerDirection.Outbound and isinstance(ptr.ptrref, irast.PointerRef) and ptr.ptrref.out_cardinality.is_single() and ptr.ptrref.out_target.is_scalar ): return evaluate_pointer_ref( evaluate(ptr.source.expr, schema=schema), ptr.ptrref ) else: raise UnsupportedExpressionError( 'expression is not constant', span=ptr.span) @functools.singledispatch def evaluate_pointer_ref( evaluated_source: EvaluationResult, ptrref: irast.PointerRef ) -> EvaluationResult: raise UnsupportedExpressionError( f'unsupported PointerRef on source {evaluated_source}', span=ptrref.span, ) @evaluate_pointer_ref.register(irast.StaticIntrospection) def evaluate_pointer_ref_StaticIntrospection( source: irast.StaticIntrospection, ptrref: irast.PointerRef ) -> EvaluationResult: return source.get_field_value(ptrref.shortname) @evaluate.register(irast.ConstExpr) def evaluate_BaseConstant( ir_const: irast.ConstExpr, schema: s_schema.Schema ) -> EvaluationResult: return ir_const @evaluate.register(irast.Array) def evaluate_Array( ir: irast.Array, schema: s_schema.Schema ) -> EvaluationResult: return irast.Array( elements=tuple( x.replace(expr=evaluate(x, schema)) for x in ir.elements ), typeref=ir.typeref, ) @evaluate.register(irast.Tuple) def evaluate_Tuple( ir: irast.Tuple, schema: s_schema.Schema ) -> EvaluationResult: return irast.Tuple( named=ir.named, elements=[ x.replace( val=x.val.replace( expr=evaluate(x.val, schema) ), ) for x in ir.elements ], typeref=ir.typeref, ) def _process_op_result( value: object, typeref: irast.TypeRef, schema: s_schema.Schema, *, span: Optional[parsing.Span]=None, ) -> irast.ConstExpr: qlconst: qlast.BaseConstant if isinstance(value, str): qlconst = qlast.Constant.string(value) elif isinstance(value, bool): qlconst = qlast.Constant.boolean(value) else: raise UnsupportedExpressionError( f"unsupported result type: {type(value)}", span=span ) result = qlcompiler.compile_constant_tree_to_ir( qlconst, styperef=typeref, schema=schema) assert isinstance(result, irast.ConstExpr), 'expected ConstExpr' return result op_table = { # Concatenation ('Infix', 'std::++'): lambda a, b: a + b, ('Infix', 'std::>='): lambda a, b: a >= b, ('Infix', 'std::>'): lambda a, b: a > b, ('Infix', 'std::<='): lambda a, b: a <= b, ('Infix', 'std::<'): lambda a, b: a < b, ('Infix', 'std::='): lambda a, b: a == b, ('Infix', 'std::!='): lambda a, b: a != b, } @evaluate.register(irast.OperatorCall) def evaluate_OperatorCall( opcall: irast.OperatorCall, schema: s_schema.Schema ) -> irast.ConstExpr: if irutils.is_union_expr(opcall): return _evaluate_union(opcall, schema) eval_func = op_table.get( (opcall.operator_kind, str(opcall.func_shortname)), ) if eval_func is None: raise UnsupportedExpressionError( f'unsupported operator: {opcall.func_shortname}', span=opcall.span) args: dict[int, irast.CallArg] = {} for key, arg in opcall.args.items(): arg_val = evaluate_to_python_val(arg.expr, schema=schema) if isinstance(arg_val, tuple): raise UnsupportedExpressionError( f'non-singleton operations are not supported', span=opcall.span) if arg_val is None: raise UnsupportedExpressionError( f'empty operations are not supported', span=opcall.span) if isinstance(key, str): raise UnsupportedExpressionError( f'named arguments are not allowed for operators', span=opcall.span) args[key] = arg_val args_list: list[irast.CallArg] = [] for key in range(len(args)): if key not in args: raise UnsupportedExpressionError( f'missing positional argument {key}', span=opcall.span) args_list.append(args[key]) value = eval_func(*args_list) return _process_op_result( value, opcall.typeref, schema, span=opcall.span) @evaluate.register(irast.SliceIndirection) def evaluate_SliceIndirection( slice: irast.SliceIndirection, schema: s_schema.Schema ) -> irast.ConstExpr: args = [slice.expr, slice.start, slice.stop] vals = [ evaluate_to_python_val(arg, schema=schema) if arg else None for arg in args ] for arg, arg_val in zip(args, vals): if arg is None: continue if isinstance(arg_val, tuple): raise UnsupportedExpressionError( f'non-singleton operations are not supported', span=slice.span) if arg_val is None: raise UnsupportedExpressionError( f'empty operations are not supported', span=slice.span) base, start, stop = vals value = base[start:stop] # type: ignore[index] return _process_op_result( value, slice.expr.typeref, schema, span=slice.span) def _evaluate_union( opcall: irast.OperatorCall, schema: s_schema.Schema ) -> irast.ConstExpr: elements: list[irast.BaseConstant] = [] for arg in opcall.args.values(): val = evaluate(arg.expr, schema=schema) if isinstance(val, irast.TypeCast): val = evaluate(val.expr, schema=schema) if isinstance(val, irast.ConstantSet): for el in val.elements: if isinstance(el, irast.BaseParameter): raise UnsupportedExpressionError( f'{el!r} not supported in UNION', span=opcall.span) elements.append(el) elif isinstance(val, irast.EmptySet): empty_set = val elif isinstance(val, irast.BaseConstant): elements.append(val) else: raise UnsupportedExpressionError( f'{val!r} not supported in UNION', span=opcall.span) if elements: return irast.ConstantSet( elements=tuple(elements), typeref=next(iter(elements)).typeref, ) else: # We get an empty set if the UNION was exclusivly empty set # literals. If that happens, grab one of the empty sets # that we saw and return it. return empty_set @functools.singledispatch def const_to_python(ir: irast.Expr | None, schema: s_schema.Schema) -> Any: raise UnsupportedExpressionError(f'cannot convert {ir!r} to Python value') @const_to_python.register(irast.EmptySet) def empty_set_to_python( ir: irast.EmptySet, schema: s_schema.Schema, ) -> None: return None @const_to_python.register(irast.ConstantSet) def const_set_to_python( ir: irast.ConstantSet, schema: s_schema.Schema ) -> tuple[Any, ...]: return tuple(const_to_python(v, schema) for v in ir.elements) @const_to_python.register(irast.Array) def array_const_to_python(ir: irast.Array, schema: s_schema.Schema) -> Any: return [const_to_python(x.expr, schema) for x in ir.elements] @const_to_python.register(irast.Tuple) def tuple_const_to_python(ir: irast.Tuple, schema: s_schema.Schema) -> Any: if ir.named: return { x.name: const_to_python(x.val.expr, schema) for x in ir.elements } else: return tuple( const_to_python(x.val.expr, schema) for x in ir.elements ) @const_to_python.register(irast.IntegerConstant) def int_const_to_python( ir: irast.IntegerConstant, schema: s_schema.Schema ) -> Any: stype = schema.get_by_id(ir.typeref.id) assert isinstance(stype, s_types.Type) bigint = schema.get('std::bigint', type=s_obj.SubclassableObject) if stype.issubclass(schema, bigint): return decimal.Decimal(ir.value) else: return int(ir.value) @const_to_python.register(irast.FloatConstant) def float_const_to_python( ir: irast.FloatConstant, schema: s_schema.Schema ) -> Any: stype = schema.get_by_id(ir.typeref.id) assert isinstance(stype, s_types.Type) bigint = schema.get('std::bigint', type=s_obj.SubclassableObject) if stype.issubclass(schema, bigint): return decimal.Decimal(ir.value) else: return float(ir.value) @const_to_python.register(irast.StringConstant) def str_const_to_python( ir: irast.StringConstant, schema: s_schema.Schema ) -> Any: return ir.value @const_to_python.register(irast.BooleanConstant) def bool_const_to_python( ir: irast.BooleanConstant, schema: s_schema.Schema ) -> Any: return ir.value == 'true' @const_to_python.register(irast.TypeCast) def cast_const_to_python(ir: irast.TypeCast, schema: s_schema.Schema) -> Any: schema, stype = irtyputils.ir_typeref_to_type(schema, ir.to_type) if not isinstance(stype, s_scalars.ScalarType): raise UnsupportedExpressionError( "non-scalar casts are not supported in Python eval") pytype = scalar_type_to_python_type(stype, schema) sval = evaluate_to_python_val(ir.expr, schema=schema) return python_cast(sval, pytype) @functools.singledispatch def python_cast(sval: Any, pytype: type) -> Any: return pytype(sval) @python_cast.register(type(None)) def python_cast_none(sval: None, pytype: type) -> None: return None @python_cast.register(tuple) def python_cast_tuple(sval: tuple[Any, ...], pytype: type) -> Any: return tuple(python_cast(elem, pytype) for elem in sval) @python_cast.register(str) def python_cast_str(sval: str, pytype: type) -> Any: if pytype is bool: if sval.lower() == 'true': return True elif sval.lower() == 'false': return False else: raise errors.InvalidValueError( f"invalid input syntax for type bool: {sval!r}", hint="bool value can only be one of: true, false" ) else: return pytype(sval) def schema_type_to_python_type( stype: s_types.Type, schema: s_schema.Schema ) -> type | statypes.CompositeTypeSpec: if isinstance(stype, s_scalars.ScalarType): return scalar_type_to_python_type(stype, schema) elif isinstance(stype, s_objtypes.ObjectType): return object_type_to_spec( stype, schema, spec_class=statypes.CompositeTypeSpec) else: raise UnsupportedExpressionError( f'{stype.get_displayname(schema)} is not representable in Python') def scalar_type_to_python_type( stype: s_scalars.ScalarType, schema: s_schema.Schema, ) -> type: typname = stype.get_name(schema) pytype = statypes.maybe_get_python_type_for_scalar_type_name(str(typname)) if pytype is None: for ancestor in stype.get_ancestors(schema).objects(schema): typname = ancestor.get_name(schema) pytype = statypes.maybe_get_python_type_for_scalar_type_name( str(typname)) if pytype is not None: break if pytype is not None: return pytype elif stype.is_enum(schema): return str raise UnsupportedExpressionError( f'{stype.get_displayname(schema)} is not representable in Python') T_spec = TypeVar('T_spec', bound=statypes.CompositeTypeSpec) class _Missing: pass def object_type_to_spec( objtype: s_objtypes.ObjectType, schema: s_schema.Schema, *, # We pass a spec_class so that users like the config system can ask for # their own subtyped versions of a spec. spec_class: type[T_spec], parent: Optional[T_spec] = None, _memo: Optional[dict[s_types.Type, T_spec | type]] = None, ) -> T_spec: if _memo is None: _memo = {} # Prevent infinite recursion _memo[objtype] = _Missing default: Any fields = {} for pn, p in objtype.get_pointers(schema).items(schema): assert isinstance(p, s_pointers.Pointer) str_pn = str(pn) if str_pn in ('id', '__type__'): continue ptype = p.get_target(schema) assert ptype is not None if isinstance(ptype, s_objtypes.ObjectType): pytype = _memo.get(ptype) if pytype is _Missing: raise UnsupportedExpressionError() if pytype is None: pytype = object_type_to_spec( ptype, schema, spec_class=spec_class, parent=parent, _memo=_memo) _memo[ptype] = pytype elif isinstance(ptype, s_scalars.ScalarType): pytype = scalar_type_to_python_type(ptype, schema) else: raise UnsupportedExpressionError(f"unsupported cast type: {ptype}") ptr_card: qltypes.SchemaCardinality = p.get_cardinality(schema) if ptr_card.is_known(): is_multi = ptr_card.is_multi() else: raise UnsupportedExpressionError() if is_multi: pytype = frozenset[pytype] # type: ignore default = p.get_default(schema) if default is None: if p.get_required(schema): default = statypes.MISSING else: default = qlcompiler.evaluate_to_python_val( default.text, schema=schema) if is_multi and not isinstance(default, frozenset): default = frozenset((default,)) constraints = p.get_constraints(schema).objects(schema) exclusive = schema.get('std::exclusive', type=s_constr.Constraint) unique = ( not ptype.is_object_type() and any( c.issubclass(schema, exclusive) and not c.get_delegated(schema) for c in constraints ) ) fields[str_pn] = statypes.CompositeTypeSpecField( name=str_pn, type=pytype, unique=unique, default=default, secret=p.get_secret(schema), protected=p.get_protected(schema), ) spec = spec_class( name=str(objtype.get_name(schema)), fields=immutables.Map(fields), parent=parent, ) for subtype in objtype.children(schema): spec.children.append( object_type_to_spec( subtype, schema, spec_class=spec_class, parent=spec, _memo=_memo)) return spec @functools.singledispatch def evaluate_to_config_op( ir: irast.Base, schema: s_schema.Schema ) -> config.Operation: raise UnsupportedExpressionError( f'no config op evaluation handler for {ir.__class__}') @evaluate_to_config_op.register(irast.ConfigSet) def evaluate_config_set( ir: irast.ConfigSet, schema: s_schema.Schema ) -> config.Operation: if ir.scope == qltypes.ConfigScope.GLOBAL: raise UnsupportedExpressionError( 'SET GLOBAL is not supported by static eval' ) value = evaluate_to_python_val(ir.expr, schema) if ir.cardinality is qltypes.SchemaCardinality.Many: if value is None: value = [] elif not typeutils.is_container(value): value = [value] return config.Operation( opcode=config.OpCode.CONFIG_SET, scope=ir.scope, setting_name=ir.name, value=value, ) @evaluate_to_config_op.register(irast.ConfigReset) def evaluate_config_reset( ir: irast.ConfigReset, schema: s_schema.Schema ) -> config.Operation: if ir.selector is not None: raise UnsupportedExpressionError( 'filtered CONFIGURE RESET is not supported by static eval' ) return config.Operation( opcode=config.OpCode.CONFIG_RESET, scope=ir.scope, setting_name=ir.name, value=None, ) @evaluate_to_config_op.register(irast.ConfigInsert) def evaluate_config_insert( ir: irast.ConfigInsert, schema: s_schema.Schema ) -> config.Operation: return config.Operation( opcode=config.OpCode.CONFIG_ADD, scope=ir.scope, setting_name=ir.name, value=evaluate_to_python_val( irast.InsertStmt(subject=ir.expr), schema=schema ), ) @value_dispatch.value_dispatch def coerce_py_const( type_id: uuid.UUID, val: Any ) -> irast.ConstExpr | irast.TypeCast: raise UnsupportedExpressionError(f"unimplemented coerce type: {type_id}") @coerce_py_const.register(s_obj.get_known_type_id("std::str")) def evaluate_std_str( type_id: uuid.UUID, val: Any ) -> irast.ConstExpr | irast.TypeCast: return irast.StringConstant( typeref=irast.TypeRef( id=type_id, name_hint=sn.name_from_string("std::str") ), value=str(val), ) @coerce_py_const.register(s_obj.get_known_type_id("std::bool")) def evaluate_std_bool( type_id: uuid.UUID, val: Any ) -> irast.ConstExpr | irast.TypeCast: return irast.BooleanConstant( typeref=irast.TypeRef( id=type_id, name_hint=sn.name_from_string("std::bool") ), value=str(bool(val)).lower(), ) @coerce_py_const.register(s_obj.get_known_type_id("std::uuid")) def evaluate_std_uuid( type_id: uuid.UUID, val: Any ) -> irast.ConstExpr | irast.TypeCast: str_type_id = s_obj.get_known_type_id("std::str") str_typeref = irast.TypeRef( id=str_type_id, name_hint=sn.name_from_string("std::str") ) return irast.TypeCast( from_type=str_typeref, to_type=irast.TypeRef( id=type_id, name_hint=sn.name_from_string("std::uuid") ), expr=irast.Set( expr=irast.StringConstant(typeref=str_typeref, value=str(val)), typeref=str_typeref, path_id=irast.PathId.from_typeref(str_typeref), ), sql_cast=True, sql_expr=False, ) ================================================ FILE: edb/ir/statypes.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, ClassVar, Mapping, Optional, Self, TYPE_CHECKING, ) import dataclasses import datetime import decimal import enum import functools import re import struct import uuid import immutables from edb import errors from edb.common import parametric from edb.common import uuidgen from edb.schema import name as s_name from edb.schema import objects as s_obj if TYPE_CHECKING: from edb.edgeql import qltypes MISSING: Any = object() @dataclasses.dataclass(frozen=True) class CompositeTypeSpecField: name: str type: type | CompositeTypeSpec _: dataclasses.KW_ONLY unique: bool = False default: Any = MISSING secret: bool = False protected: bool = False @dataclasses.dataclass(frozen=True, kw_only=True) class CompositeTypeSpec: name: str fields: immutables.Map[str, CompositeTypeSpecField] parent: Optional[CompositeTypeSpec] = None children: list[CompositeTypeSpec] = dataclasses.field( default_factory=list, hash=False, compare=False ) has_secret: bool = False def __post_init__(self) -> None: has_secret = any( field.secret or ( isinstance(field, CompositeTypeSpec) # We look at children of pointer targets, and not # children of the object itself, on the idea that for # config objects, omitting individual top level # objects with secrets should be fine. and ( field.has_secret or any(child.has_secret for child in field.children) ) ) for field in self.fields.values() ) object.__setattr__(self, 'has_secret', has_secret) @property def __name__(self) -> str: return self.name def get_field_unique_site(self, name: str) -> Optional[CompositeTypeSpec]: typ: Optional[CompositeTypeSpec] = self site: Optional[CompositeTypeSpec] = None while typ: if name in typ.fields and typ.fields[name].unique: site = typ typ = typ.parent return site class CompositeType: _tspec: CompositeTypeSpec def to_json_value(self, redacted: bool = False) -> dict[str, Any]: raise NotImplementedError class ScalarType: def __init__(self, val: str, /) -> None: raise NotImplementedError def to_backend_str(self) -> str: raise NotImplementedError @classmethod def to_backend_expr(cls, expr: str) -> str: raise NotImplementedError("{cls}.to_backend_expr()") @classmethod def to_frontend_expr(cls, expr: str) -> Optional[str]: raise NotImplementedError("{cls}.to_frontend_expr()") def to_json(self) -> str: raise NotImplementedError def encode(self) -> bytes: raise NotImplementedError @classmethod def decode(cls, data: bytes) -> ScalarType: raise NotImplementedError @functools.total_ordering class Duration(ScalarType): _pg_simple_parser = re.compile(r''' ^ \s* ( (?P(\+|\-)?) ) ( (?P\d+) ) : ( (?P\d+)? ( :(?P\d+) ( \.(?P\d{0,3}) (?P\d{0,3}) (?P\d*) )? )? )? \s* $ ''', re.X) _pg_parser = re.compile(r''' ( ( \s* (?P(\+|\-)?\d+) \s* (h|hr|hrs|hour|hours) \s* ) | ( \s* (?P(\+|\-)?\d+) \s* (m|min|mins|minute|minutes) \s* ) | ( \s* (?P(\+|\-)?\d+) \s* (ms | (millisecon(s|d|ds)?)) # '12 millisecon' is valid \s* ) | ( \s* (?P(\+|\-)?\d+) \s* (us | (microsecond(s)?)) \s* ) | ( \s* (?P(\+|\-)?\d+) ( (\s* $) | (\s* (s|sec|secs|second|seconds)) ) \s* ) )(?=$ | \d | \s) | ( \s* (?P.+) ) ''', re.X | re.I) _iso_parser = re.compile(r''' ^ PT ( (?P(\+|\-)?\d+) H )? ( (?P(\+|\-)?\d+) M )? ( ( (?P\+|\-)? (?P\d+) ( \. (?P\d+) )? ) S )? $ ''', re.X) _codec = struct.Struct('!QLL') _value: int # microseconds def __init__( self, pg_text: str = '', /, *, microseconds: Optional[int] = None ) -> None: if pg_text == '' and microseconds is not None: self._value = microseconds else: self._value = self._us_from_pg_text(pg_text) def _us_from_pg_text(self, input: str, /) -> int: try: seconds = int(input) except ValueError: pass else: return seconds * 1000 * 1000 m = self._pg_simple_parser.match(input) if m is not None: value = 0 parsed = m.groupdict() if parsed['hours']: hours = int(parsed['hours']) if 0 <= hours <= 2147483647: value += hours * 3600_000_000 else: raise errors.NumericOutOfRangeError( 'interval field value out of range') if parsed['minutes']: mins = int(parsed['minutes']) if 0 <= mins <= 59: value += mins * 60_000_000 else: raise errors.NumericOutOfRangeError( 'interval field value out of range') if parsed['seconds']: secs = int(parsed['seconds']) if 0 <= secs <= 59: value += secs * 1_000_000 else: raise errors.NumericOutOfRangeError( 'interval field value out of range') if parsed['milliseconds']: value += int(parsed['milliseconds'].ljust(3, '0')) * 1_000 if parsed['microseconds']: value += int(parsed['microseconds'].ljust(3, '0')) if parsed['submicro'] and int(parsed['submicro'][:1]) >= 5: value += 1 if parsed['sign'] == '-': value = -value return value if (parsed_iso := self._parse_iso8601(input)) is not None: return parsed_iso value = 0 seen: set[str] = set() for m in self._pg_parser.finditer(input): filtered = { k: v for k, v in m.groupdict().items() if v is not None } if len(filtered) != 1: raise errors.InvalidValueError( 'invalid input syntax for type std::duration') kind, val = next(iter(filtered.items())) if kind == 'error': raise errors.InvalidValueError( f'invalid input syntax for type std::duration: ' f'unable to parse {val!r}') if kind in seen: raise errors.InvalidValueError( f'invalid input syntax for type std::duration: ' f'the {kind!r} component has been specified ' f'more than once') seen.add(kind) intval = int(val) if kind == 'hours': value += intval * 3600_000_000 elif kind == 'minutes': value += intval * 60_000_000 elif kind == 'seconds': value += intval * 1_000_000 elif kind == 'milliseconds': value += intval * 1_000 elif kind == 'microseconds': value += intval return value @classmethod def _parse_iso8601(cls, input: str, /) -> Optional[int]: m = cls._iso_parser.match(input) if not m: return None value = 0 if m['hours']: value += int(m['hours']) * 3600_000_000 if m['minutes']: value += int(m['minutes']) * 60_000_000 secsign = -1 if m['secsign'] == '-' else +1 if m['seconds']: value += int(m['seconds']) * 1_000_000 * secsign if m['microseconds']: ms = m['microseconds'][:6] ms = ms.ljust(6, '0') value += int(ms) * secsign return value @classmethod def from_iso8601(cls, input: str, /) -> Duration: val = cls._parse_iso8601(input) if val is None: raise errors.InvalidValueError( f'invalid input syntax for type std::duration: ' f'cannot parse {input!r} as ISO 8601') return cls(microseconds=val) @classmethod def from_microseconds(cls, input: int, /) -> Duration: return cls(microseconds=input) def to_microseconds(self) -> int: return self._value def __lt__(self, other: Duration) -> bool: return self._value < other._value def to_iso8601(self) -> str: neg = '-' if self._value < 0 else '' seconds, usecs = divmod(abs(self._value), 1_000_000) minutes, seconds = divmod(seconds, 60) hours, minutes = divmod(minutes, 60) ret = ['PT'] if hours: ret.append(f'{neg}{hours}H') if minutes: ret.append(f'{neg}{minutes}M') if seconds or usecs: if usecs: ret.append(f"{neg}{seconds}.") ret.append(f"{str(usecs).rjust(6, '0')}"[:6].rstrip('0')) else: ret.append(f'{neg}{seconds}') ret.append('S') if ret == ['PT']: ret.append('0S') return ''.join(ret) def to_timedelta(self) -> datetime.timedelta: return datetime.timedelta(microseconds=self.to_microseconds()) def to_backend_str(self) -> str: return f'{self.to_microseconds()}us' @classmethod def to_backend_expr(cls, expr: str) -> str: return f"edgedb_VER._interval_to_ms(({expr})::interval)::text || 'ms'" @classmethod def to_frontend_expr(cls, expr: str) -> Optional[str]: return None def to_json(self) -> str: return self.to_iso8601() def __repr__(self) -> str: return f'' def encode(self) -> bytes: return self._codec.pack(self._value, 0, 0) @classmethod def decode(cls, data: bytes) -> Duration: return cls(microseconds=cls._codec.unpack(data)[0]) def __hash__(self) -> int: return hash(self._value) def __eq__(self, other: object) -> bool: if isinstance(other, Duration): return self._value == other._value else: return False @functools.total_ordering class ConfigMemory(ScalarType): PiB = 1024 * 1024 * 1024 * 1024 * 1024 TiB = 1024 * 1024 * 1024 * 1024 GiB = 1024 * 1024 * 1024 MiB = 1024 * 1024 KiB = 1024 _parser = re.compile(r''' ^ (?P\d+) (?PB|KiB|MiB|GiB|TiB|PiB) $ ''', re.X) _value: int def __init__( self, val: str | int, /, ) -> None: if isinstance(val, int): self._value = val elif isinstance(val, str): text = val if text == '0': self._value = 0 return m = self._parser.match(text) if m is None: raise errors.InvalidValueError( f'unable to parse memory size: {text!r}') num = int(m.group('num')) unit = m.group('unit') if unit == 'B': self._value = num elif unit == 'KiB': self._value = num * self.KiB elif unit == 'MiB': self._value = num * self.MiB elif unit == 'GiB': self._value = num * self.GiB elif unit == 'TiB': self._value = num * self.TiB elif unit == 'PiB': self._value = num * self.PiB else: raise AssertionError('unexpected unit') else: raise ValueError( f"invalid ConfigMemory value: {type(val)}, expected int | str") def __lt__(self, other: ConfigMemory) -> bool: return self._value < other._value def to_nbytes(self) -> int: return self._value def to_str(self) -> str: if self._value >= self.PiB and self._value % self.PiB == 0: return f'{self._value // self.PiB}PiB' if self._value >= self.TiB and self._value % self.TiB == 0: return f'{self._value // self.TiB}TiB' if self._value >= self.GiB and self._value % self.GiB == 0: return f'{self._value // self.GiB}GiB' if self._value >= self.MiB and self._value % self.MiB == 0: return f'{self._value // self.MiB}MiB' if self._value >= self.KiB and self._value % self.KiB == 0: return f'{self._value // self.KiB}KiB' return f'{self._value}B' def to_backend_str(self) -> str: if self._value >= self.TiB and self._value % self.TiB == 0: return f'{self._value // self.TiB}TB' if self._value >= self.GiB and self._value % self.GiB == 0: return f'{self._value // self.GiB}GB' if self._value >= self.MiB and self._value % self.MiB == 0: return f'{self._value // self.MiB}MB' if self._value >= self.KiB and self._value % self.KiB == 0: return f'{self._value // self.KiB}kB' return f'{self._value}B' @classmethod def to_backend_expr(cls, expr: str) -> str: return f"edgedb_VER.cfg_memory_to_str({expr})" @classmethod def to_frontend_expr(cls, expr: str) -> Optional[str]: return f"(edgedb_VER.str_to_cfg_memory({expr})::text || 'B')" def to_json(self) -> str: return self.to_str() def __repr__(self) -> str: return f'' def __hash__(self) -> int: return hash(self._value) def __eq__(self, other: Any) -> bool: if isinstance(other, ConfigMemory): return self._value == other._value else: return False typemap = { 'std::str': str, 'std::anyint': int, 'std::anyfloat': float, 'std::decimal': decimal.Decimal, 'std::bigint': decimal.Decimal, 'std::bool': bool, 'std::json': str, 'std::uuid': uuidgen.UUID, 'std::duration': Duration, 'cfg::memory': ConfigMemory, } def maybe_get_python_type_for_scalar_type_name(name: str) -> Optional[type]: return typemap.get(name) class EnumScalarType[E: enum.StrEnum]( ScalarType, parametric.SingleParametricType[E], ): """Configuration value represented by a custom string enum type that supports arbitrary value mapping to backend (Postgres) configuration values, e.g mapping "Enabled"/"Disabled" enum to a bool value, etc. We use SingleParametricType to obtain runtime access to the Generic type arg to avoid having to copy-paste the constructors. """ _val: E _eql_type: ClassVar[Optional[s_name.QualName]] def __init_subclass__( cls, *, edgeql_type: Optional[str] = None, **kwargs: Any, ) -> None: global typemap super().__init_subclass__(**kwargs) if edgeql_type is not None: if edgeql_type in typemap: raise TypeError( f"{edgeql_type} is already a registered EnumScalarType") typemap[edgeql_type] = cls cls._eql_type = s_name.QualName.from_string(edgeql_type) def __init__( self, val: E | str, ) -> None: if isinstance(val, self.type): self._val = val elif isinstance(val, str): try: self._val = self.type(val) except ValueError: raise errors.InvalidValueError( f'unexpected backend value for ' f'{self.__class__.__name__}: {val!r}' ) from None def to_str(self) -> str: return str(self._val) def to_json(self) -> str: return self._val def encode(self) -> bytes: return self._val.encode("utf8") @classmethod def get_translation_map(cls) -> Mapping[E, str]: raise NotImplementedError @classmethod def decode(cls, data: bytes) -> Self: return cls(val=cls.type(data.decode("utf8"))) def __repr__(self) -> str: return f"" def __hash__(self) -> int: return hash(self._val) def __eq__(self, other: Any) -> bool: if isinstance(other, type(self)): return self._val == other._val else: return NotImplemented def __reduce__(self) -> tuple[ Callable[..., EnumScalarType[Any]], tuple[ Optional[tuple[type, ...] | type], E, ], ]: assert type(self).is_fully_resolved(), \ f'{type(self)} parameters are not resolved' cls: type[EnumScalarType[E]] = self.__class__ types: Optional[tuple[type, ...]] = self.orig_args if types is None or not cls.is_anon_parametrized(): typeargs = None else: typeargs = types[0] if len(types) == 1 else types return (cls.__restore__, (typeargs, self._val)) @classmethod def __restore__( cls, typeargs: Optional[tuple[type, ...] | type], val: E, ) -> Self: if typeargs is None or cls.is_anon_parametrized(): obj = cls(val) else: obj = cls[typeargs](val) # type: ignore[index] return obj @classmethod def get_edgeql_typeid(cls) -> uuid.UUID: return s_obj.get_known_type_id('std::str') @classmethod def get_edgeql_type(cls) -> s_name.QualName: """Return fully-qualified name of the scalar type for this setting.""" assert cls._eql_type is not None return cls._eql_type def to_backend_str(self) -> str: """Convert static frontend config value to backend config value.""" return self.get_translation_map()[self._val] @classmethod def to_backend_expr(cls, expr: str) -> str: """Convert dynamic backend config value to frontend config value.""" cases_list = [] for fe_val, be_val in cls.get_translation_map().items(): cases_list.append(f"WHEN lower('{fe_val}') THEN '{be_val}'") cases = "\n".join(cases_list) errmsg = f"unexpected frontend value for {cls.__name__}: %s" err = f"edgedb_VER.raise(NULL::text, msg => format('{errmsg}', v))" return ( f"(SELECT CASE v\n{cases}\nELSE\n{err}\nEND " f"FROM lower(({expr})) AS f(v))" ) @classmethod def to_frontend_expr(cls, expr: str) -> Optional[str]: """Convert dynamic frontend config value to backend config value.""" cases_list = [] for fe_val, be_val in cls.get_translation_map().items(): cases_list.append(f"WHEN lower('{be_val}') THEN '{fe_val}'") cases = "\n".join(cases_list) errmsg = f"unexpected backend value for {cls.__name__}: %s" err = f"edgedb_VER.raise(NULL::text, msg => format('{errmsg}', v))" return ( f"(SELECT CASE v\n{cases}\nELSE\n{err}\nEND " f"FROM lower(({expr})) AS f(v))" ) class EnabledDisabledEnum(enum.StrEnum): Enabled = "Enabled" Disabled = "Disabled" class EnabledDisabledType( EnumScalarType[EnabledDisabledEnum], edgeql_type="cfg::TestEnabledDisabledEnum", ): @classmethod def get_translation_map(cls) -> Mapping[EnabledDisabledEnum, str]: return { EnabledDisabledEnum.Enabled: "true", EnabledDisabledEnum.Disabled: "false", } class TransactionAccessModeEnum(enum.StrEnum): ReadOnly = "ReadOnly" ReadWrite = "ReadWrite" class TransactionAccessMode( EnumScalarType[TransactionAccessModeEnum], edgeql_type="sys::TransactionAccessMode", ): @classmethod def get_translation_map(cls) -> Mapping[TransactionAccessModeEnum, str]: return { TransactionAccessModeEnum.ReadOnly: "true", TransactionAccessModeEnum.ReadWrite: "false", } def to_qltypes(self) -> qltypes.TransactionAccessMode: from edb.edgeql import qltypes match self._val: case TransactionAccessModeEnum.ReadOnly: return qltypes.TransactionAccessMode.READ_ONLY case TransactionAccessModeEnum.ReadWrite: return qltypes.TransactionAccessMode.READ_WRITE case _: raise AssertionError(f"unexpected value: {self._val!r}") class TransactionDeferrabilityEnum(enum.StrEnum): Deferrable = "Deferrable" NotDeferrable = "NotDeferrable" class TransactionDeferrability( EnumScalarType[TransactionDeferrabilityEnum], edgeql_type="sys::TransactionDeferrability", ): @classmethod def get_translation_map(cls) -> Mapping[TransactionDeferrabilityEnum, str]: return { TransactionDeferrabilityEnum.Deferrable: "true", TransactionDeferrabilityEnum.NotDeferrable: "false", } class TransactionIsolationEnum(enum.StrEnum): Serializable = "Serializable" RepeatableRead = "RepeatableRead" class TransactionIsolation( EnumScalarType[TransactionIsolationEnum], edgeql_type="sys::TransactionIsolation", ): @classmethod def get_translation_map(cls) -> Mapping[TransactionIsolationEnum, str]: return { TransactionIsolationEnum.Serializable: "serializable", TransactionIsolationEnum.RepeatableRead: "repeatable read", } def to_qltypes(self) -> qltypes.TransactionIsolationLevel: from edb.edgeql import qltypes match self._val: case TransactionIsolationEnum.Serializable: return qltypes.TransactionIsolationLevel.SERIALIZABLE case TransactionIsolationEnum.RepeatableRead: return qltypes.TransactionIsolationLevel.REPEATABLE_READ case _: raise AssertionError(f"unexpected value: {self._val!r}") ================================================ FILE: edb/ir/typeutils.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2015-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Utilities for IR type descriptors.""" from __future__ import annotations from typing import ( Any, Callable, Iterable, Optional, TYPE_CHECKING, overload, ) import uuid from edb.edgeql import qltypes from edb.schema import casts as s_casts from edb.schema import links as s_links from edb.schema import name as s_name from edb.schema import properties as s_props from edb.schema import pointers as s_pointers from edb.schema import scalars as s_scalars from edb.schema import types as s_types from edb.schema import objtypes as s_objtypes from edb.schema import objects as s_obj from edb.schema import utils as s_utils from . import ast as irast if TYPE_CHECKING: from edb.schema import schema as s_schema TypeRefCacheKey = tuple[uuid.UUID, bool, bool] PtrRefCacheKey = s_pointers.PointerLike PtrRefCache = dict[PtrRefCacheKey, 'irast.BasePointerRef'] TypeRefCache = dict[TypeRefCacheKey, 'irast.TypeRef'] # Modules where all the "types" in them are really just custom views # provided by metaschema. VIEW_MODULES = ('sys', 'cfg') def is_cfg_view( obj: s_obj.Object, schema: s_schema.Schema, ) -> bool: return ( isinstance(obj, (s_objtypes.ObjectType, s_pointers.Pointer)) and ( obj.get_name(schema).module in VIEW_MODULES or bool( (cfg_object := schema.get( 'cfg::ConfigObject', type=s_objtypes.ObjectType, default=None )) and ( nobj := ( obj if isinstance(obj, s_objtypes.ObjectType) else obj.get_source(schema) ) ) and nobj.issubclass(schema, cfg_object) ) ) ) def is_excluded_cfg_view( child: s_obj.Object, *, ancestor: s_obj.Object, schema: s_schema.Schema, ) -> bool: """Used to exclude sys/cfg tables from non sys/cfg views for performance. Also used by access policies to prevent including cfg views in when expanding a non-cfg type's descendants (#8865). """ return is_cfg_view(child, schema) and not is_cfg_view(ancestor, schema) def is_scalar(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes a scalar type.""" return typeref.is_scalar def is_object(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes an object type.""" return ( not is_scalar(typeref) and not is_collection(typeref) and not is_generic(typeref) ) def is_view(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes a view.""" return typeref.is_view def is_collection(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes an collection type.""" return bool(typeref.collection) def is_array(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes an array type.""" return typeref.collection == s_types.Array.get_schema_name() def is_tuple(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes a tuple type.""" return typeref.collection == s_types.Tuple.get_schema_name() def is_range(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes a range type.""" return typeref.collection == s_types.Range.get_schema_name() def is_multirange(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes a multirange type.""" return typeref.collection == s_types.MultiRange.get_schema_name() def is_any(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes the ``anytype`` generic type.""" return isinstance(typeref, irast.AnyTypeRef) def is_anytuple(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes the ``anytuple`` generic type.""" return isinstance(typeref, irast.AnyTupleRef) def is_anyobject(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes the ``anyobject`` generic type.""" return isinstance(typeref, irast.AnyObjectRef) def is_generic(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes a generic type.""" if is_collection(typeref): return any(is_generic(st) for st in typeref.subtypes) else: return is_any(typeref) or is_anytuple(typeref) or is_anyobject(typeref) def is_abstract(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes an abstract type.""" return typeref.is_abstract def is_json(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes the json type.""" return typeref.real_base_type.id == s_obj.get_known_type_id('std::json') def is_bytes(typeref: irast.TypeRef) -> bool: """Return True if *typeref* describes the bytes type.""" return typeref.real_base_type.id == s_obj.get_known_type_id('std::bytes') def is_exactly_free_object(typeref: irast.TypeRef) -> bool: return typeref.name_hint == s_name.QualName('std', 'FreeObject') def is_free_object(typeref: irast.TypeRef) -> bool: if typeref.material_type: typeref = typeref.material_type return is_exactly_free_object(typeref) def is_persistent_tuple(typeref: irast.TypeRef) -> bool: if is_tuple(typeref): if typeref.material_type is not None: material = typeref.material_type else: material = typeref return material.in_schema else: return False def is_empty_typeref(typeref: irast.TypeRef) -> bool: return typeref.union is not None and len(typeref.union) == 0 def needs_custom_serialization(typeref: irast.TypeRef) -> bool: # True if any component needs custom serialization return contains_predicate( typeref, lambda typeref: typeref.real_base_type.custom_sql_serialization is not None ) def contains_predicate( typeref: irast.TypeRef, pred: Callable[[irast.TypeRef], bool], ) -> bool: if pred(typeref): return True elif typeref.union: return any( contains_predicate(sub, pred) for sub in typeref.union ) return any( contains_predicate(sub, pred) for sub in typeref.subtypes ) def contains_object(typeref: irast.TypeRef) -> bool: return contains_predicate(typeref, is_object) def type_to_typeref( schema: s_schema.Schema, t: s_types.Type, *, cache: Optional[dict[TypeRefCacheKey, irast.TypeRef]], typename: Optional[s_name.QualName] = None, include_children: bool = False, include_ancestors: bool = False, _name: Optional[str] = None, ) -> irast.TypeRef: """Return an instance of :class:`ir.ast.TypeRef` for a given type. An IR TypeRef is an object that fully describes a schema type for the purposes of query compilation. Args: schema: A schema instance, in which the type *t* is defined. t: A schema type instance. cache: Optional mapping from (type UUID, typename) to cached IR TypeRefs. typename: Optional name hint to use for the type in the returned TypeRef. If ``None``, the type name is used. include_children: Whether to include the description of all material type children of *t*. include_ancestors: Whether to include the description of all material type ancestors of *t*. _name: Optional subtype element name if this type is a collection within a Tuple, Returns: A ``TypeRef`` instance corresponding to the given schema type. """ if cache is not None and typename is None: key = (t.id, include_children, include_ancestors) cached_result = cache.get(key) if cached_result is not None: # If the schema changed due to an ongoing compilation, the name # hint might be outdated. if cached_result.name_hint == t.get_name(schema): return cached_result # We separate the uncached version into another function because # it makes it easy to tell in a profiler when the cache isn't # operating, and because if the cache *is* operating it is no # great loss. return _type_to_typeref( schema, t, cache=cache, typename=typename, include_children=include_children, include_ancestors=include_ancestors, _name=_name, ) def _type_to_typeref( schema: s_schema.Schema, t: s_types.Type, *, cache: Optional[dict[TypeRefCacheKey, irast.TypeRef]] = None, typename: Optional[s_name.QualName] = None, include_children: bool = False, include_ancestors: bool = False, _name: Optional[str] = None, ) -> irast.TypeRef: def _typeref( t: s_types.Type, *, include_children: bool = include_children, include_ancestors: bool = include_ancestors, ) -> irast.TypeRef: return type_to_typeref( schema, t, include_children=include_children, include_ancestors=include_ancestors, cache=cache, ) result: irast.TypeRef material_type: s_types.Type name_hint = typename or t.get_name(schema) orig_name_hint = None if not typename else t.get_name(schema) if t.is_anytuple(schema): result = irast.AnyTupleRef( id=t.id, name_hint=name_hint, orig_name_hint=orig_name_hint, ) elif t.is_anyobject(schema): result = irast.AnyObjectRef( id=t.id, name_hint=name_hint, orig_name_hint=orig_name_hint, ) elif t.is_any(schema): result = irast.AnyTypeRef( id=t.id, name_hint=name_hint, orig_name_hint=orig_name_hint, ) elif not isinstance(t, s_types.Collection): assert isinstance(t, s_types.InheritingType) union: Optional[frozenset[irast.TypeRef]] = None union_is_exhaustive: bool = False expr_intersection: Optional[frozenset[irast.TypeRef]] = None expr_union: Optional[frozenset[irast.TypeRef]] = None if t.is_union_type(schema) or t.is_intersection_type(schema): union_types, union_is_exhaustive = ( s_utils.get_type_expr_non_overlapping_union(t, schema) ) union = frozenset( _typeref(c) for c in union_types ) # Keep track of type expression structure. # This is necessary to determine the correct rvar when doing # type intersections or polymorphic queries. if expr_intersection_types := t.get_intersection_of(schema): expr_intersection = frozenset( _typeref(c) for c in expr_intersection_types.objects(schema) ) if expr_union_types := t.get_union_of(schema): expr_union = frozenset( _typeref(c) for c in expr_union_types.objects(schema) ) schema, material_type = t.material_type(schema) material_typeref: Optional[irast.TypeRef] if material_type != t: material_typeref = _typeref(material_type) else: material_typeref = None if (isinstance(material_type, s_scalars.ScalarType) and not material_type.get_abstract(schema)): base_type = material_type.get_topmost_concrete_base(schema) if base_type == material_type: base_typeref = None else: assert isinstance(base_type, s_types.Type) base_typeref = _typeref(base_type, include_children=False) else: base_typeref = None children: Optional[frozenset[irast.TypeRef]] = None if ( material_typeref is None and include_children and children is None ): children = frozenset( _typeref(child, include_children=True) for child in t.children(schema) if not child.get_is_derived(schema) and not child.is_compound_type(schema) ) ancestors: Optional[frozenset[irast.TypeRef]] = None if ( material_typeref is None and include_ancestors and ancestors is None ): ancestors = frozenset( _typeref(ancestor, include_ancestors=False) for ancestor in t.get_ancestors(schema).objects(schema) ) sql_type = None needs_custom_json_cast = False custom_sql_serialization = None if isinstance(t, s_scalars.ScalarType): sql_type = t.resolve_sql_type(schema) if material_typeref is None: cast_name = s_casts.get_cast_fullname_from_names( orig_name_hint or name_hint, s_name.QualName('std', 'json')) jcast = schema.get(cast_name, type=s_casts.Cast, default=None) if jcast: needs_custom_json_cast = bool(jcast.get_code(schema)) custom_sql_serialization = t.get_custom_sql_serialization(schema) result = irast.TypeRef( id=t.id, name_hint=name_hint, orig_name_hint=orig_name_hint, material_type=material_typeref, base_type=base_typeref, children=children, ancestors=ancestors, union=union, union_is_exhaustive=union_is_exhaustive, expr_intersection=expr_intersection, expr_union=expr_union, element_name=_name, is_scalar=t.is_scalar(), is_abstract=t.get_abstract(schema), is_view=t.is_view(schema), is_cfg_view=is_cfg_view(t, schema), is_opaque_union=t.get_is_opaque_union(schema), needs_custom_json_cast=needs_custom_json_cast, sql_type=sql_type, custom_sql_serialization=custom_sql_serialization, ) elif isinstance(t, s_types.Tuple) and t.is_named(schema): schema, material_type = t.material_type(schema) if material_type != t: material_typeref = _typeref(material_type) else: material_typeref = None result = irast.TypeRef( id=t.id, name_hint=name_hint, orig_name_hint=orig_name_hint, material_type=material_typeref, element_name=_name, collection=t.get_schema_name(), in_schema=t.get_is_persistent(schema), subtypes=tuple( # ??? no cache type_to_typeref(schema, st, _name=sn, cache=None) for sn, st in t.iter_subtypes(schema) ) ) else: schema, material_type = t.material_type(schema) if material_type != t: material_typeref = type_to_typeref( schema, material_type, cache=cache ) else: material_typeref = None result = irast.TypeRef( id=t.id, name_hint=name_hint, orig_name_hint=orig_name_hint, material_type=material_typeref, element_name=_name, collection=t.get_schema_name(), in_schema=t.get_is_persistent(schema), subtypes=tuple( _typeref(st) for st in t.get_subtypes(schema) ) ) if cache is not None and typename is None and _name is None: key = (t.id, include_children, include_ancestors) # Note: there is no cache for `_name` variants since they are only used # for Tuple subtypes and thus they will be cached on the outer level # anyway. # There's also no variant for types with custom typenames since they # proved to have a very low hit rate. # This way we save on the size of the key tuple. cache[key] = result return result def ir_typeref_to_type( schema: s_schema.Schema, typeref: irast.TypeRef, ) -> tuple[s_schema.Schema, s_types.Type]: """Return a schema type for a given IR TypeRef. This is the reverse of :func:`~type_to_typeref`. Args: schema: A schema instance. The result type must exist in it. typeref: A :class:`ir.ast.TypeRef` instance for which to return the corresponding schema type. Returns: A tuple containing the possibly modified schema and a :class:`schema.types.Type` instance corresponding to the given *typeref*. """ # Optimistically try to lookup the type by id. Sometimes for # arrays and tuples this will fail, and we'll need to create it. t = schema.get_by_id(typeref.id, default=None, type=s_types.Type) if t: return schema, t elif is_tuple(typeref): named = False tuple_subtypes = {} for si, st in enumerate(typeref.subtypes): if st.element_name: named = True type_name = st.element_name else: type_name = str(si) schema, st_t = ir_typeref_to_type(schema, st) tuple_subtypes[type_name] = st_t return s_types.Tuple.from_subtypes( schema, tuple_subtypes, {'named': named}) elif is_array(typeref): array_subtypes = [] for st in typeref.subtypes: schema, st_t = ir_typeref_to_type(schema, st) array_subtypes.append(st_t) return s_types.Array.from_subtypes(schema, array_subtypes) else: raise AssertionError("couldn't find type from typeref") @overload def ptrref_from_ptrcls( *, schema: s_schema.Schema, ptrcls: s_pointers.Pointer, cache: Optional[PtrRefCache], typeref_cache: Optional[TypeRefCache], ) -> irast.PointerRef: ... @overload def ptrref_from_ptrcls( *, schema: s_schema.Schema, ptrcls: s_pointers.PointerLike, cache: Optional[PtrRefCache], typeref_cache: Optional[TypeRefCache], ) -> irast.BasePointerRef: ... def ptrref_from_ptrcls( *, schema: s_schema.Schema, ptrcls: s_pointers.PointerLike, cache: Optional[PtrRefCache], typeref_cache: Optional[TypeRefCache], ) -> irast.BasePointerRef: """Return an IR pointer descriptor for a given schema pointer. An IR PointerRef is an object that fully describes a schema pointer for the purposes of query compilation. Args: schema: A schema instance, in which the type *t* is defined. ptrcls: A :class:`schema.pointers.Pointer` instance for which to return the PointerRef. direction: The direction of the pointer in the path expression. Returns: An instance of a subclass of :class:`ir.ast.BasePointerRef` corresponding to the given schema pointer. """ if cache is not None: cached = cache.get(ptrcls) if cached is not None: return cached kwargs: dict[str, Any] = {} ircls: type[irast.BasePointerRef] source_ref: Optional[irast.TypeRef] target_ref: Optional[irast.TypeRef] out_source: Optional[irast.TypeRef] if isinstance(ptrcls, irast.TupleIndirectionLink): ircls = irast.TupleIndirectionPointerRef elif isinstance(ptrcls, irast.TypeIntersectionLink): ircls = irast.TypeIntersectionPointerRef kwargs['optional'] = ptrcls.is_optional() kwargs['is_empty'] = ptrcls.is_empty() kwargs['is_subtype'] = ptrcls.is_subtype() kwargs['rptr_specialization'] = ptrcls.get_rptr_specialization() elif isinstance(ptrcls, s_pointers.Pointer): ircls = irast.PointerRef kwargs['id'] = ptrcls.id kwargs['defined_here'] = ptrcls.get_defined_here(schema) if backlink := ptrcls.get_computed_link_alias(schema): assert isinstance(backlink, s_pointers.Pointer) kwargs['computed_link_alias'] = ptrref_from_ptrcls( ptrcls=backlink, schema=schema, cache=cache, typeref_cache=typeref_cache, ) kwargs['computed_link_alias_is_backward'] = ( ptrcls.get_computed_link_alias_is_backward(schema)) else: raise AssertionError(f'unexpected pointer class: {ptrcls}') target = ptrcls.get_target(schema) if target is not None and not isinstance(target, irast.TypeRef): assert isinstance(target, s_types.Type) target_ref = type_to_typeref( schema, target, include_children=True, cache=typeref_cache) else: target_ref = target source = ptrcls.get_source(schema) source_ptr: Optional[irast.BasePointerRef] if (isinstance(ptrcls, s_props.Property) and isinstance(source, s_links.Link)): source_ptr = ptrref_from_ptrcls( ptrcls=source, schema=schema, cache=cache, typeref_cache=typeref_cache, ) source_ref = None else: if source is not None and not isinstance(source, irast.TypeRef): assert isinstance(source, s_types.Type) source_ref = type_to_typeref(schema, source, include_ancestors=True, cache=typeref_cache) else: source_ref = source source_ptr = None out_source = source_ref out_target = target_ref out_cardinality, in_cardinality = cardinality_from_ptrcls( schema, ptrcls) schema, material_ptrcls = ptrcls.material_type(schema) material_ptr: Optional[irast.BasePointerRef] if material_ptrcls is not None and material_ptrcls != ptrcls: material_ptr = ptrref_from_ptrcls( ptrcls=material_ptrcls, schema=schema, cache=cache, typeref_cache=typeref_cache, ) else: material_ptr = None union_components: Optional[set[irast.BasePointerRef]] = None union_of = ptrcls.get_union_of(schema) union_is_exhaustive = False if union_of: union_ptrs = set() for component in union_of.objects(schema): assert isinstance(component, s_pointers.Pointer) schema, material_comp = component.material_type(schema) union_ptrs.add(material_comp) non_overlapping, union_is_exhaustive = ( s_utils.get_non_overlapping_union( schema, union_ptrs, ) ) union_components = { ptrref_from_ptrcls( ptrcls=p, schema=schema, cache=cache, typeref_cache=typeref_cache, ) for p in non_overlapping } intersection_components: Optional[set[irast.BasePointerRef]] = None intersection_of = ptrcls.get_intersection_of(schema) if intersection_of: intersection_ptrs = set() for component in intersection_of.objects(schema): assert isinstance(component, s_pointers.Pointer) schema, material_comp = component.material_type(schema) intersection_ptrs.add(material_comp) intersection_components = { ptrref_from_ptrcls( ptrcls=p, schema=schema, cache=cache, typeref_cache=typeref_cache, ) for p in intersection_ptrs } std_parent_name = None for ancestor in ptrcls.get_ancestors(schema).objects(schema): ancestor_name = ancestor.get_name(schema) if ancestor_name.module == 'std' and ancestor.is_non_concrete(schema): std_parent_name = ancestor_name break is_derived = ptrcls.get_is_derived(schema) base_ptr: Optional[irast.BasePointerRef] if is_derived: base_ptrcls = ptrcls.get_bases(schema).first(schema) top_ptr_name = type(base_ptrcls).get_default_base_name() if base_ptrcls.get_name(schema) != top_ptr_name: base_ptr = ptrref_from_ptrcls( ptrcls=base_ptrcls, schema=schema, cache=cache, typeref_cache=typeref_cache, ) else: base_ptr = None else: base_ptr = None if ( material_ptr is None and isinstance(ptrcls, s_pointers.Pointer) ): children = frozenset( ptrref_from_ptrcls( ptrcls=child, schema=schema, cache=cache, typeref_cache=typeref_cache, ) for child in ptrcls.children(schema) if not child.get_is_derived(schema) ) else: children = frozenset() kwargs.update( dict( out_source=out_source, out_target=out_target, name=ptrcls.get_name(schema), shortname=ptrcls.get_shortname(schema), std_parent_name=std_parent_name, source_ptr=source_ptr, base_ptr=base_ptr, material_ptr=material_ptr, children=children, is_derived=ptrcls.get_is_derived(schema), is_computable=ptrcls.get_computable(schema), union_components=union_components, intersection_components=intersection_components, union_is_exhaustive=union_is_exhaustive, has_properties=ptrcls.has_user_defined_properties(schema), in_cardinality=in_cardinality, out_cardinality=out_cardinality, ) ) ptrref = ircls(**kwargs) if cache is not None: cache[ptrcls] = ptrref # This is kind of unfortunate, but if we are caching, update the # base_ptr with this child if base_ptr and not material_ptr and ptrref not in base_ptr.children: base_ptr.children = base_ptr.children | frozenset([ptrref]) return ptrref @overload def ptrcls_from_ptrref( ptrref: irast.PointerRef, *, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, s_pointers.Pointer]: ... @overload def ptrcls_from_ptrref( ptrref: irast.BasePointerRef, *, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, s_pointers.PointerLike]: ... def ptrcls_from_ptrref( ptrref: irast.BasePointerRef, *, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, s_pointers.PointerLike]: """Return a schema pointer for a given IR PointerRef. This is the reverse of :func:`~type_to_typeref`. Args: schema: A schema instance. The result type must exist in it. ptrref: A :class:`ir.ast.BasePointerRef` instance for which to return the corresponding schema pointer. Returns: A tuple containing the possibly modifed schema and a :class:`schema.pointers.PointerLike` instance corresponding to the given *ptrref*. """ ptrcls: s_pointers.PointerLike if isinstance(ptrref, irast.TupleIndirectionPointerRef): schema, src_t = ir_typeref_to_type(schema, ptrref.out_source) schema, tgt_t = ir_typeref_to_type(schema, ptrref.out_target) ptrcls = irast.TupleIndirectionLink( source=src_t, target=tgt_t, element_name=ptrref.name.name, ) elif isinstance(ptrref, irast.TypeIntersectionPointerRef): target = schema.get_by_id(ptrref.out_target.id) assert isinstance(target, s_types.Type) ptrcls = irast.TypeIntersectionLink( source=schema.get_by_id(ptrref.out_source.id), target=target, optional=ptrref.optional, is_empty=ptrref.is_empty, is_subtype=ptrref.is_subtype, cardinality=ptrref.out_cardinality.to_schema_value()[1], ) elif isinstance(ptrref, irast.PointerRef): ptr = schema.get_by_id(ptrref.id) assert isinstance(ptr, s_pointers.Pointer) ptrcls = ptr else: raise TypeError(f'unexpected pointer ref type: {ptrref!r}') return schema, ptrcls def cardinality_from_ptrcls( schema: s_schema.Schema, ptrcls: s_pointers.PointerLike, ) -> tuple[Optional[qltypes.Cardinality], Optional[qltypes.Cardinality]]: out_card = ptrcls.get_cardinality(schema) required = ptrcls.get_required(schema) if out_card is None or not out_card.is_known(): # The cardinality is not yet known. out_cardinality = None in_cardinality = None else: assert isinstance(out_card, qltypes.SchemaCardinality) out_cardinality = qltypes.Cardinality.from_schema_value( required, out_card) # Backward link cannot be required, but exclusivity # controls upper bound on cardinality. if not ptrcls.is_non_concrete(schema) and ptrcls.is_exclusive(schema): in_cardinality = qltypes.Cardinality.AT_MOST_ONE else: in_cardinality = qltypes.Cardinality.MANY return out_cardinality, in_cardinality def is_id_ptrref(ptrref: irast.BasePointerRef) -> bool: """Return True if *ptrref* describes the id property.""" return ( str(ptrref.std_parent_name) == 'std::id' ) and not ptrref.source_ptr def is_computable_ptrref(ptrref: irast.BasePointerRef) -> bool: """Return True if pointer described by *ptrref* is computed.""" return ptrref.is_computable def get_tuple_element_index(ptrref: irast.TupleIndirectionPointerRef) -> int: name = ptrref.name.name if name.isdecimal() and name.isascii(): return int(name) else: for i, st in enumerate(ptrref.out_source.subtypes): if st.element_name == name: return i raise AssertionError(f"element {name} is not found in tuple type") def type_contains( parent: irast.TypeRef, child: irast.TypeRef, ) -> bool: """Check if *parent* typeref contains the given *child* typeref. Both *parent* and *child* can be type expressions. """ if parent == child: return True # Calculate the minterms of both *parent* and *child*. parent_minterms = _disjunctive_normal_form(parent) child_minterms = _disjunctive_normal_form(child) # The *parent* contains *child* if each child minterm is contained # by a parent minterm. # # Examples # - [A] contains [AB] # - [A,B] contains [A] # - [AB] does not contain [A] # - [A] does not contain [A,B] # - [AB,CD] does not contain [BD] return all( any( c.issuperset(p) for p in parent_minterms ) for c in child_minterms ) def _disjunctive_normal_form( typeref: irast.TypeRef ) -> list[set[uuid.UUID]]: """Convert any typeref into a minimal disjunctive normal form. In the result: - The outer list represents unions. - The inner sets represent intersections of simple types (ie. minterms). Duplicate and superset minterms are removed as redundant. """ def simplify( expr: Iterable[set[uuid.UUID]] ) -> list[set[uuid.UUID]]: # Remove any minterms which imply others # eg. [A, AB, BC] -> [A, BC] minterms_by_length = sorted( expr, key=lambda i: len(i) ) result: list[set[uuid.UUID]] = [] for minterm in minterms_by_length: if not any( minterm.issuperset(r) for r in result ): result.append(minterm) return result if typeref.expr_union: return simplify( minterm for t in typeref.expr_union for minterm in _disjunctive_normal_form(t) ) elif typeref.expr_intersection: components = [ _disjunctive_normal_form(t) for t in typeref.expr_intersection ] result = components[0] for other in components[1:]: result = [ set.union(r, o) for r in result for o in other ] return simplify(result) else: return [{typeref.id}] def find_actual_ptrref( source_typeref: irast.TypeRef, parent_ptrref: irast.BasePointerRef, *, dir: s_pointers.PointerDirection = s_pointers.PointerDirection.Outbound, material: bool=True, ) -> irast.BasePointerRef: if material and source_typeref.material_type: source_typeref = source_typeref.material_type if material and parent_ptrref.material_ptr: parent_ptrref = parent_ptrref.material_ptr ptrref = parent_ptrref if ptrref.source_ptr is not None: # Link property ref link_ptr: irast.BasePointerRef = ptrref.source_ptr if link_ptr.material_ptr: link_ptr = link_ptr.material_ptr if link_ptr.dir_source(dir).id == source_typeref.id: return ptrref elif ptrref.dir_source(dir).id == source_typeref.id: return ptrref # We are updating a subtype, find the # correct descendant ptrref. for dp in ( (ptrref.union_components or set()) | (ptrref.intersection_components or set()) ): candidate = maybe_find_actual_ptrref( source_typeref, dp, material=material, dir=dir) if candidate is not None: return candidate for dp in ptrref.children: if dp.dir_source(dir) and dp.dir_source(dir).id == source_typeref.id: return dp else: candidate = maybe_find_actual_ptrref( source_typeref, dp, material=material, dir=dir) if candidate is not None: return candidate raise LookupError( f'cannot find ptrref matching typeref {source_typeref.id}') def maybe_find_actual_ptrref( source_typeref: irast.TypeRef, parent_ptrref: irast.BasePointerRef, *, material: bool=True, dir: s_pointers.PointerDirection = s_pointers.PointerDirection.Outbound, ) -> Optional[irast.BasePointerRef]: try: return find_actual_ptrref( source_typeref, parent_ptrref, material=material, dir=dir) except LookupError: return None def get_typeref_descendants(typeref: irast.TypeRef) -> set[irast.TypeRef]: result = set() if typeref.children: for child in typeref.children: result.add(child) result.update(get_typeref_descendants(child)) return result def maybe_lookup_obj_pointer( schema: s_schema.Schema, name: s_name.QualName, ptr_name: s_name.UnqualName, ) -> Optional[s_pointers.Pointer]: base_object = schema.get(name, type=s_objtypes.ObjectType, default=None) if not base_object: return None ptr = base_object.maybe_get_ptr(schema, ptr_name) return ptr def lookup_obj_ptrref( schema: s_schema.Schema, name: s_name.QualName, ptr_name: s_name.UnqualName, cache: Optional[dict[PtrRefCacheKey, irast.BasePointerRef]] = None, typeref_cache: Optional[dict[TypeRefCacheKey, irast.TypeRef]] = None, ) -> irast.PointerRef: ptr = maybe_lookup_obj_pointer(schema, name, ptr_name) assert ptr return ptrref_from_ptrcls( ptrcls=ptr, schema=schema, cache=cache, typeref_cache=typeref_cache, ) def replace_pathid_prefix( path_id: irast.PathId, prefix: irast.PathId, replacement: irast.PathId, permissive_ptr_path: bool=False, ) -> irast.PathId: """Return a copy of *path_id* with *prefix* replaced by *replacement*. Example: replace_pathid_prefix(A.b.c, A.b, X.y) == PathId(X.y.c) """ if not path_id.startswith(prefix, permissive_ptr_path=permissive_ptr_path): return path_id # TODO: iter_prefixes is kind of expensive; can we do this in a # way that peeks into the internals more? result = replacement prefixes = list(path_id.iter_prefixes(include_ptr=prefix.is_ptr_path())) lastns = prefix.namespace try: start = prefixes.index(prefix) except ValueError: if permissive_ptr_path: start = prefixes.index(prefix.ptr_path()) else: raise for part in prefixes[start + 1:]: if part.is_ptr_path(): continue ptrref = part.rptr() if not ptrref: continue dir = part.rptr_dir() assert dir if ( isinstance(ptrref, irast.TupleIndirectionPointerRef) and result.target.collection == 'tuple' ): # For tuple indirections, we want to update the target # type when we get mapped to a subtype. idx = get_tuple_element_index(ptrref) target = result.target if target.id != target.subtypes[idx].id: ptrref = ptrref.replace( out_source=target, out_target=target.subtypes[idx], ) if ptrref.source_ptr: result = result.ptr_path() result = result.extend( ptrref=ptrref, direction=dir, ns=part.namespace - lastns) lastns = part.namespace if path_id.is_ptr_path(): result = result.ptr_path() return result ================================================ FILE: edb/ir/utils.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2015-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Miscellaneous utilities for the IR.""" from __future__ import annotations from typing import ( Any, Optional, AbstractSet, Mapping, Sequence, Iterable, cast, TYPE_CHECKING, ) if TYPE_CHECKING: from typing_extensions import TypeGuard import json import uuid from edb import errors from edb.common import ast from edb.common import ordered from edb.edgeql import qltypes as ft from . import ast as irast from . import typeutils def get_longest_paths(ir: irast.Base) -> set[irast.Set]: """Return a distinct set of longest paths found in an expression. For example in SELECT (A.B.C, D.E.F, A.B, D.E) the result would be {A.B.C, D.E.F}. """ result = set() parents = set() ir_sets = ast.find_children( ir, irast.Set, lambda n: sub_expr(n) is None or isinstance(n.expr, irast.TypeRoot), ) for ir_set in ir_sets: result.add(ir_set) if isinstance(ir_set.expr, irast.Pointer): parents.add(ir_set.expr.source) return result - parents def get_parameters(ir: irast.Base) -> set[irast.QueryParameter]: """Return all parameters found in *ir*.""" return set(ast.find_children(ir, irast.QueryParameter)) def is_const(ir: irast.Base) -> bool: """Return True if the given *ir* expression is constant.""" roots = ast.find_children(ir, irast.TypeRoot) variables = get_parameters(ir) return not roots and not variables def is_union_expr(ir: irast.Base) -> bool: """Return True if the given *ir* expression is a UNION expression.""" return ( isinstance(ir, irast.OperatorCall) and ir.operator_kind is ft.OperatorKind.Infix and str(ir.func_shortname) == 'std::UNION' ) def is_empty_array_expr(ir: Optional[irast.Base]) -> TypeGuard[irast.Array]: """Return True if the given *ir* expression is an empty array expression. """ return ( isinstance(ir, irast.Array) and not ir.elements ) def is_untyped_empty_array_expr( ir: Optional[irast.Base], ) -> TypeGuard[irast.Array]: """Return True if the given *ir* expression is an empty array expression of an uknown type. """ return ( is_empty_array_expr(ir) and (ir.typeref is None or typeutils.is_generic(ir.typeref)) ) def is_empty(ir: irast.Base) -> bool: """Return True if the given *ir* expression is an empty set or an empty array. """ return ( isinstance(ir, irast.EmptySet) or (isinstance(ir, irast.Array) and not ir.elements) or ( isinstance(ir, irast.Set) and is_empty(ir.expr) ) ) def is_subquery_set(ir_expr: irast.Base) -> bool: """Return True if the given *ir_expr* expression is a subquery.""" return ( isinstance(ir_expr, irast.Set) and ( isinstance(ir_expr.expr, irast.Stmt) or ( isinstance(ir_expr.expr, irast.Pointer) and ir_expr.expr.expr is not None ) ) ) def is_implicit_wrapper( ir_expr: Optional[irast.Base], ) -> TypeGuard[irast.SelectStmt]: """Return True if the given *ir_expr* expression is an implicit SELECT wrapper. """ return ( isinstance(ir_expr, irast.SelectStmt) and ir_expr.implicit_wrapper ) def is_trivial_select(ir_expr: irast.Base) -> TypeGuard[irast.SelectStmt]: """Return True if the given *ir_expr* expression is a trivial SELECT expression, i.e `SELECT `. """ if not isinstance(ir_expr, irast.SelectStmt): return False return ( not ir_expr.orderby and ir_expr.iterator_stmt is None and ir_expr.where is None and ir_expr.limit is None and ir_expr.offset is None and ir_expr.card_inference_override is None ) def unwrap_set(ir_set: irast.Set) -> irast.Set: """If the given *ir_set* is an implicit SELECT wrapper, return the wrapped set. """ if is_implicit_wrapper(ir_set.expr): return ir_set.expr.result else: return ir_set def get_path_root(ir_set: irast.Set) -> irast.Set: result = ir_set while isinstance(result.expr, irast.Pointer): result = result.expr.source return result def get_span_as_json( expr: irast.Base, exctype: type[errors.EdgeDBError] = errors.InternalServerError, ) -> str: if expr.span: details = json.dumps({ # TODO(tailhook) should we add offset, utf16column here? 'line': expr.span.start_point.line, 'column': expr.span.start_point.column, 'name': expr.span.filename, 'code': exctype.get_code(), }) else: details = json.dumps({ 'code': exctype.get_code(), }) return details def is_type_intersection_reference(ir_expr: irast.Base) -> bool: """Return True if the given *ir_expr* is a type intersection, i.e ``Foo[IS Type]``. """ if not isinstance(ir_expr, irast.Set): return False if not isinstance(ir_expr.expr, irast.Pointer): return False rptr = ir_expr.expr ir_source = rptr.source if ir_source.path_id.is_type_intersection_path(): source_is_type_intersection = True else: source_is_type_intersection = False return source_is_type_intersection def is_trivial_free_object(ir: irast.Set) -> bool: ir = unwrap_set(ir) return ( isinstance(ir.expr, irast.TypeRoot) and typeutils.is_exactly_free_object(ir.typeref) ) def collapse_type_intersection( ir_set: irast.Set, ) -> tuple[irast.Set, list[irast.TypeIntersectionPointer]]: result: list[irast.TypeIntersectionPointer] = [] source = ir_set while True: rptr = source.expr if not isinstance(rptr, irast.TypeIntersectionPointer): break result.append(rptr) source = rptr.source return source, result class CollectDMLSourceVisitor(ast.NodeVisitor): skip_hidden = True def __init__( self, binding_dml: Mapping[irast.PathId, Sequence[irast.MutatingLikeStmt]], ) -> None: super().__init__() self.binding_dml = binding_dml self.dml: list[irast.MutatingLikeStmt] = [] def visit_MutatingLikeStmt(self, stmt: irast.MutatingLikeStmt) -> None: # Only INSERTs and UPDATEs produce meaningful overlays. if not isinstance(stmt, irast.DeleteStmt): self.dml.append(stmt) def visit_Set(self, node: irast.Set) -> None: # Visit sub-trees if node.expr: self.visit(node.expr) if node.is_binding: self.dml.extend(self.binding_dml.get(node.path_id, ())) def visit_Pointer(self, node: irast.Pointer) -> None: if node.expr: self.visit(node.expr) else: self.visit(node.source) def get_dml_sources( ir_set: irast.Set, binding_dml: Mapping[irast.PathId, Sequence[irast.MutatingLikeStmt]], ) -> Sequence[irast.MutatingLikeStmt]: """Find the DML expressions that can contribute to the value of a set This is used to compute which overlays to use during SQL compilation. """ # TODO: Make this caching. visitor = CollectDMLSourceVisitor(binding_dml) visitor.visit(ir_set) # Deduplicate, but preserve order. It shouldn't matter for # *correctness* but it helps keep the nondeterminism in the output # SQL down. return tuple(ordered.OrderedSet(visitor.dml)) class ContainsDMLVisitor(ast.NodeVisitor): skip_hidden = True def __init__(self, *, skip_bindings: bool) -> None: super().__init__() self.skip_bindings = skip_bindings def combine_field_results(self, xs: Iterable[Optional[bool]]) -> bool: return any( x is True or (isinstance(x, (list, tuple)) and self.combine_field_results(x)) or (isinstance(x, dict) and self.combine_field_results(x.values())) for x in xs ) def visit_MutatingStmt(self, stmt: irast.MutatingStmt) -> bool: return True def visit_Set(self, node: irast.Set) -> bool: if self.skip_bindings and node.is_binding: return False # Visit sub-trees return bool(self.generic_visit(node)) def contains_dml( stmt: irast.Base, *, skip_bindings: bool = False, skip_nodes: Iterable[irast.Base] = (), ) -> bool: """Check whether a statement contains any DML in a subtree.""" # TODO: Make this caching. visitor = ContainsDMLVisitor(skip_bindings=skip_bindings) for node in skip_nodes: visitor._memo[node] = False res = visitor.visit(stmt) is True return res class FindPathScopes(ast.NodeVisitor): """Visitor to find the enclosing path scope id of sub expressions. Sets inherit an effective scope id from enclosing expressions, and this visitor computes those. This is set up so that another visitor could inherit from it, override process_set, and also collect the scope tree info. """ def __init__(self, init_scope: Optional[int] = None) -> None: super().__init__() self.path_scope_ids: list[Optional[int]] = [init_scope] self.use_scopes: dict[irast.Set, Optional[int]] = {} self.scopes: dict[irast.Set, Optional[int]] = {} def visit_Stmt(self, stmt: irast.Stmt) -> Any: # Sometimes there is sharing, so we want the official scope # for a node to be based on its appearance in the result, # not in a subquery. # I think it might not actually matter, though. self.visit(stmt.bindings) if stmt.iterator_stmt: self.visit(stmt.iterator_stmt) if isinstance(stmt, (irast.MutatingStmt, irast.GroupStmt)): self.visit(stmt.subject) if isinstance(stmt, irast.GroupStmt): for v in stmt.using.values(): self.visit(v) self.visit(stmt.result) return self.generic_visit(stmt) def visit_Set(self, node: irast.Set) -> Any: val = self.path_scope_ids[-1] self.use_scopes[node] = val if node.path_scope_id: self.path_scope_ids.append(node.path_scope_id) if not node.is_binding: val = self.path_scope_ids[-1] # Visit sub-trees self.scopes[node] = val res = self.process_set(node) if node.path_scope_id: self.path_scope_ids.pop() return res def process_set(self, node: irast.Set) -> Any: self.generic_visit(node) return None def find_path_scopes( stmt: irast.Base | Sequence[irast.Base], ) -> dict[irast.Set, Optional[int]]: visitor = FindPathScopes() visitor.visit(stmt) return visitor.scopes class FindPotentiallyVisibleVisitor(FindPathScopes): skip_hidden = True extra_skips = frozenset(['materialized_sets']) def __init__( self, to_skip: AbstractSet[irast.PathId], scope: irast.ScopeTreeNode, scope_tree_nodes: Mapping[int, irast.ScopeTreeNode], ) -> None: super().__init__(init_scope=scope.unique_id) self.to_skip = to_skip self.orig_scope = scope self.scope_tree_nodes = scope_tree_nodes def combine_field_results(self, xs: Any) -> set[irast.Set]: out = set() for x in xs: if isinstance(x, (list, tuple)): x = self.combine_field_results(x) if isinstance(x, dict): x = self.combine_field_results(x.values()) if x: if isinstance(x, set): out.update(x) return out def visit_Pointer(self, node: irast.Pointer) -> set[irast.Set]: res: set[irast.Set] = self.visit(node.source) return res def process_set(self, node: irast.Set) -> set[irast.Set]: if node.path_id in self.to_skip: # We only skip nodes in to_skip if their use site is # underneath our original binding site. This prevents us # from skipping references to them embedded in outside # WITH bindings. if ( (psid := self.use_scopes[node]) is not None and ( self.orig_scope in self.scope_tree_nodes[psid].ancestors ) ): return set() results = [{node}] if isinstance(node.expr, irast.Pointer): results.append(self.visit(node.expr)) results.append(self.visit(node.shape)) else: results.append(self.visit(node.shape)) results.append(self.visit(node.expr)) # Bound variables are always potentially visible as are object # references. if ( node.is_binding or isinstance(node.expr, irast.TypeRoot) ): results.append({node}) # Visit sub-trees return self.combine_field_results(results) def find_potentially_visible( stmt: irast.Base, scope: irast.ScopeTreeNode, scope_tree_nodes: Mapping[int, irast.ScopeTreeNode], to_skip: AbstractSet[irast.PathId]=frozenset() ) -> set[tuple[irast.PathId, irast.Set]]: """Find all "potentially visible" sets referenced.""" # TODO: Make this caching. visitor = FindPotentiallyVisibleVisitor( to_skip=to_skip, scope=scope, scope_tree_nodes=scope_tree_nodes) visible_sets = cast(set[irast.Set], visitor.visit(stmt)) visible_paths = set() for ir in visible_sets: path_id = ir.path_id # Collect any namespaces between where the set is referred to # and the binding point we are looking from, and strip those off. # We need to do this because visibility *from the binding point* # needs to not include namespaces defined below it. # (See test_edgeql_scope_ref_side_02 for an example where this # matters.) if (set_scope_id := visitor.scopes.get(ir)) is not None: set_scope = scope_tree_nodes[set_scope_id] for anc, ns in set_scope.ancestors_and_namespaces: if anc is scope: path_id = path_id.strip_namespace(ns) break visible_paths.add((path_id, ir)) return visible_paths def is_singleton_set_of_call( call: irast.Call ) -> bool: # Some set functions and operators are allowed in singleton mode # as long as their inputs are singletons return bool(call.is_singleton_set_of) def has_set_of_param( call: irast.Call, ) -> bool: return any( arg.param_typemod == ft.TypeModifier.SetOfType for arg in call.args.values() ) def returns_set_of( call: irast.Call, ) -> bool: return call.typemod == ft.TypeModifier.SetOfType def find_set_of_op( ir: irast.Base, has_multi_param: bool, ) -> Optional[irast.Call]: def flt(n: irast.Call) -> bool: return ( (has_multi_param or not is_singleton_set_of_call(n)) and (has_set_of_param(n) or returns_set_of(n)) ) calls = ast.find_children(ir, irast.Call, flt, terminate_early=True) return next(iter(calls or []), None) def is_set_instance[ExprT: irast.Expr]( ir: irast.Set, typ: type[ExprT], ) -> TypeGuard[irast.SetE[ExprT]]: return isinstance(ir.expr, typ) def ref_contains_multi(ref: irast.Set, singleton_id: uuid.UUID) -> bool: while isinstance(ref.expr, irast.Pointer): pointer: irast.Pointer = ref.expr if pointer.dir_cardinality.is_multi(): return True # We don't need to look further than the object that we know is a # singleton. if ( singleton_id and isinstance(pointer.ptrref, irast.PointerRef) and pointer.ptrref.id == singleton_id ): break ref = pointer.source return False def sub_expr(ir: irast.Set) -> Optional[irast.Expr]: """Fetch the "sub-expression" of a set. For a non-pointer Set, it's just the expr, but for a Pointer it is the optional computed expression. """ if isinstance(ir.expr, irast.Pointer): return ir.expr.expr else: return ir.expr class CollectSchemaTypesVisitor(ast.NodeVisitor): types: set[uuid.UUID] def __init__(self) -> None: super().__init__() self.types = set() def visit_Set(self, node: irast.Set) -> None: self.types.add(node.typeref.id) self.generic_visit(node) def collect_schema_types(stmt: irast.Base) -> set[uuid.UUID]: """Collect ids of all types referenced in the statement.""" visitor = CollectSchemaTypesVisitor() visitor.visit(stmt) return visitor.types def is_linkful(ir: irast.Base) -> bool: def flt(p: irast.Pointer) -> bool: return typeutils.is_object(p.typeref) return bool(ast.find_children(ir, irast.Pointer, flt, terminate_early=True)) ================================================ FILE: edb/language_server/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import dataclasses from typing import Optional, Any @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class Result[T, E]: ok: Optional[T] = None err: Optional[E] = None def is_schema_file(path: str) -> bool: return path.endswith(('.esdl', '.gel')) def is_edgeql_file(path: str) -> bool: return path.endswith('.edgeql') def dump_to_str(node: Any) -> str: import io from edb.common import markup buf = io.StringIO() markup.dump(node, file=buf) return buf.getvalue() def dump_to_local_file(path: str, node: Any): import pathlib from edb.common import markup with ('.' / pathlib.Path(path)).open('w') as file: markup.dump(node, file=file) ================================================ FILE: edb/language_server/completion.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any from lsprotocol import types as lsp_types import pygls from edb.common import ast from edb.common import span as edb_span from edb.edgeql import ast as qlast from edb.edgeql import tokenizer as qltokenizer from edb.edgeql import compiler as qlcompiler from edb.schema import name as sn from edb.schema import modules as s_modules from edb.schema import objtypes as s_objtypes from edb.schema import types as s_types from edb.schema import scalars as s_scalars from edb.schema import schema as s_schema from edb.schema import objects as s_objects from . import parsing as ls_parsing from . import server as ls_server from . import schema as ls_schema def get_completion( ls: ls_server.GelLanguageServer, params: lsp_types.CompletionParams ) -> lsp_types.CompletionList: document = ls.workspace.get_text_document(params.text_document.uri) target = qltokenizer.line_col_to_source_point( document.source, params.position.line, params.position.character ) ls.show_message_log(f'get_completion at position {target.offset}') # get syntactic suggestions items, can_be_ident = ls_parsing.get_completion(document, target.offset, ls) ls.show_message_log(f'can_be_ident = {can_be_ident}') if can_be_ident: ql_ast = ls_parsing.parse_and_recover(document) ls.show_message_log(f'ql_ast = {ql_ast}') if isinstance(ql_ast, qlast.Commands): items = ( _get_completion_in_ql(ls, document, ql_ast, target.offset) ) + items elif isinstance(ql_ast, qlast.Schema): items = ( _get_completion_in_schema(ls, document, ql_ast, target.offset) ) + items return lsp_types.CompletionList(is_incomplete=False, items=items) def _get_completion_in_ql( ls: ls_server.GelLanguageServer, document: pygls.workspace.TextDocument, ql_stmts: qlast.Commands, target: int, ) -> list[lsp_types.CompletionItem]: # replace the expr under the cursor with qlast.Cursor if not ql_stmts.commands: return [] for ql_stmt in ql_stmts.commands: replaced = replace_by_source_position(ql_stmt, qlast.Cursor(), target) if replaced: break if not replaced: ls.show_message_log(f'Cannot inject qlast.Cursor') return [] # compile the stmt that now contains the qlast.Cursor, # which should halt compilation, when it gets to the cursor try: diagnostics, _ir_stmts = ls_server.compile_ql(ls, document, [ql_stmt]) except qlcompiler.expr.IdentCompletionException as e: return [ lsp_types.CompletionItem( label=s, kind=lsp_types.CompletionItemKind.Variable ) for s in e.suggestions ] for diags in diagnostics.by_doc.values(): for d in diags: ls.show_message_log(f'Cannot provide completion: {d.message}') return [] raise AssertionError('qlast.Cursor did not raise IdentCompletionException') def _get_completion_in_schema( ls: ls_server.GelLanguageServer, document: pygls.workspace.TextDocument, ql_schema: qlast.Schema, target: int, ) -> list[lsp_types.CompletionItem]: node_path = edb_span.find_by_source_position(ql_schema, target) ls.show_message_log(f"node_path = {node_path}") if not node_path: return [] schema = ls.state.schema if not schema: return [] items: list[lsp_types.CompletionItem] = [] # when in a module, suggest objects from that module module = ls_schema.get_module_context(node_path[1:]) ls.show_message_log(f"module = {module}") if module: objects: s_schema.SchemaIterator[s_objects.Object] = schema.get_objects( included_modules=(sn.UnqualName(module),), ) for obj in objects: if isinstance(obj, s_types.Type) and obj.get_from_alias(schema): continue kind: lsp_types.CompletionItemKind if isinstance(obj, s_objtypes.ObjectType): kind = lsp_types.CompletionItemKind.Struct elif isinstance(obj, s_scalars.ScalarType): kind = lsp_types.CompletionItemKind.Value else: continue label = obj.get_name(schema).name items.append( lsp_types.CompletionItem( label=label, kind=kind, # detail=str(obj), ) ) # always suggest modules objects = schema.get_objects(type=s_modules.Module) for obj in objects: name = obj.get_displayname(schema) items.append( lsp_types.CompletionItem( label=name, insert_text=name + '::', kind=lsp_types.CompletionItemKind.Module, ) ) return items # Replaces an expr node in AST that has a certain position within the source. # It matches the first Expr whose span contains the target offset in a # post-order traversal of the AST. def replace_by_source_position( tree: qlast.Base, replacement: qlast.Expr, target_offset: int ) -> bool: replacer = SpanReplacer(target_offset, replacement) replacer.visit(tree) return replacer.found class SpanReplacer(ast.NodeTransformer): target_offset: int replacement: qlast.Expr found: bool def __init__(self, target_offset: int, replacement: qlast.Expr): super().__init__() self.target_offset = target_offset self.replacement = replacement self.found = False def generic_visit(self, node, *, combine_results=None) -> Any: if self.found: return node has_span = False if node_span := getattr(node, 'span', None): has_span = True if not edb_span.span_contains(node_span, self.target_offset): return node r = super().generic_visit(node) if not self.found and has_span and isinstance(node, qlast.Expr): self.found = True return self.replacement return r ================================================ FILE: edb/language_server/definition.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Optional from lsprotocol import types as lsp_types import pygls import pygls.workspace from edb.common import span as edb_span from edb.edgeql import ast as qlast from edb.edgeql import tokenizer as qltokenizer from edb.ir import ast as irast from edb.schema import objects as s_objects from edb.schema import name as s_name from edb.schema import schema as s_schema from edb.schema import types as s_types from . import parsing as ls_parsing from . import utils as ls_utils from . import server as ls_server from . import schema as ls_schema from . import is_schema_file, is_edgeql_file def get_definition( ls: ls_server.GelLanguageServer, params: lsp_types.DefinitionParams ) -> lsp_types.Location | list[lsp_types.Location]: doc_uri = params.text_document.uri document = ls.workspace.get_text_document(doc_uri) position: int = qltokenizer.line_col_to_source_point( document.source, params.position.line, params.position.character ).offset ls.show_message_log(f'get_definition at position = {position}') try: if is_schema_file(doc_uri): return _get_definition_in_schema(ls, document, position) or [] elif is_edgeql_file(doc_uri): ql_ast_res = ls_parsing.parse(document) if not ql_ast_res.ok: return [] ql_ast = ql_ast_res.ok if isinstance(ql_ast, qlast.Commands): return ( _get_definition_in_ql(ls, document, ql_ast, position) or [] ) else: # SDL in query files? pass else: ls.show_message_log(f'Unknown file type: {doc_uri}') except BaseException as e: ls_server.send_internal_error(ls, e) return [] def _get_definition_in_ql( ls: ls_server.GelLanguageServer, document: pygls.workspace.TextDocument, ql_ast: qlast.Commands, position: int, ) -> lsp_types.Location | None: # compile the whole doc # TODO: search ql ast before compiling all stmts _, ir_stmts = ls_server.compile_ql(ls, document, ql_ast.commands) # find the ir node at the position node_path = None for ir_stmt in ir_stmts: node_path = edb_span.find_by_source_position(ir_stmt, position) if node_path: break if not node_path: ls.show_message_log(f"cannot find span in {len(ir_stmts)} stmts") return None node = node_path[0] assert isinstance(node, irast.Base), node ls.show_message_log(f"node: {node}") schema = ir_stmt.schema assert schema # lookup schema objects depending on which ir node we are over target = _determine_ir_target(node, schema) if not target: ls.show_message_log(f"don't know how to lookup schema by {node}") return None return _schema_obj_to_doc_location(ls, target, schema, document) def _determine_ir_target( node: irast.Base, schema: s_schema.Schema ) -> Optional[s_objects.Object]: # special handling: references to WITH bindings if ( isinstance(node, irast.SetE) and node.is_binding == irast.BindingKind.With ): target = schema.get_by_id( node.typeref.id, type=s_objects.InheritingObject ) assert target while ( target.get_span(schema) is None and isinstance(target, s_types.Type) and target.get_from_alias(schema) ): target = target.get_bases(schema).objects(schema)[0] return target # unwrap a set if isinstance(node, irast.SetE): node = node.expr # unwrap select stmts while isinstance(node, irast.SelectStmt): node = node.result.expr # references to object types if isinstance(node, irast.TypeRoot): return schema.get_by_id(node.typeref.id) # references to pointers if isinstance(node, irast.Pointer) and isinstance( node.ptrref, irast.PointerRef ): return schema.get_by_id(node.ptrref.id) return None # Finds definition of names in schema files. # # Parses the file and finds the ObjectRef at the given position. Then, it # computes "module context", by looking at names of encapsulating modules so # it can convert ObjectRef into a qualified name. Then it just looks up that # name in the schema. # # This impl might be lacking, since it does not use the code we use for name # resolution in the main compiler (tracing.py), and might report some # definitions incorrectly (i.e. within expressions). def _get_definition_in_schema( ls: ls_server.GelLanguageServer, document: pygls.workspace.TextDocument, position: int, ) -> lsp_types.Location | None: res = ls_schema._ensure_schema_docs_loaded(ls) if res.err: return None # parse current doc, return on errors _ = ls_schema._parse_schema(ls) assert ls.state.schema_sdl # find the span in ql ast node_path = edb_span.find_by_source_position(ls.state.schema_sdl, position) if not node_path: return None # ls.show_message_log(f"found node: {dump_to_str(node_path[0])}") # only resolve ObjectRefs if not isinstance(node_path[0], qlast.ObjectRef): return None # convert qlast.ObjectRef into a sn.QualName name: str = node_path[0].name module: Optional[str] = node_path[0].module if not module: module = ls_schema.get_module_context(node_path[1:]) if not module: return None q_name = s_name.QualName(module, name) ls.show_message_log(f"name: {q_name}") # lookup the name in latest compiled schema schema = ls.state.schema if not schema: return None obj = schema.get(q_name, default=None) if not obj: ls.show_message_log(f"object with this name not found") return None return _schema_obj_to_doc_location(ls, obj, schema, document) def _schema_obj_to_doc_location( ls: ls_server.GelLanguageServer, obj: s_objects.Object, schema: s_schema.Schema, curr_doc: pygls.workspace.TextDocument, ) -> lsp_types.Location | None: name = obj.get_name(schema) ls.show_message_log(f"find schema object: {name}") span: edb_span.Span | None = obj.get_span(schema) if not span: ls.show_message_log(f"no span for schema object") return None # find originating document doc: Optional[pygls.workspace.TextDocument] = None # is doc the current document? if span.filename == curr_doc.filename: doc = curr_doc # find schema docs with this filename if not doc: docs = ls.state.schema_docs doc = next((d for d in docs if d.filename == span.filename), None) if not doc: ls.show_message_log(f"Cannot find doc: {span.filename}") return None return lsp_types.Location( uri=doc.uri, range=ls_utils.span_to_lsp(doc.source, (span.start, span.end)), ) ================================================ FILE: edb/language_server/main.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import sys import json from lsprotocol import types as lsp_types import click from edb import buildmeta from edb.edgeql import parser as qlparser from . import server as ls_server from . import definition as ls_definition from . import completion as ls_completion @click.command() @click.option('--version', is_flag=True, help="Show the version and exit.") @click.option( '--stdio', is_flag=True, help="Use stdio for LSP. This is currently the only transport.", ) @click.argument("options", type=str, default='{}') def main(options: str | None, *, version: bool, stdio: bool): if version: print(f"gel-ls, version {buildmeta.get_version()}") sys.exit(0) ls = init(options) if stdio: ls.start_io() else: print("Error: no LSP transport enabled. Use --stdio.") def init(options_json: str | None) -> ls_server.GelLanguageServer: # load config options_dict = json.loads(options_json or '{}') project_dir = '.' if 'project_dir' in options_dict: project_dir = options_dict['project_dir'] config = ls_server.Config(project_dir=project_dir) # construct server ls = ls_server.GelLanguageServer(config) debug_init(ls) # register hooks @ls.feature( lsp_types.INITIALIZE, ) def init(_params: lsp_types.InitializeParams): qlparser.preload_spec() ls.show_message_log('gel-ls ready for requests') @ls.feature(lsp_types.TEXT_DOCUMENT_DID_OPEN) def text_document_did_open(params: lsp_types.DidOpenTextDocumentParams): ls_server.document_updated(ls, params.text_document.uri, compile=True) @ls.feature(lsp_types.TEXT_DOCUMENT_DID_CHANGE) def text_document_did_change(params: lsp_types.DidChangeTextDocumentParams): ls_server.document_updated(ls, params.text_document.uri, compile=False) @ls.feature(lsp_types.TEXT_DOCUMENT_DID_SAVE) def text_document_did_save(params: lsp_types.DidChangeTextDocumentParams): ls_server.document_updated(ls, params.text_document.uri, compile=True) @ls.feature(lsp_types.TEXT_DOCUMENT_DEFINITION) def text_document_definition( params: lsp_types.DefinitionParams, ) -> lsp_types.Definition: return ls_definition.get_definition(ls, params) @ls.feature( lsp_types.TEXT_DOCUMENT_COMPLETION, lsp_types.CompletionOptions(trigger_characters=[',']), ) def completion(params: lsp_types.CompletionParams): return ls_completion.get_completion(ls, params) return ls # Last gel-ls instance initialed. Use ONLY for debugging purposes. __gel_ls: ls_server.GelLanguageServer | None = None def debug_init(ls: ls_server.GelLanguageServer): global __gel_ls __gel_ls = ls def send_log_message(message: str): global __gel_ls assert __gel_ls, 'GelLanguageServer has not be started yet' __gel_ls.show_message_log(message) ================================================ FILE: edb/language_server/parsing.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Optional from pygls.server import LanguageServer from pygls.workspace import TextDocument from lsprotocol import types as lsp_types from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import tokenizer from edb.edgeql import parser as qlparser from edb.edgeql.parser.grammar import tokens as qltokens import edb._edgeql_parser as rust_parser from . import Result, is_schema_file from . import utils as ls_utils def parse( doc: TextDocument, ) -> Result[qlast.Commands | qlast.Schema, list[lsp_types.Diagnostic]]: sdl = is_schema_file(doc.filename) if doc.filename else False start_t = qltokens.T_STARTSDLDOCUMENT if sdl else qltokens.T_STARTBLOCK start_t_name = start_t.__name__[2:] source_res = _tokenize(doc.source) if diagnostics := source_res.err: return Result(err=diagnostics) source = source_res.ok assert source result, productions = rust_parser.parse(start_t_name, source.tokens()) if result.errors: diagnostics = [] for error in result.errors: message, span, hint, details = error if details: message += f"\n{details}" if hint: message += f"\nHint: {hint}" diagnostics.append( lsp_types.Diagnostic( range=ls_utils.span_to_lsp(source.text(), span), severity=lsp_types.DiagnosticSeverity.Error, message=message, ) ) return Result(err=diagnostics) # parsing successful assert isinstance(result.out, rust_parser.CSTNode) try: ast = qlparser._cst_to_ast( result.out, productions, source, doc.filename ).val except errors.EdgeDBError as e: return Result(err=[ls_utils.error_to_lsp(e)]) if sdl: assert isinstance(ast, qlast.Schema), ast else: assert isinstance(ast, qlast.Commands), ast return Result(ok=ast) def parse_and_recover( doc: TextDocument, ) -> Optional[qlast.Commands | qlast.Schema]: sdl = is_schema_file(doc.filename) if doc.filename else False start_t = qltokens.T_STARTSDLDOCUMENT if sdl else qltokens.T_STARTBLOCK start_t_name = start_t.__name__[2:] source_res = _tokenize(doc.source) if not source_res.ok: return None source = source_res.ok result, productions = rust_parser.parse(start_t_name, source.tokens()) if not isinstance(result.out, rust_parser.CSTNode): return None try: ast = qlparser._cst_to_ast( result.out, productions, source, doc.filename ).val except errors.EdgeDBError: return None if sdl: assert isinstance(ast, qlast.Schema), ast else: assert isinstance(ast, qlast.Commands), ast return ast def get_completion( doc: TextDocument, target: int, ls: LanguageServer ) -> tuple[list[lsp_types.CompletionItem], bool]: sdl = is_schema_file(doc.path) start_t = qltokens.T_STARTSDLDOCUMENT if sdl else qltokens.T_STARTBLOCK start_t_name = start_t.__name__[2:] # tokenize source_res = _tokenize(doc.source) if not source_res.ok: return [], False source: tokenizer.Source = source_res.ok # limit tokens to things preceding cursor position cut_index = len(source.tokens()) for index, tok in enumerate(source.tokens()): if not tok.span_end() <= target: cut_index = index break tokens = source.tokens()[0:cut_index] # special case: cursor is *on* the last ident if tokens[-1].is_ident() and tokens[-1].span_end() == target: return [], True # run parser and suggest next possible keywords suggestions, can_be_ident = rust_parser.suggest_next_keywords( start_t_name, tokens ) # convert to CompletionItem return [ lsp_types.CompletionItem( label=keyword, kind=lsp_types.CompletionItemKind.Keyword, ) for keyword in suggestions ], can_be_ident def _tokenize( source: str, ) -> Result[tokenizer.Source, list[lsp_types.Diagnostic]]: try: return Result(ok=tokenizer.Source.from_string(source)) except errors.EdgeQLSyntaxError as e: return Result( err=[ lsp_types.Diagnostic( range=ls_utils.span_to_lsp(source, e.get_span()), severity=lsp_types.DiagnosticSeverity.Error, message=e.args[0], ) ] ) ================================================ FILE: edb/language_server/project.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import cast, Any from pathlib import Path import dataclasses import tomllib @dataclasses.dataclass(kw_only=True, frozen=True) class Instance: server_version: str @dataclasses.dataclass(kw_only=True, frozen=True) class Project: schema_dir: Path @dataclasses.dataclass(kw_only=True, frozen=True) class Manifest: instance: Instance | None project: Project | None # hooks: Option, # watch: Vec, def read_manifest(project_dir: Path) -> tuple[Manifest, Path]: try: path = project_dir / 'gel.toml' with open(path, 'rb') as f: manifest_dict = tomllib.load(f) except FileNotFoundError as e: path = project_dir / 'edgedb.toml' try: with open(path, 'rb') as f: manifest_dict = tomllib.load(f) except FileNotFoundError: raise e return (_load_manifest(manifest_dict), path) def _load_manifest(manifest_dict: Any) -> Manifest: instance = None if 'instance' in manifest_dict: instance = _load_instance(manifest_dict['instance']) elif 'edgedb' in manifest_dict: instance = _load_instance(manifest_dict['edgedb']) project = None if 'project' in manifest_dict: project = _load_project(manifest_dict['project']) return Manifest( instance=instance, project=project, ) def _load_instance(instance_dict: Any) -> Instance | None: server_version = None if 'server-version' in instance_dict: server_version = cast(str, instance_dict['server-version']) else: return None return Instance(server_version=server_version) def _load_project(project_dict: Any) -> Project | None: schema_dir = None if 'schema-dir' in project_dict: schema_dir = Path(project_dict['schema-dir']) else: return None return Project(schema_dir=schema_dir) ================================================ FILE: edb/language_server/schema.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Optional, cast, Iterable import pathlib import os from pygls import uris as pygls_uris import pygls from lsprotocol import types as lsp_types from edb import errors from edb.edgeql import ast as qlast from edb.schema import schema as s_schema from edb.schema import std as s_std from edb.schema import ddl as s_ddl import pygls.workspace from . import parsing as ls_parsing from . import is_schema_file from . import project from . import utils as ls_utils from . import server as ls_server from edb.language_server import Result def get_schema( ls: ls_server.GelLanguageServer, ) -> tuple[s_schema.Schema | None, ls_utils.DiagnosticsSet, str | None]: if ls.state.schema: return (ls.state.schema, ls_utils.DiagnosticsSet(), None) err_msg: str | None = None if len(ls.state.schema_docs) == 0: schema_dir, err_msg = _determine_schema_dir(ls) if not schema_dir: return (None, ls_utils.DiagnosticsSet(), err_msg) err_msg = _load_schema_docs(ls, schema_dir) if len(ls.state.schema_docs) == 0: return (None, ls_utils.DiagnosticsSet(), err_msg) schema, diagnostics = _compile_schema(ls) return schema, diagnostics, None def store_schema_doc( ls: ls_server.GelLanguageServer, doc: pygls.workspace.TextDocument ) -> list[lsp_types.Diagnostic]: res = _ensure_schema_docs_loaded(ls) if e := res.err: return [ls_utils.new_diagnostic_at_the_top(e)] schema_dir = cast(pathlib.Path, res.ok) # dont update if doc is not in schema_dir if schema_dir not in pathlib.Path(doc.path).parents: return [ ls_utils.new_diagnostic_at_the_top( f"this schema file is not in schema-dir ({schema_dir})" ) ] existing = next( (i for i, d in enumerate(ls.state.schema_docs) if d.path == doc.path), None, ) if existing is not None: # update ls.state.schema_docs[existing] = doc else: # insert ls.show_message_log("new schema file added: " + doc.path) ls.show_message_log("existing files: ") for d in ls.state.schema_docs: ls.show_message_log("- " + d.path) ls.state.schema_docs.append(doc) # clear AST cache ls.state.schema_sdl = None return [] def _ensure_schema_docs_loaded( ls: ls_server.GelLanguageServer, ) -> Result[pathlib.Path, str]: schema_dir, err_msg = _determine_schema_dir(ls) if not schema_dir: return Result(err=err_msg or "cannot find schema-dir") if len(ls.state.schema_docs) == 0: if err_mgs := _load_schema_docs(ls, schema_dir): return Result(err=err_mgs) return Result(ok=schema_dir) def _get_workspace_path( ls: ls_server.GelLanguageServer, ) -> tuple[pathlib.Path | None, str | None]: if len(ls.workspace.folders) > 1: return None, "Workspaces with multiple root folders are not supported" if len(ls.workspace.folders) == 0: return None, "No workspace open, cannot load schema" workspace: lsp_types.WorkspaceFolder = next( iter(ls.workspace.folders.values()) ) return pathlib.Path(pygls_uris.to_fs_path(workspace.uri)), None # Looks as the file system and loads schema documents into ls.state # Returns error message. def _load_schema_docs( ls: ls_server.GelLanguageServer, schema_dir: pathlib.Path ) -> Optional[str]: # discard all existing docs ls.state.schema_docs.clear() try: entries = os.listdir(schema_dir) except FileNotFoundError: return f"Cannot list directory: {schema_dir}" # read .esdl files for entry in entries: if not is_schema_file(entry): continue doc = ls.workspace.get_text_document(f"file://{schema_dir / entry}") ls.state.schema_docs.append(doc) # clear AST cache ls.state.schema_sdl = None return None def _determine_schema_dir( ls: ls_server.GelLanguageServer, ) -> tuple[pathlib.Path | None, str | None]: workspace_path, err_msg = _get_workspace_path(ls) if not workspace_path: return None, err_msg or "Cannot determine schema dir" project_dir = workspace_path / pathlib.Path(ls.config.project_dir) manifest, err_msg = _load_manifest(ls, project_dir) if not manifest: # no manifest: don't infer any schema dir return None, err_msg if manifest.project and manifest.project.schema_dir: schema_dir = project_dir / manifest.project.schema_dir else: schema_dir = project_dir / "dbschema" if schema_dir.is_dir(): return schema_dir, None return None, f"Missing schema dir at {schema_dir}" def _load_manifest( ls: ls_server.GelLanguageServer, project_dir: pathlib.Path, ) -> tuple[project.Manifest | None, str | None]: if ls.state.manifest: return ls.state.manifest[0], None try: ls.state.manifest = project.read_manifest(project_dir) return ls.state.manifest[0], None except BaseException as e: return None, str(e) def _parse_schema( ls: ls_server.GelLanguageServer, ) -> ls_utils.DiagnosticsSet: diagnostics = ls_utils.DiagnosticsSet() if ls.state.schema_sdl: return ls_utils.DiagnosticsSet() sdl = qlast.Schema(declarations=[]) for doc in ls.state.schema_docs: res = ls_parsing.parse(doc) if d := res.err: diagnostics.by_doc[doc] = d else: diagnostics.by_doc[doc] = [] if isinstance(res.ok, qlast.Schema): sdl.declarations.extend(res.ok.declarations) else: # TODO: complain that .gel contains non-SDL syntax pass ls.state.schema_sdl = sdl return diagnostics def _compile_schema( ls: ls_server.GelLanguageServer, ) -> tuple[s_schema.Schema | None, ls_utils.DiagnosticsSet]: diagnostics = _parse_schema(ls) assert ls.state.schema_sdl std_schema = _load_std_schema(ls.state) # apply SDL to std schema ls.show_message_log("compiling schema ..") try: schema, _warnings = s_ddl.apply_sdl( ls.state.schema_sdl, base_schema=std_schema ) ls.state.schema = schema ls.show_message_log(".. done") except errors.EdgeDBError as error: ls.show_message_log(".. error") schema = None # find doc do = next( (d for d in ls.state.schema_docs if error.filename == d.filename), None, ) if do is None: ls.show_message_log( f"cannot find original doc of the error ({error.filename}), " "using first schema file" ) do = ls.state.schema_docs[0] # convert error diagnostics.append(do, ls_utils.error_to_lsp(error)) ls.state.schema_diagnostics = diagnostics return (schema, diagnostics) def _load_std_schema(state: ls_server.State) -> s_schema.Schema: if state.std_schema is not None: return state.std_schema schema: s_schema.Schema = s_schema.EMPTY_SCHEMA for modname in s_schema.STD_SOURCES: schema = s_std.load_std_module(schema, modname) state.std_schema = schema return state.std_schema # Given a path from a node to qlast.Schema root, collects the names of # encapsulating modules. def get_module_context(path: Iterable[qlast.Base]) -> str | None: mod_names = [] for node in path: if isinstance(node, qlast.ModuleDeclaration): mod_names.append(node.name.name) if not mod_names: return None mod_names.reverse() return '::'.join(mod_names) ================================================ FILE: edb/language_server/server.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Mapping import dataclasses import pathlib from pygls.server import LanguageServer import pygls from lsprotocol import types as lsp_types from edb import errors from edb.common import traceback as edb_traceback from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.ir import ast as irast from edb.schema import schema as s_schema from edb.schema import ddl as s_ddl from . import project as ls_project from . import utils as ls_utils from . import parsing as ls_parsing from . import is_edgeql_file, is_schema_file @dataclasses.dataclass(kw_only=True) class State: manifest: tuple[ls_project.Manifest, pathlib.Path] | None = None schema_docs: list[pygls.workspace.TextDocument] = dataclasses.field( default_factory=lambda: [] ) schema_sdl: qlast.Schema | None = None schema: s_schema.Schema | None = None schema_diagnostics: ls_utils.DiagnosticsSet | None = None std_schema: s_schema.Schema | None = None @dataclasses.dataclass(kw_only=True) class Config: project_dir: str class GelLanguageServer(LanguageServer): state: State config: Config def __init__(self, config: Config): super().__init__("Gel Language Server", "v0.1") self.state = State() self.config = config def send_internal_error(ls: GelLanguageServer, e: BaseException): text = edb_traceback.format_exception(e) ls.show_message_log(f'Internal error: {text}') def document_updated(ls: GelLanguageServer, doc_uri: str, *, compile: bool): # each call to this function should yield in exactly one publish_diagnostics # for this document from . import schema as ls_schema document = ls.workspace.get_text_document(doc_uri) diagnostic_set = ls_utils.DiagnosticsSet() diagnostic_set.extend(document, []) # make sure we publish for document try: if is_schema_file(doc_uri): # schema file # parse diags = ls_schema.store_schema_doc(ls, document) diagnostic_set.extend(document, diags) diagnostic_set.merge(ls_schema._parse_schema(ls)) # compile if compile and not diagnostic_set.has_any(): _, _ = ls_schema._compile_schema(ls) # add schema diagnostics from last compilation if ls.state.schema_diagnostics: diagnostic_set.merge(ls.state.schema_diagnostics) elif is_edgeql_file(doc_uri): # query file # parse ast_res = ls_parsing.parse(document) if ast_res.err: diagnostic_set.extend(document, ast_res.err) # compile if compile and isinstance(ast_res.ok, qlast.Commands): diag, _ = compile_ql(ls, document, ast_res.ok.commands) diagnostic_set.merge(diag) else: ls.show_message_log(f'Unknown file type: {doc_uri}') for doc, diags in diagnostic_set.by_doc.items(): ls.publish_diagnostics(doc.uri, diags, doc.version) except BaseException as e: send_internal_error(ls, e) ls.publish_diagnostics(document.uri, [], document.version) def compile_ql( ls: GelLanguageServer, doc: pygls.workspace.TextDocument, stmts: list[qlast.Command], ) -> tuple[ls_utils.DiagnosticsSet, list[irast.Statement]]: from . import schema as ls_schema if not stmts: return (ls_utils.DiagnosticsSet(by_doc={doc: []}), []) schema, diagnostics_set, err_msg = ls_schema.get_schema(ls) if not schema: if len(ls.state.schema_docs) == 0: diagnostics_set.append( doc, ls_utils.new_diagnostic_at_the_top( err_msg or "Cannot find schema files" ), ) return (diagnostics_set, []) diagnostics: list[lsp_types.Diagnostic] = [] ir_stmts: list[irast.Statement] = [] modaliases: Mapping[str | None, str] = {None: "default"} for ql_stmt in stmts: try: if isinstance(ql_stmt, qlast.DDLCommand): schema, _delta = s_ddl.delta_and_schema_from_ddl( ql_stmt, schema=schema, modaliases=modaliases ) elif isinstance(ql_stmt, (qlast.Command, qlast.Expr)): options = qlcompiler.CompilerOptions(modaliases=modaliases) ir_res = qlcompiler.compile_ast_to_ir( ql_stmt, schema, options=options ) if isinstance(ir_res, irast.Statement): ir_stmts.append(ir_res) else: ls.show_message_log(f"skip compile of {ql_stmt}") except errors.EdgeDBError as error: diagnostics.append(ls_utils.error_to_lsp(error)) diagnostics_set.extend(doc, diagnostics) return (diagnostics_set, ir_stmts) ================================================ FILE: edb/language_server/utils.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import dataclasses from typing import Iterable from lsprotocol import types as lsp_types import pygls from edb import errors from edb.edgeql import tokenizer @dataclasses.dataclass(kw_only=True) class DiagnosticsSet: by_doc: dict[pygls.workspace.TextDocument, list[lsp_types.Diagnostic]] = ( dataclasses.field(default_factory=lambda: {}) ) def append( self, doc: pygls.workspace.TextDocument, diagnostic: lsp_types.Diagnostic, ): if doc not in self.by_doc: self.by_doc[doc] = [] self.by_doc[doc].append(diagnostic) def extend( self, doc: pygls.workspace.TextDocument, diagnostics: Iterable[lsp_types.Diagnostic], ): if doc not in self.by_doc: self.by_doc[doc] = [] self.by_doc[doc].extend(diagnostics) def merge(self, other: 'DiagnosticsSet'): for doc, diags in other.by_doc.items(): self.extend(doc, diags) def has_any(self) -> bool: for diags in self.by_doc.values(): if len(diags) != 0: return True return False # Convert a Span to LSP Range def span_to_lsp( source: str, span: tuple[int, int | None] | None ) -> lsp_types.Range: if span: (start, end) = tokenizer.inflate_span(source, span) else: (start, end) = (None, None) assert end return lsp_types.Range( start=( lsp_types.Position( line=start.line - 1, character=start.column - 1, ) if start else lsp_types.Position(line=0, character=0) ), end=( lsp_types.Position( line=end.line - 1, character=end.column - 1, ) if end else lsp_types.Position(line=0, character=0) ), ) # Convert EdgeDBError into an LSP Diagnostic def error_to_lsp(error: errors.EdgeDBError) -> lsp_types.Diagnostic: message: str = error.args[0] if hint := error.hint: message += f"\nHint: {hint}" return lsp_types.Diagnostic( range=( lsp_types.Range( start=lsp_types.Position( line=error.line - 1, character=error.col - 1, ), end=lsp_types.Position( line=error.line_end - 1, character=error.col_end - 1, ), ) if error.line >= 0 else lsp_types.Range( start=lsp_types.Position(line=0, character=0), end=lsp_types.Position(line=0, character=0), ) ), severity=lsp_types.DiagnosticSeverity.Error, message=message, ) # Constructs a new diagnostic in the first line of the document def new_diagnostic_at_the_top(message: str) -> lsp_types.Diagnostic: return lsp_types.Diagnostic( range=lsp_types.Range( start=lsp_types.Position(line=0, character=0), end=lsp_types.Position(line=1, character=0), ), severity=lsp_types.DiagnosticSeverity.Error, message=message, related_information=[], ) ================================================ FILE: edb/lib/__init__.py ================================================ from __future__ import annotations ================================================ FILE: edb/lib/_testmode.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Bits used for testing of the std-only functionality. # These definitions are picked up if the EdgeDB instance is bootstrapped # with --testmode. CREATE TYPE cfg::TestSessionConfig EXTENDING cfg::ConfigObject { CREATE REQUIRED PROPERTY name -> std::str { CREATE CONSTRAINT std::exclusive; } }; CREATE ABSTRACT TYPE cfg::Base EXTENDING cfg::ConfigObject { CREATE REQUIRED PROPERTY name -> std::str }; CREATE TYPE cfg::Subclass1 EXTENDING cfg::Base { CREATE REQUIRED PROPERTY sub1 -> std::str; }; CREATE TYPE cfg::Subclass2 EXTENDING cfg::Base { CREATE REQUIRED PROPERTY sub2 -> std::str; }; CREATE TYPE cfg::TestInstanceConfig EXTENDING cfg::ConfigObject { CREATE REQUIRED PROPERTY name -> std::str { CREATE CONSTRAINT std::exclusive; }; CREATE LINK obj -> cfg::Base; }; CREATE TYPE cfg::TestInstanceConfigStatTypes EXTENDING cfg::TestInstanceConfig { CREATE PROPERTY memprop -> cfg::memory; CREATE PROPERTY durprop -> std::duration; }; CREATE SCALAR TYPE cfg::TestEnum EXTENDING enum; CREATE SCALAR TYPE cfg::TestEnabledDisabledEnum EXTENDING enum; ALTER TYPE cfg::AbstractConfig { CREATE MULTI LINK sessobj -> cfg::TestSessionConfig { CREATE ANNOTATION cfg::internal := 'true'; }; CREATE MULTI LINK sysobj -> cfg::TestInstanceConfig { CREATE ANNOTATION cfg::internal := 'true'; }; CREATE PROPERTY __internal_testvalue -> std::int64 { CREATE ANNOTATION cfg::internal := 'true'; CREATE ANNOTATION cfg::system := 'true'; SET default := 0; }; CREATE PROPERTY __internal_sess_testvalue -> std::int64 { CREATE ANNOTATION cfg::internal := 'true'; SET default := 0; }; CREATE PROPERTY __internal_testmode -> std::bool { CREATE ANNOTATION cfg::internal := 'true'; CREATE ANNOTATION cfg::affects_compilation := 'true'; SET default := false; }; # Fully suppress apply_query_rewrites, like is done for internal # reflection queries. CREATE PROPERTY __internal_no_apply_query_rewrites -> std::bool { CREATE ANNOTATION cfg::internal := 'true'; CREATE ANNOTATION cfg::affects_compilation := 'true'; SET default := false; }; # Use the "reflection schema" as the base schema instead of the # normal std schema. This allows looking at all the schema fields # that are hidden in the public introspection schema. CREATE PROPERTY __internal_query_reflschema -> std::bool { CREATE ANNOTATION cfg::internal := 'true'; CREATE ANNOTATION cfg::affects_compilation := 'true'; SET default := false; }; CREATE PROPERTY __internal_restart -> std::bool { CREATE ANNOTATION cfg::internal := 'true'; CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION cfg::requires_restart := 'true'; SET default := false; }; CREATE MULTI PROPERTY multiprop -> std::str { CREATE ANNOTATION cfg::internal := 'true'; }; CREATE PROPERTY singleprop -> std::str { CREATE ANNOTATION cfg::internal := 'true'; SET default := ''; }; CREATE PROPERTY memprop -> cfg::memory { CREATE ANNOTATION cfg::internal := 'true'; SET default := '0'; }; CREATE PROPERTY durprop -> std::duration { CREATE ANNOTATION cfg::internal := 'true'; SET default := '0 seconds'; }; CREATE PROPERTY enumprop -> cfg::TestEnum { CREATE ANNOTATION cfg::internal := 'true'; SET default := cfg::TestEnum.One; }; CREATE PROPERTY boolprop -> std::bool { CREATE ANNOTATION cfg::internal := 'true'; SET default := true; }; CREATE PROPERTY __pg_max_connections -> std::int64 { CREATE ANNOTATION cfg::internal := 'true'; CREATE ANNOTATION cfg::backend_setting := '"max_connections"'; }; CREATE PROPERTY __check_function_bodies -> cfg::TestEnabledDisabledEnum { CREATE ANNOTATION cfg::internal := 'true'; CREATE ANNOTATION cfg::backend_setting := '"check_function_bodies"'; SET default := cfg::TestEnabledDisabledEnum.Enabled; }; }; # For testing configs defined in extensions create extension package _conf VERSION '1.0' { set ext_module := "ext::_conf"; set sql_extensions := []; create module ext::_conf; create type ext::_conf::SingleObj extending cfg::ConfigObject { create required property name -> std::str { set readonly := true; }; create required property value -> std::str { set readonly := true; }; create required property fixed -> std::str { set default := "fixed!"; set readonly := true; set protected := true; }; }; create type ext::_conf::Obj extending cfg::ConfigObject { create required property name -> std::str { set readonly := true; create constraint std::exclusive; }; create required property value -> std::str { set readonly := true; create delegated constraint std::exclusive; create constraint expression on (__subject__[:5] != 'asdf_'); }; create property opt_value -> std::str { set readonly := true; }; }; create type ext::_conf::SubObj extending ext::_conf::Obj { create required property extra -> int64 { set readonly := true; }; create required property duration_config: std::duration { set default := '10 minutes'; }; }; create type ext::_conf::SecretObj extending ext::_conf::Obj { create property secret -> std::str { set readonly := true; set secret := true; }; }; create type ext::_conf::Obj2 extending cfg::ConfigObject { create required property name -> std::str { set readonly := true; create constraint std::exclusive; }; }; create type ext::_conf::Config extending cfg::ExtensionConfig { create multi link objs -> ext::_conf::Obj; create link obj -> ext::_conf::SingleObj; create multi link objs2 -> ext::_conf::Obj2; create property config_name -> std::str { set default := ""; }; create property opt_value -> std::str; create property secret -> std::str { set secret := true; }; }; create function ext::_conf::get_secret(c: ext::_conf::SecretObj) -> optional std::str using (c.secret); create function ext::_conf::get_top_secret() -> set of std::str using ( cfg::Config.extensions[is ext::_conf::Config].secret); create alias ext::_conf::OK := ( cfg::Config.extensions[is ext::_conf::Config].secret ?= 'foobaz'); }; # std::_gen_series CREATE FUNCTION std::_gen_series( `start`: std::int64, stop: std::int64 ) -> SET OF std::int64 { SET volatility := 'Immutable'; USING SQL FUNCTION 'generate_series'; }; CREATE FUNCTION std::_gen_series( `start`: std::int64, stop: std::int64, step: std::int64 ) -> SET OF std::int64 { SET volatility := 'Immutable'; USING SQL FUNCTION 'generate_series'; }; CREATE FUNCTION std::_gen_series( `start`: std::bigint, stop: std::bigint ) -> SET OF std::bigint { SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL FUNCTION 'generate_series'; }; CREATE FUNCTION std::_gen_series( `start`: std::bigint, stop: std::bigint, step: std::bigint ) -> SET OF std::bigint { SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL FUNCTION 'generate_series'; }; CREATE FUNCTION sys::_sleep(duration: std::float64) -> std::bool { CREATE ANNOTATION std::description := 'Make the current session sleep for *duration* seconds.'; # This function has side-effect. SET volatility := 'Volatile'; USING SQL $$ SELECT pg_sleep("duration") IS NOT NULL; $$; }; CREATE FUNCTION sys::_sleep(duration: std::duration) -> std::bool { CREATE ANNOTATION std::description := 'Make the current session sleep for *duration* time.'; # This function has side-effect. SET volatility := 'Volatile'; USING SQL $$ SELECT pg_sleep_for("duration") IS NOT NULL; $$; }; CREATE FUNCTION sys::_postgres_version() -> std::str { CREATE ANNOTATION std::description := 'Get the postgres version string'; USING SQL $$ SELECT version() $$; }; CREATE FUNCTION sys::_advisory_lock(key: std::int64) -> std::bool { CREATE ANNOTATION std::description := 'Obtain an exclusive session-level advisory lock.'; # This function has side-effect. SET volatility := 'Volatile'; USING SQL $$ SELECT CASE WHEN "key" < 0 THEN edgedb_VER.raise(NULL::bool, msg => 'lock key cannot be negative') ELSE pg_advisory_lock("key") IS NOT NULL END; $$; }; CREATE FUNCTION sys::_advisory_unlock(key: std::int64) -> std::bool { CREATE ANNOTATION std::description := 'Release an exclusive session-level advisory lock.'; # This function has side-effect. SET volatility := 'Volatile'; USING SQL $$ SELECT CASE WHEN "key" < 0 THEN edgedb_VER.raise(NULL::bool, msg => 'lock key cannot be negative') ELSE pg_advisory_unlock("key") END; $$; }; CREATE FUNCTION sys::_advisory_unlock_all() -> std::bool { CREATE ANNOTATION std::description := 'Release all session-level advisory locks held by the current session.'; # This function has side-effect. SET volatility := 'Volatile'; USING SQL $$ SELECT pg_advisory_unlock_all() IS NOT NULL; $$; }; CREATE FUNCTION std::_datetime_range_buckets( low: std::datetime, high: std::datetime, granularity: str, ) -> SET OF tuple { CREATE ANNOTATION std::description := 'Generate a set of datetime buckets for a given time period ' ++ 'and a given granularity'; # date_trunc of timestamptz is STABLE in PostgreSQL SET volatility := 'Stable'; USING SQL $$ SELECT lo::edgedbt.timestamptz_t, hi::edgedbt.timestamptz_t FROM (SELECT series AS lo, lead(series) OVER () AS hi FROM generate_series( "low", "high", "granularity"::interval ) AS series) AS q WHERE hi IS NOT NULL $$; }; CREATE FUNCTION std::_current_setting(sqlname: str) -> OPTIONAL std::str { USING SQL $$ SELECT current_setting(sqlname, true) $$; }; create function std::_set_config(sqlname: std::str, val: std::str) -> std::str { using sql $$ select set_config(sqlname, val, true) $$; }; create function std::_warn_on_call() -> std::int64 { using (0) }; CREATE MODULE std::_test; CREATE FUNCTION std::_test::abs(x: std::anyreal) -> std::anyreal { SET volatility := 'Immutable'; USING SQL FUNCTION 'abs'; }; ================================================ FILE: edb/lib/cal.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE MODULE std::cal; CREATE SCALAR TYPE std::cal::local_datetime EXTENDING std::anycontiguous; CREATE SCALAR TYPE std::cal::local_date EXTENDING std::anydiscrete; CREATE SCALAR TYPE std::cal::local_time EXTENDING std::anyscalar; CREATE SCALAR TYPE std::cal::relative_duration EXTENDING std::anyscalar; CREATE SCALAR TYPE std::cal::date_duration EXTENDING std::anyscalar; ## Functions ## --------- CREATE FUNCTION std::cal::to_local_datetime(s: std::str, fmt: OPTIONAL str={}) -> std::cal::local_datetime { CREATE ANNOTATION std::description := 'Create a `std::cal::local_datetime` value.'; # Helper function to_local_datetime is VOLATILE. SET volatility := 'Volatile'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN edgedb_VER.local_datetime_in("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::edgedbt.timestamp_t, 'invalid_parameter_value', msg => ( 'to_local_datetime(): ' || '"fmt" argument must be a non-empty string' ) ) ELSE edgedb_VER.raise_on_null( edgedb_VER.to_local_datetime("s", "fmt"), 'invalid_parameter_value', msg => ( 'to_local_datetime(): ' || 'format ''' || "fmt" || ''' is invalid' ) ) END ) $$; }; CREATE FUNCTION std::cal::to_local_datetime(year: std::int64, month: std::int64, day: std::int64, hour: std::int64, min: std::int64, sec: std::float64) -> std::cal::local_datetime { CREATE ANNOTATION std::description := 'Create a `std::cal::local_datetime` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT make_timestamp( "year"::int, "month"::int, "day"::int, "hour"::int, "min"::int, "sec" )::edgedbt.timestamp_t $$; }; CREATE FUNCTION std::cal::to_local_datetime(dt: std::datetime, zone: std::str) -> std::cal::local_datetime { CREATE ANNOTATION std::description := 'Create a `std::cal::local_datetime` value.'; # The version of timezone with these arguments is IMMUTABLE. SET volatility := 'Immutable'; USING SQL $$ SELECT timezone("zone", "dt")::edgedbt.timestamp_t; $$; }; CREATE FUNCTION std::cal::to_local_date(s: std::str, fmt: OPTIONAL str={}) -> std::cal::local_date { CREATE ANNOTATION std::description := 'Create a `std::cal::local_date` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN edgedb_VER.local_date_in("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::edgedbt.date_t, 'invalid_parameter_value', msg => ( 'to_local_date(): ' || '"fmt" argument must be a non-empty string' ) ) ELSE edgedb_VER.raise_on_null( edgedb_VER.to_local_datetime("s", "fmt")::edgedbt.date_t, 'invalid_parameter_value', msg => ( 'to_local_date(): format ''' || "fmt" || ''' is invalid' ) ) END ) $$; }; CREATE FUNCTION std::cal::to_local_date(dt: std::datetime, zone: std::str) -> std::cal::local_date { CREATE ANNOTATION std::description := 'Create a `std::cal::local_date` value.'; # The version of timezone with these arguments is IMMUTABLE. SET volatility := 'Immutable'; USING SQL $$ SELECT timezone("zone", "dt")::edgedbt.date_t; $$; }; CREATE FUNCTION std::cal::to_local_date(year: std::int64, month: std::int64, day: std::int64) -> std::cal::local_date { CREATE ANNOTATION std::description := 'Create a `std::cal::local_date` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT make_date("year"::int, "month"::int, "day"::int)::edgedbt.date_t $$; }; CREATE FUNCTION std::cal::to_local_time(s: std::str, fmt: OPTIONAL str={}) -> std::cal::local_time { CREATE ANNOTATION std::description := 'Create a `std::cal::local_time` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN edgedb_VER.local_time_in("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::time, 'invalid_parameter_value', msg => ( 'to_local_time(): ' || '"fmt" argument must be a non-empty string' ) ) ELSE edgedb_VER.raise_on_null( edgedb_VER.to_local_datetime("s", "fmt")::time, 'invalid_parameter_value', msg => ( 'to_local_time(): ' || 'format ''' || "fmt" || ''' is invalid' ) ) END ) $$; }; CREATE FUNCTION std::cal::to_local_time(dt: std::datetime, zone: std::str) -> std::cal::local_time { CREATE ANNOTATION std::description := 'Create a `std::cal::local_time` value.'; # The version of timezone with these arguments is IMMUTABLE and so # is the cast. SET volatility := 'Immutable'; USING SQL $$ SELECT timezone("zone", "dt")::time; $$; }; CREATE FUNCTION std::cal::to_local_time(hour: std::int64, min: std::int64, sec: std::float64) -> std::cal::local_time { CREATE ANNOTATION std::description := 'Create a `std::cal::local_time` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN date_part('hour', x.t) = 24 THEN edgedb_VER.raise( NULL::time, 'invalid_datetime_format', msg => ( 'std::cal::local_time field value out of range: ' || quote_literal(x.t::text) ) ) ELSE x.t END FROM ( SELECT make_time("hour"::int, "min"::int, "sec") as t ) as x $$; }; CREATE FUNCTION std::cal::to_relative_duration( NAMED ONLY years: std::int64=0, NAMED ONLY months: std::int64=0, NAMED ONLY days: std::int64=0, NAMED ONLY hours: std::int64=0, NAMED ONLY minutes: std::int64=0, NAMED ONLY seconds: std::float64=0, NAMED ONLY microseconds: std::int64=0 ) -> std::cal::relative_duration { CREATE ANNOTATION std::description := 'Create a `std::cal::relative_duration` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( make_interval( "years"::int, "months"::int, 0, "days"::int, "hours"::int, "minutes"::int, "seconds" ) + (microseconds::text || ' microseconds')::interval )::edgedbt.relative_duration_t $$; }; CREATE FUNCTION std::cal::to_date_duration( NAMED ONLY years: std::int64=0, NAMED ONLY months: std::int64=0, NAMED ONLY days: std::int64=0 ) -> std::cal::date_duration { CREATE ANNOTATION std::description := 'Create a `std::cal::date_duration` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT make_interval( "years"::int, "months"::int, 0, "days"::int )::edgedbt.date_duration_t $$; }; CREATE FUNCTION std::cal::time_get(dt: std::cal::local_time, el: std::str) -> std::float64 { CREATE ANNOTATION std::description := 'Extract a specific element of input time by name.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "el" IN ('hour', 'microseconds', 'milliseconds', 'minutes', 'seconds') THEN date_part("el", "dt") WHEN "el" = 'midnightseconds' THEN date_part('epoch', "dt") ELSE edgedb_VER.raise( NULL::float, 'invalid_datetime_format', msg => ( 'invalid unit for std::time_get: ' || quote_literal("el") ), detail => ( '{"hint":"Supported units: hour, microseconds, ' || 'midnightseconds, milliseconds, minutes, seconds."}' ) ) END $$; }; CREATE FUNCTION std::cal::date_get(dt: std::cal::local_date, el: std::str) -> std::float64 { CREATE ANNOTATION std::description := 'Extract a specific element of input date by name.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "el" IN ( 'century', 'day', 'decade', 'dow', 'doy', 'isodow', 'isoyear', 'millennium', 'month', 'quarter', 'week', 'year') THEN date_part("el", "dt") ELSE edgedb_VER.raise( NULL::float, 'invalid_datetime_format', msg => ( 'invalid unit for std::date_get: ' || quote_literal("el") ), detail => ( '{"hint":"Supported units: century, day, ' || 'decade, dow, doy, isodow, isoyear, ' || 'millennium, month, quarter, seconds, week, year."}' ) ) END $$; }; CREATE FUNCTION std::cal::duration_normalize_hours(dur: std::cal::relative_duration) -> std::cal::relative_duration { CREATE ANNOTATION std::description := 'Convert 24-hour chunks into days.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL FUNCTION 'justify_hours'; }; CREATE FUNCTION std::cal::duration_normalize_days(dur: std::cal::relative_duration) -> std::cal::relative_duration { CREATE ANNOTATION std::description := 'Convert 30-day chunks into months.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL FUNCTION 'justify_days'; }; CREATE FUNCTION std::cal::duration_normalize_days(dur: std::cal::date_duration) -> std::cal::date_duration { CREATE ANNOTATION std::description := 'Convert 30-day chunks into months.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL FUNCTION 'justify_days'; }; ## Operators on std::datetime ## -------------------------- CREATE INFIX OPERATOR std::`+` (l: std::datetime, r: std::cal::relative_duration) -> std::datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; # Immutable because datetime is guaranteed to be in UTC and no DST issues # should affect this. SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamptz_t $$ }; CREATE INFIX OPERATOR std::`+` (l: std::cal::relative_duration, r: std::datetime) -> std::datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; # Immutable because datetime is guaranteed to be in UTC and no DST issues # should affect this. SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamptz_t $$ }; CREATE INFIX OPERATOR std::`-` (l: std::datetime, r: std::cal::relative_duration) -> std::datetime { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval and date/time subtraction.'; # Immutable because datetime is guaranteed to be in UTC and no DST issues # should affect this. SET volatility := 'Immutable'; USING SQL $$ SELECT ("l" - "r")::edgedbt.timestamptz_t $$ }; ## Operators on std::cal::local_datetime ## -------------------------------- CREATE INFIX OPERATOR std::`=` (l: std::cal::local_datetime, r: std::cal::local_datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'=(timestamp,timestamp)'; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::cal::local_datetime, r: OPTIONAL std::cal::local_datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::cal::local_datetime, r: std::cal::local_datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>(timestamp,timestamp)'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::cal::local_datetime, r: OPTIONAL std::cal::local_datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>` (l: std::cal::local_datetime, r: std::cal::local_datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>(timestamp,timestamp)'; }; CREATE INFIX OPERATOR std::`>=` (l: std::cal::local_datetime, r: std::cal::local_datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>=(timestamp,timestamp)'; }; CREATE INFIX OPERATOR std::`<` (l: std::cal::local_datetime, r: std::cal::local_datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<(timestamp,timestamp)'; }; CREATE INFIX OPERATOR std::`<=` (l: std::cal::local_datetime, r: std::cal::local_datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<=(timestamp,timestamp)'; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::local_datetime, r: std::duration) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`+` (l: std::duration, r: std::cal::local_datetime) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::local_datetime, r: std::duration) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval and date/time subtraction.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ("l" - "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::local_datetime, r: std::cal::relative_duration) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::relative_duration, r: std::cal::local_datetime) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::local_datetime, r: std::cal::relative_duration) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval and date/time subtraction.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ("l" - "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::local_datetime, r: std::cal::local_datetime) -> std::cal::relative_duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Date/time subtraction.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL OPERATOR r'-(timestamp, timestamp)'; }; ## Operators on std::cal::local_date ## ---------------------------- CREATE INFIX OPERATOR std::`=` (l: std::cal::local_date, r: std::cal::local_date) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'=(date,date)'; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::cal::local_date, r: OPTIONAL std::cal::local_date) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::cal::local_date, r: std::cal::local_date) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>(date,date)'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::cal::local_date, r: OPTIONAL std::cal::local_date) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>` (l: std::cal::local_date, r: std::cal::local_date) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>(date,date)'; }; CREATE INFIX OPERATOR std::`>=` (l: std::cal::local_date, r: std::cal::local_date) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>=(date,date)'; }; CREATE INFIX OPERATOR std::`<` (l: std::cal::local_date, r: std::cal::local_date) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<(date,date)'; }; CREATE INFIX OPERATOR std::`<=` (l: std::cal::local_date, r: std::cal::local_date) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<=(date,date)'; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::local_date, r: std::duration) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; SET force_return_cast := true; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`+` (l: std::duration, r: std::cal::local_date) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; SET force_return_cast := true; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::local_date, r: std::duration) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval and date/time subtraction.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL $$ SELECT ("l" - "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::local_date, r: std::cal::relative_duration) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; SET force_return_cast := true; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::relative_duration, r: std::cal::local_date) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; SET force_return_cast := true; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::local_date, r: std::cal::relative_duration) -> std::cal::local_datetime { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval and date/time subtraction.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL $$ SELECT ("l" - "r")::edgedbt.timestamp_t $$; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::local_date, r: std::cal::date_duration) -> std::cal::local_date { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; SET force_return_cast := true; USING SQL $$ SELECT ("l" + "r")::edgedbt.date_t $$; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::date_duration, r: std::cal::local_date) -> std::cal::local_date { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; SET force_return_cast := true; USING SQL $$ SELECT ("l" + "r")::edgedbt.date_t $$; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::local_date, r: std::cal::date_duration) -> std::cal::local_date { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval and date/time subtraction.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL $$ SELECT ("l" - "r")::edgedbt.date_t $$; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::local_date, r: std::cal::local_date) -> std::cal::date_duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Date subtraction.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL $$ SELECT make_interval(0, 0, 0, "l" - "r")::edgedbt.date_duration_t $$; }; ## Operators on std::cal::local_time ## ---------------------------- CREATE INFIX OPERATOR std::`=` (l: std::cal::local_time, r: std::cal::local_time) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::cal::local_time, r: OPTIONAL std::cal::local_time) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::cal::local_time, r: std::cal::local_time) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::cal::local_time, r: OPTIONAL std::cal::local_time) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>` (l: std::cal::local_time, r: std::cal::local_time) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>=` (l: std::cal::local_time, r: std::cal::local_time) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`<` (l: std::cal::local_time, r: std::cal::local_time) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<=` (l: std::cal::local_time, r: std::cal::local_time) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::local_time, r: std::duration) -> std::cal::local_time { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL OPERATOR r'+(time, interval)'; }; CREATE INFIX OPERATOR std::`+` (l: std::duration, r: std::cal::local_time) -> std::cal::local_time { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL OPERATOR r'+(interval, time)'; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::local_time, r: std::duration) -> std::cal::local_time { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval and date/time subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-(time, interval)'; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::local_time, r: std::cal::relative_duration) -> std::cal::local_time { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL OPERATOR r'+(time, interval)'; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::relative_duration, r: std::cal::local_time) -> std::cal::local_time { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL OPERATOR r'+(interval, time)'; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::local_time, r: std::cal::relative_duration) -> std::cal::local_time { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval and date/time subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-(time, interval)'; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::local_time, r: std::cal::local_time) -> std::cal::relative_duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time subtraction.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL OPERATOR r'-(time, time)'; }; ## Operators on std::cal::relative_duration ## ---------------------------- CREATE INFIX OPERATOR std::`=` (l: std::cal::relative_duration, r: std::cal::relative_duration) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'=(interval,interval)'; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::cal::relative_duration, r: OPTIONAL std::cal::relative_duration) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::cal::relative_duration, r: std::cal::relative_duration) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>(interval,interval)'; }; CREATE INFIX OPERATOR std::`?!=` ( l: OPTIONAL std::cal::relative_duration, r: OPTIONAL std::cal::relative_duration ) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>` (l: std::cal::relative_duration, r: std::cal::relative_duration) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>(interval,interval)'; }; CREATE INFIX OPERATOR std::`>=` (l: std::cal::relative_duration, r: std::cal::relative_duration) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>=(interval,interval)'; }; CREATE INFIX OPERATOR std::`<` (l: std::cal::relative_duration, r: std::cal::relative_duration) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<(interval,interval)'; }; CREATE INFIX OPERATOR std::`<=` (l: std::cal::relative_duration, r: std::cal::relative_duration) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<=(interval,interval)'; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::relative_duration, r: std::cal::relative_duration) -> std::cal::relative_duration { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l"::interval + "r"::interval)::edgedbt.relative_duration_t; $$; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::relative_duration, r: std::cal::relative_duration) -> std::cal::relative_duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval subtraction.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ("l"::interval - "r"::interval)::edgedbt.relative_duration_t; $$; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::date_duration, r: std::cal::date_duration) -> std::cal::date_duration { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l" + "r")::edgedbt.date_duration_t; $$; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::date_duration, r: std::cal::date_duration) -> std::cal::date_duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval subtraction.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ("l" - "r")::edgedbt.date_duration_t; $$; }; CREATE INFIX OPERATOR std::`+` (l: std::duration, r: std::cal::relative_duration) -> std::cal::relative_duration { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l"::interval + "r"::interval)::edgedbt.relative_duration_t; $$; }; CREATE INFIX OPERATOR std::`+` (l: std::cal::relative_duration, r: std::duration) -> std::cal::relative_duration { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l"::interval + "r"::interval)::edgedbt.relative_duration_t; $$; }; CREATE INFIX OPERATOR std::`-` (l: std::duration, r: std::cal::relative_duration) -> std::cal::relative_duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval subtraction.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ("l"::interval - "r"::interval)::edgedbt.relative_duration_t; $$; }; CREATE INFIX OPERATOR std::`-` (l: std::cal::relative_duration, r: std::duration) -> std::cal::relative_duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval subtraction.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ("l"::interval - "r"::interval)::edgedbt.relative_duration_t; $$; }; CREATE PREFIX OPERATOR std::`-` (v: std::cal::relative_duration) -> std::cal::relative_duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval negation.'; SET volatility := 'Immutable'; USING SQL $$ SELECT (-"v"::interval)::edgedbt.relative_duration_t; $$; }; ## Date/time casts ## --------------- CREATE CAST FROM std::cal::local_datetime TO std::cal::local_date { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::cal::local_datetime TO std::cal::local_time { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::cal::local_date TO std::cal::local_datetime { SET volatility := 'Immutable'; USING SQL CAST; # Analogous to implicit cast from int64 to float64. ALLOW IMPLICIT; }; CREATE CAST FROM std::str TO std::cal::local_datetime { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.local_datetime_in'; }; CREATE CAST FROM std::str TO std::cal::local_date { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.local_date_in'; }; CREATE CAST FROM std::str TO std::cal::local_time { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.local_time_in'; }; CREATE CAST FROM std::str TO std::cal::relative_duration { SET volatility := 'Immutable'; USING SQL $$ SELECT val::edgedbt.relative_duration_t; $$; }; CREATE CAST FROM std::str TO std::cal::date_duration { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.date_duration_in'; }; CREATE CAST FROM std::cal::local_datetime TO std::str { SET volatility := 'Immutable'; USING SQL $$ SELECT trim(to_json(val)::text, '"'); $$; }; CREATE CAST FROM std::cal::local_date TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::cal::local_time TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::cal::relative_duration TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::cal::date_duration TO std::str { SET volatility := 'Immutable'; # We want the 0 date_duration canonically represented be in lowest # date_duration units, i.e. in days. USING SQL $$ SELECT CASE WHEN (val::text = 'PT0S') THEN 'P0D' ELSE val::text END $$; }; CREATE CAST FROM std::cal::local_datetime TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::cal::local_date TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::cal::local_time TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::cal::relative_duration TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::cal::date_duration TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; # We want the 0 date_duration canonically represented be in lowest # date_duration units, i.e. in days. USING SQL $$ SELECT CASE WHEN (val::text = 'PT0S') THEN to_jsonb('P0D'::text) ELSE to_jsonb(val) END $$; }; CREATE CAST FROM std::json TO std::cal::local_datetime { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.local_datetime_in( edgedb_VER.jsonb_extract_scalar(val, 'string', detail => detail) ); $$; }; CREATE CAST FROM std::json TO std::cal::local_date { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.local_date_in( edgedb_VER.jsonb_extract_scalar(val, 'string', detail => detail) ); $$; }; CREATE CAST FROM std::json TO std::cal::local_time { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.local_time_in( edgedb_VER.jsonb_extract_scalar(val, 'string', detail => detail) ); $$; }; CREATE CAST FROM std::json TO std::cal::relative_duration { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.jsonb_extract_scalar( val, 'string', detail => detail )::interval::edgedbt.relative_duration_t; $$; }; CREATE CAST FROM std::json TO std::cal::date_duration { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.date_duration_in( edgedb_VER.jsonb_extract_scalar(val, 'string', detail => detail) ); $$; }; CREATE CAST FROM std::duration TO std::cal::relative_duration { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::cal::relative_duration TO std::duration { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::cal::date_duration TO std::cal::relative_duration { # Same underlying types that don't require any DST calculations to convert # into eachother. SET volatility := 'Immutable'; USING SQL CAST; # Analogous to implicit cast from int64 to float64. ALLOW IMPLICIT; }; CREATE CAST FROM std::cal::relative_duration TO std::cal::date_duration { # Same underlying types that don't require any DST calculations to convert # into eachother. SET volatility := 'Immutable'; USING SQL CAST; }; ## Modified functions ## ------------------ CREATE FUNCTION std::datetime_get(dt: std::cal::local_datetime, el: std::str) -> std::float64 { CREATE ANNOTATION std::description := 'Extract a specific element of input datetime by name.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "el" IN ( 'century', 'day', 'decade', 'dow', 'doy', 'hour', 'isodow', 'isoyear', 'microseconds', 'millennium', 'milliseconds', 'minutes', 'month', 'quarter', 'seconds', 'week', 'year') THEN date_part("el", "dt") WHEN "el" = 'epochseconds' THEN date_part('epoch', "dt") ELSE edgedb_VER.raise( NULL::float, 'invalid_datetime_format', msg => ( 'invalid unit for std::datetime_get: ' || quote_literal("el") ), detail => ( '{"hint":"Supported units: epochseconds, century, ' || 'day, decade, dow, doy, hour, isodow, isoyear, ' || 'microseconds, millennium, milliseconds, minutes, ' || 'month, quarter, seconds, week, year."}' ) ) END $$; }; CREATE FUNCTION std::duration_get(dt: std::cal::date_duration, el: std::str) -> std::float64 { CREATE ANNOTATION std::description := 'Extract a specific element of input duration by name.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "el" IN ( 'millennium', 'century', 'decade', 'year', 'quarter', 'month', 'day') THEN date_part("el", "dt") WHEN "el" = 'totalseconds' THEN date_part('epoch', "dt") ELSE edgedb_VER.raise( NULL::float, 'invalid_datetime_format', msg => ( 'invalid unit for std::duration_get: ' || quote_literal("el") ), detail => ( '{"hint":"Supported units: ' || 'millennium, century, decade, year, quarter, month, day, ' || 'hour, and totalseconds."}' ) ) END $$; }; CREATE FUNCTION std::duration_get(dt: std::cal::relative_duration, el: std::str) -> std::float64 { CREATE ANNOTATION std::description := 'Extract a specific element of input duration by name.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "el" IN ( 'millennium', 'century', 'decade', 'year', 'quarter', 'month', 'day', 'hour', 'minutes', 'seconds', 'milliseconds', 'microseconds') THEN date_part("el", "dt") WHEN "el" = 'totalseconds' THEN date_part('epoch', "dt") ELSE edgedb_VER.raise( NULL::float, 'invalid_datetime_format', msg => ( 'invalid unit for std::duration_get: ' || quote_literal("el") ), detail => ( '{"hint":"Supported units: ' || 'millennium, century, decade, year, quarter, month, day, ' || 'hour, minutes, seconds, milliseconds, microseconds, ' || 'and totalseconds."}' ) ) END $$; }; CREATE FUNCTION std::duration_truncate( dt: std::cal::date_duration, unit: std::str ) -> std::cal::date_duration { CREATE ANNOTATION std::description := 'Truncate the input duration to a particular precision.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "unit" IN ( 'days', 'weeks', 'months', 'years', 'decades', 'centuries') THEN date_trunc("unit", "dt")::edgedbt.relative_duration_t WHEN "unit" = 'quarters' THEN date_trunc('quarter', "dt")::edgedbt.relative_duration_t ELSE edgedb_VER.raise( NULL::edgedbt.relative_duration_t, 'invalid_datetime_format', msg => ( 'invalid unit for std::duration_truncate: ' || quote_literal("unit") ), detail => ( '{"hint":"Supported units: days, weeks, months, ' || 'quarters, years, decades, centuries."}' ) ) END $$; }; CREATE FUNCTION std::duration_truncate( dt: std::cal::relative_duration, unit: std::str ) -> std::cal::relative_duration { CREATE ANNOTATION std::description := 'Truncate the input duration to a particular precision.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "unit" IN ( 'microseconds', 'milliseconds', 'seconds', 'minutes', 'hours', 'days', 'weeks', 'months', 'years', 'decades', 'centuries') THEN date_trunc("unit", "dt")::edgedbt.relative_duration_t WHEN "unit" = 'quarters' THEN date_trunc('quarter', "dt")::edgedbt.relative_duration_t ELSE edgedb_VER.raise( NULL::edgedbt.relative_duration_t, 'invalid_datetime_format', msg => ( 'invalid unit for std::duration_truncate: ' || quote_literal("unit") ), detail => ( '{"hint":"Supported units: microseconds, milliseconds, ' || 'seconds, minutes, hours, days, weeks, months, ' || 'quarters, years, decades, centuries."}' ) ) END $$; }; CREATE FUNCTION std::to_str(dt: std::cal::local_datetime, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN trim(to_json("dt")::text, '"') WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_char("dt", "fmt"), 'invalid_parameter_value', msg => 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_str(d: std::cal::local_date, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN "d"::text WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_char("d", "fmt"), 'invalid_parameter_value', msg => 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; # Currently local time is formatted by composing it with the local # current local date. This at least guarantees that the time # formatting is accessible and consistent with full datetime # formatting, but it exposes current date as well if it is included in # the format. # FIXME: date formatting should not have any special effect. CREATE FUNCTION std::to_str(nt: std::cal::local_time, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN "nt"::text WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_char(date_trunc('day', localtimestamp) + "nt", "fmt"), 'invalid_parameter_value', msg => 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_str(rd: std::cal::relative_duration, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN "rd"::text WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_char("rd", "fmt"), 'invalid_parameter_value', msg => 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_datetime(local: std::cal::local_datetime, zone: std::str) -> std::datetime { CREATE ANNOTATION std::description := 'Create a `datetime` value.'; # The version of timezone with these arguments is IMMUTABLE. SET volatility := 'Immutable'; USING SQL $$ SELECT timezone("zone", "local")::edgedbt.timestamptz_t; $$; }; CREATE FUNCTION std::min(vals: SET OF std::cal::local_datetime) -> OPTIONAL std::cal::local_datetime { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF std::cal::local_date) -> OPTIONAL std::cal::local_date { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF std::cal::local_time) -> OPTIONAL std::cal::local_time { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF std::cal::relative_duration) -> OPTIONAL std::cal::relative_duration { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF std::cal::date_duration) -> OPTIONAL std::cal::date_duration { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF array) -> OPTIONAL array { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF array) -> OPTIONAL array { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF array) -> OPTIONAL array { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF array) -> OPTIONAL array { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF array) -> OPTIONAL array { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::max(vals: SET OF std::cal::local_datetime) -> OPTIONAL std::cal::local_datetime { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF std::cal::local_date) -> OPTIONAL std::cal::local_date { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF std::cal::local_time) -> OPTIONAL std::cal::local_time { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF std::cal::relative_duration) -> OPTIONAL std::cal::relative_duration { CREATE ANNOTATION std::description := 'Return the greatest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF std::cal::date_duration) -> OPTIONAL std::cal::date_duration { CREATE ANNOTATION std::description := 'Return the greatest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF array) -> OPTIONAL array { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF array) -> OPTIONAL array { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF array) -> OPTIONAL array { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF array) -> OPTIONAL array { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF array) -> OPTIONAL array { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::sum(s: SET OF std::cal::relative_duration) -> std::cal::relative_duration { CREATE ANNOTATION std::description := 'Return the arithmetic sum of values in a set.'; SET volatility := 'Immutable'; SET initial_value := "PT0S"; SET force_return_cast := true; USING SQL FUNCTION 'sum'; }; CREATE FUNCTION std::sum(s: SET OF std::cal::date_duration) -> std::cal::date_duration { CREATE ANNOTATION std::description := 'Return the arithmetic sum of values in a set.'; SET volatility := 'Immutable'; SET initial_value := "PT0S"; SET force_return_cast := true; USING SQL FUNCTION 'sum'; }; ## Range functions # FIXME: These functions introduce the concrete multirange types into the # schema. That's why they exist for each concrete type explicitly and aren't # defined generically for anytype. CREATE FUNCTION std::multirange_unpack( val: multirange, ) -> set of range { SET volatility := 'Immutable'; USING SQL FUNCTION 'unnest'; }; CREATE FUNCTION std::multirange_unpack( val: multirange, ) -> set of range { SET volatility := 'Immutable'; USING SQL FUNCTION 'unnest'; }; CREATE FUNCTION std::range_unpack( val: range, step: std::cal::relative_duration ) -> set of std::cal::local_datetime { SET volatility := 'Immutable'; USING SQL $$ SELECT d::edgedbt.timestamp_t FROM generate_series( ( edgedb_VER.range_lower_validate(val) + ( CASE WHEN lower_inc(val) THEN '0'::interval ELSE step END ) )::timestamptz, ( edgedb_VER.range_upper_validate(val) )::timestamptz, step::interval ) AS d WHERE upper_inc(val) OR d::edgedbt.timestamp_t < upper(val) $$; }; CREATE FUNCTION std::range_unpack( val: range ) -> set of std::cal::local_date { SET volatility := 'Immutable'; USING SQL $$ SELECT generate_series( ( edgedb_VER.range_lower_validate(val) + ( CASE WHEN lower_inc(val) THEN '0'::interval ELSE 'P1D'::interval END ) )::timestamp, ( edgedb_VER.range_upper_validate(val) - ( CASE WHEN upper_inc(val) THEN '0'::interval ELSE 'P1D'::interval END ) )::timestamp, 'P1D'::interval )::edgedbt.date_t $$; }; CREATE FUNCTION std::range_unpack( val: range, step: std::cal::date_duration ) -> set of std::cal::local_date { SET volatility := 'Immutable'; USING SQL $$ SELECT generate_series( ( edgedb_VER.range_lower_validate(val) + ( CASE WHEN lower_inc(val) THEN '0'::interval ELSE 'P1D'::interval END ) )::timestamp, ( edgedb_VER.range_upper_validate(val) - ( CASE WHEN upper_inc(val) THEN '0'::interval ELSE 'P1D'::interval END ) )::timestamp, step::interval )::edgedbt.date_t $$; }; # Need to cast edgedbt.date_t to date in order for the @> operator to work. CREATE FUNCTION std::contains( haystack: range, needle: std::cal::local_date ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "haystack" @> ("needle"::date) $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::contains( haystack: multirange, needle: std::cal::local_date ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "haystack" @> ("needle"::date) $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; ================================================ FILE: edb/lib/cfg.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE MODULE cfg; CREATE MODULE cfg::perm; CREATE PERMISSION cfg::perm::configure_timeout; CREATE PERMISSION cfg::perm::configure_apply_access_policies; CREATE PERMISSION cfg::perm::configure_allow_user_specified_id; CREATE ABSTRACT INHERITABLE ANNOTATION cfg::backend_setting; # If report is set to 'true', that *system* config will be included # in the `system_config` ParameterStatus on each connection. # Non-system config cannot be reported. CREATE ABSTRACT INHERITABLE ANNOTATION cfg::report; CREATE ABSTRACT INHERITABLE ANNOTATION cfg::internal; CREATE ABSTRACT INHERITABLE ANNOTATION cfg::requires_restart; # System config means that config value can only be modified using # CONFIGURE INSTANCE command. System config is therefore *not* included # in the binary protocol state. CREATE ABSTRACT INHERITABLE ANNOTATION cfg::system; CREATE ABSTRACT INHERITABLE ANNOTATION cfg::affects_compilation; # Value is json. "*" means always allowed, other strings mean a permission # that must be held. CREATE ABSTRACT INHERITABLE ANNOTATION cfg::session_cfg_permissions; CREATE SCALAR TYPE cfg::memory EXTENDING std::anyscalar; CREATE SCALAR TYPE cfg::AllowBareDDL EXTENDING enum; CREATE SCALAR TYPE cfg::StoreMigrationSDL EXTENDING enum< AlwaysStore, NeverStore, >; CREATE SCALAR TYPE cfg::ConnectionTransport EXTENDING enum< TCP, TCP_PG, HTTP, SIMPLE_HTTP, HTTP_METRICS, HTTP_HEALTH>; CREATE SCALAR TYPE cfg::QueryCacheMode EXTENDING enum< InMemory, RegInline, PgFunc, Default>; CREATE SCALAR TYPE cfg::QueryStatsOption EXTENDING enum; CREATE ABSTRACT TYPE cfg::ConfigObject EXTENDING std::BaseObject; CREATE ABSTRACT TYPE cfg::AuthMethod EXTENDING cfg::ConfigObject { # Connection transports applicable to this auth entry. # An empty set means "apply to all transports". CREATE MULTI PROPERTY transports -> cfg::ConnectionTransport { SET readonly := true; }; }; CREATE TYPE cfg::Trust EXTENDING cfg::AuthMethod; CREATE TYPE cfg::SCRAM EXTENDING cfg::AuthMethod { ALTER PROPERTY transports { SET default := { cfg::ConnectionTransport.TCP }; }; }; CREATE TYPE cfg::JWT EXTENDING cfg::AuthMethod { ALTER PROPERTY transports { SET default := { cfg::ConnectionTransport.HTTP }; }; }; CREATE TYPE cfg::Password EXTENDING cfg::AuthMethod { ALTER PROPERTY transports { SET default := { cfg::ConnectionTransport.SIMPLE_HTTP }; }; }; CREATE TYPE cfg::mTLS EXTENDING cfg::AuthMethod { ALTER PROPERTY transports { SET default := { cfg::ConnectionTransport.HTTP_METRICS, cfg::ConnectionTransport.HTTP_HEALTH, }; }; }; CREATE TYPE cfg::Auth EXTENDING cfg::ConfigObject { CREATE REQUIRED PROPERTY priority -> std::int64 { CREATE CONSTRAINT std::exclusive; SET readonly := true; }; CREATE MULTI PROPERTY user -> std::str { SET readonly := true; SET default := {'*'}; }; CREATE SINGLE LINK method -> cfg::AuthMethod { CREATE CONSTRAINT std::exclusive; SET readonly := true; }; CREATE PROPERTY comment -> std::str { SET readonly := true; }; }; CREATE SCALAR TYPE cfg::SMTPSecurity EXTENDING enum< PlainText, TLS, STARTTLS, STARTTLSOrPlainText, >; CREATE ABSTRACT TYPE cfg::EmailProviderConfig EXTENDING cfg::ConfigObject { CREATE REQUIRED PROPERTY name -> std::str { CREATE CONSTRAINT std::exclusive; CREATE ANNOTATION std::description := "The name of the email provider."; }; }; CREATE TYPE cfg::SMTPProviderConfig EXTENDING cfg::EmailProviderConfig { CREATE PROPERTY sender -> std::str { CREATE ANNOTATION std::description := "\"From\" address of system emails sent for e.g. \ password reset, etc."; }; CREATE PROPERTY host -> std::str { CREATE ANNOTATION std::description := "Host of SMTP server to use for sending emails. \ If not set, \"localhost\" will be used."; }; CREATE PROPERTY port -> std::int32 { CREATE ANNOTATION std::description := "Port of SMTP server to use for sending emails. \ If not set, common defaults will be used depending on security: \ 465 for TLS, 587 for STARTTLS, 25 otherwise."; }; CREATE PROPERTY username -> std::str { CREATE ANNOTATION std::description := "Username to login as after connected to SMTP server."; }; CREATE PROPERTY password -> std::str { SET secret := true; CREATE ANNOTATION std::description := "Password for login after connected to SMTP server."; }; CREATE REQUIRED PROPERTY security -> cfg::SMTPSecurity { SET default := cfg::SMTPSecurity.STARTTLSOrPlainText; CREATE ANNOTATION std::description := "Security mode of the connection to SMTP server. \ By default, initiate a STARTTLS upgrade if supported by the \ server, or fallback to PlainText."; }; CREATE REQUIRED PROPERTY validate_certs -> std::bool { SET default := true; CREATE ANNOTATION std::description := "Determines if SMTP server certificates are validated."; }; CREATE REQUIRED PROPERTY timeout_per_email -> std::duration { SET default := '60 seconds'; CREATE ANNOTATION std::description := "Maximum time to send an email, including retry attempts."; }; CREATE REQUIRED PROPERTY timeout_per_attempt -> std::duration { SET default := '15 seconds'; CREATE ANNOTATION std::description := "Maximum time for each SMTP request."; }; }; CREATE ABSTRACT TYPE cfg::AbstractConfig extending cfg::ConfigObject; CREATE ABSTRACT TYPE cfg::ExtensionConfig EXTENDING cfg::ConfigObject { CREATE REQUIRED SINGLE LINK cfg -> cfg::AbstractConfig { CREATE DELEGATED CONSTRAINT std::exclusive; }; }; ALTER TYPE cfg::AbstractConfig { CREATE MULTI LINK extensions := . std::duration { CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION cfg::report := 'true'; CREATE ANNOTATION std::description := 'How long client connections can stay inactive before being \ closed by the server.'; SET default := '60 seconds'; }; CREATE REQUIRED PROPERTY default_transaction_isolation -> sys::TransactionIsolation { CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION cfg::backend_setting := '"default_transaction_isolation"'; CREATE ANNOTATION cfg::session_cfg_permissions := '"*"'; CREATE ANNOTATION std::description := 'Controls the default isolation level of each new transaction, \ including implicit transactions. Defaults to `Serializable`. \ Note that changing this to a lower isolation level implies \ that the transactions are also read-only by default regardless \ of the value of the `default_transaction_access_mode` setting.'; SET default := sys::TransactionIsolation.Serializable; }; CREATE REQUIRED PROPERTY default_transaction_access_mode -> sys::TransactionAccessMode { CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION cfg::session_cfg_permissions := '"*"'; CREATE ANNOTATION std::description := 'Controls the default read-only status of each new transaction, \ including implicit transactions. Defaults to `ReadWrite`. \ Note that if `default_transaction_isolation` is set to any value \ other than Serializable this parameter is implied to be \ `ReadOnly` regardless of the actual value.'; SET default := sys::TransactionAccessMode.ReadWrite; }; CREATE REQUIRED PROPERTY default_transaction_deferrable -> sys::TransactionDeferrability { CREATE ANNOTATION cfg::backend_setting := '"default_transaction_deferrable"'; CREATE ANNOTATION cfg::session_cfg_permissions := '"*"'; CREATE ANNOTATION std::description := 'Controls the default deferrable status of each new transaction. \ It currently has no effect on read-write transactions or those \ operating at isolation levels lower than `Serializable`. \ The default is `NotDeferrable`.'; SET default := sys::TransactionDeferrability.NotDeferrable; }; CREATE REQUIRED PROPERTY session_idle_transaction_timeout -> std::duration { CREATE ANNOTATION cfg::backend_setting := '"idle_in_transaction_session_timeout"'; CREATE ANNOTATION cfg::session_cfg_permissions := '"cfg::perm::configure_timeout"'; CREATE ANNOTATION std::description := 'How long client connections can stay inactive while in a \ transaction.'; SET default := '10 seconds'; }; CREATE REQUIRED PROPERTY query_execution_timeout -> std::duration { CREATE ANNOTATION cfg::backend_setting := '"statement_timeout"'; CREATE ANNOTATION cfg::session_cfg_permissions := '"cfg::perm::configure_timeout"'; CREATE ANNOTATION std::description := 'How long an individual query can run before being aborted.'; }; CREATE REQUIRED PROPERTY listen_port -> std::int32 { CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION std::description := 'The TCP port the server listens on.'; SET default := 5656; # Really we want a uint16, but oh well CREATE CONSTRAINT std::min_value(0); CREATE CONSTRAINT std::max_value(65535); }; CREATE MULTI PROPERTY listen_addresses -> std::str { CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION std::description := 'The TCP/IP address(es) on which the server is to listen for \ connections from client applications.'; }; CREATE MULTI LINK auth -> cfg::Auth { CREATE ANNOTATION cfg::system := 'true'; }; CREATE MULTI LINK email_providers -> cfg::EmailProviderConfig { CREATE ANNOTATION std::description := 'The list of email providers that can be used to send emails.'; }; CREATE PROPERTY current_email_provider_name -> std::str { CREATE ANNOTATION std::description := 'The name of the current email provider.'; }; CREATE PROPERTY allow_dml_in_functions -> std::bool { SET default := false; CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION cfg::internal := 'true'; }; CREATE PROPERTY allow_bare_ddl -> cfg::AllowBareDDL { SET default := cfg::AllowBareDDL.AlwaysAllow; CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION std::description := 'Whether DDL is allowed to be executed outside a migration.'; }; CREATE PROPERTY store_migration_sdl -> cfg::StoreMigrationSDL { SET default := cfg::StoreMigrationSDL.NeverStore; CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION std::description := 'When to store resulting SDL of a Migration. This may be slow.'; }; CREATE PROPERTY apply_access_policies -> std::bool { SET default := true; CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION cfg::session_cfg_permissions := '"cfg::perm::configure_apply_access_policies"'; CREATE ANNOTATION std::description := 'Whether access policies will be applied when running queries.'; }; CREATE PROPERTY apply_access_policies_pg -> std::bool { SET default := true; CREATE ANNOTATION cfg::affects_compilation := 'false'; CREATE ANNOTATION cfg::session_cfg_permissions := '"cfg::perm::configure_apply_access_policies"'; CREATE ANNOTATION std::description := 'Whether access policies will be applied when running queries over \ SQL adapter.'; }; CREATE PROPERTY allow_user_specified_id -> std::bool { SET default := false; CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION cfg::session_cfg_permissions := '"cfg::perm::configure_allow_user_specified_id"'; CREATE ANNOTATION std::description := 'Whether inserts are allowed to set the \'id\' property.'; }; CREATE PROPERTY simple_scoping -> std::bool { CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION cfg::session_cfg_permissions := '"*"'; CREATE ANNOTATION std::description := 'Whether to use the new simple scoping behavior \ (disable path factoring)'; }; CREATE PROPERTY warn_old_scoping -> std::bool { CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION cfg::session_cfg_permissions := '"*"'; CREATE ANNOTATION std::description := 'Whether to warn when depending on old scoping behavior.'; }; CREATE MULTI PROPERTY cors_allow_origins -> std::str { CREATE ANNOTATION std::description := 'List of origins that can be returned in the \ Access-Control-Allow-Origin HTTP header'; }; CREATE PROPERTY auto_rebuild_query_cache -> std::bool { SET default := true; CREATE ANNOTATION std::description := 'Recompile all cached queries on DDL if enabled.'; }; CREATE PROPERTY auto_rebuild_query_cache_timeout -> std::duration { CREATE ANNOTATION std::description := 'Maximum time to spend recompiling cached queries on DDL.'; SET default := '60 seconds'; }; CREATE PROPERTY query_cache_mode -> cfg::QueryCacheMode { SET default := cfg::QueryCacheMode.Default; CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION std::description := 'Where the query cache is finally stored'; }; CREATE PROPERTY query_cache_size -> std::int32 { SET default := 1000; CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION cfg::requires_restart := 'true'; CREATE ANNOTATION std::description := 'Maximum number of queries to cache in the query cache'; }; # HTTP Worker Configuration CREATE PROPERTY http_max_connections -> std::int64 { SET default := 10; CREATE ANNOTATION std::description := 'The maximum number of concurrent HTTP connections.'; CREATE ANNOTATION cfg::system := 'true'; }; # Exposed backend settings follow. # When exposing a new setting, remember to modify # the _read_sys_config function to select the value # from pg_settings in the config_backend CTE. CREATE PROPERTY shared_buffers -> cfg::memory { CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION cfg::backend_setting := '"shared_buffers"'; CREATE ANNOTATION cfg::requires_restart := 'true'; CREATE ANNOTATION std::description := 'The amount of memory used for shared memory buffers.'; }; CREATE PROPERTY query_work_mem -> cfg::memory { CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION cfg::backend_setting := '"work_mem"'; CREATE ANNOTATION std::description := 'The amount of memory used by internal query operations such as \ sorting.'; }; CREATE PROPERTY maintenance_work_mem -> cfg::memory { CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION cfg::backend_setting := '"maintenance_work_mem"'; CREATE ANNOTATION std::description := 'The amount of memory used by operations such as \ CREATE INDEX.'; }; CREATE PROPERTY effective_cache_size -> cfg::memory { CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION cfg::backend_setting := '"effective_cache_size"'; CREATE ANNOTATION std::description := 'An estimate of the effective size of the disk cache available \ to a single query.'; }; CREATE PROPERTY effective_io_concurrency -> std::int64 { CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION cfg::backend_setting := '"effective_io_concurrency"'; CREATE ANNOTATION std::description := 'The number of concurrent disk I/O operations that can be \ executed simultaneously.'; }; CREATE PROPERTY default_statistics_target -> std::int64 { CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION cfg::backend_setting := '"default_statistics_target"'; CREATE ANNOTATION std::description := 'The default data statistics target for the planner.'; }; CREATE PROPERTY force_database_error -> std::str { SET default := 'false'; CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION cfg::session_cfg_permissions := '"*"'; CREATE ANNOTATION std::description := 'A hook to force all queries to produce an error.'; }; CREATE REQUIRED PROPERTY _pg_prepared_statement_cache_size -> std::int16 { CREATE ANNOTATION cfg::system := 'true'; CREATE ANNOTATION std::description := 'The maximum number of prepared statements each backend \ connection could hold at the same time.'; CREATE CONSTRAINT std::min_value(1); SET default := 100; }; CREATE PROPERTY track_query_stats -> cfg::QueryStatsOption { CREATE ANNOTATION cfg::backend_setting := '"edb_stat_statements.track"'; CREATE ANNOTATION std::description := 'Select what queries are tracked in sys::QueryStats'; }; }; CREATE TYPE cfg::Config EXTENDING cfg::AbstractConfig; CREATE TYPE cfg::InstanceConfig EXTENDING cfg::AbstractConfig; CREATE TYPE cfg::DatabaseConfig EXTENDING cfg::AbstractConfig; CREATE ALIAS cfg::BranchConfig := cfg::DatabaseConfig; CREATE FUNCTION cfg::get_config_json( NAMED ONLY sources: OPTIONAL array = {}, NAMED ONLY max_source: OPTIONAL std::str = {} ) -> std::json { USING SQL $$ SELECT coalesce( jsonb_object_agg( cfg.name, -- Redact config values from extension configs, since -- they might contain secrets, and it isn't worth the -- trouble right now to care about which ones actually do. (CASE WHEN cfg.name LIKE '%::%' AND cfg.value != 'null'::jsonb THEN jsonb_set(to_jsonb(cfg), '{value}', '{"redacted": true}'::jsonb) ELSE to_jsonb(cfg) END) ), '{}'::jsonb ) FROM edgedb_VER._read_sys_config( sources::edgedb._sys_config_source_t[], max_source::edgedb._sys_config_source_t ) AS cfg $$; }; CREATE FUNCTION cfg::_quote(text: std::str) -> std::str { SET volatility := 'Immutable'; SET internal := true; USING SQL $$ SELECT replace(quote_literal(text), '''''', '\\''') $$ }; CREATE CAST FROM std::int64 TO cfg::memory { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM cfg::memory TO std::int64 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::str TO cfg::memory { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.str_to_cfg_memory'; }; CREATE CAST FROM cfg::memory TO std::str { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.cfg_memory_to_str'; }; CREATE CAST FROM std::json TO cfg::memory { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.str_to_cfg_memory( edgedb_VER.jsonb_extract_scalar(val, 'string', detail => detail) ) $$; }; CREATE CAST FROM cfg::memory TO std::json { SET volatility := 'Immutable'; USING SQL $$ SELECT to_jsonb(edgedb_VER.cfg_memory_to_str(val)) $$; }; CREATE INFIX OPERATOR std::`=` (l: cfg::memory, r: cfg::memory) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'=(int8,int8)'; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL cfg::memory, r: OPTIONAL cfg::memory) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: cfg::memory, r: cfg::memory) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>(int8,int8)'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL cfg::memory, r: OPTIONAL cfg::memory) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>` (l: cfg::memory, r: cfg::memory) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>(int8,int8)'; }; CREATE INFIX OPERATOR std::`>=` (l: cfg::memory, r: cfg::memory) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>=(int8,int8)'; }; CREATE INFIX OPERATOR std::`<` (l: cfg::memory, r: cfg::memory) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<(int8,int8)'; }; CREATE INFIX OPERATOR std::`<=` (l: cfg::memory, r: cfg::memory) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<=(int8,int8)'; }; ================================================ FILE: edb/lib/enc.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright EdgeDB Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE MODULE std::enc; CREATE SCALAR TYPE std::enc::Base64Alphabet EXTENDING enum; CREATE FUNCTION std::enc::base64_encode( data: std::bytes, NAMED ONLY alphabet: std::enc::Base64Alphabet = std::enc::Base64Alphabet.standard, NAMED ONLY padding: std::bool = true, ) -> std::str { CREATE ANNOTATION std::description := 'Encode given data as a base64 string'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "alphabet" = 'standard' AND "padding" THEN pg_catalog.translate( pg_catalog.encode("data", 'base64'), E'\n', '' ) WHEN "alphabet" = 'standard' AND NOT "padding" THEN pg_catalog.translate( pg_catalog.rtrim( pg_catalog.encode("data", 'base64'), '=' ), E'\n', '' ) WHEN "alphabet" = 'urlsafe' AND "padding" THEN pg_catalog.translate( pg_catalog.encode("data", 'base64'), E'+/\n', '-_' ) WHEN "alphabet" = 'urlsafe' AND NOT "padding" THEN pg_catalog.translate( pg_catalog.rtrim( pg_catalog.encode("data", 'base64'), '=' ), E'+/\n', '-_' ) ELSE edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => ( 'invalid alphabet for std::enc::base64_encode: ' || pg_catalog.quote_literal("alphabet") ), detail => ( '{"hint":"Supported alphabets: standard, urlsafe."}' ) ) END $$; }; CREATE FUNCTION std::enc::base64_decode( data: std::str, NAMED ONLY alphabet: std::enc::Base64Alphabet = std::enc::Base64Alphabet.standard, NAMED ONLY padding: std::bool = true, ) -> std::bytes { CREATE ANNOTATION std::description := 'Decode the byte64-encoded byte string and return decoded bytes.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "alphabet" = 'standard' AND "padding" THEN pg_catalog.decode("data", 'base64') WHEN "alphabet" = 'standard' AND NOT "padding" THEN pg_catalog.decode( edgedb_VER.pad_base64_string("data"), 'base64' ) WHEN "alphabet" = 'urlsafe' AND "padding" THEN pg_catalog.decode( pg_catalog.translate("data", '-_', '+/'), 'base64' ) WHEN "alphabet" = 'urlsafe' AND NOT "padding" THEN pg_catalog.decode( edgedb_VER.pad_base64_string( pg_catalog.translate("data", '-_', '+/') ), 'base64' ) ELSE edgedb_VER.raise( NULL::bytea, 'invalid_parameter_value', msg => ( 'invalid alphabet for std::enc::base64_decode: ' || pg_catalog.quote_literal("alphabet") ), detail => ( '{"hint":"Supported alphabets: standard, urlsafe."}' ) ) END $$; }; ================================================ FILE: edb/lib/ext/ai.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE EXTENSION PACKAGE ai VERSION '1.0' { set ext_module := "ext::ai"; set dependencies := ["pgvector>=0.7"]; create module ext::ai; create module ext::ai::perm; create permission ext::ai::perm::provider_call; create permission ext::ai::perm::chat_prompt_read; create permission ext::ai::perm::chat_prompt_write; create scalar type ext::ai::ProviderAPIStyle extending enum; create abstract type ext::ai::ProviderConfig extending cfg::ConfigObject { create required property name: std::str { set readonly := true; create constraint exclusive; create annotation std::description := "Unique provider name."; }; create required property display_name: std::str { set readonly := true; create annotation std::description := "Human-friendly provider name."; }; create required property api_url: std::str { set readonly := true; create annotation std::description := "Provider API URL."; }; create property client_id: std::str { set readonly := true; create annotation std::description := "ID for client provided by model API vendor."; }; create required property secret: std::str { set readonly := true; set secret := true; create annotation std::description := "Secret provided by model API vendor."; }; create required property api_style: ext::ai::ProviderAPIStyle { create annotation std::description := "The API style exposed by this provider."; }; }; create type ext::ai::CustomProviderConfig extending ext::ai::ProviderConfig { alter property display_name { set default := 'Custom'; }; alter property api_style { set default := ext::ai::ProviderAPIStyle.OpenAI; }; }; create type ext::ai::OpenAIProviderConfig extending ext::ai::ProviderConfig { alter property name { set protected := true; set default := 'builtin::openai'; }; alter property display_name { set protected := true; set default := 'OpenAI'; }; alter property api_url { set default := 'https://api.openai.com/v1' }; alter property api_style { set protected := true; set default := ext::ai::ProviderAPIStyle.OpenAI; }; }; create type ext::ai::MistralProviderConfig extending ext::ai::ProviderConfig { alter property name { set protected := true; set default := 'builtin::mistral'; }; alter property display_name { set protected := true; set default := 'Mistral'; }; alter property api_url { set default := 'https://api.mistral.ai/v1' }; alter property api_style { set protected := true; set default := ext::ai::ProviderAPIStyle.OpenAI; }; }; create type ext::ai::AnthropicProviderConfig extending ext::ai::ProviderConfig { alter property name { set protected := true; set default := 'builtin::anthropic'; }; alter property display_name { set protected := true; set default := 'Anthropic'; }; alter property api_url { set default := 'https://api.anthropic.com/v1' }; alter property api_style { set protected := true; set default := ext::ai::ProviderAPIStyle.Anthropic; }; }; create type ext::ai::OllamaProviderConfig extending ext::ai::ProviderConfig { alter property name { set protected := true; set default := 'builtin::ollama'; }; alter property display_name { set protected := true; set default := 'Ollama'; }; alter property api_url { set default := 'http://localhost:11434/api' }; alter property secret { set default := '' }; alter property api_style { set protected := true; set default := ext::ai::ProviderAPIStyle.Ollama; }; }; create type ext::ai::Config extending cfg::ExtensionConfig { create required property indexer_naptime: std::duration { set default := '10s'; create annotation std::description := ' Specifies the minimum delay between runs of the deferred ext::ai::index indexer on any given branch. '; }; create multi link providers: ext::ai::ProviderConfig { create annotation std::description := "AI model provider configurations."; }; }; create abstract inheritable annotation ext::ai::model_name; create abstract inheritable annotation ext::ai::model_provider; create abstract type ext::ai::Model extending std::BaseObject { create annotation ext::ai::model_name := ""; create annotation ext::ai::model_provider := ""; }; create abstract inheritable annotation ext::ai::embedding_model_max_input_tokens; create abstract inheritable annotation ext::ai::embedding_model_max_batch_tokens; create abstract inheritable annotation ext::ai::embedding_model_max_batch_size; create abstract inheritable annotation ext::ai::embedding_model_max_output_dimensions; create abstract inheritable annotation ext::ai::embedding_model_supports_shortening; create abstract type ext::ai::EmbeddingModel extending ext::ai::Model { create annotation ext::ai::embedding_model_max_input_tokens := ""; # for now, use the openai batch limit as the default. create annotation ext::ai::embedding_model_max_batch_tokens := "8191"; create annotation ext::ai::embedding_model_max_batch_size := ""; create annotation ext::ai::embedding_model_max_output_dimensions := ""; create annotation ext::ai::embedding_model_supports_shortening := "false"; }; create abstract inheritable annotation ext::ai::text_gen_model_context_window; create abstract type ext::ai::TextGenerationModel extending ext::ai::Model { create annotation ext::ai::text_gen_model_context_window := ""; }; # OpenAI models create abstract type ext::ai::OpenAITextEmbedding3SmallModel extending ext::ai::EmbeddingModel { alter annotation ext::ai::model_name := "text-embedding-3-small"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::embedding_model_max_input_tokens := "8191"; alter annotation ext::ai::embedding_model_max_batch_tokens := "8191"; alter annotation ext::ai::embedding_model_max_output_dimensions := "1536"; alter annotation ext::ai::embedding_model_supports_shortening := "true"; }; create abstract type ext::ai::OpenAITextEmbedding3LargeModel extending ext::ai::EmbeddingModel { alter annotation ext::ai::model_name := "text-embedding-3-large"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::embedding_model_max_input_tokens := "8191"; alter annotation ext::ai::embedding_model_max_batch_tokens := "8191"; # Note: ext::pgvector is currently limited to 2000 dimensions, # so returned embeddings will be automatically truncated if # pgvector is used as the index implementation. alter annotation ext::ai::embedding_model_max_output_dimensions := "3072"; alter annotation ext::ai::embedding_model_supports_shortening := "true"; }; create abstract type ext::ai::OpenAITextEmbeddingAda002Model extending ext::ai::EmbeddingModel { alter annotation ext::ai::model_name := "text-embedding-ada-002"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::embedding_model_max_input_tokens := "8191"; alter annotation ext::ai::embedding_model_max_batch_tokens := "8191"; alter annotation ext::ai::embedding_model_max_output_dimensions := "1536"; }; create abstract type ext::ai::OpenAIGPT_3_5_TurboModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "gpt-3.5-turbo"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::text_gen_model_context_window := "16385"; }; create abstract type ext::ai::OpenAIGPT_4_TurboPreviewModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "gpt-4-turbo-preview"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::OpenAIGPT_4_TurboModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "gpt-4-turbo"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::OpenAIGPT_4o_Model extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "gpt-4o"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::OpenAIGPT_4o_MiniModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "gpt-4o-mini"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::OpenAIGPT_4_Model extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "gpt-4"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::OpenAI_O1_PreviewModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "o1-preview"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::OpenAI_O1_MiniModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "o1-mini"; alter annotation ext::ai::model_provider := "builtin::openai"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; # Mistral models create abstract type ext::ai::MistralEmbedModel extending ext::ai::EmbeddingModel { alter annotation ext::ai::model_name := "mistral-embed"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::embedding_model_max_input_tokens := "8192"; alter annotation ext::ai::embedding_model_max_batch_tokens := "16384"; alter annotation ext::ai::embedding_model_max_output_dimensions := "1024"; }; create abstract type ext::ai::MistralSmallModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "mistral-small-latest"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::text_gen_model_context_window := "32000"; }; # Mistral legacy model create abstract type ext::ai::MistralMediumModel extending ext::ai::TextGenerationModel { create annotation std::deprecated := "This model is noted as a legacy model in the Mistral docs."; alter annotation ext::ai::model_name := "mistral-medium-latest"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::text_gen_model_context_window := "32000"; }; create abstract type ext::ai::MistralLargeModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "mistral-large-latest"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::PixtralLargeModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "pixtral-large-latest"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::Ministral_3B_Model extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "ministral-3b-latest"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::Ministral_8B_Model extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "ministral-8b-latest"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::CodestralModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "codestral-latest"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::text_gen_model_context_window := "32000"; }; # Mistral free models create abstract type ext::ai::PixtralModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "pixtral-12b-2409"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::MistralNemo extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "open-mistral-nemo"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::text_gen_model_context_window := "128000"; }; create abstract type ext::ai::CodestralMamba extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "open-codestral-mamba"; alter annotation ext::ai::model_provider := "builtin::mistral"; alter annotation ext::ai::text_gen_model_context_window := "256000"; }; # Anthropic models # Anthropic most intelligent model create abstract type ext::ai::AnthropicClaude_3_5_SonnetModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "claude-3-5-sonnet-latest"; alter annotation ext::ai::model_provider := "builtin::anthropic"; alter annotation ext::ai::text_gen_model_context_window := "200000"; }; # Anthropic fastest model create abstract type ext::ai::AnthropicClaude_3_5_HaikuModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "claude-3-5-haiku-latest"; alter annotation ext::ai::model_provider := "builtin::anthropic"; alter annotation ext::ai::text_gen_model_context_window := "200000"; }; create abstract type ext::ai::AnthropicClaude3HaikuModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "claude-3-haiku-20240307"; alter annotation ext::ai::model_provider := "builtin::anthropic"; alter annotation ext::ai::text_gen_model_context_window := "200000"; }; create abstract type ext::ai::AnthropicClaude3SonnetModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "claude-3-sonnet-20240229"; alter annotation ext::ai::model_provider := "builtin::anthropic"; alter annotation ext::ai::text_gen_model_context_window := "200000"; }; create abstract type ext::ai::AnthropicClaude3OpusModel extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "claude-3-opus-latest"; alter annotation ext::ai::model_provider := "builtin::anthropic"; alter annotation ext::ai::text_gen_model_context_window := "200000"; }; # Ollama embedding models create abstract type ext::ai::OllamaLlama_3_2_Model extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "llama3.2"; alter annotation ext::ai::model_provider := "builtin::ollama"; alter annotation ext::ai::text_gen_model_context_window := "131072"; }; create abstract type ext::ai::OllamaLlama_3_3_Model extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "llama3.3"; alter annotation ext::ai::model_provider := "builtin::ollama"; alter annotation ext::ai::text_gen_model_context_window := "131072"; }; create abstract type ext::ai::OllamaNomicEmbedTextModel extending ext::ai::EmbeddingModel { alter annotation ext::ai::model_name := "nomic-embed-text"; alter annotation ext::ai::model_provider := "builtin::ollama"; alter annotation ext::ai::embedding_model_max_input_tokens := "2048"; alter annotation ext::ai::embedding_model_max_batch_tokens := "2048"; alter annotation ext::ai::embedding_model_max_output_dimensions := "768"; }; create abstract type ext::ai::OllamaBgeM3Model extending ext::ai::EmbeddingModel { alter annotation ext::ai::model_name := "bge-m3"; alter annotation ext::ai::model_provider := "builtin::ollama"; alter annotation ext::ai::embedding_model_max_input_tokens := "8192"; alter annotation ext::ai::embedding_model_max_batch_tokens := "8192"; alter annotation ext::ai::embedding_model_max_output_dimensions := "1024"; }; create abstract type ext::ai::OllamaSnowflakeArcticEmbed2Model extending ext::ai::EmbeddingModel { alter annotation ext::ai::model_name := "snowflake-arctic-embed2"; alter annotation ext::ai::model_provider := "builtin::ollama"; alter annotation ext::ai::embedding_model_max_input_tokens := "8192"; alter annotation ext::ai::embedding_model_max_batch_tokens := "8192"; alter annotation ext::ai::embedding_model_max_output_dimensions := "1024"; }; create scalar type ext::ai::DistanceFunction extending enum; create scalar type ext::ai::IndexType extending enum; create abstract inheritable annotation ext::ai::embedding_dimensions; create abstract index ext::ai::index ( named only embedding_model: str, named only dimensions: optional int64 = {}, named only distance_function: ext::ai::DistanceFunction = ext::ai::DistanceFunction.Cosine, named only index_type: ext::ai::IndexType = ext::ai::IndexType.HNSW, named only index_parameters: tuple = (m := 32, ef_construction := 100), named only truncate_to_max: bool = False, ) { create annotation std::description := "Semantic similarity index."; create annotation ext::ai::embedding_dimensions := ""; set deferrability := 'Required'; }; create function ext::ai::to_context( object: anyobject, ) -> std::str { create annotation std::description := "Evaluate the expression of an ai::index defined on the passed " ++ "object type and return it."; set volatility := 'Stable'; using sql expression; }; create function ext::ai::search( object: anyobject, query: array, ) -> optional tuple { create annotation std::description := ' Search an object using its ext::ai::index index. Returns objects that match the specified semantic query and the similarity score. '; set volatility := 'Stable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql expression; }; create function ext::ai::search( object: anyobject, query: str, ) -> optional tuple { create annotation std::description := ' Search an object using its ext::ai::index index. Gets an embedding for the query from the ai provider then returns objects that match the specified semantic query and the similarity score. '; set volatility := 'Volatile'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; set server_param_conversions := '{"query": ["ai_text_embedding", "object"]}'; set required_permissions := { ext::ai::perm::provider_call }; using sql expression; }; create scalar type ext::ai::ChatParticipantRole extending enum; create type ext::ai::ChatPromptMessage extending std::BaseObject { create required property participant_role: ext::ai::ChatParticipantRole { create annotation std::description := 'The role of the messages author.' }; create property participant_name: str { create annotation std::description := 'Optional name for the participant.' }; create required property content: str { create annotation std::description := 'Prompt message contenxt.' }; create access policy ap_read allow select using ( global ext::ai::perm::chat_prompt_read ); create access policy ap_write allow insert, update, delete using ( global ext::ai::perm::chat_prompt_write ); }; create type ext::ai::ChatPrompt extending std::BaseObject { create required property name: str { create constraint exclusive; create annotation std::description := 'Unique name for the prompt configuration'; }; create required multi link messages: ext::ai::ChatPromptMessage { create constraint exclusive; create annotation std::description := 'Messages in this prompt configuration'; }; create access policy ap_read allow select using ( global ext::ai::perm::chat_prompt_read ); create access policy ap_write allow insert, update, delete using ( global ext::ai::perm::chat_prompt_write ); }; insert ext::ai::ChatPrompt { name := 'builtin::rag-default', messages := { (insert ext::ai::ChatPromptMessage { participant_role := ext::ai::ChatParticipantRole.System, content := ( "You are an expert Q&A system.\n" ++ "Always answer questions based on the provided \ context information. Never use prior knowledge.\n" ++ "Follow these additional rules:\n\ 1. Never directly reference the given context in your \ answer.\n\ 2. Never include phrases like 'Based on the context, ...' \ or any similar phrases in your responses.\n\ 3. When the context does not provide information about \ the question, answer with \ 'No information available.'.\n\ Context information is below:\n{context}\n\ Given the context information above and not prior \ knowledge, answer the user query." ), }), (insert ext::ai::ChatPromptMessage { participant_role := ext::ai::ChatParticipantRole.User, content := ( "Query: {query}\n\ Answer: " ), }) } }; create index match for std::str using ext::ai::index; }; ================================================ FILE: edb/lib/ext/auth.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE EXTENSION PACKAGE auth VERSION '1.0' { set ext_module := "ext::auth"; set dependencies := ["pgcrypto>=1.3"]; create module ext::auth; create module ext::auth::perm; create permission ext::auth::perm::auth_read; create permission ext::auth::perm::auth_write; create permission ext::auth::perm::auth_read_user; create abstract type ext::auth::Auditable extending std::BaseObject { create required property created_at: std::datetime { set default := std::datetime_current(); set readonly := true; }; create required property modified_at: std::datetime { create rewrite insert, update using ( std::datetime_current() ); }; create access policy ap_read allow select using ( global ext::auth::perm::auth_read ); create access policy ap_write allow insert, update, delete using ( global ext::auth::perm::auth_write ); }; create type ext::auth::Identity extending ext::auth::Auditable { create required property issuer: std::str; create required property subject: std::str; create constraint exclusive on ((.issuer, .subject)); }; create type ext::auth::LocalIdentity extending ext::auth::Identity { alter property subject { create rewrite insert using (.id); }; }; create abstract type ext::auth::Factor extending ext::auth::Auditable { create required link identity: ext::auth::LocalIdentity { create constraint exclusive; on target delete delete source; }; }; create type ext::auth::EmailFactor extending ext::auth::Factor { create required property email: str; create property verified_at: std::datetime; }; create type ext::auth::EmailPasswordFactor extending ext::auth::EmailFactor { alter property email { create constraint exclusive; }; create required property password_hash: std::str; }; create type ext::auth::MagicLinkFactor extending ext::auth::EmailFactor { alter property email { create constraint exclusive; }; }; create type ext::auth::WebAuthnFactor extending ext::auth::EmailFactor { create required property user_handle: std::bytes; create required property credential_id: std::bytes { create constraint exclusive; }; create required property public_key: std::bytes { create constraint exclusive; }; create trigger email_shares_user_handle after insert for each do ( std::assert( __new__.user_handle = ( select detached ext::auth::WebAuthnFactor filter .email = __new__.email and not .id = __new__.id ).user_handle, message := "user_handle must be the same for a given email" ) ); create constraint exclusive on ((.email, .credential_id)); }; create type ext::auth::WebAuthnRegistrationChallenge extending ext::auth::Auditable { create required property challenge: std::bytes { create constraint exclusive; }; create required property email: std::str; create required property user_handle: std::bytes; create constraint exclusive on ((.user_handle, .email, .challenge)); }; create type ext::auth::WebAuthnAuthenticationChallenge extending ext::auth::Auditable { create required property challenge: std::bytes { create constraint exclusive; }; create required multi link factors: ext::auth::WebAuthnFactor { create constraint exclusive; on target delete delete source; }; }; create type ext::auth::PKCEChallenge extending ext::auth::Auditable { create required property challenge: std::str { create constraint exclusive; }; create property auth_token: std::str { create annotation std::description := "Identity provider's auth token."; }; create property refresh_token: std::str { create annotation std::description := "Identity provider's refresh token."; }; create property id_token: std::str { create annotation std::description := "Identity provider's OpenID Connect id_token."; }; create link identity: ext::auth::Identity { on target delete delete source; }; }; create type ext::auth::OneTimeCode extending ext::auth::Auditable { create required property code_hash: std::bytes { create constraint exclusive; create annotation std::description := "The securely hashed one-time code."; }; create required property expires_at: std::datetime { create annotation std::description := "The date and time when the code expires."; }; create index on (.expires_at); create required link factor: ext::auth::Factor { on target delete delete source; }; }; create scalar type ext::auth::AuthenticationAttemptType extending std::enum< SignIn, EmailVerification, PasswordReset, MagicLink, OneTimeCode >; create type ext::auth::AuthenticationAttempt extending ext::auth::Auditable { create required link factor: ext::auth::Factor { on target delete delete source; }; create required property attempt_type: ext::auth::AuthenticationAttemptType { create annotation std::description := "The type of authentication attempt being made."; }; create required property successful: std::bool { create annotation std::description := "Whether this authentication attempt was successful."; }; }; create scalar type ext::auth::VerificationMethod extending std::enum; create abstract type ext::auth::ProviderConfig extending cfg::ConfigObject { create required property name: std::str { set readonly := true; create constraint exclusive; } }; create abstract type ext::auth::OAuthProviderConfig extending ext::auth::ProviderConfig { alter property name { set protected := true; }; create required property secret: std::str { set readonly := true; set secret := true; create annotation std::description := "Secret provided by auth provider."; }; create required property client_id: std::str { set readonly := true; create annotation std::description := "ID for client provided by auth provider."; }; create required property display_name: std::str { set readonly := true; set protected := true; create annotation std::description := "Provider name to be displayed in login UI."; }; create property additional_scope: std::str { set readonly := true; create annotation std::description := "Space-separated list of scopes to be included in the \ authorize request to the OAuth provider."; }; }; create type ext::auth::OpenIDConnectProvider extending ext::auth::OAuthProviderConfig { alter property name { set protected := false; }; alter property display_name { set protected := false; }; create required property issuer_url: std::str { create annotation std::description := "The issuer URL of the provider."; }; create property logo_url: std::str { create annotation std::description := "A url to an image of the provider's logo."; }; create constraint exclusive on ((.issuer_url, .client_id)); }; create type ext::auth::AppleOAuthProvider extending ext::auth::OAuthProviderConfig { alter property name { set default := 'builtin::oauth_apple'; }; alter property display_name { set default := 'Apple'; }; }; create type ext::auth::AzureOAuthProvider extending ext::auth::OAuthProviderConfig { alter property name { set default := 'builtin::oauth_azure'; }; alter property display_name { set default := 'Azure'; }; }; create type ext::auth::DiscordOAuthProvider extending ext::auth::OAuthProviderConfig { alter property name { set default := 'builtin::oauth_discord'; }; alter property display_name { set default := 'Discord'; }; create required property prompt: std::str { create annotation std::description := "Controls how the authorization flow handles existing authorizations. \ If a user has previously authorized your application with the \ requested scopes and prompt is set to consent, it will request them \ to reapprove their authorization. If set to none, it will skip the \ authorization screen and redirect them back to your redirect URI \ without requesting their authorization. For passthrough scopes, like \ bot and webhook.incoming, authorization is always required."; set default := 'consent'; }; }; create type ext::auth::SlackOAuthProvider extending ext::auth::OAuthProviderConfig { alter property name { set default := 'builtin::oauth_slack'; }; alter property display_name { set default := 'Slack'; }; }; create type ext::auth::GitHubOAuthProvider extending ext::auth::OAuthProviderConfig { alter property name { set default := 'builtin::oauth_github'; }; alter property display_name { set default := 'GitHub'; }; }; create type ext::auth::GoogleOAuthProvider extending ext::auth::OAuthProviderConfig { alter property name { set default := 'builtin::oauth_google'; }; alter property display_name { set default := 'Google'; }; }; create type ext::auth::EmailPasswordProviderConfig extending ext::auth::ProviderConfig { alter property name { set default := 'builtin::local_emailpassword'; set protected := true; }; create required property require_verification: std::bool { set default := true; }; create required property verification_method: ext::auth::VerificationMethod { set default := ext::auth::VerificationMethod.Link; }; }; create type ext::auth::WebAuthnProviderConfig extending ext::auth::ProviderConfig { alter property name { set default := 'builtin::local_webauthn'; set protected := true; }; create required property relying_party_origin: std::str { create annotation std::description := "The full origin of the sign-in page including protocol and \ port of the application. If using the built-in UI, this \ should be the origin of the EdgeDB server."; }; create required property require_verification: std::bool { set default := true; }; create required property verification_method: ext::auth::VerificationMethod { set default := ext::auth::VerificationMethod.Link; }; }; create type ext::auth::MagicLinkProviderConfig extending ext::auth::ProviderConfig { alter property name { set default := 'builtin::local_magic_link'; set protected := true; }; create required property token_time_to_live: std::duration { set default := '10 minutes'; create annotation std::description := "The time after which a magic link token expires."; }; create required property verification_method: ext::auth::VerificationMethod { set default := ext::auth::VerificationMethod.Link; }; create required property auto_signup: std::bool { set default := false; }; }; create scalar type ext::auth::FlowType extending std::enum; create type ext::auth::UIConfig extending cfg::ConfigObject { create required property redirect_to: std::str { create annotation std::description := "The url to redirect to after successful sign in."; }; create property redirect_to_on_signup: std::str { create annotation std::description := "The url to redirect to after a new user signs up. \ If not set, 'redirect_to' will be used instead."; }; create required property flow_type: ext::auth::FlowType { create annotation std::description := "The flow used when requesting authentication."; set default := ext::auth::FlowType.PKCE; }; create property app_name: std::str { create annotation std::description := "The name of your application to be shown on the login \ screen."; create annotation std::deprecated := "Use the app_name property in ext::auth::AuthConfig instead."; }; create property logo_url: std::str { create annotation std::description := "A url to an image of your application's logo."; create annotation std::deprecated := "Use the logo_url property in ext::auth::AuthConfig instead."; }; create property dark_logo_url: std::str { create annotation std::description := "A url to an image of your application's logo to be used \ with the dark theme."; create annotation std::deprecated := "Use the dark_logo_url property in ext::auth::AuthConfig \ instead."; }; create property brand_color: std::str { create annotation std::description := "The brand color of your application as a hex string."; create annotation std::deprecated := "Use the brand_color property in ext::auth::AuthConfig \ instead."; }; }; create scalar type ext::auth::WebhookEvent extending std::enum< IdentityCreated, IdentityAuthenticated, EmailFactorCreated, EmailVerified, EmailVerificationRequested, PasswordResetRequested, MagicLinkRequested, OneTimeCodeRequested, OneTimeCodeVerified, >; create type ext::auth::WebhookConfig extending cfg::ConfigObject { create required property url: std::str { create annotation std::description := "The url to send webhooks to."; create constraint exclusive; }; create required multi property events: ext::auth::WebhookEvent { create annotation std::description := "The events to send webhooks for."; }; create property signing_secret_key: std::str { set secret := true; create annotation std::description := "The secret key used to sign webhook requests."; }; }; create function ext::auth::webhook_signing_key_exists( webhook_config: ext::auth::WebhookConfig ) -> std::bool { using ( select exists webhook_config.signing_secret_key ); SET required_permissions := ext::auth::perm::auth_read; }; create type ext::auth::AuthConfig extending cfg::ExtensionConfig { create multi link providers: ext::auth::ProviderConfig { create annotation std::description := "Configuration for auth provider clients."; }; create link ui: ext::auth::UIConfig { create annotation std::description := "Configuration for builtin auth UI. If not set the builtin \ UI is disabled."; }; create multi link webhooks: ext::auth::WebhookConfig { create annotation std::description := "Configuration for webhooks."; }; create property app_name: std::str { create annotation std::description := "The name of your application."; }; create property logo_url: std::str { create annotation std::description := "A url to an image of your application's logo."; }; create property dark_logo_url: std::str { create annotation std::description := "A url to an image of your application's logo to be used \ with the dark theme."; }; create property brand_color: std::str { create annotation std::description := "The brand color of your application as a hex string."; }; create property auth_signing_key: std::str { set secret := true; create annotation std::description := "The signing key used for auth extension. Must be at \ least 32 characters long."; }; create property token_time_to_live: std::duration { create annotation std::description := "The time after which an auth token expires. A value of 0 \ indicates that the token should never expire."; set default := '336 hours'; }; create multi property allowed_redirect_urls: std::str { create annotation std::description := "When redirecting the user in various flows, the URL will be \ checked against this list to ensure they are going \ to a trusted domain controlled by the application. URLs are \ matched based on checking if the candidate redirect URL is \ a match or a subdirectory of any of these allowed URLs"; }; }; create function ext::auth::signing_key_exists() -> std::bool { using ( select exists cfg::Config.extensions[is ext::auth::AuthConfig] .auth_signing_key ); SET required_permissions := ext::auth::perm::auth_read; }; create scalar type ext::auth::JWTAlgo extending enum; create function ext::auth::_jwt_check_signature( jwt: tuple, key: std::str, algo: ext::auth::JWTAlgo = ext::auth::JWTAlgo.HS256, ) -> std::json { set volatility := 'Stable'; using ( with module ext::auth, msg := jwt.header ++ "." ++ jwt.payload, hash := ( "sha256" if algo = JWTAlgo.RS256 or algo = JWTAlgo.HS256 else std::assert( false, message := "unsupported JWT algo") ), select std::to_json( std::to_str( std::enc::base64_decode( jwt.payload, padding := false, alphabet := std::enc::Base64Alphabet.urlsafe, ), ), ) order by assert( std::enc::base64_encode( ext::pgcrypto::hmac(msg, key, hash), padding := false, alphabet := std::enc::Base64Alphabet.urlsafe, ) = jwt.signature, message := "JWT signature mismatch", ) ); }; create function ext::auth::_jwt_parse( token: std::str, ) -> tuple { set volatility := 'Stable'; using ( for parts in std::str_split(token, ".") select ( header := parts[0], payload := parts[1], signature := parts[2], ) order by assert(len(parts) = 3, message := "JWT is malformed") ); }; create function ext::auth::_jwt_verify( token: std::str, key: std::str, algo: ext::auth::JWTAlgo = ext::auth::JWTAlgo.HS256, ) -> std::json { set volatility := 'Stable'; using ( for jwt in ( ext::auth::_jwt_check_signature( ext::auth::_jwt_parse(token), key, algo, ) ) with validity_range := std::range( std::to_datetime(json_get(jwt, "nbf")), std::to_datetime(json_get(jwt, "exp")), ), select jwt order by assert( std::contains( validity_range, std::datetime_of_transaction(), ), message := "JWT is expired or is not yet valid", ) ); }; create global ext::auth::client_token: std::str; create single global ext::auth::_client_token_id := ( for conf_key in ( ( select cfg::Config.extensions[is ext::auth::AuthConfig] limit 1 ).auth_signing_key ) for jwt_claims in ( ext::auth::_jwt_verify( global ext::auth::client_token, conf_key, ) ) select json_get(jwt_claims, "sub") ); alter type ext::auth::Identity { create access policy read_current allow select using ( not global ext::auth::perm::auth_read and global ext::auth::perm::auth_read_user and .id ?= global ext::auth::_client_token_id ); }; create single global ext::auth::ClientTokenIdentity := ( select ext::auth::Identity filter .id = global ext::auth::_client_token_id ); }; ================================================ FILE: edb/lib/ext/edgeqlhttp.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE EXTENSION PACKAGE edgeql_http VERSION '1.0'; ================================================ FILE: edb/lib/ext/graphql.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE EXTENSION PACKAGE graphql VERSION '1.0'; ================================================ FILE: edb/lib/ext/notebook.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE EXTENSION PACKAGE notebook VERSION '1.0' { SET internal := true; }; ================================================ FILE: edb/lib/ext/pg_trgm.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # create extension package pg_trgm version '1.6' { set ext_module := "ext::pg_trgm"; set sql_extensions := ["pg_trgm >=1.6"]; create module ext::pg_trgm; create type ext::pg_trgm::Config extending cfg::ExtensionConfig { create required property similarity_threshold: std::float32 { create annotation cfg::backend_setting := '"pg_trgm.similarity_threshold"'; create annotation cfg::session_cfg_permissions := '"*"'; create annotation std::description := "The current similarity threshold that is used by the " ++ "pg_trgm::similar() function, the pg_trgm::gin and " ++ "the pg_trgm::gist indexes. The threshold must be " ++ "between 0 and 1 (default is 0.3)."; set default := 0.3; create constraint std::min_value(0.0); create constraint std::max_value(1.0); }; create required property word_similarity_threshold: std::float32 { create annotation cfg::backend_setting := '"pg_trgm.word_similarity_threshold"'; create annotation cfg::session_cfg_permissions := '"*"'; create annotation std::description := "The current word similarity threshold that is used by the " ++ "pg_trgrm::word_similar() function. The threshold must be " ++ "between 0 and 1 (default is 0.6)."; set default := 0.6; create constraint std::min_value(0.0); create constraint std::max_value(1.0); }; create required property strict_word_similarity_threshold: std::float32 { create annotation cfg::backend_setting := '"pg_trgm.strict_word_similarity_threshold"'; create annotation cfg::session_cfg_permissions := '"*"'; create annotation std::description := "The current strict word similarity threshold that is used by " ++ "the pg_trgrm::strict_word_similar() function. The " ++ "threshold must be between 0 and 1 (default is 0.5)."; set default := 0.5; create constraint std::min_value(0.0); create constraint std::max_value(1.0); }; }; create function ext::pg_trgm::similarity( a: std::str, b: std::str, ) -> std::float32 { set volatility := 'Immutable'; using sql 'select 1.0::real - (a <-> b)'; }; create function ext::pg_trgm::similar( a: std::str, b: std::str, ) -> std::bool { set volatility := 'Stable'; # Depends on config. using sql 'select a % b'; }; create function ext::pg_trgm::similarity_dist( a: std::str, b: std::str, ) -> std::float32 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'select a <-> b'; }; create function ext::pg_trgm::word_similarity( a: std::str, b: std::str, ) -> std::float32 { set volatility := 'Immutable'; using sql 'select 1.0::real - (a <<-> b)'; }; create function ext::pg_trgm::word_similar( a: std::str, b: std::str, ) -> std::bool { set volatility := 'Stable'; # Depends on config. using sql 'select a <% b'; }; create function ext::pg_trgm::word_similarity_dist( a: std::str, b: std::str, ) -> std::float32 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'select a <<-> b'; }; create function ext::pg_trgm::strict_word_similarity( a: std::str, b: std::str, ) -> std::float32 { set volatility := 'Immutable'; using sql 'select 1.0::real - (a <<<-> b)'; }; create function ext::pg_trgm::strict_word_similar( a: std::str, b: std::str, ) -> std::bool { set volatility := 'Stable'; # Depends on config. using sql 'select a <<% b'; }; create function ext::pg_trgm::strict_word_similarity_dist( a: std::str, b: std::str, ) -> std::float32 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'select a <<<-> b'; }; create abstract index ext::pg_trgm::gin() { create annotation std::description := 'pg_trgm GIN index.'; set code := 'GIN (__col__ gin_trgm_ops)'; }; create abstract index ext::pg_trgm::gist( named only siglen: int64 = 12 ) { create annotation std::description := 'pg_trgm GIST index.'; set code := 'GIST (__col__ gist_trgm_ops(siglen = __kw_siglen__))'; }; create index match for std::str using ext::pg_trgm::gin; create index match for std::str using ext::pg_trgm::gist; }; ================================================ FILE: edb/lib/ext/pg_unaccent.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # create extension package pg_unaccent version '1.1' { set ext_module := "ext::pg_unaccent"; set sql_extensions := ["unaccent >=1.1"]; create module ext::pg_unaccent; create function ext::pg_unaccent::unaccent( text: std::str, ) -> std::str { set volatility := 'Immutable'; using sql 'select edgedb.unaccent(text)'; }; }; ================================================ FILE: edb/lib/ext/pgcrypto.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # create extension package pgcrypto version '1.3' { set ext_module := "ext::pgcrypto"; set sql_extensions := ["pgcrypto >=1.3"]; create module ext::pgcrypto; create function ext::pgcrypto::digest( data: std::str, type: std::str, ) -> std::bytes { set volatility := 'Immutable'; using sql function 'edgedb.digest'; }; create function ext::pgcrypto::digest( data: std::bytes, type: std::str, ) -> std::bytes { set volatility := 'Immutable'; using sql function 'edgedb.digest'; }; create function ext::pgcrypto::hmac( data: std::str, key: std::str, type: std::str, ) -> std::bytes { set volatility := 'Immutable'; using sql function 'edgedb.hmac'; }; create function ext::pgcrypto::hmac( data: std::bytes, key: std::bytes, type: std::str, ) -> std::bytes { set volatility := 'Immutable'; using sql function 'edgedb.hmac'; }; create function ext::pgcrypto::gen_salt( ) -> std::str { set volatility := 'Volatile'; using sql "SELECT edgedb.gen_salt('bf')"; }; create function ext::pgcrypto::gen_salt( type: std::str, ) -> std::str { set volatility := 'Volatile'; using sql 'SELECT edgedb.gen_salt("type")'; }; create function ext::pgcrypto::gen_salt( type: std::str, iter_count: std::int64, ) -> std::str { set volatility := 'Volatile'; using sql 'SELECT edgedb.gen_salt("type", "iter_count"::integer)'; }; create function ext::pgcrypto::crypt( password: std::str, salt: std::str, ) -> std::str { set volatility := 'Immutable'; using sql function 'edgedb.crypt'; }; }; ================================================ FILE: edb/lib/ext/pgvector.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # create extension package pgvector version '0.7.4' { set ext_module := "ext::pgvector"; set sql_extensions := ["vector >=0.7.4,<0.9.0"]; set sql_setup_script := $script$ -- Rename the vector_norm to be consistent with l2_norm ALTER FUNCTION edgedb.vector_norm(edgedb.vector) RENAME TO l2_norm; -- Add some helpers -- about 5-6 times slower than the C cast, but retains 0-based index CREATE FUNCTION sparsevec_to_text(val sparsevec) RETURNS text AS $$ DECLARE vectxt text := val::text; mid text[]; kv text[]; i int8; res text := '{'; BEGIN mid := string_to_array(substr(split_part(vectxt, '}', 1), 2), ','); FOR i IN 1..cardinality(mid) LOOP kv := string_to_array(mid[i], ':'); kv[1] := (kv[1]::int8 - 1)::text; res := res || kv[1] || ':' || kv[2] || ','; END LOOP; RETURN left(res, -1) || '}' || split_part(vectxt, '}', 2); END; $$ LANGUAGE plpgsql IMMUTABLE STRICT; -- about 10 times slower than a cast CREATE FUNCTION text_to_sparsevec(val text) RETURNS sparsevec AS $$ DECLARE mid text[]; kv text[]; i int8; res text := '{'; BEGIN IF val ~ '^\s*{\s*(\d+\s*:.+?,\s*)*\d+\s*:.+}\s*/\s*\d+\s*$' THEN mid := string_to_array(split_part(split_part(val, '}', 1), '{', 2), ','); FOR i IN 1..cardinality(mid) LOOP kv := string_to_array(mid[i], ':'); kv[1] := (trim(kv[1])::int8 + 1)::text; res := res || kv[1] || ':' || kv[2] || ','; END LOOP; RETURN (left(res, -1) || '}' || split_part(val, '}', 2))::sparsevec; ELSE RETURN val::sparsevec; END IF; END; $$ LANGUAGE plpgsql IMMUTABLE STRICT; CREATE FUNCTION sparsevec_to_jsonb(val sparsevec) RETURNS jsonb AS $$ DECLARE vectxt text := val::text; mid text[]; kv text[]; i int8; dim text := split_part(vectxt, '/', 2); res text := '{'; BEGIN mid := string_to_array(substr(split_part(vectxt, '}', 1), 2), ','); FOR i IN 1..cardinality(mid) LOOP kv := string_to_array(mid[i], ':'); kv[1] := (kv[1]::int8 - 1)::text; res := res || '"' || kv[1] || '":' || kv[2] || ','; END LOOP; RETURN (res || '"dim":' || dim || '}')::jsonb; END; $$ LANGUAGE plpgsql IMMUTABLE STRICT; CREATE FUNCTION jsonb_to_sparsevec(val jsonb) RETURNS sparsevec AS $$ DECLARE mid text[]; kv text[]; r record; i int8; dim text := NULL; res text := '{'; msg text; BEGIN IF jsonb_typeof(val) = 'object' THEN msg := 'missing "dim"'; FOR r IN SELECT * FROM jsonb_each(val) LOOP CASE WHEN r.key = 'dim' THEN dim := r.value::text; WHEN r.key ~ $r$\d+$r$ THEN res := res || (r.key::int8 + 1)::text || ':' || r.value::text || ','; ELSE msg := 'unexpected key in JSON object: ' || r.key; EXIT; END CASE; END LOOP; IF dim IS NOT NULL THEN RETURN (left(res, -1) || '}/' || dim)::sparsevec; END IF; ELSE msg := 'JSON object expected, got ' || jsonb_typeof(val) || ' instead'; END IF; RAISE EXCEPTION USING ERRCODE = 22000, MESSAGE = msg; END; $$ LANGUAGE plpgsql IMMUTABLE STRICT; $script$; set sql_teardown_script := $$ ALTER FUNCTION edgedb.l2_norm(edgedb.vector) RENAME TO vector_norm; -- remove helpers DROP FUNCTION edgedb.sparsevec_to_jsonb; DROP FUNCTION edgedb.jsonb_to_sparsevec; DROP FUNCTION edgedb.sparsevec_to_text; DROP FUNCTION edgedb.text_to_sparsevec; $$; create module ext::pgvector; create type ext::pgvector::Config extending cfg::ExtensionConfig { create required property probes: std::int64 { create annotation cfg::backend_setting := '"ivfflat.probes"'; create annotation cfg::session_cfg_permissions := '"*"'; create annotation std::description := "The number of probes (1 by default) used by IVFFlat " ++ "index. A higher value provides better recall at the " ++ "cost of speed, and it can be set to the number of " ++ "lists for exact nearest neighbor search (at which point " ++ "the planner won’t use the index)"; set default := 1; create constraint std::min_value(1); }; create required property ef_search: std::int64 { create annotation cfg::backend_setting := '"hnsw.ef_search"'; create annotation cfg::session_cfg_permissions := '"*"'; create annotation std::description := "The size of the dynamic candidate list for search (40 " ++ "by default) used by HNSW index. A higher value " ++ "provides better recall at the cost of speed."; set default := 40; create constraint std::min_value(1); }; }; create scalar type ext::pgvector::vector extending std::anyscalar { set id := "9565dd88-04f5-11ee-a691-0b6ebe179825"; set sql_type := "vector"; set sql_type_scheme := "vector({__arg_0__})"; set num_params := 1; }; create scalar type ext::pgvector::halfvec extending std::anyscalar { set id := "4ba84534-188e-43b4-a7ce-cea2af0f405b"; set sql_type := "halfvec"; set sql_type_scheme := "halfvec({__arg_0__})"; set num_params := 1; }; create scalar type ext::pgvector::sparsevec extending std::anyscalar { set id := "003e434d-cac2-430a-b238-fb39d73447d2"; set sql_type := "sparsevec"; set sql_type_scheme := "sparsevec({__arg_0__})"; set num_params := 1; }; create cast from ext::pgvector::vector to std::json { set volatility := 'Immutable'; using sql 'SELECT val::text::jsonb'; }; create cast from std::json to ext::pgvector::vector { set volatility := 'Immutable'; using sql $$ SELECT ( nullif(val, 'null'::jsonb)::text::vector ) $$; }; create cast from ext::pgvector::vector to std::str { set volatility := 'Immutable'; using sql cast; }; create cast from std::str to ext::pgvector::vector { set volatility := 'Immutable'; using sql cast; }; create cast from ext::pgvector::vector to std::bytes { set volatility := 'Immutable'; using sql 'SELECT vector_send(val)'; }; create cast from ext::pgvector::halfvec to std::json { set volatility := 'Immutable'; using sql 'SELECT val::text::jsonb'; }; create cast from std::json to ext::pgvector::halfvec { set volatility := 'Immutable'; using sql $$ SELECT ( nullif(val, 'null'::jsonb)::text::halfvec ) $$; }; create cast from ext::pgvector::halfvec to std::str { set volatility := 'Immutable'; using sql cast; }; create cast from std::str to ext::pgvector::halfvec { set volatility := 'Immutable'; using sql cast; }; create cast from ext::pgvector::halfvec to std::bytes { set volatility := 'Immutable'; using sql 'SELECT halfvec_send(val)'; }; create cast from ext::pgvector::sparsevec to std::str { set volatility := 'Immutable'; using sql 'SELECT sparsevec_to_text(val)'; }; create cast from std::str to ext::pgvector::sparsevec { set volatility := 'Immutable'; using sql 'SELECT text_to_sparsevec(val)'; }; create cast from ext::pgvector::sparsevec to std::bytes { set volatility := 'Immutable'; using sql 'SELECT sparsevec_send(val)'; }; create cast from ext::pgvector::sparsevec to std::json { set volatility := 'Immutable'; using sql 'SELECT sparsevec_to_jsonb(val)'; }; create cast from std::json to ext::pgvector::sparsevec { set volatility := 'Immutable'; using sql 'SELECT jsonb_to_sparsevec(val)'; }; # All casts from numerical arrays should allow assignment casts. create cast from array to ext::pgvector::vector { set volatility := 'Immutable'; using sql cast; allow assignment; }; create cast from array to ext::pgvector::halfvec { set volatility := 'Immutable'; using sql cast; allow assignment; }; create cast from array to ext::pgvector::vector { set volatility := 'Immutable'; using sql cast; allow assignment; }; create cast from array to ext::pgvector::halfvec { set volatility := 'Immutable'; using sql cast; allow assignment; }; create cast from array to ext::pgvector::vector { set volatility := 'Immutable'; using sql $$ SELECT val::float4[]::vector $$; allow assignment; }; create cast from array to ext::pgvector::halfvec { set volatility := 'Immutable'; using sql $$ SELECT val::float4[]::halfvec $$; allow assignment; }; create cast from array to ext::pgvector::vector { set volatility := 'Immutable'; using sql cast; allow assignment; }; create cast from array to ext::pgvector::halfvec { set volatility := 'Immutable'; using sql cast; allow assignment; }; create cast from array to ext::pgvector::vector { set volatility := 'Immutable'; using sql $$ SELECT val::float4[]::vector $$; allow assignment; }; create cast from array to ext::pgvector::halfvec { set volatility := 'Immutable'; using sql $$ SELECT val::float4[]::halfvec $$; allow assignment; }; create cast from ext::pgvector::vector to array { set volatility := 'Immutable'; using sql cast; }; create cast from ext::pgvector::halfvec to array { set volatility := 'Immutable'; using sql cast; }; create cast from ext::pgvector::vector to ext::pgvector::halfvec { set volatility := 'Immutable'; using sql cast; allow assignment; }; create cast from ext::pgvector::vector to ext::pgvector::sparsevec { set volatility := 'Immutable'; using sql cast; allow assignment; }; create cast from ext::pgvector::halfvec to ext::pgvector::vector { set volatility := 'Immutable'; using sql cast; allow implicit; }; create cast from ext::pgvector::halfvec to ext::pgvector::sparsevec { set volatility := 'Immutable'; using sql cast; allow assignment; }; create cast from ext::pgvector::sparsevec to ext::pgvector::vector { set volatility := 'Immutable'; using sql cast; allow assignment; }; create cast from ext::pgvector::sparsevec to ext::pgvector::halfvec { set volatility := 'Immutable'; using sql cast; allow assignment; }; create function ext::pgvector::euclidean_distance( a: ext::pgvector::vector, b: ext::pgvector::vector, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT a <-> b'; }; create function ext::pgvector::euclidean_distance( a: ext::pgvector::halfvec, b: ext::pgvector::halfvec, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT a <-> b'; }; create function ext::pgvector::euclidean_distance( a: ext::pgvector::sparsevec, b: ext::pgvector::sparsevec, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT a <-> b'; }; create function ext::pgvector::neg_inner_product( a: ext::pgvector::vector, b: ext::pgvector::vector, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT (a <#> b)'; }; create function ext::pgvector::neg_inner_product( a: ext::pgvector::halfvec, b: ext::pgvector::halfvec, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT (a <#> b)'; }; create function ext::pgvector::neg_inner_product( a: ext::pgvector::sparsevec, b: ext::pgvector::sparsevec, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT (a <#> b)'; }; create function ext::pgvector::cosine_distance( a: ext::pgvector::vector, b: ext::pgvector::vector, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT a <=> b'; }; create function ext::pgvector::cosine_distance( a: ext::pgvector::halfvec, b: ext::pgvector::halfvec, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT a <=> b'; }; create function ext::pgvector::cosine_distance( a: ext::pgvector::sparsevec, b: ext::pgvector::sparsevec, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT a <=> b'; }; create function ext::pgvector::taxicab_distance( a: ext::pgvector::vector, b: ext::pgvector::vector, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT a <+> b'; }; create function ext::pgvector::taxicab_distance( a: ext::pgvector::halfvec, b: ext::pgvector::halfvec, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT a <+> b'; }; create function ext::pgvector::taxicab_distance( a: ext::pgvector::sparsevec, b: ext::pgvector::sparsevec, ) -> std::float64 { set volatility := 'Immutable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; using sql 'SELECT a <+> b'; }; create function ext::pgvector::euclidean_norm( a: ext::pgvector::vector ) -> std::float64 { using sql function 'l2_norm'; set volatility := 'Immutable'; set force_return_cast := true; }; create function ext::pgvector::euclidean_norm( a: ext::pgvector::halfvec ) -> std::float64 { using sql function 'l2_norm'; set volatility := 'Immutable'; set force_return_cast := true; }; create function ext::pgvector::euclidean_norm( a: ext::pgvector::sparsevec ) -> std::float64 { using sql function 'l2_norm'; set volatility := 'Immutable'; set force_return_cast := true; }; create function ext::pgvector::l2_normalize( a: ext::pgvector::vector ) -> ext::pgvector::vector { using sql function 'l2_normalize'; set volatility := 'Immutable'; set force_return_cast := true; }; create function ext::pgvector::l2_normalize( a: ext::pgvector::halfvec ) -> ext::pgvector::halfvec { using sql function 'l2_normalize'; set volatility := 'Immutable'; set force_return_cast := true; }; create function ext::pgvector::l2_normalize( a: ext::pgvector::sparsevec ) -> ext::pgvector::sparsevec { using sql function 'l2_normalize'; set volatility := 'Immutable'; set force_return_cast := true; }; create function ext::pgvector::subvector( a: ext::pgvector::vector, i: std::int64, len: std::int64, ) -> ext::pgvector::vector { set volatility := 'Immutable'; using sql 'SELECT subvector(a, (i+1)::int, len::int)'; }; create function ext::pgvector::subvector( a: ext::pgvector::halfvec, i: std::int64, len: std::int64, ) -> ext::pgvector::halfvec { set volatility := 'Immutable'; using sql 'SELECT subvector(a, (i+1)::int, len::int)'; }; create function ext::pgvector::set_probes(num: std::int64) -> std::int64 { using sql $$ select num from ( select set_config('ivfflat.probes', num::text, true) ) as dummy; $$; CREATE ANNOTATION std::deprecated := 'This function is deprecated. ' ++ 'Configure ext::pgvector::Config::probes instead'; }; create abstract index ext::pgvector::ivfflat_euclidean( named only lists: int64 ) { create annotation std::description := 'IVFFlat index for euclidean distance.'; set code := 'ivfflat (__col__ vector_l2_ops) WITH (lists = __kw_lists__)'; }; create abstract index ext::pgvector::ivfflat_ip( named only lists: int64 ) { create annotation std::description := 'IVFFlat index for inner product.'; set code := 'ivfflat (__col__ vector_ip_ops) WITH (lists = __kw_lists__)'; }; create abstract index ext::pgvector::ivfflat_cosine( named only lists: int64 ) { create annotation std::description := 'IVFFlat index for cosine distance.'; set code := 'ivfflat (__col__ vector_cosine_ops) WITH (lists = __kw_lists__)'; }; create abstract index ext::pgvector::hnsw_euclidean( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for euclidean distance.'; set code := $$ hnsw (__col__ vector_l2_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create abstract index ext::pgvector::hnsw_ip( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for inner product.'; set code := $$ hnsw (__col__ vector_ip_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create abstract index ext::pgvector::hnsw_cosine( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for cosine distance.'; set code := $$ hnsw (__col__ vector_cosine_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create abstract index ext::pgvector::hnsw_taxicab( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for taxicab (L1) distance.'; set code := $$ hnsw (__col__ vector_l1_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create index match for ext::pgvector::vector using ext::pgvector::ivfflat_euclidean; create index match for ext::pgvector::vector using ext::pgvector::ivfflat_ip; create index match for ext::pgvector::vector using ext::pgvector::ivfflat_cosine; create index match for ext::pgvector::vector using ext::pgvector::hnsw_euclidean; create index match for ext::pgvector::vector using ext::pgvector::hnsw_ip; create index match for ext::pgvector::vector using ext::pgvector::hnsw_cosine; create index match for ext::pgvector::vector using ext::pgvector::hnsw_taxicab; create abstract index ext::pgvector::ivfflat_hv_euclidean( named only lists: int64 ) { create annotation std::description := 'IVFFlat index for euclidean distance.'; set code := 'ivfflat (__col__ halfvec_l2_ops) WITH (lists = __kw_lists__)'; }; create abstract index ext::pgvector::ivfflat_hv_ip( named only lists: int64 ) { create annotation std::description := 'IVFFlat index for inner product.'; set code := 'ivfflat (__col__ halfvec_ip_ops) WITH (lists = __kw_lists__)'; }; create abstract index ext::pgvector::ivfflat_hv_cosine( named only lists: int64 ) { create annotation std::description := 'IVFFlat index for cosine distance.'; set code := 'ivfflat (__col__ halfvec_cosine_ops) WITH (lists = __kw_lists__)'; }; create abstract index ext::pgvector::hnsw_hv_euclidean( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for euclidean distance.'; set code := $$ hnsw (__col__ halfvec_l2_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create abstract index ext::pgvector::hnsw_hv_ip( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for inner product.'; set code := $$ hnsw (__col__ halfvec_ip_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create abstract index ext::pgvector::hnsw_hv_cosine( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for cosine distance.'; set code := $$ hnsw (__col__ halfvec_cosine_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create abstract index ext::pgvector::hnsw_hv_taxicab( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for taxicab (L1) distance.'; set code := $$ hnsw (__col__ halfvec_l1_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create index match for ext::pgvector::halfvec using ext::pgvector::ivfflat_hv_euclidean; create index match for ext::pgvector::halfvec using ext::pgvector::ivfflat_hv_ip; create index match for ext::pgvector::halfvec using ext::pgvector::ivfflat_hv_cosine; create index match for ext::pgvector::halfvec using ext::pgvector::hnsw_hv_euclidean; create index match for ext::pgvector::halfvec using ext::pgvector::hnsw_hv_ip; create index match for ext::pgvector::halfvec using ext::pgvector::hnsw_hv_cosine; create index match for ext::pgvector::halfvec using ext::pgvector::hnsw_hv_taxicab; create abstract index ext::pgvector::hnsw_sv_euclidean( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for euclidean distance.'; set code := $$ hnsw (__col__ sparsevec_l2_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create abstract index ext::pgvector::hnsw_sv_ip( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for inner product.'; set code := $$ hnsw (__col__ sparsevec_ip_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create abstract index ext::pgvector::hnsw_sv_cosine( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for cosine distance.'; set code := $$ hnsw (__col__ sparsevec_cosine_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create abstract index ext::pgvector::hnsw_sv_taxicab( named only m: int64 = 16, named only ef_construction: int64 = 64, ) { create annotation std::description := 'HNSW index for taxicab (L1) distance.'; set code := $$ hnsw (__col__ sparsevec_l1_ops) WITH (m = __kw_m__, ef_construction = __kw_ef_construction__) $$; }; create index match for ext::pgvector::sparsevec using ext::pgvector::hnsw_sv_euclidean; create index match for ext::pgvector::sparsevec using ext::pgvector::hnsw_sv_ip; create index match for ext::pgvector::sparsevec using ext::pgvector::hnsw_sv_cosine; create index match for ext::pgvector::sparsevec using ext::pgvector::hnsw_sv_taxicab; }; ================================================ FILE: edb/lib/fts.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE MODULE std::fts; CREATE SCALAR TYPE std::fts::Language EXTENDING enum< ara, hye, eus, cat, dan, nld, eng, fin, fra, deu, ell, hin, hun, ind, gle, ita, nor, por, ron, rus, spa, swe, tur, > { CREATE ANNOTATION std::description := ' Languages supported by PostgreSQL FTS, ElasticSearch and Apache Lucene. Names are ISO 639-3 language identifiers. '; }; CREATE SCALAR TYPE std::fts::Weight EXTENDING enum { CREATE ANNOTATION std::description := " Weight category. Weight values for each category can be provided in std::fts::search. "; }; CREATE ABSTRACT INDEX std::fts::index { CREATE ANNOTATION std::description := "Full-text search index based on the Postgres's GIN index."; SET code := ''; # overridden by a special case }; CREATE SCALAR TYPE std::fts::document { SET transient := true; }; create index match for std::fts::document using std::fts::index; CREATE FUNCTION std::fts::with_options( text: std::str, NAMED ONLY language: anyenum, NAMED ONLY weight_category: optional std::fts::Weight = std::fts::Weight.A, ) -> std::fts::document { CREATE ANNOTATION std::description := ' Adds language and weight category information to a string, so it be indexed with std::fts::index. '; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE FUNCTION std::fts::search( object: anyobject, query: std::str, named only language: std::str = std::fts::Language.eng, named only weights: optional array = {}, ) -> optional tuple { CREATE ANNOTATION std::description := ' Search an object using its std::fts::index index. Returns objects that match the specified query and the matching score. '; SET volatility := 'Stable'; USING SQL EXPRESSION; }; CREATE SCALAR TYPE std::fts::PGLanguage EXTENDING enum< xxx_simple, ara, hye, eus, cat, dan, nld, eng, fin, fra, deu, ell, hin, hun, ind, gle, ita, lit, npi, nor, por, ron, rus, srp, spa, swe, tam, tur, yid, > { CREATE ANNOTATION std::description :=' Languages supported by PostgreSQL FTS. Names are ISO 639-3 language identifiers or Postgres regconfig names prefixed with `xxx_`. '; }; CREATE SCALAR TYPE std::fts::ElasticLanguage EXTENDING enum< ara, bul, cat, ces, ckb, dan, deu, ell, eng, eus, fas, fin, fra, gle, glg, hin, hun, hye, ind, ita, lav, nld, nor, por, ron, rus, spa, swe, tha, tur, zho, edb_Brazilian, edb_ChineseJapaneseKorean, > { CREATE ANNOTATION std::description := ' Languages supported by ElasticSearch. Names are ISO 639-3 language identifiers or EdgeDB language identifers. '; }; CREATE SCALAR TYPE std::fts::LuceneLanguage EXTENDING enum< ara, ben, bul, cat, ces, ckb, dan, deu, ell, eng, est, eus, fas, fin, fra, gle, glg, hin, hun, hye, ind, ita, lav, lit, nld, nor, por, ron, rus, spa, srp, swe, tha, tur, edb_Brazilian, edb_ChineseJapaneseKorean, edb_Indian, > { CREATE ANNOTATION std::description := ' Languages supported by Apache Lucene. Names are ISO 639-3 language identifiers or EdgeDB language identifers. '; }; ================================================ FILE: edb/lib/math.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE MODULE std::math; CREATE FUNCTION std::math::abs(x: std::anyreal) -> std::anyreal { CREATE ANNOTATION std::description := 'Return the absolute value of the input *x*.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'abs'; }; CREATE FUNCTION std::math::ceil(x: std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Round up to the nearest integer.'; SET volatility := 'Immutable'; USING SQL 'SELECT "x";'; }; CREATE FUNCTION std::math::ceil(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Round up to the nearest integer.'; SET volatility := 'Immutable'; USING SQL 'SELECT ceil("x");' }; CREATE FUNCTION std::math::ceil(x: std::bigint) -> std::bigint { CREATE ANNOTATION std::description := 'Round up to the nearest integer.'; SET volatility := 'Immutable'; USING SQL 'SELECT "x";' }; CREATE FUNCTION std::math::ceil(x: std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Round up to the nearest integer.'; SET volatility := 'Immutable'; USING SQL 'SELECT ceil("x");' }; CREATE FUNCTION std::math::floor(x: std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Round down to the nearest integer.'; SET volatility := 'Immutable'; USING SQL 'SELECT "x";'; }; CREATE FUNCTION std::math::floor(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Round down to the nearest integer.'; SET volatility := 'Immutable'; USING SQL 'SELECT floor("x");'; }; CREATE FUNCTION std::math::floor(x: std::bigint) -> std::bigint { CREATE ANNOTATION std::description := 'Round down to the nearest integer.'; SET volatility := 'Immutable'; USING SQL 'SELECT "x";' }; CREATE FUNCTION std::math::floor(x: std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Round down to the nearest integer.'; SET volatility := 'Immutable'; USING SQL 'SELECT floor("x");'; }; CREATE FUNCTION std::math::exp(x: std::int64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the exponential of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'exp'; }; CREATE FUNCTION std::math::exp(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the exponential of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'exp'; }; CREATE FUNCTION std::math::exp(x: std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Return the exponential of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'exp'; }; CREATE FUNCTION std::math::ln(x: std::int64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the natural logarithm of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'ln'; }; CREATE FUNCTION std::math::ln(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the natural logarithm of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'ln'; }; CREATE FUNCTION std::math::ln(x: std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Return the natural logarithm of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'ln'; }; CREATE FUNCTION std::math::lg(x: std::int64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the base 10 logarithm of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'log'; }; CREATE FUNCTION std::math::lg(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the base 10 logarithm of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'log'; }; CREATE FUNCTION std::math::lg(x: std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Return the base 10 logarithm of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'log'; }; CREATE FUNCTION std::math::log(x: std::decimal, NAMED ONLY base: std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Return the logarithm of the input value in the specified *base*.'; SET volatility := 'Immutable'; USING SQL $$ SELECT log("base", "x") $$; }; CREATE FUNCTION std::math::sqrt(x: std::int64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the square root of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'sqrt'; }; CREATE FUNCTION std::math::sqrt(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the square root of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'sqrt'; }; CREATE FUNCTION std::math::sqrt(x: std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Return the square root of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'sqrt'; }; # std::math::mean # ----------- # The mean function returns an empty set if the input is empty set. On # all other inputs it returns the mean for that input set. CREATE FUNCTION std::math::mean(vals: SET OF std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Return the arithmetic mean of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'avg'; SET error_on_null_result := 'invalid input to mean(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::mean(vals: SET OF std::int64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the arithmetic mean of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'avg'; # SQL 'avg' returns numeric on integer inputs. SET force_return_cast := true; SET error_on_null_result := 'invalid input to mean(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::mean(vals: SET OF std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the arithmetic mean of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'avg'; SET error_on_null_result := 'invalid input to mean(): not ' ++ 'enough elements in input set'; }; # std::math::stddev # ------------ CREATE FUNCTION std::math::stddev(vals: SET OF std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Return the sample standard deviation of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'stddev'; SET error_on_null_result := 'invalid input to stddev(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::stddev(vals: SET OF std::int64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the sample standard deviation of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'stddev'; # SQL 'stddev' returns numeric on integer inputs. SET force_return_cast := true; SET error_on_null_result := 'invalid input to stddev(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::stddev(vals: SET OF std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the sample standard deviation of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'stddev'; SET error_on_null_result := 'invalid input to stddev(): not ' ++ 'enough elements in input set'; }; # std::math::stddev_pop # ---------------- CREATE FUNCTION std::math::stddev_pop(vals: SET OF std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Return the population standard deviation of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'stddev_pop'; SET error_on_null_result := 'invalid input to stddev_pop(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::stddev_pop(vals: SET OF std::int64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the population standard deviation of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'stddev_pop'; # SQL 'stddev_pop' returns numeric on integer inputs. SET force_return_cast := true; SET error_on_null_result := 'invalid input to stddev_pop(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::stddev_pop(vals: SET OF std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the population standard deviation of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'stddev_pop'; SET error_on_null_result := 'invalid input to stddev_pop(): not ' ++ 'enough elements in input set'; }; # std::math::var # -------------- CREATE FUNCTION std::math::var(vals: SET OF std::decimal) -> OPTIONAL std::decimal { CREATE ANNOTATION std::description := 'Return the sample variance of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'variance'; SET error_on_null_result := 'invalid input to var(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::var(vals: SET OF std::int64) -> OPTIONAL std::float64 { CREATE ANNOTATION std::description := 'Return the sample variance of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'variance'; # SQL 'var' returns numeric on integer inputs. SET force_return_cast := true; SET error_on_null_result := 'invalid input to var(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::var(vals: SET OF std::float64) -> OPTIONAL std::float64 { CREATE ANNOTATION std::description := 'Return the sample variance of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'variance'; SET error_on_null_result := 'invalid input to var(): not ' ++ 'enough elements in input set'; }; # std::math::var_pop # ------------- CREATE FUNCTION std::math::var_pop(vals: SET OF std::decimal) -> OPTIONAL std::decimal { CREATE ANNOTATION std::description := 'Return the population variance of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'var_pop'; SET error_on_null_result := 'invalid input to var_pop(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::var_pop(vals: SET OF std::int64) -> OPTIONAL std::float64 { CREATE ANNOTATION std::description := 'Return the population variance of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'var_pop'; # SQL 'var_pop' returns numeric on integer inputs. SET force_return_cast := true; SET error_on_null_result := 'invalid input to var_pop(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::var_pop(vals: SET OF std::float64) -> OPTIONAL std::float64 { CREATE ANNOTATION std::description := 'Return the population variance of the input set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'var_pop'; SET error_on_null_result := 'invalid input to var_pop(): not ' ++ 'enough elements in input set'; }; CREATE FUNCTION std::math::pi() -> std::float64 { CREATE ANNOTATION std::description := 'Return the constant value of pi.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'pi'; }; CREATE FUNCTION std::math::e() -> std::float64 { CREATE ANNOTATION std::description := 'Return the constant value of e.'; SET volatility := 'Immutable'; USING SQL 'SELECT exp(1);' }; CREATE FUNCTION std::math::acos(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the inverse cosine of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'acos'; }; CREATE FUNCTION std::math::asin(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the inverse sine of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'asin'; }; CREATE FUNCTION std::math::atan(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the inverse tangent of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'atan'; }; CREATE FUNCTION std::math::atan2(y: std::float64, x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the inverse tangent of y/x of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'atan2'; }; CREATE FUNCTION std::math::cos(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the cosine of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'cos'; }; CREATE FUNCTION std::math::cot(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the cotangent of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'cot'; }; CREATE FUNCTION std::math::sin(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the sine of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'sin'; }; CREATE FUNCTION std::math::tan(x: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the tangent of the input value.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'tan'; }; ================================================ FILE: edb/lib/net.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE MODULE std::net; CREATE MODULE std::net::perm; CREATE PERMISSION std::net::perm::http_read; CREATE PERMISSION std::net::perm::http_write; CREATE SCALAR TYPE std::net::RequestState EXTENDING std::enum< Pending, InProgress, Completed, Failed >; CREATE SCALAR TYPE std::net::RequestFailureKind EXTENDING std::enum< NetworkError, Timeout >; CREATE MODULE std::net::http; CREATE SCALAR TYPE std::net::http::Method EXTENDING std::enum< `GET`, POST, PUT, `DELETE`, HEAD, OPTIONS, PATCH >; CREATE TYPE std::net::http::Response EXTENDING std::BaseObject { CREATE REQUIRED PROPERTY created_at: std::datetime; CREATE PROPERTY status: std::int16; CREATE PROPERTY headers: std::array>; CREATE PROPERTY body: std::bytes; CREATE ACCESS POLICY ap_read allow select using ( global std::net::perm::http_read ); CREATE ACCESS POLICY ap_write allow insert, update, delete using ( global std::net::perm::http_write ); }; CREATE TYPE std::net::http::ScheduledRequest extending std::BaseObject { CREATE REQUIRED PROPERTY state: std::net::RequestState; CREATE REQUIRED PROPERTY created_at: std::datetime; CREATE REQUIRED PROPERTY updated_at: std::datetime; CREATE PROPERTY failure: tuple; CREATE REQUIRED PROPERTY url: std::str; CREATE REQUIRED PROPERTY method: std::net::http::Method; CREATE PROPERTY headers: std::array>; CREATE PROPERTY body: std::bytes; CREATE LINK response: std::net::http::Response { CREATE CONSTRAINT exclusive; ON SOURCE DELETE DELETE TARGET; }; CREATE INDEX ON ((.state, .updated_at)); CREATE ACCESS POLICY ap_read allow select using ( global std::net::perm::http_read ); CREATE ACCESS POLICY ap_write allow insert, update, delete using ( global std::net::perm::http_write ); }; ALTER TYPE std::net::http::Response { CREATE LINK request := . > = {}, named only method: std::net::http::Method = std::net::http::Method.`GET`, ) -> std::net::http::ScheduledRequest { SET is_inlined := true; set required_permissions := { std::net::perm::http_write }; USING (( INSERT std::net::http::ScheduledRequest { url := url, method := method, headers := headers, body := body, created_at := std::datetime_of_statement(), updated_at := std::datetime_of_statement(), state := std::net::RequestState.Pending, } )); }; CREATE FUNCTION std::net::http::schedule_request( url: str, named only body: std::json, named only headers: optional std::array< std::tuple< name: std::str, value: std::str > > = {}, named only method: std::net::http::Method = std::net::http::Method.`GET`, ) -> std::net::http::ScheduledRequest { SET is_inlined := true; USING (( WITH has_content_type := any( std::str_lower(std::array_unpack(headers).0) = 'content-type' ), actual_headers := ( IF has_content_type THEN headers ELSE headers ++ [ (name := "Content-Type", value := "application/json") ] ) select std::net::http::schedule_request( url, body := to_bytes(body), headers := actual_headers, method := method, ) )); }; ================================================ FILE: edb/lib/pg.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE MODULE std::pg; CREATE ABSTRACT INDEX std::pg::hash { CREATE ANNOTATION std::description := 'Index based on a 32-bit hash derived from the indexed value.'; SET code := 'hash ((__col__))'; }; create index match for anytype using std::pg::hash; CREATE ABSTRACT INDEX std::pg::btree { CREATE ANNOTATION std::description := 'B-tree index can be used to retrieve data in sorted order.'; SET code := 'btree ((__col__) NULLS FIRST)'; }; create index match for anytype using std::pg::btree; CREATE ABSTRACT INDEX std::pg::gin { CREATE ANNOTATION std::description := 'GIN is an "inverted index" appropriate for data values that \ contain multiple elements, such as arrays and JSON.'; SET code := 'gin ((__col__))'; }; create index match for array using std::pg::gin; create index match for std::json using std::pg::gin; CREATE ABSTRACT INDEX std::pg::gist { CREATE ANNOTATION std::description := 'GIST index can be used to optimize searches involving ranges.'; SET code := 'gist ((__col__))'; }; create index match for array using std::pg::gist; create index match for range using std::pg::gist; create index match for multirange using std::pg::gist; CREATE ABSTRACT INDEX std::pg::spgist { CREATE ANNOTATION std::description := 'SP-GIST index can be used to optimize searches involving ranges \ and strings.'; SET code := 'spgist ((__col__))'; }; create index match for range using std::pg::spgist; create index match for std::str using std::pg::spgist; CREATE ABSTRACT INDEX std::pg::brin { CREATE ANNOTATION std::description := 'BRIN (Block Range INdex) index works with summaries about the values \ stored in consecutive physical block ranges in the database.'; SET code := 'brin ((__col__))'; }; create index match for range using std::pg::brin; create index match for std::anyreal using std::pg::brin; create index match for std::bytes using std::pg::brin; create index match for std::str using std::pg::brin; create index match for std::uuid using std::pg::brin; create index match for std::datetime using std::pg::brin; create index match for std::duration using std::pg::brin; create index match for std::cal::local_datetime using std::pg::brin; create index match for std::cal::local_date using std::pg::brin; create index match for std::cal::local_time using std::pg::brin; create index match for std::cal::relative_duration using std::pg::brin; create index match for std::cal::date_duration using std::pg::brin; create scalar type std::pg::json extending std::anyscalar; create scalar type std::pg::timestamptz extending std::anycontiguous; create scalar type std::pg::timestamp extending std::anycontiguous; create scalar type std::pg::date extending std::anydiscrete; create scalar type std::pg::interval extending std::anycontiguous; ================================================ FILE: edb/lib/schema.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## INTROSPECTION SCHEMA CREATE MODULE schema; CREATE SCALAR TYPE schema::Cardinality EXTENDING enum; CREATE SCALAR TYPE schema::TargetDeleteAction EXTENDING enum; CREATE SCALAR TYPE schema::SourceDeleteAction EXTENDING enum; CREATE SCALAR TYPE schema::OperatorKind EXTENDING enum; CREATE SCALAR TYPE schema::Volatility EXTENDING enum; CREATE SCALAR TYPE schema::ParameterKind EXTENDING enum; CREATE SCALAR TYPE schema::TypeModifier EXTENDING enum; CREATE SCALAR TYPE schema::AccessPolicyAction EXTENDING enum; CREATE SCALAR TYPE schema::AccessKind EXTENDING enum<`Select`, UpdateRead, UpdateWrite, `Delete`, `Insert`>; CREATE SCALAR TYPE schema::TriggerTiming EXTENDING enum; CREATE SCALAR TYPE schema::TriggerKind EXTENDING enum<`Update`, `Delete`, `Insert`>; CREATE SCALAR TYPE schema::TriggerScope EXTENDING enum; CREATE SCALAR TYPE schema::RewriteKind EXTENDING enum<`Update`, `Insert`>; CREATE SCALAR TYPE schema::MigrationGeneratedBy EXTENDING enum; CREATE SCALAR TYPE schema::IndexDeferrability EXTENDING enum; CREATE SCALAR TYPE schema::SplatStrategy EXTENDING enum; # Base type for all schema entities. CREATE ABSTRACT TYPE schema::Object EXTENDING std::BaseObject { CREATE REQUIRED PROPERTY name -> std::str; CREATE REQUIRED PROPERTY internal -> std::bool { SET default := false; }; CREATE REQUIRED PROPERTY builtin -> std::bool { SET default := false; }; CREATE PROPERTY computed_fields -> array; CREATE ACCESS POLICY not_internal ALLOW SELECT USING (not .internal); }; CREATE ABSTRACT TYPE schema::SubclassableObject EXTENDING schema::Object { CREATE PROPERTY abstract -> std::bool { SET default := false; }; # Backwards compatibility. CREATE PROPERTY is_abstract := .abstract; # Backwards compatibility. (But will maybe become a real property one day.) CREATE PROPERTY final := false; # Backwards compatibility. CREATE PROPERTY is_final := .final; }; # Base type for all *types*. CREATE ABSTRACT TYPE schema::Type EXTENDING schema::SubclassableObject; CREATE TYPE schema::PseudoType EXTENDING schema::Type; ALTER TYPE schema::Type { CREATE PROPERTY expr -> std::str; CREATE PROPERTY from_alias -> bool; # Backwards compatibility. CREATE PROPERTY is_from_alias := .from_alias; }; CREATE ABSTRACT LINK schema::reference { CREATE PROPERTY owned -> std::bool; # Backwards compatibility. CREATE PROPERTY is_owned := @owned; }; CREATE ABSTRACT LINK schema::ordered { CREATE PROPERTY index -> std::int64; }; CREATE TYPE schema::Module EXTENDING schema::Object; CREATE ABSTRACT TYPE schema::PrimitiveType EXTENDING schema::Type; CREATE ABSTRACT TYPE schema::CollectionType EXTENDING schema::PrimitiveType; CREATE TYPE schema::Array EXTENDING schema::CollectionType { CREATE REQUIRED LINK element_type -> schema::Type; CREATE PROPERTY dimensions -> array; }; CREATE TYPE schema::ArrayExprAlias EXTENDING schema::Array; CREATE TYPE schema::TupleElement EXTENDING std::BaseObject { CREATE REQUIRED LINK type -> schema::Type; CREATE PROPERTY name -> std::str; }; CREATE TYPE schema::Tuple EXTENDING schema::CollectionType { CREATE REQUIRED PROPERTY named -> bool; CREATE MULTI LINK element_types EXTENDING schema::ordered -> schema::TupleElement { CREATE CONSTRAINT std::exclusive; } }; CREATE TYPE schema::TupleExprAlias EXTENDING schema::Tuple; CREATE TYPE schema::Range EXTENDING schema::CollectionType { CREATE REQUIRED LINK element_type -> schema::Type; }; CREATE TYPE schema::RangeExprAlias EXTENDING schema::Range; CREATE TYPE schema::MultiRange EXTENDING schema::CollectionType { CREATE REQUIRED LINK element_type -> schema::Type; }; CREATE TYPE schema::MultiRangeExprAlias EXTENDING schema::MultiRange; CREATE TYPE schema::Delta EXTENDING schema::Object { CREATE MULTI LINK parents -> schema::Delta; }; CREATE ABSTRACT TYPE schema::AnnotationSubject EXTENDING schema::Object; CREATE TYPE schema::Annotation EXTENDING schema::AnnotationSubject { CREATE PROPERTY inheritable -> std::bool; }; ALTER TYPE schema::AnnotationSubject { CREATE MULTI LINK annotations EXTENDING schema::reference -> schema::Annotation { CREATE PROPERTY value -> std::str; ON TARGET DELETE ALLOW; }; }; CREATE ABSTRACT TYPE schema::InheritingObject EXTENDING schema::SubclassableObject { CREATE MULTI LINK bases EXTENDING schema::ordered -> schema::InheritingObject; CREATE MULTI LINK ancestors EXTENDING schema::ordered -> schema::InheritingObject; CREATE PROPERTY inherited_fields -> array; }; CREATE TYPE schema::Parameter EXTENDING schema::Object { CREATE REQUIRED LINK type -> schema::Type; CREATE REQUIRED PROPERTY typemod -> schema::TypeModifier; CREATE REQUIRED PROPERTY kind -> schema::ParameterKind; CREATE REQUIRED PROPERTY num -> std::int64; CREATE PROPERTY default -> std::str; }; CREATE ABSTRACT TYPE schema::CallableObject EXTENDING schema::AnnotationSubject { CREATE MULTI LINK params EXTENDING schema::ordered -> schema::Parameter { ON TARGET DELETE ALLOW; }; CREATE LINK return_type -> schema::Type; CREATE PROPERTY return_typemod -> schema::TypeModifier; }; CREATE ABSTRACT TYPE schema::VolatilitySubject EXTENDING schema::Object { CREATE PROPERTY volatility -> schema::Volatility { # NOTE: this default indicates the default value in the python # implementation, but is not itself a source of truth SET default := 'Volatile'; }; }; CREATE TYPE schema::Constraint EXTENDING schema::CallableObject, schema::InheritingObject { ALTER LINK params { CREATE PROPERTY value -> std::str; }; CREATE PROPERTY expr -> std::str; CREATE PROPERTY subjectexpr -> std::str; CREATE PROPERTY finalexpr -> std::str; CREATE PROPERTY errmessage -> std::str; CREATE PROPERTY delegated -> std::bool; CREATE PROPERTY except_expr -> std::str; }; CREATE ABSTRACT TYPE schema::ConsistencySubject EXTENDING schema::InheritingObject { CREATE MULTI LINK constraints EXTENDING schema::reference -> schema::Constraint { CREATE CONSTRAINT std::exclusive; ON TARGET DELETE ALLOW; }; }; ALTER TYPE schema::Constraint { CREATE LINK subject -> schema::ConsistencySubject; }; CREATE TYPE schema::Index EXTENDING schema::InheritingObject, schema::AnnotationSubject { CREATE PROPERTY expr -> std::str; CREATE PROPERTY except_expr -> std::str; CREATE PROPERTY deferrability -> schema::IndexDeferrability; CREATE PROPERTY deferred -> std::bool; CREATE PROPERTY active -> std::bool; CREATE PROPERTY build_concurrently -> std::bool; CREATE MULTI LINK params EXTENDING schema::ordered -> schema::Parameter { ON TARGET DELETE ALLOW; }; CREATE PROPERTY kwargs -> array>; }; CREATE ABSTRACT TYPE schema::Source EXTENDING schema::Object { CREATE MULTI LINK indexes EXTENDING schema::reference -> schema::Index { CREATE CONSTRAINT std::exclusive; ON TARGET DELETE ALLOW; }; }; CREATE ABSTRACT TYPE schema::Pointer EXTENDING schema::ConsistencySubject, schema::AnnotationSubject { CREATE PROPERTY cardinality -> schema::Cardinality; CREATE PROPERTY required -> std::bool; CREATE PROPERTY readonly -> std::bool; CREATE PROPERTY default -> std::str; CREATE PROPERTY expr -> std::str; CREATE PROPERTY secret -> std::bool; CREATE PROPERTY splat_strategy -> schema::SplatStrategy; CREATE PROPERTY linkful -> std::bool; CREATE PROPERTY protected -> std::bool; }; CREATE TYPE schema::AccessPolicy EXTENDING schema::InheritingObject, schema::AnnotationSubject; CREATE TYPE schema::Trigger EXTENDING schema::InheritingObject, schema::AnnotationSubject; CREATE TYPE schema::Rewrite EXTENDING schema::InheritingObject, schema::AnnotationSubject; ALTER TYPE schema::Source { CREATE MULTI LINK pointers EXTENDING schema::reference -> schema::Pointer { CREATE CONSTRAINT std::exclusive; ON TARGET DELETE ALLOW; }; }; CREATE TYPE schema::Alias EXTENDING schema::AnnotationSubject { CREATE REQUIRED PROPERTY expr -> std::str; # This link is DEFINITELY not optional. This works around # compiler weirdness that forces the DEFERRED RESTRICT # behavior, which prohibits required-ness. CREATE OPTIONAL LINK type -> schema::Type { ON TARGET DELETE DEFERRED RESTRICT; }; }; CREATE TYPE schema::ScalarType EXTENDING schema::PrimitiveType, schema::ConsistencySubject, schema::AnnotationSubject { CREATE PROPERTY default -> std::str; CREATE PROPERTY enum_values -> array; CREATE PROPERTY arg_values -> array; }; CREATE FUNCTION std::sequence_reset( seq: schema::ScalarType, value: std::int64, ) -> std::int64 { SET volatility := 'Volatile'; USING SQL $$ SELECT pg_catalog.setval( pg_catalog.quote_ident(sn.schema) || '.' || pg_catalog.quote_ident(sn.name), "value", true ) FROM ROWS FROM (edgedb_VER.get_user_sequence_backend_name("seq")) AS sn(schema text, name text) $$; }; CREATE FUNCTION std::sequence_reset( seq: schema::ScalarType, ) -> std::int64 { SET volatility := 'Volatile'; USING SQL $$ SELECT pg_catalog.setval( pg_catalog.quote_ident(sn.schema) || '.' || pg_catalog.quote_ident(sn.name), s.start_value, false ) FROM ROWS FROM (edgedb_VER.get_user_sequence_backend_name("seq")) AS sn(schema text, name text), LATERAL ( SELECT start_value FROM pg_catalog.pg_sequences WHERE schemaname = sn.schema AND sequencename = sn.name ) AS s $$; }; CREATE FUNCTION std::sequence_next( seq: schema::ScalarType, ) -> std::int64 { SET volatility := 'Volatile'; USING SQL $$ SELECT pg_catalog.nextval( pg_catalog.quote_ident(sn.schema) || '.' || pg_catalog.quote_ident(sn.name) ) FROM ROWS FROM (edgedb_VER.get_user_sequence_backend_name("seq")) AS sn(schema text, name text) $$; }; CREATE TYPE schema::ObjectType EXTENDING schema::Source, schema::ConsistencySubject, schema::InheritingObject, schema::Type, schema::AnnotationSubject; ALTER TYPE std::BaseObject { # N.B: Since __type__ is uniquely determined by the type of the # source object, as a special-case optimization we do not actually # store it in the database. Instead, we inject it into the views # we use to implement inheritance and inject it in the compiler # when operating on tables directly. CREATE REQUIRED LINK __type__ -> schema::ObjectType { SET readonly := True; SET protected := True; }; }; ALTER TYPE schema::ObjectType { CREATE MULTI LINK union_of -> schema::ObjectType; CREATE MULTI LINK intersection_of -> schema::ObjectType; CREATE MULTI LINK access_policies EXTENDING schema::reference -> schema::AccessPolicy { CREATE CONSTRAINT std::exclusive; ON TARGET DELETE ALLOW; }; CREATE MULTI LINK triggers EXTENDING schema::reference -> schema::Trigger { CREATE CONSTRAINT std::exclusive; ON TARGET DELETE ALLOW; }; CREATE PROPERTY compound_type := ( EXISTS .union_of OR EXISTS .intersection_of ); # Backwards compatibility. CREATE PROPERTY is_compound_type := .compound_type; }; ALTER TYPE schema::AccessPolicy { CREATE REQUIRED LINK subject -> schema::ObjectType; CREATE MULTI PROPERTY access_kinds -> schema::AccessKind; CREATE PROPERTY condition -> std::str; CREATE REQUIRED PROPERTY action -> schema::AccessPolicyAction; CREATE PROPERTY expr -> std::str; CREATE PROPERTY errmessage -> std::str; }; ALTER TYPE schema::Trigger { CREATE REQUIRED LINK subject -> schema::ObjectType; CREATE REQUIRED PROPERTY timing -> schema::TriggerTiming; CREATE MULTI PROPERTY kinds -> schema::TriggerKind; CREATE REQUIRED PROPERTY scope -> schema::TriggerScope; CREATE PROPERTY expr -> std::str; CREATE PROPERTY condition -> std::str; }; ALTER TYPE schema::Rewrite { CREATE REQUIRED LINK subject -> schema::Pointer; CREATE REQUIRED PROPERTY kind -> schema::TriggerKind; CREATE REQUIRED PROPERTY expr -> std::str; }; CREATE TYPE schema::Link EXTENDING schema::Pointer, schema::Source; CREATE TYPE schema::Property EXTENDING schema::Pointer; ALTER TYPE schema::Pointer { CREATE LINK source -> schema::Source; CREATE LINK target -> schema::Type; CREATE MULTI LINK rewrites EXTENDING schema::reference -> schema::Rewrite { CREATE CONSTRAINT std::exclusive; ON TARGET DELETE ALLOW; }; }; ALTER TYPE schema::Link { ALTER LINK target SET TYPE schema::ObjectType USING (.target[IS schema::ObjectType]); CREATE MULTI LINK properties := .pointers[IS schema::Property]; CREATE PROPERTY on_target_delete -> schema::TargetDeleteAction; CREATE PROPERTY on_source_delete -> schema::SourceDeleteAction; }; ALTER TYPE schema::ObjectType { CREATE MULTI LINK links := .pointers[IS schema::Link]; CREATE MULTI LINK properties := .pointers[IS schema::Property]; }; CREATE TYPE schema::Global EXTENDING schema::AnnotationSubject { # This is most definitely NOT optional. It works around some # compiler weirdness which requires the on target delete deferred restrict CREATE OPTIONAL LINK target -> schema::Type { ON TARGET DELETE DEFERRED RESTRICT; }; CREATE PROPERTY required -> std::bool; CREATE PROPERTY cardinality -> schema::Cardinality; CREATE PROPERTY expr -> std::str; CREATE PROPERTY default -> std::str; }; CREATE TYPE schema::Permission EXTENDING schema::AnnotationSubject; CREATE TYPE schema::Function EXTENDING schema::CallableObject, schema::VolatilitySubject { CREATE PROPERTY preserves_optionality -> std::bool { SET default := false; }; CREATE PROPERTY body -> str; CREATE REQUIRED PROPERTY language -> str; CREATE MULTI LINK used_globals EXTENDING schema::ordered -> schema::Global; CREATE MULTI LINK used_permissions EXTENDING schema::ordered -> schema::Permission; CREATE MULTI LINK required_permissions EXTENDING schema::ordered -> schema::Permission; }; CREATE TYPE schema::Operator EXTENDING schema::CallableObject, schema::VolatilitySubject { CREATE PROPERTY operator_kind -> schema::OperatorKind; CREATE PROPERTY abstract -> std::bool { SET default := false; }; # Backwards compatibility. CREATE PROPERTY is_abstract := .abstract; }; CREATE TYPE schema::Cast EXTENDING schema::AnnotationSubject, schema::VolatilitySubject { CREATE LINK from_type -> schema::Type; CREATE LINK to_type -> schema::Type; CREATE PROPERTY allow_implicit -> std::bool; CREATE PROPERTY allow_assignment -> std::bool; }; CREATE TYPE schema::Migration EXTENDING schema::AnnotationSubject { CREATE MULTI LINK parents -> schema::Migration; CREATE REQUIRED PROPERTY script -> str; CREATE PROPERTY sdl -> str; CREATE PROPERTY message -> str; CREATE PROPERTY generated_by -> schema::MigrationGeneratedBy; }; # The package link is added in sys.edgeql CREATE TYPE schema::Extension EXTENDING schema::AnnotationSubject; CREATE TYPE schema::FutureBehavior EXTENDING schema::Object; ================================================ FILE: edb/lib/std/00-prelude.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE MODULE std; CREATE MODULE ext; ================================================ FILE: edb/lib/std/10-scalars.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE PSEUDO TYPE `anytype`; CREATE PSEUDO TYPE `anytuple`; CREATE PSEUDO TYPE `anyobject`; CREATE ABSTRACT SCALAR TYPE std::anyscalar; CREATE ABSTRACT SCALAR TYPE std::anypoint EXTENDING std::anyscalar; CREATE ABSTRACT SCALAR TYPE std::anydiscrete EXTENDING std::anypoint; CREATE ABSTRACT SCALAR TYPE std::anycontiguous EXTENDING std::anypoint; CREATE SCALAR TYPE std::bool EXTENDING std::anyscalar; CREATE SCALAR TYPE std::bytes EXTENDING std::anyscalar; CREATE SCALAR TYPE std::uuid EXTENDING std::anyscalar; CREATE SCALAR TYPE std::str EXTENDING std::anyscalar; CREATE SCALAR TYPE std::json EXTENDING std::anyscalar; CREATE SCALAR TYPE std::datetime EXTENDING std::anycontiguous; CREATE SCALAR TYPE std::duration EXTENDING std::anycontiguous; CREATE ABSTRACT SCALAR TYPE std::anyreal EXTENDING std::anyscalar; CREATE ABSTRACT SCALAR TYPE std::anyint EXTENDING std::anyreal; CREATE SCALAR TYPE std::int16 EXTENDING std::anyint; CREATE SCALAR TYPE std::int32 EXTENDING std::anyint, std::anydiscrete; CREATE SCALAR TYPE std::int64 EXTENDING std::anyint, std::anydiscrete; CREATE ABSTRACT SCALAR TYPE std::anyfloat EXTENDING std::anyreal, std::anycontiguous; CREATE SCALAR TYPE std::float32 EXTENDING std::anyfloat; CREATE SCALAR TYPE std::float64 EXTENDING std::anyfloat; CREATE ABSTRACT SCALAR TYPE std::anynumeric EXTENDING std::anyreal; CREATE SCALAR TYPE std::decimal EXTENDING std::anynumeric, std::anycontiguous; CREATE SCALAR TYPE std::bigint EXTENDING std::anynumeric, std::anyint; CREATE ABSTRACT SCALAR TYPE std::sequence EXTENDING std::int64; CREATE ABSTRACT SCALAR TYPE std::anyenum EXTENDING std::anyscalar; ================================================ FILE: edb/lib/std/15-attrs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Generic annotations. CREATE ABSTRACT ANNOTATION std::description; ALTER ABSTRACT ANNOTATION std::description { CREATE ANNOTATION std::description := 'A short documentation string.'; }; CREATE ABSTRACT ANNOTATION std::title { CREATE ANNOTATION std::description := 'A human-readable name.'; }; CREATE ABSTRACT ANNOTATION std::deprecated { CREATE ANNOTATION std::description := 'A marker that an item is deprecated.'; }; CREATE ABSTRACT ANNOTATION std::identifier; CREATE MODULE std::lang; CREATE MODULE std::lang::go; CREATE ABSTRACT ANNOTATION std::lang::go::type; CREATE MODULE std::lang::js; CREATE ABSTRACT ANNOTATION std::lang::js::type; CREATE MODULE std::lang::py; CREATE ABSTRACT ANNOTATION std::lang::py::type; CREATE MODULE std::lang::rs; CREATE ABSTRACT ANNOTATION std::lang::rs::type; ================================================ FILE: edb/lib/std/17-abstractops.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # All EdgeDB types support ordering and comparison. # The below definitions of abstract operators declare this fact # for the benefit of generic expressions (e.g. in abstract constraints). CREATE ABSTRACT INFIX OPERATOR std::`>=` (l: anytype, r: anytype) -> std::bool { CREATE ANNOTATION std::identifier := 'ge'; }; CREATE ABSTRACT INFIX OPERATOR std::`>` (l: anytype, r: anytype) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; }; CREATE ABSTRACT INFIX OPERATOR std::`<=` (l: anytype, r: anytype) -> std::bool { CREATE ANNOTATION std::identifier := 'le'; }; CREATE ABSTRACT INFIX OPERATOR std::`<` (l: anytype, r: anytype) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; }; CREATE ABSTRACT INFIX OPERATOR std::`=` (l: anytype, r: anytype) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; }; CREATE ABSTRACT INFIX OPERATOR std::`?=` (l: OPTIONAL anytype, r: OPTIONAL anytype) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; }; CREATE ABSTRACT INFIX OPERATOR std::`!=` (l: anytype, r: anytype) -> std::bool { CREATE ANNOTATION std::identifier := 'ne'; }; CREATE ABSTRACT INFIX OPERATOR std::`?!=` (l: OPTIONAL anytype, r: OPTIONAL anytype) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; }; ================================================ FILE: edb/lib/std/20-genericfuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Fundamental polymorphic functions # std::assert_single -- runtime cardinality assertion (upper bound) # ----------------------------------------------------------------- CREATE FUNCTION std::assert_single( input: SET OF anytype, NAMED ONLY message: OPTIONAL str = {}, ) -> OPTIONAL anytype { CREATE ANNOTATION std::description := "Check that the input set contains at most one element, raise CardinalityViolationError otherwise."; SET volatility := 'Immutable'; SET preserves_optionality := true; USING SQL EXPRESSION; }; # std::assert_exists -- runtime cardinality assertion (lower bound) # ----------------------------------------------------------------- CREATE FUNCTION std::assert_exists( input: SET OF anytype, NAMED ONLY message: OPTIONAL str = {}, ) -> SET OF anytype { CREATE ANNOTATION std::description := "Check that the input set contains at least one element, raise CardinalityViolationError otherwise."; SET volatility := 'Immutable'; SET preserves_upper_cardinality := true; USING SQL EXPRESSION; }; # std::assert_distinct -- runtime multiplicity assertion # ------------------------------------------------------ CREATE FUNCTION std::assert_distinct( input: SET OF anytype, NAMED ONLY message: OPTIONAL str = {}, ) -> SET OF anytype { CREATE ANNOTATION std::description := "Check that the input set is a proper set, i.e. all elements are unique"; SET volatility := 'Immutable'; SET preserves_optionality := true; SET preserves_upper_cardinality := true; USING SQL EXPRESSION; }; # std::assert -- boolean assertion # -------------------------------- CREATE FUNCTION std::assert( input: bool, NAMED ONLY message: OPTIONAL str = {}, ) -> bool { CREATE ANNOTATION std::description := "Assert that a boolean value is true."; SET volatility := 'Stable'; USING SQL $$ SELECT ( edgedb_VER.raise_on_null( nullif("input", false), 'cardinality_violation', "constraint" => 'std::assert', msg => coalesce("message", 'assertion failed') ) ) $$; }; # std::materialized_exists -- force materialization of a set # ---------------------------------------------------------- CREATE FUNCTION std::materialized( input: anytype, ) -> anytype { CREATE ANNOTATION std::description := "Force materialization of a set."; SET volatility := 'Volatile'; USING SQL EXPRESSION; }; # std::len # -------- CREATE FUNCTION std::len(str: std::str) -> std::int64 { CREATE ANNOTATION std::description := 'A polymorphic function to calculate a "length" of its first argument.'; SET volatility := 'Immutable'; USING SQL $$ SELECT char_length("str")::bigint $$; }; CREATE FUNCTION std::len(bytes: std::bytes) -> std::int64 { CREATE ANNOTATION std::description := 'A polymorphic function to calculate a "length" of its first argument.'; SET volatility := 'Immutable'; USING SQL $$ SELECT length("bytes")::bigint $$; }; CREATE FUNCTION std::len(array: array) -> std::int64 { CREATE ANNOTATION std::description := 'A polymorphic function to calculate a "length" of its first argument.'; SET volatility := 'Immutable'; USING SQL $$ SELECT cardinality("array")::bigint $$; }; # std::sum # -------- CREATE FUNCTION std::sum(s: SET OF std::bigint) -> std::bigint { CREATE ANNOTATION std::description := 'Return the arithmetic sum of values in a set.'; SET volatility := 'Immutable'; SET initial_value := 0; SET force_return_cast := true; USING SQL FUNCTION 'sum'; }; CREATE FUNCTION std::sum(s: SET OF std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Return the arithmetic sum of values in a set.'; SET volatility := 'Immutable'; SET initial_value := 0; USING SQL FUNCTION 'sum'; }; CREATE FUNCTION std::sum(s: SET OF std::int32) -> std::int64 { CREATE ANNOTATION std::description := 'Return the arithmetic sum of values in a set.'; SET volatility := 'Immutable'; SET initial_value := 0; SET force_return_cast := true; USING SQL FUNCTION 'sum'; }; CREATE FUNCTION std::sum(s: SET OF std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Return the arithmetic sum of values in a set.'; SET volatility := 'Immutable'; SET initial_value := 0; SET force_return_cast := true; USING SQL FUNCTION 'sum'; }; CREATE FUNCTION std::sum(s: SET OF std::float32) -> std::float32 { CREATE ANNOTATION std::description := 'Return the arithmetic sum of values in a set.'; SET volatility := 'Immutable'; SET initial_value := 0; USING SQL FUNCTION 'sum'; }; CREATE FUNCTION std::sum(s: SET OF std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Return the arithmetic sum of values in a set.'; SET volatility := 'Immutable'; SET initial_value := 0; USING SQL FUNCTION 'sum'; }; # std::count # ---------- CREATE FUNCTION std::count(s: SET OF anytype) -> std::int64 { CREATE ANNOTATION std::description := 'Return the number of elements in a set.'; SET volatility := 'Immutable'; SET initial_value := 0; USING SQL FUNCTION 'count'; }; # std::random # ----------- CREATE FUNCTION std::random() -> std::float64 { CREATE ANNOTATION std::description := 'Return a pseudo-random number in the range `0.0 <= x < 1.0`'; SET volatility := 'Volatile'; USING SQL FUNCTION 'random'; }; # std::min # -------- CREATE FUNCTION std::min(vals: SET OF anytype) -> OPTIONAL anytype { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET fallback := true; SET preserves_optionality := true; USING SQL EXPRESSION; }; # Postgres only implements min and max for specific scalars and their # respective arrays, but in EdgeDB every type is orderable and so # minimum and maximum value can be determined for all types. The # general catch-all using `anytype` above is valid for all types, but # it is somewhat slower than the specialized natively implemented min # and max aggregates. So for the types that Postgres supports, we want # to use the more specialized implementation. # # Turns out that the min/max implementation for arrays is not # noticeably faster than the fallback we use, so there's no # specialized version of it in the polymorphic implementations. CREATE FUNCTION std::min(vals: SET OF anyreal) -> OPTIONAL anyreal { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF anyenum) -> OPTIONAL anyenum { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF str) -> OPTIONAL str { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF datetime) -> OPTIONAL datetime { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; CREATE FUNCTION std::min(vals: SET OF duration) -> OPTIONAL duration { CREATE ANNOTATION std::description := 'Return the smallest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'min'; }; # std::max # -------- CREATE FUNCTION std::max(vals: SET OF anytype) -> OPTIONAL anytype { CREATE ANNOTATION std::description := 'Return the greatest value of the input set.'; SET volatility := 'Immutable'; SET fallback := true; SET preserves_optionality := true; USING SQL EXPRESSION; }; # Postgres only implements min and max for specific scalars and their # respective arrays, but in EdgeDB every type is orderable and so # minimum and maximum value can be determined for all types. The # general catch-all using `anytype` above is valid for all types, but # it is somewhat slower than the specialized natively implemented min # and max aggregates. So for the types that Postgres supports, we want # to use the more specialized implementation. # # Turns out that the min/max implementation for arrays is not # noticeably faster than the fallback we use, so there's no # specialized version of it in the polymorphic implementations. CREATE FUNCTION std::max(vals: SET OF anyreal) -> OPTIONAL anyreal { CREATE ANNOTATION std::description := 'Return the greatest value of the input set.'; SET volatility := 'Immutable'; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF anyenum) -> OPTIONAL anyenum { CREATE ANNOTATION std::description := 'Return the greatest value of the input set.'; SET volatility := 'Immutable'; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF str) -> OPTIONAL str { CREATE ANNOTATION std::description := 'Return the greatest value of the input set.'; SET volatility := 'Immutable'; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF datetime) -> OPTIONAL datetime { CREATE ANNOTATION std::description := 'Return the greatest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; CREATE FUNCTION std::max(vals: SET OF duration) -> OPTIONAL duration { CREATE ANNOTATION std::description := 'Return the greatest value of the input set.'; SET volatility := 'Immutable'; SET force_return_cast := true; SET preserves_optionality := true; USING SQL FUNCTION 'max'; }; # std::all # -------- CREATE FUNCTION std::all(vals: SET OF std::bool) -> std::bool { CREATE ANNOTATION std::description := 'Generalized boolean `AND` applied to the set of *values*.'; SET volatility := 'Immutable'; SET initial_value := True; USING SQL FUNCTION 'bool_and'; }; # std::any # -------- CREATE FUNCTION std::any(vals: SET OF std::bool) -> std::bool { CREATE ANNOTATION std::description := 'Generalized boolean `OR` applied to the set of *values*.'; SET volatility := 'Immutable'; SET initial_value := False; USING SQL FUNCTION 'bool_or'; }; # std::enumerate # -------------- CREATE FUNCTION std::enumerate( vals: SET OF anytype ) -> SET OF tuple { CREATE ANNOTATION std::description := 'Return a set of tuples of the form `(index, element)`.'; SET volatility := 'Immutable'; SET preserves_optionality := true; SET preserves_upper_cardinality := true; USING SQL EXPRESSION; }; # std::round # ---------- CREATE FUNCTION std::round(val: std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Round to the nearest value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT "val" $$; }; CREATE FUNCTION std::round(val: std::float64) -> std::float64 { CREATE ANNOTATION std::description := 'Round to the nearest value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT round("val") $$; }; CREATE FUNCTION std::round(val: std::bigint) -> std::bigint { CREATE ANNOTATION std::description := 'Round to the nearest value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT "val"; $$; }; CREATE FUNCTION std::round(val: std::decimal) -> std::decimal { CREATE ANNOTATION std::description := 'Round to the nearest value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT round("val"); $$; }; CREATE FUNCTION std::round(val: std::decimal, d: std::int64) -> std::decimal { CREATE ANNOTATION std::description := 'Round to the nearest value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT round("val", "d"::int4) $$; }; # std::contains # --------- CREATE FUNCTION std::contains(haystack: std::str, needle: std::str) -> std::bool { CREATE ANNOTATION std::description := 'A polymorphic function to test if a sequence contains a certain element.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( -- There was a regression in 12.0 (fixed in 12.1): strpos -- started to report 0 for empty search strings: -- https://postgr.es/m/CADT4RqAz7oN4vkPir86Kg1_mQBmBxCp-L_=9vRpgSNPJf0KRkw@mail.gmail.com -- -- This CASE..WHEN fixes this edge case. CASE WHEN "needle" = '' THEN 1 ELSE strpos("haystack", "needle") END ) != 0 $$; }; CREATE FUNCTION std::contains(haystack: std::bytes, needle: std::bytes) -> std::bool { CREATE ANNOTATION std::description := 'A polymorphic function to test if a sequence contains a certain element.'; SET volatility := 'Immutable'; USING SQL $$ SELECT position("needle" in "haystack") != 0 $$; }; CREATE FUNCTION std::contains(haystack: array, needle: anytype) -> std::bool { CREATE ANNOTATION std::description := 'A polymorphic function to test if a sequence contains a certain element.'; SET volatility := 'Immutable'; # Postgres only manages to inline this function if it isn't marked strict, # and we want it to be inlined so that std::pg::gin indexes work with it. SET impl_is_strict := false; USING SQL $$ SELECT "haystack" @> ARRAY["needle"] $$; }; CREATE FUNCTION std::contains(haystack: json, needle: json) -> std::bool { CREATE ANNOTATION std::description := 'A polymorphic function to test if one JSON value contains another JSON value.'; SET volatility := 'Immutable'; # Postgres only manages to inline this function if it isn't marked strict, # and we want it to be inlined so that std::pg::gin indexes work with it. SET impl_is_strict := false; USING SQL $$ SELECT "haystack" @> "needle" $$; }; # std::find # --------- CREATE FUNCTION std::find(haystack: std::str, needle: std::str) -> std::int64 { CREATE ANNOTATION std::description := 'A polymorphic function to find index of an element in a sequence.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( -- There was a regression in 12.0 (fixed in 12.1): strpos -- started to report 0 for empty search strings: -- https://postgr.es/m/CADT4RqAz7oN4vkPir86Kg1_mQBmBxCp-L_=9vRpgSNPJf0KRkw@mail.gmail.com -- -- This CASE..WHEN fixes this edge case. CASE WHEN "needle" = '' THEN 0 ELSE strpos("haystack", "needle") - 1 END )::int8 $$; }; CREATE FUNCTION std::find(haystack: std::bytes, needle: std::bytes) -> std::int64 { CREATE ANNOTATION std::description := 'A polymorphic function to find index of an element in a sequence.'; SET volatility := 'Immutable'; USING SQL $$ SELECT (position("needle" in "haystack") - 1)::int8 $$; }; CREATE FUNCTION std::find(haystack: array, needle: anytype, from_pos: std::int64=0) -> std::int64 { CREATE ANNOTATION std::description := 'A polymorphic function to find index of an element in a sequence.'; SET volatility := 'Immutable'; USING SQL $$ SELECT COALESCE( array_position("haystack", "needle", ("from_pos"::int4 + 1)::int4) - 1, -1)::int8 $$; }; # Generic comparison operators # ---------------------------- CREATE INFIX OPERATOR std::`=` (l: anyscalar, r: anyscalar) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL anyscalar, r: OPTIONAL anyscalar) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: anyscalar, r: anyscalar) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL anyscalar, r: OPTIONAL anyscalar) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>=` (l: anyscalar, r: anyscalar) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>='; }; CREATE INFIX OPERATOR std::`>` (l: anyscalar, r: anyscalar) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>'; }; CREATE INFIX OPERATOR std::`<=` (l: anyscalar, r: anyscalar) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<='; }; CREATE INFIX OPERATOR std::`<` (l: anyscalar, r: anyscalar) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<'; }; CREATE INFIX OPERATOR std::`=` (l: anytuple, r: anytuple) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR '='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL anytuple, r: OPTIONAL anytuple) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; SET recursive := true; }; CREATE INFIX OPERATOR std::`!=` (l: anytuple, r: anytuple) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR '<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL anytuple, r: OPTIONAL anytuple) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; SET recursive := true; }; CREATE INFIX OPERATOR std::`>=` (l: anytuple, r: anytuple) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>='; }; CREATE INFIX OPERATOR std::`>` (l: anytuple, r: anytuple) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>'; }; CREATE INFIX OPERATOR std::`<=` (l: anytuple, r: anytuple) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<='; }; CREATE INFIX OPERATOR std::`<` (l: anytuple, r: anytuple) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<'; }; ================================================ FILE: edb/lib/std/25-booloperators.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Standard boolean operators ## -------------------------- CREATE INFIX OPERATOR std::`OR` (a: std::bool, b: std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'or'; CREATE ANNOTATION std::description := 'Logical disjunction.'; SET volatility := 'Immutable'; SET impl_is_strict := false; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`AND` (a: std::bool, b: std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'and'; CREATE ANNOTATION std::description := 'Logical conjunction.'; SET volatility := 'Immutable'; SET impl_is_strict := false; USING SQL EXPRESSION; }; CREATE PREFIX OPERATOR std::`NOT` (v: std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'not'; CREATE ANNOTATION std::description := 'Logical negation.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::bool, r: std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::bool, r: OPTIONAL std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::bool, r: std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::bool, r: OPTIONAL std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>=` (l: std::bool, r: std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>='; }; CREATE INFIX OPERATOR std::`>` (l: std::bool, r: std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>'; }; CREATE INFIX OPERATOR std::`<=` (l: std::bool, r: std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<='; }; CREATE INFIX OPERATOR std::`<` (l: std::bool, r: std::bool) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<'; }; ## Boolean casts ## ------------- CREATE CAST FROM std::str TO std::bool { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.str_to_bool'; }; CREATE CAST FROM std::bool TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; ================================================ FILE: edb/lib/std/25-enumoperators.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Standard enum operators ## ----------------------- CREATE INFIX OPERATOR std::`=` (l: std::anyenum, r: std::anyenum) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::anyenum, r: OPTIONAL std::anyenum) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::anyenum, r: std::anyenum) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::anyenum, r: OPTIONAL std::anyenum) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>=` (l: std::anyenum, r: std::anyenum) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>='; }; CREATE INFIX OPERATOR std::`>` (l: std::anyenum, r: std::anyenum) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>'; }; CREATE INFIX OPERATOR std::`<=` (l: std::anyenum, r: std::anyenum) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<='; }; CREATE INFIX OPERATOR std::`<` (l: std::anyenum, r: std::anyenum) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<'; }; ## Enum casts ## ---------- # The only way to create an enum is to cast a str into it, so it makes # sense to create an implicit assignment cast. CREATE CAST FROM std::str TO std::anyenum { SET volatility := 'Immutable'; USING SQL CAST; ALLOW ASSIGNMENT; }; CREATE CAST FROM std::anyenum TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::anyenum TO std::json { SET volatility := 'Immutable'; USING SQL "SELECT to_jsonb(val::text)" }; # Handled in compile_cast CREATE CAST FROM std::json TO std::anyenum { SET volatility := 'Immutable'; USING SQL EXPRESSION; }; ================================================ FILE: edb/lib/std/25-numoperators.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Standard numeric operators ## -------------------------- # NOTE: we follow PostgreSQL in creating an explicit operator # for each permutation of common integer and floating-point # operand types to avoid casting overhead, as these operations # are very common. # # Our implicit casts do not coincide with PostgreSQL. In particular we # do not implicitly cast between decimals and floats. The philosophy # behind that is that using decimal arithmetic should be opt-in. On # the other hand, if decimals are used they should not be accidentally # switched to floating point arithmetic. One of the consequences of # this is that we need to explicitly define arithmetic operators for # every legal combination of floats and decimals as unlike PostgreSQL # we cannot rely on implicit casts between decimals and other numeric # types. # # Floating point numbers are inherently imprecise. This means that # casting a given float into another representation and back may yield # a different value. This is especially important with float and # decimal casts as both directions can lose precision. Discussion # about precision loss of float to numeric casts can be found here: # https://www.postgresql.org/message-id/5A937D7E.60305%40anastigmatix.net # EQUALITY CREATE INFIX OPERATOR std::`=` (l: std::int16, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::int16, r: OPTIONAL std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::int16, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::int16, r: OPTIONAL std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::int16, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::int16, r: OPTIONAL std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::int32, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::int32, r: OPTIONAL std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::int32, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::int32, r: OPTIONAL std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::int32, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::int32, r: OPTIONAL std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::int64, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::int64, r: OPTIONAL std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::int64, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::int64, r: OPTIONAL std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::int64, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::int64, r: OPTIONAL std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::float32, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::float32, r: OPTIONAL std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::float32, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::float32, r: OPTIONAL std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::float64, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::float64, r: OPTIONAL std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::float64, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::float64, r: OPTIONAL std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::bigint, r: std::bigint) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'=(numeric,numeric)'; }; CREATE INFIX OPERATOR std::`=` (l: std::decimal, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`=` (l: std::decimal, r: std::anyint) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::bigint, r: OPTIONAL std::bigint) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::decimal, r: OPTIONAL std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::decimal, r: OPTIONAL std::anyint) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`=` (l: std::anyint, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::anyint, r: OPTIONAL std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; # INEQUALITY CREATE INFIX OPERATOR std::`!=` (l: std::int16, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::int16, r: OPTIONAL std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::int16, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::int16, r: OPTIONAL std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::int16, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::int16, r: OPTIONAL std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::int32, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::int32, r: OPTIONAL std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::int32, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::int32, r: OPTIONAL std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::int32, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::int32, r: OPTIONAL std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::int64, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::int64, r: OPTIONAL std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::int64, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::int64, r: OPTIONAL std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::int64, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::int64, r: OPTIONAL std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::float32, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::float32, r: OPTIONAL std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::float32, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::float32, r: OPTIONAL std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::float64, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::float64, r: OPTIONAL std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::float64, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::float64, r: OPTIONAL std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::bigint, r: std::bigint) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>(numeric,numeric)'; }; CREATE INFIX OPERATOR std::`!=` (l: std::decimal, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`!=` (l: std::decimal, r: std::anyint) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::bigint, r: OPTIONAL std::bigint) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::decimal, r: OPTIONAL std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::decimal, r: OPTIONAL std::anyint) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::anyint, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::anyint, r: OPTIONAL std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; # GREATER THAN CREATE INFIX OPERATOR std::`>` (l: std::int16, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::int16, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::int16, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::int32, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::int32, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::int32, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::int32, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>(float8,float8)'; }; CREATE INFIX OPERATOR std::`>` (l: std::int64, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::int64, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::int64, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::int64, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>(float8,float8)'; }; CREATE INFIX OPERATOR std::`>` (l: std::float32, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::float32, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::float32, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>(float8,float8)'; }; CREATE INFIX OPERATOR std::`>` (l: std::float64, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::float64, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::float64, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>(float8,float8)'; }; CREATE INFIX OPERATOR std::`>` (l: std::bigint, r: std::bigint) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>(numeric,numeric)'; }; CREATE INFIX OPERATOR std::`>` (l: std::decimal, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::anyint, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>` (l: std::decimal, r: std::anyint) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; # GREATER OR EQUAL CREATE INFIX OPERATOR std::`>=` (l: std::int16, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::int16, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::int16, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::int32, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::int32, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::int32, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::int32, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>=(float8,float8)'; }; CREATE INFIX OPERATOR std::`>=` (l: std::int64, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::int64, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::int64, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::int64, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>=(float8,float8)'; }; CREATE INFIX OPERATOR std::`>=` (l: std::float32, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::float32, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::float32, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>=(float8,float8)'; }; CREATE INFIX OPERATOR std::`>=` (l: std::float64, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::float64, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::float64, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>=(float8,float8)'; }; CREATE INFIX OPERATOR std::`>=` (l: std::bigint, r: std::bigint) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>=(numeric,numeric)'; }; CREATE INFIX OPERATOR std::`>=` (l: std::decimal, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::anyint, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`>=` (l: std::decimal, r: std::anyint) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; # LESS THAN CREATE INFIX OPERATOR std::`<` (l: std::int16, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::int16, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::int16, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::int32, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::int32, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::int32, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::int32, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<(float8,float8)'; }; CREATE INFIX OPERATOR std::`<` (l: std::int64, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::int64, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::int64, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::int64, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<(float8,float8)'; }; CREATE INFIX OPERATOR std::`<` (l: std::float32, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::float32, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::float32, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<(float8,float8)'; }; CREATE INFIX OPERATOR std::`<` (l: std::float64, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::float64, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::float64, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<(float8,float8)'; }; CREATE INFIX OPERATOR std::`<` (l: std::bigint, r: std::bigint) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<(numeric,numeric)'; }; CREATE INFIX OPERATOR std::`<` (l: std::decimal, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::anyint, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<` (l: std::decimal, r: std::anyint) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; # LESS THAN OR EQUAL CREATE INFIX OPERATOR std::`<=` (l: std::int16, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::int16, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::int16, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::int32, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::int32, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::int32, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::int32, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<=(float8,float8)'; }; CREATE INFIX OPERATOR std::`<=` (l: std::int64, r: std::int16) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::int64, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::int64, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::int64, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<=(float8,float8)'; }; CREATE INFIX OPERATOR std::`<=` (l: std::float32, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::float32, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::float32, r: std::int32) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<=(float8,float8)'; }; CREATE INFIX OPERATOR std::`<=` (l: std::float64, r: std::float32) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::float64, r: std::float64) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::float64, r: std::int64) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<=(float8,float8)'; }; CREATE INFIX OPERATOR std::`<=` (l: std::bigint, r: std::bigint) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<=(numeric,numeric)'; }; CREATE INFIX OPERATOR std::`<=` (l: std::decimal, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::anyint, r: std::decimal) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`<=` (l: std::decimal, r: std::anyint) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; # INFIX PLUS CREATE INFIX OPERATOR std::`+` (l: std::int16, r: std::int16) -> std::int16 { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL OPERATOR r'+'; }; CREATE INFIX OPERATOR std::`+` (l: std::int32, r: std::int32) -> std::int32 { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL OPERATOR r'+'; }; CREATE INFIX OPERATOR std::`+` (l: std::int64, r: std::int64) -> std::int64 { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL OPERATOR r'+'; }; CREATE INFIX OPERATOR std::`+` (l: std::float32, r: std::float32) -> std::float32 { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL OPERATOR r'+'; }; CREATE INFIX OPERATOR std::`+` (l: std::float64, r: std::float64) -> std::float64 { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL OPERATOR r'+'; }; CREATE INFIX OPERATOR std::`+` (l: std::bigint, r: std::bigint) -> std::bigint { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; SET force_return_cast := true; USING SQL OPERATOR r'+(numeric,numeric)'; }; CREATE INFIX OPERATOR std::`+` (l: std::decimal, r: std::decimal) -> std::decimal { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL OPERATOR r'+'; }; # PREFIX PLUS CREATE PREFIX OPERATOR std::`+` (l: std::int16) -> std::int16 { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'+'; }; CREATE PREFIX OPERATOR std::`+` (l: std::int32) -> std::int32 { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'+'; }; CREATE PREFIX OPERATOR std::`+` (l: std::int64) -> std::int64 { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'+'; }; CREATE PREFIX OPERATOR std::`+` (l: std::float32) -> std::float32 { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'+'; }; CREATE PREFIX OPERATOR std::`+` (l: std::float64) -> std::float64 { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'+'; }; CREATE PREFIX OPERATOR std::`+` (l: std::bigint) -> std::bigint { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL OPERATOR r'+(,numeric)'; }; CREATE PREFIX OPERATOR std::`+` (l: std::decimal) -> std::decimal { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Arithmetic addition.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'+'; }; # INFIX MINUS CREATE INFIX OPERATOR std::`-` (l: std::int16, r: std::int16) -> std::int16 { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; CREATE INFIX OPERATOR std::`-` (l: std::int32, r: std::int32) -> std::int32 { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; CREATE INFIX OPERATOR std::`-` (l: std::int64, r: std::int64) -> std::int64 { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; CREATE INFIX OPERATOR std::`-` (l: std::float32, r: std::float32) -> std::float32 { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; CREATE INFIX OPERATOR std::`-` (l: std::float64, r: std::float64) -> std::float64 { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; CREATE INFIX OPERATOR std::`-` (l: std::bigint, r: std::bigint) -> std::bigint { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL OPERATOR r'-(numeric,numeric)'; }; CREATE INFIX OPERATOR std::`-` (l: std::decimal, r: std::decimal) -> std::decimal { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; # PREFIX MINUS CREATE PREFIX OPERATOR std::`-` (l: std::int16) -> std::int16 { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; CREATE PREFIX OPERATOR std::`-` (l: std::int32) -> std::int32 { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; CREATE PREFIX OPERATOR std::`-` (l: std::int64) -> std::int64 { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; CREATE PREFIX OPERATOR std::`-` (l: std::float32) -> std::float32 { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; CREATE PREFIX OPERATOR std::`-` (l: std::float64) -> std::float64 { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; CREATE PREFIX OPERATOR std::`-` (l: std::bigint) -> std::bigint { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL OPERATOR r'-(,numeric)'; }; CREATE PREFIX OPERATOR std::`-` (l: std::decimal) -> std::decimal { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Arithmetic subtraction.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'-'; }; # MUL CREATE INFIX OPERATOR std::`*` (l: std::int16, r: std::int16) -> std::int16 { CREATE ANNOTATION std::identifier := 'mult'; CREATE ANNOTATION std::description := 'Arithmetic multiplication.'; SET volatility := 'Immutable'; SET commutator := 'std::*'; USING SQL OPERATOR r'*'; }; CREATE INFIX OPERATOR std::`*` (l: std::int32, r: std::int32) -> std::int32 { CREATE ANNOTATION std::identifier := 'mult'; CREATE ANNOTATION std::description := 'Arithmetic multiplication.'; SET volatility := 'Immutable'; SET commutator := 'std::*'; USING SQL OPERATOR r'*'; }; CREATE INFIX OPERATOR std::`*` (l: std::int64, r: std::int64) -> std::int64 { CREATE ANNOTATION std::identifier := 'mult'; CREATE ANNOTATION std::description := 'Arithmetic multiplication.'; SET volatility := 'Immutable'; SET commutator := 'std::*'; USING SQL OPERATOR r'*'; }; CREATE INFIX OPERATOR std::`*` (l: std::float32, r: std::float32) -> std::float32 { CREATE ANNOTATION std::identifier := 'mult'; CREATE ANNOTATION std::description := 'Arithmetic multiplication.'; SET volatility := 'Immutable'; SET commutator := 'std::*'; USING SQL OPERATOR r'*'; }; CREATE INFIX OPERATOR std::`*` (l: std::float64, r: std::float64) -> std::float64 { CREATE ANNOTATION std::identifier := 'mult'; CREATE ANNOTATION std::description := 'Arithmetic multiplication.'; SET volatility := 'Immutable'; SET commutator := 'std::*'; USING SQL OPERATOR r'*'; }; CREATE INFIX OPERATOR std::`*` (l: std::bigint, r: std::bigint) -> std::bigint { CREATE ANNOTATION std::identifier := 'mult'; CREATE ANNOTATION std::description := 'Arithmetic multiplication.'; SET volatility := 'Immutable'; SET commutator := 'std::*'; SET force_return_cast := true; USING SQL OPERATOR r'*(numeric,numeric)'; }; CREATE INFIX OPERATOR std::`*` (l: std::decimal, r: std::decimal) -> std::decimal { CREATE ANNOTATION std::identifier := 'mult'; CREATE ANNOTATION std::description := 'Arithmetic multiplication.'; SET volatility := 'Immutable'; SET commutator := 'std::*'; USING SQL OPERATOR r'*'; }; # DIV CREATE INFIX OPERATOR std::`/` (l: std::int64, r: std::int64) -> std::float64 { CREATE ANNOTATION std::identifier := 'div'; CREATE ANNOTATION std::description := 'Arithmetic division.'; SET volatility := 'Immutable'; # We need both USING SQL OPERATOR and USING SQL to copy # the common attributes of the SQL division operator while # overriding the main operator function. USING SQL OPERATOR r'/'; USING SQL 'SELECT "l" / ("r"::float8)'; }; CREATE INFIX OPERATOR std::`/` (l: std::float32, r: std::float32) -> std::float32 { CREATE ANNOTATION std::identifier := 'div'; CREATE ANNOTATION std::description := 'Arithmetic division.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'/'; }; CREATE INFIX OPERATOR std::`/` (l: std::float64, r: std::float64) -> std::float64 { CREATE ANNOTATION std::identifier := 'div'; CREATE ANNOTATION std::description := 'Arithmetic division.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'/'; }; CREATE INFIX OPERATOR std::`/` (l: std::decimal, r: std::decimal) -> std::decimal { CREATE ANNOTATION std::identifier := 'div'; CREATE ANNOTATION std::description := 'Arithmetic division.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'/'; }; # FLOORDIV # PostgreSQL uses truncation division, so the -12 % 5 is -2, because -12 // 5 # is -2, but EdgeQL uses floor division, so -12 // 5 is -3, and so -12 % 5 # must be 3. The correct divmod behavior is implemented via the floor # function working specifically with numeric type. The numeric value needs to # be forced into using arbitrary precision by getting multiplied by # 1.0::numeric. CREATE INFIX OPERATOR std::`//` (n: std::int16, d: std::int16) -> std::int16 { CREATE ANNOTATION std::identifier := 'floordiv'; CREATE ANNOTATION std::description := 'Floor division. Result is rounded down to the nearest integer'; SET volatility := 'Immutable'; # We need both USING SQL OPERATOR and USING SQL FUNCTION to copy # the common attributes of the SQL division operator while # overriding the main operator function. USING SQL OPERATOR r'/'; USING SQL 'SELECT floor(1.0::numeric * "n"::numeric / "d"::numeric)::int2'; }; CREATE INFIX OPERATOR std::`//` (n: std::int32, d: std::int32) -> std::int32 { CREATE ANNOTATION std::identifier := 'floordiv'; CREATE ANNOTATION std::description := 'Floor division. Result is rounded down to the nearest integer'; SET volatility := 'Immutable'; USING SQL OPERATOR r'/'; USING SQL 'SELECT floor(1.0::numeric * "n"::numeric / "d"::numeric)::int4'; }; CREATE INFIX OPERATOR std::`//` (n: std::int64, d: std::int64) -> std::int64 { CREATE ANNOTATION std::identifier := 'floordiv'; CREATE ANNOTATION std::description := 'Floor division. Result is rounded down to the nearest integer'; SET volatility := 'Immutable'; USING SQL OPERATOR r'/'; USING SQL 'SELECT floor(1.0::numeric * "n"::numeric / "d"::numeric)::int8'; }; CREATE INFIX OPERATOR std::`//` (n: std::float32, d: std::float32) -> std::float32 { CREATE ANNOTATION std::identifier := 'floordiv'; CREATE ANNOTATION std::description := 'Floor division. Result is rounded down to the nearest integer'; SET volatility := 'Immutable'; USING SQL OPERATOR r'/'; USING SQL 'SELECT floor("n" / "d")::float4'; }; CREATE INFIX OPERATOR std::`//` (n: std::float64, d: std::float64) -> std::float64 { CREATE ANNOTATION std::identifier := 'floordiv'; CREATE ANNOTATION std::description := 'Floor division. Result is rounded down to the nearest integer'; SET volatility := 'Immutable'; USING SQL OPERATOR r'/'; USING SQL 'SELECT floor("n" / "d")'; }; CREATE INFIX OPERATOR std::`//` (n: std::bigint, d: std::bigint) -> std::bigint { CREATE ANNOTATION std::identifier := 'floordiv'; CREATE ANNOTATION std::description := 'Floor division. Result is rounded down to the nearest integer'; SET volatility := 'Immutable'; USING SQL OPERATOR r'/(numeric,numeric)'; USING SQL $$ SELECT floor( 1.0::numeric * "n"::numeric / "d"::numeric )::edgedbt.bigint_t; $$; }; CREATE INFIX OPERATOR std::`//` (n: std::decimal, d: std::decimal) -> std::decimal { CREATE ANNOTATION std::identifier := 'floordiv'; CREATE ANNOTATION std::description := 'Floor division. Result is rounded down to the nearest integer'; SET volatility := 'Immutable'; USING SQL OPERATOR r'/'; USING SQL 'SELECT floor("n" / "d");' }; # MODULO # We have 2 issues to deal with: # 1) Postgres will produce a negative remainder for a posisitve divisor, # whereas generally it's a bit more intuitive to have the remainder in the # range [0, divisor). # 2) When implementing the modulo operator we need to make sure that addition # or subtraction doesn't cause an overflow. # # The easiest way to avoid overflow errors is to upcast values to a larger # integer type. However, upcasting int64 to bigint and back is very slow # (5x-6x slower), so we need a different approach here. # # The breakdown is like this: # - We only want to add `d` if `n` and `d` have opposite signs. # - XOR helps to isolate the sign bit if it is different. # - Right arithmetic shift by 63 bits produces an "all 1" bitmask for # negative integers and 0 otherwise. # - Performing AND using the above bitmask makes `d` go away if # `sign(n) = sign(d)` and keeps it as is otherwise. # - Finally we want to perform another MOD `d` operation to address the corner # case of 10 % -5 = -5 instead of 0 (which is equivalent, but does not # conform to making 0 inclusive and `d` itself exclusive). # # According to our microbenchmarks this kind of bit magic is no worse and # maybe slightly better than upcasting for int16 and int32 cases. CREATE INFIX OPERATOR std::`%` (n: std::int16, d: std::int16) -> std::int16 { CREATE ANNOTATION std::identifier := 'mod'; CREATE ANNOTATION std::description := 'Remainder from division (modulo).'; SET volatility := 'Immutable'; USING SQL OPERATOR r'%'; USING SQL $$ SELECT ( (n % d) + (d & ((n # d)>>15::int4)) ) % d $$; }; CREATE INFIX OPERATOR std::`%` (n: std::int32, d: std::int32) -> std::int32 { CREATE ANNOTATION std::identifier := 'mod'; CREATE ANNOTATION std::description := 'Remainder from division (modulo).'; SET volatility := 'Immutable'; USING SQL OPERATOR r'%'; USING SQL $$ SELECT ( (n % d) + (d & ((n # d)>>31::int4)) ) % d $$; }; CREATE INFIX OPERATOR std::`%` (n: std::int64, d: std::int64) -> std::int64 { CREATE ANNOTATION std::identifier := 'mod'; CREATE ANNOTATION std::description := 'Remainder from division (modulo).'; SET volatility := 'Immutable'; USING SQL OPERATOR r'%'; USING SQL $$ SELECT ( (n % d) + (d & ((n # d)>>63::int4)) ) % d $$; }; CREATE INFIX OPERATOR std::`%` (n: std::float32, d: std::float32) -> std::float32 { CREATE ANNOTATION std::identifier := 'mod'; CREATE ANNOTATION std::description := 'Remainder from division (modulo).'; SET volatility := 'Immutable'; # We cheat here a bit by copying most of SQL operator metadata # from the `/` operator, since there is no float % in Postgres. USING SQL OPERATOR r'/'; USING SQL $$ SELECT n - floor(n / d)::float4 * d; $$; }; CREATE INFIX OPERATOR std::`%` (n: std::float64, d: std::float64) -> std::float64 { CREATE ANNOTATION std::identifier := 'mod'; CREATE ANNOTATION std::description := 'Remainder from division (modulo).'; SET volatility := 'Immutable'; USING SQL OPERATOR r'/'; USING SQL $$ SELECT n - floor(n / d) * d; $$; }; CREATE INFIX OPERATOR std::`%` (n: std::bigint, d: std::bigint) -> std::bigint { CREATE ANNOTATION std::identifier := 'mod'; CREATE ANNOTATION std::description := 'Remainder from division (modulo).'; SET volatility := 'Immutable'; USING SQL OPERATOR r'%(numeric,numeric)'; USING SQL $$ SELECT (((n % d) + d) % d)::edgedbt.bigint_t; $$; }; CREATE INFIX OPERATOR std::`%` (n: std::decimal, d: std::decimal) -> std::decimal { CREATE ANNOTATION std::identifier := 'mod'; CREATE ANNOTATION std::description := 'Remainder from division (modulo).'; SET volatility := 'Immutable'; USING SQL OPERATOR r'%'; USING SQL $$ SELECT ((n % d) + d) % d; $$; }; # need an explicit operator for int64 in order to guarantee the result # is float64 and not decimal CREATE INFIX OPERATOR std::`^` (n: std::int64, p: std::int64) -> std::float64 { CREATE ANNOTATION std::identifier := 'pow'; CREATE ANNOTATION std::description := 'Power operation.'; SET volatility := 'Immutable'; # We cheat here a bit by copying most of SQL operator metadata # from the `/` operator, since there is no int ^ in Postgres. The # power operator can behave like a division (negative power), # therefore it should have the same basic properties w.r.t. types, # etc. We don't use an explicit cast of the result because # Postgres will treat this as float8 already. USING SQL OPERATOR r'/'; USING SQL 'SELECT ("n" ^ "p")'; }; CREATE INFIX OPERATOR std::`^` (n: std::float32, p: std::float32) -> std::float32 { CREATE ANNOTATION std::identifier := 'pow'; CREATE ANNOTATION std::description := 'Power operation.'; SET volatility := 'Immutable'; # We cheat here a bit by copying most of SQL operator metadata # from the `/` operator, since there is no float4 ^ in Postgres. # The power operator can behave like a division (negative power), # therefore it should have the same basic properties w.r.t. types, # etc. USING SQL OPERATOR '/'; USING SQL 'SELECT ("n" ^ "p")::float4'; }; CREATE INFIX OPERATOR std::`^` (n: std::float64, p: std::float64) -> std::float64 { CREATE ANNOTATION std::identifier := 'pow'; CREATE ANNOTATION std::description := 'Power operation.'; SET volatility := 'Immutable'; USING SQL OPERATOR '^'; }; CREATE INFIX OPERATOR std::`^` (n: std::bigint, p: std::bigint) -> std::decimal { CREATE ANNOTATION std::identifier := 'pow'; CREATE ANNOTATION std::description := 'Power operation.'; SET volatility := 'Immutable'; SET force_return_cast := true; USING SQL OPERATOR '^(numeric,numeric)'; }; CREATE INFIX OPERATOR std::`^` (n: std::decimal, p: std::decimal) -> std::decimal { CREATE ANNOTATION std::identifier := 'pow'; CREATE ANNOTATION std::description := 'Power operation.'; SET volatility := 'Immutable'; USING SQL OPERATOR '^'; }; ## Standard numeric casts ## ---------------------- ## Implicit casts between numerics. CREATE CAST FROM std::int16 TO std::int32 { SET volatility := 'Immutable'; USING SQL CAST; ALLOW IMPLICIT; }; CREATE CAST FROM std::int32 TO std::int64 { SET volatility := 'Immutable'; USING SQL CAST; ALLOW IMPLICIT; }; CREATE CAST FROM std::int16 TO std::float32 { SET volatility := 'Immutable'; USING SQL CAST; ALLOW IMPLICIT; }; CREATE CAST FROM std::int64 TO std::float64 { SET volatility := 'Immutable'; USING SQL CAST; ALLOW IMPLICIT; }; CREATE CAST FROM std::int64 TO std::bigint { SET volatility := 'Immutable'; USING SQL CAST; ALLOW IMPLICIT; }; CREATE CAST FROM std::int64 TO std::decimal { SET volatility := 'Immutable'; USING SQL CAST; ALLOW IMPLICIT; }; CREATE CAST FROM std::bigint TO std::decimal { SET volatility := 'Immutable'; USING SQL CAST; ALLOW IMPLICIT; }; CREATE CAST FROM std::float32 TO std::float64 { SET volatility := 'Immutable'; USING SQL CAST; ALLOW IMPLICIT; }; ## Explicit and assignment casts. CREATE CAST FROM std::int32 TO std::int16 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::int64 TO std::int32 { SET volatility := 'Immutable'; USING SQL CAST; ALLOW ASSIGNMENT; }; CREATE CAST FROM std::int64 TO std::int16 { SET volatility := 'Immutable'; USING SQL CAST; ALLOW ASSIGNMENT; }; CREATE CAST FROM std::int64 TO std::float32 { SET volatility := 'Immutable'; USING SQL CAST; ALLOW ASSIGNMENT; }; CREATE CAST FROM std::float64 TO std::float32 { SET volatility := 'Immutable'; USING SQL CAST; ALLOW ASSIGNMENT; }; CREATE CAST FROM std::decimal TO std::int16 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::decimal TO std::int32 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::decimal TO std::int64 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::decimal TO std::float64 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::decimal TO std::float32 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::decimal TO std::bigint { SET volatility := 'Immutable'; USING SQL 'SELECT round($1)::edgedbt.bigint_t'; }; CREATE CAST FROM std::float32 TO std::int16 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::float32 TO std::int32 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::float32 TO std::int64 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::float32 TO std::bigint { SET volatility := 'Immutable'; USING SQL 'SELECT round($1)::edgedbt.bigint_t'; }; CREATE CAST FROM std::float32 TO std::decimal { SET volatility := 'Immutable'; USING SQL $$ SELECT (CASE WHEN val != 'NaN' AND val != 'Infinity' AND val != '-Infinity' THEN val::numeric WHEN val IS NULL THEN NULL::numeric ELSE edgedb_VER.raise( NULL::numeric, 'invalid_text_representation', msg => 'invalid value for numeric: ' || quote_literal(val) ) END) ; $$; }; CREATE CAST FROM std::float64 TO std::int16 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::float64 TO std::int32 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::float64 TO std::int64 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::float64 TO std::bigint { SET volatility := 'Immutable'; USING SQL 'SELECT round($1)::edgedbt.bigint_t'; }; CREATE CAST FROM std::float64 TO std::decimal { SET volatility := 'Immutable'; USING SQL $$ SELECT (CASE WHEN val != 'NaN' AND val != 'Infinity' AND val != '-Infinity' THEN val::numeric WHEN val IS NULL THEN NULL::numeric ELSE edgedb_VER.raise( NULL::numeric, 'invalid_text_representation', msg => 'invalid value for numeric: ' || quote_literal(val) ) END) ; $$; }; ## String casts. CREATE CAST FROM std::str TO std::int16 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::str TO std::int32 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::str TO std::int64 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::str TO std::float32 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::str TO std::float64 { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::str TO std::bigint { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.str_to_bigint'; }; CREATE CAST FROM std::str TO std::decimal { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.str_to_decimal'; }; CREATE CAST FROM std::int16 TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::int32 TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::int64 TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::float32 TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::float64 TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::decimal TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; ================================================ FILE: edb/lib/std/25-setoperators.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Standard set operators ## -------------------------- # The set membership operators (IN, NOT IN) are defined # in terms of the corresponding equality operator. CREATE INFIX OPERATOR std::`IN` (e: anytype, s: SET OF anytype) -> std::bool { CREATE ANNOTATION std::identifier := 'in'; CREATE ANNOTATION std::description := 'Test the membership of an element in a set.'; USING SQL EXPRESSION; SET volatility := 'Immutable'; SET derivative_of := 'std::='; SET is_singleton_set_of := true; }; CREATE INFIX OPERATOR std::`NOT IN` (e: anytype, s: SET OF anytype) -> std::bool { CREATE ANNOTATION std::identifier := 'not_in'; CREATE ANNOTATION std::description := 'Test the membership of an element in a set.'; USING SQL EXPRESSION; SET volatility := 'Immutable'; SET derivative_of := 'std::!='; SET is_singleton_set_of := true; }; CREATE PREFIX OPERATOR std::`EXISTS` (s: SET OF anytype) -> bool { CREATE ANNOTATION std::identifier := 'exists'; CREATE ANNOTATION std::description := 'Test whether a set is not empty.'; SET volatility := 'Immutable'; SET is_singleton_set_of := true; USING SQL EXPRESSION; }; CREATE PREFIX OPERATOR std::`DISTINCT` (s: SET OF anytype) -> SET OF anytype { CREATE ANNOTATION std::identifier := 'distinct'; CREATE ANNOTATION std::description := 'Return a set without repeating any elements.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`UNION` (s1: SET OF anytype, s2: SET OF anytype) -> SET OF anytype { CREATE ANNOTATION std::identifier := 'union'; CREATE ANNOTATION std::description := 'Merge two sets.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`EXCEPT` (s1: SET OF anytype, s2: SET OF anytype) -> SET OF anytype { CREATE ANNOTATION std::identifier := 'except'; CREATE ANNOTATION std::description := 'Multiset difference.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`INTERSECT` (s1: SET OF anytype, s2: SET OF anytype) -> SET OF anytype { CREATE ANNOTATION std::identifier := 'intersect'; CREATE ANNOTATION std::description := 'Multiset intersection.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`??` (l: OPTIONAL anytype, r: SET OF anytype) -> SET OF anytype { CREATE ANNOTATION std::identifier := 'coalesce'; CREATE ANNOTATION std::description := 'Coalesce.'; SET volatility := 'Immutable'; SET is_singleton_set_of := true; USING SQL EXPRESSION; }; CREATE TERNARY OPERATOR std::`IF` (if_true: SET OF anytype, condition: bool, if_false: SET OF anytype) -> SET OF anytype { CREATE ANNOTATION std::identifier := 'if_else'; CREATE ANNOTATION std::description := 'Conditionally provide one or the other result.'; SET volatility := 'Immutable'; SET is_singleton_set_of := true; USING SQL EXPRESSION; }; ================================================ FILE: edb/lib/std/26-bitwisefuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Bitwise numeric functions ## ------------------------- CREATE FUNCTION std::bit_and(l: std::int16, r: std::int16) -> std::int16 { CREATE ANNOTATION std::description := 'Bitwise AND operator for 16-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT l & r $$; }; CREATE FUNCTION std::bit_and(l: std::int32, r: std::int32) -> std::int32 { CREATE ANNOTATION std::description := 'Bitwise AND operator for 32-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT l & r $$; }; CREATE FUNCTION std::bit_and(l: std::int64, r: std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Bitwise AND operator for 64-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT l & r $$; }; CREATE FUNCTION std::bit_or(l: std::int16, r: std::int16) -> std::int16 { CREATE ANNOTATION std::description := 'Bitwise OR operator for 16-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT l | r $$; }; CREATE FUNCTION std::bit_or(l: std::int32, r: std::int32) -> std::int32 { CREATE ANNOTATION std::description := 'Bitwise OR operator for 32-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT l | r $$; }; CREATE FUNCTION std::bit_or(l: std::int64, r: std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Bitwise OR operator for 64-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT l | r $$; }; CREATE FUNCTION std::bit_xor(l: std::int16, r: std::int16) -> std::int16 { CREATE ANNOTATION std::description := 'Bitwise exclusive OR operator for 16-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT l # r $$; }; CREATE FUNCTION std::bit_xor(l: std::int32, r: std::int32) -> std::int32 { CREATE ANNOTATION std::description := 'Bitwise exclusive OR operator for 32-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT l # r $$; }; CREATE FUNCTION std::bit_xor(l: std::int64, r: std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Bitwise exclusive OR operator for 64-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT l # r $$; }; CREATE FUNCTION std::bit_not(r: std::int16) -> std::int16 { CREATE ANNOTATION std::description := 'Bitwise NOT operator for 16-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ~r $$; }; CREATE FUNCTION std::bit_not(r: std::int32) -> std::int32 { CREATE ANNOTATION std::description := 'Bitwise NOT operator for 32-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ~r $$; }; CREATE FUNCTION std::bit_not(r: std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Bitwise NOT operator for 64-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ~r $$; }; # In Postgres bitwise shift operators accept a 32-bit integer as the number of # bit positions that need to be shifted. However, in EdgeDB the default # integer literal is int64, so we should accept that and cast it down inside # the function body. # # In Postgres the number of bits shifted gets truncated using a positive mod # 32 (or mod 64 for int8). We do not want such truncation in EdgeDB. Shifting by 20 bits 2 times # should bethe same as shifting by 40 bits once. CREATE FUNCTION std::bit_rshift(val: std::int16, n: std::int64) -> std::int16 { CREATE ANNOTATION std::description := 'Bitwise right-shift operator for 16-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN n < 0 THEN edgedb_VER.raise( NULL::int8, 'invalid_parameter_value', msg => ( 'bit_rshift(): cannot shift by negative amount' ) ) WHEN n > 31 THEN CASE WHEN val < 0 THEN -1 ELSE 0 END ELSE val >> n::int4 END ) $$; }; CREATE FUNCTION std::bit_rshift(val: std::int32, n: std::int64) -> std::int32 { CREATE ANNOTATION std::description := 'Bitwise right-shift operator for 32-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN n < 0 THEN edgedb_VER.raise( NULL::int8, 'invalid_parameter_value', msg => ( 'bit_rshift(): cannot shift by negative amount' ) ) WHEN n > 31 THEN CASE WHEN val < 0 THEN -1 ELSE 0 END ELSE val >> n::int4 END ) $$; }; CREATE FUNCTION std::bit_rshift(val: std::int64, n: std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Bitwise right-shift operator for 64-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN n < 0 THEN edgedb_VER.raise( NULL::int8, 'invalid_parameter_value', msg => ( 'bit_rshift(): cannot shift by negative amount' ) ) WHEN n > 63 THEN CASE WHEN val < 0 THEN -1 ELSE 0 END ELSE val >> n::int4 END ) $$; }; CREATE FUNCTION std::bit_lshift(val: std::int16, n: std::int64) -> std::int16 { CREATE ANNOTATION std::description := 'Bitwise left-shift operator for 16-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN n < 0 THEN edgedb_VER.raise( NULL::int8, 'invalid_parameter_value', msg => ( 'bit_lshift(): cannot shift by negative amount' ) ) WHEN n > 31 THEN 0 ELSE val << n::int4 END ) $$; }; CREATE FUNCTION std::bit_lshift(val: std::int32, n: std::int64) -> std::int32 { CREATE ANNOTATION std::description := 'Bitwise left-shift operator for 32-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN n < 0 THEN edgedb_VER.raise( NULL::int8, 'invalid_parameter_value', msg => ( 'bit_lshift(): cannot shift by negative amount' ) ) WHEN n > 31 THEN 0 ELSE val << n::int4 END ) $$; }; CREATE FUNCTION std::bit_lshift(val: std::int64, n: std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Bitwise left-shift operator for 64-bit integers.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN n < 0 THEN edgedb_VER.raise( NULL::int8, 'invalid_parameter_value', msg => ( 'bit_lshift(): cannot shift by negative amount' ) ) WHEN n > 63 THEN 0 ELSE val << n::int4 END ) $$; }; CREATE FUNCTION std::bit_count(val: std::int16) -> std::int64 { CREATE ANNOTATION std::description := 'Count the number of set bits in a 16-bit integer.'; SET volatility := 'Immutable'; USING SQL $$ SELECT bit_count(val::int4::bit(16)) $$; }; CREATE FUNCTION std::bit_count(val: std::int32) -> std::int64 { CREATE ANNOTATION std::description := 'Count the number of set bits in a 32-bit integer.'; SET volatility := 'Immutable'; USING SQL $$ SELECT bit_count(val::bit(32)) $$; }; CREATE FUNCTION std::bit_count(val: std::int64) -> std::int64 { CREATE ANNOTATION std::description := 'Count the number of set bits in a 64-bit integer.'; SET volatility := 'Immutable'; USING SQL $$ SELECT bit_count(val::bit(64)) $$; }; ================================================ FILE: edb/lib/std/30-arrayfuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Array functions CREATE FUNCTION std::array_agg(s: SET OF anytype) -> array { CREATE ANNOTATION std::description := 'Return the array made from all of the input set elements.'; SET volatility := 'Immutable'; SET initial_value := []; SET impl_is_strict := false; USING SQL FUNCTION 'array_agg'; }; CREATE FUNCTION std::array_unpack(array: array) -> SET OF anytype { CREATE ANNOTATION std::description := 'Return array elements as a set.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'unnest'; }; CREATE FUNCTION std::array_fill(val: anytype, n: std::int64) -> array { CREATE ANNOTATION std::description := 'Return an array filled with the given value repeated \ as many times as specified.'; SET volatility := 'Immutable'; # Postgres uses integer (int4) as the second argument. There is a maximum # array size, however. So when we get an `n` value greater than maximum # int4, we just truncate it to the maximum and let Postgres produce its # error. USING SQL $$ SELECT array_fill( val, ARRAY[(CASE WHEN n > 2147483647 THEN 2147483647 ELSE n END)::int4] ) $$; }; CREATE FUNCTION std::array_replace( array: array, old: anytype, new: anytype ) -> array { CREATE ANNOTATION std::description := 'Replace each array element equal to the second argument \ with the third argument.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'array_replace'; }; CREATE FUNCTION std::array_get( array: array, idx: std::int64, NAMED ONLY default: OPTIONAL anytype={} ) -> OPTIONAL anytype { CREATE ANNOTATION std::description := 'Return the element of *array* at the specified *index*.'; SET volatility := 'Immutable'; USING SQL $$ SELECT COALESCE( "array"[ edgedb_VER._normalize_array_index( "idx"::int, array_upper("array", 1)) ], "default" ) $$; }; CREATE FUNCTION std::array_set( array: array, idx: std::int64, val: anytype ) -> array { CREATE ANNOTATION std::description := 'Set the element of *array* at the specified *index*.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN cardinality("array") = 0 THEN edgedb.raise( "array", 'invalid_parameter_value', msg => 'array index ' || idx::text || ' is out of bounds' ) WHEN edgedb._normalize_array_index( "idx"::int, array_upper("array", 1) ) NOT BETWEEN 1 and array_upper("array", 1) THEN edgedb.raise( "array", 'invalid_parameter_value', msg => 'array index ' || idx::text || ' is out of bounds' ) WHEN edgedb._normalize_array_index( "idx"::int, array_upper("array", 1) ) = 1 THEN ARRAY[val] || "array"[2 :] WHEN edgedb._normalize_array_index( "idx"::int, array_upper("array", 1) ) = array_upper("array", 1) THEN "array"[: array_upper("array", 1) - 1] || ARRAY[val] ELSE "array"[ : edgedb._normalize_array_index( "idx"::int, array_upper("array", 1) ) - 1 ] || ARRAY[val] || "array"[ edgedb._normalize_array_index( "idx"::int, array_upper("array", 1) ) + 1 : ] END $$; }; CREATE FUNCTION std::array_insert( array: array, idx: std::int64, val: anytype ) -> array { CREATE ANNOTATION std::description := 'Insert *val* at the specified *index* of the *array*.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN cardinality("array") = 0 AND "idx"::int != 0 THEN edgedb.raise( "array", 'invalid_parameter_value', msg => 'array index ' || idx::text || ' is out of bounds' ) WHEN cardinality("array") = 0 AND "idx"::int = 0 THEN ARRAY[val] WHEN edgedb._normalize_array_index( "idx"::int, array_upper("array", 1) ) NOT BETWEEN 1 and array_upper("array", 1) + 1 THEN edgedb.raise( "array", 'invalid_parameter_value', msg => 'array index ' || idx::text || ' is out of bounds' ) WHEN edgedb._normalize_array_index( "idx"::int, array_upper("array", 1) ) = 1 THEN ARRAY[val] || "array" WHEN edgedb._normalize_array_index( "idx"::int, array_upper("array", 1) ) = array_upper("array", 1) + 1 THEN "array" || ARRAY[val] ELSE "array"[ : edgedb._normalize_array_index( "idx"::int, array_upper("array", 1) ) - 1 ] || ARRAY[val] || "array"[ edgedb._normalize_array_index( "idx"::int, array_upper("array", 1) ) : ] END $$; }; CREATE FUNCTION std::array_join(array: array, delimiter: std::str) -> std::str { CREATE ANNOTATION std::description := 'Render an array to a string.'; # The Postgres function array_to_string works for any array type, but we # use it specifically for string arrays. For string arrays it should be # "immutable". SET volatility := 'Immutable'; USING SQL $$ SELECT array_to_string("array", "delimiter"); $$; }; CREATE FUNCTION std::array_join(array: array, delimiter: std::bytes) -> std::bytes { CREATE ANNOTATION std::description := 'Render an array to a byte-string.'; SET volatility := 'Immutable'; USING SQL $$ SELECT COALESCE (string_agg(el, "delimiter"), '\x') FROM (SELECT unnest("array") AS el) AS t $$; }; ## Array operators CREATE INFIX OPERATOR std::`=` (l: array, r: array) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR '='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL array, r: OPTIONAL array) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; SET recursive := true; }; CREATE INFIX OPERATOR std::`!=` (l: array, r: array) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR '<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL array, r: OPTIONAL array) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; SET recursive := true; }; CREATE INFIX OPERATOR std::`>=` (l: array, r: array) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>='; }; CREATE INFIX OPERATOR std::`>` (l: array, r: array) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>'; }; CREATE INFIX OPERATOR std::`<=` (l: array, r: array) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<='; }; CREATE INFIX OPERATOR std::`<` (l: array, r: array) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<'; }; # Concatenation CREATE INFIX OPERATOR std::`++` (l: array, r: array) -> array { CREATE ANNOTATION std::identifier := 'concat'; CREATE ANNOTATION std::description := 'Array concatenation.'; SET volatility := 'Immutable'; SET impl_is_strict := false; USING SQL FUNCTION 'array_cat'; }; CREATE INFIX OPERATOR std::`[]` (l: array, r: std::int64) -> anytype { CREATE ANNOTATION std::identifier := 'index'; CREATE ANNOTATION std::description := 'Array indexing.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`[]` (l: array, r: tuple) -> array { CREATE ANNOTATION std::identifier := 'slice'; CREATE ANNOTATION std::description := 'Array slicing.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; ================================================ FILE: edb/lib/std/30-bytesfuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Byte string functions ## --------------------- CREATE FUNCTION std::bytes_get_bit(bytes: std::bytes, num: int64) -> std::int64 { CREATE ANNOTATION std::description := 'Get the *nth* bit of the *bytes* value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT get_bit("bytes", "num"::int)::bigint $$; }; CREATE FUNCTION std::bit_count(bytes: std::bytes) -> std::int64 { CREATE ANNOTATION std::description := 'Count the number of set bits the bytes value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT bit_count(bytes) $$; }; ## Byte string operators ## --------------------- CREATE INFIX OPERATOR std::`=` (l: std::bytes, r: std::bytes) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::bytes, r: OPTIONAL std::bytes) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::bytes, r: std::bytes) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::bytes, r: OPTIONAL std::bytes) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`++` (l: std::bytes, r: std::bytes) -> std::bytes { CREATE ANNOTATION std::identifier := 'concat'; CREATE ANNOTATION std::description := 'Bytes concatenation.'; SET volatility := 'Immutable'; USING SQL OPERATOR r'||'; }; CREATE INFIX OPERATOR std::`>=` (l: std::bytes, r: std::bytes) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>='; }; CREATE INFIX OPERATOR std::`>` (l: std::bytes, r: std::bytes) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>'; }; CREATE INFIX OPERATOR std::`<=` (l: std::bytes, r: std::bytes) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<='; }; CREATE INFIX OPERATOR std::`<` (l: std::bytes, r: std::bytes) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<'; }; CREATE INFIX OPERATOR std::`[]` (l: std::bytes, r: std::int64) -> std::bytes { CREATE ANNOTATION std::identifier := 'index'; CREATE ANNOTATION std::description := 'Bytes indexing.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`[]` (l: std::bytes, r: tuple) -> std::bytes { CREATE ANNOTATION std::identifier := 'slice'; CREATE ANNOTATION std::description := 'Bytes slicing.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; ================================================ FILE: edb/lib/std/30-datetimefuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Date/time functions ## ------------------- CREATE FUNCTION std::datetime_current() -> std::datetime { CREATE ANNOTATION std::description := 'Return the current server date and time.'; SET volatility := 'Volatile'; SET force_return_cast := true; USING SQL FUNCTION 'clock_timestamp'; }; CREATE FUNCTION std::datetime_of_transaction() -> std::datetime { CREATE ANNOTATION std::description := 'Return the date and time of the start of the current transaction.'; SET volatility := 'Stable'; SET force_return_cast := true; USING SQL FUNCTION 'transaction_timestamp'; }; CREATE FUNCTION std::datetime_of_statement() -> std::datetime { CREATE ANNOTATION std::description := 'Return the date and time of the start of the current statement.'; SET volatility := 'Stable'; SET force_return_cast := true; USING SQL FUNCTION 'statement_timestamp'; }; CREATE FUNCTION std::datetime_get(dt: std::datetime, el: std::str) -> std::float64 { CREATE ANNOTATION std::description := 'Extract a specific element of input datetime by name.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "el" IN ( 'century', 'day', 'decade', 'dow', 'doy', 'hour', 'isodow', 'isoyear', 'microseconds', 'millennium', 'milliseconds', 'minutes', 'month', 'quarter', 'seconds', 'week', 'year') THEN date_part("el", "dt") WHEN "el" = 'epochseconds' THEN date_part('epoch', "dt") ELSE edgedb_VER.raise( NULL::float, 'invalid_datetime_format', msg => ( 'invalid unit for std::datetime_get: ' || quote_literal("el") ), detail => ( '{"hint":"Supported units: epochseconds, century, day, ' || 'decade, dow, doy, hour, isodow, isoyear, ' || 'microseconds, millennium, milliseconds, minutes, ' || 'month, quarter, seconds, week, year."}' ) ) END $$; }; CREATE FUNCTION std::datetime_truncate(dt: std::datetime, unit: std::str) -> std::datetime { CREATE ANNOTATION std::description := 'Truncate the input datetime to a particular precision.'; # date_trunc of timestamptz is STABLE in PostgreSQL SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "unit" IN ( 'microseconds', 'milliseconds', 'seconds', 'minutes', 'hours', 'days', 'weeks', 'months', 'years', 'decades', 'centuries') THEN date_trunc("unit", "dt")::edgedbt.timestamptz_t WHEN "unit" = 'quarters' THEN date_trunc('quarter', "dt")::edgedbt.timestamptz_t ELSE edgedb_VER.raise( NULL::edgedbt.timestamptz_t, 'invalid_datetime_format', msg => ( 'invalid unit for std::datetime_truncate: ' || quote_literal("unit") ), detail => ( '{"hint":"Supported units: microseconds, milliseconds, ' || 'seconds, minutes, hours, days, weeks, months, ' || 'quarters, years, decades, centuries."}' ) ) END $$; }; CREATE FUNCTION std::duration_get(dt: std::duration, el: std::str) -> std::float64 { CREATE ANNOTATION std::description := 'Extract a specific element of input duration by name.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "el" IN ( 'hour', 'minutes', 'seconds', 'milliseconds', 'microseconds') THEN date_part("el", "dt") WHEN "el" = 'totalseconds' THEN date_part('epoch', "dt") ELSE edgedb_VER.raise( NULL::float, 'invalid_datetime_format', msg => ( 'invalid unit for std::duration_get: ' || quote_literal("el") ), detail => ( '{"hint":"Supported units: ' || 'hour, minutes, seconds, milliseconds, microseconds, ' || 'and totalseconds."}' ) ) END $$; }; CREATE FUNCTION std::duration_truncate(dt: std::duration, unit: std::str) -> std::duration { CREATE ANNOTATION std::description := 'Truncate the input duration to a particular precision.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN "unit" in ('microseconds', 'milliseconds', 'seconds', 'minutes', 'hours') THEN date_trunc("unit", "dt")::edgedbt.duration_t ELSE edgedb_VER.raise( NULL::edgedbt.duration_t, 'invalid_datetime_format', msg => ( 'invalid unit for std::duration_truncate: ' || quote_literal("unit") ), detail => ( '{"hint":"Supported units: microseconds, milliseconds, ' || 'seconds, minutes, hours."}' ) ) END $$; }; CREATE FUNCTION std::duration_to_seconds(dur: std::duration) -> std::decimal { CREATE ANNOTATION std::description := 'Return duration as total number of seconds in interval.'; SET volatility := 'Immutable'; USING SQL $$ SELECT EXTRACT(epoch FROM date_trunc('minute', dur))::bigint::decimal + '0.000001'::decimal*EXTRACT(microsecond FROM dur)::decimal $$; }; ## Date/time operators ## ------------------- # std::datetime CREATE INFIX OPERATOR std::`=` (l: std::datetime, r: std::datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'=(timestamptz,timestamptz)'; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::datetime, r: OPTIONAL std::datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::datetime, r: std::datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>(timestamptz,timestamptz)'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::datetime, r: OPTIONAL std::datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>` (l: std::datetime, r: std::datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>(timestamptz,timestamptz)'; }; CREATE INFIX OPERATOR std::`>=` (l: std::datetime, r: std::datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>=(timestamptz,timestamptz)'; }; CREATE INFIX OPERATOR std::`<` (l: std::datetime, r: std::datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<(timestamptz,timestamptz)'; }; CREATE INFIX OPERATOR std::`<=` (l: std::datetime, r: std::datetime) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<=(timestamptz,timestamptz)'; }; CREATE INFIX OPERATOR std::`+` (l: std::datetime, r: std::duration) -> std::datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; # Immutable because datetime is guaranteed to be in UTC and no DST issues # should affect this. SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamptz_t $$ }; CREATE INFIX OPERATOR std::`+` (l: std::duration, r: std::datetime) -> std::datetime { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval and date/time addition.'; # Immutable because datetime is guaranteed to be in UTC and no DST issues # should affect this. SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l" + "r")::edgedbt.timestamptz_t $$ }; CREATE INFIX OPERATOR std::`-` (l: std::datetime, r: std::duration) -> std::datetime { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval and date/time subtraction.'; # Immutable because datetime is guaranteed to be in UTC and no DST issues # should affect this. SET volatility := 'Immutable'; USING SQL $$ SELECT ("l" - "r")::edgedbt.timestamptz_t $$ }; CREATE INFIX OPERATOR std::`-` (l: std::datetime, r: std::datetime) -> std::duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Date/time subtraction.'; # Immutable because datetime is guaranteed to be in UTC and no DST issues # should affect this. SET volatility := 'Immutable'; USING SQL $$ SELECT EXTRACT(epoch FROM "l" - "r")::text::edgedbt.duration_t $$ }; # std::duration CREATE INFIX OPERATOR std::`=` (l: std::duration, r: std::duration) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'=(interval,interval)'; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::duration, r: OPTIONAL std::duration) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::duration, r: std::duration) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>(interval,interval)'; }; CREATE INFIX OPERATOR std::`?!=` ( l: OPTIONAL std::duration, r: OPTIONAL std::duration ) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>` (l: std::duration, r: std::duration) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>(interval,interval)'; }; CREATE INFIX OPERATOR std::`>=` (l: std::duration, r: std::duration) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>=(interval,interval)'; }; CREATE INFIX OPERATOR std::`<` (l: std::duration, r: std::duration) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<(interval,interval)'; }; CREATE INFIX OPERATOR std::`<=` (l: std::duration, r: std::duration) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<=(interval,interval)'; }; CREATE INFIX OPERATOR std::`+` (l: std::duration, r: std::duration) -> std::duration { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Time interval addition.'; SET volatility := 'Immutable'; SET commutator := 'std::+'; USING SQL $$ SELECT ("l"::interval + "r"::interval)::edgedbt.duration_t; $$; }; CREATE INFIX OPERATOR std::`-` (l: std::duration, r: std::duration) -> std::duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval subtraction.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ("l"::interval - "r"::interval)::edgedbt.duration_t; $$; }; CREATE PREFIX OPERATOR std::`-` (v: std::duration) -> std::duration { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Time interval negation.'; SET volatility := 'Immutable'; USING SQL $$ SELECT (-"v"::interval)::edgedbt.duration_t; $$; }; ## String casts CREATE CAST FROM std::str TO std::datetime { # Stable because the input string can contain an explicit time-zone. Time # zones are externally defined things that can change suddenly and # arbitrarily by human laws, thus potentially changing the interpretatio # of the input string. SET volatility := 'Stable'; USING SQL FUNCTION 'edgedb.datetime_in'; }; CREATE CAST FROM std::str TO std::duration { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.duration_in'; }; # Normalize [local] datetime to text conversion to have # the same format as one would get by serializing to JSON. # Otherwise Postgres doesn't follow the ISO8601 standard # and uses ' ' instead of 'T' as a separator between date # and time. CREATE CAST FROM std::datetime TO std::str { SET volatility := 'Immutable'; USING SQL $$ SELECT trim(to_json(val)::text, '"'); $$; }; CREATE CAST FROM std::duration TO std::str { SET volatility := 'Immutable'; USING SQL $$ SELECT regexp_replace(val::text, '[[:<:]]mon(?=s?[[:>:]])', 'month'); $$; }; # std::sum CREATE FUNCTION std::sum(s: SET OF std::duration) -> std::duration { CREATE ANNOTATION std::description := 'Return the arithmetic sum of values in a set.'; SET volatility := 'Immutable'; SET initial_value := "PT0S"; SET force_return_cast := true; USING SQL FUNCTION 'sum'; }; ================================================ FILE: edb/lib/std/30-jsonfuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## JSON functions and operators. CREATE SCALAR TYPE std::JsonEmpty EXTENDING enum; CREATE FUNCTION std::json_typeof(json: std::json) -> std::str { CREATE ANNOTATION std::description := 'Return the type of the outermost JSON value as a string.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'jsonb_typeof'; }; CREATE FUNCTION std::json_array_unpack(array: std::json) -> SET OF std::json { CREATE ANNOTATION std::description := 'Return elements of JSON array as a set of `json`.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'jsonb_array_elements'; }; CREATE FUNCTION std::json_object_unpack(obj: std::json) -> SET OF tuple { CREATE ANNOTATION std::description := 'Return set of key/value tuples that make up the JSON object.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'jsonb_each'; # jsonb_each is defined as (jsonb, OUT key text, OUT value jsonb), # and, quite perprexingly, would reject a column definition list # with `a column definition list is only allowed for functions # returning "record"`, even though it _is_ returning "record". # Hence, we need this flag to tell the compiler to avoid generating # a coldeflist for this function. SET sql_func_has_out_params := True; }; CREATE FUNCTION std::json_object_pack(pairs: SET OF tuple) -> std::json { CREATE ANNOTATION std::description := 'Return a JSON object with set key/value pairs.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE FUNCTION std::json_get( json: std::json, VARIADIC path: std::str, NAMED ONLY default: OPTIONAL std::json={}) -> OPTIONAL std::json { CREATE ANNOTATION std::description := 'Return the JSON value at the end of the specified path or an empty set.'; SET volatility := 'Immutable'; USING SQL $$ SELECT COALESCE( jsonb_extract_path("json", VARIADIC "path"), "default" ) $$; }; CREATE FUNCTION std::__json_get_not_null( json: std::json, VARIADIC path: std::str, NAMED ONLY detail: std::str='') -> OPTIONAL std::json { SET volatility := 'Immutable'; SET internal := true; USING SQL $$ SELECT CASE WHEN "json" = 'null'::jsonb THEN NULL ELSE edgedb_VER.raise_on_null( jsonb_extract_path("json", VARIADIC "path"), 'invalid_parameter_value', 'missing value in JSON object', detail => detail ) END $$; }; CREATE FUNCTION std::json_set( target: std::json, VARIADIC path: std::str, NAMED ONLY value: OPTIONAL std::json, NAMED ONLY create_if_missing: std::bool = true, NAMED ONLY empty_treatment: std::JsonEmpty = std::JsonEmpty.ReturnEmpty, ) -> OPTIONAL std::json { CREATE ANNOTATION std::description := 'Return an updated JSON target with a new value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "value" IS NULL AND "empty_treatment" = 'ReturnEmpty' THEN NULL WHEN "value" IS NULL AND "empty_treatment" = 'ReturnTarget' THEN "target" WHEN "value" IS NULL AND "empty_treatment" = 'Error' THEN edgedb_VER.raise( NULL::jsonb, 'invalid_parameter_value', msg => 'invalid empty JSON value' ) WHEN "value" IS NULL AND "empty_treatment" = 'UseNull' THEN jsonb_set("target", "path", 'null'::jsonb, "create_if_missing") WHEN "value" IS NULL AND "empty_treatment" = 'DeleteKey' THEN "target" #- "path" ELSE jsonb_set("target", "path", "value", "create_if_missing") END ) $$; }; CREATE INFIX OPERATOR std::`=` (l: std::json, r: std::json) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::json, r: OPTIONAL std::json) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::json, r: std::json) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::json, r: OPTIONAL std::json) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>=` (l: std::json, r: std::json) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>='; }; CREATE INFIX OPERATOR std::`>` (l: std::json, r: std::json) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>'; }; CREATE INFIX OPERATOR std::`<=` (l: std::json, r: std::json) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<='; }; CREATE INFIX OPERATOR std::`<` (l: std::json, r: std::json) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<'; }; CREATE INFIX OPERATOR std::`[]` (l: std::json, r: std::int64) -> std::json { CREATE ANNOTATION std::identifier := 'index'; CREATE ANNOTATION std::description := 'JSON array/string indexing.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`[]` (l: std::json, r: tuple) -> std::json { CREATE ANNOTATION std::identifier := 'slice'; CREATE ANNOTATION std::description := 'JSON array/string slicing.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`[]` (l: std::json, r: std::str) -> std::json { CREATE ANNOTATION std::identifier := 'destructure'; CREATE ANNOTATION std::description := 'JSON object property access.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`++` (l: std::json, r: std::json) -> std::json { CREATE ANNOTATION std::identifier := 'concatenate'; CREATE ANNOTATION std::description := 'Concatenate two JSON values into a new JSON value.'; SET volatility := 'Stable'; USING SQL $$ SELECT ( CASE WHEN jsonb_typeof("l") = 'array' AND jsonb_typeof("r") = 'array' THEN "l" || "r" WHEN jsonb_typeof("l") = 'object' AND jsonb_typeof("r") = 'object' THEN "l" || "r" WHEN jsonb_typeof("l") = 'string' AND jsonb_typeof("r") = 'string' THEN to_jsonb(("l"#>>'{}') || ("r"#>>'{}')) ELSE edgedb_VER.raise( NULL::jsonb, 'invalid_parameter_value', msg => ( 'invalid JSON values for ++ operator' ), detail => ( '{"hint":"Supported JSON types for concatenation: ' || 'array ++ array, object ++ object, string ++ string."}' ) ) END ) $$; }; ## CASTS # This is only a container cast, and subject to element type cast # availability. CREATE CAST FROM array TO std::json { SET volatility := 'Stable'; USING SQL FUNCTION 'to_jsonb'; }; # This is only a container cast, and subject to element type cast # availability. CREATE CAST FROM anytuple TO std::json { SET volatility := 'Stable'; USING SQL EXPRESSION; }; CREATE CAST FROM std::json TO anytuple { SET volatility := 'Stable'; USING SQL EXPRESSION; }; CREATE FUNCTION std::__tuple_validate_json( v: std::json, allow_null: std::bool, detail: std::str='' ) -> OPTIONAL std::json { SET volatility := 'Immutable'; SET internal := true; USING SQL $$ SELECT CASE WHEN v = 'null'::jsonb AND NOT allow_null THEN edgedb_VER.raise( NULL::jsonb, 'wrong_object_type', msg => 'invalid null value in cast', detail => detail ) ELSE edgedb_VER.jsonb_assert_type( v, ARRAY['array','object','null'], detail => detail ) END; $$; }; CREATE CAST FROM std::json TO array { SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN nullif(val, 'null'::jsonb) IS NULL THEN NULL ELSE (SELECT COALESCE(array_agg(j), ARRAY[]::jsonb[]) FROM jsonb_array_elements( edgedb_VER.jsonb_assert_type(val, ARRAY['array'], detail => detail) ) as j) END ) $$; }; CREATE CAST FROM std::json TO array { SET volatility := 'Stable'; USING SQL EXPRESSION; }; CREATE FUNCTION std::__range_validate_json(val: std::json, detail: std::str='') -> OPTIONAL std::json { SET volatility := 'Immutable'; SET internal := true; USING SQL $$ SELECT ( SELECT CASE WHEN v = 'null'::jsonb THEN NULL WHEN empty AND (lower IS DISTINCT FROM upper OR lower IS NOT NULL AND inc_upper AND inc_lower) THEN edgedb_VER.raise( NULL::jsonb, 'invalid_parameter_value', msg => 'conflicting arguments in range constructor:' || ' ''empty'' is `true` while the specified' || ' bounds suggest otherwise', detail => detail ) WHEN NOT empty AND inc_lower IS NULL THEN edgedb_VER.raise( NULL::jsonb, 'invalid_parameter_value', msg => 'JSON object representing a range must include an' || ' ''inc_lower'' boolean property', detail => detail ) WHEN NOT empty AND inc_upper IS NULL THEN edgedb_VER.raise( NULL::jsonb, 'invalid_parameter_value', msg => 'JSON object representing a range must include an' || ' ''inc_upper'' boolean property', detail => detail ) WHEN EXISTS ( SELECT jsonb_object_keys(v) EXCEPT VALUES ('lower'), ('upper'), ('inc_lower'), ('inc_upper'), ('empty') ) THEN (SELECT edgedb_VER.raise( NULL::jsonb, 'invalid_parameter_value', msg => 'JSON object representing a range contains unexpected' || ' keys: ' || string_agg(k.k, ', ' ORDER BY k.k), detail => detail ) FROM (SELECT jsonb_object_keys(v) EXCEPT VALUES ('lower'), ('upper'), ('inc_lower'), ('inc_upper'), ('empty') ) AS k(k) ) ELSE v END FROM (SELECT (v ->> 'lower') AS lower, (v ->> 'upper') AS upper, (v ->> 'inc_lower')::bool AS inc_lower, (v ->> 'inc_upper')::bool AS inc_upper, coalesce((v ->> 'empty')::bool, false) AS empty ) j ) FROM ( SELECT edgedb_VER.jsonb_assert_type( val, ARRAY['object', 'null'], detail => detail ) AS v ) AS x $$; }; CREATE CAST FROM range TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.range_to_jsonb'; }; CREATE CAST FROM multirange TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'edgedb.multirange_to_jsonb'; }; CREATE CAST FROM std::json TO range { SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE CAST FROM std::json TO multirange { SET volatility := 'Immutable'; USING SQL EXPRESSION; }; # The function to_jsonb is STABLE in PostgreSQL, but this function is # generic and STABLE volatility may be an overestimation in many cases. CREATE CAST FROM std::bool TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::bytes TO std::json { SET volatility := 'Immutable'; USING SQL $$ SELECT to_jsonb(encode(val, 'base64')); $$; }; CREATE CAST FROM std::uuid TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::str TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::datetime TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::duration TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::int16 TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::int32 TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::int64 TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::float32 TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::float64 TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::decimal TO std::json { SET volatility := 'Immutable'; USING SQL FUNCTION 'to_jsonb'; }; CREATE CAST FROM std::json TO std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.jsonb_extract_scalar(val, 'boolean', detail => detail)::bool; $$; }; CREATE CAST FROM std::json TO std::uuid { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.jsonb_extract_scalar(val, 'string', detail => detail)::uuid; $$; }; CREATE CAST FROM std::json TO std::bytes { SET volatility := 'Immutable'; USING SQL $$ SELECT decode( edgedb_VER.jsonb_extract_scalar(val, 'string', detail => detail), 'base64' )::bytea; $$; }; CREATE CAST FROM std::json TO std::str { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.jsonb_extract_scalar(val, 'string', detail => detail); $$; }; CREATE CAST FROM std::json TO std::datetime { # Stable because the input string can contain an explicit time-zone. Time # zones are externally defined things that can change suddenly and # arbitrarily by human laws, thus potentially changing the interpretatio # of the input string. SET volatility := 'Stable'; USING SQL $$ SELECT edgedb_VER.datetime_in( edgedb_VER.jsonb_extract_scalar(val, 'string', detail => detail) ); $$; }; CREATE CAST FROM std::json TO std::duration { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.duration_in( edgedb_VER.jsonb_extract_scalar(val, 'string', detail => detail) ); $$; }; CREATE CAST FROM std::json TO std::int16 { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.jsonb_extract_scalar(val, 'number', detail => detail)::int2; $$; }; CREATE CAST FROM std::json TO std::int32 { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.jsonb_extract_scalar(val, 'number', detail => detail)::int4; $$; }; CREATE CAST FROM std::json TO std::int64 { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.jsonb_extract_scalar(val, 'number', detail => detail)::int8; $$; }; CREATE CAST FROM std::json TO std::float32 { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.jsonb_extract_scalar(val, 'number', detail => detail)::float4; $$; }; CREATE CAST FROM std::json TO std::float64 { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.jsonb_extract_scalar(val, 'number', detail => detail)::float8; $$; }; CREATE CAST FROM std::json TO std::decimal { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.str_to_decimal( edgedb_VER.jsonb_extract_scalar(val, 'number', detail => detail) ); $$; }; CREATE CAST FROM std::json TO std::bigint { SET volatility := 'Immutable'; USING SQL $$ SELECT edgedb_VER.str_to_bigint( edgedb_VER.jsonb_extract_scalar(val, 'number', detail => detail) ); $$; }; ================================================ FILE: edb/lib/std/30-regexpfuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Regular expression functions. CREATE FUNCTION std::re_match(pattern: std::str, str: std::str) -> array { CREATE ANNOTATION std::description := 'Find the first regular expression match in a string.'; SET volatility := 'Immutable'; USING SQL $$ SELECT array_replace(regexp_matches("str", "pattern"), NULL, ''); $$; }; CREATE FUNCTION std::re_match_all(pattern: std::str, str: std::str) -> SET OF array { CREATE ANNOTATION std::description := 'Find all regular expression matches in a string.'; SET volatility := 'Immutable'; USING SQL $$ SELECT array_replace(regexp_matches("str", "pattern", 'g'), NULL, ''); $$; }; CREATE FUNCTION std::re_test(pattern: std::str, str: std::str) -> std::bool { CREATE ANNOTATION std::description := 'Test if a regular expression has a match in a string.'; SET volatility := 'Immutable'; USING SQL $$ SELECT "str" ~ "pattern"; $$; }; CREATE FUNCTION std::re_replace( pattern: std::str, sub: std::str, str: std::str, NAMED ONLY flags: std::str = '') -> std::str { CREATE ANNOTATION std::description := 'Replace matching substrings in a given string.'; SET volatility := 'Immutable'; USING SQL $$ SELECT regexp_replace("str", "pattern", "sub", "flags"); $$; }; ================================================ FILE: edb/lib/std/30-sequencefuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## std::sequence functions and operators. # See schema.edgeql for definitions of sequence_next() and friends. ================================================ FILE: edb/lib/std/30-strfuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## String operators CREATE INFIX OPERATOR std::`=` (l: std::str, r: std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::str, r: OPTIONAL std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::str, r: std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::str, r: OPTIONAL std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; # Concatenation. CREATE INFIX OPERATOR std::`++` (l: std::str, r: std::str) -> std::str { CREATE ANNOTATION std::identifier := 'concat'; CREATE ANNOTATION std::description := 'String concatenation.'; SET volatility := 'Immutable'; USING SQL OPERATOR '||'; }; CREATE INFIX OPERATOR std::`LIKE` (string: std::str, pattern: std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'like'; CREATE ANNOTATION std::description := 'Case-sensitive simple string matching.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`ILIKE` (string: std::str, pattern: std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'ilike'; CREATE ANNOTATION std::description := 'Case-insensitive simple string matching.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`NOT LIKE` (string: std::str, pattern: std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'not_like'; CREATE ANNOTATION std::description := 'Case-sensitive simple string matching.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`NOT ILIKE` (string: std::str, pattern: std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'not_ilike'; CREATE ANNOTATION std::description := 'Case-insensitive simple string matching.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`<` (l: std::str, r: std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR r'<'; }; CREATE INFIX OPERATOR std::`<=` (l: std::str, r: std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR r'<='; }; CREATE INFIX OPERATOR std::`>` (l: std::str, r: std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR r'>'; }; CREATE INFIX OPERATOR std::`>=` (l: std::str, r: std::str) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR r'>='; }; CREATE INFIX OPERATOR std::`[]` (l: std::str, r: std::int64) -> std::str { CREATE ANNOTATION std::identifier := 'index'; CREATE ANNOTATION std::description := 'String indexing.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`[]` (l: std::str, r: tuple) -> std::str { CREATE ANNOTATION std::identifier := 'slice'; CREATE ANNOTATION std::description := 'String slicing.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; ## String functions CREATE FUNCTION std::str_repeat(s: std::str, n: std::int64) -> std::str { CREATE ANNOTATION std::description := 'Repeat the input *string* *n* times.'; SET volatility := 'Immutable'; USING SQL $$ SELECT repeat("s", "n"::int4) $$; }; CREATE FUNCTION std::str_lower(s: std::str) -> std::str { CREATE ANNOTATION std::description := 'Return a lowercase copy of the input *string*.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'lower'; }; CREATE FUNCTION std::str_upper(s: std::str) -> std::str { CREATE ANNOTATION std::description := 'Return an uppercase copy of the input *string*.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'upper'; }; CREATE FUNCTION std::str_title(s: std::str) -> std::str { CREATE ANNOTATION std::description := 'Return a titlecase copy of the input *string*.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'initcap'; }; CREATE FUNCTION std::str_pad_start(s: std::str, n: std::int64, fill: std::str=' ') -> std::str { CREATE ANNOTATION std::description := 'Return the input string padded at the start to the length *n*.'; SET volatility := 'Immutable'; USING SQL $$ SELECT lpad("s", "n"::int4, "fill") $$; }; CREATE FUNCTION std::str_lpad(s: std::str, n: std::int64, fill: std::str=' ') -> std::str { CREATE ANNOTATION std::description := 'Return the input string left-padded to the length *n*.'; CREATE ANNOTATION std::deprecated := 'This function is deprecated and is scheduled \ to be removed before 1.0.\n\ Use std::str_pad_start() instead.'; SET volatility := 'Immutable'; USING (std::str_pad_start(s, n, fill)); }; CREATE FUNCTION std::str_pad_end(s: std::str, n: std::int64, fill: std::str=' ') -> std::str { CREATE ANNOTATION std::description := 'Return the input string padded at the end to the length *n*.'; SET volatility := 'Immutable'; USING SQL $$ SELECT rpad("s", "n"::int4, "fill") $$; }; CREATE FUNCTION std::str_rpad(s: std::str, n: std::int64, fill: std::str=' ') -> std::str { CREATE ANNOTATION std::description := 'Return the input string right-padded to the length *n*.'; CREATE ANNOTATION std::deprecated := 'This function is deprecated and is scheduled \ to be removed before 1.0.\n\ Use std::str_pad_end() instead.'; SET volatility := 'Immutable'; USING (std::str_pad_end(s, n, fill)); }; CREATE FUNCTION std::str_trim_start(s: std::str, tr: std::str=' ') -> std::str { CREATE ANNOTATION std::description := 'Return the input string with all *trim* characters removed from \ its start.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'ltrim'; }; CREATE FUNCTION std::str_ltrim(s: std::str, tr: std::str=' ') -> std::str { CREATE ANNOTATION std::description := 'Return the input string with all leftmost *trim* characters removed.'; CREATE ANNOTATION std::deprecated := 'This function is deprecated and is scheduled \ to be removed before 1.0.\n\ Use std::str_trim_start() instead.'; SET volatility := 'Immutable'; USING (std::str_trim_start(s, tr)); }; CREATE FUNCTION std::str_trim_end(s: std::str, tr: std::str=' ') -> std::str { CREATE ANNOTATION std::description := 'Return the input string with all *trim* characters removed from \ its end.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'rtrim'; }; CREATE FUNCTION std::str_rtrim(s: std::str, tr: std::str=' ') -> std::str { CREATE ANNOTATION std::description := 'Return the input string with all rightmost *trim* characters removed.'; CREATE ANNOTATION std::deprecated := 'This function is deprecated and is scheduled \ to be removed before 1.0.\n\ Use std::str_trim_end() instead.'; SET volatility := 'Immutable'; USING (std::str_trim_end(s, tr)); }; CREATE FUNCTION std::str_trim(s: std::str, tr: std::str=' ') -> std::str { CREATE ANNOTATION std::description := 'Return the input string with *trim* characters removed from \ both ends.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'btrim'; }; CREATE FUNCTION std::str_split(s: std::str, delimiter: std::str) -> array { CREATE ANNOTATION std::description := 'Split string into array elements using the supplied delimiter.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "delimiter" != '' THEN string_to_array("s", "delimiter") ELSE regexp_split_to_array("s", '') END ); $$; }; CREATE FUNCTION std::str_replace(s: std::str, old: std::str, new: std::str) -> std::str { CREATE ANNOTATION std::description := 'Given a string, find a matching substring and replace all its \ occurrences with a new substring.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'replace'; }; CREATE FUNCTION std::str_reverse(s: std::str) -> std::str { CREATE ANNOTATION std::description := 'Reverse the order of the characters in the string.'; SET volatility := 'Immutable'; USING SQL FUNCTION 'reverse'; }; ================================================ FILE: edb/lib/std/30-uuidfuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## UUID functions and operators. CREATE FUNCTION std::uuid_generate_v1mc() -> std::uuid { CREATE ANNOTATION std::description := 'Return a version 1 UUID.'; SET volatility := 'Volatile'; USING SQL FUNCTION 'edgedb.uuid_generate_v1mc'; }; CREATE FUNCTION std::uuid_generate_v4() -> std::uuid { CREATE ANNOTATION std::description := 'Return a version 4 UUID.'; SET volatility := 'Volatile'; USING SQL FUNCTION 'edgedb.uuid_generate_v4'; }; CREATE INFIX OPERATOR std::`=` (l: std::uuid, r: std::uuid) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR r'='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL std::uuid, r: OPTIONAL std::uuid) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::uuid, r: std::uuid) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR r'<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL std::uuid, r: OPTIONAL std::uuid) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>=` (l: std::uuid, r: std::uuid) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>='; }; CREATE INFIX OPERATOR std::`>` (l: std::uuid, r: std::uuid) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>'; }; CREATE INFIX OPERATOR std::`<=` (l: std::uuid, r: std::uuid) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<='; }; CREATE INFIX OPERATOR std::`<` (l: std::uuid, r: std::uuid) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<'; }; ## String casts. CREATE CAST FROM std::str TO std::uuid { SET volatility := 'Immutable'; USING SQL CAST; }; CREATE CAST FROM std::uuid TO std::str { SET volatility := 'Immutable'; USING SQL CAST; }; ================================================ FILE: edb/lib/std/31-rangefuncs.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Range/multirange functions CREATE FUNCTION std::range( lower: optional std::anypoint = {}, upper: optional std::anypoint = {}, named only inc_lower: bool = true, named only inc_upper: bool = false, named only empty: bool = false, ) -> range { SET volatility := 'Immutable'; USING SQL EXPRESSION; }; # TODO: maybe also add a constructor taking a set? CREATE FUNCTION std::multirange( ranges: array>, ) -> multirange { SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE FUNCTION std::range_is_empty( val: range ) -> bool { SET volatility := 'Immutable'; USING SQL FUNCTION 'isempty'; }; CREATE FUNCTION std::range_is_empty( val: multirange ) -> bool { SET volatility := 'Immutable'; USING SQL FUNCTION 'isempty'; }; CREATE FUNCTION std::range_unpack( val: range ) -> set of int32 { SET volatility := 'Immutable'; USING SQL $$ SELECT generate_series( ( edgedb_VER.range_lower_validate(val) + (CASE WHEN lower_inc(val) THEN 0 ELSE 1 END) )::int8, ( edgedb_VER.range_upper_validate(val) - (CASE WHEN upper_inc(val) THEN 0 ELSE 1 END) )::int8 )::int4 $$; }; CREATE FUNCTION std::range_unpack( val: range, step: int32 ) -> set of int32 { SET volatility := 'Immutable'; USING SQL $$ SELECT generate_series( ( edgedb_VER.range_lower_validate(val) + (CASE WHEN lower_inc(val) THEN 0 ELSE 1 END) )::int8, ( edgedb_VER.range_upper_validate(val) - (CASE WHEN upper_inc(val) THEN 0 ELSE 1 END) )::int8, step::int8 )::int4 $$; }; CREATE FUNCTION std::range_unpack( val: range ) -> set of int64 { SET volatility := 'Immutable'; USING SQL $$ SELECT generate_series( ( edgedb_VER.range_lower_validate(val) + (CASE WHEN lower_inc(val) THEN 0 ELSE 1 END) )::int8, ( edgedb_VER.range_upper_validate(val) - (CASE WHEN upper_inc(val) THEN 0 ELSE 1 END) )::int8 ) $$; }; CREATE FUNCTION std::range_unpack( val: range, step: int64 ) -> set of int64 { SET volatility := 'Immutable'; USING SQL $$ SELECT generate_series( ( edgedb_VER.range_lower_validate(val) + (CASE WHEN lower_inc(val) THEN 0 ELSE 1 END) )::int8, ( edgedb_VER.range_upper_validate(val) - (CASE WHEN upper_inc(val) THEN 0 ELSE 1 END) )::int8, step ) $$; }; CREATE FUNCTION std::range_unpack( val: range, step: float32 ) -> set of float32 { SET volatility := 'Immutable'; USING SQL $$ SELECT num::float4 FROM generate_series( ( edgedb_VER.range_lower_validate(val) + (CASE WHEN lower_inc(val) THEN 0 ELSE step END) )::numeric, ( edgedb_VER.range_upper_validate(val) )::numeric, step::numeric ) AS num WHERE upper_inc(val) OR num::float4 < upper(val) $$; }; CREATE FUNCTION std::range_unpack( val: range, step: float64 ) -> set of float64 { SET volatility := 'Immutable'; USING SQL $$ SELECT num::float8 FROM generate_series( ( edgedb_VER.range_lower_validate(val) + (CASE WHEN lower_inc(val) THEN 0 ELSE step END) )::numeric, ( edgedb_VER.range_upper_validate(val) )::numeric, step::numeric ) AS num WHERE upper_inc(val) OR num::float8 < upper(val) $$; }; CREATE FUNCTION std::range_unpack( val: range, step: decimal ) -> set of decimal { SET volatility := 'Immutable'; USING SQL $$ SELECT num FROM generate_series( edgedb_VER.range_lower_validate(val) + (CASE WHEN lower_inc(val) THEN 0 ELSE step END), edgedb_VER.range_upper_validate(val), step ) AS num WHERE upper_inc(val) OR num < upper(val) $$; }; CREATE FUNCTION std::range_unpack( val: range, step: duration ) -> set of datetime { SET volatility := 'Immutable'; USING SQL $$ SELECT d::edgedbt.timestamptz_t FROM generate_series( ( edgedb_VER.range_lower_validate(val) + ( CASE WHEN lower_inc(val) THEN '0'::interval ELSE step END ) )::timestamptz, ( edgedb_VER.range_upper_validate(val) )::timestamptz, step::interval ) AS d WHERE upper_inc(val) OR d::edgedbt.timestamptz_t < upper(val) $$; }; CREATE FUNCTION std::range_get_upper(r: range) -> optional anypoint { SET volatility := 'Immutable'; USING SQL FUNCTION 'upper'; SET force_return_cast := true; }; CREATE FUNCTION std::range_get_lower(r: range) -> optional anypoint { SET volatility := 'Immutable'; USING SQL FUNCTION 'lower'; SET force_return_cast := true; }; CREATE FUNCTION std::range_is_inclusive_upper(r: range) -> std::bool { SET volatility := 'Immutable'; USING SQL FUNCTION 'upper_inc'; }; CREATE FUNCTION std::range_is_inclusive_lower(r: range) -> std::bool { SET volatility := 'Immutable'; USING SQL FUNCTION 'lower_inc'; }; CREATE FUNCTION std::range_get_upper( r: multirange ) -> optional anypoint { SET volatility := 'Immutable'; USING SQL FUNCTION 'upper'; SET force_return_cast := true; }; CREATE FUNCTION std::range_get_lower( r: multirange ) -> optional anypoint { SET volatility := 'Immutable'; USING SQL FUNCTION 'lower'; SET force_return_cast := true; }; CREATE FUNCTION std::range_is_inclusive_upper( r: multirange ) -> std::bool { SET volatility := 'Immutable'; USING SQL FUNCTION 'upper_inc'; }; CREATE FUNCTION std::range_is_inclusive_lower( r: multirange ) -> std::bool { SET volatility := 'Immutable'; USING SQL FUNCTION 'lower_inc'; }; CREATE FUNCTION std::contains( haystack: range, needle: range ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "haystack" @> "needle" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; # Postgres only manages to inline this function if it isn't marked strict, # and we want it to be inlined so that std::pg::gin indexes work with it. set impl_is_strict := false; }; CREATE FUNCTION std::contains( haystack: range, needle: anypoint ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "haystack" @> "needle" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::contains( haystack: multirange, needle: multirange ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "haystack" @> "needle" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::contains( haystack: multirange, needle: range ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "haystack" @> "needle" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::contains( haystack: multirange, needle: anypoint ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "haystack" @> "needle" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::overlaps( l: range, r: range ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" && "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::overlaps( l: multirange, r: multirange ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" && "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; # FIXME: These functions introduce the concrete multirange types into the # schema. That's why they exist for each concrete type explicitly and aren't # defined generically for anytype. CREATE FUNCTION std::multirange_unpack( val: multirange, ) -> set of range { SET volatility := 'Immutable'; USING SQL FUNCTION 'unnest'; }; CREATE FUNCTION std::multirange_unpack( val: multirange, ) -> set of range { SET volatility := 'Immutable'; USING SQL FUNCTION 'unnest'; }; CREATE FUNCTION std::multirange_unpack( val: multirange, ) -> set of range { SET volatility := 'Immutable'; USING SQL FUNCTION 'unnest'; }; CREATE FUNCTION std::multirange_unpack( val: multirange, ) -> set of range { SET volatility := 'Immutable'; USING SQL FUNCTION 'unnest'; }; CREATE FUNCTION std::multirange_unpack( val: multirange, ) -> set of range { SET volatility := 'Immutable'; USING SQL FUNCTION 'unnest'; }; CREATE FUNCTION std::multirange_unpack( val: multirange, ) -> set of range { SET volatility := 'Immutable'; USING SQL FUNCTION 'unnest'; }; CREATE FUNCTION std::strictly_below( l: range, r: range ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" << "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::strictly_below( l: multirange, r: multirange ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" << "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::strictly_above( l: range, r: range ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" >> "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::strictly_above( l: multirange, r: multirange ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" >> "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::bounded_above( l: range, r: range ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" &< "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::bounded_above( l: multirange, r: multirange ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" &< "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::bounded_below( l: range, r: range ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" &> "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::bounded_below( l: multirange, r: multirange ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" &> "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::adjacent( l: range, r: range ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" -|- "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; CREATE FUNCTION std::adjacent( l: multirange, r: multirange ) -> std::bool { SET volatility := 'Immutable'; USING SQL $$ SELECT "l" -|- "r" $$; # Needed to pick up the indexes when used in FILTER. set prefer_subquery_args := true; set impl_is_strict := false; }; ## Range operators CREATE INFIX OPERATOR std::`=` (l: range, r: range) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR '='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL range, r: OPTIONAL range) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; SET recursive := true; }; CREATE INFIX OPERATOR std::`!=` (l: range, r: range) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR '<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL range, r: OPTIONAL range) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; SET recursive := true; }; CREATE INFIX OPERATOR std::`>=` (l: range, r: range) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>='; }; CREATE INFIX OPERATOR std::`>` (l: range, r: range) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>'; }; CREATE INFIX OPERATOR std::`<=` (l: range, r: range) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<='; }; CREATE INFIX OPERATOR std::`<` (l: range, r: range) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<'; }; CREATE INFIX OPERATOR std::`+` (l: range, r: range) -> range { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Range union.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::+'; USING SQL OPERATOR r'+'; }; CREATE INFIX OPERATOR std::`-` (l: range, r: range) -> range { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Range difference.'; SET volatility := 'Immutable'; SET recursive := true; USING SQL OPERATOR r'-'; }; CREATE INFIX OPERATOR std::`*` (l: range, r: range) -> range { CREATE ANNOTATION std::identifier := 'mult'; CREATE ANNOTATION std::description := 'Range intersection.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::*'; USING SQL OPERATOR r'*'; }; ## MultiRange operators CREATE INFIX OPERATOR std::`=` (l: multirange, r: multirange) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::='; SET negator := 'std::!='; USING SQL OPERATOR '='; }; CREATE INFIX OPERATOR std::`?=` (l: OPTIONAL multirange, r: OPTIONAL multirange) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; SET recursive := true; }; CREATE INFIX OPERATOR std::`!=` (l: multirange, r: multirange) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::!='; SET negator := 'std::='; USING SQL OPERATOR '<>'; }; CREATE INFIX OPERATOR std::`?!=` (l: OPTIONAL multirange, r: OPTIONAL multirange) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; SET recursive := true; }; CREATE INFIX OPERATOR std::`>=` (l: multirange, r: multirange) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::<='; SET negator := 'std::<'; USING SQL OPERATOR '>='; }; CREATE INFIX OPERATOR std::`>` (l: multirange, r: multirange) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::<'; SET negator := 'std::<='; USING SQL OPERATOR '>'; }; CREATE INFIX OPERATOR std::`<=` (l: multirange, r: multirange) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::>='; SET negator := 'std::>'; USING SQL OPERATOR '<='; }; CREATE INFIX OPERATOR std::`<` (l: multirange, r: multirange) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::>'; SET negator := 'std::>='; USING SQL OPERATOR '<'; }; CREATE INFIX OPERATOR std::`+` (l: multirange, r: multirange) -> multirange { CREATE ANNOTATION std::identifier := 'plus'; CREATE ANNOTATION std::description := 'Range union.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::+'; USING SQL OPERATOR r'+'; }; CREATE INFIX OPERATOR std::`-` (l: multirange, r: multirange) -> multirange { CREATE ANNOTATION std::identifier := 'minus'; CREATE ANNOTATION std::description := 'Range difference.'; SET volatility := 'Immutable'; SET recursive := true; USING SQL OPERATOR r'-'; }; CREATE INFIX OPERATOR std::`*` (l: multirange, r: multirange) -> multirange { CREATE ANNOTATION std::identifier := 'mult'; CREATE ANNOTATION std::description := 'Range intersection.'; SET volatility := 'Immutable'; SET recursive := true; SET commutator := 'std::*'; USING SQL OPERATOR r'*'; }; ## Range/multirange casts CREATE CAST FROM range TO multirange { SET volatility := 'Immutable'; USING SQL EXPRESSION; # Any range can be implicitly cast into a multirange. ALLOW IMPLICIT; }; ## For annoying performance reasons, we want to be able to internally ## directly call generate_series. ## Hopefully I'll fix this better later. CREATE FUNCTION std::__pg_generate_series( `start`: std::int64, stop: std::int64 ) -> SET OF std::int64 { SET volatility := 'Immutable'; USING SQL FUNCTION 'generate_series'; }; ================================================ FILE: edb/lib/std/50-constraints.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Standard constraints. CREATE FUNCTION std::_is_exclusive(s: SET OF anytype) -> std::bool { SET volatility := 'Immutable'; SET initial_value := True; SET internal := true; USING SQL EXPRESSION; }; CREATE ABSTRACT CONSTRAINT std::constraint { SET errmessage := 'invalid {__subject__}'; }; CREATE ABSTRACT CONSTRAINT std::expression EXTENDING std::constraint { CREATE ANNOTATION std::description := 'Arbitrary constraint expression.'; USING (__subject__); }; CREATE ABSTRACT CONSTRAINT std::exclusive EXTENDING std::constraint { CREATE ANNOTATION std::description := 'Specifies that the link or property value must be exclusive (unique).'; SET is_aggregate := true; SET errmessage := '{__subject__} violates exclusivity constraint'; USING (std::_is_exclusive(__subject__)); }; CREATE ABSTRACT CONSTRAINT std::one_of(VARIADIC vals: anytype) EXTENDING std::constraint { CREATE ANNOTATION std::description := 'Specifies the list of allowed values directly.'; SET errmessage := '{__subject__} must be one of: {vals}.'; USING (contains(vals, __subject__)); }; CREATE ABSTRACT CONSTRAINT std::len_value ON (len(__subject__)) EXTENDING std::constraint { SET errmessage := 'invalid {__subject__}'; }; CREATE ABSTRACT CONSTRAINT std::max_value(max: anytype) EXTENDING std::constraint { CREATE ANNOTATION std::description := 'Specifies the maximum value for the subject.'; SET errmessage := 'Maximum allowed value for {__subject__} is {max}.'; USING (__subject__ <= max); }; CREATE ABSTRACT CONSTRAINT std::min_value(min: anytype) EXTENDING std::constraint { CREATE ANNOTATION std::description := 'Specifies the minimum value for the subject.'; SET errmessage := 'Minimum allowed value for {__subject__} is {min}.'; USING (__subject__ >= min); }; CREATE ABSTRACT CONSTRAINT std::max_ex_value(max: anytype) EXTENDING std::max_value { CREATE ANNOTATION std::description := 'Specifies the maximum value (as an open interval) for the subject.'; SET errmessage := '{__subject__} must be less than {max}.'; USING (__subject__ < max); }; CREATE ABSTRACT CONSTRAINT std::min_ex_value(min: anytype) EXTENDING std::min_value { CREATE ANNOTATION std::description := 'Specifies the minimum value (as an open interval) for the subject.'; SET errmessage := '{__subject__} must be greater than {min}.'; USING (__subject__ > min); }; CREATE ABSTRACT CONSTRAINT std::regexp(pattern: std::str) EXTENDING std::constraint { CREATE ANNOTATION std::description := 'Specifies that the string representation of the subject must match a regexp.'; SET errmessage := 'invalid {__subject__}'; USING (re_test(pattern, __subject__)); }; CREATE ABSTRACT CONSTRAINT std::max_len_value(max: std::int64) EXTENDING std::max_value, std::len_value { CREATE ANNOTATION std::description := 'Specifies the maximum length of subject string representation.'; SET errmessage := '{__subject__} must be no longer than {max} characters.'; }; CREATE ABSTRACT CONSTRAINT std::min_len_value(min: std::int64) EXTENDING std::min_value, std::len_value { CREATE ANNOTATION std::description := 'Specifies the minimum length of subject string representation.'; SET errmessage := '{__subject__} must be no shorter than {min} characters.'; }; ================================================ FILE: edb/lib/std/60-baseobject.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Base object type, link and property definitions. CREATE ABSTRACT PROPERTY std::property; CREATE ABSTRACT PROPERTY std::id; CREATE ABSTRACT PROPERTY std::source; CREATE ABSTRACT PROPERTY std::target; CREATE ABSTRACT LINK std::link; CREATE ABSTRACT TYPE std::BaseObject { CREATE REQUIRED PROPERTY id EXTENDING std::id -> std::uuid { SET default := std::uuid_generate_v1mc(); SET readonly := True; CREATE CONSTRAINT std::exclusive; }; CREATE ANNOTATION std::description := 'Root object type.' }; CREATE ABSTRACT TYPE std::Object EXTENDING std::BaseObject { CREATE ANNOTATION std::description := 'Root object type for user-defined types'; }; # N.B: This does *not* derive from std::BaseObject! CREATE TYPE std::FreeObject { CREATE ANNOTATION std::description := 'Object type for free shapes'; }; # 'USING SQL EXPRESSION' creates an EdgeDB Operator for purposes of # introspection and validation by the EdgeQL compiler. However, no # object is created in Postgres and the EdgeQL->SQL compiler is expected # to produce some expression that will be valid. # # 'USING SQL OPERATOR' does all of the above and it also creates an # actual Postgres operator. It is expected that the EdgeQL->SQL compiler # will specifically use that operator. # HACK: We use 'USING SQL EXPRESSION' instead of 'USING SQL OPERATOR' # here because in actuality Objects will be resolved as their uuids # and in the end it's the uuid operators that will be called in SQL. # On the other hand, if we use "USING SQL OPERATOR", we will end up # clashing with the operators for uuid in Postgres. CREATE INFIX OPERATOR std::`=` (l: std::BaseObject, r: std::BaseObject) -> std::bool { CREATE ANNOTATION std::identifier := 'eq'; CREATE ANNOTATION std::description := 'Compare two values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`?=` ( l: OPTIONAL std::BaseObject, r: OPTIONAL std::BaseObject ) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_eq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for equality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`!=` (l: std::BaseObject, r: std::BaseObject) -> std::bool { CREATE ANNOTATION std::identifier := 'neq'; CREATE ANNOTATION std::description := 'Compare two values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`?!=` ( l: OPTIONAL std::BaseObject, r: OPTIONAL std::BaseObject ) -> std::bool { CREATE ANNOTATION std::identifier := 'coal_neq'; CREATE ANNOTATION std::description := 'Compare two (potentially empty) values for inequality.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>=` (l: std::BaseObject, r: std::BaseObject) -> std::bool { CREATE ANNOTATION std::identifier := 'gte'; CREATE ANNOTATION std::description := 'Greater than or equal.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`>` (l: std::BaseObject, r: std::BaseObject) -> std::bool { CREATE ANNOTATION std::identifier := 'gt'; CREATE ANNOTATION std::description := 'Greater than.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`<=` (l: std::BaseObject, r: std::BaseObject) -> std::bool { CREATE ANNOTATION std::identifier := 'lte'; CREATE ANNOTATION std::description := 'Less than or equal.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE INFIX OPERATOR std::`<` (l: std::BaseObject, r: std::BaseObject) -> std::bool { CREATE ANNOTATION std::identifier := 'lt'; CREATE ANNOTATION std::description := 'Less than.'; SET volatility := 'Immutable'; USING SQL EXPRESSION; }; # The only possible Object cast is into json. CREATE CAST FROM std::BaseObject TO std::json { SET volatility := 'Immutable'; USING SQL EXPRESSION; }; CREATE CAST FROM std::FreeObject TO std::json { SET volatility := 'Immutable'; USING SQL EXPRESSION; }; ================================================ FILE: edb/lib/std/70-converters.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ## Function that construct various scalars from strings or other types. # std::to_str # -------- # Normalize [local] datetime to text conversion to have # the same format as one would get by serializing to JSON. # Otherwise Postgres doesn't follow the ISO8601 standard # and uses ' ' instead of 'T' as a separator between date # and time. # # EdgeQL: '2010-10-10'; # To SQL: trim(to_json('2010-01-01'::timestamptz)::text, '"') CREATE FUNCTION std::to_str(dt: std::datetime, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; # Helper functions raising exceptions are STABLE. SET volatility := 'Stable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN trim(to_json("dt")::text, '"') WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_char("dt", "fmt"), 'invalid_parameter_value', msg => 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_str(td: std::duration, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN trim(to_json("td")::text, '"') WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_char("td", "fmt"), 'invalid_parameter_value', msg => 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; # FIXME: There's no good safe default for all possible durations and some # durations cannot be formatted without non-trivial conversions (e.g. # 7,000 days). CREATE FUNCTION std::to_str(i: std::int64, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN "i"::text WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_char("i", "fmt"), 'invalid_parameter_value', msg => 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_str(f: std::float64, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN "f"::text WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_char("f", "fmt"), 'invalid_parameter_value', msg => 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_str(d: std::bigint, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN "d"::text WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_char("d", "fmt"), 'invalid_parameter_value', 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_str(d: std::decimal, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN "d"::text WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_char("d", "fmt"), 'invalid_parameter_value', msg => 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_str(array: array, delimiter: std::str) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; CREATE ANNOTATION std::deprecated := 'This converter function is deprecated and \ is scheduled to be removed before 1.0.\n\ Use std::array_join() instead.'; SET volatility := 'Immutable'; USING ( SELECT std::array_join(array, delimiter) ); }; # JSON can be prettified by specifying 'pretty' as the format, any # other value will result in an exception. CREATE FUNCTION std::to_str(json: std::json, fmt: OPTIONAL str={}) -> std::str { CREATE ANNOTATION std::description := 'Return string representation of the input value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN "json"::text WHEN "fmt" = 'pretty' THEN jsonb_pretty("json") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise( NULL::text, 'invalid_parameter_value', msg => 'to_str(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_str(b: std::bytes) -> std::str { CREATE ANNOTATION std::description := 'Convert a binary UTF-8 string to a text value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT pg_catalog.convert_from("b", 'UTF8') $$; }; CREATE FUNCTION std::to_bytes(s: std::str) -> std::bytes { CREATE ANNOTATION std::description := 'Convert a text string to a binary UTF-8 string.'; SET volatility := 'Immutable'; USING SQL $$ SELECT pg_catalog.convert_to("s", 'UTF8') $$; }; CREATE FUNCTION std::to_bytes(j: std::json) -> std::bytes { CREATE ANNOTATION std::description := 'Convert a json value to a binary UTF-8 string.'; SET volatility := 'Immutable'; USING (to_bytes(to_str(j))); }; CREATE SCALAR TYPE std::Endian EXTENDING enum; CREATE FUNCTION std::to_bytes(val: std::int16, endian: std::Endian) -> std::bytes { CREATE ANNOTATION std::description := 'Convert an int16 using specified endian binary format.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN (endian = 'Little') THEN substring(bin, 2, 1) || substring(bin, 1, 1) ELSE bin END FROM ( SELECT int2send(val) AS bin ) AS t; $$; }; CREATE FUNCTION std::to_bytes(val: std::int32, endian: std::Endian) -> std::bytes { CREATE ANNOTATION std::description := 'Convert an int32 using specified endian binary format.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN (endian = 'Little') THEN substring(bin, 4, 1) || substring(bin, 3, 1) || substring(bin, 2, 1) || substring(bin, 1, 1) ELSE bin END FROM ( SELECT int4send(val) AS bin ) AS t; $$; }; CREATE FUNCTION std::to_bytes(val: std::int64, endian: std::Endian) -> std::bytes { CREATE ANNOTATION std::description := 'Convert an int64 using specified endian binary format.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN (endian = 'Little') THEN substring(bin, 8, 1) || substring(bin, 7, 1) || substring(bin, 6, 1) || substring(bin, 5, 1) || substring(bin, 4, 1) || substring(bin, 3, 1) || substring(bin, 2, 1) || substring(bin, 1, 1) ELSE bin END FROM ( SELECT int8send(val) AS bin ) AS t; $$; }; CREATE FUNCTION std::to_bytes(val: std::uuid) -> std::bytes { CREATE ANNOTATION std::description := 'Convert an UUID to binary format.'; SET volatility := 'Immutable'; USING SQL $$ SELECT uuid_send(val); $$; }; CREATE FUNCTION std::to_json(str: std::str) -> std::json { CREATE ANNOTATION std::description := 'Return JSON value represented by the input *string*.'; # Casting of jsonb to and from text in PostgreSQL is IMMUTABLE. SET volatility := 'Immutable'; USING SQL $$ SELECT "str"::jsonb $$; }; CREATE FUNCTION std::to_datetime(s: std::str, fmt: OPTIONAL str={}) -> std::datetime { CREATE ANNOTATION std::description := 'Create a `datetime` value.'; # Helper function to_datetime is VOLATILE. SET volatility := 'Volatile'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN edgedb_VER.datetime_in("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::edgedbt.timestamptz_t, 'invalid_parameter_value', msg => ( 'to_datetime(): "fmt" argument must be a non-empty string' ) ) ELSE edgedb_VER.raise_on_null( edgedb_VER.to_datetime("s", "fmt"), 'invalid_parameter_value', msg => 'to_datetime(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_datetime(year: std::int64, month: std::int64, day: std::int64, hour: std::int64, min: std::int64, sec: std::float64, timezone: std::str) -> std::datetime { CREATE ANNOTATION std::description := 'Create a `datetime` value.'; # make_timestamptz is STABLE SET volatility := 'Stable'; USING SQL $$ SELECT make_timestamptz( "year"::int, "month"::int, "day"::int, "hour"::int, "min"::int, "sec", "timezone" )::edgedbt.timestamptz_t $$; }; CREATE FUNCTION std::to_datetime(epochseconds: std::float64) -> std::datetime { CREATE ANNOTATION std::description := 'Create a `datetime` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT to_timestamp("epochseconds")::edgedbt.timestamptz_t $$; }; CREATE FUNCTION std::to_datetime(epochseconds: std::int64) -> std::datetime { CREATE ANNOTATION std::description := 'Create a `datetime` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT to_timestamp("epochseconds")::edgedbt.timestamptz_t $$; }; CREATE FUNCTION std::to_datetime(epochseconds: std::decimal) -> std::datetime { CREATE ANNOTATION std::description := 'Create a `datetime` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT to_timestamp("epochseconds")::edgedbt.timestamptz_t $$; }; CREATE FUNCTION std::to_duration( NAMED ONLY hours: std::int64=0, NAMED ONLY minutes: std::int64=0, NAMED ONLY seconds: std::float64=0, NAMED ONLY microseconds: std::int64=0 ) -> std::duration { CREATE ANNOTATION std::description := 'Create a `duration` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( make_interval( 0, 0, 0, 0, "hours"::int, "minutes"::int, "seconds" ) + (microseconds::text || ' microseconds')::interval )::edgedbt.duration_t $$; }; CREATE FUNCTION std::to_bigint(s: std::str, fmt: OPTIONAL str={}) -> std::bigint { CREATE ANNOTATION std::description := 'Create a `bigint` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN edgedb_VER.str_to_bigint("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::edgedbt.bigint_t, 'invalid_parameter_value', msg => ( 'to_bigint(): "fmt" argument must be a non-empty string' ) ) ELSE edgedb_VER.raise_on_null( to_number("s", "fmt")::edgedbt.bigint_t, 'invalid_parameter_value', msg => 'to_bigint(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_decimal(s: std::str, fmt: OPTIONAL str={}) -> std::decimal { CREATE ANNOTATION std::description := 'Create a `decimal` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN edgedb_VER.str_to_decimal("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::numeric, 'invalid_parameter_value', msg => ( 'to_decimal(): "fmt" argument must be a non-empty string' ) ) ELSE edgedb_VER.raise_on_null( to_number("s", "fmt")::numeric, 'invalid_parameter_value', msg => 'to_decimal(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_int64(s: std::str, fmt: OPTIONAL str={}) -> std::int64 { CREATE ANNOTATION std::description := 'Create a `int64` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN -- Must use the noninline version to prevent -- the overeager function inliner from crashing edgedb_VER.str_to_int64_noinline("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::bigint, 'invalid_parameter_value', msg => 'to_int64(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_number("s", "fmt")::bigint, 'invalid_parameter_value', msg => 'to_int64(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_int64(val: std::bytes, endian: std::Endian) -> std::int64 { CREATE ANNOTATION std::description := 'Convert bytes into `int64` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN (length(val) = 8) THEN ( 'x' || right( ( CASE WHEN (endian = 'Little') THEN substring(val, 8, 1) || substring(val, 7, 1) || substring(val, 6, 1) || substring(val, 5, 1) || substring(val, 4, 1) || substring(val, 3, 1) || substring(val, 2, 1) || substring(val, 1, 1) ELSE val END )::text, 16 ) )::bit(64)::int8 ELSE edgedb_VER.raise( 0::int8, 'invalid_parameter_value', msg => ( 'to_int64(): the argument must be exactly 8 bytes long' ) ) END $$; }; CREATE FUNCTION std::to_int32(s: std::str, fmt: OPTIONAL str={}) -> std::int32 { CREATE ANNOTATION std::description := 'Create a `int32` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN -- Must use the noninline version to prevent -- the overeager function inliner from crashing edgedb_VER.str_to_int32_noinline("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::int, 'invalid_parameter_value', msg => 'to_int32(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_number("s", "fmt")::int, 'invalid_parameter_value', msg => 'to_int32(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_int32(val: std::bytes, endian: std::Endian) -> std::int32 { CREATE ANNOTATION std::description := 'Convert bytes into `int32` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN (length(val) = 4) THEN ( 'x' || right( ( CASE WHEN (endian = 'Little') THEN substring(val, 4, 1) || substring(val, 3, 1) || substring(val, 2, 1) || substring(val, 1, 1) ELSE val END )::text, 8 ) )::bit(32)::int4 ELSE edgedb_VER.raise( 0::int4, 'invalid_parameter_value', msg => ( 'to_int32(): the argument must be exactly 4 bytes long' ) ) END $$; }; CREATE FUNCTION std::to_int16(s: std::str, fmt: OPTIONAL str={}) -> std::int16 { CREATE ANNOTATION std::description := 'Create a `int16` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN -- Must use the noninline version to prevent -- the overeager function inliner from crashing edgedb_VER.str_to_int16_noinline("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::smallint, 'invalid_parameter_value', msg => 'to_int16(): "fmt" argument must be a non-empty string' ) ELSE edgedb_VER.raise_on_null( to_number("s", "fmt")::smallint, 'invalid_parameter_value', msg => 'to_int16(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_int16(val: std::bytes, endian: std::Endian) -> std::int16 { CREATE ANNOTATION std::description := 'Convert bytes into `int16` value.'; SET volatility := 'Immutable'; # There is no direct cast from bits to int2 in Postgres, so we need to use # the bit(32)::int4 as an intermediary value. However, the first bit is # the sign bit and must be preserved as such, otherwise we will have # overflow when casting from int4 to int2. So we pad the bytes with 0 on # the right (which happens by default when casting 2 bytes from text to # bit(32)) and then right-shift preserving the sign bit. This results in # the int4 value in the lower two bytes being fully compatible with int2 # value. USING SQL $$ SELECT CASE WHEN (length(val) = 2) THEN ( ( ( 'x' || right( ( CASE WHEN (endian = 'Little') THEN substring(val, 2, 1) || substring(val, 1, 1) ELSE val END )::text, 4 ) )::bit(32)::int4 )>>16 )::int2 ELSE edgedb_VER.raise( 0::int2, 'invalid_parameter_value', msg => ( 'to_int16(): the argument must be exactly 2 bytes long' ) ) END $$; }; CREATE FUNCTION std::to_float64(s: std::str, fmt: OPTIONAL str={}) -> std::float64 { CREATE ANNOTATION std::description := 'Create a `float64` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN edgedb_VER.str_to_float64_noinline("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::float8, 'invalid_parameter_value', msg => ( 'to_float64(): "fmt" argument must be a non-empty string' ) ) ELSE edgedb_VER.raise_on_null( to_number("s", "fmt")::float8, 'invalid_parameter_value', msg => 'to_float64(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_float32(s: std::str, fmt: OPTIONAL str={}) -> std::float32 { CREATE ANNOTATION std::description := 'Create a `float32` value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT ( CASE WHEN "fmt" IS NULL THEN edgedb_VER.str_to_float32_noinline("s") WHEN "fmt" = '' THEN edgedb_VER.raise( NULL::float4, 'invalid_parameter_value', msg => ( 'to_float32(): "fmt" argument must be a non-empty string' ) ) ELSE edgedb_VER.raise_on_null( to_number("s", "fmt")::float4, 'invalid_parameter_value', msg => 'to_float32(): format ''' || "fmt" || ''' is invalid' ) END ) $$; }; CREATE FUNCTION std::to_uuid(val: std::bytes) -> std::uuid { CREATE ANNOTATION std::description := 'Convert binary representation into UUID value.'; SET volatility := 'Immutable'; USING SQL $$ SELECT CASE WHEN (length(val) = 16) THEN ENCODE(val, 'hex')::uuid ELSE edgedb_VER.raise( NULL::uuid, 'invalid_parameter_value', msg => ( 'to_uuid(): the argument must be exactly 16 bytes long' ) ) END $$; }; ================================================ FILE: edb/lib/sys.edgeql ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # CREATE MODULE sys; CREATE MODULE sys::perm; CREATE PERMISSION sys::perm::superuser; CREATE PERMISSION sys::perm::data_modification; CREATE PERMISSION sys::perm::ddl; CREATE PERMISSION sys::perm::branch_config; CREATE PERMISSION sys::perm::sql_session_config; CREATE PERMISSION sys::perm::analyze; CREATE PERMISSION sys::perm::query_stats_read; CREATE PERMISSION sys::perm::approximate_count; CREATE SCALAR TYPE sys::TransactionIsolation EXTENDING enum; CREATE SCALAR TYPE sys::TransactionAccessMode EXTENDING enum; CREATE SCALAR TYPE sys::TransactionDeferrability EXTENDING enum; CREATE SCALAR TYPE sys::VersionStage EXTENDING enum; CREATE SCALAR TYPE sys::QueryType EXTENDING enum; CREATE SCALAR TYPE sys::OutputFormat EXTENDING enum; CREATE ABSTRACT TYPE sys::SystemObject EXTENDING schema::Object; CREATE ABSTRACT TYPE sys::ExternalObject EXTENDING sys::SystemObject; CREATE TYPE sys::Branch EXTENDING sys::ExternalObject, schema::AnnotationSubject { ALTER PROPERTY name { CREATE CONSTRAINT std::exclusive; }; CREATE PROPERTY last_migration-> std::str; }; CREATE ALIAS sys::Database := sys::Branch; CREATE TYPE sys::ExtensionPackage EXTENDING sys::SystemObject, schema::AnnotationSubject { CREATE REQUIRED PROPERTY script -> str; CREATE REQUIRED PROPERTY version -> tuple< major: std::int64, minor: std::int64, stage: sys::VersionStage, stage_no: std::int64, local: array, >; }; CREATE TYPE sys::ExtensionPackageMigration EXTENDING sys::SystemObject, schema::AnnotationSubject { CREATE REQUIRED PROPERTY script -> str; CREATE REQUIRED PROPERTY from_version -> tuple< major: std::int64, minor: std::int64, stage: sys::VersionStage, stage_no: std::int64, local: array, >; CREATE REQUIRED PROPERTY to_version -> tuple< major: std::int64, minor: std::int64, stage: sys::VersionStage, stage_no: std::int64, local: array, >; }; ALTER TYPE schema::Extension { CREATE REQUIRED LINK package -> sys::ExtensionPackage { CREATE CONSTRAINT std::exclusive; } }; CREATE TYPE sys::Role EXTENDING sys::SystemObject, schema::InheritingObject, schema::AnnotationSubject { ALTER PROPERTY name { CREATE CONSTRAINT std::exclusive; }; CREATE REQUIRED PROPERTY superuser -> std::bool; # Backwards compatibility. CREATE PROPERTY is_superuser := .superuser; CREATE PROPERTY password -> std::str; CREATE MULTI PROPERTY permissions -> std::str; CREATE MULTI PROPERTY branches -> std::str; CREATE PROPERTY apply_access_policies_pg_default -> std::bool; CREATE ACCESS POLICY ap_read deny select using ( not global sys::perm::superuser ); }; ALTER TYPE sys::Role { CREATE MULTI LINK member_of -> sys::Role; }; CREATE TYPE sys::QueryStats EXTENDING sys::ExternalObject { CREATE LINK branch -> sys::Branch { CREATE ANNOTATION std::description := "The branch this statistics entry was collected in."; }; CREATE PROPERTY query -> std::str { CREATE ANNOTATION std::description := "Text string of a representative query."; }; CREATE PROPERTY query_type -> sys::QueryType { CREATE ANNOTATION std::description := "Type of the query."; }; CREATE PROPERTY tag -> std::str { CREATE ANNOTATION std::description := "Query tag, commonly specifies the origin of the query, e.g 'gel/cli' for queries originating from the CLI. Clients can specify a tag for easier query identification."; }; CREATE PROPERTY compilation_config -> std::json; CREATE PROPERTY protocol_version -> tuple; CREATE PROPERTY default_namespace -> std::str; CREATE OPTIONAL PROPERTY namespace_aliases -> std::json; CREATE OPTIONAL PROPERTY output_format -> sys::OutputFormat; CREATE OPTIONAL PROPERTY expect_one -> std::bool; CREATE OPTIONAL PROPERTY implicit_limit -> std::int64; CREATE OPTIONAL PROPERTY inline_typeids -> std::bool; CREATE OPTIONAL PROPERTY inline_typenames -> std::bool; CREATE OPTIONAL PROPERTY inline_objectids -> std::bool; CREATE PROPERTY plans -> std::int64 { CREATE ANNOTATION std::description := "Number of times the query was planned in the backend."; }; CREATE PROPERTY total_plan_time -> std::duration { CREATE ANNOTATION std::description := "Total time spent planning the query in the backend."; }; CREATE PROPERTY min_plan_time -> std::duration { CREATE ANNOTATION std::description := "Minimum time spent planning the query in the backend. " ++ "This field will be zero if the counter has been reset " ++ "using the `sys::reset_query_stats` function " ++ "with the `minmax_only` parameter set to `true` " ++ "and never been planned since."; }; CREATE PROPERTY max_plan_time -> std::duration { CREATE ANNOTATION std::description := "Maximum time spent planning the query in the backend. " ++ "This field will be zero if the counter has been reset " ++ "using the `sys::reset_query_stats` function " ++ "with the `minmax_only` parameter set to `true` " ++ "and never been planned since."; }; CREATE PROPERTY mean_plan_time -> std::duration { CREATE ANNOTATION std::description := "Mean time spent planning the query in the backend."; }; CREATE PROPERTY stddev_plan_time -> std::duration { CREATE ANNOTATION std::description := "Population standard deviation of time spent " ++ "planning the query in the backend."; }; CREATE PROPERTY calls -> std::int64 { CREATE ANNOTATION std::description := "Number of times the query was executed."; }; CREATE PROPERTY total_exec_time -> std::duration { CREATE ANNOTATION std::description := "Total time spent executing the query in the backend."; }; CREATE PROPERTY min_exec_time -> std::duration { CREATE ANNOTATION std::description := "Minimum time spent executing the query in the backend, " ++ "this field will be zero until this query is executed " ++ "first time after reset performed by the " ++ "`sys::reset_query_stats` function with the " ++ "`minmax_only` parameter set to `true`"; }; CREATE PROPERTY max_exec_time -> std::duration { CREATE ANNOTATION std::description := "Maximum time spent executing the query in the backend, " ++ "this field will be zero until this query is executed " ++ "first time after reset performed by the " ++ "`sys::reset_query_stats` function with the " ++ "`minmax_only` parameter set to `true`"; }; CREATE PROPERTY mean_exec_time -> std::duration { CREATE ANNOTATION std::description := "Mean time spent executing the query in the backend."; }; CREATE PROPERTY stddev_exec_time -> std::duration { CREATE ANNOTATION std::description := "Population standard deviation of time spent " ++ "executing the query in the backend."; }; CREATE PROPERTY rows -> std::int64 { CREATE ANNOTATION std::description := "Total number of rows retrieved or affected by the query."; }; CREATE PROPERTY stats_since -> std::datetime { CREATE ANNOTATION std::description := "Time at which statistics gathering started for this query."; }; CREATE PROPERTY minmax_stats_since -> std::datetime { CREATE ANNOTATION std::description := "Time at which min/max statistics gathering started " ++ "for this query (fields `min_plan_time`, `max_plan_time`, " ++ "`min_exec_time` and `max_exec_time`)."; }; CREATE ACCESS POLICY ap_read allow select using ( global sys::perm::query_stats_read ); }; CREATE FUNCTION sys::reset_query_stats( named only branch_name: OPTIONAL std::str = {}, named only id: OPTIONAL std::uuid = {}, named only minmax_only: OPTIONAL std::bool = false, ) -> OPTIONAL std::datetime { CREATE ANNOTATION std::description := 'Discard query statistics gathered so far corresponding to the ' ++ 'specified `branch_name` and `id`. If either of the ' ++ 'parameters is not specified, the statistics that match with the ' ++ 'other parameter will be reset. If no parameter is specified, ' ++ 'it will discard all statistics. When `minmax_only` is `true`, ' ++ 'only the values of minimum and maximum planning and execution ' ++ 'time will be reset (i.e. `min_plan_time`, `max_plan_time`, ' ++ '`min_exec_time` and `max_exec_time` fields). The default value ' ++ 'for `minmax_only` parameter is `false`. This function returns ' ++ 'the time of a reset. This time is saved to `stats_reset` or ' ++ '`minmax_stats_since` field of `sys::QueryStats` if the ' ++ 'corresponding reset was actually performed.'; SET volatility := 'Volatile'; USING SQL FUNCTION 'edgedb.reset_query_stats'; set required_permissions := { sys::perm::superuser }; }; # An intermediate function is needed because we can't # cast JSON to tuples yet. DO NOT use directly, it'll go away. CREATE FUNCTION sys::__version_internal() -> tuple> { # This function reads from a table. SET volatility := 'Stable'; SET internal := true; USING SQL $$ SELECT (v ->> 'major')::int8, (v ->> 'minor')::int8, (v ->> 'stage')::text, (v ->> 'stage_no')::int8, (SELECT coalesce(array_agg(el), ARRAY[]::text[]) FROM jsonb_array_elements_text(v -> 'local') AS el) FROM (SELECT pg_catalog.current_setting('edgedb.server_version')::jsonb AS v ) AS q $$; }; CREATE FUNCTION sys::get_version() -> tuple> { CREATE ANNOTATION std::description := 'Return the server version as a tuple.'; SET volatility := 'Stable'; USING ( SELECT >>sys::__version_internal() ); }; CREATE FUNCTION sys::get_version_as_str() -> std::str { CREATE ANNOTATION std::description := 'Return the server version as a string.'; SET volatility := 'Stable'; USING ( WITH v := sys::get_version() SELECT v.major ++ '.' ++ v.minor ++ (('-' ++ v.stage ++ '.' ++ v.stage_no) IF v.stage != 'final' ELSE '') ++ (('+' ++ std::array_join(v.local, '.')) IF len(v.local) > 0 ELSE '') ); }; CREATE FUNCTION sys::get_instance_name() -> std::str{ CREATE ANNOTATION std::description := 'Return the server instance name.'; SET volatility := 'Stable'; USING SQL $$ SELECT pg_catalog.current_setting('edgedb.instance_name'); $$; }; CREATE FUNCTION sys::get_transaction_isolation() -> sys::TransactionIsolation { CREATE ANNOTATION std::description := 'Return the isolation level of the current transaction.'; # This function only reads from a table. SET volatility := 'Stable'; SET force_return_cast := true; USING SQL FUNCTION 'edgedb._get_transaction_isolation'; }; CREATE FUNCTION sys::get_current_database() -> str { CREATE ANNOTATION std::description := 'Return the name of the current database branch as a string.'; # The results won't change within a single statement. SET volatility := 'Stable'; USING SQL FUNCTION 'edgedb.get_current_database'; }; CREATE FUNCTION sys::get_current_branch() -> str { CREATE ANNOTATION std::description := 'Return the name of the current database branch as a string.'; # The results won't change within a single statement. SET volatility := 'Stable'; USING SQL FUNCTION 'edgedb.get_current_database'; }; CREATE FUNCTION sys::_describe_roles_as_ddl() -> str { # The results won't change within a single statement. SET volatility := 'Stable'; SET internal := true; USING SQL FUNCTION 'edgedb._describe_roles_as_ddl'; set required_permissions := { sys::perm::superuser }; }; CREATE FUNCTION sys::_get_all_role_memberships(r: uuid) -> array { # The results won't change within a single statement. SET volatility := 'Stable'; SET internal := true; USING SQL FUNCTION 'edgedb._all_role_memberships'; set impl_is_strict := false; set required_permissions := { sys::perm::superuser }; }; ALTER TYPE sys::Role { CREATE MULTI PROPERTY all_permissions := distinct({ .permissions, ( with self_id := .id select detached sys::Role filter .id in array_unpack( sys::_get_all_role_memberships(self_id) ) ).permissions, }); }; CREATE FUNCTION sys::__pg_and(a: OPTIONAL std::bool, b: OPTIONAL std::bool) -> std::bool { SET volatility := 'Immutable'; SET internal := true; USING SQL $$ SELECT a AND b; $$; }; CREATE FUNCTION sys::__pg_or(a: OPTIONAL std::bool, b: OPTIONAL std::bool) -> std::bool { SET volatility := 'Immutable'; SET internal := true; USING SQL $$ SELECT a OR b; $$; }; CREATE FUNCTION sys::approximate_count( type: schema::ObjectType, NAMED ONLY ignore_subtypes: std::bool=false, ) -> int64 { SET volatility := 'Stable'; USING SQL FUNCTION 'edgedb.approximate_count'; set impl_is_strict := false; set required_permissions := { sys::perm::approximate_count }; }; CREATE REQUIRED GLOBAL sys::current_role -> str { SET default := ''; }; CREATE REQUIRED GLOBAL sys::current_permissions -> array { SET default := >[]; }; # Add permissions to schema and std. # These modules are populated before sys permissions so we need to # add these restrictions here. ALTER TYPE schema::Migration { CREATE ACCESS POLICY ap_read allow select using ( global sys::perm::ddl ); }; ALTER FUNCTION std::sequence_reset( seq: schema::ScalarType, value: std::int64, ) { SET required_permissions := sys::perm::ddl; }; ALTER FUNCTION std::sequence_reset( seq: schema::ScalarType, ) { SET required_permissions := sys::perm::ddl; }; ================================================ FILE: edb/load_ext/main.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Command to load an extension into an edgedb installation. It is a command distributed with the server, but it is designed so that it has no dependencies and does not import any server code if it is *only* installing the postgres part of an extension with a specified pg_config, so it *can* be pulled out and used standalone. (It requires Python 3.11 for tomllib.) """ from __future__ import annotations import argparse import json import os import pathlib import shutil import subprocess import sys import tempfile import tomllib import zipfile # Directories that we map to config values in pg_config. CONFIG_PATHS = { 'share': 'sharedir', 'lib': 'pkglibdir', 'include': 'pkgincludedir-server', } def install_pg_extension( pkg: pathlib.Path, pg_config: dict[str, str], manifest_target: pathlib.Path | None, ) -> None: to_install = [] with zipfile.ZipFile(pkg) as z: base = get_dir(z) # Compute what files to install for entry in z.infolist(): fpath = pathlib.Path(entry.filename) if entry.is_dir(): continue if fpath.parts[0] != str(base): continue # If the path is too short or isn't one of the # directories we know about, skip it. if ( len(fpath.parts) < 3 or not (config_field := CONFIG_PATHS.get(fpath.parts[1])) or fpath.parts[2] != 'postgresql' ): # print("Skipping", fpath) continue fpath = fpath.relative_to( pathlib.Path(fpath.parts[0]) / fpath.parts[1] / 'postgresql' ) to_install.append((entry.filename, config_field, fpath)) # Write a manifest out of all the files installed into the # postgres installation. if manifest_target: manifest_contents = [ {'postgres_dir': config_field, 'path': str(fpath)} for _, config_field, fpath in to_install ] with open(manifest_target, "w") as f: json.dump(manifest_contents, f) # Install them for zip_name, config_field, fpath in to_install: config_dir = pg_config[config_field] target_file = config_dir / fpath os.makedirs(target_file.parent, exist_ok=True) with z.open(zip_name) as src: with open(target_file, "wb") as dst: print("Installing", target_file) shutil.copyfileobj(src, dst) def uninstall_pg_extension( pg_manifest: list[dict[str, str]], pg_config: dict[str, str], ) -> None: for entry in pg_manifest: config_field = entry['postgres_dir'] fpath = entry['path'] full_path = pathlib.Path(pg_config[config_field]) / fpath print("Removing", full_path) try: os.remove(full_path) except FileNotFoundError: print("Could not remove missing", full_path) def get_pg_config(pg_config_path: pathlib.Path) -> dict[str, str]: output = subprocess.run( pg_config_path, capture_output=True, text=True, check=True, ) stdout_lines = output.stdout.split('\n') config = {} for line in stdout_lines: k, eq, v = line.partition('=') if eq: config[k.strip().lower()] = v.strip() return config def get_dir(z: zipfile.ZipFile) -> pathlib.Path: files = z.infolist() if not (files and files[0].is_dir()): print('ERROR: Extension package must contain one top-level dir') sys.exit(1) dirname = pathlib.Path(files[0].filename) return dirname def install_edgedb_extension( pkg: pathlib.Path, ext_dir: pathlib.Path, ) -> pathlib.Path: with tempfile.TemporaryDirectory() as tdir, \ zipfile.ZipFile(pkg) as z: dirname = get_dir(z) target = ext_dir / dirname if target.exists(): print( f'ERROR: Extension {dirname} is already installed at {target}' ) sys.exit(1) print("Installing", target) ttarget = pathlib.Path(tdir) / dirname os.mkdir(ttarget) with z.open(str(dirname / 'MANIFEST.toml')) as m: manifest = tomllib.load(m) files = ['MANIFEST.toml'] + manifest['files'] for f in files: target_file = target / f ttarget_file = ttarget / f with z.open(str(dirname / f)) as src: with open(ttarget_file, "wb") as dst: print("Installing", target_file) shutil.copyfileobj(src, dst) os.makedirs(ext_dir, exist_ok=True) # If there was a race and the file was created between the # earlier check and now, we'll produce a worse error # message. Oh well. shutil.move(ttarget, ext_dir) return target def load_ext_install( package: pathlib.Path, skip_edgedb: bool, skip_gel: bool, skip_postgres: bool, with_pg_config: pathlib.Path | None, ) -> None: target_dir = None if not skip_edgedb and not skip_gel: from edb import buildmeta ext_dir = buildmeta.get_extension_dir_path() target_dir = install_edgedb_extension(package, ext_dir) if not skip_postgres: if with_pg_config is None: from edb import buildmeta with_pg_config = buildmeta.get_pg_config_path() pg_config = get_pg_config(with_pg_config) pg_manifest = target_dir / "PG_MANIFEST.json" if target_dir else None install_pg_extension(package, pg_config, pg_manifest) def load_ext_uninstall( package: pathlib.Path, skip_edgedb: bool, skip_gel: bool, skip_postgres: bool, with_pg_config: pathlib.Path | None, ) -> None: from edb import buildmeta target_dir = None if len(package.parts) != 1: print( f'ERROR: {package} is not a valid extension name' ) sys.exit(1) ext_dir = buildmeta.get_extension_dir_path() target_dir = ext_dir / package if not target_dir.exists(): print( f'ERROR: Extension {package} is not currently ' f'installed at {target_dir}' ) sys.exit(1) if not skip_postgres: try: with open(target_dir / "PG_MANIFEST.json") as f: pg_manifest = json.load(f) except FileNotFoundError: pg_manifest = [] if with_pg_config is None: with_pg_config = buildmeta.get_pg_config_path() pg_config = get_pg_config(with_pg_config) uninstall_pg_extension(pg_manifest, pg_config) if not skip_edgedb and not skip_gel: print("Removing", target_dir) shutil.rmtree(target_dir) def load_ext_list_packages() -> None: from edb import buildmeta ext_dir = buildmeta.get_extension_dir_path() exts = [] try: with os.scandir(ext_dir) as it: for entry in it: entry_path = pathlib.Path(entry) manifest_path = entry_path / 'MANIFEST.toml' if ( entry.is_dir() and manifest_path.exists() ): with open(manifest_path, 'rb') as m: manifest = tomllib.load(m) info = dict( key=entry_path.name, extension_name=manifest['name'], extension_version=manifest['version'], path=str(entry_path.absolute()), ) exts.append(info) except FileNotFoundError: pass print(json.dumps(exts, indent=4)) def load_ext_main( *, package: pathlib.Path | None, uninstall: pathlib.Path | None, list_packages: bool, **kwargs, ) -> None: if uninstall: load_ext_uninstall(uninstall, **kwargs) elif package: load_ext_install(package, **kwargs) elif list_packages: load_ext_list_packages() else: raise AssertionError('No command specified?') parser = argparse.ArgumentParser(description='Install an extension package') parser.add_argument( '--skip-gel', action='store_true', help="Skip installing the extension package into the Gel " "installation", ) parser.add_argument( '--skip-edgedb', action='store_true', help=argparse.SUPPRESS, ) parser.add_argument( '--skip-postgres', action='store_true', help="Skip installing the extension package into the " "Postgres installation", ) parser.add_argument( '--with-pg-config', metavar='PATH', help="Use the specified pg_config binary to find the Postgres " "to install into (instead of using the bundled one)" ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument( '--list-packages', action='store_true', help="List the extension packages that are installed (in JSON)" ) group.add_argument( '--uninstall', metavar='NAME', type=pathlib.Path, help="Uninstall a package (by package directory name) instead of " "installing it" ) group.add_argument('package', nargs='?', type=pathlib.Path) def main(argv: tuple[str, ...] | None = None): argv = argv if argv is not None else tuple(sys.argv[1:]) args = parser.parse_args(argv) load_ext_main(**vars(args)) if __name__ == '__main__': main() ================================================ FILE: edb/pgsql/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations ================================================ FILE: edb/pgsql/ast.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import enum import dataclasses import typing import uuid from edb.common import ast, span from edb.common import typeutils from edb.edgeql import ast as qlast from edb.ir import ast as irast if typing.TYPE_CHECKING: # PathAspect is imported without qualifiers here because otherwise in # base.AST._collect_direct_fields, typing.get_type_hints will not correctly # locate the type. from .compiler.enums import PathAspect # The structure of the nodes mostly follows that of Postgres' # parsenodes.h and primnodes.h, but only with fields that are # relevant to parsing and code generation. # # Certain nodes have EdgeDB-specific fields used by the # compiler. Span = span.Span class Base(ast.AST): __ast_hidden__ = {'span'} span: typing.Optional[Span] = None def __init__(self, **kwargs): super().__init__(**kwargs) def __repr__(self): return f'' def dump_sql(self) -> None: from edb.common.debug import dump_sql dump_sql(self, reordered=True, pretty=True) class ImmutableBase(ast.ImmutableASTMixin, Base): __ast_mutable_fields__ = frozenset(['span']) class Alias(ImmutableBase): """Alias for a range variable.""" # aliased relation name aliasname: str # optional list of column aliases colnames: typing.Optional[list[str]] = None class Keyword(ImmutableBase): """An SQL keyword that must be output without quoting.""" name: str # Keyword name class Star(Base): """'*' representing all columns of a table or compound field.""" class BaseExpr(Base): """Any non-statement expression node that returns a value.""" __ast_meta__ = {'nullable'} nullable: typing.Optional[bool] = None # Whether the result can be NULL. ser_safe: bool = False # Whether the expr is serialization-safe. def __init__( self, *, nullable: typing.Optional[bool] = None, **kwargs ) -> None: nullable = self._is_nullable(kwargs, nullable) super().__init__(nullable=nullable, **kwargs) def _is_nullable( self, kwargs: dict[str, object], nullable: typing.Optional[bool] ) -> bool: if nullable is None: default = type(self).get_field('nullable').default if default is not None: nullable = default else: nullable = self._infer_nullability(kwargs) return nullable def _infer_nullability(self, kwargs: dict[str, object]) -> bool: nullable = False for v in kwargs.values(): if typeutils.is_container(v): items = typing.cast(typing.Iterable, v) nullable = any(getattr(vv, 'nullable', False) for vv in items) elif getattr(v, 'nullable', None): nullable = True if nullable: break return nullable class ImmutableBaseExpr(BaseExpr, ImmutableBase): pass class OutputVar(ImmutableBaseExpr): """A base class representing expression output address.""" # Whether this represents a packed array of data is_packed_multi: bool = False class ExprOutputVar(OutputVar): """A "fake" output var representing a wrapped BaseExpr. In some obscure cases (specifically, returning __type__ from a non-view base relation that doesn't actually contain it), we need to return a non output var value from something expecting OutputVar. Instead of fully blowing away the type discipline of OutputVar and making everything operate on BaseExpr, we require such expressions to be explicitly wrapped. """ expr: BaseExpr class EdgeQLPathInfo(Base): """A general mixin providing EdgeQL-specific metadata on certain nodes.""" # Ignore the below fields in AST visitor/transformer. __ast_meta__ = { 'path_id', 'path_bonds', 'path_outputs', 'is_distinct', 'path_id_mask', 'path_namespace', 'packed_path_outputs', 'packed_path_namespace', } # The path id represented by the node. path_id: typing.Optional[irast.PathId] = None # Whether the node represents a distinct set. is_distinct: bool = True # A subset of paths necessary to perform joining. path_bonds: set[tuple[irast.PathId, bool]] = ast.field(factory=set) # Whether to ignore namespaces when looking at path outputs. # TODO: Maybe instead, Relation should have a way of specifying # output by PointerRef instead. strip_output_namespaces: bool = False # Map of res target names corresponding to paths. path_outputs: dict[ tuple[irast.PathId, PathAspect], OutputVar ] = ast.field(factory=dict) # Map of res target names corresponding to materialized paths. packed_path_outputs: typing.Optional[dict[ tuple[irast.PathId, PathAspect], OutputVar, ]] = None def get_path_outputs( self, flavor: str ) -> dict[tuple[irast.PathId, PathAspect], OutputVar]: if flavor == 'packed': if self.packed_path_outputs is None: self.packed_path_outputs = {} return self.packed_path_outputs elif flavor == 'normal': return self.path_outputs else: raise AssertionError(f'unexpected flavor "{flavor}"') path_id_mask: set[irast.PathId] = ast.field(factory=set) # Map of col refs corresponding to paths. path_namespace: dict[ tuple[irast.PathId, PathAspect], BaseExpr, ] = ast.field(factory=dict) # Same, but for packed. packed_path_namespace: typing.Optional[dict[ tuple[irast.PathId, PathAspect], BaseExpr, ]] = None class BaseRangeVar(ImmutableBaseExpr): """ Range variable, used in FROM clauses. This can be though as a specific instance of a table within a query. """ __ast_meta__ = {'schema_object_id', 'tag', 'ir_origins'} __ast_mutable_fields__ = frozenset(['ir_origins', 'span']) # This is a hack, since there is some code that relies on not # having an alias on a range var (to refer to a CTE directly, for # example, while other code depends on reading the alias name out # of range vars. This is mostly disjoint code, so we hack around it # with an empty aliasname. alias: Alias = Alias(aliasname='') #: The id of the schema object this rvar represents schema_object_id: typing.Optional[uuid.UUID] = None #: Optional identification piece to describe what's inside the rvar tag: typing.Optional[str] = None #: Optional reference to the sets that this refers to #: Only used for helping recover information during explain. #: The type is a list of objects to help prevent any thought #: of using this field computationally during compilation. ir_origins: typing.Optional[list[object]] = None def __repr__(self) -> str: return ( f'' ) class BaseRelation(EdgeQLPathInfo, BaseExpr): """ A relation-valued (table-valued) expression. """ name: typing.Optional[str] = None nullable: typing.Optional[bool] = None # Whether the result can be NULL. class Relation(BaseRelation): """A reference to a table or a view.""" # The type or pointer this represents. # Should be non-None for any relation arising from a type or # pointer during compilation. type_or_ptr_ref: typing.Optional[irast.TypeRef | irast.PointerRef] = None catalogname: typing.Optional[str] = None schemaname: typing.Optional[str] = None is_temporary: typing.Optional[bool] = None class CommonTableExpr(Base): # Query name (unqualified) name: str # Whether the result can be NULL. nullable: typing.Optional[bool] = None # Optional list of column names aliascolnames: typing.Optional[list[str]] = None # The CTE query query: Query # True if this CTE is recursive recursive: bool = False # If specified, determines if CTE is [NOT] MATERIALIZED materialized: typing.Optional[bool] = None # the dml stmt that this CTE was generated for for_dml_stmt: typing.Optional[irast.MutatingLikeStmt] = None # marks the CTE that contains the output of a DML operation # (so it can be used in RETURNING and CommandComplete tag) output_of_dml: typing.Optional[irast.MutatingLikeStmt] = None def __repr__(self): return ( f'' ) class PathRangeVar(BaseRangeVar): #: The IR TypeRef this rvar represents (if any). typeref: typing.Optional[irast.TypeRef] = None @property def query(self) -> BaseRelation: raise NotImplementedError class RelRangeVar(PathRangeVar): """Relation range variable, used in FROM clauses.""" relation: BaseRelation | CommonTableExpr include_inherited: bool = True @property def query(self) -> BaseRelation: if isinstance(self.relation, CommonTableExpr): return self.relation.query else: return self.relation def __repr__(self) -> str: return ( f'' ) class IntersectionRangeVar(PathRangeVar): component_rvars: list[PathRangeVar] class DynamicRangeVarFunc(typing.Protocol): """A 'dynamic' range var that provides a callback hook for finding path_ids in range var. Used to sneak more complex search logic in. I am 100% going to regret this. Update: Sully says that he hasn't regretted it yet. """ # Lookup function for a DynamicRangeVar. If it returns a # PathRangeVar, keep looking in that rvar. If it returns # another expression, that's the output. def __call__( self, rel: Query, path_id: irast.PathId, *, flavor: str, aspect: str, env: typing.Any, ) -> typing.Optional[BaseExpr | PathRangeVar]: pass class DynamicRangeVar(PathRangeVar): dynamic_get_path: DynamicRangeVarFunc @property def query(self) -> BaseRelation: raise AssertionError('cannot retrieve query from a dynamic range var') # pickling is broken here, oh well def __getstate__(self) -> typing.Any: return () def __setstate__(self, state: typing.Any) -> None: self.dynamic_get_path = None # type: ignore class TypeName(ImmutableBase): """Type in definitions and casts.""" name: tuple[str, ...] # Type name setof: bool = False # SET OF? typmods: typing.Optional[list] = None # Type modifiers array_bounds: typing.Optional[list[int]] = None class ColumnRef(OutputVar): """Specifies a reference to a column.""" # Column name list. name: typing.Sequence[str | Star] # Whether the col is an optional path bond (i.e accepted when NULL) optional: typing.Optional[bool] = None def __repr__(self): if hasattr(self, 'name'): return ( f'' ) else: return super().__repr__() class TupleElementBase(ImmutableBase): path_id: irast.PathId name: typing.Optional[OutputVar | str] def __init__( self, path_id: irast.PathId, name: typing.Optional[OutputVar | str] = None, ): self.path_id = path_id self.name = name def __repr__(self): return ( f'<{self.__class__.__name__} ' f'name={self.name} path_id={self.path_id}>' ) class TupleElement(TupleElementBase): val: BaseExpr def __init__( self, path_id: irast.PathId, val: BaseExpr, *, name: typing.Optional[OutputVar | str] = None, ): super().__init__(path_id, name) self.val = val def __repr__(self): return ( f'<{self.__class__.__name__} ' f'name={self.name} val={self.val} path_id={self.path_id}>' ) class TupleVarBase(OutputVar): elements: typing.Sequence[TupleElementBase] named: bool nullable: bool typeref: typing.Optional[irast.TypeRef] def __init__( self, elements: list[TupleElementBase], *, named: bool = False, nullable: bool = False, is_packed_multi: bool = False, typeref: typing.Optional[irast.TypeRef] = None, ): self.elements = elements self.named = named self.nullable = nullable self.is_packed_multi = is_packed_multi self.typeref = typeref def __repr__(self): return f'<{self.__class__.__name__} [{self.elements!r}]' class TupleVar(TupleVarBase): elements: typing.Sequence[TupleElement] def __init__( self, elements: list[TupleElement], *, named: bool = False, nullable: bool = False, is_packed_multi: bool = False, typeref: typing.Optional[irast.TypeRef] = None, ): self.elements = elements self.named = named self.nullable = nullable self.is_packed_multi = is_packed_multi self.typeref = typeref class ParamRef(ImmutableBaseExpr): """Query parameter ($0..$n).""" __ast_mutable_fields__ = ( ImmutableBaseExpr.__ast_mutable_fields__ | frozenset(['number'])) # Number of the parameter. number: int class ResTarget(ImmutableBaseExpr): """Query result target.""" # Column name (optional) name: typing.Optional[str] = None # value expression to compute val: BaseExpr class InsertTarget(ImmutableBaseExpr): """Column reference in INSERT.""" # Column name name: str class UpdateTarget(ImmutableBaseExpr): """Query update target.""" # column names name: str # value expression to assign val: BaseExpr # subscripts, field names and '*' indirection: typing.Optional[list[IndirectionOp]] = None class OnConflictTarget(ImmutableBaseExpr): # IndexElems to infer unique index index_elems: typing.Optional[list[IndexElem]] = None # Partial-index predicate index_where: typing.Optional[BaseExpr] = None # Constraint name constraint_name: typing.Optional[str] = None class IndexElem(ImmutableBaseExpr): expr: BaseExpr ordering: typing.Optional[qlast.SortOrder] = None nulls_ordering: typing.Optional[qlast.NonesOrder] = None class OnConflictAction(enum.StrEnum): DO_NOTHING = "DO_NOTHING" DO_UPDATE = "DO_UPDATE" class OnConflictClause(ImmutableBaseExpr): action: OnConflictAction target: typing.Optional[OnConflictTarget] = None update_list: typing.Optional[list[UpdateTarget | MultiAssignRef]] = None update_where: typing.Optional[BaseExpr] = None class ReturningQuery(BaseRelation): target_list: list[ResTarget] = ast.field(factory=list) class NullRelation(ReturningQuery): """Special relation that produces nulls for all its attributes.""" type_or_ptr_ref: typing.Optional[irast.TypeRef | irast.PointerRef] = None where_clause: typing.Optional[BaseExpr] = None @dataclasses.dataclass class Param: #: postgres' variable index index: int #: whether parameter is required required: bool #: index in the "logical" arg map logical_index: int class Query(ReturningQuery): """Generic superclass representing a query.""" # Ignore the below fields in AST visitor/transformer. __ast_meta__ = {'path_rvar_map', 'path_packed_rvar_map', 'view_path_id_map', 'argnames', 'nullable'} view_path_id_map: dict[ irast.PathId, irast.PathId ] = ast.field(factory=dict) # Map of RangeVars corresponding to paths. path_rvar_map: dict[ tuple[irast.PathId, PathAspect], PathRangeVar ] = ast.field(factory=dict) # Map of materialized RangeVars corresponding to paths. path_packed_rvar_map: typing.Optional[dict[ tuple[irast.PathId, PathAspect], PathRangeVar, ]] = None argnames: typing.Optional[dict[str, Param]] = None ctes: typing.Optional[list[CommonTableExpr]] = None def get_rvar_map( self, flavor: str ) -> dict[tuple[irast.PathId, PathAspect], PathRangeVar]: if flavor == 'packed': if self.path_packed_rvar_map is None: self.path_packed_rvar_map = {} return self.path_packed_rvar_map elif flavor == 'normal': return self.path_rvar_map else: raise AssertionError(f'unexpected flavor "{flavor}"') def maybe_get_rvar_map( self, flavor: str ) -> typing.Optional[ dict[tuple[irast.PathId, PathAspect], PathRangeVar] ]: if flavor == 'packed': return self.path_packed_rvar_map elif flavor == 'normal': return self.path_rvar_map else: raise AssertionError(f'unexpected flavor "{flavor}"') @property def ser_safe(self): if not self.target_list: return False return all(t.ser_safe for t in self.target_list) def append_cte(self, cte: CommonTableExpr) -> None: if self.ctes is None: self.ctes = [] self.ctes.append(cte) class DMLQuery(Query): """Generic superclass for INSERT/UPDATE/DELETE statements.""" __abstract_node__ = True # Target relation to perform the operation on. relation: RelRangeVar # List of expressions returned returning_list: list[ResTarget] = ast.field(factory=list) @property def target_list(self): return self.returning_list class InsertStmt(DMLQuery): # (optional) list of target column names cols: typing.Optional[list[InsertTarget]] = None # source SELECT/VALUES or None select_stmt: typing.Optional[Query] = None # ON CONFLICT clause on_conflict: typing.Optional[OnConflictClause] = None class UpdateStmt(DMLQuery): # The UPDATE target list targets: list[UpdateTarget | MultiAssignRef] = ast.field( factory=list ) # WHERE clause where_clause: typing.Optional[BaseExpr] = None # optional FROM clause from_clause: list[BaseRangeVar] = ast.field(factory=list) class DeleteStmt(DMLQuery): # WHERE clause where_clause: typing.Optional[BaseExpr] = None # optional USING clause using_clause: list[BaseRangeVar] = ast.field(factory=list) class SelectStmt(Query): # List of DISTINCT ON expressions, empty list for DISTINCT ALL distinct_clause: typing.Optional[typing.Sequence[OutputVar | Star]] = None # The FROM clause from_clause: list[BaseRangeVar] = ast.field(factory=list) # The WHERE clause where_clause: typing.Optional[BaseExpr] = None # GROUP BY clauses group_clause: typing.Optional[list[Base]] = None # HAVING expression having_clause: typing.Optional[BaseExpr] = None # WINDOW window_name AS(...), window_clause: typing.Optional[list[Base]] = None # List of ImplicitRow's in a VALUES query values: typing.Optional[list[Base]] = None # ORDER BY clause sort_clause: typing.Optional[list[SortBy]] = None # OFFSET expression limit_offset: typing.Optional[BaseExpr] = None # LIMIT expression limit_count: typing.Optional[BaseExpr] = None # FOR UPDATE clause locking_clause: typing.Optional[list[LockingClause]] = None # Set operation type op: typing.Optional[str] = None # ALL modifier all: bool = False # Left operand of set op larg: typing.Optional[Query] = None # Right operand of set op, rarg: typing.Optional[Query] = None # When used as a sub-query, it is generally nullable. nullable: bool = True class Expr(ImmutableBaseExpr): """Infix, prefix, and postfix expressions.""" # Possibly-qualified name of operator name: str # Left argument, if any lexpr: typing.Optional[BaseExpr] = None # Right argument, if any rexpr: typing.Optional[BaseExpr] = None class BaseConstant(ImmutableBaseExpr): pass class StringConstant(BaseConstant): """A literal string constant.""" # Constant value val: str class NullConstant(BaseConstant): """A NULL constant.""" nullable: bool = True class BitStringConstant(BaseConstant): """A bit string constant.""" # x or b kind: str val: str class ByteaConstant(BaseConstant): """A bytea string.""" val: bytes class NumericConstant(BaseConstant): val: str class BooleanConstant(BaseConstant): val: bool class LiteralExpr(ImmutableBaseExpr): """A literal expression.""" # Expression text expr: str class TypeCast(ImmutableBaseExpr): """A CAST expression.""" # Expression being casted. arg: BaseExpr # Target type. type_name: TypeName class CollateClause(ImmutableBaseExpr): """A COLLATE expression.""" # Input expression arg: BaseExpr # Possibly-qualified collation name collname: typing.Sequence[str] class VariadicArgument(ImmutableBaseExpr): expr: BaseExpr nullable: bool = False class TableElement(ImmutableBase): pass class ColumnDef(TableElement): # name of column name: str # type of column typename: TypeName # default value, if any default_expr: typing.Optional[BaseExpr] = None # COLLATE clause, if any coll_clause: typing.Optional[BaseExpr] = None # NOT NULL is_not_null: bool = False class FuncCall(ImmutableBaseExpr): # Function name name: tuple[str, ...] # List of arguments args: list[BaseExpr] # ORDER BY agg_order: typing.Optional[list[SortBy]] # FILTER clause agg_filter: typing.Optional[BaseExpr] # Argument list is '*' agg_star: bool # Arguments were labeled DISTINCT agg_distinct: bool # arg_order is in WITHIN GROUP (...) agg_within_group: bool = False # OVER clause, if any over: typing.Optional[WindowDef] # WITH ORDINALITY with_ordinality: bool = False # list of Columndef nodes to describe result of # the function returning RECORD. coldeflist: list[ColumnDef] def __init__( self, *, nullable: typing.Optional[bool] = None, null_safe: bool = False, **kwargs, ) -> None: """Function call node. @param null_safe: Specifies whether this function is guaranteed to never return NULL on non-NULL input. """ if nullable is None and not null_safe: nullable = True super().__init__(nullable=nullable, **kwargs) class NamedFuncArg(ImmutableBaseExpr): name: str val: BaseExpr # N.B: Index and Slice aren't *really* Exprs but we mark them as such # so that nullability inference gets done on them. class Index(ImmutableBaseExpr): """Array subscript.""" idx: BaseExpr class Slice(ImmutableBaseExpr): """Array slice bounds.""" # Lower bound, if any lidx: typing.Optional[BaseExpr] # Upper bound if any ridx: typing.Optional[BaseExpr] class RecordIndirectionOp(ImmutableBase): name: str IndirectionOp = Slice | Index | Star | RecordIndirectionOp class Indirection(ImmutableBaseExpr): """Field and/or array element indirection.""" # Indirection subject arg: BaseExpr # Subscripts and/or field names and/or '*' indirection: list[IndirectionOp] class ArrayExpr(ImmutableBaseExpr): """ARRAY[] construct.""" # array element expressions elements: list[BaseExpr] class ArrayDimension(ImmutableBaseExpr): """An array dimension""" elements: list[BaseExpr] class MultiAssignRef(ImmutableBase): """UPDATE (a, b, c) = row-valued-expr.""" # row-valued expression source: BaseExpr # list of columns to assign to columns: list[str] class SortBy(ImmutableBase): """ORDER BY clause element.""" # expression to sort on node: BaseExpr # ASC/DESC/USING/default dir: typing.Optional[qlast.SortOrder] = None # NULLS FIRST/LAST nulls: typing.Optional[qlast.NonesOrder] = None class LockClauseStrength(enum.StrEnum): UPDATE = "UPDATE" NO_KEY_UPDATE = "NO KEY UPDATE" SHARE = "SHARE" KEY_SHARE = "KEY SHARE" class LockWaitPolicy(enum.StrEnum): LockWaitBlock = "" LockWaitSkip = "SKIP LOCKED" LockWaitError = "NOWAIT" class LockingClause(ImmutableBase): """Locking clause element (FOR ... )""" strength: LockClauseStrength "lock strength" locked_rels: typing.Optional[list[RelRangeVar]] = None "locked relations" wait_policy: typing.Optional[LockWaitPolicy] = None "lock wait policy" class WindowDef(ImmutableBase): """WINDOW and OVER clauses.""" # window name name: typing.Optional[str] = None # referenced window name, if any refname: typing.Optional[str] = None # PARTITION BY expr list partition_clause: typing.Optional[list[BaseExpr]] = None # ORDER BY order_clause: typing.Optional[list[SortBy]] = None # Window frame options frame_options: typing.Optional[list] = None # expression for starting bound, if any start_offset: typing.Optional[BaseExpr] = None # expression for ending ound, if any end_offset: typing.Optional[BaseExpr] = None class RangeSubselect(PathRangeVar): """Subquery appearing in FROM clauses.""" # Before postgres 16, an alias is always required on selects from # a subquery. Try to catch that with the typechecker by getting # rid of the default value. alias: Alias lateral: bool = False subquery: Query @property def query(self) -> Query: return self.subquery class RangeFunction(BaseRangeVar): lateral: bool = False # WITH ORDINALITY with_ordinality: bool = False # ROWS FROM form is_rowsfrom: bool = False functions: list[BaseExpr] class JoinClause(BaseRangeVar): # Type of join type: str # Right subtree rarg: BaseRangeVar # USING clause, if any using_clause: typing.Optional[list[ColumnRef]] = None # Qualifiers on join, if any quals: typing.Optional[BaseExpr] = None class JoinExpr(BaseRangeVar): # Left subtree larg: BaseRangeVar # Join clauses # We represent joins as being N-ary to avoid recursing too deeply joins: list[JoinClause] @classmethod def make_inplace( cls, *, larg: BaseRangeVar, type: str, rarg: BaseRangeVar, using_clause: typing.Optional[list[ColumnRef]] = None, quals: typing.Optional[BaseExpr] = None, ) -> JoinExpr: clause = JoinClause( type=type, rarg=rarg, using_clause=using_clause, quals=quals ) if isinstance(larg, JoinExpr): larg.joins.append(clause) return larg else: return JoinExpr(larg=larg, joins=[clause]) class SubLink(ImmutableBaseExpr): """Subselect appearing in an expression.""" # Sublink expression test_expr: typing.Optional[BaseExpr] = None # EXISTS, NOT_EXISTS, ALL, ANY operator: typing.Optional[str] # Sublink expression expr: BaseExpr # Sublink is never NULL nullable: bool = False class RowExpr(ImmutableBaseExpr): """A ROW() expression.""" # The fields. args: list[BaseExpr] # Row expressions, while may contain NULLs, are not NULL themselves. nullable: bool = False class ImplicitRowExpr(ImmutableBaseExpr): """A (a, b, c) expression.""" # The fields. args: typing.Sequence[BaseExpr] # Row expressions, while may contain NULLs, are not NULL themselves. nullable: bool = False class CoalesceExpr(ImmutableBaseExpr): """A COALESCE() expression.""" # The arguments. args: list[Base] def _infer_nullability(self, kwargs: dict[str, typing.Any]) -> bool: # nullability of COALESCE is the nullability of the RHS if 'args' in kwargs: return kwargs['args'][1].nullable else: return True class NullTest(ImmutableBaseExpr): """IS [NOT] NULL.""" # Input expression, arg: BaseExpr # NOT NULL? negated: bool = False # NullTest is never NULL nullable: bool = False class BooleanTest(ImmutableBaseExpr): """IS [NOT] {TRUE,FALSE}""" # Input expression, arg: BaseExpr negated: bool = False is_true: bool = False # NullTest is never NULL nullable: bool = False class CaseWhen(ImmutableBase): # Condition expression expr: BaseExpr # subsitution result result: BaseExpr class CaseExpr(ImmutableBaseExpr): # Equality comparison argument arg: typing.Optional[BaseExpr] = None # List of WHEN clauses args: list[CaseWhen] # ELSE clause defresult: typing.Optional[BaseExpr] = None class GroupingOperation(Base): operation: typing.Optional[str] = None args: list[Base] SortAsc = qlast.SortAsc SortDesc = qlast.SortDesc SortDefault = qlast.SortDefault NullsFirst = qlast.NonesFirst NullsLast = qlast.NonesLast class AlterSystem(ImmutableBaseExpr): name: str value: typing.Optional[BaseExpr] class Set(ImmutableBaseExpr): name: str value: BaseExpr class ConfigureDatabase(ImmutableBase): database_name: str parameter_name: str value: BaseExpr class IteratorCTE(ImmutableBase): path_id: irast.PathId cte: CommonTableExpr parent: typing.Optional[IteratorCTE] # A list of other paths to *also* register the iterator rvar as # providing when it is merged into a statement. other_paths: tuple[tuple[irast.PathId, PathAspect], ...] = () iterator_bond: bool = False @property def aspect(self) -> PathAspect: from .compiler import enums as pgce return ( pgce.PathAspect.ITERATOR if self.iterator_bond else pgce.PathAspect.IDENTITY ) class Statement(Base): """A statement that does not return a relation""" pass class VariableSetStmt(Statement): name: str args: ArgsList scope: OptionsScope class ArgsList(Base): args: list[BaseExpr] class VariableResetStmt(Statement): name: typing.Optional[str] scope: OptionsScope class SetTransactionStmt(Statement): """A special case of VariableSetStmt""" options: TransactionOptions scope: OptionsScope class VariableShowStmt(Statement): name: str class TransactionStmt(Statement): pass class OptionsScope(enum.IntEnum): TRANSACTION = enum.auto() SESSION = enum.auto() class BeginStmt(TransactionStmt): options: typing.Optional[TransactionOptions] class StartStmt(TransactionStmt): options: typing.Optional[TransactionOptions] class CommitStmt(TransactionStmt): chain: typing.Optional[bool] class RollbackStmt(TransactionStmt): chain: typing.Optional[bool] class SavepointStmt(TransactionStmt): savepoint_name: str class ReleaseStmt(TransactionStmt): savepoint_name: str class RollbackToStmt(TransactionStmt): savepoint_name: str class TwoPhaseTransactionStmt(TransactionStmt): gid: str class PrepareTransaction(TwoPhaseTransactionStmt): pass class CommitPreparedStmt(TwoPhaseTransactionStmt): pass class RollbackPreparedStmt(TwoPhaseTransactionStmt): pass class TransactionOptions(Base): options: dict[str, BaseExpr] class PrepareStmt(Statement): name: str argtypes: typing.Optional[list[Base]] query: BaseRelation class ExecuteStmt(Statement): name: str params: typing.Optional[list[Base]] class DeallocateStmt(Statement): name: str class SQLValueFunctionOP(enum.IntEnum): CURRENT_DATE = enum.auto() CURRENT_TIME = enum.auto() CURRENT_TIME_N = enum.auto() CURRENT_TIMESTAMP = enum.auto() CURRENT_TIMESTAMP_N = enum.auto() LOCALTIME = enum.auto() LOCALTIME_N = enum.auto() LOCALTIMESTAMP = enum.auto() LOCALTIMESTAMP_N = enum.auto() CURRENT_ROLE = enum.auto() CURRENT_USER = enum.auto() USER = enum.auto() SESSION_USER = enum.auto() CURRENT_CATALOG = enum.auto() CURRENT_SCHEMA = enum.auto() class SQLValueFunction(BaseExpr): op: SQLValueFunctionOP arg: typing.Optional[BaseExpr] = None class CreateStmt(Statement): relation: Relation table_elements: list[TableElement] on_commit: typing.Optional[str] class CreateTableAsStmt(Statement): into: CreateStmt query: Query with_no_data: bool class MinMaxExpr(BaseExpr): # GREATEST / LEAST expression # Very similar to FuncCall, except that the name is not escaped op: str args: list[BaseExpr] class LockStmt(Statement): relations: list[BaseRangeVar] mode: str no_wait: bool = False class CopyFormat(enum.IntEnum): TEXT = enum.auto() CSV = enum.auto() BINARY = enum.auto() class CopyOptions(Base): # Options for the copy command format: typing.Optional[CopyFormat] = None freeze: typing.Optional[bool] = None delimiter: typing.Optional[str] = None null: typing.Optional[str] = None header: typing.Optional[bool] = None quote: typing.Optional[str] = None escape: typing.Optional[str] = None force_quote: list[str] = [] force_not_null: list[str] = [] force_null: list[str] = [] encoding: typing.Optional[str] = None class CopyStmt(Statement): relation: typing.Optional[Relation] colnames: typing.Optional[list[str]] query: typing.Optional[Query] is_from: bool = False is_program: bool = False filename: typing.Optional[str] options: CopyOptions where_clause: typing.Optional[BaseExpr] = None class FTSDocument(BaseExpr): """ Text and information on how to search through it. Constructed with `std::fts::with_options`. """ text: BaseExpr language: BaseExpr language_domain: set[str] weight: typing.Optional[str] ================================================ FILE: edb/pgsql/codegen.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Sequence import abc import collections import dataclasses from edb import errors from edb.pgsql import common from edb.pgsql import ast as pgast from edb.common.ast import codegen from edb.common import exceptions from edb.common import markup def generate( node: pgast.Base, *, indent_with: str = ' ' * 4, add_line_information: bool = False, pretty: bool = True, reordered: bool = False, with_source_map: bool = False, ) -> SQLSource: # Main entrypoint generator = SQLSourceGenerator( opts=codegen.Options( indent_with=indent_with, add_line_information=add_line_information, pretty=pretty, ), reordered=reordered, with_source_map=with_source_map, ) try: generator.visit(node) except RecursionError: # Don't try to wrap and add context to a recursion error, # since the context might easily be too deeply recursive to # process further down the pipe. raise except GeneratorError as error: ctx = GeneratorContext(node, generator.result) exceptions.add_context(error, ctx) raise except Exception as error: ctx = GeneratorContext(node, generator.result) err = GeneratorError('error while generating SQL source') exceptions.add_context(err, ctx) raise err from error if with_source_map: assert generator.source_map return SQLSource( text=generator.finish(), source_map=generator.source_map, param_index=generator.param_index, ) def generate_source( node: pgast.Base, *, indent_with: str = ' ' * 4, add_line_information: bool = False, pretty: bool = False, reordered: bool = False, ) -> str: # Simplified entrypoint source = generate( node, indent_with=indent_with, add_line_information=add_line_information, pretty=pretty, reordered=reordered, ) return source.text def generate_ctes_source( ctes: list[pgast.CommonTableExpr], ) -> str: # Alternative simplified entrypoint generating 'WITH a AS (...)' only. generator = SQLSourceGenerator(opts=codegen.Options()) generator.gen_ctes(ctes) return generator.finish() class SourceMap: @abc.abstractmethod def translate(self, pos: int) -> int: ... @dataclasses.dataclass(kw_only=True) class BaseSourceMap(SourceMap): source_start: int output_start: int output_end: int | None = None children: list[BaseSourceMap] = ( dataclasses.field(default_factory=list)) def translate(self, pos: int) -> int: bu = None for u in self.children: if u.output_start >= pos: break bu = u if bu and (bu.output_end is None or bu.output_end > pos): return bu.translate(pos) return self.source_start @dataclasses.dataclass class ChainedSourceMap(SourceMap): parts: list[SourceMap] = ( dataclasses.field(default_factory=list)) def translate(self, pos: int) -> int: for part in self.parts: pos = part.translate(pos) return pos @dataclasses.dataclass(frozen=True) class SQLSource: text: str param_index: dict[int, list[int]] source_map: Optional[SourceMap] = None class SQLSourceGenerator(codegen.SourceGenerator): def __init__( self, opts: codegen.Options, *, with_source_map: bool = False, reordered: bool = False, ): super().__init__( indent_with=opts.indent_with, add_line_information=opts.add_line_information, pretty=opts.pretty, ) self.is_toplevel = True # params self.with_source_map: bool = with_source_map self.reordered = reordered # state self.param_index: collections.defaultdict[int, list[int]] = ( collections.defaultdict(list)) self.write_index: int = 0 self.source_map: Optional[BaseSourceMap] = None def write( self, *x: str, delimiter: Optional[str] = None, ) -> None: self.is_toplevel = False start = len(self.result) super().write(*x, delimiter=delimiter) for new in range(start, len(self.result)): self.write_index += len(self.result[new]) def visit(self, node): # type: ignore if self.with_source_map: source_map = BaseSourceMap( source_start=node.span.start if node.span else 0, output_start=self.write_index, ) old_top = self.source_map self.source_map = source_map super().visit(node) if self.with_source_map: assert self.source_map == source_map self.source_map.output_end = self.write_index if old_top: old_top.children.append(self.source_map) self.source_map = old_top def generic_visit(self, node): # type: ignore raise GeneratorError( 'No method to generate code for %s' % node.__class__.__name__ ) def gen_ctes(self, ctes: list[pgast.CommonTableExpr]) -> None: count = len(ctes) for i, cte in enumerate(ctes): self.new_lines = 1 if i == 0 and getattr(cte, 'recursive', None): self.write('RECURSIVE ') self.write(common.quote_ident(cte.name)) if cte.aliascolnames: self.write('(') for (index, col_name) in enumerate(cte.aliascolnames): self.write(common.qname(col_name, column=True)) if index + 1 < len(cte.aliascolnames): self.write(',') self.write(')') self.write(' AS ') if cte.materialized is not None: if cte.materialized: self.write('MATERIALIZED ') else: self.write('NOT MATERIALIZED ') self.indentation += 1 self.new_lines = 1 self.write('(') self.visit(cte.query) self.write(')') if i != count - 1: self.write(',') self.indentation -= 1 self.new_lines = 1 def visit__Ref(self, node): # type: ignore self.visit(node.node) def visit_Relation(self, node: pgast.Relation) -> None: assert node.name if node.schemaname is None: self.write(common.qname(node.name)) else: self.write(common.qname(node.schemaname, node.name)) def visit_NullRelation(self, node: pgast.NullRelation) -> None: self.write('(SELECT ') if node.target_list: self.visit_list(node.target_list) if node.where_clause: self.indentation += 1 self.new_lines = 1 self.write('WHERE') self.new_lines = 1 self.indentation += 1 self.visit(node.where_clause) self.indentation -= 2 self.write(')') def visit_SelectStmt(self, node: pgast.SelectStmt) -> None: parenthesize = not self.is_toplevel if parenthesize: if not self.reordered and self.result: self.new_lines = 1 self.write('(') if self.reordered: self.new_lines = 1 if not node.op: self.indentation += 1 if node.ctes: self.write('WITH ') self.gen_ctes(node.ctes) if node.values: self.write('VALUES') self.new_lines = 1 self.visit_list(node.values) if parenthesize: self.new_lines = 1 if self.reordered and not node.op: self.indentation -= 1 self.write(')') return # If reordered is True, we try to put the FROM clause *before* SELECT, # like it *ought* to be. We do various hokey things to try to make # that look good. # Otherwise we emit real SQL. def _select() -> None: self.write('SELECT') if node.distinct_clause: self.write(' DISTINCT') if len(node.distinct_clause) > 1 or not isinstance( node.distinct_clause[0], pgast.Star ): self.write(' ON (') self.visit_list(node.distinct_clause, newlines=False) self.write(')') if self.pretty: self.write('/*', repr(node), '*/') self.new_lines = 1 if node.op: # Upper level set operation node (UNION/INTERSECT) # HACK: The LHS of a set operation is *not* top-level, and # shouldn't be treated as such. Since we (also hackily) # use whether anything has been written do determine # whether we are at the top level, write out an empty # string to force parenthesization. self.is_toplevel = False self.visit(node.larg) self.write(' ' + node.op + ' ') if node.all: self.write('ALL ') self.visit(node.rarg) else: if not self.reordered: _select() self.indentation += 2 if not self.reordered: if node.target_list: self.visit_list(node.target_list) if not node.op: self.indentation -= 2 if node.from_clause: if not self.reordered: self.indentation += 1 self.new_lines = 1 self.write('FROM ') if not self.reordered: self.new_lines = 1 self.indentation += 1 self.visit_list(node.from_clause) if self.reordered: self.new_lines = 1 else: self.indentation -= 2 if self.reordered and not node.op: _select() self.indentation += 1 if node.target_list: self.visit_list(node.target_list) # In reordered mode, we don't want to indent the clauses, # so we overreduce the indentation at this point and fix # it up at the end self.indentation -= 2 if node.where_clause: self.indentation += 1 self.new_lines = 1 self.write('WHERE') self.new_lines = 1 self.indentation += 1 self.visit(node.where_clause) self.indentation -= 2 if node.group_clause: self.indentation += 1 self.new_lines = 1 self.write('GROUP BY') self.new_lines = 1 self.indentation += 1 self.visit_list(node.group_clause) self.indentation -= 2 if node.having_clause: self.indentation += 1 self.new_lines = 1 self.write('HAVING') self.new_lines = 1 self.indentation += 1 self.visit(node.having_clause) self.indentation -= 2 if node.sort_clause: self.indentation += 1 self.new_lines = 1 self.write('ORDER BY') self.new_lines = 1 self.indentation += 1 self.visit_list(node.sort_clause) self.indentation -= 2 if node.limit_offset: self.indentation += 1 self.new_lines = 1 self.write('OFFSET ') self.visit(node.limit_offset) self.indentation -= 1 if node.limit_count: self.indentation += 1 self.new_lines = 1 self.write('LIMIT ') self.visit(node.limit_count) self.indentation -= 1 if node.locking_clause: self.indentation += 1 self.new_lines = 1 self.visit_list(node.locking_clause, separator=" ") self.indentation -= 1 if self.reordered and not node.op: self.indentation += 1 if parenthesize: self.new_lines = 1 if self.reordered and not node.op: self.indentation -= 1 self.write(')') def visit_InsertStmt(self, node: pgast.InsertStmt) -> None: if node.ctes: self.write('WITH ') self.gen_ctes(node.ctes) self.write('INSERT INTO ') self.visit(node.relation) if node.cols: self.new_lines = 1 self.indentation += 1 self.write('(') self.visit_list(node.cols, newlines=False) self.write(')') self.indentation -= 1 self.indentation += 1 self.new_lines = 1 if node.select_stmt: if ( isinstance(node.select_stmt, pgast.SelectStmt) and node.select_stmt.values ): self.write('VALUES ') self.new_lines = 1 self.indentation += 1 self.visit_list(node.select_stmt.values) self.indentation -= 1 else: self.write('(') self.visit(node.select_stmt) self.write(')') else: self.write('DEFAULT VALUES') if node.on_conflict: self.new_lines = 1 self.write('ON CONFLICT') if node.on_conflict.target: self.visit(node.on_conflict.target) self.write(' DO ') self.write(node.on_conflict.action[3:]) if node.on_conflict.update_list: self.write(' SET') self.new_lines = 1 self.indentation += 1 self.visit_list(node.on_conflict.update_list) self.indentation -= 1 if node.on_conflict.update_where: self.write(' WHERE ') self.indentation += 1 self.visit(node.on_conflict.update_where) self.indentation -= 1 if node.returning_list: self.new_lines = 1 self.write('RETURNING') self.new_lines = 1 self.indentation += 1 self.visit_list(node.returning_list) self.indentation -= 1 self.indentation -= 1 def visit_UpdateStmt(self, node: pgast.UpdateStmt) -> None: if node.ctes: self.write('WITH ') self.gen_ctes(node.ctes) self.write('UPDATE ') self.new_lines = 1 self.indentation += 1 self.visit(node.relation) self.indentation -= 1 self.new_lines = 1 self.write('SET') self.new_lines = 1 self.indentation += 1 self.visit_list(node.targets) self.indentation -= 1 if node.from_clause: self.new_lines = 1 self.write('FROM') self.new_lines = 1 self.indentation += 1 self.visit_list(node.from_clause) self.indentation -= 1 if node.where_clause: self.new_lines = 1 self.write('WHERE') self.new_lines = 1 self.indentation += 1 self.visit(node.where_clause) self.new_lines = 1 self.indentation -= 1 if node.returning_list: self.new_lines = 1 self.write('RETURNING') self.new_lines = 1 self.indentation += 1 self.visit_list(node.returning_list) self.indentation -= 1 def visit_DeleteStmt(self, node: pgast.DeleteStmt) -> None: if node.ctes: self.write('WITH ') self.gen_ctes(node.ctes) self.write('DELETE FROM ') self.new_lines = 1 self.indentation += 1 self.visit(node.relation) self.indentation -= 1 if node.using_clause: self.new_lines = 1 self.write('USING') self.new_lines = 1 self.indentation += 1 self.visit_list(node.using_clause) self.indentation -= 1 if node.where_clause: self.new_lines = 1 self.write('WHERE') self.new_lines = 1 self.indentation += 1 self.visit(node.where_clause) self.new_lines = 1 self.indentation -= 1 if node.returning_list: self.new_lines = 1 self.write('RETURNING') self.new_lines = 1 self.indentation += 1 self.visit_list(node.returning_list) self.indentation -= 1 def visit_OnConflictTarget(self, node: pgast.OnConflictTarget) -> None: assert not node.constraint_name or not node.index_elems if node.constraint_name: self.write(' ON CONSTRAINT ') self.write(node.constraint_name) if node.index_elems: self.write(' (') self.visit_list(node.index_elems) self.write(')') if node.index_where: self.write(' WHERE ') self.visit(node.index_where) def visit_IndexElem(self, node: pgast.IndexElem) -> None: self.visit(node.expr) if node.ordering == pgast.SortAsc: self.write(' ASC') elif node.ordering == pgast.SortDesc: self.write(' DESC') if node.nulls_ordering == pgast.NullsFirst: self.write(' NULLS FIRST') elif node.nulls_ordering == pgast.NullsLast: self.write(' NULLS LAST') def visit_MultiAssignRef(self, node: pgast.MultiAssignRef) -> None: self.write('(') for index, col in enumerate(node.columns): if index > 0: self.write(', ') self.write(common.quote_col(col)) self.write(') = ') self.visit(node.source) def visit_LiteralExpr(self, node: pgast.LiteralExpr) -> None: self.write(node.expr) def visit_ResTarget(self, node: pgast.ResTarget) -> None: self.visit(node.val) if node.name: self.write(' AS ' + common.quote_col(node.name)) def visit_InsertTarget(self, node: pgast.InsertTarget) -> None: self.write(common.quote_col(node.name)) def visit_UpdateTarget(self, node: pgast.UpdateTarget) -> None: if isinstance(node.name, list): self.write('(') self.write(', '.join(common.quote_col(n) for n in node.name)) self.write(')') else: self.write(common.quote_col(node.name)) if node.indirection: self._visit_indirection_ops(node.indirection) self.write(' = ') self.visit(node.val) def visit_Alias(self, node: pgast.Alias) -> None: self.write(common.quote_ident(node.aliasname)) if node.colnames: self.write('(') self.write(', '.join(common.quote_col(n) for n in node.colnames)) self.write(')') def visit_Keyword(self, node: pgast.Keyword) -> None: self.write(node.name) def visit_RelRangeVar(self, node: pgast.RelRangeVar) -> None: rel = node.relation if not node.include_inherited: self.write(' ONLY (') if isinstance(rel, (pgast.Relation, pgast.NullRelation)): self.visit(rel) elif isinstance(rel, pgast.CommonTableExpr): self.write(common.quote_ident(rel.name)) else: raise GeneratorError( 'unexpected relation in RelRangeVar: {!r}'.format(rel) ) if not node.include_inherited: self.write(')') if node.alias and node.alias.aliasname: self.write(' AS ') self.visit(node.alias) def visit_RangeSubselect(self, node: pgast.RangeSubselect) -> None: if node.lateral: self.write('LATERAL ') self.visit(node.subquery) if node.alias and node.alias.aliasname: self.write(' AS ') self.visit(node.alias) def visit_RangeFunction(self, node: pgast.RangeFunction) -> None: if node.lateral: self.write('LATERAL ') if node.is_rowsfrom: self.write('ROWS FROM (') self.visit_list(node.functions) if node.is_rowsfrom: self.write(')') if node.with_ordinality: self.write(' WITH ORDINALITY ') if node.alias and node.alias.aliasname: self.write(' AS ') self.visit(node.alias) def visit_ColumnRef(self, node: pgast.ColumnRef) -> None: names = node.name if isinstance(names[-1], pgast.Star): self.write(common.qname(*names)) else: if names == ['VALUE']: self.write('VALUE') elif names[0] in {'OLD', 'NEW'}: assert isinstance(names[0], str) self.write(names[0]) if len(names) > 1: self.write('.') self.write(common.qname(*names[1:], column=True)) else: self.write(common.qname(*names, column=True)) def visit_ExprOutputVar(self, node: pgast.ExprOutputVar) -> None: self.visit(node.expr) def visit_ColumnDef(self, node: pgast.ColumnDef) -> None: self.write(common.quote_col(node.name)) if node.typename: self.write(' ') self.visit(node.typename) if node.is_not_null: self.write(' NOT NULL') if node.default_expr: self.write(' DEFAULT ') self.visit(node.default_expr) def visit_GroupingOperation(self, node: pgast.GroupingOperation) -> None: if node.operation: self.write(node.operation) self.write(' ') self.write('(') self.visit_list(node.args, newlines=False) self.write(')') def visit_JoinExpr(self, node: pgast.JoinExpr) -> None: self.visit(node.larg) for join in node.joins: self.new_lines = 1 if not join.quals and not join.using_clause: join_type = 'CROSS' else: join_type = join.type.upper() if join_type == 'INNER': self.write('JOIN ') else: self.write(join_type + ' JOIN ') nested_join = ( isinstance(join.rarg, pgast.JoinExpr) and join.rarg.joins ) if nested_join: self.write('(') self.new_lines = 1 self.indentation += 1 self.visit(join.rarg) if nested_join: self.indentation -= 1 self.new_lines = 1 self.write(')') if join.quals is not None: if not nested_join: self.indentation += 1 self.new_lines = 1 self.write('ON ') else: self.write(' ON ') self.visit(join.quals) if not nested_join: self.indentation -= 1 elif join.using_clause: self.write(" USING (") self.visit_list(join.using_clause) self.write(")") def visit_Expr(self, node: pgast.Expr) -> None: self.write('(') if node.lexpr is not None: self.visit(node.lexpr) self.write(' ') op = str(node.name) if '.' not in op: op = op.upper() self.write(op) if node.rexpr is not None: self.write(" ") self.visit_indented(node.rexpr, indent=op in {"OR", "AND"}) self.write(")") def visit_NullConstant(self, _node: pgast.NullConstant) -> None: self.write('NULL') def visit_NumericConstant(self, node: pgast.NumericConstant) -> None: self.write(node.val) def visit_BooleanConstant(self, node: pgast.BooleanConstant) -> None: self.write('TRUE' if node.val else 'FALSE') def visit_StringConstant(self, node: pgast.StringConstant) -> None: self.write(common.quote_literal(node.val)) def visit_BitStringConstant(self, node: pgast.BitStringConstant) -> None: self.write(f"{node.kind}'{node.val}'") def visit_ByteaConstant(self, node: pgast.ByteaConstant) -> None: self.write(common.quote_bytea_literal(node.val)) def visit_ParamRef(self, node: pgast.ParamRef) -> None: self.write(f'${node.number}') self.param_index[node.number].append(len(self.result) - 1) def visit_RowExpr(self, node: pgast.RowExpr) -> None: self.write('ROW(') self.visit_list(node.args, newlines=False) self.write(')') def visit_ImplicitRowExpr(self, node: pgast.ImplicitRowExpr) -> None: self.write('(') self.visit_list(node.args, newlines=False) self.write(')') def visit_ArrayExpr(self, node: pgast.ArrayExpr) -> None: self.write('ARRAY[') self.visit_list(node.elements, newlines=False) self.write(']') def visit_ArrayDimension(self, node: pgast.ArrayDimension) -> None: self.write('[') self.visit_list(node.elements, newlines=False) self.write(']') def visit_VariadicArgument(self, node: pgast.VariadicArgument) -> None: self.write('VARIADIC ') self.visit(node.expr) def visit_FuncCall(self, node: pgast.FuncCall) -> None: self.write(common.qname(*node.name)) self.write('(') if node.agg_star: self.write("*") elif node.agg_distinct: self.write('DISTINCT ') self.visit_list(node.args, newlines=False) if node.agg_order and not node.agg_within_group: self.write(' ORDER BY ') self.visit_list(node.agg_order, newlines=False) self.write(')') if node.agg_order and node.agg_within_group: self.write(' WITHIN GROUP (ORDER BY ') self.visit_list(node.agg_order, newlines=False) self.write(')') if node.agg_filter: self.write(' FILTER (WHERE ') self.visit(node.agg_filter) self.write(')') if node.over: self.write(' OVER (') if node.over.partition_clause: self.write('PARTITION BY ') self.visit_list(node.over.partition_clause, newlines=False) if node.over.order_clause: self.write(' ORDER BY ') self.visit_list(node.over.order_clause, newlines=False) # XXX: add support for frame definition self.write(')') if node.with_ordinality: self.write(' WITH ORDINALITY') if node.coldeflist: self.write(' AS (') self.visit_list(node.coldeflist, newlines=False) self.write(')') def visit_NamedFuncArg(self, node: pgast.NamedFuncArg) -> None: self.write(common.quote_ident(node.name), ' => ') self.visit(node.val) def visit_SubLink(self, node: pgast.SubLink) -> None: if node.test_expr: self.visit(node.test_expr) if node.operator: self.write(" " + node.operator + " ") self.visit_indented(node.expr, indent=True, nest=True) def visit_SortBy(self, node: pgast.SortBy) -> None: self.visit(node.node) if node.dir: direction = 'ASC' if node.dir == pgast.SortAsc else 'DESC' self.write(' ' + direction) if node.nulls is None: if node.dir == pgast.SortDesc: self.write(' NULLS LAST') else: self.write(' NULLS FIRST') elif node.nulls == pgast.NullsFirst: self.write(' NULLS FIRST') elif node.nulls == pgast.NullsLast: self.write(' NULLS LAST') else: raise GeneratorError( 'unexpected NULLS order: {}'.format(node.nulls) ) def visit_LockingClause(self, node: pgast.LockingClause) -> None: self.write("FOR ", str(node.strength)) if node.locked_rels: self.write(" OF ") self.visit_list(node.locked_rels) if node.wait_policy is not None: if kw := str(node.wait_policy): self.write(f" {kw}") def visit_TypeCast(self, node: pgast.TypeCast) -> None: # '::' has very high precedence, so parenthesize the expression. self.write('(') self.visit(node.arg) self.write(')') self.write('::') self.visit(node.type_name) def visit_TypeName(self, node: pgast.TypeName) -> None: self.write(common.quote_type(node.name)) if node.array_bounds: for array_bound in node.array_bounds: self.write('[') if array_bound >= 0: self.write(str(array_bound)) self.write(']') def visit_Star(self, _: pgast.Star) -> None: self.write('*') def visit_CaseExpr(self, node: pgast.CaseExpr) -> None: self.write('(CASE ') if node.arg: self.visit(node.arg) self.write(' ') for arg in node.args: self.visit(arg) self.new_lines = 1 if node.defresult: self.write('ELSE ') self.visit(node.defresult) self.new_lines = 1 self.write('END)') def visit_CaseWhen(self, node: pgast.CaseWhen) -> None: self.write('WHEN ') self.visit(node.expr) self.write(' THEN ') self.visit(node.result) def visit_NullTest(self, node: pgast.NullTest) -> None: self.write('(') self.visit(node.arg) if node.negated: self.write(' IS NOT NULL') else: self.write(' IS NULL') self.write(')') def visit_BooleanTest(self, node: pgast.BooleanTest) -> None: self.write("(") self.visit(node.arg) op = " IS" if node.negated: op += " NOT" if node.is_true: op += " TRUE" else: op += " FALSE" self.write(op) self.write(")") def visit_Indirection(self, node: pgast.Indirection) -> None: self.write('(') self.visit(node.arg) self.write(')') self._visit_indirection_ops(node.indirection) def visit_RecordIndirectionOp( self, node: pgast.RecordIndirectionOp ) -> None: self.write('.') self.write(common.qname(node.name)) def _visit_indirection_ops( self, ops: Sequence[pgast.IndirectionOp] ) -> None: for op in ops: if isinstance(op, pgast.Star): self.write('.') self.visit(op) def visit_Index(self, node: pgast.Index) -> None: self.write('[') self.visit(node.idx) self.write(']') def visit_Slice(self, node: pgast.Slice) -> None: self.write('[') if node.lidx is not None: self.visit(node.lidx) self.write(':') if node.ridx is not None: self.visit(node.ridx) self.write(']') def visit_CollateClause(self, node: pgast.CollateClause) -> None: self.visit(node.arg) self.write(f' COLLATE {common.qname(*node.collname)}') def visit_CoalesceExpr(self, node: pgast.CoalesceExpr) -> None: self.write('COALESCE(') self.visit_list(node.args, newlines=False) self.write(')') def visit_AlterSystem(self, node: pgast.AlterSystem) -> None: self.write('ALTER SYSTEM ') if node.value is not None: self.write('SET ') self.write(common.quote_ident(node.name)) self.write(' = ') self.visit(node.value) else: self.write('RESET ') self.write(common.quote_ident(node.name)) def visit_Set(self, node: pgast.Set) -> None: if node.value is not None: self.write('SET ') self.write(common.quote_ident(node.name)) self.write(' = ') self.visit(node.value) else: self.write('RESET ') self.write(common.quote_ident(node.name)) def visit_VariableSetStmt(self, node: pgast.VariableSetStmt) -> None: self.write("SET ") if node.scope == pgast.OptionsScope.TRANSACTION: self.write("LOCAL ") self.write(common.qname(node.name)) self.write(" TO ") self.visit(node.args) def visit_ArgsList(self, node: pgast.ArgsList) -> None: self.visit_list(node.args) def visit_VariableResetStmt(self, node: pgast.VariableResetStmt) -> None: if node.name is None: assert node.scope == pgast.OptionsScope.SESSION self.write("RESET ALL") else: self.write("SET ") if node.scope == pgast.OptionsScope.TRANSACTION: self.write("LOCAL ") self.write(common.qname(node.name)) self.write(" TO DEFAULT") def visit_SetTransactionStmt(self, node: pgast.SetTransactionStmt) -> None: self.write("SET ") if node.scope == pgast.OptionsScope.TRANSACTION: self.write("TRANSACTION ") else: self.write("SESSION CHARACTERISTICS AS TRANSACTION ") self.visit(node.options) def visit_VariableShowStmt(self, node: pgast.VariableShowStmt) -> None: self.write("SHOW ") self.write(common.qname(node.name)) def visit_BeginStmt(self, node: pgast.BeginStmt) -> None: self.write("BEGIN") if node.options: self.visit(node.options) def visit_StartStmt(self, node: pgast.StartStmt) -> None: self.write("START TRANSACTION") if node.options: self.visit(node.options) def visit_CommitStmt(self, node: pgast.CommitStmt) -> None: self.write("COMMIT") if node.chain: self.write(" AND CHAIN") def visit_RollbackStmt(self, node: pgast.RollbackStmt) -> None: self.write("ROLLBACK") if node.chain: self.write(" AND CHAIN") def visit_SavepointStmt(self, node: pgast.SavepointStmt) -> None: self.write(f"SAVEPOINT {node.savepoint_name}") def visit_ReleaseStmt(self, node: pgast.ReleaseStmt) -> None: self.write(f"RELEASE {node.savepoint_name}") def visit_RollbackToStmt(self, node: pgast.RollbackToStmt) -> None: self.write(f"ROLLBACK TO SAVEPOINT {node.savepoint_name}") def visit_PrepareTransaction(self, node: pgast.PrepareTransaction) -> None: self.write(f"PREPARE TRANSACTION '{node.gid}'") def visit_CommitPreparedStmt(self, node: pgast.CommitPreparedStmt) -> None: self.write(f"COMMIT PREPARED '{node.gid}'") def visit_RollbackPreparedStmt( self, node: pgast.RollbackPreparedStmt ) -> None: self.write(f"ROLLBACK PREPARED '{node.gid}'") def visit_TransactionOptions(self, node: pgast.TransactionOptions) -> None: for def_name, arg in node.options.items(): if def_name == "transaction_isolation": self.write(" ISOLATION LEVEL ") if isinstance(arg, pgast.StringConstant): self.write(arg.val.upper()) elif def_name == "transaction_read_only": if isinstance(arg, pgast.NumericConstant): if arg.val == "1": self.write(" READ ONLY") else: self.write(" READ WRITE") elif def_name == "transaction_deferrable": if isinstance(arg, pgast.NumericConstant): if arg.val != "1": self.write(" NOT") self.write(" DEFERRABLE") def visit_PrepareStmt(self, node: pgast.PrepareStmt) -> None: self.write(f"PREPARE {common.quote_ident(node.name)}") if node.argtypes: self.write(f"(") self.visit_list(node.argtypes, newlines=False) self.write(f")") self.write(f" AS ") self.visit(node.query) def visit_ExecuteStmt(self, node: pgast.ExecuteStmt) -> None: self.write(f"EXECUTE {common.quote_ident(node.name)}") if node.params: self.write(f"(") self.visit_list(node.params, newlines=False) self.write(f")") def visit_DeallocateStmt(self, node: pgast.DeallocateStmt) -> None: self.write(f"DEALLOCATE {common.quote_ident(node.name)}") def visit_SQLValueFunction(self, node: pgast.SQLValueFunction) -> None: self.write(common.get_sql_value_function_op(node.op)) if node.arg: self.write("(") self.visit(node.arg) self.write(")") def visit_CreateStmt(self, node: pgast.CreateStmt) -> None: self.write('CREATE ') if node.relation.is_temporary: self.write('TEMPORARY ') self.write('TABLE ') self.visit_Relation(node.relation) if node.table_elements: self.write(' (') self.visit_list(node.table_elements) self.write(')') if node.on_commit: self.write(' ON COMMIT ') self.write(node.on_commit) def visit_CreateTableAsStmt(self, node: pgast.CreateTableAsStmt) -> None: self.visit(node.into) self.write(' AS ') self.visit(node.query) if node.with_no_data: self.write(' WITH NO DATA') def visit_MinMaxExpr(self, node: pgast.MinMaxExpr) -> None: self.write(node.op) self.write('(') self.visit_list(node.args) self.write(')') def visit_LockStmt(self, node: pgast.LockStmt) -> None: self.write('LOCK TABLE ') self.visit_list(node.relations) self.write(' IN ') self.write(node.mode) self.write(' MODE') if node.no_wait: self.write(' NOWAIT') def visit_CopyStmt(self, node: pgast.CopyStmt) -> None: self.write('COPY ') if node.query: self.write('(') self.indentation += 1 self.new_lines = 1 self.visit(node.query) self.indentation -= 1 self.write(')') elif node.relation: self.visit_Relation(node.relation) if node.colnames: self.write(' (') self.write( ', '.join(common.quote_ident(n) for n in node.colnames)) self.write(')') if node.is_from: self.write(' FROM ') else: self.write(' TO ') if node.is_program: self.write('PROGRAM ') if node.filename: self.write(common.quote_literal(node.filename)) else: if node.is_from: self.write('STDIN') else: self.write('STDOUT') self.visit_CopyOptions(node.options) if node.where_clause: self.indentation += 1 self.new_lines = 1 self.write('WHERE') self.new_lines = 1 self.indentation += 1 self.visit(node.where_clause) self.indentation -= 2 def visit_CopyOptions(self, node: pgast.CopyOptions) -> None: ql = common.quote_literal qi = common.quote_ident opts = [] if node.format: opts.append('FORMAT ' + node.format._name_) if node.freeze is not None: opts.append('FREEZE' + ('' if node.freeze else ' FALSE')) if node.delimiter: opts.append('DELIMITER ' + ql(node.delimiter)) if node.null: opts.append('NULL ' + ql(node.null)) if node.header is not None: opts.append('HEADER' + ('' if node.header else ' FALSE')) if node.quote: opts.append('QUOTE ' + ql(node.quote)) if node.escape: opts.append('ESCAPE ' + ql(node.escape)) if node.force_quote: opts.append( 'FORCE_QUOTE (' + ', '.join(map(qi, node.force_quote)) + ')' ) if node.force_not_null: opts.append( 'FORCE_NOT_NULL (' + ', '.join(map(qi, node.force_not_null)) + ')' ) if node.force_null: opts.append( 'FORCE_NULL (' + ', '.join(map(qi, node.force_null)) + ')' ) if node.encoding: opts.append('ENCODING ' + ql(node.encoding)) if opts: self.write(' (' + ', '.join(opts), ')') class GeneratorContext(markup.MarkupExceptionContext): title = 'SQL Source Generator Context' def __init__( self, node: pgast.Base, chunks_generated: Optional[Sequence[str]] = None, ): self.node = node self.chunks_generated = chunks_generated @classmethod def as_markup(cls: Any, self: Any, *, ctx: Any): # type: ignore me = markup.elements body = [ me.doc.Section( title='SQL Tree', body=[markup.serialize(self.node, ctx=ctx)], # type: ignore ) ] if self.chunks_generated: code = markup.serializer.serialize_code( ''.join(self.chunks_generated), lexer='sql' ) body.append( me.doc.Section( title='SQL generated so far', body=[code] # type: ignore ) ) return me.lang.ExceptionContext( title=self.title, body=body # type: ignore ) class GeneratorError(errors.InternalServerError): def __init__( self, msg: str, *, node: Optional[pgast.Base] = None, details: Optional[str] = None, hint: Optional[str] = None, ) -> None: super().__init__(msg, details=details, hint=hint) if node is not None: ctx = GeneratorContext(node) exceptions.add_context(self, ctx) ================================================ FILE: edb/pgsql/common.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import binascii import functools import hashlib import base64 import re from typing import Literal, Optional, overload import uuid from edb import buildmeta from edb.common import typeutils from edb.common import uuidgen from edb.schema import casts as s_casts from edb.schema import constraints as s_constr from edb.schema import defines as s_def from edb.schema import functions as s_func from edb.schema import indexes as s_indexes from edb.schema import name as s_name from edb.schema import objects as so from edb.schema import objtypes as s_objtypes from edb.schema import operators as s_opers from edb.schema import pointers as s_pointers from edb.schema import scalars as s_scalars from edb.schema import types as s_types from edb.schema import schema as s_schema from edb.pgsql import ast as pgast from . import keywords as pg_keywords # This is a postgres limitation. # Note that this can be overridden in custom builds. # https://www.postgresql.org/docs/current/datatype-enum.html MAX_ENUM_LABEL_LENGTH = 63 def quote_e_literal(string: str) -> str: def escape_sq(s): split = re.split(r"(\n|\\\\|\\')", s) if len(split) == 1: return s.replace(r"'", r"\'") return ''.join((r if i % 2 else r.replace(r"'", r"\'")) for i, r in enumerate(split)) return "E'" + escape_sq(string) + "'" def quote_literal(string: str) -> str: return "'" + string.replace("'", "''") + "'" def _quote_ident(string: str) -> str: return '"' + string.replace('"', '""') + '"' def quote_ident(ident: str | pgast.Star, *, force=False, column=False) -> str: if isinstance(ident, pgast.Star): return "*" return ( _quote_ident(ident) if needs_quoting(ident, column=column) or force else ident ) def quote_col(ident: str | pgast.Star) -> str: return quote_ident(ident, column=True) def quote_bytea_literal(data: bytes) -> str: """Return valid SQL representation of a bytes value.""" if data: b = binascii.b2a_hex(data).decode('ascii') return f"'\\x{b}'::bytea" else: return "''::bytea" def needs_quoting(string: str, column: bool = False) -> bool: isalnum = ( string and not string[0].isdecimal() and string.replace('_', 'a').isalnum() ) return ( not isalnum or string.lower() in pg_keywords.by_type[ pg_keywords.RESERVED_KEYWORD] or string.lower() in pg_keywords.by_type[ pg_keywords.TYPE_FUNC_NAME_KEYWORD] or (column and string.lower() in pg_keywords.by_type[ pg_keywords.COL_NAME_KEYWORD]) or string.lower() != string ) def qname(*parts: str | pgast.Star, column: bool = False) -> str: assert len(parts) <= 3, parts return '.'.join([quote_ident(q, column=column) for q in parts]) def quote_type(type_: tuple[str, ...] | str) -> str: if isinstance(type_, tuple): first = qname(*type_[:-1]) + '.' if len(type_) > 1 else '' last = type_[-1] else: first = '' last = type_ is_rowtype = last.endswith('%ROWTYPE') if is_rowtype: last = last[:-8] is_array = last.endswith('[]') if is_array: last = last[:-2] param = None if '(' in last: last, param = last.split('(', 1) param = '(' + param last = quote_ident(last) if is_rowtype: last += '%ROWTYPE' if param: last += param if is_array: last += '[]' return first + last def get_module_backend_name(module: s_name.Name) -> str: # standard modules go into "edgedbstd", user ones into "edgedbpub" return "edgedbstd" if module in s_schema.STD_MODULES else "edgedbpub" def get_unique_random_name() -> str: return base64.b64encode(uuidgen.uuid1mc().bytes).rstrip(b'=').decode() VERSIONED_SCHEMAS = ('edgedb', 'edgedbstd', 'edgedbsql', 'edgedbinstdata') SCHEMA_SUFFIX: str | None = None def versioned_schema(s: str) -> str: global SCHEMA_SUFFIX if SCHEMA_SUFFIX is None: version = buildmeta.get_version_dict()['major'] SCHEMA_SUFFIX = f'_v{version}_{buildmeta.EDGEDB_CATALOG_VERSION:x}' # N.B: We don't bother quoting the schema name, so make sure it is # lower case and doesn't have weird characters. return f'{s}{SCHEMA_SUFFIX}' def maybe_versioned_schema( s: str, versioned: bool=True, ) -> str: return ( versioned_schema(s) if versioned and s in VERSIONED_SCHEMAS else s ) def versioned_name( s: tuple[str, ...], ) -> tuple[str, ...]: if len(s) > 1: return (maybe_versioned_schema(s[0]), *s[1:]) else: return s def maybe_versioned_name( s: tuple[str, ...], *, versioned: bool, ) -> tuple[str, ...]: return versioned_name(s) if versioned else s @functools.lru_cache() def _edgedb_name_to_pg_name(name: str, prefix_length: int = 0) -> str: # Note: PostgreSQL doesn't have a sha1 implementation as a # built-in function available in all versions, hence we use md5. # # Although sha1 would be slightly better as it's marginally faster than # md5 (and it doesn't matter which function is better cryptographically # in this case.) hashed = base64.b64encode( hashlib.md5(name.encode(), usedforsecurity=False).digest() ).decode().rstrip('=') return ( name[:prefix_length] + hashed + ':' + name[-(s_def.MAX_NAME_LENGTH - prefix_length - 1 - len(hashed)):] ) def edgedb_name_to_pg_name(name: str, prefix_length: int = 0) -> str: """Convert Gel name to a valid PostgresSQL column name. PostgreSQL has a limit of 63 characters for column names. @param name: Gel name to convert @return: PostgreSQL column name """ if not (0 <= prefix_length < s_def.MAX_NAME_LENGTH): raise ValueError('supplied name is too long ' 'to be kept in original form') name = str(name) if len(name) <= s_def.MAX_NAME_LENGTH - prefix_length: return name return _edgedb_name_to_pg_name(name, prefix_length) def convert_name( name: s_name.QualName, suffix='', catenate=True, *, versioned=True, ): schema = get_module_backend_name(name.get_module_name()) if suffix: sname = f'{name.name}_{suffix}' else: sname = name.name dbname = edgedb_name_to_pg_name(sname) if versioned: schema = maybe_versioned_schema(schema) if catenate: return qname(schema, dbname) else: return schema, dbname def get_database_backend_name(db_name: str, *, tenant_id: str) -> str: return f'{tenant_id}_{db_name}' def get_role_backend_name(role_name: str, *, tenant_id: str) -> str: return f'{tenant_id}_{role_name}' def update_aspect(name, aspect): """Update the aspect on a non catenated name. It also needs to be from an object that uses ids for names""" suffix = get_aspect_suffix(aspect) stripped = name[1].rsplit("_", 1)[0] if suffix: return (name[0], f'{stripped}_{suffix}') else: return (name[0], stripped) def get_scalar_backend_name( id, module_name, catenate=True, *, versioned=True, aspect=None ): if aspect is None: aspect = 'domain' if aspect not in ( "domain", "sequence", "enum", "enum-cast-into-str", "enum-cast-from-str", "source-del-imm-otl-f", "source-del-imm-otl-t", ): raise ValueError( f'unexpected aspect for scalar backend name: {aspect!r}') name = s_name.QualName(module=module_name, name=str(id)) # XXX: TRAMPOLINE: VERSIONING??? if aspect.startswith("enum-cast-"): suffix = "_into_str" if aspect == "enum-cast-into-str" else "_from_str" name = s_name.QualName(name.module, name.name + suffix) return get_cast_backend_name( name, catenate, versioned=versioned, aspect="function") return convert_name(name, aspect, catenate, versioned=False) def get_aspect_suffix(aspect): if aspect == 'table': return '' elif aspect == 'inhview': return 't' else: return aspect def is_inhview_name(name: str) -> bool: return name.endswith('_t') def get_objtype_backend_name( id: uuid.UUID, module_name: str, *, catenate: bool = True, versioned: bool = False, aspect: Optional[str] = None, ): if aspect is None: aspect = 'table' if ( aspect not in {'table', 'inhview', 'dummy'} and not re.match( r'(source|target)-del-(def|imm)-(inl|otl)-(f|t)', aspect) and not aspect.startswith("ext") ): raise ValueError( f'unexpected aspect for object type backend name: {aspect!r}') name = s_name.QualName(module=module_name, name=str(id)) suffix = get_aspect_suffix(aspect) return convert_name( name, suffix=suffix, catenate=catenate, versioned=versioned) def get_pointer_backend_name( id, module_name, *, catenate=False, aspect=None, versioned=True ): if aspect is None: aspect = 'table' if aspect not in ('table', 'index', 'inhview', 'dummy'): raise ValueError( f'unexpected aspect for pointer backend name: {aspect!r}') name = s_name.QualName(module=module_name, name=str(id)) suffix = get_aspect_suffix(aspect) return convert_name( name, suffix=suffix, catenate=catenate, versioned=versioned ) operator_map = { s_name.name_from_string('std::AND'): 'AND', s_name.name_from_string('std::OR'): 'OR', s_name.name_from_string('std::NOT'): 'NOT', s_name.name_from_string('std::?='): 'IS NOT DISTINCT FROM', s_name.name_from_string('std::?!='): 'IS DISTINCT FROM', s_name.name_from_string('std::LIKE'): 'LIKE', s_name.name_from_string('std::ILIKE'): 'ILIKE', s_name.name_from_string('std::NOT LIKE'): 'NOT LIKE', s_name.name_from_string('std::NOT ILIKE'): 'NOT ILIKE', } def get_operator_backend_name( name, catenate=False, *, versioned=True, aspect=None ): if aspect is None: aspect = 'operator' if aspect == 'function': return convert_name(name, 'f', catenate=catenate, versioned=versioned) elif aspect != 'operator': raise ValueError( f'unexpected aspect for operator backend name: {aspect!r}') oper_name = operator_map.get(name) if oper_name is None: oper_name = name.name if re.search(r'[a-zA-Z]', oper_name): raise ValueError( f'cannot represent operator {oper_name} in Postgres') oper_name = f'`{oper_name}`' schema = 'edgedb' else: schema = '' if catenate: return qname(schema, oper_name) else: return schema, oper_name def get_cast_backend_name( fullname: s_name.QualName, catenate=False, *, versioned=True, aspect=None ): if aspect == "function": return convert_name( fullname, "f", catenate=catenate, versioned=versioned) else: raise ValueError( f'unexpected aspect for cast backend name: {aspect!r}') def get_function_backend_name( name, backend_name, catenate=False, versioned=True, ): real_name = backend_name or name.name fullname = s_name.QualName(module=name.module, name=real_name) schema, func_name = convert_name( fullname, catenate=False, versioned=versioned) if catenate: return qname(schema, func_name) else: return schema, func_name def get_constraint_backend_name(id, module_name, catenate=True, *, aspect=None): if aspect not in ('trigproc', 'index'): raise ValueError( f'unexpected aspect for constraint backend name: {aspect!r}') sname = str(id) if aspect == 'index': aspect = None sname = get_constraint_raw_name(id) name = s_name.QualName(module=module_name, name=sname) return convert_name(name, aspect, catenate) def get_constraint_raw_name(id): return f'{id};schemaconstr' def get_index_backend_name(id, module_name, catenate=True, *, aspect=None): if aspect is None: aspect = 'index' name = s_name.QualName(module=module_name, name=str(id)) return convert_name(name, aspect, catenate) def get_index_table_backend_name( index: s_indexes.Index, schema: s_schema.Schema, *, aspect: Optional[str] = None, ) -> tuple[str, str]: subject = index.get_subject(schema) assert isinstance(subject, s_types.Type) return get_backend_name(schema, subject, aspect=aspect, catenate=False) def get_tuple_backend_name( id, catenate=True, *, aspect=None ) -> str | tuple[str, str]: name = s_name.QualName(module='edgedb', name=f'{id}_t') return convert_name(name, aspect, catenate) @overload def get_backend_name( schema: s_schema.Schema, obj: so.Object, catenate: Literal[True]=True, *, versioned: bool=True, aspect: Optional[str]=None ) -> str: ... @overload def get_backend_name( schema: s_schema.Schema, obj: so.Object, catenate: Literal[False], *, versioned: bool=True, aspect: Optional[str]=None ) -> tuple[str, str]: ... def get_backend_name( schema: s_schema.Schema, obj: so.Object, catenate: bool=True, *, aspect: Optional[str]=None, versioned: bool=True, ) -> str | tuple[str, str]: name: s_name.QualName | s_name.Name if isinstance(obj, s_objtypes.ObjectType): name = obj.get_name(schema) return get_objtype_backend_name( obj.id, name.module, catenate=catenate, aspect=aspect, versioned=versioned, ) elif isinstance(obj, s_pointers.Pointer): name = obj.get_name(schema) return get_pointer_backend_name(obj.id, name.module, catenate=catenate, versioned=versioned, aspect=aspect) elif isinstance(obj, s_scalars.ScalarType): name = obj.get_name(schema) return get_scalar_backend_name(obj.id, name.module, catenate=catenate, versioned=versioned, aspect=aspect) elif isinstance(obj, s_opers.Operator): name = obj.get_shortname(schema) return get_operator_backend_name( name, catenate, versioned=versioned, aspect=aspect) elif isinstance(obj, s_casts.Cast): name = obj.get_name(schema) return get_cast_backend_name( name, catenate, versioned=versioned, aspect=aspect) elif isinstance(obj, s_func.Function): name = obj.get_shortname(schema) backend_name = obj.get_backend_name(schema) return get_function_backend_name( name, backend_name, catenate, versioned=versioned) elif isinstance(obj, s_constr.Constraint): name = obj.get_name(schema) return get_constraint_backend_name( obj.id, name.module, catenate, aspect=aspect) elif isinstance(obj, s_indexes.Index): name = obj.get_name(schema) return get_index_backend_name( obj.id, name.module, catenate, aspect=aspect) elif isinstance(obj, s_types.Tuple): # XXX: TRAMPOLINE: VERSIONED? return get_tuple_backend_name( obj.id, catenate, aspect=aspect) else: raise ValueError(f'cannot determine backend name for {obj!r}') def get_object_from_backend_name(schema, metaclass, name, *, aspect=None): if issubclass(metaclass, s_objtypes.ObjectType): table_name = name[1] obj_id = uuidgen.UUID(table_name) return schema.get_by_id(obj_id) elif issubclass(metaclass, s_pointers.Pointer): obj_id = uuidgen.UUID(name) return schema.get_by_id(obj_id) else: raise ValueError( f'cannot determine object from backend name for {metaclass!r}') def get_sql_value_function_op(op: pgast.SQLValueFunctionOP) -> str: from edb.pgsql.ast import SQLValueFunctionOP as OP NAMES = { OP.CURRENT_DATE: "current_date", OP.CURRENT_TIME: "current_time", OP.CURRENT_TIME_N: "current_time", OP.CURRENT_TIMESTAMP: "current_timestamp", OP.CURRENT_TIMESTAMP_N: "current_timestamp", OP.LOCALTIME: "localtime", OP.LOCALTIME_N: "localtime", OP.LOCALTIMESTAMP: "localtimestamp", OP.LOCALTIMESTAMP_N: "localtimestamp", OP.CURRENT_ROLE: "current_role", OP.CURRENT_USER: "current_user", OP.USER: "user", OP.SESSION_USER: "session_user", OP.CURRENT_CATALOG: "current_catalog", OP.CURRENT_SCHEMA: "current_schema", } return NAMES[op] # Settings that are enums or bools and should not be quoted. # Can be retrieved from PostgreSQL with: # SELECt name FROM pg_catalog.pg_settings WHERE vartype IN ('enum', 'bool'); ENUM_SETTINGS = { 'allow_alter_system', 'allow_in_place_tablespaces', 'allow_system_table_mods', 'archive_mode', 'array_nulls', 'autovacuum', 'backslash_quote', 'bytea_output', 'check_function_bodies', 'client_min_messages', 'compute_query_id', 'constraint_exclusion', 'data_checksums', 'data_sync_retry', 'debug_assertions', 'debug_logical_replication_streaming', 'debug_parallel_query', 'debug_pretty_print', 'debug_print_parse', 'debug_print_plan', 'debug_print_rewritten', 'default_toast_compression', 'default_transaction_deferrable', 'default_transaction_isolation', 'default_transaction_read_only', 'dynamic_shared_memory_type', 'edb_stat_statements.save', 'edb_stat_statements.track', 'edb_stat_statements.track_planning', 'edb_stat_statements.track_utility', 'enable_async_append', 'enable_bitmapscan', 'enable_gathermerge', 'enable_group_by_reordering', 'enable_hashagg', 'enable_hashjoin', 'enable_incremental_sort', 'enable_indexonlyscan', 'enable_indexscan', 'enable_material', 'enable_memoize', 'enable_mergejoin', 'enable_nestloop', 'enable_parallel_append', 'enable_parallel_hash', 'enable_partition_pruning', 'enable_partitionwise_aggregate', 'enable_partitionwise_join', 'enable_presorted_aggregate', 'enable_seqscan', 'enable_sort', 'enable_tidscan', 'escape_string_warning', 'event_triggers', 'exit_on_error', 'fsync', 'full_page_writes', 'geqo', 'gss_accept_delegation', 'hot_standby', 'hot_standby_feedback', 'huge_pages', 'huge_pages_status', 'icu_validation_level', 'ignore_checksum_failure', 'ignore_invalid_pages', 'ignore_system_indexes', 'in_hot_standby', 'integer_datetimes', 'intervalstyle', 'jit', 'jit_debugging_support', 'jit_dump_bitcode', 'jit_expressions', 'jit_profiling_support', 'jit_tuple_deforming', 'krb_caseins_users', 'lo_compat_privileges', 'log_checkpoints', 'log_connections', 'log_disconnections', 'log_duration', 'log_error_verbosity', 'log_executor_stats', 'log_hostname', 'log_lock_waits', 'log_min_error_statement', 'log_min_messages', 'log_parser_stats', 'log_planner_stats', 'log_recovery_conflict_waits', 'log_replication_commands', 'log_statement', 'log_statement_stats', 'log_truncate_on_rotation', 'logging_collector', 'parallel_leader_participation', 'password_encryption', 'plan_cache_mode', 'quote_all_identifiers', 'recovery_init_sync_method', 'recovery_prefetch', 'recovery_target_action', 'recovery_target_inclusive', 'remove_temp_files_after_crash', 'restart_after_crash', 'row_security', 'send_abort_for_crash', 'send_abort_for_kill', 'session_replication_role', 'shared_memory_type', 'ssl', 'ssl_max_protocol_version', 'ssl_min_protocol_version', 'ssl_passphrase_command_supports_reload', 'ssl_prefer_server_ciphers', 'standard_conforming_strings', 'stats_fetch_consistency', 'summarize_wal', 'sync_replication_slots', 'synchronize_seqscans', 'synchronous_commit', 'syslog_facility', 'syslog_sequence_numbers', 'syslog_split_messages', 'trace_connection_negotiation', 'trace_notify', 'trace_sort', 'track_activities', 'track_commit_timestamp', 'track_counts', 'track_functions', 'track_io_timing', 'track_wal_io_timing', 'transaction_deferrable', 'transaction_isolation', 'transaction_read_only', 'transform_null_equals', 'update_process_title', 'wal_compression', 'wal_init_zero', 'wal_level', 'wal_log_hints', 'wal_receiver_create_temp_slot', 'wal_recycle', 'wal_sync_method', 'xmlbinary', 'xmloption', 'zero_damaged_pages', # additionally, there are some settings that also should not be quoted 'work_mem', } def setting_to_sql(name, setting): is_enum = name.lower() in ENUM_SETTINGS assert typeutils.is_container(setting) return ', '.join(setting_val_to_sql(v, is_enum) for v in setting) def setting_val_to_sql(val: str | int | float, is_enum: bool): if isinstance(val, str): if is_enum: # special case: no quoting return val # quote as identifier return quote_ident(val) if isinstance(val, int): return str(val) if isinstance(val, float): return str(val) raise NotImplementedError('cannot convert setting to SQL: ', val) ================================================ FILE: edb/pgsql/compiler/ARCHITECTURE.md ================================================ # Architecture of PG compiler RelVar = Relation variable. Basically an instance of a relation within a query. PathVar = Reference to a column, as seen from within the declaring query. OutputVar = Reference to a column, that can be used from outside of declaring query. ## Recursive column injection When as IR set is compiled, it may not be known which properties of that object will be needed downstream. To avoid fetching, computing and possibly materializing too much data, sets are compiled in two steps: 1. Compile general structure of the query. In this process every IR set will be bound to some SQL select statement. 2. Inject columns into this tree. This is mainly done in `pathctx.get_path_var`, which: - finds which RelVar provides source aspect for this path (see `pathctx._find_rel_rvar`) - determines what is the OutputVar of this path within the RelVar (see `pathctx.get_path_output`). This recursively calls `get_path_var`. - when an actual table is encountered, a plain ColRef to it's columns is returned. ## Overlays Postgres has a limitation where effects of any DML are not visible in the same query. For example: ``` WITH insert_result AS (INSERT INTO my_table(a) VALUES (1) RETURNING a) SELECT a FROM my_table, insert_result ``` In this query, `my_table` will not contain the inserted value. Obvious solution is use `insert_result` only and not rely on `my_table` anymore. This is the gist of what overlays accomplish. They define a new relation that should be used instead of the base table when the compiler wants to pull data for some path_id. Overlay also allows specifying operation that needs to be applied when constructing the rel var: union, exclude, replace. For example, union is used after INSERTing, exclude when DELETING. Overlays are also used for access policies and rewrites. ## Misc Most references to database objects are prepared by `common.get_backend_name`. ================================================ FILE: edb/pgsql/compiler/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, Mapping, TYPE_CHECKING from dataclasses import dataclass import itertools import uuid from edb import errors from edb.ir import ast as irast from edb.pgsql import ast as pgast from edb.pgsql import params as pgparams from edb.pgsql import types as pgtypes from . import config as _config_compiler # NOQA from . import expr as _expr_compiler # NOQA from . import stmt as _stmt_compiler # NOQA from . import clauses from . import context from . import dispatch from . import dml from . import pathctx from . import aliases from .context import OutputFormat as OutputFormat # NOQA if TYPE_CHECKING: import enums as pgce @dataclass(kw_only=True, slots=True, repr=False, eq=False, frozen=True) class CompileResult: ast: pgast.Base env: context.Environment argmap: dict[str, pgast.Param] cached_params: Optional[list[tuple[str, ...]]] = None def compile_ir_to_sql_tree( ir_expr: irast.Base, *, output_format: Optional[OutputFormat] = None, ignore_shapes: bool = False, explicit_top_cast: Optional[irast.TypeRef] = None, singleton_mode: bool = False, named_param_prefix: Optional[tuple[str, ...]] = None, expected_cardinality_one: bool = False, is_explain: bool = False, external_rvars: Optional[ Mapping[tuple[irast.PathId, pgce.PathAspect], pgast.PathRangeVar] ] = None, external_rels: Optional[ Mapping[ irast.PathId, tuple[ pgast.BaseRelation | pgast.CommonTableExpr, tuple[pgce.PathAspect, ...] ], ] ] = None, json_parameters: bool = False, backend_runtime_params: Optional[pgparams.BackendRuntimeParams]=None, cache_as_function: bool = False, alias_generator: Optional[aliases.AliasGenerator] = None, versioned_stdlib: bool = True, # HACK? versioned_singleton: bool = False, sql_dml_mode: bool = False, ) -> CompileResult: if singleton_mode and not versioned_singleton: versioned_stdlib = False try: # Transform to sql tree query_params: list[irast.Param] = [] query_globals: list[irast.Global] = [] server_param_conversion_params: list[irast.Param] = [] type_rewrites: dict[tuple[uuid.UUID, bool], irast.Set] = {} triggers: tuple[tuple[irast.Trigger, ...], ...] = () singletons = [] if isinstance(ir_expr, irast.Statement): scope_tree = ir_expr.scope_tree query_params = list(ir_expr.params) query_globals = list(ir_expr.globals) server_param_conversion_params = ( ir_expr.server_param_conversion_params ) type_rewrites = ir_expr.type_rewrites singletons = ir_expr.singletons triggers = ir_expr.triggers ir_expr = ir_expr.expr elif isinstance(ir_expr, irast.ConfigCommand): assert ir_expr.scope_tree scope_tree = ir_expr.scope_tree query_params = list(ir_expr.params) if ir_expr.globals: query_globals = list(ir_expr.globals) if ir_expr.type_rewrites: type_rewrites = ir_expr.type_rewrites else: scope_tree = irast.new_scope_tree() # In JSON parameters mode, keep only the synthetic globals if json_parameters: query_globals = [ g for g in query_globals if g.global_name.module == '__' ] # Ensure permissions are after globals, since they are injected # after other globals. query_globals.sort(key=lambda g: g.is_permission) scope_tree_nodes = { node.unique_id: node for node in scope_tree.descendants if node.unique_id is not None } if backend_runtime_params is None: backend_runtime_params = pgparams.get_default_runtime_params() env = context.Environment( alias_generator=alias_generator, output_format=output_format, expected_cardinality_one=expected_cardinality_one, named_param_prefix=named_param_prefix, query_params=list(tuple(query_params) + tuple(query_globals)), type_rewrites=type_rewrites, ignore_object_shapes=ignore_shapes, explicit_top_cast=explicit_top_cast, is_explain=is_explain, singleton_mode=singleton_mode, scope_tree_nodes=scope_tree_nodes, external_rvars=external_rvars, backend_runtime_params=backend_runtime_params, versioned_stdlib=versioned_stdlib, sql_dml_mode=sql_dml_mode, ) ctx = context.CompilerContextLevel( None, context.ContextSwitchMode.TRANSPARENT, env=env, scope_tree=scope_tree, ) ctx.rel = pgast.SelectStmt() _ = context.CompilerContext(initial=ctx) ctx.singleton_mode = singleton_mode ctx.expr_exposed = True for sing in singletons: ctx.path_scope[sing] = ctx.rel if external_rels: ctx.external_rels = external_rels clauses.populate_argmap( query_params, query_globals, server_param_conversion_params, ctx=ctx, ) qtree = dispatch.compile(ir_expr, ctx=ctx) dml.compile_triggers(triggers, qtree, ctx=ctx) if not singleton_mode: if isinstance(ir_expr, irast.Set): assert isinstance(qtree, pgast.Query) clauses.fini_toplevel(qtree, ctx) elif isinstance(qtree, pgast.Query): # Other types of expressions may compile to queries which may # use inheritance CTEs. Ensure they are added here. clauses.insert_ctes(qtree, ctx) if cache_as_function: cached_params_idx = { ctx.argmap[param.name].index: ( pgtypes.pg_type_from_ir_typeref( param.ir_type.base_type or param.ir_type, # Needs serialized=True so types without their own # binary encodings (like postgis::box2d) get mapped # to the real underlying type. serialized=True, ) ) for param in itertools.chain( ctx.env.query_params, server_param_conversion_params, ) if not param.sub_params } else: cached_params_idx = {} cached_params = [p for _, p in sorted(cached_params_idx.items())] except errors.EdgeDBError: # Don't wrap propertly typed EdgeDB errors into # InternalServerError; raise them as is. raise except Exception as e: # pragma: no cover try: args = [e.args[0]] except (AttributeError, IndexError): args = [] raise errors.InternalServerError(*args) from e return CompileResult( ast=qtree, env=env, argmap=ctx.argmap, cached_params=cached_params ) def new_external_rvar( *, rel_name: tuple[str, ...], path_id: irast.PathId, outputs: Mapping[tuple[irast.PathId, tuple[pgce.PathAspect, ...]], str], ) -> pgast.RelRangeVar: """Construct a ``RangeVar`` instance given a relation name and a path id. Given an optionally-qualified relation name *rel_name* and a *path_id*, return a ``RangeVar`` instance over the specified relation that is then assumed to represent the *path_id* binding. This is useful in situations where it is necessary to "prime" the compiler with a list of external relations that exist in a larger SQL expression that _this_ expression is being embedded into. The *outputs* mapping optionally specifies a set of outputs in the resulting range var as a ``(path_id, tuple-of-aspects): attribute name`` mapping. """ rel = new_external_rel(rel_name=rel_name, path_id=path_id) assert rel.name alias = pgast.Alias(aliasname=rel.name) if not path_id.is_ptr_path(): rvar = pgast.RelRangeVar( relation=rel, typeref=path_id.target, alias=alias) else: rvar = pgast.RelRangeVar( relation=rel, alias=alias) for (output_pid, output_aspects), colname in outputs.items(): var = pgast.ColumnRef(name=[colname]) for aspect in output_aspects: rel.path_outputs[output_pid, aspect] = var return rvar def new_external_rvar_as_subquery( *, rel_name: tuple[str, ...], path_id: irast.PathId, aspects: tuple[pgce.PathAspect, ...], ) -> pgast.SelectStmt: rvar = new_external_rvar( rel_name=rel_name, path_id=path_id, outputs={}, ) qry = pgast.SelectStmt( from_clause=[rvar], ) for aspect in aspects: pathctx.put_path_rvar(qry, path_id, rvar, aspect=aspect) return qry def new_external_rel( *, rel_name: tuple[str, ...], path_id: irast.PathId, ) -> pgast.Relation: if len(rel_name) == 1: table_name = rel_name[0] schema_name = None elif len(rel_name) == 2: schema_name, table_name = rel_name else: raise AssertionError(f'unexpected rvar name: {rel_name}') return pgast.Relation( name=table_name, schemaname=schema_name, path_id=path_id, ) ================================================ FILE: edb/pgsql/compiler/aliases.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb.common import compiler from edb.pgsql import common class AliasGenerator(compiler.AliasGenerator): def get(self, hint: str = '') -> str: alias = super().get(hint) return common.edgedb_name_to_pg_name(alias) ================================================ FILE: edb/pgsql/compiler/astutils.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Context-agnostic SQL AST utilities.""" from __future__ import annotations from typing import Optional, Iterator, Sequence, TYPE_CHECKING from edb.ir import typeutils as irtyputils from edb.pgsql import ast as pgast from edb.pgsql import common from edb.pgsql import types as pg_types if TYPE_CHECKING: from typing_extensions import TypeGuard from edb.ir import ast as irast from . import context def tuple_element_for_shape_el( shape_el: irast.Set, value: Optional[pgast.BaseExpr]=None, *, ctx: context.CompilerContextLevel ) -> pgast.TupleElementBase: from edb.ir import ast as irast if shape_el.path_id.is_type_intersection_path(): assert isinstance(shape_el.expr, irast.Pointer) rptr = shape_el.expr.source.expr else: rptr = shape_el.expr assert isinstance(rptr, irast.Pointer) ptrref = rptr.ptrref ptrname = ptrref.shortname if value is not None: return pgast.TupleElement( path_id=shape_el.path_id, name=ptrname.name, val=value, ) else: return pgast.TupleElementBase( path_id=shape_el.path_id, name=ptrname.name, ) def tuple_getattr( tuple_val: pgast.BaseExpr, tuple_typeref: irast.TypeRef, attr: str, ) -> pgast.BaseExpr: ttypes = [] pgtypes = [] for i, st in enumerate(tuple_typeref.subtypes): pgtype = pg_types.pg_type_from_ir_typeref(st) pgtypes.append(pgtype) if st.element_name: ttypes.append(st.element_name) else: ttypes.append(str(i)) index = ttypes.index(attr) set_expr: pgast.BaseExpr if irtyputils.is_persistent_tuple(tuple_typeref): set_expr = pgast.Indirection( arg=tuple_val, indirection=[pgast.RecordIndirectionOp(name=attr)], ) else: set_expr = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=[str(index)], ), ), ], from_clause=[ pgast.RangeFunction( functions=[ pgast.FuncCall( name=('unnest',), args=[ pgast.ArrayExpr( elements=[tuple_val], ) ], coldeflist=[ pgast.ColumnDef( name=str(i), typename=pgast.TypeName( name=t ) ) for i, t in enumerate(pgtypes) ] ) ] ) ] ) return set_expr def array_get_inner_array( wrapped_array: pgast.BaseExpr, array_typeref: irast.TypeRef, ) -> pgast.BaseExpr: """Unwrap and get the inner array of a formerly nested array. Since array> is implemented as array>>, when an element is accessed, it needs to be unwrapped. Essentially, this function takes tuple> and returns array<...> Postgres does not support arbitrarily accessing fields out of unnamed composites and so we need to do an extra unnest(array[]) to be able to specify the name and type our resulting array. For example, the query: `select [[1]][0];` will produce the following SQL: SELECT "expr-6~2"."array_value~4" AS "array_serialized~1" FROM LATERAL (SELECT "expr-5~2"."array_value~3" AS "array_value~4" FROM LATERAL (SELECT (SELECT "0" FROM -- EXTRA unnest(array[]) unnest(ARRAY[ -- INDEX INDIRECTION edgedb_v7_2f26206480._index( "expr-3~2"."array_value~2", ($2)::int8, 'ERROR MESSAGE' ) ]) AS ("0" int8[]) ) AS "array_value~3" FROM LATERAL -- INITAL ARRAY [[1]] (SELECT ARRAY[ROW("expr-2~2"."array_value~1")] AS "array_value~2" FROM LATERAL (SELECT ARRAY[($1)::int8] AS "array_value~1" ) AS "expr-2~2" ) AS "expr-3~2" ) AS "expr-5~2" ) AS "expr-6~2" WHERE ("expr-6~2"."array_value~4" IS NOT NULL) LIMIT (SELECT (101)::int8 AS "expr~7_value~1" ) """ return pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef(name=['0'])), ], from_clause=[ pgast.RangeFunction( functions=[ pgast.FuncCall( name=('unnest',), args=[ pgast.ArrayExpr( elements=[wrapped_array], ) ], coldeflist=[ pgast.ColumnDef( name='0', typename=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(array_typeref) ) ) ] ) ] ) ] ) def is_null_const(expr: pgast.BaseExpr) -> bool: if isinstance(expr, pgast.TypeCast): expr = expr.arg return isinstance(expr, pgast.NullConstant) def is_set_op_query(query: pgast.BaseExpr) -> TypeGuard[pgast.SelectStmt]: return ( isinstance(query, pgast.SelectStmt) and query.op is not None ) def get_leftmost_query(query: pgast.Query) -> pgast.Query: result = query while is_set_op_query(result): assert result.larg result = result.larg return result def each_query_in_set(qry: pgast.Query) -> Iterator[pgast.Query]: # We do this iteratively instead of recursively (with yield from) # to avoid being pointlessly quadratic. stack = [qry] while stack: qry = stack.pop() if is_set_op_query(qry): assert qry.larg and qry.rarg stack.append(qry.rarg) stack.append(qry.larg) else: yield qry def each_base_rvar(rvar: pgast.BaseRangeVar) -> Iterator[pgast.BaseRangeVar]: # We do this iteratively instead of recursively (with yield from) # to avoid being pointlessly quadratic. stack = [rvar] while stack: rvar = stack.pop() if isinstance(rvar, pgast.JoinExpr): for clause in reversed(rvar.joins): stack.append(clause.rarg) stack.append(rvar.larg) else: yield rvar def new_binop( lexpr: pgast.BaseExpr, rexpr: pgast.BaseExpr, op: str, ) -> pgast.Expr: return pgast.Expr( name=op, lexpr=lexpr, rexpr=rexpr ) def extend_binop( binop: Optional[pgast.BaseExpr], *exprs: pgast.BaseExpr, op: str = 'AND', ) -> pgast.BaseExpr: exprlist = list(exprs) result: pgast.BaseExpr if binop is None: result = exprlist.pop(0) else: result = binop for expr in exprlist: if expr is not None and expr is not result: result = new_binop(lexpr=result, op=op, rexpr=expr) return result def extend_concat( expr: str | pgast.BaseExpr, *exprs: str | pgast.BaseExpr ) -> pgast.BaseExpr: return extend_binop( pgast.StringConstant(val=expr) if isinstance(expr, str) else expr, *[ pgast.StringConstant(val=e) if isinstance(e, str) else e for e in exprs ], op='||', ) def new_coalesce( expr: pgast.BaseExpr, fallback: pgast.BaseExpr ) -> pgast.BaseExpr: return pgast.FuncCall(name=('coalesce',), args=[expr, fallback]) def extend_select_op( stmt: Optional[pgast.SelectStmt], *stmts: pgast.SelectStmt, op: str = 'UNION', ) -> Optional[pgast.SelectStmt]: stmt_list = list(stmts) result: pgast.SelectStmt if stmt is None: if len(stmt_list) == 0: return None result = stmt_list.pop(0) else: result = stmt for s in stmt_list: if s is not None and s is not result: result = pgast.SelectStmt(larg=result, op=op, rarg=s) return result def new_unop(op: str, expr: pgast.BaseExpr) -> pgast.Expr: return pgast.Expr(name=op, rexpr=expr) def join_condition( lref: pgast.ColumnRef, rref: pgast.ColumnRef, ) -> pgast.BaseExpr: path_cond: pgast.BaseExpr = new_binop(lref, rref, op='=') if lref.optional: opt_cond = pgast.NullTest(arg=lref) path_cond = extend_binop(path_cond, opt_cond, op='OR') if rref.optional: opt_cond = pgast.NullTest(arg=rref) path_cond = extend_binop(path_cond, opt_cond, op='OR') return path_cond def safe_array_expr( elements: list[pgast.BaseExpr], *, ser_safe: bool = False, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: result: pgast.BaseExpr = pgast.ArrayExpr( elements=elements, ser_safe=ser_safe, ) if any(el.nullable for el in elements): result = pgast.FuncCall( name=edgedb_func('_nullif_array_nulls', ctx=ctx), args=[result], ser_safe=ser_safe, ) return result def find_column_in_subselect_rvar( rvar: pgast.RangeSubselect, name: str, ) -> int: # Range over a subquery, we can inspect the output list # of the subquery. If the subquery is a UNION (or EXCEPT), # we take the leftmost non-setop query. subquery = get_leftmost_query(rvar.subquery) for i, rt in enumerate(subquery.target_list): if rt.name == name: return i raise RuntimeError(f'cannot find {name!r} in {rvar} output') def get_column( rvar: pgast.BaseRangeVar, colspec: str | pgast.ColumnRef, *, is_packed_multi: bool = True, nullable: Optional[bool] = None, ) -> pgast.ColumnRef: if isinstance(colspec, pgast.ColumnRef): colname = colspec.name[-1] else: colname = colspec assert isinstance(colname, str) ser_safe = False if nullable is None: if isinstance(rvar, pgast.RelRangeVar): # Range over a relation, we cannot infer nullability in # this context, so assume it's true, unless we are looking # at a colspec that says it is false if isinstance(colspec, pgast.ColumnRef): nullable = colspec.nullable else: nullable = True elif isinstance(rvar, pgast.RangeSubselect): col_idx = find_column_in_subselect_rvar(rvar, colname) if is_set_op_query(rvar.subquery): nullables = [] ser_safes = [] for q in each_query_in_set(rvar.subquery): nullables.append(q.target_list[col_idx].nullable) ser_safes.append(q.target_list[col_idx].ser_safe) nullable = any(nullables) ser_safe = all(ser_safes) else: rt = rvar.subquery.target_list[col_idx] nullable = rt.nullable ser_safe = rt.ser_safe elif isinstance(rvar, pgast.RangeFunction): # Range over a function. # TODO: look into the possibility of inspecting coldeflist. nullable = True elif isinstance(rvar, pgast.JoinExpr): raise RuntimeError( f'cannot find {colname!r} in unexpected {rvar!r} range var') name = [rvar.alias.aliasname, colname] return pgast.ColumnRef( name=name, nullable=nullable, ser_safe=ser_safe, is_packed_multi=is_packed_multi) def get_rvar_var( rvar: pgast.BaseRangeVar, var: pgast.OutputVar ) -> pgast.OutputVar: fieldref: pgast.OutputVar if isinstance(var, pgast.TupleVarBase): elements = [] for el in var.elements: assert isinstance(el.name, pgast.OutputVar) val = get_rvar_var(rvar, el.name) elements.append( pgast.TupleElement( path_id=el.path_id, name=el.name, val=val)) fieldref = pgast.TupleVar( elements, named=var.named, typeref=var.typeref, is_packed_multi=var.is_packed_multi, ) elif isinstance(var, pgast.ColumnRef): fieldref = get_column(rvar, var, is_packed_multi=var.is_packed_multi) elif isinstance(var, pgast.ExprOutputVar): fieldref = var else: raise AssertionError(f'unexpected OutputVar subclass: {var!r}') return fieldref def strip_output_var( var: pgast.OutputVar, *, optional: Optional[bool] = None, nullable: Optional[bool] = None, ) -> pgast.OutputVar: result: pgast.OutputVar if isinstance(var, pgast.TupleVarBase): elements = [] for el in var.elements: val: pgast.OutputVar el_name = el.name if isinstance(el_name, str): val = pgast.ColumnRef(name=[el_name]) elif isinstance(el_name, pgast.OutputVar): val = strip_output_var(el_name) else: raise AssertionError( f'unexpected tuple element class: {el_name!r}') elements.append( pgast.TupleElement( path_id=el.path_id, name=el_name, val=val)) result = pgast.TupleVar( elements, named=var.named, typeref=var.typeref, ) elif isinstance(var, pgast.ColumnRef): result = pgast.ColumnRef( name=[var.name[-1]], optional=optional if optional is not None else var.optional, nullable=nullable if nullable is not None else var.nullable, ) else: raise AssertionError(f'unexpected OutputVar subclass: {var!r}') return result def select_is_simple(stmt: pgast.SelectStmt) -> bool: return ( not stmt.distinct_clause and not stmt.where_clause and not stmt.group_clause and not stmt.having_clause and not stmt.window_clause and not stmt.values and not stmt.sort_clause and not stmt.limit_offset and not stmt.limit_count and not stmt.locking_clause and not stmt.op ) def is_row_expr(expr: pgast.BaseExpr) -> bool: while True: if isinstance(expr, (pgast.RowExpr, pgast.ImplicitRowExpr)): return True elif isinstance(expr, pgast.TypeCast): expr = expr.arg else: return False def _get_target_from_range( target: pgast.BaseExpr, rvar: pgast.BaseRangeVar ) -> Optional[pgast.BaseExpr]: """Try to read a target out of a very simple rvar. The goal here is to allow collapsing trivial pass-through subqueries. In particular, given a target `foo.bar` and an rvar `(SELECT as "bar") AS "foo"`, we produce . We can also recursively handle the nested case. """ if ( not isinstance(rvar, pgast.RangeSubselect) # Check that the relation name matches the rvar or not isinstance(target, pgast.ColumnRef) or not target.name or target.name[0] != rvar.alias.aliasname # And that the rvar is a simple subquery with one target # and at most one from clause or not (subq := rvar.subquery) or len(subq.target_list) != 1 or not isinstance(subq, pgast.SelectStmt) or not select_is_simple(subq) or len(subq.from_clause) > 1 # And that the one target matches or not (inner_tgt := rvar.subquery.target_list[0]) or inner_tgt.name != target.name[1] ): return None if subq.from_clause: return _get_target_from_range(inner_tgt.val, subq.from_clause[0]) else: return inner_tgt.val def collapse_query(query: pgast.Query) -> pgast.BaseExpr: """Try to collapse trivial queries into simple expressions. In particular, we want to transform `(SELECT foo.bar FROM LATERAL (SELECT as "bar") AS "foo")` into simply ``. """ if not isinstance(query, pgast.SelectStmt): return query if ( isinstance(query, pgast.SelectStmt) and len(query.target_list) == 1 and len(query.from_clause) == 0 and select_is_simple(query) ): return query.target_list[0].val if ( not isinstance(query, pgast.SelectStmt) or len(query.target_list) != 1 or len(query.from_clause) != 1 ): return query val = _get_target_from_range( query.target_list[0].val, query.from_clause[0]) if val: return val else: return query def compile_typeref(expr: irast.TypeRef) -> pgast.BaseExpr: if expr.collection: raise NotImplementedError() else: result = pgast.TypeCast( arg=pgast.StringConstant(val=str(expr.id)), type_name=pgast.TypeName( name=('uuid',) ) ) return result def maybe_unpack_row(expr: pgast.Base) -> Sequence[pgast.BaseExpr]: assert isinstance(expr, pgast.BaseExpr) match expr: case pgast.ImplicitRowExpr(): return expr.args case pgast.RowExpr(): return expr.args return (expr,) def edgedb_func( name: str, *, ctx: context.CompilerContextLevel ) -> tuple[str, ...]: return common.maybe_versioned_name( ('edgedb', name), versioned=ctx.env.versioned_stdlib, ) ================================================ FILE: edb/pgsql/compiler/clauses.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, Sequence import random from edb.common import ast as ast_visitor from edb.edgeql import qltypes from edb.ir import ast as irast from edb.ir import utils as irutils from edb.pgsql import ast as pgast from edb.pgsql import types as pg_types from . import astutils from . import context from . import dispatch from . import dml from . import enums as pgce from . import output from . import pathctx from . import relctx from . import relgen def get_volatility_ref( path_id: irast.PathId, stmt: pgast.SelectStmt, *, ctx: context.CompilerContextLevel) -> Optional[pgast.BaseExpr]: """Produce an appropriate volatility_ref from a path_id.""" ref: Optional[pgast.BaseExpr] = relctx.maybe_get_path_var( stmt, path_id, aspect=pgce.PathAspect.ITERATOR, ctx=ctx) if not ref: ref = relctx.maybe_get_path_var( stmt, path_id, aspect=pgce.PathAspect.IDENTITY, ctx=ctx) if not ref: rvar = relctx.maybe_get_path_rvar( stmt, path_id, aspect=pgce.PathAspect.VALUE, ctx=ctx) if ( rvar and isinstance(rvar.query, pgast.ReturningQuery) # Expanded inhviews might be unions, which can't naively have # a row_number stuck on; they should be safe to just grab # the path_id value from, though and rvar.tag != 'expanded-inhview' ): # If we are selecting from a nontrivial subquery, manually # add a volatility ref based on row_number. We do it # manually because the row number isn't /really/ the # identity of the set. name = ctx.env.aliases.get('key') rvar.query.target_list.append( pgast.ResTarget( name=name, val=pgast.FuncCall(name=('row_number',), args=[], over=pgast.WindowDef()) ) ) ref = pgast.ColumnRef(name=[rvar.alias.aliasname, name]) else: ref = relctx.maybe_get_path_var( stmt, path_id, aspect=pgce.PathAspect.VALUE, ctx=ctx) return ref def setup_iterator_volatility( iterator: Optional[irast.Set | pgast.IteratorCTE], *, ctx: context.CompilerContextLevel) -> None: if iterator is None: return path_id = iterator.path_id # We use a callback scheme here to avoid inserting volatility ref # columns unless there is actually a volatile operation that # requires it. ctx.volatility_ref += ( lambda stmt, xctx: get_volatility_ref(path_id, stmt, ctx=xctx),) def compile_materialized_exprs( query: pgast.SelectStmt, stmt: irast.Stmt, *, ctx: context.CompilerContextLevel) -> None: if not stmt.materialized_sets: return if stmt in ctx.materializing: return with context.output_format(ctx, context.OutputFormat.NATIVE), ( ctx.new()) as matctx: matctx.materializing |= {stmt} matctx.expr_exposed = True # HACK: Sort longer paths before shorter ones # We want foo->bar to appear before foo mat_sets = sorted( (stmt.materialized_sets.values()), key=lambda m: -len(m.materialized.path_id), ) for mat_set in mat_sets: if len(mat_set.uses) <= 1: continue assert mat_set.finalized, "materialized set was not finalized!" if relctx.find_rvar( query, flavor='packed', path_id=mat_set.materialized.path_id, ctx=matctx): continue _compile_materialized_expr(query, mat_set, ctx=matctx) def _compile_materialized_expr( query: pgast.SelectStmt, mat_set: irast.MaterializedSet, *, ctx: context.CompilerContextLevel, ) -> None: mat_ids = set(mat_set.uses) # We pack optional things into arrays also, since it works. # TODO: use NULL? card = mat_set.cardinality assert card != qltypes.Cardinality.UNKNOWN is_singleton = card.is_single() and not card.can_be_zero() old_scope = ctx.path_scope ctx.path_scope = old_scope.new_child() for mat_id in mat_ids: for k in old_scope: if k.startswith(mat_id): ctx.path_scope[k] = None mat_qry = relgen.set_as_subquery( mat_set.materialized, as_value=True, ctx=ctx ) if not is_singleton: mat_qry = relctx.set_to_array( path_id=mat_set.materialized.path_id, query=mat_qry, ctx=ctx) if not mat_qry.target_list[0].name: mat_qry.target_list[0].name = ctx.env.aliases.get('v') ref = pgast.ColumnRef( name=[mat_qry.target_list[0].name], is_packed_multi=not is_singleton, ) for mat_id in mat_ids: pathctx.put_path_packed_output(mat_qry, mat_id, ref) mat_rvar = relctx.rvar_for_rel(mat_qry, lateral=True, ctx=ctx) for mat_id in mat_ids: relctx.include_rvar( query, mat_rvar, path_id=mat_id, flavor='packed', update_mask=False, pull_namespace=False, ctx=ctx, ) def compile_iterator_expr( query: pgast.SelectStmt, iterator_expr: irast.Set, *, is_dml: bool, ctx: context.CompilerContextLevel) \ -> pgast.PathRangeVar: assert isinstance(iterator_expr.expr, (irast.GroupStmt, irast.SelectStmt)) ctx.env.binding_dml[iterator_expr.path_id] = irutils.get_dml_sources( iterator_expr, ctx.env.binding_dml ) with ctx.new() as subctx: subctx.expr_exposed = False subctx.rel = query dispatch.visit(iterator_expr, ctx=subctx) iterator_rvar = relctx.get_path_rvar( query, iterator_expr.path_id, aspect=pgce.PathAspect.VALUE, ctx=ctx, ) iterator_query = iterator_rvar.query # If the iterator value is nullable, add a null test. This # makes sure that we don't spuriously produce output when # iterating over optional pointers. is_optional = ctx.scope_tree.is_optional(iterator_expr.path_id) if isinstance(iterator_query, pgast.SelectStmt): iterator_var = pathctx.get_path_value_var( iterator_query, path_id=iterator_expr.path_id, env=ctx.env) if not is_optional: if isinstance(iterator_query, pgast.SelectStmt): iterator_var = pathctx.get_path_value_var( iterator_query, path_id=iterator_expr.path_id, env=ctx.env) if iterator_var.nullable: iterator_query.where_clause = astutils.extend_binop( iterator_query.where_clause, pgast.NullTest(arg=iterator_var, negated=True)) elif isinstance(iterator_query, pgast.Relation): # will never be null pass else: raise NotImplementedError() # For DML-containing FOR, regardless of result type, iterators need # their own transient identity for path identity of the # iterator expression in order maintain correct correlation # for the state of iteration in DML statements, even when # there are duplicates in the iterator. This gets tracked as # a special ITERATOR aspect in order to distinguish it from # actual object identity. # # We also do this for optional iterators, since object # identity isn't safe to use as a volatility ref if the object # might be NULL. if is_dml or is_optional: relctx.create_iterator_identity_for_path( iterator_expr.path_id, iterator_query, apply_volatility=is_dml, ctx=subctx) pathctx.put_path_rvar( query, iterator_expr.path_id, iterator_rvar, aspect=pgce.PathAspect.ITERATOR, ) return iterator_rvar def compile_output( ir_set: irast.Set, *, ctx: context.CompilerContextLevel) -> pgast.OutputVar: with ctx.new() as newctx: dispatch.visit(ir_set, ctx=newctx) path_id = ir_set.path_id if (output.in_serialization_ctx(ctx) and newctx.stmt is newctx.toplevel_stmt): val = pathctx.get_path_serialized_output( ctx.rel, path_id, env=ctx.env) else: val = pathctx.get_path_value_output( ctx.rel, path_id, env=ctx.env) return val def compile_volatile_bindings( stmt: irast.Stmt, *, ctx: context.CompilerContextLevel ) -> None: for binding, volatility in (stmt.bindings or ()): # If something we are WITH binding contains DML, we want to # compile it *now*, in the context of its initial appearance # and not where the variable is used. # # Similarly, if something we are WITH binding is volatile and the stmt # contains dml, we similarly want to compile it *now*. # If the binding is a with binding for a DML stmt, manually construct # the CTEs. # # Note: This condition is checked first, because if the binding # *references* DML then contains_dml is true. If the binding is compiled # normally, since the referenced DML was already compiled, the rvar will # be retrieved, and no CTEs will be set up. if volatility.is_volatile() and irutils.contains_dml(stmt): _compile_volatile_binding_for_dml(stmt, binding, ctx=ctx) # For typical DML, just compile it. This will populate dml_stmts with # the CTEs, which will be picked up when the variable is referenced. elif irutils.contains_dml(binding): with ctx.substmt() as bctx: dispatch.compile(binding, ctx=bctx) def _compile_volatile_binding_for_dml( stmt: irast.Stmt, binding: irast.Set, *, ctx: context.CompilerContextLevel ) -> None: materialized_set = None if ( stmt.materialized_sets and binding.typeref.id in stmt.materialized_sets ): materialized_set = stmt.materialized_sets[binding.typeref.id] assert materialized_set is not None last_iterator = ctx.enclosing_cte_iterator with ( context.output_format(ctx, context.OutputFormat.NATIVE), ctx.newrel() as matctx ): matctx.materializing |= {stmt} matctx.expr_exposed = True dml.merge_iterator(last_iterator, matctx.rel, ctx=matctx) setup_iterator_volatility(last_iterator, ctx=matctx) _compile_materialized_expr( matctx.rel, materialized_set, ctx=matctx ) # Add iterator identity bind_pathid = ( irast.PathId.new_dummy(ctx.env.aliases.get('bind_path')) ) with matctx.subrel() as bind_pathid_ctx: relctx.create_iterator_identity_for_path( bind_pathid, bind_pathid_ctx.rel, ctx=bind_pathid_ctx ) bind_id_rvar = relctx.rvar_for_rel( bind_pathid_ctx.rel, lateral=True, ctx=matctx ) relctx.include_rvar( matctx.rel, bind_id_rvar, path_id=bind_pathid, ctx=matctx ) bind_cte = pgast.CommonTableExpr( name=ctx.env.aliases.get('bind'), query=matctx.rel, materialized=False, ) bind_iterator = pgast.IteratorCTE( path_id=bind_pathid, cte=bind_cte, parent=last_iterator, iterator_bond=True, ) ctx.toplevel_stmt.append_cte(bind_cte) # Merge the new iterator ctx.path_scope = ctx.path_scope.new_child() dml.merge_iterator(bind_iterator, ctx.rel, ctx=ctx) setup_iterator_volatility(bind_iterator, ctx=ctx) ctx.enclosing_cte_iterator = bind_iterator def compile_filter_clause( ir_set: irast.Set, cardinality: qltypes.Cardinality, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: where_clause: pgast.BaseExpr with ctx.new() as ctx1: ctx1.expr_exposed = False assert cardinality != qltypes.Cardinality.UNKNOWN if cardinality.is_single(): where_clause = dispatch.compile(ir_set, ctx=ctx1) else: # In WHERE we compile ir.Set as a boolean disjunction: # EXISTS(SELECT FROM SetRel WHERE SetRel.value) with ctx1.subrel() as subctx: dispatch.visit(ir_set, ctx=subctx) wrapper = subctx.rel wrapper.where_clause = pathctx.get_path_value_var( wrapper, ir_set.path_id, env=subctx.env) where_clause = pgast.SubLink(operator="EXISTS", expr=wrapper) return where_clause def compile_orderby_clause( ir_exprs: Sequence[irast.SortExpr], *, ctx: context.CompilerContextLevel) -> list[pgast.SortBy]: sort_clause = [] for expr in ir_exprs: with ctx.new() as orderctx: orderctx.expr_exposed = False # In ORDER BY we compile ir.Set as a subquery: # SELECT SetRel.value FROM SetRel) subq = relgen.set_as_subquery( expr.expr, as_value=True, ctx=orderctx) # pg apparently can't use indexes for ordering if the body # of an ORDER BY is a subquery, so try to collapse the query # into a simple expression. value = astutils.collapse_query(subq) sortexpr = pgast.SortBy( node=value, dir=expr.direction, nulls=expr.nones_order) sort_clause.append(sortexpr) return sort_clause def compile_limit_offset_clause( ir_set: Optional[irast.Set], *, ctx: context.CompilerContextLevel) -> Optional[pgast.BaseExpr]: if ir_set is None: return None with ctx.new() as ctx1: ctx1.expr_exposed = False # In OFFSET/LIMIT we compile ir.Set as a subquery: # SELECT SetRel.value FROM SetRel) limit_offset_clause = relgen.set_as_subquery(ir_set, ctx=ctx1) return limit_offset_clause def make_check_scan( check_cte: pgast.CommonTableExpr, *, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: return pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.FuncCall(name=('count',), args=[pgast.Star()]), ) ], from_clause=[ relctx.rvar_for_rel(check_cte, ctx=ctx), ], ) def scan_check_ctes( stmt: pgast.Query, check_ctes: list[pgast.CommonTableExpr], *, ctx: context.CompilerContextLevel, ) -> None: if not check_ctes: return # Scan all of the check CTEs to enforce constraints that are # checked as explicit queries and not Postgres constraints or # triggers. # To make sure that Postgres can't optimize the checks away, we # reference them in the where clause of an UPDATE to a dummy # table. # Add a big random number, so that different queries should try to # access different "rows" of the table, in case that matters. base_int = random.randint(0, (1 << 60) - 1) val: pgast.BaseExpr = pgast.NumericConstant(val=str(base_int)) for check_cte in check_ctes: # We want the CTE to be MATERIALIZED, because otherwise # Postgres might not fully evaluate all its columns when # scanning it. check_cte.materialized = True check = make_check_scan(check_cte, ctx=ctx) val = pgast.Expr(name="+", lexpr=val, rexpr=check) update_query = pgast.UpdateStmt( targets=[pgast.UpdateTarget( name='flag', val=pgast.BooleanConstant(val=True) )], relation=pgast.RelRangeVar(relation=pgast.Relation( name='_dml_dummy')), where_clause=pgast.Expr( name="=", lexpr=pgast.ColumnRef(name=["id"]), rexpr=val, ) ) stmt.append_cte(pgast.CommonTableExpr( query=update_query, name=ctx.env.aliases.get(hint='check_scan') )) def insert_ctes( stmt: pgast.Query, ctx: context.CompilerContextLevel ) -> None: if stmt.ctes is None: stmt.ctes = [] stmt.ctes[:0] = [ *ctx.param_ctes.values(), *ctx.ptr_inheritance_ctes.values(), *ctx.ordered_type_ctes, ] def fini_toplevel( stmt: pgast.Query, ctx: context.CompilerContextLevel) -> None: scan_check_ctes(stmt, ctx.env.check_ctes, ctx=ctx) # Type rewrites and inheritance CTEs go first. insert_ctes(stmt, ctx) if ctx.env.named_param_prefix is None: # Adding unused parameters into a CTE # Find the used parameters by searching the query, so we don't # get confused if something has been compiled but then omitted # from the output for some reason. param_refs = ast_visitor.find_children(stmt, pgast.ParamRef) used = {param_ref.number for param_ref in param_refs} targets = [] for param in ctx.env.query_params: pgparam = ctx.argmap[param.name] if pgparam.index in used or param.sub_params: continue targets.append(pgast.ResTarget(val=pgast.TypeCast( arg=pgast.ParamRef(number=pgparam.index), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(param.ir_type) ) ))) if isinstance(param, irast.Global) and param.has_present_arg: targets.append(pgast.ResTarget(val=pgast.TypeCast( arg=pgast.ParamRef(number=pgparam.index + 1), type_name=pgast.TypeName(name=('bool',)), ))) if targets: stmt.append_cte( pgast.CommonTableExpr( name="__unused_vars", query=pgast.SelectStmt(target_list=targets) ) ) def populate_argmap( params: list[irast.Param], globals: list[irast.Global], server_param_conversion_params: list[irast.Param], *, ctx: context.CompilerContextLevel, ) -> None: physical_index = 1 logical_index = 1 for map_extra in (False, True): for param in params: if ( ctx.env.named_param_prefix is not None and not param.name.isdecimal() ): continue if param.name.startswith('__edb_arg_') != map_extra: continue ctx.argmap[param.name] = pgast.Param( index=physical_index, logical_index=logical_index, required=param.required, ) if not param.sub_params: physical_index += 1 if not param.is_sub_param: logical_index += 1 for param in globals: ctx.argmap[param.name] = pgast.Param( index=physical_index, required=param.required, logical_index=-1, ) physical_index += 1 if param.has_present_arg: ctx.argmap[param.name + "present__"] = pgast.Param( index=physical_index, required=True, logical_index=-1, ) physical_index += 1 for param in server_param_conversion_params: ctx.argmap[param.name] = pgast.Param( index=physical_index, logical_index=logical_index, required=param.required, ) if not param.sub_params: physical_index += 1 if not param.is_sub_param: logical_index += 1 ================================================ FILE: edb/pgsql/compiler/config.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb import errors from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.schema import casts as s_casts from edb.schema import name as sn from edb.pgsql import ast as pgast from edb.pgsql import common from . import astutils from . import context from . import dispatch from . import pathctx from . import relctx from . import output @dispatch.compile.register def compile_ConfigSet( op: irast.ConfigSet, *, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: val = _compile_config_value(op, ctx=ctx) result: pgast.BaseExpr if op.scope is qltypes.ConfigScope.INSTANCE and op.backend_setting: if not ctx.env.backend_runtime_params.has_configfile_access: raise errors.UnsupportedBackendFeatureError( "configuring backend parameters via CONFIGURE INSTANCE" " is not supported by the current backend" ) result = pgast.AlterSystem( name=op.backend_setting, value=val, ) elif op.scope is qltypes.ConfigScope.DATABASE and op.backend_setting: if not isinstance(val, pgast.StringConstant): val = pgast.TypeCast( arg=val, type_name=pgast.TypeName(name=('text',)), ) fcall = pgast.FuncCall( name=astutils.edgedb_func('_alter_current_database_set', ctx=ctx), args=[pgast.StringConstant(val=op.backend_setting), val], ) result = output.wrap_script_stmt( pgast.SelectStmt(target_list=[pgast.ResTarget(val=fcall)]), suppress_all_output=True, env=ctx.env, ) elif op.scope is qltypes.ConfigScope.SESSION and op.backend_setting: if not isinstance(val, pgast.StringConstant): val = pgast.TypeCast( arg=val, type_name=pgast.TypeName(name=('text',)), ) fcall = pgast.FuncCall( name=('pg_catalog', 'set_config'), args=[ pgast.StringConstant(val=op.backend_setting), val, pgast.BooleanConstant(val=False), ], ) result = output.wrap_script_stmt( pgast.SelectStmt(target_list=[pgast.ResTarget(val=fcall)]), suppress_all_output=True, env=ctx.env, ) elif op.scope is qltypes.ConfigScope.INSTANCE: result_row = pgast.RowExpr( args=[ pgast.StringConstant(val='SET'), pgast.StringConstant(val=str(op.scope)), pgast.StringConstant(val=op.name), val, ] ) result = pgast.FuncCall( name=('jsonb_build_array',), args=result_row.args, null_safe=True, ser_safe=True, ) result = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=result, ), ], ) elif op.scope is qltypes.ConfigScope.SESSION: result = pgast.InsertStmt( relation=pgast.RelRangeVar( relation=pgast.Relation( name='_edgecon_state', ), ), select_stmt=pgast.SelectStmt( values=[ pgast.ImplicitRowExpr( args=[ pgast.StringConstant( val=op.name, ), val, pgast.StringConstant( val='C', ), ] ) ] ), cols=[ pgast.InsertTarget(name='name'), pgast.InsertTarget(name='value'), pgast.InsertTarget(name='type'), ], on_conflict=pgast.OnConflictClause( action=pgast.OnConflictAction.DO_UPDATE, target=pgast.OnConflictTarget( index_elems=[ pgast.IndexElem(expr=pgast.ColumnRef(name=['name'])), pgast.IndexElem(expr=pgast.ColumnRef(name=['type'])), ], ), update_list=[ pgast.MultiAssignRef( columns=['value'], source=pgast.RowExpr( args=[ val, ], ), ), ], ), ) elif op.scope is qltypes.ConfigScope.GLOBAL: result_row = pgast.RowExpr( args=[ pgast.StringConstant(val='SET'), pgast.StringConstant(val=str(op.scope)), pgast.StringConstant(val=op.name), val, ] ) build_array = pgast.FuncCall( name=('jsonb_build_array',), args=result_row.args, null_safe=True, ser_safe=True, ) result = pgast.SelectStmt( target_list=[pgast.ResTarget(val=build_array)], ) elif op.scope is qltypes.ConfigScope.DATABASE: result = pgast.InsertStmt( relation=pgast.RelRangeVar( relation=pgast.Relation( name='_db_config', schemaname='edgedb', ), ), select_stmt=pgast.SelectStmt( values=[ pgast.ImplicitRowExpr( args=[ pgast.StringConstant( val=op.name, ), val, ] ) ] ), cols=[ pgast.InsertTarget(name='name'), pgast.InsertTarget(name='value'), ], on_conflict=pgast.OnConflictClause( action=pgast.OnConflictAction.DO_UPDATE, target=pgast.OnConflictTarget( index_elems=[ pgast.IndexElem(expr=pgast.ColumnRef(name=['name'])), ], ), update_list=[ pgast.MultiAssignRef( columns=['value'], source=pgast.RowExpr( args=[ val, ], ), ), ], ), ) else: raise AssertionError(f'unexpected configuration scope: {op.scope}') return result @dispatch.compile.register def compile_ConfigReset( op: irast.ConfigReset, *, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: stmt: pgast.BaseExpr if op.scope is qltypes.ConfigScope.INSTANCE and op.backend_setting: stmt = pgast.AlterSystem( name=op.backend_setting, value=None, ) elif op.scope is qltypes.ConfigScope.DATABASE and op.backend_setting: fcall = pgast.FuncCall( name=astutils.edgedb_func('_alter_current_database_set', ctx=ctx), args=[ pgast.StringConstant(val=op.backend_setting), pgast.NullConstant(), ], ) stmt = output.wrap_script_stmt( pgast.SelectStmt(target_list=[pgast.ResTarget(val=fcall)]), suppress_all_output=True, env=ctx.env, ) elif op.scope is qltypes.ConfigScope.SESSION and op.backend_setting: fcall = pgast.FuncCall( name=('pg_catalog', 'set_config'), args=[ pgast.StringConstant(val=op.backend_setting), pgast.NullConstant(), pgast.BooleanConstant(val=False), ], ) stmt = output.wrap_script_stmt( pgast.SelectStmt(target_list=[pgast.ResTarget(val=fcall)]), suppress_all_output=True, env=ctx.env, ) elif op.scope is qltypes.ConfigScope.INSTANCE: if op.selector is None: # Scalar reset result_row = pgast.RowExpr( args=[ pgast.StringConstant(val='RESET'), pgast.StringConstant(val=str(op.scope)), pgast.StringConstant(val=op.name), pgast.NullConstant(), ] ) rvar = None else: with context.output_format(ctx, context.OutputFormat.JSONB): selector = dispatch.compile(op.selector, ctx=ctx) assert isinstance(selector, pgast.SelectStmt), \ "expected ast.SelectStmt" target = selector.target_list[0] if not target.name: target = selector.target_list[0] = pgast.ResTarget( name=ctx.env.aliases.get('res'), val=target.val, ) assert target.name is not None rvar = relctx.rvar_for_rel(selector, ctx=ctx) result_row = pgast.RowExpr( args=[ pgast.StringConstant(val='REM'), pgast.StringConstant(val=str(op.scope)), pgast.StringConstant(val=op.name), astutils.get_column(rvar, target.name), ] ) result = pgast.FuncCall( name=('jsonb_build_array',), args=result_row.args, null_safe=True, ser_safe=True, ) stmt = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=result, ), ], ) if rvar is not None: stmt.from_clause = [rvar] elif op.scope is qltypes.ConfigScope.DATABASE and op.selector is None: stmt = pgast.DeleteStmt( relation=pgast.RelRangeVar( relation=pgast.Relation( name='_db_config', schemaname='edgedb', ), ), where_clause=astutils.new_binop( lexpr=pgast.ColumnRef(name=['name']), rexpr=pgast.StringConstant(val=op.name), op='=', ), ) elif op.scope is qltypes.ConfigScope.DATABASE and op.selector is not None: # For FILTERed RESET on the database, we have to do a decent # amount of work to actually delete the RESET objects from the # json config blogs. # # This is because the server isn't set up to write back just # the changed parts of the config based on interpreting the output, # so instead we do all the work here. with context.output_format(ctx, context.OutputFormat.JSONB): selector = dispatch.compile(op.selector, ctx=ctx) assert isinstance(selector, pgast.SelectStmt), \ "expected ast.SelectStmt" target = selector.target_list[0] if not target.name: target = selector.target_list[0] = pgast.ResTarget( name=ctx.env.aliases.get('res'), val=target.val, ) assert target.name is not None rvar = relctx.rvar_for_rel(selector, ctx=ctx) sel_expr = op.selector.expr assert isinstance(sel_expr, irast.SelectStmt) sel_expr = sel_expr.result.expr assert isinstance(sel_expr, irast.SelectStmt) # Grab all the non-link properties of the object as keys. We # could just do the exclusive ones, but this works too and we # have the information at hand. # XXX: Do we need to consider _tname also? keys = [ el.expr.ptrref.shortname.name for el, op in sel_expr.result.shape if op == qlast.ShapeOp.ASSIGN and not irtyputils.is_object(el.expr.ptrref.out_target) ] newval = pgast.SelectStmt( target_list=[pgast.ResTarget( val=pgast.FuncCall( name=('jsonb_agg',), args=[pgast.ColumnRef(name=['ov', 'value'])], ), )], from_clause=[ pgast.RangeFunction( lateral=True, alias=pgast.Alias(aliasname='ov'), functions=[pgast.FuncCall( name=('jsonb_array_elements',), args=[pgast.ColumnRef(name=['value'])], )], ), ], where_clause=( pgast.SubLink( operator="NOT EXISTS", expr=pgast.SelectStmt( from_clause=[rvar], where_clause=astutils.extend_binop( None, *[ pgast.Expr( name='=', lexpr=pgast.Expr( name='->', lexpr=pgast.ColumnRef(name=[ rvar.alias.aliasname, target.name, ]), rexpr=pgast.StringConstant(val=key), ), rexpr=pgast.CoalesceExpr( args=[ pgast.Expr( name='->', lexpr=pgast.ColumnRef(name=[ 'ov', 'value' ]), rexpr=pgast.StringConstant( val=key ), ), pgast.TypeCast( arg=pgast.StringConstant( val='null'), type_name=pgast.TypeName( name=('jsonb',), ), ), ] ) ) for key in keys ], ) ) ) ), ) stmt = pgast.UpdateStmt( targets=[pgast.UpdateTarget( name='value', val=newval, )], relation=pgast.RelRangeVar( relation=pgast.Relation( name='_db_config', schemaname='edgedb', ), ), where_clause=astutils.new_binop( lexpr=pgast.ColumnRef(name=['name']), rexpr=pgast.StringConstant(val=op.name), op='=', ), returning_list=[pgast.ResTarget( val=pgast.CaseExpr( args=[ pgast.CaseWhen( expr=pgast.NullTest( arg=pgast.ColumnRef(name=['value']) ), result=pgast.FuncCall( name=('jsonb_build_array',), args=[ pgast.StringConstant(val='RESET'), pgast.StringConstant(val=str(op.scope)), pgast.StringConstant(val=op.name), pgast.NullConstant(), ], ) ), ], defresult=pgast.FuncCall( name=('jsonb_build_array',), args=[ pgast.StringConstant(val='SET'), pgast.StringConstant(val=str(op.scope)), pgast.StringConstant(val=op.name), pgast.ColumnRef(name=['value']), ], ) ) )], ) elif op.scope is qltypes.ConfigScope.SESSION: stmt = pgast.DeleteStmt( relation=pgast.RelRangeVar( relation=pgast.Relation( name='_edgecon_state', ), ), where_clause=astutils.new_binop( lexpr=astutils.new_binop( lexpr=pgast.ColumnRef(name=['name']), rexpr=pgast.StringConstant(val=op.name), op='=', ), rexpr=astutils.new_binop( lexpr=pgast.ColumnRef(name=['type']), rexpr=pgast.StringConstant(val='C'), op='=', ), op='AND', ) ) elif op.scope is qltypes.ConfigScope.GLOBAL: stmt = pgast.SelectStmt( where_clause=pgast.BooleanConstant(val=False) ) else: raise AssertionError(f'unexpected configuration scope: {op.scope}') return stmt @dispatch.compile.register def compile_ConfigInsert( stmt: irast.ConfigInsert, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: with ctx.new() as subctx: with context.output_format(ctx, context.OutputFormat.JSONB): subctx.expr_exposed = True rewritten = _rewrite_config_insert(stmt.expr, ctx=subctx) dispatch.compile(rewritten, ctx=subctx) return pathctx.get_path_serialized_output( ctx.rel, stmt.expr.path_id, env=ctx.env) def _rewrite_config_insert( ir_set: irast.Set, *, ctx: context.CompilerContextLevel ) -> irast.Set: overwrite_query = pgast.SelectStmt() id_expr = pgast.FuncCall( name=astutils.edgedb_func('uuid_generate_v1mc', ctx=ctx), args=[], ) pathctx.put_path_identity_var( overwrite_query, ir_set.path_id, id_expr, force=True ) pathctx.put_path_value_var( overwrite_query, ir_set.path_id, id_expr, force=True ) pathctx.put_path_source_rvar( overwrite_query, ir_set.path_id, relctx.rvar_for_rel(pgast.NullRelation(), ctx=ctx), ) relctx.add_type_rel_overlay( ir_set.typeref, context.OverlayOp.REPLACE, overwrite_query, path_id=ir_set.path_id, ctx=ctx, ) # Config objects have derived computed ids, # so the autogenerated id must not be returned. ir_set.shape = tuple(filter( lambda el: ( el[0].expr.ptrref.shortname.name != 'id' ), ir_set.shape, )) for el, _ in ir_set.shape: if isinstance(el.expr.expr, irast.InsertStmt): el.shape = tuple(filter( lambda e: ( e[0].expr.ptrref.shortname.name != 'id' ), el.shape, )) result = _rewrite_config_insert(el.expr.expr.subject, ctx=ctx) el.expr.expr = irast.SelectStmt( result=result, parent_stmt=el.expr.expr.parent_stmt, ) return ir_set def _compile_config_value( op: irast.ConfigSet, *, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: val: pgast.BaseExpr expr = op.backend_expr or op.expr with ctx.new() as subctx: if op.backend_setting or op.scope == qltypes.ConfigScope.GLOBAL: output_format = context.OutputFormat.NATIVE else: output_format = context.OutputFormat.JSONB with context.output_format(ctx, output_format): if isinstance(expr.expr, irast.EmptySet): # Special handling for empty sets, because we want a # singleton representation of the value and not an empty rel # in this context. if op.cardinality is qltypes.SchemaCardinality.One: val = pgast.NullConstant() elif subctx.env.output_format is context.OutputFormat.JSONB: val = pgast.TypeCast( arg=pgast.StringConstant(val='[]'), type_name=pgast.TypeName( name=('jsonb',), ), ) else: val = pgast.TypeCast( arg=pgast.ArrayExpr(elements=[]), type_name=pgast.TypeName( name=('text[]',), ), ) else: val = dispatch.compile(expr, ctx=subctx) assert isinstance(val, pgast.SelectStmt), "expected SelectStmt" pathctx.get_path_serialized_output( val, expr.path_id, env=ctx.env) if op.cardinality is qltypes.SchemaCardinality.Many: val = output.aggregate_json_output( val, expr, env=ctx.env) # For globals, we need to output the binary encoding so that we # can just hand it back to the server. We abuse `record_send` to # act as a generic `_send` function if op.scope is qltypes.ConfigScope.GLOBAL: val = pgast.FuncCall( name=('substring',), args=[ pgast.FuncCall( name=('record_send',), args=[pgast.RowExpr(args=[val])], ), # The first 8 bytes are header, then 4 bytes are the length # of our element, then the encoding of our actual element. # We include the length so we can distinguish NULL (len=-1) # from empty strings and the like (len=0). pgast.NumericConstant(val="9"), ], ) cast_name = s_casts.get_cast_fullname_from_names( sn.QualName('std', 'bytes'), sn.QualName('std', 'json')) val = pgast.FuncCall( name=common.get_cast_backend_name( cast_name, aspect='function', versioned=ctx.env.versioned_stdlib, ), args=[val], ) if op.backend_setting and op.scope is qltypes.ConfigScope.INSTANCE: assert isinstance(val, pgast.SelectStmt) and len(val.target_list) == 1 val = val.target_list[0].val if isinstance(val, pgast.TypeCast): val = val.arg if not isinstance(val, pgast.BaseConstant): raise AssertionError('value is not a constant in ConfigSet') return val def top_output_as_config_op( ir_set: irast.Set, stmt: pgast.SelectStmt, *, env: context.Environment ) -> pgast.Query: assert isinstance(ir_set.expr, irast.ConfigCommand) op = ir_set.expr alias = env.aliases.get('cfg') cte = pgast.CommonTableExpr(query=stmt, name=alias) ctes = [cte] subrvar = relctx.rvar_for_rel(cte, env=env) stmt_res = stmt.target_list[0] if stmt_res.name is None: stmt_res = stmt.target_list[0] = pgast.ResTarget( name=env.aliases.get('v'), val=stmt_res.val, ) assert stmt_res.name is not None val = pgast.ColumnRef(name=[stmt_res.name]) # FIXME: Can the duplication with other db cases be reduced? if op.scope is qltypes.ConfigScope.DATABASE: sval = pgast.SelectStmt( target_list=[pgast.ResTarget(val=val)], from_clause=[subrvar]) ins_val = pgast.FuncCall( name=('jsonb_build_array',), args=[sval], null_safe=True, ser_safe=True, ) old_val = pgast.CoalesceExpr( args=[ pgast.ColumnRef(name=['edgedb', '_db_config', 'value']), pgast.TypeCast( arg=pgast.StringConstant(val='[]'), type_name=pgast.TypeName( name=('jsonb',), ), ), ], ) upd_val = pgast.Expr( name='||', lexpr=old_val, rexpr=ins_val, ) ins = pgast.InsertStmt( relation=pgast.RelRangeVar( relation=pgast.Relation( name='_db_config', schemaname='edgedb', ), ), select_stmt=pgast.SelectStmt( values=[ pgast.ImplicitRowExpr( args=[ pgast.StringConstant( val=op.name, ), ins_val, ] ) ], ), cols=[ pgast.InsertTarget(name='name'), pgast.InsertTarget(name='value'), ], on_conflict=pgast.OnConflictClause( action=pgast.OnConflictAction.DO_UPDATE, target=pgast.OnConflictTarget( index_elems=[ pgast.IndexElem(expr=pgast.ColumnRef(name=['name'])), ], ), update_list=[ pgast.MultiAssignRef( columns=['value'], source=pgast.RowExpr( args=[ upd_val, ], ), ), ], ), returning_list=[ pgast.ResTarget( val=pgast.ColumnRef(name=[pgast.Star()]) ) ], ) ctes.append( pgast.CommonTableExpr(query=ins, name=env.aliases.get('ins')) ) subrvar = relctx.rvar_for_rel(ctes[-1], env=env) val = pgast.ColumnRef(name=['value']) if ir_set.expr.scope in ( qltypes.ConfigScope.INSTANCE, qltypes.ConfigScope.DATABASE ): # For database config, we do SET, and we return the entire new # value, in order to avoid race conditions in duplicate # checking. command = ( 'SET' if ir_set.expr.scope is qltypes.ConfigScope.DATABASE else 'ADD' ) result_row = pgast.RowExpr( args=[ pgast.StringConstant(val=command), pgast.StringConstant(val=str(ir_set.expr.scope)), pgast.StringConstant(val=ir_set.expr.name), val, ] ) array = pgast.FuncCall( name=('jsonb_build_array',), args=result_row.args, null_safe=True, ser_safe=True, ) result = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=array, ), ], from_clause=[ subrvar, ], ctes=ctes + (stmt.ctes or []), ) stmt.ctes = [] return result else: raise errors.InternalServerError( f'CONFIGURE {ir_set.expr.scope} INSERT is not supported') ================================================ FILE: edb/pgsql/compiler/context.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """IR compiler context.""" from __future__ import annotations from typing import ( Callable, Optional, Mapping, ChainMap, Generator, Sequence, TYPE_CHECKING, ) import collections import contextlib import dataclasses import enum import uuid import immutables as immu from edb.common import compiler from edb.common import enum as s_enum from edb.pgsql import ast as pgast from edb.pgsql import params as pgparams from . import aliases as pg_aliases if TYPE_CHECKING: from edb.ir import ast as irast from . import enums as pgce class ContextSwitchMode(enum.Enum): TRANSPARENT = enum.auto() SUBREL = enum.auto() NEWREL = enum.auto() SUBSTMT = enum.auto() NEWSCOPE = enum.auto() class ShapeFormat(enum.Enum): SERIALIZED = enum.auto() FLAT = enum.auto() class OutputFormat(enum.Enum): #: Result data output in PostgreSQL format. NATIVE = enum.auto() #: Result data output as a single JSON string. JSON = enum.auto() #: Result data output as a single PostgreSQL JSONB type value. JSONB = enum.auto() #: Result data output as a JSON string for each element in returned set. JSON_ELEMENTS = enum.auto() #: None mode: query result not returned, cardinality of result set #: is returned instead. NONE = enum.auto() #: Like NATIVE, but objects without an explicit shape are serialized #: as UUIDs. NATIVE_INTERNAL = enum.auto() NO_STMT = pgast.SelectStmt() class OverlayOp(s_enum.StrEnum): UNION = 'union' REPLACE = 'replace' FILTER = 'filter' EXCEPT = 'except' OverlayEntry = tuple[ OverlayOp, pgast.BaseRelation | pgast.CommonTableExpr, 'irast.PathId', ] @dataclasses.dataclass(kw_only=True) class RelOverlays: """Container for relation overlays. These track "overlays" that can be registered for different types, in the context of DML. Consider the query: with X := ( insert Person { name := "Sully", notes := assert_distinct({ (insert Note {name := "1"}), (select Note filter .name = "2"), }), }), select X { name, notes: {name} }; When we go to select X, we find the source of that set without any trouble (it's the result of the actual insert statement, more or less; in any case, it's in a CTE that we then include). Handling the notes are trickier, though: * The links aren't in the link table yet, but only in a CTE. (In similar update cases, with things like +=, they might be mixed between both.) * Some of the actual Note objects aren't in the table yet, just an insert CTE. But some *are*, so we need to union them. We solve these problems using overlays: * Whenever we do DML (or reference WITH-bound DML), we register overlays describing the changes done to *all of the enclosing DML*. So here, the Note insert's overlays get registered both for the Note insert and for the Person insert. * When we try to compile a root set or pointer, we see if it is connected to a DML statement, and if so we apply the overlays. The overlay itself is simply a sequence of operations on relations and CTEs that mix in the new data. In the obvious insert cases, these consist of unioning the new data in. This system works decently well but is also a little broken: I think that both the "all of the enclosing DML" and the "see if it is connected to a DML statement" have dangers; see Issue #3030. In relctx, see range_for_material_objtype, range_for_ptrref, and range_from_queryset (which those two call) for details on how overlays are applied. Overlays are added to with relctx.add_type_rel_overlay and relctx.add_ptr_rel_overlay. ===== NOTE ON MUTABILITY: In typical use, the overlays are mutable: nested DML adds overlays that are then consumed by code in enclosing contexts. In some places, however, we need to temporarily customize the overlay environment (during policy and trigger compilation, for example). The original version of overlays were implemented as a dict of dicts of lists. Doing temporary customizations required doing at least some copying. Doing a full deep copy always felt excessive but doing anything short of that left me constantly terrified. So instead we represent the overlays as a mutable object that contains immutable maps. When we add overlays, we update the maps and then reassign their values. When we want to do a temporary adjustment, we can cheaply make a fresh RelOverlays object and then modify that without touching the original. """ #: Relations used to "overlay" the main table for #: the type. Mostly used with DML statements. type: immu.Map[ Optional[irast.MutatingLikeStmt], immu.Map[ uuid.UUID, tuple[OverlayEntry, ...], ], ] = immu.Map() #: Relations used to "overlay" the main table for #: the pointer. Mostly used with DML statements. ptr: immu.Map[ Optional[irast.MutatingLikeStmt], immu.Map[ tuple[uuid.UUID, str], tuple[ tuple[ OverlayOp, pgast.BaseRelation | pgast.CommonTableExpr, irast.PathId, ], ... ], ], ] = immu.Map() def copy(self) -> RelOverlays: return RelOverlays(type=self.type, ptr=self.ptr) class CompilerContextLevel(compiler.ContextLevel): #: static compilation environment env: Environment #: mapping of named args to position argmap: dict[str, pgast.Param] #: whether compiling in singleton expression mode singleton_mode: bool #: whether compiling a trigger trigger_mode: bool #: the top-level SQL statement toplevel_stmt: pgast.Query #: Record of DML CTEs generated for the corresponding IR DML. #: CTEs generated for DML-containing FOR statements are keyed #: by their iterator set. dml_stmts: dict[irast.MutatingStmt | irast.Set, pgast.CommonTableExpr] #: Inline DML functions may require additional CTEs. #: Record such CTEs as well as the path used by their iterators. #: This ensures CTEs are created only once, and that the correct #: iterator bonds are applied. inline_dml_ctes: dict[ irast.PathId, tuple[irast.PathId, pgast.CommonTableExpr], ] #: SQL statement corresponding to the IR statement #: currently being compiled. stmt: pgast.SelectStmt #: Current SQL subquery rel: pgast.SelectStmt #: SQL query hierarchy rel_hierarchy: dict[pgast.Query, pgast.Query] #: CTEs representing decoded parameters param_ctes: dict[str, pgast.CommonTableExpr] #: CTEs representing pointers and their inherited pointers ptr_inheritance_ctes: dict[uuid.UUID, pgast.CommonTableExpr] #: CTEs representing types, when rewritten based on access policy type_rewrite_ctes: dict[FullRewriteKey, pgast.CommonTableExpr] #: A set of type CTEs currently being generated pending_type_rewrite_ctes: set[RewriteKey] #: CTEs representing types and their inherited types type_inheritance_ctes: dict[uuid.UUID, pgast.CommonTableExpr] # Type and type inheriance CTEs in creation order. This ensures type CTEs # referring to other CTEs are in the correct order. ordered_type_ctes: list[pgast.CommonTableExpr] #: The logical parent of the current query in the #: query hierarchy parent_rel: Optional[pgast.Query] #: Query to become current in the next SUBSTMT switch. pending_query: Optional[pgast.SelectStmt] #: Sets currently being materialized materializing: frozenset[irast.Stmt] #: Whether the expression currently being processed is #: directly exposed to the output of the statement. expr_exposed: Optional[bool] #: A hack that indicates a tuple element that should be treated as #: exposed. This enables us to treat 'bar' in (foo, bar).1 as exposed, #: which eta-expansion and some casts rely on. expr_exposed_tuple_cheat: Optional[irast.TupleElement] #: Expression to use to force SQL expression volatility in this context #: (Delayed with a lambda to avoid inserting it when not used.) volatility_ref: tuple[ Callable[[pgast.SelectStmt, CompilerContextLevel], Optional[pgast.BaseExpr]], ...] # Current path_id we are INSERTing, so that we can avoid creating # a bogus volatility ref to it... current_insert_path_id: Optional[irast.PathId] #: Paths, for which semi-join is banned in this context. disable_semi_join: frozenset[irast.PathId] #: Paths, which need to be explicitly wrapped into SQL #: optionality scaffolding. force_optional: frozenset[irast.PathId] #: Paths that can be ignored when they appear as the source of a # computable. This is key to optimizing away free object sources in # group by aggregates. skippable_sources: frozenset[irast.PathId] #: Specifies that references to a specific Set must be narrowed #: by only selecting instances of type specified by the mapping value. intersection_narrowing: dict[irast.Set, irast.Set] #: Which SQL query holds the SQL scope for the given PathId path_scope: ChainMap[irast.PathId, Optional[pgast.SelectStmt]] #: Relevant IR scope for this context. scope_tree: irast.ScopeTreeNode #: A stack of dml statements currently being compiled. Used for #: figuring out what to record in type_rel_overlays. dml_stmt_stack: list[irast.MutatingLikeStmt] #: Relations used to "overlay" the main table for #: the type. Mostly used with DML statements. rel_overlays: RelOverlays #: Mapping from path ids to "external" rels given by a particular relation external_rels: Mapping[ irast.PathId, tuple[ pgast.BaseRelation | pgast.CommonTableExpr, tuple[pgce.PathAspect, ...] ] ] #: The CTE and some metadata of any enclosing iterator-like #: construct (which includes iterators, insert/update, and INSERT #: ELSE select clauses) currently being compiled. enclosing_cte_iterator: Optional[pgast.IteratorCTE] #: Sets to force shape compilation on, because the values are #: needed by DML. shapes_needed_by_dml: set[irast.Set] def __init__( self, prevlevel: Optional[CompilerContextLevel], mode: ContextSwitchMode, *, env: Optional[Environment] = None, scope_tree: Optional[irast.ScopeTreeNode] = None, ) -> None: if prevlevel is None: assert env is not None assert scope_tree is not None self.env = env self.argmap = collections.OrderedDict() self.singleton_mode = False self.toplevel_stmt = NO_STMT self.stmt = NO_STMT self.rel = NO_STMT self.rel_hierarchy = {} self.param_ctes = {} self.ptr_inheritance_ctes = {} self.type_rewrite_ctes = {} self.pending_type_rewrite_ctes = set() self.type_inheritance_ctes = {} self.ordered_type_ctes = [] self.dml_stmts = {} self.inline_dml_ctes = {} self.parent_rel = None self.pending_query = None self.materializing = frozenset() self.expr_exposed = None self.expr_exposed_tuple_cheat = None self.volatility_ref = () self.current_insert_path_id = None self.disable_semi_join = frozenset() self.force_optional = frozenset() self.skippable_sources = frozenset() self.intersection_narrowing = {} self.path_scope = collections.ChainMap() self.scope_tree = scope_tree self.dml_stmt_stack = [] self.rel_overlays = RelOverlays() self.external_rels = {} self.enclosing_cte_iterator = None self.shapes_needed_by_dml = set() self.trigger_mode = False else: self.env = prevlevel.env self.argmap = prevlevel.argmap self.singleton_mode = prevlevel.singleton_mode self.toplevel_stmt = prevlevel.toplevel_stmt self.stmt = prevlevel.stmt self.rel = prevlevel.rel self.rel_hierarchy = prevlevel.rel_hierarchy self.param_ctes = prevlevel.param_ctes self.ptr_inheritance_ctes = prevlevel.ptr_inheritance_ctes self.type_rewrite_ctes = prevlevel.type_rewrite_ctes self.pending_type_rewrite_ctes = prevlevel.pending_type_rewrite_ctes self.type_inheritance_ctes = prevlevel.type_inheritance_ctes self.ordered_type_ctes = prevlevel.ordered_type_ctes self.dml_stmts = prevlevel.dml_stmts self.inline_dml_ctes = prevlevel.inline_dml_ctes self.parent_rel = prevlevel.parent_rel self.pending_query = prevlevel.pending_query self.materializing = prevlevel.materializing self.expr_exposed = prevlevel.expr_exposed self.expr_exposed_tuple_cheat = prevlevel.expr_exposed_tuple_cheat self.volatility_ref = prevlevel.volatility_ref self.current_insert_path_id = prevlevel.current_insert_path_id self.disable_semi_join = prevlevel.disable_semi_join self.force_optional = prevlevel.force_optional self.skippable_sources = prevlevel.skippable_sources self.intersection_narrowing = prevlevel.intersection_narrowing self.path_scope = prevlevel.path_scope self.scope_tree = prevlevel.scope_tree self.dml_stmt_stack = prevlevel.dml_stmt_stack self.rel_overlays = prevlevel.rel_overlays self.enclosing_cte_iterator = prevlevel.enclosing_cte_iterator self.shapes_needed_by_dml = prevlevel.shapes_needed_by_dml self.external_rels = prevlevel.external_rels self.trigger_mode = prevlevel.trigger_mode if mode is ContextSwitchMode.SUBSTMT: if self.pending_query is not None: self.rel = self.pending_query else: self.rel = pgast.SelectStmt() if prevlevel.parent_rel is not None: parent_rel = prevlevel.parent_rel else: parent_rel = prevlevel.rel self.rel_hierarchy[self.rel] = parent_rel self.stmt = self.rel self.pending_query = None self.parent_rel = None elif mode is ContextSwitchMode.SUBREL: self.rel = pgast.SelectStmt() if prevlevel.parent_rel is not None: parent_rel = prevlevel.parent_rel else: parent_rel = prevlevel.rel self.rel_hierarchy[self.rel] = parent_rel self.pending_query = None self.parent_rel = None elif mode is ContextSwitchMode.NEWREL: self.rel = pgast.SelectStmt() self.pending_query = None self.parent_rel = None self.path_scope = collections.ChainMap() self.rel_hierarchy = {} self.scope_tree = prevlevel.scope_tree.root self.volatility_ref = () self.disable_semi_join = frozenset() self.force_optional = frozenset() self.intersection_narrowing = {} self.pending_type_rewrite_ctes = set( prevlevel.pending_type_rewrite_ctes ) elif mode == ContextSwitchMode.NEWSCOPE: self.path_scope = prevlevel.path_scope.new_child() def get_current_dml_stmt(self) -> Optional[irast.MutatingLikeStmt]: if len(self.dml_stmt_stack) == 0: return None return self.dml_stmt_stack[-1] def subrel( self, ) -> compiler.CompilerContextManager[CompilerContextLevel]: return self.new(ContextSwitchMode.SUBREL) def newrel( self, ) -> compiler.CompilerContextManager[CompilerContextLevel]: return self.new(ContextSwitchMode.NEWREL) def substmt( self, ) -> compiler.CompilerContextManager[CompilerContextLevel]: return self.new(ContextSwitchMode.SUBSTMT) def newscope( self, ) -> compiler.CompilerContextManager[CompilerContextLevel]: return self.new(ContextSwitchMode.NEWSCOPE) def up_hierarchy( self, n: int, q: Optional[pgast.Query]=None ) -> Optional[pgast.Query]: # mostly intended as a debugging helper q = q or self.rel for _ in range(n): if q: q = self.rel_hierarchy.get(q) return q class CompilerContext(compiler.CompilerContext[CompilerContextLevel]): ContextLevelClass = CompilerContextLevel default_mode = ContextSwitchMode.TRANSPARENT RewriteKey = tuple[uuid.UUID, bool] FullRewriteKey = tuple[ uuid.UUID, bool, Optional[frozenset['irast.MutatingLikeStmt']]] class Environment: """Static compilation environment.""" aliases: pg_aliases.AliasGenerator output_format: Optional[OutputFormat] named_param_prefix: Optional[tuple[str, ...]] ptrref_source_visibility: dict[irast.BasePointerRef, bool] expected_cardinality_one: bool ignore_object_shapes: bool explicit_top_cast: Optional[irast.TypeRef] singleton_mode: bool query_params: list[irast.Param] type_rewrites: dict[RewriteKey, irast.Set] scope_tree_nodes: dict[int, irast.ScopeTreeNode] external_rvars: Mapping[ tuple[irast.PathId, pgce.PathAspect], pgast.PathRangeVar ] materialized_views: dict[uuid.UUID, irast.Set] backend_runtime_params: pgparams.BackendRuntimeParams versioned_stdlib: bool sql_dml_mode: bool #: A list of CTEs that implement constraint validation at the #: query level. check_ctes: list[pgast.CommonTableExpr] #: Map of binding path ids to DML used in the binding. I hope and #: suspect that this will grow towards becoming a more general and #: traditional symbol table as I rip out path factoring? Who knows #: though. binding_dml: dict[irast.PathId, Sequence[irast.MutatingLikeStmt]] def __init__( self, *, alias_generator: Optional[pg_aliases.AliasGenerator] = None, output_format: Optional[OutputFormat], named_param_prefix: Optional[tuple[str, ...]], expected_cardinality_one: bool, ignore_object_shapes: bool, singleton_mode: bool, is_explain: bool, explicit_top_cast: Optional[irast.TypeRef], query_params: list[irast.Param], type_rewrites: dict[RewriteKey, irast.Set], scope_tree_nodes: dict[int, irast.ScopeTreeNode], external_rvars: Optional[ Mapping[tuple[irast.PathId, pgce.PathAspect], pgast.PathRangeVar] ] = None, backend_runtime_params: pgparams.BackendRuntimeParams, # XXX: TRAMPOLINE: THIS IS WRONG versioned_stdlib: bool = True, sql_dml_mode: bool = False, ) -> None: self.aliases = alias_generator or pg_aliases.AliasGenerator() self.output_format = output_format self.named_param_prefix = named_param_prefix self.ptrref_source_visibility = {} self.expected_cardinality_one = expected_cardinality_one self.ignore_object_shapes = ignore_object_shapes self.singleton_mode = singleton_mode self.is_explain = is_explain self.explicit_top_cast = explicit_top_cast self.query_params = query_params self.type_rewrites = type_rewrites self.scope_tree_nodes = scope_tree_nodes self.external_rvars = external_rvars or {} self.materialized_views = {} self.check_ctes = [] self.backend_runtime_params = backend_runtime_params self.versioned_stdlib = versioned_stdlib self.sql_dml_mode = sql_dml_mode self.binding_dml = {} # XXX: this context hack is necessary until pathctx is converted # to use context levels instead of using env directly. @contextlib.contextmanager def output_format( ctx: CompilerContextLevel, output_format: OutputFormat, ) -> Generator[None, None, None]: original_output_format = ctx.env.output_format original_ignore_object_shapes = ctx.env.ignore_object_shapes ctx.env.output_format = output_format ctx.env.ignore_object_shapes = False try: yield finally: ctx.env.output_format = original_output_format ctx.env.ignore_object_shapes = original_ignore_object_shapes ================================================ FILE: edb/pgsql/compiler/dispatch.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import functools from edb.ir import ast as irast from edb.pgsql import ast as pgast from . import context @functools.singledispatch def compile( ir: irast.Base, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: raise NotImplementedError(f'no IR compiler handler for {ir.__class__}') @functools.singledispatch def visit(ir: irast.Base, *, ctx: context.CompilerContextLevel) -> None: """A compilation version that does not pull the value eagerly.""" compile(ir, ctx=ctx) ================================================ FILE: edb/pgsql/compiler/dml.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """IR compiler support for INSERT/UPDATE/DELETE statements.""" # # The processing of the DML statement is done in two parts. # # 1. The statement's *range* query is built: the relation representing # the statement's target Object with any WHERE quals taken into account. # # 2. The statement body is processed to generate a series of # SQL substatements to modify all relations touched by the statement # depending on the link layout. # from __future__ import annotations from typing import ( Optional, Mapping, Sequence, Collection, NamedTuple, ) import immutables as immu from edb.common.typeutils import downcast, not_none from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.schema import name as sn from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.ir import utils as irutils from edb.pgsql import ast as pgast from edb.pgsql import types as pg_types from edb.pgsql import common from . import astutils from . import clauses from . import context from . import dispatch from . import enums as pgce from . import output from . import pathctx from . import relctx from . import relgen class DMLParts(NamedTuple): dml_ctes: Mapping[ irast.TypeRef, tuple[pgast.CommonTableExpr, pgast.PathRangeVar], ] else_cte: Optional[tuple[pgast.CommonTableExpr, pgast.PathRangeVar]] range_cte: Optional[pgast.CommonTableExpr] def init_dml_stmt( ir_stmt: irast.MutatingStmt, *, ctx: context.CompilerContextLevel, ) -> DMLParts: """Prepare the common structure of the query representing a DML stmt. Args: ir_stmt: IR of the DML statement. Returns: A ``DMLParts`` tuple containing a map of DML CTEs as well as the common range CTE for UPDATE/DELETE statements. """ range_cte: Optional[pgast.CommonTableExpr] range_rvar: Optional[pgast.RelRangeVar] clauses.compile_volatile_bindings(ir_stmt, ctx=ctx) if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)): # UPDATE and DELETE operate over a range, so generate # the corresponding CTE and connect it to the DML statements. range_cte = get_dml_range(ir_stmt, ctx=ctx) range_rvar = pgast.RelRangeVar( relation=range_cte, alias=pgast.Alias( aliasname=ctx.env.aliases.get(hint='range') ) ) else: range_cte = None range_rvar = None top_typeref = ir_stmt.material_type typerefs = [top_typeref] if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)): if top_typeref.union: for component in top_typeref.union: if component.material_type: component = component.material_type typerefs.append(component) typerefs.extend(irtyputils.get_typeref_descendants(component)) typerefs.extend(irtyputils.get_typeref_descendants(top_typeref)) # Only update/delete concrete types. (Except in the degenerate # corner case where there are none, in which case keep using # everything so as to avoid needing a more complex special case.) concrete_typerefs = [t for t in typerefs if not t.is_abstract] if concrete_typerefs: typerefs = concrete_typerefs dml_map = {} for typeref in typerefs: if typeref.union: continue if ( isinstance(typeref.name_hint, sn.QualName) and typeref.name_hint.module in ('sys', 'cfg') ): continue dml_cte, dml_rvar = gen_dml_cte( ir_stmt, range_rvar=range_rvar, typeref=typeref, ctx=ctx, ) dml_map[typeref] = (dml_cte, dml_rvar) else_cte = None if ( isinstance(ir_stmt, irast.InsertStmt) and ir_stmt.on_conflict and ir_stmt.on_conflict.else_ir is not None ): dml_cte = pgast.CommonTableExpr( query=pgast.SelectStmt(), name=ctx.env.aliases.get(hint='melse'), for_dml_stmt=ir_stmt, ) dml_rvar = relctx.rvar_for_rel(dml_cte, ctx=ctx) else_cte = (dml_cte, dml_rvar) put_iterator_bond(ctx.enclosing_cte_iterator, ctx.rel) ctx.dml_stmt_stack.append(ir_stmt) return DMLParts( dml_ctes=dml_map, range_cte=range_cte, else_cte=else_cte, ) def gen_dml_union( ir_stmt: irast.MutatingStmt, parts: DMLParts, *, ctx: context.CompilerContextLevel ) -> tuple[pgast.CommonTableExpr, pgast.PathRangeVar]: dml_entries = list(parts.dml_ctes.values()) if parts.else_cte: dml_entries.append(parts.else_cte) if len(dml_entries) == 1: union_cte, union_rvar = dml_entries[0] else: union_components = [] for _, dml_rvar in dml_entries: union_component = pgast.SelectStmt() relctx.include_rvar( union_component, dml_rvar, ir_stmt.subject.path_id, ctx=ctx, ) union_components.append(union_component) qry = pgast.SelectStmt( all=True, larg=union_components[0], ) for union_component in union_components[1:]: qry.op = 'UNION' qry.rarg = union_component qry = pgast.SelectStmt( all=True, larg=qry, ) assert qry.larg put_iterator_bond(ctx.enclosing_cte_iterator, qry.larg) union_cte = pgast.CommonTableExpr( query=qry.larg, name=ctx.env.aliases.get(hint='dml_union'), for_dml_stmt=ctx.get_current_dml_stmt(), ) union_rvar = relctx.rvar_for_rel( union_cte, typeref=ir_stmt.subject.typeref, ctx=ctx, ) ctx.dml_stmts[ir_stmt] = union_cte union_cte.output_of_dml = ir_stmt return union_cte, union_rvar def gen_dml_cte( ir_stmt: irast.MutatingStmt, *, range_rvar: Optional[pgast.RelRangeVar], typeref: irast.TypeRef, ctx: context.CompilerContextLevel, ) -> tuple[pgast.CommonTableExpr, pgast.PathRangeVar]: subject_ir_set = ir_stmt.subject subject_path_id = subject_ir_set.path_id dml_stmt: pgast.InsertStmt | pgast.SelectStmt | pgast.DeleteStmt subject_rvar: pgast.BaseRangeVar if isinstance(ir_stmt, irast.InsertStmt): relation = relctx.range_for_typeref( typeref, subject_path_id, for_mutation=True, ctx=ctx, ) assert isinstance(relation, pgast.RelRangeVar), ( "spurious overlay on DML target" ) dml_stmt = pgast.InsertStmt(relation=relation) subject_rvar = relation elif isinstance(ir_stmt, irast.UpdateStmt): # We generate a Select as the initial statement for an update, # since the contents select is the query that needs to join # the range and include policy filters and because we # sometimes end up not needing an UPDATE anyway (if it only # touches link tables). dml_stmt = pgast.SelectStmt() if ctx.env.sql_dml_mode: # We join with the concrete table for this type, but also include # overlays produced by previous DML stmts. This is needed for SQL # DML, which needs to update a link table of a newly inserted object subject_rel_overlayed = relctx.range_for_typeref( typeref, subject_path_id, for_mutation=False, include_descendants=False, dml_source=[ k for k in ctx.dml_stmts.keys() if isinstance(k, irast.MutatingLikeStmt) ], ctx=ctx, ) dml_stmt.from_clause.append(subject_rel_overlayed) subject_rvar = subject_rel_overlayed else: # We join with the concrete table for this type relation = relctx.range_for_typeref( typeref, subject_path_id, for_mutation=True, ctx=ctx, ) dml_stmt.from_clause.append(relation) subject_rvar = relation elif isinstance(ir_stmt, irast.DeleteStmt): relation = relctx.range_for_typeref( typeref, subject_path_id, for_mutation=True, ctx=ctx, ) assert isinstance(relation, pgast.RelRangeVar), ( "spurious overlay on DML target" ) dml_stmt = pgast.DeleteStmt(relation=relation) subject_rvar = relation else: raise AssertionError(f'unexpected DML IR: {ir_stmt!r}') pathctx.put_path_value_rvar(dml_stmt, subject_path_id, subject_rvar) pathctx.put_path_source_rvar(dml_stmt, subject_path_id, subject_rvar) # Skip the path bond for inserts, since it doesn't help and # interferes when inserting in an UNLESS CONFLICT ELSE if not isinstance(ir_stmt, irast.InsertStmt): pathctx.put_path_bond(dml_stmt, subject_path_id) dml_cte = pgast.CommonTableExpr( query=dml_stmt, name=ctx.env.aliases.get(hint='dml'), for_dml_stmt=ir_stmt, ) # Due to the fact that DML statements are structured # as a flat list of CTEs instead of nested range vars, # the top level path scope must be empty. The necessary # range vars will be injected explicitly in all rels that # need them. ctx.path_scope.maps.clear() if range_rvar is not None: relctx.pull_path_namespace( target=dml_stmt, source=range_rvar, ctx=ctx) # Auxiliary relations are always joined via the WHERE # clause due to the structure of the UPDATE/DELETE SQL statements. assert isinstance(dml_stmt, (pgast.SelectStmt, pgast.DeleteStmt)) dml_stmt.where_clause = astutils.new_binop( lexpr=pathctx.get_rvar_path_identity_var( subject_rvar, subject_path_id, env=ctx.env ), op='=', rexpr=pathctx.get_rvar_path_identity_var( range_rvar, subject_path_id, env=ctx.env ) ) # Do any read-side filtering if pol_expr := ir_stmt.read_policies.get(typeref.id): with ctx.newrel() as sctx: pathctx.put_path_value_rvar( sctx.rel, subject_path_id, subject_rvar ) pathctx.put_path_source_rvar( sctx.rel, subject_path_id, subject_rvar ) val = clauses.compile_filter_clause( pol_expr.expr, pol_expr.cardinality, ctx=sctx ) sctx.rel.target_list.append(pgast.ResTarget(val=val)) dml_stmt.where_clause = astutils.extend_binop( dml_stmt.where_clause, sctx.rel ) # SELECT has "FROM", while DELETE has "USING". if isinstance(dml_stmt, pgast.SelectStmt): dml_stmt.from_clause.append(range_rvar) elif isinstance(dml_stmt, pgast.DeleteStmt): dml_stmt.using_clause.append(range_rvar) dml_rvar = relctx.rvar_for_rel(dml_cte, typeref=typeref, ctx=ctx) return dml_cte, dml_rvar def wrap_dml_cte( ir_stmt: irast.MutatingStmt, dml_cte: pgast.CommonTableExpr, *, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: wrapper = ctx.rel dml_rvar = relctx.rvar_for_rel( dml_cte, typeref=ir_stmt.subject.typeref, ctx=ctx, ) relctx.include_rvar(wrapper, dml_rvar, ir_stmt.subject.path_id, ctx=ctx) pathctx.put_path_bond(wrapper, ir_stmt.subject.path_id) if ctx.dml_stmt_stack: relctx.reuse_type_rel_overlays( dml_source=ir_stmt, dml_stmts=ctx.dml_stmt_stack, ctx=ctx) return dml_rvar def put_iterator_bond( iterator: Optional[pgast.IteratorCTE], select: pgast.Query, ) -> None: if iterator: pathctx.put_path_bond( select, iterator.path_id, iterator=iterator.iterator_bond) def merge_iterator_scope( iterator: Optional[pgast.IteratorCTE], select: pgast.SelectStmt, *, ctx: context.CompilerContextLevel ) -> None: while iterator: ctx.path_scope[iterator.path_id] = select iterator = iterator.parent def merge_iterator( iterator: Optional[pgast.IteratorCTE], select: pgast.SelectStmt, *, ctx: context.CompilerContextLevel ) -> Optional[pgast.PathRangeVar]: merge_iterator_scope(iterator, select, ctx=ctx) if iterator: iterator_rvar = relctx.rvar_for_rel(iterator.cte, ctx=ctx) put_iterator_bond(iterator, select) relctx.include_rvar( select, iterator_rvar, aspects=(pgce.PathAspect.VALUE, iterator.aspect) + ( (pgce.PathAspect.SOURCE,) if iterator.path_id.is_objtype_path() else () ), path_id=iterator.path_id, overwrite_path_rvar=True, ctx=ctx) # We need nested iterators to re-export their enclosing # iterators in some cases that the path_id_mask blocks # otherwise. select.path_id_mask.discard(iterator.path_id) # HACK: This is a hack for triggers, to stick __old__ in # as a reference to __new__'s identity for updates/deletes for other_path, aspect in iterator.other_paths: pathctx.put_path_rvar( select, other_path, iterator_rvar, aspect=aspect ) return iterator_rvar else: return None def fini_dml_stmt( ir_stmt: irast.MutatingStmt, parts: DMLParts, *, ctx: context.CompilerContextLevel, ) -> None: union_cte, union_rvar = gen_dml_union(ir_stmt, parts, ctx=ctx) if len(parts.dml_ctes) > 1 or parts.else_cte: ctx.toplevel_stmt.append_cte(union_cte) relctx.include_rvar(ctx.rel, union_rvar, ir_stmt.subject.path_id, ctx=ctx) # Record the effect of this insertion in the relation overlay # context to ensure that the RETURNING clause potentially # referencing this class yields the expected results. dml_stack = ctx.dml_stmt_stack if isinstance(ir_stmt, irast.InsertStmt): # The union CTE might have a SELECT from an ELSE clause, which # we don't actually want to include. assert len(parts.dml_ctes) == 1 cte = next(iter(parts.dml_ctes.values()))[0] relctx.add_type_rel_overlay( ir_stmt.subject.typeref, context.OverlayOp.UNION, cte, dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx) elif isinstance(ir_stmt, irast.UpdateStmt): base_typeref = ir_stmt.subject.typeref.real_material_type for typeref, (cte, _) in parts.dml_ctes.items(): # Because we have a nice union_cte for the base type, # we don't need to propagate the children overlays to # that type or its ancestors, hence the stop_ref argument. if typeref.id == base_typeref.id: cte = union_cte stop_ref = None else: stop_ref = base_typeref # When the base type is abstract, there will be no CTE for it, # so the overlays of children types have to apply to the whole # ancestry tree. if base_typeref.is_abstract: stop_ref = None # The overlay for update is in two parts: # First, filter out objects that have been updated, then union them # back in. (If we just did union, we'd see the old values also.) relctx.add_type_rel_overlay( typeref, context.OverlayOp.FILTER, cte, stop_ref=stop_ref, dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx) relctx.add_type_rel_overlay( typeref, context.OverlayOp.UNION, cte, stop_ref=stop_ref, dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx) process_extra_conflicts(ir_stmt=ir_stmt, dml_parts=parts, ctx=ctx) elif isinstance(ir_stmt, irast.DeleteStmt): base_typeref = ir_stmt.subject.typeref.real_material_type for typeref, (cte, _) in parts.dml_ctes.items(): # see above, re: stop_ref if typeref.id == base_typeref.id: cte = union_cte stop_ref = None else: stop_ref = base_typeref relctx.add_type_rel_overlay( typeref, context.OverlayOp.EXCEPT, cte, stop_ref=stop_ref, dml_stmts=dml_stack, path_id=ir_stmt.subject.path_id, ctx=ctx) clauses.compile_output(ir_stmt.result, ctx=ctx) ctx.dml_stmt_stack.pop() def get_dml_range( ir_stmt: irast.UpdateStmt | irast.DeleteStmt, *, ctx: context.CompilerContextLevel, ) -> pgast.CommonTableExpr: """Create a range CTE for the given DML statement. Args: ir_stmt: IR of the DML statement. Returns: A CommonTableExpr node representing the range affected by the DML statement. """ target_ir_set = ir_stmt.subject ir_qual_expr = ir_stmt.where ir_qual_card = ir_stmt.where_card with ctx.newrel() as subctx: subctx.expr_exposed = False range_stmt = subctx.rel merge_iterator(ctx.enclosing_cte_iterator, range_stmt, ctx=subctx) dispatch.visit(target_ir_set, ctx=subctx) relgen.ensure_source_rvar(target_ir_set, range_stmt, ctx=subctx) pathctx.get_path_identity_output( range_stmt, target_ir_set.path_id, env=subctx.env) if ir_qual_expr is not None: with subctx.new() as wctx: clauses.setup_iterator_volatility(target_ir_set, ctx=wctx) range_stmt.where_clause = astutils.extend_binop( range_stmt.where_clause, clauses.compile_filter_clause( ir_qual_expr, ir_qual_card, ctx=wctx)) range_stmt.path_id_mask.discard(target_ir_set.path_id) pathctx.put_path_bond(range_stmt, target_ir_set.path_id) range_cte = pgast.CommonTableExpr( query=range_stmt, name=ctx.env.aliases.get('range'), for_dml_stmt=ctx.get_current_dml_stmt(), ) return range_cte def compile_iterator_cte( iterator_set: irast.Set, *, ctx: context.CompilerContextLevel ) -> Optional[pgast.IteratorCTE]: last_iterator = ctx.enclosing_cte_iterator # If this iterator has already been compiled to a CTE, use # that CTE instead of recompiling. (This will happen when # a DML-containing FOR loop is WITH bound, for example.) if iterator_set in ctx.dml_stmts: iterator_cte = ctx.dml_stmts[iterator_set] return pgast.IteratorCTE( path_id=iterator_set.path_id, cte=iterator_cte, parent=last_iterator, iterator_bond=True) with ctx.newrel() as ictx: ictx.scope_tree = ctx.scope_tree ictx.path_scope[iterator_set.path_id] = ictx.rel # Correlate with enclosing iterators merge_iterator(last_iterator, ictx.rel, ctx=ictx) clauses.setup_iterator_volatility(last_iterator, ctx=ictx) clauses.compile_iterator_expr( ictx.rel, iterator_set, is_dml=True, ctx=ictx) if iterator_set.path_id.is_objtype_path(): relgen.ensure_source_rvar(iterator_set, ictx.rel, ctx=ictx) ictx.rel.path_id = iterator_set.path_id pathctx.put_path_bond(ictx.rel, iterator_set.path_id, iterator=True) iterator_cte = pgast.CommonTableExpr( query=ictx.rel, name=ctx.env.aliases.get('iter'), for_dml_stmt=ctx.get_current_dml_stmt(), ) ictx.toplevel_stmt.append_cte(iterator_cte) ctx.dml_stmts[iterator_set] = iterator_cte return pgast.IteratorCTE( path_id=iterator_set.path_id, cte=iterator_cte, parent=last_iterator, iterator_bond=True, ) def _mk_dynamic_get_path( ptr_map: dict[sn.Name, pgast.BaseExpr], typeref: irast.TypeRef, fallback_rvar: Optional[pgast.PathRangeVar] = None, ) -> pgast.DynamicRangeVarFunc: """A dynamic rvar function for insert/update. It returns values out of a select purely based on material rptr, as if it was a base relation. This is to make it easy for access policies to operate on the results. """ def dynamic_get_path( rel: pgast.Query, path_id: irast.PathId, *, flavor: str, aspect: str, env: context.Environment ) -> Optional[pgast.BaseExpr | pgast.PathRangeVar]: if ( flavor != 'normal' or aspect not in ( pgce.PathAspect.VALUE, pgce.PathAspect.IDENTITY ) ): return None if rptr := path_id.rptr(): if ret := ptr_map.get(rptr.real_material_ptr.name): return ret if rptr.real_material_ptr.shortname.name == '__type__': return astutils.compile_typeref(typeref) # If a fallback rvar is specified, defer to that. # This is used in rewrites to go back to the original if fallback_rvar: return fallback_rvar if not rptr: raise LookupError('only pointers appear in insert fallback') # Properties that aren't specified are {} return pgast.NullConstant() return dynamic_get_path def process_insert_body( *, ir_stmt: irast.InsertStmt, insert_cte: pgast.CommonTableExpr, dml_parts: DMLParts, ctx: context.CompilerContextLevel, ) -> None: """Generate SQL DML CTEs from an InsertStmt IR. Args: ir_stmt: IR of the DML statement. insert_cte: A CommonTableExpr node representing the SQL INSERT into the main relation of the DML subject. else_cte_rvar: If present, a tuple containing a CommonTableExpr and a RangeVar for it, which represent the body of an ELSE clause in an UNLESS CONFLICT construct. dml_parts: A DMLParts tuple returned by init_dml_stmt(). """ # We build the tuples to insert in a select we put into a CTE select = pgast.SelectStmt(target_list=[]) # The main INSERT query of this statement will always be # present to insert at least the `id` property. insert_stmt = insert_cte.query assert isinstance(insert_stmt, pgast.InsertStmt) typeref = ir_stmt.subject.typeref.real_material_type # Handle an UNLESS CONFLICT if we need it # If there is an UNLESS CONFLICT, we need to know that there is a # conflict *before* we execute DML for fields stored in the object # itself, so we can prevent that execution from happening. If that # is necessary, compile_insert_else_body will generate an iterator # CTE with a row for each non-conflicting insert we want to do. We # then use that as the iterator for any DML in inline fields. # # (For DML in the definition of pointers stored in link tables, we # don't need to worry about this, because we can run that DML # after the enclosing INSERT, using the enclosing INSERT as the # iterator.) on_conflict_fake_iterator = None if ir_stmt.on_conflict: assert not insert_stmt.on_conflict on_conflict_fake_iterator = compile_insert_else_body( insert_stmt, ir_stmt, ir_stmt.on_conflict, ctx.enclosing_cte_iterator, dml_parts.else_cte, ctx=ctx, ) iterator = ctx.enclosing_cte_iterator inner_iterator = on_conflict_fake_iterator or iterator # ptr_map needs to be set up in advance of compiling the shape # because defaults might reference earlier pointers. ptr_map: dict[sn.Name, pgast.BaseExpr] = {} # Use a dynamic rvar to return values out of the select purely # based on material rptr, as if it was a base relation. # This is to make it easy for access policies to operate on the result # of the INSERT. fallback_rvar = pgast.DynamicRangeVar( dynamic_get_path=_mk_dynamic_get_path(ptr_map, typeref)) pathctx.put_path_source_rvar( select, ir_stmt.subject.path_id, fallback_rvar ) pathctx.put_path_value_rvar(select, ir_stmt.subject.path_id, fallback_rvar) # compile contents CTE elements: list[tuple[irast.SetE[irast.Pointer], irast.BasePointerRef]] = [] for shape_el, shape_op in ir_stmt.subject.shape: assert shape_op is qlast.ShapeOp.ASSIGN # If the shape element is a linkprop, we do nothing. # It will be picked up by the enclosing DML. if shape_el.path_id.is_linkprop_path(): continue ptrref = shape_el.expr.ptrref if ptrref.material_ptr is not None: ptrref = ptrref.material_ptr assert shape_el.expr.expr elements.append((shape_el, ptrref)) external_inserts = process_insert_shape( ir_stmt, select, ptr_map, elements, iterator, inner_iterator, ctx ) single_external = [ ir for ir in external_inserts if ir.expr.dir_cardinality.is_single() ] # Put the select that builds the tuples to insert into its own CTE. # We do this for two reasons: # 1. Generating the object ids outside of the actual SQL insert allows # us to join any enclosing iterators into any nested external inserts. # 2. We can use the contents CTE to evaluate insert access policies # before we actually try the insert. This is important because # otherwise an exclusive constraint could be raised first, # which leaks information. pathctx.put_path_bond(select, ir_stmt.subject.path_id) contents_cte = pgast.CommonTableExpr( query=select, name=ctx.env.aliases.get('ins_contents'), for_dml_stmt=ctx.get_current_dml_stmt(), ) ctx.toplevel_stmt.append_cte(contents_cte) contents_rvar = relctx.rvar_for_rel(contents_cte, ctx=ctx) rewrites = ir_stmt.rewrites and ir_stmt.rewrites.by_type.get(typeref) pol_expr = ir_stmt.write_policies.get(typeref.id) pol_ctx = None if pol_expr or rewrites or single_external: # Create a context for handling policies/rewrites that we will # use later. We do this in advance so that the link update code # can populate overlay fields in it. with ctx.new() as pol_ctx: pol_ctx.rel_overlays = context.RelOverlays() needs_insert_on_conflict = bool( ir_stmt.on_conflict and not on_conflict_fake_iterator) # The first serious bit of trickiness: if there are rewrites, the link # table updates need to be done *before* we compute the rewrites, since # the rewrites might refer to them. # # However, we can't unconditionally do it like this, because we # want to be able to use ON CONFLICT to implement UNLESS CONFLICT # ON when possible, and in that case the link table operations # need to be done after the *actual insert*, because it is the actual # insert that filters out conflicting rows. (This also means that we # can't use ON CONFLICT if there are rewrites.) # # Similar issues obtain with access policies: we can't use ON # CONFLICT if there are access policies, since we can't "see" all # possible conflicting objects. # # We *also* need link tables to go first if there are any single links # with link properties. We do the actual computation for those in a link # table and then join it in to the main table, where it is duplicated. link_ctes = [] def _update_link_tables(inp_cte: pgast.CommonTableExpr) -> None: # Process necessary updates to the link tables. for shape_el in external_inserts: link_cte, check_cte = process_link_update( ir_stmt=ir_stmt, ir_set=shape_el, dml_cte=inp_cte, source_typeref=typeref, iterator=iterator, policy_ctx=pol_ctx, ctx=ctx, ) if link_cte: link_ctes.append(link_cte) if check_cte is not None: ctx.env.check_ctes.append(check_cte) if not needs_insert_on_conflict: _update_link_tables(contents_cte) # compile rewrites CTE if rewrites or single_external: rewrites = rewrites or {} assert not needs_insert_on_conflict assert pol_ctx # Now that all the compilation for the INSERT has been done, # apply the tweaked policy overlays. pol_ctx.rel_overlays = update_overlay( ctx.rel_overlays, pol_ctx.rel_overlays ) with pol_ctx.reenter(), pol_ctx.newrel() as rctx: # Pull in ptr rel overlays, so we can see the pointers merge_overlays_globally((ir_stmt,), ctx=rctx) contents_cte, contents_rvar = process_insert_rewrites( ir_stmt, contents_cte=contents_cte, iterator=iterator, inner_iterator=inner_iterator, rewrites=rewrites, single_external=single_external, elements=elements, ctx=rctx, ) # Populate the real insert statement based on the select we generated insert_stmt.cols = [ pgast.InsertTarget(name=name) for value in contents_cte.query.target_list # Filter out generated columns; only keep concrete ones if '~' not in (name := not_none(value.name)) ] insert_stmt.select_stmt = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=col) for col in insert_stmt.cols ], from_clause=[contents_rvar], ) pathctx.put_path_bond(insert_stmt, ir_stmt.subject.path_id) real_insert_cte = pgast.CommonTableExpr( query=insert_stmt, name=ctx.env.aliases.get('ins'), for_dml_stmt=ctx.get_current_dml_stmt(), ) # Create the final CTE for the insert that joins the insert # and the select together. with ctx.newrel() as ictx: merge_iterator(iterator, ictx.rel, ctx=ictx) insert_rvar = relctx.rvar_for_rel(real_insert_cte, ctx=ctx) relctx.include_rvar( ictx.rel, insert_rvar, ir_stmt.subject.path_id, ctx=ictx) relctx.include_rvar( ictx.rel, contents_rvar, ir_stmt.subject.path_id, ctx=ictx) # TODO: set up dml_parts with a SelectStmt for inserts always? insert_cte.query = ictx.rel # If there is an ON CONFLICT clause, insert the CTEs now so that the # link inserts can depend on it. Otherwise we have the link updates # depend on the contents cte so that policies can operate before # doing any actual INSERTs. if needs_insert_on_conflict: ctx.toplevel_stmt.append_cte(real_insert_cte) ctx.toplevel_stmt.append_cte(insert_cte) link_op_cte = insert_cte else: link_op_cte = contents_cte if needs_insert_on_conflict: _update_link_tables(link_op_cte) if pol_expr: assert pol_ctx assert not needs_insert_on_conflict with pol_ctx.reenter(): policy_cte = compile_policy_check( contents_cte, ir_stmt, pol_expr, typeref=typeref, ctx=pol_ctx ) force_policy_checks( policy_cte, (insert_stmt,) + tuple(cte.query for cte in link_ctes), ctx=ctx) for link_cte in link_ctes: ctx.toplevel_stmt.append_cte(link_cte) if not needs_insert_on_conflict: ctx.toplevel_stmt.append_cte(real_insert_cte) ctx.toplevel_stmt.append_cte(insert_cte) # XXX: do we need to pass in inner_iterator here? process_extra_conflicts(ir_stmt=ir_stmt, dml_parts=dml_parts, ctx=ctx) def process_insert_rewrites( ir_stmt: irast.InsertStmt, *, contents_cte: pgast.CommonTableExpr, iterator: Optional[pgast.IteratorCTE], inner_iterator: Optional[pgast.IteratorCTE], rewrites: irast.RewritesOfType, single_external: list[irast.SetE[irast.Pointer]], elements: Sequence[tuple[irast.SetE[irast.Pointer], irast.BasePointerRef]], ctx: context.CompilerContextLevel, ) -> tuple[pgast.CommonTableExpr, pgast.PathRangeVar]: typeref = ir_stmt.subject.typeref.real_material_type subject_path_id = ir_stmt.subject.path_id rew_stmt = ctx.rel # Use the original contents as the iterator. inner_iterator = pgast.IteratorCTE( path_id=subject_path_id, cte=contents_cte, parent=inner_iterator, other_paths=( (subject_path_id, pgce.PathAspect.IDENTITY), (subject_path_id, pgce.PathAspect.VALUE), (subject_path_id, pgce.PathAspect.SOURCE), ), ) # compile rewrite shape rewrite_elements = list(rewrites.values()) nptr_map: dict[sn.Name, pgast.BaseExpr] = {} process_insert_shape( ir_stmt, rew_stmt, nptr_map, rewrite_elements, iterator, inner_iterator, ctx, force_optional=True, ) iterator_rvar = pathctx.get_path_rvar( rew_stmt, path_id=subject_path_id, aspect=pgce.PathAspect.VALUE ) fallback_rvar = pgast.DynamicRangeVar( dynamic_get_path=_mk_dynamic_get_path(nptr_map, typeref, iterator_rvar) ) pathctx.put_path_source_rvar(rew_stmt, subject_path_id, fallback_rvar) pathctx.put_path_value_rvar(rew_stmt, subject_path_id, fallback_rvar) # If there are any single links that were compiled externally, # populate the field from the link overlays. handled = set(rewrites) for ext_ir in single_external: handled.add(ext_ir.expr.ptrref.shortname.name) with ctx.subrel() as ectx: ext_rvar = relctx.new_pointer_rvar( ext_ir, link_bias=True, src_rvar=iterator_rvar, ctx=ectx) relctx.include_rvar(ectx.rel, ext_rvar, ext_ir.path_id, ctx=ectx) # Make the subquery output the target pathctx.get_path_value_output( ectx.rel, ext_ir.path_id, env=ctx.env) ptr_info = pg_types.get_ptrref_storage_info( ext_ir.expr.ptrref, resolve_type=True, link_bias=False) rew_stmt.target_list.append(pgast.ResTarget( name=ptr_info.column_name, val=ectx.rel)) nptr_map[ext_ir.expr.ptrref.real_material_ptr.name] = ectx.rel # Pull in pointers that were not rewritten not_rewritten = { (e, ptrref) for e, ptrref in elements if ptrref.shortname.name not in handled } for e, ptrref in not_rewritten: # FIXME: Duplicates some with process_insert_shape ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=True, link_bias=False) if ptr_info.table_type == 'ObjectType': val = pathctx.get_path_var( rew_stmt, e.path_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) val = output.output_as_value(val, env=ctx.env) rew_stmt.target_list.append(pgast.ResTarget( name=ptr_info.column_name, val=val)) # construct the CTE pathctx.put_path_bond(rew_stmt, ir_stmt.subject.path_id) rewrites_cte = pgast.CommonTableExpr( query=rew_stmt, name=ctx.env.aliases.get('ins_rewrites'), for_dml_stmt=ctx.get_current_dml_stmt(), ) ctx.toplevel_stmt.append_cte(rewrites_cte) rewrites_rvar = relctx.rvar_for_rel(rewrites_cte, ctx=ctx) return rewrites_cte, rewrites_rvar def process_insert_shape( ir_stmt: irast.InsertStmt, select: pgast.SelectStmt, ptr_map: dict[sn.Name, pgast.BaseExpr], elements: Sequence[tuple[irast.SetE[irast.Pointer], irast.BasePointerRef]], iterator: Optional[pgast.IteratorCTE], inner_iterator: Optional[pgast.IteratorCTE], ctx: context.CompilerContextLevel, force_optional: bool=False, ) -> list[irast.SetE[irast.Pointer]]: # Compile the shape external_inserts = [] with ctx.newrel() as subctx: subctx.enclosing_cte_iterator = inner_iterator subctx.rel = select subctx.expr_exposed = False inner_iterator_id = None if inner_iterator is not None: subctx.path_scope = ctx.path_scope.new_child() merge_iterator(inner_iterator, select, ctx=subctx) inner_iterator_id = relctx.get_path_var( select, inner_iterator.path_id, aspect=inner_iterator.aspect, ctx=ctx) # Process the Insert IR and separate links that go # into the main table from links that are inserted into # a separate link table. for element, ptrref in elements: ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=True, link_bias=False) link_ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=False, link_bias=True) # First, process all local link inserts. Single link with # link properties are not processed here; we compile those # in link tables and then select those back into the main # table as a rewrite. if not link_ptr_info and ptr_info.table_type == 'ObjectType': compile_insert_shape_element( element, ir_stmt=ir_stmt, iterator_id=inner_iterator_id, force_optional=force_optional, ctx=subctx, ) insvalue = pathctx.get_path_value_var( subctx.rel, element.path_id, env=ctx.env) if irtyputils.is_tuple(element.typeref): # Tuples require an explicit cast. insvalue = pgast.TypeCast( arg=output.output_as_value(insvalue, env=ctx.env), type_name=pgast.TypeName( name=ptr_info.column_type, ), ) ptr_map[ptrref.name] = insvalue select.target_list.append(pgast.ResTarget( name=ptr_info.column_name, val=insvalue)) if link_ptr_info and link_ptr_info.table_type == 'link': external_inserts.append(element) put_iterator_bond(iterator, select) for aspect in (pgce.PathAspect.VALUE, pgce.PathAspect.IDENTITY): pathctx._put_path_output_var( select, ir_stmt.subject.path_id, aspect=aspect, var=pgast.ColumnRef(name=['id']), ) return external_inserts def compile_insert_shape_element( shape_el: irast.SetE[irast.Pointer], *, ir_stmt: irast.MutatingStmt, iterator_id: Optional[pgast.BaseExpr], force_optional: bool, ctx: context.CompilerContextLevel, ) -> None: with ctx.new() as insvalctx: if iterator_id is not None: id = iterator_id insvalctx.volatility_ref = (lambda _stmt, _ctx: id,) else: # Single inserts have no need for forced # computable volatility, and, furthermore, # we do not have a valid identity reference # anyway. insvalctx.volatility_ref = () insvalctx.current_insert_path_id = ir_stmt.subject.path_id if shape_el.expr.dir_cardinality.can_be_zero() or force_optional: # If the element can be empty, compile it in a subquery to force it # to be NULL. value = relgen.set_as_subquery( shape_el, as_value=True, ctx=insvalctx) pathctx.put_path_value_var(insvalctx.rel, shape_el.path_id, value) else: dispatch.visit(shape_el, ctx=insvalctx) def merge_overlays_globally( ir_stmts: Collection[irast.MutatingLikeStmt | None], *, ctx: context.CompilerContextLevel, ) -> None: ctx.rel_overlays = ctx.rel_overlays.copy() type_overlay = ctx.rel_overlays.type.get(None, immu.Map()) ptr_overlay = ctx.rel_overlays.ptr.get(None, immu.Map()) for ir_stmt in ir_stmts: if not ir_stmt: continue for k, v in ctx.rel_overlays.type.get(ir_stmt, immu.Map()).items(): els = set(type_overlay.get(k, ())) n_els = ( type_overlay.get(k, ()) + tuple(e for e in v if e not in els) ) type_overlay = type_overlay.set(k, n_els) for k2, v2 in ctx.rel_overlays.ptr.get(ir_stmt, immu.Map()).items(): els = set(ptr_overlay.get(k2, ())) n_els = ( ptr_overlay.get(k2, ()) + tuple(e for e in v2 if e not in els) ) ptr_overlay = ptr_overlay.set(k2, n_els) ctx.rel_overlays.type = ctx.rel_overlays.type.set(None, type_overlay) ctx.rel_overlays.ptr = ctx.rel_overlays.ptr.set(None, ptr_overlay) def update_overlay( left: context.RelOverlays, right: context.RelOverlays, ) -> context.RelOverlays: '''Make an updated copy of the left overlay using right.''' left = left.copy() for ir_stmt, tm in right.type.items(): ltm = left.type.get(ir_stmt, immu.Map()) ltm = ltm.update(tm) left.type = left.type.set(ir_stmt, ltm) for ir_stmt, pm in right.ptr.items(): lpm = left.ptr.get(ir_stmt, immu.Map()) lpm = lpm.update(pm) left.ptr = left.ptr.set(ir_stmt, lpm) return left def compile_policy_check( dml_cte: pgast.CommonTableExpr, ir_stmt: irast.MutatingStmt, access_policies: irast.WritePolicies, typeref: irast.TypeRef, *, ctx: context.CompilerContextLevel, ) -> pgast.CommonTableExpr: subject_id = ir_stmt.subject.path_id with ctx.newrel() as ictx: # Pull in ptr rel overlays, so we can see the pointers merge_overlays_globally((ir_stmt,), ctx=ictx) dml_rvar = relctx.rvar_for_rel(dml_cte, ctx=ctx) relctx.include_rvar(ictx.rel, dml_rvar, path_id=subject_id, ctx=ictx) # split and compile allow, deny = [], [] for policy in access_policies.policies: cond_ref = clauses.compile_filter_clause( policy.expr, policy.cardinality, ctx=ictx ) if policy.action == qltypes.AccessPolicyAction.Allow: allow.append((policy, cond_ref)) else: deny.append((policy, cond_ref)) def raise_if(a: pgast.BaseExpr, msg: pgast.BaseExpr) -> pgast.BaseExpr: return pgast.FuncCall( name=astutils.edgedb_func('raise_on_null', ctx=ctx), args=[ pgast.FuncCall( name=('nullif',), args=[a, pgast.BooleanConstant(val=True)], ), pgast.StringConstant(val='insufficient_privilege'), pgast.NamedFuncArg( name='msg', val=msg, ), pgast.NamedFuncArg( name='table', val=pgast.StringConstant(val=str(typeref.id)), ), ], ) # allow if allow: allow_conds = (cond for _, cond in allow) no_allow_expr: pgast.BaseExpr = astutils.new_unop( 'NOT', astutils.extend_binop(None, *allow_conds, op='OR') ) else: no_allow_expr = pgast.BooleanConstant(val=True) # deny deny_exprs = (cond for _, cond in deny) # message if isinstance(ir_stmt, irast.InsertStmt): op = 'insert' else: op = 'update' msg = f'access policy violation on {op} of {typeref.name_hint}' allow_hints = (pol.error_msg for pol, _ in allow if pol.error_msg) allow_hint = '; '.join(allow_hints) hints = [(allow_hint, no_allow_expr)] + [ (pol.error_msg, cond) for pol, cond in deny if pol.error_msg ] hint = _conditional_string_agg(hints) if hint: hint = astutils.new_coalesce( astutils.extend_concat(' (', hint, ')'), pgast.StringConstant(val=''), ) message = astutils.extend_concat(msg, hint) else: message = astutils.extend_concat(msg) ictx.rel.target_list.append( pgast.ResTarget( name=f'error', val=raise_if( astutils.extend_binop(no_allow_expr, *deny_exprs, op='OR'), msg=message, ), ) ) policy_cte = pgast.CommonTableExpr( query=ictx.rel, name=ctx.env.aliases.get('policy'), materialized=True, for_dml_stmt=ctx.get_current_dml_stmt(), ) ictx.toplevel_stmt.append_cte(policy_cte) return policy_cte def _conditional_string_agg( pairs: Sequence[tuple[Optional[str], pgast.BaseExpr]], ) -> Optional[pgast.BaseExpr]: selects = [ pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.StringConstant(val=str) if str else pgast.NullConstant() ) ], where_clause=cond, ) for str, cond in pairs ] union = astutils.extend_select_op(None, *selects) if not union: return None return pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.FuncCall( name=('string_agg',), args=[ pgast.ColumnRef(name=('error_msg',)), pgast.StringConstant(val='; '), ], ), ) ], from_clause=[ pgast.RangeSubselect( subquery=union, alias=pgast.Alias(aliasname='t', colnames=['error_msg']), ) ], ) def force_policy_checks( policy_cte: pgast.CommonTableExpr, queries: Sequence[pgast.Query], *, ctx: context.CompilerContextLevel) -> None: # The actual DML statements need to be made dependent on the # policy CTE, to ensure that it is evaluated before any # modifications are done. scan = pgast.Expr( name=">", lexpr=clauses.make_check_scan(policy_cte, ctx=ctx), rexpr=pgast.NumericConstant(val="-1"), ) stmt: Optional[pgast.Query] for stmt in queries: if isinstance(stmt, pgast.InsertStmt): stmt = stmt.select_stmt if isinstance(stmt, (pgast.SelectStmt, pgast.UpdateStmt)): stmt.where_clause = astutils.extend_binop( stmt.where_clause, scan ) # If there aren't any update/insert queries to put it into # (because it is just an update with a -=, probably), make it a # normal check CTE. if not queries: ctx.env.check_ctes.append(policy_cte) def insert_needs_conflict_cte( ir_stmt: irast.MutatingStmt, on_conflict: irast.OnConflictClause, *, ctx: context.CompilerContextLevel, ) -> bool: # We need to generate a conflict CTE if it is possible that # the query might generate two conflicting objects. if on_conflict.else_fail: return False if on_conflict.always_check or ir_stmt.conflict_checks: return True # We can't use ON CONFLICT if there are access policies # on the type, since UNLESS CONFLICT only should avoid # conflicts with objects that are visible. type_id = ir_stmt.subject.typeref.real_material_type.id if ( (type_id, True) in ctx.env.type_rewrites or (type_id, False) in ctx.env.type_rewrites ): return True # We can't use ON CONFLICT if there are rewrites on the type # because rewrites might reference multi pointers, which means # we need to execute link operations before the final INSERT. if ir_stmt.rewrites and ir_stmt.rewrites.by_type: return True for shape_el, _ in ir_stmt.subject.shape: ptrref = shape_el.expr.ptrref ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=True, link_bias=False) # We need to generate a conflict CTE if we have a DML containing # pointer stored in the object itself if ( ptr_info.table_type == 'ObjectType' and shape_el.expr.expr and irutils.contains_dml( shape_el.expr.expr, skip_bindings=True, skip_nodes=(ir_stmt.subject,), ) ): return True # If there are any single links with link properties, we need # a conflict CTE, since the link tables have to go before the # insert. if ptr_info.table_type == 'ObjectType': link_ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=True, link_bias=True) if link_ptr_info: return True return False def compile_insert_else_body( stmt: Optional[pgast.InsertStmt], ir_stmt: irast.MutatingStmt, on_conflict: irast.OnConflictClause, enclosing_cte_iterator: Optional[pgast.IteratorCTE], else_cte_rvar: Optional[ tuple[pgast.CommonTableExpr, pgast.PathRangeVar]], *, ctx: context.CompilerContextLevel) -> Optional[pgast.IteratorCTE]: else_select = on_conflict.select_ir else_branch = on_conflict.else_ir else_fail = on_conflict.else_fail # We need to generate a "conflict CTE" that filters out # objects-to-insert that would conflict with existing objects in # three scenarios: # 1) When there is a nested DML operation as part of the value # of a pointer that is stored inline with the object. # This is because we need to prevent that DML from executing # before we have a chance to see what ON CONFLICT does. # 2) When there could be a conflict with an object INSERT/UPDATEd # in this same query. (Either because of FOR or other DML statements.) # This is because we need that to raise a ConstraintError, # which means we can't use ON CONFLICT, and so we need to prevent # the insertion of objects that conflict with existing ones ourselves. # 3) When the type to insert has rewrite rules on it that could # prevent seeing the existing objects, we use conflict ctes # instead of setting ON CONFLICT so that we raise ConstraintError # instead of succeeding. This is partially for compatibility with # cases that have access rules and fall into case 1, where we # must do this, and partly because we would not be able to return # the objects in the ELSE anyway. # # When we need a conflict CTE, we don't use SQL ON CONFLICT. In # cases 2 & 3, that is the whole point, while in case 1 it would # just be superfluous to do so. # # When none of these cases obtain, we use ON CONFLICT because it # ought to be more performant. needs_conflict_cte = insert_needs_conflict_cte( ir_stmt, on_conflict, ctx=ctx) if not needs_conflict_cte and not else_fail: target = None if on_conflict.constraint: constraint_name = common.get_constraint_raw_name( on_conflict.constraint.id) target = pgast.OnConflictTarget( constraint_name=f'"{constraint_name}"' ) assert isinstance(stmt, pgast.InsertStmt) stmt.on_conflict = pgast.OnConflictClause( action=pgast.OnConflictAction.DO_NOTHING, target=target, ) if not else_branch and not needs_conflict_cte and not else_fail: return None subject_id = ir_stmt.subject.path_id # Compile the query CTE that selects out the existing rows # that we would conflict with with ctx.newrel() as ictx: ictx.expr_exposed = False ictx.path_scope[subject_id] = ictx.rel compile_insert_else_body_failure_check(ir_stmt, on_conflict, ctx=ictx) merge_iterator(enclosing_cte_iterator, ictx.rel, ctx=ictx) clauses.setup_iterator_volatility(enclosing_cte_iterator, ctx=ictx) dispatch.compile(else_select, ctx=ictx) pathctx.put_path_id_map(ictx.rel, subject_id, else_select.path_id) # Discard else_branch from the path_id_mask to prevent subject_id # from being masked. ictx.rel.path_id_mask.discard(else_select.path_id) else_select_cte = pgast.CommonTableExpr( query=ictx.rel, name=ctx.env.aliases.get('else'), for_dml_stmt=ctx.get_current_dml_stmt(), ) if else_fail: ctx.env.check_ctes.append(else_select_cte) ictx.toplevel_stmt.append_cte(else_select_cte) else_select_rvar = relctx.rvar_for_rel(else_select_cte, ctx=ctx) if else_branch: # Compile the body of the ELSE query with ctx.newrel() as ictx: ictx.path_scope[subject_id] = ictx.rel relctx.include_rvar(ictx.rel, else_select_rvar, path_id=else_select.path_id, ctx=ictx) ictx.enclosing_cte_iterator = pgast.IteratorCTE( path_id=else_select.path_id, cte=else_select_cte, parent=enclosing_cte_iterator) dispatch.compile(else_branch, ctx=ictx) pathctx.put_path_id_map(ictx.rel, subject_id, else_branch.path_id) # Discard else_branch from the path_id_mask to prevent subject_id # from being masked. ictx.rel.path_id_mask.discard(else_branch.path_id) assert else_cte_rvar else_branch_cte = else_cte_rvar[0] else_branch_cte.query = ictx.rel ictx.toplevel_stmt.append_cte(else_branch_cte) anti_cte_iterator = None if needs_conflict_cte: # Compile a CTE that matches rows that didn't appear in the # ELSE query of conflicting rows. with ctx.newrel() as ictx: merge_iterator(enclosing_cte_iterator, ictx.rel, ctx=ictx) clauses.setup_iterator_volatility(enclosing_cte_iterator, ctx=ictx) # Set up a dummy path to represent all of the rows # that *aren't* being filtered out dummy_pathid = irast.PathId.new_dummy(ctx.env.aliases.get('dummy')) with ictx.subrel() as dctx: dummy_q = dctx.rel relctx.create_iterator_identity_for_path( dummy_pathid, dummy_q, ctx=dctx) dummy_rvar = relctx.rvar_for_rel( dummy_q, lateral=True, ctx=ictx) relctx.include_rvar(ictx.rel, dummy_rvar, path_id=dummy_pathid, ctx=ictx) with ictx.subrel() as subrelctx: subrel = subrelctx.rel relctx.include_rvar(subrel, else_select_rvar, path_id=subject_id, ctx=ictx) # Do the anti-join iter_path_id = ( enclosing_cte_iterator.path_id if enclosing_cte_iterator else None) aspect = ( enclosing_cte_iterator.aspect if enclosing_cte_iterator else pgce.PathAspect.IDENTITY ) relctx.anti_join(ictx.rel, subrel, iter_path_id, aspect=aspect, ctx=ctx) # Package it up as a CTE anti_cte = pgast.CommonTableExpr( query=ictx.rel, name=ctx.env.aliases.get('non_conflict'), for_dml_stmt=ctx.get_current_dml_stmt(), ) ictx.toplevel_stmt.append_cte(anti_cte) anti_cte_iterator = pgast.IteratorCTE( path_id=dummy_pathid, cte=anti_cte, parent=ictx.enclosing_cte_iterator, iterator_bond=True ) return anti_cte_iterator def compile_insert_else_body_failure_check( ir_stmt: irast.MutatingStmt, on_conflict: irast.OnConflictClause, *, ctx: context.CompilerContextLevel) -> None: else_fail = on_conflict.else_fail if not else_fail: return # Copy the type rels from the possibly conflicting earlier DML # into the None overlays so it gets picked up. # Also copy our own overlays, which we care about just for # the pointer overlays. merge_overlays_globally((ir_stmt, else_fail,), ctx=ctx) # Do some work so that we aren't looking at the existing on-disk # data, just newly data created data. overlays_map = ctx.rel_overlays.type.get(None, immu.Map()) for k, overlays in overlays_map.items(): # Strip out filters, which we don't care about in this context overlays = tuple([ (k, r, p) for k, r, p in overlays if k != context.OverlayOp.FILTER ]) # Drop the initial set if overlays and overlays[0][0] == context.OverlayOp.UNION: overlays = ( (context.OverlayOp.REPLACE, *overlays[0][1:]), *overlays[1:] ) overlays_map = overlays_map.set(k, overlays) ctx.rel_overlays.type = ctx.rel_overlays.type.set(None, overlays_map) assert on_conflict.constraint cid = common.get_constraint_raw_name(on_conflict.constraint.id) maybe_raise = pgast.FuncCall( name=astutils.edgedb_func('raise', ctx=ctx), args=[ pgast.TypeCast( arg=pgast.NullConstant(), type_name=pgast.TypeName(name=('text',))), pgast.StringConstant(val='exclusion_violation'), pgast.NamedFuncArg( name='msg', val=pgast.StringConstant( val=( f'duplicate key value violates unique ' f'constraint "{cid}"' ) ), ), pgast.NamedFuncArg( name='constraint', val=pgast.StringConstant(val=f"{cid}") ), ], ) ctx.rel.target_list.append( pgast.ResTarget(name='error', val=maybe_raise) ) def process_update_body( *, ir_stmt: irast.UpdateStmt, update_cte: pgast.CommonTableExpr, dml_parts: DMLParts, typeref: irast.TypeRef, ctx: context.CompilerContextLevel, ) -> None: """Generate SQL DML CTEs from an UpdateStmt IR. Args: ir_stmt: IR of the DML statement. update_cte: CTE representing the SQL UPDATE to the main relation of the UPDATE subject. dml_parts: A DMLParts tuple returned by init_dml_stmt(). typeref: A TypeRef corresponding the the type of a subject being updated by the update_cte. """ assert isinstance(update_cte.query, pgast.SelectStmt) contents_select = update_cte.query toplevel = ctx.toplevel_stmt put_iterator_bond(ctx.enclosing_cte_iterator, contents_select) assert dml_parts.range_cte iterator = pgast.IteratorCTE( path_id=ir_stmt.subject.path_id, cte=dml_parts.range_cte, parent=ctx.enclosing_cte_iterator, ) with ctx.newscope() as subctx: # It is necessary to process the expressions in # the UpdateStmt shape body in the context of the # UPDATE statement so that references to the current # values of the updated object are resolved correctly. subctx.parent_rel = contents_select subctx.expr_exposed = False subctx.enclosing_cte_iterator = iterator clauses.setup_iterator_volatility(iterator, ctx=subctx) # compile contents CTE elements = [ (shape_el, shape_el.expr.ptrref, shape_op) for shape_el, shape_op in ir_stmt.subject.shape if shape_op != qlast.ShapeOp.MATERIALIZE ] values, external_updates, ptr_map = process_update_shape( ir_stmt, contents_select, elements, typeref, subctx ) relation = contents_select.from_clause[0] assert isinstance(relation, pgast.PathRangeVar) # Use a dynamic rvar to return values out of the select purely # based on material rptr, as if it was a base relation (and to # fall back to the base relation if the value wasn't updated.) fallback_rvar = pgast.DynamicRangeVar( dynamic_get_path=_mk_dynamic_get_path(ptr_map, typeref, relation), ) pathctx.put_path_source_rvar( contents_select, ir_stmt.subject.path_id, fallback_rvar, ) pathctx.put_path_value_rvar( contents_select, ir_stmt.subject.path_id, fallback_rvar, ) update_stmt = None single_external = [ ir for ir, _ in external_updates if ir.expr.dir_cardinality.is_single() ] rewrites = ir_stmt.rewrites and ir_stmt.rewrites.by_type.get(typeref) pol_expr = ir_stmt.write_policies.get(typeref.id) pol_ctx = None if pol_expr or rewrites or single_external: # Create a context for handling policies/rewrites that we will # use later. We do this in advance so that the link update code # can populate overlay fields in it. with ctx.new() as pol_ctx: pol_ctx.rel_overlays = context.RelOverlays() no_update = not values and not rewrites and not single_external if no_update: # No updates directly to the set target table, # so convert the UPDATE statement into a SELECT. update_cte.query = contents_select contents_cte = update_cte else: contents_cte = pgast.CommonTableExpr( query=contents_select, name=ctx.env.aliases.get("upd_contents"), for_dml_stmt=ctx.get_current_dml_stmt(), ) toplevel.append_cte(contents_cte) # Process necessary updates to the link tables. # We do link tables before we do the main update so that link_ctes = [] for expr, shape_op in external_updates: link_cte, check_cte = process_link_update( ir_stmt=ir_stmt, ir_set=expr, dml_cte=contents_cte, iterator=iterator, shape_op=shape_op, source_typeref=typeref, ctx=ctx, policy_ctx=pol_ctx, ) if link_cte: link_ctes.append(link_cte) if check_cte is not None: ctx.env.check_ctes.append(check_cte) if not no_update: table_relation = relctx.range_for_typeref( typeref, ir_stmt.subject.path_id, for_mutation=True, ctx=ctx, ) assert isinstance(table_relation, pgast.RelRangeVar) range_relation = contents_select.from_clause[1] assert isinstance(range_relation, pgast.PathRangeVar) contents_rvar = relctx.rvar_for_rel(contents_cte, ctx=ctx) subject_path_id = ir_stmt.subject.path_id # Compile rewrites CTE if rewrites or single_external: rewrites = rewrites or {} assert pol_ctx # Now that all the compilation for the UPDATE has been done, # apply the tweaked policy overlays. pol_ctx.rel_overlays = update_overlay( ctx.rel_overlays, pol_ctx.rel_overlays ) with pol_ctx.reenter(), pol_ctx.new() as rctx: merge_overlays_globally((ir_stmt,), ctx=rctx) contents_cte, contents_rvar, values = process_update_rewrites( ir_stmt, typeref=typeref, contents_cte=contents_cte, contents_rvar=contents_rvar, iterator=iterator, contents_select=contents_select, table_relation=table_relation, range_relation=range_relation, single_external=single_external, rewrites=rewrites, elements=elements, ctx=rctx, ) update_stmt = pgast.UpdateStmt( relation=table_relation, where_clause=astutils.new_binop( lexpr=pgast.ColumnRef( name=[table_relation.alias.aliasname, "id"] ), op="=", rexpr=pathctx.get_rvar_path_identity_var( contents_rvar, subject_path_id, env=ctx.env ), ), from_clause=[contents_rvar], targets=[ pgast.MultiAssignRef( columns=[not_none(value.name) for value, _ in values], source=pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=[ contents_rvar.alias.aliasname, not_none(value.name), ] ) ) for value, _ in values ], ), ) ], ) relctx.pull_path_namespace( target=update_stmt, source=contents_rvar, ctx=ctx ) pathctx.put_path_value_rvar( update_stmt, subject_path_id, table_relation ) pathctx.put_path_source_rvar( update_stmt, subject_path_id, table_relation ) put_iterator_bond(ctx.enclosing_cte_iterator, update_stmt) update_cte.query = update_stmt if pol_expr: assert pol_ctx with pol_ctx.reenter(): policy_cte = compile_policy_check( contents_cte, ir_stmt, pol_expr, typeref=typeref, ctx=pol_ctx ) force_policy_checks( policy_cte, ((update_stmt,) if update_stmt else ()) + tuple(cte.query for cte in link_ctes), ctx=ctx, ) if values: toplevel.append_cte(update_cte) for link_cte in link_ctes: toplevel.append_cte(link_cte) def process_update_rewrites( ir_stmt: irast.UpdateStmt, *, typeref: irast.TypeRef, contents_cte: pgast.CommonTableExpr, contents_rvar: pgast.PathRangeVar, iterator: Optional[pgast.IteratorCTE], contents_select: pgast.SelectStmt, table_relation: pgast.RelRangeVar, range_relation: pgast.PathRangeVar, single_external: list[irast.SetE[irast.Pointer]], rewrites: irast.RewritesOfType, elements: Sequence[ tuple[irast.SetE[irast.Pointer], irast.BasePointerRef, qlast.ShapeOp]], ctx: context.CompilerContextLevel, ) -> tuple[ pgast.CommonTableExpr, pgast.PathRangeVar, list[tuple[pgast.ResTarget, irast.PathId]], ]: # assert ir_stmt.rewrites subject_path_id = ir_stmt.subject.path_id if ir_stmt.rewrites: old_path_id = ir_stmt.rewrites.old_path_id else: # Need values for the single external link case old_path_id = subject_path_id assert old_path_id table_rel = table_relation.relation assert isinstance(table_rel, pgast.Relation) # Need to set up an iterator for any internal DML. iterator = pgast.IteratorCTE( path_id=subject_path_id, cte=contents_cte, parent=iterator, # __old__ other_paths=( ((old_path_id, pgce.PathAspect.IDENTITY),) ), ) with ctx.newrel() as rctx: rewrites_stmt = rctx.rel clauses.setup_iterator_volatility(iterator, ctx=rctx) rctx.enclosing_cte_iterator = iterator # pruned down version of gen_dml_cte rewrites_stmt.from_clause.append(range_relation) # pull in contents_select for __subject__ relctx.include_rvar( rewrites_stmt, contents_rvar, subject_path_id, # We don't want to update the mask... in case the subject # is an iterator that needs to be reexported. update_mask=False, ctx=ctx, ) rewrites_stmt.where_clause = astutils.new_binop( lexpr=pathctx.get_rvar_path_identity_var( contents_rvar, subject_path_id, env=ctx.env ), op="=", rexpr=pathctx.get_rvar_path_identity_var( range_relation, subject_path_id, env=ctx.env ), ) # pull in table_relation for __old__ table_rel.path_outputs[ (old_path_id, pgce.PathAspect.VALUE) ] = pathctx.get_path_value_output( table_rel, subject_path_id, env=ctx.env ) relctx.include_rvar( rewrites_stmt, table_relation, old_path_id, ctx=ctx ) rewrites_stmt.where_clause = astutils.extend_binop( rewrites_stmt.where_clause, astutils.new_binop( lexpr=pgast.ColumnRef( name=[table_relation.alias.aliasname, "id"] ), op="=", rexpr=pathctx.get_rvar_path_identity_var( range_relation, subject_path_id, env=ctx.env ), ), ) relctx.pull_path_namespace( target=rewrites_stmt, source=table_relation, ctx=ctx ) rewrite_elements = [ (el, ptrref, qlast.ShapeOp.ASSIGN) for el, ptrref in rewrites.values() ] values, _, nptr_map = process_update_shape( ir_stmt, rewrites_stmt, rewrite_elements, typeref, rctx, ) # If there are any single links that were compiled externally, # populate the field from the link overlays. handled = set(rewrites) for ext_ir in single_external: handled.add(ext_ir.expr.ptrref.shortname.name) actual_ptrref = irtyputils.find_actual_ptrref( typeref, ext_ir.expr.ptrref) with rctx.subrel() as ectx: ext_rvar = relctx.new_pointer_rvar( ext_ir, link_bias=True, src_rvar=contents_rvar, ctx=ectx) relctx.include_rvar( ectx.rel, ext_rvar, ext_ir.path_id, ctx=ectx) # Make the subquery output the target pathctx.get_path_value_output( ectx.rel, ext_ir.path_id, env=ctx.env) ptr_info = pg_types.get_ptrref_storage_info( actual_ptrref, resolve_type=True, link_bias=False) updval = pgast.ResTarget( name=ptr_info.column_name, val=ectx.rel) rewrites_stmt.target_list.append(updval) values.append((updval, ext_ir.path_id)) nptr_map[actual_ptrref.name] = ectx.rel # Pull in pointers that were not rewritten not_rewritten = { (e, ptrref) for e, ptrref, _ in elements if ptrref.shortname.name not in handled } for e, ptrref in not_rewritten: # FIXME: Duplicates some with process_update_shape actual_ptrref = irtyputils.find_actual_ptrref(typeref, ptrref) ptr_info = pg_types.get_ptrref_storage_info( actual_ptrref, resolve_type=True, link_bias=False) if ptr_info.table_type == 'ObjectType': val = pathctx.get_path_var( rewrites_stmt, e.path_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) updval = pgast.ResTarget( name=ptr_info.column_name, val=val) values.append((updval, e.path_id)) rewrites_stmt.target_list.append(updval) fallback_rvar = pgast.DynamicRangeVar( dynamic_get_path=_mk_dynamic_get_path( nptr_map, typeref, contents_rvar), ) pathctx.put_path_source_rvar(rctx.rel, subject_path_id, fallback_rvar) pathctx.put_path_value_rvar(rctx.rel, subject_path_id, fallback_rvar) rewrites_cte = pgast.CommonTableExpr( query=rctx.rel, name=ctx.env.aliases.get("upd_rewrites"), for_dml_stmt=ctx.get_current_dml_stmt(), ) ctx.toplevel_stmt.append_cte(rewrites_cte) rewrites_rvar = relctx.rvar_for_rel(rewrites_cte, ctx=ctx) return rewrites_cte, rewrites_rvar, values def process_update_shape( ir_stmt: irast.UpdateStmt, rel: pgast.SelectStmt, elements: Sequence[ tuple[irast.SetE[irast.Pointer], irast.BasePointerRef, qlast.ShapeOp]], typeref: irast.TypeRef, ctx: context.CompilerContextLevel, ) -> tuple[ list[tuple[pgast.ResTarget, irast.PathId]], list[tuple[irast.SetE[irast.Pointer], qlast.ShapeOp]], dict[sn.Name, pgast.BaseExpr], ]: values: list[tuple[pgast.ResTarget, irast.PathId]] = [] external_updates: list[tuple[irast.SetE[irast.Pointer], qlast.ShapeOp]] = [] ptr_map: dict[sn.Name, pgast.BaseExpr] = {} for element, shape_ptrref, shape_op in elements: actual_ptrref = irtyputils.find_actual_ptrref(typeref, shape_ptrref) ptr_info = pg_types.get_ptrref_storage_info( actual_ptrref, resolve_type=True, link_bias=False ) link_ptr_info = pg_types.get_ptrref_storage_info( actual_ptrref, resolve_type=False, link_bias=True ) # XXX: Slightly nervous about this. assert isinstance(element.expr, irast.Pointer) updvalue = element.expr.expr if ( ptr_info.table_type == "ObjectType" and not link_ptr_info and updvalue is not None ): with ctx.newscope() as scopectx: scopectx.expr_exposed = False val: pgast.BaseExpr if irtyputils.is_tuple(element.typeref): # When target is a tuple type, make sure # the expression is compiled into a subquery # returning a single column that is explicitly # cast into the appropriate composite type. val = relgen.set_as_subquery( element, as_value=True, explicit_cast=ptr_info.column_type, ctx=scopectx, ) else: if ( isinstance(updvalue, irast.MutatingStmt) and updvalue in ctx.dml_stmts ): with scopectx.substmt() as srelctx: dml_cte = ctx.dml_stmts[updvalue] wrap_dml_cte(updvalue, dml_cte, ctx=srelctx) pathctx.get_path_identity_output( srelctx.rel, updvalue.subject.path_id, env=srelctx.env, ) val = srelctx.rel else: # base case val = dispatch.compile(updvalue, ctx=scopectx) assert isinstance(updvalue, irast.Stmt) val = check_update_type( val, val, is_subquery=True, ir_stmt=ir_stmt, ir_set=updvalue.result, shape_ptrref=shape_ptrref, actual_ptrref=actual_ptrref, ctx=scopectx, ) val = pgast.TypeCast( arg=val, type_name=pgast.TypeName(name=ptr_info.column_type), ) if shape_op is qlast.ShapeOp.SUBTRACT: val = pgast.FuncCall( name=("nullif",), args=[ pgast.ColumnRef(name=[ptr_info.column_name]), val, ], ) ptr_map[actual_ptrref.name] = val updtarget = pgast.ResTarget( name=ptr_info.column_name, val=val, ) values.append((updtarget, element.path_id)) # Register the output as both a var and an output # so that if it is referenced in a policy or # rewrite, the find_path_output optimization fires # and we reuse the output instead of duplicating # it. # XXX: Maybe this suggests a rework of the # DynamicRangeVar mechanism would be a good idea. pathctx.put_path_var( rel, element.path_id, aspect=pgce.PathAspect.VALUE, var=val, ) pathctx._put_path_output_var( rel, element.path_id, aspect=pgce.PathAspect.VALUE, var=pgast.ColumnRef(name=[ptr_info.column_name]), ) if link_ptr_info and link_ptr_info.table_type == "link": external_updates.append((element, shape_op)) rel.target_list.extend(v for v, _ in values) return (values, external_updates, ptr_map) def process_extra_conflicts( *, ir_stmt: irast.MutatingStmt, dml_parts: DMLParts, ctx: context.CompilerContextLevel, ) -> None: if not ir_stmt.conflict_checks: return for extra_conflict in ir_stmt.conflict_checks: q_path = extra_conflict.check_anchor assert q_path typeref = q_path.target.real_material_type cte, _ = dml_parts.dml_ctes[typeref] pathctx.put_path_id_map( cte.query, q_path, ir_stmt.subject.path_id) conflict_iterator = pgast.IteratorCTE( path_id=q_path, cte=cte, parent=ctx.enclosing_cte_iterator) compile_insert_else_body( None, ir_stmt, extra_conflict, conflict_iterator, None, ctx=ctx, ) def check_update_type( val: pgast.BaseExpr, rel_or_rvar: pgast.BaseExpr | pgast.PathRangeVar, *, is_subquery: bool, ir_stmt: irast.UpdateStmt, ir_set: irast.Set, shape_ptrref: irast.BasePointerRef, actual_ptrref: irast.BasePointerRef, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: """Possibly insert a type check on an UPDATE to a link Because edgedb allows subtypes to covariantly override the target types of links, we need to insert runtime type checks when the target in a base type being UPDATEd does not match the target type for this concrete subtype being handled. """ assert isinstance(actual_ptrref, irast.PointerRef) base_ptrref = shape_ptrref.real_material_ptr # We skip the check if either the base type matches exactly # or the shape type matches exactly. FIXME: *Really* we want to do # a subtype check, here, though, since this could do a needless # check if we have multiple levels of overloading, but we don't # have the infrastructure here. if ( not irtyputils.is_object(ir_set.typeref) or base_ptrref.out_target.id == actual_ptrref.out_target.id or shape_ptrref.out_target.id == actual_ptrref.out_target.id ): return val if isinstance(rel_or_rvar, pgast.PathRangeVar): rvar = rel_or_rvar else: assert isinstance(rel_or_rvar, pgast.BaseRelation) rvar = relctx.rvar_for_rel(rel_or_rvar, ctx=ctx) # Make up a ptrref for the __type__ link on our actual target type # and make up a new path_id to access it. Relies on __type__ always # being named __type__, but that's fine. # (Arranging to actually get a legit pointer ref is pointlessly expensive.) el_name = sn.QualName('__', '__type__') actual_type_ptrref = irast.SpecialPointerRef( name=el_name, shortname=el_name, out_source=actual_ptrref.out_target, # HACK: This is obviously not the right target type, but we don't # need it for anything and the pathid never escapes this function. out_target=actual_ptrref.out_target, out_cardinality=qltypes.Cardinality.AT_MOST_ONE, ) type_pathid = ir_set.path_id.extend(ptrref=actual_type_ptrref) # Grab the actual value we have inserted and pull the __type__ out rval = pathctx.get_rvar_path_identity_var( rvar, ir_set.path_id, env=ctx.env) typ = pathctx.get_rvar_path_identity_var(rvar, type_pathid, env=ctx.env) typeref_val = dispatch.compile(actual_ptrref.out_target, ctx=ctx) # Do the check! Include the ptrref for this concrete class and # also the (dynamic) type of the argument, so that we can produce # a good error message. check_result = pgast.FuncCall( name=astutils.edgedb_func('issubclass', ctx=ctx), args=[typ, typeref_val], ) maybe_null = pgast.CaseExpr( args=[pgast.CaseWhen(expr=check_result, result=rval)]) maybe_raise = pgast.FuncCall( name=astutils.edgedb_func('raise_on_null', ctx=ctx), args=[ maybe_null, pgast.StringConstant(val='wrong_object_type'), pgast.NamedFuncArg( name='msg', val=pgast.StringConstant(val='covariance error') ), pgast.NamedFuncArg( name='column', val=pgast.StringConstant(val=str(actual_ptrref.id)), ), pgast.NamedFuncArg( name='table', val=pgast.TypeCast( arg=typ, type_name=pgast.TypeName(name=('text',)) ), ), ], ) if is_subquery: # If this is supposed to be a subquery (because it is an # update of a single link), wrap the result query in a new one, # since we need to access two outputs from it and produce just one # from this query return pgast.SelectStmt( from_clause=[rvar], target_list=[pgast.ResTarget(val=maybe_raise)], ) else: return maybe_raise def process_link_update( *, ir_stmt: irast.MutatingStmt, ir_set: irast.SetE[irast.Pointer], shape_op: qlast.ShapeOp = qlast.ShapeOp.ASSIGN, source_typeref: irast.TypeRef, dml_cte: pgast.CommonTableExpr, iterator: Optional[pgast.IteratorCTE] = None, ctx: context.CompilerContextLevel, policy_ctx: Optional[context.CompilerContextLevel], ) -> tuple[Optional[pgast.CommonTableExpr], Optional[pgast.CommonTableExpr]]: """Perform updates to a link relation as part of a DML statement. Args: ir_stmt: IR of the DML statement. ir_set: IR of the INSERT/UPDATE body element. shape_op: The operation of the UPDATE body element (:=, +=, -=). For INSERT this should always be :=. source_typeref: An ir.TypeRef instance representing the specific type of an object being updated. dml_cte: CTE representing the SQL INSERT or UPDATE to the main relation of the DML subject. iterator: IR and CTE representing the iterator range in the FOR clause of the EdgeQL DML statement (if present). policy_ctx: Optionally, a context in which to populate overlays that use the select CTE for overlays instead of the actual insert CTE. This is needed if an access policy is to be applied, and requires disabling a potential optimization. We need separate overlay contexts because default values for link properties don't currently get populated in our IR, so we need to do actual SQL DML to get their values. (And so we disallow their use in policies.) """ toplevel = ctx.toplevel_stmt is_insert = isinstance(ir_stmt, irast.InsertStmt) rptr = ir_set.expr ptrref = rptr.ptrref assert isinstance(ptrref, irast.PointerRef) target_is_scalar = not irtyputils.is_object(ir_set.typeref) path_id = ir_set.path_id # The links in the dml class shape have been derived, # but we must use the correct specialized link class for the # base material type. mptrref = irtyputils.find_actual_ptrref(source_typeref, ptrref) assert isinstance(mptrref, irast.PointerRef) target_rvar = relctx.range_for_ptrref( mptrref, for_mutation=True, only_self=True, ctx=ctx) assert isinstance(target_rvar, pgast.RelRangeVar) assert isinstance(target_rvar.relation, pgast.Relation) target_alias = target_rvar.alias.aliasname dml_cte_rvar = pgast.RelRangeVar( relation=dml_cte, alias=pgast.Alias( aliasname=ctx.env.aliases.get('m') ) ) # Turn the IR of the expression on the right side of := # into a subquery returning records for the link table. data_cte, specified_cols = process_link_values( ir_stmt=ir_stmt, ir_expr=ir_set, dml_rvar=dml_cte_rvar, source_typeref=source_typeref, target_is_scalar=target_is_scalar, enforce_cardinality=(shape_op is qlast.ShapeOp.ASSIGN), dml_cte=dml_cte, iterator=iterator, ctx=ctx, ) toplevel.append_cte(data_cte) delqry: Optional[pgast.DeleteStmt] data_select = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=[data_cte.name, pgast.Star()] ), ), ], from_clause=[ pgast.RelRangeVar(relation=data_cte), ], ) if not is_insert and shape_op is not qlast.ShapeOp.APPEND: source_ref = pathctx.get_rvar_path_identity_var( dml_cte_rvar, ir_stmt.subject.path_id, env=ctx.env, ) if shape_op is qlast.ShapeOp.SUBTRACT: data_rvar = relctx.rvar_for_rel(data_select, ctx=ctx) if target_is_scalar: # MULTI properties are not distinct, and since `-=` must # be a proper inverse of `+=` we cannot simply DELETE # all property values matching the `-=` expression, and # instead have to resort to careful deletion of no more # than the number of tuples returned by the expression. # Here, we rely on the "ctid" system column to refer to # specific tuples. # # DELETE # FROM # WHERE # ctid IN ( # SELECT # shortlist.ctid # FROM # (SELECT # source, # target, # count(target) AS cnt # FROM # # GROUP BY source, target # ) AS counts, # LATERAL ( # SELECT # candidates.ctid # FROM # (SELECT # ctid, # row_number() OVER ( # PARTITION BY data # ORDER BY data # ) AS rn # FROM # # WHERE # source = counts.source # AND target = counts.target # ) AS candidates # WHERE # candidates.rn <= counts.cnt # ) AS shortlist # ); val_src_ref = pgast.ColumnRef( name=[data_rvar.alias.aliasname, 'source'], ) val_tgt_ref = pgast.ColumnRef( name=[data_rvar.alias.aliasname, 'target'], ) counts_select = pgast.SelectStmt( target_list=[ pgast.ResTarget(name='source', val=val_src_ref), pgast.ResTarget(name='target', val=val_tgt_ref), pgast.ResTarget( name='cnt', val=pgast.FuncCall( name=('count',), args=[val_tgt_ref], ), ), ], from_clause=[data_rvar], group_clause=[val_src_ref, val_tgt_ref], ) counts_rvar = relctx.rvar_for_rel(counts_select, ctx=ctx) counts_alias = counts_rvar.alias.aliasname target_ref = pgast.ColumnRef(name=[target_alias, 'target']) candidates_select = pgast.SelectStmt( target_list=[ pgast.ResTarget( name='ctid', val=pgast.ColumnRef( name=[target_alias, 'ctid'], ), ), pgast.ResTarget( name='rn', val=pgast.FuncCall( name=('row_number',), args=[], over=pgast.WindowDef( partition_clause=[target_ref], order_clause=[ pgast.SortBy(node=target_ref), ], ), ), ), ], from_clause=[target_rvar], where_clause=astutils.new_binop( lexpr=astutils.new_binop( lexpr=pgast.ColumnRef( name=[counts_alias, 'source'], ), op='=', rexpr=pgast.ColumnRef( name=[target_alias, 'source'], ), ), op='AND', rexpr=astutils.new_binop( lexpr=target_ref, op='=', rexpr=pgast.ColumnRef( name=[counts_alias, 'target']), ), ), ) candidates_rvar = relctx.rvar_for_rel( candidates_select, ctx=ctx) candidates_alias = candidates_rvar.alias.aliasname shortlist_select = pgast.SelectStmt( target_list=[ pgast.ResTarget( name='ctid', val=pgast.ColumnRef( name=[candidates_alias, 'ctid'], ), ), ], from_clause=[candidates_rvar], where_clause=astutils.new_binop( lexpr=pgast.ColumnRef(name=[candidates_alias, 'rn']), op='<=', rexpr=pgast.ColumnRef(name=[counts_alias, 'cnt']), ), ) shortlist_rvar = relctx.rvar_for_rel( shortlist_select, lateral=True, ctx=ctx) shortlist_alias = shortlist_rvar.alias.aliasname ctid_select = pgast.SelectStmt( target_list=[ pgast.ResTarget( name='ctid', val=pgast.ColumnRef(name=[shortlist_alias, 'ctid']) ), ], from_clause=[ counts_rvar, shortlist_rvar, ], ) delqry = pgast.DeleteStmt( relation=target_rvar, where_clause=astutils.new_binop( lexpr=pgast.ColumnRef( name=[target_alias, 'ctid'], ), op='=', rexpr=pgast.SubLink( operator="ANY", expr=ctid_select, ), ), returning_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=[target_alias, pgast.Star()], ), ) ] ) else: # Links are always distinct, so we can simply # DELETE the tuples matching the `-=` expression. delqry = pgast.DeleteStmt( relation=target_rvar, where_clause=astutils.new_binop( lexpr=astutils.new_binop( lexpr=source_ref, op='=', rexpr=pgast.ColumnRef( name=[target_alias, 'source'], ), ), op='AND', rexpr=astutils.new_binop( lexpr=pgast.ColumnRef( name=[target_alias, 'target'], ), op='=', rexpr=pgast.ColumnRef( name=[data_rvar.alias.aliasname, 'target'], ), ), ), using_clause=[ dml_cte_rvar, data_rvar, ], returning_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=[target_alias, pgast.Star()], ), ) ] ) else: # Drop all previous link records for this source. delqry = pgast.DeleteStmt( relation=target_rvar, where_clause=astutils.new_binop( lexpr=source_ref, op='=', rexpr=pgast.ColumnRef( name=[target_alias, 'source'], ), ), using_clause=[dml_cte_rvar], returning_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=[target_alias, pgast.Star()], ), ) ] ) delcte = pgast.CommonTableExpr( name=ctx.env.aliases.get(hint='link_upd_del'), query=delqry, for_dml_stmt=ctx.get_current_dml_stmt(), ) if shape_op is not qlast.ShapeOp.SUBTRACT: # Correlate the deletion with INSERT to make sure # link properties get erased properly and we aren't # just ON CONFLICT UPDATE-ing the link rows. # This basically just tacks on a # WHERE (SELECT count(*) FROM delcte) IS NOT NULL) del_select = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.FuncCall( name=['count'], args=[pgast.ColumnRef(name=[pgast.Star()])], ), ), ], from_clause=[ pgast.RelRangeVar(relation=delcte), ], ) data_select.where_clause = astutils.extend_binop( data_select.where_clause, pgast.NullTest(arg=del_select, negated=True), ) pathctx.put_path_value_rvar( delcte.query, path_id.ptr_path(), target_rvar ) # Record the effect of this removal in the relation overlay # context to ensure that references to the link in the result # of this DML statement yield the expected results. except_overlay = (lambda octx: relctx.add_ptr_rel_overlay( mptrref, context.OverlayOp.EXCEPT, delcte, path_id=path_id.ptr_path(), dml_stmts=ctx.dml_stmt_stack, ctx=octx, ) ) except_overlay(ctx) if policy_ctx: except_overlay(policy_ctx) toplevel.append_cte(delcte) else: delqry = None if shape_op is qlast.ShapeOp.SUBTRACT: if mptrref.dir_cardinality(rptr.direction).can_be_zero(): # The pointer is OPTIONAL, no checks or further processing # is needed. return None, None else: # The pointer is REQUIRED, so we must take the result of # the subtraction produced by the "delcte" above, apply it # as a subtracting overlay, and re-compute the pointer relation # to see if there are any newly created empty sets. # # The actual work is done via raise_on_null injection performed # by "process_link_values()" below (hence "enforce_cardinality"). # # The other part of this enforcement is in doing it when a # target is deleted and the link policy is ALLOW. This is # handled in _get_outline_link_trigger_proc_text in # pgsql/delta.py. # Turn `foo := ` into just `foo`. ptr_ref_set = irast.Set( path_id=ir_set.path_id, path_scope_id=ir_set.path_scope_id, typeref=ir_set.typeref, expr=ir_set.expr.replace(expr=None), ) assert irutils.is_set_instance(ptr_ref_set, irast.Pointer) with ctx.new() as subctx: # TODO: Do we really need a copy here? things /seem/ # to work without it subctx.rel_overlays = subctx.rel_overlays.copy() relctx.add_ptr_rel_overlay( ptrref, context.OverlayOp.EXCEPT, delcte, path_id=path_id.ptr_path(), ctx=subctx ) check_cte, _ = process_link_values( ir_stmt=ir_stmt, ir_expr=ptr_ref_set, dml_rvar=dml_cte_rvar, source_typeref=source_typeref, target_is_scalar=target_is_scalar, enforce_cardinality=True, dml_cte=dml_cte, iterator=iterator, ctx=subctx, ) toplevel.append_cte(check_cte) return None, check_cte cols = [pgast.ColumnRef(name=[col]) for col in specified_cols] conflict_cols = ['source', 'target'] if is_insert or target_is_scalar: conflict_clause = None elif ( len(cols) == len(conflict_cols) and delqry is not None and not policy_ctx ): # There are no link properties, so we can optimize the # link replacement operation by omitting the overlapping # link rows from deletion. filter_select = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.ColumnRef(name=['source']), ), pgast.ResTarget( val=pgast.ColumnRef(name=['target']), ), ], from_clause=[pgast.RelRangeVar(relation=data_cte)], ) delqry.where_clause = astutils.extend_binop( delqry.where_clause, astutils.new_binop( lexpr=pgast.ImplicitRowExpr( args=[ pgast.ColumnRef(name=['source']), pgast.ColumnRef(name=['target']), ], ), rexpr=pgast.SubLink( operator="ALL", expr=filter_select, ), op='!=', ) ) conflict_clause = pgast.OnConflictClause( action=pgast.OnConflictAction.DO_NOTHING, target=pgast.OnConflictTarget( index_elems=[ pgast.IndexElem(expr=pgast.ColumnRef(name=[col])) for col in conflict_cols ] ), ) else: # Inserting rows into the link table may produce cardinality # constraint violations, since the INSERT into the link table # is executed in the snapshot where the above DELETE from # the link table is not visible. Hence, we need to use # the ON CONFLICT clause to resolve this. conflict_inference = [ pgast.IndexElem(expr=pgast.ColumnRef(name=[col])) for col in conflict_cols ] target_cols = [ col.name[0] for col in cols if isinstance(col.name[0], str) and col.name[0] not in conflict_cols ] if len(target_cols) == 0: conflict_clause = pgast.OnConflictClause( action=pgast.OnConflictAction.DO_NOTHING, target=pgast.OnConflictTarget( index_elems=conflict_inference ) ) else: conflict_data = pgast.RowExpr( args=[ pgast.ColumnRef(name=['excluded', col]) for col in target_cols ], ) conflict_clause = pgast.OnConflictClause( action=pgast.OnConflictAction.DO_UPDATE, target=pgast.OnConflictTarget( index_elems=conflict_inference ), update_list=[ pgast.MultiAssignRef( columns=target_cols, source=conflict_data ) ] ) update = pgast.CommonTableExpr( name=ctx.env.aliases.get(hint='link_upd_ins'), query=pgast.InsertStmt( relation=target_rvar, select_stmt=data_select, cols=[ pgast.InsertTarget(name=downcast(str, col.name[0])) for col in cols ], on_conflict=conflict_clause, returning_list=[ pgast.ResTarget( val=pgast.ColumnRef(name=[pgast.Star()]) ) ] ), for_dml_stmt=ctx.get_current_dml_stmt(), ) pathctx.put_path_value_rvar(update.query, path_id.ptr_path(), target_rvar) def register_overlays( overlay_cte: pgast.CommonTableExpr, octx: context.CompilerContextLevel ) -> None: assert isinstance(mptrref, irast.PointerRef) # Record the effect of this insertion in the relation overlay # context to ensure that references to the link in the result # of this DML statement yield the expected results. if shape_op is qlast.ShapeOp.APPEND and not target_is_scalar: # When doing an UPDATE with +=, we need to do an anti-join # based filter to filter out links that were already present # and have been re-added. relctx.add_ptr_rel_overlay( mptrref, context.OverlayOp.FILTER, overlay_cte, dml_stmts=ctx.dml_stmt_stack, path_id=path_id.ptr_path(), ctx=octx ) relctx.add_ptr_rel_overlay( mptrref, context.OverlayOp.UNION, overlay_cte, dml_stmts=ctx.dml_stmt_stack, path_id=path_id.ptr_path(), ctx=octx ) if policy_ctx: register_overlays(data_cte, policy_ctx) register_overlays(update, ctx) return update, None def process_link_values( *, ir_stmt: irast.MutatingStmt, ir_expr: irast.SetE[irast.Pointer], dml_rvar: pgast.PathRangeVar, dml_cte: pgast.CommonTableExpr, source_typeref: irast.TypeRef, target_is_scalar: bool, enforce_cardinality: bool, iterator: Optional[pgast.IteratorCTE], ctx: context.CompilerContextLevel, ) -> tuple[pgast.CommonTableExpr, list[str]]: """Produce a pointer relation for a given body element of an INSERT/UPDATE. Given an INSERT/UPDATE body shape element that mutates a MULTI pointer, produce a (source, target [, link properties]) relation as a CTE and return it along with a list of relation attribute names. Args: ir_stmt: IR of the DML statement. ir_set: IR of the INSERT/UPDATE body element. dml_rvar: The RangeVar over the SQL INSERT/UPDATE of the main relation of the object being updated. dml_cte: CTE representing the SQL INSERT or UPDATE to the main relation of the DML subject. source_typeref: An ir.TypeRef instance representing the specific type of an object being updated. target_is_scalar: True, if mutating a property, False if a link. enforce_cardinality: Whether an explicit empty set check should be generated. Used for REQUIRED pointers. iterator: IR and CTE representing the iterator range in the FOR clause of the EdgeQL DML statement (if present). Returns: A tuple containing the pointer relation CTE and a list of attribute names in it. """ old_dml_count = len(ctx.dml_stmts) with ctx.newrel() as subrelctx: # For inserts, we need to use the main DML statement as the # iterator, while for updates, we need to use the DML range # CTE as the iterator (and so arrange for it to be passed in). # This is because, for updates, we need to execute any nested # DML once for each row in the range over all types, while # dml_cte contains just one subtype. if isinstance(ir_stmt, irast.InsertStmt): subrelctx.enclosing_cte_iterator = pgast.IteratorCTE( path_id=ir_stmt.subject.path_id, cte=dml_cte, parent=iterator) else: subrelctx.enclosing_cte_iterator = iterator row_query = subrelctx.rel merge_iterator(iterator, row_query, ctx=subrelctx) relctx.include_rvar(row_query, dml_rvar, pull_namespace=False, path_id=ir_stmt.subject.path_id, ctx=subrelctx) subrelctx.path_scope[ir_stmt.subject.path_id] = row_query ir_rptr = ir_expr.expr ptrref = ir_rptr.ptrref if ptrref.material_ptr is not None: ptrref = ptrref.material_ptr assert isinstance(ptrref, irast.PointerRef) ptr_is_multi_required = ( ptrref.out_cardinality == qltypes.Cardinality.AT_LEAST_ONE ) with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx: input_rel = input_rel_ctx.rel input_rel_ctx.expr_exposed = False input_rel_ctx.volatility_ref = ( lambda _stmt, _ctx: pathctx.get_path_identity_var( row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env),) # Check if some nested Set provides a shape that is # visible here. shape_expr = ir_expr.shape_source or ir_expr # Register that this shape needs to be compiled for use by DML, # so that the values will be there for us to grab later. input_rel_ctx.shapes_needed_by_dml.add(shape_expr) if ptr_is_multi_required and enforce_cardinality: input_rel_ctx.force_optional |= {ir_expr.path_id} dispatch.visit(ir_expr, ctx=input_rel_ctx) input_stmt: pgast.Query = input_rel input_rvar = pgast.RangeSubselect( subquery=input_rel, lateral=True, alias=pgast.Alias( aliasname=ctx.env.aliases.get('val') ) ) if len(ctx.dml_stmts) > old_dml_count: # If there were any nested inserts, we need to join them in. pathctx.put_rvar_path_bond(input_rvar, ir_stmt.subject.path_id) relctx.include_rvar(row_query, input_rvar, path_id=ir_stmt.subject.path_id, ctx=ctx) source_data: dict[str, tuple[irast.PathId, pgast.BaseExpr]] = {} if isinstance(input_stmt, pgast.SelectStmt) and input_stmt.op is not None: # UNION assert input_stmt.rarg input_stmt = input_stmt.rarg path_id = ir_expr.path_id target_ref: pgast.BaseExpr if shape_expr.shape: for element, _ in shape_expr.shape: if not element.path_id.is_linkprop_path(): continue val = pathctx.get_rvar_path_value_var( input_rvar, element.path_id, env=ctx.env) rptr = element.path_id.rptr() assert isinstance(rptr, irast.PointerRef) actual_rptr = irtyputils.find_actual_ptrref(source_typeref, rptr) ptr_info = pg_types.get_ptrref_storage_info(actual_rptr) real_path_id = path_id.ptr_path().extend(ptrref=actual_rptr) source_data.setdefault( ptr_info.column_name, (real_path_id, val)) if not target_is_scalar and 'target' not in source_data: target_ref = pathctx.get_rvar_path_identity_var( input_rvar, path_id, env=ctx.env) else: if target_is_scalar: target_ref = pathctx.get_rvar_path_value_var( input_rvar, path_id, env=ctx.env) target_ref = output.output_as_value(target_ref, env=ctx.env) else: target_ref = pathctx.get_rvar_path_identity_var( input_rvar, path_id, env=ctx.env) if isinstance(ir_stmt, irast.UpdateStmt) and not target_is_scalar: actual_ptrref = irtyputils.find_actual_ptrref(source_typeref, ptrref) target_ref = check_update_type( target_ref, input_rvar, is_subquery=False, ir_stmt=ir_stmt, ir_set=ir_expr, shape_ptrref=ptrref, actual_ptrref=actual_ptrref, ctx=ctx, ) if ptr_is_multi_required and enforce_cardinality: target_ref = pgast.FuncCall( name=astutils.edgedb_func('raise_on_null', ctx=ctx), args=[ target_ref, pgast.StringConstant(val='not_null_violation'), pgast.NamedFuncArg( name='msg', val=pgast.StringConstant(val='missing value'), ), pgast.NamedFuncArg( name='column', val=pgast.StringConstant(val=str(ptrref.id)), ), ], ) source_data['target'] = (path_id, target_ref) row_query.target_list.append( pgast.ResTarget( name='source', val=pathctx.get_rvar_path_identity_var( dml_rvar, ir_stmt.subject.path_id, env=ctx.env, ), ), ) specified_cols = ['source'] for col, (col_path_id, expr) in source_data.items(): row_query.target_list.append( pgast.ResTarget( val=expr, name=col, ), ) specified_cols.append(col) # XXX: This is dodgy. Do we need to do the dynamic rvar thing? # XXX: And can we make defaults work? pathctx._put_path_output_var( row_query, col_path_id, aspect=pgce.PathAspect.VALUE, var=pgast.ColumnRef(name=[col]), ) link_rows = pgast.CommonTableExpr( query=row_query, name=ctx.env.aliases.get(hint='r'), for_dml_stmt=ctx.get_current_dml_stmt(), ) return link_rows, specified_cols def process_delete_body( *, ir_stmt: irast.DeleteStmt, delete_cte: pgast.CommonTableExpr, typeref: irast.TypeRef, ctx: context.CompilerContextLevel, ) -> None: """Finalize DELETE on an object. The actual DELETE was generated in gen_dml_cte, so we only have work to do here if there are link tables to clean up. """ ctx.toplevel_stmt.append_cte(delete_cte) put_iterator_bond(ctx.enclosing_cte_iterator, delete_cte.query) pointers = ir_stmt.links_to_delete[typeref.id] for ptrref in pointers: target_rvar = relctx.range_for_ptrref( ptrref, for_mutation=True, only_self=True, ctx=ctx) assert isinstance(target_rvar, pgast.RelRangeVar) range_rvar = pgast.RelRangeVar( relation=delete_cte, alias=pgast.Alias( aliasname=ctx.env.aliases.get(hint='range') ) ) where_clause = astutils.new_binop( lexpr=pgast.ColumnRef(name=[ target_rvar.alias.aliasname, 'source' ]), op='=', rexpr=pathctx.get_rvar_path_identity_var( range_rvar, ir_stmt.result.path_id, env=ctx.env) ) del_query = pgast.DeleteStmt( relation=target_rvar, where_clause=where_clause, using_clause=[range_rvar], ) ctx.toplevel_stmt.append_cte(pgast.CommonTableExpr( query=del_query, name=ctx.env.aliases.get(hint='mlink'), for_dml_stmt=ctx.get_current_dml_stmt(), )) # Trigger compilation def compile_triggers( triggers: tuple[tuple[irast.Trigger, ...], ...], stmt: pgast.Base, *, ctx: context.CompilerContextLevel, ) -> None: if not triggers: return assert isinstance(stmt, pgast.Query) if stmt.ctes is None: stmt.ctes = [] start_ctes = len(stmt.ctes) with ctx.new() as ictx: # Clear out type_ctes so that we will recompile them all with # our overlays baked in (trigger_mode = True causes the # overlays to be included), so that access policies still # apply to our "new view" of the database. # FIXME: I think we actually need to keep the old type_ctes # available for pointers off of __old__ to use. ictx.trigger_mode = True ictx.type_rewrite_ctes = {} ictx.type_inheritance_ctes = {} ictx.ordered_type_ctes = [] ictx.toplevel_stmt = stmt for stage in triggers: new_overlays = [] for trigger in stage: ictx.path_scope = ctx.path_scope.new_child() new_overlays.append(compile_trigger(trigger, ctx=ictx)) for overlay in new_overlays: ictx.rel_overlays.type = ( ictx.rel_overlays.type.update(overlay.type)) ictx.rel_overlays.ptr = ( ictx.rel_overlays.ptr.update(overlay.ptr)) # Install any newly created type CTEs before the CTEs created from # this trigger compilation but after anything compiled before. stmt.ctes[start_ctes:start_ctes] = ictx.ordered_type_ctes def compile_trigger( trigger: irast.Trigger, *, ctx: context.CompilerContextLevel, ) -> context.RelOverlays: # N.B: The *base type* overlays have the whole union, while subtypes # just have subtype things. # The things we produce for `affected` take this into account. new_path = trigger.new_set.path_id old_path = trigger.old_set.path_id if trigger.old_set else None # We use overlays to drive the trigger, since with a bit of # tweaking, they contain all the relevant information. overlays: list[context.OverlayEntry] = [] for typeref, dml in trigger.affected: toverlays = ctx.rel_overlays.type[dml] if ov := toverlays.get(typeref.id): overlays.extend(ov) # Handle deletions by turning except into union # Drop FILTER, which is included by update but doesn't help us here overlays = [ ( (context.OverlayOp.UNION, *x[1:]) if x[0] == context.OverlayOp.EXCEPT else x ) for x in overlays if x[0] != context.OverlayOp.FILTER ] # Replace an initial union with REPLACE, since we *don't* want whatever # already existed assert overlays and overlays[0][0] == context.OverlayOp.UNION overlays[0] = (context.OverlayOp.REPLACE, *overlays[0][1:]) # Produce a CTE containing all of the affected objects for this trigger with ctx.newrel() as ictx: ictx.rel_overlays = context.RelOverlays() ictx.rel_overlays.type = immu.Map({ None: immu.Map({trigger.source_type.id: tuple(overlays)}) }) # The range produced here will be driven just by the overlays rvar = relctx.range_for_material_objtype( trigger.source_type, new_path, include_overlays=True, ignore_rewrites=True, ctx=ictx, ) relctx.include_rvar( ictx.rel, rvar, path_id=new_path, ctx=ictx ) # If __old__ is available, we register its identity/value, # but *not* its source. if old_path: new_ident = pathctx.get_path_identity_var( ictx.rel, new_path, env=ctx.env ) pathctx.put_path_identity_var(ictx.rel, old_path, new_ident) pathctx.put_path_value_var(ictx.rel, old_path, new_ident) contents_cte = pgast.CommonTableExpr( query=ictx.rel, name=ctx.env.aliases.get('trig_contents'), materialized=True, # XXX: or not? for_dml_stmt=ctx.get_current_dml_stmt(), ) ictx.toplevel_stmt.append_cte(contents_cte) # Actually compile the trigger with ctx.newrel() as tctx: # With FOR EACH, we use the iterator machinery to iterate over # all of the objects if trigger.scope == qltypes.TriggerScope.Each: tctx.enclosing_cte_iterator = pgast.IteratorCTE( path_id=new_path, cte=contents_cte, parent=None, # old_path gets registered as also appearing in the # iterator cte, and so will get included whenever # merged other_paths=( ((old_path, pgce.PathAspect.IDENTITY),) if old_path else () ), ) merge_iterator(tctx.enclosing_cte_iterator, tctx.rel, ctx=ctx) # While with FOR ALL, we register the sets as external rels else: tctx.external_rels = dict(tctx.external_rels) # new_path is just the contents_cte tctx.external_rels[new_path] = ( contents_cte, (pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE) ) if old_path: # old_path is *also* the contents_cte, but without a source # aspect, so we need to include the real database back in. tctx.external_rels[old_path] = ( contents_cte, (pgce.PathAspect.VALUE, pgce.PathAspect.IDENTITY,) ) # This is somewhat subtle: we merge *every* DML into # the "None" overlay, so that all the new database state shows # up everywhere... but __old__ has a TriggerAnchor set up in # it, which acts like a dml statement, and *diverts* __old__ # away from the new data! # We grab the list of DML out of dml_stmts instead of just # from the overlays for determinism reasons; it effects the # order overlays appear in all_dml = [ x for x in ctx.dml_stmts if isinstance(x, irast.MutatingStmt)] merge_overlays_globally(all_dml, ctx=tctx) # Strip out everything but None. This tidies things up and makes # it easy to detect new additions. tctx.rel_overlays.type = immu.Map({None: tctx.rel_overlays.type[None]}) tctx.rel_overlays.ptr = immu.Map({None: tctx.rel_overlays.ptr[None]}) # Copy over the global overlay to __new__, since it should see # the new data also. # TODO: We should consider building a dedicated __new__overlay # in order to reduce overlay sizes in common cases assert isinstance(trigger.new_set.expr, irast.TriggerAnchor) tctx.rel_overlays.type = tctx.rel_overlays.type.set( trigger.new_set.expr, tctx.rel_overlays.type[None]) tctx.rel_overlays.ptr = tctx.rel_overlays.ptr.set( trigger.new_set.expr, tctx.rel_overlays.ptr[None]) # N.B: Any DML in the trigger will have the "global" overlay (None) # as its starting point. dispatch.compile(trigger.expr, ctx=tctx) # Force the value to get output so that if it might error # it will be forced up by check_ctes pathctx.get_path_value_output( tctx.rel, trigger.expr.path_id, env=ctx.env) pathctx.get_path_serialized_output( tctx.rel, trigger.expr.path_id, env=ctx.env) # If the expression is *just* DML, as an optimization, skip # generating a CTE for the expression and forcing its evaluation # with check_ctes. The actual work is all in a DML CTE so we # don't need to worry about anything more. if ( not isinstance(trigger.expr.expr, irast.MutatingStmt) and not trigger.expr.shape ): trigger_cte = pgast.CommonTableExpr( query=tctx.rel, name=ctx.env.aliases.get('trig_body'), materialized=True, # XXX: or not? for_dml_stmt=ctx.get_current_dml_stmt(), ) tctx.toplevel_stmt.append_cte(trigger_cte) tctx.env.check_ctes.append(trigger_cte) saved_overlays = tctx.rel_overlays.copy() saved_overlays.type = saved_overlays.type.delete(trigger.new_set.expr) saved_overlays.ptr = saved_overlays.ptr.delete(trigger.new_set.expr) return saved_overlays ================================================ FILE: edb/pgsql/compiler/enums.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from edb.common import enum as s_enum class PathAspect(s_enum.StrEnum): IDENTITY = 'identity' VALUE = 'value' SOURCE = 'source' SERIALIZED = 'serialized' ITERATOR = 'iterator' ================================================ FILE: edb/pgsql/compiler/expr.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Compilation handlers for non-statement expressions.""" from __future__ import annotations from typing import Optional, Sequence from edb import errors from edb.edgeql import qltypes as ql_ft from edb.edgeql import ast as qlast from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.ir import utils as irutils from edb.pgsql import ast as pgast from edb.pgsql import common from edb.pgsql import types as pg_types from . import astutils from . import config from . import context from . import dispatch from . import enums as pgce from . import expr as expr_compiler # NOQA from . import output from . import pathctx from . import relgen from . import shapecomp @dispatch.compile.register(irast.Set) def compile_Set( ir_set: irast.Set, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: if ctx.singleton_mode: return dispatch.compile(ir_set.expr, ctx=ctx) is_toplevel = ctx.toplevel_stmt is context.NO_STMT _compile_set_impl(ir_set, ctx=ctx) if is_toplevel: if isinstance(ir_set.expr, irast.ConfigCommand): return config.top_output_as_config_op( ir_set, ctx.rel, env=ctx.env) else: pathctx.get_path_serialized_output( ctx.rel, ir_set.path_id, env=ctx.env) return output.top_output_as_value(ctx.rel, ir_set, env=ctx.env) else: value = pathctx.get_path_value_var( ctx.rel, ir_set.path_id, env=ctx.env) return output.output_as_value(value, env=ctx.env) @dispatch.visit.register(irast.Set) def visit_Set( ir_set: irast.Set, *, ctx: context.CompilerContextLevel) -> None: if ctx.singleton_mode: dispatch.compile(ir_set.expr, ctx=ctx) _compile_set_impl(ir_set, ctx=ctx) def _compile_set_impl( ir_set: irast.Set, *, ctx: context.CompilerContextLevel) -> None: is_toplevel = ctx.toplevel_stmt is context.NO_STMT if isinstance(ir_set.expr, (irast.BaseConstant, irast.BaseParameter)): # Avoid creating needlessly complicated constructs for # constant expressions. Besides being an optimization, # this helps in GROUP BY queries. value = dispatch.compile(ir_set.expr, ctx=ctx) if is_toplevel: ctx.rel = ctx.toplevel_stmt = pgast.SelectStmt() pathctx.put_path_value_var_if_not_exists( ctx.rel, ir_set.path_id, value) if (output.in_serialization_ctx(ctx) and ir_set.shape and not ctx.env.ignore_object_shapes): _compile_shape(ir_set, ir_set.shape, ctx=ctx) elif ir_set.path_scope_id is not None and not is_toplevel: # This Set is behind a scope fence, so compute it # in a fenced context. with ctx.newscope() as scopectx: _compile_set(ir_set, ctx=scopectx) else: # All other sets. _compile_set(ir_set, ctx=ctx) @dispatch.compile.register(irast.QueryParameter) def compile_QueryParameter( expr: irast.QueryParameter, *, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: result: pgast.BaseExpr params = [p for p in ctx.env.query_params if p.name == expr.name] param = params[0] if params else None if param and param.sub_params: return relgen.process_encoded_param(param, ctx=ctx) else: index = ctx.argmap[expr.name].index result = pgast.ParamRef(number=index, nullable=not expr.required) if irtyputils.needs_custom_serialization(expr.typeref): if irtyputils.is_array(expr.typeref): subt = expr.typeref.subtypes[0] el_sql_type = subt.real_base_type.custom_sql_serialization # Arrays of text encoded types need to come in as the custom type result = pgast.TypeCast( arg=result, type_name=pgast.TypeName(name=(f'{el_sql_type}[]',)), ) else: el_sql_type = expr.typeref.real_base_type.custom_sql_serialization assert el_sql_type is not None result = pgast.TypeCast( arg=result, type_name=pgast.TypeName(name=(el_sql_type,)), ) return pgast.TypeCast( arg=result, type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) @dispatch.compile.register(irast.FunctionParameter) def compile_FunctionParameter( expr: irast.FunctionParameter, *, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: result: pgast.BaseExpr if ctx.env.named_param_prefix is not None: # When compiling functions result = pgast.ColumnRef( name=ctx.env.named_param_prefix + (expr.name,), nullable=not expr.required, ) else: # Other things such as constraints index = ctx.argmap[expr.name].index result = pgast.ParamRef(number=index, nullable=not expr.required) return pgast.TypeCast( arg=result, type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) @dispatch.compile.register(irast.StringConstant) def compile_StringConstant( expr: irast.StringConstant, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: return pgast.TypeCast( arg=pgast.StringConstant(val=expr.value), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) @dispatch.compile.register(irast.BytesConstant) def compile_BytesConstant( expr: irast.BytesConstant, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: return pgast.ByteaConstant(val=expr.value) @dispatch.compile.register(irast.FloatConstant) @dispatch.compile.register(irast.DecimalConstant) @dispatch.compile.register(irast.BigintConstant) @dispatch.compile.register(irast.IntegerConstant) def compile_FloatConstant( expr: irast.BaseConstant, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: return pgast.TypeCast( arg=pgast.NumericConstant(val=expr.value), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) @dispatch.compile.register(irast.BooleanConstant) def compile_BooleanConstant( expr: irast.BooleanConstant, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: return pgast.TypeCast( arg=pgast.BooleanConstant(val=expr.value.lower() == 'true'), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) @dispatch.compile.register(irast.TypeCast) def compile_TypeCast( expr: irast.TypeCast, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: pg_expr = dispatch.compile(expr.expr, ctx=ctx) detail: Optional[pgast.StringConstant] = None if expr.error_message_context is not None: detail = pgast.StringConstant( val=( '{"error_message_context": "' + expr.error_message_context + '"}' ) ) if expr.sql_cast: # Use explicit SQL cast. pg_type = pg_types.pg_type_from_ir_typeref(expr.to_type) res: pgast.BaseExpr = pgast.TypeCast( arg=pg_expr, type_name=pgast.TypeName( name=pg_type ) ) elif expr.sql_expr: # Cast implemented as a function. assert expr.cast_name func_name = common.get_cast_backend_name( expr.cast_name, aspect="function", versioned=ctx.env.versioned_stdlib, ) args = [pg_expr] if detail is not None: args.append(detail) res = pgast.FuncCall( name=func_name, args=args, ) elif expr.sql_function: res = pgast.FuncCall( name=tuple(expr.sql_function.split(".")), args=[pg_expr], ) else: raise errors.UnsupportedFeatureError('cast not supported') if expr.cardinality_mod is qlast.CardinalityModifier.Required: args = [ res, pgast.StringConstant( val='invalid_parameter_value', ), pgast.StringConstant( val='invalid null value in cast', ), ] if detail is not None: args.append(detail) res = pgast.FuncCall( name=astutils.edgedb_func('raise_on_null', ctx=ctx), args=args ) return res @dispatch.compile.register(irast.IndexIndirection) def compile_IndexIndirection( expr: irast.IndexIndirection, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: # Handle Expr[Index], where Expr may be std::str, array or # std::json. For strings we translate this into substr calls. # Arrays use the native index access. JSON is handled by using the # `->` accessor. Additionally, in all of the above cases a # boundary-check is performed on the index and an exception is # potentially raised. # line, column and filename are captured here to be used with the # error message span = pgast.StringConstant( val=irutils.get_span_as_json( expr.index, errors.InvalidValueError ) ) with ctx.new() as subctx: subctx.expr_exposed = False subj = dispatch.compile(expr.expr, ctx=subctx) index = dispatch.compile(expr.index, ctx=subctx) result: pgast.BaseExpr = pgast.FuncCall( name=astutils.edgedb_func('_index', ctx=ctx), args=[subj, index, span] ) if irtyputils.is_array(expr.typeref): # Unwrap the nested array from its tuple result = astutils.array_get_inner_array(result, expr.typeref) return result @dispatch.compile.register(irast.SliceIndirection) def compile_SliceIndirection( expr: irast.SliceIndirection, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: # Handle Expr[Index], where Expr may be std::str, array or # std::json. For strings we translate this into substr calls. # Arrays use the native slice syntax. JSON is handled by a # combination of unnesting aggregation and array slicing. with ctx.new() as subctx: subctx.expr_exposed = False subj = dispatch.compile(expr.expr, ctx=subctx) if expr.start is None: start: pgast.BaseExpr = pgast.LiteralExpr(expr="0") else: start = dispatch.compile(expr.start, ctx=subctx) if expr.stop is None: stop: pgast.BaseExpr = pgast.LiteralExpr(expr=str(2**31 - 1)) else: stop = dispatch.compile(expr.stop, ctx=subctx) typ = expr.expr.typeref inline_array_slicing = irtyputils.is_array(typ) and any( irtyputils.is_tuple(st) or irtyputils.is_array(st) for st in typ.subtypes ) if inline_array_slicing: return _inline_array_slicing(subj, start, stop, ctx=ctx) else: return pgast.FuncCall( name=astutils.edgedb_func('_slice', ctx=ctx), args=[subj, start, stop] ) def _inline_array_slicing( subj: pgast.BaseExpr, start: pgast.BaseExpr, stop: pgast.BaseExpr, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: return pgast.Indirection( arg=subj, indirection=[ pgast.Slice( lidx=pgast.FuncCall( name=astutils.edgedb_func( '_normalize_array_slice_index', ctx=ctx), args=[ start, pgast.FuncCall( name=("cardinality",), args=[subj] ), ], ), ridx=astutils.new_binop( lexpr=pgast.FuncCall( name=astutils.edgedb_func( '_normalize_array_slice_index', ctx=ctx), args=[ stop, pgast.FuncCall( name=("cardinality",), args=[subj] ), ], ), op="-", rexpr=pgast.LiteralExpr(expr="1"), ), ) ], ) def _compile_call_args( expr: irast.Call, *, ctx: context.CompilerContextLevel ) -> tuple[list[pgast.BaseExpr], list[pgast.BaseExpr]]: args = [] maybe_null = [] if isinstance(expr, irast.FunctionCall) and expr.global_args: args += [dispatch.compile(arg, ctx=ctx) for arg in expr.global_args] for ir_arg in expr.args.values(): ref = dispatch.compile(ir_arg.expr, ctx=ctx) args.append(ref) if ( not expr.impl_is_strict and ir_arg.cardinality.can_be_zero() and ref.nullable and ir_arg.param_typemod == ql_ft.TypeModifier.SingletonType ): maybe_null.append(ref) return args, maybe_null def _wrap_call( expr: pgast.BaseExpr, maybe_nulls: list[pgast.BaseExpr], *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: # If necessary, use CASE to filter out NULLs while calling a # non-strict function. if maybe_nulls: tests = [pgast.NullTest(arg=arg, negated=True) for arg in maybe_nulls] expr = pgast.CaseExpr( args=[pgast.CaseWhen( expr=astutils.extend_binop(None, *tests, op='AND'), result=expr, )] ) return expr @dispatch.compile.register(irast.OperatorCall) def compile_OperatorCall( expr: irast.OperatorCall, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: if (str(expr.func_shortname) == 'std::IF' and expr.args[0].cardinality.is_single() and expr.args[2].cardinality.is_single()): if_expr, condition, else_expr = (a.expr for a in expr.args.values()) return pgast.CaseExpr( args=[ pgast.CaseWhen( expr=dispatch.compile(condition, ctx=ctx), result=dispatch.compile(if_expr, ctx=ctx)) ], defresult=dispatch.compile(else_expr, ctx=ctx)) elif (str(expr.func_shortname) == 'std::??' and expr.args[0].cardinality.is_single() and expr.args[1].cardinality.is_single()): l_expr, r_expr = (a.expr for a in expr.args.values()) return pgast.CoalesceExpr( args=[ dispatch.compile(l_expr, ctx=ctx), dispatch.compile(r_expr, ctx=ctx), ], ) elif irutils.is_singleton_set_of_call(expr): pass elif irutils.returns_set_of(expr): raise errors.UnsupportedFeatureError( f"set returning operator '{expr.func_shortname}' is not supported " f"in singleton expressions") elif irutils.has_set_of_param(expr): raise errors.UnsupportedFeatureError( f"aggregate operator '{expr.func_shortname}' is not supported " f"in singleton expressions") args, maybe_null = _compile_call_args(expr, ctx=ctx) return _wrap_call( compile_operator(expr, args, ctx=ctx), maybe_null, ctx=ctx) def compile_operator( expr: irast.OperatorCall, args: Sequence[pgast.BaseExpr], *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: lexpr = rexpr = None result: Optional[pgast.BaseExpr] = None if expr.operator_kind is ql_ft.OperatorKind.Infix: lexpr, rexpr = args elif expr.operator_kind is ql_ft.OperatorKind.Prefix: rexpr = args[0] elif expr.operator_kind is ql_ft.OperatorKind.Postfix: lexpr = args[0] else: raise RuntimeError(f'unexpected operator kind: {expr.operator_kind!r}') str_func_name = str(expr.func_shortname) if ((str_func_name in {'std::=', 'std::!='} or str(expr.origin_name) in {'std::=', 'std::!='}) and expr.args[0].expr.typeref is not None and irtyputils.is_object(expr.args[0].expr.typeref) and expr.args[1].expr.typeref is not None and irtyputils.is_object(expr.args[1].expr.typeref)): if str_func_name == 'std::=' or str(expr.origin_name) == 'std::=': sql_oper = '=' else: sql_oper = '!=' elif str_func_name == 'std::EXISTS': assert rexpr result = pgast.NullTest(arg=rexpr, negated=True) elif expr.func_shortname in common.operator_map: sql_oper = common.operator_map[expr.func_shortname] elif expr.sql_operator: sql_oper = expr.sql_operator[0] if len(expr.sql_operator) > 1: # Explicit operand types given in FROM SQL OPERATOR lexpr, rexpr = _cast_operands(lexpr, rexpr, expr.sql_operator[1:]) elif expr.origin_name is not None: sql_oper = common.get_operator_backend_name( expr.origin_name)[1] else: if expr.sql_function: sql_func, *cast_types = expr.sql_function func_name = common.maybe_versioned_name( tuple(sql_func.split('.', 1)), versioned=( ctx.env.versioned_stdlib and expr.func_shortname.get_root_module_name().name != 'ext' ), ) if cast_types: # Explicit operand types given in FROM SQL FUNCTION lexpr, rexpr = _cast_operands(lexpr, rexpr, cast_types) else: func_name = common.get_operator_backend_name( expr.func_shortname, aspect='function', versioned=ctx.env.versioned_stdlib) args = [] if lexpr is not None: args.append(lexpr) if rexpr is not None: args.append(rexpr) result = pgast.FuncCall(name=func_name, args=args) # If result was not already computed, it's going to be a generic Expr. if result is None: result = pgast.Expr( name=sql_oper, lexpr=lexpr, rexpr=rexpr, ) if expr.force_return_cast: # The underlying operator has a return value type # different from that of the EdgeQL operator declaration, # so we need to make an explicit cast here. result = pgast.TypeCast( arg=result, type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) return result def _cast_operands( lexpr: Optional[pgast.BaseExpr], rexpr: Optional[pgast.BaseExpr], sql_types: Sequence[str], ) -> tuple[Optional[pgast.BaseExpr], Optional[pgast.BaseExpr]]: if lexpr is not None: lexpr = pgast.TypeCast( arg=lexpr, type_name=pgast.TypeName( name=(sql_types[0],) ) ) if rexpr is not None: rexpr_qry = None if (isinstance(rexpr, pgast.SubLink) and isinstance(rexpr.expr, pgast.SelectStmt)): rexpr_qry = rexpr.expr elif isinstance(rexpr, pgast.SelectStmt): rexpr_qry = rexpr if rexpr_qry is not None: # Handle cases like foo ANY (SELECT) and # foo (SELECT). rexpr_qry.target_list[0] = pgast.ResTarget( name=rexpr_qry.target_list[0].name, val=pgast.TypeCast( arg=rexpr_qry.target_list[0].val, type_name=pgast.TypeName( name=(sql_types[1],) ) ) ) else: rexpr = pgast.TypeCast( arg=rexpr, type_name=pgast.TypeName( name=(sql_types[1],) ) ) return lexpr, rexpr def get_func_call_backend_name( expr: irast.FunctionCall, *, ctx: context.CompilerContextLevel ) -> tuple[str, ...]: if expr.func_sql_function: # The name might contain a "." if it's one of our # metaschema helpers. func_name = common.maybe_versioned_name( tuple(expr.func_sql_function.split('.', 1)), versioned=( ctx.env.versioned_stdlib and expr.func_shortname.get_root_module_name().name != 'ext' ), ) else: func_name = common.get_function_backend_name( expr.func_shortname, expr.backend_name, versioned=ctx.env.versioned_stdlib) return func_name @dispatch.compile.register(irast.TypeCheckOp) def compile_TypeCheckOp( expr: irast.TypeCheckOp, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: with ctx.new() as newctx: newctx.expr_exposed = False left = dispatch.compile(expr.left, ctx=newctx) negated = expr.op == 'IS NOT' result: pgast.BaseExpr if expr.result is not None: result = pgast.BooleanConstant( val=(expr.result and not negated) ) else: right: pgast.BaseExpr if expr.right.union: right = pgast.ArrayExpr( elements=[ dispatch.compile(c, ctx=newctx) for c in expr.right.union ] ) else: right = dispatch.compile(expr.right, ctx=newctx) result = pgast.FuncCall( name=astutils.edgedb_func('issubclass', ctx=ctx), args=[left, right]) if negated: result = astutils.new_unop('NOT', result) return result @dispatch.compile.register(irast.ConstantSet) def compile_ConstantSet( expr: irast.ConstantSet, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: raise errors.UnsupportedFeatureError( "Constant sets not allowed in singleton mode", hint="Are you passing a set into a variadic function?") @dispatch.compile.register(irast.Array) def compile_Array( expr: irast.Array, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: elements = [dispatch.compile(e, ctx=ctx) for e in expr.elements] return relgen.build_array_expr(expr, elements, ctx=ctx) @dispatch.compile.register(irast.Tuple) def compile_Tuple( expr: irast.Tuple, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: ttype = expr.typeref ttypes = {} for i, st in enumerate(ttype.subtypes): if st.element_name: ttypes[st.element_name] = st else: ttypes[str(i)] = st telems = list(ttypes) elements = [] for i, e in enumerate(expr.elements): telem = telems[i] ttype = ttypes[telem] val = dispatch.compile(e.val, ctx=ctx) assert e.path_id elements.append(pgast.TupleElement(path_id=e.path_id, val=val)) result = pgast.TupleVar(elements=elements, typeref=ttype) return output.output_as_value(result, env=ctx.env) @dispatch.compile.register(irast.TypeRef) def compile_TypeRef( expr: irast.TypeRef, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: return astutils.compile_typeref(expr) @dispatch.compile.register(irast.TypeIntrospection) def compile_TypeIntrospection( expr: irast.TypeIntrospection, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: return astutils.compile_typeref(expr.output_typeref) @dispatch.compile.register(irast.FunctionCall) def compile_FunctionCall( expr: irast.FunctionCall, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: fname = str(expr.func_shortname) if sfunc := relgen._SIMPLE_SPECIAL_FUNCTIONS.get(fname): return sfunc(expr, ctx=ctx) if expr.func_sql_expr: raise errors.UnsupportedFeatureError( f'unimplemented function for singleton mode: {fname}' ) if irutils.is_singleton_set_of_call(expr): pass elif irutils.returns_set_of(expr): raise errors.UnsupportedFeatureError( 'set returning functions are not supported in simple expressions') elif irutils.has_set_of_param(expr): raise errors.UnsupportedFeatureError( f"aggregate function '{expr.func_shortname}' is not supported " f"in singleton expressions") args, maybe_null = _compile_call_args(expr, ctx=ctx) if expr.has_empty_variadic and expr.variadic_param_type is not None: var = pgast.TypeCast( arg=pgast.ArrayExpr(elements=[]), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.variadic_param_type) ) ) args.append(pgast.VariadicArgument(expr=var)) name = get_func_call_backend_name(expr, ctx=ctx) result: pgast.BaseExpr = pgast.FuncCall(name=name, args=args) result = _wrap_call(result, maybe_null, ctx=ctx) if expr.force_return_cast: # The underlying function has a return value type # different from that of the EdgeQL function declaration, # so we need to make an explicit cast here. result = pgast.TypeCast( arg=result, type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) return result def _tuple_to_row_expr( tuple_set: irast.Set, *, ctx: context.CompilerContextLevel, ) -> pgast.ImplicitRowExpr | pgast.RowExpr: tuple_val = dispatch.compile(tuple_set, ctx=ctx) if not isinstance(tuple_val, (pgast.RowExpr, pgast.ImplicitRowExpr)): raise RuntimeError('tuple compilation unexpectedly did ' 'not return a RowExpr or ImplicitRowExpr') return tuple_val def _compile_set( ir_set: irast.Set, *, ctx: context.CompilerContextLevel) -> None: relgen.get_set_rvar(ir_set, ctx=ctx) if (output.in_serialization_ctx(ctx) and ir_set.shape and not ctx.env.ignore_object_shapes): _compile_shape(ir_set, ir_set.shape, ctx=ctx) elif ir_set.shape and ir_set in ctx.shapes_needed_by_dml: # If this shape is needed for DML purposes (because it is # populating link properties), compile it and populate its # elements as *values*, so that process_link_values can pick # them up and insert them. shape_tuple = shapecomp.compile_shape(ir_set, ir_set.shape, ctx=ctx) for element in shape_tuple.elements: pathctx.put_path_var_if_not_exists( ctx.rel, element.path_id, element.val, aspect=pgce.PathAspect.VALUE, ) def _compile_shape( ir_set: irast.Set, shape: Sequence[tuple[irast.SetE[irast.Pointer], qlast.ShapeOp]], *, ctx: context.CompilerContextLevel) -> pgast.TupleVar: result = shapecomp.compile_shape(ir_set, shape, ctx=ctx) for element in result.elements: # We want to force, because the path id might already exist # serialized with a different shape, and we need ours to be # visible. (Anything needing the old one needs to have pulled # it already: see the "unfortunate hack" in # process_set_as_tuple.) pathctx.put_path_serialized_var( ctx.rel, element.path_id, element.val, force=True ) # When we compile a shape during materialization, stash the # set away so we can consume it in unpack_rvar. if ( ctx.materializing and ir_set.typeref.id not in ctx.env.materialized_views ): ctx.env.materialized_views[ir_set.typeref.id] = ir_set ser_elements = [] for el in result.elements: ser_val = pathctx.get_path_serialized_or_value_var( ctx.rel, el.path_id, env=ctx.env) ser_elements.append(pgast.TupleElement( path_id=el.path_id, name=el.name, val=ser_val )) # Don't let the elements themselves leak out, since they may # be wrapped in arrays. pathctx.put_path_id_mask(ctx.rel, el.path_id) ser_result = pgast.TupleVar(elements=ser_elements, named=True) sval = output.serialize_expr( ser_result, path_id=ir_set.path_id, env=ctx.env ) pathctx.put_path_serialized_var(ctx.rel, ir_set.path_id, sval, force=True) return result @dispatch.compile.register def compile_EmptySet( expr: irast.EmptySet, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: return pgast.NullConstant() @dispatch.compile.register def compile_TypeRoot( expr: irast.TypeRoot, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: name = [common.edgedb_name_to_pg_name(str(expr.typeref.id))] if irtyputils.is_object(expr.typeref): name.append('id') return pgast.ColumnRef(name=name) @dispatch.compile.register def compile_Pointer( rptr: irast.Pointer, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: assert ctx.singleton_mode if rptr.expr: return dispatch.compile(rptr.expr, ctx=ctx) ptrref = rptr.ptrref source = rptr.source if ptrref.source_ptr is None and isinstance(source.expr, irast.Pointer): raise errors.UnsupportedFeatureError( 'unexpectedly long path in simple expr') # In most cases, we don't need to reference the rvar (since there # will be only one in scope), but sometimes we do (for example NEW # in trigger functions). rvar_name = [] if src := ctx.env.external_rvars.get( (source.path_id, pgce.PathAspect.SOURCE) ): rvar_name = [src.alias.aliasname] # compile column name ptr_stor_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=False) colref = pgast.ColumnRef( name=rvar_name + [ptr_stor_info.column_name], nullable=rptr.dir_cardinality.can_be_zero()) return colref @dispatch.compile.register def compile_TupleIndirectionPointer( rptr: irast.TupleIndirectionPointer, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: tuple_val = dispatch.compile(rptr.source, ctx=ctx) set_expr = astutils.tuple_getattr( tuple_val, rptr.source.typeref, rptr.ptrref.shortname.name, ) return set_expr @dispatch.compile.register(irast.FTSDocument) def compile_FTSDocument( expr: irast.FTSDocument, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: return pgast.FTSDocument( text=dispatch.compile(expr.text, ctx=ctx), language=dispatch.compile(expr.language, ctx=ctx), language_domain=expr.language_domain, weight=expr.weight, ) ================================================ FILE: edb/pgsql/compiler/group.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, AbstractSet from edb.edgeql import ast as qlast from edb.edgeql import desugar_group from edb.ir import ast as irast from edb.ir import utils as irutils from edb.pgsql import ast as pgast from . import astutils from . import clauses from . import context from . import dispatch from . import enums as pgce from . import output from . import pathctx from . import relctx from . import relgen def compile_grouping_atom( el: qlast.GroupingAtom, stmt: irast.GroupStmt, *, ctx: context.CompilerContextLevel ) -> pgast.Base: '''Compile a GroupingAtom into sql grouping sets''' if isinstance(el, qlast.GroupingIdentList): return pgast.GroupingOperation( args=[ compile_grouping_atom(at, stmt, ctx=ctx) for at in el.elements ], ) assert isinstance(el, qlast.ObjectRef) alias_set, _ = stmt.using[el.name] return pathctx.get_path_value_var( ctx.rel, alias_set.path_id, env=ctx.env) def compile_grouping_el( el: qlast.GroupingElement, stmt: irast.GroupStmt, *, ctx: context.CompilerContextLevel ) -> pgast.Base: '''Compile a GroupingElement into sql grouping sets''' if isinstance(el, qlast.GroupingSets): return pgast.GroupingOperation( operation='GROUPING SETS', args=[compile_grouping_el(sub, stmt, ctx=ctx) for sub in el.sets], ) elif isinstance(el, qlast.GroupingOperation): return pgast.GroupingOperation( operation=el.oper, args=[ compile_grouping_atom(at, stmt, ctx=ctx) for at in el.elements ], ) elif isinstance(el, qlast.GroupingSimple): return compile_grouping_atom(el.element, stmt, ctx=ctx) raise AssertionError('Unknown GroupingElement') def _compile_grouping_value( stmt: irast.GroupStmt, used_args: AbstractSet[str], *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: '''Produce the value for the grouping binding saying what is grouped on''' assert stmt.grouping_binding grouprel = ctx.rel # If there is only one grouping set, hardcode the output if all(isinstance(b, qlast.GroupingSimple) for b in stmt.by): return pgast.ArrayExpr( elements=[ pgast.StringConstant(val=desugar_group.key_name(arg)) for arg in used_args ], ) using = {k: stmt.using[k] for k in used_args} args = [ pathctx.get_path_var( grouprel, alias_set.path_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) for alias_set, _ in using.values() ] # Call grouping on each element we group on to produce a bitmask grouping_alias = ctx.env.aliases.get('g') grouping_call = pgast.FuncCall(name=('grouping',), args=args) subq = pgast.SelectStmt( target_list=[ pgast.ResTarget(name=grouping_alias, val=grouping_call), ] ) q = pgast.SelectStmt( from_clause=[pgast.RangeSubselect( subquery=subq, alias=pgast.Alias(aliasname=ctx.env.aliases.get()) )] ) grouping_ref = pgast.ColumnRef(name=(grouping_alias,)) # Generate a call to ARRAY[...] with a case for each grouping # element, then array_remove out the NULLs. els: list[pgast.BaseExpr] = [] for i, name in enumerate(using): name = desugar_group.key_name(name) mask = 1 << (len(using) - i - 1) # (CASE (e & ) WHEN 0 THEN '' ELSE NULL END) els.append(pgast.CaseExpr( arg=pgast.Expr( name='&', lexpr=grouping_ref, rexpr=pgast.LiteralExpr(expr=str(mask)) ), args=[ pgast.CaseWhen( expr=pgast.LiteralExpr(expr='0'), result=pgast.StringConstant(val=name) ) ], defresult=pgast.NullConstant() )) val = pgast.FuncCall( name=('array_remove',), args=[pgast.ArrayExpr(elements=els), pgast.NullConstant()] ) q.target_list.append(pgast.ResTarget(val=val)) return q def _compile_grouping_binding( stmt: irast.GroupStmt, *, used_args: AbstractSet[str], ctx: context.CompilerContextLevel) -> None: assert stmt.grouping_binding pathctx.put_path_var( ctx.rel, stmt.grouping_binding.path_id, _compile_grouping_value(stmt, used_args=used_args, ctx=ctx), aspect=pgce.PathAspect.VALUE, ) def _compile_group( stmt: irast.GroupStmt, *, ctx: context.CompilerContextLevel, parent_ctx: context.CompilerContextLevel) -> pgast.BaseExpr: clauses.compile_volatile_bindings(stmt, ctx=ctx) query = ctx.stmt # Compile a GROUP BY into a subquery, along with all the aggregations with ctx.subrel() as groupctx: grouprel = groupctx.rel # First compile the actual subject # subrel *solely* for path id map reasons with groupctx.subrel() as subjctx: subjctx.expr_exposed = False dispatch.visit(stmt.subject, ctx=subjctx) if stmt.subject.path_id.is_objtype_path(): # This shouldn't technically be needed but we generate # better code with it. relgen.ensure_source_rvar( stmt.subject, subjctx.rel, ctx=subjctx) subj_rvar = relctx.rvar_for_rel( subjctx.rel, ctx=groupctx, lateral=True) aspects = pathctx.list_path_aspects(subjctx.rel, stmt.subject.path_id) pathctx.put_path_id_map( subjctx.rel, stmt.group_binding.path_id, stmt.subject.path_id) # update_mask=False because we are doing this solely to remap # elements individually and don't want to affect the mask. relctx.include_rvar( grouprel, subj_rvar, stmt.group_binding.path_id, aspects=aspects, update_mask=False, ctx=groupctx) relctx.include_rvar( grouprel, subj_rvar, stmt.subject.path_id, aspects=aspects, update_mask=False, ctx=groupctx) # Now we compile the bindings groupctx.path_scope = subjctx.path_scope.new_child() groupctx.path_scope[stmt.group_binding.path_id] = None if stmt.grouping_binding: groupctx.path_scope[stmt.grouping_binding.path_id] = None # Compile all the 'using' items for _alias, (value, using_card) in stmt.using.items(): # If the using bit is nullable, we need to compile it # as optional, or we'll get in trouble. # TODO: Can we do better here and not do this # in obvious cases like directly referencing an optional # property. if using_card.can_be_zero(): groupctx.force_optional = ctx.force_optional | {value.path_id} groupctx.path_scope[value.path_id] = None dispatch.visit(value, ctx=groupctx) groupctx.force_optional = ctx.force_optional # Compile all of the aggregate calls that we found. This lets us # compute things like sum and count without needing to materialize # the result. for group_use, skippable in stmt.group_aggregate_sets.items(): if not group_use: continue with groupctx.subrel() as hoistctx: hoistctx.skippable_sources |= skippable assert irutils.is_set_instance(group_use, irast.FunctionCall) relgen.process_set_as_agg_expr_inner( group_use, aspect=pgce.PathAspect.VALUE, wrapper=None, for_group_by=True, ctx=hoistctx, ) pathctx.get_path_value_output( rel=hoistctx.rel, path_id=group_use.path_id, env=ctx.env) pathctx.put_path_value_var( grouprel, group_use.path_id, hoistctx.rel ) packed = False # Materializing the actual grouping sets if None in stmt.group_aggregate_sets: packed = True # TODO: Be able to directly output the final serialized version # if it is consumed directly. with context.output_format(ctx, context.OutputFormat.NATIVE), ( groupctx.new()) as matctx: matctx.materializing |= {stmt} matctx.expr_exposed = True mat_qry = relgen.set_as_subquery( stmt.group_binding, as_value=True, ctx=matctx) mat_qry = relctx.set_to_array( path_id=stmt.group_binding.path_id, for_group_by=True, query=mat_qry, ctx=matctx) if not mat_qry.target_list[0].name: mat_qry.target_list[0].name = ctx.env.aliases.get('v') ref = pgast.ColumnRef( name=[mat_qry.target_list[0].name], is_packed_multi=True, ) pathctx.put_path_packed_output( mat_qry, stmt.group_binding.path_id, ref) pathctx.put_path_var( grouprel, stmt.group_binding.path_id, mat_qry, aspect=pgce.PathAspect.VALUE, flavor='packed', ) used_args = desugar_group.collect_grouping_atoms(stmt.by) if stmt.grouping_binding: _compile_grouping_binding(stmt, used_args=used_args, ctx=groupctx) # We want to make sure that every grouping key is associated # with exactly one output from the query. The means that # tuples must be packed up and keys must not have an extra # serialized output. # # We do this by manually packing up any TupleVarBases and # copying value aspects to serialized. # of the grouping keys get an extra serialized output from # grouprel, so we just copy all their value aspects to their # serialized aspects. using = {k: stmt.using[k] for k in used_args} for using_val, _ in using.values(): uvar = pathctx.get_path_var( grouprel, using_val.path_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) if isinstance(uvar, pgast.TupleVarBase): uvar = output.output_as_value(uvar, env=ctx.env) pathctx.put_path_var( grouprel, using_val.path_id, uvar, aspect=pgce.PathAspect.VALUE, force=True, ) uout = pathctx.get_path_output( grouprel, using_val.path_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) pathctx._put_path_output_var( grouprel, using_val.path_id, pgce.PathAspect.SERIALIZED, uout, ) grouprel.group_clause = [ compile_grouping_el(el, stmt, ctx=groupctx) for el in stmt.by ] group_rvar = relctx.rvar_for_rel(grouprel, ctx=ctx, lateral=True) if packed: relctx.include_rvar( query, group_rvar, path_id=stmt.group_binding.path_id, flavor='packed', update_mask=False, pull_namespace=False, aspects=(pgce.PathAspect.VALUE,), ctx=ctx) else: # Not include_rvar because we don't actually provide the path id! relctx.rel_join(query, group_rvar, ctx=ctx) # Set up the hoisted aggregates and bindings to be found # in the group subquery. for group_use in [ *stmt.group_aggregate_sets, *[x for x, _ in stmt.using.values()], stmt.grouping_binding, ]: if group_use: pathctx.put_path_rvar( query, group_use.path_id, group_rvar, aspect=pgce.PathAspect.VALUE, ) vol_ref = None def _get_volatility_ref() -> Optional[pgast.BaseExpr]: nonlocal vol_ref if vol_ref: return vol_ref name = ctx.env.aliases.get('key') grouprel.target_list.append( pgast.ResTarget( name=name, val=pgast.FuncCall(name=('row_number',), args=[], over=pgast.WindowDef()) ) ) vol_ref = pgast.ColumnRef(name=[group_rvar.alias.aliasname, name]) return vol_ref with ctx.new() as outctx: # Inherit the path_scope we made earlier (with the GROUP bindings # removed), so that we'll always look for those in the right place. outctx.path_scope = groupctx.path_scope outctx.volatility_ref += (lambda stmt, xctx: _get_volatility_ref(),) # Process materialized sets clauses.compile_materialized_exprs(query, stmt, ctx=outctx) clauses.compile_output(stmt.result, ctx=outctx) with ctx.new() as ictx: ictx.path_scope = groupctx.path_scope # FILTER and ORDER BY need to have the base result as a # volatility ref. clauses.setup_iterator_volatility(stmt.result, ctx=ictx) # The FILTER clause. if stmt.where is not None: query.where_clause = astutils.extend_binop( query.where_clause, clauses.compile_filter_clause( stmt.where, stmt.where_card, ctx=ictx)) # The ORDER BY clause if stmt.orderby is not None: with ictx.new() as octx: query.sort_clause = clauses.compile_orderby_clause( stmt.orderby, ctx=octx) return query def compile_group( stmt: irast.GroupStmt, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: with ctx.substmt() as sctx: return _compile_group(stmt, ctx=sctx, parent_ctx=ctx) ================================================ FILE: edb/pgsql/compiler/output.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Compilation helpers for output formatting and serialization.""" from __future__ import annotations from typing import Optional, Sequence import itertools from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.schema import casts as s_casts from edb.schema import defines as s_defs from edb.schema import name as sn from edb.pgsql import ast as pgast from edb.pgsql import common from edb.pgsql import types as pgtypes from . import astutils from . import context _JSON_FORMATS = {context.OutputFormat.JSON, context.OutputFormat.JSON_ELEMENTS} def _get_json_func( name: str, *, output_format: Optional[context.OutputFormat] = None, env: context.Environment, ) -> tuple[str, ...]: if output_format is None: output_format = env.output_format if output_format in _JSON_FORMATS: prefix_suffix = 'json' else: prefix_suffix = 'jsonb' if name == 'to': return (f'{name}_{prefix_suffix}',) else: return (f'{prefix_suffix}_{name}',) def _build_json( name: str, args: Sequence[pgast.BaseExpr], *, null_safe: bool = False, ser_safe: bool = False, nullable: Optional[bool] = None, env: context.Environment, ) -> pgast.BaseExpr: # PostgreSQL has a limit on the maximum number of arguments # passed to a function call, so we must chop input into chunks # if the argument count is greater then the limit. if len(args) > s_defs.MAX_FUNC_ARG_COUNT: json_func = _get_json_func( name, output_format=context.OutputFormat.JSONB, env=env, ) chunk_iters = [iter(args)] * s_defs.MAX_FUNC_ARG_COUNT chunks = list(itertools.zip_longest(*chunk_iters, fillvalue=None)) if len(args) != len(chunks) * s_defs.MAX_FUNC_ARG_COUNT: chunks[-1] = tuple(filter(None, chunks[-1])) result: pgast.BaseExpr = pgast.FuncCall( name=json_func, args=list(chunks[0]), null_safe=null_safe, ser_safe=ser_safe, nullable=nullable, ) for chunk in chunks[1:]: fc = pgast.FuncCall( name=json_func, args=list(chunk), null_safe=null_safe, ser_safe=ser_safe, nullable=nullable, ) result = astutils.new_binop( lexpr=result, rexpr=fc, op='||', ) if env.output_format in _JSON_FORMATS: result = pgast.TypeCast( arg=result, type_name=pgast.TypeName( name=('json',) ) ) return result else: json_func = _get_json_func(name, env=env) return pgast.FuncCall( name=json_func, args=args, null_safe=null_safe, ser_safe=ser_safe, nullable=nullable, ) def coll_as_json_object( expr: pgast.BaseExpr, *, styperef: irast.TypeRef, env: context.Environment, ) -> pgast.BaseExpr: if irtyputils.is_tuple(styperef): return tuple_as_json_object(expr, styperef=styperef, env=env) elif irtyputils.is_array(styperef): return array_as_json_object(expr, styperef=styperef, env=env) else: raise RuntimeError(f'{styperef!r} is not a collection') def array_as_json_object( expr: pgast.BaseExpr, *, styperef: irast.TypeRef, env: context.Environment, ) -> pgast.BaseExpr: el_type = styperef.subtypes[0] is_tuple = irtyputils.is_tuple(el_type) # Tuples/ranges/scalars with custom casts need underlying casts to be done if ( is_tuple or irtyputils.is_range(el_type) or irtyputils.is_multirange(el_type) or el_type.real_base_type.needs_custom_json_cast ): coldeflist = [] out_alias = env.aliases.get('q') val: pgast.BaseExpr if is_tuple: json_args: list[pgast.BaseExpr] = [] is_named = any(st.element_name for st in el_type.subtypes) for i, st in enumerate(el_type.subtypes): if is_named: colname = st.element_name assert colname json_args.append(pgast.StringConstant(val=colname)) else: colname = str(i) val = pgast.ColumnRef(name=[colname]) val = serialize_expr_to_json( val, styperef=st, nested=True, env=env) json_args.append(val) if not irtyputils.is_persistent_tuple(el_type): # Column definition list is only allowed for functions # returning "record", i.e. an anonymous tuple, which # would not be the case for schema-persistent tuple types. coldeflist.append( pgast.ColumnDef( name=colname, typename=pgast.TypeName( name=pgtypes.pg_type_from_ir_typeref(st) ) ) ) json_func = 'build_object' if is_named else 'build_array' agg_arg = _build_json(json_func, json_args, env=env) needs_unnest = bool(el_type.subtypes) else: val = pgast.ColumnRef(name=[out_alias]) agg_arg = serialize_expr_to_json( val, styperef=el_type, nested=True, env=env) needs_unnest = True return pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.CoalesceExpr( args=[ pgast.FuncCall( name=_get_json_func('agg', env=env), args=[agg_arg], ), pgast.StringConstant(val='[]'), ] ), ser_safe=True, ) ], from_clause=[ pgast.RangeFunction( alias=pgast.Alias(aliasname=out_alias), is_rowsfrom=True, functions=[ pgast.FuncCall( name=('unnest',), args=[expr], coldeflist=coldeflist, ) ] ) ] if needs_unnest else [], ) elif irtyputils.is_array(el_type): # array> implemented as array>> # # If we serialize without any special handling, the tuple will be # included, with the key 'f1' # # eg. [[1, 2, 3], [4, 5]] -> [{'f1': [1,2,3]}, {'f1': [4,5]}] # # To prevent this, we need to explicitly serialize the inner arrays then # aggregate them. el_name = 'f1' coldeflist = [ pgast.ColumnDef( name=str(el_name), typename=pgast.TypeName( name=pgtypes.pg_type_from_ir_typeref(el_type), ), ) ] unwrapped_inner_array = pgast.RangeFunction( functions=[ pgast.FuncCall( name=('unnest',), args=[expr], coldeflist=coldeflist, ) ] ) serialized_inner_array = serialize_expr_to_json( pgast.ColumnRef(name=[str(el_name)]), styperef=el_type, nested=True, env=env, ) return pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.CoalesceExpr( args=[ pgast.FuncCall( name=_get_json_func('agg', env=env), args=[serialized_inner_array], ), pgast.StringConstant(val='[]'), ] ), ser_safe=True, ) ], from_clause=[unwrapped_inner_array] ) else: return pgast.FuncCall( name=_get_json_func('to', env=env), args=[expr], null_safe=True, ser_safe=True) def tuple_as_json_object( expr: pgast.BaseExpr, *, styperef: irast.TypeRef, env: context.Environment, ) -> pgast.BaseExpr: if any(st.element_name for st in styperef.subtypes): return named_tuple_as_json_object(expr, styperef=styperef, env=env) else: return unnamed_tuple_as_json_object(expr, styperef=styperef, env=env) def unnamed_tuple_as_json_object( expr: pgast.BaseExpr, *, styperef: irast.TypeRef, env: context.Environment, ) -> pgast.BaseExpr: vals: list[pgast.BaseExpr] = [] if irtyputils.is_persistent_tuple(styperef): for el_idx, el_type in enumerate(styperef.subtypes): val: pgast.BaseExpr = pgast.Indirection( arg=expr, indirection=[pgast.RecordIndirectionOp(name=str(el_idx))], ) val = serialize_expr_to_json( val, styperef=el_type, nested=True, env=env) vals.append(val) obj = _build_json( 'build_array', args=vals, null_safe=True, ser_safe=True, nullable=expr.nullable, env=env, ) else: coldeflist = [] for el_idx, el_type in enumerate(styperef.subtypes): coldeflist.append(pgast.ColumnDef( name=str(el_idx), typename=pgast.TypeName( name=pgtypes.pg_type_from_ir_typeref(el_type), ), )) val = pgast.ColumnRef(name=[str(el_idx)]) val = serialize_expr_to_json( val, styperef=el_type, nested=True, env=env) vals.append(val) obj = _build_json( 'build_array', args=vals, null_safe=True, ser_safe=True, nullable=expr.nullable, env=env, ) obj = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=obj, ), ], from_clause=[ pgast.RangeFunction( functions=[ pgast.FuncCall( name=('unnest',), args=[ pgast.ArrayExpr( elements=[expr], ) ], coldeflist=coldeflist, ) ] ) ] if styperef.subtypes else [] ) if expr.nullable: obj = pgast.SelectStmt( target_list=[pgast.ResTarget(val=obj)], where_clause=pgast.NullTest(arg=expr, negated=True) ) return obj def named_tuple_as_json_object( expr: pgast.BaseExpr, *, styperef: irast.TypeRef, env: context.Environment, ) -> pgast.BaseExpr: keyvals: list[pgast.BaseExpr] = [] if irtyputils.is_persistent_tuple(styperef): for el_type in styperef.subtypes: assert el_type.element_name keyvals.append(pgast.StringConstant(val=el_type.element_name)) val: pgast.BaseExpr = pgast.Indirection( arg=expr, indirection=[ pgast.RecordIndirectionOp( name=el_type.element_name ) ] ) val = serialize_expr_to_json( val, styperef=el_type, nested=True, env=env) keyvals.append(val) obj = _build_json( 'build_object', args=keyvals, null_safe=True, ser_safe=True, nullable=expr.nullable, env=env, ) else: coldeflist = [] for el_type in styperef.subtypes: assert el_type.element_name keyvals.append(pgast.StringConstant(val=el_type.element_name)) coldeflist.append(pgast.ColumnDef( name=el_type.element_name, typename=pgast.TypeName( name=pgtypes.pg_type_from_ir_typeref(el_type), ), )) val = pgast.ColumnRef(name=[el_type.element_name]) val = serialize_expr_to_json( val, styperef=el_type, nested=True, env=env) keyvals.append(val) obj = _build_json( 'build_object', args=keyvals, null_safe=True, ser_safe=True, nullable=expr.nullable, env=env, ) obj = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=obj, ), ], from_clause=[ pgast.RangeFunction( functions=[ pgast.FuncCall( name=('unnest',), args=[ pgast.ArrayExpr( elements=[expr], ) ], coldeflist=coldeflist, ) ] ) ] if styperef.subtypes else [] ) if expr.nullable: obj = pgast.SelectStmt( target_list=[pgast.ResTarget(val=obj)], where_clause=pgast.NullTest(arg=expr, negated=True) ) return obj def tuple_var_as_json_object( tvar: pgast.TupleVar, *, styperef: irast.TypeRef, env: context.Environment, ) -> pgast.BaseExpr: if not tvar.named: vals = [ serialize_expr(t.val, path_id=t.path_id, nested=True, env=env) for t in tvar.elements ] return _build_json( 'build_array', args=vals, null_safe=True, ser_safe=True, nullable=tvar.nullable, env=env, ) else: keyvals: list[pgast.BaseExpr] = [] for element in tvar.elements: rptr = element.path_id.rptr() assert rptr is not None name = rptr.shortname.name if rptr.source_ptr is not None: name = '@' + name keyvals.append(pgast.StringConstant(val=name)) val = serialize_expr( element.val, path_id=element.path_id, nested=True, env=env) keyvals.append(val) return _build_json( 'build_object', args=keyvals, null_safe=True, ser_safe=True, nullable=tvar.nullable, env=env, ) def in_serialization_ctx(ctx: context.CompilerContextLevel) -> bool: return ctx.expr_exposed is None or ctx.expr_exposed def serialize_custom_tuple( expr: pgast.BaseExpr, *, styperef: irast.TypeRef, env: context.Environment, ) -> pgast.BaseExpr: """Serialize a tuple that needs custom serialization for a component""" vals: list[pgast.BaseExpr] = [] obj: pgast.BaseExpr if irtyputils.is_persistent_tuple(styperef): for el_idx, el_type in enumerate(styperef.subtypes): val: pgast.BaseExpr = pgast.Indirection( arg=expr, indirection=[ pgast.RecordIndirectionOp(name=str(el_idx)), ], ) val = output_as_value( val, ser_typeref=el_type, env=env) vals.append(val) obj = _row(vals) else: coldeflist = [] for el_idx, el_type in enumerate(styperef.subtypes): coldeflist.append(pgast.ColumnDef( name=str(el_idx), typename=pgast.TypeName( name=pgtypes.pg_type_from_ir_typeref(el_type), ), )) val = pgast.ColumnRef(name=[str(el_idx)]) val = output_as_value( val, ser_typeref=el_type, env=env) vals.append(val) obj = _row(vals) obj = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=obj, ), ], from_clause=[ pgast.RangeFunction( functions=[ pgast.FuncCall( name=('unnest',), args=[ pgast.ArrayExpr( elements=[expr], ) ], coldeflist=coldeflist, ) ] ) ] if styperef.subtypes else [] ) if expr.nullable: obj = pgast.SelectStmt( target_list=[pgast.ResTarget(val=obj)], where_clause=pgast.NullTest(arg=expr, negated=True) ) return obj def serialize_custom_array( expr: pgast.BaseExpr, *, styperef: irast.TypeRef, env: context.Environment, ) -> pgast.BaseExpr: """Serialize an array that needs custom serialization for a component""" el_type = styperef.subtypes[0] is_tuple = irtyputils.is_tuple(el_type) if is_tuple: coldeflist = [] out_alias = env.aliases.get('q') val: pgast.BaseExpr args: list[pgast.BaseExpr] = [] is_named = any(st.element_name for st in el_type.subtypes) for i, st in enumerate(el_type.subtypes): if is_named: colname = st.element_name assert colname args.append(pgast.StringConstant(val=colname)) else: colname = str(i) val = pgast.ColumnRef(name=[colname]) val = output_as_value(val, ser_typeref=st, env=env) args.append(val) if not irtyputils.is_persistent_tuple(el_type): # Column definition list is only allowed for functions # returning "record", i.e. an anonymous tuple, which # would not be the case for schema-persistent tuple types. coldeflist.append( pgast.ColumnDef( name=colname, typename=pgast.TypeName( name=pgtypes.pg_type_from_ir_typeref(st) ) ) ) agg_arg: pgast.BaseExpr = _row(args) return pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.CoalesceExpr( args=[ pgast.FuncCall( name=('array_agg',), args=[agg_arg], ), pgast.TypeCast( arg=pgast.ArrayExpr(elements=[]), type_name=pgast.TypeName(name=('record[]',)), ), ] ), ser_safe=True, ) ], from_clause=[ pgast.RangeFunction( alias=pgast.Alias(aliasname=out_alias), is_rowsfrom=True, functions=[ pgast.FuncCall( name=('unnest',), args=[expr], coldeflist=coldeflist, ) ] ) ] ) else: el_sql_type = el_type.real_base_type.custom_sql_serialization return pgast.TypeCast( arg=expr, type_name=pgast.TypeName(name=(f'{el_sql_type}[]',)), ) def _row( args: list[pgast.BaseExpr] ) -> pgast.ImplicitRowExpr | pgast.RowExpr: if len(args) > 1: return pgast.ImplicitRowExpr(args=args) else: return pgast.RowExpr(args=args) def output_as_value( expr: pgast.BaseExpr, *, ser_typeref: Optional[irast.TypeRef] = None, env: context.Environment) -> pgast.BaseExpr: """Format an expression as a proper value. Normally this just means packing TupleVars into real expressions, but if ser_typeref is provided, we also will do binary serialization. In particular, certain types actually need to be serialized as text or or some other format, and we handle that here. """ needs_custom_serialization = ser_typeref and ( irtyputils.needs_custom_serialization(ser_typeref)) val = expr if isinstance(expr, pgast.TupleVar): if ( env.output_format is context.OutputFormat.NATIVE_INTERNAL and len(expr.elements) == 1 and (path_id := (el0 := expr.elements[0]).path_id) is not None and (rptr_name := path_id.rptr_name()) is not None and (rptr_name.name == 'id') ): # This is is a special mode whereby bare refs to objects # are serialized to UUID values. return output_as_value(el0.val, env=env) ser_typerefs = [ ser_typeref.subtypes[i] if ser_typeref and ser_typeref.subtypes else None for i in range(len(expr.elements)) ] val = _row([ output_as_value(e.val, ser_typeref=ser_typerefs[i], env=env) for i, e in enumerate(expr.elements) ]) if (expr.typeref is not None and not needs_custom_serialization and not env.singleton_mode and irtyputils.is_persistent_tuple(expr.typeref)): pg_type = pgtypes.pg_type_from_ir_typeref(expr.typeref) val = pgast.TypeCast( arg=val, type_name=pgast.TypeName( name=pg_type, ), ) elif (needs_custom_serialization and not expr.ser_safe): assert ser_typeref is not None if irtyputils.is_array(ser_typeref): return serialize_custom_array(expr, styperef=ser_typeref, env=env) elif irtyputils.is_tuple(ser_typeref): return serialize_custom_tuple(expr, styperef=ser_typeref, env=env) else: el_sql_type = ser_typeref.real_base_type.custom_sql_serialization assert el_sql_type is not None val = pgast.TypeCast( arg=val, type_name=pgast.TypeName(name=(el_sql_type,)), ) return val def add_null_test(expr: pgast.BaseExpr, query: pgast.SelectStmt) -> None: if not expr.nullable: return while isinstance(expr, pgast.TupleVar) and expr.elements: expr = expr.elements[0].val query.where_clause = astutils.extend_binop( query.where_clause, pgast.NullTest(arg=expr, negated=True) ) def serialize_expr_if_needed( expr: pgast.BaseExpr, *, path_id: irast.PathId, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: if in_serialization_ctx(ctx): val = serialize_expr(expr, path_id=path_id, env=ctx.env) else: val = expr return val def serialize_expr_to_json( expr: pgast.BaseExpr, *, styperef: irast.TypeRef, nested: bool=False, env: context.Environment) -> pgast.BaseExpr: val: pgast.BaseExpr if isinstance(expr, pgast.TupleVar): val = tuple_var_as_json_object(expr, styperef=styperef, env=env) elif isinstance(expr, (pgast.RowExpr, pgast.ImplicitRowExpr)): val = _build_json( 'build_array', args=expr.args, null_safe=True, ser_safe=True, env=env, ) elif irtyputils.is_range(styperef) and not expr.ser_safe: val = pgast.FuncCall( # Use the actual generic helper for converting anyrange to jsonb name=common.maybe_versioned_name( ('edgedb', 'range_to_jsonb'), versioned=env.versioned_stdlib, ), args=[expr], null_safe=True, ser_safe=True) if env.output_format in _JSON_FORMATS: val = pgast.TypeCast( arg=val, type_name=pgast.TypeName(name=('json',)) ) elif irtyputils.is_multirange(styperef) and not expr.ser_safe: val = pgast.FuncCall( # Use the actual generic helper for converting anymultirange to # jsonb name=common.maybe_versioned_name( ('edgedb', 'multirange_to_jsonb'), versioned=env.versioned_stdlib, ), args=[expr], null_safe=True, ser_safe=True) if env.output_format in _JSON_FORMATS: val = pgast.TypeCast( arg=val, type_name=pgast.TypeName(name=('json',)) ) elif irtyputils.is_collection(styperef) and not expr.ser_safe: val = coll_as_json_object(expr, styperef=styperef, env=env) elif ( styperef.real_base_type.needs_custom_json_cast and not expr.ser_safe ): base = styperef.real_base_type cast_name = s_casts.get_cast_fullname_from_names( base.orig_name_hint or base.name_hint, sn.QualName('std', 'json'), ) val = pgast.FuncCall( name=common.get_cast_backend_name( cast_name, aspect='function', versioned=env.versioned_stdlib ), args=[expr], null_safe=True, ser_safe=True) if env.output_format in _JSON_FORMATS: val = pgast.TypeCast( arg=val, type_name=pgast.TypeName(name=('json',)) ) elif not nested: val = pgast.FuncCall( name=_get_json_func('to', env=env), args=[expr], null_safe=True, ser_safe=True) else: val = expr return val def serialize_expr( expr: pgast.BaseExpr, *, path_id: irast.PathId, nested: bool=False, env: context.Environment) -> pgast.BaseExpr: if env.output_format in (context.OutputFormat.JSON, context.OutputFormat.JSON_ELEMENTS, context.OutputFormat.JSONB): val = serialize_expr_to_json( expr, styperef=path_id.target, nested=nested, env=env) elif env.output_format in (context.OutputFormat.NATIVE, context.OutputFormat.NATIVE_INTERNAL, context.OutputFormat.NONE): val = output_as_value(expr, ser_typeref=path_id.target, env=env) else: raise RuntimeError(f'unexpected output format: {env.output_format!r}') return val def get_pg_type( typeref: irast.TypeRef, *, ctx: context.CompilerContextLevel) -> tuple[str, ...]: if in_serialization_ctx(ctx): if ctx.env.output_format is context.OutputFormat.JSONB: return ('jsonb',) elif ctx.env.output_format in _JSON_FORMATS: return ('json',) elif irtyputils.is_object(typeref): return ('record',) else: return pgtypes.pg_type_from_ir_typeref(typeref) else: return pgtypes.pg_type_from_ir_typeref(typeref) def aggregate_json_output( stmt: pgast.SelectStmt, ir_set: irast.Set, *, env: context.Environment) -> pgast.SelectStmt: subrvar = pgast.RangeSubselect( subquery=stmt, alias=pgast.Alias( aliasname=env.aliases.get('aggw') ) ) stmt_res = stmt.target_list[0] if stmt_res.name is None: stmt_res = stmt.target_list[0] = pgast.ResTarget( name=env.aliases.get('v'), val=stmt_res.val, ) assert stmt_res.name is not None new_val = pgast.CoalesceExpr( args=[ pgast.FuncCall( name=_get_json_func('agg', env=env), args=[pgast.ColumnRef(name=[stmt_res.name])] ), pgast.StringConstant(val='[]') ] ) result = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=new_val ) ], from_clause=[ subrvar ] ) result.ctes = stmt.ctes stmt.ctes = [] return result def wrap_script_stmt( stmt: pgast.SelectStmt, *, suppress_all_output: bool = False, env: context.Environment, ) -> pgast.SelectStmt: subrvar = pgast.RangeSubselect( subquery=stmt, alias=pgast.Alias( aliasname=env.aliases.get('aggw') ) ) stmt_res = stmt.target_list[0] if stmt_res.name is None: stmt_res = stmt.target_list[0] = pgast.ResTarget( name=env.aliases.get('v'), val=stmt_res.val, ) assert stmt_res.name is not None count_val = pgast.FuncCall( name=('count',), args=[pgast.ColumnRef(name=[stmt_res.name])] ) result = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=count_val, name=stmt_res.name, ), ], from_clause=[ subrvar, ] ) if suppress_all_output: subrvar = pgast.RangeSubselect( subquery=result, alias=pgast.Alias( aliasname=env.aliases.get('q') ) ) result = pgast.SelectStmt( target_list=[], from_clause=[ subrvar, ], where_clause=pgast.NullTest( arg=pgast.ColumnRef( name=[subrvar.alias.aliasname, stmt_res.name], ), ), ) result.ctes = stmt.ctes stmt.ctes = [] return result def top_output_as_value( stmt: pgast.SelectStmt, ir_set: irast.Set, *, env: context.Environment) -> pgast.SelectStmt: """Finalize output serialization on the top level.""" if (env.output_format is context.OutputFormat.JSON and not env.expected_cardinality_one): # For JSON we just want to aggregate the whole thing # into a JSON array. return aggregate_json_output(stmt, ir_set, env=env) elif ( env.explicit_top_cast is not None and ( env.output_format is context.OutputFormat.NATIVE or env.output_format is context.OutputFormat.NATIVE_INTERNAL ) ): typecast = pgast.TypeCast( arg=stmt.target_list[0].val, type_name=pgast.TypeName( name=pgtypes.pg_type_from_ir_typeref( env.explicit_top_cast, persistent_tuples=True, ), ), ) stmt.target_list[0] = pgast.ResTarget( name=env.aliases.get('v'), val=typecast, ) return stmt elif env.output_format is context.OutputFormat.NONE: return wrap_script_stmt(stmt, env=env) else: # JSON_ELEMENTS and BINARY don't require any wrapping return stmt ================================================ FILE: edb/pgsql/compiler/pathctx.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Helpers to manage statement path contexts.""" from __future__ import annotations from typing import Optional, Sequence, TypeGuard from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.schema import pointers as s_pointers from edb.pgsql import ast as pgast from edb.pgsql import common from edb.pgsql import types as pg_types from . import astutils from . import context from . import enums as pgce from . import output # A mapping of more specific aspect -> less specific aspect for objects OBJECT_ASPECT_SPECIFICITY_MAP = { pgce.PathAspect.IDENTITY: pgce.PathAspect.VALUE, pgce.PathAspect.VALUE: pgce.PathAspect.SOURCE, pgce.PathAspect.SERIALIZED: pgce.PathAspect.SOURCE, } # A mapping of more specific aspect -> less specific aspect for primitives PRIMITIVE_ASPECT_SPECIFICITY_MAP = { pgce.PathAspect.SERIALIZED: pgce.PathAspect.VALUE, } def get_less_specific_aspect( path_id: irast.PathId, aspect: pgce.PathAspect, ) -> Optional[pgce.PathAspect]: if path_id.is_objtype_path(): mapping = OBJECT_ASPECT_SPECIFICITY_MAP else: mapping = PRIMITIVE_ASPECT_SPECIFICITY_MAP less_specific_aspect = mapping.get(pgce.PathAspect(aspect)) if less_specific_aspect is not None: return less_specific_aspect else: return None def map_path_id( path_id: irast.PathId, path_id_map: dict[irast.PathId, irast.PathId]) -> irast.PathId: sorted_map = sorted( path_id_map.items(), key=lambda kv: len(kv[0]), reverse=True) for outer_id, inner_id in sorted_map: new_path_id = irtyputils.replace_pathid_prefix( path_id, outer_id, inner_id, permissive_ptr_path=True) if new_path_id != path_id: path_id = new_path_id break return path_id def reverse_map_path_id( path_id: irast.PathId, path_id_map: dict[irast.PathId, irast.PathId]) -> irast.PathId: for outer_id, inner_id in path_id_map.items(): new_path_id = irtyputils.replace_pathid_prefix( path_id, inner_id, outer_id) if new_path_id != path_id: path_id = new_path_id break return path_id def put_path_id_mask( stmt: pgast.EdgeQLPathInfo, path_id: irast.PathId ) -> None: stmt.path_id_mask.add(path_id) def put_path_id_map( rel: pgast.Query, outer_path_id: irast.PathId, inner_path_id: irast.PathId, ) -> None: inner_path_id = map_path_id(inner_path_id, rel.view_path_id_map) rel.view_path_id_map[outer_path_id] = inner_path_id def get_path_var( rel: pgast.Query, path_id: irast.PathId, *, flavor: str='normal', aspect: pgce.PathAspect, env: context.Environment, ) -> pgast.BaseExpr: """ Return a value expression for a given *path_id* in a given *rel*. This function is a part of "recursive column injection" algorithm, described in [./ARCHITECTURE.md]. """ if isinstance(rel, pgast.CommonTableExpr): rel = rel.query if flavor == 'normal': if rel.view_path_id_map: path_id = map_path_id(path_id, rel.view_path_id_map) if (path_id, aspect) in rel.path_namespace: return rel.path_namespace[path_id, aspect] elif flavor == 'packed': if ( rel.packed_path_namespace and (path_id, aspect) in rel.packed_path_namespace ): return rel.packed_path_namespace[path_id, aspect] if astutils.is_set_op_query(rel): return _get_path_var_in_setop( rel, path_id, aspect=aspect, flavor=flavor, env=env) ptrref = path_id.rptr() ptrref_dir = path_id.rptr_dir() is_type_intersection = path_id.is_type_intersection_path() src_path_id: Optional[irast.PathId] = None if ptrref is not None and not is_type_intersection: ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=False, link_bias=False, allow_missing=True) ptr_dir = path_id.rptr_dir() is_inbound = ptr_dir == s_pointers.PointerDirection.Inbound if is_inbound: src_path_id = path_id else: src_path_id = path_id.src_path() assert src_path_id is not None src_rptr = src_path_id.rptr() if ( irtyputils.is_id_ptrref(ptrref) and ( src_rptr is None or ptrref_dir is not s_pointers.PointerDirection.Inbound ) ): # When there is a reference to the id property of # an object which is linked to by a link stored # inline, we want to route the reference to the # inline attribute. For example, # Foo.__type__.id gets resolved to the Foo.__type__ # column. This can only be done if Foo is visible # in scope, and Foo.__type__ is not a computable. pid = src_path_id while pid.is_type_intersection_path(): # Skip type intersection step(s). src_pid = pid.src_path() if src_pid is not None: src_rptr = src_pid.rptr() pid = src_pid else: break if (src_rptr is not None and not irtyputils.is_computable_ptrref(src_rptr) and env.ptrref_source_visibility.get(src_rptr)): src_ptr_info = pg_types.get_ptrref_storage_info( src_rptr, resolve_type=False, link_bias=False, allow_missing=True) if (src_ptr_info and src_ptr_info.table_type == 'ObjectType'): src_path_id = src_path_id.src_path() ptr_info = src_ptr_info else: ptr_info = None ptr_dir = None var: Optional[pgast.BaseExpr] if ptrref is None: if len(path_id) == 1: # This is an scalar set derived from an expression. src_path_id = path_id elif ptrref.source_ptr is not None: if ptr_info and ptr_info.table_type != 'link' and not is_inbound: # This is a link prop that is stored in source rel, # step back to link source rvar. _prefix_pid = path_id.src_path() assert _prefix_pid is not None src_path_id = _prefix_pid.src_path() elif is_type_intersection: src_path_id = path_id assert src_path_id is not None # Find which rvar will have path_id as an output src_aspect, rel_rvar, found_path_var = _find_rel_rvar( rel, path_id, src_path_id, aspect=aspect, flavor=flavor ) if found_path_var: return found_path_var # Slight hack: Inject the __type__ field of a FreeObject when necessary if ( rel_rvar is None and ptrref and ptrref.shortname.name == '__type__' and irtyputils.is_free_object(src_path_id.target) ): return astutils.compile_typeref(src_path_id.target.real_material_type) if isinstance(rel_rvar, pgast.DynamicRangeVar): var = rel_rvar.dynamic_get_path( rel, path_id, flavor=flavor, aspect=aspect, env=env) if isinstance(var, pgast.PathRangeVar): rel_rvar = var elif var: put_path_var(rel, path_id, var, aspect=aspect, flavor=flavor) return var else: rel_rvar = None if rel_rvar is None: raise LookupError( f'there is no range var for ' f'{src_path_id} {src_aspect} in {rel}') if isinstance(rel_rvar, pgast.IntersectionRangeVar): if ( (path_id.is_objtype_path() and src_path_id == path_id) or (ptrref is not None and irtyputils.is_id_ptrref(ptrref)) ): rel_rvar = rel_rvar.component_rvars[-1] else: # Intersection rvars are basically JOINs of the relevant # parts of the type intersection, and so we need to make # sure we pick the correct component relation of that JOIN. rel_rvar = _find_rvar_in_intersection_by_typeref( path_id, rel_rvar.component_rvars, ) source_rel = rel_rvar.query outvar = get_path_output( source_rel, path_id, aspect=aspect, flavor=flavor, env=env) var = astutils.get_rvar_var(rel_rvar, outvar) put_path_var(rel, path_id, var, aspect=aspect, flavor=flavor) if isinstance(var, pgast.TupleVar): for element in var.elements: put_path_var_if_not_exists( rel, element.path_id, element.val, flavor=flavor, aspect=aspect ) return var def _find_rel_rvar( rel: pgast.Query, path_id: irast.PathId, src_path_id: irast.PathId, *, aspect: pgce.PathAspect, flavor: str, ) -> tuple[str, Optional[pgast.PathRangeVar], Optional[pgast.BaseExpr]]: """Rummage around rel looking for an appropriate rvar for path_id. Somewhat unfortunately, some checks to find the actual path var (in a particular tuple case) need to occur in the middle of the rvar rel search, so we can also find the actual path var in passing. """ src_aspect = aspect rel_rvar = maybe_get_path_rvar(rel, path_id, aspect=aspect, flavor=flavor) if rel_rvar is None: alt_aspect = get_less_specific_aspect(path_id, aspect) if alt_aspect is not None: rel_rvar = maybe_get_path_rvar(rel, path_id, aspect=alt_aspect) else: alt_aspect = None if rel_rvar is None: if flavor == 'packed': src_aspect = aspect elif src_path_id.is_objtype_path(): src_aspect = pgce.PathAspect.SOURCE else: src_aspect = aspect if src_path_id.is_tuple_path(): if src_aspect == pgce.PathAspect.IDENTITY: src_aspect = pgce.PathAspect.VALUE if var := _find_in_output_tuple(rel, path_id, src_aspect): return src_aspect, None, var rel_rvar = maybe_get_path_rvar(rel, src_path_id, aspect=src_aspect) if rel_rvar is None: _src_path_id_prefix = src_path_id.src_path() if _src_path_id_prefix is not None: rel_rvar = maybe_get_path_rvar( rel, _src_path_id_prefix, aspect=src_aspect ) else: rel_rvar = maybe_get_path_rvar(rel, src_path_id, aspect=src_aspect) if ( rel_rvar is None and src_aspect != pgce.PathAspect.SOURCE and path_id != src_path_id ): rel_rvar = maybe_get_path_rvar( rel, src_path_id, aspect=pgce.PathAspect.SOURCE ) if rel_rvar is None and alt_aspect is not None and flavor == 'normal': # There is no source range var for the requested aspect, # check if there is a cached var with less specificity. var = rel.path_namespace.get((path_id, alt_aspect)) if var is not None: put_path_var(rel, path_id, var, aspect=aspect, flavor=flavor) return src_aspect, None, var return src_aspect, rel_rvar, None def _get_path_var_in_setop( rel: pgast.Query, path_id: irast.PathId, *, aspect: pgce.PathAspect, flavor: str, env: context.Environment, ) -> pgast.BaseExpr: test_vals = [] if aspect in (pgce.PathAspect.VALUE, pgce.PathAspect.SERIALIZED): test_vals = [ maybe_get_path_var(q, env=env, path_id=path_id, aspect=aspect) for q in astutils.each_query_in_set(rel) ] # In order to ensure output balance, we only want to output # a TupleVar if *every* subquery outputs a TupleVar. # If some but not all output TupleVars, we need to fix up # the output TupleVars by outputting them as a real tuple. # This is needed for cases like `(Foo.bar UNION (1,2))`. if ( any(isinstance(x, pgast.TupleVarBase) for x in test_vals) and not all(isinstance(x, pgast.TupleVarBase) for x in test_vals) ): for subrel in astutils.each_query_in_set(rel): cur = get_path_var( subrel, env=env, path_id=path_id, aspect=aspect) assert flavor == 'normal' if isinstance(cur, pgast.TupleVarBase): new = output.output_as_value(cur, env=env) new_path_id = map_path_id(path_id, subrel.view_path_id_map) put_path_var( subrel, new_path_id, new, force=True, aspect=aspect ) # We disable the find_path_output optimization when doing # UNIONs to avoid situations where they have different numbers # of columns. outputs = [ get_path_output_or_null( q, env=env, disable_output_fusion=True, path_id=path_id, aspect=aspect, flavor=flavor ) for q in astutils.each_query_in_set(rel) ] counts = [len(x.target_list) for x in astutils.each_query_in_set(rel)] assert counts == [counts[0]] * len(counts) first: Optional[pgast.OutputVar] = None optional = False all_null = True nullable = False for colref, is_null in outputs: if colref.nullable: nullable = True if first is None: first = colref if is_null: optional = True else: all_null = False # Fail if no subquery had the path or, for scalar identity paths, # if any did not have it. # # We need to do this for scalar identity because scalar identity # is only used for volatility refs, and it is OK if looking it up # fails, because we create a backup volatility ref---but it is # *not* OK for it to succeed and produce NULL in some cases. if all_null or ( aspect == pgce.PathAspect.IDENTITY and optional and not path_id.is_objtype_path() ): # If *none* of the subqueries had it, we have to remove them all # before erroring, lest a future call see them and decide # they really exist. for subrel in astutils.each_query_in_set(rel): assert flavor == 'normal' new_path_id = map_path_id(path_id, subrel.view_path_id_map) del subrel.path_outputs[new_path_id, aspect] subrel.target_list.pop() raise LookupError( f'cannot find refs for ' f'path {path_id} {aspect} in {rel}') if first is None: raise AssertionError( f'union did not produce any outputs') # Path vars produced by UNION expressions can be "optional", # i.e the record is accepted as-is when such var is NULL. # This is necessary to correctly join heterogeneous UNIONs. var = astutils.strip_output_var( first, optional=optional, nullable=optional or nullable) put_path_var(rel, path_id, var, aspect=aspect, flavor=flavor) return var def _find_rvar_in_intersection_by_typeref( path_id: irast.PathId, component_rvars: Sequence[pgast.PathRangeVar], ) -> pgast.PathRangeVar: assert component_rvars if src_path := path_id.src_path(): tref = src_path.target else: tref = path_id.target for component_rvar in component_rvars: if ( component_rvar.typeref is not None and irtyputils.type_contains(tref, component_rvar.typeref) ): rel_rvar = component_rvar break else: raise AssertionError( f'no rvar in intersection matches path id {path_id}' ) return rel_rvar def _find_in_output_tuple( rel: pgast.Query, path_id: irast.PathId, aspect: pgce.PathAspect ) -> Optional[pgast.BaseExpr]: """Try indirecting a source tuple already present as an output. Normally tuple indirections are handled by process_set_as_tuple_indirection, but UNIONing an explicit tuple with a tuple coming from a base relation (like `(Foo.bar UNION (1,2)).0`) can lead to us looking for a tuple path in relations that only have the actual full tuple. (See test_edgeql_coalesce_tuple_{08,09}). We handle this by checking whether some prefix of the tuple path is present in the path_outputs. This is sufficient because the relevant cases are all caused by set ops, and the "fixup" done in set op cases ensures that the tuple will be already present. """ steps = [] src_path_id = path_id.src_path() ptrref = path_id.rptr() while ( src_path_id and src_path_id.is_tuple_path() and isinstance(ptrref, irast.TupleIndirectionPointerRef) ): steps.append((ptrref.shortname.name, src_path_id)) if ( (var := rel.path_namespace.get((src_path_id, aspect))) and not isinstance(var, pgast.TupleVarBase) ): for name, src in reversed(steps): var = astutils.tuple_getattr(var, src.target, name) put_path_var(rel, path_id, var, aspect=aspect) return var ptrref = src_path_id.rptr() src_path_id = src_path_id.src_path() return None def get_path_identity_var( rel: pgast.Query, path_id: irast.PathId, *, env: context.Environment, ) -> pgast.BaseExpr: return get_path_var(rel, path_id, aspect=pgce.PathAspect.IDENTITY, env=env) def get_path_value_var( rel: pgast.Query, path_id: irast.PathId, *, env: context.Environment, ) -> pgast.BaseExpr: return get_path_var(rel, path_id, aspect=pgce.PathAspect.VALUE, env=env) def is_relation_rvar( rvar: pgast.BaseRangeVar, ) -> bool: return ( isinstance(rvar, pgast.RelRangeVar) and is_terminal_relation(rvar.query) ) def is_terminal_relation( rel: pgast.BaseRelation ) -> TypeGuard[pgast.Relation | pgast.NullRelation]: return isinstance(rel, (pgast.Relation, pgast.NullRelation)) def is_values_relation( rel: pgast.BaseRelation, ) -> bool: return bool(getattr(rel, 'values', None)) def maybe_get_path_var( rel: pgast.Query, path_id: irast.PathId, *, aspect: pgce.PathAspect, flavor: str='normal', env: context.Environment, ) -> Optional[pgast.BaseExpr]: try: return get_path_var( rel, path_id, aspect=aspect, flavor=flavor, env=env) except LookupError: return None def maybe_get_path_identity_var( rel: pgast.Query, path_id: irast.PathId, *, env: context.Environment, ) -> Optional[pgast.BaseExpr]: try: return get_path_var( rel, path_id, aspect=pgce.PathAspect.IDENTITY, env=env ) except LookupError: return None def maybe_get_path_value_var( rel: pgast.Query, path_id: irast.PathId, *, env: context.Environment, ) -> Optional[pgast.BaseExpr]: try: return get_path_var( rel, path_id, aspect=pgce.PathAspect.VALUE, env=env ) except LookupError: return None def maybe_get_path_serialized_var( rel: pgast.Query, path_id: irast.PathId, *, env: context.Environment, ) -> Optional[pgast.BaseExpr]: try: return get_path_var( rel, path_id, aspect=pgce.PathAspect.SERIALIZED, env=env ) except LookupError: return None def put_path_var( rel: pgast.BaseRelation, path_id: irast.PathId, var: pgast.BaseExpr, *, aspect: pgce.PathAspect, flavor: str = 'normal', force: bool = False, ) -> None: if flavor == 'packed': if rel.packed_path_namespace is None: rel.packed_path_namespace = {} path_namespace = rel.packed_path_namespace else: path_namespace = rel.path_namespace if (path_id, aspect) in path_namespace and not force: raise KeyError( f'{aspect} of {path_id} is already present in {rel}') path_namespace[path_id, aspect] = var def put_path_var_if_not_exists( rel: pgast.Query, path_id: irast.PathId, var: pgast.BaseExpr, *, flavor: str = 'normal', aspect: pgce.PathAspect, ) -> None: try: put_path_var(rel, path_id, var, aspect=aspect, flavor=flavor) except KeyError: pass def put_path_identity_var( rel: pgast.BaseRelation, path_id: irast.PathId, var: pgast.BaseExpr, *, force: bool = False, ) -> None: put_path_var( rel, path_id, var, aspect=pgce.PathAspect.IDENTITY, force=force ) def put_path_value_var( rel: pgast.BaseRelation, path_id: irast.PathId, var: pgast.BaseExpr, *, force: bool = False, ) -> None: put_path_var( rel, path_id, var, aspect=pgce.PathAspect.VALUE, force=force ) def put_path_serialized_var( rel: pgast.BaseRelation, path_id: irast.PathId, var: pgast.BaseExpr, *, force: bool = False, ) -> None: put_path_var( rel, path_id, var, aspect=pgce.PathAspect.SERIALIZED, force=force ) def put_path_value_var_if_not_exists( rel: pgast.BaseRelation, path_id: irast.PathId, var: pgast.BaseExpr, *, force: bool = False, ) -> None: try: put_path_var( rel, path_id, var, aspect=pgce.PathAspect.VALUE, force=force ) except KeyError: pass def put_path_serialized_var_if_not_exists( rel: pgast.BaseRelation, path_id: irast.PathId, var: pgast.BaseExpr, *, force: bool = False, ) -> None: try: put_path_var( rel, path_id, var, aspect=pgce.PathAspect.SERIALIZED, force=force, ) except KeyError: pass def put_path_bond( stmt: pgast.BaseRelation, path_id: irast.PathId, iterator: bool=False ) -> None: '''Register a path id that should be joined on when joining stmt iterator indicates whether the identity or iterator aspect should be used. ''' stmt.path_bonds.add((path_id, iterator)) def put_rvar_path_bond( rvar: pgast.PathRangeVar, path_id: irast.PathId) -> None: put_path_bond(rvar.query, path_id) def get_path_output_alias( path_id: irast.PathId, aspect: pgce.PathAspect, *, env: context.Environment, ) -> str: rptr = path_id.rptr() if rptr is not None: alias_base = rptr.shortname.name elif path_id.is_collection_path(): assert path_id.target.collection is not None alias_base = path_id.target.collection else: alias_base = path_id.target_name_hint.name return env.aliases.get(f'{alias_base}_{aspect}') def get_rvar_path_var( rvar: pgast.PathRangeVar, path_id: irast.PathId, aspect: pgce.PathAspect, *, flavor: str='normal', env: context.Environment, ) -> pgast.OutputVar: """Return ColumnRef for a given *path_id* in a given *range var*.""" outvar = get_path_output( rvar.query, path_id, aspect=aspect, flavor=flavor, env=env) return astutils.get_rvar_var(rvar, outvar) def put_rvar_path_output( rvar: pgast.PathRangeVar, path_id: irast.PathId, aspect: pgce.PathAspect, var: pgast.OutputVar, ) -> None: _put_path_output_var(rvar.query, path_id, aspect, var) def maybe_get_rvar_path_var( rvar: pgast.PathRangeVar, path_id: irast.PathId, *, aspect: pgce.PathAspect, flavor: str='normal', env: context.Environment, ) -> Optional[pgast.OutputVar]: try: return get_rvar_path_var( rvar, path_id, aspect=aspect, flavor=flavor, env=env) except LookupError: return None def get_rvar_path_identity_var( rvar: pgast.PathRangeVar, path_id: irast.PathId, *, env: context.Environment, ) -> pgast.OutputVar: return get_rvar_path_var( rvar, path_id, aspect=pgce.PathAspect.IDENTITY, env=env ) def get_rvar_path_value_var( rvar: pgast.PathRangeVar, path_id: irast.PathId, *, env: context.Environment, ) -> pgast.OutputVar: return get_rvar_path_var( rvar, path_id, aspect=pgce.PathAspect.VALUE, env=env ) def get_rvar_output_var_as_col_list( rvar: pgast.PathRangeVar, outvar: pgast.OutputVar, aspect: pgce.PathAspect, *, env: context.Environment, ) -> list[pgast.OutputVar]: cols: list[pgast.OutputVar] if isinstance(outvar, pgast.TupleVarBase): cols = [] for el in outvar.elements: col = get_rvar_path_var(rvar, el.path_id, aspect=aspect, env=env) cols.extend(get_rvar_output_var_as_col_list( rvar, col, aspect=aspect, env=env)) else: cols = [outvar] return cols def put_path_rvar( stmt: pgast.Query, path_id: irast.PathId, rvar: pgast.PathRangeVar, *, flavor: str = 'normal', aspect: pgce.PathAspect, ) -> None: assert isinstance(path_id, irast.PathId) stmt.get_rvar_map(flavor)[path_id, aspect] = rvar def put_path_value_rvar( stmt: pgast.Query, path_id: irast.PathId, rvar: pgast.PathRangeVar, *, flavor: str = 'normal', ) -> None: put_path_rvar( stmt, path_id, rvar, aspect=pgce.PathAspect.VALUE, flavor=flavor ) def put_path_source_rvar( stmt: pgast.Query, path_id: irast.PathId, rvar: pgast.PathRangeVar, *, flavor: str = 'normal', ) -> None: put_path_rvar( stmt, path_id, rvar, aspect=pgce.PathAspect.SOURCE, flavor=flavor ) def has_rvar(stmt: pgast.Query, rvar: pgast.PathRangeVar) -> bool: return any( rvar in set(stmt.get_rvar_map(flavor).values()) for flavor in ('normal', 'packed') ) def put_path_rvar_if_not_exists( stmt: pgast.Query, path_id: irast.PathId, rvar: pgast.PathRangeVar, *, flavor: str = 'normal', aspect: pgce.PathAspect, ) -> None: if (path_id, aspect) not in stmt.get_rvar_map(flavor): put_path_rvar(stmt, path_id, rvar, aspect=aspect, flavor=flavor) def get_path_rvar( stmt: pgast.Query, path_id: irast.PathId, *, flavor: str = 'normal', aspect: pgce.PathAspect, ) -> pgast.PathRangeVar: rvar = maybe_get_path_rvar(stmt, path_id, aspect=aspect, flavor=flavor) if rvar is None: raise LookupError( f'there is no range var for {path_id} {aspect} in {stmt}') return rvar def maybe_get_path_rvar( stmt: pgast.Query, path_id: irast.PathId, *, aspect: pgce.PathAspect, flavor: str = 'normal', ) -> Optional[pgast.PathRangeVar]: rvar = None path_rvar_map = stmt.maybe_get_rvar_map(flavor) if path_rvar_map is not None: if path_rvar_map: rvar = path_rvar_map.get((path_id, aspect)) if rvar is None and aspect == pgce.PathAspect.IDENTITY: rvar = path_rvar_map.get((path_id, pgce.PathAspect.VALUE)) return rvar def _has_path_aspect( stmt: pgast.Query, path_id: irast.PathId, *, aspect: pgce.PathAspect, ) -> bool: key = path_id, aspect return ( key in stmt.path_rvar_map or key in stmt.path_namespace or key in stmt.path_outputs ) def has_path_aspect( stmt: pgast.Query, path_id: irast.PathId, *, aspect: pgce.PathAspect ) -> bool: path_id = map_path_id(path_id, stmt.view_path_id_map) return _has_path_aspect(stmt, path_id, aspect=aspect) def list_path_aspects( stmt: pgast.Query, path_id: irast.PathId ) -> set[pgce.PathAspect]: path_aspects = ( pgce.PathAspect.VALUE, pgce.PathAspect.IDENTITY, pgce.PathAspect.SOURCE, pgce.PathAspect.SERIALIZED, ) path_id = map_path_id(path_id, stmt.view_path_id_map) return { aspect for aspect in path_aspects if _has_path_aspect(stmt, path_id, aspect=aspect) } def maybe_get_path_value_rvar( stmt: pgast.Query, path_id: irast.PathId ) -> Optional[pgast.BaseRangeVar]: return maybe_get_path_rvar(stmt, path_id, aspect=pgce.PathAspect.VALUE) def _same_expr(expr1: pgast.BaseExpr, expr2: pgast.BaseExpr) -> bool: if (isinstance(expr1, pgast.ColumnRef) and isinstance(expr2, pgast.ColumnRef)): return expr1.name == expr2.name else: return expr1 == expr2 def put_path_packed_output( rel: pgast.EdgeQLPathInfo, path_id: irast.PathId, val: pgast.OutputVar, aspect: pgce.PathAspect=pgce.PathAspect.VALUE, ) -> None: if rel.packed_path_outputs is None: rel.packed_path_outputs = {} rel.packed_path_outputs[path_id, aspect] = val def _put_path_output_var( rel: pgast.BaseRelation, path_id: irast.PathId, aspect: pgce.PathAspect, var: pgast.OutputVar, *, flavor: str = 'normal', ) -> None: if flavor == 'packed': put_path_packed_output(rel, path_id, var, aspect) else: rel.path_outputs[path_id, aspect] = var def _get_rel_object_id_output( rel: pgast.BaseRelation, path_id: irast.PathId, *, aspect: pgce.PathAspect, env: context.Environment, ) -> pgast.OutputVar: var = rel.path_outputs.get((path_id, aspect)) if var is not None: return var if isinstance(rel, pgast.NullRelation): name = env.aliases.get('id') val = pgast.TypeCast( arg=pgast.NullConstant(), type_name=pgast.TypeName( name=('uuid',), ) ) rel.target_list.append(pgast.ResTarget(name=name, val=val)) result = pgast.ColumnRef(name=[name], nullable=True) else: result = pgast.ColumnRef(name=['id'], nullable=False) _put_path_output_var(rel, path_id, aspect, result) return result def _get_rel_path_output( rel: pgast.Relation | pgast.NullRelation, path_id: irast.PathId, *, aspect: pgce.PathAspect, flavor: str, env: context.Environment, ) -> pgast.OutputVar: if path_id.is_objtype_path(): if aspect == pgce.PathAspect.IDENTITY: aspect = pgce.PathAspect.VALUE if aspect != pgce.PathAspect.VALUE: raise LookupError( f'invalid request for non-scalar path {path_id} {aspect}') if (path_id == rel.path_id or (rel.path_id and rel.path_id.is_type_intersection_path() and path_id == rel.path_id.src_path())): return _get_rel_object_id_output( rel, path_id, aspect=aspect, env=env) else: if aspect == pgce.PathAspect.IDENTITY: raise LookupError( f'invalid request for scalar path {path_id} {aspect}') elif aspect == pgce.PathAspect.SERIALIZED: aspect = pgce.PathAspect.VALUE var = rel.path_outputs.get((path_id, aspect)) if var is not None: return var # The ptrref from the path id may be from a super type of the # actual type this relation corresponds to. We know the relation's # type, so find the real ptrref that corresponds to the current # type (since the column names will be different in the parent # and child tables). rptr_dir = path_id.rptr_dir() ptrref = path_id.rptr() if isinstance(ptrref, irast.PointerRef) and rel.type_or_ptr_ref: typeref = rel.type_or_ptr_ref if isinstance(typeref, irast.PointerRef): typeref = typeref.out_source assert rptr_dir actual_ptrref = irtyputils.maybe_find_actual_ptrref( typeref, ptrref, dir=rptr_dir) if actual_ptrref: ptrref = actual_ptrref ptr_info = None if ptrref and not isinstance(ptrref, irast.TypeIntersectionPointerRef): ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=False, link_bias=bool(rel.path_id and rel.path_id.is_ptr_path()), ) if (rptr_dir is not None and rptr_dir != s_pointers.PointerDirection.Outbound): raise LookupError( f'{path_id} is an inbound pointer and cannot be resolved ' f'on a base relation') result: pgast.OutputVar if isinstance(rel, pgast.NullRelation): if ptrref is not None: target = ptrref.out_target else: target = path_id.target pg_type = pg_types.pg_type_from_ir_typeref(target) if ptr_info is not None: name = env.aliases.get(ptr_info.column_name) else: name = env.aliases.get('v') val = pgast.TypeCast( arg=pgast.NullConstant(), type_name=pgast.TypeName( name=pg_type, ) ) rel.target_list.append(pgast.ResTarget(name=name, val=val)) result = pgast.ColumnRef(name=[name], nullable=True) else: if ptrref is None or ptr_info is None: raise LookupError( f'could not resolve trailing pointer class for {path_id}') if ptrref.is_computable: raise LookupError("can't lookup computable ptrref") # Refuse to try to access a link table when we are actually # looking at an object rel. This check is needed because # relgen._lookup_set_rvar_in_source sometimes does some pretty # wild maybe_get_path_value_var calls. if ( ptr_info.table_type == 'link' and isinstance(rel.type_or_ptr_ref, irast.TypeRef) ): raise LookupError("can't access link table on object rel") if ( ptrref.shortname.name == '__type__' and rel.name and not common.is_inhview_name(rel.name) ): assert isinstance(rel.type_or_ptr_ref, irast.TypeRef) result = pgast.ExprOutputVar( expr=astutils.compile_typeref(rel.type_or_ptr_ref)) else: result = pgast.ColumnRef( name=[ptr_info.column_name], nullable=not ptrref.required) _put_path_output_var(rel, path_id, aspect, result, flavor=flavor) return result def has_type_rewrite( typeref: irast.TypeRef, *, env: context.Environment) -> bool: return any( (typeref.real_material_type.id, b) in env.type_rewrites for b in (True, False) ) def link_needs_type_rewrite( typeref: irast.TypeRef, *, env: context.Environment) -> bool: return ( has_type_rewrite(typeref, env=env) # Typically we need to apply rewrites when looking at a link # target that has a policy on it, but we suppress this for # schema::ObjectType. None of the hidden objects should be # user visible anyway, and this allows us to do type id # injection without a join. and str(typeref.real_material_type.name_hint) != 'schema::ObjectType' ) def find_path_output( rel: pgast.BaseRelation, ref: pgast.BaseExpr ) -> Optional[pgast.OutputVar]: if isinstance(ref, pgast.TupleVarBase): return None for key, other_ref in rel.path_namespace.items(): if _same_expr(other_ref, ref) and key in rel.path_outputs: return rel.path_outputs.get(key) else: return None def get_path_output( rel: pgast.BaseRelation, path_id: irast.PathId, *, aspect: pgce.PathAspect, allow_nullable: bool=True, disable_output_fusion: bool=False, flavor: str='normal', env: context.Environment ) -> pgast.OutputVar: if isinstance(rel, pgast.Query) and flavor == 'normal': path_id = map_path_id(path_id, rel.view_path_id_map) # XXX: This is a haaaaack. if rel.strip_output_namespaces: path_id = path_id.strip_namespace(path_id.namespace) return _get_path_output(rel, path_id=path_id, aspect=aspect, disable_output_fusion=disable_output_fusion, allow_nullable=allow_nullable, flavor=flavor, env=env) def _get_path_output( rel: pgast.BaseRelation, path_id: irast.PathId, *, aspect: pgce.PathAspect, allow_nullable: bool=True, disable_output_fusion: bool=False, flavor: str, env: context.Environment, ) -> pgast.OutputVar: if flavor == 'packed': result = (rel.packed_path_outputs.get((path_id, aspect)) if rel.packed_path_outputs else None) else: result = rel.path_outputs.get((path_id, aspect)) if result is not None: return result ref: pgast.BaseExpr alias = None rptr = path_id.rptr() if ( rptr is not None and irtyputils.is_id_ptrref(rptr) and (src_path_id := path_id.src_path()) and not disable_output_fusion and not ( (src_rptr := src_path_id.rptr()) and src_rptr.real_material_ptr.out_cardinality.is_multi() and not irtyputils.is_free_object(src_path_id.target) ) and not link_needs_type_rewrite(src_path_id.target, env=env) ): # A value reference to Object.id is the same as a value # reference to the Object itself. (Though we want to only # apply this in the cases that process_set_as_path does this # optimization, which means not for multi props. We also always # allow it for free objects.) id_output = maybe_get_path_output( rel, src_path_id, aspect=pgce.PathAspect.VALUE, allow_nullable=allow_nullable, env=env ) if id_output is not None: _put_path_output_var(rel, path_id, aspect, id_output) return id_output if is_terminal_relation(rel): return _get_rel_path_output( rel, path_id, aspect=aspect, flavor=flavor, env=env) assert isinstance(rel, pgast.Query) if is_values_relation(rel) and aspect != pgce.PathAspect.IDENTITY: # The VALUES() construct seems to always expose its # value as "column1". alias = 'column1' ref = pgast.ColumnRef(name=[alias], nullable=rel.nullable) else: ref = get_path_var(rel, path_id, aspect=aspect, flavor=flavor, env=env) # As an optimization, look to see if the same expression is being # output on a different aspect. This can save us needing to do the # work twice in the query. other_output = find_path_output(rel, ref) if other_output is not None and not disable_output_fusion: _put_path_output_var(rel, path_id, aspect, other_output, flavor=flavor) return other_output if isinstance(ref, pgast.TupleVarBase): elements = [] for el in ref.elements: element = _get_path_output( rel, el.path_id, aspect=aspect, disable_output_fusion=disable_output_fusion, flavor=flavor, allow_nullable=allow_nullable, env=env) # We need to reverse the mapping for the element path in # the output TupleVar, since it will be used *outside* # this rel, and so without the map applied. el_path_id = reverse_map_path_id(el.path_id, rel.view_path_id_map) elements.append(pgast.TupleElement( path_id=el_path_id, val=element, name=element)) result = pgast.TupleVar( elements=elements, named=ref.named, typeref=ref.typeref, is_packed_multi=ref.is_packed_multi, ) else: if astutils.is_set_op_query(rel): assert isinstance(ref, pgast.OutputVar) result = astutils.strip_output_var(ref) else: assert isinstance(rel, pgast.ReturningQuery), \ "expected ReturningQuery" if alias is None: alias = get_path_output_alias(path_id, aspect, env=env) if isinstance(ref, pgast.NullConstant): pg_type = pg_types.pg_type_from_ir_typeref(path_id.target) ref = pgast.TypeCast( arg=ref, type_name=pgast.TypeName(name=pg_type) ) restarget = pgast.ResTarget( name=alias, val=ref, ser_safe=getattr(ref, 'ser_safe', False)) rel.target_list.append(restarget) nullable = is_nullable(ref, env=env) optional = None is_packed_multi = False if isinstance(ref, pgast.ColumnRef): optional = ref.optional is_packed_multi = ref.is_packed_multi # group by will register a *subquery* as a path var # for a packed group, and if we want to avoid losing # track of whether is is multi, we need to figure that out. if ( isinstance(ref, pgast.SelectStmt) and flavor == 'packed' and ref.packed_path_outputs and (path_id, aspect) in ref.packed_path_outputs ): is_packed_multi = ref.packed_path_outputs[ path_id, aspect].is_packed_multi if nullable and not allow_nullable: assert isinstance(rel, pgast.SelectStmt), \ "expected SelectStmt" var = get_path_var(rel, path_id, aspect=aspect, env=env) rel.where_clause = astutils.extend_binop( rel.where_clause, pgast.NullTest(arg=var, negated=True) ) nullable = False result = pgast.ColumnRef( name=[alias], nullable=nullable, optional=optional, is_packed_multi=is_packed_multi) _put_path_output_var(rel, path_id, aspect, result, flavor=flavor) if (path_id.is_objtype_path() and not isinstance(result, pgast.TupleVarBase)): equiv_aspect = None if aspect == pgce.PathAspect.IDENTITY: equiv_aspect = pgce.PathAspect.VALUE elif aspect == pgce.PathAspect.VALUE: equiv_aspect = pgce.PathAspect.IDENTITY if (equiv_aspect is not None and (path_id, equiv_aspect) not in rel.path_outputs): _put_path_output_var( rel, path_id, equiv_aspect, result, flavor=flavor ) return result def maybe_get_path_output( rel: pgast.BaseRelation, path_id: irast.PathId, *, aspect: pgce.PathAspect, allow_nullable: bool=True, disable_output_fusion: bool=False, flavor: str='normal', env: context.Environment, ) -> Optional[pgast.OutputVar]: try: return get_path_output(rel, path_id=path_id, aspect=aspect, allow_nullable=allow_nullable, disable_output_fusion=disable_output_fusion, flavor=flavor, env=env) except LookupError: return None def get_path_identity_output( rel: pgast.Query, path_id: irast.PathId, *, env: context.Environment, ) -> pgast.OutputVar: return get_path_output( rel, path_id, aspect=pgce.PathAspect.IDENTITY, env=env ) def get_path_value_output( rel: pgast.BaseRelation, path_id: irast.PathId, *, env: context.Environment, ) -> pgast.OutputVar: return get_path_output( rel, path_id, aspect=pgce.PathAspect.VALUE, env=env ) def get_path_serialized_or_value_var( rel: pgast.Query, path_id: irast.PathId, *, env: context.Environment) -> pgast.BaseExpr: ref = maybe_get_path_serialized_var(rel, path_id, env=env) if ref is None: ref = get_path_value_var(rel, path_id, env=env) return ref def fix_tuple( rel: pgast.Query, ref: pgast.BaseExpr, *, aspect: pgce.PathAspect, env: context.Environment, ) -> pgast.BaseExpr: if ( isinstance(ref, pgast.TupleVarBase) and not isinstance(ref, pgast.TupleVar) ): elements = [] for el in ref.elements: assert el.path_id is not None var = get_path_var(rel, el.path_id, aspect=aspect, env=env) val = fix_tuple(rel, var, aspect=aspect, env=env) elements.append( pgast.TupleElement( path_id=el.path_id, name=el.name, val=val)) ref = pgast.TupleVar( elements, named=ref.named, typeref=ref.typeref, ) return ref def get_path_serialized_output( rel: pgast.Query, path_id: irast.PathId, *, env: context.Environment, ) -> pgast.OutputVar: # Serialized output is a special case, we don't # want this behaviour to be recursive, so it # must be kept outside of get_path_output() generic. aspect = pgce.PathAspect.SERIALIZED path_id = map_path_id(path_id, rel.view_path_id_map) result = rel.path_outputs.get((path_id, aspect)) if result is not None: return result ref = get_path_serialized_or_value_var(rel, path_id, env=env) if ( isinstance(ref, pgast.TupleVarBase) and not isinstance(ref, pgast.TupleVar) ): elements = [] for el in ref.elements: assert el.path_id is not None val = get_path_serialized_or_value_var(rel, el.path_id, env=env) elements.append( pgast.TupleElement( path_id=el.path_id, name=el.name, val=val)) ref = pgast.TupleVar( elements, named=ref.named, typeref=ref.typeref, ) refexpr = output.serialize_expr(ref, path_id=path_id, env=env) alias = get_path_output_alias(path_id, aspect, env=env) restarget = pgast.ResTarget(name=alias, val=refexpr, ser_safe=True) rel.target_list.append(restarget) result = pgast.ColumnRef( name=[alias], nullable=refexpr.nullable, ser_safe=True) _put_path_output_var(rel, path_id, aspect, result) return result def get_path_output_or_null( rel: pgast.Query, path_id: irast.PathId, *, disable_output_fusion: bool=False, flavor: str='normal', aspect: pgce.PathAspect, env: context.Environment, ) -> tuple[pgast.OutputVar, bool]: path_id = map_path_id(path_id, rel.view_path_id_map) ref = maybe_get_path_output( rel, path_id, disable_output_fusion=disable_output_fusion, flavor=flavor, aspect=aspect, env=env) if ref is not None: return ref, False alt_aspect = get_less_specific_aspect(path_id, aspect) if alt_aspect is not None and flavor == 'normal': # If disable_output_fusion is true, we need to be careful # to not reuse an existing column if disable_output_fusion: preexisting = rel.path_outputs.pop((path_id, alt_aspect), None) ref = maybe_get_path_output( rel, path_id, disable_output_fusion=disable_output_fusion, aspect=alt_aspect, env=env) if disable_output_fusion: # Put back the path_output to whatever it was before if not preexisting: rel.path_outputs.pop((path_id, alt_aspect), None) else: rel.path_outputs[(path_id, alt_aspect)] = preexisting if ref is not None: _put_path_output_var(rel, path_id, aspect, ref) return ref, False alias = env.aliases.get('null') restarget = pgast.ResTarget( name=alias, val=pgast.NullConstant()) rel.target_list.append(restarget) ref = pgast.ColumnRef(name=[alias], nullable=True) _put_path_output_var(rel, path_id, aspect, ref, flavor=flavor) return ref, True def is_nullable( expr: pgast.BaseExpr, *, env: context.Environment) -> Optional[bool]: try: return expr.nullable except AttributeError: if isinstance(expr, pgast.ReturningQuery): tl_len = len(expr.target_list) if tl_len != 1: raise RuntimeError( f'subquery used as a value returns {tl_len} columns') return is_nullable(expr.target_list[0].val, env=env) else: raise ================================================ FILE: edb/pgsql/compiler/relctx.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Compiler routines managing relation ranges and scope.""" from __future__ import annotations from typing import ( Callable, Optional, AbstractSet, Iterable, Sequence, NamedTuple, cast, ) import uuid import immutables as immu from edb import errors from edb.edgeql import qltypes from edb.edgeql import ast as qlast from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.ir import utils as irutils from edb.schema import pointers as s_pointers from edb.schema import name as sn from edb.pgsql import ast as pgast from edb.pgsql import common from edb.pgsql import types as pg_types from . import astutils from . import context from . import dispatch from . import enums as pgce from . import output from . import pathctx def init_toplevel_query( ir_set: irast.Set, *, ctx: context.CompilerContextLevel) -> None: ctx.toplevel_stmt = ctx.stmt = ctx.rel update_scope(ir_set, ctx.rel, ctx=ctx) ctx.pending_query = ctx.rel def _pull_path_namespace( *, target: pgast.Query, source: pgast.PathRangeVar, flavor: str='normal', replace_bonds: bool=True, ctx: context.CompilerContextLevel) -> None: squery = source.query source_qs: list[pgast.BaseRelation] if astutils.is_set_op_query(squery): # Set op query assert squery.larg and squery.rarg source_qs = [squery, squery.larg, squery.rarg] else: source_qs = [squery] for source_q in source_qs: s_paths: set[tuple[irast.PathId, pgce.PathAspect]] = set() if flavor == 'normal': if hasattr(source_q, 'path_outputs'): s_paths.update(source_q.path_outputs) if hasattr(source_q, 'path_namespace'): s_paths.update(source_q.path_namespace) if isinstance(source_q, pgast.Query): s_paths.update(source_q.path_rvar_map) elif flavor == 'packed': if hasattr(source_q, 'packed_path_outputs'): if source_q.packed_path_outputs: s_paths.update(source_q.packed_path_outputs) if isinstance(source_q, pgast.Query): if source_q.path_packed_rvar_map: s_paths.update(source_q.path_packed_rvar_map) else: raise AssertionError(f'unexpected flavor "{flavor}"') view_path_id_map = getattr(source_q, 'view_path_id_map', {}) for path_id, aspect in s_paths: orig_path_id = path_id if flavor != 'packed': path_id = pathctx.reverse_map_path_id( path_id, view_path_id_map) # Skip pulling paths that match the path_id_mask before or after # doing path id mapping. We need to look at before as well # to prevent paths leaking out under a different name. if flavor != 'packed' and ( path_id in squery.path_id_mask or orig_path_id in squery.path_id_mask ): continue rvar = pathctx.maybe_get_path_rvar( target, path_id, aspect=aspect, flavor=flavor ) if rvar is None or flavor == 'packed': pathctx.put_path_rvar( target, path_id, source, aspect=aspect, flavor=flavor ) def pull_path_namespace( *, target: pgast.Query, source: pgast.PathRangeVar, replace_bonds: bool=True, ctx: context.CompilerContextLevel) -> None: for flavor in ('normal', 'packed'): _pull_path_namespace(target=target, source=source, flavor=flavor, replace_bonds=replace_bonds, ctx=ctx) def find_rvar( stmt: pgast.Query, *, flavor: str='normal', source_stmt: Optional[pgast.Query]=None, path_id: irast.PathId, ctx: context.CompilerContextLevel) -> \ Optional[pgast.PathRangeVar]: """Find an existing range var for a given *path_id* in stmt hierarchy. If a range var is visible in a given SQL scope denoted by *stmt*, or, optionally, *source_stmt*, record it on *stmt* for future reference. :param stmt: The statement to ensure range var visibility in. :param flavor: Whether to look for normal rvars or packed rvars :param source_stmt: An optional statement object which is used as the starting SQL scope for range var search. If not specified, *stmt* is used as the starting scope. :param path_id: The path ID of the range var being searched. :param ctx: Compiler context. :return: A range var instance if found, ``None`` otherwise. """ if source_stmt is None: source_stmt = stmt rvar = maybe_get_path_rvar( source_stmt, path_id=path_id, aspect=pgce.PathAspect.VALUE, flavor=flavor, ctx=ctx, ) if rvar is not None: pathctx.put_path_rvar_if_not_exists( stmt, path_id, rvar, aspect=pgce.PathAspect.VALUE, flavor=flavor, ) src_rvar = maybe_get_path_rvar( source_stmt, path_id=path_id, aspect=pgce.PathAspect.SOURCE, flavor=flavor, ctx=ctx ) if src_rvar is not None: pathctx.put_path_rvar_if_not_exists( stmt, path_id, src_rvar, aspect=pgce.PathAspect.SOURCE, flavor=flavor, ) return rvar def include_rvar( stmt: pgast.SelectStmt, rvar: pgast.PathRangeVar, path_id: irast.PathId, *, overwrite_path_rvar: bool=False, pull_namespace: bool=True, update_mask: bool=True, flavor: str='normal', aspects: Optional[ tuple[pgce.PathAspect, ...] | AbstractSet[pgce.PathAspect] ]=None, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: """Ensure that *rvar* is visible in *stmt* as a value/source aspect. :param stmt: The statement to include *rel* in. :param rvar: The range var node to join. :param join_type: JOIN type to use when including *rel*. :param flavor: Whether this is a normal or packed rvar :param aspect: The reference aspect of the range var. :param ctx: Compiler context. """ if aspects is None: aspects = (pgce.PathAspect.VALUE,) if path_id.is_objtype_path(): if isinstance(rvar, pgast.RangeSubselect): if pathctx.has_path_aspect( rvar.query, path_id, aspect=pgce.PathAspect.SOURCE, ): aspects += (pgce.PathAspect.SOURCE,) else: aspects += (pgce.PathAspect.SOURCE,) elif path_id.is_tuple_path(): aspects += (pgce.PathAspect.SOURCE,) return include_specific_rvar( stmt, rvar=rvar, path_id=path_id, overwrite_path_rvar=overwrite_path_rvar, pull_namespace=pull_namespace, update_mask=update_mask, flavor=flavor, aspects=aspects, ctx=ctx) def include_specific_rvar( stmt: pgast.SelectStmt, rvar: pgast.PathRangeVar, path_id: irast.PathId, *, overwrite_path_rvar: bool=False, pull_namespace: bool=True, update_mask: bool=True, flavor: str='normal', aspects: Iterable[pgce.PathAspect]=(pgce.PathAspect.VALUE,), ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: """Make the *aspect* of *path_id* visible in *stmt* as *rvar*. :param stmt: The statement to include *rel* in. :param rvar: The range var node to join. :param join_type: JOIN type to use when including *rel*. :param flavor: Whether this is a normal or packed rvar :param aspect: The reference aspect of the range var. :param ctx: Compiler context. """ if not has_rvar(stmt, rvar, ctx=ctx): rel_join(stmt, rvar, ctx=ctx) # Make sure that the path namespace of *rvar* is mapped # onto the path namespace of *stmt*. if pull_namespace: pull_path_namespace(target=stmt, source=rvar, ctx=ctx) for aspect in aspects: if overwrite_path_rvar: pathctx.put_path_rvar( stmt, path_id, rvar, flavor=flavor, aspect=aspect ) else: pathctx.put_path_rvar_if_not_exists( stmt, path_id, rvar, flavor=flavor, aspect=aspect ) if update_mask: scopes = [ctx.scope_tree] parent_scope = ctx.scope_tree.parent if parent_scope is not None: scopes.append(parent_scope) tpath_id = path_id.tgt_path() if not any(scope.path_id == tpath_id or scope.find_child(tpath_id) for scope in scopes): pathctx.put_path_id_mask(stmt, path_id) return rvar def has_rvar( stmt: pgast.Query, rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel) -> bool: curstmt: Optional[pgast.Query] = stmt if ctx.env.external_rvars and has_external_rvar(rvar, ctx=ctx): return True while curstmt is not None: if pathctx.has_rvar(curstmt, rvar): return True curstmt = ctx.rel_hierarchy.get(curstmt) return False def has_external_rvar( rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel, ) -> bool: return rvar in ctx.env.external_rvars.values() def _maybe_get_path_rvar( stmt: pgast.Query, path_id: irast.PathId, *, flavor: str='normal', aspect: pgce.PathAspect, ctx: context.CompilerContextLevel, ) -> Optional[tuple[pgast.PathRangeVar, irast.PathId]]: rvar = ctx.env.external_rvars.get((path_id, aspect)) if rvar: return rvar, path_id qry: Optional[pgast.Query] = stmt while qry is not None: rvar = pathctx.maybe_get_path_rvar( qry, path_id, aspect=aspect, flavor=flavor ) if rvar is not None: if qry is not stmt: # Cache the rvar reference. pathctx.put_path_rvar( stmt, path_id, rvar, flavor=flavor, aspect=aspect ) return rvar, path_id qry = ctx.rel_hierarchy.get(qry) return None def _get_path_rvar( stmt: pgast.Query, path_id: irast.PathId, *, flavor: str='normal', aspect: pgce.PathAspect, ctx: context.CompilerContextLevel, ) -> tuple[pgast.PathRangeVar, irast.PathId]: result = _maybe_get_path_rvar( stmt, path_id, flavor=flavor, aspect=aspect, ctx=ctx) if result is None: raise LookupError(f'there is no range var for {path_id} in {stmt}') else: return result def get_path_rvar( stmt: pgast.Query, path_id: irast.PathId, *, flavor: str='normal', aspect: pgce.PathAspect, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: return _get_path_rvar( stmt, path_id, flavor=flavor, aspect=aspect, ctx=ctx)[0] def get_path_var( stmt: pgast.Query, path_id: irast.PathId, *, aspect: pgce.PathAspect, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: var = pathctx.maybe_get_path_var( stmt, path_id=path_id, aspect=aspect, env=ctx.env) if var is not None: return var else: rvar, path_id = _get_path_rvar(stmt, path_id, aspect=aspect, ctx=ctx) return pathctx.get_rvar_path_var( rvar, path_id, aspect=aspect, env=ctx.env) def maybe_get_path_rvar( stmt: pgast.Query, path_id: irast.PathId, *, flavor: str='normal', aspect: pgce.PathAspect, ctx: context.CompilerContextLevel, ) -> Optional[pgast.PathRangeVar]: result = _maybe_get_path_rvar(stmt, path_id, aspect=aspect, flavor=flavor, ctx=ctx) return result[0] if result is not None else None def maybe_get_path_var( stmt: pgast.Query, path_id: irast.PathId, *, aspect: pgce.PathAspect, ctx: context.CompilerContextLevel, ) -> Optional[pgast.BaseExpr]: result = _maybe_get_path_rvar(stmt, path_id, aspect=aspect, ctx=ctx) if result is None: return None else: try: return pathctx.get_rvar_path_var( result[0], result[1], aspect=aspect, env=ctx.env) except LookupError: return None def new_empty_rvar( ir_set: irast.SetE[irast.EmptySet], *, ctx: context.CompilerContextLevel) -> pgast.PathRangeVar: nullrel = pgast.NullRelation( path_id=ir_set.path_id, type_or_ptr_ref=ir_set.typeref) rvar = rvar_for_rel(nullrel, ctx=ctx) return rvar def new_free_object_rvar( typeref: irast.TypeRef, path_id: irast.PathId, *, lateral: bool=False, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: """Create a fake source rel for a free object Free objects don't *really* have ids, but the compiler needs all objects to have ids, so we just inject the type id as if it was an id. It shouldn't get used for anything but NULL testing, so no problem. The only thing other than ids that need to come from a free object is __type__, which we inject in a special case way in pathctx.get_path_var. We also have a special case in relgen.ensure_source_rvar to reuse an existing value rvar instead of creating a new root rvar. (We inject __type__ in get_path_var instead of injecting it here because we don't have the pathid for it available to us here and because it allows ensure_source_rvar to simply reuse a value rvar.) """ with ctx.subrel() as subctx: qry = subctx.rel id_expr = astutils.compile_typeref(typeref.real_material_type) pathctx.put_path_identity_var(qry, path_id, id_expr) pathctx.put_path_value_var(qry, path_id, id_expr) return rvar_for_rel(qry, typeref=typeref, lateral=lateral, ctx=ctx) def deep_copy_primitive_rvar_path_var( orig_id: irast.PathId, new_id: irast.PathId, rvar: pgast.PathRangeVar, *, env: context.Environment ) -> None: """Copy one identity path to another in a primitive rvar. The trickiness here is because primitive rvars might have an overlay stack, which means if they are joined on, it might be using _lateral_union_join, which requires every component of the union to have all the path bonds. """ if isinstance(rvar, pgast.RangeSubselect): for component in astutils.each_query_in_set(rvar.query): rref = pathctx.get_path_var( component, orig_id, aspect=pgce.PathAspect.IDENTITY, env=env ) pathctx.put_path_var( component, new_id, rref, aspect=pgce.PathAspect.IDENTITY, ) else: rref = pathctx.get_path_output( rvar.query, orig_id, aspect=pgce.PathAspect.IDENTITY, env=env ) pathctx.put_rvar_path_output( rvar, new_id, aspect=pgce.PathAspect.IDENTITY, var=rref, ) def new_primitive_rvar( ir_set: irast.Set, *, path_id: irast.PathId, lateral: bool, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: # XXX: is this needed? expr = irutils.sub_expr(ir_set) if isinstance(expr, irast.TypeRoot): skip_subtypes = expr.skip_subtypes is_global = expr.is_cached_global else: skip_subtypes = False is_global = False typeref = ir_set.typeref dml_source = irutils.get_dml_sources(ir_set, ctx.env.binding_dml) set_rvar = range_for_typeref( typeref, path_id, lateral=lateral, dml_source=dml_source, include_descendants=not skip_subtypes, ignore_rewrites=ir_set.ignore_rewrites, is_global=is_global, ctx=ctx, ) pathctx.put_rvar_path_bond(set_rvar, path_id) # FIXME: This feels like it should all not be here. if isinstance(ir_set.expr, irast.Pointer): rptr = ir_set.expr if ( isinstance(rptr.ptrref, irast.TypeIntersectionPointerRef) and isinstance(rptr.source.expr, irast.Pointer) ): rptr = rptr.source.expr # If the set comes from an backlink, and the link is stored inline, # we want to output the source path. if ( rptr.is_inbound and ( rptrref := irtyputils.maybe_find_actual_ptrref( set_rvar.typeref, rptr.ptrref) or rptr.ptrref if set_rvar.typeref else rptr.ptrref ) and ( ptr_info := pg_types.get_ptrref_storage_info( rptrref, resolve_type=False, link_bias=False, allow_missing=True) ) and ptr_info.table_type == 'ObjectType' ): # Inline link prefix_path_id = path_id.src_path() assert prefix_path_id is not None, 'expected a path' flipped_id = path_id.extend(ptrref=rptrref) # Unfortunately we can't necessarily just install the # prefix path id path---the rvar from range_from_typeref # might be a DML overlay, which means joins on it will try # to use _lateral_union_join; this means that all of the # path bonds need to be valid on each *subquery*, so we # need to set them up in each subquery. deep_copy_primitive_rvar_path_var( flipped_id, prefix_path_id, set_rvar, env=ctx.env) pathctx.put_rvar_path_bond(set_rvar, prefix_path_id) return set_rvar def new_root_rvar( ir_set: irast.Set, *, lateral: bool = False, path_id: Optional[irast.PathId] = None, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: if path_id is None: path_id = ir_set.path_id if irtyputils.is_free_object(path_id.target): return new_free_object_rvar( path_id.target, path_id, lateral=lateral, ctx=ctx) narrowing = ctx.intersection_narrowing.get(ir_set) if narrowing is not None: ir_set = narrowing return new_primitive_rvar( ir_set, lateral=lateral, path_id=path_id, ctx=ctx) def new_pointer_rvar( ir_set: irast.SetE[irast.Pointer], *, link_bias: bool=False, src_rvar: pgast.PathRangeVar, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: ir_ptr = ir_set.expr ptrref = ir_ptr.ptrref link_bias = link_bias or ir_ptr.force_link_table ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=False, link_bias=link_bias, allow_missing=True, versioned=ctx.env.versioned_stdlib, ) if ptr_info and ptr_info.table_type == 'ObjectType': # Inline link return _new_inline_pointer_rvar( ir_set, src_rvar=src_rvar, ctx=ctx ) else: return _new_mapped_pointer_rvar(ir_set, ctx=ctx) def _new_inline_pointer_rvar( ir_set: irast.SetE[irast.Pointer], *, lateral: bool=True, src_rvar: pgast.PathRangeVar, ctx: context.CompilerContextLevel) -> pgast.PathRangeVar: ir_ptr = ir_set.expr ptr_rel = pgast.SelectStmt() ptr_rvar = rvar_for_rel(ptr_rel, lateral=lateral, ctx=ctx) ptr_rvar.query.path_id = ir_set.path_id.ptr_path() is_inbound = ir_ptr.direction == s_pointers.PointerDirection.Inbound if is_inbound: far_pid = ir_ptr.source.path_id else: far_pid = ir_set.path_id far_ref = pathctx.get_rvar_path_identity_var( src_rvar, far_pid, env=ctx.env) pathctx.put_rvar_path_bond(ptr_rvar, far_pid) pathctx.put_path_identity_var(ptr_rel, far_pid, var=far_ref) return ptr_rvar def _new_mapped_pointer_rvar( ir_set: irast.SetE[irast.Pointer], *, ctx: context.CompilerContextLevel) -> pgast.PathRangeVar: ir_ptr = ir_set.expr ptrref = ir_ptr.ptrref dml_source = irutils.get_dml_sources(ir_ptr.source, ctx.env.binding_dml) ptr_rvar = range_for_pointer(ir_set, dml_source=dml_source, ctx=ctx) src_col = 'source' source_ref = pgast.ColumnRef(name=[src_col], nullable=False) tgt_col = 'target' target_ref = pgast.ColumnRef( name=[tgt_col], nullable=not ptrref.required) if ( ir_ptr.direction == s_pointers.PointerDirection.Inbound or ptrref.computed_link_alias_is_backward ): near_ref = target_ref far_ref = source_ref else: near_ref = source_ref far_ref = target_ref src_pid = ir_ptr.source.path_id tgt_pid = ir_set.path_id ptr_pid = tgt_pid.ptr_path() ptr_rvar.query.path_id = ptr_pid pathctx.put_rvar_path_bond(ptr_rvar, src_pid) pathctx.put_rvar_path_output( ptr_rvar, src_pid, aspect=pgce.PathAspect.IDENTITY, var=near_ref ) pathctx.put_rvar_path_output( ptr_rvar, src_pid, aspect=pgce.PathAspect.VALUE, var=near_ref ) pathctx.put_rvar_path_output( ptr_rvar, tgt_pid, aspect=pgce.PathAspect.VALUE, var=far_ref ) if tgt_pid.is_objtype_path(): pathctx.put_rvar_path_bond(ptr_rvar, tgt_pid) pathctx.put_rvar_path_output( ptr_rvar, tgt_pid, aspect=pgce.PathAspect.IDENTITY, var=far_ref ) return ptr_rvar def is_pointer_rvar( rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel, ) -> bool: return rvar.query.path_id is not None and rvar.query.path_id.is_ptr_path() def new_rel_rvar( ir_set: irast.Set, stmt: pgast.Query, *, lateral: bool=True, ctx: context.CompilerContextLevel) -> pgast.PathRangeVar: return rvar_for_rel(stmt, typeref=ir_set.typeref, lateral=lateral, ctx=ctx) def semi_join( stmt: pgast.SelectStmt, ir_set: irast.SetE[irast.Pointer], src_rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel) -> pgast.PathRangeVar: """Join an IR Set using semi-join.""" rptr = ir_set.expr # Target set range. set_rvar = new_root_rvar(ir_set, lateral=True, ctx=ctx) ptrref = rptr.ptrref ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=False, allow_missing=True) if ptr_info and ptr_info.table_type == 'ObjectType': if rptr.is_inbound: far_pid = ir_set.path_id.src_path() assert far_pid is not None else: far_pid = ir_set.path_id else: far_pid = ir_set.path_id # Link range. map_rvar = new_pointer_rvar(ir_set, src_rvar=src_rvar, ctx=ctx) include_rvar( ctx.rel, map_rvar, path_id=ir_set.path_id.ptr_path(), ctx=ctx) tgt_ref = pathctx.get_rvar_path_identity_var( set_rvar, far_pid, env=ctx.env) pathctx.get_path_identity_output( ctx.rel, far_pid, env=ctx.env) cond = astutils.new_binop(tgt_ref, ctx.rel, 'IN') stmt.where_clause = astutils.extend_binop( stmt.where_clause, cond) return set_rvar def apply_volatility_ref( stmt: pgast.SelectStmt, *, ctx: context.CompilerContextLevel) -> None: for ref in ctx.volatility_ref: # Apply the volatility reference. # See the comment in process_set_as_subquery(). arg = ref(stmt, ctx) if not arg: continue stmt.where_clause = astutils.extend_binop( stmt.where_clause, pgast.NullTest( arg=arg, negated=True, ) ) def create_iterator_identity_for_path( path_id: irast.PathId, stmt: pgast.BaseRelation, *, ctx: context.CompilerContextLevel, apply_volatility: bool=True, ) -> None: id_expr = pgast.FuncCall( name=astutils.edgedb_func('uuid_generate_v4', ctx=ctx), args=[], ) if isinstance(stmt, pgast.SelectStmt): path_id = pathctx.map_path_id(path_id, stmt.view_path_id_map) if apply_volatility: apply_volatility_ref(stmt, ctx=ctx) pathctx.put_path_var( stmt, path_id, id_expr, force=True, aspect=pgce.PathAspect.ITERATOR, ) pathctx.put_path_bond(stmt, path_id, iterator=True) def get_scope( ir_set: irast.Set, *, ctx: context.CompilerContextLevel, ) -> Optional[irast.ScopeTreeNode]: result: Optional[irast.ScopeTreeNode] = None if ir_set.path_scope_id is not None: result = ctx.env.scope_tree_nodes.get(ir_set.path_scope_id) return result def update_scope( ir_set: irast.Set, stmt: pgast.SelectStmt, *, ctx: context.CompilerContextLevel) -> None: """Update the scope of an ir set to be a pg stmt. If ir_set has a scope node associated with it, update path_scope so that any paths bound in that scope will be compiled in the context of stmt. This, combined with maybe_get_scope_stmt, is the mechanism by which the scope tree influences the shape of the output query. """ scope_tree = get_scope(ir_set, ctx=ctx) if scope_tree is None: return ctx.scope_tree = scope_tree ctx.path_scope = ctx.path_scope.new_child() # Register paths in the current scope to be compiled as a subrel # of stmt. for p in scope_tree.path_children: assert p.path_id is not None ctx.path_scope[p.path_id] = stmt def update_scope_masks( ir_set: irast.Set, rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel) -> None: if not isinstance(rvar, pgast.RangeSubselect): return stmt = rvar.subquery # Mark any paths under the scope tree as masked, so that they # won't get picked up by pull_path_namespace. for child_path in ctx.scope_tree.get_all_paths(): pathctx.put_path_id_mask(stmt, child_path) # If this is an optional scope node, we need to be certain that # we don't leak out any paths that collide with a visible non-optional # path. # See test_edgeql_optional_leakage_01 for one case where this comes up. # # FIXME: I actually think we ought to be able to mask off visible # paths in *most* cases, but when I tried it I ran into trouble # with some DML linkprop cases (probably easy to fix) and a number # of materialization cases (possibly hard to fix), so I'm going # with a more conservative approach. if ctx.scope_tree.is_optional(ir_set.path_id): # Since compilation is done, anything visible to us *will* be # up on the spine. Anything tucked away under a node must have # been pulled up. for anc in ctx.scope_tree.ancestors: for direct_child in anc.path_children: if not direct_child.optional: pathctx.put_path_id_mask(stmt, direct_child.path_id) def maybe_get_scope_stmt( path_id: irast.PathId, *, ctx: context.CompilerContextLevel, ) -> Optional[pgast.SelectStmt]: stmt = ctx.path_scope.get(path_id) if stmt is None and path_id.is_ptr_path(): stmt = ctx.path_scope.get(path_id.tgt_path()) return stmt def set_to_array( path_id: irast.PathId, query: pgast.Query, *, for_group_by: bool=False, ctx: context.CompilerContextLevel) -> pgast.Query: """Collapse a set into an array.""" subrvar = pgast.RangeSubselect( subquery=query, alias=pgast.Alias( aliasname=ctx.env.aliases.get('aggw') ) ) result = pgast.SelectStmt() aspects = pathctx.list_path_aspects(subrvar.query, path_id) include_rvar(result, subrvar, path_id=path_id, aspects=aspects, ctx=ctx) val: Optional[pgast.BaseExpr] = ( pathctx.maybe_get_path_serialized_var( result, path_id, env=ctx.env) ) if val is None: value_var = pathctx.get_path_value_var(result, path_id, env=ctx.env) val = output.serialize_expr(value_var, path_id=path_id, env=ctx.env) pathctx.put_path_serialized_var(result, path_id, val, force=True) if isinstance(val, pgast.TupleVarBase): val = output.serialize_expr( val, path_id=path_id, env=ctx.env) pg_type = output.get_pg_type(path_id.target, ctx=ctx) agg_filter_safe = True if for_group_by: # When doing this as part of a GROUP, the stuff being aggregated # needs to actually appear *inside* of the aggregate call... result.target_list = [pgast.ResTarget(val=val, ser_safe=val.ser_safe)] val = result try_collapse = astutils.collapse_query(val) if isinstance(try_collapse, pgast.ColumnRef): val = try_collapse else: agg_filter_safe = False result = pgast.SelectStmt() orig_val = val if (path_id.is_array_path() and ctx.env.output_format is context.OutputFormat.NATIVE): # We cannot aggregate arrays straight away, as # they be of different length, so we have to # encase each element into a record. val = pgast.RowExpr(args=[val], ser_safe=val.ser_safe) pg_type = ('record',) array_agg = pgast.FuncCall( name=('array_agg',), args=[val], agg_filter=( astutils.new_binop(orig_val, pgast.NullConstant(), 'IS DISTINCT FROM') if orig_val.nullable and agg_filter_safe else None ), ser_safe=val.ser_safe, ) # If this is for a group by, and the body isn't just a column ref, # then we need to remove NULLs after the fact. if orig_val.nullable and not agg_filter_safe: array_agg = pgast.FuncCall( name=('array_remove',), args=[array_agg, pgast.NullConstant()] ) agg_expr = pgast.CoalesceExpr( args=[ array_agg, pgast.TypeCast( arg=pgast.ArrayExpr(elements=[]), type_name=pgast.TypeName(name=pg_type, array_bounds=[-1]) ) ], ser_safe=array_agg.ser_safe, nullable=False, ) result.target_list = [ pgast.ResTarget( name=ctx.env.aliases.get('v'), val=agg_expr, ser_safe=agg_expr.ser_safe, ) ] return result class UnpackElement(NamedTuple): path_id: irast.PathId colname: str packed: bool multi: bool ref: Optional[pgast.BaseExpr] def unpack_rvar( stmt: pgast.SelectStmt, path_id: irast.PathId, *, packed_rvar: pgast.PathRangeVar, ctx: context.CompilerContextLevel) -> pgast.PathRangeVar: ref = pathctx.get_rvar_path_var( packed_rvar, path_id, aspect=pgce.PathAspect.VALUE, flavor='packed', env=ctx.env, ) return unpack_var(stmt, path_id, ref=ref, ctx=ctx) def unpack_var( stmt: pgast.SelectStmt, path_id: irast.PathId, *, ref: pgast.OutputVar, ctx: context.CompilerContextLevel) -> pgast.PathRangeVar: qry = pgast.SelectStmt() view_tvars: list[tuple[irast.PathId, pgast.TupleVarBase, bool]] = [] els = [] ctr = 0 def walk(ref: pgast.BaseExpr, path_id: irast.PathId, multi: bool) -> None: nonlocal ctr coldeflist = [] alias = ctx.env.aliases.get('unpack') simple = False if irtyputils.is_tuple(path_id.target): els.append(UnpackElement( path_id, alias, packed=False, multi=False, ref=None )) orig_view_count = len(view_tvars) tuple_tvar_elements = [] for i, st in enumerate(path_id.target.subtypes): colname = f'_t{ctr}' ctr += 1 typ = pg_types.pg_type_from_ir_typeref(st) if st.id in ctx.env.materialized_views: typ = ('record',) # Construct a path_id for the element el_name = sn.QualName('__tuple__', st.element_name or str(i)) el_ref = irast.TupleIndirectionPointerRef( name=el_name, shortname=el_name, out_source=path_id.target, out_target=st, out_cardinality=qltypes.Cardinality.ONE, ) el_path_id = path_id.extend(ptrref=el_ref) el_var = ( astutils.tuple_getattr( pgast.ColumnRef(name=[alias]), path_id.target, el_name.name) if irtyputils.is_persistent_tuple(path_id.target) else pgast.ColumnRef(name=[colname]) ) walk(el_var, el_path_id, multi=False) tuple_tvar_elements.append( pgast.TupleElementBase( path_id=el_path_id, name=el_name.name ) ) coldeflist.append( pgast.ColumnDef( name=colname, typename=pgast.TypeName(name=typ) ) ) if len(view_tvars) > orig_view_count: tuple_tvar = pgast.TupleVarBase( elements=tuple_tvar_elements, typeref=path_id.target, named=any( st.element_name for st in path_id.target.subtypes), ) view_tvars.append((path_id, tuple_tvar, True)) if irtyputils.is_persistent_tuple(path_id.target): coldeflist = [] elif irtyputils.is_array(path_id.target) and multi: # TODO: materialized arrays of tuples and arrays are really # quite broken coldeflist = [ pgast.ColumnDef( name='q', typename=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref( path_id.target) ) ) ] els.append(UnpackElement( path_id, coldeflist[0].name, packed=False, multi=False, ref=None )) elif path_id.target.id in ctx.env.materialized_views: view_tuple = ctx.env.materialized_views[path_id.target.id] vpath_ids = [] id_idx = None for el, _ in view_tuple.shape: src_path, el_ptrref = el.path_id.src_path(), el.path_id.rptr() assert src_path and el_ptrref el_id = path_id.ptr_path().extend(ptrref=el_ptrref) card = el.expr.dir_cardinality is_singleton = card.is_single() and not card.can_be_zero() must_pack = not is_singleton if (rptr_name := el_id.rptr_name()) and rptr_name.name == 'id': id_idx = len(els) colname = f'_t{ctr}' ctr += 1 typ = pg_types.pg_type_from_ir_typeref(el_id.target) if el_id.target.id in ctx.env.materialized_views: typ = ('record',) must_pack = True if not is_singleton: # Arrays get wrapped in a record before they can be put # in another array if el_id.is_array_path(): typ = ('record',) must_pack = True typ = pg_types.pg_type_array(typ) coldeflist.append( pgast.ColumnDef( name=colname, typename=pgast.TypeName(name=typ), ) ) els.append(UnpackElement( el_id, colname, packed=must_pack, multi=not is_singleton, ref=None )) vpath_ids.append(el_id) if id_idx is not None: els.append(UnpackElement( path_id, els[id_idx].colname, multi=False, packed=False, ref=None, )) else: colname = f'_t{ctr}' ctr += 1 coldeflist.append( pgast.ColumnDef( name=colname, typename=pgast.TypeName(name=('uuid',)), ) ) els.append(UnpackElement( path_id, colname, packed=False, multi=False, ref=None )) view_tvars.append((path_id, pgast.TupleVarBase( elements=[ pgast.TupleElementBase( path_id=pid, name=astutils.tuple_element_for_shape_el( el, ctx=ctx).name, ) for (el, op), pid in zip(view_tuple.shape, vpath_ids) if op != qlast.ShapeOp.MATERIALIZE or ctx.materializing ], typeref=path_id.target, named=True, ), False)) else: coldeflist = [] simple = not multi els.append(UnpackElement( path_id, alias, multi=False, packed=False, ref=ref if simple else None, )) if not simple: if not multi: # Sigh, have to wrap in an array so we can unpack. ref = pgast.ArrayExpr(elements=[ref]) qry.from_clause.insert( 0, pgast.RangeFunction( alias=pgast.Alias( aliasname=alias, ), is_rowsfrom=True, functions=[ pgast.FuncCall( name=('unnest',), args=[ref], coldeflist=coldeflist, ) ] ) ) ######################## walk(ref, path_id, ref.is_packed_multi) rvar = rvar_for_rel(qry, lateral=True, ctx=ctx) include_rvar( stmt, rvar, path_id=path_id, aspects=(pgce.PathAspect.VALUE,), ctx=ctx, ) for el in els: el_id = el.path_id cur_ref = el.ref or pgast.ColumnRef(name=[el.colname]) for aspect in (pgce.PathAspect.VALUE, pgce.PathAspect.SERIALIZED): pathctx.put_path_var(qry, el_id, cur_ref, aspect=aspect) if not el.packed: pathctx.put_path_rvar( stmt, el_id, rvar, aspect=pgce.PathAspect.VALUE, ) pathctx.put_path_rvar( ctx.rel, el_id, rvar, aspect=pgce.PathAspect.VALUE, ) else: cref = pathctx.get_path_output( qry, el_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) cref = cref.replace(is_packed_multi=el.multi) pathctx.put_path_packed_output(qry, el_id, val=cref) pathctx.put_path_rvar( stmt, el_id, rvar, flavor='packed', aspect=pgce.PathAspect.VALUE, ) # When we're producing an exposed shape, we need to rewrite the # serialized shape. # We also need to rewrite tuples that contain such shapes! # What a pain! # # We *also* need to rewrite tuple values, so that we don't consider # serialized materialized objects as part of the value of the tuple for view_path_id, view_tvar, is_tuple in view_tvars: if not view_tvar.elements: continue rewrite_aspects = [] if ctx.expr_exposed and not is_tuple: rewrite_aspects.append(pgce.PathAspect.SERIALIZED) if is_tuple: rewrite_aspects.append(pgce.PathAspect.VALUE) # Reserialize links if we are producing final output if ( ctx.expr_exposed and not ctx.materializing and not is_tuple ): for tel in view_tvar.elements: el = [x for x in els if x.path_id == tel.path_id][0] if not el.packed: continue reqry = reserialize_object(el, tel, ctx=ctx) pathctx.put_path_var( qry, tel.path_id, reqry, aspect=pgce.PathAspect.SERIALIZED, force=True, ) for aspect in rewrite_aspects: tv = pathctx.fix_tuple(qry, view_tvar, aspect=aspect, env=ctx.env) sval = ( output.output_as_value(tv, env=ctx.env) if aspect == pgce.PathAspect.VALUE else output.serialize_expr(tv, path_id=view_path_id, env=ctx.env) ) pathctx.put_path_var( qry, view_path_id, sval, aspect=aspect, force=True ) pathctx.put_path_rvar(ctx.rel, view_path_id, rvar, aspect=aspect) return rvar def reserialize_object( el: UnpackElement, tel: pgast.TupleElementBase, *, ctx: context.CompilerContextLevel) -> pgast.Query: tref = pgast.ColumnRef(name=[el.colname], is_packed_multi=el.multi) with ctx.subrel() as subctx: sub_rvar = unpack_var(subctx.rel, tel.path_id, ref=tref, ctx=subctx) reqry = sub_rvar.query assert isinstance(reqry, pgast.Query) rptr = tel.path_id.rptr() pathctx.get_path_serialized_output(reqry, tel.path_id, env=ctx.env) assert rptr if rptr.out_cardinality.is_multi(): with ctx.subrel() as subctx: reqry = set_to_array( path_id=tel.path_id, query=reqry, ctx=subctx) return reqry def get_scope_stmt( path_id: irast.PathId, *, ctx: context.CompilerContextLevel, ) -> pgast.SelectStmt: stmt = maybe_get_scope_stmt(path_id, ctx=ctx) if stmt is None: raise LookupError(f'cannot find scope statement for {path_id}') else: return stmt def rel_join( query: pgast.SelectStmt, right_rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel, ) -> None: if ( isinstance(right_rvar, pgast.RangeSubselect) and astutils.is_set_op_query(right_rvar.subquery) and right_rvar.tag == "overlay-stack" and all(isinstance(q, pgast.SelectStmt) for q in astutils.each_query_in_set(right_rvar.subquery)) and not is_pointer_rvar(right_rvar, ctx=ctx) ): # Unfortunately Postgres sometimes produces a very bad plan # when we join a UNION which is not a trivial Append, most notably # those produced by DML overlays. To work around this we push # the JOIN condition into the WHERE clause of each UNION component. # While this is likely not harmful (and possibly beneficial) for # all kinds of UNIONs, we restrict this optimization to overlay # UNIONs only to limit the possibility of breakage as not all # UNIONs are guaranteed to have correct path namespace and # translation maps set up. _lateral_union_join(query, right_rvar, ctx=ctx) else: _plain_join(query, right_rvar, ctx=ctx) def _plain_join( query: pgast.SelectStmt, right_rvar: pgast.PathRangeVar, *, ctx: context.CompilerContextLevel, ) -> None: condition = None for path_id, iterator_var in right_rvar.query.path_bonds: lref = None aspect = ( pgce.PathAspect.ITERATOR if iterator_var else pgce.PathAspect.IDENTITY ) lref = maybe_get_path_var( query, path_id, aspect=aspect, ctx=ctx) if lref is None and not iterator_var: lref = maybe_get_path_var( query, path_id, aspect=pgce.PathAspect.VALUE, ctx=ctx, ) if lref is None: continue rref = pathctx.get_rvar_path_var( right_rvar, path_id, aspect=aspect, env=ctx.env) assert isinstance(lref, pgast.ColumnRef) assert isinstance(rref, pgast.ColumnRef) path_cond = astutils.join_condition(lref, rref) condition = astutils.extend_binop(condition, path_cond) if condition is None: join_type = 'cross' else: join_type = 'inner' if not query.from_clause: query.from_clause.append(right_rvar) if condition is not None: query.where_clause = astutils.extend_binop( query.where_clause, condition) else: larg = query.from_clause[0] rarg = right_rvar query.from_clause[0] = pgast.JoinExpr.make_inplace( type=join_type, larg=larg, rarg=rarg, quals=condition) def _lateral_union_join( query: pgast.SelectStmt, right_rvar: pgast.RangeSubselect, *, ctx: context.CompilerContextLevel, ) -> None: # Inject the filter into every subquery for component in astutils.each_query_in_set(right_rvar.subquery): condition = None for path_id, iterator_var in right_rvar.query.path_bonds: aspect = ( pgce.PathAspect.ITERATOR if iterator_var else pgce.PathAspect.IDENTITY ) lref = maybe_get_path_var( query, path_id, aspect=aspect, ctx=ctx) if lref is None and not iterator_var: lref = maybe_get_path_var( query, path_id, aspect=pgce.PathAspect.VALUE, ctx=ctx, ) if lref is None: continue rref = pathctx.get_path_var( component, path_id, aspect=aspect, env=ctx.env) assert isinstance(lref, pgast.ColumnRef) assert isinstance(rref, pgast.ColumnRef) path_cond = astutils.join_condition(lref, rref) condition = astutils.extend_binop(condition, path_cond) if condition is not None: assert isinstance(component, pgast.SelectStmt) component.where_clause = astutils.extend_binop( component.where_clause, condition) # Do the actual join if not query.from_clause: query.from_clause.append(right_rvar) else: larg = query.from_clause[0] rarg = right_rvar query.from_clause[0] = pgast.JoinExpr.make_inplace( type='cross', larg=larg, rarg=rarg) def _needs_cte(typeref: irast.TypeRef) -> bool: """Check whether a typeref needs to be forced into a materialized CTE. The main use case here is for sys::SystemObjects which are stored as views that populate their data by parsing JSON metadata embedded in comments on the SQL system objects. The query plans when fetching multi links from these objects wind up being pretty pathologically quadratic. So instead we force the objects and links into materialized CTEs so that they *can't* be shoved into nested loops. """ assert isinstance(typeref.name_hint, sn.QualName) return typeref.name_hint.module == 'sys' def range_for_material_objtype( typeref: irast.TypeRef, path_id: irast.PathId, *, for_mutation: bool=False, lateral: bool=False, include_overlays: bool=True, include_descendants: bool=True, ignore_rewrites: bool=False, is_global: bool=False, dml_source: Sequence[irast.MutatingLikeStmt]=(), ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: if not is_global: typeref = typeref.real_material_type if not is_global and not path_id.is_objtype_path(): raise ValueError('cannot create root rvar for non-object path') assert isinstance(typeref.name_hint, sn.QualName) dml_source_key = ( frozenset(dml_source) if ctx.trigger_mode and dml_source else None ) rw_key = (typeref.id, include_descendants) key = rw_key + (dml_source_key,) force_cte = _needs_cte(typeref) if ( (not ignore_rewrites or is_global) and ( (rewrite := ctx.env.type_rewrites.get(rw_key)) is not None or force_cte ) and rw_key not in ctx.pending_type_rewrite_ctes and not for_mutation ): if not rewrite: # If we are forcing CTE materialization but there is not a # real rewrite, then create a trivial one. rewrite = irast.Set( path_id=irast.PathId.from_typeref(typeref, namespace={'rw'}), typeref=typeref, expr=irast.TypeRoot(typeref=typeref), ) # Don't include overlays in the normal way in trigger mode # when a type cte is used, because we bake the overlays into # the cte instead (and so including them normally could union # back in things that we have filtered out). # # We *don't* do this if we actually have DML sources; if it is # __old__, because we don't want overlays at all and for # __new__ and for actual DML because we want the overlays to # apply after policies. trigger_mode = ctx.trigger_mode and not dml_source if trigger_mode: include_overlays = False type_rel: pgast.BaseRelation | pgast.CommonTableExpr if (type_cte := ctx.type_rewrite_ctes.get(key)) is None: with ctx.newrel() as sctx: sctx.pending_type_rewrite_ctes.add(rw_key) sctx.pending_query = sctx.rel # Normally we want to compile type rewrites without # polluting them with any sort of overlays, but when # compiling triggers, we recompile all of the type # rewrites *to include* overlays, so that we can't peek # at all newly created objects that we can't see if not trigger_mode: sctx.rel_overlays = context.RelOverlays() dispatch.visit(rewrite, ctx=sctx) # If we are explaining, we also expand type # rewrites, so don't populate type_ctes. The normal # case is to stick it in a CTE and cache that, though. if ctx.env.is_explain and not is_global: type_rel = sctx.rel else: type_cte = pgast.CommonTableExpr( name=ctx.env.aliases.get(f't_{typeref.name_hint}'), query=sctx.rel, materialized=is_global or force_cte, ) ctx.type_rewrite_ctes[key] = type_cte ctx.ordered_type_ctes.append(type_cte) type_rel = type_cte else: type_rel = type_cte with ctx.subrel() as sctx: cte_rvar = rvar_for_rel( type_rel, typeref=typeref, alias=ctx.env.aliases.get('t'), ctx=ctx, ) pathctx.put_path_id_map(sctx.rel, path_id, rewrite.path_id) include_rvar( sctx.rel, cte_rvar, rewrite.path_id, pull_namespace=False, ctx=sctx, ) rvar = rvar_for_rel( sctx.rel, lateral=lateral, typeref=typeref, ctx=sctx) else: assert not typeref.is_view, "attempting to generate range from view" typeref_descendants = _get_typeref_descendants( typeref, include_descendants=include_descendants, for_mutation=for_mutation, ) if ( # When we are compiling a query for EXPLAIN, expand out type # references to an explicit union of all the types, rather than # using a CTE. This allows postgres to actually give us back the # alias names that we use for relations, which we use to track which # parts of the query are being referred to. ctx.env.is_explain # Don't use CTEs if there is no inheritance. (ie. There is only a # single material type) or len(typeref_descendants) <= 1 ): inheritance_selects = _selects_for_typeref_descendants( typeref_descendants, path_id, ctx=ctx, ) ops = [ (context.OverlayOp.UNION, select) for select in inheritance_selects ] rvar = range_from_queryset( ops, typeref.name_hint, lateral=lateral, path_id=path_id, typeref=typeref, tag='expanded-inhview', ctx=ctx, ) else: typeref_path: irast.PathId = irast.PathId.from_typeref( typeref, # If there are backlinks and the path revisits a type, a # semi-join is produced. This ensures that the rvar produced # does not have a duplicate path var. # For example: (select A.b. list[irast.TypeRef]: if ( include_descendants and not for_mutation ): descendants = [ typeref, *( descendant for descendant in irtyputils.get_typeref_descendants(typeref) # XXX: Exclude sys/cfg tables from non sys/cfg inheritance CTEs. # This probably isn't *really* what we want to do, but until we # figure that out, do *something* so that DDL isn't # excruciatingly slow because of the cost of explicit id # checks. See #5168. if ( not descendant.is_cfg_view or typeref.is_cfg_view ) ) ] # Try to only select from actual concrete types. concrete_descendants = [ subref for subref in descendants if not subref.is_abstract ] # If there aren't any concrete types, we still need to # generate *something*, so just do the initial one. if concrete_descendants: return concrete_descendants else: return [typeref] else: return [typeref] def _selects_for_typeref_descendants( typeref_descendants: Sequence[irast.TypeRef], path_id: irast.PathId, *, ctx: context.CompilerContextLevel, ) -> list[pgast.SelectStmt]: selects = [] for subref in typeref_descendants: rvar = _table_from_typeref( subref, path_id, ctx=ctx, ) qry = pgast.SelectStmt(from_clause=[rvar]) sub_path_id = path_id pathctx.put_path_value_rvar(qry, sub_path_id, rvar) pathctx.put_path_source_rvar(qry, sub_path_id, rvar) selects.append(qry) return selects def _table_from_typeref( typeref: irast.TypeRef, path_id: irast.PathId, *, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: assert isinstance(typeref.name_hint, sn.QualName) aspect = 'table' table_schema_name, table_name = common.get_objtype_backend_name( typeref.id, typeref.name_hint.module, aspect=aspect, catenate=False, versioned=ctx.env.versioned_stdlib, ) relation = pgast.Relation( schemaname=table_schema_name, name=table_name, path_id=path_id, type_or_ptr_ref=typeref, ) return pgast.RelRangeVar( relation=relation, typeref=typeref, alias=pgast.Alias( aliasname=ctx.env.aliases.get(typeref.name_hint.name) ) ) def range_for_typeref( typeref: irast.TypeRef, path_id: irast.PathId, *, lateral: bool=False, for_mutation: bool=False, include_descendants: bool=True, ignore_rewrites: bool=False, is_global: bool=False, dml_source: Sequence[irast.MutatingLikeStmt]=(), ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: if typeref.union: # Union object types are represented as a UNION of selects # from their children, which is, for most purposes, equivalent # to SELECTing from a parent table. set_ops = [] # Concrete unions might have view type elements with duplicate # material types, and we need to filter those out. seen = set() for child in typeref.union: mat_child = child.material_type or child if mat_child.id in seen: assert typeref.union_is_exhaustive continue seen.add(mat_child.id) c_rvar = range_for_typeref( child, path_id=path_id, include_descendants=not typeref.union_is_exhaustive, for_mutation=for_mutation, dml_source=dml_source, lateral=lateral, ctx=ctx, ) qry = pgast.SelectStmt( from_clause=[c_rvar], ) pathctx.put_path_value_rvar(qry, path_id, c_rvar) if path_id.is_objtype_path(): pathctx.put_path_source_rvar(qry, path_id, c_rvar) pathctx.put_path_bond(qry, path_id) set_ops.append((context.OverlayOp.UNION, qry)) rvar = range_from_queryset( set_ops, typeref.name_hint, lateral=lateral, typeref=typeref, ctx=ctx, ) else: rvar = range_for_material_objtype( typeref, path_id, lateral=lateral, include_descendants=include_descendants, ignore_rewrites=ignore_rewrites, include_overlays=not for_mutation, for_mutation=for_mutation, is_global=is_global, dml_source=dml_source, ctx=ctx, ) rvar.query.path_id = path_id return rvar def wrap_set_op_query( qry: pgast.SelectStmt, *, ctx: context.CompilerContextLevel ) -> pgast.SelectStmt: if astutils.is_set_op_query(qry): rvar = rvar_for_rel(qry, ctx=ctx) nqry = pgast.SelectStmt(from_clause=[rvar]) nqry.target_list = [ pgast.ResTarget( name=col.name, val=pgast.ColumnRef( name=[rvar.alias.aliasname, col.name], ) ) for col in astutils.get_leftmost_query(qry).target_list if col.name ] pull_path_namespace(target=nqry, source=rvar, ctx=ctx) qry = nqry return qry def anti_join( lhs: pgast.SelectStmt, rhs: pgast.SelectStmt, path_id: Optional[irast.PathId], *, aspect: pgce.PathAspect=pgce.PathAspect.IDENTITY, ctx: context.CompilerContextLevel, ) -> None: """Filter elements out of the LHS that appear on the RHS""" if path_id: # grab the identity from the LHS and do an # anti-join against the RHS. src_ref = pathctx.get_path_var( lhs, path_id=path_id, aspect=aspect, env=ctx.env) pathctx.get_path_output( rhs, path_id=path_id, aspect=aspect, env=ctx.env) cond_expr: pgast.BaseExpr = astutils.new_binop( src_ref, rhs, 'NOT IN') else: # No path we care about. Just check existance. cond_expr = pgast.SubLink(operator="NOT EXISTS", expr=rhs) lhs.where_clause = astutils.extend_binop(lhs.where_clause, cond_expr) def range_from_queryset( set_ops: Sequence[tuple[context.OverlayOp, pgast.SelectStmt]], objname: sn.Name, *, prep_filter: Callable[ [pgast.SelectStmt, pgast.SelectStmt], None]=lambda a, b: None, path_id: Optional[irast.PathId]=None, lateral: bool=False, typeref: Optional[irast.TypeRef]=None, tag: Optional[str]=None, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: rvar: pgast.PathRangeVar if len(set_ops) > 1: # More than one class table, generate a UNION/EXCEPT clause. qry = set_ops[0][1] for op, rarg in set_ops[1:]: if op == context.OverlayOp.FILTER: qry = wrap_set_op_query(qry, ctx=ctx) prep_filter(qry, rarg) anti_join(qry, rarg, path_id, ctx=ctx) else: qry = pgast.SelectStmt( op=op, all=True, larg=qry, rarg=rarg, ) rvar = pgast.RangeSubselect( subquery=qry, lateral=lateral, tag=tag, alias=pgast.Alias( aliasname=ctx.env.aliases.get(objname.name), ), typeref=typeref, ) elif any( ( target.name is not None and isinstance(target.val, pgast.ColumnRef) and target.name != target.val.name[-1] ) for target in set_ops[0][1].target_list ): # A column name name is being changed rvar = pgast.RangeSubselect( subquery=set_ops[0][1], lateral=lateral, tag=tag, alias=pgast.Alias( aliasname=ctx.env.aliases.get(objname.name), ), typeref=typeref, ) else: # Just one class table, so return it directly from_rvar = set_ops[0][1].from_clause[0] assert isinstance(from_rvar, pgast.PathRangeVar) from_rvar = from_rvar.replace(typeref=typeref) rvar = from_rvar return rvar def range_for_ptrref( ptrref: irast.BasePointerRef, *, dml_source: Sequence[irast.MutatingLikeStmt]=(), for_mutation: bool=False, only_self: bool=False, path_id: Optional[irast.PathId]=None, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: """"Return a Range subclass corresponding to a given ptr step. The return value may potentially be a UNION of all tables corresponding to a set of specialized links computed from the given `ptrref` taking source inheritance into account. """ if ptrref.union_components: component_refs = ptrref.union_components if only_self and len(component_refs) > 1: raise errors.InternalServerError( 'unexpected union link' ) elif ptrref.intersection_components: # This is a little funky, but in an intersection, the pointer # needs to appear in *all* of the tables, so we just pick any # one of them. component_refs = {next(iter((ptrref.intersection_components)))} elif ptrref.computed_link_alias: component_refs = {ptrref.computed_link_alias} else: component_refs = {ptrref} assert isinstance(ptrref.out_source.name_hint, sn.QualName) include_descendants = not ptrref.union_is_exhaustive output_cols = ('source', 'target') set_ops = [] for component_ref in component_refs: assert isinstance(component_ref, irast.PointerRef), \ "expected regular PointerRef" component_rvar = _range_for_component_ptrref( component_ref, output_cols, dml_source=dml_source, include_descendants=include_descendants, for_mutation=for_mutation, path_id=path_id, ctx=ctx, ) component_qry = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=[output_colname] ), name=output_colname ) for output_colname in output_cols ], from_clause=[component_rvar] ) if path_id: target_ref = pgast.ColumnRef( name=[component_rvar.alias.aliasname, output_cols[1]] ) pathctx.put_path_identity_var( component_qry, path_id, var=target_ref ) pathctx.put_path_source_rvar( component_qry, path_id, component_rvar ) set_ops.append((context.OverlayOp.UNION, component_qry)) return range_from_queryset( set_ops, ptrref.shortname, path_id=path_id, ctx=ctx, ) def _range_for_component_ptrref( component_ptrref: irast.PointerRef, output_cols: Sequence[str], *, dml_source: Sequence[irast.MutatingLikeStmt], include_descendants: bool, for_mutation: bool, path_id: Optional[irast.PathId], ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: ptrref_descendants = _get_ptrref_descendants( component_ptrref, include_descendants=include_descendants, for_mutation=for_mutation, ) if ( # When we are compiling a query for EXPLAIN, expand out pointer # references in place. See range_for_material_objtype for more details. ctx.env.is_explain # Don't use CTEs if there is no inheritance. (ie. There is only a # single ptrref) or len(ptrref_descendants) <= 1 ): descendant_selects = _selects_for_ptrref_descendants( ptrref_descendants, output_cols=output_cols, path_id=path_id, ctx=ctx, ) descendant_ops = [ (context.OverlayOp.UNION, select) for select in descendant_selects ] component_rvar = range_from_queryset( descendant_ops, component_ptrref.shortname, path_id=path_id, ctx=ctx, ) else: component_ptrref_path_id: irast.PathId = irast.PathId.from_ptrref( component_ptrref, ).ptr_path() if component_ptrref.id not in ctx.ptr_inheritance_ctes: descendant_selects = _selects_for_ptrref_descendants( ptrref_descendants, output_cols=output_cols, path_id=component_ptrref_path_id, ctx=ctx, ) inheritance_qry: pgast.SelectStmt = descendant_selects[0] for rarg in descendant_selects[1:]: inheritance_qry = pgast.SelectStmt( op='union', all=True, larg=inheritance_qry, rarg=rarg, ) # Add the path to the CTE's query. This allows for the proper # path mapping to occur when processing link properties inheritance_qry.path_id = component_ptrref_path_id ptr_cte = pgast.CommonTableExpr( name=ctx.env.aliases.get(f't_{component_ptrref.name}'), query=inheritance_qry, materialized=False, ) ctx.ptr_inheritance_ctes[component_ptrref.id] = ptr_cte else: ptr_cte = ctx.ptr_inheritance_ctes[component_ptrref.id] with ctx.subrel() as sctx: cte_rvar = rvar_for_rel( ptr_cte, typeref=component_ptrref.out_target, ctx=ctx, ) if path_id is not None and path_id != component_ptrref_path_id: pathctx.put_path_id_map( sctx.rel, path_id, component_ptrref_path_id, ) include_rvar( sctx.rel, cte_rvar, component_ptrref_path_id, pull_namespace=False, ctx=sctx, ) # Ensure source and target columns are output for output_colname in output_cols: selexpr = pgast.ColumnRef( name=[cte_rvar.alias.aliasname, output_colname]) sctx.rel.target_list.append( pgast.ResTarget(val=selexpr, name=output_colname) ) target_ref = sctx.rel.target_list[1].val pathctx.put_path_identity_var( sctx.rel, component_ptrref_path_id, var=target_ref ) pathctx.put_path_source_rvar( sctx.rel, component_ptrref_path_id, cte_rvar, ) component_rvar = rvar_for_rel( sctx.rel, typeref=component_ptrref.out_target, ctx=sctx ) # Add overlays at the end of each expanded inheritance. overlays = get_ptr_rel_overlays( component_ptrref, dml_source=dml_source, ctx=ctx) if overlays and not for_mutation: set_ops = [] component_qry = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=[output_colname] ), name=output_colname ) for output_colname in output_cols ], from_clause=[component_rvar] ) if path_id: target_ref = pgast.ColumnRef( name=[component_rvar.alias.aliasname, output_cols[1]] ) pathctx.put_path_identity_var( component_qry, path_id, var=target_ref ) pathctx.put_path_source_rvar( component_qry, path_id, component_rvar ) set_ops.append((context.OverlayOp.UNION, component_qry)) orig_ptr_info = _get_ptrref_storage_info(component_ptrref, ctx=ctx) cols = _get_ptrref_column_names(orig_ptr_info) for op, cte, cte_path_id in overlays: rvar = rvar_for_rel(cte, ctx=ctx) qry = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=[col] ) ) for col in cols ], from_clause=[rvar], ) # Set up identity var, source rvar for reasons discussed above if path_id: target_ref = pgast.ColumnRef( name=[rvar.alias.aliasname, cols[1]]) pathctx.put_path_identity_var( qry, cte_path_id, var=target_ref ) pathctx.put_path_source_rvar(qry, cte_path_id, rvar) pathctx.put_path_id_map(qry, path_id, cte_path_id) set_ops.append((op, qry)) component_rvar = range_from_queryset( set_ops, component_ptrref.shortname, prep_filter=_prep_filter, path_id=path_id, ctx=ctx) return component_rvar def _prep_filter(larg: pgast.SelectStmt, rarg: pgast.SelectStmt) -> None: # Set up the proper join on the source field and clear the target list # of the rhs of a filter overlay. # If the names don't have table refs, make them refer to the table # being joined. lval = larg.target_list[0].val assert isinstance(lval, pgast.ColumnRef) if len(lval.name) == 1: lval = astutils.get_column(larg.from_clause[0], lval) rval = rarg.target_list[0].val assert isinstance(rval, pgast.ColumnRef) if len(rval.name) == 1: rval = astutils.get_column(rarg.from_clause[0], rval) rarg.where_clause = astutils.join_condition(lval, rval) rarg.target_list.clear() def _get_ptrref_descendants( ptrref: irast.PointerRef, *, include_descendants: bool, for_mutation: bool, ) -> list[irast.PointerRef]: # When doing EXPLAIN, don't use CTEs. See range_for_material_objtype for # details. if ( include_descendants and not for_mutation ): include_descendants = False descendants: list[irast.PointerRef] = [] descendants.extend( cast(Iterable[irast.PointerRef], ptrref.descendants()) ) descendants.append(ptrref) assert isinstance(ptrref, irast.PointerRef) # Try to only select from actual concrete types. concrete_descendants = [ ref for ref in descendants if not ref.out_source.is_abstract ] # If there aren't any concrete types, we still need to # generate *something*, so just do the initial one. if concrete_descendants: return concrete_descendants else: return [ptrref] else: return [ptrref] def _selects_for_ptrref_descendants( ptrref_descendants: Sequence[irast.PointerRef], output_cols: Iterable[str], *, path_id: Optional[irast.PathId], ctx: context.CompilerContextLevel, ) -> list[pgast.SelectStmt]: selects = [] for ptrref_descendant in ptrref_descendants: ptr_info = _get_ptrref_storage_info(ptrref_descendant, ctx=ctx) cols = _get_ptrref_column_names(ptr_info) table = _table_from_ptrref( ptrref_descendant, ptr_info, ctx=ctx, ) table.query.path_id = path_id qry = pgast.SelectStmt() qry.from_clause.append(table) # Make sure all property references are pulled up properly for colname, output_colname in zip(cols, output_cols): selexpr = pgast.ColumnRef( name=[table.alias.aliasname, colname]) qry.target_list.append( pgast.ResTarget(val=selexpr, name=output_colname)) selects.append(qry) # We need the identity var for semi_join to work and # the source rvar so that linkprops can be found here. if path_id: target_ref = qry.target_list[1].val pathctx.put_path_identity_var(qry, path_id, var=target_ref) pathctx.put_path_source_rvar(qry, path_id, table) return selects def _table_from_ptrref( ptrref: irast.PointerRef, ptr_info: pg_types.PointerStorageInfo, *, ctx: context.CompilerContextLevel, ) -> pgast.RelRangeVar: """Return a Table corresponding to a given Link.""" aspect = 'table' table_schema_name, table_name = common.update_aspect( ptr_info.table_name, aspect ) typeref = ptrref.out_source if ptrref else None relation = pgast.Relation( schemaname=table_schema_name, name=table_name, type_or_ptr_ref=ptrref, ) # Pseudo pointers (tuple and type intersection) have no schema id. sobj_id = ptrref.id if isinstance(ptrref, irast.PointerRef) else None rvar = pgast.RelRangeVar( schema_object_id=sobj_id, typeref=typeref, relation=relation, alias=pgast.Alias( aliasname=ctx.env.aliases.get(ptrref.shortname.name) ) ) return rvar def _get_ptrref_storage_info( ptrref: irast.PointerRef, *, ctx: context.CompilerContextLevel ) -> pg_types.PointerStorageInfo: # Most references to inline links are dispatched to a separate # code path (_new_inline_pointer_rvar) by new_pointer_rvar, # but when we have union pointers, some might be inline. We # always use the link table if it exists (because this range # needs to contain any link properties, for one reason.) ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=False, link_bias=True, versioned=ctx.env.versioned_stdlib, ) if not ptr_info: ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=False, link_bias=False, versioned=ctx.env.versioned_stdlib, ) return ptr_info def _get_ptrref_column_names( ptr_info: pg_types.PointerStorageInfo ) -> list[str]: return [ 'source' if ptr_info.table_type == 'link' else 'id', ptr_info.column_name, ] def range_for_pointer( ir_set: irast.SetE[irast.Pointer], *, dml_source: Sequence[irast.MutatingLikeStmt]=(), ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: pointer = ir_set.expr path_id = ir_set.path_id.ptr_path() external_rvar = ctx.env.external_rvars.get( (path_id, pgce.PathAspect.SOURCE) ) if external_rvar is not None: return external_rvar ptrref = pointer.ptrref if ptrref.material_ptr is not None: ptrref = ptrref.material_ptr return range_for_ptrref( ptrref, dml_source=dml_source, path_id=path_id, ctx=ctx) def rvar_for_rel( rel: pgast.BaseRelation | pgast.CommonTableExpr, *, alias: Optional[str] = None, typeref: Optional[irast.TypeRef] = None, lateral: bool = False, colnames: Optional[list[str]] = None, ctx: Optional[context.CompilerContextLevel] = None, env: Optional[context.Environment] = None, ) -> pgast.PathRangeVar: if ctx: env = ctx.env assert env rvar: pgast.PathRangeVar if colnames is None: colnames = [] if isinstance(rel, pgast.Query): alias = alias or env.aliases.get(rel.name or 'q') rvar = pgast.RangeSubselect( subquery=rel, alias=pgast.Alias(aliasname=alias, colnames=colnames), lateral=lateral, typeref=typeref, ) else: alias = alias or env.aliases.get(rel.name or '') rvar = pgast.RelRangeVar( relation=rel, alias=pgast.Alias(aliasname=alias, colnames=colnames), typeref=typeref, ) return rvar def _add_type_rel_overlay( typeid: uuid.UUID, op: context.OverlayOp, rel: pgast.BaseRelation | pgast.CommonTableExpr, *, dml_stmts: Iterable[irast.MutatingLikeStmt] = (), path_id: irast.PathId, ctx: context.CompilerContextLevel ) -> None: entry = (op, rel, path_id) dml_stmts2 = dml_stmts if dml_stmts else (None,) # If there is a "global" overlay, and there is none for the # current statements, use it as the base. This is important for # not losing track of the global environment in triggers. root = ctx.rel_overlays.type.get(None, immu.Map()) for dml_stmt in dml_stmts2: ds_overlays = ctx.rel_overlays.type.get(dml_stmt, root) overlays = ds_overlays.get(typeid, ()) if entry not in overlays: ds_overlays = ds_overlays.set(typeid, overlays + (entry,)) ctx.rel_overlays.type = ( ctx.rel_overlays.type.set(dml_stmt, ds_overlays)) def add_type_rel_overlay( typeref: irast.TypeRef, op: context.OverlayOp, rel: pgast.BaseRelation | pgast.CommonTableExpr, *, stop_ref: Optional[irast.TypeRef]=None, dml_stmts: Iterable[irast.MutatingLikeStmt] = (), path_id: irast.PathId, ctx: context.CompilerContextLevel ) -> None: typeref = typeref.real_material_type objs = [typeref] if typeref.ancestors: objs.extend(typeref.ancestors) for obj in objs: if stop_ref and ( obj == stop_ref or (stop_ref.ancestors and obj in stop_ref.ancestors) ): continue _add_type_rel_overlay( obj.id, op, rel, dml_stmts=dml_stmts, path_id=path_id, ctx=ctx) def get_type_rel_overlays( typeref: irast.TypeRef, *, dml_source: Sequence[irast.MutatingLikeStmt]=(), ctx: context.CompilerContextLevel, ) -> tuple[context.OverlayEntry, ...]: if typeref.material_type is not None: typeref = typeref.material_type xdml_source = dml_source or (None,) return tuple( entry for src in xdml_source if src in ctx.rel_overlays.type for entry in ctx.rel_overlays.type[src].get(typeref.id, ()) ) def reuse_type_rel_overlays( *, dml_stmts: Iterable[irast.MutatingLikeStmt] = (), dml_source: irast.MutatingLikeStmt, ctx: context.CompilerContextLevel, ) -> None: """Update type rel overlays when a DML statement is reused. When a WITH bound DML is used, we need to add it (and all of its nested overlays) as an overlay for all the enclosing DML statements. """ ref_overlays = ctx.rel_overlays.type.get(dml_source, immu.Map()) for tid, overlays in ref_overlays.items(): for op, rel, path_id in overlays: _add_type_rel_overlay( tid, op, rel, dml_stmts=dml_stmts, path_id=path_id, ctx=ctx ) ptr_overlays = ctx.rel_overlays.ptr.get(dml_source, immu.Map()) for (obj, ptr), poverlays in ptr_overlays.items(): for op, rel, path_id in poverlays: _add_ptr_rel_overlay( obj, ptr, op, rel, path_id=path_id, dml_stmts=dml_stmts, ctx=ctx ) def _add_ptr_rel_overlay( typeid: uuid.UUID, ptrref_name: str, op: context.OverlayOp, rel: pgast.BaseRelation | pgast.CommonTableExpr, *, dml_stmts: Iterable[irast.MutatingLikeStmt] = (), path_id: irast.PathId, ctx: context.CompilerContextLevel ) -> None: entry = (op, rel, path_id) dml_stmts2 = dml_stmts if dml_stmts else (None,) key = typeid, ptrref_name # If there is a "global" overlay, and there is none for the # current statements, use it as the base. This is important for # not losing track of the global environment in triggers. root = ctx.rel_overlays.ptr.get(None, immu.Map()) for dml_stmt in dml_stmts2: ds_overlays = ctx.rel_overlays.ptr.get(dml_stmt, root) overlays = ds_overlays.get(key, ()) if entry not in overlays: ds_overlays = ds_overlays.set(key, overlays + (entry,)) ctx.rel_overlays.ptr = ( ctx.rel_overlays.ptr.set(dml_stmt, ds_overlays)) def add_ptr_rel_overlay( ptrref: irast.PointerRef, op: context.OverlayOp, rel: pgast.BaseRelation | pgast.CommonTableExpr, *, dml_stmts: Iterable[irast.MutatingLikeStmt] = (), path_id: irast.PathId, ctx: context.CompilerContextLevel ) -> None: typeref = ptrref.out_source.real_material_type objs = [typeref] if typeref.ancestors: objs.extend(typeref.ancestors) for obj in objs: _add_ptr_rel_overlay( obj.id, ptrref.shortname.name, op, rel, path_id=path_id, dml_stmts=dml_stmts, ctx=ctx) def get_ptr_rel_overlays( ptrref: irast.PointerRef, *, dml_source: Sequence[irast.MutatingLikeStmt]=(), ctx: context.CompilerContextLevel, ) -> tuple[context.OverlayEntry, ...]: typeref = ptrref.out_source.real_material_type key = typeref.id, ptrref.shortname.name xdml_source = dml_source or (None,) return tuple( entry for src in xdml_source if src in ctx.rel_overlays.ptr for entry in ctx.rel_overlays.ptr[src].get(key, ()) ) ================================================ FILE: edb/pgsql/compiler/relgen.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Compiler functions to generate SQL relations for IR sets.""" from __future__ import annotations from typing import ( Callable, Optional, Protocol, Iterable, Collection, NamedTuple, cast, ) import dataclasses import contextlib import functools from edb import errors from edb.edgeql import qltypes from edb.schema import objects as s_obj from edb.schema import name as sn from edb.edgeql import ast as qlast from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.ir import utils as irutils from edb.pgsql import ast as pgast from edb.pgsql import common from edb.pgsql import types as pg_types from edb.common.typeutils import not_none from . import astutils from . import clauses from . import context from . import dispatch from . import dml from . import enums as pgce from . import expr as exprcomp from . import output from . import pathctx from . import relctx @dataclasses.dataclass(repr=False, eq=False) class SetRVar: rvar: pgast.PathRangeVar path_id: irast.PathId aspects: Iterable[pgce.PathAspect] = dataclasses.field( default=(pgce.PathAspect.VALUE,) ) @dataclasses.dataclass(kw_only=True, repr=False, eq=False) class SetRVars: main: SetRVar new: list[SetRVar] def new_simple_set_rvar( ir_set: irast.Set, rvar: pgast.PathRangeVar, aspects: Iterable[pgce.PathAspect], ) -> SetRVars: srvar = SetRVar(rvar=rvar, path_id=ir_set.path_id, aspects=aspects) return SetRVars(main=srvar, new=[srvar]) def new_source_set_rvar( ir_set: irast.Set, rvar: pgast.PathRangeVar, ) -> SetRVars: aspects = [pgce.PathAspect.VALUE] if ir_set.path_id.is_objtype_path(): aspects.append(pgce.PathAspect.SOURCE) return new_simple_set_rvar(ir_set, rvar, aspects) def new_stmt_set_rvar( ir_set: irast.Set, stmt: pgast.Query, *, aspects: Optional[Iterable[pgce.PathAspect]]=None, ctx: context.CompilerContextLevel, ) -> SetRVars: rvar = relctx.new_rel_rvar(ir_set, stmt, ctx=ctx) if aspects is not None: aspects = tuple(aspects) else: aspects = pathctx.list_path_aspects(stmt, ir_set.path_id) return new_simple_set_rvar(ir_set, rvar, aspects=aspects) class OptionalRel(NamedTuple): scope_rel: pgast.SelectStmt target_rel: pgast.SelectStmt emptyrel: pgast.SelectStmt unionrel: pgast.SelectStmt wrapper: pgast.SelectStmt container: pgast.SelectStmt marker: str def _lookup_set_rvar( ir_set: irast.Set, *, scope_stmt: Optional[pgast.SelectStmt]=None, ctx: context.CompilerContextLevel) -> Optional[pgast.PathRangeVar]: path_id = ir_set.path_id rvar = relctx.find_rvar(ctx.rel, source_stmt=scope_stmt, path_id=path_id, ctx=ctx) if rvar is not None: return rvar # We couldn't find a regular rvar, but maybe we can find a packed one? packed_rvar = relctx.find_rvar(ctx.rel, flavor='packed', source_stmt=scope_stmt, path_id=path_id, ctx=ctx) if packed_rvar is not None: rvar = relctx.unpack_rvar( scope_stmt or ctx.rel, path_id, packed_rvar=packed_rvar, ctx=ctx) return rvar return None def get_set_rvar( ir_set: irast.Set, *, ctx: context.CompilerContextLevel) -> pgast.PathRangeVar: """Return a PathRangeVar for a given IR Set. Basically all of compilation comes through here for each set. @param ir_set: IR Set node. """ path_id = ir_set.path_id scope_stmt = relctx.maybe_get_scope_stmt(path_id, ctx=ctx) if rvar := _lookup_set_rvar(ir_set, scope_stmt=scope_stmt, ctx=ctx): return rvar if ctx.toplevel_stmt is context.NO_STMT: # Top level query return _process_toplevel_query(ir_set, ctx=ctx) with contextlib.ExitStack() as cstack: # If there was a scope_stmt registered for our path, we compile # as a subrel of that scope_stmt. Otherwise we use whatever the # current rel was. if scope_stmt is not None: newctx = cstack.enter_context(ctx.new()) newctx.rel = scope_stmt else: newctx = ctx scope_stmt = newctx.rel subctx = cstack.enter_context(newctx.subrel()) # *stmt* here is a tentative container for the relation generated # by processing the *ir_set*. However, the actual compilation # is free to return something else instead of a range var over # stmt. stmt = subctx.rel stmt.name = ctx.env.aliases.get(get_set_rel_alias(ir_set, ctx=ctx)) # If ir.Set compilation needs to produce a subquery, # make sure it uses the current subrel. This makes it # possible to set up the path scope here and don't worry # about it later. subctx.pending_query = stmt is_empty_set = isinstance(ir_set.expr, irast.EmptySet) path_scope = relctx.get_scope(ir_set, ctx=subctx) new_scope = path_scope or subctx.scope_tree is_optional = ( subctx.scope_tree.is_optional(path_id) or new_scope.is_optional(path_id) or path_id in subctx.force_optional ) and not can_omit_optional_wrapper(ir_set, new_scope, ctx=ctx) optional_wrapping = is_optional and not is_empty_set if optional_wrapping: stmt, optrel = prepare_optional_rel( ir_set=ir_set, stmt=stmt, ctx=subctx) subctx.pending_query = subctx.rel = stmt # XXX: This is pretty dodgy, because it updates the path_scope # *before* we call new_child() on it. Removing it only breaks two # tests of lprops on backlinks. if path_scope and path_scope.is_visible(path_id): subctx.path_scope[path_id] = scope_stmt # If this set has a scope in the scope tree associated with it, # register paths in that scope to be compiled with this stmt # as their scope_stmt. if path_scope: relctx.update_scope(ir_set, stmt, ctx=subctx) # Actually compile the set rvars = _get_expr_set_rvar(ir_set.expr, ir_set, ctx=subctx) relctx.update_scope_masks(ir_set, rvars.main.rvar, ctx=subctx) if ctx.env.is_explain: for srvar in rvars.new: if not srvar.rvar.ir_origins: srvar.rvar.ir_origins = [] srvar.rvar.ir_origins.append(ir_set) if optional_wrapping: rvars = finalize_optional_rel(ir_set, optrel=optrel, rvars=rvars, ctx=subctx) relctx.update_scope_masks(ir_set, rvars.main.rvar, ctx=subctx) elif not is_optional and is_empty_set: # In most cases it is totally fine for us to represent an # empty set as an empty relation. # (except when it needs to be fed to an optional argument) null_query = rvars.main.rvar.query assert isinstance( null_query, (pgast.SelectStmt, pgast.NullRelation)) null_query.where_clause = pgast.BooleanConstant(val=False) result_rvar = _include_rvars(rvars, scope_stmt=scope_stmt, ctx=subctx) for aspect in rvars.main.aspects: pathctx.put_path_rvar_if_not_exists( ctx.rel, path_id, result_rvar, aspect=aspect, ) return result_rvar def _include_rvars( rvars: SetRVars, *, scope_stmt: pgast.SelectStmt, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: for set_rvar in rvars.new: # overwrite_path_rvar is needed because we want # the outermost Set with the given path_id to # represent the path. Nested Sets with the # same path_id but different expression are # possible when there is a computable pointer # that refers to itself in its expression. relctx.include_specific_rvar( scope_stmt, set_rvar.rvar, path_id=set_rvar.path_id, overwrite_path_rvar=True, aspects=set_rvar.aspects, ctx=ctx, ) return rvars.main.rvar def _process_toplevel_query( ir_set: irast.Set, *, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: # TODO: Can we get rid of the need for this special handling of # the toplevel? What is it good for anyway? # I think it might just be suppressing what would be one extra # level of select wrapping? relctx.init_toplevel_query(ir_set, ctx=ctx) rvars = _get_expr_set_rvar(ir_set.expr, ir_set, ctx=ctx) result_rvar = rvars.main.rvar # Usually the result_rvar is wrapping ctx.rel, which is the final # top-level query. (And thus the result_rvar is actually bogus and # will never be used!) But if not, we need to include it, or we'll # have an empty query. if result_rvar.query is not ctx.rel: _include_rvars(rvars, scope_stmt=ctx.rel, ctx=ctx) return result_rvar class _SpecialCaseFunc(Protocol): def __call__( self, ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel, ) -> SetRVars: pass class _FunctionSpecialCase(NamedTuple): func: _SpecialCaseFunc only_as_fallback: bool _SPECIAL_FUNCTIONS: dict[str, _FunctionSpecialCase] = {} def _special_case(name: str, only_as_fallback: bool = False) -> Callable[ [_SpecialCaseFunc], _SpecialCaseFunc ]: def func(f: _SpecialCaseFunc) -> _SpecialCaseFunc: _SPECIAL_FUNCTIONS[name] = _FunctionSpecialCase(f, only_as_fallback) return f return func class _SimpleSpecialCaseFunc(Protocol): def __call__( self, expr: irast.FunctionCall, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: pass _SIMPLE_SPECIAL_FUNCTIONS: dict[str, _SimpleSpecialCaseFunc] = {} def simple_special_case( name: str, ) -> Callable[[_SimpleSpecialCaseFunc], _SimpleSpecialCaseFunc]: def func(f: _SimpleSpecialCaseFunc) -> _SimpleSpecialCaseFunc: _SIMPLE_SPECIAL_FUNCTIONS[name] = f return f return func # Dispatcher for _get_set_rvar implementations for different expressions. # The implementations just take a SetE[T] for some T, so register_get_rvar # needs to do some wrapping. @functools.singledispatch def _get_expr_set_rvar( expr: irast.Expr, ir: irast.Set, *, ctx: context.CompilerContextLevel, ) -> SetRVars: raise NotImplementedError(f'no relgen handler for {ir.__class__}') class _GetExprRvarFunc[T_expr: irast.Expr](Protocol): # noqa: UP046 def __call__( self, __ir_set: irast.SetE[T_expr], *, ctx: context.CompilerContextLevel ) -> SetRVars: pass def register_get_rvar[T_expr: irast.Expr]( typ: type[T_expr], ) -> Callable[[_GetExprRvarFunc[T_expr]], _GetExprRvarFunc[T_expr]]: def func(f: _GetExprRvarFunc[T_expr]) -> _GetExprRvarFunc[T_expr]: _get_expr_set_rvar.register(typ)( lambda _, ir, *, ctx: f(ir, ctx=ctx)) return f return func def _get_source_rvar( ir_set: irast.Set, scope_stmt: pgast.SelectStmt, *, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: is_optional = ( ctx.scope_tree.is_optional(ir_set.path_id) or ir_set.path_id in ctx.force_optional ) if not is_optional: rvar = relctx.new_root_rvar(ir_set, lateral=True, ctx=ctx) relctx.include_rvar( scope_stmt, rvar, path_id=ir_set.path_id, ctx=ctx ) else: # If the path is optional in the context we are in, then we # need to put optional wrapping around the join with the base table. with ctx.subrel() as subctx: stmt, optrel = prepare_optional_rel( ir_set=ir_set, stmt=subctx.rel, ctx=subctx) subctx.pending_query = subctx.rel = stmt rvar = relctx.new_root_rvar(ir_set, lateral=True, ctx=subctx) rvars = new_source_set_rvar(ir_set, rvar) rvars = finalize_optional_rel( ir_set, optrel=optrel, rvars=rvars, ctx=subctx) rvar = _include_rvars(rvars, scope_stmt=scope_stmt, ctx=ctx) return rvar def ensure_source_rvar( ir_set: irast.Set, stmt: pgast.Query, *, ctx: context.CompilerContextLevel, ) -> pgast.PathRangeVar: """Make sure that a source aspect is available for ir_set. If no aspect is available, compile it. If value/identity is available but source is not, select from the base relation and join it in. """ rvar = relctx.maybe_get_path_rvar( stmt, ir_set.path_id, aspect=pgce.PathAspect.SOURCE, ctx=ctx) if rvar is None: get_set_rvar(ir_set, ctx=ctx) rvar = relctx.maybe_get_path_rvar( stmt, ir_set.path_id, aspect=pgce.PathAspect.SOURCE, ctx=ctx) if rvar is None: scope_stmt = relctx.maybe_get_scope_stmt(ir_set.path_id, ctx=ctx) if scope_stmt is None: scope_stmt = ctx.rel rvar = relctx.maybe_get_path_rvar( scope_stmt, ir_set.path_id, aspect=pgce.PathAspect.SOURCE, ctx=ctx, ) if rvar is None: if irtyputils.is_free_object(ir_set.path_id.target): # Free objects don't have a real source, and # generating a new fake source doesn't work because # the ids don't match, so instead we call the existing # value rvar a source. rvar = relctx.get_path_rvar( scope_stmt, ir_set.path_id, aspect=pgce.PathAspect.VALUE, ctx=ctx, ) else: rvar = _get_source_rvar(ir_set, scope_stmt, ctx=ctx) pathctx.put_path_rvar( stmt, ir_set.path_id, rvar, aspect=pgce.PathAspect.SOURCE, ) return rvar def set_as_subquery( ir_set: irast.Set, *, as_value: bool=False, explicit_cast: Optional[tuple[str, ...]] = None, ctx: context.CompilerContextLevel) -> pgast.Query: # Compile *ir_set* into a subquery as follows: # ( # SELECT .v # FROM # ) with ctx.subrel() as subctx: wrapper = subctx.rel wrapper.name = ctx.env.aliases.get('set_as_subquery') dispatch.visit(ir_set, ctx=subctx) if as_value: if output.in_serialization_ctx(ctx): pathctx.get_path_serialized_output( rel=wrapper, path_id=ir_set.path_id, env=ctx.env) else: pathctx.get_path_value_output( rel=wrapper, path_id=ir_set.path_id, env=ctx.env) var = pathctx.get_path_value_var( rel=wrapper, path_id=ir_set.path_id, env=ctx.env) value = output.output_as_value(var, env=ctx.env) if explicit_cast is not None: value = pgast.TypeCast( arg=value, type_name=pgast.TypeName(name=explicit_cast), ) wrapper.target_list = [ pgast.ResTarget(val=value) ] else: pathctx.get_path_value_output( rel=wrapper, path_id=ir_set.path_id, env=ctx.env) return wrapper def can_omit_optional_wrapper( ir_set: irast.Set, new_scope: irast.ScopeTreeNode, *, ctx: context.CompilerContextLevel) -> bool: """Determine whether it is safe to omit the optional wrapper. Doing so is safe when the expression is guarenteed to result in a NULL and not an empty set. The main such case implemented is a path `foo.bar` where foo is visible and bar is a single non-computed property, which we know will be stored as NULL in the database. We also handle trivial SELECTs wrapping such an expression. """ if ir_set.expr and irutils.is_trivial_select(ir_set.expr): return can_omit_optional_wrapper( ir_set.expr.result, relctx.get_scope(ir_set.expr.result, ctx=ctx) or new_scope, ctx=ctx, ) if isinstance(ir_set.expr, irast.QueryParameter): return True # Our base json casts should all preserve nullity (instead of # turning it into an empty set), so allow passing through those # cases. This is mainly an optimization for passing globals to # functions, where we need to convert a bunch of optional params # to json, and for casting out of json there and in schema updates. if ( isinstance(ir_set.expr, irast.TypeCast) and (( irtyputils.is_scalar(ir_set.expr.expr.typeref) and irtyputils.is_json(ir_set.expr.to_type) ) or ( irtyputils.is_json(ir_set.expr.expr.typeref) and irtyputils.is_scalar(ir_set.expr.to_type) )) ): return can_omit_optional_wrapper( ir_set.expr.expr, relctx.get_scope(ir_set.expr.expr, ctx=ctx) or new_scope, ctx=ctx, ) if isinstance(ir_set.expr, irast.TupleIndirectionPointer): return can_omit_optional_wrapper(ir_set.expr.source, new_scope, ctx=ctx) return bool( isinstance(ir_set.expr, irast.Pointer) and (rptr := ir_set.expr) and rptr.expr is None and not ir_set.path_id.is_objtype_path() and not ir_set.path_id.is_type_intersection_path() and new_scope.is_visible(rptr.source.path_id) and not rptr.is_inbound and rptr.ptrref.out_cardinality.is_single() and not rptr.ptrref.is_computable ) def prepare_optional_rel( *, ir_set: irast.Set, stmt: pgast.SelectStmt, ctx: context.CompilerContextLevel) \ -> tuple[pgast.SelectStmt, OptionalRel]: # For OPTIONAL sets we compute a UNION of both sides and annotate # each side with a marker. We then select only rows that match # the marker of the first row: # # SELECT # q.* # FROM # (SELECT # marker = first_value(marker) OVER () AS marker, # ... # FROM # (SELECT 1 AS marker, * FROM left # UNION ALL # SELECT 2 AS marker, * FROM right) AS u # ) AS q # WHERE marker with ctx.new() as subctx: subctx.rel = stmt with subctx.subrel() as wrapctx: wrapper = wrapctx.rel with wrapctx.subrel() as unionctx: with unionctx.subrel() as scopectx: scope_rel = scopectx.rel with scopectx.subrel() as targetctx: target_rel = targetctx.rel with unionctx.subrel() as scopectx: emptyrel = scopectx.rel empty_ir = irast.Set( path_id=ir_set.path_id, typeref=ir_set.typeref, expr=irast.EmptySet(typeref=ir_set.typeref), ) emptyrvar = relctx.new_empty_rvar( cast('irast.SetE[irast.EmptySet]', empty_ir), ctx=scopectx) relctx.include_rvar( emptyrel, emptyrvar, path_id=ir_set.path_id, ctx=scopectx) marker = unionctx.env.aliases.get('m') scope_rel.target_list.insert( 0, pgast.ResTarget(val=pgast.NumericConstant(val='1'), name=marker)) emptyrel.target_list.insert( 0, pgast.ResTarget(val=pgast.NumericConstant(val='2'), name=marker)) unionqry = unionctx.rel unionqry.op = 'UNION' unionqry.all = True unionqry.larg = scope_rel unionqry.rarg = emptyrel lagged_marker = pgast.FuncCall( name=('first_value',), args=[pgast.ColumnRef(name=[marker])], over=pgast.WindowDef() ) marker_ok = astutils.new_binop( pgast.ColumnRef(name=[marker]), lagged_marker, op='=', ) wrapper.target_list.append( pgast.ResTarget( name=marker, val=marker_ok ) ) return ( target_rel, OptionalRel(scope_rel=scope_rel, target_rel=target_rel, emptyrel=emptyrel, unionrel=unionqry, wrapper=wrapper, container=stmt, marker=marker) ) def finalize_optional_rel( ir_set: irast.Set, optrel: OptionalRel, rvars: SetRVars, ctx: context.CompilerContextLevel) -> SetRVars: with ctx.new() as subctx: subctx.rel = setrel = optrel.scope_rel for set_rvar in rvars.new: relctx.include_specific_rvar( setrel, set_rvar.rvar, path_id=set_rvar.path_id, aspects=set_rvar.aspects, ctx=subctx) for aspect in rvars.main.aspects: pathctx.put_path_rvar_if_not_exists( setrel, ir_set.path_id, rvars.main.rvar, aspect=aspect ) lvar = pathctx.get_path_value_var( setrel, path_id=ir_set.path_id, env=subctx.env) if lvar.nullable: # The left var is still nullable, which may be the # case for non-required singleton properties. # Filter out NULLs. setrel.where_clause = astutils.extend_binop( setrel.where_clause, pgast.NullTest( arg=lvar, negated=True ) ) unionrel = optrel.unionrel union_rvar = relctx.rvar_for_rel(unionrel, lateral=True, ctx=ctx) with ctx.new() as subctx: subctx.rel = wrapper = optrel.wrapper relctx.include_rvar(wrapper, union_rvar, ir_set.path_id, ctx=subctx) with ctx.new() as subctx: subctx.rel = stmt = optrel.container wrapper_rvar = relctx.rvar_for_rel(wrapper, lateral=True, ctx=subctx) relctx.include_rvar(stmt, wrapper_rvar, ir_set.path_id, ctx=subctx) stmt.where_clause = astutils.extend_binop( stmt.where_clause, astutils.get_column(wrapper_rvar, optrel.marker, nullable=False)) stmt.nullable = True sub_rvar = SetRVar(rvar=relctx.new_rel_rvar(ir_set, stmt, ctx=ctx), path_id=ir_set.path_id, aspects=rvars.main.aspects) return SetRVars(main=sub_rvar, new=[sub_rvar]) def get_set_rel_alias(ir_set: irast.Set, *, ctx: context.CompilerContextLevel) -> str: dname = ir_set.path_id.target_name_hint.name if ( isinstance(ir_set.expr, irast.Pointer) and ir_set.expr.source.typeref is not None ): alias_hint = '{}_{}'.format( dname, ir_set.expr.ptrref.shortname.name ) else: alias_hint = dname.replace('~', '-') return alias_hint # N.B: registered for get_rvar for TypeRoot below def process_set_as_root( ir_set: irast.Set, *, ctx: context.CompilerContextLevel ) -> SetRVars: # TODO(ir): Represent these as something other than TypeRoot? if ir_set.path_id in ctx.external_rels: return process_external_rel(ir_set, ctx=ctx) assert not ir_set.is_visible_binding_ref, ( f"Can't compile ref to visible binding root {ir_set.path_id}" ) rvar = relctx.new_root_rvar(ir_set, ctx=ctx) return new_source_set_rvar(ir_set, rvar) register_get_rvar(irast.TypeRoot)(process_set_as_root) @register_get_rvar(irast.VisibleBindingExpr) def process_set_as_visible_binding( ir_set: irast.SetE[irast.VisibleBindingExpr], *, ctx: context.CompilerContextLevel ) -> SetRVars: raise AssertionError( f"Can't compile ref to visible binding {ir_set.path_id}" ) @register_get_rvar(irast.InlinedParameterExpr) def process_set_as_inlined_parameter( ir_set: irast.SetE[irast.InlinedParameterExpr], *, ctx: context.CompilerContextLevel ) -> SetRVars: raise AssertionError( f"Can't compile ref to inline parameter {ir_set.path_id}" ) @register_get_rvar(irast.EmptySet) def process_set_as_empty( ir_set: irast.SetE[irast.EmptySet], *, ctx: context.CompilerContextLevel ) -> SetRVars: rvar = relctx.new_empty_rvar(ir_set, ctx=ctx) return new_source_set_rvar(ir_set, rvar) def process_external_rel( ir_set: irast.Set, *, ctx: context.CompilerContextLevel ) -> SetRVars: rel, aspects = ctx.external_rels[ir_set.path_id] for a in aspects: assert isinstance(a, str) rvar = relctx.rvar_for_rel(rel, ctx=ctx) return new_simple_set_rvar(ir_set, rvar, aspects) def process_set_as_link_property_ref( ir_set: irast.SetE[irast.Pointer], *, ctx: context.CompilerContextLevel ) -> SetRVars: rptr = ir_set.expr ir_source = rptr.source rvars = [] lpropref = rptr.ptrref ptr_info = pg_types.get_ptrref_storage_info( lpropref, resolve_type=False, link_bias=False) if (ptr_info.table_type == 'ObjectType' or str(lpropref.std_parent_name) == 'std::target'): # This is a singleton link property stored in source rel, # e.g. @target src_rvar = get_set_rvar(ir_source, ctx=ctx) val = pathctx.get_rvar_path_var( src_rvar, ir_source.path_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) pathctx.put_rvar_path_output( src_rvar, ir_set.path_id, aspect=pgce.PathAspect.VALUE, var=val ) return SetRVars( main=SetRVar(rvar=src_rvar, path_id=ir_set.path_id), new=[]) with ctx.new() as newctx: link_path_id = ir_set.path_id.src_path() assert link_path_id is not None rptr_specialization: Optional[set[irast.PointerRef]] = None if link_path_id.is_type_intersection_path(): rptr_specialization = set() link_prefix, ind_ptrs = ( irutils.collapse_type_intersection(ir_source)) for ind_ptr in ind_ptrs: rptr_specialization.update(ind_ptr.ptrref.rptr_specialization) else: link_prefix = ir_source source_scope_stmt = relctx.maybe_get_scope_stmt( ir_source.path_id, ctx=ctx ) or ctx.rel link_rvar = pathctx.maybe_get_path_rvar( source_scope_stmt, link_path_id, aspect=pgce.PathAspect.SOURCE ) if link_rvar is None: src_rvar = get_set_rvar(ir_source, ctx=newctx) assert irutils.is_set_instance(link_prefix, irast.Pointer), ( f'projecting lprop on {link_prefix.expr}') link_rvar = relctx.new_pointer_rvar( link_prefix, src_rvar=src_rvar, link_bias=True, ctx=newctx) # Make sure the link rvar understands the path_id we are using. # (FIXME: Would it be better to pass this in to new_pointer_rvar?) pathctx.put_path_bond(link_rvar.query, link_path_id.tgt_path()) var = pathctx.get_rvar_path_identity_var( link_rvar, link_prefix.path_id, env=ctx.env) pathctx.put_rvar_path_output( link_rvar, link_path_id.tgt_path(), pgce.PathAspect.IDENTITY, var, ) if astutils.is_set_op_query(link_rvar.query): # If we have an rptr_specialization, then this is a link # property reference to a link union narrowed by a type # intersection. We already know which union components # match the indirection expression, and can route the link # property references to correct UNION subqueries. ptr_ids = ( {spec.id for spec in rptr_specialization} if rptr_specialization is not None else None ) if ptr_ids and rptr_specialization: ptr_ids.update( x.id for spec in rptr_specialization for x in spec.descendants() if isinstance(x, irast.PointerRef) ) for subquery in astutils.each_query_in_set(link_rvar.query): if isinstance(subquery, pgast.SelectStmt): rvar = subquery.from_clause[0] assert isinstance(rvar, pgast.PathRangeVar) if ptr_ids is None or rvar.schema_object_id in ptr_ids: pathctx.put_path_source_rvar( subquery, link_path_id, rvar ) continue # Spare get_path_var() from attempting to rebalance # the UNION by recording an explicit NULL as as the # link property var. pathctx.put_path_value_var( subquery, ir_set.path_id, pgast.TypeCast( arg=pgast.NullConstant(), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref( ir_set.typeref), ), ), ) elif isinstance(link_rvar.query, pgast.SelectStmt): # When processing link properties into a CTE, map the current link # path id into the form used by the CTE. for from_rvar in link_rvar.query.from_clause: if ( isinstance(from_rvar, pgast.RelRangeVar) and isinstance(from_rvar.relation, pgast.CommonTableExpr) and from_rvar.relation.query.path_id is not None ): pathctx.put_path_id_map( link_rvar.query, link_path_id, from_rvar.relation.query.path_id ) rvars.append(SetRVar( link_rvar, link_path_id, aspects=[pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE], )) return SetRVars(main=SetRVar(link_rvar, ir_set.path_id), new=rvars) @register_get_rvar(irast.TypeIntersectionPointer) def process_set_as_path_type_intersection( ir_set: irast.SetE[irast.TypeIntersectionPointer], *, ctx: context.CompilerContextLevel, ) -> SetRVars: rptr = ir_set.expr ir_source = rptr.source source_is_visible = ctx.scope_tree.is_visible(ir_source.path_id) stmt = ctx.rel assert not rptr.expr, 'type intersection pointer with expr??' if irtyputils.is_empty_typeref(ir_set.typeref): # If the typeref was a type expression which resolves to no actual # types, just return an empty set. empty_ir = irast.Set( path_id=ir_set.path_id, typeref=ir_set.typeref, expr=irast.EmptySet(typeref=ir_set.typeref), ) source_rvar = relctx.new_empty_rvar( cast('irast.SetE[irast.EmptySet]', empty_ir), ctx=ctx) relctx.include_rvar(stmt, source_rvar, ir_set.path_id, ctx=ctx) elif (not source_is_visible and isinstance(ir_source.expr, irast.Pointer) and not ir_source.path_id.is_type_intersection_path() and not ir_source.expr.expr and ( rptr.ptrref.is_subtype or pg_types.get_ptrref_storage_info( ir_source.expr.ptrref).table_type != 'ObjectType' )): # Otherwise, if the source link path is not visible, # and this is a subtype intersection, or the pointer is not inline, # we have an opportunity to opmimize the target join by # directly replacing the target type. with ctx.new() as subctx: subctx.intersection_narrowing = ( subctx.intersection_narrowing.copy()) subctx.intersection_narrowing[ir_source] = ir_set source_rvar = get_set_rvar(ir_source, ctx=subctx) pathctx.put_path_id_map(stmt, ir_set.path_id, ir_source.path_id) relctx.include_rvar(stmt, source_rvar, ir_set.path_id, ctx=ctx) else: source_rvar = get_set_rvar(ir_source, ctx=ctx) poly_rvar = relctx.range_for_typeref( rptr.ptrref.out_target, path_id=ir_set.path_id, dml_source=irutils.get_dml_sources(ir_set, ctx.env.binding_dml), lateral=True, ctx=ctx, ) prefix_path_id = ir_set.path_id.src_path() assert prefix_path_id is not None, 'expected a path' relctx.deep_copy_primitive_rvar_path_var( ir_set.path_id, prefix_path_id, poly_rvar, env=ctx.env) pathctx.put_rvar_path_bond(poly_rvar, prefix_path_id) relctx.include_rvar(stmt, poly_rvar, ir_set.path_id, ctx=ctx) int_rvar = pgast.IntersectionRangeVar( component_rvars=[ source_rvar, poly_rvar, ] ) if isinstance(source_rvar.query, pgast.Query): pathctx.put_path_id_map( source_rvar.query, ir_set.path_id, ir_source.path_id) for aspect in (pgce.PathAspect.SOURCE, pgce.PathAspect.VALUE): pathctx.put_path_rvar( stmt, ir_source.path_id, source_rvar, aspect=aspect, ) pathctx.put_path_rvar( stmt, ir_set.path_id, int_rvar, aspect=aspect, ) rvars = new_stmt_set_rvar( ir_set, stmt, aspects=[pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE], ctx=ctx, ) # If the inner set also exposes a pointer path source, we need to # also expose a pointer path source. See tests like # test_edgeql_for_lprop_02, where it is needed to to make FOR binding # of backlinks work. if pathctx.maybe_get_path_rvar( stmt, ir_source.path_id.ptr_path(), aspect=pgce.PathAspect.SOURCE, ): rvars.new.append( SetRVar( rvars.main.rvar, ir_set.path_id.ptr_path(), aspects=(pgce.PathAspect.SOURCE,), ) ) return rvars def _source_path_needs_semi_join( ir_source: irast.Set, ctx: context.CompilerContextLevel) -> bool: """Check if the path might need a semi-join It does not need one if it has a visible prefix followed by single pointers. Otherwise it might. This is an optimization that allows us to avoid doing a semi-join when there is a chain of single links referenced (probably in a filter or a computable). """ if ctx.scope_tree.is_visible(ir_source.path_id): return False while ( isinstance(ir_source.expr, irast.Pointer) and ir_source.expr.dir_cardinality.is_single() and not ir_source.expr.expr ): ir_source = ir_source.expr.source if ctx.scope_tree.is_visible(ir_source.path_id): return False return True @register_get_rvar(irast.Pointer) def process_set_as_path( ir_set: irast.SetE[irast.Pointer], *, ctx: context.CompilerContextLevel ) -> SetRVars: if ir_set.expr.expr: return process_set_as_subquery(ir_set, ctx=ctx) rptr = ir_set.expr ptrref = rptr.ptrref ir_source = rptr.source stmt = ctx.rel source_is_visible = ctx.scope_tree.is_visible(ir_source.path_id) rvars = [] ptr_info = pg_types.get_ptrref_storage_info( ptrref, resolve_type=False, link_bias=rptr.force_link_table, allow_missing=True, ) # Path is a link property. is_linkprop = ptrref.source_ptr is not None is_primitive_ref = not irtyputils.is_object(ptrref.out_target) # Path is a reference to a relationship stored in the source table. is_inline_ref = bool(ptr_info and ptr_info.table_type == 'ObjectType') is_inline_primitive_ref = is_inline_ref and is_primitive_ref is_id_ref_to_inline_source = False semi_join = ( ir_set.path_id not in ctx.disable_semi_join and not (is_linkprop or is_primitive_ref) and _source_path_needs_semi_join(ir_source, ctx=ctx) and # This is an optimization for when we are inside of a semi-join on # a computable: process_set_as_subquery will have included an # rvar for the computable source, and we want to join on it # instead of semi-joining. not relctx.find_rvar(stmt, path_id=ir_source.path_id, ctx=ctx) ) if irtyputils.is_empty_typeref(ir_source.typeref): # If the source is an empty type intersection, just produce an empty set if is_primitive_ref: aspects = [pgce.PathAspect.VALUE] else: aspects = [pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE] empty_ir = irast.Set( path_id=ir_set.path_id, typeref=ir_set.typeref, expr=irast.EmptySet(typeref=ir_set.typeref), ) empty_rvar = SetRVar( relctx.new_empty_rvar( cast('irast.SetE[irast.EmptySet]', empty_ir), ctx=ctx ), path_id=ir_set.path_id, aspects=aspects, ) return SetRVars(main=empty_rvar, new=[empty_rvar]) main_rvar = None source_rptr = ( ir_source.expr if isinstance(ir_source.expr, irast.Pointer) else None) if (irtyputils.is_id_ptrref(ptrref) and source_rptr is not None and isinstance(source_rptr.ptrref, irast.PointerRef) and not source_rptr.is_inbound and not irtyputils.is_computable_ptrref(source_rptr.ptrref) and not irutils.is_type_intersection_reference(ir_set) and not pathctx.link_needs_type_rewrite( ir_source.typeref, env=ctx.env)): src_src_is_visible = ctx.scope_tree.is_visible( source_rptr.source.path_id) # Record the ptrref visibility in a way that get_path_var # can access, to properly apply the second part of this # optimization. ctx.env.ptrref_source_visibility[source_rptr.ptrref] = ( src_src_is_visible) if src_src_is_visible: # When there is a reference to the id property of # an object which is linked to by a link stored # inline, we want to route the reference to the # inline attribute. For example, # Foo.__type__.id gets resolved to the Foo.__type__ # column. However, this optimization must not be # applied if the source is a type intersection, e.g # __type__[IS Array].id, or if Foo is not visible in # this scope. source_ptr_info = pg_types.get_ptrref_storage_info( source_rptr.ptrref, resolve_type=False, link_bias=False, allow_missing=True) is_id_ref_to_inline_source = bool( source_ptr_info and source_ptr_info.table_type == 'ObjectType') if semi_join: with ctx.subrel() as srcctx: srcctx.expr_exposed = False src_rvar = get_set_rvar(ir_source, ctx=srcctx) # semi_join needs a source rvar, so make sure we have one. # (The returned one won't be a source rvar if it comes # from a function, for example) if not ir_source.path_id.is_type_intersection_path(): src_rvar = ensure_source_rvar(ir_source, stmt, ctx=srcctx) set_rvar = relctx.semi_join(stmt, ir_set, src_rvar, ctx=srcctx) rvars.append(SetRVar( set_rvar, ir_set.path_id, [pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE] )) elif is_id_ref_to_inline_source: assert source_rptr is not None ir_source = source_rptr.source src_rvar = get_set_rvar(ir_source, ctx=ctx) elif not source_is_visible: with ctx.subrel() as srcctx: srcctx.expr_exposed = False get_set_rvar(ir_source, ctx=srcctx) if is_inline_primitive_ref: # Semi-join variant for inline scalar links, # which is, essentially, just filtering out NULLs. ensure_source_rvar(ir_source, srcctx.rel, ctx=srcctx) var = pathctx.get_path_value_var( srcctx.rel, path_id=ir_set.path_id, env=ctx.env) if var.nullable: srcctx.rel.where_clause = astutils.extend_binop( srcctx.rel.where_clause, pgast.NullTest(arg=var, negated=True)) srcrel = srcctx.rel src_rvar = relctx.rvar_for_rel(srcrel, lateral=True, ctx=srcctx) relctx.include_rvar(stmt, src_rvar, path_id=ir_source.path_id, ctx=ctx) pathctx.put_path_id_mask(stmt, ir_source.path_id) # Path is a reference to a link property. if is_linkprop: srvars = process_set_as_link_property_ref(ir_set, ctx=ctx) main_rvar = srvars.main rvars.extend(srvars.new) elif is_id_ref_to_inline_source: main_rvar = SetRVar( ensure_source_rvar(ir_source, stmt, ctx=ctx), path_id=ir_set.path_id, aspects=[pgce.PathAspect.VALUE] ) elif is_inline_primitive_ref: # There is an opportunity to also expose the "source" aspect # for tuple refs here, but that requires teaching pathctx about # complex field indirections, so rely on tuple_getattr() # fallback for tuple properties for now. main_rvar = SetRVar( ensure_source_rvar(ir_source, stmt, ctx=ctx), path_id=ir_set.path_id, aspects=[pgce.PathAspect.VALUE] ) rvars = [main_rvar] elif not semi_join: # Link range. if is_inline_ref: aspects = [pgce.PathAspect.VALUE] # If this is a link that is stored inline, make sure # the source aspect is actually accessible (not just value). src_rvar = ensure_source_rvar(ir_source, stmt, ctx=ctx) # In case the source is visible (so the codepath below # that uses the current statement doesn't trigger) but the # source aspect wasn't available, make sure we include it # in our return. This can come up with __old__ in triggers. if source_is_visible: rvars.append(SetRVar( src_rvar, path_id=ir_source.path_id, aspects=[pgce.PathAspect.SOURCE] )) else: aspects = [pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE] src_rvar = get_set_rvar(ir_source, ctx=ctx) map_rvar = SetRVar( relctx.new_pointer_rvar(ir_set, src_rvar=src_rvar, ctx=ctx), path_id=ir_set.path_id.ptr_path(), aspects=aspects ) rvars.append(map_rvar) # Target set range. if irtyputils.is_object(ir_set.typeref): target_rvar = relctx.new_root_rvar(ir_set, lateral=True, ctx=ctx) main_rvar = SetRVar( target_rvar, path_id=ir_set.path_id, aspects=[pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE] ) rvars.append(main_rvar) else: main_rvar = SetRVar( map_rvar.rvar, path_id=ir_set.path_id, aspects=[pgce.PathAspect.VALUE], ) rvars.append(main_rvar) if not source_is_visible: # If the source path is not visible in the current scope, # it means that there are no other paths sharing this path prefix # in this scope. In such cases the path is represented by a subquery # rather than a simple set of ranges. for srvar in rvars: relctx.include_specific_rvar( stmt, srvar.rvar, path_id=srvar.path_id, aspects=srvar.aspects, ctx=ctx) if is_primitive_ref: aspects = [pgce.PathAspect.VALUE] else: aspects = [pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE] main_rvar = SetRVar( relctx.new_rel_rvar(ir_set, stmt, ctx=ctx), path_id=ir_set.path_id, aspects=aspects, ) rvars = [main_rvar] assert main_rvar return SetRVars(main=main_rvar, new=rvars) def _new_subquery_stmt_set_rvar( ir_set: irast.Set, stmt: pgast.Query, *, ctx: context.CompilerContextLevel, ) -> SetRVars: aspects = pathctx.list_path_aspects(stmt, ir_set.path_id) if ir_set.path_id.is_tuple_path(): # If we are wrapping a tuple expression, make sure not to # over-represent it in terms of the exposed aspects. aspects -= {pgce.PathAspect.SERIALIZED} return new_stmt_set_rvar( ir_set, stmt, aspects=aspects, ctx=ctx) def _lookup_set_rvar_in_source( ir_set: irast.Set, src_rvar: Optional[pgast.PathRangeVar], *, ctx: context.CompilerContextLevel) -> Optional[pgast.PathRangeVar]: if not ( ir_set.is_materialized_ref and isinstance(src_rvar, pgast.RangeSubselect) ): return None if pathctx.maybe_get_path_value_var( src_rvar.subquery, ir_set.path_id, env=ctx.env ): return src_rvar # When looking for an packed value in our source rvar, we need to # account for the fact that unpack_rvar names all of its outputs # based solely on the source--that is, if any of the pointer paths # have extra namespaces on them, they won't appear. Rebuild the # path_id without any namespaces that aren't on the src_path. path_id = ir_set.path_id path_id = not_none(path_id.src_path()).extend( ptrref=not_none(path_id.rptr()), direction=not_none(path_id.rptr_dir()), ) if packed_ref := pathctx.maybe_get_rvar_path_var( src_rvar, pathctx.map_path_id( path_id, src_rvar.subquery.view_path_id_map, ), aspect=pgce.PathAspect.VALUE, flavor='packed', env=ctx.env, ): return relctx.unpack_var( ctx.rel, ir_set.path_id, ref=packed_ref, ctx=ctx) return None # N.B: registered for get_rvar for Stmt and MaterializedExpr below # Also, called explicitly for Pointer when expr is not None # TODO: This is a tangled mess that handles several cases. # Most of the code is for computed pointers. def process_set_as_subquery( ir_set: irast.Set, *, ctx: context.CompilerContextLevel ) -> SetRVars: is_objtype_path = ir_set.path_id.is_objtype_path() stmt = ctx.rel if isinstance(ir_set.expr, irast.Pointer): rptr = ir_set.expr expr = rptr.expr else: expr = ir_set.expr rptr = None ir_source: Optional[irast.Set] source_set_rvar = None if rptr is not None: ir_source = rptr.source if not is_objtype_path: source_is_visible = True else: # Non-scalar computable pointer. Check if path source is # visible in the outer scope. outer_fence = ctx.scope_tree.parent_branch assert outer_fence is not None source_is_visible = outer_fence.is_visible(ir_source.path_id) if source_is_visible and ( ir_source.path_id not in ctx.skippable_sources ): with ctx.new() as sctx: sctx.expr_exposed = False source_set_rvar = get_set_rvar(ir_source, ctx=sctx) # Force a source rvar so that trivial computed pointers # on erroneous objects (like a bad array deref) fail. # (Most sensible computables will end up requiring the # source rvar anyway.) ensure_source_rvar(ir_source, stmt, ctx=sctx) else: ir_source = None source_is_visible = False with ctx.new() as newctx: # Suppress volatility refs while compiling schema # aliases/globals. While they might try to apply volatility # refs due to FOR/free objects, it shouldn't be semantically # necessary that they actually are attached to the enclosing # location. This turns out to be an important optimization for # ext::auth::ClientTokenIdentity. if ir_set.is_schema_alias: newctx.volatility_ref = () outer_id = ir_set.path_id semi_join = False if ir_source is not None: if ( ir_source.path_id != ctx.current_insert_path_id and not irutils.is_trivial_free_object(ir_source) ): # This is a computable pointer. In order to ensure that # the volatile functions in the pointer expression are called # the necessary number of times, we must inject a # "volatility reference" into function expressions. # The volatility_ref is the identity of the pointer source. # If the source is an insert that we are in the middle # of doing, we don't have a volatility ref to add, so # skip it based on the current_insert_path_id check. # Note also that we skip this when the source is a # trivial free object reference. A trivial free object # reference is always executed exactly once (if there # is an outer iterator of some kind, we'll pick up # *that* volatility ref) and, unlike other shapes, may # contain DML. We disable the volatility ref for # trival free objects then both as a minor # optimization and to avoid it interfering with DML in # the object (since the volatility ref would not be # visible in DML CTEs). path_id = ir_source.path_id newctx.volatility_ref += ( lambda _stmt, xctx: relctx.maybe_get_path_var( stmt, path_id=path_id, aspect=pgce.PathAspect.IDENTITY, ctx=xctx, ), ) if is_objtype_path and not source_is_visible: # Non-scalar computable semi-join. # TODO: The basic path case has a more sophisticated # understanding of when to do semi-joins. Using that # naively here doesn't work, but perhaps it could be # adapted? # Don't semi-join on free objects, since they are all unique # (but don't *actually* have unique ids...) semi_join = not irtyputils.is_free_object(ir_set.typeref) # We need to compile the source and include it in, # since we need to do the semi-join deduplication here # on the outside, and not when the source is used in a # path inside the computable. # (See test_edgeql_scope_computables_09 for an example.) with newctx.subrel() as _, _.newscope() as subctx: get_set_rvar(ir_source, ctx=subctx) subrvar = relctx.rvar_for_rel(subctx.rel, ctx=subctx) # Force a source rvar. See above. ensure_source_rvar(ir_source, subctx.rel, ctx=subctx) relctx.include_rvar( stmt, subrvar, ir_source.path_id, ctx=newctx) # If we are looking at a materialized computable, running # get_set_rvar on the source above may have made it show # up. So try to lookup the rvar again, and try to look it up # in the source_rvar itself, and if we find it, skip compiling # the computable. if ir_source and (new_rvar := ( _lookup_set_rvar(ir_set, ctx=newctx) or _lookup_set_rvar_in_source(ir_set, source_set_rvar, ctx=newctx) )): if semi_join: # We need to use DISTINCT, instead of doing an actual # semi-join, unfortunately: we need to extract data # out from stmt, which we can't do with a semi-join. value_var = pathctx.get_rvar_path_var( new_rvar, outer_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) stmt.distinct_clause = ( pathctx.get_rvar_output_var_as_col_list( subrvar, value_var, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) ) return _new_subquery_stmt_set_rvar(ir_set, stmt, ctx=newctx) # materialized refs should always get picked up by now assert not isinstance(expr, irast.MaterializedExpr), ( f"Can't find materialized set {ir_set.path_id}" ) assert isinstance(expr, irast.Stmt) inner_set = expr.result inner_id = inner_set.path_id if inner_id != outer_id: pathctx.put_path_id_map(stmt, outer_id, inner_id) if isinstance(expr, irast.MutatingStmt) and expr in ctx.dml_stmts: # The DML table-routing logic may result in the same # DML subquery to be visited twice, such as in the case # of a nested INSERT declaring link properties, so guard # against generating a duplicate DML CTE. with newctx.substmt() as subrelctx: dml_cte = ctx.dml_stmts[expr] dml.wrap_dml_cte(expr, dml_cte, ctx=subrelctx) else: dispatch.visit(expr, ctx=newctx) if semi_join: set_rvar = relctx.new_root_rvar(ir_set, ctx=newctx) tgt_ref = pathctx.get_rvar_path_identity_var( set_rvar, ir_set.path_id, env=ctx.env) pathctx.get_path_identity_output( stmt, ir_set.path_id, env=ctx.env) cond_expr = astutils.new_binop(tgt_ref, stmt, 'IN') # Make a new stmt, join in the new root, and semi join on # the original statement. stmt = pgast.SelectStmt() relctx.include_rvar(stmt, set_rvar, ir_set.path_id, ctx=newctx) stmt.where_clause = astutils.extend_binop( stmt.where_clause, cond_expr) rvars = _new_subquery_stmt_set_rvar(ir_set, stmt, ctx=ctx) # If the inner set also exposes a pointer path source, we need to # also expose a pointer path source. See tests like # test_edgeql_select_linkprop_rebind_01 if pathctx.maybe_get_path_rvar( stmt, inner_id.ptr_path(), aspect=pgce.PathAspect.SOURCE, ): rvars.new.append( SetRVar( rvars.main.rvar, outer_id.ptr_path(), aspects=(pgce.PathAspect.SOURCE,), ) ) return rvars register_get_rvar(irast.Stmt)(process_set_as_subquery) register_get_rvar(irast.MaterializedExpr)(process_set_as_subquery) @_special_case('std::IN') @_special_case('std::NOT IN') def process_set_as_membership_expr( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr assert isinstance(expr, irast.OperatorCall) with ctx.new() as newctx: left, right = (a for a in expr.args.values()) left_arg, right_arg = left.expr, right.expr newctx.expr_exposed = False left_out = dispatch.compile(left_arg, ctx=newctx) orig_right_arg = right_arg unwrapped_right_arg = irutils.unwrap_set(right_arg) # If the right operand of [NOT] IN is an array_unpack call, # then use the ANY/ALL array comparison operator directly, # since that has a higher chance of using the indexes. right_expr = unwrapped_right_arg.expr needs_coalesce = False if ( isinstance(right_expr, irast.FunctionCall) and str(right_expr.func_shortname) == 'std::array_unpack' and not right_expr.args[0].cardinality.is_multi() and (not expr.sql_operator or len(expr.sql_operator) <= 1) ): is_array_unpack = True right_arg = right_expr.args[0].expr needs_coalesce = right_expr.args[0].cardinality.can_be_zero() else: is_array_unpack = False left_is_row_expr = astutils.is_row_expr(left_out) with newctx.subrel() as _, _.newscope() as subctx: if is_array_unpack: relctx.update_scope(orig_right_arg, subctx.rel, ctx=subctx) relctx.update_scope( unwrapped_right_arg, subctx.rel, ctx=subctx) dispatch.compile(right_arg, ctx=subctx) right_rel = subctx.rel right_out = pathctx.get_path_value_var( right_rel, right_arg.path_id, env=subctx.env) right_out = output.output_as_value(right_out, env=ctx.env) if ( left_is_row_expr and right_arg.path_id.is_tuple_path() ): # When the RHS is an opaque tuple, we must unpack # it using the (...).* indirection syntax, otherwise # we get "subquery has too few columns". right_out = pgast.Indirection( arg=right_out, indirection=[pgast.Star()], ) right_rel.target_list = [pgast.ResTarget(val=right_out)] if is_array_unpack: right_rel = pgast.TypeCast( arg=right_rel, type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref( right_arg.typeref) ) ) negated = str(expr.func_shortname) == 'std::NOT IN' set_expr = exprcomp.compile_operator( expr, [ left_out, pgast.SubLink( operator="ALL" if negated else "ANY", expr=right_rel, ), ], ctx=ctx, ) # A NULL argument to the array variant will produce NULL, so we # need to coalesce if that is possible. if needs_coalesce: empty_val = negated set_expr = pgast.CoalesceExpr(args=[ set_expr, pgast.BooleanConstant(val=empty_val)]) # Filter out situations where the LHS is a SQL NULL, # since those will report false instead of {}. if left.cardinality.can_be_zero() and left_out.nullable: ctx.rel.where_clause = astutils.extend_binop( ctx.rel.where_clause, pgast.NullTest(arg=left_out, negated=True), ) pathctx.put_path_value_var_if_not_exists( ctx.rel, ir_set.path_id, set_expr ) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) @_special_case('std::UNION') @_special_case('std::EXCEPT') @_special_case('std::INTERSECT') def process_set_as_setop( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr with ctx.new() as newctx: newctx.expr_exposed = False left, right = (a.expr for a in expr.args.values()) with newctx.subrel() as _, _.newscope() as scopectx: larg = scopectx.rel pathctx.put_path_id_map(larg, ir_set.path_id, left.path_id) dispatch.visit(left, ctx=scopectx) with newctx.subrel() as _, _.newscope() as scopectx: rarg = scopectx.rel pathctx.put_path_id_map(rarg, ir_set.path_id, right.path_id) dispatch.visit(right, ctx=scopectx) aspects = pathctx.list_path_aspects( larg, left.path_id ) & pathctx.list_path_aspects(rarg, right.path_id) with ctx.subrel() as subctx: subqry = subctx.rel # There are three possible binary set operators coming from IR: # UNION, EXCEPT, and INTERSECT subqry.op = expr.func_shortname.name subqry.all = True subqry.larg = larg subqry.rarg = rarg setop_rvar = relctx.rvar_for_rel(subqry, lateral=True, ctx=subctx) # No pull_namespace because we don't want the union arguments to # escape, just the final result. relctx.include_rvar( ctx.rel, setop_rvar, ir_set.path_id, aspects=aspects, pull_namespace=False, ctx=subctx, ) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) @_special_case('std::DISTINCT') def process_set_as_distinct( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr stmt = ctx.rel with ctx.subrel() as subctx: subqry = subctx.rel arg = expr.args[0].expr pathctx.put_path_id_map(subqry, ir_set.path_id, arg.path_id) dispatch.visit(arg, ctx=subctx) subrvar = relctx.rvar_for_rel( subqry, typeref=arg.typeref, lateral=True, ctx=subctx) relctx.include_rvar(stmt, subrvar, ir_set.path_id, ctx=ctx) value_var = pathctx.get_rvar_path_var( subrvar, ir_set.path_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) stmt.distinct_clause = pathctx.get_rvar_output_var_as_col_list( subrvar, value_var, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) # If there aren't any columns, we are doing DISTINCT on empty # tuples. All empty tuples are equivalent, so we can just compile # this by adding a LIMIT 1. if not stmt.distinct_clause: stmt.limit_count = pgast.NumericConstant(val="1") return new_stmt_set_rvar(ir_set, stmt, ctx=ctx) @_special_case('std::IF') def process_set_as_ifelse( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: # A IF Cond ELSE B is transformed into: # SELECT A WHERE Cond UNION ALL SELECT B WHERE NOT Cond expr = ir_set.expr stmt = ctx.rel if_expr, condition, else_expr = (a.expr for a in expr.args.values()) if_expr_card, _, else_expr_card = ( a.cardinality for a in expr.args.values() ) with ctx.new() as newctx: newctx.expr_exposed = False dispatch.visit(condition, ctx=newctx) condref = relctx.get_path_var( stmt, path_id=condition.path_id, aspect=pgce.PathAspect.VALUE, ctx=newctx, ) if (if_expr_card.is_single() and else_expr_card.is_single() and irtyputils.is_scalar(expr.typeref)): # For a simple case of singleton scalars on both ends of IF, # use a CASE WHEN construct, since it's normally faster than # a UNION ALL with filters. The reason why we limit this # optimization to scalars is because CASE WHEN can only yield # a single value, hence no other aspects can be supported # by this rvar. with ctx.new() as newctx: newctx.expr_exposed = False # Values still need to be encased in subqueries to guard # against empty sets. if_val = set_as_subquery(if_expr, as_value=True, ctx=newctx) else_val = set_as_subquery(else_expr, as_value=True, ctx=newctx) set_expr = pgast.CaseExpr( args=[pgast.CaseWhen(expr=condref, result=if_val)], defresult=else_val, ) with ctx.subrel() as subctx: pathctx.put_path_value_var_if_not_exists( subctx.rel, ir_set.path_id, set_expr, ) sub_rvar = relctx.rvar_for_rel( subctx.rel, lateral=True, ctx=subctx, ) relctx.include_rvar(stmt, sub_rvar, ir_set.path_id, ctx=subctx) rvar = pathctx.get_path_value_var( stmt, path_id=ir_set.path_id, env=ctx.env) # We need to NULL filter both the result and the input condition for var in [rvar, condref]: stmt.where_clause = astutils.extend_binop( stmt.where_clause, pgast.NullTest( arg=var, negated=True ) ) else: with ctx.subrel() as _, _.newscope() as subctx: subctx.expr_exposed = False larg = subctx.rel pathctx.put_path_id_map(larg, ir_set.path_id, if_expr.path_id) dispatch.visit(if_expr, ctx=subctx) larg.where_clause = astutils.extend_binop( larg.where_clause, condref ) with ctx.subrel() as _, _.newscope() as subctx: subctx.expr_exposed = False rarg = subctx.rel pathctx.put_path_id_map(rarg, ir_set.path_id, else_expr.path_id) dispatch.visit(else_expr, ctx=subctx) rarg.where_clause = astutils.extend_binop( rarg.where_clause, astutils.new_unop('NOT', condref) ) aspects = pathctx.list_path_aspects( larg, if_expr.path_id ) & pathctx.list_path_aspects(rarg, else_expr.path_id) with ctx.subrel() as subctx: subqry = subctx.rel subqry.op = 'UNION' subqry.all = True subqry.larg = larg subqry.rarg = rarg union_rvar = relctx.rvar_for_rel(subqry, lateral=True, ctx=subctx) relctx.include_rvar( stmt, union_rvar, ir_set.path_id, pull_namespace=False, aspects=aspects, ctx=subctx, ) return new_stmt_set_rvar(ir_set, stmt, ctx=ctx) @_special_case('std::??') def process_set_as_coalesce( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr with ctx.new() as newctx: newctx.expr_exposed = False left_ir, right_ir = (a.expr for a in expr.args.values()) _left_card, right_card = (a.cardinality for a in expr.args.values()) is_object = ( ir_set.path_id.is_objtype_path() or ir_set.path_id.is_tuple_path() ) # The cardinality optimization below applies only to # non-object/non-tuple expressions, because we don't want to # have to deal with the complexity of resolving coalesced # sources for potential link or property references. if right_card.is_single() and not is_object: left = dispatch.compile(left_ir, ctx=newctx) # If the RHS is optional, we compile it in a subquery so that # it becomes NULL instead of potentially joining in zero rows. # If not, just compile it without any protection. right = ( set_as_subquery(right_ir, ctx=newctx) if right_card.can_be_zero() else dispatch.compile(right_ir, ctx=newctx) ) # Just use scalar COALESCE now set_expr = pgast.CoalesceExpr(args=[left, right]) pathctx.put_path_value_var(ctx.rel, ir_set.path_id, set_expr) else: # Things become tricky in cases where the RHS is a non-singleton # or where we need to worry about source aspects. # We cannot use the regular scalar COALESCE over a JOIN, # as that'll blow up the result cardinality. Instead, we do # something like: # # CROSS JOIN LATERAL # () as lhs # CROSS JOIN LATERAL ( # SELECT * FROM lhs WHERE lhs.value IS NOT NULL # UNION ALL # SELECT * FROM as rhs WHERE lhs.value IS NULL # ) as q # # Note that will be compiled with optional wrapping, # so it shouldn't ever produce zero rows. lhs_rvar = get_set_rvar(left_ir, ctx=newctx) lvar = pathctx.get_rvar_path_var( lhs_rvar, left_ir.path_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) lval = output.output_as_value(lvar, env=ctx.env) with newctx.subrel() as lctx: larg = lctx.rel pathctx.put_path_id_map(larg, ir_set.path_id, left_ir.path_id) relctx.include_rvar( larg, lhs_rvar, path_id=left_ir.path_id, ctx=lctx, # Only include the aspects that got included # into our rel; it's possible some were explicitly # left out, and we should respect that. aspects=pathctx.list_path_aspects( newctx.rel, left_ir.path_id ), ) # Include the LHS when it is not NULL. larg.where_clause = astutils.extend_binop( larg.where_clause, pgast.NullTest( arg=lval, negated=True ) ) with newctx.subrel() as rctx: rarg = rctx.rel pathctx.put_path_id_map(rarg, ir_set.path_id, right_ir.path_id) rvar = dispatch.compile(right_ir, ctx=rctx) # Include the RHS when the LHS is NULL. # Note that an important precondition of this is that # there are not "stray" NULLs in the LHS input set. # # HACK: We need to use IS NOT DISTINCT FROM # because ROW() IS NULL is true, and that # breaks some things. rarg.where_clause = astutils.extend_binop( rarg.where_clause, astutils.new_binop( lval, pgast.NullConstant(), 'IS NOT DISTINCT FROM', ), ) if rvar.nullable: rarg.where_clause = astutils.extend_binop( rarg.where_clause, pgast.NullTest(arg=rvar, negated=True) ) union_rvar = relctx.rvar_for_rel( pgast.SelectStmt(op='UNION', all=True, larg=larg, rarg=rarg), lateral=True, ctx=newctx, ) aspects = ( pathctx.list_path_aspects(larg, left_ir.path_id) & pathctx.list_path_aspects(rarg, right_ir.path_id) ) # No pull_namespace because we don't want the coalesce arguments to # escape, just the final result. relctx.include_rvar( ctx.rel, union_rvar, path_id=ir_set.path_id, aspects=aspects, pull_namespace=False, ctx=newctx, ) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) @register_get_rvar(irast.Tuple) def process_set_as_tuple( ir_set: irast.SetE[irast.Tuple], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr stmt = ctx.rel with ctx.new() as subctx: subctx.expr_exposed_tuple_cheat = None elements = [] ttypes = {} for i, st in enumerate(ir_set.typeref.subtypes): if st.element_name: ttypes[st.element_name] = st else: ttypes[str(i)] = st for element in expr.elements: assert element.path_id path_id = element.path_id # We compile in a subrel *solely* so that we can map # each element individually. It would be nice to have # a way to do this that doesn't actually affect the output! with subctx.subrel() as newctx: if element is ctx.expr_exposed_tuple_cheat: newctx.expr_exposed = True if path_id != element.val.path_id: pathctx.put_path_id_map( newctx.rel, path_id, element.val.path_id) dispatch.visit(element.val, ctx=newctx) el_rvar = relctx.new_rel_rvar(ir_set, newctx.rel, ctx=ctx) aspects = pathctx.list_path_aspects( newctx.rel, element.val.path_id ) # update_mask=False because we are doing this solely to remap # elements individually and don't want to affect the mask. relctx.include_rvar( stmt, el_rvar, path_id, update_mask=False, aspects=aspects, ctx=ctx, ) tvar = pathctx.get_path_value_var(stmt, path_id, env=subctx.env) elements.append(pgast.TupleElementBase(path_id=path_id)) # We need to filter out NULLs at tuple creation time, to # prevent having tuples that are part-NULL. if tvar.nullable: stmt.where_clause = astutils.extend_binop( stmt.where_clause, pgast.NullTest(arg=tvar, negated=True) ) var = pathctx.maybe_get_path_var( stmt, element.val.path_id, aspect=pgce.PathAspect.SERIALIZED, env=subctx.env, ) if var is not None: pathctx.put_path_var( stmt, path_id, var, aspect=pgce.PathAspect.SERIALIZED, ) set_expr = pgast.TupleVarBase( elements=elements, named=expr.named, typeref=ir_set.typeref, ) pathctx.put_path_value_var(stmt, ir_set.path_id, set_expr) # This is an unfortunate hack. If any of those types that we # contain are an object, then force the computation of the # serialized output now. This avoids issues where there may be # references to tuple elements with the same path id but different # shapes, and the delaying induced by a TupleBaseVar can cause the # wrong one to be output. (See test_edgeql_scope_shape_03 for an example # where this can come up.) # (We only do it for objects as an optimization.) if ( output.in_serialization_ctx(ctx) and any(irtyputils.is_object(x) for x in ir_set.typeref.subtypes) ): pathctx.get_path_serialized_output(stmt, ir_set.path_id, env=ctx.env) return new_stmt_set_rvar( ir_set, stmt, aspects=[pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE], ctx=ctx, ) @register_get_rvar(irast.TupleIndirectionPointer) def process_set_as_tuple_indirection( ir_set: irast.SetE[irast.TupleIndirectionPointer], *, ctx: context.CompilerContextLevel ) -> SetRVars: rptr = ir_set.expr tuple_set = rptr.source stmt = ctx.rel assert not rptr.expr, 'tuple indirection pointer with expr??' with ctx.new() as subctx: # Usually the LHS is is not exposed, but when we are directly # projecting from an explicit tuple, and the result is a # collection, arrange to have the element we are projecting # treated as exposed. This behavior is needed for our # eta-expansion of arrays to work, since it generates that # idiom in a place where it needs the output to be exposed. subctx.expr_exposed = False if ( ctx.expr_exposed and not tuple_set.is_binding and isinstance(tuple_set.expr, irast.Tuple) and ir_set.path_id.is_collection_path() ): for el in tuple_set.expr.elements: if el.name == rptr.ptrref.shortname.name: subctx.expr_exposed_tuple_cheat = el break rvar = get_set_rvar(tuple_set, ctx=subctx) source_rvar = relctx.maybe_get_path_rvar( stmt, tuple_set.path_id, aspect=pgce.PathAspect.SOURCE, ctx=subctx, ) if source_rvar is None: # Lack of visible tuple source means we are # an indirection over an opaque tuple, e.g. in # `SELECT [(1,)][0].0`. This means we must # use an explicit row attribute dereference. tuple_val = pathctx.get_path_value_var( stmt, path_id=tuple_set.path_id, env=subctx.env) set_expr = astutils.tuple_getattr( tuple_val, tuple_set.typeref, rptr.ptrref.shortname.name, ) pathctx.put_path_var_if_not_exists( stmt, ir_set.path_id, set_expr, aspect=pgce.PathAspect.VALUE, ) rvar = relctx.new_rel_rvar(ir_set, stmt, ctx=subctx) return new_simple_set_rvar( ir_set, rvar, aspects=(pgce.PathAspect.VALUE,), ) @register_get_rvar(irast.TypeCast) def process_set_as_type_cast( ir_set: irast.SetE[irast.TypeCast], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr stmt = ctx.rel inner_set = expr.expr is_json_cast = expr.to_type.id == s_obj.get_known_type_id('std::json') # Are we casting by compiling the innards in json mode? implicit_cast = ( is_json_cast and not irtyputils.is_range(inner_set.typeref) and (irtyputils.is_collection(inner_set.typeref) or irtyputils.is_object(inner_set.typeref)) ) fmt_ctx = ( context.output_format(ctx, context.OutputFormat.JSONB) if implicit_cast else contextlib.nullcontext() ) with fmt_ctx, ctx.new() as subctx: pathctx.put_path_id_map(ctx.rel, ir_set.path_id, inner_set.path_id) if implicit_cast: subctx.expr_exposed = True set_expr = dispatch.compile(inner_set, ctx=subctx) serialized: Optional[pgast.BaseExpr] = ( pathctx.maybe_get_path_serialized_var( stmt, inner_set.path_id, env=subctx.env) ) if serialized is not None: if irtyputils.is_collection(inner_set.typeref): serialized = output.serialize_expr_to_json( serialized, styperef=inner_set.path_id.target, env=subctx.env) pathctx.put_path_value_var( stmt, inner_set.path_id, serialized, force=True ) pathctx.put_path_serialized_var( stmt, inner_set.path_id, serialized, force=True ) else: # Rely on the simple implementation of TypeCast set_expr = dispatch.compile(expr, ctx=subctx) # A proper path var mapping way would be to wrap # the inner expression in a subquery, but that # seems excessive for a type cast, so we cover # our tracks here by removing the mapping and # relying on the value and serialized vars # populated above. stmt.view_path_id_map.pop(ir_set.path_id) pathctx.put_path_value_var_if_not_exists(stmt, ir_set.path_id, set_expr) return new_stmt_set_rvar(ir_set, stmt, ctx=ctx) @register_get_rvar(irast.ConstantSet) def process_set_as_const_set( ir_set: irast.SetE[irast.ConstantSet], *, ctx: context.CompilerContextLevel ) -> SetRVars: with ctx.subrel() as subctx: vals = [dispatch.compile(v, ctx=subctx) for v in ir_set.expr.elements] vals_rel = subctx.rel vals_rel.values = [pgast.ImplicitRowExpr(args=[v]) for v in vals] vals_rel.nullable = any(v.nullable for v in vals) vals_rvar = relctx.new_rel_rvar(ir_set, vals_rel, ctx=ctx) relctx.include_rvar(ctx.rel, vals_rvar, ir_set.path_id, ctx=ctx) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) def process_set_as_oper_expr( ir_set: irast.SetE[irast.OperatorCall], *, ctx: context.CompilerContextLevel ) -> SetRVars: # XXX: do we need a subrel? with ctx.new() as newctx: newctx.expr_exposed = False args = _compile_call_args(ir_set, ctx=newctx) oper_expr = exprcomp.compile_operator(ir_set.expr, args, ctx=newctx) if _should_unwrap_polymorphic_return_array(ir_set.expr): oper_expr = astutils.array_get_inner_array( oper_expr, ir_set.expr.typeref ) pathctx.put_path_value_var_if_not_exists( ctx.rel, ir_set.path_id, oper_expr ) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) @register_get_rvar(irast.TriggerAnchor) def process_set_as_trigger_anchor( ir_set: irast.Set, *, ctx: context.CompilerContextLevel ) -> SetRVars: # XXX: This will need to grow more things if ir_set.path_id in ctx.external_rels: return process_external_rel(ir_set, ctx=ctx) return process_set_as_root(ir_set, ctx=ctx) @register_get_rvar(irast.Expr) def process_set_as_expr( ir_set: irast.SetE[irast.Expr], *, ctx: context.CompilerContextLevel ) -> SetRVars: with ctx.new() as newctx: newctx.expr_exposed = False set_expr = dispatch.compile(ir_set.expr, ctx=newctx) pathctx.put_path_value_var_if_not_exists(ctx.rel, ir_set.path_id, set_expr) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) @_special_case('std::assert_single') def process_set_as_singleton_assertion( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel, ) -> SetRVars: expr = ir_set.expr stmt = ctx.rel msg_arg = expr.args['message'] ir_arg = expr.args[0] ir_arg_set = ir_arg.expr if ( ir_arg.cardinality.is_single() and not msg_arg.cardinality.is_multi() ): # If the argument has been statically proven to be a singleton, # elide the entire assertion. arg_ref = dispatch.compile(ir_arg_set, ctx=ctx) pathctx.put_path_value_var(stmt, ir_set.path_id, arg_ref) pathctx.put_path_id_map(stmt, ir_set.path_id, ir_arg_set.path_id) return new_stmt_set_rvar(ir_set, stmt, ctx=ctx) with ctx.subrel() as newctx: arg_ref = dispatch.compile(ir_arg_set, ctx=newctx) arg_val = output.output_as_value(arg_ref, env=newctx.env) msg = dispatch.compile(msg_arg.expr, ctx=newctx) # Generate a singleton set assertion as the following SQL: # # SELECT # , # raise_on_null(NULLIF(row_number() OVER (), 2)) AS _sentinel # ORDER BY # _sentinel # # This effectively raises an error whenever the row counter reaches 2. check_expr = pgast.FuncCall( name=('nullif',), args=[ pgast.FuncCall( name=('row_number',), args=[], over=pgast.WindowDef() ), pgast.NumericConstant( val='2', ), ], ) maybe_raise = pgast.FuncCall( name=astutils.edgedb_func('raise_on_null', ctx=ctx), args=[ check_expr, pgast.StringConstant(val='cardinality_violation'), pgast.NamedFuncArg( name='msg', val=pgast.CoalesceExpr( args=[ msg, pgast.StringConstant( val='assert_single violation: more than one ' 'element returned by an expression', ), ], ), ), pgast.NamedFuncArg( name='constraint', val=pgast.StringConstant(val='std::assert_single'), ), ], ) output.add_null_test(arg_ref, newctx.rel) # Force Postgres to actually evaluate the result target # by putting it into an ORDER BY. newctx.rel.target_list.append( pgast.ResTarget( name="_sentinel", val=maybe_raise, ), ) if newctx.rel.sort_clause is None: newctx.rel.sort_clause = [] newctx.rel.sort_clause.append( pgast.SortBy(node=pgast.ColumnRef(name=["_sentinel"])), ) pathctx.put_path_var_if_not_exists( newctx.rel, ir_set.path_id, arg_val, aspect=pgce.PathAspect.VALUE ) pathctx.put_path_id_map(newctx.rel, ir_set.path_id, ir_arg_set.path_id) aspects = pathctx.list_path_aspects(newctx.rel, ir_arg_set.path_id) func_rvar = relctx.new_rel_rvar(ir_set, newctx.rel, ctx=ctx) relctx.include_rvar(stmt, func_rvar, ir_set.path_id, aspects=aspects, ctx=ctx) return new_stmt_set_rvar(ir_set, stmt, aspects=aspects, ctx=ctx) @_special_case('std::assert_exists') def process_set_as_existence_assertion( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel, ) -> SetRVars: """Implementation of std::assert_exists""" expr = ir_set.expr stmt = ctx.rel msg_arg = expr.args['message'] ir_arg = expr.args[0] ir_arg_set = ir_arg.expr if ( not ir_arg.cardinality.can_be_zero() and not msg_arg.cardinality.is_multi() ): # If the argument has been statically proven to be non empty, # elide the entire assertion. arg_ref = dispatch.compile(ir_arg_set, ctx=ctx) pathctx.put_path_value_var(stmt, ir_set.path_id, arg_ref) pathctx.put_path_id_map(stmt, ir_set.path_id, ir_arg_set.path_id) return new_stmt_set_rvar(ir_set, stmt, ctx=ctx) with ctx.subrel() as newctx: # The solution to assert_exists() is as simple as # calling raise_on_null(). newctx.expr_exposed = False newctx.force_optional |= {ir_arg_set.path_id} pathctx.put_path_id_map(newctx.rel, ir_set.path_id, ir_arg_set.path_id) arg_ref = dispatch.compile(ir_arg_set, ctx=newctx) arg_val = output.output_as_value(arg_ref, env=newctx.env) msg = dispatch.compile(msg_arg.expr, ctx=newctx) set_expr = pgast.FuncCall( name=astutils.edgedb_func('raise_on_null', ctx=ctx), args=[ arg_val, pgast.StringConstant(val='cardinality_violation'), pgast.NamedFuncArg( name='msg', val=pgast.CoalesceExpr( args=[ msg, pgast.StringConstant( val='assert_exists violation: expression ' 'returned an empty set', ), ] ), ), pgast.NamedFuncArg( name='constraint', val=pgast.StringConstant(val='std::assert_exists'), ), ], ) pathctx.put_path_value_var( newctx.rel, ir_arg_set.path_id, set_expr, force=True, ) other_aspect = ( pgce.PathAspect.IDENTITY if ir_set.path_id.is_objtype_path() else pgce.PathAspect.SERIALIZED ) pathctx.put_path_var( newctx.rel, ir_arg_set.path_id, set_expr, force=True, aspect=other_aspect, ) # It is important that we do not provide source, which could allow # fields on the object to be accessed without triggering the # raise_on_null. Not providing source means another join is # needed, which will trigger it. func_rvar = relctx.new_rel_rvar(ir_set, newctx.rel, ctx=ctx) relctx.include_rvar( stmt, func_rvar, ir_set.path_id, aspects=(pgce.PathAspect.VALUE,), ctx=ctx, ) return new_stmt_set_rvar( ir_set, stmt, aspects=(pgce.PathAspect.VALUE,), ctx=ctx, ) @_special_case('std::assert_distinct') def process_set_as_multiplicity_assertion( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel, ) -> SetRVars: """Implementation of std::assert_distinct""" expr = ir_set.expr msg_arg = expr.args['message'] ir_arg = expr.args[0] ir_arg_set = ir_arg.expr if ( not ir_arg.multiplicity.is_duplicate() and not msg_arg.cardinality.is_multi() ): # If the argument has been statically proven to be distinct, # elide the entire assertion. arg_ref = dispatch.compile(ir_arg_set, ctx=ctx) pathctx.put_path_value_var(ctx.rel, ir_set.path_id, arg_ref) pathctx.put_path_id_map(ctx.rel, ir_set.path_id, ir_arg_set.path_id) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) # Generate a distinct set assertion as the following SQL: # # SELECT # , # (CASE WHEN # # IS DISTINCT FROM # lag() OVER (ORDER BY ) # THEN # ELSE edgedb.raise(ConstraintViolationError)) AS check_expr # FROM # (SELECT , row_number() OVER () AS i) AS q # ORDER BY # q.i, check_expr # # NOTE: sorting over original row_number() is necessary to preserve # order, as assert_distinct() must be completely transparent for # compliant sets. with ctx.subrel() as newctx: with newctx.subrel() as subctx: dispatch.visit(ir_arg_set, ctx=subctx) arg_ref = pathctx.get_path_output( subctx.rel, ir_arg_set.path_id, aspect=pgce.PathAspect.VALUE, env=subctx.env, ) arg_val = output.output_as_value(arg_ref, env=newctx.env) sub_rvar = relctx.new_rel_rvar(ir_arg_set, subctx.rel, ctx=subctx) aspects = pathctx.list_path_aspects(subctx.rel, ir_arg_set.path_id) relctx.include_rvar( newctx.rel, sub_rvar, ir_arg_set.path_id, aspects=aspects, ctx=subctx, ) alias = ctx.env.aliases.get('i') subctx.rel.target_list.append( pgast.ResTarget( name=alias, val=pgast.FuncCall( name=('row_number',), args=[], over=pgast.WindowDef(), ) ) ) msg = dispatch.compile(msg_arg.expr, ctx=newctx) do_raise = pgast.FuncCall( name=astutils.edgedb_func('raise', ctx=ctx), args=[ pgast.TypeCast( arg=pgast.NullConstant(), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref( ir_arg_set.typeref), ), ), pgast.StringConstant(val='cardinality_violation'), pgast.NamedFuncArg( name='msg', val=pgast.CoalesceExpr( args=[ msg, pgast.StringConstant( val='assert_distinct violation: expression ' 'returned a set with duplicate elements', ), ], ), ), pgast.NamedFuncArg( name='constraint', val=pgast.StringConstant(val='std::assert_distinct'), ), ], ) check_expr = pgast.CaseExpr( args=[ pgast.CaseWhen( expr=astutils.new_binop( lexpr=arg_val, op='IS DISTINCT FROM', rexpr=pgast.FuncCall( name=('lag',), args=[arg_val], over=pgast.WindowDef( order_clause=[pgast.SortBy(node=arg_val)], ), ), ), result=arg_val, ), ], defresult=do_raise, ) alias2 = ctx.env.aliases.get('v') newctx.rel.target_list.append( pgast.ResTarget( val=check_expr, name=alias2, ) ) pathctx.put_path_var( newctx.rel, ir_set.path_id, check_expr, aspect=pgce.PathAspect.VALUE, ) if newctx.rel.sort_clause is None: newctx.rel.sort_clause = [] newctx.rel.sort_clause.extend([ pgast.SortBy( node=pgast.ColumnRef(name=[sub_rvar.alias.aliasname, alias]), ), pgast.SortBy( node=pgast.ColumnRef(name=[alias2]), ), ]) pathctx.put_path_id_map(newctx.rel, ir_set.path_id, ir_arg_set.path_id) func_rvar = relctx.new_rel_rvar(ir_set, newctx.rel, ctx=ctx) relctx.include_rvar( ctx.rel, func_rvar, ir_set.path_id, aspects=aspects, ctx=ctx ) return new_stmt_set_rvar(ir_set, ctx.rel, aspects=aspects, ctx=ctx) @_special_case('std::materialized') def process_set_as_materialized_call( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel, ) -> SetRVars: # It's a pure pass-through. Just an identity function marked volatile. stmt = ctx.rel ir_arg_set = ir_set.expr.args[0].expr arg_ref = dispatch.compile(ir_arg_set, ctx=ctx) pathctx.put_path_value_var(stmt, ir_set.path_id, arg_ref) pathctx.put_path_id_map(stmt, ir_set.path_id, ir_arg_set.path_id) return new_stmt_set_rvar(ir_set, stmt, ctx=ctx) def process_set_as_simple_enumerate( ir_set: irast.Set, *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr assert isinstance(expr, irast.FunctionCall) with ctx.subrel() as newctx: ir_call_arg = expr.args[0] ir_arg = ir_call_arg.expr arg_ref = dispatch.compile(ir_arg, ctx=newctx) arg_val = output.output_as_value(arg_ref, env=newctx.env) if arg_ref.nullable: newctx.rel.where_clause = astutils.extend_binop( newctx.rel.where_clause, pgast.NullTest(arg=arg_ref, negated=True) ) rtype = expr.typeref named_tuple = any(st.element_name for st in rtype.subtypes) num_expr = pgast.Expr( name='-', lexpr=pgast.FuncCall( name=('row_number',), args=[], over=pgast.WindowDef() ), rexpr=pgast.NumericConstant(val='1'), nullable=False, ) set_expr = pgast.TupleVar( elements=[ pgast.TupleElement( path_id=expr.tuple_path_ids[0], name=rtype.subtypes[0].element_name or '0', val=num_expr, ), pgast.TupleElement( path_id=expr.tuple_path_ids[1], name=rtype.subtypes[1].element_name or '1', val=arg_val, ), ], named=named_tuple, typeref=ir_set.typeref, ) for element in set_expr.elements: pathctx.put_path_value_var( newctx.rel, element.path_id, element.val ) var = pathctx.maybe_get_path_var( newctx.rel, ir_arg.path_id, aspect=pgce.PathAspect.SERIALIZED, env=newctx.env, ) if var is not None: pathctx.put_path_var( newctx.rel, set_expr.elements[1].path_id, var, aspect=pgce.PathAspect.SERIALIZED, ) pathctx.put_path_var_if_not_exists( newctx.rel, ir_set.path_id, set_expr, aspect=pgce.PathAspect.VALUE, ) aspects = pathctx.list_path_aspects(newctx.rel, ir_arg.path_id) | { pgce.PathAspect.SOURCE } pathctx.put_path_id_map(newctx.rel, expr.tuple_path_ids[1], ir_arg.path_id) func_rvar = relctx.new_rel_rvar(ir_set, newctx.rel, ctx=ctx) relctx.include_rvar( ctx.rel, func_rvar, ir_set.path_id, aspects=aspects, ctx=ctx ) return new_stmt_set_rvar(ir_set, ctx.rel, aspects=aspects, ctx=ctx) @_special_case('std::enumerate') def process_set_as_enumerate( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr arg_set = expr.args[0].expr arg_expr = arg_set.expr arg_subj = irutils.unwrap_set(arg_set).expr if ( isinstance(arg_subj, irast.FunctionCall) and not arg_subj.func_sql_expr and not ( isinstance(arg_expr, irast.SelectStmt) and ( arg_expr.where or arg_expr.orderby or arg_expr.limit or arg_expr.offset ) ) and not any( f_arg.param_typemod == qltypes.TypeModifier.SetOfType for _, f_arg in arg_subj.args.items() ) ): # Enumeration of a non-aggregate function rvars = process_set_as_func_enumerate(ir_set, ctx=ctx) else: rvars = process_set_as_simple_enumerate(ir_set, ctx=ctx) return rvars @_special_case('std::max', only_as_fallback=True) @_special_case('std::min', only_as_fallback=True) def process_set_as_std_min_max( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel, ) -> SetRVars: # Postgres implements min/max aggregates for only a specific # subset of scalars and their respective arrays. However, in # EdgeDB every type is orderable (supports < and >) and so to # accommodate that we must choose between the native Postgres # aggregate and the generic fallback implementation (the native # implementation being faster). # # Since the fallback implementation is not mapped onto the same # polymorphic function in Postgres as the implementation for # supported types, we cannot rely on Postgres to always correctly # pick the polymorphic function to call, instead we use static # type inference to determine whether we'll delegate this to # Postgres (e.g. for anyreal) or we'll use the slower # one-size-fits-all fallback which then gets compiled differently. # In particular this means that when used inside a body of another # polymorphic (anytype) function, the slower generic version of # min/max will be used regardless of the actual concrete input # type. expr = ir_set.expr with ctx.subrel() as newctx: ir_arg = expr.args[0].expr dispatch.visit(ir_arg, ctx=newctx) arg_ref = pathctx.get_path_value_var( newctx.rel, ir_arg.path_id, env=newctx.env) arg_val = output.output_as_value(arg_ref, env=newctx.env) if newctx.rel.sort_clause is None: newctx.rel.sort_clause = [] newctx.rel.sort_clause.append( pgast.SortBy( node=arg_val, dir=( pgast.SortAsc if str(expr.func_shortname) == 'std::min' else pgast.SortDesc ), ), ) newctx.rel.limit_count = pgast.NumericConstant(val='1') pathctx.put_path_id_map(newctx.rel, ir_set.path_id, ir_arg.path_id) func_rvar = relctx.new_rel_rvar(ir_set, newctx.rel, ctx=ctx) relctx.include_rvar( ctx.rel, func_rvar, ir_set.path_id, pull_namespace=False, ctx=ctx ) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) @simple_special_case('std::range') def process_set_as_std_range( expr: irast.FunctionCall, *, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: # Generic range constructor implementation # # std::range( # lower, # upper, # named only inc_lower, # named only inc_upper, # named only empty, # ) # # into # # case when empty then # 'empty':: # else # ( # lower, # upper, # (array[['()', '(]'], ['[)', '[]']]) # [inc_lower::int + 1][inc_upper::int + 1] # ) # end empty = dispatch.compile(expr.args['empty'].expr, ctx=ctx) inc_lower = dispatch.compile(expr.args['inc_lower'].expr, ctx=ctx) inc_upper = dispatch.compile(expr.args['inc_upper'].expr, ctx=ctx) lower = dispatch.compile(expr.args[0].expr, ctx=ctx) upper = dispatch.compile(expr.args[1].expr, ctx=ctx) lb = pgast.Index( idx=astutils.new_binop( lexpr=pgast.TypeCast( arg=inc_lower, type_name=pgast.TypeName(name=('int4',)), ), op='+', rexpr=pgast.NumericConstant(val='1'), ), ) rb = pgast.Index( idx=astutils.new_binop( lexpr=pgast.TypeCast( arg=inc_upper, type_name=pgast.TypeName(name=('int4',)), ), op='+', rexpr=pgast.NumericConstant(val='1'), ), ) bounds_matrix = pgast.ArrayExpr( elements=[ pgast.ArrayDimension( elements=[ pgast.StringConstant(val="()"), pgast.StringConstant(val="(]"), ], ), pgast.ArrayDimension( elements=[ pgast.StringConstant(val="[)"), pgast.StringConstant(val="[]"), ], ), ] ) bounds = pgast.Indirection(arg=bounds_matrix, indirection=[lb, rb]) pg_type = pg_types.pg_type_from_ir_typeref(expr.typeref) non_empty_range = pgast.FuncCall(name=pg_type, args=[lower, upper, bounds]) empty_range = pgast.TypeCast( arg=pgast.StringConstant(val='empty'), type_name=pgast.TypeName(name=pg_type), ) # If any of the non-optional arguments are nullable, add an explicit # null check for them. null_checks = [ pgast.NullTest(arg=e) for e in [empty, inc_upper, inc_lower] if e.nullable ] if null_checks: null_case = [ pgast.CaseWhen( expr=astutils.extend_binop(None, *null_checks, op='OR'), result=pgast.NullConstant(), ) ] else: null_case = [] set_expr = pgast.CaseExpr( args=[ *null_case, pgast.CaseWhen( expr=pgast.FuncCall( name=astutils.edgedb_func('range_validate', ctx=ctx), args=[lower, upper, inc_lower, inc_upper, empty], ), result=empty_range, ), ], defresult=non_empty_range, ) return set_expr @simple_special_case('std::_is_exclusive') def process_set_as_std_is_exclusive( expr: irast.FunctionCall, *, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: # `std::_is_exclusive` is a helper function used in the implementation of # exclusive constraints. It is removed before (ir->sql) compilation and will # never be executed by the server. # # However, during the (ql->ir) compilation, an additional (ir->sql) # compilation takes place in order to catch any potential downstream errors, # such as set returning functions. # # This simple special case is used to prevent unwanted errors during this # additional (ir->sql) compilation. return pgast.BooleanConstant(val=False) @_special_case('std::multirange', only_as_fallback=True) def process_set_as_std_multirange( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel, ) -> SetRVars: # Generic multirange constructor implementation # # std::multirange( # ranges: array>, # ) # # into # # (variadic ranges) expr = ir_set.expr ranges = dispatch.compile(expr.args[0].expr, ctx=ctx) pg_type = pg_types.pg_type_from_ir_typeref(expr.typeref) set_expr = pgast.FuncCall( name=pg_type, args=[pgast.VariadicArgument(expr=ranges)] ) pathctx.put_path_value_var(ctx.rel, ir_set.path_id, set_expr) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) @register_get_rvar(irast.Call) def process_set_as_call( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: fname = str(ir_set.expr.func_shortname) if (func := _SPECIAL_FUNCTIONS.get(fname)) and ( not func.only_as_fallback or ir_set.expr.func_sql_expr ): return func.func(ir_set, ctx=ctx) # Route simple special functions through expr compilation if fname in _SIMPLE_SPECIAL_FUNCTIONS: return process_set_as_expr(ir_set, ctx=ctx) if irutils.is_set_instance(ir_set, irast.OperatorCall): # Operator call return process_set_as_oper_expr(ir_set, ctx=ctx) assert irutils.is_set_instance(ir_set, irast.FunctionCall) if any( arg.param_typemod is qltypes.TypeModifier.SetOfType for key, arg in ir_set.expr.args.items() ): # Call to an aggregate function. return process_set_as_agg_expr(ir_set, ctx=ctx) # Regular function call. return process_set_as_func_expr(ir_set, ctx=ctx) @dataclasses.dataclass class _FuncWithOrdinalityInfo: rvar: pgast.BaseRangeVar colnames: list[str] inner_expr: pgast.OutputVar arg_is_tuple: bool nullable: bool def _process_typical_set_func_with_ordinality( ir_set: irast.Set, *, outer_func_set: irast.Set, func_name: tuple[str, ...], args: list[pgast.BaseExpr], ctx: context.CompilerContextLevel ) -> _FuncWithOrdinalityInfo: expr = ir_set.expr assert isinstance(expr, irast.FunctionCall) rtype = outer_func_set.typeref outer_func_expr = outer_func_set.expr assert isinstance(outer_func_expr, irast.FunctionCall) inner_rtype = ir_set.typeref coldeflist = [] arg_is_tuple = irtyputils.is_tuple(inner_rtype) if arg_is_tuple: subtypes = {} for i, st in enumerate(inner_rtype.subtypes): colname = st.element_name or f'_t{i + 1}' subtypes[colname] = st coldeflist.append( pgast.ColumnDef( name=colname, typename=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(st) ) ) ) colnames = list(subtypes) else: colnames = [ctx.env.aliases.get('v')] coldeflist = [] if (expr.sql_func_has_out_params or irtyputils.is_persistent_tuple(inner_rtype)): # SQL functions declared with OUT params reject column definitions. # Also persistent tuple types coldeflist = [] fexpr = pgast.FuncCall(name=func_name, args=args, coldeflist=coldeflist) colnames.append( rtype.subtypes[0].element_name or '_i' ) func_rvar = pgast.RangeFunction( alias=pgast.Alias( aliasname=ctx.env.aliases.get('f'), colnames=colnames), lateral=True, is_rowsfrom=True, with_ordinality=True, functions=[fexpr]) ctx.rel.from_clause.append(func_rvar) inner_expr: pgast.OutputVar if arg_is_tuple: inner_named_tuple = any(st.element_name for st in inner_rtype.subtypes) inner_expr = pgast.TupleVar( elements=[ pgast.TupleElement( path_id=outer_func_expr.tuple_path_ids[ len(rtype.subtypes) + i], name=n, val=astutils.get_column( func_rvar, n, nullable=fexpr.nullable) ) for i, n in enumerate(colnames[:-1]) ], named=inner_named_tuple, typeref=inner_rtype, ) else: inner_expr = astutils.get_column( func_rvar, colnames[0], nullable=fexpr.nullable) return _FuncWithOrdinalityInfo( rvar=func_rvar, colnames=colnames, inner_expr=inner_expr, arg_is_tuple=arg_is_tuple, nullable=bool(fexpr.nullable), ) def _process_nested_array_set_func_with_ordinality( ir_set: irast.Set, *, outer_func_set: irast.Set, args: list[pgast.BaseExpr], ctx: context.CompilerContextLevel ) -> _FuncWithOrdinalityInfo: # array> is implemented as array>> # # If we are unnesting with ordinality, the ordinality is paired with # the wrapping tuple. We need to unpack the tuple and re-pair the # ordinality with the array. expr = ir_set.expr assert isinstance(expr, irast.FunctionCall) outer_func_expr = outer_func_set.expr assert isinstance(outer_func_expr, irast.FunctionCall) inner_rtype = ir_set.typeref colnames = [ctx.env.aliases.get('v'), '_i'] alias = pgast.Alias( aliasname=ctx.env.aliases.get('f'), colnames=colnames, ) # Get (ordinal, tuple>) # - unnests the outer array with ordinal # - select the resulting ordinal and tuple> ordinal_tuple_expr = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.ColumnRef(name=[colname]), name=colname ) for colname in colnames ], from_clause=[ pgast.RangeFunction( alias=alias, functions=[ pgast.FuncCall( name=('unnest',), args=[args[0]], coldeflist=[ pgast.ColumnDef( name='0', typename=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref( inner_rtype ) ) ) ] ) ], lateral=True, is_rowsfrom=True, with_ordinality=True, ) ] ) ordinal_tuple_rvar = relctx.rvar_for_rel( ordinal_tuple_expr, lateral=True, ctx=ctx, ) # Get (ordinal, array<...>) inner_array_expr = pgast.CoalesceExpr( args=[ astutils.get_column(ordinal_tuple_rvar, colnames[0]), pgast.TypeCast( arg=pgast.ArrayExpr(elements=[]), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref( inner_rtype ) ), ), ] ) ordinal_expr = astutils.get_column(ordinal_tuple_rvar, colnames[-1]) ordinal_array_expr = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=inner_array_expr, name=colnames[0], ), pgast.ResTarget( val=ordinal_expr, name=colnames[-1], ), ], from_clause=[ ordinal_tuple_rvar, ] ) func_rvar = relctx.rvar_for_rel( ordinal_array_expr, lateral=True, ctx=ctx, ) ctx.rel.from_clause.append(func_rvar) return _FuncWithOrdinalityInfo( rvar=func_rvar, colnames=colnames, inner_expr=astutils.get_column(func_rvar, colnames[0]), arg_is_tuple=False, nullable=True, ) def _process_set_func_with_ordinality( ir_set: irast.Set, *, outer_func_set: irast.Set, func_name: tuple[str, ...], args: list[pgast.BaseExpr], ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: expr = ir_set.expr assert isinstance(expr, irast.FunctionCall) rtype = outer_func_set.typeref outer_func_expr = outer_func_set.expr assert isinstance(outer_func_expr, irast.FunctionCall) func_info: _FuncWithOrdinalityInfo if ( func_name == ('unnest',) and irtyputils.is_array(ir_set.typeref) ): func_info = _process_nested_array_set_func_with_ordinality( ir_set, outer_func_set=outer_func_set, args=args, ctx=ctx, ) else: func_info = _process_typical_set_func_with_ordinality( ir_set, outer_func_set=outer_func_set, func_name=func_name, args=args, ctx=ctx, ) named_tuple = any(st.element_name for st in rtype.subtypes) set_expr = pgast.TupleVar( elements=[ pgast.TupleElement( path_id=outer_func_expr.tuple_path_ids[0], name=func_info.colnames[0], val=pgast.Expr( name='-', lexpr=astutils.get_column( func_info.rvar, func_info.colnames[-1], nullable=func_info.nullable, ), rexpr=pgast.NumericConstant(val='1') ) ), pgast.TupleElement( path_id=outer_func_expr.tuple_path_ids[1], name=rtype.subtypes[1].element_name or '1', val=func_info.inner_expr, ), ], named=named_tuple, typeref=outer_func_set.typeref, ) for element in set_expr.elements: pathctx.put_path_value_var(ctx.rel, element.path_id, element.val) if func_info.arg_is_tuple: arg_tuple = set_expr.elements[1].val assert isinstance(arg_tuple, pgast.TupleVar) for element in arg_tuple.elements: pathctx.put_path_value_var(ctx.rel, element.path_id, element.val) # If there is a shape specified on the argument to enumerate, we need # to compile it here manually, since we are skipping the normal # code path for it. if (output.in_serialization_ctx(ctx) and ir_set.shape and not ctx.env.ignore_object_shapes): ensure_source_rvar(ir_set, ctx.rel, ctx=ctx) exprcomp._compile_shape(ir_set, ir_set.shape, ctx=ctx) var = pathctx.maybe_get_path_var( ctx.rel, ir_set.path_id, aspect=pgce.PathAspect.SERIALIZED, env=ctx.env, ) if var is not None: pathctx.put_path_var( ctx.rel, set_expr.elements[1].path_id, var, aspect=pgce.PathAspect.SERIALIZED, ) return set_expr def _process_set_func( ir_set: irast.Set, *, func_name: tuple[str, ...], args: list[pgast.BaseExpr], ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: expr = ir_set.expr assert isinstance(expr, irast.FunctionCall) rtype = expr.typeref named_tuple = any(st.element_name for st in rtype.subtypes) coldeflist = [] is_tuple = irtyputils.is_tuple(rtype) if is_tuple: subtypes = {} for i, st in enumerate(rtype.subtypes): colname = st.element_name or str(i) subtypes[colname] = st coldeflist.append( pgast.ColumnDef( name=colname, typename=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(st) ) ) ) colnames = list(subtypes) else: colnames = [ctx.env.aliases.get('v')] if _should_unwrap_polymorphic_return_array(expr): # If we are unwrapping a previously nested array, its pg type is # record and so we need to provide a column definition list. coldeflist = [ pgast.ColumnDef( name='v', typename=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) ] if ( # SQL functions declared with OUT params or returning # named composite types reject column definitions. irtyputils.is_persistent_tuple(rtype) or expr.sql_func_has_out_params ): coldeflist = [] fexpr = pgast.FuncCall(name=func_name, args=args, coldeflist=coldeflist) func_rvar = pgast.RangeFunction( alias=pgast.Alias( aliasname=ctx.env.aliases.get('f'), colnames=colnames), lateral=True, is_rowsfrom=True, functions=[fexpr]) ctx.rel.from_clause.append(func_rvar) set_expr: pgast.BaseExpr if not is_tuple: set_expr = astutils.get_column( func_rvar, colnames[0], nullable=fexpr.nullable) else: set_expr = pgast.TupleVar( elements=[ pgast.TupleElement( path_id=expr.tuple_path_ids[i], name=n, val=astutils.get_column( func_rvar, n, nullable=fexpr.nullable) ) for i, n in enumerate(colnames) ], named=named_tuple, typeref=rtype, ) for element in set_expr.elements: pathctx.put_path_value_var_if_not_exists( ctx.rel, element.path_id, element.val ) return set_expr def _compile_func_epilogue( ir_set: irast.Set, *, set_expr: pgast.BaseExpr, func_rel: pgast.SelectStmt, ctx: context.CompilerContextLevel) -> SetRVars: expr = ir_set.expr assert isinstance(expr, irast.FunctionCall) if expr.volatility.is_volatile(): relctx.apply_volatility_ref(func_rel, ctx=ctx) pathctx.put_path_var_if_not_exists( func_rel, ir_set.path_id, set_expr, aspect=pgce.PathAspect.VALUE, ) aspects: tuple[pgce.PathAspect, ...] if expr.body: # For inlined functions, we want all of the aspects provided. aspects = tuple(pathctx.list_path_aspects(func_rel, ir_set.path_id)) else: # Otherwise we just know we have value. aspects = (pgce.PathAspect.VALUE,) func_rvar = relctx.new_rel_rvar(ir_set, func_rel, ctx=ctx) relctx.include_rvar( ctx.rel, func_rvar, ir_set.path_id, pull_namespace=False, aspects=aspects, ctx=ctx, ) if (ir_set.path_id.is_tuple_path() and expr.typemod is qltypes.TypeModifier.SetOfType): # Functions returning a set of tuples are compiled with an # explicit coldeflist, so the result is represented as a # TupleVar as opposed to an opaque record datum, so # we can access the elements directly without using # `tuple_getattr()`. aspects += (pgce.PathAspect.SOURCE,) return new_stmt_set_rvar(ir_set, ctx.rel, aspects=aspects, ctx=ctx) def _needs_arg_null_check( call_expr: irast.Call, ir_arg: irast.CallArg, typemod: qltypes.TypeModifier, *, ctx: context.CompilerContextLevel ) -> bool: return ( not call_expr.impl_is_strict and not ir_arg.is_default and ( ( typemod == qltypes.TypeModifier.SingletonType and ir_arg.cardinality.can_be_zero() ) or typemod == qltypes.TypeModifier.SetOfType ) ) def _compile_arg_null_check( call_expr: irast.Call, ir_arg: irast.CallArg, arg_ref: pgast.BaseExpr, typemod: qltypes.TypeModifier, *, ctx: context.CompilerContextLevel ) -> None: if ( _needs_arg_null_check(call_expr, ir_arg, typemod, ctx=ctx) and arg_ref.nullable ): ctx.rel.where_clause = astutils.extend_binop( ctx.rel.where_clause, pgast.NullTest(arg=arg_ref, negated=True) ) # Polymorphic calls need special handling for nested arrays since # array> is implemented as array>>. # # Currently 2 cases are handled: # A) At least 1 arg type is `anyarray` # B) Return type is `anyarray` # # In both cases, any simple polymorphic types will be added into an array in # the result. If this parameter is also an array, then it needs to be wrapped # in a tuple. # # In case A and if the return type is anytype, the call is returning the # contents of the array. The result needs to be unwrapped to get the actual # array. def _has_polymorphic_array_arg( expr: irast.Call ) -> bool: return any( ir_arg.polymorphism == qltypes.Polymorphism.Array for ir_arg in expr.args.values() ) def _should_wrap_polymorphic_array_args( expr: irast.Call ) -> bool: return ( _has_polymorphic_array_arg(expr) or expr.return_polymorphism == qltypes.Polymorphism.Array ) def _is_array_arg_as_simple_polymorphic( arg: irast.CallArg ) -> bool: return ( arg.polymorphism == qltypes.Polymorphism.Simple and irtyputils.is_array(arg.expr.typeref) ) def _should_unwrap_polymorphic_return_array( expr: irast.Call ) -> bool: return ( _has_polymorphic_array_arg(expr) and expr.return_polymorphism == qltypes.Polymorphism.Simple and irtyputils.is_array(expr.typeref) ) def _compile_call_args( ir_set: irast.SetE[irast.Call], *, skip: Collection[int] = (), no_subquery_args: bool = False, ctx: context.CompilerContextLevel, ) -> list[pgast.BaseExpr]: """ Compiles function call arguments, whose index is not in `skip`. """ expr = ir_set.expr args = [] if isinstance(expr, irast.FunctionCall) and expr.global_args: for glob_arg in expr.global_args: arg_ref = dispatch.compile(glob_arg, ctx=ctx) args.append(output.output_as_value(arg_ref, env=ctx.env)) for ir_key, ir_arg in expr.args.items(): if ir_key in skip: continue assert ir_arg.multiplicity != qltypes.Multiplicity.UNKNOWN typemod = ir_arg.param_typemod # Support a mode where we try to compile arguments as pure # subqueries. This is occasionally valuable as it lets us # "push down" the subqueries from the top level, which is # important for things like hitting pgvector indexes in an # ORDER BY. arg_typeref = ir_arg.expr.typeref make_subquery = ( expr.prefer_subquery_args and typemod != qltypes.TypeModifier.SetOfType and ir_arg.cardinality.is_single() and (arg_typeref.is_scalar or arg_typeref.collection) and not _needs_arg_null_check(expr, ir_arg, typemod, ctx=ctx) and not no_subquery_args ) if make_subquery: arg_ref = set_as_subquery(ir_arg.expr, as_value=True, ctx=ctx) arg_ref.nullable = ir_arg.cardinality.can_be_zero() arg_ref = astutils.collapse_query(arg_ref) else: arg_ref = dispatch.compile(ir_arg.expr, ctx=ctx) arg_ref = output.output_as_value(arg_ref, env=ctx.env) if ( _should_wrap_polymorphic_array_args(expr) and _is_array_arg_as_simple_polymorphic(ir_arg) ): arg_ref = pgast.RowExpr(args=[arg_ref]) args.append(arg_ref) _compile_arg_null_check(expr, ir_arg, arg_ref, typemod, ctx=ctx) if ( isinstance(expr, irast.FunctionCall) and ir_arg.expr_type_path_id is not None ): # Object type arguments are represented by two # SQL arguments: object id and object type id. # The latter is needed for proper overload # dispatch. ensure_source_rvar(ir_arg.expr, ctx.rel, ctx=ctx) type_ref = relctx.get_path_var( ctx.rel, ir_arg.expr_type_path_id, aspect=pgce.PathAspect.IDENTITY, ctx=ctx, ) args.append(type_ref) if ( isinstance(expr, irast.FunctionCall) and expr.has_empty_variadic and expr.variadic_param_type is not None ): var = pgast.TypeCast( arg=pgast.ArrayExpr(elements=[]), type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref( expr.variadic_param_type) ) ) args.append(pgast.VariadicArgument(expr=var)) return args def _compile_inlined_call_args( ir_set: irast.SetE[irast.FunctionCall], *, ctx: context.CompilerContextLevel ) -> None: expr = ir_set.expr assert expr.body is not None if irutils.contains_dml(expr.body): last_iterator = ctx.enclosing_cte_iterator # If this function call has already been compiled to a CTE, don't # recompile the arguments. # (This will happen when a DML-containing funcion in a FOR loop is # WITH bound, for example.) if ir_set.path_id in ctx.inline_dml_ctes: args_pathid, arg_cte = ctx.inline_dml_ctes[ir_set.path_id] else: # Compile args into an iterator CTE with ctx.newrel() as arg_ctx: dml.merge_iterator(last_iterator, arg_ctx.rel, ctx=arg_ctx) clauses.setup_iterator_volatility(last_iterator, ctx=arg_ctx) _compile_call_args(ir_set, ctx=arg_ctx) # Add iterator identity args_pathid = irast.PathId.new_dummy( ctx.env.aliases.get('args') ) with arg_ctx.subrel() as args_pathid_ctx: relctx.create_iterator_identity_for_path( args_pathid, args_pathid_ctx.rel, ctx=args_pathid_ctx ) args_id_rvar = relctx.rvar_for_rel( args_pathid_ctx.rel, lateral=True, ctx=arg_ctx ) relctx.include_rvar( arg_ctx.rel, args_id_rvar, path_id=args_pathid, ctx=arg_ctx ) for ir_arg in expr.args.values(): arg_path_id = ir_arg.expr.path_id # Ensure args appear in arg CTE pathctx.get_path_output( arg_ctx.rel, arg_path_id, aspect=pgce.PathAspect.VALUE, env=arg_ctx.env, ) pathctx.put_path_bond( arg_ctx.rel, arg_path_id, iterator=True ) arg_cte = pgast.CommonTableExpr( name=ctx.env.aliases.get('args'), query=arg_ctx.rel, materialized=False, ) ctx.toplevel_stmt.append_cte(arg_cte) ctx.inline_dml_ctes[ir_set.path_id] = (args_pathid, arg_cte) arg_iterator = pgast.IteratorCTE( path_id=args_pathid, cte=arg_cte, parent=last_iterator, other_paths=tuple( (ir_arg.expr.path_id, pgce.PathAspect.VALUE) for ir_arg in expr.args.values() ), iterator_bond=True, ) # Merge the new iterator ctx.path_scope = ctx.path_scope.new_child() arg_rvar = not_none( dml.merge_iterator(arg_iterator, ctx.rel, ctx=ctx) ) clauses.setup_iterator_volatility(arg_iterator, ctx=ctx) ctx.enclosing_cte_iterator = arg_iterator else: with ctx.subrel() as arg_ctx: # Compile the call args but don't do anything with the resulting # exprs. Their rvar will be found when compiling the inlined body. _compile_call_args(ir_set, ctx=arg_ctx) arg_rvar = relctx.rvar_for_rel(arg_ctx.rel, ctx=ctx) for ir_arg in expr.args.values(): arg_path_id = ir_arg.expr.path_id relctx.include_rvar( ctx.rel, arg_rvar, arg_path_id, ctx=ctx, ) for ir_arg in expr.args.values(): arg_path_id = ir_arg.expr.path_id if arg_scope_stmt := relctx.maybe_get_scope_stmt( arg_path_id, ctx=ctx ): # The args are joined to ctx.rel, but other sets may # look for it in the scope statement. Make sure it's # available. pathctx.put_path_value_rvar( arg_scope_stmt, arg_path_id, arg_rvar, ) def process_set_as_func_enumerate( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr inner_func_set = irutils.unwrap_set(expr.args[0].expr) assert irutils.is_set_instance(inner_func_set, irast.FunctionCall) inner_func = inner_func_set.expr with ctx.subrel() as newctx: with newctx.new() as newctx2: newctx2.expr_exposed = False args = _compile_call_args(inner_func_set, ctx=newctx2) func_name = exprcomp.get_func_call_backend_name(inner_func, ctx=newctx) set_expr = _process_set_func_with_ordinality( ir_set=inner_func_set, outer_func_set=ir_set, func_name=func_name, args=args, ctx=newctx) func_rel = newctx.rel return _compile_func_epilogue( ir_set, set_expr=set_expr, func_rel=func_rel, ctx=ctx ) def process_set_as_func_expr( ir_set: irast.SetE[irast.FunctionCall], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr with ctx.subrel() as newctx: newctx.expr_exposed = False if expr.body is not None: _compile_inlined_call_args(ir_set, ctx=newctx) set_expr = dispatch.compile(expr.body, ctx=newctx) # Map the path id so that we can extract source aspects # from it, which we want so that we can directly select # from an INSERT instead of using overlays. pathctx.put_path_id_map( newctx.rel, ir_set.path_id, expr.body.path_id ) if _should_unwrap_polymorphic_return_array(expr): set_expr = astutils.array_get_inner_array( set_expr, expr.typeref ) else: args = _compile_call_args(ir_set, ctx=newctx) name = exprcomp.get_func_call_backend_name(expr, ctx=newctx) if expr.typemod is qltypes.TypeModifier.SetOfType: set_expr = _process_set_func( ir_set, func_name=name, args=args, ctx=newctx, ) else: set_expr = pgast.FuncCall(name=name, args=args) if _should_unwrap_polymorphic_return_array(expr): set_expr = astutils.array_get_inner_array( set_expr, expr.typeref ) if expr.error_on_null_result: set_expr = pgast.FuncCall( name=astutils.edgedb_func('raise_on_null', ctx=ctx), args=[ set_expr, pgast.StringConstant( val='invalid_parameter_value', ), pgast.StringConstant( val=expr.error_on_null_result, ), pgast.StringConstant( val=irutils.get_span_as_json( expr, errors.InvalidValueError), ), ] ) if expr.force_return_cast: # The underlying function has a return value type # different from that of the EdgeQL function declaration, # so we need to make an explicit cast here. set_expr = pgast.TypeCast( arg=set_expr, type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) func_rel = newctx.rel return _compile_func_epilogue( ir_set, set_expr=set_expr, func_rel=func_rel, ctx=ctx ) def process_set_as_agg_expr_inner( ir_set: irast.SetE[irast.FunctionCall], *, aspect: pgce.PathAspect, wrapper: Optional[pgast.SelectStmt], for_group_by: bool = False, ctx: context.CompilerContextLevel, ) -> SetRVars: expr = ir_set.expr assert isinstance(expr, irast.FunctionCall) stmt = ctx.rel set_expr: pgast.BaseExpr with ctx.newscope() as newctx: agg_filter = None agg_sort = [] with newctx.new() as argctx: # We want array_agg() (and similar) to do the right # thing with respect to output format, so, barring # the (unacceptable) hardcoding of function names, # check if the aggregate accepts a single argument # of "any" to determine serialized input safety. serialization_safe = ( expr.func_polymorphic and aspect == pgce.PathAspect.SERIALIZED ) if not serialization_safe: argctx.expr_exposed = False args = [] for ir_call_key, ir_call_arg in expr.args.items(): ir_arg = ir_call_arg.expr arg_ref: pgast.BaseExpr if for_group_by: arg_ref = set_as_subquery( ir_arg, as_value=True, ctx=argctx) arg_ref.nullable = False elif aspect == pgce.PathAspect.SERIALIZED: dispatch.visit(ir_arg, ctx=argctx) arg_ref = pathctx.get_path_serialized_or_value_var( argctx.rel, ir_arg.path_id, env=argctx.env) if isinstance(arg_ref, pgast.TupleVar): arg_ref = output.serialize_expr( arg_ref, path_id=ir_arg.path_id, env=argctx.env) else: dispatch.visit(ir_arg, ctx=argctx) arg_ref = pathctx.get_path_value_var( argctx.rel, ir_arg.path_id, env=argctx.env) if isinstance(arg_ref, pgast.TupleVar): arg_ref = output.output_as_value( arg_ref, env=argctx.env) _compile_arg_null_check( expr, ir_call_arg, arg_ref, ir_call_arg.param_typemod, ctx=argctx ) path_scope = relctx.get_scope(ir_arg, ctx=argctx) if path_scope is not None and path_scope.parent is not None: arg_is_visible = path_scope.parent.is_any_prefix_visible( ir_arg.path_id) else: arg_is_visible = False if arg_is_visible: # If the argument set is visible above us, we # are aggregating a singleton set, potentially on # the same query level, as the source set. # Postgres doesn't like aggregates on the same query # level, so wrap the arg ref into a VALUES range. wrapper = pgast.SelectStmt( values=[pgast.ImplicitRowExpr(args=[arg_ref])] ) colname = argctx.env.aliases.get('a') wrapper_rvar = relctx.rvar_for_rel( wrapper, lateral=True, colnames=[colname], ctx=argctx) relctx.include_rvar( ctx.rel, wrapper_rvar, path_id=ir_arg.path_id, ctx=argctx, ) arg_ref = astutils.get_column(wrapper_rvar, colname) if ir_call_key == 0 and irutils.is_subquery_set(ir_arg): # If the first argument of the aggregate # is a SELECT or GROUP with an ORDER BY clause, # we move the ordering conditions to the aggregate # call to make sure the ordering is as expected. substmt = ir_arg.expr if isinstance(substmt, irast.GroupStmt): substmt = substmt.result.expr if (isinstance(substmt, irast.SelectStmt) and substmt.orderby): qrvar = pathctx.get_path_rvar( ctx.rel, ir_arg.path_id, aspect=pgce.PathAspect.VALUE, ) query = qrvar.query assert isinstance(query, pgast.SelectStmt) for i, sortref in enumerate(query.sort_clause or ()): alias = argctx.env.aliases.get(f's{i}') query.target_list.append( pgast.ResTarget( val=sortref.node, name=alias ) ) agg_sort.append( pgast.SortBy( node=astutils.get_column(qrvar, alias), dir=sortref.dir, nulls=sortref.nulls)) query.sort_clause = [] if ( _should_wrap_polymorphic_array_args(expr) and _is_array_arg_as_simple_polymorphic(ir_call_arg) ): # Wrap aggregated arrays in a tuple arg_ref = pgast.RowExpr(args=[arg_ref]) if aspect == pgce.PathAspect.SERIALIZED: arg_ref = output.serialize_expr( arg_ref, path_id=ir_arg.path_id, env=argctx.env) args.append(arg_ref) name = exprcomp.get_func_call_backend_name(expr, ctx=newctx) set_expr = pgast.FuncCall( name=name, args=args, agg_order=agg_sort, agg_filter=agg_filter, ser_safe=serialization_safe and all(x.ser_safe for x in args)) if _should_unwrap_polymorphic_return_array(expr): set_expr = astutils.array_get_inner_array( set_expr, expr.typeref ) if for_group_by and not expr.impl_is_strict: # If we are doing this for a GROUP BY, and the function is not # strict in its arguments, we are in trouble! # The problem is that we don't have a way to filter the NULLs # out in the subquery in general. The value could be # computed *inside* the subquery, so we can't use an agg_filter, # and we can't filter it inside the subquery because it gets # executed separately for each row and collapses to NULL when # it is empty! # Fortunately I think that only array_agg has this property, # so we can just handle that by popping the NULLs out. # If other cases turn up, we could handle it by falling # back to aggregate grouping. # TODO: only do this when there might really be a null? assert str(expr.func_shortname) == 'std::array_agg' set_expr = pgast.FuncCall( name=('array_remove',), args=[set_expr, pgast.NullConstant()] ) if expr.error_on_null_result: set_expr = pgast.FuncCall( name=astutils.edgedb_func('raise_on_null', ctx=ctx), args=[ set_expr, pgast.StringConstant( val='invalid_parameter_value', ), pgast.StringConstant( val=expr.error_on_null_result, ), pgast.StringConstant( val=irutils.get_span_as_json( expr, errors.InvalidValueError), ), ] ) if expr.force_return_cast: # The underlying function has a return value type # different from that of the EdgeQL function declaration, # so we need to make an explicit cast here. set_expr = pgast.TypeCast( arg=set_expr, type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) if expr.func_initial_value is not None and wrapper: iv_ir = expr.func_initial_value.expr assert iv_ir is not None if serialization_safe and aspect == pgce.PathAspect.SERIALIZED: # Serialization has changed the output type. with ctx.new() as ivctx: iv = dispatch.compile(iv_ir, ctx=ivctx) iv = output.serialize_expr_if_needed( iv, path_id=ir_set.path_id, ctx=ctx) set_expr = output.serialize_expr_if_needed( set_expr, path_id=ir_set.path_id, ctx=ctx) else: with ctx.new() as ivctx: iv = dispatch.compile(iv_ir, ctx=ivctx) pathctx.put_path_var(stmt, ir_set.path_id, set_expr, aspect=aspect) out = pathctx.get_path_output( stmt, ir_set.path_id, aspect=aspect, env=ctx.env ) assert isinstance(out, pgast.ColumnRef) # HACK: We select join in the inner statement instead of just # using it as a subquery to work around a postgres bug that # occurs when something defined with a subquery is used as an # argument to `grouping`. See #3844. stmt_rvar = relctx.rvar_for_rel(stmt, ctx=ctx) wrapper.from_clause.append(stmt_rvar) val = astutils.get_column(stmt_rvar, out) assert wrapper set_expr = pgast.CoalesceExpr( args=[val, iv], ser_safe=serialization_safe) pathctx.put_path_var(wrapper, ir_set.path_id, set_expr, aspect=aspect) stmt = wrapper pathctx.put_path_var_if_not_exists( stmt, ir_set.path_id, set_expr, aspect=aspect ) # Cheat a little bit: as discussed above, pretend the serialized # value is also really a value. Eta-expansion should ensure this # only happens when we don't really need the value again. if aspect == pgce.PathAspect.SERIALIZED: pathctx.put_path_var_if_not_exists( stmt, ir_set.path_id, set_expr, aspect=pgce.PathAspect.VALUE ) return new_stmt_set_rvar(ir_set, stmt, ctx=ctx) def process_set_as_agg_expr( ir_set: irast.SetE[irast.FunctionCall], *, ctx: context.CompilerContextLevel ) -> SetRVars: # If the func has an initial val, we need to do the interesting # work in subrels and provide a wrapper to put the coalesces in wrapper = None if ir_set.expr.func_initial_value is not None: wrapper = ctx.rel # In a serialization context that produces something containing an object, # we produce *only* a serialized value, and we claim it is the value too. # For this to be correct, we need to only have serialized agg expr results # in cases where value can't be used anymore. Our eta-expansion pass # make sure this happens. # (... the only such *function* currently is array_agg.) # Though if the result type contains no objects, the value should be good # enough, so don't generate a bunch of unnecessary code to produce # a serialized value when we can use value. serialized = ( output.in_serialization_ctx(ctx=ctx) and irtyputils.contains_object(ir_set.typeref) ) cctx = ctx.subrel() if wrapper else ctx.new() with cctx as xctx: xctx.expr_exposed = serialized aspect = ( pgce.PathAspect.SERIALIZED if serialized else pgce.PathAspect.VALUE ) process_set_as_agg_expr_inner( ir_set, aspect=aspect, wrapper=wrapper, ctx=xctx ) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) @_special_case('std::EXISTS') def process_set_as_exists_expr( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr with ctx.subrel() as subctx: wrapper = subctx.rel subctx.expr_exposed = False ir_expr = expr.args[0].expr set_ref = dispatch.compile(ir_expr, ctx=subctx) pathctx.put_path_value_var(wrapper, ir_set.path_id, set_ref) pathctx.get_path_value_output(wrapper, ir_set.path_id, env=ctx.env) wrapper.where_clause = astutils.extend_binop( wrapper.where_clause, pgast.NullTest(arg=set_ref, negated=True)) set_expr = pgast.SubLink(operator="EXISTS", expr=wrapper) pathctx.put_path_value_var(ctx.rel, ir_set.path_id, set_expr) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) @_special_case('std::json_object_pack') def process_set_as_json_object_pack( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: ir_arg = ir_set.expr.args[0].expr # compile IR to pg AST dispatch.visit(ir_arg, ctx=ctx) arg_val = pathctx.get_path_value_var(ctx.rel, ir_arg.path_id, env=ctx.env) # get first and the second fields of the tuple if isinstance(arg_val, pgast.TupleVar): keys = arg_val.elements[0].val values = arg_val.elements[1].val else: keys = astutils.tuple_getattr(arg_val, ir_arg.typeref, "0") values = astutils.tuple_getattr(arg_val, ir_arg.typeref, "1") # construct the function call set_expr = pgast.FuncCall( name=("coalesce",), args=[ pgast.FuncCall(name=("jsonb_object_agg",), args=[keys, values]), pgast.TypeCast( arg=pgast.StringConstant(val="{}"), type_name=pgast.TypeName(name=('jsonb',)), ), ], ) # declare that the 'aspect=value' of ir_set (original set) # can be found by in ctx.rel, by using set_expr pathctx.put_path_value_var_if_not_exists(ctx.rel, ir_set.path_id, set_expr) # return subquery as set_rvar return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) def build_array_expr( ir_expr: irast.Base, elements: list[pgast.BaseExpr], *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: array = astutils.safe_array_expr(elements, ctx=ctx) if irutils.is_empty_array_expr(ir_expr): assert isinstance(ir_expr, irast.Array) typeref = ir_expr.typeref if irtyputils.is_any(typeref.subtypes[0]): # The type of the input is not determined, which means that # the result of this expression is passed as an argument # to a generic function, e.g. `count(array_agg({}))`. In this # case, amend the array type to a concrete type, # since Postgres balks at `[]::anyarray`. pg_type: tuple[str, ...] = ('text[]',) else: serialized = output.in_serialization_ctx(ctx=ctx) pg_type = pg_types.pg_type_from_ir_typeref( typeref, serialized=serialized) return pgast.TypeCast( arg=array, type_name=pgast.TypeName( name=pg_type, ), ) else: return array @register_get_rvar(irast.Array) def process_set_as_array_expr( ir_set: irast.SetE[irast.Array], *, ctx: context.CompilerContextLevel ) -> SetRVars: expr = ir_set.expr elements = [] s_elements = [] serializing = ( output.in_serialization_ctx(ctx=ctx) and irtyputils.contains_object(ir_set.typeref) ) for ir_element in expr.elements: element = dispatch.compile(ir_element, ctx=ctx) if irtyputils.is_array(ir_element.typeref): # Wrap nested arrays in a tuple element = pgast.RowExpr(args=[element]) elements.append(element) if serializing: s_var: Optional[pgast.BaseExpr] s_var = pathctx.maybe_get_path_serialized_var( ctx.rel, ir_element.path_id, env=ctx.env ) if s_var is None: v_var = pathctx.get_path_value_var( ctx.rel, ir_element.path_id, env=ctx.env ) s_var = output.serialize_expr( v_var, path_id=ir_element.path_id, env=ctx.env) elif isinstance(s_var, pgast.TupleVar): s_var = output.serialize_expr( s_var, path_id=ir_element.path_id, env=ctx.env) s_elements.append(s_var) if serializing: set_expr = astutils.safe_array_expr( s_elements, ser_safe=all(x.ser_safe for x in s_elements), ctx=ctx) if irutils.is_empty_array_expr(expr): set_expr = pgast.TypeCast( arg=set_expr, type_name=pgast.TypeName( name=pg_types.pg_type_from_ir_typeref(expr.typeref) ) ) pathctx.put_path_serialized_var(ctx.rel, ir_set.path_id, set_expr) else: set_expr = build_array_expr(expr, elements, ctx=ctx) pathctx.put_path_value_var_if_not_exists(ctx.rel, ir_set.path_id, set_expr) return new_stmt_set_rvar(ir_set, ctx.rel, ctx=ctx) def process_encoded_param( param: irast.Param, *, ctx: context.CompilerContextLevel) -> pgast.BaseExpr: assert param.sub_params decoder = param.sub_params.decoder_ir assert decoder if (param_cte := ctx.param_ctes.get(param.name)) is None: with ctx.newrel() as sctx: sctx.pending_query = sctx.rel sctx.rel_overlays = context.RelOverlays() arg_ref = dispatch.compile(decoder, ctx=sctx) # Force it into a real tuple so we can just always grab it # from a subquery below. arg_val = output.output_as_value(arg_ref, env=sctx.env) pathctx.put_path_value_var( sctx.rel, decoder.path_id, arg_val, force=True ) param_cte = pgast.CommonTableExpr( name=ctx.env.aliases.get('p'), query=sctx.rel, materialized=False, ) ctx.param_ctes[param.name] = param_cte with ctx.subrel() as sctx: cte_rvar = pgast.RelRangeVar( relation=param_cte, typeref=decoder.typeref, alias=pgast.Alias(aliasname=ctx.env.aliases.get('t')) ) relctx.include_rvar( sctx.rel, cte_rvar, decoder.path_id, pull_namespace=False, aspects=(pgce.PathAspect.VALUE,), ctx=sctx, ) pathctx.get_path_value_output(sctx.rel, decoder.path_id, env=ctx.env) if not param.required: sctx.rel.nullable = True return sctx.rel _ObjectSearchInnerCallback = Callable[ [ irast.Call, irast.PathId, list[pgast.BaseExpr], context.CompilerContextLevel, context.CompilerContextLevel, context.CompilerContextLevel, ], tuple[pgast.BaseExpr, Optional[pgast.BaseExpr]], ] @_special_case('std::fts::search') def process_set_as_fts_search( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: from edb.common import debug cb: _ObjectSearchInnerCallback if debug.flags.zombodb: cb = _fts_search_inner_zombo else: cb = _fts_search_inner_pg return _process_set_as_object_search( ir_set, inner_cb=cb, ctx=ctx) @_special_case('ext::ai::search') def process_set_as_ext_ai_search( ir_set: irast.SetE[irast.Call], *, ctx: context.CompilerContextLevel ) -> SetRVars: cb = _ext_ai_search_inner_pgvector return _process_set_as_object_search( ir_set, inner_cb=cb, ctx=ctx) def _ext_ai_search_inner_pgvector( call: irast.Call, obj_id: irast.PathId, args_pg: list[pgast.BaseExpr], _ctx: context.CompilerContextLevel, newctx: context.CompilerContextLevel, _inner_ctx: context.CompilerContextLevel, ) -> tuple[pgast.BaseExpr, Optional[pgast.BaseExpr]]: assert isinstance(call, irast.FunctionCall) if call.extras is None: raise AssertionError( "missing expected index metadata in FunctionCall.extras") index_metadata = call.extras.get("index_metadata") if index_metadata is None: raise AssertionError( "missing expected index metadata in FunctionCall.extras") tgt = obj_id.target if tgt.material_type is not None: tgt = tgt.material_type target_index_metadata = index_metadata.get(tgt) if target_index_metadata is None: raise AssertionError( "missing expected index metadata in FunctionCall.extras") index_id = target_index_metadata.get("id") if index_id is None: raise AssertionError( "missing expected index metadata in FunctionCall.extras") dimensions = target_index_metadata.get("dimensions") if dimensions is None: raise AssertionError( "missing expected index metadata in FunctionCall.extras") df = target_index_metadata.get("distance_function") if index_id is None: raise AssertionError( "missing expected index metadata in FunctionCall.extras") query, = args_pg el_name = sn.QualName( '__object__', f'__ext_ai_{index_id}_embedding__', ) embedding_ptrref = irast.SpecialPointerRef( name=el_name, shortname=el_name, out_source=obj_id.target, out_target=pg_types.pg_tsvector_typeref, out_cardinality=qltypes.Cardinality.AT_MOST_ONE, ) embedding_id = obj_id.extend(ptrref=embedding_ptrref) embedding = relctx.get_path_var( newctx.rel, embedding_id, aspect=pgce.PathAspect.VALUE, ctx=newctx, ) similarity = pgast.FuncCall( name=common.get_function_backend_name(*df), args=[ embedding, pgast.TypeCast( arg=query, type_name=pgast.TypeName( name=('edgedb', f'vector({dimensions})'), ), ), ], ) # Install the filter directly in newctx.rel. We could return it # and have it put in inner_ctx.rel, and that does seem to work, # but seems weirder. valid = pgast.NullTest(arg=embedding, negated=True) newctx.rel.where_clause = astutils.extend_binop( newctx.rel.where_clause, valid ) # Do an integrated sort. This ensures we can hit the index, and is # more ergonomic anyway. Having the ORDER BY operate directly on # the function call is not the *only* way to have it work, but it # is the most reliable. sort_by = pgast.SortBy( node=similarity, dir=qlast.SortOrder.Asc, nulls=qlast.NonesOrder.Last, ) if newctx.rel.sort_clause is None: newctx.rel.sort_clause = [] newctx.rel.sort_clause.append(sort_by) return similarity, None def _process_set_as_object_search( ir_set: irast.SetE[irast.Call], *, inner_cb: _ObjectSearchInnerCallback, ctx: context.CompilerContextLevel, ) -> SetRVars: func_call = ir_set.expr # We skip the object, as it has to be compiled as rvar source. # # Also, disable subquery args. ai::search needs it for its # scoping effects, but we don't need to use it here, since # it can cause the ai search to duplicate arguments. args_pg = _compile_call_args( ir_set, skip={0}, no_subquery_args=True, ctx=ctx) with ctx.subrel() as newctx: newctx.expr_exposed = False obj_ir = func_call.args[0].expr obj_id = obj_ir.path_id obj_rvar = ensure_source_rvar(obj_ir, newctx.rel, ctx=newctx) out_obj_id, out_score_id = func_call.tuple_path_ids with newctx.subrel() as inner_ctx: # inner_ctx generates the `SELECT score WHERE test` relation score_pg, where_clause = inner_cb( func_call, obj_id, args_pg, ctx, newctx, inner_ctx, ) pathctx.put_path_var( inner_ctx.rel, out_score_id, score_pg, aspect=pgce.PathAspect.VALUE, ) if where_clause is not None: inner_ctx.rel.where_clause = astutils.extend_binop( inner_ctx.rel.where_clause, where_clause ) in_rvar = relctx.new_rel_rvar(ir_set, inner_ctx.rel, ctx=newctx) relctx.include_rvar( newctx.rel, in_rvar, out_score_id, aspects={pgce.PathAspect.VALUE}, ctx=newctx, ) obj_id_pg_ref = pathctx.get_rvar_path_var( obj_rvar, obj_id, aspect=pgce.PathAspect.VALUE, env=newctx.env, ) score_pg_ref = pathctx.get_path_var( newctx.rel, out_score_id, aspect=pgce.PathAspect.VALUE, env=newctx.env, ) tuple_expr = pgast.TupleVar( elements=[ pgast.TupleElement( path_id=out_obj_id, name='object', val=obj_id_pg_ref, ), pgast.TupleElement( path_id=out_score_id, name='score', val=score_pg_ref, ), ], named=True, typeref=ir_set.typeref, ) pathctx.put_path_var( newctx.rel, ir_set.path_id, tuple_expr, aspect=pgce.PathAspect.VALUE, ) var = pathctx.maybe_get_path_var( newctx.rel, obj_id, aspect=pgce.PathAspect.SERIALIZED, env=newctx.env, ) if var is not None: pathctx.put_path_var( newctx.rel, out_obj_id, var, aspect=pgce.PathAspect.SERIALIZED, ) pathctx.put_path_id_map(newctx.rel, out_obj_id, obj_id) aspects = {pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE} func_rvar = relctx.new_rel_rvar(ir_set, newctx.rel, ctx=ctx) relctx.include_rvar( ctx.rel, func_rvar, ir_set.path_id, aspects=aspects, ctx=ctx ) pathctx.put_path_rvar( ctx.rel, out_obj_id, func_rvar, aspect=pgce.PathAspect.SOURCE, ) return new_stmt_set_rvar(ir_set, ctx.rel, aspects=aspects, ctx=ctx) def _fts_search_inner_pg( _call: irast.Call, obj_id: irast.PathId, args_pg: list[pgast.BaseExpr], ctx: context.CompilerContextLevel, newctx: context.CompilerContextLevel, inner_ctx: context.CompilerContextLevel, ) -> tuple[pgast.BaseExpr, pgast.BaseExpr]: lang, weights, query = args_pg el_name = sn.QualName('__object__', '__fts_document__') fts_document_ptrref = irast.SpecialPointerRef( name=el_name, shortname=el_name, out_source=obj_id.target, out_target=pg_types.pg_tsvector_typeref, out_cardinality=qltypes.Cardinality.AT_MOST_ONE, ) fts_document_id = obj_id.extend(ptrref=fts_document_ptrref) fts_document = relctx.get_path_var( newctx.rel, fts_document_id, aspect=pgce.PathAspect.VALUE, ctx=newctx, ) lang = pgast.FuncCall( name=astutils.edgedb_func('fts_to_regconfig', ctx=ctx), args=[lang], ) parsed_query: pgast.BaseExpr = pgast.FuncCall( name=astutils.edgedb_func('fts_parse_query', ctx=ctx), args=[query, lang] ) parsed_query_id = create_subrel_for_expr(parsed_query, ctx=inner_ctx) parsed_query = pathctx.get_path_var( inner_ctx.rel, parsed_query_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) weights = _fts_prepare_weights(weights, ctx=inner_ctx) score_pg = pgast.FuncCall( name=('pg_catalog', 'ts_rank'), args=[weights, fts_document, parsed_query], ) where_clause = pgast.Expr(lexpr=fts_document, name='@@', rexpr=parsed_query) return score_pg, where_clause def _fts_prepare_weights( weights: pgast.BaseExpr, ctx: context.CompilerContextLevel, ) -> pgast.BaseExpr: # default value default_weights = pgast.ArrayExpr( elements=[ pgast.NumericConstant(val='1.0'), pgast.NumericConstant(val='0.5'), pgast.NumericConstant(val='0.25'), pgast.NumericConstant(val='0.125'), ] ) weights = pgast.CoalesceExpr(args=[weights, default_weights]) # cast to reals weights = pgast.TypeCast( arg=weights, type_name=pgast.TypeName(name=('real',), array_bounds=[-1]) ) # pad with zeros zero = pgast.NumericConstant(val='0.0') padding_weights = pgast.ArrayExpr(elements=[zero, zero, zero, zero]) weights = pgast.Expr( lexpr=weights, name='||', rexpr=padding_weights, ) # put the whole expression into subrel, # so it can be referenced mutiple times weights_id = create_subrel_for_expr(weights, ctx=ctx) weights = pathctx.get_path_var( ctx.rel, weights_id, aspect=pgce.PathAspect.VALUE, env=ctx.env, ) # return array of first 4 values, reversed return pgast.ArrayExpr( elements=[ pgast.Indirection( arg=weights, indirection=[ pgast.Index(idx=pgast.NumericConstant(val=str(i))) ], ) for i in range(4, 0, -1) ] ) def _fts_search_inner_zombo( _call: irast.Call, obj_id: irast.PathId, args_pg: list[pgast.BaseExpr], _ctx: context.CompilerContextLevel, newctx: context.CompilerContextLevel, _inner_ctx: context.CompilerContextLevel, ) -> tuple[pgast.BaseExpr, pgast.BaseExpr]: _, _, query = args_pg el_name = sn.QualName('__object__', 'ctid') ctid_ptrref = irast.SpecialPointerRef( name=el_name, shortname=el_name, out_source=obj_id.target, out_target=pg_types.pg_oid_typeref, out_cardinality=qltypes.Cardinality.AT_MOST_ONE, ) ctid_id = obj_id.extend(ptrref=ctid_ptrref) ctid = relctx.get_path_var( newctx.rel, ctid_id, aspect=pgce.PathAspect.VALUE, ctx=newctx, ) score_pg = pgast.FuncCall(name=('zdb', 'score'), args=[ctid]) where_clause = pgast.Expr( lexpr=ctid, name='==>', rexpr=query, ) return score_pg, where_clause def create_subrel_for_expr( expr: pgast.BaseExpr, *, ctx: context.CompilerContextLevel ) -> irast.PathId: """ Creates a sub query relation that contains the given expression. """ # create a dummy path id for a dummy object expr_id = irast.PathId.new_dummy(ctx.env.aliases.get('d')) with ctx.subrel() as newctx: # register the expression pathctx.put_path_var( newctx.rel, expr_id, expr, aspect=pgce.PathAspect.VALUE, ) # include the subrel in the parent new_rvar = relctx.rvar_for_rel(newctx.rel, ctx=ctx) relctx.include_rvar( ctx.rel, new_rvar, expr_id, aspects=(pgce.PathAspect.VALUE,), ctx=ctx, ) return expr_id ================================================ FILE: edb/pgsql/compiler/shapecomp.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Compilation helpers for shapes.""" from __future__ import annotations from typing import Sequence from edb.edgeql import ast as qlast from edb.ir import ast as irast from edb.ir import utils as irutils from edb.pgsql import ast as pgast from . import astutils from . import context from . import dispatch from . import expr as expr_compiler # NOQA from . import relgen from . import relctx from . import pathctx def compile_shape( ir_set: irast.Set, shape: Sequence[tuple[irast.SetE[irast.Pointer], qlast.ShapeOp]], *, ctx: context.CompilerContextLevel) -> pgast.TupleVar: elements = [] # If the object identity is potentially nullable, filter it out # to prevent shapes with bogusly null insides. var = pathctx.get_path_value_var( ctx.rel, path_id=ir_set.path_id, env=ctx.env) if var.nullable: ctx.rel.where_clause = astutils.extend_binop( ctx.rel.where_clause, pgast.NullTest(arg=var, negated=True)) with ctx.newscope() as shapectx: shapectx.disable_semi_join |= {ir_set.path_id} if isinstance(ir_set.expr, irast.Stmt): # The source set for this shape is a FOR statement, # which is special in that besides set path_id it # should also expose the path_id of the FOR iterator # so that shape element expressions that might contain # an iterator reference find it properly. # # So, for: # SELECT Bar { # foo := (FOR x := ... UNION Foo { spam := x }) # } # # the path scope when processing the shape of Bar.foo # should be {'Bar.foo', 'x'}. iterator = ir_set.expr.iterator_stmt if iterator: shapectx.path_scope[iterator.path_id] = ctx.rel has_id = False for el, op in shape: if op == qlast.ShapeOp.MATERIALIZE and not ctx.materializing: continue rptr = el.expr ptrref = rptr.ptrref has_id |= ptrref.shortname.name == 'id' # As an implementation expedient, we currently represent # AT_MOST_ONE materialized values with arrays card = rptr.dir_cardinality is_singleton = ( card.is_single() and ( not ctx.materializing or not card.can_be_zero() ) ) value: pgast.BaseExpr if (irutils.is_subquery_set(el) or el.path_id.is_objtype_path() or not is_singleton or not ptrref.required): wrapper = relgen.set_as_subquery( el, as_value=True, ctx=shapectx) if not is_singleton: value = relctx.set_to_array( path_id=el.path_id, query=wrapper, ctx=shapectx) else: value = wrapper else: value = dispatch.compile(el, ctx=shapectx) tuple_el = astutils.tuple_element_for_shape_el( el, value, ctx=shapectx) assert isinstance(tuple_el, pgast.TupleElement) elements.append(tuple_el) # If there wasn't an id (because its a FreeObject), add a fake one. if ctx.materializing and not has_id: elements.append(pgast.TupleElement( path_id=ir_set.path_id, val=var, )) return pgast.TupleVar(elements=elements, named=True) ================================================ FILE: edb/pgsql/compiler/stmt.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional from edb import errors from edb.ir import ast as irast from edb.ir import utils as irutils from edb.pgsql import ast as pgast from . import astutils from . import clauses from . import context from . import dispatch from . import enums as pgce from . import group from . import dml from . import output from . import pathctx @dispatch.compile.register(irast.SelectStmt) def compile_SelectStmt( stmt: irast.SelectStmt, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: if ctx.singleton_mode: if not irutils.is_trivial_select(stmt): raise errors.UnsupportedFeatureError( 'Clause on SELECT statement in simple expression') return dispatch.compile(stmt.result, ctx=ctx) parent_ctx = ctx with parent_ctx.substmt() as ctx: # Common setup. clauses.compile_volatile_bindings(stmt, ctx=ctx) query = ctx.stmt # Process materialized sets clauses.compile_materialized_exprs(query, stmt, ctx=ctx) iterator_set = stmt.iterator_stmt last_iterator: Optional[irast.Set] = None if iterator_set: if irutils.contains_dml(stmt): # If we have iterators and we contain nested DML # statements, we need to hoist the iterators into CTEs and # then explicitly join them back into the query. iterator = dml.compile_iterator_cte(iterator_set, ctx=ctx) ctx.path_scope = ctx.path_scope.new_child() dml.merge_iterator(iterator, ctx.rel, ctx=ctx) ctx.enclosing_cte_iterator = iterator last_iterator = stmt.iterator_stmt else: # Process FOR clause. with ctx.new() as ictx: clauses.setup_iterator_volatility(last_iterator, ctx=ictx) iterator_rvar = clauses.compile_iterator_expr( query, iterator_set, is_dml=False, ctx=ictx) for aspect in {pgce.PathAspect.IDENTITY, pgce.PathAspect.VALUE}: pathctx.put_path_rvar( query, path_id=iterator_set.path_id, rvar=iterator_rvar, aspect=aspect, ) last_iterator = iterator_set # Process the result expression. with ctx.new() as ictx: clauses.setup_iterator_volatility(last_iterator, ctx=ictx) outvar = clauses.compile_output(stmt.result, ctx=ictx) with ctx.new() as ictx: # FILTER and ORDER BY need to have the base result as a # volatility ref. clauses.setup_iterator_volatility(stmt.result, ctx=ictx) # The FILTER clause. if stmt.where is not None: query.where_clause = astutils.extend_binop( query.where_clause, clauses.compile_filter_clause( stmt.where, stmt.where_card, ctx=ictx)) # The ORDER BY clause if stmt.orderby is not None: with ictx.new() as octx: query.sort_clause = clauses.compile_orderby_clause( stmt.orderby, ctx=octx) # Need to filter out NULLs in certain cases: if outvar.nullable and ( # A nullable var has bubbled up to the top query is ctx.toplevel_stmt # The cardinality is being overridden, so we need to make # sure there aren't extra NULLs in single set or stmt.card_inference_override # There is a LIMIT or OFFSET clause and NULLs would interfere or stmt.limit or stmt.offset ): valvar = pathctx.get_path_value_var( query, stmt.result.path_id, env=ctx.env) output.add_null_test(valvar, query) # The OFFSET clause query.limit_offset = clauses.compile_limit_offset_clause( stmt.offset, ctx=ctx) # The LIMIT clause query.limit_count = clauses.compile_limit_offset_clause( stmt.limit, ctx=ctx) return query @dispatch.compile.register(irast.GroupStmt) def compile_GroupStmt( stmt: irast.GroupStmt, *, ctx: context.CompilerContextLevel ) -> pgast.BaseExpr: return group.compile_group(stmt, ctx=ctx) @dispatch.compile.register(irast.InsertStmt) def compile_InsertStmt( stmt: irast.InsertStmt, *, ctx: context.CompilerContextLevel ) -> pgast.Query: parent_ctx = ctx with parent_ctx.substmt() as ctx: # Common DML bootstrap. parts = dml.init_dml_stmt(stmt, ctx=ctx) top_typeref = stmt.subject.typeref if top_typeref.material_type is not None: top_typeref = top_typeref.material_type insert_cte, _ = parts.dml_ctes[top_typeref] # Process INSERT body. dml.process_insert_body( ir_stmt=stmt, insert_cte=insert_cte, dml_parts=parts, ctx=ctx, ) # Wrap up. dml.fini_dml_stmt(stmt, parts, ctx=ctx) return ctx.rel @dispatch.compile.register(irast.UpdateStmt) def compile_UpdateStmt( stmt: irast.UpdateStmt, *, ctx: context.CompilerContextLevel ) -> pgast.Query: parent_ctx = ctx with parent_ctx.substmt() as ctx: # Common DML bootstrap. parts = dml.init_dml_stmt(stmt, ctx=ctx) range_cte = parts.range_cte assert range_cte is not None toplevel = ctx.toplevel_stmt toplevel.append_cte(range_cte) for typeref, (update_cte, _) in parts.dml_ctes.items(): # Process UPDATE body. dml.process_update_body( ir_stmt=stmt, update_cte=update_cte, dml_parts=parts, typeref=typeref, ctx=ctx, ) dml.fini_dml_stmt(stmt, parts, ctx=ctx) return ctx.rel @dispatch.compile.register(irast.DeleteStmt) def compile_DeleteStmt( stmt: irast.DeleteStmt, *, ctx: context.CompilerContextLevel ) -> pgast.Query: parent_ctx = ctx with parent_ctx.substmt() as ctx: # Common DML bootstrap parts = dml.init_dml_stmt(stmt, ctx=ctx) range_cte = parts.range_cte assert range_cte is not None ctx.toplevel_stmt.append_cte(range_cte) for typeref, (delete_cte, _) in parts.dml_ctes.items(): dml.process_delete_body( ir_stmt=stmt, delete_cte=delete_cte, typeref=typeref, ctx=ctx, ) # Wrap up. dml.fini_dml_stmt(stmt, parts, ctx=ctx) return ctx.rel ================================================ FILE: edb/pgsql/dbops/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Abstractions for low-level database DDL and DML operations and data.""" from __future__ import annotations from .base import * # NOQA from .config import * # type: ignore # NOQA from .ddl import * # NOQA from .databases import * # NOQA from .domains import * # NOQA from .enums import * # NOQA from .extensions import * # NOQA from .functions import * # NOQA from .indexes import * # NOQA from .operators import * # NOQA from .ranges import * # NOQA from .roles import * # NOQA from .schemas import * # NOQA from .sequences import * # NOQA from .tables import * # NOQA from .triggers import * # NOQA from .types import * # NOQA from .views import * # NOQA ================================================ FILE: edb/pgsql/dbops/base.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Final, Iterable, Iterator, Mapping, Optional, Sequence, ) from collections.abc import MutableSequence import collections import enum import numbers import textwrap from edb.common import markup from edb.common import struct from edb.common import typeutils from ..common import quote_ident as qi from ..common import quote_literal as ql from ..common import quote_type as qt from ..common import qname as qn class NotSpecifiedT(enum.Enum): NotSpecified = 0 NotSpecified: Final = NotSpecifiedT.NotSpecified def encode_value(val: Any) -> str: """Encode value into an appropriate SQL expression.""" if hasattr(val, 'to_sql_expr'): val = val.to_sql_expr() elif isinstance(val, tuple): val_list = [encode_value(el) for el in val] val = f'ROW({", ".join(val_list)})' elif isinstance(val, struct.Struct): val_list = [encode_value(el) for el in val.as_tuple()] val = f'ROW({", ".join(val_list)})' elif typeutils.is_container(val): val_list = [encode_value(el) for el in val] val = f'ARRAY[{", ".join(val_list)}]' elif val is None: val = 'NULL' elif not isinstance(val, numbers.Number): val = ql(str(val)) elif isinstance(val, int): val = str(int(val)) else: val = str(val) return val class PLExpression(str): pass class SQLBlock: commands: list[str | PLBlock] def __init__(self) -> None: self.commands = [] self._transactional = True def add_block(self) -> PLBlock: block = PLTopBlock() self.add_command(block) return block def to_string(self) -> str: if not self._transactional: raise ValueError( 'block is non-transactional, please use .get_statements()' ) stmts = self.get_statements() body = '\n\n'.join(stmt + ';' if stmt[-1] != ';' else stmt for stmt in stmts if stmt).rstrip() if body and body[-1] != ';': body += ';' return body def get_statements(self) -> list[str]: return [(cmd if isinstance(cmd, str) else cmd.to_string()).rstrip() for cmd in self.commands] def add_command(self, stmt: str | PLBlock) -> None: self.commands.append(stmt) def has_declarations(self) -> bool: return False def set_non_transactional(self) -> None: self._transactional = False def is_transactional(self) -> bool: return self._transactional class PLBlock(SQLBlock): varcounter: dict[str, int] shared_vars: set[str] declarations: list[tuple[str, str | tuple[str, str]]] conditions: Iterable[str | Condition] neg_conditions: Iterable[str | Condition] def __init__(self, top_block: Optional[PLTopBlock], level: int) -> None: super().__init__() self.top_block = top_block self.varcounter = collections.defaultdict(int) self.shared_vars = set() self.declarations = [] self.level = level self.conditions = set() self.neg_conditions = set() def has_declarations(self) -> bool: return bool(self.declarations) def has_statements(self) -> bool: return bool(self.commands) def get_top_block(self) -> PLTopBlock: return typeutils.not_none(self.top_block) def add_block(self) -> PLBlock: block = PLBlock(top_block=self.top_block, level=self.level + 1) self.add_command(block) return block def to_string(self) -> str: if self.declarations: vv = (f' {qi(n)} {qt(t)};' for n, t in self.declarations) decls = 'DECLARE\n' + '\n'.join(vv) + '\n' else: decls = '' body = super().to_string() if self.conditions or self.neg_conditions: exprs = [] if self.conditions: for cond in self.conditions: if not isinstance(cond, str): cond_expr = f'EXISTS ({cond.code()})' else: cond_expr = cond exprs.append(cond_expr) if self.neg_conditions: for cond in self.neg_conditions: if not isinstance(cond, str): cond_expr = f'EXISTS ({cond.code()})' else: cond_expr = cond exprs.append(f'NOT {cond_expr}') if_clause = '\n AND'.join( f'({textwrap.indent(expr, " ").lstrip()})' for expr in exprs ) body = textwrap.indent(body, ' ').rstrip() semicolon = ';' if body[-1] != ';' else '' body = f'IF {if_clause}\nTHEN\n{body}{semicolon}\nEND IF;' if decls or not isinstance(self.top_block, PLBlock): return textwrap.indent( f'{decls}BEGIN\n{body}\nEND;', ' ' * self.level * 4, ) else: return body def add_command( self, cmd: str | PLBlock, *, conditions: Optional[Iterable[str | Condition]] = None, neg_conditions: Optional[Iterable[str | Condition]] = None ) -> None: stmt: str | PLBlock if conditions or neg_conditions: exprs = [] if conditions: for cond in conditions: if not isinstance(cond, str): cond_expr = f'EXISTS ({cond.code()})' else: cond_expr = cond exprs.append(cond_expr) if neg_conditions: for cond in neg_conditions: if not isinstance(cond, str): cond_expr = f'EXISTS ({cond.code()})' else: cond_expr = cond exprs.append(f'NOT {cond_expr}') if_clause = '\n AND'.join( f'({textwrap.indent(expr, " ").lstrip()})' for expr in exprs ) if isinstance(cmd, PLBlock): cmd = cmd.to_string() cmd = textwrap.indent(cmd, ' ').rstrip() semicolon = ';' if cmd[-1] != ';' else '' stmt = f'IF {if_clause}\nTHEN\n{cmd}{semicolon}\nEND IF;' else: stmt = cmd super().add_command(stmt) def get_var_name(self, hint: Optional[str] = None) -> str: if hint is None: hint = 'v' self.varcounter[hint] += 1 return f'{hint}_{self.varcounter[hint]}' def declare_var( self, type_name: str | tuple[str, str], *, var_name: str='', var_name_prefix: str='v', shared: bool=False, ) -> str: if shared: if not var_name: var_name = var_name_prefix if var_name not in self.shared_vars: self.declarations.append((var_name, type_name)) self.shared_vars.add(var_name) else: if not var_name: var_name = self.get_var_name(var_name_prefix) self.declarations.append((var_name, type_name)) return var_name class PLTopBlock(PLBlock): def __init__(self) -> None: super().__init__(top_block=None, level=0) self.declare_var('text', var_name='_dummy_text', shared=True) def add_block(self) -> PLBlock: block = PLBlock(top_block=self, level=self.level + 1) self.add_command(block) return block def to_string(self) -> str: body = super().to_string() return f'DO LANGUAGE plpgsql $__$\n{body}\n$__$;' def get_top_block(self) -> PLTopBlock: return self class BaseCommand(markup.MarkupCapableMixin): def generate(self, block: SQLBlock) -> None: raise NotImplementedError @classmethod def as_markup(cls, self, *, ctx) -> markup.elements.lang.TreeNode: return markup.elements.lang.TreeNode(name=repr(self)) def dump(self) -> str: return str(self) class Command(BaseCommand): conditions: set[str | Condition] neg_conditions: set[str | Condition] def __init__( self, *, conditions: Optional[Iterable[str | Condition]] = None, neg_conditions: Optional[Iterable[str | Condition]] = None, ) -> None: self.opid = id(self) self.conditions = set(conditions) if conditions else set() self.neg_conditions = set(neg_conditions) if neg_conditions else set() def generate(self, block: SQLBlock) -> None: self_block = self.generate_self_block(block) if self_block is None: return self.generate_extra(self_block) self_block.conditions = self.conditions self_block.neg_conditions = self.neg_conditions def generate_self_block(self, block: SQLBlock) -> Optional[PLBlock]: # Default implementation simply calls self.code_with_block() self_block = block.add_block() self_block.add_command(self.code_with_block(self_block)) return self_block def generate_extra(self, block: PLBlock) -> None: pass def code(self) -> str: raise NotImplementedError def code_with_block(self, block: PLBlock) -> str: return self.code() class CommandGroup(Command): commands: MutableSequence[Command] def __init__( self, *, conditions: Optional[Iterable[str | Condition]] = None, neg_conditions: Optional[Iterable[str | Condition]] = None, ) -> None: super().__init__(conditions=conditions, neg_conditions=neg_conditions) self.commands = [] def add_command(self, cmd: Command) -> None: self.commands.append(cmd) def add_commands(self, cmds: Sequence[Command]) -> None: self.commands.extend(cmds) def generate_self_block(self, block: SQLBlock) -> Optional[PLBlock]: if not self.commands: return None self_block = block.add_block() for cmd in self.commands: cmd.generate(self_block) return self_block @classmethod def as_markup(cls, self, *, ctx) -> markup.elements.lang.TreeNode: node = markup.elements.lang.TreeNode(name=repr(self)) for op in self.commands: node.add_child(node=markup.serialize(op, ctx=ctx)) return node def __iter__(self) -> Iterator[Command]: return iter(self.commands) def __len__(self) -> int: return len(self.commands) class CompositeCommand(Command): def generate_extra_composite( self, block: PLBlock, group: CompositeCommandGroup ) -> None: pass class CompositeCommandGroup(Command): commands: MutableSequence[CompositeCommand] def __init__( self, *, conditions: Optional[Iterable[str | Condition]] = None, neg_conditions: Optional[Iterable[str | Condition]] = None, ) -> None: super().__init__(conditions=conditions, neg_conditions=neg_conditions) self.commands = [] def add_command(self, cmd: CompositeCommand) -> None: self.commands.append(cmd) def add_commands(self, cmds: Sequence[CompositeCommand]) -> None: self.commands.extend(cmds) def generate_self_block(self, block: SQLBlock) -> Optional[PLBlock]: if not self.commands: return None self_block = block.add_block() prefix_code = self.prefix_code() actions = [] dynamic_actions = [] for cmd in self.commands: if isinstance(cmd, tuple) and (cmd[1] or cmd[2]): action = cmd[0].code_with_block(self_block) if isinstance(action, PLExpression): subcommand = \ f"EXECUTE {ql(prefix_code)} || ' ' || {action}" else: subcommand = prefix_code + ' ' + action self_block.add_command( subcommand, conditions=cmd[1], neg_conditions=cmd[2]) else: action = cmd.code_with_block(self_block) if isinstance(action, PLExpression): subcommand = \ f"EXECUTE {ql(prefix_code)} || ' ' || {action}" dynamic_actions.append(subcommand) else: actions.append(action) if actions: command = prefix_code + ' ' + ', '.join(actions) self_block.add_command(command) if dynamic_actions: for action in dynamic_actions: self_block.add_command(action) extra_block = self_block.add_block() for cmd in self.commands: if isinstance(cmd, tuple) and (cmd[1] or cmd[2]): cmd[0].generate_extra_composite(extra_block, self) else: cmd.generate_extra_composite(extra_block, self) return self_block def prefix_code(self) -> str: raise NotImplementedError def __iter__(self) -> Iterator[CompositeCommand]: return iter(self.commands) def __len__(self) -> int: return len(self.commands) class Condition(BaseCommand): def code(self) -> str: raise NotImplementedError() class Query(Command): def __init__( self, text: str, *, type: Optional[str | tuple[str, str]] = None, trampoline_fixup: bool = True, ) -> None: from ..import trampoline super().__init__() if trampoline_fixup: text = trampoline.fixup_query(text) self.text = text self.type = type def to_sql_expr(self) -> str: if self.type: return f'({self.text})::{qn(*self.type)}' else: return self.text def code(self) -> str: return self.text def __repr__(self) -> str: return f'' class PLQuery(Query): pass class DefaultMeta(type): def __bool__(cls): return False def __repr__(self): return '' __str__ = __repr__ class Default(metaclass=DefaultMeta): pass class DBObject: def __init__( self, *, metadata: Optional[Mapping[str, Any]] = None ) -> None: self.metadata = dict(metadata) if metadata else None def add_metadata(self, key: str, value: Any) -> None: if self.metadata is None: self.metadata = {} self.metadata[key] = value def get_metadata(self, key: str) -> Any: if self.metadata is None: return None else: return self.metadata.get(key) def is_shared(self) -> bool: return False def get_type(self) -> str: raise NotImplementedError() def get_id(self) -> str: raise NotImplementedError() class InheritableDBObject(DBObject): def __init__( self, *, inherit: bool = False, metadata: Optional[Mapping[str, Any]] = None, ) -> None: super().__init__(metadata=metadata) if inherit: self.add_metadata('ddl:inherit', inherit) @property def inherit(self) -> bool: return self.get_metadata('ddl:inherit') or False class NoOpCommand(Command): def generate_self_block(self, block: SQLBlock) -> Optional[PLBlock]: return None ================================================ FILE: edb/pgsql/dbops/catalogs.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2013-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from . import tables class PgDescriptionTable(tables.Table): def __init__(self, name=None): super().__init__(name=('pg_catalog', 'pg_description')) self.add_columns([ tables.Column(name='objoid', type='oid', required=True), tables.Column(name='classoid', type='oid', required=True), tables.Column(name='objsubid', type='integer', required=True), tables.Column(name='description', type='text') ]) ================================================ FILE: edb/pgsql/dbops/composites.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Iterable, Sequence from edb.common import ordered from .. import common from . import base from . import tables class Record(type): def __new__(mcls, name, fields, default=None): dct = {'_fields___': fields, '_default___': default} bases = (RecordBase, ) return super(Record, mcls).__new__(mcls, name, bases, dct) def __init__(cls, name, fields, default): pass def has_field(cls, name): return name in cls._fields___ class RecordBase: def __init__(self, **kwargs): for k, v in kwargs.items(): if k not in self.__class__._fields___: msg = '__init__() got an unexpected keyword argument %s' % k raise TypeError(msg) setattr(self, k, v) for k in set(self.__class__._fields___) - set(kwargs.keys()): setattr(self, k, self.__class__._default___) def __setattr__(self, name, value): if name not in self.__class__._fields___: msg = '%s has no attribute %s' % (self.__class__.__name__, name) raise AttributeError(msg) super().__setattr__(name, value) def __eq__(self, tup): if not isinstance(tup, tuple): return NotImplemented return tuple(self) == tup def __getitem__(self, index): return getattr(self, self.__class__._fields___[index]) def __iter__(self): for name in self.__class__._fields___: yield getattr(self, name) def __len__(self): return len(self.__class__._fields___) def items(self): for name in self.__class__._fields___: yield name, getattr(self, name) def keys(self): return iter(self.__class__._fields___) def __str__(self): f = ', '.join(str(v) for v in self) if len(self) == 1: f += ',' return '(%s)' % f __repr__ = __str__ class CompositeDBObject(base.DBObject): def __init__( self, name: Sequence[str], columns: Iterable[tables.Column] | None = None, ): super().__init__() self.name = name self._columns: ordered.OrderedSet[tables.Column] = ordered.OrderedSet() self.add_columns(columns or []) def add_columns(self, iterable: Iterable[tables.Column]): self._columns.update(iterable) @property def record(self): return Record( self.__class__.__name__ + '_record', [c.name for c in self._columns], default=base.Default) class CompositeAttributeCommand: def __init__(self, attribute): self.attribute = attribute def __repr__(self): return '<%s.%s %r>' % ( self.__class__.__module__, self.__class__.__name__, self.attribute) class AlterCompositeAddAttribute(CompositeAttributeCommand): def code(self) -> str: return (f'ADD {self.get_attribute_term()} ' # type: ignore f'{self.attribute.code()}') def generate_extra_composite( self, block: base.PLBlock, alter: base.CompositeCommandGroup ) -> None: self.attribute.generate_extra_composite(block, alter) class AlterCompositeDropAttribute(CompositeAttributeCommand): def code(self) -> str: attrname = common.qname(self.attribute.name) return f'DROP {self.get_attribute_term()} {attrname}' # type: ignore class AlterCompositeAlterAttributeType: def __init__(self, attribute_name, new_type, *, cast_expr=None): self.attribute_name = attribute_name self.new_type = new_type self.cast_expr = cast_expr def code(self) -> str: attrterm = self.get_attribute_term() # type: ignore attrname = common.quote_ident(str(self.attribute_name)) code = f'ALTER {attrterm} {attrname} SET DATA TYPE {self.new_type}' if self.cast_expr is not None: code += f' USING ({self.cast_expr})' return code def __repr__(self): cls = self.__class__ return f'<{cls.__name__} {self.attribute_name!r} to {self.new_type}>' ================================================ FILE: edb/pgsql/dbops/config.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from ..common import quote_ident as qi from ..common import quote_literal as ql from . import base class Set(base.Command): def __init__(self, key, val, **kwargs): super().__init__(**kwargs) self.key = key self.val = val def code(self) -> str: return f'SET {qi(self.key)} = {ql(self.val)}' ================================================ FILE: edb/pgsql/dbops/constraints.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, Sequence from .. import common from . import base class Constraint(base.DBObject): def __init__( self, subject_name: Sequence[str], constraint_name: Optional[str] = None, ): self._subject_name = tuple(subject_name) self._constraint_name = constraint_name def get_type(self): return 'CONSTRAINT' def get_subject_type(self): raise NotImplementedError def generate_extra(self, block: base.PLBlock) -> None: raise NotImplementedError def get_subject_name(self, quote=True): if quote: return common.qname(*self._subject_name) else: return self._subject_name def get_id(self): return '{} ON {} {}'.format( self.constraint_name(), self.get_subject_type(), self.get_subject_name()) def constraint_name(self, quote=True) -> str: if quote and self._constraint_name: return common.quote_ident(self._constraint_name) else: return self._constraint_name or '' def constraint_code(self, block: base.PLBlock) -> str | list[str]: raise NotImplementedError ================================================ FILE: edb/pgsql/dbops/databases.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Mapping import textwrap from ..common import quote_ident as qi from ..common import quote_literal as ql from ..common import versioned_schema as V from . import base from . import ddl class AbstractDatabase(base.DBObject): def get_type(self): return 'DATABASE' def is_shared(self) -> bool: return True def _get_id_expr(self) -> str: raise NotImplementedError() def get_oid(self) -> base.Query: qry = textwrap.dedent(f'''\ SELECT 'pg_database'::regclass::oid AS classoid, pg_database.oid AS objectoid, 0 FROM pg_database WHERE datname = {self._get_id_expr()} ''') return base.Query(text=qry) class Database(AbstractDatabase): def __init__( self, name: str, *, owner: Optional[str] = None, is_template: bool = False, encoding: Optional[str] = None, lc_collate: Optional[str] = None, lc_ctype: Optional[str] = None, metadata: Optional[Mapping[str, Any]] = None, ) -> None: super().__init__(metadata=metadata) self.name = name self.owner = owner self.is_template = is_template self.encoding = encoding self.lc_collate = lc_collate self.lc_ctype = lc_ctype def get_id(self): return qi(self.name) def _get_id_expr(self) -> str: return ql(self.name) class DatabaseWithTenant(Database): def __init__( self, name: str, ) -> None: super().__init__(name=name) def get_id(self) -> str: return f"' || quote_ident({self._get_id_expr()}) || '" def _get_id_expr(self) -> str: return f'{V("edgedb")}.get_database_backend_name({ql(self.name)})' class CurrentDatabase(AbstractDatabase): def get_id(self) -> str: return f"' || quote_ident({self._get_id_expr()}) || '" def _get_id_expr(self) -> str: return 'current_database()' class DatabaseExists(base.Condition): def __init__(self, name): self.name = name def code(self) -> str: return textwrap.dedent(f'''\ SELECT typname FROM pg_catalog.pg_database AS db WHERE datname = {ql(self.name)} ''') class CreateDatabase(ddl.CreateObject, ddl.NonTransactionalDDLOperation): def __init__(self, object, *, template: str | None, **kwargs): super().__init__(object, **kwargs) self.template = template def code(self) -> str: extra = '' if self.object.owner: extra += f' OWNER={ql(self.object.owner)}' if self.object.is_template: extra += f' IS_TEMPLATE = TRUE' if self.template: extra += f' TEMPLATE={ql(self.template)}' if self.object.encoding: extra += f' ENCODING={ql(self.object.encoding)}' if self.object.lc_collate: extra += f' LC_COLLATE={ql(self.object.lc_collate)}' if self.object.lc_ctype: extra += f' LC_CTYPE={ql(self.object.lc_ctype)}' return (f'CREATE DATABASE {self.object.get_id()} {extra}') class DropDatabase(ddl.SchemaObjectOperation, ddl.NonTransactionalDDLOperation): def code(self) -> str: return f'DROP DATABASE {qi(self.name)}' class RenameDatabase(ddl.AlterObject, ddl.NonTransactionalDDLOperation): def __init__(self, object, *, old_name: str, **kwargs): super().__init__(object, **kwargs) self.old_name = old_name def code(self) -> str: return ( f'ALTER DATABASE {qi(self.old_name)} ' f'RENAME TO {self.object.get_id()}' ) ================================================ FILE: edb/pgsql/dbops/ddl.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import json import textwrap from edb.server import defines from ..common import quote_ident as qi from ..common import quote_literal as ql from . import base class DDLOperation(base.Command): pass class NonTransactionalDDLOperation(DDLOperation): def generate( self, block: base.SQLBlock, ) -> None: block.add_command(self.code()) block.set_non_transactional() self_block = block.add_block() self.generate_extra(self_block) self_block.conditions = self.conditions self_block.neg_conditions = self.neg_conditions class SchemaObjectOperation(DDLOperation): def __init__(self, name, *, conditions=None, neg_conditions=None): super().__init__(conditions=conditions, neg_conditions=neg_conditions) self.name = name self.opid = name def __repr__(self): return '' % (self.__class__.__name__, self.name) class Comment(DDLOperation): def __init__(self, object, text, **kwargs): super().__init__(**kwargs) self.object = object self.text = text def code(self) -> str: object_type = self.object.get_type() object_id = self.object.get_id() code = 'COMMENT ON {type} {id} IS {text}'.format( type=object_type, id=object_id, text=ql(self.text)) return code class ReassignOwned(DDLOperation): def __init__(self, old_role, new_role, **kwargs): super().__init__(**kwargs) self.old_role = old_role self.new_role = new_role def qi(self, ident: str) -> str: if ident.upper() in ('CURRENT_USER', 'SESSION_USER'): return ident else: return qi(ident) def code(self) -> str: return ( f'REASSIGN OWNED BY {self.qi(self.old_role)} ' f'TO {self.qi(self.new_role)}' ) class GetMetadata(base.Command): def __init__(self, object): super().__init__() self.object = object def code_with_block(self, block: base.PLBlock) -> str: from .. import trampoline oid = self.object.get_oid() is_shared = self.object.is_shared() if isinstance(oid, base.Query): qry = oid.text classoid = block.declare_var('oid') objoid = block.declare_var('oid') objsubid = block.declare_var('oid') block.add_command( qry + f' INTO {classoid}, {objoid}, {objsubid}') else: objoid, classoid, objsubid = oid if is_shared: q = textwrap.dedent(f'''\ SELECT edgedb_VER.shobj_metadata( {objoid}, {classoid}::regclass::text ) ''') elif objsubid: q = textwrap.dedent(f'''\ SELECT edgedb_VER.col_metadata( {objoid}, {objsubid} ) ''') else: q = textwrap.dedent(f'''\ SELECT edgedb_VER.obj_metadata( {objoid}, {classoid}::regclass::text, ) ''') return trampoline.fixup_query(q) class GetSingleDBMetadata(base.Command): def __init__(self, dbname, **kwargs): super().__init__(**kwargs) self.dbname = dbname def code(self) -> str: from .. import trampoline key = f'{self.dbname}metadata' return textwrap.dedent(trampoline.fixup_query(f'''\ SELECT json FROM edgedbinstdata_VER.instdata WHERE key = {ql(key)} ''')) class PutMetadata(DDLOperation): def __init__(self, object, metadata, **kwargs): super().__init__(**kwargs) self.object = object self.metadata = metadata def __repr__(self): return \ '<{mod}.{cls} {object!r} {metadata!r}>'.format( mod=self.__class__.__module__, cls=self.__class__.__name__, object=self.object, metadata=self.metadata) class PutSingleDBMetadata(DDLOperation): def __init__(self, dbname, metadata, **kwargs): super().__init__(**kwargs) self.dbname = dbname self.metadata = metadata @property def key(self): return f'{self.dbname}metadata' def __repr__(self): return \ '<{mod}.{cls} Branch({dbname!r}) {metadata!r}>'.format( mod=self.__class__.__module__, cls=self.__class__.__name__, dbname=self.dbname, metadata=self.metadata) class SetMetadata(PutMetadata): def creation_code(self) -> str: metadata = self.metadata object_type = self.object.get_type() object_id = self.object.get_id() prefix = ql(defines.EDGEDB_VISIBLE_METADATA_PREFIX) desc = ql(json.dumps(metadata)) comment = f'E{prefix} || {desc}' return textwrap.dedent(f'''\ 'COMMENT ON {object_type} {object_id} IS ' || quote_literal({comment}) ''') def code(self) -> str: return 'EXECUTE ' + self.creation_code() + ';' class SetSingleDBMetadata(PutSingleDBMetadata): def code(self) -> str: from .. import trampoline metadata = ql(json.dumps(self.metadata)) return textwrap.dedent(trampoline.fixup_query(f'''\ UPDATE edgedbinstdata_VER.instdata SET json = {metadata} WHERE key = {ql(self.key)}; ''')) class UpdateMetadata(PutMetadata): def code_with_block(self, block: base.PLBlock) -> str: metadata_qry = GetMetadata(self.object).code_with_block(block) prefix = ql(defines.EDGEDB_VISIBLE_METADATA_PREFIX) json_v = block.declare_var('jsonb') upd_v = block.declare_var('text') meta_v = block.declare_var('jsonb') block.add_command(f'{json_v} := ({metadata_qry});') upd_metadata = ql(json.dumps(self.metadata)) block.add_command(f'{meta_v} := {upd_metadata}::jsonb') block.add_command(textwrap.dedent(f'''\ IF {json_v} IS NOT NULL THEN {upd_v} := E{prefix} || ({json_v} || {meta_v})::text; ELSE {upd_v} := E{prefix} || {meta_v}::text; END IF; ''')) object_type = self.object.get_type() object_id = self.object.get_id() return textwrap.dedent(f'''\ IF {upd_v} IS NOT NULL THEN EXECUTE 'COMMENT ON {object_type} {object_id} IS ' || quote_literal({upd_v}); END IF; ''') class UpdateSingleDBMetadata(PutSingleDBMetadata): def code_with_block(self, block: base.PLBlock) -> str: from .. import trampoline metadata_qry = GetSingleDBMetadata(self.dbname).code_with_block(block) json_v = block.declare_var('jsonb') meta_v = block.declare_var('jsonb') block.add_command(f'{json_v} := ({metadata_qry});') upd_metadata = ql(json.dumps(self.metadata)) block.add_command(f'{meta_v} := {upd_metadata}::jsonb') return textwrap.dedent(trampoline.fixup_query(f'''\ UPDATE edgedbinstdata_VER.instdata SET json = {json_v} || {meta_v} WHERE key = {ql(self.key)} ''')) class UpdateMetadataSectionMixin: def __init__(self, *args, section, **kwargs): super().__init__(*args, **kwargs) self.section = section def _metadata_query(self) -> base.Command: raise NotImplementedError def _merge(self, block): metadata_qry = self._metadata_query().code_with_block(block) json_v = block.declare_var('jsonb') meta_v = block.declare_var('jsonb') block.add_command(f'{json_v} := ({metadata_qry});') upd_metadata = ql(json.dumps(self.metadata)) block.add_command( f"{meta_v} := jsonb_strip_nulls(jsonb_build_object(\n" f" {ql(self.section)},\n" f" COALESCE({json_v} -> {ql(self.section)}, '{{}}')" f" || {upd_metadata}::jsonb\n" f"))" ) return json_v, meta_v class UpdateMetadataSection(UpdateMetadataSectionMixin, PutMetadata): def _metadata_query(self) -> base.Command: return GetMetadata(self.object) def code_with_block(self, block: base.PLBlock) -> str: json_v, meta_v = self._merge(block) upd_v = block.declare_var('text') prefix = ql(defines.EDGEDB_VISIBLE_METADATA_PREFIX) block.add_command(textwrap.dedent(f'''\ IF {json_v} IS NOT NULL THEN {upd_v} := E{prefix} || ({json_v} || {meta_v})::text; ELSE {upd_v} := E{prefix} || {meta_v}::text; END IF; ''')) object_type = self.object.get_type() object_id = self.object.get_id() return textwrap.dedent(f'''\ IF {upd_v} IS NOT NULL THEN EXECUTE 'COMMENT ON {object_type} {object_id} IS ' || quote_literal({upd_v}); END IF; ''') class UpdateSingleDBMetadataSection( UpdateMetadataSectionMixin, PutSingleDBMetadata ): def _metadata_query(self) -> base.Command: return GetSingleDBMetadata(self.dbname) def code_with_block(self, block: base.PLBlock) -> str: from .. import trampoline json_v, meta_v = self._merge(block) return textwrap.dedent(trampoline.fixup_query(f'''\ UPDATE edgedbinstdata_VER.instdata SET json = {json_v} || {meta_v} WHERE key = {ql(self.key)} ''')) class CreateObject(SchemaObjectOperation): def __init__(self, object, **kwargs): super().__init__(object.get_id(), **kwargs) self.object = object def generate_extra(self, block: base.PLBlock) -> None: super().generate_extra(block) if self.object.metadata: mdata = SetMetadata(self.object, self.object.metadata) block.add_command(mdata.code_with_block(block)) class RenameObject(SchemaObjectOperation): def __init__(self, object, *, new_name, **kwargs): super().__init__(name=object.name, **kwargs) self.object = object self.altered_object = object.copy() self.altered_object.rename(new_name) self.new_name = new_name def generate_extra(self, block: base.PLBlock) -> None: super().generate_extra(block) # FIXME?: This used to be here, in the original code, but it # doesn't work anymore and probably isn't important. # if self.object.metadata: # mdata = UpdateMetadata( # self.altered_object, self.altered_object.metadata) # block.add_command(mdata.code_with_block(block)) class AlterObject(SchemaObjectOperation): def __init__(self, object, **kwargs): super().__init__(object.get_id(), **kwargs) self.object = object def generate_extra(self, block: base.PLBlock) -> None: super().generate_extra(block) if self.object.metadata: mdata = SetMetadata(self.object, self.object.metadata) block.add_command(mdata.code_with_block(block)) class DropObject(SchemaObjectOperation): def __init__(self, object, **kwargs): super().__init__(object.get_id(), **kwargs) self.object = object ================================================ FILE: edb/pgsql/dbops/domains.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Iterable, Mapping, Optional, Sequence, TypeAlias, ) import textwrap from ..common import qname as qn from ..common import quote_literal as ql from ..common import quote_type as qt from . import base from . import constraints from . import ddl DomainName: TypeAlias = tuple[str, ...] class DomainExists(base.Condition): def __init__(self, name: DomainName): self.name = name def code(self) -> str: return textwrap.dedent(f'''\ SELECT domain_name FROM information_schema.domains WHERE domain_schema = {ql(self.name[0])} AND domain_name = {ql(self.name[1])} ''') class Domain(base.DBObject): def __init__( self, name: DomainName, *, base: str | DomainName, constraints: Sequence[DomainConstraint] = (), metadata: Optional[Mapping[str, Any]] = None ): self.constraints = tuple(constraints) self.base = base self.name = name super().__init__(metadata=metadata) class CreateDomain(ddl.SchemaObjectOperation): def __init__( self, domain: Domain, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: super().__init__( domain.name, conditions=conditions, neg_conditions=neg_conditions ) self.domain = domain def code_with_block(self, block: base.PLBlock) -> str: extra: list[str] = [] for constraint in self.domain.constraints: extra.append(constraint.constraint_code(block)) return textwrap.dedent(f'''\ CREATE DOMAIN {qn(*self.domain.name)} AS {qt(self.domain.base)} {' '.join(extra) if extra else ''} ''').strip() class AlterDomain(ddl.DDLOperation): def __init__( self, name: DomainName, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: super().__init__(conditions=conditions, neg_conditions=neg_conditions) self.name = name def prefix_code(self) -> str: return f'ALTER DOMAIN {qn(*self.name)}' def __repr__(self) -> str: return '' % (self.__class__.__name__, self.name) class AlterDomainAlterDefault(AlterDomain): def __init__( self, name: DomainName, default: Optional[str] ) -> None: super().__init__(name) self.default = default def code(self) -> str: code = self.prefix_code() if self.default is None: code += ' DROP DEFAULT ' else: if self.default is not None: value = ql(str(self.default)) else: value = 'None' code += f' SET DEFAULT {value}' return code class AlterDomainAlterNull(AlterDomain): def __init__(self, name: DomainName, null: bool) -> None: super().__init__(name) self.null = null def code(self) -> str: code = self.prefix_code() if self.null: code += ' DROP NOT NULL ' else: code += ' SET NOT NULL ' return code class AlterDomainAlterConstraint(AlterDomain): def __init__( self, name: DomainName, constraint: DomainConstraint, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: super().__init__( name, conditions=conditions, neg_conditions=neg_conditions) self._constraint = constraint class DomainConstraint(constraints.Constraint): def get_subject_type(self) -> str: return 'DOMAIN' def constraint_code(self, block: base.PLBlock) -> str: raise NotImplementedError() class DomainCheckConstraint(DomainConstraint): def __init__( self, domain_name: DomainName, constraint_name: Optional[str] = None, *, expr: base.Query | str, ) -> None: super().__init__(domain_name, constraint_name=constraint_name) self.expr = expr def constraint_code(self, block: base.PLBlock) -> str: if isinstance(self.expr, base.Query): assert self.expr.type var = block.declare_var(self.expr.type) indent = len(var) + 5 expr_text = textwrap.indent(self.expr.text, ' ' * indent).strip() block.add_command(f'{var} := ({expr_text})') code = f"'CHECK (' || {var} || ')'" code = base.PLExpression(code) else: code = f'CHECK ({self.expr})' return code class AlterDomainAddConstraint(AlterDomainAlterConstraint): def code_with_block(self, block: base.PLBlock) -> str: code = self.prefix_code() constr_name = self._constraint.constraint_name() constr_code = self._constraint.constraint_code(block) if isinstance(constr_code, base.PLExpression): code = (f"EXECUTE {ql(code)} || ' ADD CONSTRAINT {constr_name} ' " f"|| {constr_code}") else: code += f' ADD CONSTRAINT {constr_name} {constr_code}' return code def generate_extra(self, block: base.PLBlock) -> None: return self._constraint.generate_extra(block) class AlterDomainDropConstraint(AlterDomainAlterConstraint): def code(self) -> str: code = super().prefix_code() code += f' DROP CONSTRAINT {self._constraint.constraint_name()}' return code class DropDomain(ddl.SchemaObjectOperation): def code(self) -> str: return f'DROP DOMAIN {qn(*self.name)}' ================================================ FILE: edb/pgsql/dbops/enums.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Iterable, Mapping, Optional, Sequence, TypeAlias, ) import textwrap from ..common import qname as qn from ..common import quote_literal as ql from . import base from . import ddl EnumName: TypeAlias = tuple[str, str] class EnumExists(base.Condition): def __init__(self, name: EnumName) -> None: self.name = name def code(self) -> str: return textwrap.dedent(f'''\ SELECT t.typname FROM pg_catalog.pg_type t INNER JOIN pg_namespace nsp ON (t.typnamespace = nsp.oid) WHERE nsp.nspname = {ql(self.name[0])} AND t.typname = {ql(self.name[1])} AND t.typtype = 'e' ''') class Enum(base.DBObject): def __init__( self, name: EnumName, values: Sequence[str], *, metadata: Optional[Mapping[str, Any]] = None, ) -> None: self.name = name self.values = values super().__init__(metadata=metadata) class CreateEnum(ddl.SchemaObjectOperation): def __init__( self, enum: Enum, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: super().__init__( enum.name, conditions=conditions, neg_conditions=neg_conditions) self.values = enum.values def code(self) -> str: vals = ', '.join(ql(v) for v in self.values) return f'CREATE TYPE {qn(*self.name)} AS ENUM ({vals})' class AlterEnum(ddl.DDLOperation): def __init__( self, name: EnumName, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: super().__init__(conditions=conditions, neg_conditions=neg_conditions) self.name = name def prefix_code(self) -> str: return f'ALTER TYPE {qn(*self.name)}' def __repr__(self) -> str: return '' % (self.__class__.__name__, self.name) class AlterEnumAddValue(AlterEnum): def __init__( self, name: EnumName, value: str, *, before: Optional[str] = None, after: Optional[str] = None, conditional: bool = False, ): super().__init__(name) self.value = value self.before = before self.after = after self.conditional = conditional def code(self) -> str: code = self.prefix_code() code += ' ADD VALUE' if self.conditional: code += ' IF NOT EXISTS' code += f' {ql(self.value)}' if self.before: code += f' BEFORE {ql(self.before)}' elif self.after: code += f' AFTER {ql(self.after)}' return code class DropEnum(ddl.SchemaObjectOperation): def code(self) -> str: return f'DROP TYPE {qn(*self.name)}' ================================================ FILE: edb/pgsql/dbops/extensions.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from ..common import quote_ident as qi from . import ddl class Extension: def __init__(self, name, schema='edgedb'): self.name = name self.schema = schema def get_extension_name(self): return self.name def code(self) -> str: name = qi(self.get_extension_name()) schema = qi(self.schema) return f'CREATE EXTENSION {name} WITH SCHEMA {schema}' class CreateExtension(ddl.DDLOperation): def __init__( self, extension, *, conditions=None, neg_conditions=None, conditional=False, ): super().__init__(conditions=conditions, neg_conditions=neg_conditions) self.extension = extension self.opid = extension.name self.conditional = conditional def code(self) -> str: ext = self.extension name = qi(ext.get_extension_name()) schema = qi(ext.schema) condition = "IF NOT EXISTS " if self.conditional else '' return f'CREATE EXTENSION {condition}{name} WITH SCHEMA {schema}' class DropExtension(ddl.DDLOperation): def __init__(self, extension, *, conditions=None, neg_conditions=None): super().__init__(conditions=conditions, neg_conditions=neg_conditions) self.extension = extension self.opid = extension.name def code(self) -> str: ext = self.extension name = qi(ext.get_extension_name()) return f'DROP EXTENSION {name}' ================================================ FILE: edb/pgsql/dbops/functions.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import textwrap from typing import Optional, Sequence, cast from ..common import qname as qn from ..common import quote_literal as ql from ..common import quote_type as qt from . import base from . import ddl FunctionArgType = str | tuple[str, ...] FunctionArgTyped = tuple[Optional[str], FunctionArgType] FunctionArgDefaulted = tuple[Optional[str], FunctionArgType, str] FunctionArg = FunctionArgTyped | FunctionArgDefaulted NormalizedFunctionArg = tuple[Optional[str], tuple[str, ...], Optional[str]] class Function(base.DBObject): def __init__( self, name: tuple[str, ...], *, args: Optional[Sequence[FunctionArg]] = None, returns: str | tuple[str, ...], text: str, volatility: str = "volatile", language: str = "sql", has_variadic: Optional[bool] = None, strict: bool = False, parallel_safe: bool = False, set_returning: bool = False, # Unused for Function, used in VersionedFunction. wrapper_volatility: Optional[str] = None, ): if volatility.lower() == 'modifying': volatility = 'volatile' self.name = name self.args = args self.returns = returns self.text = text self.volatility = volatility self.language = language self.has_variadic = has_variadic self.strict = strict self.set_returning = set_returning self.parallel_safe = parallel_safe def __repr__(self): return '<{} {} at 0x{}>'.format( self.__class__.__name__, self.name, id(self)) class FunctionExists(base.Condition): def __init__(self, name, args=None): self.name = name self.args = FunctionOperation.normalize_args(args) def code(self) -> str: targs = [f"{ql(qt(a))}::regtype::oid" for _, a, _ in self.args] args = f"ARRAY[{','.join(targs)}]" return textwrap.dedent(f'''\ SELECT p.proname FROM pg_catalog.pg_proc p INNER JOIN pg_catalog.pg_namespace ns ON (ns.oid = p.pronamespace) WHERE p.proname = {ql(self.name[1])} AND ns.nspname = {ql(self.name[0])} AND {args}::oid[] = ARRAY( SELECT t FROM unnest(p.proargtypes) AS t) ''') class FunctionOperation: @staticmethod def normalize_args( args: Optional[Sequence[FunctionArg]] ) -> Sequence[NormalizedFunctionArg]: normed = [] for arg in args or (): name = None default = None if isinstance(arg, tuple): name = arg[0] typ = arg[1] if len(arg) > 2: arg_def = cast(FunctionArgDefaulted, arg) default = arg_def[2] else: typ = arg ttyp = (typ,) if isinstance(typ, str) else typ normed.append((name, ttyp, default)) return normed @staticmethod def format_args( args: Optional[Sequence[FunctionArg]], has_variadic: Optional[bool], *, include_defaults: bool = True, ) -> str: if not args: return '' args_buf = [] normed = FunctionOperation.normalize_args(args) for argi, (arg_name, arg_typ, arg_def) in enumerate(normed, 1): vararg = has_variadic and (len(args) == argi) arg_expr = 'VARIADIC ' if vararg else '' if arg_name is not None: arg_expr += qn(arg_name, column=True) arg_expr += ' ' + qt(arg_typ) if include_defaults: if arg_def: arg_expr += ' = ' + arg_def args_buf.append(arg_expr) return ', '.join(args_buf) class CreateFunction(ddl.DDLOperation, FunctionOperation): def __init__( self, function: Function, *, or_replace: bool = False, **kwargs ): super().__init__(**kwargs) self.function = function self.or_replace = or_replace def code(self) -> str: args = self.format_args(self.function.args, self.function.has_variadic) code = textwrap.dedent(''' CREATE {replace} FUNCTION {name}({args}) RETURNS {setof} {returns} AS $____funcbody____$ {text} $____funcbody____$ LANGUAGE {lang} {volatility} {strict} {parallel}; ''').format_map({ 'replace': 'OR REPLACE' if self.or_replace else '', 'name': qn(*self.function.name), 'args': args, 'returns': qt(self.function.returns), 'lang': self.function.language, 'volatility': self.function.volatility.upper(), 'text': textwrap.dedent(self.function.text).strip(), 'strict': 'STRICT' if self.function.strict else '', 'setof': 'SETOF' if self.function.set_returning else '', 'parallel': ( 'PARALLEL ' + ('SAFE' if self.function.parallel_safe else 'UNSAFE') ), }) return code.strip() class DropFunction(ddl.DDLOperation, FunctionOperation): def __init__( self, name: tuple[str, ...], args: Sequence[FunctionArg], *, if_exists: bool = False, has_variadic: bool = False, conditions: Optional[list[str | base.Condition]] = None, neg_conditions: Optional[list[str | base.Condition]] = None, ): self.conditional = if_exists super().__init__(conditions=conditions, neg_conditions=neg_conditions) self.name = name self.args = args self.has_variadic = has_variadic def code(self) -> str: ifexists = ' IF EXISTS' if self.conditional else '' args = self.format_args(self.args, self.has_variadic, include_defaults=False) return f'DROP FUNCTION{ifexists} {qn(*self.name)}({args})' ================================================ FILE: edb/pgsql/dbops/indexes.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import re import textwrap from typing import Any, Iterable from edb.common import ordered from ..common import qname as qn from ..common import quote_ident as qi from ..common import quote_literal as ql from .. import ast as pgast from . import base from . import ddl from . import tables class Index(tables.InheritableTableObject): def __init__( self, name: str, table_name: tuple[str, str], unique: bool = True, exprs: Iterable[str] | None = None, with_clause: dict[str, str] | None = None, predicate: str | None = None, inherit: bool = False, metadata: dict[str, Any] | None = None, columns: Iterable[str | pgast.Star] | None = None, ) -> None: super().__init__(inherit=inherit, metadata=metadata) assert table_name[1] != 'feature' self.name = name self.table_name = table_name self._columns: ordered.OrderedSet[str] = ordered.OrderedSet() if columns: self.add_columns(columns) self.with_clause = with_clause self.predicate = predicate self.unique = unique self.exprs = exprs self.add_metadata('fullname', self.name) @property def name_in_catalog(self) -> str: return self.name def add_columns(self, columns: Iterable[str | pgast.Star]) -> None: for col in columns: if isinstance(col, pgast.Star): raise NotImplementedError() self._columns.add(col) def rename(self, new_name): self.name = new_name def creation_code( self, concurrently: bool = False, conditional: bool = False, ) -> str: if self.exprs: exprs = self.exprs else: exprs = [qi(c) for c in self.columns] # Break down the code into the index name (if present) and the rest of # the expression m = re.match(r'(?P\w+)?\s*(?P.+)', self.get_metadata('code').strip()) assert m is not None using = m['using'] expr = m['expr'] code = 'CREATE' if self.unique: code += ' UNIQUE' code += f' INDEX' if concurrently: code += ' CONCURRENTLY' if conditional: code += ' IF NOT EXISTS' code += f' {qn(self.name_in_catalog)} ON {qn(*self.table_name)}' if using: code += f' USING {using}' # expr is expected to be wrapped in parentheses, but in order to # manipulate it better we strip the parentheses expr = expr[1:-1].replace('__col__', '{col}') expr = ', '.join(expr.format(col=e) for e in exprs) kwargs = self.get_metadata('kwargs') if kwargs is not None: # Escape the expression first expr = expr.replace('{', '{{').replace('}', '}}') expr = re.sub(r'(__kw_(\w+?)__)', r'{\2}', expr) expr = expr.format(**kwargs) # Put the stripped parentheses back code += f'({expr})' if self.with_clause: with_content = ','.join( f'{k}={v}' for (k, v) in self.with_clause.items() ) code += f' WITH ({with_content})' if self.predicate: code += f' WHERE {self.predicate}' return code @property def columns(self) -> list[str]: return list(self._columns) def get_type(self) -> str: return 'INDEX' def get_id(self) -> str: return qn(self.table_name[0], self.name_in_catalog) def get_oid(self) -> base.Query: qry = textwrap.dedent(f'''\ SELECT 'pg_class'::regclass::oid AS classoid, i.indexrelid AS objectoid, 0 FROM pg_class AS c INNER JOIN pg_namespace AS ns ON ns.oid = c.relnamespace INNER JOIN pg_index AS i ON i.indrelid = c.oid INNER JOIN pg_class AS ic ON i.indexrelid = ic.oid WHERE ic.relname = {ql(self.name_in_catalog)} AND ns.nspname = {ql(self.table_name[0])} AND c.relname = {ql(self.table_name[1])} ''') return base.Query(text=qry) def copy(self) -> Index: return self.__class__( name=self.name, table_name=self.table_name, unique=self.unique, exprs=self.exprs, predicate=self.predicate, columns=self.columns, metadata=( self.metadata.copy() if self.metadata is not None else None ) ) def __repr__(self) -> str: return \ '<%(mod)s.%(cls)s table=%(table)s name=%(name)s ' \ 'cols=(%(cols)s) unique=%(uniq)s predicate=%(pred)s>' % \ {'mod': self.__class__.__module__, 'cls': self.__class__.__name__, 'name': self.name, 'cols': ','.join('%r' % c for c in self.columns), 'uniq': self.unique, 'pred': self.predicate, 'table': '{}.{}'.format(*self.table_name)} class IndexExists(base.Condition): def __init__(self, index_name: tuple[str, str]): self.index_name = index_name def code(self) -> str: return textwrap.dedent(f'''\ SELECT i.indexrelid FROM pg_catalog.pg_index i INNER JOIN pg_catalog.pg_class ic ON ic.oid = i.indexrelid INNER JOIN pg_catalog.pg_namespace icn ON icn.oid = ic.relnamespace WHERE icn.nspname = {ql(self.index_name[0])} AND ic.relname = {ql(self.index_name[1])} ''') class CreateIndex(ddl.CreateObject): def __init__( self, index: Index, *, conditional: bool = False, builtin_conditional: bool = False, concurrently: bool = False, **kwargs: Any ) -> None: super().__init__(index, **kwargs) self.index = index self.concurrently = concurrently self.builtin_conditional = builtin_conditional if conditional: self.neg_conditions.add( IndexExists((index.table_name[0], index.name_in_catalog)) ) def code(self) -> str: return self.index.creation_code( concurrently=self.concurrently, conditional=self.builtin_conditional, ) class RenameIndex(ddl.RenameObject): def __init__(self, index, *, new_name, conditional=False, **kwargs): super().__init__(index, new_name=new_name, **kwargs) if conditional: self.conditions.add( IndexExists((index.table_name[0], index.name_in_catalog))) def code(self) -> str: name = qn(self.object.table_name[0], self.object.name_in_catalog) new_name = qi(self.altered_object.name_in_catalog) return f'ALTER INDEX {name} RENAME TO {new_name}' @classmethod def pl_code(cls, index_desc_var: str, block: base.PLBlock) -> str: index_name = ( f"(quote_ident({index_desc_var}.table_name[0])" f" || '.' || quote_ident({index_desc_var}.name))" ) new_name = ( f"quote_ident({index_desc_var}.name)" ) return ( f"EXECUTE 'ALTER INDEX ' || {index_name} " f"|| ' RENAME TO ' || {new_name};" ) class DropIndex(ddl.DropObject): def __init__( self, index: Index, *, conditional: bool = False, **kwargs: Any ): super().__init__(index, **kwargs) if conditional: self.conditions.add( IndexExists((index.table_name[0], index.name_in_catalog)) ) def code(self) -> str: assert isinstance(self.object, Index) name = qn(self.object.table_name[0], self.object.name_in_catalog) return f'DROP INDEX {name}' ================================================ FILE: edb/pgsql/dbops/operators.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import textwrap from ..common import quote_ident as qi from ..common import quote_literal as ql from ..common import quote_type as qt from . import base from . import ddl class CreateOperatorAlias(ddl.SchemaObjectOperation): def __init__( self, *, name, args, base_operator, operator_args, negator=None, commutator=None, procedure=None, **kwargs, ): super().__init__(name=name, **kwargs) self.args = args self.operator = base_operator self.operator_args = operator_args self.procedure = procedure self.commutator = commutator self.negator = negator def code_with_block(self, block: base.PLBlock) -> str: oper_var = block.declare_var(('pg_catalog', 'pg_operator%ROWTYPE')) oper_cond = [] oper_name = f'{qi(self.name[0])}.{self.name[1]}' if self.args[0] is not None: left_type_desc = qt(self.args[0]) left_type = f"', LEFTARG = {left_type_desc}'" oper_cond.append( f'o.oprleft = {ql(qt(self.operator_args[0]))}::regtype') else: left_type_desc = 'NONE' left_type = "''" oper_cond.append(f'o.oprleft = 0') if self.args[1] is not None: right_type_desc = qt(self.args[1]) right_type = f"', RIGHTARG = {right_type_desc}'" oper_cond.append( f'o.oprright = {ql(qt(self.operator_args[1]))}::regtype') else: right_type_desc = 'NONE' right_type = "''" oper_cond.append(f'o.oprright = 0') oper_desc = ( f'{qi(self.operator[0])}.{self.operator[1]} (' f'{left_type_desc}, {right_type_desc})' ).strip() if self.commutator: commutator_name = f'{qi(self.commutator[0])}.{self.commutator[1]}' commutator_decl = textwrap.indent(textwrap.dedent(f'''\ ', COMMUTATOR = OPERATOR({commutator_name})' '''), ' ' * 8).strip() commutator_cond = 'TRUE' else: commutator_decl = textwrap.indent(textwrap.dedent(f'''\ ', COMMUTATOR = ' || ( SELECT edgedb.raise( NULL::text, 'invalid_object_definition', msg => ( 'missing required commutator for operator ' || {ql(oper_name)} ) ) ) '''), ' ' * 8).strip() commutator_cond = 'FALSE' if self.negator: negator_name = f'{qi(self.negator[0])}.{self.negator[1]}' negator_decl = textwrap.indent(textwrap.dedent(f'''\ ', NEGATOR = OPERATOR({negator_name})' '''), ' ' * 8).strip() negator_cond = 'TRUE' else: negator_decl = textwrap.indent(textwrap.dedent(f'''\ ', NEGATOR = ' || ( SELECT edgedb.raise( NULL::text, 'invalid_object_definition', msg => ( 'missing required negator for operator ' || {ql(oper_name)} ) ) ) '''), ' ' * 8).strip() negator_cond = 'FALSE' def _get_op_field(field, oid): return textwrap.indent(textwrap.dedent(f'''\ (CASE WHEN {oid} != 0 THEN ', {field} = ' || ( SELECT 'OPERATOR(' || quote_ident(nspname) || '.' || oprname || ')' FROM pg_operator o INNER JOIN pg_namespace ns ON (o.oprnamespace = ns.oid) WHERE o.oid = {oid} ) ELSE '' END) '''), ' ' * 8).strip() code = textwrap.dedent('''\ SELECT o.* INTO {oper} FROM pg_operator o INNER JOIN pg_namespace ns ON (o.oprnamespace = ns.oid) WHERE o.oprname = {oper_name} AND ns.nspname = {oper_schema} AND {oper_cond} ; IF NOT FOUND THEN RAISE 'SQL operator does not exist: %', {oper_desc} USING ERRCODE = 'undefined_function'; END IF; EXECUTE 'CREATE OPERATOR {name} (' || 'PROCEDURE = ' || {procedure} || {left_type} || {right_type} || {commutator} || {negator} || {restrict} || {join} || {hashes} || {merges} || ')' ; ''').format_map({ 'name': oper_name, 'oper': oper_var, 'procedure': (ql(self.procedure) if self.procedure else f'{oper_var}.oprcode::text'), 'oper_schema': ql(self.operator[0]), 'oper_name': ql(self.operator[1]), 'oper_cond': ' AND '.join(oper_cond), 'oper_desc': ql(oper_desc), 'left_type': left_type, 'right_type': right_type, 'commutator': ( f"(CASE WHEN {oper_var}.oprcom != 0 OR {commutator_cond} THEN " f"{commutator_decl} " f"ELSE '' END)" ), 'negator': ( f"(CASE WHEN {oper_var}.oprnegate != 0 OR {negator_cond} THEN " f"{negator_decl} " f"ELSE '' END)" ), 'restrict': ( f"(CASE WHEN {oper_var}.oprrest != 0 THEN " f"', RESTRICT = ' || {oper_var}.oprrest::text " f"ELSE '' END)" ), 'join': ( f"(CASE WHEN {oper_var}.oprjoin != 0 THEN " f"', JOIN = ' || {oper_var}.oprjoin::text " f"ELSE '' END)" ), 'hashes': ( f"(CASE WHEN {oper_var}.oprcanhash THEN " f"', HASHES ' " f"ELSE '' END)" ), 'merges': ( f"(CASE WHEN {oper_var}.oprcanmerge THEN " f"', MERGES ' " f"ELSE '' END)" ), }) return code.strip() class CreateOperator(ddl.SchemaObjectOperation): def __init__(self, *, name, args, procedure, **kwargs): super().__init__(name=name, **kwargs) self.args = args self.procedure = procedure def code(self) -> str: if self.args[0] is not None: left_type_desc = qt(self.args[0]) left_type = f", LEFTARG = {left_type_desc}" else: left_type = "" if self.args[1] is not None: right_type_desc = qt(self.args[1]) right_type = f", RIGHTARG = {right_type_desc}" else: right_type = "" return textwrap.dedent(f'''\ CREATE OPERATOR {qi(self.name[0])}.{self.name[1]} ( PROCEDURE = {self.procedure} {left_type} {right_type} ); ''') class DropOperator(ddl.SchemaObjectOperation): def __init__(self, *, name, args, **kwargs): super().__init__(name=name, **kwargs) self.args = args def code(self) -> str: if self.args[0] is not None: left_type = qt(self.args[0]) else: left_type = 'NONE' if self.args[1] is not None: right_type = qt(self.args[1]) else: right_type = 'NONE' return textwrap.dedent(f'''\ DROP OPERATOR {qi(self.name[0])}.{self.name[1]} ( {left_type}, {right_type} ); ''') ================================================ FILE: edb/pgsql/dbops/ranges.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import textwrap from ..common import qname as qn from ..common import quote_literal as ql from . import base from . import ddl class RangeExists(base.Condition): def __init__(self, name): self.name = name def code(self) -> str: return textwrap.dedent(f'''\ SELECT t.typname FROM pg_catalog.pg_type t INNER JOIN pg_namespace nsp ON (t.typnamespace = nsp.oid) WHERE nsp.nspname = {ql(self.name[0])} AND t.typname = {ql(self.name[1])} AND t.typtype = 'r' ''') class Range(base.DBObject): def __init__(self, name, subtype, *, subtype_diff=None): super().__init__() self.name = name self.subtype = subtype self.subtype_diff = subtype_diff class CreateRange(ddl.SchemaObjectOperation): def __init__(self, range, *, conditions=None, neg_conditions=None): super().__init__( range.name, conditions=conditions, neg_conditions=neg_conditions) self.range = range def code(self) -> str: subs = [f'subtype = {qn(*self.range.subtype)}'] if self.range.subtype_diff is not None: subs.append(f'subtype_diff = {qn(*self.range.subtype_diff)}') subcommands = ', '.join(subs) return f'''\ CREATE TYPE {qn(*self.name)} AS RANGE ({subcommands}) ''' class DropRange(ddl.SchemaObjectOperation): def code(self) -> str: return f'DROP TYPE {qn(*self.name)}' ================================================ FILE: edb/pgsql/dbops/roles.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Iterable, Mapping, Optional, TypeAlias, ) import json import textwrap from ..common import quote_ident as qi from ..common import quote_literal as ql from . import base from . import ddl RoleName: TypeAlias = str class Role(base.DBObject): def __init__( self, name: RoleName, *, allow_login: bool | base.NotSpecifiedT = base.NotSpecified, allow_createdb: bool | base.NotSpecifiedT = base.NotSpecified, allow_createrole: bool | base.NotSpecifiedT = base.NotSpecified, password: None | str | base.NotSpecifiedT = base.NotSpecified, superuser: bool | base.NotSpecifiedT = base.NotSpecified, membership: Optional[Iterable[str]] = None, members: Optional[Iterable[str]] = None, metadata: Optional[Mapping[str, Any]] = None, ) -> None: super().__init__(metadata=metadata) self.name = name self.superuser = superuser self.allow_login = allow_login self.allow_createdb = allow_createdb self.allow_createrole = allow_createrole self.password = password self.membership = membership self.members = members def get_type(self) -> str: return 'ROLE' def get_id(self) -> str: return qi(self.name) class SingleRole(Role): def __init__( self, *, password: None | str | base.NotSpecifiedT = base.NotSpecified, metadata: Optional[Mapping[str, Any]] = None, ) -> None: super().__init__('current_user', password=password) self.single_role_metadata = metadata def get_id(self) -> str: return self.name class RoleExists(base.Condition): def __init__(self, name: RoleName): self.name = name def code(self) -> str: return textwrap.dedent(f'''\ SELECT rolname FROM pg_catalog.pg_roles WHERE rolname = {ql(self.name)} ''') class RoleCommand: object: Role def _role(self) -> str: return f'ROLE {self.object.get_id()}' def _attrs(self) -> str: attrs = [] attrmap = { 'superuser': 'SUPERUSER', 'allow_login': 'LOGIN', 'allow_createdb': 'CREATEDB', 'allow_createrole': 'CREATEROLE', } for objattr, stmtattr in attrmap.items(): attr = getattr(self.object, objattr) if attr is base.NotSpecified: continue elif attr: attrs.append(stmtattr) else: attrs.append(f'NO{stmtattr}') if self.object.password is None: attrs.append('PASSWORD NULL') elif self.object.password is not base.NotSpecified: attrs.append(f'PASSWORD {ql(self.object.password)}') return " ".join(attrs) class CreateRole(ddl.CreateObject, RoleCommand): def code(self) -> str: if self.object.membership: roles = ', '.join(qi(str(m)) for m in self.object.membership) membership = f'IN ROLE {roles}' else: membership = '' if self.object.members: roles = ', '.join(qi(str(m)) for m in self.object.members) members = f'ROLE {roles}' else: members = '' return f'CREATE {self._role()} {self._attrs()} {membership} {members}' class AlterRole(ddl.AlterObject, RoleCommand): def code(self) -> str: attrs = self._attrs() if attrs: return f'ALTER {self._role()} {attrs}' else: return '' def generate_extra(self, block: base.PLBlock) -> None: from .. import trampoline super().generate_extra(block) if getattr(self.object, 'single_role_metadata', None): value = json.dumps(self.object.single_role_metadata) query = base.Query(trampoline.fixup_query( f''' UPDATE edgedbinstdata_VER.instdata SET json = {ql(value)}::jsonb WHERE key = 'single_role_metadata' ''' )) block.add_command(query.code_with_block(block)) class DropRole(ddl.SchemaObjectOperation): def code(self) -> str: return f'DROP ROLE {qi(self.name)}' class AlterRoleAddMember(ddl.SchemaObjectOperation): def __init__( self, name: RoleName, member: str, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ): super().__init__( name, conditions=conditions, neg_conditions=neg_conditions ) self.member = member def code(self) -> str: return f'GRANT {qi(self.name)} TO {qi(self.member)}' class AlterRoleDropMember(ddl.SchemaObjectOperation): def __init__( self, name: RoleName, member: str, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: super().__init__( name, conditions=conditions, neg_conditions=neg_conditions ) self.member = member def code(self) -> str: return f'REVOKE {qi(self.name)} FROM {qi(self.member)}' class AlterRoleAddMembership(ddl.SchemaObjectOperation): def __init__( self, name: RoleName, membership: Iterable[str], *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ): super().__init__( name, conditions=conditions, neg_conditions=neg_conditions ) self.membership = membership def code(self) -> str: roles = ', '.join(qi(m) for m in self.membership) return f'GRANT {roles} TO {qi(self.name)}' ================================================ FILE: edb/pgsql/dbops/schemas.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import textwrap from ..common import quote_ident as qi from ..common import quote_literal as ql from . import base from . import ddl class SchemaExists(base.Condition): def __init__(self, name): self.name = name def code(self) -> str: return textwrap.dedent(f'''\ SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = {ql(self.name)} ''') class CreateSchema(ddl.DDLOperation): def __init__( self, name, *, conditions=None, neg_conditions=None, conditional=False ): super().__init__(conditions=conditions, neg_conditions=neg_conditions) self.name = name self.opid = name self.conditional = conditional def code(self) -> str: condition = "IF NOT EXISTS " if self.conditional else '' return f'CREATE SCHEMA {condition}{qi(self.name)}' def __repr__(self): return '' % (self.__class__.__name__, self.name) class DropSchema(ddl.DDLOperation): def __init__(self, name, *, conditions=None, neg_conditions=None): super().__init__(conditions=conditions, neg_conditions=neg_conditions) self.name = name def code(self) -> str: return f'DROP SCHEMA {qi(self.name)}' def __repr__(self): return '' % (self.__class__.__name__, self.name) ================================================ FILE: edb/pgsql/dbops/sequences.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from ..common import qname as qn from . import ddl class CreateSequence(ddl.SchemaObjectOperation): def __init__(self, name): super().__init__(name) def code(self) -> str: return f'CREATE SEQUENCE {qn(*self.name)}' class DropSequence(ddl.SchemaObjectOperation): def __init__(self, name): super().__init__(name) def code(self) -> str: return f'DROP SEQUENCE {qn(*self.name)}' ================================================ FILE: edb/pgsql/dbops/tables.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import collections import textwrap from typing import ( Iterable, Iterator, Optional, Sequence, TypeAlias, ) from edb.common import ordered from ..common import qname as qn from ..common import quote_ident as qi from ..common import quote_literal as ql from .. import ast as pgast from . import base from . import composites from . import constraints from . import ddl TableName: TypeAlias = tuple[str, ...] ColumnName: TypeAlias = str class Table(composites.CompositeDBObject): bases: ordered.OrderedSet[Table] constraints: ordered.OrderedSet[SingleTableConstraint] # Columns from bases and self all_columns: collections.OrderedDict[str, Column] def __init__( self, name: TableName, *, columns: Optional[Iterable[Column]] = None, bases: Optional[ordered.OrderedSet[Table]] = None, constraints: Optional[ordered.OrderedSet[SingleTableConstraint]] = None ) -> None: self.bases = ordered.OrderedSet(bases or []) self.constraints = ordered.OrderedSet(constraints or []) super().__init__(name, columns=columns) self.all_columns = collections.OrderedDict( (c.name, c) for c in self.iter_columns() ) def iter_columns( self, writable_only: bool = False, only_self: bool = False, ) -> Iterable[Column]: cols: collections.OrderedDict[ColumnName, Column] = ( collections.OrderedDict() ) cols.update( (c.name, c) for c in self._columns if not writable_only or not c.readonly ) if not only_self: for base in reversed(self.bases): cols.update( (name, bc) for name, bc in base.all_columns.items() if not writable_only or not bc.readonly ) return ordered.OrderedSet(cols.values()) def __iter__(self) -> Iterator[Column]: return iter(self._columns) def add_bases(self, iterable: Iterable[Table]) -> None: self.bases.update(iterable) self.all_columns = collections.OrderedDict( (c.name, c) for c in self.iter_columns() ) def add_columns(self, iterable: Iterable[Column]) -> None: super().add_columns(iterable) self.all_columns = collections.OrderedDict( (c.name, c) for c in self.iter_columns() ) def add_constraint(self, const: SingleTableConstraint) -> None: self.constraints.add(const) def get_column(self, name: ColumnName) -> Optional[Column]: return self.all_columns.get(name) def get_type(self) -> str: return 'TABLE' def get_id(self) -> str: return qn(*self.name) @property def record(self) -> composites.Record: return composites.Record( self.__class__.__name__ + '_record', list(self.all_columns), default=base.Default) @property def system_catalog(self) -> str: return 'pg_class' @property def oid_type(self) -> str: return 'regclass' def __repr__(self) -> str: return f'' class InheritableTableObject(base.InheritableDBObject): name: str @property def name_in_catalog(self) -> str: return self.name class Column(base.DBObject): def __init__( self, name: ColumnName, type: str | tuple[str, str], required: bool = False, default: Optional[str] = None, constraints: Sequence[ColumnConstraint] = (), readonly: bool = False, comment: Optional[str] = None, ) -> None: self.name = name self.type = type self.required = required self.default = default self.constraints = constraints self.readonly = readonly self.comment = comment def add_constraint(self, constraint: ColumnConstraint) -> None: self.constraints = list(self.constraints) + [constraint] def code(self, short: bool = False) -> str: code = f"{qi(self.name)} {self.type}" if not short: if self.required: code += ' NOT NULL' if self.default is not None: code += f' DEFAULT {self.default}' for c in self.constraints: code += ' ' + c.code() return code def generate_extra_composite( self, block: base.PLBlock, alter_table: base.CompositeCommandGroup ) -> None: if self.comment is not None: assert isinstance(alter_table, AlterTable) col = TableColumn(table_name=alter_table.name, column=self) cmd = ddl.Comment(object=col, text=self.comment) cmd.generate(block) def __repr__(self) -> str: return '<%s.%s "%s" %s>' % ( self.__class__.__module__, self.__class__.__name__, self.name, self.type) class TableColumn(base.DBObject): def __init__(self, table_name: TableName, column: Column) -> None: self.table_name = table_name self.column = column def get_type(self) -> str: return 'COLUMN' def get_id(self) -> str: return qn( self.table_name[0], self.table_name[1], self.column.name ) class ColumnConstraint: def __init__(self, constraint_name: str): self.constraint_name = constraint_name def code(self) -> str: raise NotImplementedError() class GeneratedConstraint(ColumnConstraint): def __init__(self, constraint_name: str, expr: str) -> None: super().__init__(constraint_name) self.expr = expr def code(self) -> str: return ( f'CONSTRAINT {self.constraint_name} ' f'GENERATED ALWAYS AS ({self.expr}) STORED' ) class TableConstraint(constraints.Constraint): def generate_extra(self, block: base.PLBlock) -> None: pass def get_subject_type(self) -> str: return '' # For table constraints the accepted syntax is # simply CONSTRAINT ON "{tab_name}", not # CONSTRAINT ON TABLE, unlike constraints on # other objects. class SingleTableConstraint(TableConstraint): def constraint_code(self, block: base.PLBlock) -> str: raise NotImplementedError() class PrimaryKey(SingleTableConstraint): def __init__( self, table_name: TableName, columns: Sequence[str | pgast.Star] ) -> None: super().__init__(table_name) self.columns = columns def constraint_code(self, block: base.PLBlock) -> str: cols = ', '.join(qi(c) for c in self.columns) return f'PRIMARY KEY ({cols})' class UniqueConstraint(SingleTableConstraint): def __init__( self, table_name: TableName, columns: Sequence[str | pgast.Star] ) -> None: super().__init__(table_name) self.columns = columns def constraint_code(self, block: base.PLBlock) -> str: cols = ', '.join(qi(c) for c in self.columns) return f'UNIQUE ({cols})' class CheckConstraint(SingleTableConstraint): def __init__( self, table_name: TableName, constraint_name: str, expr: base.Query | str, inherit: bool = True, ) -> None: super().__init__(table_name, constraint_name=constraint_name) self.expr = expr self.inherit = inherit def constraint_code(self, block: base.PLBlock) -> str: if isinstance(self.expr, base.Query): assert self.expr.type var = block.declare_var(self.expr.type) indent = len(var) + 5 expr_text = textwrap.indent(self.expr.text, ' ' * indent).strip() block.add_command(f'{var} := ({expr_text})') code = f"'CHECK (' || {var} || ')'" if not self.inherit: code += " || ' NO INHERIT'" code = base.PLExpression(code) else: code = f'CHECK ({self.expr})' if not self.inherit: code += ' NO INHERIT' return code class TableExists(base.Condition): def __init__(self, name: TableName) -> None: self.name = name def code(self) -> str: return textwrap.dedent(f'''\ SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname = {ql(self.name[0])} AND tablename = {ql(self.name[1])} ''') class TableInherits(base.Condition): def __init__( self, name: TableName, parent_name: TableName, ) -> None: self.name = name self.parent_name = parent_name def code(self) -> str: return textwrap.dedent(f'''\ SELECT c.relname FROM pg_class c INNER JOIN pg_namespace ns ON ns.oid = c.relnamespace INNER JOIN pg_inherits i ON i.inhrelid = c.oid INNER JOIN pg_class pc ON i.inhparent = pc.oid INNER JOIN pg_namespace pns ON pns.oid = pc.relnamespace WHERE ns.nspname = {ql(self.name[0])} AND c.relname = {ql(self.name[1])} AND pns.nspname = {ql(self.parent_name[0])} AND pc.relname = {ql(self.parent_name[1])} ''') class ColumnExists(base.Condition): def __init__( self, table_name: TableName, column_name: ColumnName, ) -> None: self.table_name = table_name self.column_name = column_name def code(self) -> str: return textwrap.dedent(f'''\ SELECT column_name FROM information_schema.columns WHERE table_schema = {ql(self.table_name[0])} AND table_name = {ql(self.table_name[1])} AND column_name = {ql(self.column_name)} ''') class ColumnIsInherited(base.Condition): def __init__( self, table_name: TableName, column_name: ColumnName, ) -> None: self.table_name = table_name self.column_name = column_name def code(self) -> str: return textwrap.dedent(f'''\ SELECT True FROM pg_class c INNER JOIN pg_namespace ns ON ns.oid = c.relnamespace INNER JOIN pg_attribute a ON a.attrelid = c.oid WHERE ns.nspname = {ql(self.table_name[0])} AND c.relname = {ql(self.table_name[1])} AND a.attname = {ql(self.column_name)} AND a.attinhcount > 0 ''') class CreateTable(ddl.SchemaObjectOperation): def __init__( self, table: Table, temporary: bool = False, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: super().__init__( table.name, conditions=conditions, neg_conditions=neg_conditions ) self.table = table self.temporary = temporary def code_with_block(self, block: base.PLBlock) -> str: elems = [ c.code() for c in self.table.iter_columns(only_self=True) ] for c in self.table.constraints: elems.append(c.constraint_code(block)) name = qn(*self.table.name) temp = 'TEMPORARY ' if self.temporary else '' chunks = [f'CREATE {temp}TABLE {name} (', ')'] if self.table.bases: bases = ','.join(qn(*b.name) for b in self.table.bases) chunks.append(f' INHERITS ({bases})') if any(isinstance(e, base.PLExpression) for e in elems): # Dynamic declaration elem_chunks: list[base.PLExpression | str] = [] for e in elems: if isinstance(e, base.PLExpression): elem_chunks.append(e) else: elem_chunks.append(ql(e)) chunks = [ql(c) for c in chunks] chunks.insert(1, " || ',' || ".join(elem_chunks)) code = 'EXECUTE ' + ' || '.join(chunks) else: # Static declaration chunks.insert(1, ', '.join(elems)) code = ''.join(chunks) return code class AlterTableBaseMixin: name: TableName contained: bool def __init__( self, name: TableName, contained: bool = False ) -> None: self.name = name self.contained = contained def prefix_code(self) -> str: return 'ALTER TABLE %s%s' % ( 'ONLY ' if self.contained else '', qn(*self.name)) def __repr__(self) -> str: return '<%s.%s %s>' % ( self.__class__.__module__, self.__class__.__name__, self.name) class AlterTableBase(AlterTableBaseMixin, ddl.DDLOperation): def __init__( self, name: TableName, *, contained: bool = False, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: ddl.DDLOperation.__init__( self, conditions=conditions, neg_conditions=neg_conditions) AlterTableBaseMixin.__init__(self, name=name, contained=contained) def get_attribute_term(self) -> str: return 'COLUMN' class AlterTableFragment(ddl.DDLOperation, base.CompositeCommand): def get_attribute_term(self) -> str: return 'COLUMN' def generate_extra_composite( self, block: base.PLBlock, group: base.CompositeCommandGroup ) -> None: pass class AlterTable( AlterTableBaseMixin, ddl.DDLOperation, base.CompositeCommandGroup ): def __init__( self, name: TableName, *, contained: bool = False, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ): base.CompositeCommandGroup.__init__( self, conditions=conditions, neg_conditions=neg_conditions) AlterTableBaseMixin.__init__(self, name=name, contained=contained) self.ops = self.commands add_operation = base.CompositeCommandGroup.add_command class AlterTableAddParent(AlterTableFragment): def __init__(self, parent_name: TableName, **kwargs) -> None: super().__init__(**kwargs) self.parent_name = parent_name def code(self) -> str: return f'INHERIT {qn(*self.parent_name)}' def __repr__(self) -> str: return '<%s.%s %s>' % ( self.__class__.__module__, self.__class__.__name__, self.parent_name) class AlterTableDropParent(AlterTableFragment): def __init__(self, parent_name: TableName): self.parent_name = parent_name def code(self) -> str: return f'NO INHERIT {qn(*self.parent_name)}' def __repr__(self) -> str: return '<%s.%s %s>' % ( self.__class__.__module__, self.__class__.__name__, self.parent_name) class AlterTableAddColumn( # type: ignore composites.AlterCompositeAddAttribute, AlterTableFragment): pass class AlterTableDropColumn( composites.AlterCompositeDropAttribute, AlterTableFragment): pass class AlterTableAlterColumnType( composites.AlterCompositeAlterAttributeType, AlterTableFragment): pass class AlterTableAlterColumnNull(AlterTableFragment): def __init__(self, column_name: ColumnName, null) -> None: self.column_name = column_name self.null = null def code(self) -> str: action = 'DROP' if self.null else 'SET' return f'ALTER COLUMN {qi(self.column_name)} {action} NOT NULL' def __repr__(self) -> str: return '<{}.{} "{}" {} NOT NULL>'.format( self.__class__.__module__, self.__class__.__name__, self.column_name, 'DROP' if self.null else 'SET') class AlterTableAlterColumnDefault(AlterTableFragment): def __init__(self, column_name: ColumnName, default: Optional[str]): self.column_name = column_name self.default = default def code(self) -> str: if self.default is None: return f'ALTER COLUMN {qi(self.column_name)} DROP DEFAULT' else: return (f'ALTER COLUMN {qi(self.column_name)} ' f'SET DEFAULT {self.default}') def __repr__(self) -> str: return '<{}.{} "{}" {} DEFAULT{}>'.format( self.__class__.__module__, self.__class__.__name__, self.column_name, 'DROP' if self.default is None else 'SET', '' if self.default is None else ' {!r}'.format(self.default)) class TableConstraintCommand: pass class TableConstraintExists(base.Condition): def __init__(self, table_name: TableName, constraint_name: str): self.table_name = table_name self.constraint_name = constraint_name def code(self) -> str: return textwrap.dedent(f'''\ SELECT True FROM pg_catalog.pg_constraint c INNER JOIN pg_catalog.pg_class t ON c.conrelid = t.oid INNER JOIN pg_catalog.pg_namespace ns ON t.relnamespace = ns.oid WHERE conname = {ql(self.constraint_name)} AND nspname = {ql(self.table_name[0])} AND relname = {ql(self.table_name[1])} ''') class AlterTableAddConstraint[TableConstraint_T: "TableConstraint"]( AlterTableFragment, TableConstraintCommand, ): constraint: TableConstraint_T def __init__(self, constraint: TableConstraint_T): assert not isinstance(constraint, list) self.constraint = constraint def code_with_block(self, block: base.PLBlock) -> str: code = 'ADD ' name = self.constraint.constraint_name() if name: code += f'CONSTRAINT {name} ' constr_code = self.constraint.constraint_code(block) assert isinstance(constr_code, str) if not isinstance(constr_code, base.PLExpression): # Static declaration return code + constr_code else: # Dynamic declaration return base.PLExpression(f'{ql(code)} || {constr_code}') def generate_extra_composite( self, block: base.PLBlock, group: base.CompositeCommandGroup ) -> None: return self.constraint.generate_extra(block) def __repr__(self) -> str: return '<%s.%s %r>' % ( self.__class__.__module__, self.__class__.__name__, self.constraint) class AlterTableDropConstraint(AlterTableFragment, TableConstraintCommand): def __init__(self, constraint) -> None: self.constraint = constraint def code(self) -> str: return f'DROP CONSTRAINT {self.constraint.constraint_name()}' def __repr__(self) -> str: return '<%s.%s %r>' % ( self.__class__.__module__, self.__class__.__name__, self.constraint) class DropTable(ddl.SchemaObjectOperation): def code(self) -> str: return f'DROP TABLE {qn(*self.name)}' ================================================ FILE: edb/pgsql/dbops/triggers.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import textwrap from typing import ( Any, Mapping, Optional, TypeAlias, ) from ...common import enum as s_enum from ..common import qname as qn from ..common import quote_ident as qi from ..common import quote_literal as ql from . import base from . import ddl from . import tables TriggerName: TypeAlias = str class TriggerTiming(s_enum.StrEnum): Before = 'before' After = 'after' class TriggerGranularity(s_enum.StrEnum): Row = 'row' class TriggerExists(base.Condition): def __init__( self, trigger_name: TriggerName, table_name: tables.TableName ) -> None: self.trigger_name = trigger_name self.table_name = table_name def code(self) -> str: return textwrap.dedent( f'''\ SELECT tg.tgname FROM pg_catalog.pg_trigger tg INNER JOIN pg_catalog.pg_class tab ON (tab.oid = tg.tgrelid) INNER JOIN pg_catalog.pg_namespace ns ON (ns.oid = tab.relnamespace) WHERE tab.relname = {ql(self.table_name[1])} AND ns.nspname = {ql(self.table_name[0])} AND tg.tgname = {ql(self.trigger_name)} ''' ) class Trigger(tables.InheritableTableObject): def __init__( self, name: TriggerName, *, table_name: tables.TableName, events: tuple[str, ...], timing: TriggerTiming = TriggerTiming.After, granularity: TriggerGranularity = TriggerGranularity.Row, procedure, condition=None, is_constraint: bool = False, deferred: bool = False, inherit: bool = False, metadata: Optional[Mapping[str, Any]] = None, ) -> None: super().__init__(inherit=inherit, metadata=metadata) self.name = name self.table_name = table_name self.events = events self.timing = timing self.granularity = granularity self.procedure = procedure self.condition = condition self.is_constraint = is_constraint self.deferred = deferred if is_constraint and granularity != TriggerGranularity.Row: msg = 'invalid granularity for ' 'constraint trigger: {}'.format( granularity ) raise ValueError(msg) if deferred and not is_constraint: raise ValueError('only constraint triggers can be deferred') def get_type(self) -> str: return 'TRIGGER' def get_id(self) -> str: return f'{qi(self.name)} ON {qn(*self.table_name)}' def get_oid(self) -> base.Query: qry = textwrap.dedent( f'''\ SELECT 'pg_trigger'::regclass::oid AS classoid, pg_trigger.oid AS objectoid, 0 FROM pg_trigger INNER JOIN pg_class ON tgrelid = pg_class.oid INNER JOIN pg_namespace ON relnamespace = pg_namespace.oid WHERE tgname = {ql(self.name)} AND nspname = {ql(self.table_name[0])} AND relname = {ql(self.table_name[1])} ''' ) return base.Query(text=qry) def copy(self) -> Trigger: return self.__class__( name=self.name, table_name=self.table_name, events=self.events, timing=self.timing, granularity=self.granularity, procedure=self.procedure, condition=self.condition, is_constraint=self.is_constraint, deferred=self.deferred, metadata=( self.metadata.copy() if self.metadata is not None else None ), ) def __repr__(self) -> str: return '<{mod}.{cls} {name} ON {table_name} {timing} {events}>'.format( mod=self.__class__.__module__, cls=self.__class__.__name__, name=self.name, table_name=qn(*self.table_name), timing=self.timing, events=' OR '.join(self.events), ) class CreateTrigger(ddl.CreateObject): def __init__( self, object: Trigger, *, conditional: bool = False, **kwargs: Any, ) -> None: super().__init__(object, **kwargs) self.trigger = object if conditional: self.neg_conditions.add( TriggerExists(self.trigger.name, self.trigger.table_name) ) def code(self) -> str: return textwrap.dedent( '''\ CREATE {constr}TRIGGER {trigger_name} {timing} {events} ON {table_name} {deferred} FOR EACH {granularity} {condition} EXECUTE PROCEDURE {procedure} ''' ).format( constr='CONSTRAINT ' if self.trigger.is_constraint else '', trigger_name=qi(self.trigger.name), timing=self.trigger.timing, events=' OR '.join(self.trigger.events), table_name=qn(*self.trigger.table_name), deferred=( 'DEFERRABLE INITIALLY DEFERRED' if self.trigger.deferred else '' ), granularity=self.trigger.granularity, condition=( f'WHEN ({self.trigger.condition})' if self.trigger.condition else '' ), procedure=f'{qn(*self.trigger.procedure)}()', ) class DropTrigger(ddl.DropObject): def __init__( self, object: Trigger, *, conditional: bool = False, **kwargs: Any, ) -> None: super().__init__(object, **kwargs) self.trigger = object self.conditional = conditional if conditional: self.conditions.add( TriggerExists(self.trigger.name, self.trigger.table_name) ) def code(self) -> str: ifexists = ' IF EXISTS' if self.conditional else '' return ( f'DROP TRIGGER{ifexists} {qi(self.trigger.name)} ' f'ON {qn(*self.trigger.table_name)}' ) class DisableTrigger(ddl.DDLOperation): def __init__( self, trigger: Trigger, *, self_only: bool = False, **kwargs: Any, ): super().__init__(**kwargs) self.trigger = trigger self.self_only = self_only def code(self) -> str: only = ' ONLY' if self.self_only else '' return ( f'ALTER TABLE{only} {qn(*self.trigger.table_name)} ' f'DISABLE TRIGGER {qi(self.trigger.name)}' ) def __repr__(self) -> str: return '<{mod}.{cls} {trigger!r}>'.format( mod=self.__class__.__module__, cls=self.__class__.__name__, trigger=self.trigger, ) class EnableTrigger(ddl.DDLOperation): def __init__( self, trigger: Trigger, **kwargs: Any, ): super().__init__(**kwargs) self.trigger = trigger def code(self) -> str: return ( f'ALTER TABLE {qn(*self.trigger.table_name)} ' f'ENABLE TRIGGER {qi(self.trigger.name)}' ) def __repr__(self) -> str: return '<{mod}.{cls} {trigger!r}>'.format( mod=self.__class__.__module__, cls=self.__class__.__name__, trigger=self.trigger, ) ================================================ FILE: edb/pgsql/dbops/types.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Collection, Iterable, Iterator, Optional, TypeAlias, ) import textwrap from edb.common import ordered from ..common import qname as qn from ..common import quote_literal as ql from . import base from . import composites from . import ddl from . import tables CompositeTypeName: TypeAlias = tuple[str, str] class CompositeType(composites.CompositeDBObject): def __init__( self, name: CompositeTypeName, columns: Collection[tables.Column] = (), ): super().__init__(name) self._columns = ordered.OrderedSet(columns) def iter_columns(self) -> Iterator[tables.Column]: return iter(self._columns) class TypeExists(base.Condition): def __init__(self, name: CompositeTypeName): self.name = name def code(self) -> str: return textwrap.dedent(f'''\ SELECT typname FROM pg_catalog.pg_type typ INNER JOIN pg_catalog.pg_namespace nsp ON nsp.oid = typ.typnamespace WHERE nsp.nspname = {ql(self.name[0])} AND typ.typname = {ql(self.name[1])} ''') def type_oid(name: CompositeTypeName) -> base.Query: if len(name) == 2: typnamespace, typname = name else: typname = name[0] typnamespace = 'pg_catalog' qry = textwrap.dedent(f'''\ SELECT typ.oid FROM pg_catalog.pg_type typ INNER JOIN pg_catalog.pg_namespace nsp ON nsp.oid = typ.typnamespace WHERE typ.typname = {ql(typname)} AND nsp.nspname = {ql(typnamespace)} ''') return base.Query(qry) CompositeTypeExists = TypeExists class CompositeTypeAttributeExists(base.Condition): def __init__( self, type_name: CompositeTypeName, attribute_name: str, ): self.type_name = type_name self.attribute_name = attribute_name def code(self) -> str: return textwrap.dedent(f'''\ SELECT attribute_name FROM information_schema.attributes WHERE udt_schema = {ql(self.type_name[0])} AND udt_name = {ql(self.type_name[1])} AND attribute_name = {ql(self.attribute_name)} ''') class CreateCompositeType(ddl.SchemaObjectOperation): def __init__( self, type: CompositeType, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: super().__init__( type.name, conditions=conditions, neg_conditions=neg_conditions ) self.type = type def code(self) -> str: elems = [c.code(short=True) for c in self.type.iter_columns()] name = qn(*self.type.name) cols = ', '.join(c for c in elems) return f'CREATE TYPE {name} AS ({cols})' class AlterCompositeTypeBaseMixin: def __init__(self, name: CompositeTypeName, **kwargs: Any): self.name = name def prefix_code(self) -> str: return f'ALTER TYPE {qn(*self.name)}' def __repr__(self) -> str: return '<%s.%s %s>' % ( self.__class__.__module__, self.__class__.__name__, self.name) class AlterCompositeTypeBase(AlterCompositeTypeBaseMixin, ddl.DDLOperation): def __init__( self, name: CompositeTypeName, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: ddl.DDLOperation.__init__( self, conditions=conditions, neg_conditions=neg_conditions) AlterCompositeTypeBaseMixin.__init__(self, name=name) class AlterCompositeTypeFragment(ddl.DDLOperation): def get_attribute_term(self) -> str: return 'ATTRIBUTE' class AlterCompositeType( AlterCompositeTypeBaseMixin, base.CompositeCommandGroup ): def __init__( self, name: CompositeTypeName, *, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ) -> None: base.CompositeCommandGroup.__init__( self, conditions=conditions, neg_conditions=neg_conditions) AlterCompositeTypeBaseMixin.__init__(self, name=name) class AlterCompositeTypeAddAttribute( # type: ignore composites.AlterCompositeAddAttribute, AlterCompositeTypeFragment ): def code(self) -> str: return 'ADD {} {}'.format( self.get_attribute_term(), self.attribute.code(short=True)) class AlterCompositeTypeDropAttribute( composites.AlterCompositeDropAttribute, AlterCompositeTypeFragment): pass class AlterCompositeTypeAlterAttributeType( composites.AlterCompositeAlterAttributeType, AlterCompositeTypeFragment): pass class DropCompositeType(ddl.SchemaObjectOperation): def __init__( self, name: CompositeTypeName, *, cascade: bool = False, conditions: Optional[Iterable[str | base.Condition]] = None, neg_conditions: Optional[Iterable[str | base.Condition]] = None, ): super().__init__( name, conditions=conditions, neg_conditions=neg_conditions ) self.cascade = cascade def code(self) -> str: cascade = ' CASCADE' if self.cascade else '' return f'DROP TYPE {qn(*self.name)}{cascade}' ================================================ FILE: edb/pgsql/dbops/views.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import textwrap from ..common import qname as qn from ..common import quote_literal as ql from . import base from . import ddl class View(base.DBObject): def __init__(self, name, query, materialized=False): super().__init__() self.name = name self.query = query self.materialized = materialized def get_type(self) -> str: return "VIEW" if not self.materialized else "MATERIALIZED VIEW" def get_id(self): return qn(*self.name) class CreateView(ddl.SchemaObjectOperation): def __init__( self, view, *, conditions=None, neg_conditions=None, or_replace=False, ): super().__init__(view.name, conditions=conditions, neg_conditions=neg_conditions) self.view = view self.or_replace = or_replace def code(self) -> str: query = textwrap.indent(textwrap.dedent(self.view.query), ' ') return ( f'CREATE {"OR REPLACE" if self.or_replace else ""}' f' {self.view.get_type()} {qn(*self.view.name)} AS\n{query}' ) class DropView(ddl.SchemaObjectOperation): def __init__( self, name, *, conditional=False, conditions=None, neg_conditions=None, materialized=False, ): super().__init__( name, conditions=conditions, neg_conditions=neg_conditions, ) self.conditional = conditional self.materialized = materialized def code(self) -> str: mat = 'MATERIALIZED ' if self.materialized else '' if self.conditional: return f'DROP {mat}VIEW IF EXISTS {qn(*self.name)}' else: return f'DROP {mat}VIEW {qn(*self.name)}' class ViewExists(base.Condition): def __init__(self, name, materialized=False): self.name = name self.materialized = materialized def code(self) -> str: mat = 'mat' if self.materialized else '' return textwrap.dedent(f'''\ SELECT {mat}viewname FROM pg_catalog.pg_{mat}views WHERE schemaname = {ql(self.name[0])} AND {mat}viewname = {ql(self.name[1])} ''') ================================================ FILE: edb/pgsql/debug.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import uuid from edb.common import debug from edb.common import uuidgen from edb.schema import functions as s_funcs from edb.schema import objects as so from edb.schema import pointers as s_pointers from edb.schema import schema as s_schema from edb.pgsql import ast as pgast from edb.pgsql import codegen as pgcodegen from edb.ir import ast as irast def dump_ast_and_query( pg_expr: pgast.Base, ir_expr: irast.Base, ) -> None: if not ( debug.flags.edgeql_compile or debug.flags.edgeql_compile_sql_ast or debug.flags.edgeql_compile_sql_reordered_text or debug.flags.edgeql_compile_sql_text ): return if debug.flags.edgeql_compile or debug.flags.edgeql_compile_sql_ast: debug.header('SQL Tree') debug.dump( pg_expr, _ast_include_meta=debug.flags.edgeql_compile_sql_ast_meta ) if debug.flags.edgeql_compile or debug.flags.edgeql_compile_sql_text: sql_text = pgcodegen.generate_source(pg_expr, pretty=True) debug.header('SQL') debug.dump_code(sql_text, lexer='sql') if debug.flags.edgeql_compile_sql_reordered_text: debug.header('Reordered SQL') debug_sql_text = pgcodegen.generate_source( pg_expr, pretty=True, reordered=True ) if isinstance(ir_expr, irast.Statement): debug_sql_text = _rewrite_names_in_sql( debug_sql_text, ir_expr.schema ) debug.dump_code(debug_sql_text, lexer='sql') def _rewrite_names_in_sql(text: str, schema: s_schema.Schema) -> str: """Rewrite the SQL output of the compiler to include real object names. Replace UUIDs with object names when possible. The output of this won't be valid, but will probably be easier to read. This is done by default when pretty printing our "reordered" output, which isn't anything like valid SQL anyway. """ # Functions are actually named after their `backend_name` rather # than their id, so that overloaded functions all have the same # name. Build a map from `backend_name` to real names. (This dict # comprehension might have collisions, but that's fine; the names # we get out will be the same no matter which is picked.) func_map = { f.get_backend_name(schema): f for f in schema.get_objects(type=s_funcs.Function) } # Find all the uuids and try to rewrite them. for m in set(uuidgen.UUID_RE.findall(text)): uid = uuid.UUID(m) sobj = schema.get_by_id(uid, default=None) if not sobj: sobj = func_map.get(uid) if sobj: s = _obj_to_name(sobj, schema) text = text.replace(m, s) return text def _obj_to_name( sobj: so.Object, schema: s_schema.Schema, ) -> str: if isinstance(sobj, s_pointers.Pointer): s = str(sobj.get_shortname(schema).name) if sobj.is_link_property(schema): s = f'@{s}' # If the pointer is multi, then it is probably a table name, # so let's give a fully qualified version with the source. if sobj.get_cardinality(schema).is_multi() and ( src := sobj.get_source(schema) ): src_name = src.get_name(schema) s = f'{src_name}.{s}' elif isinstance(sobj, s_funcs.Function): return str(sobj.get_shortname(schema)) else: s = str(sobj.get_name(schema)) return s ================================================ FILE: edb/pgsql/delta.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Callable, Optional, Iterable, Mapping, Sequence, cast, TYPE_CHECKING, ) from copy import copy import collections.abc import itertools import textwrap import uuid from edb import errors from edb.edgeql import ast as ql_ast from edb.edgeql import qltypes as ql_ft from edb.edgeql import compiler as qlcompiler from edb.schema import annos as s_anno from edb.schema import casts as s_casts from edb.schema import scalars as s_scalars from edb.schema import objtypes as s_objtypes from edb.schema import constraints as s_constr from edb.schema import database as s_db from edb.schema import delta as sd from edb.schema import expr as s_expr from edb.schema import expraliases as s_aliases from edb.schema import extensions as s_exts from edb.schema import futures as s_futures from edb.schema import functions as s_funcs from edb.schema import globals as s_globals from edb.schema import indexes as s_indexes from edb.schema import links as s_links from edb.schema import permissions as s_permissions from edb.schema import policies as s_policies from edb.schema import properties as s_props from edb.schema import migrations as s_migrations from edb.schema import modules as s_mod from edb.schema import name as sn from edb.schema import objects as so from edb.schema import operators as s_opers from edb.schema import pointers as s_pointers from edb.schema import pseudo as s_pseudo from edb.schema import roles as s_roles from edb.schema import rewrites as s_rewrites from edb.schema import sources as s_sources from edb.schema import triggers as s_triggers from edb.schema import types as s_types from edb.schema import version as s_ver from edb.schema import utils as s_utils from edb.common import markup from edb.common import ordered from edb.common import uuidgen from edb.common import parsing from edb.common.typeutils import not_none from edb.ir import ast as irast from edb.ir import pathid as irpathid from edb.ir import typeutils as irtyputils from edb.ir import utils as irutils from edb.pgsql import common from edb.pgsql import debug as pg_debug from edb.pgsql import dbops from edb.pgsql import params from edb.pgsql import deltafts from edb.pgsql import delta_ext_ai from edb.server import defines as edbdef from edb.server import config from edb.server.config import ops as config_ops from edb.server.compiler import sertypes from . import ast as pgast from .common import qname as q from .common import quote_literal as ql from .common import quote_ident as qi from .common import quote_type as qt from .common import versioned_schema as V from .compiler import enums as pgce from . import compiler from . import codegen from . import schemamech from . import trampoline from . import types if TYPE_CHECKING: from edb.schema import schema as s_schema DEFAULT_INDEX_CODE = ' ((__col__) NULLS FIRST)' class CommandMeta(sd.CommandMeta): pass class MetaCommand(sd.Command, metaclass=CommandMeta): pgops: ordered.OrderedSet[dbops.Command | sd.Command] def __init__(self, **kwargs): super().__init__(**kwargs) self.pgops = ordered.OrderedSet() def apply_prerequisites( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply_prerequisites(schema, context) for op in self.get_prerequisites(): if not isinstance(op, sd.AlterObjectProperty): self.pgops.add(op) return schema def apply_subcommands( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply_subcommands(schema, context) for op in self.get_subcommands( include_prerequisites=False, include_caused=False, ): if not isinstance(op, sd.AlterObjectProperty): self.pgops.add(op) return schema def apply_caused( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply_caused(schema, context) for op in self.get_caused(): if not isinstance(op, sd.AlterObjectProperty): self.pgops.add(op) return schema def generate(self, block: dbops.PLBlock) -> None: for op in self.pgops: assert isinstance(op, (dbops.Command, MetaCommand)) op.generate(block) @classmethod def as_markup(cls, self, *, ctx): node = markup.elements.lang.TreeNode(name=str(self)) for dd in self.ops: if isinstance(dd, AlterObjectProperty): diff = markup.elements.doc.ValueDiff( before=repr(dd.old_value), after=repr(dd.new_value)) if dd.new_inherited: diff.comment = 'inherited' elif dd.new_computed: diff.comment = 'computed' node.add_child(label=dd.property, node=diff) for dd in self.pgops: node.add_child(node=markup.serialize(dd, ctx=ctx)) return node def _get_backend_params( self, context: sd.CommandContext, ) -> params.BackendRuntimeParams: ctx_backend_params = context.backend_runtime_params if ctx_backend_params is not None: backend_params = cast( params.BackendRuntimeParams, ctx_backend_params) else: backend_params = params.get_default_runtime_params() return backend_params def _get_instance_params( self, context: sd.CommandContext, ) -> params.BackendInstanceParams: return self._get_backend_params(context).instance_params def _get_tenant_id(self, context: sd.CommandContext) -> str: return self._get_instance_params(context).tenant_id def _get_topmost_command_op( self, context: sd.CommandContext, ctxcls: type[sd.CommandContextToken[sd.Command]], ) -> CompositeMetaCommand: ctx = context.get_topmost_ancestor(ctxcls) if ctx is None: raise AssertionError(f"there is no {ctxcls} in context stack") assert isinstance(ctx.op, CompositeMetaCommand) return ctx.op def schedule_constraint_trigger_update( self, constraint: s_constr.Constraint, schema: s_schema.Schema, context: sd.CommandContext, ctxcls: type[sd.CommandContextToken[sd.Command]], ) -> None: if ( not isinstance( constraint.get_subject(schema), (s_objtypes.ObjectType, s_pointers.Pointer) ) or not schemamech.table_constraint_requires_triggers( constraint, schema, 'unique' ) ): return op = self._get_topmost_command_op(context, ctxcls) op.constraint_trigger_updates.add(constraint.id) @staticmethod def get_function_type( name: tuple[str, str] ) -> type[dbops.Function] | type[trampoline.VersionedFunction]: return ( trampoline.VersionedFunction if name[0] == 'edgedbstd' else dbops.Function ) @classmethod def maybe_trampoline( cls, f: Optional[dbops.Function], context: sd.CommandContext, ) -> None: if isinstance(f, trampoline.VersionedFunction): create = trampoline.make_trampoline(f) ctx = not_none(context.get(sd.DeltaRootContext)) assert isinstance(ctx.op, DeltaRoot) create_trampolines = ctx.op.create_trampolines create_trampolines.trampolines.append(create) class CommandGroupAdapted(MetaCommand, adapts=sd.CommandGroup): pass class Nop(MetaCommand, adapts=sd.Nop): pass class Query(MetaCommand, adapts=sd.Query): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) assert self.expr.irast sql_res = compiler.compile_ir_to_sql_tree( self.expr.irast, output_format=compiler.OutputFormat.NATIVE_INTERNAL, explicit_top_cast=irtyputils.type_to_typeref( schema, schema.get('std::str', type=s_types.Type), cache=None, ), backend_runtime_params=context.backend_runtime_params, ) sql_text = codegen.generate_source(sql_res.ast) # The INTO _dummy_text bit is needed because PL/pgSQL _really_ # wants the result of a returning query to be stored in a variable, # and the PERFORM hack does not work if the query has DML CTEs. self.pgops.add(dbops.Query( text=f'{sql_text} INTO _dummy_text', )) return schema class AlterObjectProperty(MetaCommand, adapts=sd.AlterObjectProperty): pass class SchemaVersionCommand(MetaCommand): pass class CreateSchemaVersion( SchemaVersionCommand, adapts=s_ver.CreateSchemaVersion, ): pass class AlterSchemaVersion( SchemaVersionCommand, adapts=s_ver.AlterSchemaVersion, ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) expected_ver = self.get_orig_attribute_value('version') check = dbops.Query( f''' SELECT edgedb_VER.raise_on_not_null( (SELECT NULLIF( (SELECT version::text FROM {V('edgedb')}."_SchemaSchemaVersion" FOR UPDATE), {ql(str(expected_ver))} )), 'serialization_failure', msg => ( 'Cannot serialize DDL: ' || (SELECT version::text FROM {V('edgedb')}."_SchemaSchemaVersion") ) ) INTO _dummy_text ''' ) self.pgops.add(check) return schema class GlobalSchemaVersionCommand(MetaCommand): pass class CreateGlobalSchemaVersion( MetaCommand, adapts=s_ver.CreateGlobalSchemaVersion, ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) ver_id = str(self.scls.id) ver_name = str(self.scls.get_name(schema)) ctx_backend_params = context.backend_runtime_params if ctx_backend_params is not None: backend_params = cast( params.BackendRuntimeParams, ctx_backend_params) else: backend_params = params.get_default_runtime_params() metadata = { ver_id: { 'id': ver_id, 'name': ver_name, 'version': str(self.scls.get_version(schema)), 'builtin': self.scls.get_builtin(schema), 'internal': self.scls.get_internal(schema), } } if backend_params.has_create_database: self.pgops.add( dbops.UpdateMetadataSection( dbops.DatabaseWithTenant(name=edbdef.EDGEDB_TEMPLATE_DB), section='GlobalSchemaVersion', metadata=metadata ) ) else: self.pgops.add( dbops.UpdateSingleDBMetadataSection( edbdef.EDGEDB_TEMPLATE_DB, section='GlobalSchemaVersion', metadata=metadata ) ) return schema class AlterGlobalSchemaVersion( MetaCommand, adapts=s_ver.AlterGlobalSchemaVersion, ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) ver_id = str(self.scls.id) ver_name = str(self.scls.get_name(schema)) ctx_backend_params = context.backend_runtime_params if ctx_backend_params is not None: backend_params = cast( params.BackendRuntimeParams, ctx_backend_params) else: backend_params = params.get_default_runtime_params() if not backend_params.has_create_database: key = f'{edbdef.EDGEDB_TEMPLATE_DB}metadata' lock = dbops.Query( trampoline.fixup_query(f''' SELECT json FROM edgedbinstdata_VER.instdata WHERE key = {ql(key)} FOR UPDATE INTO _dummy_text ''' )) elif backend_params.has_superuser_access: # Only superusers are generally allowed to make an UPDATE # lock on shared catalogs. lock = dbops.Query( f''' SELECT description FROM pg_catalog.pg_shdescription WHERE objoid = ( SELECT oid FROM pg_database WHERE datname = {V('edgedb')}.get_database_backend_name( {ql(edbdef.EDGEDB_TEMPLATE_DB)}) ) AND classoid = 'pg_database'::regclass::oid FOR UPDATE INTO _dummy_text ''' ) else: # Without superuser access we have to resort to lock polling. # This is racy, but is unfortunately the best we can do. lock = dbops.Query(f''' SELECT edgedb_VER.raise_on_not_null( ( SELECT 'locked' FROM pg_catalog.pg_locks WHERE locktype = 'object' AND classid = 'pg_database'::regclass::oid AND objid = ( SELECT oid FROM pg_database WHERE datname = {V('edgedb')}.get_database_backend_name( {ql(edbdef.EDGEDB_TEMPLATE_DB)}) ) AND mode = 'ShareUpdateExclusiveLock' AND granted AND pid != pg_backend_pid() ), 'serialization_failure', msg => ( 'Cannot serialize global DDL: ' || (SELECT version::text FROM {V('edgedb')}."_SysGlobalSchemaVersion") ) ) INTO _dummy_text ''') self.pgops.add(lock) expected_ver = self.get_orig_attribute_value('version') check = dbops.Query( f''' SELECT edgedb_VER.raise_on_not_null( (SELECT NULLIF( (SELECT version::text FROM {V('edgedb')}."_SysGlobalSchemaVersion" ), {ql(str(expected_ver))} )), 'serialization_failure', msg => ( 'Cannot serialize global DDL: ' || (SELECT version::text FROM {V('edgedb')}."_SysGlobalSchemaVersion") ) ) INTO _dummy_text ''' ) self.pgops.add(check) metadata = { ver_id: { 'id': ver_id, 'name': ver_name, 'version': str(self.scls.get_version(schema)), 'builtin': self.scls.get_builtin(schema), 'internal': self.scls.get_internal(schema), } } if backend_params.has_create_database: self.pgops.add( dbops.UpdateMetadataSection( dbops.DatabaseWithTenant(name=edbdef.EDGEDB_TEMPLATE_DB), section='GlobalSchemaVersion', metadata=metadata ) ) else: self.pgops.add( dbops.UpdateSingleDBMetadataSection( edbdef.EDGEDB_TEMPLATE_DB, section='GlobalSchemaVersion', metadata=metadata ) ) return schema class PseudoTypeCommand(MetaCommand): pass class CreatePseudoType( PseudoTypeCommand, adapts=s_pseudo.CreatePseudoType, ): pass class TupleCommand(MetaCommand): pass class CreateTuple(TupleCommand, adapts=s_types.CreateTuple): @classmethod def create_tuple( cls, tup: s_types.Tuple, schema: s_schema.Schema, conditional: bool=False, ) -> dbops.Command: elements = tup.get_element_types(schema).items(schema) name = common.get_backend_name(schema, tup, catenate=False) ctype = dbops.CompositeType( name=name, columns=[ dbops.Column( name=n, type=qt(types.pg_type_from_object( schema, t, persistent_tuples=True)), ) for n, t in elements ] ) neg_conditions = [] if conditional: neg_conditions.append(dbops.TypeExists(name=name)) return dbops.CreateCompositeType( type=ctype, neg_conditions=neg_conditions) def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) if self.scls.is_polymorphic(schema): return schema self.pgops.add(self.create_tuple( self.scls, schema, # XXX: WHY conditional=context.stdmode, )) return schema class AlterTuple(TupleCommand, adapts=s_types.AlterTuple): pass class RenameTuple(TupleCommand, adapts=s_types.RenameTuple): pass class DeleteTuple(TupleCommand, adapts=s_types.DeleteTuple): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: tup = schema.get_global(s_types.Tuple, self.classname) if not tup.is_polymorphic(schema): domain_name = common.get_backend_name(schema, tup, catenate=False) assert isinstance(domain_name, tuple) self.pgops.add(drop_dependant_func_cache(domain_name)) self.pgops.add(dbops.DropCompositeType(name=domain_name)) schema = super().apply(schema, context) return schema class ExprAliasCommand(MetaCommand): pass class CreateAlias( ExprAliasCommand, adapts=s_aliases.CreateAlias, ): pass class RenameAlias( ExprAliasCommand, adapts=s_aliases.RenameAlias, ): pass class AlterAlias( ExprAliasCommand, adapts=s_aliases.AlterAlias, ): pass class DeleteAlias( ExprAliasCommand, adapts=s_aliases.DeleteAlias, ): pass class GlobalCommand(MetaCommand): pass class CreateGlobal( GlobalCommand, adapts=s_globals.CreateGlobal, ): pass class RenameGlobal( GlobalCommand, adapts=s_globals.RenameGlobal, ): pass class AlterGlobal( GlobalCommand, adapts=s_globals.AlterGlobal, ): pass class SetGlobalType( GlobalCommand, # ??? adapts=s_globals.SetGlobalType, ): def register_config_op(self, op, context): ops = context.get(sd.DeltaRootContext).op.config_ops ops.append(op) def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) if self.reset_value: op = config_ops.Operation( opcode=config_ops.OpCode.CONFIG_RESET, scope=ql_ft.ConfigScope.GLOBAL, setting_name=str(self.scls.get_name(schema)), value=None, ) self.register_config_op(op, context) return schema class DeleteGlobal( GlobalCommand, adapts=s_globals.DeleteGlobal, ): pass class PermissionCommand(MetaCommand): pass class CreatePermission( PermissionCommand, adapts=s_permissions.CreatePermission, ): pass class AlterPermission( PermissionCommand, adapts=s_permissions.AlterPermission, ): pass class DeletePermission( PermissionCommand, adapts=s_permissions.DeletePermission, ): pass class RenamePermission( PermissionCommand, adapts=s_permissions.RenamePermission, ): pass class AccessPolicyCommand(MetaCommand): pass class CreateAccessPolicy( AccessPolicyCommand, adapts=s_policies.CreateAccessPolicy, ): pass class RenameAccessPolicy( AccessPolicyCommand, adapts=s_policies.RenameAccessPolicy, ): pass class RebaseAccessPolicy( AccessPolicyCommand, adapts=s_policies.RebaseAccessPolicy, ): pass class AlterAccessPolicy( AccessPolicyCommand, adapts=s_policies.AlterAccessPolicy, ): pass class DeleteAccessPolicy( AccessPolicyCommand, adapts=s_policies.DeleteAccessPolicy, ): pass class TriggerCommand(MetaCommand): pass class CreateTrigger( TriggerCommand, adapts=s_triggers.CreateTrigger, ): pass class RenameTrigger( TriggerCommand, adapts=s_triggers.RenameTrigger, ): pass class RebaseTrigger( TriggerCommand, adapts=s_triggers.RebaseTrigger, ): pass class AlterTrigger( TriggerCommand, adapts=s_triggers.AlterTrigger, ): pass class DeleteTrigger( TriggerCommand, adapts=s_triggers.DeleteTrigger, ): pass class RewriteCommand(MetaCommand): pass class CreateRewrite( RewriteCommand, adapts=s_rewrites.CreateRewrite, ): pass class RebaseRewrite( RewriteCommand, adapts=s_rewrites.RebaseRewrite, ): pass class RenameRewrite( RewriteCommand, adapts=s_rewrites.RenameRewrite, ): pass class AlterRewrite( RewriteCommand, adapts=s_rewrites.AlterRewrite, ): pass class DeleteRewrite( RewriteCommand, adapts=s_rewrites.DeleteRewrite, ): pass class TupleExprAliasCommand(MetaCommand): pass class CreateTupleExprAlias( TupleExprAliasCommand, adapts=s_types.CreateTupleExprAlias, ): pass class RenameTupleExprAlias( TupleExprAliasCommand, adapts=s_types.RenameTupleExprAlias, ): pass class AlterTupleExprAlias( TupleExprAliasCommand, adapts=s_types.AlterTupleExprAlias, ): pass class DeleteTupleExprAlias( TupleExprAliasCommand, adapts=s_types.DeleteTupleExprAlias, ): pass class ArrayCommand(MetaCommand): pass class CreateArray(ArrayCommand, adapts=s_types.CreateArray): pass class AlterArray(ArrayCommand, adapts=s_types.AlterArray): pass class RenameArray(ArrayCommand, adapts=s_types.RenameArray): pass class DeleteArray(ArrayCommand, adapts=s_types.DeleteArray): pass class ArrayExprAliasCommand(MetaCommand): pass class CreateArrayExprAlias( ArrayExprAliasCommand, adapts=s_types.CreateArrayExprAlias, ): pass class RenameArrayExprAlias( ArrayExprAliasCommand, adapts=s_types.RenameArrayExprAlias, ): pass class AlterArrayExprAlias( ArrayExprAliasCommand, adapts=s_types.AlterArrayExprAlias, ): pass class DeleteArrayExprAlias( ArrayExprAliasCommand, adapts=s_types.DeleteArrayExprAlias, ): pass class RangeCommand(MetaCommand): pass class CreateRange(RangeCommand, adapts=s_types.CreateRange): pass class AlterRange(RangeCommand, adapts=s_types.AlterRange): pass class RenameRange(RangeCommand, adapts=s_types.RenameRange): pass class DeleteRange(RangeCommand, adapts=s_types.DeleteRange): pass class RangeExprAliasCommand(MetaCommand): pass class CreateRangeExprAlias( RangeExprAliasCommand, adapts=s_types.CreateRangeExprAlias, ): pass class RenameRangeExprAlias( RangeExprAliasCommand, adapts=s_types.RenameRangeExprAlias, ): pass class AlterRangeExprAlias( RangeExprAliasCommand, adapts=s_types.AlterRangeExprAlias, ): pass class DeleteRangeExprAlias( RangeExprAliasCommand, adapts=s_types.DeleteRangeExprAlias, ): pass class MultiRangeCommand(MetaCommand): pass class CreateMultiRange(MultiRangeCommand, adapts=s_types.CreateMultiRange): pass class AlterMultiRange(MultiRangeCommand, adapts=s_types.AlterMultiRange): pass class RenameMultiRange(MultiRangeCommand, adapts=s_types.RenameMultiRange): pass class DeleteMultiRange(MultiRangeCommand, adapts=s_types.DeleteMultiRange): pass class ParameterCommand(MetaCommand): pass class CreateParameter( ParameterCommand, adapts=s_funcs.CreateParameter, ): pass class DeleteParameter( ParameterCommand, adapts=s_funcs.DeleteParameter, ): pass class RenameParameter( ParameterCommand, adapts=s_funcs.RenameParameter, ): pass class AlterParameter( ParameterCommand, adapts=s_funcs.AlterParameter, ): pass class FunctionCommand(MetaCommand): def get_pgname(self, func: s_funcs.Function, schema, versioned: bool=False): return common.get_backend_name( schema, func, catenate=False, versioned=versioned) def get_pgtype(self, func: s_funcs.CallableObject, obj, schema): if obj.is_any(schema): return ('anyelement',) try: return types.pg_type_from_object( schema, obj, persistent_tuples=True) except ValueError: raise errors.QueryError( f'could not compile parameter type {obj!r} ' f'of function {func.get_shortname(schema)}', span=self.span) from None def compile_default( self, func: s_funcs.Function, default: s_expr.Expression, schema ): try: comp = default.compiled( schema=schema, as_fragment=True, context=None, ) ir = comp.irast if not irutils.is_const(ir.expr): raise ValueError('expression not constant') sql_res = compiler.compile_ir_to_sql_tree( ir.expr, singleton_mode=True) return codegen.generate_source(sql_res.ast) except Exception as ex: raise errors.QueryError( f'could not compile default expression {default!r} ' f'of function {func.get_shortname(schema)}: {ex}', span=self.span) from ex def compile_args(self, func: s_funcs.Function, schema): func_params = func.get_params(schema) has_inlined_defaults = func.has_inlined_defaults(schema) args = [] func_language = func.get_language(schema) if func_language is ql_ast.Language.EdgeQL: args.append(('__edb_json_globals__', ('jsonb',), None)) if has_inlined_defaults: args.append(('__defaults_mask__', ('bytea',), None)) compile_defaults = not ( has_inlined_defaults or func_params.find_named_only(schema) ) for param in func_params.get_in_canonical_order(schema): param_type = param.get_type(schema) param_default = param.get_default(schema) pg_at = self.get_pgtype(func, param_type, schema) default = None if compile_defaults and param_default is not None: default = self.compile_default(func, param_default, schema) pn = param.get_parameter_name(schema) args.append((pn, pg_at, default)) if param_type.is_object_type(): args.append((f'__{pn}__type', ('uuid',), None)) return args def make_function(self, func: s_funcs.Function, code, schema): func_return_typemod = func.get_return_typemod(schema) func_params = func.get_params(schema) name = self.get_pgname(func, schema, versioned=False) return self.get_function_type(name)( name=name, args=self.compile_args(func, schema), has_variadic=func_params.find_variadic(schema) is not None, set_returning=func_return_typemod is ql_ft.TypeModifier.SetOfType, volatility=func.get_volatility(schema), strict=func.get_impl_is_strict(schema), returns=self.get_pgtype( func, func.get_return_type(schema), schema), text=code, ) def compile_sql_function(self, func: s_funcs.Function, schema): return self.make_function(func, func.get_code(schema), schema) def _compile_edgeql_function( self, schema: s_schema.Schema, context: sd.CommandContext, func: s_funcs.Function, body: s_expr.Expression, ) -> s_expr.CompiledExpression: if isinstance(body, s_expr.CompiledExpression): return body # HACK: When an object type selected by a function (via # inheritance) is dropped, the function gets # recompiled. Unfortunately, 'caused' subcommands run *before* # the object is actually deleted, and so we would ordinarily # still try to select from the deleted object. To avoid # needing to add *another* type of subcommand, we work around # this by temporarily stripping all objects that are about to # be deleted from the schema. for ctx in context.stack: if isinstance(ctx.op, s_objtypes.DeleteObjectType): # Also get the pointers, since we look at pointer descendents. # This is really all a pretty bad hack. for ptr in ctx.op.scls.get_pointers(schema).objects(schema): schema = schema.delete(ptr) schema = schema.delete(ctx.op.scls) elif isinstance(ctx.op, s_pointers.DeletePointer): schema = schema.delete(ctx.op.scls) return s_funcs.compile_function( schema, context, body=body, func_name=func.get_name(schema), params=func.get_params(schema), language=ql_ast.Language.EdgeQL, return_type=func.get_return_type(schema), return_typemod=func.get_return_typemod(schema), ) def fix_return_type( self, func: s_funcs.Function, nativecode: s_expr.CompiledExpression, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_expr.CompiledExpression: return_type = func.get_return_type(schema) ir = nativecode.irast if not ( return_type.is_object_type() or s_types.is_type_compatible(return_type, ir.stype, schema=nativecode.schema) ): # Add a cast and recompile it qlexpr = qlcompiler.astutils.ensure_ql_query( ql_ast.TypeCast( type=s_utils.typeref_to_ast(schema, return_type), expr=nativecode.parse(), ) ) nativecode = self._compile_edgeql_function( schema, context, func, type(nativecode).from_ast(qlexpr, schema), ) return nativecode def compile_edgeql_function_body( self, func: s_funcs.Function, schema: s_schema.Schema, context: sd.CommandContext, ) -> str: nativecode = func.get_nativecode(schema) assert nativecode nativecode = self._compile_edgeql_function( schema, context, func, nativecode, ) nativecode = self.fix_return_type(func, nativecode, schema, context) sql_res = compiler.compile_ir_to_sql_tree( nativecode.irast, ignore_shapes=True, explicit_top_cast=irtyputils.type_to_typeref( # note: no cache schema, func.get_return_type(schema), cache=None), output_format=compiler.OutputFormat.NATIVE, named_param_prefix=self.get_pgname(func, schema)[-1:], versioned_stdlib=context.stdmode, ) return codegen.generate_source(sql_res.ast) def compile_edgeql_overloaded_function_body( self, func: s_funcs.Function, overloads: list[s_funcs.Function], ov_param_idx: int, schema: s_schema.Schema, context: sd.CommandContext, ) -> str: func_return_typemod = func.get_return_typemod(schema) set_returning = func_return_typemod is ql_ft.TypeModifier.SetOfType my_params = func.get_params(schema).objects(schema) param_name = my_params[ov_param_idx].get_parameter_name(schema) type_param_name = f'__{param_name}__type' cases = {} all_overloads = list(overloads) if not isinstance(self, DeleteFunction): all_overloads.append(func) for overload in all_overloads: ov_p = tuple(overload.get_params(schema).objects(schema)) ov_p_t = ov_p[ov_param_idx].get_type(schema) ov_body = self.compile_edgeql_function_body( overload, schema, context) if set_returning: case = ( f"(SELECT * FROM ({ov_body}) AS q " f"WHERE ancestor = {ql(str(ov_p_t.id))})" ) else: case = ( f"WHEN ancestor = {ql(str(ov_p_t.id))} " f"THEN \n({ov_body})" ) cases[ov_p_t] = case impl_ids = ', '.join(f'{ql(str(t.id))}::uuid' for t in cases) branches = list(cases.values()) # N.B: edgedb_VER.raise and coalesce are used below instead of # raise_on_null, because the latter somehow results in a # significantly more complex query plan. matching_impl = f""" coalesce( ( SELECT ancestor FROM (SELECT {qi(type_param_name)} AS ancestor, -1 AS index UNION ALL SELECT target AS ancestor, index FROM edgedb._object_ancestors WHERE source = {qi(type_param_name)} ) a WHERE ancestor IN ({impl_ids}) ORDER BY index LIMIT 1 ), edgedb.raise( NULL::uuid, 'assert_failure', msg => format( 'unhandled object type %s in overloaded function', {qi(type_param_name)} ) ) ) AS impl(ancestor) """ if set_returning: arms = "\nUNION ALL\n".join(branches) return f""" SELECT q.* FROM {matching_impl}, LATERAL ( {arms} ) AS q """ else: arms = "\n".join(branches) return f""" SELECT (CASE {arms} END) FROM {matching_impl} """ def compile_edgeql_function( self, func: s_funcs.Function, schema: s_schema.Schema, context: sd.CommandContext, ) -> tuple[Optional[dbops.Function], bool]: if func.get_volatility(schema) == ql_ft.Volatility.Modifying: # Modifying functions cannot be compiled correctly and should be # inlined at the call point. if func.find_object_param_overloads(schema) is not None: raise errors.SchemaDefinitionError( f"cannot overload an existing function " f"with a modifying function: " f"'{func.get_shortname(schema)}'", span=self.span, ) return None, False nativecode: s_expr.Expression = not_none(func.get_nativecode(schema)) compiled_expr = self._compile_edgeql_function( schema, context, func, nativecode ) compiled_expr = self.fix_return_type( func, compiled_expr, schema, context ) replace = False obj_overload = func.find_object_param_overloads(schema) if obj_overload is not None: overloads, ov_param_idx = obj_overload if any( overload.get_volatility(schema) == ql_ft.Volatility.Modifying for overload in overloads ): raise errors.SchemaDefinitionError( f"cannot overload an existing modifying function: " f"'{func.get_shortname(schema)}'", span=self.span, ) body = self.compile_edgeql_overloaded_function_body( func, overloads, ov_param_idx, schema, context ) replace = True else: sql_res = compiler.compile_ir_to_sql_tree( compiled_expr.irast, ignore_shapes=True, explicit_top_cast=irtyputils.type_to_typeref( # note: no cache schema, func.get_return_type(schema), cache=None), output_format=compiler.OutputFormat.NATIVE, named_param_prefix=self.get_pgname(func, schema)[-1:], backend_runtime_params=context.backend_runtime_params, versioned_stdlib=context.stdmode, ) pg_debug.dump_ast_and_query(sql_res.ast, compiled_expr.irast) body = codegen.generate_source(sql_res.ast) return self.make_function(func, body, schema), replace def sql_rval_consistency_check( self, cobj: s_funcs.CallableObject, expr: str, schema: s_schema.Schema, ) -> dbops.Command: fname = cobj.get_verbosename(schema) rtype = types.pg_type_from_object( schema, cobj.get_return_type(schema), persistent_tuples=True, ) rtype_desc = '.'.join(rtype) # Determine the actual returned type of the SQL function. # We can't easily do this by looking in system catalogs because # of polymorphic dispatch, but, fortunately, there's pg_typeof(). # We only need to be sure to actually NOT call the target function, # as we can't assume how it'll behave with dummy inputs. Hence, the # weird looking query below, where we rely in Postgres executor to # skip the call, because no rows satisfy the WHERE condition, but # we then still generate a NULL row via a LEFT JOIN. f_test = textwrap.dedent(f'''\ (SELECT pg_typeof(f.i) FROM (SELECT NULL::text) AS spreader LEFT JOIN (SELECT {expr} WHERE False) AS f(i) ON (true))''') check = dbops.Query(text=f''' PERFORM edgedb_VER.raise_on_not_null( NULLIF( pg_typeof(NULL::{qt(rtype)}), {f_test} ), 'invalid_function_definition', msg => format( '%s is declared to return SQL type "%s", but ' || 'the underlying SQL function returns "%s"', {ql(fname)}, {ql(rtype_desc)}, {f_test}::text ), hint => ( 'Declare the function with ' || '`force_return_cast := true`, ' || 'or add an explicit cast to its body.' ) ); ''') return check def sql_strict_consistency_check( self, cobj: s_funcs.CallableObject, func: str, schema: s_schema.Schema, ) -> dbops.Command: fname = cobj.get_verbosename(schema) # impl_is_strict means that the function is strict in all # singleton arguments, so we don't need to do the check if # no such arguments exist. if ( not cobj.get_impl_is_strict(schema) or not cobj.get_params(schema).has_type_mod( schema, ql_ft.TypeModifier.SingletonType ) ): return dbops.CommandGroup() if '.' in func: ns, func = func.split('.') else: ns = 'pg_catalog' f_test = textwrap.dedent(f'''\ COALESCE(( SELECT bool_and(proisstrict) FROM pg_proc INNER JOIN pg_namespace ON pg_namespace.oid = pronamespace WHERE proname = {ql(func)} AND nspname = {ql(ns)} ), false) ''') check = dbops.Query(text=f''' PERFORM edgedb_VER.raise_on_null( NULLIF( false, {f_test} ), 'invalid_function_definition', msg => format( '%s is declared to have a strict impl but does not', {ql(fname)} ), hint => ( 'Add `impl_is_strict := false` to the declaration.' ) ); ''') return check def get_dummy_func_call( self, cobj: s_funcs.CallableObject, sql_func: Sequence[str], schema: s_schema.Schema, ) -> str: name = common.maybe_versioned_name( tuple(sql_func), versioned=( cobj.get_name(schema).get_root_module_name().name != 'ext' ), ) args = [] func_params = cobj.get_params(schema) for param in func_params.get_in_canonical_order(schema): param_type = param.get_type(schema) pg_at = self.get_pgtype(cobj, param_type, schema) args.append(f'NULL::{qt(pg_at)}') if isinstance(param_type, s_objtypes.ObjectType): args.append(f'NULL::uuid') return f'{q(*name)}({", ".join(args)})' def make_op( self, func: s_funcs.Function, schema: s_schema.Schema, context: sd.CommandContext, *, or_replace: bool=False, ) -> Iterable[dbops.Command]: if func.get_from_expr(schema): # Intrinsic function, handled directly by the compiler. return () elif sql_func := func.get_from_function(schema): func_params = func.get_params(schema) if ( func.get_force_return_cast(schema) or func_params.has_polymorphic(schema) or func.get_sql_func_has_out_params(schema) ): return () else: # Function backed directly by an SQL function. # Check the consistency of the return type. dexpr = self.get_dummy_func_call( func, sql_func.split('.'), schema) return ( self.sql_rval_consistency_check(func, dexpr, schema), self.sql_strict_consistency_check(func, sql_func, schema), ) else: func_language = func.get_language(schema) dbf: Optional[dbops.Function] if func_language is ql_ast.Language.SQL: dbf = self.compile_sql_function(func, schema) elif func_language is ql_ast.Language.EdgeQL: dbf, overload_replace = self.compile_edgeql_function( func, schema, context ) if overload_replace: or_replace = True else: raise errors.QueryError( f'cannot compile function {func.get_shortname(schema)}: ' f'unsupported language {func_language}', span=self.span) ops: list[dbops.Command] = [] if dbf is not None: ops.append(dbops.CreateFunction(dbf, or_replace=or_replace)) self.maybe_trampoline(dbf, context) return ops class CreateFunction( FunctionCommand, adapts=s_funcs.CreateFunction, ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) ops = self.make_op(self.scls, schema, context) self.pgops.update(ops) return schema class RenameFunction(FunctionCommand, adapts=s_funcs.RenameFunction): pass class AlterFunction(FunctionCommand, adapts=s_funcs.AlterFunction): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) if self.metadata_only: return schema if ( self.get_attribute_value('volatility') is not None or self.get_attribute_value('nativecode') is not None or self.get_attribute_value('code') is not None ): self.pgops.update( self.make_op(self.scls, schema, context, or_replace=True)) return schema class DeleteFunction(FunctionCommand, adapts=s_funcs.DeleteFunction): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: func = self.get_object(schema, context) nativecode = func.get_nativecode(schema) if func.get_code(schema) or nativecode: # An EdgeQL or a SQL function # (not just an alias to a SQL function). overload = False if nativecode and func.find_object_param_overloads(schema): dbf, overload_replace = self.compile_edgeql_function( func, schema, context ) if dbf is not None and overload_replace: self.pgops.add(dbops.CreateFunction(dbf, or_replace=True)) overload = True if not overload: variadic = func.get_params(schema).find_variadic(schema) if func.get_volatility(schema) != ql_ft.Volatility.Modifying: # Modifying functions are not compiled. # See: compile_edgeql_function self.pgops.add( dbops.DropFunction( name=self.get_pgname(func, schema), args=self.compile_args(func, schema), has_variadic=variadic is not None, ) ) return super().apply(schema, context) class OperatorCommand(FunctionCommand): def oper_name_to_pg_name( self, schema, name: sn.QualName, ) -> tuple[str, str]: return common.get_operator_backend_name( name, catenate=False) def get_pg_operands(self, schema, oper: s_opers.Operator): left_type = None right_type = None oper_params = list(oper.get_params(schema).objects(schema)) oper_kind = oper.get_operator_kind(schema) if oper_kind is ql_ft.OperatorKind.Infix: left_type = types.pg_type_from_object( schema, oper_params[0].get_type(schema)) right_type = types.pg_type_from_object( schema, oper_params[1].get_type(schema)) elif oper_kind is ql_ft.OperatorKind.Prefix: right_type = types.pg_type_from_object( schema, oper_params[0].get_type(schema)) elif oper_kind is ql_ft.OperatorKind.Postfix: left_type = types.pg_type_from_object( schema, oper_params[0].get_type(schema)) else: raise RuntimeError( f'unexpected operator type: {oper_kind!r}') return left_type, right_type # FIXME: We should make split FunctionCommand into CallableCommand # and FunctionCommand and only inherit from CallableCommand def compile_args(self, oper: s_opers.Operator, schema): # type: ignore args = [] oper_params = oper.get_params(schema) for param in oper_params.get_in_canonical_order(schema): pg_at = self.get_pgtype(oper, param.get_type(schema), schema) args.append((param.get_parameter_name(schema), pg_at)) return args def make_operator_function(self, oper: s_opers.Operator, schema): name = common.get_backend_name( schema, oper, catenate=False, versioned=False, aspect='function') return self.get_function_type(name)( name=name, args=self.compile_args(oper, schema), volatility=oper.get_volatility(schema), returns=self.get_pgtype( oper, oper.get_return_type(schema), schema), text=not_none(oper.get_code(schema)), ) def get_dummy_operator_call( self, oper: s_opers.Operator, pgop: str, from_args: Sequence[tuple[str, ...] | str], schema: s_schema.Schema, ) -> str: # Need a proxy function with casts oper_kind = oper.get_operator_kind(schema) if oper_kind is ql_ft.OperatorKind.Infix: op = f'NULL::{qt(from_args[0])} {pgop} NULL::{qt(from_args[1])}' elif oper_kind is ql_ft.OperatorKind.Postfix: op = f'NULL::{qt(from_args[0])} {pgop}' elif oper_kind is ql_ft.OperatorKind.Prefix: op = f'{pgop} NULL::{qt(from_args[1])}' else: raise RuntimeError(f'unexpected operator kind: {oper_kind!r}') return op class CreateOperator(OperatorCommand, adapts=s_opers.CreateOperator): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) oper = self.scls if oper.get_abstract(schema): return schema params = oper.get_params(schema) oper_language = oper.get_language(schema) oper_fromop = oper.get_from_operator(schema) oper_fromfunc = oper.get_from_function(schema) oper_code = oper.get_code(schema) # We support having both fromop and one of the others for # "legacy" purposes, but ignore it. if oper_code or oper_fromfunc: oper_fromop = None if oper_language is ql_ast.Language.SQL and oper_fromop: pg_oper_name = oper_fromop[0] args = self.get_pg_operands(schema, oper) if len(oper_fromop) > 1: # Explicit operand types given in FROM SQL OPERATOR. from_args = oper_fromop[1:] else: from_args = args if ( pg_oper_name is not None and not params.has_polymorphic(schema) and not oper.get_force_return_cast(schema) ): cexpr = self.get_dummy_operator_call( oper, pg_oper_name, from_args, schema) # We don't do a strictness consistency check for # USING SQL OPERATOR because they are heavily # overloaded, and so we'd need to take the types # into account; this is doable, but doesn't seem # worth doing since the only non-strict operator # is || on arrays, and we use array_cat for that # anyway! check = self.sql_rval_consistency_check( oper, cexpr, schema) self.pgops.add(check) elif oper_language is ql_ast.Language.SQL and oper_code: args = self.get_pg_operands(schema, oper) oper_func = self.make_operator_function(oper, schema) self.pgops.add(dbops.CreateFunction(oper_func)) self.maybe_trampoline(oper_func, context) if not params.has_polymorphic(schema): cexpr = self.get_dummy_func_call( oper, oper_func.name, schema) check = self.sql_rval_consistency_check(oper, cexpr, schema) self.pgops.add(check) elif oper_language is ql_ast.Language.SQL and oper_fromfunc: args = self.get_pg_operands(schema, oper) oper_func_name = oper_fromfunc[0] if len(oper_fromfunc) > 1: args = oper_fromfunc[1:] cargs = [] for t in args: if t is not None: cargs.append(f'NULL::{qt(t)}') if not params.has_polymorphic(schema): cexpr = f"{qi(oper_func_name)}({', '.join(cargs)})" check = self.sql_rval_consistency_check(oper, cexpr, schema) self.pgops.add(check) check2 = self.sql_strict_consistency_check( oper, oper_func_name, schema) self.pgops.add(check2) elif oper.get_from_expr(schema): # This operator is handled by the compiler and does not # need explicit representation in the backend. pass else: raise errors.QueryError( f'cannot create operator {oper.get_shortname(schema)}: ' f'only "FROM SQL" and "FROM SQL OPERATOR" operators ' f'are currently supported', span=self.span) return schema class RenameOperator(OperatorCommand, adapts=s_opers.RenameOperator): pass class AlterOperator(OperatorCommand, adapts=s_opers.AlterOperator): pass class DeleteOperator(OperatorCommand, adapts=s_opers.DeleteOperator): pass class CastCommand(MetaCommand): def make_cast_function(self, cast: s_casts.Cast, schema): name = common.get_backend_name( schema, cast, catenate=False, versioned=False, aspect='function') args: Sequence[dbops.FunctionArg] = [ ( 'val', types.pg_type_from_object(schema, cast.get_from_type(schema)) ), ('detail', ('text',), "''"), ] returns = types.pg_type_from_object(schema, cast.get_to_type(schema)) # N.B: Semantically, strict *ought* to be true, since we want # all of our casts to have strict behavior. Unfortunately, # actually marking them as strict causes a huge performance # regression when bootstrapping (and probably anything else that # is heavy on json casts), so instead we just need to make sure # to write cast code that is naturally strict (this is enforced # by test_edgeql_casts_all_null). return self.get_function_type(name)( name=name, args=args, returns=returns, strict=False, wrapper_volatility=cast.get_volatility(schema), text=not_none(cast.get_code(schema)), ) class CreateCast(CastCommand, adapts=s_casts.CreateCast): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) cast = self.scls cast_language = cast.get_language(schema) cast_code = cast.get_code(schema) from_cast = cast.get_from_cast(schema) from_expr = cast.get_from_expr(schema) if cast_language is ql_ast.Language.SQL and cast_code: cast_func = self.make_cast_function(cast, schema) self.pgops.add(dbops.CreateFunction(cast_func)) self.maybe_trampoline(cast_func, context) elif from_cast is not None or from_expr is not None: # This operator is handled by the compiler and does not # need explicit representation in the backend. pass else: raise errors.QueryError( f'cannot create cast: ' f'only "FROM SQL" and "FROM SQL FUNCTION" casts ' f'are currently supported', span=self.span) return schema class RenameCast(CastCommand, adapts=s_casts.RenameCast): pass class AlterCast(CastCommand, adapts=s_casts.AlterCast): pass class DeleteCast(CastCommand, adapts=s_casts.DeleteCast): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema cast = schema.get(self.classname, type=s_casts.Cast) cast_language = cast.get_language(schema) cast_code = cast.get_code(schema) schema = super().apply(schema, context) if cast_language is ql_ast.Language.SQL and cast_code: cast_func = self.make_cast_function(cast, orig_schema) self.pgops.add(dbops.DropFunction( cast_func.name, cast_func.args)) return schema class AnnotationCommand(MetaCommand): pass class CreateAnnotation(AnnotationCommand, adapts=s_anno.CreateAnnotation): pass class RenameAnnotation(AnnotationCommand, adapts=s_anno.RenameAnnotation): pass class AlterAnnotation(AnnotationCommand, adapts=s_anno.AlterAnnotation): pass class DeleteAnnotation(AnnotationCommand, adapts=s_anno.DeleteAnnotation): pass class AnnotationValueCommand(MetaCommand): pass class CreateAnnotationValue( AnnotationValueCommand, adapts=s_anno.CreateAnnotationValue, ): pass class AlterAnnotationValue( AnnotationValueCommand, adapts=s_anno.AlterAnnotationValue, ): pass class AlterAnnotationValueOwned( AnnotationValueCommand, adapts=s_anno.AlterAnnotationValueOwned, ): pass class RenameAnnotationValue( AnnotationValueCommand, adapts=s_anno.RenameAnnotationValue, ): pass class RebaseAnnotationValue( AnnotationValueCommand, adapts=s_anno.RebaseAnnotationValue, ): pass class DeleteAnnotationValue( AnnotationValueCommand, adapts=s_anno.DeleteAnnotationValue, ): pass class ConstraintCommand(MetaCommand): @classmethod def constraint_is_effective( cls, schema: s_schema.Schema, constraint: s_constr.Constraint ) -> bool: subject = constraint.get_subject(schema) if subject is None: return False ancestors = [ a for a in constraint.get_ancestors(schema).objects(schema) if not a.is_non_concrete(schema) ] if ( constraint.get_delegated(schema) and all(ancestor.get_delegated(schema) for ancestor in ancestors) ): return False if irtyputils.is_cfg_view(subject, schema): return False match subject: case s_pointers.Pointer(): if subject.is_non_concrete(schema): return True else: return types.has_table(subject.get_source(schema), schema) case s_objtypes.ObjectType(): return types.has_table(subject, schema) case s_scalars.ScalarType(): return not subject.get_abstract(schema) raise NotImplementedError(subject) def schedule_relatives_constraint_trigger_update( self, constraint: s_constr.Constraint, orig_schema: s_schema.Schema, curr_schema: s_schema.Schema, context: sd.CommandContext, ): # Find all origins whose relationship with the constraint has changed. orig_origins: dict[uuid.UUID, s_constr.Constraint] = {} if orig_schema.has_object(constraint.id): for origin in constraint.get_constraint_origins(orig_schema): orig_origins[origin.id] = origin curr_origins: dict[uuid.UUID, s_constr.Constraint] = {} if curr_schema.has_object(constraint.id): for origin in constraint.get_constraint_origins(curr_schema): curr_origins[origin.id] = origin # Find all constraints whose inheritance relationship with the # constraint has changed. relative_ids: set[uuid.UUID] = set() for origin_id in (orig_origins.keys() - curr_origins.keys()): origin = orig_origins[origin_id] for relative in ( [origin] + list(origin.descendants(orig_schema)) ): if not curr_schema.has_object(relative.id): # The constraint was deleted, updating the triggers is # not needed. continue relative_ids.add(relative.id) for origin_id in (curr_origins.keys() - orig_origins.keys()): origin = curr_origins[origin_id] for relative in ( [origin] + list(origin.descendants(curr_schema)) ): relative_ids.add(relative.id) relatives: list[s_constr.Constraint] = [ curr_schema.get_by_id(relative_id, type=s_constr.Constraint) for relative_id in relative_ids ] op = dbops.CommandGroup() # Schedule constraint trigger updates for relatives. for relative in relatives: self.schedule_constraint_trigger_update( relative, curr_schema, context, s_sources.SourceCommandContext, ) return op @staticmethod def create_constraint( current_command: MetaCommand, constraint: s_constr.Constraint, schema: s_schema.Schema, context: sd.CommandContext, span: Optional[parsing.Span] = None, *, create_triggers_if_needed: bool = True, ) -> dbops.Command: op = dbops.CommandGroup() if ConstraintCommand.constraint_is_effective(schema, constraint): subject = constraint.get_subject(schema) if subject is not None: op.add_command(ConstraintCommand._get_create_ops( current_command, constraint, schema, context, span, create_triggers_if_needed=create_triggers_if_needed, )) return op @staticmethod def _get_create_ops( current_command: MetaCommand, constraint: s_constr.Constraint, schema: s_schema.Schema, context: sd.CommandContext, span: Optional[parsing.Span] = None, *, create_triggers_if_needed: bool = True, ) -> dbops.CommandGroup: subject = constraint.get_subject(schema) assert subject is not None compiled_constraint = schemamech.compile_constraint( subject, constraint, schema, span, ) op = compiled_constraint.create_ops() if create_triggers_if_needed: # Constraint triggers are created last to avoid repeated # recompilation. current_command.schedule_constraint_trigger_update( constraint, schema, context, s_sources.SourceCommandContext, ) return op @staticmethod def _get_alter_ops( current_command: MetaCommand, constraint: s_constr.Constraint, orig_schema: s_schema.Schema, schema: s_schema.Schema, context: sd.CommandContext, span: Optional[parsing.Span] = None, ) -> dbops.CommandGroup: orig_subject = constraint.get_subject(orig_schema) assert orig_subject is not None orig_compiled_constraint = schemamech.compile_constraint( orig_subject, constraint, orig_schema, span, ) subject = constraint.get_subject(schema) assert subject is not None compiled_constraint = schemamech.compile_constraint( subject, constraint, schema, span, ) op = compiled_constraint.alter_ops(orig_compiled_constraint) # Constraint triggers are created last to avoid repeated recompilation. current_command.schedule_constraint_trigger_update( constraint, schema, context, s_sources.SourceCommandContext, ) return op @classmethod def delete_constraint( cls, constraint: s_constr.Constraint, schema: s_schema.Schema, span: Optional[parsing.Span] = None, ) -> dbops.Command: op = dbops.CommandGroup() if cls.constraint_is_effective(schema, constraint): subject = constraint.get_subject(schema) if subject is not None: bconstr = schemamech.compile_constraint( subject, constraint, schema, span ) op.add_command(bconstr.delete_ops()) return op @classmethod def enforce_constraint( cls, constraint: s_constr.Constraint, schema: s_schema.Schema, span: Optional[parsing.Span] = None, ) -> dbops.Command: if cls.constraint_is_effective(schema, constraint): subject = constraint.get_subject(schema) if subject is not None: bconstr = schemamech.compile_constraint( subject, constraint, schema, span ) return bconstr.enforce_ops() else: return dbops.CommandGroup() class CreateConstraint(ConstraintCommand, adapts=s_constr.CreateConstraint): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super().apply(schema, context) constraint: s_constr.Constraint = self.scls self.pgops.add(ConstraintCommand.create_constraint( self, constraint, schema, context, self.span )) self.pgops.add(self.schedule_relatives_constraint_trigger_update( constraint, orig_schema, schema, context, )) # If the constraint is being added to existing data, # we need to enforce it on the existing data. (This only # matters when inheritance is in play and we use triggers # to enforce exclusivity across tables.) if ( (subject := constraint.get_subject(schema)) and isinstance( subject, (s_objtypes.ObjectType, s_pointers.Pointer)) and not context.is_creating(subject) ): self.pgops.add(self.enforce_constraint( constraint, schema, self.span )) return schema class RenameConstraint(ConstraintCommand, adapts=s_constr.RenameConstraint): pass class AlterConstraintOwned( ConstraintCommand, adapts=s_constr.AlterConstraintOwned, ): pass class AlterConstraint( ConstraintCommand, adapts=s_constr.AlterConstraint, ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super().apply(schema, context) constraint: s_constr.Constraint = self.scls if self.metadata_only: return schema if ( not self.constraint_is_effective(schema, constraint) and not self.constraint_is_effective(orig_schema, constraint) ): return schema subject = constraint.get_subject(schema) subcommands = list(self.get_subcommands()) if (not subcommands or isinstance(subcommands[0], s_constr.RenameConstraint)): # This is a pure rename, so everything had been handled by # RenameConstraint above. return schema if subject is not None: if pcontext := context.get(s_pointers.PointerCommandContext): orig_schema = pcontext.original_schema op = dbops.CommandGroup() if not self.constraint_is_effective(orig_schema, constraint): op.add_command(ConstraintCommand._get_create_ops( self, constraint, schema, context, self.span )) # XXX: I don't think any of this logic is needed?? for child in constraint.children(schema): op.add_command(ConstraintCommand._get_alter_ops( self, child, orig_schema, schema, context, self.span )) elif not self.constraint_is_effective(schema, constraint): op.add_command(ConstraintCommand._get_alter_ops( self, constraint, orig_schema, schema, context, self.span )) for child in constraint.children(schema): op.add_command(ConstraintCommand._get_alter_ops( self, child, orig_schema, schema, context, self.span )) else: op.add_command(ConstraintCommand._get_alter_ops( self, constraint, orig_schema, schema, context, self.span )) self.pgops.add(op) if ( (subject := constraint.get_subject(schema)) and isinstance( subject, (s_objtypes.ObjectType, s_pointers.Pointer)) and not context.is_creating(subject) and not context.is_deleting(subject) ): self.pgops.add(self.enforce_constraint( constraint, schema, self.span )) self.pgops.add(self.schedule_relatives_constraint_trigger_update( constraint, orig_schema, schema, context, )) return schema class DeleteConstraint(ConstraintCommand, adapts=s_constr.DeleteConstraint): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: delta_root_ctx = context.top() orig_schema = delta_root_ctx.original_schema constraint: s_constr.Constraint = ( schema.get(self.classname, type=s_constr.Constraint) ) schema = super().apply(schema, context) op = self.delete_constraint( constraint, orig_schema, self.span ) self.pgops.add(op) self.pgops.add(self.schedule_relatives_constraint_trigger_update( constraint, orig_schema, schema, context, )) return schema class RebaseConstraint(ConstraintCommand, adapts=s_constr.RebaseConstraint): pass class AliasCapableMetaCommand(MetaCommand): pass class ScalarTypeMetaCommand(AliasCapableMetaCommand): @classmethod def is_sequence(cls, schema, scalar): seq = schema.get('std::sequence', default=None) return seq is not None and scalar.issubclass(schema, seq) class CreateScalarType(ScalarTypeMetaCommand, adapts=s_scalars.CreateScalarType): @classmethod def create_scalar( cls, scalar: s_scalars.ScalarType, default: Optional[s_expr.Expression], schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: if scalar.is_concrete_enum(schema): enum_values = scalar.get_enum_values(schema) assert enum_values return CreateScalarType.create_enum( scalar, enum_values, schema, context) else: ops = dbops.CommandGroup() if scalar.get_transient(schema): return ops base = types.get_scalar_base(schema, scalar) new_domain_name = types.pg_type_from_scalar(schema, scalar) if cls.is_sequence(schema, scalar): seq_name = common.get_backend_name( schema, scalar, catenate=False, aspect='sequence') ops.add_command(dbops.CreateSequence(name=seq_name)) domain = dbops.Domain(name=new_domain_name, base=base) ops.add_command(dbops.CreateDomain(domain=domain)) if (default is not None and not isinstance(default, s_expr.Expression)): # We only care to support literal defaults here. Supporting # defaults based on queries has no sense on the database # level since the database forbids queries for DEFAULT and # pre- calculating the value does not make sense either # since the whole point of query defaults is for them to be # dynamic. ops.add_command( dbops.AlterDomainAlterDefault( name=new_domain_name, default=default)) return ops @classmethod def create_enum( cls, scalar: s_scalars.ScalarType, values: Sequence[str], schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: ops = dbops.CommandGroup() new_enum_name = common.get_backend_name(schema, scalar, catenate=False) neg_conditions = [] if context.stdmode: neg_conditions.append(dbops.EnumExists(name=new_enum_name)) ops.add_command( dbops.CreateEnum( dbops.Enum(name=new_enum_name, values=values), neg_conditions=neg_conditions, ) ) fcls = cls.get_function_type(new_enum_name) # Cast wrapper function is needed for immutable casts, which are # needed for casting within indexes/constraints. # (Postgres casts are only stable) cast_func_name = common.get_backend_name( schema, scalar, catenate=False, aspect="enum-cast-from-str" ) cast_func = fcls( name=cast_func_name, args=[("value", ("anyelement",))], volatility="immutable", returns=new_enum_name, text=f"SELECT value::{qt(new_enum_name)}", ) ops.add_command(dbops.CreateFunction(cast_func)) cls.maybe_trampoline(cast_func, context) # Simialry, uncast from enum to str uncast_func_name = common.get_backend_name( schema, scalar, catenate=False, aspect="enum-cast-into-str" ) uncast_func = fcls( name=uncast_func_name, args=[("value", ("anyelement",))], volatility="immutable", returns="text", text=f"SELECT value::text", ) ops.add_command(dbops.CreateFunction(uncast_func)) cls.maybe_trampoline(uncast_func, context) return ops def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._create_begin(schema, context) scalar = self.scls if scalar.get_abstract(schema): return schema if types.is_builtin_scalar(schema, scalar): return schema # If this type exposes a SQL type or is a parameterized # subtype of a SQL type, we don't create a real type here. if scalar.resolve_sql_type_scheme(schema)[0]: return schema default = self.get_resolved_attribute_value( 'default', schema=schema, context=context, ) self.pgops.add(self.create_scalar(scalar, default, schema, context)) return schema class RenameScalarType( ScalarTypeMetaCommand, adapts=s_scalars.RenameScalarType, ): pass class RebaseScalarType( ScalarTypeMetaCommand, adapts=s_scalars.RebaseScalarType, ): # Actual rebase is taken care of in AlterScalarType pass class AlterScalarType(ScalarTypeMetaCommand, adapts=s_scalars.AlterScalarType): problematic_refs: Optional[tuple[ tuple[so.Object, ...], dict[s_props.Property, s_types.TypeShell], ]] def _get_problematic_refs( self, schema: s_schema.Schema, context: sd.CommandContext, *, composite_only: bool, ) -> Optional[tuple[ tuple[so.Object, ...], dict[s_props.Property, s_types.TypeShell], ]]: """Find problematic references to this scalar type that need handled. This is used to work around two irritating limitations of Postgres: 1. That elements of enum types may not be removed or reordered 2. That a constraint may not be added to a domain type if that domain type appears in a *composite* type that is used in a column somewhere. We don't want to have these limitations, and we need to do a decent amount of work to work around them. 1. Find all of the affected properties. For case 2, this is any property whose type is a container type that contains this scalar. (Possibly transitively.) For case 1, the container type restriction is dropped. 2. Change the type of all offending properties to an equivalent type that does not reference this scalar. This may require creating new types. (See _undo_everything.) 3. Add the constraint. 4. Restore the type of all offending properties. If existing data violates the new constraint, we will fail here. Delete any temporarily created types. (See _redo_everything.) Somewhat hackily, _undo_everything and _redo_everything operate by creating new schema delta command objects, and adapting and applying them. This is the most straightforward way to perform the high-level operations needed here. I've kept this code in pgsql/delta instead of trying to put in schema/delta because it is pretty aggressively an irritating pgsql implementation detail and because I didn't want it to have to interact with ordering ever. This function finds all of the relevant properties and returns a list of them along with the appropriate replacement type. In case 1, it also finds other referencing objects which need to be deleted and then recreated. """ seen_props = set() seen_other: set[so.Object] = set() typ = self.scls typs = [typ] # Do a worklist driven search for properties that refer to this scalar # through a collection type. We search backwards starting from # referring collection types or from all refs, depending on # composite_only. scls_type = s_types.Collection if composite_only else None wl = list(schema.get_referrers(typ, scls_type=scls_type)) while wl: obj = wl.pop() if isinstance(obj, s_props.Property): seen_props.add(obj) elif isinstance(obj, s_scalars.ScalarType) and not composite_only: wl.extend(schema.get_referrers(obj)) seen_other.add(obj) typs.append(obj) elif isinstance(obj, s_types.Collection): wl.extend(schema.get_referrers(obj)) seen_other.add(obj) elif isinstance(obj, s_funcs.Parameter) and not composite_only: wl.extend(schema.get_referrers(obj)) seen_other.add(obj) elif isinstance(obj, s_funcs.Function) and not composite_only: wl.extend(schema.get_referrers(obj)) seen_other.add(obj) elif isinstance(obj, s_constr.Constraint) and not composite_only: seen_other.add(obj) elif isinstance(obj, s_indexes.Index) and not composite_only: seen_other.add(obj) if not seen_props and not seen_other: return None props = {} if seen_props: type_substs: dict[sn.Name, s_types.TypeShell[s_types.Type]] = {} for typ in typs: # Find a concrete ancestor to substitute in. if typ.is_enum(schema): ancestor = schema.get( sn.QualName('std', 'str'), type=s_types.Type) else: for ancestor in typ.get_ancestors(schema).objects(schema): if not ancestor.get_abstract(schema): break else: raise AssertionError( "can't find concrete base for scalar") type_substs[typ.get_name(schema)] = ancestor.as_shell(schema) props = { prop: s_utils.type_shell_multi_substitute( type_substs, not_none(prop.get_target(schema)).as_shell(schema), schema, ) for prop in seen_props } objs = sd.sort_by_cross_refs(schema, seen_props | seen_other) return objs, props def _undo_everything( self, schema: s_schema.Schema, context: sd.CommandContext, objs: tuple[so.Object, ...], props: dict[s_props.Property, s_types.TypeShell], ) -> s_schema.Schema: """Rewrite the type of everything that uses this scalar dangerously. See _get_problematic_refs above for details. """ # First we need to strip out any default value that might reference # one of the functions we are going to delete. # We also create any new types, in this pass. cmd = sd.DeltaRoot() for prop, new_typ in props.items(): try: cmd.add(new_typ.as_create_delta(schema)) except NotImplementedError as e: if e.args == ('unsupported typeshell',): pass else: raise if prop.get_default(schema): delta_alter, cmd_alter, _alter_context = prop.init_delta_branch( schema, context, cmdtype=sd.AlterObject) cmd_alter.set_attribute_value('default', None) cmd.add(delta_alter) cmd.apply(schema, context) acmd = CommandMeta.adapt(cmd) schema = acmd.apply(schema, context) self.pgops.update(acmd.get_subcommands()) # Now process all the objects in the appropriate order for obj in objs: if isinstance(obj, s_funcs.Function): # Force function deletions at the SQL level without ever # bothering to remove them from our schema. fc = FunctionCommand() variadic = obj.get_params(schema).find_variadic(schema) self.pgops.add( dbops.DropFunction( name=fc.get_pgname(obj, schema), args=fc.compile_args(obj, schema), has_variadic=variadic is not None, ) ) elif isinstance(obj, s_constr.Constraint): self.pgops.add(ConstraintCommand.delete_constraint(obj, schema)) elif isinstance(obj, s_indexes.Index): self.pgops.add(DeleteIndex.delete_index(obj, schema, context)) elif isinstance(obj, s_types.Tuple): self.pgops.add(dbops.DropCompositeType( name=common.get_backend_name(schema, obj, catenate=False), )) elif isinstance(obj, s_scalars.ScalarType): self.pgops.add(DeleteScalarType.delete_scalar(obj, schema)) elif isinstance(obj, s_props.Property): new_typ = props[obj] delta_alter, cmd_alter, _alter_context = obj.init_delta_branch( schema, context, cmdtype=sd.AlterObject) cmd_alter.set_attribute_value('target', new_typ) cmd_alter.set_attribute_value('default', None) delta_alter.apply(schema, context) acmd2 = CommandMeta.adapt(delta_alter) schema = acmd2.apply(schema, context) self.pgops.add(acmd2) return schema def _redo_everything( self, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, objs: tuple[so.Object, ...], props: dict[s_props.Property, s_types.TypeShell], ) -> s_schema.Schema: """Restore the type of everything that uses this scalar dangerously. See _get_problematic_refs above for details. """ for obj in reversed(objs): if isinstance(obj, s_funcs.Function): # Super hackily recreate the functions fc = CreateFunction( classname=obj.get_name(schema)) # type: ignore for f in ('language', 'params', 'return_type'): fc.set_attribute_value(f, obj.get_field_value(schema, f)) self.pgops.update(fc.make_op(obj, schema, context)) elif isinstance(obj, s_constr.Constraint): self.pgops.add(ConstraintCommand.create_constraint( self, obj, schema, context, create_triggers_if_needed=False, )) elif isinstance(obj, s_indexes.Index): self.pgops.add( CreateIndex.create_index(obj, orig_schema, context)) elif isinstance(obj, s_types.Tuple): self.pgops.add(CreateTuple.create_tuple(obj, orig_schema)) elif isinstance(obj, s_scalars.ScalarType): self.pgops.add( CreateScalarType.create_scalar( obj, obj.get_default(schema), orig_schema, context ) ) elif isinstance(obj, s_props.Property): new_typ = props[obj] delta_alter, cmd_alter, _ = obj.init_delta_branch( schema, context, cmdtype=sd.AlterObject) cmd_alter.set_attribute_value( 'target', obj.get_target(orig_schema)) delta_alter.apply(schema, context) acmd = CommandMeta.adapt(delta_alter) schema = acmd.apply(schema, context) self.pgops.add(acmd) # Restore defaults and prune newly created types cmd = sd.DeltaRoot() for prop, new_typ in props.items(): rnew_typ = new_typ.resolve(schema) if delete := rnew_typ.as_type_delete_if_unused(schema): cmd.add_caused(delete) delta_alter, cmd_alter, _ = prop.init_delta_branch( schema, context, cmdtype=sd.AlterObject) cmd_alter.set_attribute_value( 'default', prop.get_default(orig_schema)) cmd.add(delta_alter) # do an apply of the schema-level command to force it to canonicalize, # which prunes out duplicate deletions # # HACK: Clear out the context's stack so that # context.canonical is false while doing this. stack, context.stack = context.stack, [] cmd.apply(schema, context) context.stack = stack for sub in cmd.get_subcommands(): acmd2 = CommandMeta.adapt(sub) schema = acmd2.apply(schema, context) self.pgops.add(acmd2) return schema def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super()._alter_begin(schema, context) new_scalar = self.scls has_create_constraint = bool( list(self.get_subcommands(type=s_constr.CreateConstraint))) has_rebase = bool( list(self.get_subcommands(type=s_scalars.RebaseScalarType))) old_enum_values: Sequence[str] = ( new_scalar.get_enum_values(orig_schema) or []) new_enum_values: Sequence[str] if has_rebase and old_enum_values: # Ugly hack alert: we need to do this "lookahead" rebase # apply to get the list of new enum values to decide # whether special handling is needed _before_ the actual # _alter_innards() takes place, because we are also handling # domain constraints here. TODO: a cleaner way to handle this # would be to move this logic into actual subcomands # (RebaseScalarType and CreateConstraint). rebased_schema = super()._alter_innards(schema, context) new_enum_values = new_scalar.get_enum_values(rebased_schema) or [] else: new_enum_values = old_enum_values # If values were deleted or reordered, we need to drop the enum # and recreate it. needs_recreate = ( old_enum_values != new_enum_values and old_enum_values != new_enum_values[:len(old_enum_values)]) self.problematic_refs = None if needs_recreate or has_create_constraint: self.problematic_refs = self._get_problematic_refs( schema, context, composite_only=not needs_recreate) if self.problematic_refs: objs, props = self.problematic_refs schema = self._undo_everything(schema, context, objs, props) if new_enum_values: type_name = common.get_backend_name( schema, new_scalar, catenate=False) if needs_recreate: self.pgops.add( DeleteScalarType.delete_scalar(new_scalar, orig_schema) ) self.pgops.add( CreateScalarType.create_enum( new_scalar, new_enum_values, schema, context ) ) elif old_enum_values != new_enum_values: old_idx = 0 old_enum_values = list(old_enum_values) for v in new_enum_values: if old_idx >= len(old_enum_values): self.pgops.add( dbops.AlterEnumAddValue( type_name, v, ) ) elif v != old_enum_values[old_idx]: self.pgops.add( dbops.AlterEnumAddValue( type_name, v, before=old_enum_values[old_idx], ) ) old_enum_values.insert(old_idx, v) else: old_idx += 1 default_delta = self.get_resolved_attribute_value( 'default', schema=schema, context=context, ) if default_delta: if (default_delta is None or isinstance(default_delta, s_expr.Expression)): new_default = None else: new_default = default_delta domain_name = common.get_backend_name( schema, new_scalar, catenate=False) adad = dbops.AlterDomainAlterDefault( name=domain_name, default=new_default) self.pgops.add(adad) return schema def _alter_finalize( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_finalize(schema, context) if self.problematic_refs: objs, props = self.problematic_refs schema = self._redo_everything( schema, context.current().original_schema, context, objs, props, ) return schema def drop_dependant_func_cache(pg_type: tuple[str, ...]) -> dbops.PLQuery: if len(pg_type) == 1: types_cte = f''' SELECT pt.oid AS oid FROM pg_type pt WHERE pt.typname = {ql(pg_type[0])} OR pt.typname = {ql('_' + pg_type[0])}\ ''' else: types_cte = f''' SELECT pt.oid AS oid FROM pg_type pt JOIN pg_namespace pn ON pt.typnamespace = pn.oid WHERE pn.nspname = {ql(pg_type[0])} AND ( pt.typname = {ql(pg_type[1])} OR pt.typname = {ql('_' + pg_type[1])} )\ ''' drop_func_cache_sql = textwrap.dedent(f''' DECLARE qc RECORD; BEGIN FOR qc IN WITH types AS ({types_cte} ), class AS ( SELECT pc.oid AS oid FROM pg_class pc JOIN pg_namespace pn ON pc.relnamespace = pn.oid WHERE pn.nspname = 'pg_catalog' AND pc.relname = 'pg_type' ) SELECT substring(p.proname FROM 6)::uuid AS key FROM pg_proc p JOIN pg_depend d ON d.objid = p.oid JOIN types t ON d.refobjid = t.oid JOIN class c ON d.refclassid = c.oid WHERE p.proname LIKE '__qh_%' LOOP PERFORM edgedb_VER."_evict_query_cache"(qc.key); END LOOP; END; ''') return dbops.PLQuery(drop_func_cache_sql) class DeleteScalarType(ScalarTypeMetaCommand, adapts=s_scalars.DeleteScalarType): @classmethod def delete_scalar( cls, scalar: s_scalars.ScalarType, orig_schema: s_schema.Schema ) -> dbops.Command: ops = dbops.CommandGroup() # The custom scalar types are sometimes included in the function # signatures of query cache functions under QueryCacheMode.PgFunc. # We need to find such functions through pg_depend and evict the cache # before dropping the custom scalar type. pg_type = types.pg_type_from_scalar(orig_schema, scalar) ops.add_command(drop_dependant_func_cache(pg_type)) old_domain_name = common.get_backend_name( orig_schema, scalar, catenate=False) cond: dbops.Condition if scalar.is_concrete_enum(orig_schema): old_enum_name = old_domain_name cond = dbops.EnumExists(old_enum_name) cast_func_name = common.get_backend_name( orig_schema, scalar, False, aspect="enum-cast-from-str" ) cast_func = dbops.DropFunction( name=cast_func_name, args=[("value", ("anyelement",))], conditions=[cond], ) ops.add_command(cast_func) uncast_func_name = common.get_backend_name( orig_schema, scalar, False, aspect="enum-cast-into-str" ) uncast_func = dbops.DropFunction( name=uncast_func_name, args=[("value", ("anyelement",))], conditions=[cond], ) ops.add_command(uncast_func) enum = dbops.DropEnum(name=old_enum_name, conditions=[cond]) ops.add_command(enum) else: cond = dbops.DomainExists(old_domain_name) domain = dbops.DropDomain(name=old_domain_name, conditions=[cond]) ops.add_command(domain) return ops def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super().apply(schema, context) scalar = self.scls link = None if context: link = context.get(s_links.LinkCommandContext) if link: assert isinstance(link.op, MetaCommand) ops = link.op.pgops else: ops = self.pgops ops.add(self.delete_scalar(scalar, orig_schema)) if self.is_sequence(orig_schema, scalar): seq_name = common.get_backend_name( orig_schema, scalar, catenate=False, aspect='sequence') self.pgops.add(dbops.DropSequence(name=seq_name)) return schema if TYPE_CHECKING: # In pgsql/delta, a "composite object" is anything that can have a table. # That is, an object type, a link, or a property. # We represent it as Source | Pointer, since many call sites are generic # over one of those things. CompositeObject = s_sources.Source | s_pointers.Pointer PostCommand = ( dbops.Command | Callable[ [s_schema.Schema, sd.CommandContext], Optional[dbops.Command] ] ) class CompositeMetaCommand(MetaCommand): constraint_trigger_updates: set[uuid.UUID] def __init__(self, **kwargs): super().__init__(**kwargs) self.table_name = None self._multicommands = {} self.update_search_indexes = None self.constraint_trigger_updates = set() def schedule_trampoline(self, obj, schema, context): delta = context.get(sd.DeltaRootContext).op create_trampolines = delta.create_trampolines create_trampolines.table_targets.append(obj) def _get_multicommand( self, context, cmdtype, object_name, *, force_new=False, manual=False, cmdkwargs=None, ): if cmdkwargs is None: cmdkwargs = {} key = (object_name, frozenset(cmdkwargs.items())) try: typecommands = self._multicommands[cmdtype] except KeyError: typecommands = self._multicommands[cmdtype] = {} commands = typecommands.get(key) if commands is None or force_new or manual: command = cmdtype(object_name, **cmdkwargs) if not manual: try: commands = typecommands[key] except KeyError: commands = typecommands[key] = [] commands.append(command) else: command = commands[-1] return command def _attach_multicommand(self, context, cmdtype): try: typecommands = self._multicommands[cmdtype] except KeyError: return else: commands = list( itertools.chain.from_iterable(typecommands.values())) if commands: self.pgops.update(commands) def get_alter_table( self, schema, context, force_new=False, contained=False, manual=False, table_name=None, ): tabname = table_name if table_name else self.table_name # XXX: should this be arranged to always have been done? if not tabname: ctx = context.get(self.__class__) assert ctx tabname = self._get_table_name(ctx.scls, schema) if table_name is None: self.table_name = tabname return self._get_multicommand( context, dbops.AlterTable, tabname, force_new=force_new, manual=manual, cmdkwargs={'contained': contained}) def attach_alter_table(self, context): self._attach_multicommand(context, dbops.AlterTable) @staticmethod def _get_table_name(obj, schema) -> tuple[str, str]: is_internal_view = irtyputils.is_cfg_view(obj, schema) aspect = 'dummy' if is_internal_view else None return common.get_backend_name( schema, obj, catenate=False, aspect=aspect) @classmethod def _refresh_fake_cfg_view_cmd( cls, obj: CompositeObject, schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: if not types.has_table(obj, schema): return dbops.CommandGroup() # Objects in sys and cfg are actually implemented by views # that are defined in metaschema. The metaschema scripts run # *after* the schema is instantiated, though, and we need to # populate something *now* that can go into inhviews. # # The way we do this is by creating an actual concrete table # with the suffix "_dummy" and then creating a view with the # expected table name that simply `select *`s from the dummy # table. Pointer creation on the type gets routed to the dummy # table, so it has the right columns. Since the view `select # *`s from the table, it also has the right columns, and can # go into all of the inheritance views without any trouble. # # We refresh the fake config view before creating/updating # inhviews associated with the object, since that corresponds # with when it actually needs to happen by. # # Then, when we run the metaschema script, it simply swaps out # this hacky view for the real one and everything works out fine. orig_name = common.get_backend_name( schema, obj, catenate=False, ) dummy_name = cls._get_table_name(obj, schema) query = f''' SELECT * FROM {q(*dummy_name)} ''' view = dbops.View(name=orig_name, query=query) return dbops.CreateView(view, or_replace=True) def update_if_cfg_view( self, schema: s_schema.Schema, context: sd.CommandContext, obj: CompositeObject, ): if irtyputils.is_cfg_view(obj, schema) and not context.in_deletion(): self.pgops.add( self._refresh_fake_cfg_view_cmd(obj, schema, context)) def update_source_if_cfg_view( self, schema: s_schema.Schema, context: sd.CommandContext, ptr: s_pointers.Pointer, ) -> None: if src := ptr.get_source(schema): assert isinstance(src, s_sources.Source) self.update_if_cfg_view(schema, context, src) @classmethod def get_source_and_pointer_ctx(cls, schema, context): if context: objtype = context.get(s_objtypes.ObjectTypeCommandContext) link = context.get(s_links.LinkCommandContext) else: objtype = link = None if objtype: source, pointer = objtype, link elif link: property = context.get(s_props.PropertyCommandContext) source, pointer = link, property else: source = pointer = None return source, pointer @classmethod def create_type_trampoline( cls, schema: s_schema.Schema, obj: CompositeObject, aspect: str='table', ) -> Optional[trampoline.TrampolineView]: versioned_name = common.get_backend_name( schema, obj, aspect=aspect, catenate=False ) trampolined_name = common.get_backend_name( schema, obj, aspect=aspect, catenate=False, versioned=False ) if versioned_name != trampolined_name: return trampoline.make_table_trampoline(versioned_name) else: return None def apply_constraint_trigger_updates( self, schema: s_schema.Schema, ) -> None: for constraint_id in self.constraint_trigger_updates: constraint = ( schema.get_by_id(constraint_id, type=s_constr.Constraint) if schema.has_object(constraint_id) else None ) if not constraint: continue if not ConstraintCommand.constraint_is_effective( schema, constraint ): continue subject = constraint.get_subject(schema) bconstr = schemamech.compile_constraint( subject, constraint, schema, None ) self.pgops.add(bconstr.update_trigger_ops()) class IndexCommand(MetaCommand): pass def get_index_compile_options( index: s_indexes.Index, schema: s_schema.Schema, modaliases: Mapping[Optional[str], str], schema_object_context: Optional[type[so.Object_T]], ) -> qlcompiler.CompilerOptions: subject = index.get_subject(schema) assert isinstance(subject, (s_types.Type, s_pointers.Pointer)) return qlcompiler.CompilerOptions( modaliases=modaliases, schema_object_context=schema_object_context, anchors={'__subject__': subject}, path_prefix_anchor='__subject__', singletons=[subject], apply_query_rewrites=False, ) def get_reindex_sql( obj: s_objtypes.ObjectType, restore_desc: sertypes.ShapeDesc, schema: s_schema.Schema, ) -> Optional[str]: """Generate SQL statement that repopulates the index after a restore. Currently this only applies to FTS indexes, and it only fires if __fts_document__ is not in the dump (which it wasn't prior to 5.0). AI index columns might also be missing if they were made with a 5.0rc1 dump, but the indexer will pick them up without our intervention. """ (fts_index, _) = s_indexes.get_effective_object_index( schema, obj, sn.QualName("std::fts", "index") ) if fts_index and '__fts_document__' not in restore_desc.fields: options = get_index_compile_options(fts_index, schema, {}, None) cmd = deltafts.update_fts_document(fts_index, options, schema) return cmd.code() return None class CreateIndex(IndexCommand, adapts=s_indexes.CreateIndex): @classmethod def create_index( cls, index: s_indexes.Index, schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: from .compiler import astutils options = get_index_compile_options( index, schema, context.modaliases, cls.get_schema_metaclass() ) index_sexpr: Optional[s_expr.Expression] = index.get_expr(schema) assert index_sexpr index_expr = index_sexpr.ensure_compiled( schema=schema, options=options, context=None, ) ir = index_expr.irast except_expr = index.get_except_expr(schema) if except_expr: except_expr = except_expr.ensure_compiled( schema=schema, options=options, context=None, ) assert except_expr.irast except_res = compiler.compile_ir_to_sql_tree( except_expr.irast.expr, singleton_mode=True) except_src = codegen.generate_source(except_res.ast) predicate_src = f'({except_src}) is not true' else: predicate_src = None sql_kwarg_exprs = dict() # Get the name of the root index that this index implements orig_name: sn.Name = sn.shortname_from_fullname(index.get_name(schema)) root_name: sn.Name root_code: str | None if orig_name == s_indexes.DEFAULT_INDEX: root_name = orig_name root_code = DEFAULT_INDEX_CODE else: root = index.get_root(schema) root_name = root.get_name(schema) root_code = root.get_code(schema) kwargs = index.get_concrete_kwargs(schema) for name, expr in kwargs.items(): kw_ir = expr.assert_compiled().irast kw_sql_res = compiler.compile_ir_to_sql_tree( kw_ir.expr, singleton_mode=True) kw_sql_tree = kw_sql_res.ast # HACK: the compiled SQL is expected to have some unnecessary # casts, strip them as they mess with the requirement that # index expressions are IMMUTABLE (also indexes expect the # usage of literals and will do their own implicit casts). if isinstance(kw_sql_tree, pgast.TypeCast): kw_sql_tree = kw_sql_tree.arg sql = codegen.generate_source(kw_sql_tree) sql_kwarg_exprs[name] = sql # FTS if root_name == sn.QualName('std::fts', 'index'): return deltafts.create_fts_index( index, ir.expr, predicate_src, sql_kwarg_exprs, options, schema, context, ) elif root_name == sn.QualName('ext::ai', 'index'): return delta_ext_ai.create_ext_ai_index( index, predicate_src, sql_kwarg_exprs, options, schema, context, ) if root_code is None: raise AssertionError(f'index {root_name} is missing the code') sql_res = compiler.compile_ir_to_sql_tree(ir.expr, singleton_mode=True) exprs = astutils.maybe_unpack_row(sql_res.ast) if len(exprs) == 0: raise errors.SchemaDefinitionError( f'cannot index empty tuples using {root_name}' ) subject = index.get_subject(schema) assert subject table_name = common.get_backend_name(schema, subject, catenate=False) module_name = index.get_name(schema).module index_name = common.get_index_backend_name( index.id, module_name, catenate=False) sql_exprs = [codegen.generate_source(e) for e in exprs] pg_index = dbops.Index( name=index_name[1], table_name=table_name, # type: ignore exprs=sql_exprs, unique=False, inherit=True, predicate=predicate_src, metadata={ 'schemaname': str(index.get_name(schema)), 'code': root_code, 'kwargs': sql_kwarg_exprs, } ) concurrently = index.get_build_concurrently(schema) return dbops.CreateIndex( pg_index, concurrently=concurrently, builtin_conditional=concurrently, ) # N.B: This is in _create_finalize instead of _create_innards # because when trying to do repair_schema() to repair issue #9033, # there will be generated CreateIndexes where the annotation # creation is in `caused`, not regular subcommands... def _create_finalize( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._create_finalize(schema, context) index = self.scls if index.get_abstract(schema): # Don't do anything for abstract indexes return schema if index.get_build_concurrently(schema): return schema with errors.ensure_span(self.span): self.pgops.add(self.create_index(index, schema, context)) return schema # mypy claims that _cmd_from_ast in IndexCommand is incompatible with # that in RenameObject. class RenameIndex(IndexCommand, adapts=s_indexes.RenameIndex): # type: ignore pass class AlterIndexOwned(IndexCommand, adapts=s_indexes.AlterIndexOwned): pass class AlterIndex(IndexCommand, adapts=s_indexes.AlterIndex): pass class DeleteIndex(IndexCommand, adapts=s_indexes.DeleteIndex): @classmethod def delete_index( cls, index: s_indexes.Index, schema: s_schema.Schema, context: sd.CommandContext, ): subject = index.get_subject(schema) assert subject table_name = common.get_backend_name( schema, subject, catenate=False) module_name = index.get_name(schema).module orig_idx_name = common.get_index_backend_name( index.id, module_name, catenate=False) pg_index = dbops.Index( name=orig_idx_name[1], table_name=table_name, inherit=True) index_exists = dbops.IndexExists( (table_name[0], pg_index.name_in_catalog) ) return dbops.DropIndex(pg_index, conditions=(index_exists,)) def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super().apply(schema, context) index = self.scls if index.get_abstract(orig_schema): # Don't do anything for abstract indexes return schema source: Optional[ sd.CommandContextToken[s_sources.SourceCommand[s_sources.Source]]] # XXX: I think to make these work, the type vars in the Commands # would need to be covariant. source = context.get(s_links.LinkCommandContext) # type: ignore if not source: source = context.get( s_objtypes.ObjectTypeCommandContext) # type: ignore assert source if not isinstance(source.op, sd.DeleteObject): # We should not drop indexes when the host is being dropped since # the indexes are dropped automatically in this case. drop_index = self.delete_index(index, orig_schema, context) else: drop_index = dbops.NoOpCommand() # FTS if s_indexes.is_fts_index(orig_schema, index): # compile commands for index drop options = get_index_compile_options( index, orig_schema, context.modaliases, self.get_schema_metaclass() ) self.pgops.add(deltafts.delete_fts_index( index, drop_index, options, schema, orig_schema, context )) # ext::ai::index elif s_indexes.is_ext_ai_index(orig_schema, index): # compile commands for index drop options = get_index_compile_options( index, orig_schema, context.modaliases, self.get_schema_metaclass() ) drop_support_ops, drop_col_ops = delta_ext_ai.delete_ext_ai_index( index, drop_index, options, schema, orig_schema, context ) # Even though the object type table is getting dropped, we have # to drop the trigger and its function self.pgops.add(drop_support_ops) self.pgops.add(drop_col_ops) else: self.pgops.add(drop_index) return schema class RebaseIndex(IndexCommand, adapts=s_indexes.RebaseIndex): pass class IndexMatchCommand(MetaCommand): pass class CreateIndexMatch(IndexMatchCommand, adapts=s_indexes.CreateIndexMatch): # Index match is handled by the compiler and does not need explicit # representation in the backend. pass class DeleteIndexMatch(IndexMatchCommand, adapts=s_indexes.DeleteIndexMatch): pass class CreateUnionType( MetaCommand, adapts=s_types.CreateUnionType, metaclass=CommandMeta, ): pass class CreateIntersectionType( MetaCommand, adapts=s_types.CreateIntersectionType, metaclass=CommandMeta, ): pass class ObjectTypeMetaCommand(AliasCapableMetaCommand, CompositeMetaCommand): def schedule_endpoint_delete_action_update(self, obj, schema, context): endpoint_delete_actions = context.get( sd.DeltaRootContext).op.update_endpoint_delete_actions changed_targets = endpoint_delete_actions.changed_targets changed_targets.add((self, obj)) def _fixup_configs( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: orig_schema = context.current().original_schema eff_schema = ( orig_schema if isinstance(self, sd.DeleteObject) else schema) scls: s_objtypes.ObjectType = self.scls # type: ignore # If we are updating a config object that is *not* in cfg:: # (that is, an extension config), we need to update the config # views and specs. We *don't* do that for standard library # configs, since those need to be created after the standard # schema is in place. if not ( irtyputils.is_cfg_view(scls, eff_schema) and scls.get_name(eff_schema).module not in irtyputils.VIEW_MODULES ): return from edb.pgsql import metaschema new_local_spec = config.load_spec_from_schema( schema, only_exts=True, # suppress validation because we might be in an intermediate state validate=False, ) spec_json = config.spec_to_json(new_local_spec) self.pgops.add(dbops.Query(textwrap.dedent(trampoline.fixup_query(f'''\ UPDATE edgedbinstdata_VER.instdata SET json = {ql(spec_json)} WHERE key = 'configspec_ext'; ''')))) for sub in self.get_subcommands(type=s_pointers.DeletePointer): if types.has_table(sub.scls, orig_schema): self.pgops.add(dbops.DropView(common.get_backend_name( orig_schema, sub.scls, catenate=False))) if isinstance(self, sd.DeleteObject): self.pgops.add(dbops.DropView(common.get_backend_name( eff_schema, scls, catenate=False))) elif isinstance(self, sd.CreateObject): views = metaschema.get_config_type_views( eff_schema, scls, scope=None) self.pgops.update(views) # FIXME: ALTER doesn't work in meaningful ways. We'll maybe # need to fix that when we have patching configs. class CreateObjectType( ObjectTypeMetaCommand, adapts=s_objtypes.CreateObjectType ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) objtype = self.scls if objtype.is_compound_type(schema) or objtype.get_is_derived(schema): return schema self.attach_alter_table(context) if self.update_search_indexes: schema = self.update_search_indexes.apply(schema, context) self.pgops.add(self.update_search_indexes) self.schedule_endpoint_delete_action_update(self.scls, schema, context) self.schedule_trampoline(self.scls, schema, context) self._fixup_configs(schema, context) return schema def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._create_begin(schema, context) objtype = self.scls if objtype.is_compound_type(schema) or objtype.get_is_derived(schema): return schema new_table_name = self._get_table_name(self.scls, schema) self.table_name = new_table_name columns: list[dbops.Column] = [] objtype_table = dbops.Table(name=new_table_name, columns=columns) self.pgops.add(dbops.CreateTable(table=objtype_table)) self.pgops.add(dbops.Comment( object=objtype_table, text=str(objtype.get_verbosename(schema)), )) # Don't update ancestors yet: no pointers have been added to # the type yet, so this type won't actually be added to any # ancestor views. We'll fix up the ancestors in # _create_finalize. self.update_if_cfg_view(schema, context, objtype) return schema def _create_finalize(self, schema, context): schema = super()._create_finalize(schema, context) self.apply_constraint_trigger_updates(schema) return schema class RenameObjectType( ObjectTypeMetaCommand, adapts=s_objtypes.RenameObjectType, ): pass class RebaseObjectType( ObjectTypeMetaCommand, adapts=s_objtypes.RebaseObjectType ): def _alter_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if types.has_table(self.scls, schema): self.update_if_cfg_view(schema, context, self.scls) schema = super()._alter_innards(schema, context) self.schedule_endpoint_delete_action_update(self.scls, schema, context) return schema class AlterObjectType(ObjectTypeMetaCommand, adapts=s_objtypes.AlterObjectType): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) # We want to set this name up early, so children operations see it self.table_name = self._get_table_name(self.scls, schema) return schema def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super().apply(schema, context=context) objtype = self.scls self.apply_constraint_trigger_updates(schema) self._maybe_do_abstract_test(orig_schema, schema, context) if types.has_table(objtype, schema): self.attach_alter_table(context) if self.update_search_indexes: schema = self.update_search_indexes.apply(schema, context) self.pgops.add(self.update_search_indexes) self._fixup_configs(schema, context) return schema def _maybe_do_abstract_test( self, orig_schema: s_schema.Schema, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: orig_abstract = self.scls.get_abstract(orig_schema) new_abstract = self.scls.get_abstract(schema) if orig_abstract or not new_abstract: return table = q(*common.get_backend_name( schema, self.scls, catenate=False, )) vn = self.scls.get_verbosename(schema) check_qry = textwrap.dedent(f'''\ SELECT edgedb_VER.raise( NULL::text, 'cardinality_violation', msg => {common.quote_literal( f"may not make non-empty {vn} abstract")}, "constraint" => 'set abstract' ) FROM {table} INTO _dummy_text; ''') self.pgops.add(dbops.Query(check_qry)) class DeleteObjectType( ObjectTypeMetaCommand, adapts=s_objtypes.DeleteObjectType ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: self.scls = objtype = schema.get( self.classname, type=s_objtypes.ObjectType) old_table_name = self._get_table_name(self.scls, schema) orig_schema = schema schema = super().apply(schema, context) self.apply_constraint_trigger_updates(schema) if types.has_table(objtype, orig_schema): self.attach_alter_table(context) self.pgops.add(dbops.DropTable(name=old_table_name)) self._fixup_configs(schema, context) return schema class SchedulePointerCardinalityUpdate(MetaCommand): pass class CancelPointerCardinalityUpdate(MetaCommand): pass class PointerMetaCommand[Pointer_T: s_pointers.Pointer]( CompositeMetaCommand, s_pointers.PointerCommand[Pointer_T], ): def get_host(self, schema, context): if context: link = context.get(s_links.LinkCommandContext) if link and isinstance(self, s_props.PropertyCommand): return link objtype = context.get(s_objtypes.ObjectTypeCommandContext) if objtype: return objtype def is_sequence_ptr(self, ptr, schema): return bool( (tgt := ptr.get_target(schema)) and tgt.issubclass(schema, schema.get('std::sequence')) ) def get_pointer_default(self, ptr, schema, context): if ptr.is_pure_computable(schema): return None # Skip id, because it shouldn't ever matter for performance # and because it wants to use the trampoline function, which # might not exist yet. if ptr.is_id_pointer(schema): return None if context.stdmode: return None # We only *need* to use postgres defaults for link properties # and sequence values (since we always explicitly inject it in # INSERTs anyway), but we *want* to use it whenever we can, # since it is much faster than explicitly populating the # column. default = ptr.get_default(schema) default_value = None if default is not None: default_value = schemamech.ptr_default_to_col_default( schema, ptr, default) elif self.is_sequence_ptr(ptr, schema): # TODO: replace this with a generic scalar type default # using std::nextval(). seq_name = common.quote_literal( common.get_backend_name( schema, ptr.get_target(schema), aspect='sequence')) default_value = f'nextval({seq_name}::regclass)' return default_value @classmethod def get_columns( cls, pointer, schema, default=None, sets_required=False ) -> list[dbops.Column]: ptr_stor_info = types.get_pointer_storage_info(pointer, schema=schema) col_type = common.quote_type(tuple(ptr_stor_info.column_type)) return [ dbops.Column( name=ptr_stor_info.column_name, type=col_type, required=( ( pointer.get_required(schema) and not pointer.is_pure_computable(schema) and not sets_required and not (pointer.get_default(schema) and not default) ) or ( ptr_stor_info.table_type == 'link' and not pointer.is_link_property(schema) ) ), default=default, comment=str(pointer.get_shortname(schema)), ), ] def create_table(self, ptr, schema, context): if types.has_table(ptr, schema): c = self._create_table(ptr, schema, context, conditional=True) self.pgops.add(c) self.update_if_cfg_view(schema, context, ptr) return True else: return False def _alter_pointer_cardinality( self, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, ) -> None: assert isinstance(self, s_pointers.AlterPointerUpperCardinality) ptr = self.scls ptr_stor_info = types.get_pointer_storage_info(ptr, schema=schema) old_ptr_stor_info = types.get_pointer_storage_info( ptr, schema=orig_schema) ptr_table = ptr_stor_info.table_type == 'link' is_lprop = ptr.is_link_property(schema) is_multi = ptr_table and not is_lprop is_required = ptr.get_required(schema) is_scalar = ptr.is_property() ref_ctx = self.get_referrer_context_or_die(context) ref_op = ref_ctx.op source_op: sd.ObjectCommand if is_multi: if isinstance(self, sd.AlterObjectFragment): source_op = self.get_parent_op(context) else: source_op = self else: source_op = ref_op assert isinstance(source_op, CompositeMetaCommand) # Ignore cardinality changes resulting from the creation of # an overloaded pointer as there is no data yet. if isinstance(source_op, sd.CreateObject): return # If the pointer has any constraints, drop them now. We'll # create them again at the end. # N.B: Since the pointer is either starting or ending as multi, # it can't have any object constraints referencing it. # TODO?: Maybe we should handle the constraint by generating # an alter in the front-end and running _alter_innards in the # middle of this function. (After creations, before deletions.) for constr in ptr.get_constraints(schema).objects(schema): self.pgops.add(ConstraintCommand.delete_constraint( constr, orig_schema )) assert ptr_stor_info.table_name tab = q(*ptr_stor_info.table_name) target_col = ptr_stor_info.column_name source = ptr.get_source(orig_schema) assert source src_tab = q(*common.get_backend_name( orig_schema, source, catenate=False, )) # initial extern relvar (see docs of _compile_conversion_expr) source_rel = textwrap.dedent(f'''\ SELECT * FROM {src_tab} ''') source_rel_alias = f'source_{uuidgen.uuid1mc()}' if self.conv_expr is not None: (conv_expr_ctes, _, _) = self._compile_conversion_expr( ptr, self.conv_expr, source_rel_alias, schema=schema, orig_schema=orig_schema, context=context, target_as_singleton=False, check_non_null=is_required and not is_multi ) else: if not is_multi: raise AssertionError( 'explicit conversion expression was expected' ' for multi->single transition' ) # single -> multi conv_expr_ctes = textwrap.dedent(f'''\ _conv_rel(val, id) AS ( SELECT {qi(old_ptr_stor_info.column_name)}, id FROM {qi(source_rel_alias)} ) ''') if not is_multi: # Moving from pointer table to source table. cols = self.get_columns(ptr, schema) # create columns alter_table = source_op.get_alter_table( schema, context, manual=True) cols_required: list[dbops.Column] = [] for col in cols: cond = dbops.ColumnExists( ptr_stor_info.table_name, column_name=col.name, ) if col.required: cols_required.append(copy(col)) col.required = False op = (dbops.AlterTableAddColumn(col), None, (cond, )) alter_table.add_operation(op) self.pgops.add(alter_table) update_qry = textwrap.dedent(f'''\ WITH "{source_rel_alias}" AS ({source_rel}), {conv_expr_ctes} UPDATE {tab} AS _update SET {qi(target_col)} = _conv_rel.val FROM _conv_rel WHERE _update.id = _conv_rel.id ''') self.pgops.add(dbops.Query(update_qry)) # set NOT NULL if cols_required: alter_table = source_op.get_alter_table( schema, context, manual=True ) for col in cols_required: op2 = dbops.AlterTableAlterColumnNull(col.name, False) alter_table.add_operation(op2) self.pgops.add(alter_table) # A link might still own a table if it has properties. if not types.has_table(ptr, schema): otabname = common.get_backend_name( orig_schema, ptr, catenate=False) condition = dbops.TableExists(name=otabname) dt = dbops.DropTable(name=otabname, conditions=[condition]) self.pgops.add(dt) self.update_source_if_cfg_view( schema, context, ptr, ) else: # Moving from source table to pointer table. self.create_table(ptr, schema, context) source = ptr.get_source(orig_schema) assert source src_tab = q(*common.get_backend_name( orig_schema, source, catenate=False, )) update_qry = textwrap.dedent(f'''\ WITH "{source_rel_alias}" AS ({source_rel}), {conv_expr_ctes} INSERT INTO {tab} (source, target) (SELECT id, val FROM _conv_rel WHERE _conv_rel.val IS NOT NULL) ''') if not is_scalar: update_qry += 'ON CONFLICT (source, target) DO NOTHING' self.pgops.add(dbops.Query(update_qry)) assert isinstance(ref_op.scls, s_sources.Source) self.update_if_cfg_view(schema, context, ref_op.scls) ref_op = self.get_referrer_context_or_die(context).op assert isinstance(ref_op, CompositeMetaCommand) alter_table = ref_op.get_alter_table( schema, context, manual=True) col = dbops.Column( name=old_ptr_stor_info.column_name, type=common.qname(*old_ptr_stor_info.column_type), ) alter_table.add_operation(dbops.AlterTableDropColumn(col)) self.pgops.add(alter_table) for constr in ptr.get_constraints(schema).objects(schema): self.pgops.add(ConstraintCommand.create_constraint( self, constr, schema, context )) def _alter_pointer_optionality( self, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, *, fill_expr: Optional[s_expr.Expression], is_default: bool=False, ) -> None: new_required = self.scls.get_required(schema) ptr = self.scls ptr_stor_info = types.get_pointer_storage_info(ptr, schema=schema) ptr_table = ptr_stor_info.table_type == 'link' is_lprop = ptr.is_link_property(schema) is_multi = ptr_table and not is_lprop is_required = ptr.get_required(schema) source_ctx = self.get_referrer_context_or_die(context) source_op = source_ctx.op assert isinstance(source_op, CompositeMetaCommand) alter_table = None if not ptr_table or is_lprop: alter_table = source_op.get_alter_table( schema, context, manual=True, ) alter_table.add_operation( dbops.AlterTableAlterColumnNull( column_name=ptr_stor_info.column_name, null=not new_required, ) ) # Ignore optionality changes resulting from the creation of # an overloaded pointer as there is no data yet. if isinstance(source_op, sd.CreateObject): if alter_table: self.pgops.add(alter_table) return ops = dbops.CommandGroup() # For multi pointers, if there is no fill expression, we # synthesize a bogus one so that an error will trip if there # are any objects with empty values. if fill_expr is None and is_multi and is_required: if ( ptr.get_cardinality(schema).is_multi() and fill_expr is None and (target := ptr.get_target(schema)) ): fill_ast = ql_ast.TypeCast( expr=ql_ast.Set(elements=[]), type=s_utils.typeref_to_ast(schema, target), ) fill_expr = s_expr.Expression.from_ast( qltree=fill_ast, schema=schema ) if fill_expr is not None: assert ptr_stor_info.table_name tab = q(*ptr_stor_info.table_name) target_col = ptr_stor_info.column_name source = ptr.get_source(orig_schema) assert source src_tab = q(*common.get_backend_name( orig_schema, source, catenate=False, )) if not is_multi: # For singleton pointers we simply update the # requisite column of the host source in every # row where it is NULL. source_rel = textwrap.dedent(f'''\ SELECT * FROM {tab} WHERE {qi(target_col)} IS NULL ''') else: # For multi pointers we have to INSERT the # result of USING into the link table for # every source object that has _no entries_ # in said link table. source_rel = textwrap.dedent(f'''\ SELECT * FROM {src_tab} WHERE id NOT IN (SELECT source FROM {tab}) ''') source_rel_alias = f'source_{uuidgen.uuid1mc()}' (conv_expr_ctes, conv_expr, _) = self._compile_conversion_expr( ptr, fill_expr, source_rel_alias, schema=schema, orig_schema=orig_schema, context=context, check_non_null=is_required and not is_multi, allow_globals=is_default, produce_ctes=not is_lprop, ) if is_lprop: # The produce_ctes=True flow really wants to key # everything based on id, which doesn't work for link # properties. If produce_ctes=True was able to use # ctid or (source, target) for keying, we could use # that here, but for now we will just use the # old-school version. See #5050, which would require # that in order to support DML in cast expressions. assert conv_expr update_with = ( f'WITH {conv_expr_ctes}' if conv_expr_ctes else '' ) update_qry = f''' {update_with} UPDATE {tab} AS {qi(source_rel_alias)} SET {qi(target_col)} = ({conv_expr}) WHERE {qi(target_col)} IS NULL ''' self.pgops.add(dbops.Query(update_qry)) elif not is_multi: update_qry = textwrap.dedent(f'''\ WITH "{source_rel_alias}" AS ({source_rel}), {conv_expr_ctes} UPDATE {tab} AS _update SET {qi(target_col)} = _conv_rel.val FROM _conv_rel WHERE _update.id = _conv_rel.id ''') ops.add_command(dbops.Query(update_qry)) else: update_qry = textwrap.dedent(f'''\ WITH "{source_rel_alias}" AS ({source_rel}), {conv_expr_ctes} INSERT INTO {tab} (source, target) (SELECT id, val FROM _conv_rel WHERE val IS NOT NULL) ''') ops.add_command(dbops.Query(update_qry)) if is_required: check_qry = textwrap.dedent(f'''\ SELECT edgedb_VER.raise( NULL::text, 'not_null_violation', msg => 'missing value for required property', detail => '{{"object_id": "' || id || '"}}', "column" => {ql(str(ptr.id))} ) FROM {src_tab} WHERE id != ALL (SELECT source FROM {tab}) LIMIT 1 INTO _dummy_text; ''') ops.add_command(dbops.Query(check_qry)) if alter_table: ops.add_command(alter_table) self.pgops.add(ops) def _drop_constraints(self, pointer, schema, context): # We need to be able to drop all the constraints referencing a # pointer before modifying its type, and then recreate them # once the change is done. # We look at all referrers to the pointer (and not just the # constraints directly on the pointer) because we want to # pick up object constraints that reference it as well. for cnstr in schema.get_referrers( pointer, scls_type=s_constr.Constraint ): self.pgops.add(ConstraintCommand.delete_constraint(cnstr, schema)) def _recreate_constraints(self, pointer, schema, context): for cnstr in schema.get_referrers( pointer, scls_type=s_constr.Constraint ): self.pgops.add(ConstraintCommand.create_constraint( self, cnstr, schema, context, )) def _alter_pointer_type(self, pointer, schema, orig_schema, context): old_ptr_stor_info = types.get_pointer_storage_info( pointer, schema=orig_schema) new_target = pointer.get_target(schema) ptr_table = old_ptr_stor_info.table_type == 'link' is_link = isinstance(pointer, s_links.Link) is_lprop = pointer.is_link_property(schema) is_multi = ptr_table and not is_lprop is_required = pointer.get_required(schema) changing_col_type = not is_link source_ctx = self.get_referrer_context_or_die(context) ptr_op = self.get_parent_op(context) if is_multi: source_op = ptr_op else: source_op = source_ctx.op # Ignore type narrowing resulting from a creation of a subtype # as there isn't any data in the link yet. if is_link and isinstance(source_ctx.op, sd.CreateObject): return new_target = pointer.get_target(schema) orig_target = pointer.get_target(orig_schema) new_type = types.pg_type_from_object( schema, new_target, persistent_tuples=True) source = source_op.scls cast_expr = self.cast_expr # For links, when the new type is a supertype of the old, no # SQL-level changes are necessary, unless an explicit conversion # expression was specified. if ( is_link and cast_expr is None and orig_target.issubclass(orig_schema, new_target) ): return # We actually have work to do, so drop any constraints we have self._drop_constraints(pointer, schema, context) if cast_expr is None and not is_link: # A lack of an explicit EdgeQL conversion expression means # that the new type is assignment-castable from the old type # in the EdgeDB schema. BUT, it would not necessarily be # assignment-castable in Postgres, especially if the types are # compound. Thus, generate an explicit cast expression. pname = pointer.get_shortname(schema).name cast_expr = s_expr.Expression.from_ast( ql_ast.TypeCast( expr=ql_ast.Path( partial=True, steps=[ ql_ast.Ptr( name=pname, type='property' if is_lprop else None, ), ], ), type=s_utils.typeref_to_ast(schema, new_target), ), schema=orig_schema, ) tab = q(*old_ptr_stor_info.table_name) target_col = old_ptr_stor_info.column_name aux_ptr_table = None aux_ptr_col = None source_rel_alias = f'source_{uuidgen.uuid1mc()}' # There are two major possibilities about the USING claus: # 1) trivial case, where the USING clause refers only to the # columns of the source table, in which case we simply compile that # into an equivalent SQL USING clause, and 2) complex case, which # supports arbitrary queries, but requires a temporary column, # which is populated with the transition query and then used as the # source for the SQL USING clause. (cast_expr_ctes, cast_expr_sql, expr_is_nullable) = ( self._compile_conversion_expr( pointer, cast_expr, source_rel_alias, schema=schema, orig_schema=orig_schema, context=context, check_non_null=is_required and not is_multi, produce_ctes=False, ) ) assert cast_expr_sql is not None need_temp_col = ( (is_multi and expr_is_nullable) or changing_col_type ) if is_link: old_lb_ptr_stor_info = types.get_pointer_storage_info( pointer, link_bias=True, schema=orig_schema) if ( old_lb_ptr_stor_info is not None and old_lb_ptr_stor_info.table_type == 'link' ): aux_ptr_table = old_lb_ptr_stor_info.table_name aux_ptr_col = old_lb_ptr_stor_info.column_name if need_temp_col: alter_table = source_op.get_alter_table( schema, context, force_new=True, manual=True) temp_column = dbops.Column( name=f'??{pointer.id}_{common.get_unique_random_name()}', type=qt(new_type), ) alter_table.add_operation( dbops.AlterTableAddColumn(temp_column)) self.pgops.add(alter_table) target_col = temp_column.name update_with = f'WITH {cast_expr_ctes}' if cast_expr_ctes else '' update_qry = f''' {update_with} UPDATE {tab} AS {qi(source_rel_alias)} SET {qi(target_col)} = ({cast_expr_sql}) ''' self.pgops.add(dbops.Query(update_qry)) trivial_cast_expr = qi(target_col) if changing_col_type or need_temp_col: alter_table = source_op.get_alter_table( schema, context, force_new=True, manual=True) if is_multi: # Remove all rows where the conversion expression produced NULLs. col = qi(target_col) if pointer.get_required(schema): clean_nulls = dbops.Query(textwrap.dedent(f'''\ WITH d AS ( DELETE FROM {tab} WHERE {col} IS NULL RETURNING source ) SELECT edgedb_VER.raise( NULL::text, 'not_null_violation', msg => 'missing value for required property', detail => '{{"object_id": "' || l.source || '"}}', "column" => {ql(str(pointer.id))} ) FROM {tab} AS l WHERE l.source IN (SELECT source FROM d) AND True = ALL ( SELECT {col} IS NULL FROM {tab} AS l2 WHERE l2.source = l.source ) LIMIT 1 INTO _dummy_text; ''')) else: clean_nulls = dbops.Query(textwrap.dedent(f'''\ DELETE FROM {tab} WHERE {col} IS NULL ''')) self.pgops.add(clean_nulls) elif aux_ptr_table is not None: # SINGLE links with link properties are represented in # _two_ tables (the host type table and a link table with # properties), and we must update both. actual_col = qi(old_ptr_stor_info.column_name) if expr_is_nullable and not is_required: cleanup_qry = textwrap.dedent(f'''\ DELETE FROM {q(*aux_ptr_table)} AS aux USING {tab} AS main WHERE main.id = aux.source AND {actual_col} IS NULL ''') self.pgops.add(dbops.Query(cleanup_qry)) update_qry = textwrap.dedent(f'''\ UPDATE {q(*aux_ptr_table)} AS aux SET {qi(aux_ptr_col)} = main.{actual_col} FROM {tab} AS main WHERE main.id = aux.source ''') self.pgops.add(dbops.Query(update_qry)) if changing_col_type: # In case the column has a default, clear it out before # changing the type alter_table.add_operation( dbops.AlterTableAlterColumnDefault( column_name=old_ptr_stor_info.column_name, default=None)) alter_type = dbops.AlterTableAlterColumnType( old_ptr_stor_info.column_name, common.quote_type(new_type), cast_expr=trivial_cast_expr, ) alter_table.add_operation(alter_type) elif need_temp_col: move_data = dbops.Query(textwrap.dedent(f'''\ UPDATE {q(*old_ptr_stor_info.table_name)} SET {qi(old_ptr_stor_info.column_name)} = ({qi(target_col)}) ''')) self.pgops.add(move_data) if need_temp_col: alter_table.add_operation(dbops.AlterTableDropColumn(temp_column)) if changing_col_type or need_temp_col: self.pgops.add(alter_table) self._recreate_constraints(pointer, schema, context) if changing_col_type: self.update_if_cfg_view(schema, context, source) def _compile_conversion_expr( self, pointer: s_pointers.Pointer, conv_expr: s_expr.Expression, source_alias: str, *, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, target_as_singleton: bool = True, check_non_null: bool = False, produce_ctes: bool = True, allow_globals: bool=False, ) -> tuple[ str, # CTE SQL Optional[str], # Query SQL bool, # is_nullable ]: """ Compile USING expression of an ALTER statement. producing_ctes contract: - Must be provided with alias of "source" rel - the relation that contains a row for each of the evaluations for the USING expression. - Source rel var must contain all columns of the __subject__ ObjectType. - Result is SQL string that contains CTEs, last of which has following signature: _conv_rel (id, val) not producing_ctes contract: - Alias of the source must refer to a relation var, not a relation. - Result is SQL string that contain a single SELECT statement that has a single value column. """ old_ptr_stor_info = types.get_pointer_storage_info( pointer, schema=orig_schema) ptr_table = old_ptr_stor_info.table_type == 'link' is_link = isinstance(pointer, s_links.Link) is_lprop = pointer.is_link_property(schema) new_target = not_none(pointer.get_target(schema)) if conv_expr.irast is None: conv_expr = self._compile_expr( orig_schema, context, conv_expr, target_as_singleton=target_as_singleton, make_globals_empty=allow_globals, no_query_rewrites=True, ) ir = conv_expr.irast assert ir if ir.stype != new_target and not is_link: # The result of an EdgeQL USING clause does not match # the target type exactly, but is castable. Like in the # case of an empty USING clause, we still have to make # ane explicit EdgeQL cast rather than rely on Postgres # casting. conv_expr = self._compile_expr( orig_schema, context, s_expr.Expression.from_ast( ql_ast.TypeCast( expr=conv_expr.parse(), type=s_utils.typeref_to_ast(schema, new_target), ), schema=orig_schema, ), target_as_singleton=target_as_singleton, make_globals_empty=allow_globals, no_query_rewrites=True, ) ir = conv_expr.irast if params := irutils.get_parameters(ir): param = list(params)[0] if param.is_global: if param.is_implicit_global: problem = 'functions that reference globals' else: problem = 'globals' else: problem = 'parameters' raise errors.UnsupportedFeatureError( f'{problem} may not be used when converting/populating ' f'data in migrations', span=self.span, ) # Non-trivial conversion expression means that we # are compiling a full-blown EdgeQL statement as # opposed to compiling a scalar fragment in trivial # expression mode. if is_lprop: # For linkprops we actually want the source path. # To make it work for abstract links, get the source # path out of the IR's output (to take advantage # of the types we made up for it). # FIXME: Maybe we shouldn't be compiling stuff # for abstract links! tgt_path_id = ir.singletons[0] else: tgt_path_id = irpathid.PathId.from_pointer( orig_schema, pointer, env=None ) refs = irutils.get_longest_paths(ir.expr) ref_tables = schemamech.get_ref_storage_info(ir.schema, refs) local_table_only = all( t == old_ptr_stor_info.table_name for t in ref_tables ) ptr_path_id = tgt_path_id.ptr_path() src_path_id = ptr_path_id.src_path() assert src_path_id external_rels = {} external_rvars = {} if produce_ctes: external_rels[src_path_id] = compiler.new_external_rel( rel_name=(source_alias,), path_id=src_path_id, ), (pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE) else: if ptr_table: rvar = compiler.new_external_rvar( rel_name=(source_alias,), path_id=ptr_path_id, outputs={ (src_path_id, (pgce.PathAspect.IDENTITY,)): 'source', }, ) external_rvars[ptr_path_id, pgce.PathAspect.SOURCE] = rvar external_rvars[ptr_path_id, pgce.PathAspect.VALUE] = rvar external_rvars[src_path_id, pgce.PathAspect.IDENTITY] = rvar external_rvars[tgt_path_id, pgce.PathAspect.IDENTITY] = rvar if local_table_only and not is_lprop: external_rvars[src_path_id, pgce.PathAspect.SOURCE] = rvar external_rvars[src_path_id, pgce.PathAspect.VALUE] = rvar elif is_lprop: external_rvars[tgt_path_id, pgce.PathAspect.VALUE] = rvar else: src_rvar = compiler.new_external_rvar( rel_name=(source_alias,), path_id=src_path_id, outputs={}, ) external_rvars[src_path_id, pgce.PathAspect.IDENTITY] = src_rvar external_rvars[src_path_id, pgce.PathAspect.VALUE] = src_rvar external_rvars[src_path_id, pgce.PathAspect.SOURCE] = src_rvar # Wrap the expression into a select with iterator, so DML and # volatile expressions are executed once for each object. # # The result is roughly equivalent to: # for obj in Object union select # generate a unique path id for the outer scope typ = orig_schema.get(f'schema::ObjectType', type=s_types.Type) outer_path = irast.PathId.from_type( orig_schema, typ, typename=sn.QualName("std", "obj"), env=None, ) root_uid = -1 iter_uid = -2 body_uid = -3 # scope tree wrapping is roughly equivalent to: # "(std::obj) uid:-1": { # "BRANCH uid:-2", # "FENCE uid:-3": { ... compiled scope children ... } # } scope_iter = irast.ScopeTreeNode( unique_id=iter_uid, ) scope_body = irast.ScopeTreeNode( unique_id=body_uid, fenced=True ) # Need to make a copy of the children list because # attach_child removes the node from the parent list. for child in list(ir.scope_tree.children): scope_body.attach_child(child) scope_node = irast.ScopeTreeNode( unique_id=root_uid, path_id=outer_path, ) scope_node.attach_child(scope_iter) scope_node.attach_child(scope_body) # The top-level node must be a fence. scope_root = irast.ScopeTreeNode(fenced=True) scope_root.attach_child(scope_node) ir.scope_tree = scope_root # IR ast wrapping assert isinstance(ir.expr, irast.Set) for_body = ir.expr for_body.path_scope_id = body_uid ir.expr = irast.Set( path_id=outer_path, typeref=outer_path.target, path_scope_id=root_uid, expr=irast.SelectStmt( iterator_stmt=irast.Set( path_id=src_path_id, typeref=src_path_id.target, path_scope_id=iter_uid, expr=irast.SelectStmt( result=irast.Set( path_scope_id=iter_uid, path_id=src_path_id, typeref=src_path_id.target, expr=irast.TypeRoot(typeref=src_path_id.target), ) ) ), result=for_body, ) ) # compile sql_res = compiler.compile_ir_to_sql_tree( ir, output_format=compiler.OutputFormat.NATIVE_INTERNAL, external_rels=external_rels, external_rvars=external_rvars, backend_runtime_params=context.backend_runtime_params, ) sql_tree = sql_res.ast assert isinstance(sql_tree, pgast.SelectStmt) if produce_ctes: # ensure the result contains the object id in the second column from edb.pgsql.compiler import pathctx pathctx.get_path_output( sql_tree, src_path_id, aspect=pgce.PathAspect.IDENTITY, env=sql_res.env, ) ctes = list(sql_tree.ctes or []) if sql_tree.ctes: sql_tree.ctes.clear() if check_non_null: # wrap into raise_on_null pointer_name = 'link' if is_link else 'property' msg = pgast.StringConstant( val=f"missing value for required {pointer_name}" ) # Concat to string which is a JSON. Great. Equivalent to SQL: # '{"object_id": "' || {obj_id_ref} || '"}' # We report 'source' for linkprops. Seems OK, I guess. id_col = 'source' if is_lprop else 'id' detail = pgast.Expr( name='||', lexpr=pgast.StringConstant(val='{"object_id": "'), rexpr=pgast.Expr( name='||', lexpr=pgast.ColumnRef(name=(id_col,)), rexpr=pgast.StringConstant(val='"}'), ) ) column = pgast.StringConstant(val=str(pointer.id)) null_check = pgast.FuncCall( name=("edgedb", "raise_on_null"), args=[ pgast.ColumnRef(name=("val",)), pgast.StringConstant(val="not_null_violation"), pgast.NamedFuncArg(name="msg", val=msg), pgast.NamedFuncArg(name="detail", val=detail), pgast.NamedFuncArg(name="column", val=column), ], ) inner_colnames = ["val"] target_list = [pgast.ResTarget(val=null_check)] if produce_ctes: inner_colnames.append("id") target_list.append( pgast.ResTarget(val=pgast.ColumnRef(name=("id",))) ) sql_tree = pgast.SelectStmt( target_list=target_list, from_clause=[ pgast.RangeSubselect( subquery=sql_tree, alias=pgast.Alias( aliasname="_inner", colnames=inner_colnames ) ) ] ) nullable = conv_expr.cardinality.can_be_zero() if produce_ctes: # convert root query into last CTE ctes.append( pgast.CommonTableExpr( name="_conv_rel", aliascolnames=["val", "id"], query=sql_tree, ) ) # compile to SQL ctes_sql = codegen.generate_ctes_source(ctes) return (ctes_sql, None, nullable) else: # keep CTEs and select separate ctes_sql = codegen.generate_ctes_source(ctes) select_sql = codegen.generate_source(sql_tree) return (ctes_sql, select_sql, nullable) def schedule_endpoint_delete_action_update( self, link, orig_schema, schema, context ): endpoint_delete_actions = context.get( sd.DeltaRootContext).op.update_endpoint_delete_actions link_ops = endpoint_delete_actions.link_ops if isinstance(self, sd.DeleteObject): for i, (_, ex_link, _, _) in enumerate(link_ops): if ex_link == link: link_ops.pop(i) break link_ops.append((self, link, orig_schema, schema)) class LinkMetaCommand(PointerMetaCommand[s_links.Link]): @classmethod def _create_table( cls, link: s_links.Link, schema: s_schema.Schema, context: sd.CommandContext, conditional: bool = False, create_children: bool = True, ): new_table_name = cls._get_table_name(link, schema) create_c = dbops.CommandGroup() constraints = [] columns = [] src_col = 'source' tgt_col = 'target' columns.append( dbops.Column( name=src_col, type='uuid', required=True)) columns.append( dbops.Column( name=tgt_col, type='uuid', required=True)) constraints.append( dbops.UniqueConstraint( table_name=new_table_name, columns=[src_col, tgt_col])) if not link.is_non_concrete(schema) and link.is_property(): tgt_prop = link.getptr(schema, sn.UnqualName('target')) tgt_ptr = types.get_pointer_storage_info( tgt_prop, schema=schema) columns.append( dbops.Column( name=tgt_ptr.column_name, type=common.qname(*tgt_ptr.column_type))) table = dbops.Table(name=new_table_name) table.add_columns(columns) table.constraints = ordered.OrderedSet(constraints) ct = dbops.CreateTable(table=table) index_name = common.edgedb_name_to_pg_name( str(link.id) + '_target_key') index = dbops.Index( index_name, new_table_name, unique=False, metadata={'code': DEFAULT_INDEX_CODE}, ) index.add_columns([tgt_col]) ci = dbops.CreateIndex(index) if conditional: c = dbops.CommandGroup( neg_conditions=[dbops.TableExists(new_table_name)]) else: c = dbops.CommandGroup() c.add_command(ct) c.add_command(ci) c.add_command( dbops.Comment( table, str(link.get_verbosename(schema, with_parent=True)), ), ) create_c.add_command(c) if create_children: for l_descendant in link.descendants(schema): if types.has_table(l_descendant, schema): lc = LinkMetaCommand._create_table( l_descendant, schema, context, conditional=True, create_children=False, ) create_c.add_command(lc) return create_c def _create_link( self, link: s_links.Link, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, ) -> None: objtype = context.get(s_objtypes.ObjectTypeCommandContext) source = link.get_source(schema) if source is not None: source_is_view = ( source.is_view(schema) or source.is_compound_type(schema) or source.get_is_derived(schema) ) else: source_is_view = None if types.has_table(self.scls, schema): self.create_table(self.scls, schema, context) if ( source is not None and not source_is_view and not link.is_pure_computable(schema) ): # We optimize away __type__ and don't store it. # Nothing to do except make sure the inhviews get updated. if link.get_shortname(schema).name == '__type__': self.update_source_if_cfg_view( schema, context, link ) return assert objtype ptr_stor_info = types.get_pointer_storage_info( link, resolve_type=False, schema=schema) fills_required = any( x.fill_expr for x in self.get_subcommands( type=s_pointers.AlterPointerLowerCardinality)) sets_required = bool( self.get_subcommands( type=s_pointers.AlterPointerLowerCardinality)) if ptr_stor_info.table_type == 'ObjectType': cols = self.get_columns( link, schema, None, sets_required) table_name = objtype.op.table_name # type: ignore assert isinstance(objtype.op, CompositeMetaCommand) objtype_alter_table = objtype.op.get_alter_table( schema, context, manual=True) for col in cols: cmd = dbops.AlterTableAddColumn(col) objtype_alter_table.add_operation(cmd) self.pgops.add(objtype_alter_table) index_name = common.get_backend_name( schema, link, catenate=False, aspect='index' )[1] pg_index = dbops.Index( name=index_name, table_name=table_name, unique=False, columns=[c.name for c in cols], inherit=True, metadata={ 'code': DEFAULT_INDEX_CODE, }, ) ci = dbops.CreateIndex(pg_index) self.pgops.add(ci) self.update_source_if_cfg_view( schema, context, link ) if ( (default := link.get_default(schema)) and not link.is_pure_computable(schema) and not fills_required ): self._alter_pointer_optionality( schema, schema, context, fill_expr=default, is_default=True) # If we're creating a required multi pointer without a SET # REQUIRED USING inside, run the alter_pointer_optionality # path to produce an error if there is existing data. elif ( link.get_cardinality(schema).is_multi() and link.get_required(schema) and not link.is_pure_computable(schema) and not sets_required ): self._alter_pointer_optionality( schema, schema, context, fill_expr=None) if not link.is_pure_computable(schema): self.schedule_endpoint_delete_action_update( link, orig_schema, schema, context) def _delete_link( self, link: s_links.Link, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, ) -> None: # We optimize away __type__ and don't store it. Nothing to do. if link.get_shortname(schema).name == '__type__': return old_table_name = self._get_table_name(link, schema) if ( not link.is_non_concrete(orig_schema) and types.has_table(link.get_source(orig_schema), orig_schema) and not link.is_pure_computable(orig_schema) ): ptr_stor_info = types.get_pointer_storage_info( link, schema=orig_schema) objtype = context.get(s_objtypes.ObjectTypeCommandContext) assert objtype if (not isinstance(objtype.op, s_objtypes.DeleteObjectType) and ptr_stor_info.table_type == 'ObjectType'): self.update_if_cfg_view(schema, context, objtype.scls) assert isinstance(objtype.op, CompositeMetaCommand) alter_table = objtype.op.get_alter_table( schema, context, manual=True) col = dbops.Column( name=ptr_stor_info.column_name, type=common.qname(*ptr_stor_info.column_type)) colop = dbops.AlterTableDropColumn(col) alter_table.add_operation(colop) self.pgops.add(alter_table) self.attach_alter_table(context) if types.has_table(link, orig_schema): condition = dbops.TableExists(name=old_table_name) self.pgops.add( dbops.DropTable(name=old_table_name, conditions=[condition])) self.schedule_endpoint_delete_action_update( link, orig_schema, schema, context) class CreateLink(LinkMetaCommand, adapts=s_links.CreateLink): def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super()._create_begin(schema, context) link = self.scls self.table_name = self._get_table_name(self.scls, schema) self._create_link(link, schema, orig_schema, context) return schema def _create_finalize(self, schema, context): schema = super()._create_finalize(schema, context) self.apply_constraint_trigger_updates(schema) self.schedule_trampoline(self.scls, schema, context) return schema class RenameLink(LinkMetaCommand, adapts=s_links.RenameLink): pass class RebaseLink(LinkMetaCommand, adapts=s_links.RebaseLink): def _alter_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = context.current().original_schema if types.has_table(self.scls, schema): self.update_if_cfg_view(schema, context, self.scls) schema = super()._alter_innards(schema, context) self.schedule_endpoint_delete_action_update( self.scls, orig_schema, schema, context) return schema class SetLinkType(LinkMetaCommand, adapts=s_links.SetLinkType): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super()._alter_begin(schema, context) orig_type = self.scls.get_target(orig_schema) new_type = self.scls.get_target(schema) if ( types.has_table(self.scls.get_source(schema), schema) and not self.scls.is_pure_computable(schema) and (orig_type != new_type or self.cast_expr is not None) ): self._alter_pointer_type(self.scls, schema, orig_schema, context) self.schedule_endpoint_delete_action_update( self.scls, orig_schema, schema, context) return schema class AlterLinkUpperCardinality( LinkMetaCommand, adapts=s_links.AlterLinkUpperCardinality, ): def _alter_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = context.current().original_schema # We need to run the parent change *before* the children, # or else the view update in the child might fail if a # link table isn't created in the parent yet. if ( not self.scls.is_non_concrete(schema) and not self.scls.is_pure_computable(schema) and types.has_table(self.scls.get_source(schema), schema) ): orig_card = self.scls.get_cardinality(orig_schema) new_card = self.scls.get_cardinality(schema) if orig_card != new_card: self._alter_pointer_cardinality(schema, orig_schema, context) return super()._alter_innards(schema, context) class AlterLinkLowerCardinality( LinkMetaCommand, adapts=s_links.AlterLinkLowerCardinality, ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super().apply(schema, context) if not self.scls.is_non_concrete(schema): orig_required = self.scls.get_required(orig_schema) new_required = self.scls.get_required(schema) if ( types.has_table(self.scls.get_source(schema), schema) and not self.scls.is_endpoint_pointer(schema) and not self.scls.is_pure_computable(schema) and orig_required != new_required ): self._alter_pointer_optionality( schema, orig_schema, context, fill_expr=self.fill_expr) # If the link has an Allow on target delete action, we # need to refresh the triggers, since required and # optional behave differently. (Required does a # check.) if ( self.scls.get_on_target_delete(schema) == s_links.LinkTargetDeleteAction.Allow ): self.schedule_endpoint_delete_action_update( self.scls, orig_schema, schema, context) return schema class AlterLinkOwned(LinkMetaCommand, adapts=s_links.AlterLinkOwned): pass class AlterLink(LinkMetaCommand, adapts=s_links.AlterLink): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) # We want to set this name up early, so children operations see it self.table_name = self._get_table_name(self.scls, schema) return schema def _alter_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = context.current().original_schema link = self.scls is_abs = link.is_non_concrete(schema) is_comp = link.is_pure_computable(schema) was_comp = link.is_pure_computable(orig_schema) if not is_abs and (was_comp and not is_comp): self._create_link(link, schema, orig_schema, context) schema = super()._alter_innards(schema, context) if not is_abs and (not was_comp and is_comp): self._delete_link(link, schema, orig_schema, context) # We check whether otd has changed, rather than whether # it is an attribute on this alter, because it might # live on a nested SetOwned, for example. otd_changed = ( link.get_on_target_delete(orig_schema) != link.get_on_target_delete(schema) ) osd_changed = ( link.get_on_source_delete(orig_schema) != link.get_on_source_delete(schema) ) card_changed = ( link.get_cardinality(orig_schema) != link.get_cardinality(schema) ) if ( (otd_changed or osd_changed or card_changed) and not link.is_pure_computable(schema) ): self.schedule_endpoint_delete_action_update( link, orig_schema, schema, context) return schema def _alter_finalize(self, schema, context): schema = super()._alter_finalize(schema, context) self.apply_constraint_trigger_updates(schema) return schema class DeleteLink(LinkMetaCommand, adapts=s_links.DeleteLink): def _delete_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = context.current().original_schema link = schema.get(self.classname, type=s_links.Link) schema = super()._delete_innards(schema, context) self._delete_link(link, schema, orig_schema, context) return schema def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) self.apply_constraint_trigger_updates(schema) return schema class PropertyMetaCommand(PointerMetaCommand[s_props.Property]): @classmethod def _create_table( cls, prop: s_props.Property, schema: s_schema.Schema, context: sd.CommandContext, conditional: bool = False, create_children: bool = True, ): new_table_name = cls._get_table_name(prop, schema) create_c = dbops.CommandGroup() columns = [] src_col = common.edgedb_name_to_pg_name('source') columns.append( dbops.Column( name=src_col, type='uuid', required=True)) id = sn.QualName( module=prop.get_name(schema).module, name=str(prop.id)) index_name = common.convert_name(id, 'idx0', catenate=False) pg_index = dbops.Index( name=index_name[1], table_name=new_table_name, unique=False, columns=[src_col], metadata={'code': DEFAULT_INDEX_CODE}, ) ci = dbops.CreateIndex(pg_index) if not prop.is_non_concrete(schema): tgt_cols = cls.get_columns(prop, schema, None) columns.extend(tgt_cols) table = dbops.Table(name=new_table_name) table.add_columns(columns) ct = dbops.CreateTable(table=table) if conditional: c = dbops.CommandGroup( neg_conditions=[dbops.TableExists(new_table_name)]) else: c = dbops.CommandGroup() c.add_command(ct) c.add_command(ci) c.add_command( dbops.Comment( table, str(prop.get_verbosename(schema, with_parent=True)), ), ) create_c.add_command(c) if create_children: for p_descendant in prop.descendants(schema): if types.has_table(p_descendant, schema): pc = PropertyMetaCommand._create_table( p_descendant, schema, context, conditional=True, create_children=False, ) create_c.add_command(pc) return create_c def _create_property( self, prop: s_props.Property, src: Optional[sd.ObjectCommandContext[s_sources.Source]], schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, ) -> None: propname = prop.get_shortname(schema).name if types.has_table(prop, schema): self.create_table(prop, schema, context) if ( src and types.has_table(src.scls, schema) and not prop.is_pure_computable(schema) ): if ( isinstance(src.scls, s_links.Link) and not types.has_table(src.scls, orig_schema) ): ct = src.op._create_table( # type: ignore src.scls, schema, context) self.pgops.add(ct) ptr_stor_info = types.get_pointer_storage_info( prop, resolve_type=False, schema=schema) fills_required = any( x.fill_expr for x in self.get_subcommands( type=s_pointers.AlterPointerLowerCardinality)) sets_required = bool( self.get_subcommands( type=s_pointers.AlterPointerLowerCardinality)) if ( not isinstance(src.scls, s_objtypes.ObjectType) or ptr_stor_info.table_type == 'ObjectType' ): if ( not isinstance(src.scls, s_links.Link) or propname not in {'source', 'target'} ): assert isinstance(src.op, CompositeMetaCommand) alter_table = src.op.get_alter_table( schema, context, force_new=True, manual=True, ) default_value = self.get_pointer_default( prop, schema, context) if ( isinstance(src.scls, s_links.Link) and not default_value and prop.get_default(schema) ): raise errors.UnsupportedFeatureError( f'default value for ' f'{prop.get_verbosename(schema, with_parent=True)}' f' is too complicated; link property defaults ' f'must not depend on database contents', span=self.span) cols = self.get_columns( prop, schema, default_value, sets_required) for col in cols: cmd = dbops.AlterTableAddColumn(col) alter_table.add_operation(cmd) self.pgops.add(alter_table) self.update_source_if_cfg_view( schema, context, prop ) if ( (default := prop.get_default(schema)) and not prop.is_pure_computable(schema) and not fills_required and not irtyputils.is_cfg_view(src.scls, schema) # sigh # link properties use SQL defaults and shouldn't need # us to do it explicitly (which is good, since # _alter_pointer_optionality doesn't currently work on # linkprops) and not prop.is_link_property(schema) ): self._alter_pointer_optionality( schema, schema, context, fill_expr=default, is_default=True) # If we're creating a required multi pointer without a SET # REQUIRED USING inside, run the alter_pointer_optionality # path to produce an error if there is existing data. elif ( prop.get_cardinality(schema).is_multi() and prop.get_required(schema) and not prop.is_pure_computable(schema) and not sets_required ): self._alter_pointer_optionality( schema, schema, context, fill_expr=None) if not prop.is_pure_computable(schema): self.schedule_endpoint_delete_action_update( prop, orig_schema, schema, context) def _delete_property( self, prop: s_props.Property, source: s_sources.Source, source_op, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, ) -> None: if types.has_table(source, schema): ptr_stor_info = types.get_pointer_storage_info( prop, schema=schema, link_bias=prop.is_link_property(schema), ) if ( ptr_stor_info.table_type == 'ObjectType' or prop.is_link_property(schema) ): alter_table = source_op.get_alter_table( schema, context, force_new=True, manual=True) # source and target don't have a proper inheritence # hierarchy, so we can't do the source trick for them self.update_if_cfg_view(schema, context, source) col = dbops.AlterTableDropColumn( dbops.Column(name=ptr_stor_info.column_name, type=ptr_stor_info.column_type)) alter_table.add_operation(col) self.pgops.add(alter_table) elif ( prop.is_link_property(schema) and types.has_table(source, orig_schema) ): old_table_name = self._get_table_name(source, orig_schema) self.pgops.add(dbops.DropTable(name=old_table_name)) if types.has_table(prop, orig_schema): old_table_name = self._get_table_name(prop, orig_schema) self.pgops.add(dbops.DropTable(name=old_table_name)) self.schedule_endpoint_delete_action_update( prop, orig_schema, schema, context) class CreateProperty(PropertyMetaCommand, adapts=s_props.CreateProperty): def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super()._create_begin(schema, context) prop = self.scls src = context.get(s_sources.SourceCommandContext) self.table_name = self._get_table_name(prop, schema) self._create_property(prop, src, schema, orig_schema, context) self.schedule_trampoline(self.scls, schema, context) return schema class RenameProperty(PropertyMetaCommand, adapts=s_props.RenameProperty): pass class RebaseProperty(PropertyMetaCommand, adapts=s_props.RebaseProperty): def _alter_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = context.current().original_schema if types.has_table(self.scls, schema): self.update_if_cfg_view(schema, context, self.scls) schema = super()._alter_innards(schema, context) if not self.scls.is_pure_computable(schema): self.schedule_endpoint_delete_action_update( self.scls, orig_schema, schema, context) return schema class SetPropertyType(PropertyMetaCommand, adapts=s_props.SetPropertyType): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super()._alter_begin(schema, context) orig_type = self.scls.get_target(orig_schema) new_type = self.scls.get_target(schema) if ( types.has_table(self.scls.get_source(schema), schema) and not self.scls.is_pure_computable(schema) and not self.scls.is_endpoint_pointer(schema) and (orig_type != new_type or self.cast_expr is not None) ): self._alter_pointer_type(self.scls, schema, orig_schema, context) return schema class AlterPropertyUpperCardinality( PropertyMetaCommand, adapts=s_props.AlterPropertyUpperCardinality, ): def _alter_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = context.current().original_schema # We need to run the parent change *before* the children, # or else the view update in the child might fail if a # link table isn't created in the parent yet. if ( not self.scls.is_non_concrete(schema) and not self.scls.is_pure_computable(schema) and not self.scls.is_endpoint_pointer(schema) and types.has_table(self.scls.get_source(schema), schema) ): orig_card = self.scls.get_cardinality(orig_schema) new_card = self.scls.get_cardinality(schema) if orig_card != new_card: self._alter_pointer_cardinality(schema, orig_schema, context) return super()._alter_innards(schema, context) class AlterPropertyLowerCardinality( PropertyMetaCommand, adapts=s_props.AlterPropertyLowerCardinality, ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super().apply(schema, context) if not self.scls.is_non_concrete(schema): orig_required = self.scls.get_required(orig_schema) new_required = self.scls.get_required(schema) if ( types.has_table(self.scls.get_source(schema), schema) and not self.scls.is_endpoint_pointer(schema) and not self.scls.is_pure_computable(schema) and orig_required != new_required ): self._alter_pointer_optionality( schema, orig_schema, context, fill_expr=self.fill_expr) return schema class AlterPropertyOwned( PropertyMetaCommand, adapts=s_props.AlterPropertyOwned, ): pass class AlterProperty(PropertyMetaCommand, adapts=s_props.AlterProperty): def _alter_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: prop = self.scls orig_schema = context.current().original_schema src = context.get(s_sources.SourceCommandContext) is_comp = prop.is_pure_computable(schema) was_comp = prop.is_pure_computable(orig_schema) if src and (was_comp and not is_comp): self._create_property(prop, src, schema, orig_schema, context) schema = super()._alter_innards(schema, context) if src and (not was_comp and is_comp): self._delete_property( prop, src.scls, src.op, schema, orig_schema, context) if self.metadata_only: return schema if ( not is_comp and (src and types.has_table(src.scls, schema)) ): orig_def_val = self.get_pointer_default(prop, orig_schema, context) def_val = self.get_pointer_default(prop, schema, context) if orig_def_val != def_val: if prop.get_cardinality(schema).is_multi(): source_op: sd.Command = self else: source_op = not_none(context.get_ancestor( s_sources.SourceCommandContext, self)).op assert isinstance(source_op, CompositeMetaCommand) alter_table = source_op.get_alter_table( schema, context, manual=True) ptr_stor_info = types.get_pointer_storage_info( prop, schema=schema) alter_table.add_operation( dbops.AlterTableAlterColumnDefault( column_name=ptr_stor_info.column_name, default=def_val)) self.pgops.add(alter_table) card = self.get_resolved_attribute_value( 'cardinality', schema=schema, context=context, ) if card: self.schedule_endpoint_delete_action_update( prop, orig_schema, schema, context) return schema class DeleteProperty(PropertyMetaCommand, adapts=s_props.DeleteProperty): def _delete_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._delete_innards(schema, context) prop = self.scls orig_schema = context.current().original_schema source_ctx = self.get_referrer_context(context) if source_ctx is not None: source = source_ctx.scls source_op = source_ctx.op else: source = None source_op = None if source and not prop.is_pure_computable(schema): assert isinstance(source, s_sources.Source) self._delete_property( prop, source, source_op, schema, orig_schema, context) return schema class CreateTrampolines(MetaCommand): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.trampolines: list[trampoline.Trampoline] = [] self.table_targets: list[s_objtypes.ObjectType | s_pointers.Pointer] = ( [] ) def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: for obj in self.table_targets: if not (schema.has_object(obj.id) and types.has_table(obj, schema)): continue if tramp := CompositeMetaCommand.create_type_trampoline( schema, obj ): self.trampolines.append(tramp) for t in self.trampolines: self.pgops.add(t.make()) return schema class UpdateEndpointDeleteActions(MetaCommand): def __init__(self, **kwargs): super().__init__(**kwargs) self.link_ops = [] self.changed_targets = set() def _get_link_table_union(self, schema, links) -> str: selects = [] for link in links: selects.append(textwrap.dedent('''\ (SELECT {id}::uuid AS __sobj_id__, {src} as source, {tgt} as target FROM {table}) ''').format( id=ql(str(link.id)), src=common.quote_ident('source'), tgt=common.quote_ident('target'), table=common.get_backend_name( schema, link, ), )) return '(' + '\nUNION ALL\n '.join(selects) + ') as q' def _get_inline_link_table_union( self, schema, links ) -> str: selects = [] for link in links: link_psi = types.get_pointer_storage_info(link, schema=schema) link_col = link_psi.column_name selects.append(textwrap.dedent('''\ (SELECT {id}::uuid AS __sobj_id__, {src} as source, {tgt} as target FROM {table}) ''').format( id=ql(str(link.id)), src=common.quote_ident('id'), tgt=common.quote_ident(link_col), table=common.get_backend_name( schema, link.get_source(schema), ), )) return '(' + '\nUNION ALL\n '.join(selects) + ') as q' def get_target_objs(self, link, schema): tgt = link.get_target(schema) if union := tgt.get_union_of(schema).objects(schema): objs = set(union) else: objs = {tgt} objs |= { x for obj in objs for x in obj.descendants(schema)} return { obj for obj in objs if ( not obj.is_view(schema) and not irtyputils.is_cfg_view(obj, schema) ) } def get_orphan_link_ancestors(self, link, schema): val = s_links.LinkSourceDeleteAction.DeleteTargetIfOrphan if link.get_on_source_delete(schema) != val: return set() ancestors = { x for base in link.get_bases(schema).objects(schema) for x in self.get_orphan_link_ancestors(base, schema) } if ancestors: return ancestors else: return {link} def get_trigger_name( self, schema, target, disposition, deferred=False, inline=False ): if disposition == 'target': aspect = 'target-del' else: aspect = 'source-del' if deferred: aspect += '-def' else: aspect += '-imm' if inline: aspect += '-inl' else: aspect += '-otl' aspect += '-t' # Postgres applies triggers in alphabetical order, and # the names are uuids, which are not useful here. # # All we want for now is for source triggers to apply first, # though, so that a loop of objects with # 'on source delete delete target' + 'on target delete restrict' # succeeds. # # Fortunately S comes before T. order_prefix = disposition[0] name = common.get_backend_name(schema, target, catenate=False) return f'{order_prefix}_{name[1]}_{aspect}' def get_trigger_proc_name( self, schema, target, disposition, deferred=False, inline=False ): if disposition == 'target': aspect = 'target-del' else: aspect = 'source-del' if deferred: aspect += '-def' else: aspect += '-imm' if inline: aspect += '-inl' else: aspect += '-otl' aspect += '-f' name = common.get_backend_name(schema, target, catenate=False) return (name[0], f'{name[1]}_{aspect}') def get_trigger_proc_text( self, target, links, *, disposition, inline, schema, context, ): if inline: return self._get_inline_link_trigger_proc_text( target, links, disposition=disposition, schema=schema, context=context) else: return self._get_outline_link_trigger_proc_text( target, links, disposition=disposition, schema=schema, context=context) def _get_outline_link_trigger_proc_text( self, target, links, *, disposition, schema, context ): chunks = [] DA = s_links.LinkTargetDeleteAction if disposition == 'target': groups = itertools.groupby( links, lambda l: l.get_on_target_delete(schema)) near_endpoint, far_endpoint = 'target', 'source' else: groups = itertools.groupby( links, lambda l: ( l.get_on_source_delete(schema) if isinstance(l, s_links.Link) else s_links.LinkSourceDeleteAction.Allow)) near_endpoint, far_endpoint = 'source', 'target' for action, links in groups: if action is DA.Restrict or action is DA.DeferredRestrict: tables = self._get_link_table_union(schema, links) # We want versioned for stdlib (since the trampolines # don't exist yet) but trampolined for user code prefix = 'edgedb_VER' if context.stdmode else 'edgedb' text = textwrap.dedent(trampoline.fixup_query('''\ SELECT q.__sobj_id__, q.source, q.target INTO link_type_id, srcid, tgtid FROM {tables} WHERE q.{near_endpoint} = OLD.{id} LIMIT 1; IF FOUND THEN SELECT {prefix}.shortname_from_fullname(link.name), {prefix}._get_schema_object_name( link.{far_endpoint}) INTO linkname, endname FROM {prefix}._schema_links AS link WHERE link.id = link_type_id; RAISE foreign_key_violation USING TABLE = TG_TABLE_NAME, SCHEMA = TG_TABLE_SCHEMA, MESSAGE = 'deletion of {tgtname} (' || tgtid || ') is prohibited by link target policy', DETAIL = 'Object is still referenced in link ' || linkname || ' of ' || endname || ' (' || srcid || ').'; END IF; '''.format( tables=tables, id='id', tgtname=target.get_displayname(schema), near_endpoint=near_endpoint, far_endpoint=far_endpoint, prefix=prefix, ))) chunks.append(text) elif ( action == s_links.LinkTargetDeleteAction.Allow or action == s_links.LinkSourceDeleteAction.Allow ): for link in links: link_table = common.get_backend_name( schema, link) # Since enforcement of 'required' on multi links # is enforced manually on the query side and (not # through constraints/triggers of its own), we # also need to do manual enforcement of it when # deleting a required multi link. if link.get_required(schema) and disposition == 'target': required_text = textwrap.dedent('''\ SELECT q.source INTO srcid FROM {link_table} as q WHERE q.target = OLD.{id} AND NOT EXISTS ( SELECT FROM {link_table} as q2 WHERE q.source = q2.source AND q2.target != OLD.{id} ); IF FOUND THEN RAISE not_null_violation USING TABLE = TG_TABLE_NAME, SCHEMA = TG_TABLE_SCHEMA, MESSAGE = 'missing value', COLUMN = '{link_id}'; END IF; ''').format( link_table=link_table, link_id=str(link.id), id='id' ) chunks.append(required_text) # Otherwise just delete it from the link table. text = textwrap.dedent('''\ DELETE FROM {link_table} WHERE {endpoint} = OLD.{id}; ''').format( link_table=link_table, endpoint=common.quote_ident(near_endpoint), id='id' ) chunks.append(text) elif action == s_links.LinkTargetDeleteAction.DeleteSource: sources = collections.defaultdict(list) for link in links: sources[link.get_source(schema)].append(link) for source, source_links in sources.items(): tables = self._get_link_table_union(schema, source_links) text = textwrap.dedent('''\ DELETE FROM {source_table} WHERE {source_table}.{id} IN ( SELECT source FROM {tables} WHERE target = OLD.{id} ); ''').format( source_table=common.get_backend_name(schema, source), id='id', tables=tables, ) chunks.append(text) elif ( action == s_links.LinkSourceDeleteAction.DeleteTarget or action == s_links.LinkSourceDeleteAction.DeleteTargetIfOrphan ): for link in links: link_table = common.get_backend_name(schema, link) objs = self.get_target_objs(link, schema) # If the link is DELETE TARGET IF ORPHAN, build # filters to ignore any objects that aren't # orphans (wrt to this link). roots = { x for root in self.get_orphan_link_ancestors(link, schema) for x in [root, *root.descendants(schema)] } orphan_check = '' for orphan_check_root in roots: if not types.has_table(orphan_check_root, schema): continue check_table = common.get_backend_name( schema, orphan_check_root ) orphan_check += f'''\ AND NOT EXISTS ( SELECT FROM {check_table} as q2 WHERE q.target = q2.target AND q2.source != OLD.id ) '''.strip() # We find all the objects to delete in a CTE, then # delete the link table entries, and then delete # the targets. We apply the non-orphan filter when # finding the objects. prefix = textwrap.dedent(f'''\ WITH range AS ( SELECT target FROM {link_table} as q WHERE q.source = OLD.id {orphan_check} ), del AS ( DELETE FROM {link_table} WHERE source = OLD.id ) ''').strip() parts = [prefix] for i, obj in enumerate(objs): tgt_table = common.get_backend_name(schema, obj) text = textwrap.dedent(f'''\ d{i} AS ( DELETE FROM {tgt_table} WHERE {tgt_table}.id IN ( SELECT target FROM range ) ) ''').strip() parts.append(text) full = ',\n'.join(parts) + "\nSELECT '' INTO _dummy_text;" chunks.append(full) text = textwrap.dedent('''\ DECLARE link_type_id uuid; srcid uuid; tgtid uuid; linkname text; endname text; _dummy_text text; BEGIN {chunks} RETURN OLD; END; ''').format(chunks='\n\n'.join(chunks)) return text def _get_inline_link_trigger_proc_text( self, target, links, *, disposition, schema, context ): chunks = [] DA = s_links.LinkTargetDeleteAction if disposition == 'target': groups = itertools.groupby( links, lambda l: l.get_on_target_delete(schema)) else: groups = itertools.groupby( links, lambda l: l.get_on_source_delete(schema)) near_endpoint, far_endpoint = 'target', 'source' for action, links in groups: if action is DA.Restrict or action is DA.DeferredRestrict: tables = self._get_inline_link_table_union(schema, links) # We want versioned for stdlib (since the trampolines # don't exist yet) but trampolined for user code prefix = 'edgedb_VER' if context.stdmode else 'edgedb' text = textwrap.dedent(trampoline.fixup_query('''\ SELECT q.__sobj_id__, q.source, q.target INTO link_type_id, srcid, tgtid FROM {tables} WHERE q.{near_endpoint} = OLD.{id} LIMIT 1; IF FOUND THEN SELECT {prefix}.shortname_from_fullname(link.name), {prefix}._get_schema_object_name( link.{far_endpoint}) INTO linkname, endname FROM {prefix}._schema_links AS link WHERE link.id = link_type_id; RAISE foreign_key_violation USING TABLE = TG_TABLE_NAME, SCHEMA = TG_TABLE_SCHEMA, MESSAGE = 'deletion of {tgtname} (' || tgtid || ') is prohibited by link target policy', DETAIL = 'Object is still referenced in link ' || linkname || ' of ' || endname || ' (' || srcid || ').'; END IF; '''.format( tables=tables, id='id', tgtname=target.get_displayname(schema), near_endpoint=near_endpoint, far_endpoint=far_endpoint, prefix=prefix, ))) chunks.append(text) elif action == s_links.LinkTargetDeleteAction.Allow: for link in links: link_psi = types.get_pointer_storage_info( link, schema=schema) link_col = link_psi.column_name source_table = common.get_backend_name( schema, link.get_source(schema)) text = textwrap.dedent(f'''\ UPDATE {source_table} SET {qi(link_col)} = NULL WHERE {qi(link_col)} = OLD.id; ''') chunks.append(text) elif action == s_links.LinkTargetDeleteAction.DeleteSource: sources = collections.defaultdict(list) for link in links: sources[link.get_source(schema)].append(link) for source, source_links in sources.items(): tables = self._get_inline_link_table_union( schema, source_links) text = textwrap.dedent('''\ DELETE FROM {source_table} WHERE {source_table}.{id} IN ( SELECT source FROM {tables} WHERE target = OLD.{id} ); ''').format( source_table=common.get_backend_name(schema, source), id='id', tables=tables, ) chunks.append(text) elif ( action == s_links.LinkSourceDeleteAction.DeleteTarget or action == s_links.LinkSourceDeleteAction.DeleteTargetIfOrphan ): for link in links: objs = self.get_target_objs(link, schema) link_psi = types.get_pointer_storage_info( link, schema=schema) link_col = common.quote_ident(link_psi.column_name) # If the link is DELETE TARGET IF ORPHAN, filter out # any objects that aren't orphans (wrt to this link). roots = { x for root in self.get_orphan_link_ancestors(link, schema) for x in [root, *root.descendants(schema)] } orphan_check = '' for orphan_check_root in roots: check_source = orphan_check_root.get_source(schema) if not types.has_table(check_source, schema): continue check_table = common.get_backend_name( schema, check_source ) check_link_psi = types.get_pointer_storage_info( orphan_check_root, schema=schema) check_link_col = common.quote_ident( check_link_psi.column_name) orphan_check += f'''\ AND NOT EXISTS ( SELECT FROM {check_table} as q2 WHERE q2.{check_link_col} = OLD.{link_col} AND q2.id != OLD.id ) '''.strip() # Do the orphan check (which trivially succeeds if # the link isn't IF ORPHAN) text = textwrap.dedent(f'''\ SELECT ( SELECT true {orphan_check} ) INTO ok; ''').strip() chunks.append(text) for obj in objs: tgt_table = common.get_backend_name(schema, obj) text = textwrap.dedent(f'''\ IF ok THEN DELETE FROM {tgt_table} WHERE {tgt_table}.id = OLD.{link_col}; END IF; ''') chunks.append(text) text = textwrap.dedent('''\ DECLARE link_type_id uuid; srcid uuid; tgtid uuid; linkname text; endname text; ok bool; links text[]; BEGIN {chunks} RETURN OLD; END; ''').format(chunks='\n\n'.join(chunks)) return text def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if not self.link_ops and not self.changed_targets: return schema DA = s_links.LinkTargetDeleteAction DS = s_links.LinkSourceDeleteAction affected_sources: set[s_sources.Source] = set() affected_targets = {t for _, t in self.changed_targets} modifications = any( isinstance(op, RebaseObjectType) and op.removed_bases for op, _ in self.changed_targets ) for link_op, link, orig_schema, eff_schema in self.link_ops: # Skip __type__ triggers, since __type__ isn't real and # also would be a huge pain to update each time if it was. if link.get_shortname(eff_schema).name == '__type__': continue if ( isinstance(link_op, (DeleteProperty, DeleteLink)) or ( link.is_pure_computable(eff_schema) and not link.is_pure_computable(orig_schema) ) ): source = link.get_source(orig_schema) if source: current_source = schema.get_by_id(source.id, None) if (current_source is not None and not current_source.is_view(schema)): modifications = True affected_sources.add(current_source) if not eff_schema.has_object(link.id): continue target_is_affected = isinstance(link, s_links.Link) if link.is_non_concrete(eff_schema) or ( link.is_pure_computable(eff_schema) and link.is_pure_computable(orig_schema) ): continue source = link.get_source(eff_schema) target = link.get_target(eff_schema) if ( not isinstance(source, s_objtypes.ObjectType) or irtyputils.is_cfg_view(source, eff_schema) ): continue if not isinstance(link_op, (CreateProperty, CreateLink)): modifications = True if isinstance(link_op, (DeleteProperty, DeleteLink)): current_target = schema.get_by_id(target.id, None) if target_is_affected and current_target is not None: affected_targets.add(current_target) else: if not source.is_material_object_type(eff_schema): continue current_source = schema.get_by_id(source.id, None) if current_source: affected_sources.add(current_source) if target_is_affected: affected_targets.add(target) if isinstance(link_op, (SetLinkType, SetPropertyType)): orig_target = link.get_target(orig_schema) if target != orig_target: current_orig_target = schema.get_by_id( orig_target.id, None) if current_orig_target is not None: affected_targets.add(current_orig_target) # All descendants of affected targets also need to have their # triggers updated, so track them down. all_affected_targets = set() for target in affected_targets: union_of = target.get_union_of(schema) if union_of: objtypes = tuple(union_of.objects(schema)) else: objtypes = (target,) for objtype in objtypes: all_affected_targets.add(objtype) for descendant in objtype.descendants(schema): if types.has_table(descendant, schema): all_affected_targets.add(descendant) delete_target_targets = set() for target in all_affected_targets: if irtyputils.is_cfg_view(target, schema): continue deferred_links = [] deferred_inline_links = [] links = [] inline_links = [] inbound_links = schema.get_referrers( target, scls_type=s_links.Link, field_name='target') # We need to look at all inbound links to all ancestors for ancestor in itertools.chain( target.get_ancestors(schema).objects(schema), schema.get_referrers( target, scls_type=s_objtypes.ObjectType, field_name='union_of' ), ): inbound_links |= schema.get_referrers( ancestor, scls_type=s_links.Link, field_name='target') for link in inbound_links: if link.is_pure_computable(schema): continue # Skip __type__ triggers, since __type__ isn't real and # also would be a huge pain to update each time if it was. if link.get_shortname(schema).name == '__type__': continue source = link.get_source(schema) if ( not source.is_material_object_type(schema) or irtyputils.is_cfg_view(source, schema) ): continue # We need to track what objects are targets that can be # deleted on a source delete; it feeds into a decision we # need to make when handling source triggers below if link.get_on_source_delete(schema) != DS.Allow: delete_target_targets.add(target) affected_sources.add(target) action = link.get_on_target_delete(schema) ptr_stor_info = types.get_pointer_storage_info( link, schema=schema) if ptr_stor_info.table_type != 'link': if action is DA.DeferredRestrict: deferred_inline_links.append(link) else: inline_links.append(link) else: if action is DA.DeferredRestrict: deferred_links.append(link) else: links.append(link) # The ordering that we process links matters: Restrict # must be processed *after* Allow and DeleteSource, # because Restrict is applied (via views) to all # descendant links regardless of whether they have been # overridden, and so Allow and DeleteSource must be # handled first. ordering = (DA.Restrict, DA.Allow, DA.DeleteSource) links.sort( key=lambda l: (ordering.index(l.get_on_target_delete(schema)), l.get_name(schema))) inline_links.sort( key=lambda l: (ordering.index(l.get_on_target_delete(schema)), l.get_name(schema))) deferred_links.sort( key=lambda l: l.get_name(schema)) deferred_inline_links.sort( key=lambda l: l.get_name(schema)) if links or modifications: self._update_action_triggers( schema, context, target, links, disposition='target') if inline_links or modifications: self._update_action_triggers( schema, context, target, inline_links, disposition='target', inline=True) if deferred_links or modifications: self._update_action_triggers( schema, context, target, deferred_links, disposition='target', deferred=True) if deferred_inline_links or modifications: self._update_action_triggers( schema, context, target, deferred_inline_links, disposition='target', deferred=True, inline=True) # Now process source targets for source in affected_sources: links = [] inline_links = [] can_be_deleted_by_trigger = any( link.get_on_target_delete(schema) == DA.DeleteSource for link in source.get_pointers(schema).objects(schema) if isinstance(link, s_links.Link) ) or source in delete_target_targets for link in source.get_pointers(schema).objects(schema): if link.is_pure_computable(schema): continue ptr_stor_info = types.get_pointer_storage_info( link, schema=schema) delete_target = ( isinstance(link, s_links.Link) and link.get_on_source_delete(schema) != DS.Allow ) if ptr_stor_info.table_type == 'link' and ( # When a query does a delete, link tables get # cleared out explicitly in our SQL, and so we # don't need to run a source trigger unless there # is an interesting source delete policy. # # However, if the object might be deleted by a # link policy, then we still use a trigger to # clean up the link table, since handling it # in the original policy triggers would require # lots of pretty nonlocal changes (adding a link # to type Bar might require changing the triggers for # type Foo that links to Bar). can_be_deleted_by_trigger or delete_target ): links.append(link) # Inline links only need source actions if they might # delete the target elif delete_target: inline_links.append(link) links.sort( key=lambda l: ( (l.get_on_target_delete(schema),) if isinstance(l, s_links.Link) else (), l.get_name(schema))) inline_links.sort( key=lambda l: ( (l.get_on_target_delete(schema),) if isinstance(l, s_links.Link) else (), l.get_name(schema))) if links or modifications: self._update_action_triggers( schema, context, source, links, disposition='source') if inline_links or modifications: self._update_action_triggers( schema, context, source, inline_links, inline=True, disposition='source') return schema def _update_action_triggers( self, schema, context: sd.CommandContext, objtype: s_objtypes.ObjectType, links: list[s_links.Link], *, disposition: str, deferred: bool = False, inline: bool = False, ) -> None: table_name = common.get_backend_name(schema, objtype, catenate=False) trigger_name = self.get_trigger_name( schema, objtype, disposition=disposition, deferred=deferred, inline=inline) proc_name = self.get_trigger_proc_name( schema, objtype, disposition=disposition, deferred=deferred, inline=inline) trigger = dbops.Trigger( name=trigger_name, table_name=table_name, events=('delete',), procedure=proc_name, is_constraint=True, inherit=True, deferred=deferred) if links: proc_text = self.get_trigger_proc_text( objtype, links, disposition=disposition, inline=inline, schema=schema, context=context) trig_func = dbops.Function( name=proc_name, text=proc_text, volatility='volatile', returns='trigger', language='plpgsql') self.pgops.add(dbops.CreateFunction(trig_func, or_replace=True)) self.pgops.add(dbops.CreateTrigger( trigger, neg_conditions=[dbops.TriggerExists( trigger_name=trigger_name, table_name=table_name )] )) else: self.pgops.add( dbops.DropTrigger( trigger, conditions=[dbops.TriggerExists( trigger_name=trigger_name, table_name=table_name, )] ) ) self.pgops.add( dbops.DropFunction( name=proc_name, args=[], conditions=[dbops.FunctionExists( name=proc_name, args=[], )] ) ) class ModuleMetaCommand(MetaCommand): pass class CreateModule(ModuleMetaCommand, adapts=s_mod.CreateModule): pass class AlterModule(ModuleMetaCommand, adapts=s_mod.AlterModule): pass class DeleteModule(ModuleMetaCommand, adapts=s_mod.DeleteModule): pass class DatabaseMixin: def ensure_has_create_database(self, backend_params): if not backend_params.has_create_database: self.pgops.add( dbops.Query( f''' SELECT edgedb_VER.raise( NULL::uuid, msg => 'operation is not supported by the backend', exc => 'feature_not_supported' ) INTO _dummy_text ''' ) ) class CreateDatabase(MetaCommand, DatabaseMixin, adapts=s_db.CreateBranch): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: backend_params = self._get_backend_params(context) self.ensure_has_create_database(backend_params) schema = super().apply(schema, context) db = self.scls tenant_id = self._get_tenant_id(context) db_name = common.get_database_backend_name( str(self.classname), tenant_id=tenant_id) # We use the base template for SCHEMA and DATA branches, since we # implement branches ourselves using pg_dump in order to avoid # connection restrictions. # For internal-only TEMPLATE branches, we use the source as # the template. template = ( self.template if self.template and self.branch_type == ql_ast.BranchType.TEMPLATE else edbdef.EDGEDB_TEMPLATE_DB ) tpl_name = common.get_database_backend_name( template, tenant_id=tenant_id) self.pgops.add( dbops.CreateDatabase( dbops.Database( db_name, metadata=dict( id=str(db.id), tenant_id=tenant_id, builtin=self.get_attribute_value('builtin'), ), ), template=tpl_name, ) ) return schema class DropDatabase(MetaCommand, DatabaseMixin, adapts=s_db.DropBranch): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: backend_params = self._get_backend_params(context) self.ensure_has_create_database(backend_params) schema = super().apply(schema, context) tenant_id = self._get_tenant_id(context) db_name = common.get_database_backend_name( str(self.classname), tenant_id=tenant_id) self.pgops.add(dbops.DropDatabase(db_name)) return schema class AlterDatabase(MetaCommand, DatabaseMixin, adapts=s_db.AlterBranch): pass class RenameDatabase(MetaCommand, DatabaseMixin, adapts=s_db.RenameBranch): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: backend_params = self._get_backend_params(context) self.ensure_has_create_database(backend_params) schema = super().apply(schema, context) tenant_id = self._get_tenant_id(context) db_name = common.get_database_backend_name( str(self.classname), tenant_id=tenant_id) new_name = common.get_database_backend_name( str(self.new_name), tenant_id=tenant_id) self.pgops.add( dbops.RenameDatabase( dbops.Database( new_name, ), old_name=db_name, ) ) return schema class RoleMixin: def ensure_has_create_role(self, backend_params): if not backend_params.has_create_role: self.pgops.add( dbops.Query( f''' SELECT edgedb_VER.raise( NULL::uuid, msg => 'operation is not supported by the backend', exc => 'feature_not_supported' ) INTO _dummy_text ''' ) ) class CreateRole(MetaCommand, RoleMixin, adapts=s_roles.CreateRole): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: backend_params = self._get_backend_params(context) self.ensure_has_create_role(backend_params) schema = super().apply(schema, context) role = self.scls membership = [str(x) for x in role.get_bases(schema).names(schema)] passwd = role.get_password(schema) superuser_flag = False members = set() role_name = str(role.get_name(schema)) permissions: list[str] = list(sorted( role.get_permissions(schema) or () )) branches: list[str] = list(sorted( role.get_branches(schema) )) apply_access_policies_pg_default = ( role.get_apply_access_policies_pg_default(schema) ) instance_params = backend_params.instance_params tenant_id = instance_params.tenant_id if role.get_superuser(schema): membership.append(edbdef.EDGEDB_SUPERGROUP) # If the cluster is not exposing an explicit superuser role, # we will make the created Postgres role superuser if we can if not instance_params.base_superuser: superuser_flag = backend_params.has_superuser_access if backend_params.session_authorization_role is not None: # When we connect to the backend via a proxy role, we # must ensure that role is a member of _every_ EdgeDB # role so that `SET ROLE` can work properly. members.add(backend_params.session_authorization_role) db_role = dbops.Role( name=common.get_role_backend_name(role_name, tenant_id=tenant_id), allow_login=True, superuser=superuser_flag, password=passwd, membership=[ common.get_role_backend_name(parent_role, tenant_id=tenant_id) for parent_role in membership ], metadata=dict( id=str(role.id), name=role_name, tenant_id=tenant_id, password_hash=passwd, builtin=role.get_builtin(schema), permissions=permissions, branches=branches, apply_access_policies_pg_default=( apply_access_policies_pg_default ), ), ) self.pgops.add(dbops.CreateRole(db_role)) return schema class AlterRole(MetaCommand, RoleMixin, adapts=s_roles.AlterRole): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) role = self.scls backend_params = self._get_backend_params(context) instance_params = backend_params.instance_params tenant_id = instance_params.tenant_id role_name = str(role.get_name(schema)) kwargs = {} update_metadata = False metadata = dict( id=str(role.id), name=role_name, tenant_id=tenant_id, builtin=role.get_builtin(schema), permissions=list(sorted(role.get_permissions(schema) or ())), branches=list(sorted(role.get_branches(schema) or ())), apply_access_policies_pg_default=( role.get_apply_access_policies_pg_default(schema) ), ) if self.has_attribute_value('password'): passwd = self.get_attribute_value('password') if backend_params.has_create_role: # Only modify Postgres password of roles managed by EdgeDB kwargs['password'] = passwd update_metadata = True metadata['password_hash'] = passwd elif old_passwd := role.get_password(schema): if backend_params.has_create_role: # Only modify Postgres password of roles managed by EdgeDB kwargs['password'] = old_passwd metadata['password_hash'] = old_passwd if ( self.has_attribute_value('permissions') or self.has_attribute_value('branches') or self.has_attribute_value('apply_access_policies_pg_default') ): update_metadata = True pg_role_name = common.get_role_backend_name( role_name, tenant_id=tenant_id) if self.has_attribute_value('superuser'): self.ensure_has_create_role(backend_params) membership = [str(x) for x in role.get_bases(schema).names(schema)] membership.append(edbdef.EDGEDB_SUPERGROUP) self.pgops.add( dbops.AlterRoleAddMembership( name=pg_role_name, membership=[ common.get_role_backend_name( parent_role, tenant_id=tenant_id) for parent_role in membership ], ) ) superuser_flag = False # If the cluster is not exposing an explicit superuser role, # we will make the modified Postgres role superuser if we can if not instance_params.base_superuser: superuser_flag = backend_params.has_superuser_access kwargs['superuser'] = superuser_flag if update_metadata: kwargs['metadata'] = metadata if backend_params.has_create_role: dbrole = dbops.Role(name=pg_role_name, **kwargs) else: dbrole = dbops.SingleRole(**kwargs) self.pgops.add(dbops.AlterRole(dbrole)) return schema class RebaseRole(MetaCommand, RoleMixin, adapts=s_roles.RebaseRole): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: backend_params = self._get_backend_params(context) self.ensure_has_create_role(backend_params) schema = super().apply(schema, context) role = self.scls tenant_id = self._get_tenant_id(context) for dropped in self.removed_bases: self.pgops.add(dbops.AlterRoleDropMember( name=common.get_role_backend_name( str(dropped.name), tenant_id=tenant_id), member=common.get_role_backend_name( str(role.get_name(schema)), tenant_id=tenant_id), )) for bases, _pos in self.added_bases: for added in bases: self.pgops.add(dbops.AlterRoleAddMember( name=common.get_role_backend_name( str(added.name), tenant_id=tenant_id), member=common.get_role_backend_name( str(role.get_name(schema)), tenant_id=tenant_id), )) return schema class DeleteRole(MetaCommand, RoleMixin, adapts=s_roles.DeleteRole): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: backend_params = self._get_backend_params(context) self.ensure_has_create_role(backend_params) schema = super().apply(schema, context) tenant_id = self._get_tenant_id(context) self.pgops.add(dbops.DropRole( common.get_role_backend_name( str(self.classname), tenant_id=tenant_id))) return schema class CreateExtensionPackage( MetaCommand, adapts=s_exts.CreateExtensionPackage, ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) ext_id = str(self.scls.id) name__internal = str(self.scls.get_name(schema)) name = self.scls.get_displayname(schema) version = self.scls.get_version(schema)._asdict() version['stage'] = version['stage'].name.lower() ext_module = self.scls.get_ext_module(schema) metadata = { ext_id: { 'id': ext_id, 'name': name, 'name__internal': name__internal, 'script': self.scls.get_script(schema), 'version': version, 'builtin': self.scls.get_builtin(schema), 'internal': self.scls.get_internal(schema), 'ext_module': ext_module and str(ext_module), 'sql_extensions': list(self.scls.get_sql_extensions(schema)), 'sql_setup_script': self.scls.get_sql_setup_script(schema), 'sql_teardown_script': ( self.scls.get_sql_teardown_script(schema) ), 'dependencies': list(self.scls.get_dependencies(schema)), } } ctx_backend_params = context.backend_runtime_params if ctx_backend_params is not None: backend_params = cast( params.BackendRuntimeParams, ctx_backend_params) else: backend_params = params.get_default_runtime_params() if backend_params.has_create_database: self.pgops.add( dbops.UpdateMetadataSection( dbops.DatabaseWithTenant(name=edbdef.EDGEDB_TEMPLATE_DB), section='ExtensionPackage', metadata=metadata ) ) else: self.pgops.add( dbops.UpdateSingleDBMetadataSection( edbdef.EDGEDB_TEMPLATE_DB, section='ExtensionPackage', metadata=metadata ) ) return schema class DeleteExtensionPackage( MetaCommand, adapts=s_exts.DeleteExtensionPackage, ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) ctx_backend_params = context.backend_runtime_params if ctx_backend_params is not None: backend_params = cast( params.BackendRuntimeParams, ctx_backend_params) else: backend_params = params.get_default_runtime_params() ext_id = str(self.scls.id) metadata = { ext_id: None } if backend_params.has_create_database: self.pgops.add( dbops.UpdateMetadataSection( dbops.DatabaseWithTenant(name=edbdef.EDGEDB_TEMPLATE_DB), section='ExtensionPackage', metadata=metadata ) ) else: self.pgops.add( dbops.UpdateSingleDBMetadataSection( edbdef.EDGEDB_TEMPLATE_DB, section='ExtensionPackage', metadata=metadata ) ) return schema class CreateExtensionPackageMigration( MetaCommand, adapts=s_exts.CreateExtensionPackageMigration, ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) ext_id = str(self.scls.id) name__internal = str(self.scls.get_name(schema)) name = self.scls.get_displayname(schema) from_version = self.scls.get_from_version(schema)._asdict() from_version['stage'] = from_version['stage'].name.lower() to_version = self.scls.get_to_version(schema)._asdict() to_version['stage'] = to_version['stage'].name.lower() metadata = { ext_id: { 'id': ext_id, 'name': name, 'name__internal': name__internal, 'script': self.scls.get_script(schema), 'from_version': from_version, 'to_version': to_version, 'builtin': self.scls.get_builtin(schema), 'internal': self.scls.get_internal(schema), 'sql_early_script': self.scls.get_sql_early_script(schema), 'sql_late_script': self.scls.get_sql_late_script(schema), } } ctx_backend_params = context.backend_runtime_params if ctx_backend_params is not None: backend_params = cast( params.BackendRuntimeParams, ctx_backend_params) else: backend_params = params.get_default_runtime_params() if backend_params.has_create_database: self.pgops.add( dbops.UpdateMetadataSection( dbops.DatabaseWithTenant(name=edbdef.EDGEDB_TEMPLATE_DB), section='ExtensionPackageMigration', metadata=metadata ) ) else: self.pgops.add( dbops.UpdateSingleDBMetadataSection( edbdef.EDGEDB_TEMPLATE_DB, section='ExtensionPackageMigration', metadata=metadata ) ) return schema class DeleteExtensionPackageMigration( MetaCommand, adapts=s_exts.DeleteExtensionPackageMigration, ): # XXX: 100% duplication with DeleteExtensionPackage def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) ctx_backend_params = context.backend_runtime_params if ctx_backend_params is not None: backend_params = cast( params.BackendRuntimeParams, ctx_backend_params) else: backend_params = params.get_default_runtime_params() ext_id = str(self.scls.id) metadata = { ext_id: None } if backend_params.has_create_database: self.pgops.add( dbops.UpdateMetadataSection( dbops.DatabaseWithTenant(name=edbdef.EDGEDB_TEMPLATE_DB), section='ExtensionPackageMigration', metadata=metadata ) ) else: self.pgops.add( dbops.UpdateSingleDBMetadataSection( edbdef.EDGEDB_TEMPLATE_DB, section='ExtensionPackageMigration', metadata=metadata ) ) return schema class ExtensionCommand(MetaCommand): def _compute_version(self, ext_spec: str) -> None: '''Emits a Query to compute the version. Dumps it in _dummy_text. ''' ext, vclauses = _parse_spec(ext_spec) # Dynamically select the highest version extension that matches # the provided version specification. lclauses = [] for op, ver in vclauses: pver = f"string_to_array({ql(ver)}, '.')::int8[]" assert op in {'=', '>', '>=', '<', '<='} lclauses.append(f'v.split {op} {pver}') cond = ' and '.join(lclauses) if lclauses else 'true' ver_regexp = r'^\d+(\.\d+)+$' qry = textwrap.dedent(f'''\ with v as ( select name, version, string_to_array(version, '.')::int8[] as split from pg_available_extension_versions where name = {ql(ext)} and version ~ '{ver_regexp}' ) select edgedb_VER.raise_on_null( ( select v.version from v where {cond} order by split desc limit 1 ), 'feature_not_supported', msg => ( 'could not find extension satisfying ' || {ql(ext_spec)} || ': ' || coalesce( 'only found versions ' || (select string_agg(v.version, ', ' order by v.split) from v), 'extension not found' ) ) ) into _dummy_text; ''') self.pgops.add(dbops.Query(qry)) def _create_extension(self, ext_spec: str) -> None: ext = _get_ext_name(ext_spec) self._compute_version(ext_spec) # XXX: hardcode to put stuff into edgedb schema # so that operations can be easily accessed. # N.B: this won't work on heroku; is that fine? target_schema = 'edgedb' self.pgops.add(dbops.Query(textwrap.dedent(f"""\ EXECUTE 'CREATE EXTENSION {ext} WITH SCHEMA {target_schema} VERSION ''' || _dummy_text || '''' """))) def _get_ext_name(spec: str) -> str: return spec.split(' ')[0] def _parse_spec(spec: str) -> tuple[str, list[tuple[str, str]]]: if ' ' not in spec: return (spec, []) ext, versions = spec.split(' ', 1) clauses = versions.split(',') pclauses = [] for clause in clauses: for i in range(len(clause)): if clause[i].isnumeric(): break pclauses.append((clause[:i], clause[i:])) return ext, pclauses class CreateExtension(ExtensionCommand, adapts=s_exts.CreateExtension): def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._create_begin(schema, context) # backend_params = self._get_backend_params(context) # ext_schema = backend_params.instance_params.ext_schema package = self.scls.get_package(schema) for ext_spec in package.get_sql_extensions(schema): self._create_extension(ext_spec) if script := package.get_sql_setup_script(schema): self.pgops.add(dbops.Query(script)) return schema def _create_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._create_innards(schema, context) if str(self.classname) == "ai": self.pgops.add( delta_ext_ai.pg_rebuild_all_pending_embeddings_views( schema, context ), ) return schema class AlterExtension(ExtensionCommand, adapts=s_exts.AlterExtension): def _upgrade_extension(self, ext_spec: str) -> None: ext = _get_ext_name(ext_spec) self._compute_version(ext_spec) self.pgops.add(dbops.Query(textwrap.dedent(f"""\ EXECUTE 'ALTER EXTENSION {ext} UPDATE TO ''' || _dummy_text || '''' """))) def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: old_package = self.scls.get_package(schema) schema = super()._alter_begin(schema, context) if not self.migration: return schema new_package = self.scls.get_package(schema) old_exts = { _get_ext_name(spec) for spec in old_package.get_sql_extensions(schema) } new_exts = set() # XXX: be smarter! for ext_spec in new_package.get_sql_extensions(schema): ext = _get_ext_name(ext_spec) new_exts.add(ext) if ext in old_exts: self._upgrade_extension(ext_spec) else: self._create_extension(ext_spec) # # XXX??? should do this after # for ext in old_exts: # if ext not in new_exts: # self.pgops.add( # dbops.DropExtension( # dbops.Extension( # name=ext, # schema=ext, # ), # ) # ) # TODO: UPDATE the sql extension? Or should we do that in the # script? if script := self.migration.get_sql_early_script(schema): self.pgops.add(dbops.Query(script)) return schema def _alter_finalize( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_finalize(schema, context) if ( self.migration and (script := self.migration.get_sql_late_script(schema)) ): self.pgops.add(dbops.Query(script)) return schema class DeleteExtension(ExtensionCommand, adapts=s_exts.DeleteExtension): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: extension = schema.get_global(s_exts.Extension, self.classname) package = extension.get_package(schema) if str(self.classname) == "ai": self.pgops.add( delta_ext_ai.pg_drop_all_pending_embeddings_views(schema), ) schema = super().apply(schema, context) if script := package.get_sql_teardown_script(schema): self.pgops.add(dbops.Query(script)) for ext_spec in package.get_sql_extensions(schema): ext = _get_ext_name(ext_spec) self.pgops.add( dbops.DropExtension( dbops.Extension( name=ext, schema=ext, ), ) ) return schema class FutureBehaviorCommand(MetaCommand, s_futures.FutureBehaviorCommand): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) if self.future_cmd: self.pgops.add(self.future_cmd) return schema class CreateFutureBehavior( FutureBehaviorCommand, adapts=s_futures.CreateFutureBehavior): pass class DeleteFutureBehavior( FutureBehaviorCommand, adapts=s_futures.DeleteFutureBehavior): pass class AlterFutureBehavior( FutureBehaviorCommand, adapts=s_futures.AlterFutureBehavior): pass class DeltaRoot(MetaCommand, adapts=sd.DeltaRoot): def __init__(self, **kwargs): super().__init__(**kwargs) self.config_ops = [] def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: self.update_endpoint_delete_actions = UpdateEndpointDeleteActions() self.create_trampolines = CreateTrampolines() schema = super().apply(schema, context) self.update_endpoint_delete_actions.apply(schema, context) self.pgops.add(self.update_endpoint_delete_actions) self.create_trampolines.apply(schema, context) return schema class MigrationCommand(MetaCommand): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) if last_mig := schema.get_last_migration(): last_mig_name = last_mig.get_name(schema).name else: last_mig_name = None self.pgops.add(dbops.UpdateMetadata( dbops.CurrentDatabase(), {'last_migration': last_mig_name}, )) return schema class CreateMigration( MigrationCommand, adapts=s_migrations.CreateMigration, ): pass class AlterMigration( MigrationCommand, adapts=s_migrations.AlterMigration, ): pass class DeleteMigration( MigrationCommand, adapts=s_migrations.DeleteMigration, ): pass ================================================ FILE: edb/pgsql/delta_ext_ai.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # # Backend support for ext::ai::index # # The index adds the following hidden attribute to the object type relation # # __ext_ai_{idx_id}_embedding__ vector() # # The data in the attribute gets populated by an external indexing process, # hence the ext::ai::index is currently always deferred. If a given object # record is yet unindexed, the attribute value would be NULL and the entry # will be picked up in the work queue view (see below). # # To invalidate embeddings on changes of data referenced in the index # expression changes, a simple trigger is also added, which resets the # value of the embedding attribute back to NULL. # # The index is currently always deferred, hence the unindexed data # needs to be exposed conveniently to an external indexer. We do # this here by creating the following internal SQL views: # # Enumeration of embedding models currently used in ext::ai::index() # declarations in the current schema: # # CREATE VIEW edgedbext.ai_active_embedding_models( # id, -- generated unique id as int64 (could be used for locking) # name, -- model name as specified in the ext::ai::model_name anno # provider, -- provider name as specified in the ext::ai::provider_name # ) # # For each active model in the above view the following views are also # generated: # # CREATE VIEW edgedbext."ai_pending_embeddings_{model_name}"( # id, -- Object ID # text, -- Indexed text document (result of index expr eval) # target_rel, -- SQL relation containing the embedding data # target_attr, -- Column in the above relation containing embedding data # target_dims_shortening -- If the embedding model produces more dimensions # -- than the underlying index can handle, this # -- would be the maximum dimensions supported by # -- the index. Embedding model must support # -- vector shortening (e.g OpenAI # -- embedding-text-3- models). # ) # # The above view is a UNION of SELECTs over object relations, where each # UNION element is roughly this: # # SELECT ( # Object.id, # eval(get_index_expr(Object, 'ext::ai::index')), # ) # WHERE # eval(get_index_except_expr(Object, 'ext::ai::index')) IS NOT TRUE # AND Object.__ext_ai_{idx_id}_embedding__ IS NULL # from __future__ import annotations from typing import ( cast, Optional, ) import collections import dataclasses import hashlib import struct import textwrap from edb.schema import expr as s_expr from edb.schema import indexes as s_indexes from edb.schema import types as s_types from edb.schema import schema as s_schema from edb.schema import delta as sd from edb.schema import name as sn from edb.schema import properties as s_props from edb.ir import ast as irast from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from . import codegen from . import common from . import deltadbops from . import dbops from . import compiler from . import types from . import ast as pgast from .common import qname as q from .common import quote_literal as ql from .common import quote_ident as qi from .compiler import astutils from .compiler import enums as pgce ai_index_base_name = sn.QualName("ext::ai", "index") def get_ext_ai_pre_restore_script( schema: s_schema.Schema, ) -> str: # We helpfully populate ext::ai::ChatPrompt with a starter prompt # in the extension setup script. # Unfortunately, this means that before user data is restored, we need # to delete those objects, or there will be a constraint error. return ''' delete {ext::ai::ChatPrompt, ext::ai::ChatPromptMessage} ''' def create_ext_ai_index( index: s_indexes.Index, predicate_src: Optional[str], sql_kwarg_exprs: dict[str, str], options: qlcompiler.CompilerOptions, schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: subject = index.get_subject(schema) assert isinstance(subject, s_indexes.IndexableSubject) effective, has_overridden = s_indexes.get_effective_object_index( schema, subject, ai_index_base_name ) if index != effective: return dbops.CommandGroup() # When creating an index on a child that already has an ext::ai index # inherited from the parent, we don't need to create the index, but just # update the populating expressions. if has_overridden: return _refresh_ai_embeddings( index, has_overridden[0], options, schema, context, ) else: return _create_ai_embeddings( index, predicate_src, sql_kwarg_exprs, options, schema, context, ) def delete_ext_ai_index( index: s_indexes.Index, drop_index: dbops.Command, options: qlcompiler.CompilerOptions, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, ) -> tuple[dbops.Command, dbops.Command]: subject = index.get_subject(orig_schema) assert isinstance(subject, s_indexes.IndexableSubject) effective, _ = s_indexes.get_effective_object_index( schema, subject, ai_index_base_name ) if not effective: return _delete_ai_embeddings( index, drop_index, schema, orig_schema, context) else: # effective index remains: don't drop the embeddings return ( dbops.CommandGroup(), _refresh_ai_embeddings( effective, index, options, orig_schema, context, ), ) def _compile_ai_embeddings_source_view_expr( index: s_indexes.Index, options: qlcompiler.CompilerOptions, schema: s_schema.Schema, ) -> pgast.SelectStmt: # Compile a view returning a set of (id, text-to-embed) tuples # roughly as the following pseudo-QL # # FOR obj in Object UNION ( # SELECT ( # obj.id, # eval(get_index_expr(obj, 'ext::ai::index')), # ) # WHERE # eval(get_index_except_expr(obj, 'ext::ai::index')) IS NOT TRUE # AND obj.embedding_column IS NULL # ) index_sexpr: Optional[s_expr.Expression] = index.get_expr(schema) assert index_sexpr ql_iter_alias = "__ext_ai_index_iter__" ql_index_sexpr_name = "__ext_ai_index_sexpr__" ql = qlast.ForQuery( iterator_alias=ql_iter_alias, iterator=qlast.Shape( expr=qlast.Path( steps=[qlast.IRAnchor(name='__subject__')], ), elements=[ qlast.ShapeElement( expr=qlast.Path( steps=[qlast.Ptr(name="id")], ), ), qlast.ShapeElement( expr=qlast.Path( steps=[qlast.Ptr(name=ql_index_sexpr_name)], ), compexpr=index_sexpr.parse(), ), ] ), result=qlast.Tuple( elements=[ qlast.Path( steps=[ qlast.ObjectRef(name=ql_iter_alias), qlast.Ptr(name="id"), ], ), qlast.Path( steps=[ qlast.ObjectRef(name=ql_iter_alias), qlast.Ptr(name=ql_index_sexpr_name), ], ), ], ), ) my_options = dataclasses.replace(options, singletons=frozenset()) ir = qlcompiler.compile_ast_to_ir( ql, schema=schema, options=my_options, ) assert isinstance(ir, irast.Statement) subject = index.get_subject(schema) assert isinstance(subject, s_types.Type) subject_id = irast.PathId.from_type(schema, subject, env=None) idx_id = _get_index_root_id(schema, index) table_name = common.get_index_table_backend_name(index, schema) aspects = ( pgce.PathAspect.IDENTITY, pgce.PathAspect.VALUE, pgce.PathAspect.SOURCE ) qry = compiler.new_external_rvar_as_subquery( rel_name=table_name, path_id=subject_id, aspects=aspects, ) qry.where_clause = astutils.extend_binop( qry.where_clause, pgast.NullTest( arg=pgast.ColumnRef( name=(f"__ext_ai_{idx_id}_embedding__",), ), ) ) except_expr = index.get_except_expr(schema) if except_expr: except_expr = except_expr.ensure_compiled( schema=schema, options=options, context=None, ) assert except_expr.irast except_res = compiler.compile_ir_to_sql_tree( except_expr.irast.expr, singleton_mode=True) assert isinstance(except_res.ast, pgast.BaseExpr) qry.where_clause = astutils.extend_binop( qry.where_clause, pgast.Expr( lexpr=except_res.ast, name="IS NOT", rexpr=pgast.BooleanConstant(val=True), ), ) sql_res = compiler.compile_ir_to_sql_tree( ir, output_format=compiler.OutputFormat.NATIVE_INTERNAL, external_rels={ subject_id: (qry, aspects), }, ) expr = sql_res.ast assert isinstance(expr, pgast.SelectStmt) return expr def _create_ai_embeddings( index: s_indexes.Index, predicate_src: Optional[str], sql_kwarg_exprs: dict[str, str], options: qlcompiler.CompilerOptions, schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: return _pg_create_ai_embeddings( index, options, predicate_src, sql_kwarg_exprs, schema, context, ) def _refresh_ai_embeddings( index: s_indexes.Index, old_index: s_indexes.Index, options: qlcompiler.CompilerOptions, schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: ops = dbops.CommandGroup() table_name = common.get_index_table_backend_name(index, schema) ops.add_command( _pg_drop_trigger(index, table_name, schema)) idx_id = _get_index_root_id(schema, index) ops.add_command(dbops.Query(textwrap.dedent(f"""\ UPDATE {common.qname(*table_name)} SET __ext_ai_{idx_id}_embedding__ = NULL WHERE __ext_ai_{idx_id}_embedding__ IS NOT NULL """))) ops.add_command( _pg_create_trigger(index, table_name, schema)) ops.add_command( _pg_create_ai_embeddings_source_view( index, options, schema, context)) # Sigh, we need to rename the main index to match the new id, # entirely for the purpose of having ANALYZE be able to pick it up dimensions = index.must_get_json_annotation( schema, sn.QualName("ext::ai", "embedding_dimensions"), int, ) ops.add_command( deltadbops.rename_pg_index( old_index=old_index, new_index=index, schema=schema, aspect=f'{dimensions}_index', ) ) return ops def _delete_ai_embeddings( index: s_indexes.Index, drop_index: dbops.Command, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, ) -> tuple[dbops.Command, dbops.Command]: return _pg_delete_ai_embeddings( index, drop_index, schema, orig_schema, context ) # --- pgvector --- def _pg_create_ai_embeddings( index: s_indexes.Index, options: qlcompiler.CompilerOptions, predicate_src: Optional[str], sql_kwarg_exprs: dict[str, str], schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: # Create: # * the "__ext_ai_{idx_id}_embedding__" vector attribute; # * pgvector index on the above; # * the embedding attribute invalidation trigger # * a component view for the "ai_pending_embeddings_{model_name}" union ops = dbops.CommandGroup() table_name = common.get_index_table_backend_name(index, schema) with_clause = {} kwargs = index.get_concrete_kwargs(schema) index_params_expr = kwargs.get("index_parameters") if index_params_expr is not None: index_params = index_params_expr.assert_compiled().as_python_value() with_clause["m"] = index_params["m"] with_clause["ef_construction"] = index_params["ef_construction"] dimensions = index.must_get_json_annotation( schema, sn.QualName("ext::ai", "embedding_dimensions"), int, ) idx_id = _get_index_root_id(schema, index) alter_table = dbops.AlterTable(table_name) # The attribute alter_table.add_operation( dbops.AlterTableAddColumn( dbops.Column( name=f'__ext_ai_{idx_id}_embedding__', type=f'edgedb.vector({dimensions})', required=False, ) ) ) ops.add_command(alter_table) # Also create a constant partial index on outdated entries # so that we use an index scan and not a seq scan when # picking out pending embeddings. outdated_idx_name = common.get_index_table_backend_name( index, schema, aspect="extaiselidx") ops.add_command( dbops.CreateIndex( dbops.Index( name=outdated_idx_name[1], table_name=table_name, exprs=["(1)"], predicate=( f'__ext_ai_{idx_id}_embedding__ IS NULL'), unique=False, metadata={ 'code': '(__col__)', }, ), ), ) df_expr = kwargs.get("distance_function") if df_expr is not None: df = df_expr.assert_compiled().as_python_value() else: df = "Cosine" match df: case "Cosine": opclass = "vector_cosine_ops" case "InnerProduct": opclass = "vector_ip_ops" case "L2": opclass = "vector_l2_ops" case _: raise RuntimeError(f"unsupported distance_function: {df}") # The main similarity (a.k.a distance) search index. module_name = index.get_name(schema).module index_name = common.get_index_backend_name( index.id, module_name, catenate=False, aspect=f'{dimensions}_index' ) pg_index = dbops.Index( name=index_name[1], table_name=table_name, exprs=[f"__ext_ai_{idx_id}_embedding__"], with_clause=with_clause, unique=False, predicate=predicate_src, metadata={ 'schemaname': str(index.get_name(schema)), 'kwargs': sql_kwarg_exprs, 'code': f'hnsw (__col__ {opclass})', 'dimensions': str(dimensions), 'distance_function': str(df), }, ) ops.add_command(dbops.CreateIndex(pg_index)) # The invalidation trigger ops.add_command(_pg_create_trigger(index, table_name, schema)) # The component view for the "ai_pending_embeddings_{model_name}" union ops.add_command( _pg_create_ai_embeddings_source_view(index, options, schema, context)) return ops def _get_dep_cols( index: s_indexes.Index, schema: s_schema.Schema, ) -> list[str]: index_expr = index.get_expr(schema) assert index_expr is not None dep_cols = [] assert index_expr.refs is not None for obj in index_expr.refs.objects(schema): if ( isinstance(obj, s_props.Property) # Exclude computed pointers, they don't actually have columns and obj.get_expr(schema) is None ): ptrinfo = types.get_pointer_storage_info(obj, schema=schema) dep_cols.append(ptrinfo.column_name) return dep_cols def _pg_delete_ai_embeddings( index: s_indexes.Index, drop_index: dbops.Command, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, ) -> tuple[dbops.Command, dbops.Command]: table_name = common.get_index_table_backend_name(index, orig_schema) idx_id = _get_index_root_id(orig_schema, index) table_ops = dbops.CommandGroup() ops = dbops.CommandGroup() ops.add_command(drop_index) # Drop the invalidation trigger ops.add_command(_pg_drop_trigger(index, table_name, orig_schema)) # Drop component view for the "ai_pending_embeddings_{model_name}" union ops.add_command(_pg_drop_ai_embeddings_source_view( index, schema, orig_schema, context)) # When the ObjectType is being deleted, we don't drop the index, # as it will get dropped with the parent table. # The same goes for __ext_ai_{idx_id}_embedding__. source_drop = isinstance(drop_index, dbops.NoOpCommand) if not source_drop: table_name = common.get_index_table_backend_name(index, orig_schema) alter_table = dbops.AlterTable(table_name) alter_table.add_operation( dbops.AlterTableDropColumn( dbops.Column( name=f'__ext_ai_{idx_id}_embedding__', # This isn't actually needed to do the drop, and # it saves us needing to get the dimensions. # (Which is good, because they are missing when we # try to schema_repair to fix #9033.) type='XXX UNUSED', ) ) ) table_ops.add_command(alter_table) return ops, table_ops def _pg_create_trigger( index: s_indexes.Index, table_name: tuple[str, str], schema: s_schema.Schema, ) -> dbops.Command: dep_cols = _get_dep_cols(index, schema) # Create a trigger that resets the __ext_ai_{idx_id}_embedding__ to # NULL whenever data referenced in the ext::ai::index expression gets # modified (TODO: the selective approach could also be used on # std::fts::index) ops = dbops.CommandGroup() idx_id = _get_index_root_id(schema, index) # create update function func_name = _pg_update_func_name(table_name, idx_id) function = dbops.Function( name=func_name, text=f""" BEGIN NEW."__ext_ai_{idx_id}_embedding__" := NULL; RETURN NEW; END; """, volatility='immutable', returns='trigger', language='plpgsql', ) ops.add_command(dbops.CreateFunction(function)) conditions = [] for dep_col in dep_cols: dep_col = qi(dep_col) conditions.append(f'OLD.{dep_col} IS DISTINCT FROM NEW.{dep_col}') trigger_name = _pg_trigger_name(table_name[1], idx_id) trigger = dbops.Trigger( name=trigger_name, table_name=table_name, events=('update',), timing=dbops.TriggerTiming.Before, procedure=func_name, condition=' OR '.join(conditions), ) ops.add_command(dbops.CreateTrigger(trigger)) return ops def _pg_drop_trigger( index: s_indexes.Index, table_name: tuple[str, str], schema: s_schema.Schema, override_id: Optional[str] = None, ) -> dbops.Command: idx_id = override_id or _get_index_root_id(schema, index) ops = dbops.CommandGroup() ops.add_command( dbops.DropTrigger( dbops.Trigger( _pg_trigger_name(table_name[1], idx_id), table_name=table_name, events=(), procedure='', ) ) ) ops.add_command( dbops.DropFunction( _pg_update_func_name(table_name, idx_id), (), ) ) return ops def pg_rebuild_all_pending_embeddings_views( schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: ops = dbops.CommandGroup() def flt(schema: s_schema.Schema, index: s_indexes.Index) -> bool: return ( index.get_subject(schema) is not None and s_indexes.is_ext_ai_index(schema, index) ) all_ai_indexes = schema.get_objects( type=s_indexes.Index, extra_filters=(flt,), ) all_models = s_indexes.get_defined_ext_ai_embedding_models(schema) used_models = collections.defaultdict(list) for other_index in all_ai_indexes: if context.is_deleting(other_index): continue tabname = common.get_index_table_backend_name( other_index, schema, aspect="extaiview") model_name = other_index.must_get_annotation( schema, sn.QualName("ext::ai", "model_name")) used_models[model_name].append(f"SELECT * FROM {q(*tabname)}") model_providers = {} for model_name, model_stype in all_models.items(): views = used_models.get(model_name) if views: query = " UNION ALL ".join(views) else: query = textwrap.dedent("""\ SELECT NULL::uuid AS "id", NULL::text AS "text", NULL::text AS "target_rel", NULL::text AS "target_attr", NULL::int AS "target_dims_shortening", NULL::boolean AS "truncate_to_max" WHERE FALSE """) view = dbops.View( name=( "edgedbext", common.edgedb_name_to_pg_name( f"ai_pending_embeddings_{model_name}" ), ), query=query, ) ops.add_command(dbops.CreateView(view, or_replace=True)) provider = model_stype.must_get_annotation( schema, sn.QualName("ext::ai", "model_provider")) model_providers[model_name] = provider if used_models: bits = [] for model_name in used_models: mnhash = hashlib.blake2b(model_name.encode("utf-8"), digest_size=8) model_id: int = struct.unpack("q", mnhash.digest())[0] provider = model_providers[model_name] bits.append(textwrap.dedent(f"""\ SELECT {model_id}::bigint AS id, {ql(model_name)} AS name, {ql(provider)} AS provider """)) used_sql = " UNION ALL ".join(bits) else: used_sql = textwrap.dedent("""\ SELECT NULL::bigint AS id, NULL::text AS name, NULL::text AS provider WHERE FALSE """) ops.add_command(dbops.CreateView( view=dbops.View( name=("edgedbext", "ai_active_embedding_models"), query=used_sql, ), or_replace=True, )) return ops def pg_drop_all_pending_embeddings_views( schema: s_schema.Schema, ) -> dbops.Command: ops = dbops.CommandGroup() all_models = s_indexes.get_defined_ext_ai_embedding_models(schema) for model_name in all_models: view_name = ( "edgedbext", common.edgedb_name_to_pg_name( f"ai_pending_embeddings_{model_name}" ), ) ops.add_command(dbops.DropView(view_name, conditional=True)) ops.add_command( dbops.DropView(("edgedbext", "ai_active_embedding_models"))) return ops def _pg_create_ai_embeddings_source_view( index: s_indexes.Index, options: qlcompiler.CompilerOptions, schema: s_schema.Schema, context: sd.CommandContext, *, rebuild_all: bool=True, ) -> dbops.Command: ops = dbops.CommandGroup() expr = _compile_ai_embeddings_source_view_expr(index, options, schema) view_name = common.get_index_table_backend_name( index, schema, aspect="extaiview") idx_id = _get_index_root_id(schema, index) target_col = f"__ext_ai_{idx_id}_embedding__" index_dimensions = index.must_get_json_annotation( schema, sn.QualName("ext::ai", "embedding_dimensions"), int, ) model_dimensions = index.must_get_json_annotation( schema, sn.QualName("ext::ai", "embedding_model_max_output_dimensions"), int, ) if index_dimensions < model_dimensions: target_dims_shortening = str(index_dimensions) else: target_dims_shortening = "NULL" kwargs = index.get_concrete_kwargs(schema) truncate_to_max_arg = kwargs.get("truncate_to_max") if truncate_to_max_arg is not None: truncate_to_max = cast( bool, truncate_to_max_arg.assert_compiled().as_python_value() ) else: truncate_to_max = False table_name = common.get_index_table_backend_name(index, schema) expr_sql = codegen.generate_source(expr) document_sql = textwrap.dedent(f"""\ SELECT (q.val).f1 AS "id", (q.val).f2 AS "text", {ql(q(*table_name))} AS "target_rel", {ql(qi(target_col))} AS "target_attr", {target_dims_shortening}::int AS "target_dims_shortening", {truncate_to_max}::boolean AS "truncate_to_max" FROM ({expr_sql}) AS q(val) """) view = dbops.View(name=view_name, query=document_sql) ops.add_command(dbops.CreateView(view, or_replace=True)) if rebuild_all: ops.add_command( pg_rebuild_all_pending_embeddings_views(schema, context)) return ops def _pg_drop_ai_embeddings_source_view( index: s_indexes.Index, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: ops = dbops.CommandGroup() ops.add_command(pg_rebuild_all_pending_embeddings_views( schema, context )) view_name = common.get_index_table_backend_name( index, orig_schema, aspect="extaiview") ops.add_command(dbops.DropView(view_name)) return ops def _pg_update_func_name( tbl_name: tuple[str, str], idx_id: str, ) -> tuple[str, ...]: return ( tbl_name[0], common.edgedb_name_to_pg_name(tbl_name[1] + f'_extai_{idx_id}_upd'), ) def _pg_trigger_name( tbl_name: str, idx_id: str, ) -> str: return common.edgedb_name_to_pg_name(tbl_name + f'_extai_{idx_id}_trg') def _get_index_root_id( schema: s_schema.Schema, index: s_indexes.Index, ) -> str: return s_indexes.get_ai_index_id(schema, index) ================================================ FILE: edb/pgsql/deltadbops.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Abstractions for low-level database DDL and DML operations.""" from __future__ import annotations from typing import Optional import itertools from edb.common import adapter from edb.schema import indexes as s_indexes from edb.schema import objects as s_obj from edb.schema import schema as s_schema from edb.pgsql import common from edb.pgsql import dbops from edb.pgsql import schemamech class SchemaDBObjectMeta(adapter.Adapter): # type: ignore def __init__(cls, name, bases, dct, *, adapts=None): adapter.Adapter.__init__(cls, name, bases, dct, adapts=adapts) type(s_obj.Object).__init__(cls, name, bases, dct) class SchemaDBObject(metaclass=SchemaDBObjectMeta): @classmethod def adapt(cls, obj): return cls.copy(obj) class ConstraintCommon: def __init__(self, constraint, schema): self._constr_id = constraint.id self._schema_constr_name = constraint.get_name(schema) self._schema_constr_is_delegated = constraint.get_delegated(schema) self._schema = schema self._constraint = constraint def constraint_name(self, quote=True): name = self.raw_constraint_name() name = common.edgedb_name_to_pg_name(name) return common.quote_ident(name) if quote else name def schema_constraint_name(self): return self._schema_constr_name def raw_constraint_name(self): return common.get_constraint_raw_name(self._constr_id) def generate_extra(self, block: dbops.PLBlock) -> None: text = self.raw_constraint_name() cmd = dbops.Comment(object=self, text=text) cmd.generate(block) @property def delegated(self): return self._schema_constr_is_delegated class SchemaConstraintDomainConstraint( ConstraintCommon, dbops.DomainConstraint ): def __init__(self, domain_name, constraint, exprdata, schema): ConstraintCommon.__init__(self, constraint, schema) dbops.DomainConstraint.__init__(self, domain_name) self._exprdata = exprdata def constraint_code(self, block: dbops.PLBlock) -> str: if len(self._exprdata) == 1: expr = self._exprdata[0].exprdata.plain else: exprs = [e.plain for e in self._exprdata.exprdata] expr = '(' + ') AND ('.join(exprs) + ')' return f'CHECK ({expr})' def __repr__(self): return '<{}.{} {!r} {!r}>'.format( self.__class__.__module__, self.__class__.__name__, self.domain_name, self._constraint) class SchemaConstraintTableConstraint(ConstraintCommon, dbops.TableConstraint): def __init__( self, table_name, *, constraint, exprdata: list[schemamech.ExprData], relative_exprdata: list[schemamech.ExprData], scope, type, table_type, except_data, schema, ): ConstraintCommon.__init__(self, constraint, schema) dbops.TableConstraint.__init__(self, table_name, None) self._exprdata = exprdata self._relative_exprdata = relative_exprdata self._scope = scope self._type = type self._table_type = table_type self._except_data = except_data def is_non_row_and_identical( self, other: SchemaConstraintTableConstraint ) -> bool: """Constraints which only contain references to columns and have no expressions or except clause are considered trivial. Such constraints can be checked for equivalence without actually generating their code. This function should by updated if `constraint_code` changes. """ return ( # Constraint is on the same subject self._subject_name == other._subject_name # Row scope constraints treated differently, see `constraint_code` and self._scope != 'row' and other._scope != 'row' # Expr data is identical and len(self._exprdata) == len(other._exprdata) and all( exprdata.is_trivial == other_exprdata.is_trivial and ( list(exprdata.exprdata.plain_chunks) == list(other_exprdata.exprdata.plain_chunks) ) and ( list(exprdata.exprdata.plain_chunks) == list(other_exprdata.exprdata.plain_chunks) ) for exprdata, other_exprdata in zip( self._exprdata, other._exprdata ) ) # Except data is identical and self._except_data == other._except_data ) def constraint_code(self, block: dbops.PLBlock) -> str | list[str]: if self._scope == 'row': if len(self._exprdata) == 1: expr = self._exprdata[0].exprdata.plain else: exprs = [e.exprdata.plain for e in self._exprdata] expr = '(' + ') AND ('.join(exprs) + ')' if self._except_data: cond = self._except_data.plain expr = f'({expr}) OR ({cond}) is true' return f'CHECK ({expr})' else: if self._type != 'unique': raise ValueError( 'unexpected constraint type: {}'.format(self._type)) constr_exprs = [] for exprdata in self._exprdata: if exprdata.is_trivial and not self._except_data: # A constraint that contains one or more # references to columns, and no expressions. # # Update `is_non_row_and_identical` if this ever changes! expr = ', '.join(exprdata.exprdata.plain_chunks) expr = 'UNIQUE ({})'.format(expr) else: # Complex constraint with arbitrary expressions # needs to use EXCLUDE. # chunks = exprdata.exprdata.plain_chunks expr = ', '.join( "{} WITH =".format(chunk) for chunk in chunks) expr = f'EXCLUDE ({expr})' if self._except_data: cond = self._except_data.plain expr = f'{expr} WHERE (({cond}) is not true)' constr_exprs.append(expr) return constr_exprs def numbered_constraint_name(self, i, quote=True): raw_name = self.raw_constraint_name() name = common.edgedb_name_to_pg_name('{}#{}'.format(raw_name, i)) return common.quote_ident(name) if quote else name def get_trigger_procname(self): return common.get_backend_name( self._schema, self._constraint, catenate=False, aspect='trigproc') def get_trigger_condition(self): chunks = [] for expr in self._exprdata: condition = '{old_expr} IS DISTINCT FROM {new_expr}'.format( old_expr=expr.exprdata.old, new_expr=expr.exprdata.new ) chunks.append(condition) if len(chunks) == 1: return chunks[0] else: return '(' + ') OR ('.join(chunks) + ')' def get_trigger_proc_text(self): chunks = [] constr_name = self.constraint_name() raw_constr_name = self.constraint_name(quote=False) errmsg = 'duplicate key value violates unique ' \ 'constraint {constr}'.format(constr=constr_name) for expr, relative_expr in zip( itertools.cycle(self._exprdata), self._relative_exprdata ): exprdata = expr.exprdata relative_exprdata = relative_expr.exprdata except_data = self._except_data relative_except_data = relative_expr.except_data if self._except_data: except_part = f''' AND ({relative_except_data.plain} is not true) AND ({except_data.new} is not true) ''' else: except_part = '' # Link tables get updated by deleting and then reinserting # rows, and so the trigger might fire even on rows that # did not *really* change. Check `source` also to prevent # spurious errors in those cases. (Anything with the same # source must have the same type, so any genuine constraint # errors this filters away will get caught by the *actual* # constraint.) # We *could* do a check for id on object tables, but it # isn't needed and would take at least some time. src_check = ( ' AND source != NEW.source' if self._table_type == 'link' else '' ) schemaname, tablename = relative_expr.subject_db_name text = ''' PERFORM TRUE FROM {table} WHERE {plain_expr} = {new_expr}{except_part}{src_check}; IF FOUND THEN RAISE unique_violation USING TABLE = '{tablename}', SCHEMA = '{schemaname}', CONSTRAINT = '{constr}', MESSAGE = '{errmsg}', DETAIL = {detail}; END IF; '''.format( table=common.qname(schemaname, tablename), plain_expr=relative_exprdata.plain, new_expr=exprdata.new, except_part=except_part, src_check=src_check, schemaname=schemaname, tablename=tablename, constr=raw_constr_name, errmsg=errmsg, detail=common.quote_literal( f"Key ({relative_exprdata.plain}) already exists." ), ) chunks.append(text) text = 'BEGIN\n' + '\n\n'.join(chunks) + '\nRETURN NEW;\nEND;' return text def requires_triggers(self): return schemamech.table_constraint_requires_triggers( self._constraint, self._schema, self._type, ) def can_disable_triggers(self): return self._constraint.is_independent(self._schema) def __repr__(self): return '<{}.{} {!r} at 0x{:x}>'.format( self.__class__.__module__, self.__class__.__name__, self.schema_constraint_name(), id(self)) class MultiConstraintItem: def __init__( self, constraint: SchemaConstraintTableConstraint, index: int, ): self.constraint = constraint self.index = index def get_type(self): return self.constraint.get_type() def get_id(self): # XXX name = self.constraint.numbered_constraint_name(self.index) return '{} ON {} {}'.format( name, self.constraint.get_subject_type(), self.constraint.get_subject_name()) class AlterTableAddMultiConstraint( dbops.AlterTableAddConstraint[SchemaConstraintTableConstraint] ): def code_with_block(self, block: dbops.PLBlock) -> str: exprs = self.constraint.constraint_code(block) if isinstance(exprs, list) and len(exprs) > 1: chunks = [] for i, expr in enumerate(exprs): name = self.constraint.numbered_constraint_name(i) chunk = f'ADD CONSTRAINT {name} {expr}' chunks.append(chunk) code = ', '.join(chunks) else: if isinstance(exprs, list): exprs = exprs[0] name = self.constraint.constraint_name() code = f'ADD CONSTRAINT {name} {exprs}' return code def generate_extra_composite( self, block: dbops.PLBlock, group: dbops.CompositeCommandGroup ) -> None: comments = [] exprs = self.constraint.constraint_code(block) if isinstance(exprs, list) and len(exprs) > 1: assert isinstance(self.constraint, SchemaConstraintTableConstraint) for i, _expr in enumerate(exprs): name = self.constraint.numbered_constraint_name(i) constraint = MultiConstraintItem(self.constraint, i) comment = dbops.Comment(constraint, name) comments.append(comment) else: name = self.constraint.constraint_name() comment = dbops.Comment(self.constraint, name) comments.append(comment) for comment in comments: comment.generate(block) class AlterTableDropMultiConstraint(dbops.AlterTableDropConstraint): def code_with_block(self, block: dbops.PLBlock) -> str: exprs = self.constraint.constraint_code(block) if isinstance(exprs, list) and len(exprs) > 1: chunks = [] for i, _expr in enumerate(exprs): name = self.constraint.numbered_constraint_name(i) chunk = f'DROP CONSTRAINT {name}' chunks.append(chunk) code = ', '.join(chunks) else: name = self.constraint.constraint_name() code = f'DROP CONSTRAINT {name}' return code class AlterTableConstraintBase(dbops.AlterTableBaseMixin, dbops.CommandGroup): def __init__( self, name: tuple[str, ...], *, constraint: SchemaConstraintTableConstraint, contained: bool = False, conditions: Optional[set[str | dbops.Condition]] = None, neg_conditions: Optional[set[str | dbops.Condition]] = None, ): dbops.CommandGroup.__init__( self, conditions=conditions, neg_conditions=neg_conditions ) dbops.AlterTableBaseMixin.__init__( self, name=name, contained=contained) self._constraint = constraint def _get_triggers( self, table_name: tuple[str, ...], constraint: SchemaConstraintTableConstraint, proc_name='null', ) -> tuple[dbops.Trigger, ...]: cname = constraint.raw_constraint_name() ins_trigger_name = cname + '_instrigger' ins_trigger = dbops.Trigger( name=ins_trigger_name, table_name=table_name, events=('insert', ), procedure=proc_name, is_constraint=True, inherit=True) upd_trigger_name = cname + '_updtrigger' condition = constraint.get_trigger_condition() upd_trigger = dbops.Trigger( name=upd_trigger_name, table_name=table_name, events=('update', ), procedure=proc_name, condition=condition, is_constraint=True, inherit=True) return ins_trigger, upd_trigger def create_constr_trigger( self, table_name: tuple[str, ...], constraint: SchemaConstraintTableConstraint, proc_name: str, ) -> list[dbops.CreateTrigger]: ins_trigger, upd_trigger = self._get_triggers( table_name, constraint, proc_name ) return [ dbops.CreateTrigger(ins_trigger), dbops.CreateTrigger(upd_trigger), ] def drop_constr_trigger( self, table_name: tuple[str, ...], constraint: SchemaConstraintTableConstraint, ) -> list[dbops.DDLOperation]: ins_trigger, upd_trigger = self._get_triggers(table_name, constraint) return [ dbops.DropTrigger(ins_trigger, conditional=True), dbops.DropTrigger(upd_trigger, conditional=True), ] def create_constr_trigger_function( self, constraint: SchemaConstraintTableConstraint ): proc_name = constraint.get_trigger_procname() proc_text = constraint.get_trigger_proc_text() # Because of casting is not immutable in PG, this function may not be # immutable, only stable. But because we check that casing in edgeql # *is* immutable, we can (almost) safely assume that this function is # also immutable. func = dbops.Function( name=proc_name, text=proc_text, volatility='immutable', returns='trigger', language='plpgsql', ) return [dbops.CreateFunction(func, or_replace=True)] def drop_constr_trigger_function(self, proc_name: tuple[str, ...]): return [dbops.DropFunction( name=proc_name, args=(), # Use a condition instead of if_exists ot reduce annoying # debug spew from postgres. conditions=[dbops.FunctionExists(name=proc_name, args=())], )] def create_constraint(self, constraint: SchemaConstraintTableConstraint): # Add the constraint normally to our table # my_alter = dbops.AlterTable(self.name) add_constr = AlterTableAddMultiConstraint(constraint=constraint) my_alter.add_command(add_constr) self.add_command(my_alter) def create_constraint_trigger_and_fuction( self, constraint: SchemaConstraintTableConstraint ): """Create constraint trigger FUNCTION and TRIGGER. Adds the new function to the trigger. Disables the trigger if possible. """ if ( constraint.requires_triggers() and not constraint.can_disable_triggers() ): # Create trigger function self.add_commands(self.create_constr_trigger_function(constraint)) proc_name = constraint.get_trigger_procname() cr_trigger = self.create_constr_trigger( self.name, constraint, proc_name) self.add_commands(cr_trigger) def alter_constraint( self, old_constraint: SchemaConstraintTableConstraint, new_constraint: SchemaConstraintTableConstraint, ): if old_constraint.delegated and not new_constraint.delegated: # No longer delegated, create db structures self.create_constraint(new_constraint) elif not old_constraint.delegated and new_constraint.delegated: # Becoming delegated, drop db structures self.drop_constraint(old_constraint) elif not new_constraint.delegated: # Some other modification if old_constraint.is_non_row_and_identical(new_constraint): # If the constraint itself is unchanged, it is still necessary # to drop any constraint triggers. This is to clear any # postgres dependencies on constraint columns. # # For example, given: # type X { property n: int64 { constraint exclusive } }; # type Y extending X; # # Altering the property with # alter type X alter property n using (1); # # will result in the column for X.n to be dropped. To ensure # this operation succeeds, the constraint trigger must be # deleted before dropping the column. self.drop_constraint_trigger_and_fuction(old_constraint) else: # If the constraint is actually different, drop and create. self.drop_constraint(old_constraint) self.create_constraint(new_constraint) def drop_constraint(self, constraint: SchemaConstraintTableConstraint): self.drop_constraint_trigger_and_fuction(constraint) # Drop the constraint normally from our table # my_alter = dbops.AlterTable(constraint._subject_name) drop_constr = AlterTableDropMultiConstraint(constraint=constraint) my_alter.add_command(drop_constr) self.add_command(my_alter) def drop_constraint_trigger_and_fuction( self, constraint: SchemaConstraintTableConstraint ): """Drop constraint trigger FUNCTION and TRIGGER.""" if constraint.requires_triggers(): self.add_commands(self.drop_constr_trigger( constraint._subject_name, constraint)) proc_name = constraint.get_trigger_procname() self.add_commands(self.drop_constr_trigger_function(proc_name)) class AlterTableAddConstraint(AlterTableConstraintBase): def __repr__(self): return '<{}.{} {!r}>'.format( self.__class__.__module__, self.__class__.__name__, self._constraint) def generate(self, block): if not self._constraint.delegated: self.create_constraint(self._constraint) super().generate(block) class AlterTableAlterConstraint(AlterTableConstraintBase): def __init__( self, name, *, constraint, new_constraint, **kwargs ): super().__init__(name, constraint=constraint, **kwargs) self._new_constraint = new_constraint def __repr__(self): return '<{}.{} {!r}>'.format( self.__class__.__module__, self.__class__.__name__, self._constraint) def generate(self, block): self.alter_constraint(self._constraint, self._new_constraint) super().generate(block) class AlterTableDropConstraint(AlterTableConstraintBase): def __repr__(self): return '<{}.{} {!r}>'.format( self.__class__.__module__, self.__class__.__name__, self._constraint) def generate(self, block): if not self._constraint.delegated: self.drop_constraint(self._constraint) super().generate(block) class AlterTableUpdateConstraintTrigger(AlterTableConstraintBase): def __repr__(self): return '<{}.{} {!r}>'.format( self.__class__.__module__, self.__class__.__name__, self._constraint) def generate(self, block): self.drop_constraint_trigger_and_fuction(self._constraint) self.create_constraint_trigger_and_fuction(self._constraint) super().generate(block) class AlterTableUpdateConstraintTriggerFixup(AlterTableConstraintBase): def __repr__(self): return '<{}.{} {!r}>'.format( self.__class__.__module__, self.__class__.__name__, self._constraint) def generate(self, block): # Pre 6.8 versions of gel created needless disabled triggers # in some cases. This path (invoked by administer # remove_pointless_triggers()) deletes them. if ( self._constraint.requires_triggers() and self._constraint.can_disable_triggers() ): self.drop_constraint_trigger_and_fuction(self._constraint) super().generate(block) def rename_pg_index( old_index: s_indexes.Index, new_index: s_indexes.Index, schema: s_schema.Schema, aspect: str = 'index' ) -> dbops.Command: table_name = common.get_index_table_backend_name(new_index, schema) module_name = new_index.get_name(schema).module old_index_name = common.get_index_backend_name( old_index.id, module_name, catenate=False, aspect=aspect ) new_index_name = common.get_index_backend_name( new_index.id, module_name, catenate=False, aspect=aspect ) pg_index = dbops.Index( name=old_index_name[1], table_name=table_name, # type: ignore ) return dbops.RenameIndex( pg_index, new_name=new_index_name[1], conditional=True, ) ================================================ FILE: edb/pgsql/deltafts.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Optional, Iterable, Sequence, Collection from edb import errors from edb.schema import indexes as s_indexes from edb.schema import types as s_types from edb.schema import expr as s_expr from edb.schema import schema as s_schema from edb.schema import delta as sd from edb.schema import name as sn from edb.ir import ast as irast from edb.edgeql import compiler as qlcompiler from . import common from . import dbops from . import deltadbops from . import compiler from . import codegen from . import types from . import ast as pgast from .common import qname as q from .common import quote_literal as ql from .common import quote_ident as qi from .compiler import astutils from .compiler import enums as pgce def create_fts_index( index: s_indexes.Index, index_expr: irast.Set, predicate_src: Optional[str], sql_kwarg_exprs: dict[str, str], options: qlcompiler.CompilerOptions, schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: subject = index.get_subject(schema) assert isinstance(subject, s_indexes.IndexableSubject) effective, has_overridden = s_indexes.get_effective_object_index( schema, subject, sn.QualName("std::fts", "index") ) if index != effective: return dbops.CommandGroup() # When creating an index on a child that already has an fts index # inherited from the parent, we don't need to create the index, but just # update the populating expressions. if has_overridden: return _refresh_fts_document( index, has_overridden[0], options, schema, context ) else: return _create_fts_document( index, index_expr, predicate_src, sql_kwarg_exprs, schema, context, ) def delete_fts_index( index: s_indexes.Index, drop_index: dbops.Command, options: qlcompiler.CompilerOptions, schema: s_schema.Schema, orig_schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: subject = index.get_subject(orig_schema) assert isinstance(subject, s_indexes.IndexableSubject) effective, _ = s_indexes.get_effective_object_index( schema, subject, sn.QualName("std::fts", "index") ) if not effective: return _delete_fts_document(index, drop_index, orig_schema, context) else: # effective index remains: don't drop the fts document effective_subject = effective.get_subject(schema) is_eff_on_direct_parent = effective_subject in subject.get_bases( schema ).objects(schema) if is_eff_on_direct_parent: return _refresh_fts_document( index, effective, options, schema, context ) else: return dbops.CommandGroup() def _compile_ir_index_exprs( index: s_indexes.Index, index_expr: irast.Set, schema: s_schema.Schema ): subject = index.get_subject(schema) assert isinstance(subject, s_types.Type) subject_id = irast.PathId.from_type(schema, subject, env=None) sql_res = compiler.compile_ir_to_sql_tree( index_expr, singleton_mode=True, external_rvars={ (subject_id, pgce.PathAspect.SOURCE): pgast.RelRangeVar( alias=pgast.Alias(aliasname='NEW'), relation=pgast.Relation(name='NEW'), ) }, ) return astutils.maybe_unpack_row(sql_res.ast) def _create_fts_document( index: s_indexes.Index, index_expr: irast.Set, predicate_src: Optional[str], sql_kwarg_exprs: dict[str, str], schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: exprs = _compile_ir_index_exprs(index, index_expr, schema) from edb.common import debug if debug.flags.zombodb: return _zombo_create_fts_document( index, exprs, predicate_src, sql_kwarg_exprs, schema ) else: return _pg_create_fts_document( index, exprs, predicate_src, sql_kwarg_exprs, schema ) def _delete_fts_document( index: s_indexes.Index, drop_index: dbops.Command, schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: table_name = common.get_index_table_backend_name(index, schema) ops = dbops.CommandGroup() ops.add_command(drop_index) from edb.common import debug if debug.flags.zombodb: zombo_func_name = _zombo_func_name(table_name) ops.add_command(dbops.DropFunction(zombo_func_name, args=[table_name])) zombo_type_name = _zombo_type_name(table_name) ops.add_command(dbops.DropCompositeType(zombo_type_name)) else: ops.add_command(_pg_drop_trigger(table_name)) # When the ObjectType is being deleted, we don't drop the index, as it # will get dropped with parent table. # The same goes for the __fts_document__ column. source_drop = isinstance(drop_index, dbops.NoOpCommand) if not source_drop: fts_document = dbops.Column( name=f'__fts_document__', type=('pg_catalog', 'tsvector'), ) alter_table = dbops.AlterTable(table_name) alter_table.add_operation(dbops.AlterTableDropColumn(fts_document)) ops.add_command(alter_table) return ops def update_fts_document( index: s_indexes.Index, options: qlcompiler.CompilerOptions, schema: s_schema.Schema, ) -> dbops.Query: table_name = common.get_index_table_backend_name(index, schema) # compile the expression index_sexpr: Optional[s_expr.Expression] = index.get_expr(schema) assert index_sexpr index_expr = index_sexpr.ensure_compiled( schema=schema, options=options, context=None, ) exprs = _compile_ir_index_exprs(index, index_expr.irast.expr, schema) from edb.common import debug if debug.flags.zombodb: raise NotImplementedError('zombo refresh index not implemented') else: # to avoid code duplication, we call code for creating triggers and # extract the first UPDATE command create_trigger_ops = _pg_create_trigger(table_name, exprs) update_fts_document_op = create_trigger_ops.commands[0] assert isinstance(update_fts_document_op, dbops.Query) return update_fts_document_op def _refresh_fts_document( index: s_indexes.Index, old_index: s_indexes.Index, options: qlcompiler.CompilerOptions, schema: s_schema.Schema, context: sd.CommandContext, ) -> dbops.Command: table_name = common.get_index_table_backend_name(index, schema) # compile the expression index_sexpr: Optional[s_expr.Expression] = index.get_expr(schema) assert index_sexpr index_expr = index_sexpr.ensure_compiled( schema=schema, options=options, context=context, ) exprs = _compile_ir_index_exprs(index, index_expr.irast.expr, schema) ops = dbops.CommandGroup() from edb.common import debug if debug.flags.zombodb: raise NotImplementedError('zombo refresh index not implemented') else: ops.add_command(_pg_drop_trigger(table_name)) ops.add_command(_pg_create_trigger(table_name, exprs)) # Sigh, we need to rename the main index to match the new id, # entirely for the purpose of having ANALYZE be able to pick it up ops.add_command( deltadbops.rename_pg_index( old_index=old_index, new_index=index, schema=schema, ) ) return ops def _raise_unsupported_language_error( unsupported: Collection[str], ) -> None: unsupported = list(unsupported) unsupported.sort() msg = 'Full text search language' if len(unsupported) > 1: msg += 's' msg += ' ' + ', '.join(f'`{l}`' for l in unsupported) msg += ' not supported' raise errors.UnsupportedFeatureError(msg) # --- pg fts --- def _pg_create_fts_document( index: s_indexes.Index, exprs: Sequence[pgast.BaseExpr], predicate_src: Optional[str], sql_kwarg_exprs: dict[str, str], schema: s_schema.Schema, ) -> dbops.Command: ops = dbops.CommandGroup() # create column __fts_document__ table_name = common.get_index_table_backend_name(index, schema) module_name = index.get_name(schema).module index_name = common.get_index_backend_name( index.id, module_name, catenate=False ) fts_document = dbops.Column( name=f'__fts_document__', type='pg_catalog.tsvector' ) alter_table = dbops.AlterTable(table_name) alter_table.add_operation(dbops.AlterTableAddColumn(fts_document)) ops.add_command(alter_table) ops.add_command(_pg_create_trigger(table_name, exprs)) pg_index = dbops.Index( name=index_name[1], table_name=table_name, # type: ignore exprs=['__fts_document__'], unique=False, inherit=True, predicate=predicate_src, metadata={ 'schemaname': str(index.get_name(schema)), 'kwargs': sql_kwarg_exprs, # use a reference to the new column in the index instead 'code': 'gin (__col__)', }, ) ops.add_command(dbops.CreateIndex(pg_index)) return ops def _pg_create_trigger( table_name: tuple[str, str], exprs: Sequence[pgast.BaseExpr], ) -> dbops.CommandGroup: ops = dbops.CommandGroup() # prepare the expression to update __fts_document__ document_exprs = [] for expr in exprs: assert isinstance(expr, pgast.FTSDocument) lang_domain: Iterable[str] = expr.language_domain lang_domain = map(types.to_regconfig, lang_domain) unsupported = set(lang_domain).difference(types.pg_langs) if len(unsupported) > 0: _raise_unsupported_language_error(unsupported) text_sql = codegen.generate_source(expr.text) language_sql = codegen.generate_source(expr.language) document_expr = f''' to_tsvector( edgedb.fts_to_regconfig(({language_sql})::text), COALESCE({text_sql}, '') ) ''' if expr.weight: document_expr = f'setweight({document_expr}, {ql(expr.weight)})' document_exprs.append(document_expr) document_sql = ' || '.join(document_exprs) if document_exprs else 'NULL' # update existing rows ops.add_command( dbops.Query( f""" UPDATE {q(*table_name)} as NEW SET __fts_document__ = ({document_sql}); """ ) ) # create update function func_name = _pg_update_func_name(table_name) function = dbops.Function( name=func_name, text=f''' BEGIN NEW.__fts_document__ := ({document_sql}); RETURN NEW; END; ''', volatility='immutable', returns='trigger', language='plpgsql', ) ops.add_command(dbops.CreateFunction(function)) # create trigger to update the __fts_document__ trigger_name = _pg_trigger_name(table_name[1]) trigger = dbops.Trigger( name=trigger_name, table_name=table_name, events=('insert', 'update'), timing=dbops.TriggerTiming.Before, procedure=func_name, ) ops.add_command(dbops.CreateTrigger(trigger)) return ops def _pg_drop_trigger( table_name: tuple[str, str], ) -> dbops.Command: ops = dbops.CommandGroup() ops.add_command( dbops.DropTrigger( dbops.Trigger( _pg_trigger_name(table_name[1]), table_name=table_name, events=(), procedure='', ) ) ) ops.add_command( dbops.DropFunction( _pg_update_func_name(table_name), (), ) ) return ops def _pg_update_func_name( tbl_name: tuple[str, str], ) -> tuple[str, ...]: return ( tbl_name[0], common.edgedb_name_to_pg_name(tbl_name[1] + '_ftsupdate'), ) def _pg_trigger_name( tbl_name: str, ) -> str: return common.edgedb_name_to_pg_name(tbl_name + '_ftstrigger') # --- zombo --- def _zombo_create_fts_document( index: s_indexes.Index, exprs: Sequence[pgast.BaseExpr], predicate_src: Optional[str], sql_kwarg_exprs: dict[str, str], schema: s_schema.Schema, ) -> dbops.Command: ops = dbops.CommandGroup() table_name = common.get_index_table_backend_name(index, schema) module_name = index.get_name(schema).module index_name = common.get_index_backend_name( index.id, module_name, catenate=False ) zombo_type_name = _zombo_type_name(table_name) ops.add_command( dbops.CreateCompositeType( dbops.CompositeType( name=zombo_type_name, columns=[ dbops.Column( name=f'field{idx}', type='text', ) for idx, _ in enumerate(exprs) ], ) ) ) type_mappings: list[tuple[str, str]] = [] document_exprs = [] for idx, expr in enumerate(exprs): assert isinstance(expr, pgast.FTSDocument) text_sql = codegen.generate_source(expr.text) if len(expr.language_domain) != 1: raise errors.UnsupportedFeatureError( 'zombo fts indexes support only exactly one language' ) language = next(iter(expr.language_domain)) document_exprs.append(text_sql) type_mappings.append((f'field{idx}', language)) zombo_func_name = _zombo_func_name(table_name) ops.add_command( dbops.CreateFunction( dbops.Function( name=zombo_func_name, args=[('new', table_name)], returns=zombo_type_name, text=f''' SELECT ROW({','.join(document_exprs)})::{q(*zombo_type_name)}; ''', ) ) ) for col_name, language in type_mappings: mapping = f'{{"type": "text", "analyzer": "{language}"}}' ops.add_command( dbops.Query( f"""PERFORM zdb.define_field_mapping( {ql(q(*table_name))}::regclass, {ql(col_name)}::text, {ql(mapping)}::json )""" ) ) index_exprs = [f'{q(*zombo_func_name)}({qi(table_name[1])}.*)'] pg_index = dbops.Index( name=index_name[1], table_name=table_name, # type: ignore exprs=index_exprs, unique=False, inherit=True, with_clause={'url': ql('http://localhost:9200/')}, predicate=predicate_src, metadata={ 'schemaname': str(index.get_name(schema)), 'code': 'zombodb ((__col__))', 'kwargs': sql_kwarg_exprs, }, ) ops.add_command(dbops.CreateIndex(pg_index)) return ops def _zombo_type_name( tbl_name: tuple[str, str], ) -> tuple[str, str]: return ( tbl_name[0], common.edgedb_name_to_pg_name(tbl_name[1] + '_zombo_type'), ) def _zombo_func_name( tbl_name: tuple[str, str], ) -> tuple[str, ...]: return ( tbl_name[0], common.edgedb_name_to_pg_name(tbl_name[1] + '_zombo_func'), ) ================================================ FILE: edb/pgsql/inheritance.py ================================================ from __future__ import annotations from typing import Optional, AbstractSet, Iterator from edb.schema import links as s_links from edb.schema import name as sn from edb.schema import pointers as s_pointers from edb.schema import sources as s_sources from edb.schema import schema as s_schema from edb.ir import typeutils as irtyputils from . import ast as pgast from . import types from . import common def get_inheritance_view( schema: s_schema.Schema, obj: s_sources.Source | s_pointers.Pointer, exclude_children: AbstractSet[ s_sources.Source | s_pointers.Pointer ] = frozenset(), exclude_ptrs: AbstractSet[s_pointers.Pointer] = frozenset(), ) -> pgast.SelectStmt: ptrs: dict[sn.UnqualName, tuple[list[str], tuple[str, ...]]] = {} if isinstance(obj, s_sources.Source): pointers = list(obj.get_pointers(schema).items(schema)) # Sort by UUID timestamp for stable VIEW column order. pointers.sort(key=lambda p: p[1].id.time) for ptrname, ptr in pointers: if ptr in exclude_ptrs: continue if ptr.is_pure_computable(schema): continue ptr_stor_info = types.get_pointer_storage_info( ptr, link_bias=isinstance(obj, s_links.Link), schema=schema, ) if ( isinstance(obj, s_links.Link) or ptr_stor_info.table_type == 'ObjectType' ): ptrs[ptrname] = ( [ptr_stor_info.column_name], ptr_stor_info.column_type, ) shortname = ptr.get_shortname(schema).name if shortname != ptr_stor_info.column_name: ptrs[ptrname][0].append(common.quote_ident(shortname)) for name, alias, type in obj.get_addon_columns(schema): ptrs[sn.UnqualName(name)] = ([alias], type) else: # MULTI PROPERTY ptrs[sn.UnqualName('source')] = (['source'], ('uuid',)) lp_info = types.get_pointer_storage_info( obj, link_bias=True, schema=schema, ) ptrs[sn.UnqualName('target')] = (['target'], lp_info.column_type) descendants = [ child for child in obj.descendants(schema) if types.has_table(child, schema) and child not in exclude_children # XXX: Exclude sys/cfg tables from non sys/cfg views. This # probably isn't *really* what we want to do, but until we # figure that out, do *something* so that DDL isn't # excruciatingly slow because of the cost of explicit id # checks. See #5168. and not irtyputils.is_excluded_cfg_view( child, ancestor=obj, schema=schema ) ] # Hackily force 'source' to appear in abstract links. We need # source present in the code we generate to enforce newly # created exclusive constraints across types. if ( ptrs and isinstance(obj, s_links.Link) and sn.UnqualName('source') not in ptrs and obj.is_non_concrete(schema) ): ptrs[sn.UnqualName('source')] = (['source'], ('uuid',)) components = [] components.append(_get_select_from(schema, obj, ptrs)) components.extend( _get_select_from(schema, child, ptrs) for child in descendants ) return _union_all(filter(None, components)) def _union_all(components: Iterator[pgast.SelectStmt]) -> pgast.SelectStmt: query = next(components) for component in components: query = pgast.SelectStmt( larg=query, op='UNION', all=True, rarg=component, ) return query def _get_select_from( schema: s_schema.Schema, obj: s_sources.Source | s_pointers.Pointer, ptr_names: dict[sn.UnqualName, tuple[list[str], tuple[str, ...]]], ) -> Optional[pgast.SelectStmt]: schema_name, table_name = common.get_backend_name( schema, obj, catenate=False, aspect='table', ) # the name of the rel var of the object table within the select query table_rvar_name = table_name target_list: list[pgast.ResTarget] = [] system_cols = ['tableoid', 'xmin', 'cmin', 'xmax', 'cmax', 'ctid'] for sys_col_name in system_cols: val: pgast.BaseExpr if not irtyputils.is_cfg_view(obj, schema): val = pgast.ColumnRef(name=(table_rvar_name, sys_col_name)) else: val = pgast.NullConstant() target_list.append(pgast.ResTarget(name=sys_col_name, val=val)) if isinstance(obj, s_sources.Source): ptrs = dict(obj.get_pointers(schema).items(schema)) for ptr_name, (aliases, pg_type) in ptr_names.items(): ptr = ptrs.get(ptr_name) if ptr_name == sn.UnqualName('__type__'): # __type__ is special cased: since it is uniquely # determined by the type, we directly insert it # into the views instead of storing it (to save space) val = pgast.TypeCast( arg=pgast.StringConstant(val=str(obj.id)), type_name=pgast.TypeName(name=('uuid',)), ) elif ptr is not None: ptr_stor_info = types.get_pointer_storage_info( ptr, link_bias=isinstance(obj, s_links.Link), schema=schema, ) if ptr_stor_info.column_type != pg_type: return None val = pgast.ColumnRef( name=(table_rvar_name, ptr_stor_info.column_name) ) elif ptr_name == sn.UnqualName('source'): val = pgast.TypeCast( arg=pgast.NullConstant(), type_name=pgast.TypeName(name=('uuid',)), ) elif ptr_name == sn.UnqualName('__fts_document__') or ( ptr_name.name.startswith('__ext_ai_') and ptr_name.name.endswith('__') ): # an addon column val = pgast.ColumnRef(name=(table_rvar_name, ptr_name.name)) else: return None for alias in aliases: target_list.append(pgast.ResTarget(name=alias, val=val)) else: for ptr_name, (aliases, _) in ptr_names.items(): for alias in aliases: target_list.append( pgast.ResTarget( name=alias, val=pgast.ColumnRef( name=(table_rvar_name, str(ptr_name)), ), ) ) return pgast.SelectStmt( from_clause=[ pgast.RelRangeVar( alias=pgast.Alias(aliasname=table_rvar_name), relation=pgast.Relation( schemaname=schema_name, name=table_name, ), ) ], target_list=target_list, ) ================================================ FILE: edb/pgsql/keywords.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2010-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations keyword_types = range(1, 5) (UNRESERVED_KEYWORD, RESERVED_KEYWORD, TYPE_FUNC_NAME_KEYWORD, COL_NAME_KEYWORD) = keyword_types pg_keywords = { "abort": ("ABORT_P", UNRESERVED_KEYWORD), "absolute": ("ABSOLUTE_P", UNRESERVED_KEYWORD), "access": ("ACCESS", UNRESERVED_KEYWORD), "action": ("ACTION", UNRESERVED_KEYWORD), "add": ("ADD_P", UNRESERVED_KEYWORD), "admin": ("ADMIN", UNRESERVED_KEYWORD), "after": ("AFTER", UNRESERVED_KEYWORD), "aggregate": ("AGGREGATE", UNRESERVED_KEYWORD), "all": ("ALL", RESERVED_KEYWORD), "also": ("ALSO", UNRESERVED_KEYWORD), "alter": ("ALTER", UNRESERVED_KEYWORD), "always": ("ALWAYS", UNRESERVED_KEYWORD), "analyse": ("ANALYSE", RESERVED_KEYWORD), "analyze": ("ANALYZE", RESERVED_KEYWORD), "and": ("AND", RESERVED_KEYWORD), "any": ("ANY", RESERVED_KEYWORD), "array": ("ARRAY", RESERVED_KEYWORD), "as": ("AS", RESERVED_KEYWORD), "asc": ("ASC", RESERVED_KEYWORD), "assertion": ("ASSERTION", UNRESERVED_KEYWORD), "assignment": ("ASSIGNMENT", UNRESERVED_KEYWORD), "asymmetric": ("ASYMMETRIC", RESERVED_KEYWORD), "at": ("AT", UNRESERVED_KEYWORD), "authorization": ("AUTHORIZATION", TYPE_FUNC_NAME_KEYWORD), "backward": ("BACKWARD", UNRESERVED_KEYWORD), "before": ("BEFORE", UNRESERVED_KEYWORD), "begin": ("BEGIN_P", UNRESERVED_KEYWORD), "between": ("BETWEEN", COL_NAME_KEYWORD), "bigint": ("BIGINT", COL_NAME_KEYWORD), "binary": ("BINARY", TYPE_FUNC_NAME_KEYWORD), "bit": ("BIT", COL_NAME_KEYWORD), "boolean": ("BOOLEAN_P", COL_NAME_KEYWORD), "both": ("BOTH", RESERVED_KEYWORD), "by": ("BY", UNRESERVED_KEYWORD), "cache": ("CACHE", UNRESERVED_KEYWORD), "called": ("CALLED", UNRESERVED_KEYWORD), "cascade": ("CASCADE", UNRESERVED_KEYWORD), "cascaded": ("CASCADED", UNRESERVED_KEYWORD), "case": ("CASE", RESERVED_KEYWORD), "cast": ("CAST", RESERVED_KEYWORD), "catalog": ("CATALOG_P", UNRESERVED_KEYWORD), "chain": ("CHAIN", UNRESERVED_KEYWORD), "char": ("CHAR_P", COL_NAME_KEYWORD), "character": ("CHARACTER", COL_NAME_KEYWORD), "characteristics": ("CHARACTERISTICS", UNRESERVED_KEYWORD), "check": ("CHECK", RESERVED_KEYWORD), "checkpoint": ("CHECKPOINT", UNRESERVED_KEYWORD), "class": ("CLASS", UNRESERVED_KEYWORD), "close": ("CLOSE", UNRESERVED_KEYWORD), "cluster": ("CLUSTER", UNRESERVED_KEYWORD), "coalesce": ("COALESCE", COL_NAME_KEYWORD), "collate": ("COLLATE", RESERVED_KEYWORD), "column": ("COLUMN", RESERVED_KEYWORD), "comment": ("COMMENT", UNRESERVED_KEYWORD), "comments": ("COMMENTS", UNRESERVED_KEYWORD), "commit": ("COMMIT", UNRESERVED_KEYWORD), "committed": ("COMMITTED", UNRESERVED_KEYWORD), "concurrently": ("CONCURRENTLY", TYPE_FUNC_NAME_KEYWORD), "configuration": ("CONFIGURATION", UNRESERVED_KEYWORD), "connection": ("CONNECTION", UNRESERVED_KEYWORD), "constraint": ("CONSTRAINT", RESERVED_KEYWORD), "constraints": ("CONSTRAINTS", UNRESERVED_KEYWORD), "content": ("CONTENT_P", UNRESERVED_KEYWORD), "continue": ("CONTINUE_P", UNRESERVED_KEYWORD), "conversion": ("CONVERSION_P", UNRESERVED_KEYWORD), "copy": ("COPY", UNRESERVED_KEYWORD), "cost": ("COST", UNRESERVED_KEYWORD), "create": ("CREATE", RESERVED_KEYWORD), "createdb": ("CREATEDB", UNRESERVED_KEYWORD), "createrole": ("CREATEROLE", UNRESERVED_KEYWORD), "createuser": ("CREATEUSER", UNRESERVED_KEYWORD), "cross": ("CROSS", TYPE_FUNC_NAME_KEYWORD), "csv": ("CSV", UNRESERVED_KEYWORD), "current": ("CURRENT_P", UNRESERVED_KEYWORD), "current_catalog": ("CURRENT_CATALOG", RESERVED_KEYWORD), "current_date": ("CURRENT_DATE", RESERVED_KEYWORD), "current_role": ("CURRENT_ROLE", RESERVED_KEYWORD), "current_schema": ("CURRENT_SCHEMA", TYPE_FUNC_NAME_KEYWORD), "current_time": ("CURRENT_TIME", RESERVED_KEYWORD), "current_timestamp": ("CURRENT_TIMESTAMP", RESERVED_KEYWORD), "current_user": ("CURRENT_USER", RESERVED_KEYWORD), "cursor": ("CURSOR", UNRESERVED_KEYWORD), "cycle": ("CYCLE", UNRESERVED_KEYWORD), "data": ("DATA_P", UNRESERVED_KEYWORD), "database": ("DATABASE", UNRESERVED_KEYWORD), "day": ("DAY_P", UNRESERVED_KEYWORD), "deallocate": ("DEALLOCATE", UNRESERVED_KEYWORD), "dec": ("DEC", COL_NAME_KEYWORD), "decimal": ("DECIMAL_P", COL_NAME_KEYWORD), "declare": ("DECLARE", UNRESERVED_KEYWORD), "default": ("DEFAULT", RESERVED_KEYWORD), "defaults": ("DEFAULTS", UNRESERVED_KEYWORD), "deferrable": ("DEFERRABLE", RESERVED_KEYWORD), "deferred": ("DEFERRED", UNRESERVED_KEYWORD), "definer": ("DEFINER", UNRESERVED_KEYWORD), "delete": ("DELETE_P", UNRESERVED_KEYWORD), "delimiter": ("DELIMITER", UNRESERVED_KEYWORD), "delimiters": ("DELIMITERS", UNRESERVED_KEYWORD), "desc": ("DESC", RESERVED_KEYWORD), "dictionary": ("DICTIONARY", UNRESERVED_KEYWORD), "disable": ("DISABLE_P", UNRESERVED_KEYWORD), "discard": ("DISCARD", UNRESERVED_KEYWORD), "distinct": ("DISTINCT", RESERVED_KEYWORD), "do": ("DO", RESERVED_KEYWORD), "document": ("DOCUMENT_P", UNRESERVED_KEYWORD), "domain": ("DOMAIN_P", UNRESERVED_KEYWORD), "double": ("DOUBLE_P", UNRESERVED_KEYWORD), "drop": ("DROP", UNRESERVED_KEYWORD), "each": ("EACH", UNRESERVED_KEYWORD), "else": ("ELSE", RESERVED_KEYWORD), "enable": ("ENABLE_P", UNRESERVED_KEYWORD), "encoding": ("ENCODING", UNRESERVED_KEYWORD), "encrypted": ("ENCRYPTED", UNRESERVED_KEYWORD), "end": ("END_P", RESERVED_KEYWORD), "enum": ("ENUM_P", UNRESERVED_KEYWORD), "escape": ("ESCAPE", UNRESERVED_KEYWORD), "except": ("EXCEPT", RESERVED_KEYWORD), "exclude": ("EXCLUDE", UNRESERVED_KEYWORD), "excluding": ("EXCLUDING", UNRESERVED_KEYWORD), "exclusive": ("EXCLUSIVE", UNRESERVED_KEYWORD), "execute": ("EXECUTE", UNRESERVED_KEYWORD), "exists": ("EXISTS", COL_NAME_KEYWORD), "explain": ("EXPLAIN", UNRESERVED_KEYWORD), "external": ("EXTERNAL", UNRESERVED_KEYWORD), "extract": ("EXTRACT", COL_NAME_KEYWORD), "false": ("FALSE_P", RESERVED_KEYWORD), "family": ("FAMILY", UNRESERVED_KEYWORD), "fetch": ("FETCH", RESERVED_KEYWORD), "first": ("FIRST_P", UNRESERVED_KEYWORD), "float": ("FLOAT_P", COL_NAME_KEYWORD), "following": ("FOLLOWING", UNRESERVED_KEYWORD), "for": ("FOR", RESERVED_KEYWORD), "force": ("FORCE", UNRESERVED_KEYWORD), "foreign": ("FOREIGN", RESERVED_KEYWORD), "forward": ("FORWARD", UNRESERVED_KEYWORD), "freeze": ("FREEZE", TYPE_FUNC_NAME_KEYWORD), "from": ("FROM", RESERVED_KEYWORD), "full": ("FULL", TYPE_FUNC_NAME_KEYWORD), "function": ("FUNCTION", UNRESERVED_KEYWORD), "functions": ("FUNCTIONS", UNRESERVED_KEYWORD), "global": ("GLOBAL", UNRESERVED_KEYWORD), "grant": ("GRANT", RESERVED_KEYWORD), "granted": ("GRANTED", UNRESERVED_KEYWORD), "greatest": ("GREATEST", COL_NAME_KEYWORD), "group": ("GROUP_P", RESERVED_KEYWORD), "handler": ("HANDLER", UNRESERVED_KEYWORD), "having": ("HAVING", RESERVED_KEYWORD), "header": ("HEADER_P", UNRESERVED_KEYWORD), "hold": ("HOLD", UNRESERVED_KEYWORD), "hour": ("HOUR_P", UNRESERVED_KEYWORD), "identity": ("IDENTITY_P", UNRESERVED_KEYWORD), "if": ("IF_P", UNRESERVED_KEYWORD), "ilike": ("ILIKE", TYPE_FUNC_NAME_KEYWORD), "immediate": ("IMMEDIATE", UNRESERVED_KEYWORD), "immutable": ("IMMUTABLE", UNRESERVED_KEYWORD), "implicit": ("IMPLICIT_P", UNRESERVED_KEYWORD), "in": ("IN_P", RESERVED_KEYWORD), "including": ("INCLUDING", UNRESERVED_KEYWORD), "increment": ("INCREMENT", UNRESERVED_KEYWORD), "index": ("INDEX", UNRESERVED_KEYWORD), "indexes": ("INDEXES", UNRESERVED_KEYWORD), "inherit": ("INHERIT", UNRESERVED_KEYWORD), "inherits": ("INHERITS", UNRESERVED_KEYWORD), "initially": ("INITIALLY", RESERVED_KEYWORD), "inline": ("INLINE_P", UNRESERVED_KEYWORD), "inner": ("INNER_P", TYPE_FUNC_NAME_KEYWORD), "inout": ("INOUT", COL_NAME_KEYWORD), "input": ("INPUT_P", UNRESERVED_KEYWORD), "insensitive": ("INSENSITIVE", UNRESERVED_KEYWORD), "insert": ("INSERT", UNRESERVED_KEYWORD), "instead": ("INSTEAD", UNRESERVED_KEYWORD), "int": ("INT_P", COL_NAME_KEYWORD), "integer": ("INTEGER", COL_NAME_KEYWORD), "intersect": ("INTERSECT", RESERVED_KEYWORD), "interval": ("INTERVAL", COL_NAME_KEYWORD), "into": ("INTO", RESERVED_KEYWORD), "invoker": ("INVOKER", UNRESERVED_KEYWORD), "is": ("IS", TYPE_FUNC_NAME_KEYWORD), "isnull": ("ISNULL", TYPE_FUNC_NAME_KEYWORD), "isolation": ("ISOLATION", UNRESERVED_KEYWORD), "join": ("JOIN", TYPE_FUNC_NAME_KEYWORD), "json": ("JSON", TYPE_FUNC_NAME_KEYWORD), "json_array": ("JSON_ARRAY", TYPE_FUNC_NAME_KEYWORD), "json_arrayagg": ("JSON_ARRAYAGG", TYPE_FUNC_NAME_KEYWORD), "json_exists": ("JSON_EXISTS", TYPE_FUNC_NAME_KEYWORD), "json_object": ("JSON_OBJECT", TYPE_FUNC_NAME_KEYWORD), "json_objectagg": ("JSON_OBJECTAGG", TYPE_FUNC_NAME_KEYWORD), "json_query": ("JSON_QUERY", TYPE_FUNC_NAME_KEYWORD), "json_scalar": ("JSON_SCALAR", TYPE_FUNC_NAME_KEYWORD), "json_serialize": ("JSON_SERIALIZE", TYPE_FUNC_NAME_KEYWORD), "json_table": ("JSON_TABLE", TYPE_FUNC_NAME_KEYWORD), "json_table_primitive": ("JSON_TABLE_PRIMITIVE", TYPE_FUNC_NAME_KEYWORD), "json_value": ("JSON_VALUE", TYPE_FUNC_NAME_KEYWORD), "key": ("KEY", UNRESERVED_KEYWORD), "language": ("LANGUAGE", UNRESERVED_KEYWORD), "large": ("LARGE_P", UNRESERVED_KEYWORD), "last": ("LAST_P", UNRESERVED_KEYWORD), "lc_collate": ("LC_COLLATE_P", UNRESERVED_KEYWORD), "lc_ctype": ("LC_CTYPE_P", UNRESERVED_KEYWORD), "leading": ("LEADING", RESERVED_KEYWORD), "least": ("LEAST", COL_NAME_KEYWORD), "left": ("LEFT", TYPE_FUNC_NAME_KEYWORD), "level": ("LEVEL", UNRESERVED_KEYWORD), "like": ("LIKE", TYPE_FUNC_NAME_KEYWORD), "limit": ("LIMIT", RESERVED_KEYWORD), "listen": ("LISTEN", UNRESERVED_KEYWORD), "load": ("LOAD", UNRESERVED_KEYWORD), "local": ("LOCAL", UNRESERVED_KEYWORD), "localtime": ("LOCALTIME", RESERVED_KEYWORD), "localtimestamp": ("LOCALTIMESTAMP", RESERVED_KEYWORD), "location": ("LOCATION", UNRESERVED_KEYWORD), "lock": ("LOCK_P", UNRESERVED_KEYWORD), "login": ("LOGIN_P", UNRESERVED_KEYWORD), "mapping": ("MAPPING", UNRESERVED_KEYWORD), "match": ("MATCH", UNRESERVED_KEYWORD), "maxvalue": ("MAXVALUE", UNRESERVED_KEYWORD), "minute": ("MINUTE_P", UNRESERVED_KEYWORD), "minvalue": ("MINVALUE", UNRESERVED_KEYWORD), "mode": ("MODE", UNRESERVED_KEYWORD), "month": ("MONTH_P", UNRESERVED_KEYWORD), "move": ("MOVE", UNRESERVED_KEYWORD), "name": ("NAME_P", UNRESERVED_KEYWORD), "names": ("NAMES", UNRESERVED_KEYWORD), "national": ("NATIONAL", COL_NAME_KEYWORD), "natural": ("NATURAL", TYPE_FUNC_NAME_KEYWORD), "nchar": ("NCHAR", COL_NAME_KEYWORD), "next": ("NEXT", UNRESERVED_KEYWORD), "no": ("NO", UNRESERVED_KEYWORD), "nocreatedb": ("NOCREATEDB", UNRESERVED_KEYWORD), "nocreaterole": ("NOCREATEROLE", UNRESERVED_KEYWORD), "nocreateuser": ("NOCREATEUSER", UNRESERVED_KEYWORD), "noinherit": ("NOINHERIT", UNRESERVED_KEYWORD), "nologin": ("NOLOGIN_P", UNRESERVED_KEYWORD), "none": ("NONE", COL_NAME_KEYWORD), "nosuperuser": ("NOSUPERUSER", UNRESERVED_KEYWORD), "not": ("NOT", RESERVED_KEYWORD), "nothing": ("NOTHING", UNRESERVED_KEYWORD), "notify": ("NOTIFY", UNRESERVED_KEYWORD), "notnull": ("NOTNULL", TYPE_FUNC_NAME_KEYWORD), "nowait": ("NOWAIT", UNRESERVED_KEYWORD), "null": ("NULL_P", RESERVED_KEYWORD), "nullif": ("NULLIF", COL_NAME_KEYWORD), "nulls": ("NULLS_P", UNRESERVED_KEYWORD), "numeric": ("NUMERIC", COL_NAME_KEYWORD), "object": ("OBJECT_P", UNRESERVED_KEYWORD), "of": ("OF", UNRESERVED_KEYWORD), "off": ("OFF", RESERVED_KEYWORD), "offset": ("OFFSET", RESERVED_KEYWORD), "oids": ("OIDS", UNRESERVED_KEYWORD), "on": ("ON", RESERVED_KEYWORD), "only": ("ONLY", RESERVED_KEYWORD), "operator": ("OPERATOR", UNRESERVED_KEYWORD), "option": ("OPTION", UNRESERVED_KEYWORD), "options": ("OPTIONS", UNRESERVED_KEYWORD), "or": ("OR", RESERVED_KEYWORD), "order": ("ORDER", RESERVED_KEYWORD), "out": ("OUT_P", COL_NAME_KEYWORD), "outer": ("OUTER_P", TYPE_FUNC_NAME_KEYWORD), "over": ("OVER", TYPE_FUNC_NAME_KEYWORD), "overlaps": ("OVERLAPS", TYPE_FUNC_NAME_KEYWORD), "overlay": ("OVERLAY", COL_NAME_KEYWORD), "owned": ("OWNED", UNRESERVED_KEYWORD), "owner": ("OWNER", UNRESERVED_KEYWORD), "parser": ("PARSER", UNRESERVED_KEYWORD), "partial": ("PARTIAL", UNRESERVED_KEYWORD), "partition": ("PARTITION", UNRESERVED_KEYWORD), "password": ("PASSWORD", UNRESERVED_KEYWORD), "placing": ("PLACING", RESERVED_KEYWORD), "plans": ("PLANS", UNRESERVED_KEYWORD), "position": ("POSITION", COL_NAME_KEYWORD), "preceding": ("PRECEDING", UNRESERVED_KEYWORD), "precision": ("PRECISION", COL_NAME_KEYWORD), "prepare": ("PREPARE", UNRESERVED_KEYWORD), "prepared": ("PREPARED", UNRESERVED_KEYWORD), "preserve": ("PRESERVE", UNRESERVED_KEYWORD), "primary": ("PRIMARY", RESERVED_KEYWORD), "prior": ("PRIOR", UNRESERVED_KEYWORD), "privileges": ("PRIVILEGES", UNRESERVED_KEYWORD), "procedural": ("PROCEDURAL", UNRESERVED_KEYWORD), "procedure": ("PROCEDURE", UNRESERVED_KEYWORD), "quote": ("QUOTE", UNRESERVED_KEYWORD), "range": ("RANGE", UNRESERVED_KEYWORD), "read": ("READ", UNRESERVED_KEYWORD), "real": ("REAL", COL_NAME_KEYWORD), "reassign": ("REASSIGN", UNRESERVED_KEYWORD), "recheck": ("RECHECK", UNRESERVED_KEYWORD), "recursive": ("RECURSIVE", UNRESERVED_KEYWORD), "references": ("REFERENCES", RESERVED_KEYWORD), "reindex": ("REINDEX", UNRESERVED_KEYWORD), "relative": ("RELATIVE_P", UNRESERVED_KEYWORD), "release": ("RELEASE", UNRESERVED_KEYWORD), "rename": ("RENAME", UNRESERVED_KEYWORD), "repeatable": ("REPEATABLE", UNRESERVED_KEYWORD), "replace": ("REPLACE", UNRESERVED_KEYWORD), "replica": ("REPLICA", UNRESERVED_KEYWORD), "reset": ("RESET", UNRESERVED_KEYWORD), "restart": ("RESTART", UNRESERVED_KEYWORD), "restrict": ("RESTRICT", UNRESERVED_KEYWORD), "returning": ("RETURNING", RESERVED_KEYWORD), "returns": ("RETURNS", UNRESERVED_KEYWORD), "revoke": ("REVOKE", UNRESERVED_KEYWORD), "right": ("RIGHT", TYPE_FUNC_NAME_KEYWORD), "role": ("ROLE", UNRESERVED_KEYWORD), "rollback": ("ROLLBACK", UNRESERVED_KEYWORD), "row": ("ROW", COL_NAME_KEYWORD), "rows": ("ROWS", UNRESERVED_KEYWORD), "rule": ("RULE", UNRESERVED_KEYWORD), "savepoint": ("SAVEPOINT", UNRESERVED_KEYWORD), "schema": ("SCHEMA", UNRESERVED_KEYWORD), "scroll": ("SCROLL", UNRESERVED_KEYWORD), "search": ("SEARCH", UNRESERVED_KEYWORD), "second": ("SECOND_P", UNRESERVED_KEYWORD), "security": ("SECURITY", UNRESERVED_KEYWORD), "select": ("SELECT", RESERVED_KEYWORD), "sequence": ("SEQUENCE", UNRESERVED_KEYWORD), "sequences": ("SEQUENCES", UNRESERVED_KEYWORD), "serializable": ("SERIALIZABLE", UNRESERVED_KEYWORD), "server": ("SERVER", UNRESERVED_KEYWORD), "session": ("SESSION", UNRESERVED_KEYWORD), "session_user": ("SESSION_USER", RESERVED_KEYWORD), "set": ("SET", UNRESERVED_KEYWORD), "setof": ("SETOF", COL_NAME_KEYWORD), "share": ("SHARE", UNRESERVED_KEYWORD), "show": ("SHOW", UNRESERVED_KEYWORD), "similar": ("SIMILAR", TYPE_FUNC_NAME_KEYWORD), "simple": ("SIMPLE", UNRESERVED_KEYWORD), "smallint": ("SMALLINT", COL_NAME_KEYWORD), "some": ("SOME", RESERVED_KEYWORD), "stable": ("STABLE", UNRESERVED_KEYWORD), "standalone": ("STANDALONE_P", UNRESERVED_KEYWORD), "start": ("START", UNRESERVED_KEYWORD), "statement": ("STATEMENT", UNRESERVED_KEYWORD), "statistics": ("STATISTICS", UNRESERVED_KEYWORD), "stdin": ("STDIN", UNRESERVED_KEYWORD), "stdout": ("STDOUT", UNRESERVED_KEYWORD), "storage": ("STORAGE", UNRESERVED_KEYWORD), "strict": ("STRICT_P", UNRESERVED_KEYWORD), "strip": ("STRIP_P", UNRESERVED_KEYWORD), "substring": ("SUBSTRING", COL_NAME_KEYWORD), "superuser": ("SUPERUSER_P", UNRESERVED_KEYWORD), "symmetric": ("SYMMETRIC", RESERVED_KEYWORD), "sysid": ("SYSID", UNRESERVED_KEYWORD), "system": ("SYSTEM_P", UNRESERVED_KEYWORD), "table": ("TABLE", RESERVED_KEYWORD), "tables": ("TABLES", UNRESERVED_KEYWORD), "tablespace": ("TABLESPACE", UNRESERVED_KEYWORD), "temp": ("TEMP", UNRESERVED_KEYWORD), "template": ("TEMPLATE", UNRESERVED_KEYWORD), "temporary": ("TEMPORARY", UNRESERVED_KEYWORD), "text": ("TEXT_P", UNRESERVED_KEYWORD), "then": ("THEN", RESERVED_KEYWORD), "time": ("TIME", COL_NAME_KEYWORD), "timestamp": ("TIMESTAMP", COL_NAME_KEYWORD), "to": ("TO", RESERVED_KEYWORD), "trailing": ("TRAILING", RESERVED_KEYWORD), "transaction": ("TRANSACTION", UNRESERVED_KEYWORD), "treat": ("TREAT", COL_NAME_KEYWORD), "trigger": ("TRIGGER", UNRESERVED_KEYWORD), "trim": ("TRIM", COL_NAME_KEYWORD), "true": ("TRUE_P", RESERVED_KEYWORD), "truncate": ("TRUNCATE", UNRESERVED_KEYWORD), "trusted": ("TRUSTED", UNRESERVED_KEYWORD), "type": ("TYPE_P", UNRESERVED_KEYWORD), "unbounded": ("UNBOUNDED", UNRESERVED_KEYWORD), "uncommitted": ("UNCOMMITTED", UNRESERVED_KEYWORD), "unencrypted": ("UNENCRYPTED", UNRESERVED_KEYWORD), "union": ("UNION", RESERVED_KEYWORD), "unique": ("UNIQUE", RESERVED_KEYWORD), "unknown": ("UNKNOWN", UNRESERVED_KEYWORD), "unlisten": ("UNLISTEN", UNRESERVED_KEYWORD), "until": ("UNTIL", UNRESERVED_KEYWORD), "update": ("UPDATE", UNRESERVED_KEYWORD), "user": ("USER", RESERVED_KEYWORD), "using": ("USING", RESERVED_KEYWORD), "vacuum": ("VACUUM", UNRESERVED_KEYWORD), "valid": ("VALID", UNRESERVED_KEYWORD), "validator": ("VALIDATOR", UNRESERVED_KEYWORD), "value": ("VALUE_P", UNRESERVED_KEYWORD), "values": ("VALUES", COL_NAME_KEYWORD), "varchar": ("VARCHAR", COL_NAME_KEYWORD), "variadic": ("VARIADIC", RESERVED_KEYWORD), "varying": ("VARYING", UNRESERVED_KEYWORD), "verbose": ("VERBOSE", TYPE_FUNC_NAME_KEYWORD), "version": ("VERSION_P", UNRESERVED_KEYWORD), "view": ("VIEW", UNRESERVED_KEYWORD), "volatile": ("VOLATILE", UNRESERVED_KEYWORD), "when": ("WHEN", RESERVED_KEYWORD), "where": ("WHERE", RESERVED_KEYWORD), "whitespace": ("WHITESPACE_P", UNRESERVED_KEYWORD), "window": ("WINDOW", RESERVED_KEYWORD), "with": ("WITH", RESERVED_KEYWORD), "without": ("WITHOUT", UNRESERVED_KEYWORD), "work": ("WORK", UNRESERVED_KEYWORD), "wrapper": ("WRAPPER", UNRESERVED_KEYWORD), "write": ("WRITE", UNRESERVED_KEYWORD), "xml": ("XML_P", UNRESERVED_KEYWORD), "xmlattributes": ("XMLATTRIBUTES", COL_NAME_KEYWORD), "xmlconcat": ("XMLCONCAT", COL_NAME_KEYWORD), "xmlelement": ("XMLELEMENT", COL_NAME_KEYWORD), "xmlforest": ("XMLFOREST", COL_NAME_KEYWORD), "xmlparse": ("XMLPARSE", COL_NAME_KEYWORD), "xmlpi": ("XMLPI", COL_NAME_KEYWORD), "xmlroot": ("XMLROOT", COL_NAME_KEYWORD), "xmlserialize": ("XMLSERIALIZE", COL_NAME_KEYWORD), "year": ("YEAR_P", UNRESERVED_KEYWORD), "yes": ("YES_P", UNRESERVED_KEYWORD), "zone": ("ZONE", UNRESERVED_KEYWORD), } by_type: dict[int, dict[str, str]] = {typ: {} for typ in keyword_types} for val, spec in pg_keywords.items(): by_type[spec[1]][val] = spec[0] ================================================ FILE: edb/pgsql/metaschema.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Database structure and objects supporting Gel metadata.""" from __future__ import annotations from typing import ( Callable, Optional, Protocol, Iterable, Sequence, cast, Any ) import functools import json import re import edb._edgeql_parser as ql_parser from edb.common import debug from edb.common import exceptions from edb.common import ordered from edb.common import uuidgen from edb.common import xdedent from edb.common.typeutils import not_none from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.edgeql import quote as qlquote from edb.edgeql import compiler as qlcompiler from edb.ir import statypes from edb.schema import constraints as s_constr from edb.schema import links as s_links from edb.schema import name as s_name from edb.schema import objects as s_obj from edb.schema import objtypes as s_objtypes from edb.schema import pointers as s_pointers from edb.schema import properties as s_props from edb.schema import scalars as s_scalars from edb.schema import schema as s_schema from edb.schema import sources as s_sources from edb.schema import types as s_types from edb.schema import utils as s_utils from edb.server import defines from edb.server import compiler as edbcompiler from edb.server import config as edbconfig from edb.server import pgcon # HM. from .resolver import sql_introspection from . import codegen from . import common from . import compiler from . import dbops from . import inheritance from . import params from . import trampoline from . import types q = common.qname qi = common.quote_ident ql = common.quote_literal qt = common.quote_type V = common.versioned_schema DATABASE_ID_NAMESPACE = uuidgen.UUID('0e6fed66-204b-11e9-8666-cffd58a5240b') CONFIG_ID_NAMESPACE = uuidgen.UUID('a48b38fa-349b-11e9-a6be-4f337f82f5ad') CONFIG_ID = { None: uuidgen.UUID('172097a4-39f4-11e9-b189-9321eb2f4b97'), qltypes.ConfigScope.INSTANCE: uuidgen.UUID( '172097a4-39f4-11e9-b189-9321eb2f4b98'), qltypes.ConfigScope.DATABASE: uuidgen.UUID( '172097a4-39f4-11e9-b189-9321eb2f4b99'), } def qtl(t: tuple[str, ...]) -> str: """Quote type literal""" return ql(f'{t[0]}.{t[1]}') if len(t) == 2 else ql(f'pg_catalog.{t[0]}') class PGConnection(Protocol): async def sql_execute( self, sql: bytes, ) -> None: ... async def sql_fetch( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> list[tuple[bytes, ...]]: ... async def sql_fetch_val( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> bytes: ... async def sql_fetch_col( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> list[bytes]: ... class DBConfigTable(dbops.Table): def __init__(self) -> None: super().__init__(name=('edgedb', '_db_config')) self.add_columns([ dbops.Column(name='name', type='text'), dbops.Column(name='value', type='jsonb'), ]) self.add_constraint( dbops.UniqueConstraint( table_name=('edgedb', '_db_config'), columns=['name'], ), ) class InstDataTable(dbops.Table): def __init__(self) -> None: sname = V('edgedbinstdata') super().__init__( name=(sname, 'instdata'), columns=[ dbops.Column( name='key', type='text', ), dbops.Column( name='bin', type='bytea', ), dbops.Column( name='text', type='text', ), dbops.Column( name='json', type='jsonb', ), ], constraints=ordered.OrderedSet([ dbops.PrimaryKey( table_name=(sname, 'instdata'), columns=['key'], ), ]), ) class QueryCacheTable(dbops.Table): def __init__(self) -> None: super().__init__(name=('edgedb', '_query_cache')) self.add_columns([ dbops.Column(name='key', type='uuid', required=True), dbops.Column(name='schema_version', type='uuid', required=True), dbops.Column(name='input', type='bytea', required=True), dbops.Column(name='output', type='bytea', required=True), dbops.Column(name='evict', type='text', required=True), dbops.Column( name='creation_time', type='timestamp with time zone', required=True, default='current_timestamp', ), ]) self.add_constraint( dbops.PrimaryKey( table_name=('edgedb', '_query_cache'), columns=['key'], ), ) class EvictQueryCacheFunction(trampoline.VersionedFunction): text = f''' DECLARE evict_sql text; BEGIN DELETE FROM "edgedb"."_query_cache" WHERE "key" = cache_key RETURNING "evict" INTO evict_sql; IF evict_sql IS NOT NULL THEN EXECUTE evict_sql; END IF; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_evict_query_cache'), args=[("cache_key", ("uuid",))], returns=("void",), language='plpgsql', volatility='volatile', text=self.text, ) class ClearQueryCacheFunction(trampoline.VersionedFunction): # TODO(fantix): this may consume a lot of memory in Postgres text = f''' DECLARE row record; BEGIN FOR row IN DELETE FROM "edgedb"."_query_cache" RETURNING "input", "evict" LOOP EXECUTE row."evict"; RETURN NEXT row."input"; END LOOP; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_clear_query_cache'), args=[], returns=('bytea',), set_returning=True, language='plpgsql', volatility='volatile', text=self.text, ) class CreateTrampolineViewFunction(trampoline.VersionedFunction): text = f''' DECLARE cols text; tgt text; dummy text; BEGIN tgt := quote_ident(tgt_schema) || '.' || quote_ident(tgt_name); -- Check if the view already exists. select viewname into dummy from pg_catalog.pg_views where schemaname = tgt_schema and viewname = tgt_name; IF FOUND THEN -- If the view already existed, we need to generate a column -- list that maintains the order of anything that was present in -- the old view, and that doesn't remove any columns that were -- dropped. select string_agg( COALESCE( quote_ident(tname), 'NULL::' || vtypname || ' AS ' || quote_ident(vname) ), ',' ) from ( select a1.attname as tname, a2.attname as vname, pg_catalog.format_type(a2.atttypid, NULL) as vtypname from ( select * from pg_catalog.pg_attribute where attrelid = src::regclass::oid and attnum >= 0 ) a1 full outer join ( select * from pg_catalog.pg_attribute where attrelid = tgt::regclass::oid ) a2 on a1.attname = a2.attname order by a2.attnum, a1.attnum ) t INTO cols; END IF; -- If it doesn't exist or has no columns, create it with SELECT * cols := COALESCE(cols, '*'); EXECUTE 'CREATE OR REPLACE VIEW ' || tgt || ' AS ' || 'SELECT ' || cols || ' FROM ' || src; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_create_trampoline_view'), args=[ ('src', ('text',)), ('tgt_schema', ('text',)), ('tgt_name', ('text',)), ], returns=('void',), language='plpgsql', volatility='volatile', text=self.text, ) class BigintDomain(dbops.Domain): """Bigint: a variant of numeric that enforces zero digits after the dot. We're using an explicit scale check as opposed to simply specifying the numeric bounds, because using bounds severly restricts the range of the numeric type (1000 vs 131072 digits). """ def __init__(self) -> None: super().__init__( name=('edgedbt', 'bigint_t'), base='numeric', constraints=( dbops.DomainCheckConstraint( domain_name=('edgedbt', 'bigint_t'), expr=("scale(VALUE) = 0 AND VALUE != 'NaN'"), ), ), ) class ConfigMemoryDomain(dbops.Domain): """Represents the cfg::memory type. Stores number of bytes. Defined just as edgedbt.bigint_t: * numeric is used to ensure we can comfortably represent huge amounts of data beyond petabytes; * enforces zero digits after the dot. """ def __init__(self) -> None: super().__init__( name=('edgedbt', 'memory_t'), base='int8', constraints=( dbops.DomainCheckConstraint( domain_name=('edgedbt', 'memory_t'), expr=("VALUE >= 0"), ), ), ) class TimestampTzDomain(dbops.Domain): """Timestamptz clamped to years 0001-9999. The default timestamp range of (4713 BC - 294276 AD) has problems: Postgres isn't ISO compliant with years out of the 1-9999 range and language compatibility is questionable. """ def __init__(self) -> None: super().__init__( name=('edgedbt', 'timestamptz_t'), base='timestamptz', constraints=( dbops.DomainCheckConstraint( domain_name=('edgedbt', 'timestamptz_t'), expr=("EXTRACT(years from VALUE) BETWEEN 1 AND 9999"), ), ), ) class TimestampDomain(dbops.Domain): """Timestamp clamped to years 0001-9999. The default timestamp range of (4713 BC - 294276 AD) has problems: Postgres isn't ISO compliant with years out of the 1-9999 range and language compatibility is questionable. """ def __init__(self) -> None: super().__init__( name=('edgedbt', 'timestamp_t'), base='timestamp', constraints=( dbops.DomainCheckConstraint( domain_name=('edgedbt', 'timestamp_t'), expr=("EXTRACT(years from VALUE) BETWEEN 1 AND 9999"), ), ), ) class DateDomain(dbops.Domain): """Date clamped to years 0001-9999. The default timestamp range of (4713 BC - 294276 AD) has problems: Postgres isn't ISO compliant with years out of the 1-9999 range and language compatibility is questionable. """ def __init__(self) -> None: super().__init__( name=('edgedbt', 'date_t'), base='date', constraints=( dbops.DomainCheckConstraint( domain_name=('edgedbt', 'date_t'), expr=("EXTRACT(years from VALUE) BETWEEN 1 AND 9999"), ), ), ) class DurationDomain(dbops.Domain): def __init__(self) -> None: super().__init__( name=('edgedbt', 'duration_t'), base='interval', constraints=( dbops.DomainCheckConstraint( domain_name=('edgedbt', 'duration_t'), expr=r''' EXTRACT(months from VALUE) = 0 AND EXTRACT(years from VALUE) = 0 AND EXTRACT(days from VALUE) = 0 ''', ), ), ) class RelativeDurationDomain(dbops.Domain): def __init__(self) -> None: super().__init__( name=('edgedbt', 'relative_duration_t'), base='interval', constraints=( dbops.DomainCheckConstraint( domain_name=('edgedbt', 'relative_duration_t'), expr="true", ), ), ) class DateDurationDomain(dbops.Domain): def __init__(self) -> None: super().__init__( name=('edgedbt', 'date_duration_t'), base='interval', constraints=( dbops.DomainCheckConstraint( domain_name=('edgedbt', 'date_duration_t'), expr=r''' EXTRACT(hour from VALUE) = 0 AND EXTRACT(minute from VALUE) = 0 AND EXTRACT(second from VALUE) = 0 ''', ), ), ) class Float32Range(dbops.Range): def __init__(self) -> None: super().__init__( name=types.type_to_range_name_map[('float4',)], subtype=('float4',), ) class Float64Range(dbops.Range): def __init__(self) -> None: super().__init__( name=types.type_to_range_name_map[('float8',)], subtype=('float8',), subtype_diff=('float8mi',) ) class DatetimeRange(dbops.Range): def __init__(self) -> None: super().__init__( name=types.type_to_range_name_map[('edgedbt', 'timestamptz_t')], subtype=('edgedbt', 'timestamptz_t'), ) class LocalDatetimeRange(dbops.Range): def __init__(self) -> None: super().__init__( name=types.type_to_range_name_map[('edgedbt', 'timestamp_t')], subtype=('edgedbt', 'timestamp_t'), ) class RangeToJsonFunction(trampoline.VersionedFunction): """Convert anyrange to a jsonb object.""" text = r''' SELECT CASE WHEN val IS NULL THEN NULL WHEN isempty(val) THEN jsonb_build_object('empty', true) ELSE to_jsonb(o) END FROM (SELECT lower(val) as lower, lower_inc(val) as inc_lower, upper(val) as upper, upper_inc(val) as inc_upper ) AS o ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'range_to_jsonb'), args=[ ('val', ('anyrange',)), ], returns=('jsonb',), volatility='immutable', language='sql', text=self.text, ) class MultiRangeToJsonFunction(trampoline.VersionedFunction): """Convert anymultirange to a jsonb object.""" text = r''' SELECT CASE WHEN val IS NULL THEN NULL WHEN isempty(val) THEN jsonb_build_array() ELSE ( SELECT jsonb_agg(edgedb_VER.range_to_jsonb(m.el)) FROM (SELECT unnest(val) AS el ) AS m ) END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'multirange_to_jsonb'), args=[ ('val', ('anymultirange',)), ], returns=('jsonb',), volatility='immutable', language='sql', text=self.text, ) class RangeValidateFunction(trampoline.VersionedFunction): """Range constructor validation function.""" text = r''' SELECT CASE WHEN empty AND (lower IS DISTINCT FROM upper OR lower IS NOT NULL AND inc_upper AND inc_lower) THEN edgedb_VER.raise( NULL::bool, 'invalid_parameter_value', msg => 'conflicting arguments in range constructor:' || ' "empty" is `true` while the specified' || ' bounds suggest otherwise' ) ELSE empty END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'range_validate'), args=[ ('lower', ('anyelement',)), ('upper', ('anyelement',)), ('inc_lower', ('bool',)), ('inc_upper', ('bool',)), ('empty', ('bool',)), ], returns=('bool',), volatility='immutable', language='sql', text=self.text, ) class RangeUnpackLowerValidateFunction(trampoline.VersionedFunction): """Range unpack validation function.""" text = r''' SELECT CASE WHEN NOT isempty(range) THEN edgedb_VER.raise_on_null( lower(range), 'invalid_parameter_value', msg => 'cannot unpack an unbounded range' ) ELSE lower(range) END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'range_lower_validate'), args=[ ('range', ('anyrange',)), ], returns=('anyelement',), volatility='immutable', language='sql', text=self.text, ) class RangeUnpackUpperValidateFunction(trampoline.VersionedFunction): """Range unpack validation function.""" text = r''' SELECT CASE WHEN NOT isempty(range) THEN edgedb_VER.raise_on_null( upper(range), 'invalid_parameter_value', msg => 'cannot unpack an unbounded range' ) ELSE upper(range) END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'range_upper_validate'), args=[ ('range', ('anyrange',)), ], returns=('anyelement',), volatility='immutable', language='sql', text=self.text, ) class StrToConfigMemoryFunction(trampoline.VersionedFunction): """An implementation of std::str to cfg::memory cast.""" text = r''' SELECT (CASE WHEN m.v[1] IS NOT NULL AND m.v[2] IS NOT NULL THEN ( CASE WHEN m.v[2] = 'B' THEN m.v[1]::int8 WHEN m.v[2] = 'KiB' THEN m.v[1]::int8 * 1024 WHEN m.v[2] = 'MiB' THEN m.v[1]::int8 * 1024 * 1024 WHEN m.v[2] = 'GiB' THEN m.v[1]::int8 * 1024 * 1024 * 1024 WHEN m.v[2] = 'TiB' THEN m.v[1]::int8 * 1024 * 1024 * 1024 * 1024 WHEN m.v[2] = 'PiB' THEN m.v[1]::int8 * 1024 * 1024 * 1024 * 1024 * 1024 ELSE -- Won't happen but we still have a guard for -- completeness. edgedb_VER.raise( NULL::int8, 'invalid_parameter_value', msg => ( 'unsupported memory size unit "' || m.v[2] || '"' ) ) END ) ELSE CASE WHEN "val" = '0' THEN 0::int8 ELSE edgedb_VER.raise( NULL::int8, 'invalid_parameter_value', msg => ( 'unable to parse memory size "' || "val" || '"' ) ) END END)::edgedbt.memory_t FROM LATERAL ( SELECT regexp_match( "val", '^(\d+)([[:alpha:]]+)$') AS v ) AS m ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'str_to_cfg_memory'), args=[ ('val', ('text',)), ], returns=('edgedbt', 'memory_t'), strict=True, volatility='immutable', language='sql', text=self.text, ) class ConfigMemoryToStrFunction(trampoline.VersionedFunction): """An implementation of cfg::memory to std::str cast.""" text = r''' SELECT CASE WHEN "val" >= (1024::int8 * 1024 * 1024 * 1024 * 1024) AND "val" % (1024::int8 * 1024 * 1024 * 1024 * 1024) = 0 THEN ( "val" / (1024::int8 * 1024 * 1024 * 1024 * 1024) )::text || 'PiB' WHEN "val" >= (1024::int8 * 1024 * 1024 * 1024) AND "val" % (1024::int8 * 1024 * 1024 * 1024) = 0 THEN ( "val" / (1024::int8 * 1024 * 1024 * 1024) )::text || 'TiB' WHEN "val" >= (1024::int8 * 1024 * 1024) AND "val" % (1024::int8 * 1024 * 1024) = 0 THEN ("val" / (1024::int8 * 1024 * 1024))::text || 'GiB' WHEN "val" >= 1024::int8 * 1024 AND "val" % (1024::int8 * 1024) = 0 THEN ("val" / (1024::int8 * 1024))::text || 'MiB' WHEN "val" >= 1024 AND "val" % 1024 = 0 THEN ("val" / 1024::int8)::text || 'KiB' ELSE "val"::text || 'B' END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'cfg_memory_to_str'), args=[ ('val', ('edgedbt', 'memory_t')), ], returns=('text',), volatility='immutable', language='sql', text=self.text, ) class AlterCurrentDatabaseSetString(trampoline.VersionedFunction): """Alter a PostgreSQL configuration parameter of the current database.""" text = ''' BEGIN EXECUTE 'ALTER DATABASE ' || quote_ident(current_database()) || ' SET ' || quote_ident(parameter) || ' = ' || coalesce(quote_literal(value), 'DEFAULT'); RETURN value; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_alter_current_database_set'), args=[('parameter', ('text',)), ('value', ('text',))], returns=('text',), volatility='volatile', language='plpgsql', text=self.text, ) class AlterCurrentDatabaseSetStringArray(trampoline.VersionedFunction): """Alter a PostgreSQL configuration parameter of the current database.""" text = ''' BEGIN EXECUTE 'ALTER DATABASE ' || quote_ident(current_database()) || ' SET ' || quote_ident(parameter) || ' = ' || coalesce( (SELECT array_to_string(array_agg(quote_literal(q.v)), ',') FROM unnest(value) AS q(v) ), 'DEFAULT' ); RETURN value; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_alter_current_database_set'), args=[ ('parameter', ('text',)), ('value', ('text[]',)), ], returns=('text[]',), volatility='volatile', language='plpgsql', text=self.text, ) class AlterCurrentDatabaseSetNonArray(trampoline.VersionedFunction): """Alter a PostgreSQL configuration parameter of the current database.""" text = ''' BEGIN EXECUTE 'ALTER DATABASE ' || quote_ident(current_database()) || ' SET ' || quote_ident(parameter) || ' = ' || coalesce(value::text, 'DEFAULT'); RETURN value; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_alter_current_database_set'), args=[ ('parameter', ('text',)), ('value', ('anynonarray',)), ], returns=('anynonarray',), volatility='volatile', language='plpgsql', text=self.text, ) class AlterCurrentDatabaseSetArray(trampoline.VersionedFunction): """Alter a PostgreSQL configuration parameter of the current database.""" text = ''' BEGIN EXECUTE 'ALTER DATABASE ' || quote_ident(current_database()) || ' SET ' || quote_ident(parameter) || ' = ' || coalesce( (SELECT array_to_string(array_agg(q.v::text), ',') FROM unnest(value) AS q(v) ), 'DEFAULT' ); RETURN value; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_alter_current_database_set'), args=[ ('parameter', ('text',)), ('value', ('anyarray',)), ], returns=('anyarray',), volatility='volatile', language='plpgsql', text=self.text, ) class CopyDatabaseConfigs(trampoline.VersionedFunction): """Copy database configs from one database to the current one""" text = ''' SELECT edgedb_VER._alter_current_database_set( nameval.name, nameval.value) FROM pg_db_role_setting AS cfg, LATERAL unnest(cfg.setconfig) as cfg_set(s), LATERAL ( SELECT split_part(cfg_set.s, '=', 1) AS name, split_part(cfg_set.s, '=', 2) AS value ) AS nameval WHERE setdatabase = ( SELECT oid FROM pg_database WHERE datname = source_db ) AND setrole = 0; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_copy_database_configs'), args=[('source_db', ('text',))], returns=('text',), volatility='volatile', text=self.text, ) class StrToBigint(trampoline.VersionedFunction): """Parse bigint from text.""" # The plpgsql execption handling nonsense is actually just so that # we can produce an exception that mentions edgedbt.bigint_t # instead of numeric, and thus produce the right user-facing # exception. As a nice side effect it is like twice as fast # as the previous code too. text = r''' DECLARE v numeric; BEGIN BEGIN v := val::numeric; EXCEPTION WHEN OTHERS THEN v := NULL; END; IF scale(v) = 0 THEN RETURN v::edgedbt.bigint_t; ELSE EXECUTE edgedb_VER.raise( NULL::numeric, 'invalid_text_representation', msg => ( 'invalid input syntax for type edgedbt.bigint_t: ' || quote_literal(val) ) ); END IF; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'str_to_bigint'), args=[('val', ('text',))], returns=('edgedbt', 'bigint_t'), language='plpgsql', volatility='immutable', strict=True, text=self.text) class StrToDecimal(trampoline.VersionedFunction): """Parse decimal from text.""" text = r''' SELECT (CASE WHEN v.column1 != 'NaN' THEN v.column1 ELSE edgedb_VER.raise( NULL::numeric, 'invalid_text_representation', msg => ( 'invalid input syntax for type numeric: ' || quote_literal(val) ) ) END) FROM (VALUES ( val::numeric )) AS v ; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'str_to_decimal'), args=[('val', ('text',))], returns=('numeric',), volatility='immutable', strict=True, text=self.text, ) class StrToInt64NoInline(trampoline.VersionedFunction): """String-to-int64 cast with noinline guard. Adding a LIMIT clause to the function statement makes it uninlinable due to the Postgres inlining heuristic looking for simple SELECT expressions only (i.e. no clauses.) This might need to change in the future if the heuristic changes. """ text = r''' SELECT "val"::bigint LIMIT 1 ; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'str_to_int64_noinline'), args=[('val', ('text',))], returns=('bigint',), volatility='immutable', text=self.text, ) class StrToInt32NoInline(trampoline.VersionedFunction): """String-to-int32 cast with noinline guard.""" text = r''' SELECT "val"::int LIMIT 1 ; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'str_to_int32_noinline'), args=[('val', ('text',))], returns=('int',), volatility='immutable', text=self.text, ) class StrToInt16NoInline(trampoline.VersionedFunction): """String-to-int16 cast with noinline guard.""" text = r''' SELECT "val"::smallint LIMIT 1 ; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'str_to_int16_noinline'), args=[('val', ('text',))], returns=('smallint',), volatility='immutable', text=self.text, ) class StrToFloat64NoInline(trampoline.VersionedFunction): """String-to-float64 cast with noinline guard.""" text = r''' SELECT "val"::float8 LIMIT 1 ; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'str_to_float64_noinline'), args=[('val', ('text',))], returns=('float8',), volatility='immutable', text=self.text, ) class StrToFloat32NoInline(trampoline.VersionedFunction): """String-to-float32 cast with noinline guard.""" text = r''' SELECT "val"::float4 LIMIT 1 ; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'str_to_float32_noinline'), args=[('val', ('text',))], returns=('float4',), volatility='immutable', text=self.text, ) class GetBackendCapabilitiesFunction(trampoline.VersionedFunction): text = f''' SELECT (json ->> 'capabilities')::bigint FROM edgedbinstdata_VER.instdata WHERE key = 'backend_instance_params' ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_backend_capabilities'), args=[], returns=('bigint',), language='sql', volatility='stable', text=self.text, ) class GetBackendTenantIDFunction(trampoline.VersionedFunction): text = f''' SELECT (json ->> 'tenant_id')::text FROM edgedbinstdata_VER.instdata WHERE key = 'backend_instance_params' ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_backend_tenant_id'), args=[], returns=('text',), language='sql', volatility='stable', text=self.text, ) class GetDatabaseBackendNameFunction(trampoline.VersionedFunction): text = f''' SELECT CASE WHEN (edgedb_VER.get_backend_capabilities() & {int(params.BackendCapabilities.CREATE_DATABASE)}) != 0 THEN edgedb_VER.get_backend_tenant_id() || '_' || "db_name" ELSE current_database()::text END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_database_backend_name'), args=[('db_name', ('text',))], returns=('text',), language='sql', volatility='stable', text=self.text, ) class GetDatabaseFrontendNameFunction(trampoline.VersionedFunction): text = f''' SELECT CASE WHEN (edgedb_VER.get_backend_capabilities() & {int(params.BackendCapabilities.CREATE_DATABASE)}) != 0 THEN substring(db_name, position('_' in db_name) + 1) ELSE 'main' END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_database_frontend_name'), args=[('db_name', ('text',))], returns=('text',), language='sql', volatility='stable', text=self.text, ) class GetRoleBackendNameFunction(trampoline.VersionedFunction): text = f''' SELECT CASE WHEN (edgedb_VER.get_backend_capabilities() & {int(params.BackendCapabilities.CREATE_ROLE)}) != 0 THEN edgedb_VER.get_backend_tenant_id() || '_' || "role_name" ELSE current_user::text END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_role_backend_name'), args=[('role_name', ('text',))], returns=('text',), language='sql', volatility='stable', text=self.text, ) class GetUserSequenceBackendNameFunction(trampoline.VersionedFunction): text = f""" SELECT 'edgedbpub', "sequence_type_id"::text || '_sequence' """ def __init__(self) -> None: super().__init__( name=('edgedb', 'get_user_sequence_backend_name'), args=[('sequence_type_id', ('uuid',))], returns=('record',), language='sql', volatility='stable', text=self.text, ) class GetSequenceBackendNameFunction(trampoline.VersionedFunction): text = f''' SELECT (CASE WHEN edgedb_VER.get_name_module(st.name) = any(edgedb_VER.get_std_modules()) THEN 'edgedbstd' ELSE 'edgedbpub' END), "sequence_type_id"::text || '_sequence' FROM edgedb_VER."_SchemaScalarType" AS st WHERE st.id = "sequence_type_id" ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_sequence_backend_name'), args=[('sequence_type_id', ('uuid',))], returns=('record',), language='sql', volatility='stable', text=self.text, ) class GetStdModulesFunction(trampoline.VersionedFunction): text = f''' SELECT ARRAY[{",".join(ql(str(m)) for m in s_schema.STD_MODULES)}] ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_std_modules'), args=[], returns=('text[]',), language='sql', volatility='immutable', text=self.text, ) class GetObjectMetadata(trampoline.VersionedFunction): """Return Gel metadata associated with a backend object.""" text = ''' SELECT CASE WHEN substr(d, 1, char_length({prefix})) = {prefix} THEN substr(d, char_length({prefix}) + 1)::jsonb ELSE '{{}}'::jsonb END FROM obj_description("objoid", "objclass") AS d '''.format( prefix=f'E{ql(defines.EDGEDB_VISIBLE_METADATA_PREFIX)}', ) def __init__(self) -> None: super().__init__( name=('edgedb', 'obj_metadata'), args=[('objoid', ('oid',)), ('objclass', ('text',))], returns=('jsonb',), volatility='stable', text=self.text) class GetColumnMetadata(trampoline.VersionedFunction): """Return Gel metadata associated with a backend object.""" text = ''' SELECT CASE WHEN substr(d, 1, char_length({prefix})) = {prefix} THEN substr(d, char_length({prefix}) + 1)::jsonb ELSE '{{}}'::jsonb END FROM col_description("tableoid", "column") AS d '''.format( prefix=f'E{ql(defines.EDGEDB_VISIBLE_METADATA_PREFIX)}', ) def __init__(self) -> None: super().__init__( name=('edgedb', 'col_metadata'), args=[('tableoid', ('oid',)), ('column', ('integer',))], returns=('jsonb',), volatility='stable', text=self.text) class GetSharedObjectMetadata(trampoline.VersionedFunction): """Return Gel metadata associated with a backend object.""" text = ''' SELECT CASE WHEN substr(d, 1, char_length({prefix})) = {prefix} THEN substr(d, char_length({prefix}) + 1)::jsonb ELSE '{{}}'::jsonb END FROM shobj_description("objoid", "objclass") AS d '''.format( prefix=f'E{ql(defines.EDGEDB_VISIBLE_METADATA_PREFIX)}', ) def __init__(self) -> None: super().__init__( name=('edgedb', 'shobj_metadata'), args=[('objoid', ('oid',)), ('objclass', ('text',))], returns=('jsonb',), volatility='stable', text=self.text) class GetDatabaseMetadataFunction(trampoline.VersionedFunction): """Return Gel metadata associated with a given database.""" text = f''' SELECT CASE WHEN "dbname" = {ql(defines.EDGEDB_SUPERUSER_DB)} OR (edgedb_VER.get_backend_capabilities() & {int(params.BackendCapabilities.CREATE_DATABASE)}) != 0 THEN edgedb_VER.shobj_metadata( (SELECT oid FROM pg_database WHERE datname = edgedb_VER.get_database_backend_name("dbname") ), 'pg_database' ) ELSE COALESCE( (SELECT json FROM edgedbinstdata_VER.instdata WHERE key = "dbname" || 'metadata' ), '{{}}'::jsonb ) END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_database_metadata'), args=[('dbname', ('text',))], returns=('jsonb',), volatility='stable', text=self.text, ) class GetCurrentDatabaseFunction(trampoline.VersionedFunction): text = f''' SELECT CASE WHEN (edgedb_VER.get_backend_capabilities() & {int(params.BackendCapabilities.CREATE_DATABASE)}) != 0 THEN substr( current_database(), char_length(edgedb_VER.get_backend_tenant_id()) + 2 ) ELSE {ql(defines.EDGEDB_SUPERUSER_DB)} END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_current_database'), args=[], returns=('text',), language='sql', volatility='stable', text=self.text, ) class RaiseNoticeFunction(trampoline.VersionedFunction): text = ''' BEGIN RAISE NOTICE USING MESSAGE = "msg", DETAIL = COALESCE("detail", ''), HINT = COALESCE("hint", ''), COLUMN = COALESCE("column", ''), CONSTRAINT = COALESCE("constraint", ''), DATATYPE = COALESCE("datatype", ''), TABLE = COALESCE("table", ''), SCHEMA = COALESCE("schema", ''); RETURN "rtype"; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'notice'), args=[ ('rtype', ('anyelement',)), ('msg', ('text',), "''"), ('detail', ('text',), "''"), ('hint', ('text',), "''"), ('column', ('text',), "''"), ('constraint', ('text',), "''"), ('datatype', ('text',), "''"), ('table', ('text',), "''"), ('schema', ('text',), "''"), ], returns=('anyelement',), # NOTE: The main reason why we don't want this function to be # immutable is that immutable functions can be # pre-evaluated by the query planner once if they have # constant arguments. This means that using this function # as the second argument in a COALESCE will raise a # notice regardless of whether the first argument is # NULL or not. volatility='stable', language='plpgsql', text=self.text, ) # edgedb.indirect_return() to be used to return values from # anonymous code blocks or other contexts that have no return # data channel. class IndirectReturnFunction(trampoline.VersionedFunction): text = """ SELECT edgedb_VER.notice( NULL::text, msg => 'edb:notice:indirect_return', detail => "value" ) """ def __init__(self) -> None: super().__init__( name=('edgedb', 'indirect_return'), args=[ ('value', ('text',)), ], returns=('text',), # NOTE: The main reason why we don't want this function to be # immutable is that immutable functions can be # pre-evaluated by the query planner once if they have # constant arguments. This means that using this function # as the second argument in a COALESCE will raise a # notice regardless of whether the first argument is # NULL or not. volatility='stable', language='sql', text=self.text, ) class RaiseExceptionFunction(trampoline.VersionedFunction): text = ''' BEGIN RAISE EXCEPTION USING ERRCODE = "exc", MESSAGE = "msg", DETAIL = COALESCE("detail", ''), HINT = COALESCE("hint", ''), COLUMN = COALESCE("column", ''), CONSTRAINT = COALESCE("constraint", ''), DATATYPE = COALESCE("datatype", ''), TABLE = COALESCE("table", ''), SCHEMA = COALESCE("schema", ''); RETURN "rtype"; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'raise'), args=[ ('rtype', ('anyelement',)), ('exc', ('text',), "'raise_exception'"), ('msg', ('text',), "''"), ('detail', ('text',), "''"), ('hint', ('text',), "''"), ('column', ('text',), "''"), ('constraint', ('text',), "''"), ('datatype', ('text',), "''"), ('table', ('text',), "''"), ('schema', ('text',), "''"), ], returns=('anyelement',), # NOTE: The main reason why we don't want this function to be # immutable is that immutable functions can be # pre-evaluated by the query planner once if they have # constant arguments. This means that using this function # as the second argument in a COALESCE will raise an # exception regardless of whether the first argument is # NULL or not. volatility='stable', language='plpgsql', text=self.text, ) class RaiseExceptionOnNullFunction(trampoline.VersionedFunction): """Return the passed value or raise an exception if it's NULL.""" text = ''' SELECT coalesce( val, edgedb_VER.raise( val, exc, msg => msg, detail => detail, hint => hint, "column" => "column", "constraint" => "constraint", "datatype" => "datatype", "table" => "table", "schema" => "schema" ) ) ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'raise_on_null'), args=[ ('val', ('anyelement',)), ('exc', ('text',)), ('msg', ('text',)), ('detail', ('text',), "''"), ('hint', ('text',), "''"), ('column', ('text',), "''"), ('constraint', ('text',), "''"), ('datatype', ('text',), "''"), ('table', ('text',), "''"), ('schema', ('text',), "''"), ], returns=('anyelement',), # Same volatility as raise() volatility='stable', text=self.text, ) class RaiseExceptionOnNotNullFunction(trampoline.VersionedFunction): """Return the passed value or raise an exception if it's NOT NULL.""" text = ''' SELECT CASE WHEN val IS NULL THEN val ELSE edgedb_VER.raise( val, exc, msg => msg, detail => detail, hint => hint, "column" => "column", "constraint" => "constraint", "datatype" => "datatype", "table" => "table", "schema" => "schema" ) END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'raise_on_not_null'), args=[ ('val', ('anyelement',)), ('exc', ('text',)), ('msg', ('text',)), ('detail', ('text',), "''"), ('hint', ('text',), "''"), ('column', ('text',), "''"), ('constraint', ('text',), "''"), ('datatype', ('text',), "''"), ('table', ('text',), "''"), ('schema', ('text',), "''"), ], returns=('anyelement',), # Same volatility as raise() volatility='stable', text=self.text, ) class RaiseExceptionOnEmptyStringFunction(trampoline.VersionedFunction): """Return the passed string or raise an exception if it's empty.""" text = ''' SELECT CASE WHEN edgedb_VER._length(val) = 0 THEN edgedb_VER.raise(val, exc, msg => msg, detail => detail) ELSE val END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'raise_on_empty'), args=[ ('val', ('anyelement',)), ('exc', ('text',)), ('msg', ('text',)), ('detail', ('text',), "''"), ], returns=('anyelement',), # Same volatility as raise() volatility='stable', text=self.text, ) class AssertJSONTypeFunction(trampoline.VersionedFunction): """Assert that the JSON type matches what is expected.""" text = ''' SELECT CASE WHEN array_position(typenames, jsonb_typeof(val)) IS NULL THEN edgedb_VER.raise( NULL::jsonb, 'wrong_object_type', msg => coalesce( msg, format( 'expected JSON %s; got JSON %s: %s', array_to_string(typenames, ' or '), coalesce(jsonb_typeof(val), 'UNKNOWN'), val::text ) ), detail => detail ) ELSE val END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'jsonb_assert_type'), args=[ ('val', ('jsonb',)), ('typenames', ('text[]',)), ('msg', ('text',), 'NULL'), ('detail', ('text',), "''"), ], returns=('jsonb',), # Max volatility of raise() and array_to_string() (stable) volatility='stable', text=self.text, ) class ExtractJSONScalarFunction(trampoline.VersionedFunction): """Convert a given JSON scalar value into a text value.""" text = ''' SELECT (to_jsonb(ARRAY[ edgedb_VER.jsonb_assert_type( coalesce(val, 'null'::jsonb), ARRAY[json_typename, 'null'], msg => msg, detail => detail ) ])->>0) ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'jsonb_extract_scalar'), args=[ ('val', ('jsonb',)), ('json_typename', ('text',)), ('msg', ('text',), 'NULL'), ('detail', ('text',), "''"), ], returns=('text',), volatility='stable', wrapper_volatility='immutable', text=self.text, ) class GetSchemaObjectNameFunction(trampoline.VersionedFunction): text = ''' SELECT coalesce( (SELECT name FROM edgedb_VER."_SchemaObject" WHERE id = type::uuid), edgedb_VER.raise( NULL::text, msg => 'resolve_type_name: unknown type: "' || type || '"' ) ) ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_get_schema_object_name'), args=[('type', ('uuid',))], returns=('text',), # Max volatility of raise() and a SELECT from a # table (stable). volatility='stable', text=self.text, strict=True, ) # We create this version first (since it is used by the stdlib), and # then replace it with the real version later. class ApproximateCountDummy(trampoline.VersionedFunction): text = ''' SELECT 0 ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'approximate_count'), args=[ ('ignore_subtypes', ('bool',)), ('type', ('uuid',)), ('type_type', ('uuid',), "NULL"), ], returns=('bigint',), volatility='stable', text=self.text, ) class ApproximateCount(trampoline.VersionedFunction): text = ''' SELECT coalesce(sum(reltuples::bigint), 0) AS estimate FROM pg_class pc WHERE pc.relname IN ( SELECT oa.source::text FROM edgedb_VER."_SchemaObjectType__ancestors" oa WHERE oa.target = type AND not ignore_subtypes UNION select type::text ) AND pc.reltuples >= 0; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'approximate_count'), args=[ ('ignore_subtypes', ('bool',)), ('type', ('uuid',)), ('type_type', ('uuid',), "NULL"), ], returns=('bigint',), volatility='stable', text=self.text, ) class IssubclassFunction(trampoline.VersionedFunction): text = ''' SELECT clsid = any(classes) OR ( SELECT classes && q.ancestors FROM (SELECT array_agg(o.target) AS ancestors FROM edgedb_VER."_SchemaInheritingObject__ancestors" o WHERE o.source = clsid ) AS q ); ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'issubclass'), args=[('clsid', 'uuid'), ('classes', 'uuid[]')], returns='bool', volatility='stable', text=self.__class__.text) class IssubclassFunction2(trampoline.VersionedFunction): text = ''' SELECT clsid = pclsid OR ( SELECT pclsid IN ( SELECT o.target FROM edgedb_VER."_SchemaInheritingObject__ancestors" o WHERE o.source = clsid ) ); ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'issubclass'), args=[('clsid', 'uuid'), ('pclsid', 'uuid')], returns='bool', volatility='stable', text=self.__class__.text) class NormalizeNameFunction(trampoline.VersionedFunction): text = ''' SELECT CASE WHEN strpos(name, '@') = 0 THEN name ELSE CASE WHEN strpos(name, '::') = 0 THEN replace(split_part(name, '@', 1), '|', '::') ELSE replace( split_part( -- "reverse" calls are to emulate "rsplit" reverse(split_part(reverse(name), '::', 1)), '@', 1), '|', '::') END END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'shortname_from_fullname'), args=[('name', 'text')], returns='text', volatility='immutable', language='sql', text=self.__class__.text) class GetNameModuleFunction(trampoline.VersionedFunction): text = ''' SELECT reverse(split_part(reverse("name"), '::', 1)) ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_name_module'), args=[('name', 'text')], returns='text', volatility='immutable', language='sql', text=self.__class__.text) class NullIfArrayNullsFunction(trampoline.VersionedFunction): """Check if array contains NULLs and if so, return NULL.""" def __init__(self) -> None: super().__init__( name=('edgedb', '_nullif_array_nulls'), args=[('a', 'anyarray')], returns='anyarray', volatility='stable', language='sql', text=''' SELECT CASE WHEN array_position(a, NULL) IS NULL THEN a ELSE NULL END ''') class NormalizeArrayIndexFunction(trampoline.VersionedFunction): """Convert an EdgeQL index to SQL index.""" text = ''' SELECT CASE WHEN index > (2147483647-1) OR index < -2147483648 THEN NULL WHEN index < 0 THEN length + index::int + 1 ELSE index::int + 1 END ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_normalize_array_index'), args=[('index', ('bigint',)), ('length', ('int',))], returns=('int',), volatility='immutable', text=self.text, ) class NormalizeArraySliceIndexFunction(trampoline.VersionedFunction): """Convert an EdgeQL index to SQL index (for slices)""" text = ''' SELECT GREATEST(0, LEAST(2147483647, CASE WHEN index < 0 THEN length::bigint + index + 1 ELSE index + 1 END )) WHERE index IS NOT NULL ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_normalize_array_slice_index'), args=[('index', ('bigint',)), ('length', ('int',))], returns=('int',), volatility='immutable', text=self.text, ) class IntOrNullFunction(trampoline.VersionedFunction): """ Convert bigint to int. If it does not fit, return NULL. """ text = """ SELECT CASE WHEN val <= 2147483647 AND val >= -2147483648 THEN val ELSE NULL END """ def __init__(self) -> None: super().__init__( name=("edgedb", "_int_or_null"), args=[("val", ("bigint",))], returns=("int",), volatility="immutable", strict=True, text=self.text, ) class ArrayIndexWithBoundsFunction(trampoline.VersionedFunction): """Get an array element or raise an out-of-bounds exception.""" text = ''' SELECT CASE WHEN val IS NULL THEN NULL ELSE edgedb_VER.raise_on_null( val[edgedb_VER._normalize_array_index( index, array_upper(val, 1))], 'array_subscript_error', msg => 'array index ' || index::text || ' is out of bounds', detail => detail ) END ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_index'), args=[('val', ('anyarray',)), ('index', ('bigint',)), ('detail', ('text',))], returns=('anyelement',), # Min volatility of exception helpers and pg_typeof is 'stable', # but for all practical purposes, we can assume 'immutable' volatility='stable', wrapper_volatility='immutable', text=self.text, ) class ArraySliceFunction(trampoline.VersionedFunction): """Get an array slice.""" # This function is also inlined in expr.py#_inline_array_slicing. # Known bug: if array has 2G elements and both bounds are overflowing, # this will return last element instead of an empty array. text = """ SELECT val[ edgedb_VER._normalize_array_slice_index(start, cardinality(val)) : edgedb_VER._normalize_array_slice_index(stop, cardinality(val)) - 1 ] """ def __init__(self) -> None: super().__init__( name=("edgedb", "_slice"), args=[ ("val", ("anyarray",)), ("start", ("bigint",)), ("stop", ("bigint",)), ], returns=("anyarray",), volatility="stable", wrapper_volatility='immutable', text=self.text, ) class StringIndexWithBoundsFunction(trampoline.VersionedFunction): """Get a string character or raise an out-of-bounds exception.""" text = ''' SELECT edgedb_VER.raise_on_empty( CASE WHEN pg_index IS NULL THEN '' ELSE substr("val", pg_index, 1) END, 'invalid_parameter_value', "typename" || ' index ' || "index"::text || ' is out of bounds', "detail" ) FROM ( SELECT ( edgedb_VER._normalize_array_index("index", char_length("val")) ) as pg_index ) t ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_index'), args=[ ('val', ('text',)), ('index', ('bigint',)), ('detail', ('text',)), ('typename', ('text',), "'string'"), ], returns=('text',), # Min volatility of exception helpers and pg_typeof is 'stable', # but for all practical purposes, we can assume 'immutable' volatility='stable', wrapper_volatility='immutable', text=self.text, ) class BytesIndexWithBoundsFunction(trampoline.VersionedFunction): """Get a bytes character or raise an out-of-bounds exception.""" text = ''' SELECT edgedb_VER.raise_on_empty( CASE WHEN pg_index IS NULL THEN ''::bytea ELSE substr("val", pg_index, 1) END, 'invalid_parameter_value', 'byte string index ' || "index"::text || ' is out of bounds', "detail" ) FROM ( SELECT ( edgedb_VER._normalize_array_index("index", length("val")) ) as pg_index ) t ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_index'), args=[ ('val', ('bytea',)), ('index', ('bigint',)), ('detail', ('text',)), ], returns=('bytea',), # Min volatility of exception helpers and pg_typeof is 'stable', # but for all practical purposes, we can assume 'immutable' volatility='stable', wrapper_volatility='immutable', text=self.text, ) class SubstrProxyFunction(trampoline.VersionedFunction): """Same as substr, but interpret negative length as 0 instead.""" text = r""" SELECT CASE WHEN length < 0 THEN '' ELSE substr(val, start::int, length) END """ def __init__(self) -> None: super().__init__( name=("edgedb", "_substr"), args=[ ("val", ("anyelement",)), ("start", ("int",)), ("length", ("int",)), ], returns=("anyelement",), volatility="immutable", strict=True, text=self.text, ) class LengthStringProxyFunction(trampoline.VersionedFunction): """Same as substr, but interpret negative length as 0 instead.""" text = r''' SELECT char_length(val) ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_length'), args=[('val', ('text',))], returns=('int',), volatility='immutable', strict=True, text=self.text) class LengthBytesProxyFunction(trampoline.VersionedFunction): """Same as substr, but interpret negative length as 0 instead.""" text = r''' SELECT length(val) ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_length'), args=[('val', ('bytea',))], returns=('int',), volatility='immutable', strict=True, text=self.text) class StringSliceImplFunction(trampoline.VersionedFunction): """Get a string slice.""" text = r""" SELECT edgedb_VER._substr( val, pg_start, pg_end - pg_start ) FROM (SELECT edgedb_VER._normalize_array_slice_index( start, edgedb_VER._length(val) ) as pg_start, edgedb_VER._normalize_array_slice_index( stop, edgedb_VER._length(val) ) as pg_end ) t """ def __init__(self) -> None: super().__init__( name=("edgedb", "_str_slice"), args=[ ("val", ("anyelement",)), ("start", ("bigint",)), ("stop", ("bigint",)), ], returns=("anyelement",), volatility="immutable", text=self.text, ) class StringSliceFunction(trampoline.VersionedFunction): """Get a string slice.""" text = r''' SELECT edgedb_VER._str_slice(val, start, stop) ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_slice'), args=[ ('val', ('text',)), ('start', ('bigint',)), ('stop', ('bigint',)), ], returns=('text',), volatility='stable', text=self.text) class BytesSliceFunction(trampoline.VersionedFunction): """Get a string slice.""" text = r''' SELECT edgedb_VER._str_slice(val, start, stop) ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_slice'), args=[ ('val', ('bytea',)), ('start', ('bigint',)), ('stop', ('bigint',)), ], returns=('bytea',), volatility='stable', text=self.text) class JSONIndexByTextFunction(trampoline.VersionedFunction): """Get a JSON element by text index or raise an exception.""" text = r''' SELECT CASE jsonb_typeof(val) WHEN 'object' THEN ( edgedb_VER.raise_on_null( val -> index, 'invalid_parameter_value', msg => ( 'JSON index ' || quote_literal(index) || ' is out of bounds' ), detail => detail ) ) WHEN 'array' THEN ( edgedb_VER.raise( NULL::jsonb, 'wrong_object_type', msg => ( 'cannot index JSON ' || jsonb_typeof(val) || ' by ' || pg_typeof(index)::text ), detail => detail ) ) ELSE edgedb_VER.raise( NULL::jsonb, 'wrong_object_type', msg => ( 'cannot index JSON ' || coalesce(jsonb_typeof(val), 'UNKNOWN') ), detail => ( '{"hint":"Retrieving an element by a string index ' || 'is only available for JSON objects."}' ) ) END ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_index'), args=[ ('val', ('jsonb',)), ('index', ('text',)), ('detail', ('text',), "''"), ], returns=('jsonb',), # Min volatility of exception helpers 'stable', # but for all practical purposes, we can assume 'immutable' volatility='stable', wrapper_volatility='immutable', strict=True, text=self.text, ) class JSONIndexByIntFunction(trampoline.VersionedFunction): """Get a JSON element by int index or raise an exception.""" text = r''' SELECT CASE jsonb_typeof(val) WHEN 'object' THEN ( edgedb_VER.raise( NULL::jsonb, 'wrong_object_type', msg => ( 'cannot index JSON ' || jsonb_typeof(val) || ' by ' || pg_typeof(index)::text ), detail => detail ) ) WHEN 'array' THEN ( edgedb_VER.raise_on_null( val -> edgedb_VER._int_or_null(index), 'invalid_parameter_value', msg => 'JSON index ' || index::text || ' is out of bounds', detail => detail ) ) WHEN 'string' THEN ( to_jsonb(edgedb_VER._index( val#>>'{}', index, detail, 'JSON' )) ) ELSE edgedb_VER.raise( NULL::jsonb, 'wrong_object_type', msg => ( 'cannot index JSON ' || coalesce(jsonb_typeof(val), 'UNKNOWN') ), detail => ( '{"hint":"Retrieving an element by an integer index ' || 'is only available for JSON arrays and strings."}' ) ) END ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_index'), args=[ ('val', ('jsonb',)), ('index', ('bigint',)), ('detail', ('text',), "''"), ], returns=('jsonb',), # Min volatility of exception helpers and pg_typeof is 'stable', # but for all practical purposes, we can assume 'immutable' volatility='stable', wrapper_volatility='immutable', strict=True, text=self.text, ) class JSONSliceFunction(trampoline.VersionedFunction): """Get a JSON array slice.""" text = r""" SELECT CASE WHEN val IS NULL THEN NULL WHEN jsonb_typeof(val) = 'array' THEN ( to_jsonb(edgedb_VER._slice( ( SELECT coalesce(array_agg(value), '{}'::jsonb[]) FROM jsonb_array_elements(val) ), start, stop )) ) WHEN jsonb_typeof(val) = 'string' THEN ( to_jsonb(edgedb_VER._slice(val#>>'{}', start, stop)) ) ELSE edgedb_VER.raise( NULL::jsonb, 'wrong_object_type', msg => ( 'cannot slice JSON ' || coalesce(jsonb_typeof(val), 'UNKNOWN') ), detail => ( '{"hint":"Slicing is only available for JSON arrays' || ' and strings."}' ) ) END """ def __init__(self) -> None: super().__init__( name=("edgedb", "_slice"), args=[ ("val", ("jsonb",)), ("start", ("bigint",)), ("stop", ("bigint",)), ], returns=("jsonb",), # Min volatility of to_jsonb is 'stable', # but for all practical purposes, we can assume 'immutable' volatility="stable", wrapper_volatility='immutable', text=self.text, ) # We need custom casting functions for various datetime scalars in # order to enforce correctness w.r.t. local vs time-zone-aware # datetime. Postgres does a lot of magic and guessing for time zones # and generally will accept text with or without time zone for any # particular flavor of timestamp. In order to guarantee that we can # detect time-zones we restrict the inputs to ISO8601 format. # # See issue #740. class DatetimeInFunction(trampoline.VersionedFunction): """Cast text into timestamptz using ISO8601 spec.""" text = r''' SELECT CASE WHEN val !~ ( '^\s*(' || '(\d{4}-\d{2}-\d{2}|\d{8})' || '[ tT]' || '(\d{2}(:\d{2}(:\d{2}(\.\d+)?)?)?|\d{2,6}(\.\d+)?)' || '([zZ]|[-+](\d{2,4}|\d{2}:\d{2}))' || ')\s*$' ) THEN edgedb_VER.raise( NULL::edgedbt.timestamptz_t, 'invalid_datetime_format', msg => ( 'invalid input syntax for type timestamptz: ' || quote_literal(val) ), detail => ( '{"hint":"Please use ISO8601 format. Example: ' || '2010-12-27T23:59:59-07:00. Alternatively ' || '\"to_datetime\" function provides custom ' || 'formatting options."}' ) ) ELSE val::edgedbt.timestamptz_t END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'datetime_in'), args=[('val', ('text',))], returns=('edgedbt', 'timestamptz_t'), # Same volatility as raise() (stable) volatility='stable', text=self.text) class DurationInFunction(trampoline.VersionedFunction): """Cast text into duration, ensuring there is no days or months units""" text = r''' SELECT CASE WHEN EXTRACT(MONTH FROM v.column1) != 0 OR EXTRACT(YEAR FROM v.column1) != 0 OR EXTRACT(DAY FROM v.column1) != 0 THEN edgedb_VER.raise( NULL::edgedbt.duration_t, 'invalid_datetime_format', msg => ( 'invalid input syntax for type std::duration: ' || quote_literal(val) ), detail => ( '{"hint":"Day, month and year units cannot be used ' || 'for std::duration."}' ) ) ELSE v.column1::edgedbt.duration_t END FROM (VALUES ( val::interval )) AS v ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'duration_in'), args=[('val', ('text',))], returns=('edgedbt', 'duration_t'), volatility='immutable', text=self.text, ) class DateDurationInFunction(trampoline.VersionedFunction): """ Cast text into date_duration, ensuring there is no unit smaller than days. """ text = r''' SELECT CASE WHEN EXTRACT(HOUR FROM v.column1) != 0 OR EXTRACT(MINUTE FROM v.column1) != 0 OR EXTRACT(SECOND FROM v.column1) != 0 THEN edgedb_VER.raise( NULL::edgedbt.date_duration_t, 'invalid_datetime_format', msg => ( 'invalid input syntax for type ' || 'std::cal::date_duration: ' || quote_literal(val) ), detail => ( '{"hint":"Units smaller than days cannot be used ' || 'for std::cal::date_duration."}' ) ) ELSE v.column1::edgedbt.date_duration_t END FROM (VALUES ( val::interval )) AS v ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'date_duration_in'), args=[('val', ('text',))], returns=('edgedbt', 'date_duration_t'), volatility='immutable', text=self.text, ) class LocalDatetimeInFunction(trampoline.VersionedFunction): """Cast text into timestamp using ISO8601 spec.""" text = r''' SELECT CASE WHEN val !~ ( '^\s*(' || '(\d{4}-\d{2}-\d{2}|\d{8})' || '[ tT]' || '(\d{2}(:\d{2}(:\d{2}(\.\d+)?)?)?|\d{2,6}(\.\d+)?)' || ')\s*$' ) THEN edgedb_VER.raise( NULL::edgedbt.timestamp_t, 'invalid_datetime_format', msg => ( 'invalid input syntax for type timestamp: ' || quote_literal(val) ), detail => ( '{"hint":"Please use ISO8601 format. Example ' || '2010-04-18T09:27:00 Alternatively ' || '\"to_local_datetime\" function provides custom ' || 'formatting options."}' ) ) ELSE val::edgedbt.timestamp_t END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'local_datetime_in'), args=[('val', ('text',))], returns=('edgedbt', 'timestamp_t'), volatility='immutable', text=self.text) class LocalDateInFunction(trampoline.VersionedFunction): """Cast text into date using ISO8601 spec.""" text = r''' SELECT CASE WHEN val !~ ( '^\s*(' || '(\d{4}-\d{2}-\d{2}|\d{8})' || ')\s*$' ) THEN edgedb_VER.raise( NULL::edgedbt.date_t, 'invalid_datetime_format', msg => ( 'invalid input syntax for type date: ' || quote_literal(val) ), detail => ( '{"hint":"Please use ISO8601 format. Example ' || '2010-04-18 Alternatively ' || '\"to_local_date\" function provides custom ' || 'formatting options."}' ) ) ELSE val::edgedbt.date_t END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'local_date_in'), args=[('val', ('text',))], returns=('edgedbt', 'date_t'), volatility='immutable', text=self.text) class LocalTimeInFunction(trampoline.VersionedFunction): """Cast text into time using ISO8601 spec.""" text = r''' SELECT CASE WHEN date_part('hour', x.t) = 24 THEN edgedb_VER.raise( NULL::time, 'invalid_datetime_format', msg => ( 'std::cal::local_time field value out of range: ' || quote_literal(val) ) ) ELSE x.t END FROM ( SELECT CASE WHEN val !~ ('^\s*(' || '(\d{2}(:\d{2}(:\d{2}(\.\d+)?)?)?|\d{2,6}(\.\d+)?)' || ')\s*$') THEN edgedb_VER.raise( NULL::time, 'invalid_datetime_format', msg => ( 'invalid input syntax for type time: ' || quote_literal(val) ), detail => ( '{"hint":"Please use ISO8601 format. Examples: ' || '18:43:27 or 18:43 Alternatively ' || '\"to_local_time\" function provides custom ' || 'formatting options."}' ) ) ELSE val::time END as t ) as x; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'local_time_in'), args=[('val', ('text',))], returns=('time',), volatility='immutable', text=self.text, ) class ToTimestampTZCheck(trampoline.VersionedFunction): """Checks if the original text has time zone or not.""" # What are we trying to mitigate? # We're trying to detect that when we're casting to datetime the # time zone is in fact present in the input. It is a problem if # it's not since then one gets assigned implicitly based on the # server settings. # # It is insufficient to rely on the presence of TZH in the format # string, since `to_timestamp` will happily ignore the missing # time-zone in the input anyway. So in order to tell whether the # input string contained a time zone that was in fact parsed we # employ the following trick: # # If the time zone is in the input then it is unambiguous and the # parsed value will not depend on the current server time zone. # However, if the time zone was omitted, then the parsed value # will default to the server time zone. This implies that if # changing the server time zone for the same input string affects # the parsed value, the input string itself didn't contain a time # zone. text = r''' DECLARE result timestamptz; chk timestamptz; msg text; BEGIN result := to_timestamp(val, fmt); PERFORM set_config('TimeZone', 'America/Toronto', true); chk := to_timestamp(val, fmt); -- We're deliberately not doing any save/restore because -- the server MUST be in UTC. In fact, this check relies -- on it. PERFORM set_config('TimeZone', 'UTC', true); IF hastz THEN msg := 'missing required'; ELSE msg := 'unexpected'; END IF; IF (result = chk) != hastz THEN RAISE EXCEPTION USING ERRCODE = 'invalid_datetime_format', MESSAGE = msg || ' time zone in input ' || quote_literal(val), DETAIL = ''; END IF; RETURN result::edgedbt.timestamptz_t; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_to_timestamptz_check'), args=[('val', ('text',)), ('fmt', ('text',)), ('hastz', ('bool',))], returns=('edgedbt', 'timestamptz_t'), # We're relying on changing settings, so it's volatile. volatility='volatile', language='plpgsql', text=self.text) class ToDatetimeFunction(trampoline.VersionedFunction): """Convert text into timestamptz using a formatting spec.""" # NOTE that if only the TZM (minutes) are mentioned it is not # enough for a valid time zone definition text = r''' SELECT CASE WHEN fmt !~ ( '^(' || '("([^"\\]|\\.)*")|' || '([^"]+)' || ')*(TZH).*$' ) THEN edgedb_VER.raise( NULL::edgedbt.timestamptz_t, 'invalid_datetime_format', msg => ( 'missing required time zone in format: ' || quote_literal(fmt) ), detail => ( $h${"hint":"Use one or both of the following: $h$ || $h$'TZH', 'TZM'"}$h$ ) ) ELSE edgedb_VER._to_timestamptz_check(val, fmt, true) END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'to_datetime'), args=[('val', ('text',)), ('fmt', ('text',))], returns=('edgedbt', 'timestamptz_t'), # Same as _to_timestamptz_check. volatility='volatile', text=self.text) class ToLocalDatetimeFunction(trampoline.VersionedFunction): """Convert text into timestamp using a formatting spec.""" # NOTE time zone should not be mentioned at all. text = r''' SELECT CASE WHEN fmt ~ ( '^(' || '("([^"\\]|\\.)*")|' || '([^"]+)' || ')*(TZH|TZM).*$' ) THEN edgedb_VER.raise( NULL::edgedbt.timestamp_t, 'invalid_datetime_format', msg => ( 'unexpected time zone in format: ' || quote_literal(fmt) ) ) ELSE edgedb_VER._to_timestamptz_check(val, fmt, false) ::edgedbt.timestamp_t END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'to_local_datetime'), args=[('val', ('text',)), ('fmt', ('text',))], returns=('edgedbt', 'timestamp_t'), # Same as _to_timestamptz_check. volatility='volatile', text=self.text) class StrToBool(trampoline.VersionedFunction): """Parse bool from text.""" # We first try to match case-insensitive "true|false" at all. On # null, we raise an exception. But otherwise we know that we have # an array of matches. The first element matching "true" and # second - "false". So the boolean value is then "true" if the # second array element is NULL and false otherwise. text = r''' SELECT ( coalesce( regexp_match(val, '^\s*(?:(true)|(false))\s*$', 'i')::text[], edgedb_VER.raise( NULL::text[], 'invalid_text_representation', msg => 'invalid input syntax for type bool: ' || quote_literal(val) ) ) )[2] IS NULL; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'str_to_bool'), args=[('val', ('text',))], returns=('bool',), strict=True, # Stable because it's raising exceptions. volatility='stable', text=self.text) class QuoteLiteralFunction(trampoline.VersionedFunction): """Encode string as edgeql literal quoted string""" text = r''' SELECT concat('\'', replace( replace(val, '\\', '\\\\'), '\'', '\\\''), '\'') ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'quote_literal'), args=[('val', ('text',))], returns=('str',), volatility='immutable', text=self.text) class QuoteIdentFunction(trampoline.VersionedFunction): """Quote ident function.""" # TODO do not quote valid identifiers unless they are reserved text = r''' SELECT concat('`', replace(val, '`', '``'), '`') ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'quote_ident'), args=[('val', ('text',))], returns=('text',), volatility='immutable', text=self.text, ) class QuoteNameFunction(trampoline.VersionedFunction): text = r""" SELECT string_agg(edgedb_VER.quote_ident(np), '::') FROM unnest(string_to_array("name", '::')) AS np """ def __init__(self) -> None: super().__init__( name=('edgedb', 'quote_name'), args=[('name', ('text',))], returns=('text',), volatility='immutable', text=self.text, ) class DescribeRolesAsDDLFunctionForwardDecl(trampoline.VersionedFunction): """Forward declaration for _describe_roles_as_ddl""" def __init__(self) -> None: super().__init__( name=('edgedb', '_describe_roles_as_ddl'), args=[], returns=('text'), # Stable because it's raising exceptions. volatility='stable', text='SELECT NULL::text', ) class DescribeRolesAsDDLFunction(trampoline.VersionedFunction): """Describe roles as DDL""" def __init__(self, schema: s_schema.Schema) -> None: role_obj = schema.get("sys::Role", type=s_objtypes.ObjectType) roles = _schema_alias_view_name(schema, role_obj) roles = (common.maybe_versioned_schema(roles[0]), roles[1]) member_of = role_obj.getptr(schema, s_name.UnqualName('member_of')) members = _schema_alias_view_name(schema, member_of) members = (common.maybe_versioned_schema(members[0]), members[1]) permissions_ptr = role_obj.getptr( schema, s_name.UnqualName('permissions'), type=s_props.Property ) permissions = _schema_alias_view_name(schema, permissions_ptr) permissions = ( common.maybe_versioned_schema(permissions[0]), permissions[1] ) branches_ptr = role_obj.getptr( schema, s_name.UnqualName('branches'), type=s_props.Property ) branches = _schema_alias_view_name(schema, branches_ptr) branches = ( common.maybe_versioned_schema(branches[0]), branches[1] ) super_col = ptr_col_name(schema, role_obj, 'superuser') name_col = ptr_col_name(schema, role_obj, 'name') pass_col = ptr_col_name(schema, role_obj, 'password') pg_pol_col = ptr_col_name( schema, role_obj, 'apply_access_policies_pg_default', ) qi_superuser = qlquote.quote_ident(defines.EDGEDB_SUPERUSER) text = f""" WITH RECURSIVE dependencies AS ( SELECT r.id AS id, m.target AS parent FROM {q(*roles)} r LEFT OUTER JOIN {q(*members)} m ON r.id = m.source ), roles_with_depths(id, depth) AS ( SELECT id, 0 FROM dependencies WHERE parent IS NULL UNION ALL SELECT dependencies.id, roles_with_depths.depth + 1 FROM dependencies INNER JOIN roles_with_depths ON dependencies.parent = roles_with_depths.id ), ordered_roles AS ( SELECT id, max(depth) FROM roles_with_depths GROUP BY id ORDER BY max(depth) ASC ) SELECT coalesce(string_agg( CASE WHEN role.{qi(name_col)} = {ql(defines.EDGEDB_SUPERUSER)} THEN NULLIF( concat( 'ALTER ROLE {qi_superuser} {{ ', NULLIF( (SELECT concat( 'EXTENDING ', string_agg( edgedb_VER.quote_ident( parent.{qi(name_col)} ), ', ' ), '; ' ) FROM {q(*members)} member INNER JOIN {q(*roles)} parent ON parent.id = member.target WHERE member.source = role.id ), 'EXTENDING ; ' ), (CASE WHEN role.{qi(pass_col)} IS NOT NULL THEN concat( 'SET password_hash := ', quote_literal(role.{qi(pass_col)}), '; ' ) ELSE NULL END ), (CASE WHEN role.{qi(pg_pol_col)} IS NOT NULL THEN concat( 'SET apply_access_policies_pg_default ', ':= ', role.{qi(pg_pol_col)}::text, '; ' ) ELSE NULL END ), NULLIF ( concat( 'SET permissions := {{ ', ( SELECT string_agg( permissions.target, ', ' ) FROM {q(*permissions)} permissions WHERE permissions.source = role.id ), ' }}; ' ), 'SET permissions := {{ }}; ' ), NULLIF ( concat( 'SET branches := {{ ', ( SELECT string_agg( quote_literal(branches.target), ', ' ) FROM {q(*branches)} branches WHERE branches.source = role.id ), ' }}; ' ), 'SET branches := {{ ''*'' }}; ' ), '}};' ), 'ALTER ROLE {qi_superuser} {{ }};' ) ELSE concat( 'CREATE ', (CASE WHEN role.{qi(super_col)} THEN 'SUPERUSER ' ELSE NULL END ), 'ROLE ', edgedb_VER.quote_ident(role.{qi(name_col)}), NULLIF( ( SELECT concat( ' EXTENDING ', string_agg( edgedb_VER.quote_ident( parent.{qi(name_col)} ), ', ' ) ) FROM {q(*members)} member INNER JOIN {q(*roles)} parent ON parent.id = member.target WHERE member.source = role.id ), ' EXTENDING ' ), NULLIF( concat( ' {{ ', (CASE WHEN role.{qi(pass_col)} IS NOT NULL THEN concat( 'SET password_hash := ', quote_literal(role.{qi(pass_col)}), '; ' ) ELSE NULL END ), (CASE WHEN role.{qi(pg_pol_col)} IS NOT NULL THEN concat( 'SET ', 'apply_access_policies_pg_default ', ':= ', role.{qi(pg_pol_col)}::text, '; ' ) ELSE NULL END ), NULLIF ( concat( 'SET permissions := {{ ', ( SELECT string_agg( permissions.target, ', ' ) FROM {q(*permissions)} permissions WHERE permissions.source = role.id ), ' }}; ' ), 'SET permissions := {{ }}; ' ), NULLIF ( concat( 'SET branches := {{ ', ( SELECT string_agg( quote_literal( branches.target), ', ' ) FROM {q(*branches)} branches WHERE branches.source = role.id ), ' }}; ' ), 'SET branches := {{ ''*'' }}; ' ), '}}' ), ' {{ }}' ), ';' ) END, '\n' ), '') str FROM ordered_roles JOIN {q(*roles)} role ON role.id = ordered_roles.id """ super().__init__( name=('edgedb', '_describe_roles_as_ddl'), args=[], returns=('text'), # Stable because it's raising exceptions. volatility='stable', text=text) class AllRoleMembershipsFunctionForwardDecl(trampoline.VersionedFunction): """Forward declaration for _all_role_memberships""" def __init__(self) -> None: super().__init__( name=('edgedb', '_all_role_memberships'), args=[('role_id', ('uuid',))], returns=('uuid[]'), volatility='stable', text='SELECT NULL::uuid[]', ) class AllRoleMembershipsFunction(trampoline.VersionedFunction): """Get all memberships for a given role""" def __init__(self, schema: s_schema.Schema) -> None: role_obj = schema.get("sys::Role", type=s_objtypes.ObjectType) roles = _schema_alias_view_name(schema, role_obj) roles = (common.maybe_versioned_schema(roles[0]), roles[1]) member_of = role_obj.getptr(schema, s_name.UnqualName('member_of')) members = _schema_alias_view_name(schema, member_of) members = (common.maybe_versioned_schema(members[0]), members[1]) text = f""" WITH RECURSIVE memberships (id, member_of) AS ( ( SELECT r.id as id, m.target as member_of FROM {q(*roles)} r INNER JOIN {q(*members)} m ON r.id = m.source WHERE r.id = "role_id" ) UNION ( SELECT r.id as id, m.target as member_of FROM {q(*roles)} r INNER JOIN {q(*members)} m ON r.id = m.source INNER JOIN memberships ms ON ms.member_of = r.id ) ) SELECT array_agg(member_of) FROM memberships """ super().__init__( name=('edgedb', '_all_role_memberships'), args=[('role_id', ('uuid',))], returns=('uuid[]'), volatility='stable', text=text, ) class DumpSequencesFunction(trampoline.VersionedFunction): text = r""" SELECT string_agg( 'SELECT std::sequence_reset(' || 'INTROSPECT ' || edgedb_VER.quote_name(seq.name) || (CASE WHEN seq_st.is_called THEN ', ' || seq_st.last_value::text ELSE '' END) || ');', E'\n' ) FROM (SELECT id, name FROM edgedb_VER."_SchemaScalarType" WHERE id = any("seqs") ) AS seq, LATERAL ( SELECT COALESCE(last_value, start_value)::text AS last_value, last_value IS NOT NULL AS is_called FROM pg_sequences, LATERAL ROWS FROM ( edgedb_VER.get_sequence_backend_name(seq.id) ) AS seq_name(schema text, name text) WHERE (pg_sequences.schemaname, pg_sequences.sequencename) = (seq_name.schema, seq_name.name) ) AS seq_st """ def __init__(self) -> None: super().__init__( name=('edgedb', '_dump_sequences'), args=[('seqs', ('uuid[]',))], returns=('text',), # Volatile because sequence state is volatile volatility='volatile', text=self.text, ) class SysConfigSourceType(dbops.Enum): def __init__(self) -> None: super().__init__( name=('edgedb', '_sys_config_source_t'), values=[ 'default', 'postgres default', 'postgres environment variable', 'postgres configuration file', 'environment variable', 'command line', 'postgres command line', 'postgres global', 'postgres client', 'system override', 'database', 'postgres override', 'postgres interactive', 'postgres test', 'session', ] ) class SysConfigScopeType(dbops.Enum): def __init__(self) -> None: super().__init__( name=('edgedb', '_sys_config_scope_t'), values=[ 'INSTANCE', 'DATABASE', 'SESSION', ] ) class SysConfigValueType(dbops.CompositeType): """Type of values returned by _read_sys_config.""" def __init__(self) -> None: super().__init__(name=('edgedb', '_sys_config_val_t')) self.add_columns([ dbops.Column(name='name', type='text'), dbops.Column(name='value', type='jsonb'), dbops.Column(name='source', type='edgedb._sys_config_source_t'), dbops.Column(name='scope', type='edgedb._sys_config_scope_t'), ]) class SysConfigEntryType(dbops.CompositeType): """Type of values returned by _read_sys_config_full.""" def __init__(self) -> None: super().__init__(name=('edgedb', '_sys_config_entry_t')) self.add_columns([ dbops.Column(name='max_source', type='edgedb._sys_config_source_t'), dbops.Column(name='value', type='edgedb._sys_config_val_t'), ]) class IntervalToMillisecondsFunction(trampoline.VersionedFunction): """Cast an interval into milliseconds.""" text = r''' SELECT trunc(extract(hours from "val"))::numeric * 3600000 + trunc(extract(minutes from "val"))::numeric * 60000 + trunc(extract(milliseconds from "val"))::numeric ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_interval_to_ms'), args=[('val', ('interval',))], returns=('numeric',), volatility='immutable', text=self.text, ) class SafeIntervalCastFunction(trampoline.VersionedFunction): """A safer text to interval casting implementaion. Casting large-unit durations (like '4032000000us') results in an error. Huge durations like this can be returned when introspecting current database config. Fix that by parsing the argument and using multiplication. """ text = r''' SELECT CASE WHEN m.v[1] IS NOT NULL AND m.v[2] IS NOT NULL THEN m.v[1]::numeric * ('1' || m.v[2])::interval ELSE "val"::interval END FROM LATERAL ( SELECT regexp_match( "val", '^(\d+)\s*(us|ms|s|min|h)$') AS v ) AS m ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_interval_safe_cast'), args=[('val', ('text',))], returns=('interval',), volatility='immutable', text=self.text, ) class ConvertPostgresConfigUnitsFunction(trampoline.VersionedFunction): """Convert duration/memory values to milliseconds/kilobytes. See https://www.postgresql.org/docs/12/config-setting.html for information about the units Postgres config system has. """ text = r""" SELECT ( CASE WHEN "unit" = any(ARRAY['us', 'ms', 's', 'min', 'h']) THEN to_jsonb( edgedb_VER._interval_safe_cast( ("value" * "multiplier")::text || "unit" ) ) WHEN "unit" = 'B' THEN to_jsonb( ("value" * "multiplier")::text || 'B' ) WHEN "unit" = 'kB' THEN to_jsonb( ("value" * "multiplier")::text || 'KiB' ) WHEN "unit" = 'MB' THEN to_jsonb( ("value" * "multiplier")::text || 'MiB' ) WHEN "unit" = 'GB' THEN to_jsonb( ("value" * "multiplier")::text || 'GiB' ) WHEN "unit" = 'TB' THEN to_jsonb( ("value" * "multiplier")::text || 'TiB' ) WHEN "unit" = '' THEN ("value" * "multiplier")::text::jsonb ELSE edgedb_VER.raise( NULL::jsonb, msg => ( 'unknown configuration unit "' || COALESCE("unit", '') || '"' ) ) END ) """ def __init__(self) -> None: super().__init__( name=('edgedb', '_convert_postgres_config_units'), args=[ ('value', ('numeric',)), ('multiplier', ('numeric',)), ('unit', ('text',)) ], returns=('jsonb',), volatility='immutable', text=self.text, ) class TypeIDToConfigType(trampoline.VersionedFunction): """Get a postgres config type from a type id. (We typically try to read extension configs straight from the config tables, but for extension configs those aren't present.) """ config_types = { 'bool': ['std::bool'], 'string': ['std::str'], 'integer': ['std::int16', 'std::int32', 'std::int64'], 'real': ['std::float32', 'std::float64'], } cases = [ f''' WHEN "typeid" = '{s_obj.get_known_type_id(t)}' THEN '{ct}' ''' for ct, types in config_types.items() for t in types ] scases = '\n'.join(cases) text = f""" SELECT ( CASE {scases} ELSE edgedb_VER.raise( NULL::text, msg => ( 'unknown configuration type "' || "typeid" || '"' ) ) END ) """ def __init__(self) -> None: super().__init__( name=('edgedb', '_type_id_to_config_type'), args=[ ('typeid', ('uuid',)), ], returns=('text',), volatility='immutable', text=self.text, ) class NormalizedPgSettingsView(trampoline.VersionedView): """Just like `pg_settings` but with the parsed 'unit' column.""" query = r''' SELECT s.name AS name, s.setting AS setting, s.vartype AS vartype, s.source AS source, unit.multiplier AS multiplier, unit.unit AS unit FROM pg_settings AS s, LATERAL ( SELECT regexp_match( s.unit, '^(\d*)\s*([a-zA-Z]{1,3})$') AS v ) AS _unit, LATERAL ( SELECT COALESCE( CASE WHEN _unit.v[1] = '' THEN 1 ELSE _unit.v[1]::int END, 1 ) AS multiplier, COALESCE(_unit.v[2], '') AS unit ) AS unit ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_normalized_pg_settings'), query=self.query, ) class InterpretConfigValueToJsonFunction(trampoline.VersionedFunction): """Convert a Postgres config value to jsonb. This function: * converts booleans to JSON true/false; * converts enums and strings to JSON strings; * converts real/integers to JSON numbers: - for durations: we always convert to milliseconds; - for memory size: we always convert to kilobytes; - already unitless numbers are left as is. See https://www.postgresql.org/docs/current/config-setting.html for information about the units Postgres config system has. """ text = r""" SELECT ( CASE WHEN "type" = 'bool' THEN ( CASE WHEN lower("value") = any(ARRAY['on', 'true', 'yes', '1']) THEN 'true' ELSE 'false' END )::jsonb WHEN "type" = 'enum' OR "type" = 'string' THEN to_jsonb("value") WHEN "type" = 'integer' OR "type" = 'real' THEN edgedb_VER._convert_postgres_config_units( "value"::numeric, "multiplier"::numeric, "unit" ) ELSE edgedb_VER.raise( NULL::jsonb, msg => ( 'unknown configuration type "' || COALESCE("type", '') || '"' ) ) END ) """ def __init__(self) -> None: super().__init__( name=('edgedb', '_interpret_config_value_to_json'), args=[ ('value', ('text',)), ('type', ('text',)), ('multiplier', ('int',)), ('unit', ('text',)) ], returns=('jsonb',), volatility='immutable', text=self.text, ) class PostgresJsonConfigValueToFrontendConfigValueFunction( trampoline.VersionedFunction, ): """Convert a Postgres config value to frontend config value. Most values are retained as-is, but some need translation, which is implemented as a to_frontend_expr() on the corresponding setting ScalarType. """ def __init__(self, config_spec: edbconfig.Spec) -> None: variants_list = [] for setting in config_spec.values(): if ( setting.backend_setting and isinstance(setting.type, type) and issubclass(setting.type, statypes.ScalarType) ): conv_expr = setting.type.to_frontend_expr('"value"->>0') if conv_expr is not None: variants_list.append(f""" WHEN {ql(setting.backend_setting)} THEN to_jsonb({conv_expr}) """) variants = "\n".join(variants_list) text = f""" SELECT ( CASE "setting_name" {variants} ELSE "value" END ) """ super().__init__( name=('edgedb', '_postgres_json_config_value_to_fe_config_value'), args=[ ('setting_name', ('text',)), ('value', ('jsonb',)) ], returns=('jsonb',), volatility='immutable', text=text, ) class PostgresConfigValueToJsonFunction(trampoline.VersionedFunction): """Convert a Postgres setting to JSON value. Steps: * Lookup the `setting_name` in pg_settings to determine its type and unit. * Parse `setting_value` to see if it starts with numbers and ends with what looks like a unit. * Fetch the unit/multiplier pg_settings (well, from our view over it). * If `setting_value` has a unit, pass it to `_interpret_config_value_to_json` * If `setting_value` doesn't have a unit, pass it to `_interpret_config_value_to_json` along with the base unit/multiplier from pg_settings. * Then, the `_interpret_config_value_to_json` is capable of casting the value correctly based on the pg_settings type and the supplied unit/multiplier. """ text = r""" SELECT edgedb_VER._postgres_json_config_value_to_fe_config_value( "setting_name", backend_json_value.value ) FROM LATERAL ( SELECT regexp_match( "setting_value", '^(\d+)\s*([a-zA-Z]{0,3})$') AS v ) AS _unit, LATERAL ( SELECT COALESCE(_unit.v[1], "setting_value") AS val, COALESCE(_unit.v[2], '') AS unit ) AS parsed_value LEFT OUTER JOIN ( SELECT epg_settings.vartype AS vartype, epg_settings.multiplier AS multiplier, epg_settings.unit AS unit FROM edgedb_VER._normalized_pg_settings AS epg_settings WHERE epg_settings.name = "setting_name" ) AS settings_in ON true CROSS JOIN LATERAL ( SELECT COALESCE(settings_in.vartype, edgedb_VER._type_id_to_config_type("setting_typeid")) as vartype, COALESCE(settings_in.multiplier, '1') as multiplier, COALESCE(settings_in.unit, '') as unit ) AS settings CROSS JOIN LATERAL (SELECT (CASE WHEN parsed_value.unit != '' THEN edgedb_VER._interpret_config_value_to_json( parsed_value.val, settings.vartype, 1, parsed_value.unit ) ELSE edgedb_VER._interpret_config_value_to_json( "setting_value", settings.vartype, settings.multiplier, settings.unit ) END) AS value ) AS backend_json_value """ def __init__(self) -> None: super().__init__( name=('edgedb', '_postgres_config_value_to_json'), args=[ ('setting_name', ('text',)), ('setting_typeid', ('uuid',)), ('setting_value', ('text',)), ], returns=('jsonb',), volatility='volatile', text=self.text, ) class SysConfigFullFunction(trampoline.VersionedFunction): # This is a function because "_edgecon_state" is a temporary table # and therefore cannot be used in a view. text = f''' DECLARE query text; BEGIN query := $$ WITH config_spec AS ( SELECT s.key AS name, s.value->'default' AS default, (s.value->>'internal')::bool AS internal, (s.value->>'system')::bool AS system, (s.value->>'typeid')::uuid AS typeid, (s.value->>'typemod') AS typemod, (s.value->>'backend_setting') AS backend_setting FROM edgedbinstdata_VER.instdata as id, LATERAL jsonb_each(id.json) AS s WHERE id.key LIKE 'configspec%' ), config_defaults AS ( SELECT s.name AS name, s.default AS value, 'default' AS source, s.backend_setting IS NOT NULL AS is_backend FROM config_spec s ), config_extension_defaults AS ( SELECT * FROM config_defaults WHERE name like '%::%' ), config_static AS ( SELECT s.name AS name, s.value AS value, (CASE WHEN s.type = 'A' THEN 'command line' -- Due to inplace upgrade limits, without adding a new -- layer, configuration file values are manually squashed -- into the `environment variables` layer, see below. ELSE 'environment variable' END) AS source, config_spec.backend_setting IS NOT NULL AS is_backend FROM _edgecon_state s INNER JOIN config_spec ON (config_spec.name = s.name) WHERE -- Give precedence to configuration file values over -- environment variables manually. s.type = 'A' OR s.type = 'F' OR ( s.type = 'E' AND NOT EXISTS ( SELECT 1 FROM _edgecon_state ss WHERE ss.name = s.name AND ss.type = 'F' ) ) ), config_sys AS ( SELECT s.key AS name, s.value AS value, 'system override' AS source, config_spec.backend_setting IS NOT NULL AS is_backend FROM jsonb_each( edgedb_VER.get_database_metadata( {ql(defines.EDGEDB_SYSTEM_DB)} ) -> 'sysconfig' ) AS s INNER JOIN config_spec ON (config_spec.name = s.key) ), config_db AS ( SELECT s.name AS name, s.value AS value, 'database' AS source, config_spec.backend_setting IS NOT NULL AS is_backend FROM edgedb._db_config s INNER JOIN config_spec ON (config_spec.name = s.name) ), config_sess AS ( SELECT s.name AS name, s.value AS value, 'session' AS source, FALSE AS is_backend -- only 'B' is for backend settings FROM _edgecon_state s WHERE s.type = 'C' ), pg_db_setting AS ( SELECT spec.name, edgedb_VER._postgres_config_value_to_json( spec.backend_setting, spec.typeid, nameval.value ) AS value, 'database' AS source, TRUE AS is_backend FROM (SELECT setconfig FROM pg_db_role_setting WHERE setdatabase = ( SELECT oid FROM pg_database WHERE datname = current_database() ) AND setrole = 0 ) AS cfg_array, LATERAL unnest(cfg_array.setconfig) AS cfg_set(s), LATERAL ( SELECT split_part(cfg_set.s, '=', 1) AS name, split_part(cfg_set.s, '=', 2) AS value ) AS nameval, LATERAL ( SELECT config_spec.name, config_spec.backend_setting, config_spec.typeid FROM config_spec WHERE nameval.name = config_spec.backend_setting ) AS spec ), $$; IF fs_access THEN query := query || $$ pg_conf_settings AS ( SELECT spec.name, edgedb_VER._postgres_config_value_to_json( spec.backend_setting, spec.typeid, setting ) AS value, 'postgres configuration file' AS source, TRUE AS is_backend FROM pg_file_settings, LATERAL ( SELECT config_spec.name, config_spec.backend_setting, config_spec.typeid FROM config_spec WHERE pg_file_settings.name = config_spec.backend_setting ) AS spec WHERE sourcefile != (( SELECT setting FROM pg_settings WHERE name = 'data_directory' ) || '/postgresql.auto.conf') AND applied ), pg_auto_conf_settings AS ( SELECT spec.name, edgedb_VER._postgres_config_value_to_json( spec.backend_setting, spec.typeid, setting ) AS value, 'system override' AS source, TRUE AS is_backend FROM pg_file_settings, LATERAL ( SELECT config_spec.name, config_spec.backend_setting, config_spec.typeid FROM config_spec WHERE pg_file_settings.name = config_spec.backend_setting ) AS spec WHERE sourcefile = (( SELECT setting FROM pg_settings WHERE name = 'data_directory' ) || '/postgresql.auto.conf') AND applied ), $$; END IF; query := query || $$ pg_config AS ( SELECT spec.name, edgedb_VER._postgres_json_config_value_to_fe_config_value( settings.name, edgedb_VER._interpret_config_value_to_json( settings.setting, settings.vartype, settings.multiplier, settings.unit ) ) AS value, source AS source, TRUE AS is_backend FROM ( SELECT epg_settings.name AS name, epg_settings.unit AS unit, epg_settings.multiplier AS multiplier, epg_settings.vartype AS vartype, epg_settings.setting AS setting, (CASE WHEN epg_settings.source = 'session' THEN epg_settings.source ELSE 'postgres ' || epg_settings.source END) AS source FROM edgedb_VER._normalized_pg_settings AS epg_settings WHERE epg_settings.source != 'database' ) AS settings, LATERAL ( SELECT config_spec.name FROM config_spec WHERE settings.name = config_spec.backend_setting ) AS spec ), -- extension session configs don't show up in any system view, so we -- check _edgecon_state to see when they are present. pg_extension_config AS ( SELECT config_spec.name, -- XXX: Or would it be better to just use the json directly? edgedb_VER._postgres_config_value_to_json( config_spec.backend_setting, config_spec.typeid, current_setting(config_spec.backend_setting, true) ) AS value, 'session' AS source, TRUE AS is_backend FROM _edgecon_state s INNER JOIN config_spec ON s.name = config_spec.name WHERE s.type = 'B' AND s.name LIKE '%::%' ), edge_all_settings AS MATERIALIZED ( SELECT q.* FROM ( SELECT * FROM config_defaults UNION ALL SELECT * FROM config_static UNION ALL SELECT * FROM config_sys UNION ALL SELECT * FROM config_db UNION ALL SELECT * FROM config_sess ) AS q WHERE NOT q.is_backend ), $$; IF fs_access THEN query := query || $$ pg_all_settings AS MATERIALIZED ( SELECT q.* FROM ( -- extension defaults aren't in any system views SELECT * FROM config_extension_defaults UNION ALL SELECT * FROM pg_db_setting UNION ALL SELECT * FROM pg_conf_settings UNION ALL SELECT * FROM pg_auto_conf_settings UNION ALL SELECT * FROM pg_config UNION ALL SELECT * FROM pg_extension_config ) AS q WHERE q.is_backend ) $$; ELSE query := query || $$ pg_all_settings AS MATERIALIZED ( SELECT q.* FROM ( -- extension defaults aren't in any system views SELECT * FROM config_extension_defaults UNION ALL -- config_sys is here, because there -- is no other way to read instance-level -- configuration overrides. SELECT * FROM config_sys UNION ALL SELECT * FROM pg_db_setting UNION ALL SELECT * FROM pg_config UNION ALL SELECT * FROM pg_extension_config ) AS q WHERE q.is_backend ) $$; END IF; query := query || $$ SELECT max_source AS max_source, (q.name, q.value, q.source, (CASE WHEN q.source < 'database'::edgedb._sys_config_source_t THEN 'INSTANCE' WHEN q.source = 'database'::edgedb._sys_config_source_t THEN 'DATABASE' ELSE 'SESSION' END)::edgedb._sys_config_scope_t )::edgedb._sys_config_val_t as value FROM unnest($2) as max_source, LATERAL (SELECT u.name, u.value, u.source::edgedb._sys_config_source_t, row_number() OVER ( PARTITION BY u.name ORDER BY u.source::edgedb._sys_config_source_t DESC ) AS n FROM (SELECT * FROM ( SELECT * FROM edge_all_settings UNION ALL SELECT * FROM pg_all_settings ) AS q WHERE q.value IS NOT NULL AND ($1 IS NULL OR q.source::edgedb._sys_config_source_t = any($1) ) AND (max_source IS NULL OR q.source::edgedb._sys_config_source_t <= max_source ) ) AS u ) AS q WHERE q.n = 1; $$; RETURN QUERY EXECUTE query USING source_filter, max_sources; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_read_sys_config_full'), args=[ ( 'source_filter', ('edgedb', '_sys_config_source_t[]',), 'NULL', ), ( 'max_sources', ('edgedb', '_sys_config_source_t[]'), 'NULL', ), ( 'fs_access', ('bool',), 'TRUE', ) ], returns=('edgedb', '_sys_config_entry_t'), set_returning=True, language='plpgsql', volatility='volatile', text=self.text, ) class SysConfigUncachedFunction(trampoline.VersionedFunction): text = f''' DECLARE backend_caps bigint; BEGIN backend_caps := edgedb_VER.get_backend_capabilities(); IF (backend_caps & {int(params.BackendCapabilities.CONFIGFILE_ACCESS)}) != 0 THEN RETURN QUERY SELECT * FROM edgedb_VER._read_sys_config_full( source_filter, max_sources, TRUE); ELSE RETURN QUERY SELECT * FROM edgedb_VER._read_sys_config_full( source_filter, max_sources, FALSE); END IF; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_read_sys_config_uncached'), args=[ ( 'source_filter', ('edgedb', '_sys_config_source_t[]',), 'NULL', ), ( 'max_sources', ('edgedb', '_sys_config_source_t[]'), 'NULL', ), ], returns=('edgedb', '_sys_config_entry_t'), set_returning=True, language='plpgsql', volatility='volatile', text=self.text, ) class SysConfigFunction(trampoline.VersionedFunction): text = f''' DECLARE BEGIN -- Only bother caching the source_filter IS NULL case, since that -- is what drives the config views. source_filter is used in -- DESCRIBE CONFIG IF source_filter IS NOT NULL OR array_position( ARRAY[NULL, 'database', 'system override']::edgedb._sys_config_source_t[], max_source) IS NULL THEN RETURN QUERY SELECT (c.value).name, (c.value).value, (c.value).source, (c.value).scope FROM edgedb_VER._read_sys_config_uncached( source_filter, ARRAY[max_source]) AS c; RETURN; END IF; IF count(*) = 0 FROM "_config_cache" c WHERE source IS NOT DISTINCT FROM max_source THEN INSERT INTO "_config_cache" SELECT (s.max_source), (s.value) FROM edgedb_VER._read_sys_config_uncached( source_filter, ARRAY[ NULL, 'database', 'system override']::edgedb._sys_config_source_t[]) AS s; END IF; RETURN QUERY SELECT (c.value).name, (c.value).value, (c.value).source, (c.value).scope FROM "_config_cache" c WHERE source IS NOT DISTINCT FROM max_source; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_read_sys_config'), args=[ ( 'source_filter', ('edgedb', '_sys_config_source_t[]',), 'NULL', ), ( 'max_source', ('edgedb', '_sys_config_source_t'), 'NULL', ), ], returns=('edgedb', '_sys_config_val_t'), set_returning=True, language='plpgsql', volatility='volatile', text=self.text, ) class SysClearConfigCacheFunction(trampoline.VersionedFunction): text = f''' DECLARE BEGIN DELETE FROM "_config_cache" c; RETURN true; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_clear_sys_config_cache'), args=[], returns=("boolean"), set_returning=False, language='plpgsql', volatility='volatile', text=self.text, ) class ResetSessionConfigFunction(trampoline.VersionedFunction): text = f''' RESET ALL ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_reset_session_config'), args=[], returns=('void',), language='sql', volatility='volatile', text=self.text, ) class ApplySessionConfigFunction(trampoline.VersionedFunction): """Apply a Gel config setting to the backend, if possible. The function accepts any Gel config name/value pair. If this specific config setting happens to be implemented via a backend setting, it would be applied to the current PostgreSQL session. If the config setting doesn't reflect into a backend setting the function is a no-op. The function always returns the passed config name, unmodified (this simplifies using the function in queries.) """ def __init__(self, config_spec: edbconfig.Spec) -> None: backend_settings = {} for setting_name in config_spec: setting = config_spec[setting_name] if setting.backend_setting and not setting.system: backend_settings[setting_name] = setting.backend_setting variants_list = [] for setting_name, backend_setting_name in backend_settings.items(): setting = config_spec[setting_name] valql = '"value"->>0' if ( isinstance(setting.type, type) and issubclass(setting.type, statypes.ScalarType) ): valql = setting.type.to_backend_expr(valql) variants_list.append(f''' WHEN "name" = {ql(setting_name)} THEN pg_catalog.set_config( {ql(backend_setting_name)}::text, {valql}, false ) ''') ext_config = ''' SELECT pg_catalog.set_config( (s.val->>'backend_setting')::text, "value"->>0, false ) FROM edgedbinstdata_VER.instdata as id, LATERAL jsonb_each(id.json) AS s(key, val) WHERE id.key = 'configspec_ext' AND s.key = "name" ''' variants = "\n".join(variants_list) text = f''' SELECT ( CASE WHEN "name" = any( ARRAY[{",".join(ql(str(bs)) for bs in backend_settings)}] ) THEN ( CASE WHEN (CASE {variants} END) IS NULL THEN "name" ELSE "name" END ) WHEN "name" LIKE '%::%' THEN CASE WHEN ({ext_config}) IS NULL THEN "name" ELSE "name" END ELSE "name" END ) ''' super().__init__( name=('edgedb', '_apply_session_config'), args=[ ('name', ('text',)), ('value', ('jsonb',)), ], returns=('text',), language='sql', volatility='volatile', text=text, ) class SysGetTransactionIsolation(trampoline.VersionedFunction): "Get transaction isolation value as text compatible with Gel's enum." text = r''' SELECT CASE setting WHEN 'repeatable read' THEN 'RepeatableRead' WHEN 'serializable' THEN 'Serializable' ELSE ( SELECT edgedb_VER.raise( NULL::text, msg => ( 'unknown transaction isolation level "' || setting || '"' ) ) ) END FROM pg_settings WHERE name = 'transaction_isolation' ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_get_transaction_isolation'), args=[], returns=('text',), # This function only reads from a table. volatility='stable', text=self.text) class GetCachedReflection(trampoline.VersionedFunction): "Return a list of existing schema reflection helpers." text = ''' SELECT substring(proname, '__rh_#"%#"', '#') AS eql_hash, proargnames AS argnames FROM pg_proc INNER JOIN pg_namespace ON (pronamespace = pg_namespace.oid) WHERE proname LIKE '\\_\\_rh\\_%' AND nspname = 'edgedb_VER' ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_get_cached_reflection'), args=[], returns=('record',), set_returning=True, # This function only reads from a table. volatility='stable', text=self.text, ) class GetBaseScalarTypeMap(trampoline.VersionedFunction): """Return a map of base Gel scalar type ids to Postgres type names.""" text = "VALUES" + ", ".join( f"({ql(str(k))}::uuid, {qtl(v)})" for k, v in types.base_type_name_map.items() ) def __init__(self) -> None: super().__init__( name=('edgedb', '_get_base_scalar_type_map'), args=[], returns=('record',), set_returning=True, volatility='immutable', text=self.text, ) class GetTypeToRangeNameMap(trampoline.VersionedFunction): """Return a map of type names to the name of the associated range type""" text = f"VALUES" + ", ".join( f"({qtl(k)}, {qtl(v)})" for k, v in types.type_to_range_name_map.items() ) def __init__(self) -> None: super().__init__( name=('edgedb', '_get_type_to_range_type_map'), args=[], returns=('record',), set_returning=True, volatility='immutable', text=self.text, ) class GetTypeToMultiRangeNameMap(trampoline.VersionedFunction): "Return a map of type names to the name of the associated multirange type" text = f"VALUES" + ", ".join( f"({qtl(k)}, {qtl(v)})" for k, v in types.type_to_multirange_name_map.items() ) def __init__(self) -> None: super().__init__( name=('edgedb', '_get_type_to_multirange_type_map'), args=[], returns=('record',), set_returning=True, volatility='immutable', text=self.text, ) class GetPgTypeForEdgeDBTypeFunction(trampoline.VersionedFunction): """Return Postgres OID representing a given Gel type.""" text = f''' SELECT coalesce( sql_type::regtype::oid, ( SELECT tn::regtype::oid FROM edgedb_VER._get_base_scalar_type_map() AS m(tid uuid, tn text) WHERE m.tid = "typeid" ), ( SELECT typ.oid FROM pg_catalog.pg_type typ WHERE typ.typname = "typeid"::text || '_domain' OR typ.typname = "typeid"::text || '_t' ), ( SELECT typ.typarray FROM pg_catalog.pg_type typ WHERE "kind" = 'schema::Array' AND ( typ.typname = "elemid"::text || '_domain' OR typ.typname = "elemid"::text || '_t' OR typ.oid = ( SELECT tn::regtype::oid FROM edgedb_VER._get_base_scalar_type_map() AS m(tid uuid, tn text) WHERE tid = "elemid" ) ) ), ( SELECT rng.rngtypid FROM pg_catalog.pg_range rng WHERE "kind" = 'schema::Range' -- For ranges, we need to do the lookup based on -- our internal map of elem names to range names, -- because we use the builtin daterange as the range -- for edgedbt.date_t. AND rng.rngtypid = ( SELECT rn::regtype::oid FROM edgedb_VER._get_base_scalar_type_map() AS m(tid uuid, tn text) INNER JOIN edgedb_VER._get_type_to_range_type_map() AS m2(tn2 text, rn text) ON tn = tn2 WHERE tid = "elemid" ) ), ( SELECT rng.rngmultitypid FROM pg_catalog.pg_range rng WHERE "kind" = 'schema::MultiRange' -- For multiranges, we need to do the lookup based on -- our internal map of elem names to range names, -- because we use the builtin daterange as the range -- for edgedbt.date_t. AND rng.rngmultitypid = ( SELECT rn::regtype::oid FROM edgedb_VER._get_base_scalar_type_map() AS m(tid uuid, tn text) INNER JOIN edgedb_VER._get_type_to_multirange_type_map() AS m2(tn2 text, rn text) ON tn = tn2 WHERE tid = "elemid" ) ), edgedb_VER.raise( NULL::bigint, 'invalid_parameter_value', msg => ( format( 'cannot determine OID of Gel type %L', "typeid"::text ) ) ) )::bigint ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_pg_type_for_edgedb_type'), args=[ ('typeid', ('uuid',)), ('kind', ('text',)), ('elemid', ('uuid',)), ('sql_type', ('text',)), ], returns=('bigint',), volatility='stable', text=self.text, ) class GetPgTypeForEdgeDBTypeFunction2(trampoline.VersionedFunction): """Return Postgres OID representing a given Gel type. This is an updated version that should replace the original. It takes advantage of the schema views to correctly identify non-trivial array types. """ text = f''' SELECT coalesce( sql_type::regtype::oid, ( SELECT tn::regtype::oid FROM edgedb_VER._get_base_scalar_type_map() AS m(tid uuid, tn text) WHERE m.tid = "typeid" ), ( SELECT typ.oid FROM pg_catalog.pg_type typ WHERE typ.typname = "typeid"::text || '_domain' OR typ.typname = "typeid"::text || '_t' ), ( SELECT typ.typarray FROM pg_catalog.pg_type typ WHERE "kind" = 'schema::Array' AND ( typ.typname = "elemid"::text || '_domain' OR typ.typname = "elemid"::text || '_t' OR typ.oid = ( SELECT tn::regtype::oid FROM edgedb_VER._get_base_scalar_type_map() AS m(tid uuid, tn text) WHERE tid = "elemid" ) ) ), ( SELECT typ.typarray FROM pg_catalog.pg_type typ WHERE "kind" = 'schema::Array' AND ( typ.typname = "elemid"::text || '_domain' OR typ.typname = "elemid"::text OR typ.oid = ( SELECT st.backend_id FROM edgedb_VER."_SchemaType" AS st WHERE st.id = "elemid" ) ) ), ( SELECT rng.rngtypid FROM pg_catalog.pg_range rng WHERE "kind" = 'schema::Range' -- For ranges, we need to do the lookup based on -- our internal map of elem names to range names, -- because we use the builtin daterange as the range -- for edgedbt.date_t. AND rng.rngtypid = ( SELECT rn::regtype::oid FROM edgedb_VER._get_base_scalar_type_map() AS m(tid uuid, tn text) INNER JOIN edgedb_VER._get_type_to_range_type_map() AS m2(tn2 text, rn text) ON tn = tn2 WHERE tid = "elemid" ) ), ( SELECT rng.rngmultitypid FROM pg_catalog.pg_range rng WHERE "kind" = 'schema::MultiRange' -- For multiranges, we need to do the lookup based on -- our internal map of elem names to range names, -- because we use the builtin daterange as the range -- for edgedbt.date_t. AND rng.rngmultitypid = ( SELECT rn::regtype::oid FROM edgedb_VER._get_base_scalar_type_map() AS m(tid uuid, tn text) INNER JOIN edgedb_VER._get_type_to_multirange_type_map() AS m2(tn2 text, rn text) ON tn = tn2 WHERE tid = "elemid" ) ), edgedb_VER.raise( NULL::bigint, 'invalid_parameter_value', msg => ( format( 'cannot determine Postgres OID of Gel %s(%L)%s', "kind", "typeid"::text, (case when "elemid" is not null then ' with element type ' || "elemid"::text else '' end) ) ) ) )::bigint ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'get_pg_type_for_edgedb_type'), args=[ ('typeid', ('uuid',)), ('kind', ('text',)), ('elemid', ('uuid',)), ('sql_type', ('text',)), ], returns=('bigint',), volatility='stable', text=self.text, ) class FTSParseQueryFunction(trampoline.VersionedFunction): """Return tsquery representing the given FTS input query.""" text = r''' DECLARE parts text[]; exclude text; term text; rest text; cur_op text := NULL; default_op text; tsq tsquery; el tsquery; result tsquery := ''::tsquery; BEGIN IF q IS NULL OR q = '' THEN RETURN result; END IF; -- Break up the query string into the current term, optional next -- operator and the rest. parts := regexp_match( q, $$^(-)?((?:"[^"]*")|(?:\S+))\s*(OR|AND)?\s*(.*)$$ ); exclude := parts[1]; term := parts[2]; cur_op := parts[3]; rest := parts[4]; IF starts_with(term, '"') THEN -- match as a phrase tsq := phraseto_tsquery(language, trim(both '"' from term)); ELSE tsq := to_tsquery(language, term); END IF; IF exclude IS NOT NULL THEN tsq := !!tsq; END IF; -- figure out the operator between the current term and the next one IF rest = '' THEN -- base case, one one term left, so we ignore the cur_op even if -- present IF prev_op = 'OR' THEN -- explicit 'OR' terms are "should" should := array_append(should, tsq); ELSIF starts_with(term, '"') OR exclude IS NOT NULL OR prev_op = 'AND' THEN -- phrases, exclusions and 'AND' terms are "must" must := array_append(must, tsq); ELSE -- regular terms are "should" by default should := array_append(should, tsq); END IF; ELSE -- recursion IF prev_op = 'OR' OR cur_op = 'OR' THEN -- if at least one of the suprrounding operators is 'OR', -- then the phrase is put into "should" category should := array_append(should, tsq); ELSIF prev_op = 'AND' OR cur_op = 'AND' THEN -- if at least one of the suprrounding operators is 'AND', -- then the phrase is put into "must" category must := array_append(must, tsq); ELSIF starts_with(term, '"') OR exclude IS NOT NULL THEN -- phrases and exclusions are "must" must := array_append(must, tsq); ELSE -- regular terms are "should" by default should := array_append(should, tsq); END IF; RETURN edgedb_VER.fts_parse_query( rest, language, must, should, cur_op); END IF; FOREACH el IN ARRAY should LOOP result := result || el; END LOOP; FOREACH el IN ARRAY must LOOP result := result && el; END LOOP; RETURN result; END; ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'fts_parse_query'), args=[ ('q', ('text',)), ('language', ('regconfig',), "'english'"), ('must', ('tsquery[]',), 'array[]::tsquery[]'), ('should', ('tsquery[]',), 'array[]::tsquery[]'), ('prev_op', ('text',), 'NULL'), ], returns=('tsquery',), volatility='immutable', language='plpgsql', text=self.text, ) class FTSNormalizeWeightFunction(trampoline.VersionedFunction): """Normalize an array of weights to be a 4-value weight array.""" text = r''' SELECT CASE COALESCE(array_length(weights, 1), 0) WHEN 0 THEN array[1, 1, 1, 1]::float4[] WHEN 1 THEN array[0, 0, 0, weights[1]]::float4[] WHEN 2 THEN array[0, 0, weights[2], weights[1]]::float4[] WHEN 3 THEN array[0, weights[3], weights[2], weights[1]]::float4[] ELSE ( WITH raw as ( SELECT w FROM UNNEST(weights) AS w ORDER BY w DESC ) SELECT array_prepend(rest.w, first.arrw)::float4[] FROM ( SELECT array_agg(rw1.w) as arrw FROM ( SELECT w FROM (SELECT w FROM raw LIMIT 3) as rw0 ORDER BY w ASC ) as rw1 ) AS first, ( SELECT avg(rw2.w) as w FROM (SELECT w FROM raw OFFSET 3) as rw2 ) AS rest ) END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'fts_normalize_weights'), args=[ ('weights', ('float8[]',)), ], returns=('float4[]',), volatility='immutable', text=self.text, ) class FTSNormalizeDocFunction(trampoline.VersionedFunction): """Normalize a document based on an array of weights.""" text = r''' SELECT CASE COALESCE(array_length(doc, 1), 0) WHEN 0 THEN ''::tsvector WHEN 1 THEN setweight(to_tsvector(language, doc[1]), 'A') WHEN 2 THEN ( setweight(to_tsvector(language, doc[1]), 'A') || setweight(to_tsvector(language, doc[2]), 'B') ) WHEN 3 THEN ( setweight(to_tsvector(language, doc[1]), 'A') || setweight(to_tsvector(language, doc[2]), 'B') || setweight(to_tsvector(language, doc[3]), 'C') ) ELSE ( WITH raw as ( SELECT d.v as t FROM UNNEST(doc) WITH ORDINALITY AS d(v, n) LEFT JOIN UNNEST(weights) WITH ORDINALITY AS w(v, n) ON d.n = w.n ORDER BY w.v DESC ) SELECT setweight(to_tsvector(language, d.arr[1]), 'A') || setweight(to_tsvector(language, d.arr[2]), 'B') || setweight(to_tsvector(language, d.arr[3]), 'C') || setweight(to_tsvector(language, array_to_string(d.arr[4:], ' ')), 'D') FROM ( SELECT array_agg(raw.t) as arr FROM raw ) AS d ) END ''' def __init__(self) -> None: super().__init__( name=('edgedb', 'fts_normalize_doc'), args=[ ('doc', ('text[]',)), ('weights', ('float8[]',)), ('language', ('regconfig',)), ], returns=('tsvector',), volatility='stable', text=self.text, ) class FTSToRegconfig(trampoline.VersionedFunction): """ Converts ISO 639-3 language identifiers into a regconfig. Defaults to english. Identifiers prefixed with 'xxx_' have the prefix stripped and the remainder used as regconfg identifier. """ def __init__(self) -> None: super().__init__( name=('edgedb', 'fts_to_regconfig'), args=[ ('language', ('text',)), ], returns=('regconfig',), volatility='immutable', text=''' SELECT CASE WHEN language ILIKE 'xxx_%' THEN SUBSTR(language, 4) ELSE (CASE LOWER(language) WHEN 'ara' THEN 'arabic' WHEN 'hye' THEN 'armenian' WHEN 'eus' THEN 'basque' WHEN 'cat' THEN 'catalan' WHEN 'dan' THEN 'danish' WHEN 'nld' THEN 'dutch' WHEN 'eng' THEN 'english' WHEN 'fin' THEN 'finnish' WHEN 'fra' THEN 'french' WHEN 'deu' THEN 'german' WHEN 'ell' THEN 'greek' WHEN 'hin' THEN 'hindi' WHEN 'hun' THEN 'hungarian' WHEN 'ind' THEN 'indonesian' WHEN 'gle' THEN 'irish' WHEN 'ita' THEN 'italian' WHEN 'lit' THEN 'lithuanian' WHEN 'npi' THEN 'nepali' WHEN 'nor' THEN 'norwegian' WHEN 'por' THEN 'portuguese' WHEN 'ron' THEN 'romanian' WHEN 'rus' THEN 'russian' WHEN 'srp' THEN 'serbian' WHEN 'spa' THEN 'spanish' WHEN 'swe' THEN 'swedish' WHEN 'tam' THEN 'tamil' WHEN 'tur' THEN 'turkish' WHEN 'yid' THEN 'yiddish' ELSE 'english' END ) END::pg_catalog.regconfig; ''', ) class UuidGenerateV1mcFunction(trampoline.VersionedFunction): def __init__(self, ext_schema: str) -> None: super().__init__( name=('edgedb', 'uuid_generate_v1mc'), args=[], returns=('uuid',), volatility='volatile', language='sql', strict=True, parallel_safe=True, text=f'SELECT "{ext_schema}".uuid_generate_v1mc();' ) class UuidGenerateV4Function(trampoline.VersionedFunction): def __init__(self, ext_schema: str) -> None: super().__init__( name=('edgedb', 'uuid_generate_v4'), args=[], returns=('uuid',), volatility='volatile', language='sql', strict=True, parallel_safe=True, text=f'SELECT "{ext_schema}".uuid_generate_v4();' ) class UuidGenerateV5Function(trampoline.VersionedFunction): def __init__(self, ext_schema: str) -> None: super().__init__( name=('edgedb', 'uuid_generate_v5'), args=[ ('namespace', ('uuid',)), ('name', ('text',)), ], returns=('uuid',), volatility='immutable', language='sql', strict=True, parallel_safe=True, text=f'SELECT "{ext_schema}".uuid_generate_v5(namespace, name);' ) class PadBase64StringFunction(trampoline.VersionedFunction): text = r""" WITH l AS (SELECT pg_catalog.length("s") % 4 AS r), p AS ( SELECT (CASE WHEN l.r > 0 THEN repeat('=', (4 - l.r)) ELSE '' END) AS p FROM l ) SELECT "s" || p.p FROM p """ def __init__(self) -> None: super().__init__( name=('edgedb', 'pad_base64_string'), args=[ ('s', ('text',)), ], returns=('text',), volatility='immutable', language='sql', strict=True, parallel_safe=True, text=self.text, ) class ResetQueryStatsFunction(trampoline.VersionedFunction): text = r""" DECLARE tenant_id TEXT; other_tenant_exists BOOLEAN; db_oid OID; queryid bigint; BEGIN tenant_id := edgedb_VER.get_backend_tenant_id(); IF id IS NULL THEN queryid := 0; ELSE queryid := edgedbext.edb_stat_queryid(id); END IF; SELECT EXISTS ( SELECT 1 FROM pg_database dat CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(dat.oid, 'pg_database') AS description ) AS d WHERE (d.description)->>'id' IS NOT NULL AND (d.description)->>'tenant_id' != tenant_id ) INTO other_tenant_exists; IF branch_name IS NULL THEN IF other_tenant_exists THEN RETURN edgedbext.edb_stat_statements_reset( 0, -- userid ARRAY( SELECT dat.oid FROM pg_database dat CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(dat.oid, 'pg_database') AS description ) AS d WHERE (d.description)->>'id' IS NOT NULL AND (d.description)->>'tenant_id' = tenant_id ), queryid, COALESCE(minmax_only, false) ); ELSE RETURN edgedbext.edb_stat_statements_reset( 0, -- userid '{}', -- database oid queryid, COALESCE(minmax_only, false) ); END IF; ELSE SELECT dat.oid INTO db_oid FROM pg_database dat CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(dat.oid, 'pg_database') AS description ) AS d WHERE (d.description)->>'id' IS NOT NULL AND (d.description)->>'tenant_id' = tenant_id AND edgedb_VER.get_database_frontend_name(dat.datname) = branch_name; IF db_oid IS NULL THEN RETURN NULL::edgedbt.timestamptz_t; END IF; RETURN edgedbext.edb_stat_statements_reset( 0, -- userid ARRAY[db_oid], queryid, COALESCE(minmax_only, false) ); END IF; RETURN now()::edgedbt.timestamptz_t; END; """ noop_text = r""" BEGIN RETURN NULL::edgedbt.timestamptz_t; END; """ def __init__(self, enable_stats: bool) -> None: super().__init__( name=('edgedb', 'reset_query_stats'), args=[ ('branch_name', ('text',)), ('id', ('uuid',)), ('minmax_only', ('bool',)), ], returns=('edgedbt', 'timestamptz_t'), volatility='volatile', language='plpgsql', text=self.text if enable_stats else self.noop_text, ) # N.B: This is a VersionedFunction but it can not be trampolined, since # the trampoline wrapper can't be a trigger. class ClearFELocalSQLSettingsFunction(trampoline.VersionedFunction): text = r""" BEGIN DELETE FROM _edgecon_state WHERE type = 'L' AND name = NEW.name; RETURN NEW; END; """ def __init__(self) -> None: super().__init__( name=('edgedb', '_clear_fe_local_sql_settings'), args=[], returns='trigger', language='plpgsql', volatility='volatile', text=self.text, ) def _maybe_trampoline( cmd: dbops.Command, out: list[trampoline.Trampoline] ) -> None: namespace = V('') if ( isinstance(cmd, dbops.CreateFunction) and cmd.function.name[0].endswith(namespace) ): out.append(trampoline.make_trampoline(cmd.function)) elif ( isinstance(cmd, dbops.CreateView) and cmd.view.name[0].endswith(namespace) ): out.append(trampoline.make_view_trampoline(cmd.view)) elif ( isinstance(cmd, dbops.CreateTable) and cmd.table.name[0].endswith(namespace) ): f, n = cmd.table.name out.append(trampoline.make_table_trampoline((f, n))) def trampoline_functions( cmds: Sequence[dbops.Command] ) -> list[trampoline.Trampoline]: ncmds: list[trampoline.Trampoline] = [] for cmd in cmds: _maybe_trampoline(cmd, ncmds) return ncmds def trampoline_command(cmd: dbops.Command) -> list[trampoline.Trampoline]: ncmds: list[trampoline.Trampoline] = [] def go(cmd: dbops.Command) -> None: if isinstance(cmd, dbops.CommandGroup): for subcmd in cmd.commands: go(subcmd) else: _maybe_trampoline(cmd, ncmds) go(cmd) return ncmds def get_fixed_bootstrap_commands() -> dbops.CommandGroup: """Create metaschema objects that are truly global""" cmds = [ dbops.CreateSchema(name='edgedb'), dbops.CreateSchema(name='edgedbt'), dbops.CreateSchema(name='edgedbpub'), dbops.CreateSchema(name='edgedbstd'), dbops.CreateSchema(name='edgedbinstdata'), dbops.CreateTable( DBConfigTable(), ), # TODO: SHOULD THIS BE VERSIONED? dbops.CreateTable(QueryCacheTable()), dbops.CreateDomain(BigintDomain()), dbops.CreateDomain(ConfigMemoryDomain()), dbops.CreateDomain(TimestampTzDomain()), dbops.CreateDomain(TimestampDomain()), dbops.CreateDomain(DateDomain()), dbops.CreateDomain(DurationDomain()), dbops.CreateDomain(RelativeDurationDomain()), dbops.CreateDomain(DateDurationDomain()), dbops.CreateEnum(SysConfigSourceType()), dbops.CreateEnum(SysConfigScopeType()), dbops.CreateCompositeType(SysConfigValueType()), dbops.CreateCompositeType(SysConfigEntryType()), dbops.CreateRange(Float32Range()), dbops.CreateRange(Float64Range()), dbops.CreateRange(DatetimeRange()), dbops.CreateRange(LocalDatetimeRange()), ] commands = dbops.CommandGroup() commands.add_commands(cmds) return commands def get_instdata_commands( ) -> tuple[dbops.CommandGroup, list[trampoline.Trampoline]]: cmds = [ dbops.CreateSchema(name=V('edgedbinstdata')), dbops.CreateTable(InstDataTable()), ] commands = dbops.CommandGroup() commands.add_commands(cmds) return commands, trampoline_functions(cmds) async def generate_instdata_table( conn: PGConnection, ) -> list[trampoline.Trampoline]: commands, trampolines = get_instdata_commands() block = dbops.PLTopBlock() commands.generate(block) await _execute_block(conn, block) return trampolines def get_bootstrap_commands( config_spec: edbconfig.Spec, ) -> tuple[dbops.CommandGroup, list[trampoline.Trampoline]]: trampolined = [ dbops.CreateSchema(name=V('edgedb')), dbops.CreateSchema(name=V('edgedbpub')), dbops.CreateSchema(name=V('edgedbstd')), dbops.CreateSchema(name=V('edgedbsql')), dbops.CreateView(NormalizedPgSettingsView()), dbops.CreateFunction(EvictQueryCacheFunction()), dbops.CreateFunction(ClearQueryCacheFunction()), dbops.CreateFunction(CreateTrampolineViewFunction()), dbops.CreateFunction(UuidGenerateV1mcFunction('edgedbext')), dbops.CreateFunction(UuidGenerateV4Function('edgedbext')), dbops.CreateFunction(UuidGenerateV5Function('edgedbext')), dbops.CreateFunction(IntervalToMillisecondsFunction()), dbops.CreateFunction(SafeIntervalCastFunction()), dbops.CreateFunction(QuoteIdentFunction()), dbops.CreateFunction(QuoteNameFunction()), dbops.CreateFunction(AlterCurrentDatabaseSetString()), dbops.CreateFunction(AlterCurrentDatabaseSetStringArray()), dbops.CreateFunction(AlterCurrentDatabaseSetNonArray()), dbops.CreateFunction(AlterCurrentDatabaseSetArray()), dbops.CreateFunction(CopyDatabaseConfigs()), dbops.CreateFunction(GetBackendCapabilitiesFunction()), dbops.CreateFunction(GetBackendTenantIDFunction()), dbops.CreateFunction(GetDatabaseBackendNameFunction()), dbops.CreateFunction(GetDatabaseFrontendNameFunction()), dbops.CreateFunction(GetRoleBackendNameFunction()), dbops.CreateFunction(GetUserSequenceBackendNameFunction()), dbops.CreateFunction(GetStdModulesFunction()), dbops.CreateFunction(GetObjectMetadata()), dbops.CreateFunction(GetColumnMetadata()), dbops.CreateFunction(GetSharedObjectMetadata()), dbops.CreateFunction(GetDatabaseMetadataFunction()), dbops.CreateFunction(GetCurrentDatabaseFunction()), dbops.CreateFunction(RaiseNoticeFunction()), dbops.CreateFunction(IndirectReturnFunction()), dbops.CreateFunction(RaiseExceptionFunction()), dbops.CreateFunction(RaiseExceptionOnNullFunction()), dbops.CreateFunction(RaiseExceptionOnNotNullFunction()), dbops.CreateFunction(RaiseExceptionOnEmptyStringFunction()), dbops.CreateFunction(AssertJSONTypeFunction()), dbops.CreateFunction(ExtractJSONScalarFunction()), dbops.CreateFunction(NormalizeNameFunction()), dbops.CreateFunction(GetNameModuleFunction()), dbops.CreateFunction(NullIfArrayNullsFunction()), dbops.CreateFunction(StrToConfigMemoryFunction()), dbops.CreateFunction(ConfigMemoryToStrFunction()), dbops.CreateFunction(StrToBigint()), dbops.CreateFunction(StrToDecimal()), dbops.CreateFunction(StrToInt64NoInline()), dbops.CreateFunction(StrToInt32NoInline()), dbops.CreateFunction(StrToInt16NoInline()), dbops.CreateFunction(StrToFloat64NoInline()), dbops.CreateFunction(StrToFloat32NoInline()), dbops.CreateFunction(NormalizeArrayIndexFunction()), dbops.CreateFunction(NormalizeArraySliceIndexFunction()), dbops.CreateFunction(IntOrNullFunction()), dbops.CreateFunction(ArrayIndexWithBoundsFunction()), dbops.CreateFunction(ArraySliceFunction()), dbops.CreateFunction(StringIndexWithBoundsFunction()), dbops.CreateFunction(LengthStringProxyFunction()), dbops.CreateFunction(LengthBytesProxyFunction()), dbops.CreateFunction(SubstrProxyFunction()), dbops.CreateFunction(StringSliceImplFunction()), dbops.CreateFunction(StringSliceFunction()), dbops.CreateFunction(BytesSliceFunction()), dbops.CreateFunction(JSONIndexByTextFunction()), dbops.CreateFunction(JSONIndexByIntFunction()), dbops.CreateFunction(JSONSliceFunction()), dbops.CreateFunction(DatetimeInFunction()), dbops.CreateFunction(DurationInFunction()), dbops.CreateFunction(DateDurationInFunction()), dbops.CreateFunction(LocalDatetimeInFunction()), dbops.CreateFunction(LocalDateInFunction()), dbops.CreateFunction(LocalTimeInFunction()), dbops.CreateFunction(ToTimestampTZCheck()), dbops.CreateFunction(ToDatetimeFunction()), dbops.CreateFunction(ToLocalDatetimeFunction()), dbops.CreateFunction(StrToBool()), dbops.CreateFunction(BytesIndexWithBoundsFunction()), dbops.CreateFunction(TypeIDToConfigType()), dbops.CreateFunction(ConvertPostgresConfigUnitsFunction()), dbops.CreateFunction(InterpretConfigValueToJsonFunction()), dbops.CreateFunction( PostgresJsonConfigValueToFrontendConfigValueFunction(config_spec)), dbops.CreateFunction(PostgresConfigValueToJsonFunction()), dbops.CreateFunction(SysConfigFullFunction()), dbops.CreateFunction(SysConfigUncachedFunction()), dbops.Query(pgcon.SETUP_CONFIG_CACHE_SCRIPT), dbops.CreateFunction(SysConfigFunction()), dbops.CreateFunction(SysClearConfigCacheFunction()), dbops.CreateFunction(ResetSessionConfigFunction()), dbops.CreateFunction(ApplySessionConfigFunction(config_spec)), dbops.CreateFunction(SysGetTransactionIsolation()), dbops.CreateFunction(GetCachedReflection()), dbops.CreateFunction(GetBaseScalarTypeMap()), dbops.CreateFunction(GetTypeToRangeNameMap()), dbops.CreateFunction(GetTypeToMultiRangeNameMap()), dbops.CreateFunction(GetPgTypeForEdgeDBTypeFunction()), dbops.CreateFunction(DescribeRolesAsDDLFunctionForwardDecl()), dbops.CreateFunction(AllRoleMembershipsFunctionForwardDecl()), dbops.CreateFunction(RangeToJsonFunction()), dbops.CreateFunction(MultiRangeToJsonFunction()), dbops.CreateFunction(RangeValidateFunction()), dbops.CreateFunction(RangeUnpackLowerValidateFunction()), dbops.CreateFunction(RangeUnpackUpperValidateFunction()), dbops.CreateFunction(FTSParseQueryFunction()), dbops.CreateFunction(FTSNormalizeWeightFunction()), dbops.CreateFunction(FTSNormalizeDocFunction()), dbops.CreateFunction(FTSToRegconfig()), dbops.CreateFunction(PadBase64StringFunction()), dbops.CreateFunction(ResetQueryStatsFunction(False)), dbops.CreateFunction(ApproximateCountDummy()), ] non_trampolined = [ dbops.CreateFunction(ClearFELocalSQLSettingsFunction()), ] commands = dbops.CommandGroup() commands.add_commands(trampolined) commands.add_commands(non_trampolined) return commands, trampoline_functions(trampolined) async def create_pg_extensions( conn: PGConnection, backend_params: params.BackendRuntimeParams, ) -> None: inst_params = backend_params.instance_params ext_schema = inst_params.ext_schema # Both the extension schema, and the desired extension # might already exist in a single database backend, # attempt to create things conditionally. commands = dbops.CommandGroup() commands.add_command( dbops.CreateSchema(name=ext_schema, conditional=True), ) extensions = ["uuid-ossp"] if backend_params.has_stat_statements: extensions.append("edb_stat_statements") for ext in extensions: if ( inst_params.existing_exts is None or inst_params.existing_exts.get(ext) is None ): commands.add_commands([ dbops.CreateExtension( dbops.Extension(name=ext, schema=ext_schema), ), ]) block = dbops.PLTopBlock() commands.generate(block) await _execute_block(conn, block) async def patch_pg_extensions( conn: PGConnection, backend_params: params.BackendRuntimeParams, ) -> None: # A single database backend might restrict creation of extensions # to a specific schema, or restrict creation of extensions altogether # and provide a way to register them using a different method # (e.g. a hosting panel UI). inst_params = backend_params.instance_params if inst_params.existing_exts is not None: uuid_ext_schema = inst_params.existing_exts.get("uuid-ossp") if uuid_ext_schema is None: uuid_ext_schema = inst_params.ext_schema else: uuid_ext_schema = inst_params.ext_schema commands = dbops.CommandGroup() if uuid_ext_schema != "edgedbext": commands.add_commands([ dbops.CreateFunction( UuidGenerateV1mcFunction(uuid_ext_schema), or_replace=True), dbops.CreateFunction( UuidGenerateV4Function(uuid_ext_schema), or_replace=True), dbops.CreateFunction( UuidGenerateV5Function(uuid_ext_schema), or_replace=True), ]) if len(commands) > 0: block = dbops.PLTopBlock() commands.generate(block) await _execute_block(conn, block) classref_attr_aliases = { 'links': 'pointers', 'link_properties': 'pointers' } def tabname( schema: s_schema.Schema, obj: s_obj.Object ) -> tuple[str, str]: return common.get_backend_name( schema, obj, aspect='table', catenate=False, versioned=True, ) def ptr_col_name( schema: s_schema.Schema, obj: s_sources.Source, propname: str, ) -> str: prop = obj.getptr(schema, s_name.UnqualName(propname)) psi = types.get_pointer_storage_info(prop, schema=schema) return psi.column_name def format_fields( schema: s_schema.Schema, obj: s_sources.Source, fields: dict[str, str], ) -> str: """Format a dictionary of column mappings for database views The reason we do it this way is because, since these views are overwriting existing temporary views, we need to put all the columns in the same order as the original view. """ ptrs = [obj.getptr(schema, s_name.UnqualName(s)) for s in fields] # Sort by the order the pointers were added to the source. # N.B: This only works because we are using the original in-memory # schema. If it was loaded from reflection it probably wouldn't # work. ptr_indexes = { v: i for i, v in enumerate(obj.get_pointers(schema).objects(schema)) } ptrs.sort(key=( lambda p: (not p.is_link_source_property(schema), ptr_indexes[p]) )) cols = [] for ptr in ptrs: name = ptr.get_shortname(schema).name val = fields[name] sname = qi(ptr_col_name(schema, obj, name)) cols.append(f' {val} AS {sname}') return ',\n'.join(cols) def _generate_branch_views(schema: s_schema.Schema) -> list[dbops.View]: Branch = schema.get('sys::Branch', type=s_objtypes.ObjectType) annos = Branch.getptr( schema, s_name.UnqualName('annotations'), type=s_links.Link) int_annos = Branch.getptr( schema, s_name.UnqualName('annotations__internal'), type=s_links.Link) view_fields = { 'id': "((d.description)->>'id')::uuid", 'internal': f"""(CASE WHEN (edgedb_VER.get_backend_capabilities() & {int(params.BackendCapabilities.CREATE_DATABASE)}) != 0 THEN datname IN ( edgedb_VER.get_database_backend_name( {ql(defines.EDGEDB_TEMPLATE_DB)}), edgedb_VER.get_database_backend_name( {ql(defines.EDGEDB_SYSTEM_DB)}) ) ELSE False END )""", 'name': ( 'edgedb_VER.get_database_frontend_name(datname) COLLATE "default"' ), 'name__internal': ( 'edgedb_VER.get_database_frontend_name(datname) COLLATE "default"' ), 'computed_fields': 'ARRAY[]::text[]', 'builtin': "((d.description)->>'builtin')::bool", 'last_migration': "(d.description)->>'last_migration'", } view_query = f''' SELECT {format_fields(schema, Branch, view_fields)} FROM pg_database dat CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(dat.oid, 'pg_database') AS description ) AS d WHERE (d.description)->>'id' IS NOT NULL AND (d.description)->>'tenant_id' = edgedb_VER.get_backend_tenant_id() ''' annos_link_fields = { 'source': "((d.description)->>'id')::uuid", 'target': "(annotations->>'id')::uuid", 'value': "(annotations->>'value')::text", 'owned': "(annotations->>'owned')::bool", } annos_link_query = f''' SELECT {format_fields(schema, annos, annos_link_fields)} FROM pg_database dat CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(dat.oid, 'pg_database') AS description ) AS d CROSS JOIN LATERAL ROWS FROM ( jsonb_array_elements((d.description)->'annotations') ) AS annotations ''' int_annos_link_fields = { 'source': "((d.description)->>'id')::uuid", 'target': "(annotations->>'id')::uuid", 'owned': "(annotations->>'owned')::bool", } int_annos_link_query = f''' SELECT {format_fields(schema, int_annos, int_annos_link_fields)} FROM pg_database dat CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(dat.oid, 'pg_database') AS description ) AS d CROSS JOIN LATERAL ROWS FROM ( jsonb_array_elements( (d.description)->'annotations__internal' ) ) AS annotations ''' objects = { Branch: view_query, annos: annos_link_query, int_annos: int_annos_link_query, } views: list[dbops.View] = [] for obj, query in objects.items(): tabview = trampoline.VersionedView( name=tabname(schema, obj), query=query) views.append(tabview) return views def _generate_extension_views(schema: s_schema.Schema) -> list[dbops.View]: ExtPkg = schema.get('sys::ExtensionPackage', type=s_objtypes.ObjectType) annos = ExtPkg.getptr( schema, s_name.UnqualName('annotations'), type=s_links.Link) int_annos = ExtPkg.getptr( schema, s_name.UnqualName('annotations__internal'), type=s_links.Link) ver = ExtPkg.getptr( schema, s_name.UnqualName('version'), type=s_props.Property) ver_t = common.get_backend_name( schema, not_none(ver.get_target(schema)), catenate=False, ) view_query_fields = { 'id': "(e.value->>'id')::uuid", 'name': "(e.value->>'name')", 'name__internal': "(e.value->>'name__internal')", 'script': "(e.value->>'script')", 'sql_extensions': ''' COALESCE( (SELECT array_agg(edgedb_VER.jsonb_extract_scalar(q.v, 'string')) FROM jsonb_array_elements( e.value->'sql_extensions' ) AS q(v)), ARRAY[]::text[] ) ''', 'dependencies': ''' COALESCE( (SELECT array_agg(edgedb_VER.jsonb_extract_scalar(q.v, 'string')) FROM jsonb_array_elements( e.value->'dependencies' ) AS q(v)), ARRAY[]::text[] ) ''', 'ext_module': "(e.value->>'ext_module')", 'sql_setup_script': "(e.value->>'sql_setup_script')", 'sql_teardown_script': "(e.value->>'sql_teardown_script')", 'computed_fields': 'ARRAY[]::text[]', 'builtin': "(e.value->>'builtin')::bool", 'internal': "(e.value->>'internal')::bool", 'version': f''' ( (e.value->'version'->>'major')::int, (e.value->'version'->>'minor')::int, (e.value->'version'->>'stage')::text, (e.value->'version'->>'stage_no')::int, COALESCE( (SELECT array_agg(q.v::text) FROM jsonb_array_elements( e.value->'version'->'local' ) AS q(v)), ARRAY[]::text[] ) )::{qt(ver_t)} ''', } view_query = f''' SELECT {format_fields(schema, ExtPkg, view_query_fields)} FROM jsonb_each( edgedb_VER.get_database_metadata( {ql(defines.EDGEDB_TEMPLATE_DB)} ) -> 'ExtensionPackage' ) AS e ''' annos_link_fields = { 'source': "(e.value->>'id')::uuid", 'target': "(annotations->>'id')::uuid", 'value': "(annotations->>'value')::text", 'owned': "(annotations->>'owned')::bool", } int_annos_link_fields = { 'source': "(e.value->>'id')::uuid", 'target': "(annotations->>'id')::uuid", 'owned': "(annotations->>'owned')::bool", } annos_link_query = f''' SELECT {format_fields(schema, annos, annos_link_fields)} FROM jsonb_each( edgedb_VER.get_database_metadata( {ql(defines.EDGEDB_TEMPLATE_DB)} ) -> 'ExtensionPackage' ) AS e CROSS JOIN LATERAL ROWS FROM ( jsonb_array_elements(e.value->'annotations') ) AS annotations ''' int_annos_link_query = f''' SELECT {format_fields(schema, int_annos, int_annos_link_fields)} FROM jsonb_each( edgedb_VER.get_database_metadata( {ql(defines.EDGEDB_TEMPLATE_DB)} ) -> 'ExtensionPackage' ) AS e CROSS JOIN LATERAL ROWS FROM ( jsonb_array_elements(e.value->'annotations__internal') ) AS annotations ''' objects = { ExtPkg: view_query, annos: annos_link_query, int_annos: int_annos_link_query, } views: list[dbops.View] = [] for obj, query in objects.items(): tabview = trampoline.VersionedView( name=tabname(schema, obj), query=query) views.append(tabview) return views def _generate_extension_migration_views( schema: s_schema.Schema ) -> list[dbops.View]: ExtPkgMigration = schema.get( 'sys::ExtensionPackageMigration', type=s_objtypes.ObjectType) annos = ExtPkgMigration.getptr( schema, s_name.UnqualName('annotations'), type=s_links.Link) int_annos = ExtPkgMigration.getptr( schema, s_name.UnqualName('annotations__internal'), type=s_links.Link) from_ver = ExtPkgMigration.getptr( schema, s_name.UnqualName('from_version'), type=s_props.Property) ver_t = common.get_backend_name( schema, not_none(from_ver.get_target(schema)), catenate=False, ) view_query_fields = { 'id': "(e.value->>'id')::uuid", 'name': "(e.value->>'name')", 'name__internal': "(e.value->>'name__internal')", 'script': "(e.value->>'script')", 'sql_early_script': "(e.value->>'sql_early_script')", 'sql_late_script': "(e.value->>'sql_late_script')", 'computed_fields': 'ARRAY[]::text[]', 'builtin': "(e.value->>'builtin')::bool", 'internal': "(e.value->>'internal')::bool", # XXX: code duplication here 'from_version': f''' ( (e.value->'from_version'->>'major')::int, (e.value->'from_version'->>'minor')::int, (e.value->'from_version'->>'stage')::text, (e.value->'from_version'->>'stage_no')::int, COALESCE( (SELECT array_agg(q.v::text) FROM jsonb_array_elements( e.value->'from_version'->'local' ) AS q(v)), ARRAY[]::text[] ) )::{qt(ver_t)} ''', 'to_version': f''' ( (e.value->'to_version'->>'major')::int, (e.value->'to_version'->>'minor')::int, (e.value->'to_version'->>'stage')::text, (e.value->'to_version'->>'stage_no')::int, COALESCE( (SELECT array_agg(q.v::text) FROM jsonb_array_elements( e.value->'to_version'->'local' ) AS q(v)), ARRAY[]::text[] ) )::{qt(ver_t)} ''', } view_query = f''' SELECT {format_fields(schema, ExtPkgMigration, view_query_fields)} FROM jsonb_each( edgedb_VER.get_database_metadata( {ql(defines.EDGEDB_TEMPLATE_DB)} ) -> 'ExtensionPackageMigration' ) AS e ''' annos_link_fields = { 'source': "(e.value->>'id')::uuid", 'target': "(annotations->>'id')::uuid", 'value': "(annotations->>'value')::text", 'owned': "(annotations->>'owned')::bool", } int_annos_link_fields = { 'source': "(e.value->>'id')::uuid", 'target': "(annotations->>'id')::uuid", 'owned': "(annotations->>'owned')::bool", } annos_link_query = f''' SELECT {format_fields(schema, annos, annos_link_fields)} FROM jsonb_each( edgedb_VER.get_database_metadata( {ql(defines.EDGEDB_TEMPLATE_DB)} ) -> 'ExtensionPackageMigration' ) AS e CROSS JOIN LATERAL ROWS FROM ( jsonb_array_elements(e.value->'annotations') ) AS annotations ''' int_annos_link_query = f''' SELECT {format_fields(schema, int_annos, int_annos_link_fields)} FROM jsonb_each( edgedb_VER.get_database_metadata( {ql(defines.EDGEDB_TEMPLATE_DB)} ) -> 'ExtensionPackageMigration' ) AS e CROSS JOIN LATERAL ROWS FROM ( jsonb_array_elements(e.value->'annotations__internal') ) AS annotations ''' objects = { ExtPkgMigration: view_query, annos: annos_link_query, int_annos: int_annos_link_query, } views: list[dbops.View] = [] for obj, query in objects.items(): tabview = trampoline.VersionedView( name=tabname(schema, obj), query=query) views.append(tabview) return views def _generate_role_views(schema: s_schema.Schema) -> list[dbops.View]: Role = schema.get('sys::Role', type=s_objtypes.ObjectType) member_of = Role.getptr( schema, s_name.UnqualName('member_of'), type=s_links.Link ) bases = Role.getptr( schema, s_name.UnqualName('bases'), type=s_links.Link ) ancestors = Role.getptr( schema, s_name.UnqualName('ancestors'), type=s_links.Link ) annos = Role.getptr( schema, s_name.UnqualName('annotations'), type=s_links.Link ) int_annos = Role.getptr( schema, s_name.UnqualName('annotations__internal'), type=s_links.Link ) permissions = Role.getptr( schema, s_name.UnqualName('permissions'), type=s_props.Property ) branches = Role.getptr( schema, s_name.UnqualName('branches'), type=s_props.Property ) superuser = f''' a.rolsuper OR EXISTS ( SELECT FROM pg_auth_members m INNER JOIN pg_catalog.pg_roles g ON (m.roleid = g.oid) WHERE m.member = a.oid AND g.rolname = edgedb_VER.get_role_backend_name( {ql(defines.EDGEDB_SUPERGROUP)} ) ) ''' view_query_fields = { 'id': "((d.description)->>'id')::uuid", 'name': "(d.description)->>'name'", 'name__internal': "(d.description)->>'name'", 'superuser': f'{superuser}', 'abstract': 'False', 'is_derived': 'False', 'inherited_fields': 'ARRAY[]::text[]', 'computed_fields': 'ARRAY[]::text[]', 'builtin': "((d.description)->>'builtin')::bool", 'internal': 'False', 'password': "(d.description)->>'password_hash'", 'apply_access_policies_pg_default': ( "((d.description)->>'apply_access_policies_pg_default')::bool" ), } view_query = f''' SELECT {format_fields(schema, Role, view_query_fields)} FROM pg_catalog.pg_roles AS a CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(a.oid, 'pg_authid') AS description ) AS d WHERE (d.description)->>'id' IS NOT NULL AND (d.description)->>'tenant_id' = edgedb_VER.get_backend_tenant_id() ''' member_of_link_query_fields = { 'source': "((d.description)->>'id')::uuid", 'target': "((md.description)->>'id')::uuid", } member_of_link_query = f''' SELECT {format_fields(schema, member_of, member_of_link_query_fields)} FROM pg_catalog.pg_roles AS a CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(a.oid, 'pg_authid') AS description ) AS d INNER JOIN pg_auth_members m ON m.member = a.oid CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(m.roleid, 'pg_authid') AS description ) AS md ''' bases_link_query_fields = { 'source': "((d.description)->>'id')::uuid", 'target': "((md.description)->>'id')::uuid", 'index': 'row_number() OVER (PARTITION BY a.oid ORDER BY m.roleid)', } bases_link_query = f''' SELECT {format_fields(schema, bases, bases_link_query_fields)} FROM pg_catalog.pg_roles AS a CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(a.oid, 'pg_authid') AS description ) AS d INNER JOIN pg_auth_members m ON m.member = a.oid CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(m.roleid, 'pg_authid') AS description ) AS md ''' ancestors_link_query = f''' SELECT {format_fields(schema, ancestors, bases_link_query_fields)} FROM pg_catalog.pg_roles AS a CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(a.oid, 'pg_authid') AS description ) AS d INNER JOIN pg_auth_members m ON m.member = a.oid CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(m.roleid, 'pg_authid') AS description ) AS md ''' annos_link_fields = { 'source': "((d.description)->>'id')::uuid", 'target': "(annotations->>'id')::uuid", 'value': "(annotations->>'value')::text", 'owned': "(annotations->>'owned')::bool", } annos_link_query = f''' SELECT {format_fields(schema, annos, annos_link_fields)} FROM pg_catalog.pg_roles AS a CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(a.oid, 'pg_authid') AS description ) AS d CROSS JOIN LATERAL ROWS FROM ( jsonb_array_elements( (d.description)->'annotations' ) ) AS annotations ''' int_annos_link_fields = { 'source': "((d.description)->>'id')::uuid", 'target': "(annotations->>'id')::uuid", 'owned': "(annotations->>'owned')::bool", } int_annos_link_query = f''' SELECT {format_fields(schema, int_annos, int_annos_link_fields)} FROM pg_catalog.pg_roles AS a CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(a.oid, 'pg_authid') AS description ) AS d CROSS JOIN LATERAL ROWS FROM ( jsonb_array_elements( (d.description)->'annotations__internal' ) ) AS annotations ''' permissions_query = f''' SELECT ((d.description)->>'id')::uuid AS source, jsonb_array_elements_text( (d.description)->'permissions' )::text as target FROM pg_catalog.pg_roles AS a CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(a.oid, 'pg_authid') AS description ) AS d WHERE (d.description)->>'id' IS NOT NULL AND (d.description)->>'tenant_id' = edgedb_VER.get_backend_tenant_id() ''' branches_query = f''' SELECT ((d.description)->>'id')::uuid AS source, jsonb_array_elements_text( -- The coalesce is to handle inplace upgrades from versions -- before the field was added. If it is lacking from the dict, -- make it ['*']. coalesce((d.description)->'branches', '["*"]'::jsonb) )::text as target FROM pg_catalog.pg_roles AS a CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(a.oid, 'pg_authid') AS description ) AS d WHERE (d.description)->>'id' IS NOT NULL AND (d.description)->>'tenant_id' = edgedb_VER.get_backend_tenant_id() ''' objects = { Role: view_query, member_of: member_of_link_query, bases: bases_link_query, ancestors: ancestors_link_query, annos: annos_link_query, int_annos: int_annos_link_query, permissions: permissions_query, branches: branches_query, } views: list[dbops.View] = [] for obj, query in objects.items(): tabview = trampoline.VersionedView( name=tabname(schema, obj), query=query) views.append(tabview) return views def _generate_single_role_views(schema: s_schema.Schema) -> list[dbops.View]: Role = schema.get('sys::Role', type=s_objtypes.ObjectType) member_of = Role.getptr( schema, s_name.UnqualName('member_of'), type=s_links.Link ) bases = Role.getptr( schema, s_name.UnqualName('bases'), type=s_links.Link ) ancestors = Role.getptr( schema, s_name.UnqualName('ancestors'), type=s_links.Link ) annos = Role.getptr( schema, s_name.UnqualName('annotations'), type=s_links.Link ) int_annos = Role.getptr( schema, s_name.UnqualName('annotations__internal'), type=s_links.Link ) permissions = Role.getptr( schema, s_name.UnqualName('permissions'), type=s_props.Property ) branches = Role.getptr( schema, s_name.UnqualName('branches'), type=s_props.Property ) view_query_fields = { 'id': "(json->>'id')::uuid", 'name': "json->>'name'", 'name__internal': "json->>'name'", 'superuser': 'True', 'abstract': 'False', 'is_derived': 'False', 'inherited_fields': 'ARRAY[]::text[]', 'computed_fields': 'ARRAY[]::text[]', 'builtin': 'True', 'internal': 'False', 'password': "json->>'password_hash'", 'apply_access_policies_pg_default': ( "(json->>'pg_apply_access_policies_default')::bool" ), } view_query = f''' SELECT {format_fields(schema, Role, view_query_fields)} FROM edgedbinstdata_VER.instdata WHERE key = 'single_role_metadata' AND json->>'tenant_id' = edgedb_VER.get_backend_tenant_id() ''' member_of_link_query_fields = { 'source': "'00000000-0000-0000-0000-000000000000'::uuid", 'target': "'00000000-0000-0000-0000-000000000000'::uuid", } member_of_link_query = f''' SELECT {format_fields(schema, member_of, member_of_link_query_fields)} LIMIT 0 ''' bases_link_query_fields = { 'source': "'00000000-0000-0000-0000-000000000000'::uuid", 'target': "'00000000-0000-0000-0000-000000000000'::uuid", 'index': "0::bigint", } bases_link_query = f''' SELECT {format_fields(schema, bases, bases_link_query_fields)} LIMIT 0 ''' ancestors_link_query = f''' SELECT {format_fields(schema, ancestors, bases_link_query_fields)} LIMIT 0 ''' annos_link_fields = { 'source': "(json->>'id')::uuid", 'target': "(annotations->>'id')::uuid", 'value': "(annotations->>'value')::text", 'owned': "(annotations->>'owned')::bool", } annos_link_query = f''' SELECT {format_fields(schema, annos, annos_link_fields)} FROM edgedbinstdata_VER.instdata CROSS JOIN LATERAL ROWS FROM ( jsonb_array_elements(json->'annotations') ) AS annotations WHERE key = 'single_role_metadata' AND json->>'tenant_id' = edgedb_VER.get_backend_tenant_id() ''' int_annos_link_fields = { 'source': "(json->>'id')::uuid", 'target': "(annotations->>'id')::uuid", 'owned': "(annotations->>'owned')::bool", } int_annos_link_query = f''' SELECT {format_fields(schema, int_annos, int_annos_link_fields)} FROM edgedbinstdata_VER.instdata CROSS JOIN LATERAL ROWS FROM ( jsonb_array_elements(json->'annotations__internal') ) AS annotations WHERE key = 'single_role_metadata' AND json->>'tenant_id' = edgedb_VER.get_backend_tenant_id() ''' # The single superuser role already has all permissions. # For completeness, create a permissions multi-prop table with dummy # values. It will return no rows since its WHERE clause is always false. permissions_query = f''' SELECT '00000000-0000-0000-0000-000000000000'::uuid AS source, ''::text AS target WHERE 1 = 0 ''' branches_query = f''' SELECT (json->>'id')::uuid AS source, '*'::text as target FROM edgedbinstdata_VER.instdata WHERE key = 'single_role_metadata' AND json->>'tenant_id' = edgedb_VER.get_backend_tenant_id() ''' objects = { Role: view_query, member_of: member_of_link_query, bases: bases_link_query, ancestors: ancestors_link_query, annos: annos_link_query, int_annos: int_annos_link_query, permissions: permissions_query, branches: branches_query, } views: list[dbops.View] = [] for obj, query in objects.items(): tabview = trampoline.VersionedView( name=tabname(schema, obj), query=query) views.append(tabview) return views def _generate_schema_ver_views(schema: s_schema.Schema) -> list[dbops.View]: Ver = schema.get( 'sys::GlobalSchemaVersion', type=s_objtypes.ObjectType, ) view_fields = { 'id': "(v.value->>'id')::uuid", 'name': "(v.value->>'name')", 'name__internal': "(v.value->>'name')", 'version': "(v.value->>'version')::uuid", 'builtin': "(v.value->>'builtin')::bool", 'internal': "(v.value->>'internal')::bool", 'computed_fields': 'ARRAY[]::text[]', } view_query = f''' SELECT {format_fields(schema, Ver, view_fields)} FROM jsonb_each( edgedb_VER.get_database_metadata( {ql(defines.EDGEDB_TEMPLATE_DB)} ) -> 'GlobalSchemaVersion' ) AS v ''' objects = { Ver: view_query } views: list[dbops.View] = [] for obj, query in objects.items(): tabview = trampoline.VersionedView( name=tabname(schema, obj), query=query) views.append(tabview) return views def _generate_stats_views(schema: s_schema.Schema) -> list[dbops.View]: QueryStats = schema.get( 'sys::QueryStats', type=s_objtypes.ObjectType, ) pvd = common.get_backend_name( schema, QueryStats .getptr(schema, s_name.UnqualName("protocol_version")) .get_target(schema) # type: ignore ) QueryType = schema.get( 'sys::QueryType', type=s_scalars.ScalarType, ) query_type_domain = common.get_backend_name(schema, QueryType) type_mapping = { str(v): k for k, v in defines.QueryType.__members__.items() } output_format_domain = common.get_backend_name( schema, schema.get('sys::OutputFormat', type=s_scalars.ScalarType) ) def float64_to_duration_t(val: str) -> str: return f"({val} * interval '1ms')::edgedbt.duration_t" query_stats_fields = { 'id': "s.id", 'name': "s.id::text", 'name__internal': "s.queryid::text", 'builtin': "false", 'internal': "false", 'computed_fields': 'ARRAY[]::text[]', 'compilation_config': "s.extras->'cc'", 'protocol_version': f"ROW(s.extras->'pv'->0, s.extras->'pv'->1)::{pvd}", 'default_namespace': "s.extras->>'dn'", 'namespace_aliases': "s.extras->'na'", 'output_format': f"(s.extras->>'of')::{output_format_domain}", 'expect_one': "(s.extras->'e1')::boolean", 'implicit_limit': "(s.extras->'il')::bigint", 'inline_typeids': "(s.extras->'ii')::boolean", 'inline_typenames': "(s.extras->'in')::boolean", 'inline_objectids': "(s.extras->'io')::boolean", 'branch': "((d.description)->>'id')::uuid", 'query': "s.query", 'query_type': f"(t.mapping->>s.stmt_type::text)::{query_type_domain}", 'tag': "s.tag", 'plans': 's.plans', 'total_plan_time': float64_to_duration_t('s.total_plan_time'), 'min_plan_time': float64_to_duration_t('s.min_plan_time'), 'max_plan_time': float64_to_duration_t('s.max_plan_time'), 'mean_plan_time': float64_to_duration_t('s.mean_plan_time'), 'stddev_plan_time': float64_to_duration_t('s.stddev_plan_time'), 'calls': 's.calls', 'total_exec_time': float64_to_duration_t('s.total_exec_time'), 'min_exec_time': float64_to_duration_t('s.min_exec_time'), 'max_exec_time': float64_to_duration_t('s.max_exec_time'), 'mean_exec_time': float64_to_duration_t('s.mean_exec_time'), 'stddev_exec_time': float64_to_duration_t('s.stddev_exec_time'), 'rows': 's.rows', 'stats_since': 's.stats_since::edgedbt.timestamptz_t', 'minmax_stats_since': 's.minmax_stats_since::edgedbt.timestamptz_t', } query_stats_query = fr''' SELECT {format_fields(schema, QueryStats, query_stats_fields)} FROM edgedbext.edb_stat_statements AS s INNER JOIN pg_database dat ON s.dbid = dat.oid CROSS JOIN LATERAL ( SELECT edgedb_VER.shobj_metadata(dat.oid, 'pg_database') AS description ) AS d CROSS JOIN LATERAL ( SELECT {ql(json.dumps(type_mapping))}::jsonb AS mapping ) AS t WHERE s.id IS NOT NULL AND (d.description)->>'id' IS NOT NULL AND (d.description)->>'tenant_id' = edgedb_VER.get_backend_tenant_id() AND t.mapping ? s.stmt_type::text ''' objects = { QueryStats: query_stats_query, } views: list[dbops.View] = [] for obj, query in objects.items(): tabview = trampoline.VersionedView( name=tabname(schema, obj), query=query) views.append(tabview) return views def _make_json_caster( schema: s_schema.Schema, stype: s_types.Type, versioned: bool, ) -> Callable[[str], str]: cast_expr = qlast.TypeCast( expr=qlast.TypeCast( expr=qlast.FunctionParameter(name="__replaceme__"), type=s_utils.typeref_to_ast(schema, schema.get('std::json')), ), type=s_utils.typeref_to_ast(schema, stype), ) cast_ir = qlcompiler.compile_ast_fragment_to_ir( cast_expr, schema, ) cast_sql_res = compiler.compile_ir_to_sql_tree( cast_ir, named_param_prefix=(), singleton_mode=True, versioned_singleton=versioned, ) cast_sql = codegen.generate_source(cast_sql_res.ast) return lambda val: cast_sql.replace('__replaceme__', val) def _generate_schema_alias_views( schema: s_schema.Schema, module: s_name.UnqualName, ) -> list[dbops.View]: views = [] schema_objs = schema.get_objects( type=s_objtypes.ObjectType, included_modules=(module,), ) for schema_obj in schema_objs: if not schema_obj.get_from_alias(schema): views.append(_generate_schema_alias_view(schema, schema_obj)) return views def _generate_schema_alias_view( schema: s_schema.Schema, obj: s_sources.Source | s_pointers.Pointer, ) -> dbops.View: name = _schema_alias_view_name(schema, obj) select = inheritance.get_inheritance_view(schema, obj) return trampoline.VersionedView( name=name, query=codegen.generate_source(select), ) def _schema_alias_view_name( schema: s_schema.Schema, obj: s_sources.Source | s_pointers.Pointer, ) -> tuple[str, str]: module = obj.get_name(schema).module prefix = module.capitalize() if isinstance(obj, s_links.Link): objtype = obj.get_source(schema) assert objtype is not None objname = objtype.get_name(schema).name lname = obj.get_shortname(schema).name name = f'_{prefix}{objname}__{lname}' else: name = f'_{prefix}{obj.get_name(schema).name}' return ('edgedb', name) def _generate_sql_information_schema( backend_version: params.BackendVersion ) -> list[dbops.Command]: # Helper to create wrappers around materialized views. For # performance, we use MATERIALIZED VIEW for some of our SQL # emulation tables. Unfortunately we can't use those directly, # since we need tableoid to match the real pg_catalog table. def make_wrapper_view(name: str) -> trampoline.VersionedView: return trampoline.VersionedView( name=("edgedbsql", name), query=f""" SELECT *, 'pg_catalog.{name}'::regclass::oid as tableoid, xmin, cmin, xmax, cmax, ctid FROM edgedbsql_VER.{name}_ """, ) # A helper view that contains all data tables we expose over SQL, excluding # introspection tables. # It contains table & schema names and associated module id. virtual_tables = trampoline.VersionedView( name=('edgedbsql', 'virtual_tables'), materialized=True, query=''' WITH obj_ty_pre AS ( SELECT id, REGEXP_REPLACE(name, '::[^:]*$', '') AS module_name, REGEXP_REPLACE(name, '^.*::', '') as table_name FROM edgedb_VER."_SchemaObjectType" WHERE internal IS NOT TRUE ), obj_ty AS ( SELECT id, REGEXP_REPLACE(module_name, '^default(?=::|$)', 'public') AS schema_name, module_name, table_name FROM obj_ty_pre ), all_tables (id, schema_name, module_name, table_name) AS (( SELECT * FROM obj_ty ) UNION ALL ( WITH qualified_links AS ( -- multi links and links with at least one property -- (besides source and target) SELECT link.id FROM edgedb_VER."_SchemaLink" link JOIN edgedb_VER."_SchemaProperty" AS prop ON link.id = prop.source WHERE prop.computable IS NOT TRUE AND prop.internal IS NOT TRUE GROUP BY link.id, link.cardinality HAVING link.cardinality = 'Many' OR COUNT(*) > 2 ) SELECT link.id, obj_ty.schema_name, obj_ty.module_name, CONCAT(obj_ty.table_name, '.', link.name) AS table_name FROM edgedb_VER."_SchemaLink" link JOIN obj_ty ON obj_ty.id = link.source WHERE link.id IN (SELECT * FROM qualified_links) ) UNION ALL ( -- multi properties SELECT prop.id, obj_ty.schema_name, obj_ty.module_name, CONCAT(obj_ty.table_name, '.', prop.name) AS table_name FROM edgedb_VER."_SchemaProperty" AS prop JOIN obj_ty ON obj_ty.id = prop.source WHERE prop.computable IS NOT TRUE AND prop.internal IS NOT TRUE AND prop.cardinality = 'Many' )) SELECT at.id, schema_name, table_name, sm.id AS module_id, pt.oid AS pg_type_id FROM all_tables at JOIN edgedb_VER."_SchemaModule" sm ON sm.name = at.module_name LEFT JOIN pg_type pt ON pt.typname = at.id::text WHERE schema_name not in ( 'cfg', 'sys', 'schema', 'std', 'std::net', 'std::net::http', 'std::net::perm' ) ''' ) # A few tables in here were causing problems, so let's hide them as an # implementation detail. # To be more specific: # - following tables were missing from information_schema: # Link.properties, ObjectType.links, ObjectType.properties # - even though introspection worked, I wasn't able to select from some # tables in cfg and sys # For making up oids of schemas that represent modules uuid_to_oid = trampoline.VersionedFunction( name=('edgedbsql', 'uuid_to_oid'), args=( ('id', 'uuid'), # extra is two extra bits to throw into the oid, for now ('extra', 'int4', '0'), ), returns=('oid',), volatility='immutable', text=""" SELECT ( ('x' || substring(id::text, 2, 7))::bit(28)::bigint*4 + extra + 40000)::oid; """ ) long_name = trampoline.VersionedFunction( name=('edgedbsql', '_long_name'), args=[ ('origname', ('text',)), ('longname', ('text',)), ], returns=('text',), volatility='stable', text=r''' SELECT CASE WHEN length(longname) > 63 THEN left(longname, 55) || left(origname, 8) ELSE longname END ''' ) type_rename = trampoline.VersionedFunction( name=('edgedbsql', '_pg_type_rename'), args=[ ('typeoid', ('oid',)), ('typename', ('name',)), ], returns=('name',), volatility='stable', text=r''' SELECT COALESCE ( -- is the name in virtual_tables? ( SELECT vt.table_name::name FROM edgedbsql_VER.virtual_tables vt WHERE vt.pg_type_id = typeoid ), -- is this a scalar or tuple? ( SELECT name::name FROM ( -- get the built-in scalars SELECT split_part(name, '::', 2) AS name, backend_id FROM edgedb_VER."_SchemaScalarType" WHERE NOT builtin AND arg_values IS NULL UNION ALL -- get the tuples SELECT edgedbsql_VER._long_name(typename, name), backend_id FROM edgedb_VER."_SchemaTuple" ) x WHERE x.backend_id = typeoid ), typename ) ''' ) namespace_rename = trampoline.VersionedFunction( name=('edgedbsql', '_pg_namespace_rename'), args=[ ('typeoid', ('oid',)), ('typens', ('oid',)), ], returns=('oid',), volatility='stable', text=r''' WITH nspub AS ( SELECT oid FROM pg_namespace WHERE nspname = 'edgedbpub' ), nsdef AS ( SELECT edgedbsql_VER.uuid_to_oid(id) AS oid FROM edgedb_VER."_SchemaModule" WHERE name = 'default' ) SELECT COALESCE ( ( SELECT edgedbsql_VER.uuid_to_oid(vt.module_id) FROM edgedbsql_VER.virtual_tables vt WHERE vt.pg_type_id = typeoid ), -- just replace "edgedbpub" with "public" (SELECT nsdef.oid WHERE typens = nspub.oid), typens ) FROM nspub, nsdef ''' ) # pg_settings is a function because "_edgecon_state" is a temporary table # and therefore cannot be used in a view. fe_pg_settings = trampoline.VersionedFunction( name=('edgedbsql', 'pg_show_all_settings'), args=[], returns=('pg_catalog', 'pg_settings'), set_returning=True, volatility='volatile', text=''' SELECT p.name, COALESCE( COALESCE(l.value, s.value, d.value) #>> '{}', p.setting ) AS setting, unit, category, short_desc, extra_desc, context, vartype, CASE WHEN l.value IS NOT NULL THEN 'session' WHEN s.value IS NOT NULL THEN 'session' WHEN d.value IS NOT NULL THEN 'default' ELSE p.source END AS source, min_val, max_val, enumvals, boot_val, CASE WHEN d.value IS NOT NULL THEN d.value #>> '{}' ELSE p.reset_val END AS reset_val, sourcefile, sourceline, pending_restart FROM pg_settings p LEFT JOIN _edgecon_state l ON p.name = l.name AND l.type = 'L' LEFT JOIN _edgecon_state s ON p.name = s.name AND s.type = 'S' LEFT JOIN ( SELECT (j->>'name') AS name, (j->'value') AS value FROM edgedbinstdata_VER.instdata CROSS JOIN LATERAL jsonb_array_elements(instdata.json) AS j WHERE key = 'sql_default_fe_settings' ) d ON p.name = d.name ''' ) sql_ident = 'information_schema.sql_identifier' sql_str = 'information_schema.character_data' sql_bool = 'information_schema.yes_or_no' sql_card = 'information_schema.cardinal_number' tables_and_columns = [ trampoline.VersionedView( name=('edgedbsql', 'tables'), query=( f''' SELECT edgedb_VER.get_current_database()::{sql_ident} AS table_catalog, vt.schema_name::{sql_ident} AS table_schema, vt.table_name::{sql_ident} AS table_name, ist.table_type, ist.self_referencing_column_name, ist.reference_generation, ist.user_defined_type_catalog, ist.user_defined_type_schema, ist.user_defined_type_name, ist.is_insertable_into, ist.is_typed, ist.commit_action FROM information_schema.tables ist JOIN edgedbsql_VER.virtual_tables vt ON vt.id::text = ist.table_name ''' ), ), trampoline.VersionedView( name=('edgedbsql', 'columns'), query=( f''' SELECT edgedb_VER.get_current_database()::{sql_ident} AS table_catalog, vt_table_schema::{sql_ident} AS table_schema, vt_table_name::{sql_ident} AS table_name, v_column_name::{sql_ident} as column_name, cast(ROW_NUMBER() OVER ( PARTITION BY vt_table_schema, vt_table_name ORDER BY position, v_column_name ) AS INT) AS ordinal_position, column_default, is_nullable, data_type, NULL::{sql_card} AS character_maximum_length, NULL::{sql_card} AS character_octet_length, NULL::{sql_card} AS numeric_precision, NULL::{sql_card} AS numeric_precision_radix, NULL::{sql_card} AS numeric_scale, NULL::{sql_card} AS datetime_precision, NULL::{sql_str} AS interval_type, NULL::{sql_card} AS interval_precision, NULL::{sql_ident} AS character_set_catalog, NULL::{sql_ident} AS character_set_schema, NULL::{sql_ident} AS character_set_name, NULL::{sql_ident} AS collation_catalog, NULL::{sql_ident} AS collation_schema, NULL::{sql_ident} AS collation_name, NULL::{sql_ident} AS domain_catalog, NULL::{sql_ident} AS domain_schema, NULL::{sql_ident} AS domain_name, edgedb_VER.get_current_database()::{sql_ident} AS udt_catalog, 'pg_catalog'::{sql_ident} AS udt_schema, NULL::{sql_ident} AS udt_name, NULL::{sql_ident} AS scope_catalog, NULL::{sql_ident} AS scope_schema, NULL::{sql_ident} AS scope_name, NULL::{sql_card} AS maximum_cardinality, 0::{sql_ident} AS dtd_identifier, 'NO'::{sql_bool} AS is_self_referencing, 'NO'::{sql_bool} AS is_identity, NULL::{sql_str} AS identity_generation, NULL::{sql_str} AS identity_start, NULL::{sql_str} AS identity_increment, NULL::{sql_str} AS identity_maximum, NULL::{sql_str} AS identity_minimum, 'NO' ::{sql_bool} AS identity_cycle, 'NEVER'::{sql_str} AS is_generated, NULL::{sql_str} AS generation_expression, 'YES'::{sql_bool} AS is_updatable FROM ( SELECT vt.schema_name AS vt_table_schema, vt.table_name AS vt_table_name, COALESCE( -- this happends for id and __type__ spec.name, -- fallback to pointer name, with suffix '_id' for links sp.name || case when sl.id is not null then '_id' else '' end ) AS v_column_name, COALESCE(spec.position, 2) AS position, (sp.expr IS NOT NULL) AS is_computed, isc.column_default, CASE WHEN sp.required OR spec.k IS NOT NULL THEN 'NO' ELSE 'YES' END AS is_nullable, -- HACK: computeds don't have backing rows in isc, -- so we just default to 'text'. This is wrong. COALESCE(isc.data_type, 'text') AS data_type FROM edgedb_VER."_SchemaPointer" sp LEFT JOIN information_schema.columns isc ON ( isc.table_name = sp.source::TEXT AND CASE WHEN length(isc.column_name) = 36 -- if column name is uuid THEN isc.column_name = sp.id::text -- compare uuids ELSE isc.column_name = sp.name -- for id, source, target END ) -- needed for attaching `_id` LEFT JOIN edgedb_VER."_SchemaLink" sl ON sl.id = sp.id -- needed for determining table name JOIN edgedbsql_VER.virtual_tables vt ON vt.id = sp.source -- positions for special pointers -- duplicate id get both id and __type__ columns out of it LEFT JOIN ( VALUES ('id', 'id', 0), ('id', '__type__', 1), ('source', 'source', 0), ('target', 'target', 1) ) spec(k, name, position) ON (spec.k = isc.column_name) WHERE isc.column_name IS NOT NULL -- normal pointers OR sp.expr IS NOT NULL AND sp.cardinality <> 'Many' -- computeds UNION ALL -- special case: multi properties source and target -- (this is needed, because schema does not create pointers for -- these two columns) SELECT vt.schema_name AS vt_table_schema, vt.table_name AS vt_table_name, isc.column_name AS v_column_name, spec.position as position, FALSE as is_computed, isc.column_default, 'NO' as is_nullable, isc.data_type as data_type FROM edgedb_VER."_SchemaPointer" sp JOIN information_schema.columns isc ON isc.table_name = sp.id::TEXT -- needed for filtering out links LEFT JOIN edgedb_VER."_SchemaLink" sl ON sl.id = sp.id -- needed for determining table name JOIN edgedbsql_VER.virtual_tables vt ON vt.id = sp.id -- positions for special pointers JOIN ( VALUES ('source', 'source', 0), ('target', 'target', 1) ) spec(k, name, position) ON (spec.k = isc.column_name) WHERE sl.id IS NULL -- property (non-link) AND sp.cardinality = 'Many' -- multi AND sp.expr IS NULL -- non-computed ) t ''' ), ), ] pg_catalog_views = [ trampoline.VersionedView( name=("edgedbsql", "pg_namespace_"), materialized=True, query=""" -- system schemas SELECT oid, nspname, nspowner, nspacl FROM pg_namespace WHERE nspname IN ('pg_catalog', 'pg_toast', 'information_schema', 'edgedb', 'edgedbstd', 'edgedbt', 'edgedb_VER', 'edgedbstd_VER') UNION ALL -- virtual schemas SELECT edgedbsql_VER.uuid_to_oid(t.module_id) AS oid, t.schema_name AS nspname, (SELECT oid FROM pg_roles WHERE rolname = CURRENT_USER LIMIT 1) AS nspowner, NULL AS nspacl FROM ( SELECT schema_name, module_id FROM edgedbsql_VER.virtual_tables UNION -- always include the default module, -- because it is needed for tuple types SELECT 'public' AS schema_name, id AS module_id FROM edgedb_VER."_SchemaModule" WHERE name = 'default' ) t """, ), make_wrapper_view("pg_namespace"), trampoline.VersionedView( name=("edgedbsql", "pg_type_"), materialized=True, query=""" SELECT pt.oid, edgedbsql_VER._pg_type_rename(pt.oid, pt.typname) AS typname, edgedbsql_VER._pg_namespace_rename(pt.oid, pt.typnamespace) AS typnamespace, {0} FROM pg_type pt JOIN pg_namespace pn ON pt.typnamespace = pn.oid WHERE nspname IN ('pg_catalog', 'pg_toast', 'information_schema', 'edgedb', 'edgedbstd', 'edgedb_VER', 'edgedbstd_VER', 'edgedbpub', 'edgedbt') """.format( ",".join( f"pt.{col}" for col, _, _ in sql_introspection.PG_CATALOG["pg_type"][3:] ) ), ), make_wrapper_view("pg_type"), # pg_class that contains classes only for tables # This is needed so we can use it to filter pg_index to indexes only on # visible tables. trampoline.VersionedView( name=("edgedbsql", "pg_class_tables"), materialized=True, query=""" -- Postgres tables SELECT pc.* FROM pg_class pc JOIN pg_namespace pn ON pc.relnamespace = pn.oid WHERE nspname IN ('pg_catalog', 'pg_toast', 'information_schema') UNION ALL -- user-defined tables SELECT oid, vt.table_name as relname, edgedbsql_VER.uuid_to_oid(vt.module_id) as relnamespace, reltype, reloftype, relowner, relam, relfilenode, reltablespace, relpages, reltuples, relallvisible, reltoastrelid, relhasindex, relisshared, relpersistence, relkind, relnatts, 0 as relchecks, -- don't care about CHECK constraints relhasrules, relhastriggers, relhassubclass, relrowsecurity, relforcerowsecurity, relispopulated, relreplident, relispartition, relrewrite, relfrozenxid, relminmxid, relacl, reloptions, relpartbound FROM pg_class pc JOIN edgedbsql_VER.virtual_tables vt ON vt.pg_type_id = pc.reltype """, ), trampoline.VersionedView( name=("edgedbsql", "pg_index_"), materialized=True, query=f""" SELECT pi.indexrelid, pi.indrelid, pi.indnatts, pi.indnkeyatts, CASE WHEN COALESCE(is_id.t, FALSE) THEN TRUE ELSE pi.indisprimary END AS indisunique, {'pi.indnullsnotdistinct,' if backend_version.major >= 15 else ''} CASE WHEN COALESCE(is_id.t, FALSE) THEN TRUE ELSE pi.indisprimary END AS indisprimary, pi.indisexclusion, pi.indimmediate, pi.indisclustered, pi.indisvalid, pi.indcheckxmin, CASE WHEN COALESCE(is_id.t, FALSE) THEN TRUE ELSE FALSE -- override so pg_dump won't try to recreate them END AS indisready, pi.indislive, pi.indisreplident, CASE WHEN COALESCE(is_id.t, FALSE) THEN ARRAY[1]::int2vector -- id: 1 ELSE pi.indkey END AS indkey, pi.indcollation, pi.indclass, pi.indoption, pi.indexprs, pi.indpred FROM pg_index pi -- filter by tables visible in pg_class INNER JOIN edgedbsql_VER.pg_class_tables pr ON pi.indrelid = pr.oid -- find indexes that are on virtual tables and on `id` columns LEFT JOIN LATERAL ( SELECT TRUE AS t FROM pg_attribute pa WHERE pa.attrelid = pi.indrelid AND pa.attnum = ANY(pi.indkey) AND pa.attname = 'id' ) is_id ON TRUE -- for our tables show only primary key indexes LEFT JOIN edgedbsql_VER.virtual_tables vt ON vt.pg_type_id = pr.reltype WHERE vt.id IS NULL OR is_id.t IS NOT NULL """, ), make_wrapper_view('pg_index'), trampoline.VersionedView( name=("edgedbsql", "pg_class_"), materialized=True, query=""" -- tables SELECT pc.* FROM edgedbsql_VER.pg_class_tables pc UNION -- indexes SELECT pc.* FROM pg_class pc JOIN pg_index pi ON pc.oid = pi.indexrelid UNION -- compound types (tuples) SELECT pc.oid, edgedbsql_VER._long_name(pc.reltype::text, tup.name) as relname, nsdef.oid as relnamespace, pc.reltype, pc.reloftype, pc.relowner, pc.relam, pc.relfilenode, pc.reltablespace, pc.relpages, pc.reltuples, pc.relallvisible, pc.reltoastrelid, pc.relhasindex, pc.relisshared, pc.relpersistence, pc.relkind, pc.relnatts, 0 as relchecks, -- don't care about CHECK constraints pc.relhasrules, pc.relhastriggers, pc.relhassubclass, pc.relrowsecurity, pc.relforcerowsecurity, pc.relispopulated, pc.relreplident, pc.relispartition, pc.relrewrite, pc.relfrozenxid, pc.relminmxid, pc.relacl, pc.reloptions, pc.relpartbound FROM pg_class pc JOIN edgedb_VER."_SchemaTuple" tup ON tup.backend_id = pc.reltype JOIN ( SELECT edgedbsql_VER.uuid_to_oid(id) AS oid FROM edgedb_VER."_SchemaModule" WHERE name = 'default' ) nsdef ON TRUE """, ), make_wrapper_view("pg_class"), # Because we hide some columns and # because pg_dump expects attnum to be sequential numbers # we have to invent new attnums with ROW_NUMBER(). # Since attnum is used elsewhere, we need to know the mapping from # constructed attnum into underlying attnum. # To do that, we have pg_attribute_ext view with additional # attnum_internal column. trampoline.VersionedView( name=("edgedbsql", "pg_attribute_ext"), materialized=True, query=r""" SELECT attrelid, attname, atttypid, attstattarget, attlen, attnum, attnum as attnum_internal, attndims, attcacheoff, atttypmod, attbyval, attstorage, attalign, attnotnull, atthasdef, atthasmissing, attidentity, attgenerated, attisdropped, attislocal, attinhcount, attcollation, attacl, attoptions, attfdwoptions, null::int[] as attmissingval FROM pg_attribute pa JOIN pg_class pc ON pa.attrelid = pc.oid JOIN pg_namespace pn ON pc.relnamespace = pn.oid LEFT JOIN edgedb_VER."_SchemaTuple" tup ON tup.backend_id = pc.reltype WHERE nspname IN ('pg_catalog', 'pg_toast', 'information_schema') OR tup.backend_id IS NOT NULL UNION ALL SELECT pc_oid as attrelid, col_name as attname, COALESCE(atttypid, 25) as atttypid, -- defaults to TEXT COALESCE(attstattarget, -1) as attstattarget, COALESCE(attlen, -1) as attlen, (ROW_NUMBER() OVER ( PARTITION BY pc_oid ORDER BY col_position, col_name ) - 6)::smallint AS attnum, t.attnum as attnum_internal, COALESCE(attndims, 0) as attndims, COALESCE(attcacheoff, -1) as attcacheoff, COALESCE(atttypmod, -1) as atttypmod, COALESCE(attbyval, FALSE) as attbyval, COALESCE(attstorage, 'x') as attstorage, COALESCE(attalign, 'i') as attalign, required as attnotnull, -- Always report no default, to avoid expr trouble false as atthasdef, COALESCE(atthasmissing, FALSE) as atthasmissing, COALESCE(attidentity, '') as attidentity, COALESCE(attgenerated, '') as attgenerated, COALESCE(attisdropped, FALSE) as attisdropped, COALESCE(attislocal, TRUE) as attislocal, COALESCE(attinhcount, 0) as attinhcount, COALESCE(attcollation, 0) as attcollation, attacl, attoptions, attfdwoptions, null::int[] as attmissingval FROM ( SELECT COALESCE( spec.name, -- for special columns sp.name || case when sl.id is not null then '_id' else '' end, pa.attname -- for system columns ) as col_name, COALESCE(spec.position, 2) AS col_position, (sp.required IS TRUE OR spec.k IS NOT NULL) as required, pc.oid AS pc_oid, pa.* FROM edgedb_VER."_SchemaPointer" sp JOIN edgedbsql_VER.virtual_tables vt ON vt.id = sp.source JOIN pg_class pc ON pc.reltype = vt.pg_type_id -- try to find existing pg_attribute (it will not exist for computeds) LEFT JOIN pg_attribute pa ON ( pa.attrelid = pc.oid AND CASE WHEN length(pa.attname) = 36 -- if column name is uuid THEN pa.attname = sp.id::text -- compare uuids ELSE pa.attname = sp.name -- for id, source, target END ) -- positions for special pointers -- duplicate id get both id and __type__ columns out of it LEFT JOIN ( VALUES ('id', 'id', 0), ('id', '__type__', 1), ('source', 'source', 0), ('target', 'target', 1) ) spec(k, name, position) ON (spec.k = pa.attname) -- needed for attaching `_id` LEFT JOIN edgedb_VER."_SchemaLink" sl ON sl.id = sp.id WHERE pa.attname IS NOT NULL -- non-computed pointers OR sp.expr IS NOT NULL AND sp.cardinality <> 'Many' -- computeds UNION ALL -- special case: multi properties source and target -- (this is needed, because schema does not create pointers for -- these two columns) SELECT pa.attname AS col_name, spec.position as position, TRUE as required, pa.attrelid as pc_oid, pa.* FROM edgedb_VER."_SchemaProperty" sp JOIN pg_class pc ON pc.relname = sp.id::TEXT JOIN pg_attribute pa ON pa.attrelid = pc.oid -- positions for special pointers JOIN ( VALUES ('source', 0), ('target', 1) ) spec(k, position) ON (spec.k = pa.attname) WHERE sp.cardinality = 'Many' -- multi AND sp.expr IS NULL -- non-computed UNION ALL -- special case: system columns SELECT pa.attname AS col_name, pa.attnum as position, TRUE as required, pa.attrelid as pc_oid, pa.* FROM pg_attribute pa JOIN pg_class pc ON pc.oid = pa.attrelid JOIN edgedbsql_VER.virtual_tables vt ON vt.pg_type_id = pc.reltype WHERE pa.attnum < 0 ) t """, ), trampoline.VersionedView( name=("edgedbsql", "pg_attribute"), query=""" SELECT attrelid, attname, atttypid, attstattarget, attlen, attnum, attndims, attcacheoff, atttypmod, attbyval, attstorage, attalign, attnotnull, atthasdef, atthasmissing, attidentity, attgenerated, attisdropped, attislocal, attinhcount, attcollation, attacl, attoptions, attfdwoptions, attmissingval, 'pg_catalog.pg_attribute'::regclass::oid as tableoid, xmin, cmin, xmax, cmax, ctid FROM edgedbsql_VER.pg_attribute_ext """, ), trampoline.VersionedView( name=("edgedbsql", "pg_database"), query=f""" SELECT oid, frontend_name.n as datname, datdba, encoding, {'datlocprovider,' if backend_version.major >= 15 else ''} datcollate, datctype, datistemplate, datallowconn, {'dathasloginevt,' if backend_version.major >= 17 else ''} datconnlimit, 0::oid AS datlastsysoid, datfrozenxid, datminmxid, dattablespace, {'datlocale,' if backend_version.major >= 17 else ''} {'daticurules,' if backend_version.major >= 16 else ''} {'datcollversion,' if backend_version.major >= 15 else ''} datacl, tableoid, xmin, cmin, xmax, cmax, ctid FROM pg_database, LATERAL ( SELECT edgedb_VER.get_database_frontend_name(datname) AS n ) frontend_name, LATERAL ( SELECT edgedb_VER.get_database_metadata(frontend_name.n) AS j ) metadata WHERE metadata.j->>'tenant_id' = edgedb_VER.get_backend_tenant_id() AND NOT (metadata.j->'builtin')::bool """, ), # HACK: there were problems with pg_dump when exposing this table, so # I've added WHERE FALSE. The query could be simplified, but it may # be needed in the future. Its EXPLAIN cost is 0..0 anyway. trampoline.VersionedView( name=("edgedbsql", "pg_stats"), query=""" SELECT n.nspname AS schemaname, c.relname AS tablename, a.attname, s.stainherit AS inherited, s.stanullfrac AS null_frac, s.stawidth AS avg_width, s.stadistinct AS n_distinct, NULL::real[] AS most_common_vals, s.stanumbers1 AS most_common_freqs, s.stanumbers1 AS histogram_bounds, s.stanumbers1[1] AS correlation, NULL::real[] AS most_common_elems, s.stanumbers1 AS most_common_elem_freqs, s.stanumbers1 AS elem_count_histogram FROM pg_statistic s JOIN pg_class c ON c.oid = s.starelid JOIN edgedbsql_VER.pg_attribute_ext a ON ( c.oid = a.attrelid and a.attnum_internal = s.staattnum ) LEFT JOIN pg_namespace n ON n.oid = c.relnamespace WHERE FALSE """, ), trampoline.VersionedView( name=("edgedbsql", "pg_constraint"), query=r""" -- primary keys for: -- - objects tables (that contains id) -- - link tables (that contains source and target) -- there exists a unique constraint for each of these SELECT pc.oid, vt.table_name || '_pk' AS conname, pc.connamespace, 'p'::"char" AS contype, pc.condeferrable, pc.condeferred, pc.convalidated, pc.conrelid, pc.contypid, pc.conindid, pc.conparentid, NULL::oid AS confrelid, NULL::"char" AS confupdtype, NULL::"char" AS confdeltype, NULL::"char" AS confmatchtype, pc.conislocal, pc.coninhcount, pc.connoinherit, CASE WHEN pa.attname = 'id' THEN ARRAY[1]::int2[] -- id will always have attnum 1 ELSE ARRAY[1, 2]::int2[] -- source and target END AS conkey, NULL::int2[] AS confkey, NULL::oid[] AS conpfeqop, NULL::oid[] AS conppeqop, NULL::oid[] AS conffeqop, NULL::int2[] AS confdelsetcols, NULL::oid[] AS conexclop, pc.conbin, pc.tableoid, pc.xmin, pc.cmin, pc.xmax, pc.cmax, pc.ctid FROM pg_constraint pc JOIN edgedbsql_VER.pg_class_tables pct ON pct.oid = pc.conrelid JOIN edgedbsql_VER.virtual_tables vt ON vt.pg_type_id = pct.reltype JOIN pg_attribute pa ON (pa.attrelid = pct.oid AND pa.attnum = ANY(conkey) AND pa.attname IN ('id', 'source') ) WHERE contype = 'u' -- our ids and all links will have unique constraint UNION ALL -- foreign keys for object tables SELECT -- uuid_to_oid needs "extra" arg to disambiguate from the link table -- keys below edgedbsql_VER.uuid_to_oid(sl.id, 0) as oid, vt.table_name || '_fk_' || sl.name AS conname, edgedbsql_VER.uuid_to_oid(vt.module_id) AS connamespace, 'f'::"char" AS contype, FALSE AS condeferrable, FALSE AS condeferred, TRUE AS convalidated, pc.oid AS conrelid, 0::oid AS contypid, 0::oid AS conindid, -- let's hope this is not needed 0::oid AS conparentid, pc_target.oid AS confrelid, 'a'::"char" AS confupdtype, 'a'::"char" AS confdeltype, 's'::"char" AS confmatchtype, TRUE AS conislocal, 0::int2 AS coninhcount, TRUE AS connoinherit, ARRAY[pa.attnum]::int2[] AS conkey, ARRAY[1]::int2[] AS confkey, -- id will always have attnum 1 ARRAY['uuid_eq'::regproc]::oid[] AS conpfeqop, ARRAY['uuid_eq'::regproc]::oid[] AS conppeqop, ARRAY['uuid_eq'::regproc]::oid[] AS conffeqop, NULL::int2[] AS confdelsetcols, NULL::oid[] AS conexclop, NULL::pg_node_tree AS conbin, pa.tableoid, pa.xmin, pa.cmin, pa.xmax, pa.cmax, pa.ctid FROM edgedbsql_VER.virtual_tables vt JOIN pg_class pc ON pc.reltype = vt.pg_type_id JOIN edgedb_VER."_SchemaLink" sl ON sl.source = vt.id -- AND COALESCE(sl.cardinality = 'One', TRUE) JOIN edgedbsql_VER.virtual_tables vt_target ON sl.target = vt_target.id JOIN pg_class pc_target ON pc_target.reltype = vt_target.pg_type_id JOIN edgedbsql_VER.pg_attribute pa ON pa.attrelid = pc.oid AND pa.attname = sl.name || '_id' UNION ALL -- foreign keys for: -- - multi link tables (source & target), -- - multi property tables (source), -- - single link with link properties (source & target), -- these constraints do not actually exist, so we emulate it entierly SELECT -- uuid_to_oid needs "extra" arg to disambiguate from other -- constraints using this pointer edgedbsql_VER.uuid_to_oid(sp.id, spec.attnum) AS oid, vt.table_name || '_fk_' || spec.name AS conname, edgedbsql_VER.uuid_to_oid(vt.module_id) AS connamespace, 'f'::"char" AS contype, FALSE AS condeferrable, FALSE AS condeferred, TRUE AS convalidated, pc.oid AS conrelid, pc.reltype AS contypid, 0::oid AS conindid, -- TODO 0::oid AS conparentid, pcf.oid AS confrelid, 'r'::"char" AS confupdtype, 'r'::"char" AS confdeltype, 's'::"char" AS confmatchtype, TRUE AS conislocal, 0::int2 AS coninhcount, TRUE AS connoinherit, ARRAY[spec.attnum]::int2[] AS conkey, ARRAY[1]::int2[] AS confkey, -- id will have attnum 1 ARRAY['uuid_eq'::regproc]::oid[] AS conpfeqop, ARRAY['uuid_eq'::regproc]::oid[] AS conppeqop, ARRAY['uuid_eq'::regproc]::oid[] AS conffeqop, NULL::int2[] AS confdelsetcols, NULL::oid[] AS conexclop, pc.relpartbound AS conbin, pc.tableoid, pc.xmin, pc.cmin, pc.xmax, pc.cmax, pc.ctid FROM edgedb_VER."_SchemaPointer" sp -- find links with link properties LEFT JOIN LATERAL ( SELECT sl.id FROM edgedb_VER."_SchemaLink" sl LEFT JOIN edgedb_VER."_SchemaProperty" AS slp ON slp.source = sl.id GROUP BY sl.id HAVING COUNT(*) > 2 ) link_props ON link_props.id = sp.id JOIN pg_class pc ON pc.relname = sp.id::TEXT JOIN edgedbsql_VER.virtual_tables vt ON vt.pg_type_id = pc.reltype -- duplicate each row for source and target JOIN LATERAL (VALUES ('source', 1::int2, sp.source), ('target', 2::int2, sp.target) ) spec(name, attnum, foreign_id) ON TRUE JOIN edgedbsql_VER.virtual_tables vtf ON vtf.id = spec.foreign_id JOIN pg_class pcf ON pcf.reltype = vtf.pg_type_id WHERE sp.cardinality = 'Many' OR link_props.id IS NOT NULL AND sp.computable IS NOT TRUE AND sp.internal IS NOT TRUE """ ), trampoline.VersionedView( name=("edgedbsql", "pg_statistic"), query=""" SELECT starelid, a.attnum as staattnum, stainherit, stanullfrac, stawidth, stadistinct, stakind1, stakind2, stakind3, stakind4, stakind5, staop1, staop2, staop3, staop4, staop5, stacoll1, stacoll2, stacoll3, stacoll4, stacoll5, stanumbers1, stanumbers2, stanumbers3, stanumbers4, stanumbers5, NULL::real[] AS stavalues1, NULL::real[] AS stavalues2, NULL::real[] AS stavalues3, NULL::real[] AS stavalues4, NULL::real[] AS stavalues5, s.tableoid, s.xmin, s.cmin, s.xmax, s.cmax, s.ctid FROM pg_statistic s JOIN edgedbsql_VER.pg_attribute_ext a ON ( a.attrelid = s.starelid AND a.attnum_internal = s.staattnum ) """, ), trampoline.VersionedView( name=("edgedbsql", "pg_statistic_ext"), query=""" SELECT oid, stxrelid, stxname, stxnamespace, stxowner, stxstattarget, stxkeys, stxkind, NULL::pg_node_tree as stxexprs, tableoid, xmin, cmin, xmax, cmax, ctid FROM pg_statistic_ext """, ), trampoline.VersionedView( name=("edgedbsql", "pg_statistic_ext_data"), query=""" SELECT stxoid, stxdndistinct, stxddependencies, stxdmcv, NULL::oid AS stxdexpr, tableoid, xmin, cmin, xmax, cmax, ctid FROM pg_statistic_ext_data """, ), trampoline.VersionedView( name=("edgedbsql", "pg_rewrite"), query=""" SELECT pr.*, pr.tableoid, pr.xmin, pr.cmin, pr.xmax, pr.cmax, pr.ctid FROM pg_rewrite pr JOIN edgedbsql_VER.pg_class pn ON pr.ev_class = pn.oid """, ), # HACK: Automatically generated cast function for ranges/multiranges # was causing issues for pg_dump. So at the end of the day we opt for # not exposing any casts at all here since there is no real reason for # this compatibility layer that is read-only to have elaborate casts # present. trampoline.VersionedView( name=("edgedbsql", "pg_cast"), query=""" SELECT pc.*, pc.tableoid, pc.xmin, pc.cmin, pc.xmax, pc.cmax, pc.ctid FROM pg_cast pc WHERE FALSE """, ), # Omit all funcitons for now. trampoline.VersionedView( name=("edgedbsql", "pg_proc"), query=""" SELECT *, tableoid, xmin, cmin, xmax, cmax, ctid FROM pg_proc WHERE FALSE """, ), # Omit all operators for now. trampoline.VersionedView( name=("edgedbsql", "pg_operator"), query=""" SELECT *, tableoid, xmin, cmin, xmax, cmax, ctid FROM pg_operator WHERE FALSE """, ), # Omit all triggers for now. trampoline.VersionedView( name=("edgedbsql", "pg_trigger"), query=""" SELECT *, tableoid, xmin, cmin, xmax, cmax, ctid FROM pg_trigger WHERE FALSE """, ), # Omit all subscriptions for now. # This table is queried by pg_dump with COUNT(*) when user does not # have permissions to access it. This should be allowed, but the # view expands the query to all columns, which is not allowed. # So we have to construct an empty view with correct signature that # does not reference pg_subscription. trampoline.VersionedView( name=("edgedbsql", "pg_subscription"), query=""" SELECT NULL::oid AS oid, NULL::oid AS subdbid, NULL::name AS subname, NULL::oid AS subowner, NULL::boolean AS subenabled, NULL::text AS subconninfo, NULL::name AS subslotname, NULL::text AS subsynccommit, NULL::oid AS subpublications, tableoid, xmin, cmin, xmax, cmax, ctid FROM pg_namespace WHERE FALSE """, ), trampoline.VersionedView( name=("edgedbsql", "pg_tables"), query=""" SELECT n.nspname AS schemaname, c.relname AS tablename, pg_get_userbyid(c.relowner) AS tableowner, t.spcname AS tablespace, c.relhasindex AS hasindexes, c.relhasrules AS hasrules, c.relhastriggers AS hastriggers, c.relrowsecurity AS rowsecurity FROM edgedbsql_VER.pg_class c LEFT JOIN edgedbsql_VER.pg_namespace n ON n.oid = c.relnamespace LEFT JOIN pg_tablespace t ON t.oid = c.reltablespace WHERE c.relkind = ANY (ARRAY['r'::"char", 'p'::"char"]) """, ), trampoline.VersionedView( name=("edgedbsql", "pg_views"), query=""" SELECT n.nspname AS schemaname, c.relname AS viewname, pg_get_userbyid(c.relowner) AS viewowner, pg_get_viewdef(c.oid) AS definition FROM edgedbsql_VER.pg_class c LEFT JOIN edgedbsql_VER.pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind = 'v'::"char" """, ), # Omit all descriptions (comments), becase all non-system comments # are our internal implementation details. trampoline.VersionedView( name=("edgedbsql", "pg_description"), query=""" SELECT *, tableoid, xmin, cmin, xmax, cmax, ctid FROM pg_description WHERE FALSE """, ), trampoline.VersionedView( name=("edgedbsql", "pg_settings"), query=""" select * from edgedbsql_VER.pg_show_all_settings() """, ), # A helper view for the `SHOW` SQL command. # See also NormalizedPgSettingsView, InterpretConfigValueToJsonFunction # as well as the Postgres C function ShowGUCOption. trampoline.VersionedView( name=("edgedbsql", "pg_settings_for_show"), query=r""" SELECT s.name AS name, CASE WHEN vartype = 'bool' THEN ( CASE WHEN lower(setting) = any(ARRAY['on', 'true', 'yes', '1']) THEN 'on' ELSE 'off' END ) WHEN vartype = 'enum' OR vartype = 'string' THEN setting WHEN vartype = 'integer' OR vartype = 'real' THEN ( CASE WHEN setting::numeric > 0 AND unit.unit IS NOT NULL THEN (setting::numeric * unit.multiplier)::text || unit.unit ELSE setting END ) ELSE edgedb_VER.raise( NULL::text, msg => ( 'unknown configuration type "' || COALESCE(vartype, '') || '"' ) ) END AS setting, short_desc AS description FROM edgedbsql_VER.pg_settings AS s, LATERAL ( SELECT regexp_match( s.unit, '^(\d*)\s*([a-zA-Z]{1,3})$') AS v ) AS _unit, LATERAL ( SELECT COALESCE( CASE WHEN _unit.v[1] = '' THEN 1 ELSE _unit.v[1]::int END, 1 ) AS multiplier, COALESCE(_unit.v[2], '') AS unit ) AS unit """ ), ] # We expose most of the views as empty tables, just to prevent errors when # the tools do introspection. # For the tables that it turns out are actually needed, we handcraft the # views that expose the actual data. # I've been cautious about exposing too much data, for example limiting # pg_type to pg_catalog and pg_toast namespaces. views: list[dbops.View] = [] views.extend(tables_and_columns) for table_name, columns in sql_introspection.INFORMATION_SCHEMA.items(): if table_name in ["tables", "columns"]: continue views.append( trampoline.VersionedView( name=("edgedbsql", table_name), query="SELECT {} LIMIT 0".format( ",".join( f"NULL::information_schema.{type} AS {name}" for name, type, _ver_since in columns ) ), ) ) PG_TABLES_SKIP = { 'pg_type', 'pg_attribute', 'pg_namespace', 'pg_class', 'pg_database', 'pg_proc', 'pg_operator', 'pg_pltemplate', 'pg_stats', 'pg_stats_ext_exprs', 'pg_statistic', 'pg_statistic_ext', 'pg_statistic_ext_data', 'pg_rewrite', 'pg_cast', 'pg_index', 'pg_constraint', 'pg_trigger', 'pg_subscription', 'pg_tables', 'pg_views', 'pg_description', 'pg_settings', } PG_TABLES_WITH_SYSTEM_COLS = { 'pg_aggregate', 'pg_am', 'pg_amop', 'pg_amproc', 'pg_attrdef', 'pg_attribute', 'pg_auth_members', 'pg_authid', 'pg_cast', 'pg_class', 'pg_collation', 'pg_constraint', 'pg_conversion', 'pg_database', 'pg_db_role_setting', 'pg_default_acl', 'pg_depend', 'pg_enum', 'pg_event_trigger', 'pg_extension', 'pg_foreign_data_wrapper', 'pg_foreign_server', 'pg_foreign_table', 'pg_index', 'pg_inherits', 'pg_init_privs', 'pg_language', 'pg_largeobject', 'pg_largeobject_metadata', 'pg_namespace', 'pg_opclass', 'pg_operator', 'pg_opfamily', 'pg_partitioned_table', 'pg_policy', 'pg_publication', 'pg_publication_rel', 'pg_range', 'pg_replication_origin', 'pg_rewrite', 'pg_seclabel', 'pg_sequence', 'pg_shdepend', 'pg_shdescription', 'pg_shseclabel', 'pg_statistic', 'pg_statistic_ext', 'pg_statistic_ext_data', 'pg_subscription_rel', 'pg_tablespace', 'pg_transform', 'pg_trigger', 'pg_ts_config', 'pg_ts_config_map', 'pg_ts_dict', 'pg_ts_parser', 'pg_ts_template', 'pg_type', 'pg_user_mapping', } SYSTEM_COLUMNS = ['tableoid', 'xmin', 'cmin', 'xmax', 'cmax', 'ctid'] def construct_pg_view( table_name: str, backend_version: params.BackendVersion ) -> Optional[dbops.View]: pg_columns = sql_introspection.PG_CATALOG[table_name] columns = [] has_columns = False for c_name, c_typ, c_ver_since in pg_columns: if c_ver_since <= backend_version.major: columns.append('o.' + c_name) has_columns = True elif c_typ: columns.append(f'NULL::{c_typ} as {c_name}') else: columns.append(f'NULL as {c_name}') if not has_columns: return None if table_name in PG_TABLES_WITH_SYSTEM_COLS: for c_name in SYSTEM_COLUMNS: columns.append('o.' + c_name) return trampoline.VersionedView( name=("edgedbsql", table_name), query=f"SELECT {','.join(columns)} FROM pg_catalog.{table_name} o", ) views.extend(pg_catalog_views) for table_name in sql_introspection.PG_CATALOG.keys(): if table_name in PG_TABLES_SKIP: continue if v := construct_pg_view(table_name, backend_version): views.append(v) util_functions = [ # A 1:1 PL/pgSQL replication of the Postgres SplitIdentifierString trampoline.VersionedFunction( name=('edgedbsql', 'split_identifier_string'), args=( ('rawstring', 'text',), ('separator', 'text',), ), returns=('text[]',), language="plpgsql", volatility="immutable", text=r""" DECLARE namelist text[] := '{}'; pos integer := 1; len integer; c char; sep_char char; curname text; in_quote boolean; db_encoding text := getdatabaseencoding(); BEGIN -- Initialization IF length(separator) != 1 THEN RAISE EXCEPTION 'Separator must be a single character'; END IF; sep_char := substring(separator FROM 1 FOR 1); len := length(rawstring); -- Skip leading whitespace WHILE pos <= len LOOP c := substring(rawstring FROM pos FOR 1); IF c IN (' ', '\t', '\n', '\r', '\f') THEN pos := pos + 1; ELSE EXIT; END IF; END LOOP; -- Allow empty string IF pos > len THEN RETURN namelist; END IF; -- At the top of the loop, we are at start of a new identifier. LOOP IF substring(rawstring FROM pos FOR 1) = '"' THEN -- Quoted name --- collapse quote-quote pairs, no downcasing pos := pos + 1; curname := ''; in_quote := TRUE; WHILE pos <= len LOOP c := substring(rawstring FROM pos FOR 1); IF c = '"' THEN IF pos < len AND substring(rawstring FROM pos + 1 FOR 1) = '"' THEN -- Collapse adjacent quotes into one quote, -- and look again curname := curname || '"'; pos := pos + 2; ELSE -- Found end of quoted name pos := pos + 1; in_quote := FALSE; EXIT; END IF; ELSE curname := curname || c; pos := pos + 1; END IF; END LOOP; -- Mismatched quotes IF in_quote THEN RAISE EXCEPTION 'Unterminated quoted identifier'; END IF; ELSE -- Unquoted name --- extends to separator or whitespace curname := ''; WHILE pos <= len LOOP c := substring(rawstring FROM pos FOR 1); IF c = sep_char OR c IN (' ', '\t', '\n', '\r', '\f') THEN EXIT; END IF; curname := curname || c; pos := pos + 1; END LOOP; IF curname = '' THEN RAISE EXCEPTION 'Empty unquoted identifier'; END IF; -- Downcase the identifier curname := lower(curname); END IF; -- Truncate name if it's overlength IF octet_length(curname) > 63 THEN RAISE NOTICE 'identifier "%" will be truncated', curname; curname := convert_from( substring(convert_to(curname, db_encoding) FROM 1 FOR 63), db_encoding ); END IF; -- Finished isolating current name --- add it to list namelist := array_append(namelist, curname); -- Skip trailing whitespace WHILE pos <= len LOOP c := substring(rawstring FROM pos FOR 1); IF c IN (' ', '\t', '\n', '\r', '\f') THEN pos := pos + 1; ELSE EXIT; END IF; END LOOP; IF pos > len THEN EXIT; -- End of string ELSIF substring(rawstring FROM pos FOR 1) = sep_char THEN pos := pos + 1; -- Skip leading whitespace for next WHILE pos <= len LOOP c := substring(rawstring FROM pos FOR 1); IF c IN (' ', '\t', '\n', '\r', '\f') THEN pos := pos + 1; ELSE EXIT; END IF; END LOOP; ELSE RAISE EXCEPTION 'Invalid character at position %: "%"', pos, c; END IF; END LOOP; RETURN namelist; END; """, ), # current_schemas() is an emulation of Postgres current_schemas(), # fetch_search_path() and recomputeNamespacePath() internal functions. trampoline.VersionedFunction( name=('edgedbsql', 'current_schemas'), args=( ('include_implicit', 'bool',), ), returns=('name[]',), language="plpgsql", volatility="stable", text=""" DECLARE search_path_str text; schema_list text[]; rv name[] := '{}'::name[]; is_valid_namespace boolean; has_pg_catalog boolean; resolved_schema text; BEGIN -- Get the current search_path from the emulated pg_settings view SELECT COALESCE(setting, 'public') INTO search_path_str FROM edgedbsql_VER.pg_settings WHERE name = 'search_path'; -- Split using our custom function with comma separator schema_list := edgedbsql_VER.split_identifier_string(search_path_str, ','); -- Handle implicit schemas if requested IF include_implicit THEN -- Temporary schema is not supported yet -- Check if pg_catalog is already present has_pg_catalog := 'pg_catalog' = ANY(schema_list); -- Add pg_catalog if not present and GUC variables allow it IF NOT has_pg_catalog THEN rv := array_append(rv, 'pg_catalog'::name); END IF; END IF; -- Process each schema element FOR i IN 1..array_length(schema_list, 1) LOOP -- Handle $user substitution IF schema_list[i] = '$user' THEN resolved_schema := current_user; ELSE resolved_schema := schema_list[i]; END IF; -- Check existence in namespace catalog IF EXISTS ( SELECT 1 FROM edgedbsql_VER.pg_namespace WHERE nspname = resolved_schema ) THEN rv := array_append(rv, resolved_schema::name); END IF; END LOOP; RETURN rv; END; """, ), trampoline.VersionedFunction( name=('edgedbsql', 'to_regclass'), args=( ('name_or_oid', 'text',), ), returns=('regclass',), language="plpgsql", text=""" DECLARE parts text[]; result regclass; BEGIN IF name_or_oid = '-' THEN RETURN 0::regclass; END IF; IF name_or_oid ~ '^[0-9]+$' THEN RETURN name_or_oid::oid::regclass; END IF; parts := parse_ident(name_or_oid); IF array_length(parts, 1) = 1 THEN SELECT pc.oid::regclass INTO result FROM unnest(edgedbsql_VER.current_schemas(true)) WITH ORDINALITY AS ns(nspname, ord) JOIN edgedbsql_VER.pg_namespace pn USING (nspname) JOIN edgedbsql_VER.pg_class pc ON pn.oid = pc.relnamespace WHERE pc.relname = parts[1] ORDER BY ns.ord LIMIT 1; ELSEIF array_length(parts, 1) = 2 THEN SELECT pc.oid::regclass INTO result FROM edgedbsql_VER.pg_class pc JOIN edgedbsql_VER.pg_namespace pn ON pn.oid = pc.relnamespace WHERE relname = parts[2] AND nspname = parts[1]; ELSE RAISE EXCEPTION 'improper relation name (too many dotted names): %', name_or_oid; END IF; RETURN result; END; """ ), # Unlike pg_catalog.to_regclass(), edgedbsql.to_regclass() also takes # numeric parameters to support compiled `::regclass` typecasting. trampoline.VersionedFunction( name=('edgedbsql', 'to_regclass'), args=( ('oid', 'integer',), ), returns=('regclass',), volatility="stable", text=""" SELECT oid::regclass """ ), trampoline.VersionedFunction( name=('edgedbsql', 'to_regclass'), args=( ('oid', 'smallint',), ), returns=('regclass',), volatility="stable", text=""" SELECT oid::regclass """ ), trampoline.VersionedFunction( name=('edgedbsql', 'to_regclass'), args=( ('oid', 'bigint',), ), returns=('regclass',), volatility="stable", text=""" SELECT oid::regclass """ ), trampoline.VersionedFunction( name=('edgedbsql', 'to_regclass'), args=( ('oid', 'oid',), ), returns=('regclass',), volatility="stable", text=""" SELECT oid::regclass """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_database_privilege'), args=( ('database_name', 'text'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT has_database_privilege(oid, privilege) FROM edgedbsql_VER.pg_database WHERE datname = database_name """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_database_privilege'), args=( ('database_oid', 'oid'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT has_database_privilege(database_oid, privilege) """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_schema_privilege'), args=( ('schema_name', 'text'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT COALESCE(( SELECT has_schema_privilege(oid, privilege) FROM edgedbsql_VER.pg_namespace WHERE nspname = schema_name ), TRUE); """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_schema_privilege'), args=( ('schema_oid', 'oid'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT COALESCE( has_schema_privilege(schema_oid, privilege), TRUE ) """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_table_privilege'), args=( ('table_name', 'text'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT has_table_privilege( edgedbsql_VER.to_regclass(table_name), privilege) """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_table_privilege'), args=( ('table_oid', 'oid'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT has_table_privilege(table_oid, privilege) """ ), # pg_catalog.has_column_privilege will return NULL for computed and # static columns. So we COALESCE to TRUE. trampoline.VersionedFunction( name=('edgedbsql', 'has_column_privilege'), args=( ('tbl', 'oid'), ('col', 'smallint'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT COALESCE(( SELECT has_column_privilege(tbl, col, privilege) ), TRUE) """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_column_privilege'), args=( ('tbl', 'text'), ('col', 'smallint'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT COALESCE(( SELECT has_column_privilege( edgedbsql_VER.to_regclass(tbl), col, privilege) ), TRUE) """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_column_privilege'), args=( ('tbl', 'oid'), ('col', 'text'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT COALESCE(( SELECT has_column_privilege(tbl, attnum_internal, privilege) FROM edgedbsql_VER.pg_attribute_ext pa WHERE attrelid = tbl AND attname = col ), TRUE) """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_column_privilege'), args=( ('tbl', 'text'), ('col', 'text'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT COALESCE(( SELECT has_column_privilege(pc.oid, attnum_internal, privilege) FROM edgedbsql_VER.pg_attribute_ext pa, LATERAL (SELECT edgedbsql_VER.to_regclass(tbl) AS oid) pc WHERE pa.attrelid = pc.oid AND pa.attname = col ), TRUE) """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_any_column_privilege'), args=( ('tbl', 'oid'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT has_any_column_privilege(tbl, privilege) """ ), trampoline.VersionedFunction( name=('edgedbsql', 'has_any_column_privilege'), args=( ('tbl', 'text'), ('privilege', 'text'), ), returns=('bool',), text=""" SELECT has_any_column_privilege( edgedbsql_VER.to_regclass(tbl), privilege) """ ), trampoline.VersionedFunction( name=('edgedbsql', '_pg_truetypid'), args=( ('att', ('edgedbsql_VER', 'pg_attribute')), ('typ', ('edgedbsql_VER', 'pg_type')), ), returns=('oid',), volatility='IMMUTABLE', strict=True, text=""" SELECT CASE WHEN typ.typtype = 'd' THEN typ.typbasetype ELSE att.atttypid END """ ), trampoline.VersionedFunction( name=('edgedbsql', '_pg_truetypmod'), args=( ('att', ('edgedbsql_VER', 'pg_attribute')), ('typ', ('edgedbsql_VER', 'pg_type')), ), returns=('int4',), volatility='IMMUTABLE', strict=True, text=""" SELECT CASE WHEN typ.typtype = 'd' THEN typ.typtypmod ELSE att.atttypmod END """ ), trampoline.VersionedFunction( name=('edgedbsql', 'pg_table_is_visible'), args=[ ('id', ('oid',)), ('search_path', ('text[]',)), ], returns=('bool',), volatility='stable', text=r''' SELECT pc.relnamespace IN ( SELECT oid FROM edgedbsql_VER.pg_namespace pn WHERE pn.nspname IN (select * from unnest(search_path)) ) FROM edgedbsql_VER.pg_class pc WHERE id = pc.oid ''' ), trampoline.VersionedFunction( # Used instead of pg_catalog.format_type in pg_dump. name=('edgedbsql', '_format_type'), args=[ ('typeoid', ('oid',)), ('typemod', ('integer',)), ], returns=('text',), volatility='STABLE', text=r''' SELECT CASE -- arrays WHEN t.typcategory = 'A' THEN ( SELECT quote_ident(nspname) || '.' || quote_ident(el.typname) || tm.mod || '[]' FROM edgedbsql_VER.pg_namespace WHERE oid = el.typnamespace ) -- composite (tuples) and types in irregular schemas WHEN ( t.typcategory = 'C' OR COALESCE(tn.nspname IN ( 'edgedb', 'edgedbt', 'edgedbpub', 'edgedbstd', 'edgedb_VER', 'edgedbstd_VER' ), TRUE) ) THEN ( SELECT quote_ident(nspname) || '.' || quote_ident(t.typname) || tm.mod FROM edgedbsql_VER.pg_namespace WHERE oid = t.typnamespace ) ELSE format_type(typeoid, typemod) END FROM edgedbsql_VER.pg_type t LEFT JOIN pg_namespace tn ON t.typnamespace = tn.oid LEFT JOIN edgedbsql_VER.pg_type el ON t.typelem = el.oid CROSS JOIN ( SELECT CASE WHEN typemod >= 0 THEN '(' || typemod::text || ')' ELSE '' END AS mod ) as tm WHERE t.oid = typeoid ''', ), trampoline.VersionedFunction( name=("edgedbsql", "pg_get_constraintdef"), args=[ ('conid', ('oid',)), ], returns=('text',), volatility='stable', text=r""" -- Wrap in a subquery SELECT so that we get a clear failure -- if something is broken and this returns multiple rows. -- (By default it would silently return the first.) SELECT ( SELECT CASE WHEN contype = 'p' THEN 'PRIMARY KEY(' || ( SELECT string_agg('"' || attname || '"', ', ') FROM edgedbsql_VER.pg_attribute WHERE attrelid = conrelid AND attnum = ANY(conkey) ) || ')' WHEN contype = 'f' THEN 'FOREIGN KEY ("' || ( SELECT attname FROM edgedbsql_VER.pg_attribute WHERE attrelid = conrelid AND attnum = ANY(conkey) ) || '")' || ' REFERENCES "' || pn.nspname || '"."' || pc.relname || '"(id)' ELSE '' END FROM edgedbsql_VER.pg_constraint con LEFT JOIN edgedbsql_VER.pg_class_tables pc ON pc.oid = confrelid LEFT JOIN edgedbsql_VER.pg_namespace pn ON pc.relnamespace = pn.oid WHERE con.oid = conid ) """ ), trampoline.VersionedFunction( name=("edgedbsql", "pg_get_constraintdef"), args=[ ('conid', ('oid',)), ('pretty', ('bool',)), ], returns=('text',), volatility='stable', text=r""" SELECT pg_get_constraintdef(conid) """ ), ] return ( [cast(dbops.Command, dbops.CreateFunction(uuid_to_oid))] + [dbops.CreateView(virtual_tables)] + [ cast(dbops.Command, dbops.CreateFunction(long_name)), cast(dbops.Command, dbops.CreateFunction(type_rename)), cast(dbops.Command, dbops.CreateFunction(namespace_rename)), cast(dbops.Command, dbops.CreateFunction(fe_pg_settings)), ] + [dbops.CreateView(view) for view in views] + [dbops.CreateFunction(func) for func in util_functions] ) @functools.cache def generate_sql_information_schema_refresh( backend_version: params.BackendVersion ) -> dbops.Command: refresh = dbops.CommandGroup() for command in _generate_sql_information_schema(backend_version): if ( isinstance(command, dbops.CreateView) and command.view.materialized ): refresh.add_command(dbops.Query( text=f'REFRESH MATERIALIZED VIEW {q(*command.view.name)}' )) return refresh class ObjectAncestorsView(trampoline.VersionedView): """A trampolined and explicit version of _SchemaObjectType__ancestors""" query = r''' SELECT source, target, index FROM edgedb_VER."_SchemaObjectType__ancestors" ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_object_ancestors'), query=self.query, ) class LinksView(trampoline.VersionedView): """A trampolined and explicit version of _SchemaLink""" query = r''' SELECT id, name, source, target FROM edgedb_VER."_SchemaLink" ''' def __init__(self) -> None: super().__init__( name=('edgedb', '_schema_links'), query=self.query, ) def get_config_type_views( schema: s_schema.Schema, conf: s_objtypes.ObjectType, scope: Optional[qltypes.ConfigScope], existing_view_columns: Optional[dict[str, list[str]]]=None, ) -> dbops.CommandGroup: commands = dbops.CommandGroup() cfg_views, _ = _generate_config_type_view( schema, conf, scope=scope, path=[], rptr=None, existing_view_columns=existing_view_columns, ) commands.add_commands([ dbops.CreateView( (trampoline.VersionedView if tn[0] == 'edgedbstd' else dbops.View)( name=tn, query=trampoline.fixup_query(q) ), or_replace=True, ) for tn, q in cfg_views ]) return commands def generate_drop_views( group: Sequence[dbops.Command | trampoline.Trampoline], preblock: dbops.PLBlock, ) -> None: for cv in reversed(list(group)): dv: Any if isinstance(cv, dbops.CreateView): # We try deleting both a MATERIALIZED and not materialized # version, since that allows us to switch between them # more easily. dv = dbops.CommandGroup() dv.add_command(dbops.DropView( cv.view.name, conditions=[dbops.ViewExists(cv.view.name)], )) dv.add_command(dbops.DropView( cv.view.name, conditions=[dbops.ViewExists(cv.view.name, materialized=True)], materialized=True, )) elif isinstance(cv, dbops.CreateFunction): dv = dbops.DropFunction( cv.function.name, args=cv.function.args or (), has_variadic=bool(cv.function.has_variadic), if_exists=True, ) elif isinstance(cv, trampoline.Trampoline): dv = cv.drop() else: raise AssertionError(f'unsupported support view command {cv}') dv.generate(preblock) def get_config_views( schema: s_schema.Schema, existing_view_columns: Optional[dict[str, list[str]]]=None, ) -> dbops.CommandGroup: commands = dbops.CommandGroup() conf = schema.get('cfg::Config', type=s_objtypes.ObjectType) commands.add_command( get_config_type_views( schema, conf, scope=None, existing_view_columns=existing_view_columns, ), ) conf = schema.get('cfg::InstanceConfig', type=s_objtypes.ObjectType) commands.add_command( get_config_type_views( schema, conf, scope=qltypes.ConfigScope.INSTANCE, existing_view_columns=existing_view_columns, ), ) conf = schema.get('cfg::DatabaseConfig', type=s_objtypes.ObjectType) commands.add_command( get_config_type_views( schema, conf, scope=qltypes.ConfigScope.DATABASE, existing_view_columns=existing_view_columns, ), ) return commands def get_synthetic_type_views( schema: s_schema.Schema, backend_params: params.BackendRuntimeParams, ) -> dbops.CommandGroup: commands = dbops.CommandGroup() commands.add_command(get_config_views(schema)) for dbview in _generate_branch_views(schema): commands.add_command(dbops.CreateView(dbview, or_replace=True)) for extview in _generate_extension_views(schema): commands.add_command(dbops.CreateView(extview, or_replace=True)) for extview in _generate_extension_migration_views(schema): commands.add_command(dbops.CreateView(extview, or_replace=True)) if backend_params.has_create_role: role_views = _generate_role_views(schema) else: role_views = _generate_single_role_views(schema) for roleview in role_views: commands.add_command(dbops.CreateView(roleview, or_replace=True)) for verview in _generate_schema_ver_views(schema): commands.add_command(dbops.CreateView(verview, or_replace=True)) if backend_params.has_stat_statements: for stats_view in _generate_stats_views(schema): commands.add_command(dbops.CreateView(stats_view, or_replace=True)) commands.add_command( dbops.CreateFunction( ResetQueryStatsFunction(True), or_replace=True ) ) return commands def _get_wrapper_views() -> dbops.CommandGroup: # Create some trampolined wrapper views around _Schema types we need # to reference from functions. wrapper_commands = dbops.CommandGroup() wrapper_commands.add_command( dbops.CreateView(ObjectAncestorsView(), or_replace=True)) wrapper_commands.add_command( dbops.CreateView(LinksView(), or_replace=True)) return wrapper_commands def get_support_views( schema: s_schema.Schema, backend_params: params.BackendRuntimeParams, ) -> tuple[dbops.CommandGroup, list[trampoline.Trampoline]]: commands = dbops.CommandGroup() schema_alias_views = _generate_schema_alias_views( schema, s_name.UnqualName('schema')) InhObject = schema.get( 'schema::InheritingObject', type=s_objtypes.ObjectType) InhObject__ancestors = InhObject.getptr( schema, s_name.UnqualName('ancestors'), type=s_links.Link) schema_alias_views.append( _generate_schema_alias_view(schema, InhObject__ancestors)) ObjectType = schema.get( 'schema::ObjectType', type=s_objtypes.ObjectType) ObjectType__ancestors = ObjectType.getptr( schema, s_name.UnqualName('ancestors'), type=s_links.Link) schema_alias_views.append( _generate_schema_alias_view(schema, ObjectType__ancestors)) for alias_view in schema_alias_views: commands.add_command(dbops.CreateView(alias_view, or_replace=True)) synthetic_types = get_synthetic_type_views(schema, backend_params) commands.add_command(synthetic_types) wrapper_commands = _get_wrapper_views() commands.add_command(wrapper_commands) sys_alias_views = _generate_schema_alias_views( schema, s_name.UnqualName('sys')) # Include sys::Role::member_of and sys::Role::permissions # to support DescribeRolesAsDDLFunction SysRole = schema.get( 'sys::Role', type=s_objtypes.ObjectType) SysRole__member_of = SysRole.getptr( schema, s_name.UnqualName('member_of')) SysRole__permissions = SysRole.getptr( schema, s_name.UnqualName('permissions')) SysRole__branches = SysRole.getptr( schema, s_name.UnqualName('branches')) sys_alias_views.append( _generate_schema_alias_view(schema, SysRole__member_of) ) sys_alias_views.append( _generate_schema_alias_view(schema, SysRole__permissions) ) sys_alias_views.append( _generate_schema_alias_view(schema, SysRole__branches) ) for alias_view in sys_alias_views: commands.add_command(dbops.CreateView(alias_view, or_replace=True)) commands.add_commands( _generate_sql_information_schema( backend_params.instance_params.version ) ) # The synthetic type views (cfg::, sys::) need to be trampolined trampolines = [] trampolines.extend(trampoline_command(synthetic_types)) trampolines.extend(trampoline_command(wrapper_commands)) return commands, trampolines async def generate_support_views( conn: PGConnection, schema: s_schema.Schema, backend_params: params.BackendRuntimeParams, ) -> list[trampoline.Trampoline]: commands, trampolines = get_support_views(schema, backend_params) block = dbops.PLTopBlock() commands.generate(block) await _execute_block(conn, block) return trampolines async def generate_support_functions( conn: PGConnection, schema: s_schema.Schema, ) -> list[trampoline.Trampoline]: commands = dbops.CommandGroup() cmds = [ dbops.CreateFunction(GetPgTypeForEdgeDBTypeFunction2(), or_replace=True), dbops.CreateFunction(IssubclassFunction()), dbops.CreateFunction(IssubclassFunction2()), dbops.CreateFunction(GetSchemaObjectNameFunction()), dbops.CreateFunction(ApproximateCount(), or_replace=True), ] commands.add_commands(cmds) block = dbops.PLTopBlock() commands.generate(block) await _execute_block(conn, block) return trampoline_functions(cmds) def _get_regenerated_config_support_functions( config_spec: edbconfig.Spec, ) -> dbops.CommandGroup: # Regenerate functions dependent on config spec. commands = dbops.CommandGroup() funcs = [ ApplySessionConfigFunction(config_spec), PostgresJsonConfigValueToFrontendConfigValueFunction(config_spec), ] cmds = [dbops.CreateFunction(func, or_replace=True) for func in funcs] commands.add_commands(cmds) return commands async def regenerate_config_support_functions( conn: PGConnection, config_spec: edbconfig.Spec, ) -> None: # Regenerate functions dependent on config spec. commands = _get_regenerated_config_support_functions(config_spec) block = dbops.PLTopBlock() commands.generate(block) await _execute_block(conn, block) async def generate_more_support_functions( conn: PGConnection, compiler: edbcompiler.Compiler, schema: s_schema.Schema, testmode: bool, ) -> list[trampoline.Trampoline]: commands = dbops.CommandGroup() cmds = [ dbops.CreateFunction( DescribeRolesAsDDLFunction(schema), or_replace=True ), dbops.CreateFunction( AllRoleMembershipsFunction(schema), or_replace=True ), dbops.CreateFunction(GetSequenceBackendNameFunction()), dbops.CreateFunction(DumpSequencesFunction()), ] commands.add_commands(cmds) block = dbops.PLTopBlock() commands.generate(block) await _execute_block(conn, block) return trampoline_functions(cmds) def _build_key_source( schema: s_schema.Schema, exc_props: Iterable[s_pointers.Pointer], rptr: Optional[s_pointers.Pointer], source_idx: str, ) -> str: if exc_props: restargets = [] for prop in exc_props: pname = prop.get_shortname(schema).name restarget = f'(q{source_idx}.val)->>{ql(pname)}' restargets.append(restarget) targetlist = ','.join(restargets) keysource = f''' (SELECT ARRAY[{targetlist}] AS key ) AS k{source_idx}''' else: assert rptr is not None rptr_name = rptr.get_shortname(schema).name keysource = f''' (SELECT ARRAY[ (CASE WHEN q{source_idx}.val = 'null'::jsonb THEN NULL ELSE {ql(rptr_name)} END) ] AS key ) AS k{source_idx}''' return keysource def _build_key_expr( key_components: list[str], versioned: bool, ) -> str: prefix = 'edgedb_VER' if versioned else 'edgedb' key_expr = ' || '.join(key_components) final_keysource = f''' (SELECT (CASE WHEN array_position(q.v, NULL) IS NULL THEN {prefix}.uuid_generate_v5( '{DATABASE_ID_NAMESPACE}'::uuid, array_to_string(q.v, ';') ) ELSE NULL END) AS key FROM (SELECT {key_expr} AS v) AS q )''' return final_keysource def _build_data_source( schema: s_schema.Schema, rptr: s_pointers.Pointer, source_idx: int, *, always_array: bool = False, alias: Optional[str] = None, ) -> str: rptr_name = rptr.get_shortname(schema).name rptr_card = rptr.get_cardinality(schema) rptr_multi = rptr_card.is_multi() if alias is None: alias = f'q{source_idx + 1}' else: alias = f'q{alias}' if rptr_multi: sourceN = f''' (SELECT jel.val FROM jsonb_array_elements( (q{source_idx}.val)->{ql(rptr_name)}) AS jel(val) ) AS {alias}''' else: proj = '[0]' if always_array else '' sourceN = f''' (SELECT (q{source_idx}.val){proj}->{ql(rptr_name)} AS val ) AS {alias}''' return sourceN def _escape_like(s: str) -> str: return s.replace('\\', '\\\\').replace('%', '\\%').replace('_', '\\_') def _generate_config_type_view( schema: s_schema.Schema, stype: s_objtypes.ObjectType, *, scope: Optional[qltypes.ConfigScope], path: list[tuple[s_pointers.Pointer, list[s_pointers.Pointer]]], rptr: Optional[s_pointers.Pointer], existing_view_columns: Optional[dict[str, list[str]]], override_exclusive_props: Optional[list[s_pointers.Pointer]] = None, _memo: Optional[set[s_obj.Object]] = None, ) -> tuple[ list[tuple[tuple[str, str], str]], list[s_pointers.Pointer], ]: X = xdedent.escape exc = schema.get('std::exclusive', type=s_constr.Constraint) if scope is not None: if scope is qltypes.ConfigScope.INSTANCE: max_source = "'system override'" elif scope is qltypes.ConfigScope.DATABASE: max_source = "'database'" else: raise AssertionError(f'unexpected config scope: {scope!r}') else: max_source = 'NULL' if _memo is None: _memo = set() _memo.add(stype) views = [] sources = [] ext_cfg = schema.get('cfg::ExtensionConfig', type=s_objtypes.ObjectType) is_ext_cfg = stype.issubclass(schema, ext_cfg) if is_ext_cfg: rptr = None is_rptr_ext_cfg = False # For extension configs, we want to use the trampolined version, # since we know it must exist already and don't want to have to # recreate the views on update. versioned = not is_ext_cfg or stype == ext_cfg prefix = 'edgedb_VER' if versioned else 'edgedb' if not path: if is_ext_cfg: # Extension configs get one object per scope. cfg_name = str(stype.get_name(schema)) escaped_name = _escape_like(cfg_name) source0 = f''' (SELECT (SELECT jsonb_object_agg( substr(name, {len(cfg_name) + 3}), value) AS val FROM {prefix}._read_sys_config( NULL, scope::edgedb._sys_config_source_t) cfg WHERE name LIKE {ql(escaped_name + '%')} ) AS val, scope::text AS scope, scope_id AS scope_id FROM (VALUES (NULL, '{CONFIG_ID[None]}'::uuid), ('database', '{CONFIG_ID[qltypes.ConfigScope.DATABASE]}'::uuid) ) AS s(scope, scope_id) ) AS q0 ''' elif rptr is None: # This is the root config object. source0 = f''' (SELECT jsonb_object_agg(name, value) AS val FROM {prefix}._read_sys_config(NULL, {max_source}) cfg) AS q0''' else: rptr_name = rptr.get_shortname(schema).name rptr_source = not_none(rptr.get_source(schema)) is_rptr_ext_cfg = rptr_source.issubclass(schema, ext_cfg) if is_rptr_ext_cfg: versioned = False prefix = 'edgedb' cfg_name = str(rptr_source.get_name(schema)) + '::' + rptr_name escaped_name = _escape_like(cfg_name) source0 = f''' (SELECT el.val AS val, s.scope::text AS scope, s.scope_id AS scope_id FROM (VALUES (NULL, '{CONFIG_ID[None]}'::uuid), ('database', '{CONFIG_ID[qltypes.ConfigScope.DATABASE]}'::uuid) ) AS s(scope, scope_id), LATERAL ( SELECT (value::jsonb) AS val FROM {prefix}._read_sys_config( NULL, scope::edgedb._sys_config_source_t) cfg WHERE name LIKE {ql(escaped_name + '%')} ) AS cfg, LATERAL jsonb_array_elements(cfg.val) AS el(val) ) AS q0 ''' else: source0 = f''' (SELECT el.val FROM (SELECT (value::jsonb) AS val FROM {prefix}._read_sys_config(NULL, {max_source}) WHERE name = {ql(rptr_name)}) AS cfg, LATERAL jsonb_array_elements(cfg.val) AS el(val) ) AS q0''' sources.append(source0) key_start = 0 else: # XXX: The second level is broken for extension configs. # Can we solve this without code duplication? root = path[0][0] root_source = not_none(root.get_source(schema)) is_root_ext_cfg = root_source.issubclass(schema, ext_cfg) assert not is_root_ext_cfg, ( "nested conf objects not yet supported for ext configs") key_start = 0 for i, (l, exc_props) in enumerate(path): l_card = l.get_cardinality(schema) l_multi = l_card.is_multi() l_name = l.get_shortname(schema).name if i == 0: if l_multi: sourceN = f''' (SELECT el.val FROM (SELECT (value::jsonb) AS val FROM {prefix}._read_sys_config(NULL, {max_source}) WHERE name = {ql(l_name)}) AS cfg, LATERAL jsonb_array_elements(cfg.val) AS el(val) ) AS q{i}''' else: sourceN = f''' (SELECT (value::jsonb) AS val FROM {prefix}._read_sys_config(NULL, {max_source}) cfg WHERE name = {ql(l_name)}) AS q{i}''' else: sourceN = _build_data_source(schema, l, i - 1) sources.append(sourceN) sources.append(_build_key_source(schema, exc_props, l, str(i))) if exc_props: key_start = i exclusive_props = [] single_links = [] multi_links = [] multi_props = [] target_cols: dict[s_pointers.Pointer, str] = {} where = '' path_steps = [p.get_shortname(schema).name for p, _ in path] if rptr is not None: self_idx = len(path) # Generate a source rvar for _this_ target rptr_name = rptr.get_shortname(schema).name path_steps.append(rptr_name) if self_idx > 0: sourceN = _build_data_source(schema, rptr, self_idx - 1) sources.append(sourceN) else: self_idx = 0 sval = f'(q{self_idx}.val)' for pp_name, pp in stype.get_pointers(schema).items(schema): pn = str(pp_name) if pn in ('id', '__type__'): continue pp_type = pp.get_target(schema) assert pp_type is not None pp_card = pp.get_cardinality(schema) pp_multi = pp_card.is_multi() pp_psi = types.get_pointer_storage_info(pp, schema=schema) pp_col = pp_psi.column_name if isinstance(pp, s_links.Link): if pp_multi: multi_links.append(pp) else: single_links.append(pp) else: pp_cast = _make_json_caster( schema, pp_type, versioned=versioned ) if pp_multi: multi_props.append((pp, pp_cast)) else: extract_col = ( f'{pp_cast(f"{sval}->{ql(pn)}")} AS {qi(pp_col)}') target_cols[pp] = extract_col constraints = pp.get_constraints(schema).objects(schema) if any(c.issubclass(schema, exc) for c in constraints): exclusive_props.append(pp) if override_exclusive_props: exclusive_props = [ stype.getptr( schema, s_name.UnqualName(p.get_shortname(schema).name) ) for p in override_exclusive_props ] exclusive_props.sort(key=lambda p: p.get_shortname(schema).name) if is_ext_cfg: # Extension configs get custom keys based on their type name # and the scope, since we create one object per scope. key_components = [ f'ARRAY[{ql(str(stype.get_name(schema)))}]', "ARRAY[coalesce(q0.scope, 'session')]" ] final_keysource = f'{_build_key_expr(key_components, versioned)} AS k' sources.append(final_keysource) key_expr = 'k.key' where = f"q0.val IS NOT NULL" elif exclusive_props or rptr: sources.append( _build_key_source(schema, exclusive_props, rptr, str(self_idx))) key_components = [f'k{i}.key' for i in range(key_start, self_idx + 1)] if is_rptr_ext_cfg: assert rptr_source key_components = [ f'ARRAY[{ql(str(rptr_source.get_name(schema)))}]', "ARRAY[coalesce(q0.scope, 'session')]" ] + key_components final_keysource = f'{_build_key_expr(key_components, versioned)} AS k' sources.append(final_keysource) key_expr = 'k.key' tname = str(stype.get_name(schema)) where = f"{key_expr} IS NOT NULL AND ({sval}->>'_tname') = {ql(tname)}" else: key_expr = f"'{CONFIG_ID[scope]}'::uuid" key_components = [] id_ptr = stype.getptr(schema, s_name.UnqualName('id')) target_cols[id_ptr] = f'{X(key_expr)} AS id' base_sources = list(sources) for link in single_links: link_name = link.get_shortname(schema).name link_type = link.get_target(schema) link_psi = types.get_pointer_storage_info(link, schema=schema) link_col = link_psi.column_name if str(link_type.get_name(schema)) == 'cfg::AbstractConfig': target_cols[link] = f'q0.scope_id AS {qi(link_col)}' continue if rptr is not None: target_path = path + [(rptr, exclusive_props)] else: target_path = path target_views, target_exc_props = _generate_config_type_view( schema, link_type, scope=scope, path=target_path, rptr=link, existing_view_columns=existing_view_columns, _memo=_memo, ) for descendant in link_type.descendants(schema): if descendant not in _memo: desc_views, _ = _generate_config_type_view( schema, descendant, scope=scope, path=target_path, rptr=link, existing_view_columns=existing_view_columns, override_exclusive_props=target_exc_props, _memo=_memo, ) views.extend(desc_views) target_source = _build_data_source( schema, link, self_idx, alias=link_name, always_array=rptr is None, ) sources.append(target_source) target_key_source = _build_key_source( schema, target_exc_props, link, source_idx=link_name) sources.append(target_key_source) target_key_components = key_components + [f'k{link_name}.key'] target_key = _build_key_expr(target_key_components, versioned) target_cols[link] = f'({X(target_key)}) AS {qi(link_col)}' views.extend(target_views) # You can't change the order of a postgres view... so # we have to maintain the original order. # # If we are applying patches that modify the config views, # then we will have an existing_view_columns map that tells us # the existing order in postgres. # If it isn't already in that map, then we order based on # the order in the pointers refdict, which will be the order # the pointers were created, *if* they were added to the in-memory # schema in this process. (If it was loaded from reflection, that # order won't be preserved, which is why we need existing_view_columns). # # FIXME: We should consider adding enough info to the schema to not need # this complication. existing_indexes = { v: i for i, v in enumerate(existing_view_columns.get(str(stype.id), [])) } if existing_view_columns else {} ptr_indexes = {} for i, v in enumerate(stype.get_pointers(schema).objects(schema)): # First try the id if (eidx := existing_indexes.get(str(v.id))) is not None: idx = (0, eidx) # Certain columns use their actual names, so try the actual # name also. elif ( eidx := existing_indexes.get(v.get_shortname(schema).name) ) is not None: idx = (0, eidx) # Not already in the database, use the order in pointers refdict else: idx = (1, i) ptr_indexes[v] = idx target_cols_sorted = sorted( target_cols.items(), key=lambda p: ptr_indexes[p[0]] ) target_cols_str = ',\n'.join([x for _, x in target_cols_sorted if x]) fromlist = ',\n'.join(f'LATERAL {X(s)}' for s in sources) target_query = xdedent.xdedent(f''' SELECT {X(target_cols_str)} FROM {X(fromlist)} ''') if where: target_query += f'\nWHERE\n {where}' views.append((tabname(schema, stype), target_query)) for link in multi_links: target_sources = list(base_sources) link_name = link.get_shortname(schema).name link_type = link.get_target(schema) if rptr is not None: target_path = path + [(rptr, exclusive_props)] else: target_path = path target_views, target_exc_props = _generate_config_type_view( schema, link_type, scope=scope, path=target_path, rptr=link, existing_view_columns=existing_view_columns, _memo=_memo, ) views.extend(target_views) for descendant in link_type.descendants(schema): if descendant not in _memo: desc_views, _ = _generate_config_type_view( schema, descendant, scope=scope, path=target_path, rptr=link, existing_view_columns=existing_view_columns, override_exclusive_props=target_exc_props, _memo=_memo, ) views.extend(desc_views) # HACK: For computable links (just extensions hopefully?), we # want to compile the targets as a side effect, but we don't # want to actually include them in the view. if link.get_computable(schema): continue target_source = _build_data_source( schema, link, self_idx, alias=link_name) target_sources.append(target_source) target_key_source = _build_key_source( schema, target_exc_props, link, source_idx=link_name) target_sources.append(target_key_source) target_key_components = key_components + [f'k{link_name}.key'] target_key = _build_key_expr(target_key_components, versioned) target_fromlist = ',\n'.join(f'LATERAL {X(s)}' for s in target_sources) link_query = xdedent.xdedent(f'''\ SELECT q.source, q.target FROM (SELECT {X(key_expr)} AS source, {X(target_key)} AS target FROM {X(target_fromlist)} ) q WHERE q.target IS NOT NULL ''') views.append((tabname(schema, link), link_query)) for prop, pp_cast in multi_props: target_sources = list(sources) pn = prop.get_shortname(schema).name target_source = _build_data_source( schema, prop, self_idx, alias=pn) target_sources.append(target_source) target_fromlist = ',\n'.join(f'LATERAL {X(s)}' for s in target_sources) link_query = xdedent.xdedent(f'''\ SELECT {X(key_expr)} AS source, {pp_cast(f'q{pn}.val')} AS target FROM {X(target_fromlist)} ''') if where: link_query += f'\nWHERE\n {where}' views.append((tabname(schema, prop), link_query)) return views, exclusive_props async def _execute_block( conn: PGConnection, block: dbops.SQLBlock, ) -> None: await execute_sql_script(conn, block.to_string()) async def execute_sql_script( conn: PGConnection, sql_text: str, ) -> None: from edb.server import pgcon if debug.flags.bootstrap: debug.header('Bootstrap Script') if len(sql_text) > 102400: # Make sure we don't hog CPU by attempting to highlight # huge scripts. print(sql_text) else: debug.dump_code(sql_text, lexer='sql') try: await conn.sql_execute(sql_text.encode("utf-8")) except pgcon.BackendError as e: position = e.get_field('P') internal_position = e.get_field('p') context = e.get_field('W') if context: pl_func_line_m = re.search( r'^PL/pgSQL function inline_code_block line (\d+).*', context, re.M) if pl_func_line_m: pl_func_line = int(pl_func_line_m.group(1)) else: pl_func_line = None point = None text = None if position is not None: point = int(position) - 1 text = sql_text elif internal_position is not None: point = int(internal_position) - 1 text = e.get_field('q') elif pl_func_line: point = ql_parser.offset_of_line(sql_text, pl_func_line) text = sql_text if point is not None: assert text span = qlast.Span( None, text, start=point, end=point, context_lines=30 ) exceptions.replace_context(e, span) raise ================================================ FILE: edb/pgsql/params.py ================================================ # Copyright (C) 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Any, Optional, Mapping, NamedTuple import enum import functools import locale from edb import buildmeta BackendVersion = buildmeta.BackendVersion class BackendCapabilities(enum.IntFlag): NONE = 0 #: Whether CREATE ROLE .. SUPERUSER is allowed SUPERUSER_ACCESS = 1 << 0 #: Whether reading PostgreSQL configuration files #: via pg_file_settings is allowed CONFIGFILE_ACCESS = 1 << 1 #: Whether the PostgreSQL server supports the C.UTF-8 locale C_UTF8_LOCALE = 1 << 2 #: Whether CREATE ROLE is allowed CREATE_ROLE = 1 << 3 #: Whether CREATE DATABASE is allowed CREATE_DATABASE = 1 << 4 #: Whether extension "edb_stat_statements" is available STAT_STATEMENTS = 1 << 5 ALL_BACKEND_CAPABILITIES = ( BackendCapabilities.SUPERUSER_ACCESS | BackendCapabilities.CONFIGFILE_ACCESS | BackendCapabilities.C_UTF8_LOCALE | BackendCapabilities.CREATE_ROLE | BackendCapabilities.CREATE_DATABASE | BackendCapabilities.STAT_STATEMENTS ) class BackendInstanceParams(NamedTuple): capabilities: BackendCapabilities version: BackendVersion tenant_id: str base_superuser: Optional[str] = None max_connections: int = 500 reserved_connections: int = 0 ext_schema: str = "edgedbext" """A Postgres schema where extensions can be created.""" existing_exts: Optional[Mapping[str, str]] = None """A map of preexisting extensions in the target backend with schemas.""" class BackendRuntimeParams(NamedTuple): instance_params: BackendInstanceParams session_authorization_role: Optional[str] = None @property def tenant_id(self) -> str: return self.instance_params.tenant_id @property def has_superuser_access(self) -> bool: return bool( self.instance_params.capabilities & BackendCapabilities.SUPERUSER_ACCESS ) @property def has_configfile_access(self) -> bool: return bool( self.instance_params.capabilities & BackendCapabilities.CONFIGFILE_ACCESS ) @property def has_c_utf8_locale(self) -> bool: return bool( self.instance_params.capabilities & BackendCapabilities.C_UTF8_LOCALE ) @property def has_create_role(self) -> bool: return bool( self.instance_params.capabilities & BackendCapabilities.CREATE_ROLE ) @property def has_create_database(self) -> bool: return bool( self.instance_params.capabilities & BackendCapabilities.CREATE_DATABASE ) @property def has_stat_statements(self) -> bool: return self.has_superuser_access and bool( self.instance_params.capabilities & BackendCapabilities.STAT_STATEMENTS ) @functools.lru_cache def get_default_runtime_params( **instance_params: Any, ) -> BackendRuntimeParams: capabilities = ALL_BACKEND_CAPABILITIES if not _is_c_utf8_locale_present(): capabilities &= ~BackendCapabilities.C_UTF8_LOCALE instance_params.setdefault('capabilities', capabilities) if 'tenant_id' not in instance_params: instance_params = dict( tenant_id=buildmeta.get_default_tenant_id(), **instance_params, ) if 'version' not in instance_params: try: version = buildmeta.get_pg_version() except buildmeta.MetadataError as _: # HACK: if get_pg_version fails, this means we have no pg_config, # which happens for edgedb-ls. It is invoking pg compiler from # schema delta. Ideally, schema delta would not need pg compiler, # but that would require a lot of cleanups. version = BackendVersion( major=100, minor=0, micro=0, releaselevel='final', serial=0, string='100.0' ) instance_params = dict( version=version, **instance_params, ) return BackendRuntimeParams( instance_params=BackendInstanceParams(**instance_params), ) def _is_c_utf8_locale_present() -> bool: try: locale.setlocale(locale.LC_CTYPE, 'C.UTF-8') except Exception: return False else: # We specifically don't use locale.getlocale(), because # it can lie and return a non-existent locale due to PEP 538. locale.setlocale(locale.LC_CTYPE, '') return True ================================================ FILE: edb/pgsql/parser/.gitignore ================================================ /*.c ================================================ FILE: edb/pgsql/parser/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2010-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import json from edb.pgsql import ast as pgast from . import ast_builder from . import parser from .parser import ( Source, NormalizedSource, deserialize, ) __all__ = ( "parse", "Source", "NormalizedSource", "deserialize" ) def parse( sql_query: str, propagate_spans: bool = False ) -> list[pgast.Query | pgast.Statement]: ast_json = parser.pg_parse(sql_query) return ast_builder.build_stmts( json.loads(ast_json), sql_query, propagate_spans, ) ================================================ FILE: edb/pgsql/parser/ast_builder.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2010-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import dataclasses from typing import ( Any, Callable, Optional, Sequence, cast, ) from edb.common import span from edb.common.parsing import Span from edb.pgsql import ast as pgast from edb.edgeql import ast as qlast from edb.pgsql.parser.exceptions import PSqlUnsupportedError, get_node_name @dataclasses.dataclass(kw_only=True) class Context: source_sql: str has_fallback = False # Node = bool | str | int | float | List[Any] | dict[str, Any] Node = Any def build_stmts( node: Node, source_sql: str, propagate_spans: bool ) -> list[pgast.Query | pgast.Statement]: ctx = Context(source_sql=source_sql) try: res = [_build_stmt(node["stmt"], ctx) for node in node["stmts"]] except IndexError: raise PSqlUnsupportedError() except PSqlUnsupportedError as e: if e.node and "location" in e.node: e.location = e.node["location"] assert e.location e.message += f" near location {e.location}:" e.message += source_sql[e.location : (e.location + 50)] raise if propagate_spans: # we need to do a full pass of span propagation, because some # nodes (CommonTableExpr) have span, but their children don't (Insert). span.SpanPropagator(full_pass=True).container_visit(res) return res def _maybe[T]( node: Node, ctx: Context, name: str, builder: Callable[[Node, Context], T] ) -> Optional[T]: if name in node: return builder(node[name], ctx) return None def _ident(t: Any) -> Any: return t def _list[T, U]( node: Node, ctx: Context, name: str, builder: Callable[[Node, Context], T], mapper: Callable[[T], U] = _ident, *, unwrap: Optional[str] = None, ) -> list[U]: if unwrap is not None: return [ mapper(builder(_unwrap(n, unwrap), ctx)) for n in node.get(name, []) ] else: return [mapper(builder(n, ctx)) for n in node.get(name, [])] def _maybe_list[T, U]( node: Node, ctx: Context, name: str, builder: Callable[[Node, Context], T], mapper: Callable[[T], U] = _ident, *, unwrap: Optional[str] = None, ) -> Optional[list[U]]: return ( _list(node, ctx, name, builder, mapper, unwrap=unwrap) if name in node else None ) def _enum[T]( _ty: type[T], node: Node, ctx: Context, builders: dict[str, Callable[[Node, Context], T]], fallbacks: Sequence[Callable[[Node, Context], T]] = (), ) -> T: outer_fallback = ctx.has_fallback ctx.has_fallback = False try: for name in builders: if name in node: builder = builders[name] return builder(node[name], ctx) ctx.has_fallback = True for fallback in fallbacks: res = fallback(node, ctx) if res: return res if outer_fallback: return None # type: ignore raise PSqlUnsupportedError( node, ", ".join(get_node_name(k) for k in node.keys()) ) finally: ctx.has_fallback = outer_fallback def _build_str(node: Node, _: Context) -> str: node = _unwrap_string(node) return str(node) def _build_bool(node: Node, _: Context) -> bool: assert isinstance(node, bool) return node def _bool_or_false(node: Node, name: str) -> bool: return node[name] if name in node else False def _unwrap(node: Node, name: str) -> Node: if isinstance(node, dict) and name in node: return node[name] return node def _unwrap_boolean(n: Node) -> Node: n = _unwrap(n, 'Boolean') n = _unwrap(n, 'str') n = _unwrap(n, 'boolval') n = _unwrap(n, 'boolval') if isinstance(n, dict) and len(n) == 0: n = False return n def _unwrap_int(n: Node) -> Node: n = _unwrap(n, 'Integer') n = _unwrap(n, 'str') n = _unwrap(n, 'ival') n = _unwrap(n, 'ival') if isinstance(n, dict) and len(n) == 0: n = 0 return n def _unwrap_float(n: Node) -> Node: n = _unwrap(n, 'Float') n = _unwrap(n, 'str') n = _unwrap(n, 'fval') n = _unwrap(n, 'fval') return n def _unwrap_string(n: Node) -> Node: n = _unwrap(n, 'String') n = _unwrap(n, 'str') n = _unwrap(n, 'sval') n = _unwrap(n, 'sval') return n def _as_column_ref(name: str) -> pgast.ColumnRef: return pgast.ColumnRef( name=(name,), ) def _build_span(n: Node, c: Context) -> Optional[Span]: if 'location' not in n: return None return Span( filename=None, buffer=c.source_sql, start=n["location"], end=n["location"], ) def _build_stmt(node: Node, c: Context) -> pgast.Query | pgast.Statement: return _enum( pgast.Query | pgast.Statement, # type: ignore node, c, { "VariableSetStmt": _build_variable_set_stmt, "VariableShowStmt": _build_variable_show_stmt, "TransactionStmt": _build_transaction_stmt, "PrepareStmt": _build_prepare, "ExecuteStmt": _build_execute, "DeallocateStmt": _build_deallocate, "CreateStmt": _build_create, "CreateTableAsStmt": _build_create_table_as, "LockStmt": _build_lock, "CopyStmt": _build_copy, }, [_build_query], ) def _build_query(node: Node, c: Context) -> pgast.Query: return _enum( pgast.Query, node, c, { "SelectStmt": _build_select_stmt, "InsertStmt": _build_insert_stmt, "UpdateStmt": _build_update_stmt, "DeleteStmt": _build_delete_stmt, }, ) def _build_select_stmt(n: Node, c: Context) -> pgast.SelectStmt: op = _maybe(n, c, "op", _build_str) if op: op = op[6:] if op == "NONE": op = None return pgast.SelectStmt( distinct_clause=( _maybe(n, c, "distinctClause", _build_distinct) # type: ignore ), target_list=_maybe_list(n, c, "targetList", _build_res_target) or [], from_clause=_maybe_list(n, c, "fromClause", _build_base_range_var) or [], where_clause=_maybe(n, c, "whereClause", _build_base_expr), group_clause=_maybe_list(n, c, "groupClause", _build_base), having_clause=_maybe(n, c, "havingClause", _build_base_expr), window_clause=_maybe_list(n, c, "windowClause", _build_base), values=_maybe_list(n, c, "valuesLists", _build_base_expr), sort_clause=_maybe_list(n, c, "sortClause", _build_sort_by), limit_offset=_maybe(n, c, "limitOffset", _build_base_expr), limit_count=_maybe(n, c, "limitCount", _build_base_expr), locking_clause=_maybe_list( n, c, "lockingClause", _build_locking_clause ), op=op, all=n["all"] if "all" in n else False, larg=_maybe(n, c, "larg", _build_select_stmt), rarg=_maybe(n, c, "rarg", _build_select_stmt), ctes=_maybe(n, c, "withClause", _build_ctes), ) def _build_insert_stmt(n: Node, c: Context) -> pgast.InsertStmt: select_stmt = _maybe(n, c, "selectStmt", _build_stmt) assert select_stmt is None or isinstance(select_stmt, pgast.Query) return pgast.InsertStmt( relation=_build_rel_range_var(n["relation"], c), returning_list=( _maybe_list(n, c, "returningList", _build_res_target) or [] ), cols=_maybe_list(n, c, "cols", _build_insert_target), select_stmt=select_stmt, on_conflict=_maybe(n, c, "onConflictClause", _build_on_conflict), ctes=_maybe(n, c, "withClause", _build_ctes), ) def _build_update_stmt(n: Node, c: Context) -> pgast.UpdateStmt: return pgast.UpdateStmt( relation=_build_rel_range_var(n["relation"], c), targets=_maybe(n, c, "targetList", _build_update_targets) or [], where_clause=_maybe(n, c, "whereClause", _build_base_expr), from_clause=( _maybe_list(n, c, "fromClause", _build_base_range_var) or [] ), returning_list=( _maybe_list(n, c, "returningList", _build_res_target) or [] ), ctes=_maybe(n, c, "withClause", _build_ctes), ) def _build_delete_stmt(n: Node, c: Context) -> pgast.DeleteStmt: return pgast.DeleteStmt( relation=_build_rel_range_var(n["relation"], c), returning_list=( _maybe_list(n, c, "returningList", _build_res_target) or [] ), where_clause=_maybe(n, c, "whereClause", _build_base_expr), using_clause=( _maybe_list(n, c, "usingClause", _build_base_range_var) or [] ), ctes=_maybe(n, c, "withClause", _build_ctes), ) def _build_lock(n: Node, c: Context) -> pgast.LockStmt: MODES = { 1: 'ACCESS SHARE', 2: 'ROW SHARE', 3: 'ROW EXCLUSIVE', 4: 'SHARE UPDATE EXCLUSIVE', 5: 'SHARE', 6: 'SHARE ROW EXCLUSIVE', 7: 'EXCLUSIVE', 8: 'ACCESS EXCLUSIVE', } return pgast.LockStmt( relations=_list(n, c, "relations", _build_base_range_var), mode=MODES[n['mode']], no_wait=_bool_or_false(n, 'nowait'), ) def _build_variable_set_stmt(n: Node, c: Context) -> pgast.Statement: tx_only_vars = { "transaction_isolation", "transaction_read_only", "transaction_deferrable", } if n["kind"] == "VAR_RESET": return pgast.VariableResetStmt( name=n["name"], scope=( pgast.OptionsScope.TRANSACTION if n["name"] in tx_only_vars else pgast.OptionsScope.SESSION ), span=_build_span(n, c), ) if n["kind"] == "VAR_RESET_ALL": return pgast.VariableResetStmt( name=None, scope=pgast.OptionsScope.SESSION, span=_build_span(n, c), ) if n["kind"] == "VAR_SET_MULTI": name = n["name"] if name == "TRANSACTION" or name == "SESSION CHARACTERISTICS": return pgast.SetTransactionStmt( options=_build_transaction_options(n["args"], c), scope=( pgast.OptionsScope.TRANSACTION if name == "TRANSACTION" else pgast.OptionsScope.SESSION ), ) if n["kind"] == "VAR_SET_VALUE": return pgast.VariableSetStmt( name=n["name"], args=pgast.ArgsList(args=_list(n, c, "args", _build_base_expr)), scope=( pgast.OptionsScope.TRANSACTION if n["name"] in tx_only_vars or ("is_local" in n and n["is_local"]) else pgast.OptionsScope.SESSION ), span=_build_span(n, c), ) if n["kind"] == "VAR_SET_DEFAULT": return pgast.VariableResetStmt( name=n["name"], scope=( pgast.OptionsScope.TRANSACTION if n["name"] in tx_only_vars or ("is_local" in n and n["is_local"]) else pgast.OptionsScope.SESSION ), span=_build_span(n, c), ) raise PSqlUnsupportedError(n) def _build_variable_show_stmt(n: Node, c: Context) -> pgast.Statement: return pgast.VariableShowStmt( name=n["name"], span=_build_span(n, c), ) def _build_transaction_stmt(n: Node, c: Context) -> pgast.TransactionStmt: def begin(n: Node, c: Context) -> pgast.BeginStmt: return pgast.BeginStmt( options=_maybe(n, c, "options", _build_transaction_options) ) def start(n: Node, c: Context) -> pgast.StartStmt: return pgast.StartStmt( options=_maybe(n, c, "options", _build_transaction_options) ) def commit(n: Node, _: Context) -> pgast.CommitStmt: return pgast.CommitStmt(chain=_bool_or_false(n, "chain")) def rollback(n: Node, _: Context) -> pgast.RollbackStmt: return pgast.RollbackStmt(chain=_bool_or_false(n, "chain")) def savepoint(n: Node, _: Context) -> pgast.SavepointStmt: return pgast.SavepointStmt(savepoint_name=n["savepoint_name"]) def release(n: Node, _: Context) -> pgast.ReleaseStmt: return pgast.ReleaseStmt(savepoint_name=n["savepoint_name"]) def rollback_to(n: Node, _: Context) -> pgast.RollbackToStmt: return pgast.RollbackToStmt(savepoint_name=n["savepoint_name"]) def prepare(n: Node, _: Context) -> pgast.PrepareTransaction: return pgast.PrepareTransaction(gid=n["gid"]) def commit_prepared(n: Node, _: Context) -> pgast.CommitPreparedStmt: return pgast.CommitPreparedStmt(gid=n["gid"]) def rollback_prepared(n: Node, _: Context) -> pgast.RollbackPreparedStmt: return pgast.RollbackPreparedStmt(gid=n["gid"]) kinds = { "BEGIN": begin, "START": start, "COMMIT": commit, "ROLLBACK": rollback, "SAVEPOINT": savepoint, "RELEASE": release, "ROLLBACK_TO": rollback_to, "PREPARE": prepare, "COMMIT_PREPARED": commit_prepared, "ROLLBACK_PREPARED": rollback_prepared, } kind = cast(str, n["kind"]).removeprefix("TRANS_STMT_") if kind not in kinds: raise PSqlUnsupportedError(n) return kinds[kind](n, c) def _build_transaction_options( nodes: list[Node], c: Context ) -> pgast.TransactionOptions: options = {} for n in nodes: if "DefElem" not in n: continue def_e = n["DefElem"] if not ("defname" in def_e and "arg" in def_e): continue options[def_e["defname"]] = _build_base_expr(def_e["arg"], c) return pgast.TransactionOptions(options=options) def _build_prepare(n: Node, c: Context) -> pgast.PrepareStmt: return pgast.PrepareStmt( name=n["name"], argtypes=_maybe_list(n, c, "argtypes", _build_type_name), query=_build_base_relation(n["query"], c), ) def _build_execute(n: Node, c: Context) -> pgast.ExecuteStmt: return pgast.ExecuteStmt( name=n["name"], params=_maybe_list(n, c, "params", _build_base_expr) ) def _build_deallocate(n: Node, c: Context) -> pgast.DeallocateStmt: return pgast.DeallocateStmt(name=n["name"]) def _build_create_table_as(n: Node, c: Context) -> pgast.CreateTableAsStmt: return pgast.CreateTableAsStmt( into=_build_create(n['into'], c), query=_build_query(n['query'], c), with_no_data=_bool_or_false(n['into'], 'skipData'), ) def _build_create(n: Node, c: Context) -> pgast.CreateStmt: def _build_on_commit(n: str, _c: Context) -> Optional[str]: on_commit = n[9:] return on_commit if on_commit != 'NOOP' else None relation = n['relation'] if 'relation' in n else n['rel'] return pgast.CreateStmt( relation=_build_relation(relation, c), table_elements=_maybe_list(n, c, 'tableElts', _build_table_element) or [], span=_build_span(n, c), on_commit=_maybe(n, c, 'oncommit', _build_on_commit), ) def _build_table_element(n: Node, c: Context) -> pgast.TableElement: return _enum( pgast.TableElement, n, c, { "ColumnDef": _build_column_def, }, ) def _build_column_def(n: Node, c: Context) -> pgast.ColumnDef: is_not_null = False default_expr = None if 'constraints' in n: for constraint in n['constraints']: constraint = _unwrap(constraint, 'Constraint') if constraint['contype'] == 'CONSTR_NOTNULL': is_not_null = True if constraint['contype'] == 'CONSTR_DEFAULT': is_not_null = True default_expr = _maybe(n, c, 'raw_expr', _build_base_expr) return pgast.ColumnDef( name=n['colname'], typename=_build_type_name(n['typeName'], c), default_expr=default_expr, is_not_null=is_not_null, span=_build_span(n, c), ) def _build_base(n: Node, c: Context) -> pgast.Base: return _enum( pgast.Base, n, c, { "CommonTableExpr": _build_cte, }, [_build_base_expr], # type: ignore ) def _build_base_expr(node: Node, c: Context) -> pgast.BaseExpr: return _enum( pgast.BaseExpr, node, c, { "ResTarget": _build_res_target, "FuncCall": _build_func_call, "CoalesceExpr": _build_coalesce, "List": _build_implicit_row, "A_Expr": _build_a_expr, "A_ArrayExpr": _build_array_expr, "A_Const": _build_const, "A_Indirection": _build_indirection, "BoolExpr": _build_bool_expr, "CaseExpr": _build_case_expr, "TypeCast": _build_type_cast, "NullTest": _build_null_test, "BooleanTest": _build_boolean_test, "RowExpr": _build_row_expr, "SubLink": _build_sub_link, "ParamRef": _build_param_ref, "SetToDefault": _build_keyword("DEFAULT"), # type: ignore "SQLValueFunction": _build_sql_value_function, "CollateClause": _build_collate_clause, "MinMaxExpr": _build_min_max_expr, }, [_build_base_range_var, _build_indirection_op], # type: ignore ) def _build_distinct(nodes: list[Node], c: Context) -> list[pgast.Base]: # For some reason, plain DISTINCT is parsed as [{}] # In our AST this is represented by [pgast.Star()] if len(nodes) == 1 and len(nodes[0]) == 0: return [pgast.Star()] return [_build_base_expr(n, c) for n in nodes] def _build_indirection(n: Node, c: Context) -> pgast.Indirection: return pgast.Indirection( arg=_build_base_expr(n['arg'], c), indirection=_list(n, c, 'indirection', _build_indirection_op), ) def _build_indirection_op(n: Node, c: Context) -> pgast.IndirectionOp: return _enum( pgast.IndirectionOp, # type: ignore n, c, { 'A_Indices': _build_index_or_slice, 'Star': _build_star, 'ColumnRef': _build_column_ref, 'String': _build_record_indirection_op, }, ) def _build_record_indirection_op( n: Node, c: Context ) -> pgast.RecordIndirectionOp: return pgast.RecordIndirectionOp(name=_build_str(n, c)) def _build_ctes(n: Node, c: Context) -> list[pgast.CommonTableExpr]: is_recursive = _maybe(n, c, 'recursive', lambda x, _: bool(x)) or False ctes: list[pgast.CommonTableExpr] = _list(n, c, "ctes", _build_cte) for cte in ctes: cte.recursive = is_recursive return ctes def _build_cte(n: Node, c: Context) -> pgast.CommonTableExpr: n = _unwrap(n, "CommonTableExpr") materialized = None if n["ctematerialized"] == "CTEMaterializeAlways": materialized = True elif n["ctematerialized"] == "CTEMaterializeNever": materialized = False return pgast.CommonTableExpr( name=n["ctename"], query=_build_query(n["ctequery"], c), recursive=False, aliascolnames=_maybe_list(n, c, "aliascolnames", _build_str), materialized=materialized, span=_build_span(n, c), ) def _build_keyword(name: str) -> Callable[[Node, Context], pgast.Keyword]: return lambda n, c: pgast.Keyword(name=name, span=_build_span(n, c)) def _build_param_ref(n: Node, c: Context) -> pgast.ParamRef: return pgast.ParamRef(number=n.get("number", 0), span=_build_span(n, c)) def _build_collate_clause(n: Node, c: Context) -> pgast.CollateClause: return pgast.CollateClause( arg=_build_base_expr(n['arg'], c), collname=tuple(_list(n, c, 'collname', _build_str)), span=_build_span(n, c), ) def _build_min_max_expr(n: Node, c: Context) -> pgast.MinMaxExpr: if n['op'] == 'IS_GREATEST': op = 'GREATEST' elif n['op'] == 'IS_LEAST': op = 'LEAST' else: raise PSqlUnsupportedError(n) return pgast.MinMaxExpr( op=op, args=_list(n, c, 'args', _build_base_expr), span=_build_span(n, c), ) def _build_sql_value_function(n: Node, c: Context) -> pgast.SQLValueFunction: op = n["op"].removeprefix("SVFOP_") op_mapping = { "CURRENT_DATE": pgast.SQLValueFunctionOP.CURRENT_DATE, "CURRENT_TIME": pgast.SQLValueFunctionOP.CURRENT_TIME, "CURRENT_TIME_N": pgast.SQLValueFunctionOP.CURRENT_TIME_N, "CURRENT_TIMESTAMP": pgast.SQLValueFunctionOP.CURRENT_TIMESTAMP, "CURRENT_TIMESTAMP_N": pgast.SQLValueFunctionOP.CURRENT_TIMESTAMP_N, "LOCALTIME": pgast.SQLValueFunctionOP.LOCALTIME, "LOCALTIME_N": pgast.SQLValueFunctionOP.LOCALTIME_N, "LOCALTIMESTAMP": pgast.SQLValueFunctionOP.LOCALTIMESTAMP, "LOCALTIMESTAMP_N": pgast.SQLValueFunctionOP.LOCALTIMESTAMP_N, "CURRENT_ROLE": pgast.SQLValueFunctionOP.CURRENT_ROLE, "CURRENT_USER": pgast.SQLValueFunctionOP.CURRENT_USER, "USER": pgast.SQLValueFunctionOP.USER, "SESSION_USER": pgast.SQLValueFunctionOP.SESSION_USER, "CURRENT_CATALOG": pgast.SQLValueFunctionOP.CURRENT_CATALOG, "CURRENT_SCHEMA": pgast.SQLValueFunctionOP.CURRENT_SCHEMA, } if op not in op_mapping: raise PSqlUnsupportedError(n) return pgast.SQLValueFunction( op=op_mapping[op], arg=_maybe(n, c, "xpr", _build_base_expr) ) def _build_sub_link(n: Node, c: Context) -> pgast.SubLink: typ = n["subLinkType"] if typ == "EXISTS_SUBLINK": operator = "EXISTS" elif typ == "NOT_EXISTS_SUBLINK": operator = "NOT_EXISTS" elif typ == "ALL_SUBLINK": operator = "ALL" elif typ == "ANY_SUBLINK": operator = "= ANY" elif typ == "EXPR_SUBLINK": operator = None elif typ == "ARRAY_SUBLINK": operator = "ARRAY" else: raise PSqlUnsupportedError(n) return pgast.SubLink( operator=operator, expr=_build_query(n["subselect"], c), test_expr=_maybe(n, c, 'testexpr', _build_base_expr), span=_build_span(n, c), ) def _build_row_expr(n: Node, c: Context) -> pgast.BaseExpr: args: list[pgast.BaseExpr] = ( _maybe_list(n, c, "args", _build_base_expr) or [] ) format = n.get('row_format', None) if format in {'COERCE_EXPLICIT_CALL', 'COERCE_EXPLICIT_CAST'}: return pgast.RowExpr(args=args, span=_build_span(n, c)) else: return pgast.ImplicitRowExpr(args=args, span=_build_span(n, c)) def _build_boolean_test(n: Node, c: Context) -> pgast.BooleanTest: return pgast.BooleanTest( arg=_build_base_expr(n["arg"], c), negated=n["booltesttype"].startswith("IS_NOT"), is_true=n["booltesttype"].endswith("TRUE"), span=_build_span(n, c), ) def _build_null_test(n: Node, c: Context) -> pgast.NullTest: return pgast.NullTest( arg=_build_base_expr(n["arg"], c), negated=n["nulltesttype"] == "IS_NOT_NULL", span=_build_span(n, c), ) def _build_type_cast(n: Node, c: Context) -> pgast.TypeCast: return pgast.TypeCast( arg=_build_base_expr(n["arg"], c), type_name=_build_type_name(n["typeName"], c), span=_build_span(n, c), ) def _build_type_name(n: Node, c: Context) -> pgast.TypeName: n = _unwrap(n, "TypeName") name: tuple[str, ...] = tuple(_list(n, c, "names", _build_str)) # we don't escape char properly, so let's just resolve it during parsing if name == ("char",): name = ("pg_catalog", "char") def unwrap_int_builder(n: Node, _c: Context) -> Node: return _unwrap_int(n) return pgast.TypeName( name=name, setof=_bool_or_false(n, "setof"), typmods=None, array_bounds=_maybe_list(n, c, "arrayBounds", unwrap_int_builder), span=_build_span(n, c), ) def _build_case_expr(n: Node, c: Context) -> pgast.CaseExpr: return pgast.CaseExpr( arg=_maybe(n, c, "arg", _build_base_expr), args=_list(n, c, "args", _build_case_when), defresult=_maybe(n, c, "defresult", _build_base_expr), span=_build_span(n, c), ) def _build_case_when(n: Node, c: Context) -> pgast.CaseWhen: n = _unwrap(n, "CaseWhen") return pgast.CaseWhen( expr=_build_base_expr(n["expr"], c), result=_build_base_expr(n["result"], c), span=_build_span(n, c), ) def _build_bool_expr(n: Node, c: Context) -> pgast.Expr: name = _build_str(n["boolop"], c)[0:-5] args = list(n["args"]) res = pgast.Expr( name=name, lexpr=_build_base_expr(args.pop(0), c) if len(args) > 1 else None, rexpr=_build_base_expr(args.pop(0), c) if len(args) > 0 else None, span=_build_span(n, c), ) while len(args) > 0: res = pgast.Expr( name=_build_str(n["boolop"], c)[0:-5], lexpr=res, rexpr=_build_base_expr(args.pop(0), c) if len(args) > 0 else None, span=_build_span(n, c), ) return res def _build_base_range_var(n: Node, c: Context) -> pgast.BaseRangeVar: return _enum( pgast.BaseRangeVar, n, c, { "RangeVar": _build_rel_range_var, "JoinExpr": _build_join_expr, "RangeFunction": _build_range_function, "RangeSubselect": _build_range_subselect, }, ) def _build_const(n: Node, c: Context) -> pgast.BaseConstant: n = _unwrap(n, "val") span = _build_span(n, c) if "Null" in n or "isnull" in n: return pgast.NullConstant(span=span) if "Boolean" in n or "boolval" in n: return pgast.BooleanConstant(val=_unwrap_boolean(n), span=span) if "Integer" in n or "ival" in n: return pgast.NumericConstant(val=str(_unwrap_int(n)), span=span) if "Float" in n or "fval" in n: return pgast.NumericConstant(val=_unwrap_float(n), span=span) if "String" in n or "sval" in n: return pgast.StringConstant(val=_unwrap_string(n), span=span) if "BitString" in n or "bsval" in n: n = _unwrap(n, 'BitString') n = _unwrap(n, 'str') n = _unwrap(n, 'bsval') n = _unwrap(n, 'bsval') return pgast.BitStringConstant(kind=n[0], val=n[1:], span=span) raise PSqlUnsupportedError(n) def _build_range_subselect(n: Node, c: Context) -> pgast.RangeSubselect: return pgast.RangeSubselect( alias=_maybe(n, c, "alias", _build_alias) or pgast.Alias(aliasname=""), lateral=_bool_or_false(n, "lateral"), subquery=_build_query(n["subquery"], c), ) def _build_range_function(n: Node, c: Context) -> pgast.RangeFunction: return pgast.RangeFunction( alias=_maybe(n, c, "alias", _build_alias) or pgast.Alias(aliasname=""), lateral=_bool_or_false(n, "lateral"), with_ordinality=_bool_or_false(n, "ordinality"), is_rowsfrom=_bool_or_false(n, "is_rowsfrom"), functions=[ _build_base_expr(fn, c) for fn in n["functions"][0]["List"]["items"] if len(fn) > 0 ], ) def _build_join_expr(n: Node, c: Context) -> pgast.JoinExpr: return pgast.JoinExpr( alias=_maybe(n, c, "alias", _build_alias) or pgast.Alias(aliasname=""), larg=_build_base_range_var(n["larg"], c), joins=[ pgast.JoinClause( type=n["jointype"][5:], rarg=_build_base_range_var(n["rarg"], c), using_clause=_maybe_list( n, c, "usingClause", _build_str, _as_column_ref ), quals=_maybe(n, c, "quals", _build_base_expr), ) ], ) def _build_rel_range_var(n: Node, c: Context) -> pgast.RelRangeVar: return pgast.RelRangeVar( alias=_maybe(n, c, "alias", _build_alias) or pgast.Alias(aliasname=""), relation=_build_relation(n, c), include_inherited=_bool_or_false(n, "inh"), span=_build_span(n, c), ) def _build_alias(n: Node, c: Context) -> pgast.Alias: return pgast.Alias( aliasname=_build_str(n["aliasname"], c), colnames=_maybe_list(n, c, "colnames", _build_str), ) def _build_base_relation(n: Node, c: Context) -> pgast.BaseRelation: return _enum( pgast.BaseRelation, n, c, { "SelectStmt": _build_select_stmt, "Relation": _build_relation, }, ) def _build_relation(n: Node, c: Context) -> pgast.Relation: return pgast.Relation( name=_maybe(n, c, "relname", _build_str), catalogname=_maybe(n, c, "catalogname", _build_str), schemaname=_maybe(n, c, "schemaname", _build_str), is_temporary=_maybe(n, c, "relpersistence", lambda n, _c: n == 't'), span=_build_span(n, c), ) def _build_implicit_row(n: Node, c: Context) -> pgast.ImplicitRowExpr: if isinstance(n, list): n = n[0] n = _unwrap(n, "List") return pgast.ImplicitRowExpr( args=[_build_base_expr(e, c) for e in n["items"] if len(e) > 0], ) def _build_array_expr(n: Node, c: Context) -> pgast.ArrayExpr: return pgast.ArrayExpr(elements=_list(n, c, "elements", _build_base_expr)) def _build_a_expr(n: Node, c: Context) -> pgast.BaseExpr: names: list[str] = _list(n, c, 'name', _build_str) if names[0] == 'pg_catalog': names.pop(0) name = names.pop(0) if n["kind"] == "AEXPR_OP": pass elif n["kind"] in ("AEXPR_LIKE", "AEXPR_ILIKE"): op = name if op.endswith("*"): name = "ILIKE" else: name = "LIKE" if op.startswith("!"): name = "NOT " + name elif n["kind"] == "AEXPR_IN": if name == "<>": name = "NOT IN" else: name = "IN" elif n["kind"] in ("AEXPR_OP_ANY", "AEXPR_OP_ALL"): operator = "ANY" if n["kind"] == "AEXPR_OP_ANY" else "ALL" return pgast.SubLink( operator=name + " " + operator, test_expr=_maybe(n, c, "lexpr", _build_base_expr), expr=_build_base_expr(n["rexpr"], c), span=_build_span(n, c), ) elif n['kind'] == 'AEXPR_NULLIF': return pgast.FuncCall( name=('nullif',), args=[ _build_base_expr(n['lexpr'], c), _build_base_expr(n['rexpr'], c), ], ) elif n['kind'] == 'AEXPR_DISTINCT': name = 'IS DISTINCT FROM' elif n['kind'] == 'AEXPR_NOT_DISTINCT': name = 'IS NOT DISTINCT FROM' else: raise PSqlUnsupportedError(n) return pgast.Expr( name=name, lexpr=_maybe(n, c, "lexpr", _build_base_expr), rexpr=_maybe(n, c, "rexpr", _build_base_expr), span=_build_span(n, c), ) def _build_func_call(n: Node, c: Context) -> pgast.FuncCall: n = _unwrap(n, "FuncCall") return pgast.FuncCall( name=tuple(_list(n, c, "funcname", _build_str)), args=_maybe_list(n, c, "args", _build_base_expr) or [], agg_order=( _maybe_list(n, c, "aggOrder", _build_sort_by) or _maybe_list(n, c, "agg_order", _build_sort_by) ), agg_filter=_maybe(n, c, "aggFilter", _build_base_expr), agg_star=_bool_or_false(n, "agg_star"), agg_distinct=_bool_or_false(n, "agg_distinct"), agg_within_group=_bool_or_false(n, "agg_within_group"), over=_maybe(n, c, "over", _build_window_def), with_ordinality=_bool_or_false(n, "withOrdinality"), span=_build_span(n, c), ) def _build_coalesce(n: Node, c: Context) -> pgast.CoalesceExpr: return pgast.CoalesceExpr( args=_list(n, c, "args", _build_base_expr), ) def _build_index_or_slice(n: Node, c: Context) -> pgast.Slice | pgast.Index: if 'is_slice' in n and n['is_slice']: return pgast.Slice( lidx=_build_base_expr(n['lidx'], c), ridx=_build_base_expr(n['uidx'], c), ) else: idx = n['lidx'] if 'lidx' in n else n['uidx'] return pgast.Index( idx=_build_base_expr(idx, c), ) def _build_res_target(n: Node, c: Context) -> pgast.ResTarget: n = _unwrap(n, "ResTarget") return pgast.ResTarget( name=_maybe(n, c, "name", _build_str), val=_build_base_expr(n["val"], c), span=_build_span(n, c), ) def _build_insert_target(n: Node, c: Context) -> pgast.InsertTarget: n = _unwrap(n, "ResTarget") return pgast.InsertTarget( name=_build_str(n['name'], c), span=_build_span(n, c), ) def _build_update_targets( target_list: list[Node], c: Context ) -> list[pgast.UpdateTarget | pgast.MultiAssignRef]: targets: list[pgast.UpdateTarget | pgast.MultiAssignRef] = [] while len(target_list) > 0: val: dict = target_list[0]["ResTarget"]["val"] if first_mar := val.get("MultiAssignRef", None): ncolumns = first_mar['ncolumns'] columns: list[str] = [] for _ in range(ncolumns): target = target_list.pop(0) mar = target['ResTarget'] if 'indirection' in mar: raise PSqlUnsupportedError( val, f"multi-assign SET with indirection" ) columns.append(mar['name']) targets.append(_build_multi_assign_ref(first_mar, columns, c)) else: target = target_list.pop(0) targets.append(_build_update_target(target, c)) return targets def _build_multi_assign_ref( n: Node, columns: list[str], c: Context ) -> pgast.MultiAssignRef: return pgast.MultiAssignRef( source=_build_base_expr(n['source'], c), columns=columns, span=_build_span(n, c), ) def _build_update_target( n: Node, c: Context ) -> pgast.UpdateTarget | pgast.MultiAssignRef: n = _unwrap(n, "ResTarget") return pgast.UpdateTarget( name=_build_str(n['name'], c), val=_build_base_expr(n['val'], c), indirection=_maybe_list(n, c, "indirection", _build_indirection_op), span=_build_span(n, c), ) def _build_window_def(n: Node, c: Context) -> pgast.WindowDef: return pgast.WindowDef( name=_maybe(n, c, "name", _build_str), refname=_maybe(n, c, "refname", _build_str), partition_clause=_maybe_list(n, c, "partitionClause", _build_base_expr), order_clause=_maybe_list(n, c, "orderClause", _build_sort_by), frame_options=None, start_offset=_maybe(n, c, "startOffset", _build_base_expr), end_offset=_maybe(n, c, "endOffset", _build_base_expr), span=_build_span(n, c), ) def _build_sort_by(n: Node, c: Context) -> pgast.SortBy: n = _unwrap(n, "SortBy") return pgast.SortBy( node=_build_base_expr(n["node"], c), dir=_maybe(n, c, "sortby_dir", _build_sort_order), nulls=_maybe(n, c, "sortby_nulls", _build_nones_order), span=_build_span(n, c), ) def _build_locking_clause(n: Node, c: Context) -> pgast.LockingClause: n = _unwrap(n, "LockingClause") match n["strength"]: case "LCS_FORUPDATE": strength = pgast.LockClauseStrength.UPDATE case "LCS_FORNOKEYUPDATE": strength = pgast.LockClauseStrength.NO_KEY_UPDATE case "LCS_FORSHARE": strength = pgast.LockClauseStrength.SHARE case "LCS_FORKEYSHARE": strength = pgast.LockClauseStrength.KEY_SHARE case lcs: raise PSqlUnsupportedError(n, f"FOR lock strength: {lcs}") if pol := n.get("waitPolicy"): wait_policy = getattr(pgast.LockWaitPolicy, pol, None) if wait_policy is None: raise PSqlUnsupportedError(n, f"FOR wait policy: {pol}") else: wait_policy = None return pgast.LockingClause( strength=strength, locked_rels=_maybe_list( n, c, "lockedRels", _build_rel_range_var, unwrap="RangeVar" ), wait_policy=wait_policy, ) def _build_nones_order(n: Node, _c: Context) -> Optional[qlast.NonesOrder]: if n == "SORTBY_NULLS_FIRST": return qlast.NonesFirst if n == "SORTBY_NULLS_LAST": return qlast.NonesLast return None def _build_sort_order(n: Node, _c: Context) -> Optional[qlast.SortOrder]: if n == "SORTBY_DESC": return qlast.SortOrder.Desc if n == "SORTBY_ASC": return qlast.SortOrder.Asc if n == "SORTBY_DEFAULT": return None raise PSqlUnsupportedError(n) def _build_column_ref(n: Node, c: Context) -> pgast.ColumnRef: return pgast.ColumnRef( name=_list(n, c, "fields", _build_string_or_star), optional=_maybe(n, c, "optional", _build_bool), span=_build_span(n, c), ) def _build_on_conflict(n: Node, c: Context) -> pgast.OnConflictClause: action_str = _build_str(n["action"], c) if action_str == "ONCONFLICT_NOTHING": action = pgast.OnConflictAction.DO_NOTHING elif action_str == "ONCONFLICT_UPDATE": action = pgast.OnConflictAction.DO_UPDATE else: raise PSqlUnsupportedError(n, f"ON CONFLICT {action_str}") return pgast.OnConflictClause( action=action, target=_maybe(n, c, "infer", _build_on_conflict_target), update_list=_maybe(n, c, "targetList", _build_update_targets) or [], update_where=_maybe(n, c, "whereClause", _build_base_expr), span=_build_span(n, c), ) def _build_on_conflict_target(n: Node, c: Context) -> pgast.OnConflictTarget: return pgast.OnConflictTarget( index_elems=_maybe_list(n, c, "indexElems", _build_index_elem), index_where=_maybe(n, c, "whereClause", _build_base_expr), constraint_name=_maybe(n, c, "conname", _build_str), span=_build_span(n, c), ) def _build_index_elem(n: Node, c: Context) -> pgast.IndexElem: n = _unwrap(n, 'IndexElem') expr: pgast.BaseExpr if 'name' in n: expr = pgast.ColumnRef(name=(n['name'],)) elif 'indexcolname' in n: expr = pgast.ColumnRef(name=(n['name'],)) elif 'expr' in n: expr = _build_base_expr(n['expr'], c) else: raise PSqlUnsupportedError(n) # TODO: # collation # opclass # opclassopts return pgast.IndexElem( expr=expr, ordering=_build_sort_order(n['ordering'], c), nulls_ordering=_build_nones_order(n['nulls_ordering'], c), span=_build_span(n, c), ) def _build_star(_n: Node, _c: Context) -> pgast.Star | str: return pgast.Star() def _build_string_or_star(node: Node, c: Context) -> pgast.Star | str: return _enum( pgast.Star | str, # type: ignore node, c, {"String": _build_str, "A_Star": _build_star}, ) def _build_copy_format(n: Node, c: Context) -> pgast.CopyFormat: val = _build_str(n, c) if val == 'text': return pgast.CopyFormat.TEXT if val == 'csv': return pgast.CopyFormat.CSV if val == 'binary': return pgast.CopyFormat.BINARY raise PSqlUnsupportedError(val, f"{val} format") def _build_copy_options(nodes: list[Node], c: Context) -> pgast.CopyOptions: opt = pgast.CopyOptions() for n in nodes: if 'DefElem' not in n or 'defname' not in n['DefElem']: continue n = n['DefElem'] def_name = n['defname'] arg = n['arg'] if 'arg' in n else None if def_name == 'format' and arg: opt.format = _build_copy_format(arg, c) elif def_name == 'freeze': opt.freeze = _build_str(arg, c) == 'true' if arg else True elif def_name == 'delimiter' and arg: opt.delimiter = _build_str(arg, c) elif def_name == 'null' and arg: opt.null = _build_str(arg, c) elif def_name == 'header' and arg: opt.header = _build_str(arg, c) == 'true' if arg else True elif def_name == 'quote' and arg: opt.quote = _build_str(arg, c) elif def_name == 'escape' and arg: opt.escape = _build_str(arg, c) elif def_name == 'force_quote': arg = _unwrap(arg, 'List') opt.force_quote = _list(arg, c, 'items', _build_str) elif def_name == 'force_not_null': arg = _unwrap(arg, 'List') opt.force_not_null = _list(arg, c, 'items', _build_str) elif def_name == 'force_null': arg = _unwrap(arg, 'List') opt.force_null = _list(arg, c, 'items', _build_str) elif def_name == 'encoding': opt.encoding = _build_str(arg, c) return opt def _build_copy(n: Node, c: Context) -> pgast.CopyStmt: return pgast.CopyStmt( relation=_maybe(n, c, 'relation', _build_relation), colnames=_maybe_list(n, c, 'attlist', _build_str), query=_maybe(n, c, 'query', _build_query), is_from=_bool_or_false(n, 'is_from'), is_program=_bool_or_false(n, 'is_program'), filename=_maybe(n, c, 'filename', _build_str), options=( _maybe(n, c, 'options', _build_copy_options) or pgast.CopyOptions() ), where_clause=_maybe(n, c, "whereClause", _build_base_expr), span=_build_span(n, c), ) ================================================ FILE: edb/pgsql/parser/exceptions.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2010-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import re from typing import Any, Optional class PSqlParseError(Exception): pass class PSqlSyntaxError(PSqlParseError): def __init__( self, message: str, cursor_pos: int, # 0-based query_source: str, ): self.message = message self.cursor_pos = cursor_pos self.query_source = query_source def __str__(self): return self.message class PSqlUnsupportedError(PSqlParseError): node: Optional[Any] location: Optional[int] message: str def __init__(self, node: Optional[Any] = None, feat: Optional[str] = None): self.node = node self.location = None self.message = "not supported" if feat: self.message += f": {feat}" def __str__(self): return self.message def get_node_name(name: str) -> str: """ Given a node name (CreateTableStmt), this function tries to guess the SQL command text (CREATE TABLE). """ name = name.removesuffix('Stmt').removesuffix('Expr') name = re.sub(r'(? str: cdef: PgQueryParseResult result bytes queryb queryb = query.encode("utf-8") result = pg_query_parse(queryb) if result.error: error = PSqlSyntaxError( result.error.message.decode('utf8'), result.error.cursorpos, query, ) pg_query_free_parse_result(result) raise error result_utf8 = result.parse_tree.decode('utf8') pg_query_free_parse_result(result) return result_utf8 class LiteralTokenType(enum.StrEnum): FCONST = "FCONST" SCONST = "SCONST" BCONST = "BCONST" XCONST = "XCONST" ICONST = "ICONST" TRUE_P = "TRUE_P" FALSE_P = "FALSE_P" class PgLiteralTypeOID(enum.IntEnum): BOOL = 16 INT4 = 23 TEXT = 25 UNKNOWN = 705 VARBIT = 1562 NUMERIC = 1700 class NormalizedQuery(NamedTuple): text: str highest_extern_param_id: int extracted_constants: list[tuple[int, LiteralTokenType, bytes]] def pg_normalize(query: str) -> NormalizedQuery: cdef: PgQueryNormalizeResult result PgQueryNormalizeConstLocation loc const ProtobufCEnumValue *token int i bytes queryb bytes const queryb = query.encode("utf-8") result = pg_query_normalize(queryb) try: if result.error: error = PSqlSyntaxError( result.error.message.decode('utf8'), result.error.cursorpos, query, ) raise error normalized_query = result.normalized_query.decode('utf8') consts = [] for i in range(result.clocations_count): loc = result.clocations[i] if loc.length != -1: if loc.param_id < 0: # Negative param_id means *relative* to highest explicit # param id (after taking the absolute value). param_id = ( abs(loc.param_id) + result.highest_extern_param_id ) else: # Otherwise it's the absolute param id. param_id = loc.param_id if loc.val != NULL: token = protobuf_c_enum_descriptor_get_value( &pg_query__token__descriptor, loc.token) if token == NULL: raise RuntimeError( f"could not lookup pg_query enum descriptor " f"for token value {loc.token}" ) consts.append(( param_id, LiteralTokenType(bytes(token.name).decode("ascii")), bytes(loc.val), )) return NormalizedQuery( text=normalized_query, highest_extern_param_id=result.highest_extern_param_id, extracted_constants=consts, ) finally: pg_query_free_normalize_result(result) cdef ReadBuffer _init_deserializer(serialized: bytes, tag: uint8_t, cls: str): cdef ReadBuffer buf buf = ReadBuffer.new_message_parser(serialized) if buf.read_byte() != tag: raise ValueError(f"malformed {cls} serialization") return buf cdef class Source: def __init__( self, text: str, serialized: Optional[bytes] = None, ) -> None: self._text = text if serialized is not None: self._serialized = serialized else: self._serialized = b'' self._cache_key = b'' @classmethod def _tag(self) -> int: return 0 cdef WriteBuffer _serialize(self): cdef WriteBuffer buf = WriteBuffer.new() buf.write_byte(self._tag()) buf.write_len_prefixed_utf8(self._text) return buf def serialize(self) -> bytes: if not self._serialized: self._serialized = bytes(self._serialize()) return self._serialized @classmethod def from_serialized(cls, serialized: bytes) -> Source: cdef ReadBuffer buf buf = _init_deserializer(serialized, cls._tag(), cls.__name__) text = buf.read_len_prefixed_utf8() return Source(text, serialized) def text(self) -> str: return self._text def original_text(self) -> str: return self._text def cache_key(self) -> bytes: if not self._cache_key: h = hashlib.blake2b(self._tag().to_bytes()) h.update(bytes(self.text(), 'UTF-8')) # Include types of extracted constants for extra_type_oid in self.extra_type_oids(): h.update(extra_type_oid.to_bytes(8, signed=True)) self._cache_key = h.digest() return self._cache_key def variables(self) -> dict[str, Any]: return {} def first_extra(self) -> Optional[int]: return None def extra_counts(self) -> Sequence[int]: return [] def extra_blobs(self) -> list[bytes]: return [] def extra_formatted_as_text(self) -> bool: return True def extra_type_oids(self) -> Sequence[int]: return () @classmethod def from_string(cls, text: str) -> Source: return Source(text) def denormalized(self) -> Source: return self cdef class NormalizedSource(Source): def __init__( self, normalized: NormalizedQuery, orig_text: str, serialized: Optional[bytes] = None, ) -> None: super().__init__(text=normalized.text, serialized=serialized) self._extracted_constants = list( sorted(normalized.extracted_constants, key=lambda i: i[0]), ) self._highest_extern_param_id = normalized.highest_extern_param_id self._orig_text = orig_text @classmethod def _tag(cls) -> int: return 1 def original_text(self) -> str: return self._orig_text cdef WriteBuffer _serialize(self): cdef WriteBuffer buf buf = Source._serialize(self) buf.write_len_prefixed_utf8(self._orig_text) buf.write_int32(self._highest_extern_param_id) buf.write_int32(len(self._extracted_constants)) for param_id, token, val in self._extracted_constants: buf.write_int32(param_id) buf.write_len_prefixed_utf8(token.value) buf.write_len_prefixed_bytes(val) return buf def variables(self) -> dict[str, bytes]: return {f"${n}": v for n, _, v in self._extracted_constants} def first_extra(self) -> Optional[int]: return ( self._highest_extern_param_id if self._extracted_constants else None ) def extra_counts(self) -> Sequence[int]: return [len(self._extracted_constants)] def extra_blobs(self) -> list[bytes]: cdef WriteBuffer buf buf = WriteBuffer.new() for _, _, v in self._extracted_constants: buf.write_len_prefixed_bytes(v) return [bytes(buf)] def extra_type_oids(self) -> Sequence[int]: oids = [] for _, token, _ in self._extracted_constants: if token is LiteralTokenType.FCONST: oids.append(PgLiteralTypeOID.NUMERIC) elif token is LiteralTokenType.ICONST: oids.append(PgLiteralTypeOID.INT4) elif ( token is LiteralTokenType.FALSE_P or token is LiteralTokenType.TRUE_P ): oids.append(PgLiteralTypeOID.BOOL) elif token is LiteralTokenType.SCONST: oids.append(PgLiteralTypeOID.UNKNOWN) elif ( token is LiteralTokenType.XCONST or token is LiteralTokenType.BCONST ): oids.append(PgLiteralTypeOID.VARBIT) else: raise AssertionError(f"unexpected literal token type: {token}") return oids @classmethod def from_string(cls, text: str) -> NormalizedSource: normalized = pg_normalize(text) return NormalizedSource(normalized, text) @classmethod def from_serialized(cls, serialized: bytes) -> NormalizedSource: cdef ReadBuffer buf buf = _init_deserializer(serialized, cls._tag(), cls.__name__) text = buf.read_len_prefixed_utf8() orig_text = buf.read_len_prefixed_utf8() highest_extern_param_id = buf.read_int32() n_constants = buf.read_int32() consts = [] for _ in range(n_constants): param_id = buf.read_int32() token = buf.read_len_prefixed_utf8() val = buf.read_len_prefixed_bytes() consts.append((param_id, LiteralTokenType(token), val)) return NormalizedSource( NormalizedQuery( text=text, highest_extern_param_id=highest_extern_param_id, extracted_constants=consts, ), orig_text, serialized, ) def denormalized(self) -> Source: return Source.from_string(self._orig_text) def deserialize(serialized: bytes) -> Source: if serialized[0] == 0: return Source.from_serialized(serialized) elif serialized[0] == 1: return NormalizedSource.from_serialized(serialized) raise ValueError(f"Invalid type/version byte: {serialized[0]}") ================================================ FILE: edb/pgsql/patches.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Patches to apply to databases""" from __future__ import annotations def get_patch_level(num_patches: int) -> int: return sum(p.startswith('edgeql+schema') for p, _ in PATCHES[:num_patches]) def get_version_key(num_patches: int) -> str: """Produce a version key to add to instdata keys after major patches. Patches that modify the schema class layout and introspection queries are not safe to downgrade from. So for such patches, we add a version suffix to the names of the core instdata entries that we would need to update, so that we don't clobber the old version. After a downgrade, we'll have more patches applied than we actually know exist in the running version, but since we compute the key based on the number of schema layout patches that we can *see*, we still compute the right key. """ level = get_patch_level(num_patches) if level == 0: return '' else: return f'_v{level}' """ The actual list of patches. The patches are (kind, script) pairs. The current kinds are: * sql - simply runs a SQL script * metaschema-sql - create a function from metaschema * edgeql - runs an edgeql DDL command * edgeql+schema - runs an edgeql DDL command and updates the std schemas * NOTE: objects and fields added to the reflschema must * have their patch_level set to the `get_patch_level` value * for this patch. * edgeql+user_ext| - updates extensions installed in user databases * - should be paired with an ext-pkg patch * ...+config - updates config views * ext-pkg - installs an extension package given a name * repair - fix up inconsistencies in *user* schemas * sql-introspection - refresh all sql introspection views * ...+testmode - only run the patch in testmode. Works with any patch kind. """ PATCHES: list[tuple[str, str]] = [ ] ================================================ FILE: edb/pgsql/patches_6x.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Patches copied over from 6.x. Deeply unfortunately, we need to be able to apply the user_ext patches... """ from __future__ import annotations """ The actual list of patches. The patches are (kind, script) pairs. The current kinds are: * sql - simply runs a SQL script * metaschema-sql - create a function from metaschema * edgeql - runs an edgeql DDL command * edgeql+schema - runs an edgeql DDL command and updates the std schemas * NOTE: objects and fields added to the reflschema must * have their patch_level set to the `get_patch_level` value * for this patch. * edgeql+user_ext| - updates extensions installed in user databases * - should be paired with an ext-pkg patch * ...+config - updates config views * ext-pkg - installs an extension package given a name * repair - fix up inconsistencies in *user* schemas * sql-introspection - refresh all sql introspection views * ...+testmode - only run the patch in testmode. Works with any patch kind. """ PATCHES: list[tuple[str, str]] = [ # 6.0b2 # One of the sql-introspection's adds a param with a default to # uuid_to_oid, so we need to drop the original to avoid ambiguity. ('sql', ''' drop function if exists edgedbsql_v6_2f20b3fed0.uuid_to_oid(uuid) cascade '''), ('sql-introspection', ''), ('metaschema-sql', 'SysConfigFullFunction'), # 6.0rc1 ('edgeql+schema+config+testmode', ''' CREATE SCALAR TYPE cfg::TestEnabledDisabledEnum EXTENDING enum; ALTER TYPE cfg::AbstractConfig { CREATE PROPERTY __check_function_bodies -> cfg::TestEnabledDisabledEnum { CREATE ANNOTATION cfg::internal := 'true'; CREATE ANNOTATION cfg::backend_setting := '"check_function_bodies"'; SET default := cfg::TestEnabledDisabledEnum.Enabled; }; }; '''), ('metaschema-sql', 'PostgresConfigValueToJsonFunction'), ('metaschema-sql', 'SysConfigFullFunction'), ('edgeql', ''' ALTER FUNCTION std::assert_single( input: SET OF anytype, NAMED ONLY message: OPTIONAL str = {}, ) { SET volatility := 'Immutable'; }; ALTER FUNCTION std::assert_exists( input: SET OF anytype, NAMED ONLY message: OPTIONAL str = {}, ) { SET volatility := 'Immutable'; }; ALTER FUNCTION std::assert_distinct( input: SET OF anytype, NAMED ONLY message: OPTIONAL str = {}, ) { SET volatility := 'Immutable'; }; '''), ('edgeql+schema+config', ''' CREATE SCALAR TYPE sys::TransactionAccessMode EXTENDING enum; CREATE SCALAR TYPE sys::TransactionDeferrability EXTENDING enum; ALTER TYPE cfg::AbstractConfig { CREATE REQUIRED PROPERTY default_transaction_isolation -> sys::TransactionIsolation { CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION cfg::backend_setting := '"default_transaction_isolation"'; CREATE ANNOTATION std::description := 'Controls the default isolation level of each new transaction, \ including implicit transactions. Defaults to `Serializable`. \ Note that changing this to a lower isolation level implies \ that the transactions are also read-only by default regardless \ of the value of the `default_transaction_access_mode` setting.'; SET default := sys::TransactionIsolation.Serializable; }; CREATE REQUIRED PROPERTY default_transaction_access_mode -> sys::TransactionAccessMode { CREATE ANNOTATION cfg::affects_compilation := 'true'; CREATE ANNOTATION std::description := 'Controls the default read-only status of each new transaction, \ including implicit transactions. Defaults to `ReadWrite`. \ Note that if `default_transaction_isolation` is set to any value \ other than Serializable this parameter is implied to be \ `ReadOnly` regardless of the actual value.'; SET default := sys::TransactionAccessMode.ReadWrite; }; CREATE REQUIRED PROPERTY default_transaction_deferrable -> sys::TransactionDeferrability { CREATE ANNOTATION cfg::backend_setting := '"default_transaction_deferrable"'; CREATE ANNOTATION std::description := 'Controls the default deferrable status of each new transaction. \ It currently has no effect on read-write transactions or those \ operating at isolation levels lower than `Serializable`. \ The default is `NotDeferrable`.'; SET default := sys::TransactionDeferrability.NotDeferrable; }; }; '''), # 6.2 ('ext-pkg', 'ai'), ('edgeql+user_ext+config|ai', ''' alter type ext::ai::EmbeddingModel { drop annotation ext::ai::embedding_model_max_batch_tokens; create annotation ext::ai::embedding_model_max_batch_tokens := "8191"; } '''), # 6.3 ('repair', ''), # For #8466 # 6.5 ('sql-introspection', ''), # For #8511 ('edgeql+user_ext|ai', r''' update ext::ai::ChatPrompt filter .name = 'builtin::rag-default' set { messages += (insert ext::ai::ChatPromptMessage { participant_role := ext::ai::ChatParticipantRole.User, content := ( "Query: {query}\n\ Answer: " ), }) } '''), # For #8553 # 6.6 ('edgeql+schema', ''), # For #8554 ('ext-pkg', 'ai'), # For #8521, #8646 ('edgeql+user_ext+config|ai', ''' create function ext::ai::search( object: anyobject, query: str, ) -> optional tuple { create annotation std::description := ' Search an object using its ext::ai::index index. Gets an embedding for the query from the ai provider then returns objects that match the specified semantic query and the similarity score. '; set volatility := 'Stable'; # Needed to pick up the indexes when used in ORDER BY. set prefer_subquery_args := true; set server_param_conversions := '{"query": ["ai_text_embedding", "object"]}'; using sql expression; }; alter scalar type ext::ai::ProviderAPIStyle extending enum; create type ext::ai::OllamaProviderConfig extending ext::ai::ProviderConfig { alter property name { set protected := true; set default := 'builtin::ollama'; }; alter property display_name { set protected := true; set default := 'Ollama'; }; alter property api_url { set default := 'http://localhost:11434/api' }; alter property secret { set default := '' }; alter property api_style { set protected := true; set default := ext::ai::ProviderAPIStyle.Ollama; }; }; # Ollama embedding models create abstract type ext::ai::OllamaLlama_3_2_Model extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "llama3.2"; alter annotation ext::ai::model_provider := "builtin::ollama"; alter annotation ext::ai::text_gen_model_context_window := "131072"; }; create abstract type ext::ai::OllamaLlama_3_3_Model extending ext::ai::TextGenerationModel { alter annotation ext::ai::model_name := "llama3.3"; alter annotation ext::ai::model_provider := "builtin::ollama"; alter annotation ext::ai::text_gen_model_context_window := "131072"; }; create abstract type ext::ai::OllamaNomicEmbedTextModel extending ext::ai::EmbeddingModel { alter annotation ext::ai::model_name := "nomic-embed-text"; alter annotation ext::ai::model_provider := "builtin::ollama"; alter annotation ext::ai::embedding_model_max_input_tokens := "8192"; alter annotation ext::ai::embedding_model_max_batch_tokens := "8192"; alter annotation ext::ai::embedding_model_max_output_dimensions := "768"; }; create abstract type ext::ai::OllamaBgeM3Model extending ext::ai::EmbeddingModel { alter annotation ext::ai::model_name := "bge-m3"; alter annotation ext::ai::model_provider := "builtin::ollama"; alter annotation ext::ai::embedding_model_max_input_tokens := "8192"; alter annotation ext::ai::embedding_model_max_batch_tokens := "8192"; alter annotation ext::ai::embedding_model_max_output_dimensions := "1024"; }; '''), # 6.8 ('edgeql+user+remove_pointless_triggers', ''), ('edgeql', ''' CREATE FUNCTION std::to_bytes(j: std::json) -> std::bytes { CREATE ANNOTATION std::description := 'Convert a json value to a binary UTF-8 string.'; SET volatility := 'Immutable'; USING (to_bytes(to_str(j))); }; '''), ('metaschema-sql', 'ArrayIndexWithBoundsFunction'), ('metaschema-sql', 'ArraySliceFunction'), ('metaschema-sql', 'StringIndexWithBoundsFunction'), ('metaschema-sql', 'BytesIndexWithBoundsFunction'), ('metaschema-sql', 'StringSliceFunction'), ('metaschema-sql', 'BytesSliceFunction'), ('metaschema-sql', 'JSONIndexByTextFunction'), ('metaschema-sql', 'JSONIndexByIntFunction'), ('metaschema-sql', 'JSONSliceFunction'), ('edgeql', ''' CREATE MODULE std::lang; CREATE MODULE std::lang::go; CREATE ABSTRACT ANNOTATION std::lang::go::type; CREATE MODULE std::lang::js; CREATE ABSTRACT ANNOTATION std::lang::js::type; CREATE MODULE std::lang::py; CREATE ABSTRACT ANNOTATION std::lang::py::type; CREATE MODULE std::lang::rs; CREATE ABSTRACT ANNOTATION std::lang::rs::type; '''), # 6.9 ('edgeql', ''' CREATE FUNCTION std::__pg_generate_series( `start`: std::int64, stop: std::int64 ) -> SET OF std::int64 { SET volatility := 'Immutable'; USING SQL FUNCTION 'generate_series'; }; '''), # !!!!!! 7.x !!!!! ('edgeql+user_ext+config|auth', ''' create type ext::auth::OneTimeCode extending ext::auth::Auditable { create required property code_hash: std::bytes { create constraint exclusive; create annotation std::description := "The securely hashed one-time code."; }; create required property expires_at: std::datetime { create annotation std::description := "The date and time when the code expires."; }; create index on (.expires_at); create required link factor: ext::auth::Factor { on target delete delete source; }; }; create scalar type ext::auth::AuthenticationAttemptType extending std::enum< SignIn, EmailVerification, PasswordReset, MagicLink, OneTimeCode >; create type ext::auth::AuthenticationAttempt extending ext::auth::Auditable { create required link factor: ext::auth::Factor { on target delete delete source; }; create required property attempt_type: ext::auth::AuthenticationAttemptType { create annotation std::description := "The type of authentication attempt being made."; }; create required property successful: std::bool { create annotation std::description := "Whether this authentication attempt was successful."; }; }; create scalar type ext::auth::VerificationMethod extending std::enum; alter type ext::auth::EmailPasswordProviderConfig { create required property verification_method: ext::auth::VerificationMethod { set default := ext::auth::VerificationMethod.Link; }; }; alter type ext::auth::WebAuthnProviderConfig { create required property verification_method: ext::auth::VerificationMethod { set default := ext::auth::VerificationMethod.Link; }; }; alter type ext::auth::MagicLinkProviderConfig { create required property verification_method: ext::auth::VerificationMethod { set default := ext::auth::VerificationMethod.Link; }; }; alter scalar type ext::auth::WebhookEvent extending std::enum< IdentityCreated, IdentityAuthenticated, EmailFactorCreated, EmailVerified, EmailVerificationRequested, PasswordResetRequested, MagicLinkRequested, OneTimeCodeRequested, OneTimeCodeVerified, >; '''), ] ================================================ FILE: edb/pgsql/resolver/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional import copy import dataclasses from edb.common import debug from edb.pgsql import ast as pgast from edb.pgsql import codegen as pgcodegen from edb.schema import schema as s_schema from edb.server.compiler import dbstate, enums from . import dispatch from . import context from . import expr # NOQA from . import relation # NOQA from . import command # NOQA Options = context.Options @dataclasses.dataclass(kw_only=True, eq=False, frozen=True, repr=False) class ResolvedSQL: # AST representing the query that can be sent to PostgreSQL ast: pgast.Base # Optionally, AST representing the query returning data in EdgeQL # format (i.e. single-column output). edgeql_output_format_ast: Optional[pgast.Base] # Special behavior for "tag" of "CommandComplete" message of this query. command_complete_tag: Optional[dbstate.CommandCompleteTag] # query parameters params: list[dbstate.SQLParam] capabilities: enums.Capability = enums.Capability.NONE def resolve( query: pgast.Query | pgast.CopyStmt, schema: s_schema.Schema, options: context.Options, ) -> ResolvedSQL: if debug.flags.sql_input: debug.header('SQL Input') debug_sql_text = pgcodegen.generate_source( query, reordered=True, pretty=True ) debug.dump_code(debug_sql_text, lexer='sql') ctx = context.ResolverContextLevel( None, context.ContextSwitchMode.EMPTY, schema=schema, options=options ) _ = context.ResolverContext(initial=ctx) command.init_external_params(query, ctx) top_level_ctes = command.compile_dml(query, ctx=ctx) resolved: pgast.Base if isinstance(query, pgast.Query): resolved, resolved_table = dispatch.resolve_relation(query, ctx=ctx) elif isinstance(query, pgast.CopyStmt): resolved = dispatch.resolve(query, ctx=ctx) resolved_table = None else: raise AssertionError() if limit := ctx.options.implicit_limit: resolved = apply_implicit_limit(resolved, limit, resolved_table, ctx) command.fini_external_params(ctx) if top_level_ctes: assert isinstance(resolved, pgast.Query) if not resolved.ctes: resolved.ctes = [] resolved.ctes.extend(top_level_ctes) # when the top-level query is DML statement, clients will expect a tag in # the CommandComplete message that describes the number of modified rows. # Since our resolved SQL does not have a top-level DML stmt, we need to # override that tag. command_complete_tag: Optional[dbstate.CommandCompleteTag] = None if isinstance(query, pgast.DMLQuery): prefix: str if isinstance(query, pgast.InsertStmt): prefix = 'INSERT 0 ' elif isinstance(query, pgast.DeleteStmt): prefix = 'DELETE ' elif isinstance(query, pgast.UpdateStmt): prefix = 'UPDATE ' if query.returning_list: # resolved SQL will return a result, we count those rows command_complete_tag = dbstate.TagCountMessages(prefix=prefix) else: # resolved SQL will contain an injected COUNT clause # we instruct io process to unpack that command_complete_tag = dbstate.TagUnpackRow(prefix=prefix) if debug.flags.sql_output: debug.header('SQL Output') debug_sql_text = pgcodegen.generate_source( resolved, pretty=True, reordered=True ) debug.dump_code(debug_sql_text, lexer='sql') if options.include_edgeql_io_format_alternative: edgeql_output_format_ast = copy.copy(resolved) if e := as_plain_select(edgeql_output_format_ast, resolved_table, ctx): # Turn the query into one that returns a ROW. # # We need to do this by injecting a new query and putting # the old one in its FROM clause, since things like # DISTINCT/ORDER BY care about what exact columns are in # the target list. columns = [] for i, target in enumerate(e.target_list): if not target.name: e.target_list[i] = target = target.replace(name=f'__i~{i}') assert target.name columns.append(pgast.ColumnRef(name=(target.name,))) edgeql_output_format_ast = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=expr.construct_row_expr(columns, ctx=ctx) ) ], from_clause=[pgast.RangeSubselect( subquery=e, alias=pgast.Alias(aliasname='r'), )], ctes=e.ctes, ) e.ctes = [] else: edgeql_output_format_ast = None return ResolvedSQL( ast=resolved, edgeql_output_format_ast=edgeql_output_format_ast, command_complete_tag=command_complete_tag, params=ctx.query_params, capabilities=ctx.env.capabilities, ) def as_plain_select( query: pgast.Base, table: Optional[context.Table], ctx: context.ResolverContextLevel, ) -> Optional[pgast.SelectStmt]: if not isinstance(query, pgast.Query): return None assert table if ( isinstance(query, pgast.SelectStmt) and not query.op and not query.values ): return query table.alias = "t" return pgast.SelectStmt( from_clause=[ pgast.RangeSubselect( subquery=query, alias=pgast.Alias(aliasname="t"), ) ], target_list=[ pgast.ResTarget( name=f'column{index + 1}', val=expr.resolve_column_kind(table, c.kind, ctx=ctx), ) for index, c in enumerate(table.columns) ], ) def apply_implicit_limit( expr: pgast.Base, limit: int, table: Optional[context.Table], ctx: context.ResolverContextLevel, ) -> pgast.Base: e = as_plain_select(expr, table, ctx) if not e: return expr if e.limit_count is None: e.limit_count = pgast.NumericConstant(val=str(limit)) return e ================================================ FILE: edb/pgsql/resolver/command.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """SQL resolver that compiles public SQL to internal SQL which is executable in our internal Postgres instance.""" from typing import Optional, Iterable, Mapping import dataclasses import functools import uuid from edb.server.pgcon import errors as pgerror from edb.server.compiler import dbstate from edb.common import ast from edb.common.typeutils import not_none from edb import errors from edb.pgsql import ast as pgast from edb.pgsql.parser import parser as pg_parser from edb.pgsql import compiler as pgcompiler from edb.pgsql import types as pgtypes from edb.pgsql.compiler import enums as pgce from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.edgeql import compiler as qlcompiler from edb.ir import ast as irast from edb.ir import typeutils as irtypeutils from edb.schema import objtypes as s_objtypes from edb.schema import pointers as s_pointers from edb.schema import links as s_links from edb.schema import properties as s_properties from edb.schema import name as sn from edb.schema import types as s_types from edb.schema import utils as s_utils from edb.server.compiler import enums from . import dispatch from . import context from . import expr as pg_res_expr from . import relation as pg_res_rel Context = context.ResolverContextLevel @dispatch._resolve.register def resolve_CopyStmt(stmt: pgast.CopyStmt, *, ctx: Context) -> pgast.CopyStmt: query: Optional[pgast.Query] if stmt.query: query = dispatch.resolve(stmt.query, ctx=ctx) elif stmt.relation: relation, table = dispatch.resolve_relation(stmt.relation, ctx=ctx) table.reference_as = ctx.alias_generator.get('rel') selected_columns = _pull_columns_from_table( table, ((c, stmt.span) for c in stmt.colnames) if stmt.colnames else None, ) # The relation being copied is potentially a view and views cannot be # copied if referenced by name, so we just always wrap it into a SELECT. # This is probably a view based on edgedb schema, so wrap it into # a select query. query = pgast.SelectStmt( from_clause=[ pgast.RelRangeVar( alias=pgast.Alias(aliasname=table.reference_as), relation=relation, ) ], target_list=[ pgast.ResTarget( val=pg_res_expr.resolve_column_kind(table, c.kind, ctx=ctx) ) for c in selected_columns ], ) else: raise AssertionError('CopyStmt must either have relation or query set') # WHERE where = dispatch.resolve_opt(stmt.where_clause, ctx=ctx) # COPY will always be top-level, so we must extract CTEs if not query.ctes: query.ctes = list() query.ctes.extend(ctx.ctes_buffer) ctx.ctes_buffer.clear() _validate_no_params(ctx) return pgast.CopyStmt( relation=None, colnames=None, query=query, is_from=stmt.is_from, is_program=stmt.is_program, filename=stmt.filename, # TODO: forbid some options? options=stmt.options, where_clause=where, ) def _validate_no_params(ctx: Context): if len(ctx.query_params) == 0: return has_globals = any( isinstance(p, dbstate.SQLParamGlobal) and not p.is_permission for p in ctx.query_params ) has_permissions = any( isinstance(p, dbstate.SQLParamGlobal) and p.is_permission for p in ctx.query_params ) hint: str | None if has_globals or has_permissions: if has_globals and has_permissions: offender = "globals or permissions" elif has_globals: offender = "globals" elif has_permissions: offender = "permissions" offender += " in computed properties and access policies" hint = "To disable policies, set apply_access_policies_pg := false" else: offender = "query parameters" hint = None raise errors.QueryError( f"COPY cannot use {offender}", pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, hint=hint, ) def _pull_columns_from_table( table: context.Table, col_names: Optional[Iterable[tuple[str, pgast.Span | None]]], ) -> list[context.Column]: if not col_names: return [c for c in table.columns if not c.hidden] col_map: dict[str, context.Column] = { col.name: col for col in table.columns } res = [] for name, span in col_names: col = col_map.get(name, None) if not col: raise errors.QueryError( f'column {name} does not exist', span=span, ) res.append(col) return res def compile_dml( stmt: pgast.Base, *, ctx: Context ) -> list[pgast.CommonTableExpr]: # extract all dml stmts dml_stmts_sql = _collect_dml_stmts(stmt) if len(dml_stmts_sql) == 0: return [] # un-compile each SQL dml stmt into EdgeQL stmts = [_uncompile_dml_stmt(s, ctx=ctx) for s in dml_stmts_sql] # merge EdgeQL stmts & compile to SQL ctx.compiled_dml, ctes = _compile_uncompiled_dml(stmts, ctx=ctx) return ctes def _collect_dml_stmts(stmt: pgast.Base) -> list[pgast.DMLQuery]: if not isinstance(stmt, pgast.Query): return [] # DML can only be in the top-level statement or its CTEs. # If it is in any of the nested CTEs, throw errors later on res: list[pgast.DMLQuery] = [] if stmt.ctes: for cte in stmt.ctes: if isinstance(cte.query, pgast.DMLQuery): res.append(cte.query) if isinstance(stmt, pgast.DMLQuery): res.append(stmt) return res ExternalRel = tuple[pgast.BaseRelation, tuple[pgce.PathAspect, ...]] @dataclasses.dataclass(kw_only=True, eq=False, repr=False) class UncompiledDML: # the input DML node input: pgast.Query # schema object associated with the table that is the subject of the DML subject: s_objtypes.ObjectType | s_links.Link | s_properties.Property # EdgeQL equivalent to the input node ql_stmt: qlast.Expr # additional params needed during compilation of the edgeql node ql_returning_shape: list[qlast.ShapeElement] ql_singletons: set[irast.PathId] ql_anchors: Mapping[str, irast.PathId] external_rels: Mapping[irast.PathId, ExternalRel] # list of column names of the subject type, along with pointer name # these columns will be available within RETURNING clause subject_columns: list[tuple[str, str]] stype_refs: dict[uuid.UUID, list[qlast.Set]] # data needed for stitching the compiled ast into the resolver output early_result: context.CompiledDML @functools.singledispatch def _uncompile_dml_stmt(stmt: pgast.DMLQuery, *, ctx: Context): """ Takes an SQL DML query and produces an equivalent EdgeQL query plus a bunch of metadata needed to extract associated CTEs from result of the EdgeQL compiler. In this context: - subject is the object type/pointer being updated, - source is the source of the subject (when subject is a pointer), - value is the relation that provides new value to be inserted/updated, - ptr-s are (usually) pointers on the subject. """ raise dispatch._raise_unsupported(stmt) def _uncompile_dml_subject( rvar: pgast.RelRangeVar, *, ctx: Context ) -> tuple[ context.Table, s_objtypes.ObjectType | s_links.Link | s_properties.Property ]: """ Determines the subject object of a DML operation. This can either be an ObjectType or a Pointer for link tables. """ assert isinstance(rvar.relation, pgast.Relation) _sub_rel, sub_table = pg_res_rel.resolve_relation( rvar.relation, include_inherited=False, ctx=ctx ) if not sub_table.schema_id: # this happens doing DML on an introspection table or (somehow) a CTE raise errors.QueryError( msg=f'cannot write into table "{sub_table.name}"', pgext_code=pgerror.ERROR_UNDEFINED_TABLE, ) sub = ctx.schema.get_by_id(sub_table.schema_id) assert isinstance( sub, (s_objtypes.ObjectType, s_links.Link, s_properties.Property) ) return sub_table, sub def _uncompile_subject_columns( sub: s_objtypes.ObjectType | s_links.Link | s_properties.Property, sub_table: context.Table, res: UncompiledDML, *, ctx: Context, ): ''' Instruct UncompiledDML to wrap the EdgeQL DML into a select shape that selects all pointers. This is applied when a RETURNING clause is present and these columns might be used in the clause. ''' for column in sub_table.columns: if column.hidden: continue _, ptr_name, _ = _get_pointer_for_column(column, sub, ctx) res.subject_columns.append((column.name, ptr_name)) @_uncompile_dml_stmt.register def _uncompile_insert_stmt( stmt: pgast.InsertStmt, *, ctx: Context ) -> UncompiledDML: # determine the subject object sub_table, sub = _uncompile_dml_subject(stmt.relation, ctx=ctx) expected_columns = _pull_columns_from_table( sub_table, ((c.name, c.span) for c in stmt.cols) if stmt.cols else None, ) res: UncompiledDML if isinstance(sub, s_objtypes.ObjectType): res = _uncompile_insert_object_stmt( stmt, sub, sub_table, expected_columns, ctx=ctx ) elif isinstance(sub, (s_links.Link, s_properties.Property)): res = _uncompile_insert_pointer_stmt( stmt, sub, sub_table, expected_columns, ctx=ctx ) else: raise NotImplementedError() if stmt.returning_list: _uncompile_subject_columns(sub, sub_table, res, ctx=ctx) return res def _uncompile_insert_object_stmt( stmt: pgast.InsertStmt, sub: s_objtypes.ObjectType, sub_table: context.Table, expected_columns: list[context.Column], *, ctx: Context, ) -> UncompiledDML: """ Translates a 'SQL INSERT into an object type table' to an EdgeQL insert. """ subject_id = irast.PathId.from_type( ctx.schema, sub, env=None, ) # handle DEFAULT and prepare the value relation value_relation, expected_columns = _uncompile_default_value( stmt.select_stmt, stmt.ctes, expected_columns, sub, ctx=ctx ) # if we are sure that we are inserting a single row # we can skip for-loops and iterators, which produces better SQL is_value_single = _has_at_most_one_row(stmt.select_stmt) # prepare anchors for inserted value columns value_name = ctx.alias_generator.get('ins_val') iterator_name = ctx.alias_generator.get('ins_iter') value_id = irast.PathId.from_type( ctx.schema, sub, typename=sn.QualName('__derived__', value_name), env=None, ) value_ql: qlast.PathElement = ( qlast.IRAnchor(name=value_name) if is_value_single else qlast.ObjectRef(name=iterator_name) ) # a phantom relation that is supposed to hold the inserted value # (in the resolver, this will be replaced by the real value relation) value_cte_name = ctx.alias_generator.get('ins_value') value_rel = pgast.Relation( name=value_cte_name, strip_output_namespaces=True, ) value_columns = [] insert_shape = [] stype_refs: dict[uuid.UUID, list[qlast.Set]] = {} for index, expected_col in enumerate(expected_columns): ptr, ptr_name, is_link = _get_pointer_for_column(expected_col, sub, ctx) value_columns.append((expected_col.name, ptr_name, is_link)) # inject type annotation into value relation _try_inject_ptr_type_cast(value_relation, index, ptr, ctx) # prepare the outputs of the source CTE ptr_id = _get_ptr_id(value_id, ptr, ctx) output = pgast.ColumnRef(name=(ptr_name,), nullable=True) if is_link: value_rel.path_outputs[(ptr_id, pgce.PathAspect.IDENTITY)] = output value_rel.path_outputs[(ptr_id, pgce.PathAspect.VALUE)] = output else: value_rel.path_outputs[(ptr_id, pgce.PathAspect.VALUE)] = output if ptr_name == 'id': value_rel.path_outputs[(value_id, pgce.PathAspect.VALUE)] = output # prepare insert shape that will use the paths from source_outputs insert_shape.append( _construct_assign_element_for_ptr( value_ql, ptr_name, ptr, is_link, ctx, stype_refs, ) ) # source needs an iterator column, so we need to invent one # Here we only decide on the name of that iterator column, the actual column # is generated later, when resolving the DML stmt. value_iterator = ctx.alias_generator.get('iter') output = pgast.ColumnRef(name=(value_iterator,)) value_rel.path_outputs[(value_id, pgce.PathAspect.ITERATOR)] = output if not any(c.name == 'id' for c in expected_columns): value_rel.path_outputs[(value_id, pgce.PathAspect.VALUE)] = output # construct the EdgeQL DML AST sub_name = sub.get_name(ctx.schema) ql_stmt_insert = qlast.InsertQuery( subject=s_utils.name_to_ast_ref(sub_name), shape=insert_shape, ) ql_stmt: qlast.Expr = ql_stmt_insert if not is_value_single: # value relation might contain multiple rows # to express this in EdgeQL, we must wrap `insert` into a `for` query ql_stmt = qlast.ForQuery( iterator=qlast.Path(steps=[qlast.IRAnchor(name=value_name)]), iterator_alias=iterator_name, result=ql_stmt, ) # on conflict conflict = _uncompile_on_conflict( stmt, sub, sub_table, value_id, ctx, stype_refs ) if conflict: ql_stmt_insert.unless_conflict = conflict.ql_unless_conflict ql_returning_shape: list[qlast.ShapeElement] = [] if stmt.returning_list: # construct the shape that will extract all needed column of the subject # table (because they might be be used by RETURNING clause) for column in sub_table.columns: if column.hidden: continue _, ptr_name, _ = _get_pointer_for_column(column, sub, ctx) ql_returning_shape.append( qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=ptr_name)]), ) ) ql_singletons = {value_id} ql_anchors = {value_name: value_id} external_rels: dict[irast.PathId, ExternalRel] = { value_id: ( value_rel, (pgce.PathAspect.SOURCE,), ) } subject_columns = None if conflict and conflict.update_name is not None: assert conflict.update_id assert conflict.update_input_placeholder # inject path_output for identity aspect # (that will be injected after resolving) conflict.update_input_placeholder.path_outputs[ (conflict.update_id, pgce.PathAspect.VALUE) ] = pgast.ColumnRef(name=('id',)) conflict.update_input_placeholder.path_outputs[ (conflict.update_id, pgce.PathAspect.IDENTITY) ] = pgast.ColumnRef(name=('id',)) # register __cu__ as singleton and provided by an external rel ql_singletons.add(conflict.update_id) ql_anchors[conflict.update_name] = conflict.update_id external_rels[conflict.update_id] = ( conflict.update_input_placeholder, ( pgce.PathAspect.SOURCE, pgce.PathAspect.VALUE, pgce.PathAspect.IDENTITY, ), ) # subject columns are needed to pull them into the "conflict update" rel subject_columns = [] for column in sub_table.columns: if column.hidden: continue ptr, _, _ = _get_pointer_for_column(column, sub, ctx) path_id = _get_ptr_id(subject_id, ptr, ctx) subject_columns.append((column.name, path_id)) return UncompiledDML( input=stmt, subject=sub, ql_stmt=ql_stmt, ql_returning_shape=ql_returning_shape, ql_singletons=ql_singletons, ql_anchors=ql_anchors, external_rels=external_rels, stype_refs=stype_refs, early_result=context.CompiledDML( value_cte_name=value_cte_name, value_relation_input=value_relation, value_columns=value_columns, value_iterator_name=value_iterator, conflict_update_input=conflict.update_input if conflict else None, conflict_update_name=conflict.update_name if conflict else None, conflict_update_iterator=( conflict.update_iterator if conflict else None ), subject_id=subject_id, subject_columns=subject_columns, value_id=value_id, # these will be populated after compilation output_ctes=[], output_relation_name='', output_namespace={}, ), # these will be populated by _uncompile_dml_stmt subject_columns=[], ) @dataclasses.dataclass(kw_only=True, frozen=True, repr=False, eq=False) class UncompileOnConflict: ql_unless_conflict: tuple[qlast.Expr | None, qlast.Expr | None] # relation (that still has to be resolved) which will provide the values of # columns that must be updated by the ON CONFLICT clause update_input: Optional[pgast.Query] = None # name of the CTE that will contain resolved update_input update_name: Optional[str] = None # name of the IR anchor which provides the iterator to the update stmt update_iterator: Optional[str] = None # IR id which can be used as an anchor that will refer to the update_input update_id: Optional[irast.PathId] = None # a dummy relation that has all necessary path_outputs set, so it can be # passed into external_rels update_input_placeholder: Optional[pgast.Relation] = None # Uncompiles pg INSERT ON CONFLICT into edgeql UNLESS CONFLICT. # Will produce: # - qlast unless conflict that contain an empty node or an update stmt, # - update_input relation (that provides values to the UPDATE stmt) and a bunch # of related variables needed for compiling it. def _uncompile_on_conflict( stmt: pgast.InsertStmt, sub: s_objtypes.ObjectType, sub_table: context.Table, value_id: irast.PathId, ctx: Context, stype_refs: dict[uuid.UUID, list[qlast.Set]] ) -> Optional[UncompileOnConflict]: if not stmt.on_conflict: return None # determine the target constraint on_clause: Optional[qlast.Expr] = None if stmt.on_conflict.target: tgt = stmt.on_conflict.target if tgt.constraint_name: raise errors.UnsupportedFeatureError( 'ON CONFLICT ON CONSTRAINT', span=tgt.span, ) if tgt.index_where: raise errors.UnsupportedFeatureError( 'ON CONFLICT WHERE', span=tgt.span, ) index_col_names: list[tuple[str, qlast.Span | None]] = [] for e in tgt.index_elems or []: if e.nulls_ordering or e.ordering: raise errors.UnsupportedFeatureError( 'ON CONFLICT index ordering', span=tgt.span, ) if not ( isinstance(e.expr, pgast.ColumnRef) and len(e.expr.name) == 1 and isinstance(e.expr.name[0], str) ): raise errors.UnsupportedFeatureError( 'ON CONFLICT supports only plain column names', ) index_col_names.append((e.expr.name[0], e.expr.span)) index_cols = _pull_columns_from_table(sub_table, iter(index_col_names)) index_paths: list[qlast.Expr] = [] for index_col in index_cols: _ptr, ptr_name, _is_link = _get_pointer_for_column( index_col, sub, ctx ) index_paths.append(qlast.Path( partial=True, steps=[qlast.Ptr(name=ptr_name)] )) on_clause = qlast.Tuple(elements=index_paths) if stmt.on_conflict.action == pgast.OnConflictAction.DO_NOTHING: return UncompileOnConflict( ql_unless_conflict=(None, None) # contraints columns, update ) if not on_clause: raise errors.QueryError( 'ON CONFLICT DO UPDATE requires index specification by column ' 'names', pgext_code=pgerror.ERROR_SYNTAX_ERROR, span=stmt.on_conflict.span, ) # determine names of updated columns update_col_names = [] assert stmt.on_conflict.update_list is not None for col in stmt.on_conflict.update_list: if isinstance(col, pgast.MultiAssignRef): raise errors.UnsupportedFeatureError( 'ON CONFLICT UPDATE of multiple columns at once', span=col.span, ) assert isinstance(col, pgast.UpdateTarget) if col.indirection: raise errors.UnsupportedFeatureError( 'ON CONFLICT UPDATE with index indirection', span=col.span, ) update_col_names.append((col.name, col.span)) update_columns = _pull_columns_from_table( sub_table, iter(update_col_names), ) # construct update shape update_name = ctx.alias_generator.get('cu') iterator_name = ctx.alias_generator.get('cu_iter') update_id = irast.PathId.from_type( ctx.schema, sub, typename=sn.QualName('__derived__', update_name), env=None, ) # the shape of the edge ql update we will generate conflict_update_shape = [] # IR anchor and placeholder relation for the relation that will provide # values for each of the updated columns. This relation will be replaced # by the resolved sql relation. conflict_source_ql = qlast.Path(steps=[qlast.ObjectRef(name=iterator_name)]) update_input_placeholder = pgast.Relation( name=update_name, strip_output_namespaces=True, ) for column in update_columns: ptr, ptr_name, is_link = _get_pointer_for_column(column, sub, ctx) # prepare the outputs of the source CTE ptr_id = _get_ptr_id(update_id, ptr, ctx) output = pgast.ColumnRef(name=(ptr_name,), nullable=True) if is_link: update_input_placeholder.path_outputs[ (ptr_id, pgce.PathAspect.IDENTITY)] = output update_input_placeholder.path_outputs[ (ptr_id, pgce.PathAspect.VALUE)] = output else: update_input_placeholder.path_outputs[ (ptr_id, pgce.PathAspect.VALUE)] = output conflict_update_shape.append( _construct_assign_element_for_ptr( conflict_source_ql, ptr_name, ptr, is_link, ctx, stype_refs, ) ) sub_name = sub.get_name(ctx.schema) ql_update = qlast.UpdateQuery( subject=qlast.Path( steps=[s_utils.name_to_ast_ref(sub_name)] ), shape=conflict_update_shape, where=qlast.BinOp( left=qlast.Path(steps=[ qlast.ObjectRef(name=iterator_name), qlast.Ptr(name='id') ]), op='=', right=qlast.Path(steps=[ s_utils.name_to_ast_ref(sub_name), qlast.Ptr(name='id') ]), ) ) # update_value relation has to be evaluated *for each conflicting row* # to express this in EdgeQL, we must wrap `update` into a `for` query ql_stmt = qlast.ForQuery( iterator=qlast.Path(steps=[qlast.IRAnchor(name=update_name)]), iterator_alias=iterator_name, result=ql_update, ) # the relation that we will later resolve and which contains all columns # that the user specified to need to be updated. When this is resolved, # it will replace conflict_source_rel CTE in the final compiled output. update_input = pgast.SelectStmt( target_list=[ pgast.ResTarget( name=ut.name, val=ut.val, # TODO: UpdateTarget indirections ) for ut in stmt.on_conflict.update_list if isinstance(ut, pgast.UpdateTarget) ], where_clause=stmt.on_conflict.update_where, ) return UncompileOnConflict( ql_unless_conflict=(on_clause, ql_stmt), update_name=update_name, update_id=update_id, update_input=update_input, update_input_placeholder=update_input_placeholder, ) def _construct_assign_element_for_ptr( source_ql: qlast.PathElement, ptr_name: str, ptr: s_pointers.Pointer, is_link: bool, ctx: context.ResolverContextLevel, stype_refs: dict[uuid.UUID, list[qlast.Set]], ): ptr_ql: qlast.Expr = qlast.Path( steps=[ source_ql, qlast.Ptr(name=ptr_name), ] ) if is_link: # Convert UUIDs into objects. assert isinstance(ptr_ql, qlast.Path) target = ptr.get_target(ctx.schema) assert isinstance(target, s_objtypes.ObjectType) ptr_ql = _construct_cast_from_uuid_to_obj_type( ptr_ql, target, stype_refs, optional=True, ctx=ctx ) return qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=ptr_name)]), operation=qlast.ShapeOperation(op=qlast.ShapeOp.ASSIGN), compexpr=ptr_ql, ) def _construct_cast_from_uuid_to_obj_type( ptr_ql: qlast.Path, object: s_objtypes.ObjectType, stype_refs: dict[uuid.UUID, list[qlast.Set]], *, optional: bool, ctx: Context, ) -> qlast.Expr: # Constructs AST that converts a UUID provided by ptr_ql to an object type. # This mechanism similar to overlays in IR->SQL compiler. # Makes sure that when an object is inserted, later casts from UUID do find # this object. This is needed because this cast is part of the # "under-the-hood" mechanism and is not visible to the user. They perceive # plain UUID insertion and they expect FOREIGN KEY constrains to reject # invalid UUIDs. # Constructs qlast equivalent to: # for i in ptr_ql union # assert_exists(( # select {type_name, #all preceding DML clauses#} # filter .id = i.id # limit 1 # )) # else {} object_name: sn.Name = object.get_name(ctx.schema) ptr_id_ql = qlast.Path(steps=ptr_ql.steps + [qlast.Ptr(name='id')]) if optional: ptr_iter = ctx.alias_generator.get('i') id_ql = qlast.Path(steps=[qlast.ObjectRef(name=ptr_iter)]) else: id_ql = ptr_id_ql stype_ref = qlast.Set( elements=[ qlast.Path(steps=[s_utils.name_to_ast_ref(object_name)]), # here we later inject references to preceding inserts of this type ] ) if object.id not in stype_refs: stype_refs[object.id] = [] stype_refs[object.id].append(stype_ref) res: qlast.Expr = qlast.FunctionCall( func=('std', 'assert_exists'), args=[ qlast.SelectQuery( result=stype_ref, where=qlast.BinOp( left=qlast.Path(partial=True, steps=[qlast.Ptr(name='id')]), op='=', right=id_ql, ), # this is needed for cardinality check only: there will # always be at most one object with matching id. It will be # either an existing object or a newly inserted one. limit=qlast.Constant.integer(1), ) ], kwargs={ 'message': qlast.BinOp( left=qlast.Constant.string( f'object type {object_name} with id \'' ), op='++', right=qlast.BinOp( left=qlast.TypeCast( expr=id_ql, type=qlast.TypeName( maintype=qlast.ObjectRef(module='std', name='str'), ), ), op='++', right=qlast.Constant.string(f'\' does not exist'), ), ) }, ) if optional: res = qlast.ForQuery( iterator=ptr_id_ql, iterator_alias=ptr_iter, result=res, ) return res def _add_pointer( source: s_objtypes.ObjectType, name: str, target_scls: s_types.Type, *, ctx: Context, ) -> s_pointers.Pointer: base_name = 'link' if target_scls.is_object_type() else 'property' base = ctx.schema.get( sn.QualName('std', base_name), type=s_pointers.Pointer, ) ctx.schema, ptr = base.derive_ref( ctx.schema, source, name=base.get_derived_name( ctx.schema, source, derived_name_base=sn.QualName('__', name) ), target=target_scls, inheritance_refdicts={'pointers'}, mark_derived=True, transient=True, ) return ptr def _uncompile_insert_pointer_stmt( stmt: pgast.InsertStmt, sub: s_links.Link | s_properties.Property, sub_table: context.Table, expected_columns: list[context.Column], *, ctx: Context, ) -> UncompiledDML: """ Translates a SQL 'INSERT INTO a link / multi-property table' into an `EdgeQL update SourceObject { subject: ... }`. """ if stmt.on_conflict: raise errors.UnsupportedFeatureError( 'ON CONFLICT is not yet supported for link tables', span=stmt.on_conflict.span, ) if not any(c.name == 'source' for c in expected_columns): raise errors.QueryError( 'column source is required when inserting into link tables', span=stmt.span, ) if not any(c.name == 'target' for c in expected_columns): raise errors.QueryError( 'column target is required when inserting into link tables', span=stmt.span, ) sub_source = sub.get_source(ctx.schema) assert isinstance(sub_source, s_objtypes.ObjectType) sub_target = sub.get_target(ctx.schema) assert sub_target # handle DEFAULT and prepare the value relation value_relation, expected_columns = _uncompile_default_value( stmt.select_stmt, stmt.ctes, expected_columns, sub, ctx=ctx ) # if we are sure that we are inserting a single row # we can skip for-loops and iterators, which produces better SQL # is_value_single = _has_at_most_one_row(stmt.select_stmt) is_value_single = False free_obj_ty = ctx.schema.get('std::FreeObject', type=s_objtypes.ObjectType) ctx.schema, dummy_ty = free_obj_ty.derive_subtype( ctx.schema, name=sn.QualName('__derived__', ctx.alias_generator.get('ins_ty')), mark_derived=True, transient=True, ) src_ptr = _add_pointer(dummy_ty, '__source__', sub_source, ctx=ctx) tgt_ptr = _add_pointer(dummy_ty, '__target__', sub_target, ctx=ctx) # prepare anchors for inserted value columns value_name = ctx.alias_generator.get('ins_val') iterator_name = ctx.alias_generator.get('ins_iter') base_id = irast.PathId.from_type( ctx.schema, dummy_ty, typename=sn.QualName('__derived__', value_name), env=None, ) source_id = _get_ptr_id(base_id, src_ptr, ctx=ctx) target_id = _get_ptr_id(base_id, tgt_ptr, ctx=ctx) value_ql: qlast.PathElement = ( qlast.IRAnchor(name=value_name) if is_value_single else qlast.ObjectRef(name=iterator_name) ) # a phantom relation that is supposed to hold the inserted value # (in the resolver, this will be replaced by the real value relation) value_cte_name = ctx.alias_generator.get('ins_value') value_rel = pgast.Relation( name=value_cte_name, strip_output_namespaces=True, ) value_columns: list[tuple[str, str, bool]] = [] for index, expected_col in enumerate(expected_columns): ptr: Optional[s_pointers.Pointer] = None if expected_col.name == 'source': ptr_name = 'source' is_link = True ptr_id = source_id elif expected_col.name == 'target': ptr_name = 'target' is_link = isinstance(sub, s_links.Link) ptr = sub ptr_id = target_id else: # link pointer assert isinstance(sub, s_links.Link) ptr_name = expected_col.name ptr = sub.maybe_get_ptr(ctx.schema, sn.UnqualName(ptr_name)) assert ptr lprop_tgt = not_none(ptr.get_target(ctx.schema)) lprop_ptr = _add_pointer(dummy_ty, ptr_name, lprop_tgt, ctx=ctx) ptr_id = _get_ptr_id(base_id, lprop_ptr, ctx=ctx) is_link = False var = pgast.ColumnRef(name=(ptr_name,), nullable=True) value_rel.path_outputs[(ptr_id, pgce.PathAspect.VALUE)] = var # inject type annotation into value relation if is_link: _try_inject_type_cast( value_relation, index, pgast.TypeName(name=('uuid',)) ) else: assert ptr _try_inject_ptr_type_cast(value_relation, index, ptr, ctx) value_columns.append((expected_col.name, ptr_name, is_link)) # source needs an iterator column, so we need to invent one # Here we only decide on the name of that iterator column, the actual column # is generated later, when resolving the DML stmt. value_iterator = ctx.alias_generator.get('iter') var = pgast.ColumnRef(name=(value_iterator,)) value_rel.path_outputs[(base_id, pgce.PathAspect.ITERATOR)] = var value_rel.path_outputs[(base_id, pgce.PathAspect.VALUE)] = var # construct the EdgeQL DML AST stype_refs: dict[uuid.UUID, list[qlast.Set]] = {} sub_name = sub.get_shortname(ctx.schema) target_ql: qlast.Expr = qlast.Path( steps=[value_ql, qlast.Ptr(name='__target__')] ) if isinstance(sub_target, s_objtypes.ObjectType): assert isinstance(target_ql, qlast.Path) target_ql = _construct_cast_from_uuid_to_obj_type( target_ql, sub_target, stype_refs, optional=True, ctx=ctx ) ql_ptr_val: qlast.Expr if isinstance(sub, s_links.Link): ql_ptr_val = qlast.Shape( expr=target_ql, elements=[ qlast.ShapeElement( expr=qlast.Path( steps=[qlast.Ptr(name=ptr_name, type='property')], ), compexpr=qlast.Path( steps=[ value_ql, # qlast.Ptr(name=sub_name.name), qlast.Ptr(name=ptr_name), ], ), ) for ptr_name, _, _ in value_columns if ptr_name not in ('source', 'target') ], ) else: # multi pointer ql_ptr_val = target_ql source_ql_p = qlast.Path(steps=[value_ql, qlast.Ptr(name='__source__')]) # XXX: rewrites are getting missed when we do this cast! Now, we # *want* rewrites getting missed tbh, but I think it's a broader # bug. source_ql = _construct_cast_from_uuid_to_obj_type( source_ql_p, sub_source, stype_refs, optional=True, ctx=ctx, ) is_multi = sub.get_cardinality(ctx.schema) == qltypes.SchemaCardinality.Many # Update the source_ql directly -- the filter is done there. ql_stmt: qlast.Expr = qlast.UpdateQuery( subject=source_ql, shape=[ qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=sub_name.name)]), operation=( qlast.ShapeOperation(op=qlast.ShapeOp.APPEND) if is_multi else qlast.ShapeOperation(op=qlast.ShapeOp.ASSIGN) ), compexpr=ql_ptr_val, ) ], ) if not is_value_single: # value relation might contain multiple rows # to express this in EdgeQL, we must wrap `insert` into a `for` query ql_stmt = qlast.ForQuery( iterator=qlast.Path(steps=[qlast.IRAnchor(name=value_name)]), iterator_alias=iterator_name, result=ql_stmt, ) ql_returning_shape: list[qlast.ShapeElement] = [] if stmt.returning_list: # construct the shape that will extract all needed column of the subject # table (because they might be be used by RETURNING clause) for column in sub_table.columns: if column.hidden: continue if column.name in ('source', 'target'): # no need to include in shape, they will be present anyway continue ql_returning_shape.append( qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=column.name)]), compexpr=qlast.Path( partial=True, steps=[ qlast.Ptr(name=sub_name.name), qlast.Ptr(name=column.name, type='property'), ], ), ) ) return UncompiledDML( input=stmt, subject=sub, ql_stmt=ql_stmt, ql_returning_shape=ql_returning_shape, ql_singletons={base_id}, ql_anchors={value_name: base_id}, external_rels={ base_id: ( value_rel, (pgce.PathAspect.SOURCE,), ) }, stype_refs=stype_refs, early_result=context.CompiledDML( value_cte_name=value_cte_name, value_relation_input=value_relation, value_columns=value_columns, value_iterator_name=value_iterator, # these will be populated after compilation output_ctes=[], output_relation_name='', output_namespace={}, ), # these will be populated by _uncompile_dml_stmt subject_columns=[], ) def _has_at_most_one_row(query: pgast.Query | None) -> bool: if not query: return True return False def _compile_standalone_default( col: context.Column, sub: s_objtypes.ObjectType | s_links.Link | s_properties.Property, ctx: Context, ) -> pgast.BaseExpr: ptr, _, _ = _get_pointer_for_column(col, sub, ctx) default = ptr.get_default(ctx.schema) if default is None: return pgast.NullConstant() # TODO(?): Support defaults that reference the object being inserted. # That seems like a pretty heavy lift in this scenario, though. options = qlcompiler.CompilerOptions( make_globals_empty=False, apply_user_access_policies=ctx.options.apply_access_policies, ) compiled = default.compiled(ctx.schema, options=options, context=None) sql_tree = pgcompiler.compile_ir_to_sql_tree( compiled.irast, output_format=pgcompiler.OutputFormat.NATIVE_INTERNAL, alias_generator=ctx.alias_generator, ) merge_params(sql_tree, compiled.irast, ctx) assert isinstance(sql_tree.ast, pgast.BaseExpr) return sql_tree.ast def _uncompile_default_value( value_query: Optional[pgast.Query], value_ctes: Optional[list[pgast.CommonTableExpr]], expected_columns: list[context.Column], sub: s_objtypes.ObjectType | s_links.Link | s_properties.Property, *, ctx: Context, ) -> tuple[pgast.BaseRelation, list[context.Column]]: # INSERT INTO x DEFAULT VALUES if not value_query: value_query = pgast.SelectStmt(values=[]) # edgeql compiler will provide default values # (and complain about missing ones) expected_columns = [] return value_query, expected_columns # VALUES (DEFAULT) if isinstance(value_query, pgast.SelectStmt) and value_query.values: # find DEFAULT keywords in VALUES def is_default(e: pgast.BaseExpr) -> bool: return isinstance(e, pgast.Keyword) and e.name == 'DEFAULT' default_columns: dict[int, int] = {} for row in value_query.values: assert isinstance(row, pgast.ImplicitRowExpr) for to_remove, col in enumerate(row.args): if is_default(col): default_columns[to_remove] = ( default_columns.setdefault(to_remove, 0) + 1 ) # remove DEFAULT keywords and expected columns, # so EdgeQL insert will not get those columns, which will use the # property defaults. for to_remove in sorted(default_columns, reverse=True): if default_columns[to_remove] != len(value_query.values): continue raise errors.QueryError( 'DEFAULT keyword is supported only when ' 'used for a column in all rows', span=value_query.span, pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, ) del expected_columns[to_remove] for r_index, row in enumerate(value_query.values): assert isinstance(row, pgast.ImplicitRowExpr) assert is_default(row.args[to_remove]) cols = list(row.args) del cols[to_remove] value_query.values[r_index] = row.replace(args=cols) # Go back through and compile any left over for r_index, row in enumerate(value_query.values): assert isinstance(row, pgast.ImplicitRowExpr) if not any(is_default(col) for col in row.args): continue cols = list(row.args) for i, col in enumerate(row.args): if is_default(col): cols[i] = _compile_standalone_default( expected_columns[i], sub, ctx=ctx ) value_query.values[r_index] = row.replace(args=cols) if ( len(value_query.values) > 0 and isinstance(value_query.values[0], pgast.ImplicitRowExpr) and len(value_query.values[0].args) == 0 ): # special case: `VALUES (), (), ..., ()` # This is syntactically incorrect, so we transform it into: # `SELECT FROM (VALUES (NULL), (NULL), ..., (NULL)) _` value_query = pgast.SelectStmt( target_list=[], from_clause=[ pgast.RangeSubselect( subquery=pgast.SelectStmt( values=[ pgast.ImplicitRowExpr( args=[pgast.NullConstant()] ) for _ in value_query.values ] ), alias=pgast.Alias(aliasname='_'), ) ], ) # compile these CTEs as they were defined on value relation assert not value_query.ctes value_query.ctes = value_ctes return value_query, expected_columns @_uncompile_dml_stmt.register def _uncompile_delete_stmt( stmt: pgast.DeleteStmt, *, ctx: Context ) -> UncompiledDML: # determine the subject object sub_table, sub = _uncompile_dml_subject(stmt.relation, ctx=ctx) res: UncompiledDML if isinstance(sub, s_objtypes.ObjectType): res = _uncompile_delete_object_stmt(stmt, sub, sub_table, ctx=ctx) elif isinstance(sub, (s_links.Link, s_properties.Property)): res = _uncompile_delete_pointer_stmt(stmt, sub, sub_table, ctx=ctx) else: raise NotImplementedError() if stmt.returning_list: _uncompile_subject_columns(sub, sub_table, res, ctx=ctx) return res def _uncompile_delete_object_stmt( stmt: pgast.DeleteStmt, sub: s_objtypes.ObjectType, sub_table: context.Table, *, ctx: Context, ) -> UncompiledDML: """ Translates a 'SQL DELETE of object type table' to an EdgeQL delete. """ # prepare value relation # For deletes, value relation contains a single column of ids of all the # objects that need to be deleted. We construct this relation from WHERE # and USING clauses of DELETE. assert isinstance(stmt.relation, pgast.RelRangeVar) val_sub_rvar = stmt.relation.alias.aliasname or stmt.relation.relation.name assert val_sub_rvar value_relation = pgast.SelectStmt( ctes=stmt.ctes, target_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=(val_sub_rvar, 'id'), ) ) ], from_clause=[ pgast.RelRangeVar( relation=stmt.relation.relation, alias=pgast.Alias(aliasname=val_sub_rvar), # DELETE ONLY include_inherited=stmt.relation.include_inherited, ) ] + stmt.using_clause, where_clause=stmt.where_clause, ) stmt.ctes = [] # prepare anchors for inserted value columns value_name = ctx.alias_generator.get('del_val') value_id = irast.PathId.from_type( ctx.schema, sub, typename=sn.QualName('__derived__', value_name), env=None, ) value_ql = qlast.IRAnchor(name=value_name) # a phantom relation that contains a single column, which is the id of all # the objects that should be deleted. value_cte_name = ctx.alias_generator.get('del_value') value_rel = pgast.Relation( name=value_cte_name, strip_output_namespaces=True, ) value_columns = [('id', 'id', False)] output_var = pgast.ColumnRef(name=('id',), nullable=False) value_rel.path_outputs[(value_id, pgce.PathAspect.IDENTITY)] = output_var value_rel.path_outputs[(value_id, pgce.PathAspect.VALUE)] = output_var value_rel.path_outputs[(value_id, pgce.PathAspect.ITERATOR)] = output_var # construct the EdgeQL DML AST sub_name = sub.get_name(ctx.schema) where = qlast.BinOp( left=qlast.Path(partial=True, steps=[qlast.Ptr(name='id')]), op='IN', right=qlast.Path(steps=[value_ql, qlast.Ptr(name='id')]), ) ql_stmt: qlast.Expr = qlast.DeleteQuery( subject=qlast.Path(steps=[s_utils.name_to_ast_ref(sub_name)]), where=where, ) ql_returning_shape: list[qlast.ShapeElement] = [] if stmt.returning_list: # construct the shape that will extract all needed column of the subject # table (because they might be be used by RETURNING clause) for column in sub_table.columns: if column.hidden: continue _, ptr_name, _ = _get_pointer_for_column(column, sub, ctx) ql_returning_shape.append( qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=ptr_name)]), ) ) return UncompiledDML( input=stmt, subject=sub, ql_stmt=ql_stmt, ql_returning_shape=ql_returning_shape, ql_singletons={value_id}, ql_anchors={value_name: value_id}, external_rels={ value_id: ( value_rel, (pgce.PathAspect.SOURCE,), ) }, stype_refs={}, early_result=context.CompiledDML( value_cte_name=value_cte_name, value_relation_input=value_relation, value_columns=value_columns, value_iterator_name=None, # these will be populated after compilation output_ctes=[], output_relation_name='', output_namespace={}, ), # these will be populated by _uncompile_dml_stmt subject_columns=[], ) def _uncompile_delete_pointer_stmt( stmt: pgast.DeleteStmt, sub: s_links.Link | s_properties.Property, sub_table: context.Table, *, ctx: Context, ) -> UncompiledDML: """ Translates a SQL 'DELETE FROM a link / multi-property table' into an EdgeQL `update SourceObject { pointer := ... }.pointer`. """ sub_source = sub.get_source(ctx.schema) assert isinstance(sub_source, s_objtypes.ObjectType) sub_target = sub.get_target(ctx.schema) assert sub_target # prepare value relation # For link deletes, value relation contains two columns: source and target # of all links that need to be deleted. We construct this relation from # WHERE and USING clauses of DELETE. assert isinstance(stmt.relation, pgast.RelRangeVar) val_sub_rvar = stmt.relation.alias.aliasname or stmt.relation.relation.name assert val_sub_rvar value_relation = pgast.SelectStmt( ctes=stmt.ctes, target_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=(val_sub_rvar, 'source'), ) ), pgast.ResTarget( val=pgast.ColumnRef( name=(val_sub_rvar, 'target'), ) ), ], from_clause=[ pgast.RelRangeVar( relation=stmt.relation.relation, alias=pgast.Alias(aliasname=val_sub_rvar), ) ] + stmt.using_clause, where_clause=stmt.where_clause, ) stmt.ctes = [] # if we are sure that we are updating a single source object, # we can skip for-loops and iterators, which produces better SQL is_value_single = False # prepare anchors for inserted value columns value_name = ctx.alias_generator.get('ins_val') iterator_name = ctx.alias_generator.get('ins_iter') source_id = irast.PathId.from_type( ctx.schema, sub_source, typename=sn.QualName('__derived__', value_name), env=None, ) link_ref = irtypeutils.ptrref_from_ptrcls( schema=ctx.schema, ptrcls=sub, cache=None, typeref_cache=None ) value_id: irast.PathId = source_id.extend(ptrref=link_ref) value_ql: qlast.PathElement = ( qlast.IRAnchor(name=value_name) if is_value_single else qlast.ObjectRef(name=iterator_name) ) # a phantom relation that is supposed to hold the two source and target # columns of rows that need to be deleted. value_cte_name = ctx.alias_generator.get('del_value') value_rel = pgast.Relation( name=value_cte_name, strip_output_namespaces=True, ) value_columns = [('source', 'source', False), ('target', 'target', False)] var = pgast.ColumnRef(name=('source',), nullable=True) value_rel.path_outputs[(source_id, pgce.PathAspect.VALUE)] = var value_rel.path_outputs[(source_id, pgce.PathAspect.IDENTITY)] = var tgt_id = value_id.tgt_path() var = pgast.ColumnRef(name=('target',), nullable=True) value_rel.path_outputs[(tgt_id, pgce.PathAspect.VALUE)] = var value_rel.path_outputs[(tgt_id, pgce.PathAspect.IDENTITY)] = var # source needs an iterator column, so we need to invent one # Here we only decide on the name of that iterator column, the actual column # is generated later, when resolving the DML stmt. value_iterator = ctx.alias_generator.get('iter') var = pgast.ColumnRef(name=(value_iterator,)) value_rel.path_outputs[(source_id, pgce.PathAspect.ITERATOR)] = var value_rel.path_outputs[(value_id, pgce.PathAspect.ITERATOR)] = var # construct the EdgeQL DML AST sub_name = sub.get_name(ctx.schema) sub_source_name = sub_source.get_name(ctx.schema) sub_target_name = sub_target.get_name(ctx.schema) sub_name = sub.get_shortname(ctx.schema) ql_sub_source_ref = s_utils.name_to_ast_ref(sub_source_name) ql_sub_target_ref = s_utils.name_to_ast_ref(sub_target_name) ql_ptr_val: qlast.Expr = qlast.Path( steps=[value_ql, qlast.Ptr(name=sub_name.name)] ) if isinstance(sub, s_links.Link): ql_ptr_val = qlast.TypeCast( expr=ql_ptr_val, type=qlast.TypeName(maintype=ql_sub_target_ref), ) ql_stmt: qlast.Expr = qlast.UpdateQuery( subject=qlast.Path(steps=[ql_sub_source_ref]), where=qlast.BinOp( # ObjectType == value.source left=qlast.Path(steps=[ql_sub_source_ref]), op='=', right=qlast.Path(steps=[value_ql]), ), shape=[ qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=sub_name.name)]), operation=qlast.ShapeOperation(op=qlast.ShapeOp.SUBTRACT), compexpr=ql_ptr_val, ) ], ) if not is_value_single: # value relation might contain multiple rows # to express this in EdgeQL, we must wrap `delete` into a `for` query ql_stmt = qlast.ForQuery( iterator=qlast.Path(steps=[qlast.IRAnchor(name=value_name)]), iterator_alias=iterator_name, result=ql_stmt, ) # append .pointer onto the shape, so the resulting CTE contains the pointer # data, not the subject table # ql_stmt = qlast.Path( # steps=[ql_stmt, qlast.Ptr(name=sub_name.name)] # ) ql_returning_shape: list[qlast.ShapeElement] = [] if stmt.returning_list: # construct the shape that will extract all needed column of the subject # table (because they might be be used by RETURNING clause) for column in sub_table.columns: if column.hidden: continue if column.name in ('source', 'target'): # no need to include in shape, they will be present anyway continue ql_returning_shape.append( qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=column.name)]), compexpr=qlast.Path( partial=True, steps=[ qlast.Ptr(name=sub_name.name), qlast.Ptr(name=column.name, type='property'), ], ), ) ) return UncompiledDML( input=stmt, subject=sub, ql_stmt=ql_stmt, ql_returning_shape=ql_returning_shape, ql_singletons={source_id}, ql_anchors={value_name: source_id}, external_rels={ source_id: ( value_rel, (pgce.PathAspect.SOURCE,), ) }, stype_refs={}, early_result=context.CompiledDML( value_cte_name=value_cte_name, value_relation_input=value_relation, value_columns=value_columns, value_iterator_name=value_iterator, # these will be populated after compilation output_ctes=[], output_relation_name='', output_namespace={}, ), # these will be populated by _uncompile_dml_stmt subject_columns=[], ) @_uncompile_dml_stmt.register def _uncompile_update_stmt( stmt: pgast.UpdateStmt, *, ctx: Context ) -> UncompiledDML: # determine the subject object sub_table, sub = _uncompile_dml_subject(stmt.relation, ctx=ctx) # convert the general repr of SET clause into a list of columns update_targets: list[pgast.UpdateTarget] = [] for target in stmt.targets: if isinstance(target, pgast.UpdateTarget): if target.indirection: raise errors.QueryError( 'indirections in UPDATE SET not supported', pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, span=stmt.span, ) update_targets.append(target) elif isinstance(target, pgast.MultiAssignRef): if not isinstance( target.source, (pgast.ImplicitRowExpr, pgast.RowExpr) ): raise errors.QueryError( 'multi-assigns UPDATE SET are supported only for plain row ' 'literals (`ROW(...)`)', pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, span=stmt.span, ) update_targets.extend( pgast.UpdateTarget(name=c, val=v, span=v.span) for (c, v) in zip(target.columns, target.source.args) ) else: raise NotImplementedError() set_columns = _pull_columns_from_table( sub_table, ((c.name, c.span) for c in update_targets), ) column_updates = list(zip(set_columns, (c.val for c in update_targets))) res: UncompiledDML if isinstance(sub, s_objtypes.ObjectType): res = _uncompile_update_object_stmt( stmt, sub, sub_table, column_updates, ctx=ctx ) elif isinstance(sub, (s_links.Link, s_properties.Property)): raise errors.QueryError( f'UPDATE of link tables is not supported', pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, span=stmt.span, ) else: raise NotImplementedError() if stmt.returning_list: _uncompile_subject_columns(sub, sub_table, res, ctx=ctx) return res def _uncompile_update_object_stmt( stmt: pgast.UpdateStmt, sub: s_objtypes.ObjectType, sub_table: context.Table, column_updates: list[tuple[context.Column, pgast.BaseExpr]], *, ctx: Context, ) -> UncompiledDML: """ Translates a 'SQL UPDATE into an object type table' to an EdgeQL update. """ def is_default(e: pgast.BaseExpr) -> bool: return isinstance(e, pgast.Keyword) and e.name == 'DEFAULT' # prepare value relation # For updates, value relation contains: # - `id` column, that contains the id of the subject, # - one column for each of the pointers on the subject to be updated, # We construct this relation from WHERE and FROM clauses of UPDATE. assert isinstance(stmt.relation, pgast.RelRangeVar) val_sub_rvar = stmt.relation.alias.aliasname or stmt.relation.relation.name assert val_sub_rvar value_relation = pgast.SelectStmt( ctes=stmt.ctes, target_list=[ pgast.ResTarget( val=pgast.ColumnRef( name=(val_sub_rvar, 'id'), ) ) ] + [ pgast.ResTarget(val=val, name=c.name) for c, val in column_updates if not is_default(val) # skip DEFAULT column updates ], from_clause=[ pgast.RelRangeVar( relation=stmt.relation.relation, alias=pgast.Alias(aliasname=val_sub_rvar), # UPDATE ONLY include_inherited=stmt.relation.include_inherited, ) ] + stmt.from_clause, where_clause=stmt.where_clause, ) stmt.ctes = [] # prepare anchors for inserted value columns value_name = ctx.alias_generator.get('upd_val') iterator_name = ctx.alias_generator.get('upd_iter') value_id = irast.PathId.from_type( ctx.schema, sub, typename=sn.QualName('__derived__', value_name), env=None, ) value_ql: qlast.PathElement = qlast.ObjectRef(name=iterator_name) # a phantom relation that is supposed to hold the inserted value # (in the resolver, this will be replaced by the real value relation) value_cte_name = ctx.alias_generator.get('upd_value') value_rel = pgast.Relation( name=value_cte_name, strip_output_namespaces=True, ) output_var = pgast.ColumnRef(name=('id',)) value_rel.path_outputs[(value_id, pgce.PathAspect.ITERATOR)] = output_var value_rel.path_outputs[(value_id, pgce.PathAspect.VALUE)] = output_var value_columns = [('id', 'id', False)] update_shape = [] stype_refs: dict[uuid.UUID, list[qlast.Set]] = {} for index, (col, val) in enumerate(column_updates): ptr, ptr_name, is_link = _get_pointer_for_column(col, sub, ctx) if not is_default(val): value_columns.append((col.name, ptr_name, is_link)) # inject type annotation into value relation _try_inject_ptr_type_cast(value_relation, index + 1, ptr, ctx) # prepare the outputs of the source CTE ptr_id = _get_ptr_id(value_id, ptr, ctx) output_var = pgast.ColumnRef(name=(ptr_name,), nullable=True) if is_link: value_rel.path_outputs[(ptr_id, pgce.PathAspect.IDENTITY)] = ( output_var ) value_rel.path_outputs[(ptr_id, pgce.PathAspect.VALUE)] = output_var else: value_rel.path_outputs[(ptr_id, pgce.PathAspect.VALUE)] = output_var # prepare insert shape that will use the paths from source_outputs if is_default(val): # special case: DEFAULT default_ql: qlast.Expr if ptr.get_default(ctx.schema) is None: default_ql = qlast.Set(elements=[]) # NULL else: default_ql = qlast.Path( steps=[qlast.SpecialAnchor(name='__default__')] ) update_shape.append( qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=ptr_name)]), operation=qlast.ShapeOperation(op=qlast.ShapeOp.ASSIGN), compexpr=default_ql, ) ) else: # base case update_shape.append( _construct_assign_element_for_ptr( value_ql, ptr_name, ptr, is_link, ctx, stype_refs, ) ) # construct the EdgeQL DML AST sub_name = sub.get_name(ctx.schema) ql_sub_ref = s_utils.name_to_ast_ref(sub_name) where = qlast.BinOp( # ObjectType == value.source left=qlast.Path(steps=[ql_sub_ref]), op='=', right=qlast.Path(steps=[value_ql]), ) ql_stmt: qlast.Expr = qlast.UpdateQuery( subject=qlast.Path(steps=[ql_sub_ref]), where=where, shape=update_shape, ) # value relation might contain multiple rows # to express this in EdgeQL, we must wrap `update` into a `for` query ql_stmt = qlast.ForQuery( iterator=qlast.Path(steps=[qlast.IRAnchor(name=value_name)]), iterator_alias=iterator_name, result=ql_stmt, ) ql_returning_shape: list[qlast.ShapeElement] = [] if stmt.returning_list: # construct the shape that will extract all needed column of the subject # table (because they might be be used by RETURNING clause) for column in sub_table.columns: if column.hidden: continue _, ptr_name, _ = _get_pointer_for_column(column, sub, ctx) ql_returning_shape.append( qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=ptr_name)]), ) ) return UncompiledDML( input=stmt, subject=sub, ql_stmt=ql_stmt, ql_returning_shape=ql_returning_shape, ql_singletons={value_id}, ql_anchors={value_name: value_id}, external_rels={ value_id: ( value_rel, (pgce.PathAspect.SOURCE,), ) }, stype_refs=stype_refs, early_result=context.CompiledDML( value_cte_name=value_cte_name, value_relation_input=value_relation, value_columns=value_columns, value_iterator_name=None, # these will be populated after compilation output_ctes=[], output_relation_name='', output_namespace={}, ), # these will be populated by _uncompile_dml_stmt subject_columns=[], ) def _compile_uncompiled_dml( stmts: list[UncompiledDML], ctx: context.ResolverContextLevel ) -> tuple[ Mapping[pgast.Query, context.CompiledDML], list[pgast.CommonTableExpr], ]: """ Compiles *all* DML statements in the query. Statements must already be uncompiled into equivalent EdgeQL statements. Will merge the statements into one large shape of all DML queries and compile that with a single invocation of EdgeQL compiler. Returns: - mapping from the original SQL statement into CompiledDML and - a list of "global" CTEs that should be injected at the end of top-level CTE list. """ # merge params singletons = set() anchors: dict[str, irast.PathId] = {} for stmt in stmts: singletons.update(stmt.ql_singletons) anchors.update(stmt.ql_anchors) # construct the main query ql_aliases: list[qlast.Alias] = [] ql_stmt_shape: list[qlast.ShapeElement] = [] ql_stmt_shape_names = [] inserts_by_type: dict[uuid.UUID, list[str]] = {} for index, stmt in enumerate(stmts): # fixup references to stypes that have been modified be previous inserts # for more info, see _construct_cast_from_uuid_to_obj_type for stype_id, ref_sets in stmt.stype_refs.items(): if insert_names := inserts_by_type.get(stype_id, None): for ref_set in ref_sets: for name in insert_names: ref_set.elements.append( qlast.Path(steps=[qlast.ObjectRef(name=name)]) ) # the main thing name = f'dml_{index}' ql_stmt_shape_names.append(name) ql_aliases.append( qlast.AliasedExpr( alias=name, expr=stmt.ql_stmt, ) ) ql_stmt_shape.append( qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=name)]), compexpr=qlast.Shape( expr=qlast.Path(steps=[qlast.ObjectRef(name=name)]), elements=stmt.ql_returning_shape, ), ) ) # save inserts for later fixups if isinstance(stmt.input, pgast.InsertStmt) and stmt.subject.id: if stmt.subject.id not in inserts_by_type: inserts_by_type[stmt.subject.id] = [] inserts_by_type[stmt.subject.id].append(name) ql_stmt = qlast.SelectQuery( aliases=ql_aliases, result=qlast.Shape(expr=None, elements=ql_stmt_shape), ) ir_stmt: irast.Statement try: # compile synthetic ql statement into SQL options = qlcompiler.CompilerOptions( modaliases={None: 'default'}, make_globals_empty=False, singletons=singletons, anchors=anchors, allow_user_specified_id=ctx.options.allow_user_specified_id, apply_user_access_policies=ctx.options.apply_access_policies ) ir_stmt = qlcompiler.compile_ast_to_ir( ql_stmt, schema=ctx.schema, options=options, ) external_rels, ir_stmts = _merge_and_prepare_external_rels( ir_stmt, stmts, ql_stmt_shape_names ) sql_result = pgcompiler.compile_ir_to_sql_tree( ir_stmt, external_rels=external_rels, output_format=pgcompiler.OutputFormat.NATIVE_INTERNAL, alias_generator=ctx.alias_generator, sql_dml_mode=True, ) merge_params(sql_result, ir_stmt, ctx) except errors.QueryError as e: raise errors.QueryError( msg=e.args[0], details=e.details, hint=e.hint, # not sure if this is ok, but it is better than InternalServerError, # which is the default pgext_code=pgerror.ERROR_DATA_EXCEPTION, ) except errors.UnsupportedFeatureError as e: raise errors.QueryError( msg=e.args[0], position=e.get_position(), details=e.details, hint=e.hint, pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, ) assert isinstance(sql_result.ast, pgast.Query) assert sql_result.ast.ctes ctes = list(sql_result.ast.ctes) result = {} for stmt, ir_mutating_stmt in zip(stmts, ir_stmts): stmt_ctes = _collect_stmt_ctes(ctes, ir_mutating_stmt) # Find the output CTE of the DML operation. We do this in two different # ways: # - look for SQL DML on the subject relation. This is used for # operations on link tables. Kinda hacky. # - use the `output_for_dml`, which will be set on the CTE that contains # the union of all SQL DML stmts that are generated for an IR DML. # There might be multiple because: 1) inheritance, which stores child # objects in seprate tables, 2) unless conflict that contains another # DML stmt. output_cte: pgast.CommonTableExpr | None if isinstance(stmt.subject, (s_pointers.Pointer)): subject_id = str(stmt.subject.id) output_cte = next( c for c in reversed(stmt_ctes) if isinstance(c.query, pgast.DMLQuery) and isinstance(c.query.relation, pgast.RelRangeVar) and c.query.relation.relation.name == subject_id ) else: output_cte = next( (c for c in stmt_ctes if c.output_of_dml == ir_mutating_stmt), None, ) assert output_cte, 'cannot find the output CTE of a DML stmt' output_rel = output_cte.query # This "output_rel" must contain entry in path_namespace for each column # of the subject table. This is ensured by applying a shape on the ql # dml stmt, which selects all pointers. Although the shape is not # constructed in CTEs (so we discard it), it causes values for pointers # to be read from DML CTEs, which makes the appear in the path_namespace # prepare a map from pointer name into pgast ptr_map: dict[tuple[str, str], pgast.BaseExpr] = {} for (ptr_id, aspect), output_var in output_rel.path_outputs.items(): qual_name = ptr_id.rptr_name() if not qual_name: ptr_map['id', aspect] = output_var else: ptr_map[qual_name.name, aspect] = output_var output_namespace: dict[str, pgast.BaseExpr] = {} for col_name, ptr_name in stmt.subject_columns: val = ptr_map.get((ptr_name, 'serialized'), None) if not val: val = ptr_map.get((ptr_name, 'value'), None) if not val: val = ptr_map.get((ptr_name, 'identity'), None) if ptr_name in ('source', 'target'): val = pgast.ColumnRef(name=(ptr_name,)) assert val, f'{ptr_name} was in shape, but not in path_namespace' output_namespace[col_name] = val result[stmt.input] = context.CompiledDML( value_cte_name=stmt.early_result.value_cte_name, value_relation_input=stmt.early_result.value_relation_input, value_columns=stmt.early_result.value_columns, value_iterator_name=stmt.early_result.value_iterator_name, conflict_update_input=stmt.early_result.conflict_update_input, conflict_update_name=stmt.early_result.conflict_update_name, conflict_update_iterator=stmt.early_result.conflict_update_iterator, subject_id=stmt.early_result.subject_id, subject_columns=stmt.early_result.subject_columns, value_id=stmt.early_result.value_id, env=sql_result.env, output_ctes=stmt_ctes, output_relation_name=output_cte.name, output_namespace=output_namespace, ) # The remaining CTEs do not belong to any one specific DML statement and # should be included to at the end of the top-level query. # They were probably generated by "after all" triggers. return result, ctes def _collect_stmt_ctes( ctes: list[pgast.CommonTableExpr], ir_stmt: irast.MutatingStmt ) -> list[pgast.CommonTableExpr]: # We compile all SQL DML queries in a single EdgeQL query. # Result is an enormous SQL query with many CTEs. # This function looks through these CTEs and matches them to the original # IR stmt that they originate from. # It will pop elements from ctes list until it reaches the main CTE for # the IR stmt. # When looking for "CTEs of a stmt" we also want to include stmts that # originate from the UNLESS CONFLICT clause. ir_stmts = {ir_stmt} if isinstance(ir_stmt, irast.InsertStmt): if ir_stmt.on_conflict and ir_stmt.on_conflict.else_ir: else_expr = ir_stmt.on_conflict.else_ir.expr assert isinstance(else_expr, irast.SelectStmt), else_expr assert else_expr.iterator_stmt dml_stmt = else_expr.result.expr assert isinstance(dml_stmt, irast.MutatingStmt), dml_stmt ir_stmts.add(dml_stmt) stmt_ctes = [] found_it = False while len(ctes) > 0: matches = ctes[0].for_dml_stmt in ir_stmts if not matches and found_it: # use all matching CTEs plus all preceding break if matches: found_it = True stmt_ctes.append(ctes.pop(0)) return stmt_ctes def _merge_and_prepare_external_rels( ir_stmt: irast.Statement, stmts: list[UncompiledDML], stmt_names: list[str], ) -> tuple[ Mapping[irast.PathId, ExternalRel], list[irast.MutatingStmt], ]: """Construct external rels used for compiling all DML statements at once.""" # This should be straight-forward, but because we put DML into shape # elements, ql compiler will put each binding into a separate namespace. # So we need to find the correct path_id for each DML stmt in the IR by # looking at the paths in the shape elements. assert isinstance(ir_stmt.expr, irast.SetE) assert isinstance(ir_stmt.expr.expr, irast.SelectStmt) ir_shape = ir_stmt.expr.expr.result.shape assert ir_shape # extract stmt name from the shape elements shape_elements_by_name = {} for b, _ in ir_shape: rptr_name = b.path_id.rptr_name() if not rptr_name: continue shape_elements_by_name[rptr_name.name] = b.expr.expr external_rels: dict[irast.PathId, ExternalRel] = {} ir_stmts = [] for stmt, name in zip(stmts, stmt_names): # find the associated binding (this is real funky) element = shape_elements_by_name[name] while not isinstance(element, irast.MutatingStmt): if isinstance(element, irast.SelectStmt): element = element.result.expr elif isinstance(element, irast.Pointer): element = element.source.expr elif isinstance(element, irast.OperatorCall): element = element.args[0].expr.expr else: raise NotImplementedError('cannot find mutating stmt') ir_stmts.append(element) subject_path_id = element.result.path_id subject_namespace = subject_path_id.namespace # add all external rels, but add the namespace to their output's ids for rel_id, (rel, aspects) in stmt.external_rels.items(): for (out_id, out_asp), out in list(rel.path_outputs.items()): # HACK: this is a hacky hack to get the path_id used by the # pointers within the DML statement's namespace out_id = out_id.replace_namespace(subject_namespace) out_id._prefix = out_id._get_prefix(1).replace_namespace(set()) rel.path_outputs[out_id, out_asp] = out external_rels[rel_id] = (rel, aspects) return external_rels, ir_stmts @dispatch._resolve_relation.register def resolve_DMLQuery( stmt: pgast.DMLQuery, *, include_inherited: bool, ctx: Context ) -> tuple[pgast.Query, context.Table]: assert stmt.relation if ctx.subquery_depth >= 2: raise errors.QueryError( 'WITH clause containing a data-modifying statement must be at ' 'the top level', span=stmt.span, pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, ) subject_alias: str | None = None if stmt.relation.alias and stmt.relation.alias.aliasname: subject_alias = stmt.relation.alias.aliasname assert stmt.relation.relation.name subject_name = (stmt.relation.relation.name, subject_alias) compiled_dml = ctx.compiled_dml[stmt] _resolve_dml_value_rel(compiled_dml, ctx=ctx) if compiled_dml.conflict_update_input is not None: _resolve_conflict_update_rel(compiled_dml, subject_name, ctx=ctx) ctx.ctes_buffer.extend(compiled_dml.output_ctes) ctx.env.capabilities |= enums.Capability.MODIFICATIONS return _fini_resolve_dml(stmt, compiled_dml, ctx=ctx) def _resolve_dml_value_rel( compiled_dml: context.CompiledDML, *, ctx: Context ): # resolve the value relation with ctx.child() as sctx: # this subctx is needed so it is not deemed as top-level which would # extract and attach CTEs, but not make the available to all # following CTEs # but it is not a "real" subquery context sctx.subquery_depth -= 1 val_rel, val_table = dispatch.resolve_relation( compiled_dml.value_relation_input, ctx=sctx ) assert isinstance(val_rel, pgast.Query) if len(compiled_dml.value_columns) != len(val_table.columns): col_names = ', '.join(c for c, _, _ in compiled_dml.value_columns) raise errors.QueryError( f'Expected {len(compiled_dml.value_columns)} columns ' f'({col_names}), but got {len(val_table.columns)}', span=compiled_dml.value_relation_input.span, ) if val_rel.ctes: ctx.ctes_buffer.extend(val_rel.ctes) val_rel.ctes = None val_table.alias = ctx.alias_generator.get('rel') # wrap the value relation into a "pre-projection", # so we can add type casts for link ids and an iterator column pre_projection_needed = compiled_dml.value_iterator_name or ( any(cast_to_uuid for _, _, cast_to_uuid in compiled_dml.value_columns) ) if pre_projection_needed: value_target_list: list[pgast.ResTarget] = [] for val_col, (_col_name, ptr_name, cast_to_uuid) in zip( val_table.columns, compiled_dml.value_columns ): # prepare pre-projection of this pointer value val_col_pg = pg_res_expr.resolve_column_kind( val_table, val_col.kind, ctx=ctx ) if cast_to_uuid: val_col_pg = pgast.TypeCast( arg=val_col_pg, type_name=pgast.TypeName(name=('uuid',)) ) value_target_list.append( pgast.ResTarget(name=ptr_name, val=val_col_pg) ) if compiled_dml.value_iterator_name: # source needs an iterator column, so we need to invent one # The name of the column was invented before (in pre-processing) so # it could be used in DML CTEs. value_target_list.append( pgast.ResTarget( name=compiled_dml.value_iterator_name, val=pgast.FuncCall( name=('edgedb', 'uuid_generate_v4'), args=(), ), ) ) val_rel = pgast.SelectStmt( from_clause=[ pgast.RangeSubselect( subquery=val_rel, alias=pgast.Alias(aliasname=val_table.alias), ) ], target_list=value_target_list, ) value_cte = pgast.CommonTableExpr( name=compiled_dml.value_cte_name, query=val_rel, ) ctx.ctes_buffer.append(value_cte) def _resolve_conflict_update_rel( compiled_dml: context.CompiledDML, subject_name: tuple[str, str | None], *, ctx: Context, ): # Resolves the relation that provides the rows that should be updated in # ON CONFLICT DO UDPATE # This is done by manipualting CTEs produced by our pg compiler compiling # `insert Subject { ... } unless conflict (update Subject set {...})` # There is a few relevant CTEs for this: # - "value": provides rows to be inserted # (produced by resolver just before this function), # - "else": provides iterator over rows that are in conflict and should not # be inserted but updated instead (generated by pgcompiler), # - "conflict update": provides rows to be updated, # (produced by this function) # - "ins_contents": computes values of the inserted shape, # (generated by pg compiler, anti-joined against "else") # "conflict update" needs two relational inputs: # - subject relation of rows that exist in the database and are in conflict, # - excluded relation which are rows provided to the insert. # "subject" is provided by "else", "excluded" is provided by "value" # To correlate the two, we join them on the iterator that is generated in # the "value" CTE. We cannot use real ids for this because they do not yet # exist for these objects, since they are computed in "ins_contents". # Additionally, we need to correlate "conflict update" rel with the EdgeQL # update stmt. This is done via EdgeQL where clause, which compares ids of # the ids of the subject and the update iterator. if not ( compiled_dml.conflict_update_name and compiled_dml.conflict_update_input ): return from edb.pgsql.compiler import enums as pgce from edb.pgsql.compiler import astutils as pg_astutils # Find "else" CTE in the list of compiled CTEs # This is comparing by CTE name, which is hacky and error prone # We are guaranteed to have only one such CTE, since these CTEs all # belong to a single insert stmt. else_index, else_cte = next( (i, cte) for i, cte in enumerate(compiled_dml.output_ctes) if cte.name.startswith('else') ) # Apply a view_map_id_map to get around the fact that rvar map path_ids # contain namespaces due to use using for loops around the insert stmt. assert compiled_dml.subject_id else_cte.query.view_path_id_map[compiled_dml.subject_id] = ( next(iter(else_cte.query.view_path_id_map.keys())) ) # Include 'excluded' rel var in scope # This is outside of the inner scope, so the subject table takes precedence ctx.scope.tables.append( context.Table( name='excluded', reference_as='excluded', columns=[ context.Column( name=col_name, kind=context.ColumnByName(reference_as=ptr_name), ) for col_name, ptr_name, _ in compiled_dml.value_columns ] ) ) # resolve the relation with ctx.child() as sctx: # this subctx is needed so it is not deemed as top-level which would # extract and attach CTEs, but not make the available to all # following CTEs # include subject rel var in scope sctx.scope.tables.append( context.Table( name=subject_name[0], alias=subject_name[1], reference_as='else', columns=[ context.Column( name=col_name, kind=context.ColumnByName( reference_as=_get_path_id_output( else_cte.query, path_id, compiled_dml, ) ), ) for col_name, path_id in compiled_dml.subject_columns or [] ] ) ) # the important bit: resolve "conflict update" rel cu_rel, _ = dispatch.resolve_relation( compiled_dml.conflict_update_input, ctx=sctx ) assert isinstance(cu_rel, pgast.SelectStmt) # inject the 'excluded' rel var (from "value" CTE) cu_rel.from_clause.append( pgast.RelRangeVar( relation=pgast.Relation(name=compiled_dml.value_cte_name), alias=pgast.Alias(aliasname='excluded'), ) ) # inject the subject rel var (from "else" CTE) cu_rel.from_clause.append( pgast.RelRangeVar( relation=pgast.Relation(name=else_cte.name), alias=pgast.Alias(aliasname='else'), ) ) # inject interator column, which we can pull from excluded assert compiled_dml.value_iterator_name cu_rel.target_list.append(pgast.ResTarget( val=pgast.ColumnRef( name=('excluded', compiled_dml.value_iterator_name) ), )) # inject subject id from "else" rvar subject_id_col = _get_path_id_output( else_cte.query, compiled_dml.subject_id, compiled_dml, aspect=pgce.PathAspect.IDENTITY ) cu_rel.target_list.append(pgast.ResTarget( val=pgast.ColumnRef( name=('else', subject_id_col), ), name='id' )) # add a join condition for "excluded" and "subject" rvars # We start with a plain value_id, but because of for loops, rvar_map # will contain path_ids polluted with namespaces. So instead of a plain # one, we find a path_id from rvar_map that matches the plain one in all # but the namespace. value_id = next( p for p, _ in else_cte.query.path_rvar_map.keys() if p.replace_namespace(set()) == compiled_dml.value_id ) # pull value iterator from the else CTE value_iter = _get_path_id_output( else_cte.query, value_id, compiled_dml ) cu_rel.where_clause = pg_astutils.extend_binop( cu_rel.where_clause, pgast.Expr( lexpr=pgast.ColumnRef(name=('else', value_iter)), name='=', rexpr=pgast.ColumnRef( name=('excluded', compiled_dml.value_iterator_name) ), ) ) # convert the resolved "conflict update" into a flat list of ctes conflict_ctes = [] if cu_rel.ctes: conflict_ctes.extend(cu_rel.ctes) cu_rel.ctes = None cu_cte = pgast.CommonTableExpr( name=compiled_dml.conflict_update_name, query=cu_rel, ) conflict_ctes.append(cu_cte) # combine compiled CTEs and CTEs from "conflict update" compiled_dml.output_ctes = ( compiled_dml.output_ctes[0:else_index + 1] + conflict_ctes + compiled_dml.output_ctes[else_index + 1:] ) # Invokes pg compiler machinery to pull value for columns out of a query def _get_path_id_output( query: pgast.Query, path_id: irast.PathId, compiled_dml: context.CompiledDML, *, aspect: pgce.PathAspect = pgce.PathAspect.VALUE, ) -> str: # The mere fact that this is used outside of the pg compiler signals # that this is hacky. assert compiled_dml.env output = pgcompiler.pathctx.get_path_output( query, path_id, aspect=aspect, env=compiled_dml.env ) assert isinstance(output, pgast.ColumnRef) name = output.name[-1] assert isinstance(name, str) return name def _fini_resolve_dml( stmt: pgast.DMLQuery, compiled_dml: context.CompiledDML, *, ctx: Context ) -> tuple[pgast.Query, context.Table]: if stmt.returning_list: assert isinstance(stmt.relation.relation, pgast.Relation) assert stmt.relation.relation.name res_query, res_table = _resolve_returning_rows( stmt.returning_list, compiled_dml.output_relation_name, compiled_dml.output_namespace, stmt.relation.relation.name, stmt.relation.alias.aliasname, ctx, ) else: if ctx.subquery_depth == 0: # when a top-level DML query have a RETURNING clause, # we inject a COUNT(*) clause so we can efficiently count # modified rows which will be converted into CommandComplete tag. res_query = pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.FuncCall( name=('count',), agg_star=True, args=[] ), ) ], from_clause=[ pgast.RelRangeVar( relation=pgast.Relation( name=compiled_dml.output_relation_name ) ) ], ) else: # nested DML queries without RETURNING does not need any result res_query = pgast.SelectStmt() res_table = context.Table() if not res_query.ctes: res_query.ctes = [] res_query.ctes.extend(pg_res_rel.extract_ctes_from_ctx(ctx)) return res_query, res_table def _resolve_returning_rows( returning_list: list[pgast.ResTarget], output_relation_name: str, output_namespace: Mapping[str, pgast.BaseExpr], subject_name: str, subject_alias: Optional[str], ctx: context.ResolverContextLevel, ) -> tuple[pgast.Query, context.Table]: # "output" is the relation that provides the values of the subject table # after the DML operation. # It contains data you'd get by having `RETURNING *` on the DML. output_rvar_name = ctx.alias_generator.get('output') output_query = pgast.SelectStmt( from_clause=[ pgast.RelRangeVar( relation=pgast.Relation(name=output_relation_name), ) ] ) output_table = context.Table( name=subject_name, alias=subject_alias, reference_as=output_rvar_name, ) for col_name, val in output_namespace.items(): output_query.target_list.append( pgast.ResTarget(name=col_name, val=val) ) output_table.columns.append( context.Column( name=col_name, kind=context.ColumnByName(reference_as=col_name), ) ) with ctx.empty() as sctx: sctx.scope.tables.append(output_table) returning_query = pgast.SelectStmt( from_clause=[ pgast.RangeSubselect( alias=pgast.Alias(aliasname=output_rvar_name), subquery=output_query, ) ], target_list=[], ) returning_table = context.Table() names: set[str] = set() for t in returning_list: targets, columns = pg_res_expr.resolve_ResTarget( t, existing_names=names, ctx=sctx ) returning_query.target_list.extend(targets) returning_table.columns.extend(columns) return returning_query, returning_table def _get_pointer_for_column( col: context.Column, subject: s_objtypes.ObjectType | s_links.Link | s_properties.Property, ctx: context.ResolverContextLevel, ) -> tuple[s_pointers.Pointer, str, bool]: if isinstance( subject, (s_links.Link, s_properties.Property) ) and col.name in ('source', 'target'): return subject, col.name, False assert not isinstance(subject, s_properties.Property) is_link = False ptr_name = col.name if col.name.endswith('_id'): # If the name ends with _id, and a single link exists with that name, # then we are referring to the link. root_name = ptr_name[0:-3] if ( (link := subject.maybe_get_ptr( ctx.schema, sn.UnqualName(root_name), type=s_links.Link )) and link.singular(ctx.schema) ): ptr_name = root_name is_link = True ptr = subject.maybe_get_ptr(ctx.schema, sn.UnqualName(ptr_name)) assert ptr return ptr, ptr_name, is_link def _get_ptr_id( source_id: irast.PathId, ptr: s_pointers.Pointer, ctx: context.ResolverContextLevel, ) -> irast.PathId: ptrref = irtypeutils.ptrref_from_ptrcls( schema=ctx.schema, ptrcls=ptr, cache=None, typeref_cache=None ) return source_id.extend(ptrref=ptrref) def _try_inject_ptr_type_cast( rel: pgast.BaseRelation, index: int, ptr: s_pointers.Pointer, ctx: Context ): ptr_name = ptr.get_shortname(ctx.schema).name tgt_pg: tuple[str, ...] if ptr_name == 'id' or isinstance(ptr, s_links.Link): tgt_pg = ('uuid',) else: tgt = ptr.get_target(ctx.schema) assert tgt tgt_pg = pgtypes.pg_type_from_object(ctx.schema, tgt) _try_inject_type_cast(rel, index, pgast.TypeName(name=tgt_pg)) def _try_inject_type_cast( rel: pgast.BaseRelation, pos: int, ty: pgast.TypeName, ): """ If a relation is simple, injects type annotation for a column. This is needed for Postgres to correctly infer the type so it will be able to bind to correct parameter types. For example: INSERT x (a, b) VALUES ($1, $2) is compiled into something like: WITH cte AS (VALUES ($1, $2)) INSERT x (a, b) SELECT * FROM cte This function adds type casts into `cte`. """ if not isinstance(rel, pgast.SelectStmt): return if rel.values: for row_i, row in enumerate(rel.values): if isinstance(row, pgast.ImplicitRowExpr) and pos < len(row.args): args = list(row.args) args[pos] = pgast.TypeCast(arg=args[pos], type_name=ty) rel.values[row_i] = row.replace(args=args) elif rel.target_list and pos < len(rel.target_list): target = rel.target_list[pos] rel.target_list[pos] = target.replace( val=pgast.TypeCast(arg=target.val, type_name=ty) ) def merge_params( sql_result: pgcompiler.CompileResult, ir_stmt: irast.Statement, ctx: Context ): # Merge the params produced by the main compiler with params for the rest of # the query that the resolved is keeping track of. param_remapping: dict[int, int] = {} for arg_name, arg in sql_result.argmap.items(): # find the global glob = next(g for g in ir_stmt.globals if g.name == arg_name) # search for existing params for this global existing_param = next( ( p for p in ctx.query_params if isinstance(p, dbstate.SQLParamGlobal) and p.global_name == glob.global_name ), None, ) internal_index: int if existing_param is not None: internal_index = existing_param.internal_index else: # append a new param internal_index = len(ctx.query_params) + 1 pg_type = pgtypes.pg_type_from_ir_typeref( glob.ir_type.base_type or glob.ir_type ) ctx.query_params.append( dbstate.SQLParamGlobal( global_name=glob.global_name, pg_type=pg_type, is_permission=glob.is_permission, internal_index=internal_index, ) ) # remap if necessary if internal_index != arg.index: param_remapping[arg.index] = internal_index if len(param_remapping) > 0: ParamMapper(param_remapping).visit(sql_result.ast) class ParamMapper(ast.NodeVisitor): def __init__(self, mapping: dict[int, int]) -> None: super().__init__() self.mapping = mapping def visit_ParamRef(self, p: pgast.ParamRef) -> None: p.number = self.mapping[p.number] def init_external_params(query: pgast.Base, ctx: Context): counter = ParamCounter() counter.node_visit(query) for _ in range(0, counter.param_count - len(ctx.options.normalized_params)): ctx.query_params.append(dbstate.SQLParamExternal()) for param_type_oid in ctx.options.normalized_params: ctx.query_params.append(dbstate.SQLParamExtractedConst( type_oid=param_type_oid )) class ParamCounter(ast.NodeVisitor): def __init__(self) -> None: super().__init__() self.param_count = 0 def visit_ParamRef(self, p: pgast.ParamRef) -> None: if self.param_count < p.number: self.param_count = p.number def fini_external_params(ctx: Context): for param in ctx.query_params: if ( not param.used and isinstance(param, dbstate.SQLParamExtractedConst) and param.type_oid == pg_parser.PgLiteralTypeOID.UNKNOWN ): param.type_oid = pg_parser.PgLiteralTypeOID.TEXT ================================================ FILE: edb/pgsql/resolver/context.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from copy import deepcopy from typing import Optional, Sequence, Mapping, Any from dataclasses import dataclass, field import enum import uuid from edb.pgsql import ast as pgast from edb.pgsql.compiler import aliases from edb.common import compiler from edb.server.compiler import dbstate, enums from edb.schema import schema as s_schema from edb.schema import objects as s_objects from edb.schema import pointers as s_pointers @dataclass(frozen=True, kw_only=True, repr=False, match_args=False) class Options: current_database: str current_query: str # schemas that will be searched when idents don't have an explicit one search_path: Sequence[str] # allow setting id in inserts allow_user_specified_id: bool # apply access policies to select & dml statements apply_access_policies: bool # whether to generate an EdgeQL-compatible single-column output variant. include_edgeql_io_format_alternative: Optional[bool] # makes sure that output does not contain duplicated column names disambiguate_column_names: bool # Type oids of parameters that have taken place of constants during query # normalization. # When this is non-empty, the resolver is allowed to raise # DisableNormalization to recompile the query without normalization. normalized_params: list[int] # Apply a limit to the number of rows in the top-level query implicit_limit: Optional[int] @dataclass(kw_only=True) class Scope: """ Information about that objects are visible at a specific point in an SQL query. Scope is modified during resolving of a query, when new tables are discovered in FROM or JOIN or new columns declared in SELECT's projection. After a query is done resolving, resulting relations are extracted from its scope and inserted into parent scope. """ # RangeVars (table instances) in this query tables: list[Table] = field(default_factory=lambda: []) # Common Table Expressions ctes: list[CTE] = field(default_factory=lambda: []) # Pairs of columns of the same name that have been compared in a USING # clause. This makes unqualified references to their name them un-ambiguous. # The fourth tuple element is the join type. factored_columns: list[tuple[str, Table, Table, str]] = field( default_factory=lambda: [] ) @dataclass(kw_only=True) class Table: # The schema id of the object that is the source of this table schema_id: Optional[uuid.UUID] = None # Public SQL name: Optional[str] = None alias: Optional[str] = None columns: list[Column] = field(default_factory=lambda: []) # Internal SQL reference_as: Optional[str] = None # For ambiguous references, this fields determines lookup order. # Higher value is matched before lower. # Aliases from current relation have higher precedence in GROUP BY # than columns of input rel vars (tables). # Columns from parent scopes have lower precedence # than columns of input rel vars (tables). precedence: int = 0 # True when this relation is compiled to a direct reference to the # underlying table, without any views or CTEs. # Is the condition for usage of locking clauses. is_direct_relation: bool = False def __str__(self) -> str: columns = ', '.join(str(c) for c in self.columns) alias = f'{self.alias} = ' if self.alias else '' return f'{alias}{self.name or ""}({columns})' @dataclass(kw_only=True) class CTE: name: Optional[str] = None columns: list[Column] = field(default_factory=lambda: []) @dataclass(kw_only=True) class Column: # Public SQL name: str # When true, column is not included when selecting * # Used for system columns # https://www.postgresql.org/docs/14/ddl-system-columns.html hidden: bool = False kind: ColumnKind def __str__(self) -> str: return self.name or '' class ColumnKind: # When a column is referenced, implementation of this class determined # into what it is compiled to. # The base case is ColumnByName, which just means that it compiles to an # identifier to a column. pass @dataclass(kw_only=True) class ColumnByName(ColumnKind): # Internal SQL column name reference_as: str @dataclass(kw_only=True) class ColumnStaticVal(ColumnKind): # Value that can be used instead referencing the column. # Used from __type__ only, so that's why it is UUID (for now). val: uuid.UUID @dataclass(kw_only=True) class ColumnComputable(ColumnKind): # An EdgeQL computable property. To get the AST for this column, EdgeQL # compiler needs to be invoked. pointer: s_pointers.Pointer @dataclass(kw_only=True) class ColumnPgExpr(ColumnKind): # Value that was provided by some special resolver path. expr: pgast.BaseExpr @dataclass(kw_only=True, eq=False, slots=True, repr=False) class CompiledDML: # relation that provides the DML value. not yet resolved. value_cte_name: str # relation that provides the DML value. not yet resolved. value_relation_input: pgast.BaseRelation # columns that are expected to be produced by the value relation # contains: column name, ptr name, is_link value_columns: list[tuple[str, str, bool]] # name of the column in the value relation, that should provide the identity value_iterator_name: Optional[str] # for INSERTs, relation that provides values for UPDATE that happens ON # CONFLICT. not yet resolved conflict_update_input: Optional[pgast.BaseRelation] = None # for INSERTs, name of CTE that provides values for UPDATE that happens ON # CONFLICT conflict_update_name: Optional[str] = None conflict_update_iterator: Optional[str] = None subject_id: Optional[Any] = None subject_columns: list[tuple[str, Any]] | None = None value_id: Optional[Any] = None env: Optional[Any] = None # CTEs that perform the operation output_ctes: list[pgast.CommonTableExpr] # name of the CTE that contains the output of the insert output_relation_name: str # mapping from output column names into output vars output_namespace: Mapping[str, pgast.BaseExpr] class ContextSwitchMode(enum.Enum): EMPTY = enum.auto() CHILD = enum.auto() LATERAL = enum.auto() @dataclass(kw_only=True) class Environment: """Static compilation environment.""" # Capabilities required by the query capabilities: enums.Capability = enums.Capability.NONE class ResolverContextLevel(compiler.ContextLevel): # Compilation environment common for all context levels. env: Environment schema: s_schema.Schema alias_generator: aliases.AliasGenerator # Visible names in scope scope: Scope # 0 for top-level statement, 1 for its CTEs/sub-relations/links # and so on for all subqueries. subquery_depth: int # List of CTEs to add the top-level statement. # This is used, for example, by DML compilation to ensure that all DML is # in the top-level WITH binding. ctes_buffer: list[pgast.CommonTableExpr] # A mapping of from objects to CTEs that provide an "inheritance view", # which is basically a union of all of their descendant's tables. inheritance_ctes: dict[s_objects.InheritingObject, str] compiled_dml: Mapping[pgast.Query, CompiledDML] options: Options query_params: list[dbstate.SQLParam] """List of params needed by the compiled query. Gets populated during compilation and also includes params needed for globals, from calls to ql compiler.""" def __init__( self, prevlevel: Optional[ResolverContextLevel], mode: ContextSwitchMode, *, schema: Optional[s_schema.Schema] = None, options: Optional[Options] = None, ) -> None: if prevlevel is None: assert schema assert options self.env = Environment() self.schema = schema self.options = options self.scope = Scope() self.alias_generator = aliases.AliasGenerator() self.subquery_depth = 0 self.ctes_buffer = [] self.inheritance_ctes = dict() self.compiled_dml = dict() self.query_params = [] else: self.env = prevlevel.env self.schema = prevlevel.schema self.options = prevlevel.options self.alias_generator = prevlevel.alias_generator self.subquery_depth = prevlevel.subquery_depth + 1 self.ctes_buffer = prevlevel.ctes_buffer self.inheritance_ctes = prevlevel.inheritance_ctes self.compiled_dml = prevlevel.compiled_dml self.query_params = prevlevel.query_params if mode == ContextSwitchMode.EMPTY: self.scope = Scope(ctes=prevlevel.scope.ctes) elif mode == ContextSwitchMode.CHILD: self.scope = deepcopy(prevlevel.scope) for t in self.scope.tables: t.precedence -= 1 elif mode == ContextSwitchMode.LATERAL: self.scope = deepcopy(prevlevel.scope) def empty( self, ) -> compiler.CompilerContextManager[ResolverContextLevel]: """Create a new empty context""" return self.new(ContextSwitchMode.EMPTY) def child(self) -> compiler.CompilerContextManager[ResolverContextLevel]: """Clone current context, prevent changes from leaking to parent""" return self.new(ContextSwitchMode.CHILD) def lateral(self) -> compiler.CompilerContextManager[ResolverContextLevel]: """Clone current context, prevent changes from leaking to parent""" return self.new(ContextSwitchMode.LATERAL) class ResolverContext(compiler.CompilerContext[ResolverContextLevel]): ContextLevelClass = ResolverContextLevel default_mode = ContextSwitchMode.EMPTY ================================================ FILE: edb/pgsql/resolver/dispatch.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import functools import typing import re from edb.server.pgcon import errors as pgerror from edb.pgsql import ast as pgast from edb import errors from . import context @functools.singledispatch def _resolve( expr: pgast.Base, *, ctx: context.ResolverContextLevel ) -> pgast.Base: _raise_unsupported(expr) def resolve[Base_T: pgast.Base]( expr: Base_T, *, ctx: context.ResolverContextLevel ) -> Base_T: res = _resolve(expr, ctx=ctx) return typing.cast(Base_T, res.replace(span=expr.span)) def resolve_opt[Base_T: pgast.Base]( node: typing.Optional[Base_T], *, ctx: context.ResolverContextLevel ) -> typing.Optional[Base_T]: if not node: return None return resolve(node, ctx=ctx) def resolve_list[Base_T: pgast.Base]( exprs: typing.Sequence[Base_T], *, ctx: context.ResolverContextLevel ) -> list[Base_T]: return [resolve(e, ctx=ctx) for e in exprs] def resolve_opt_list[Base_T: pgast.Base]( exprs: typing.Optional[list[Base_T]], *, ctx: context.ResolverContextLevel, ) -> typing.Optional[list[Base_T]]: if not exprs: return None return resolve_list(exprs, ctx=ctx) def resolve_relation( rel: pgast.BaseRelation, *, include_inherited: bool = True, ctx: context.ResolverContextLevel, ) -> tuple[pgast.BaseRelation, context.Table]: rel, tab = _resolve_relation( rel, include_inherited=include_inherited, ctx=ctx ) return rel.replace(span=rel.span), tab @functools.singledispatch def _resolve_relation( rel: pgast.BaseRelation, *, include_inherited: bool, ctx: context.ResolverContextLevel, ) -> tuple[pgast.BaseRelation, context.Table]: _raise_unsupported(rel) @_resolve.register def _resolve_BaseRelation( rel: pgast.BaseRelation, *, ctx: context.ResolverContextLevel ) -> pgast.BaseRelation: # use _resolve_BaseRelation in normal _resolve dispatch rel, _ = resolve_relation(rel, ctx=ctx) return rel def _raise_unsupported(expr: pgast.Base) -> typing.Never: pretty_name = expr.__class__.__name__ pretty_name = pretty_name.removesuffix('Stmt') # title case to spaces pretty_name = re.sub(r'(? Optional[str]: if res_target.name: return res_target.name val = res_target.val if isinstance(val, pgast.TypeCast): val = val.arg if isinstance(val, pgast.FuncCall): return val.name[-1] if isinstance(val, pgast.ImplicitRowExpr): return 'row' # if just name has been selected, use it as the alias if isinstance(val, pgast.ColumnRef): name = val.name if isinstance(name[-1], str): return name[-1] return None # this function cannot go though dispatch, # because it may return multiple nodes, due to * notation def resolve_ResTarget( res_target: pgast.ResTarget, *, existing_names: set[str], ctx: Context, ) -> tuple[Sequence[pgast.ResTarget], Sequence[context.Column]]: targets, columns = _resolve_ResTarget( res_target, existing_names=existing_names, ctx=ctx ) return (targets, columns) def _resolve_ResTarget( res_target: pgast.ResTarget, *, existing_names: set[str], ctx: Context, ) -> tuple[Sequence[pgast.ResTarget], Sequence[context.Column]]: alias = infer_alias(res_target) # special case for ColumnRef for handing wildcards if not alias and isinstance(res_target.val, pgast.ColumnRef): col_res = _lookup_column(res_target.val, ctx) res = [] columns = [] for table, column in col_res: val = resolve_column_kind(table, column.kind, ctx=ctx) # make sure name is not duplicated # this behavior is technically different then Postgres, but EdgeDB # protocol does not support duplicate names. And we doubt that # anyone is depending on original behavior. nam: str = column.name if nam in existing_names: # prefix with table name rel_var_name = table.alias or table.name if rel_var_name: nam = rel_var_name + '_' + nam if nam in existing_names: if ctx.options.disambiguate_column_names: raise errors.QueryError( f'duplicate column name: `{nam}`', span=res_target.span, pgext_code=pgerror.ERROR_UNDEFINED_COLUMN, ) existing_names.add(nam) res.append( pgast.ResTarget( name=nam, val=val, ) ) columns.append( context.Column( name=nam, hidden=column.hidden, kind=column.kind, ) ) return (res, columns) # base case val = dispatch.resolve(res_target.val, ctx=ctx) # special case for statically-evaluated FuncCall if ( not alias and isinstance(val, pgast.StringConstant) and isinstance(res_target.val, pgast.FuncCall) ): alias = static.name_in_pg_catalog(res_target.val.name) if alias in existing_names: # duplicate name if res_target.name: # explicit duplicate name: error out if ctx.options.disambiguate_column_names: raise errors.QueryError( f'duplicate column name: `{alias}`', span=res_target.span, pgext_code=pgerror.ERROR_UNDEFINED_COLUMN, ) else: # inferred duplicate name: use generated alias instead # this behavior is technically different than Postgres, but it is # also not documented and users should not be relying on it. # It does help us in some cases # (passing `SELECT a.id, b.id` into DML). alias = None name: str = alias or ctx.alias_generator.get('col') existing_names.add(name) col = context.Column( name=name, kind=context.ColumnByName(reference_as=name) ) new_target = pgast.ResTarget(val=val, name=name, span=res_target.span) return (new_target,), (col,) def resolve_column_kind( table: context.Table, column: context.ColumnKind, *, ctx: Context ) -> pgast.BaseExpr: match column: case context.ColumnByName(reference_as=reference_as): if table.reference_as: return pgast.ColumnRef(name=(table.reference_as, reference_as)) else: # In some cases tables might not have an assigned alias # because that is not syntactically possible (COPY), or because # the table being referenced is currently being assembled # (e.g. ORDER BY refers to a newly defined column). # So we make an assumption that in such cases, this will not # be ambiguous. I think this is not strictly correct. return pgast.ColumnRef(name=(reference_as,)) case context.ColumnStaticVal(val=val): # special case: __type__ static value return _uuid_const(val) case context.ColumnPgExpr(expr=e): return e case context.ColumnComputable(pointer=pointer): expr = pointer.get_expr(ctx.schema) assert expr source = pointer.get_source(ctx.schema) subject_id: irast.PathId source_id: irast.PathId if isinstance(source, s_types.Type): subject_id = irast.PathId.from_type( ctx.schema, source, env=None ) source_id = subject_id else: assert isinstance(source, s_pointers.Pointer) subject_id = irast.PathId.from_pointer( ctx.schema, source, env=None ) s = source.get_source(ctx.schema) assert isinstance(s, s_types.Type) source_id = irast.PathId.from_type(ctx.schema, s, env=None) singletons = [source] options = qlcompiler.CompilerOptions( modaliases={None: 'default'}, anchors={'__source__': source}, path_prefix_anchor='__source__', singletons=singletons, make_globals_empty=False, apply_user_access_policies=ctx.options.apply_access_policies, ) compiled = expr.compiled(ctx.schema, options=options, context=None) subject_rel = pgast.Relation(name=table.reference_as) if isinstance(source, s_types.Type): subject_rel.path_outputs = { (source_id, pgce.PathAspect.IDENTITY): pgast.ColumnRef( name=('id',) ) } else: subject_rel.path_outputs = { (source_id, pgce.PathAspect.IDENTITY): pgast.ColumnRef( name=('source',) ) } subject_rel_var = pgast.RelRangeVar( alias=pgast.Alias(aliasname=table.reference_as), relation=subject_rel, ) sql_tree = pgcompiler.compile_ir_to_sql_tree( compiled.irast, external_rvars={ (subject_id, pgce.PathAspect.SOURCE): subject_rel_var, (subject_id, pgce.PathAspect.VALUE): subject_rel_var, (source_id, pgce.PathAspect.IDENTITY): subject_rel_var, }, output_format=pgcompiler.OutputFormat.NATIVE_INTERNAL, alias_generator=ctx.alias_generator, ) command.merge_params(sql_tree, compiled.irast, ctx) assert isinstance(sql_tree.ast, pgast.BaseExpr) return sql_tree.ast case _: raise NotImplementedError(column) @dispatch._resolve.register def resolve_ColumnRef( column_ref: pgast.ColumnRef, *, ctx: Context ) -> pgast.BaseExpr: res = _lookup_column(column_ref, ctx) table, column = res[0] if len(res) != 1: # Lookup can have multiple results only when using *. assert table.reference_as return pgast.ColumnRef(name=(table.reference_as, pgast.Star())) return resolve_column_kind(table, column.kind, ctx=ctx) def _uuid_const(val: uuid.UUID): return pgast.TypeCast( arg=pgast.StringConstant(val=str(val)), type_name=pgast.TypeName(name=('uuid',)), ) def _lookup_column( column_ref: pgast.ColumnRef, ctx: Context, ) -> Sequence[tuple[context.Table, context.Column]]: matched_columns: list[tuple[context.Table, context.Column]] = [] name = column_ref.name col_name: str | pgast.Star if len(name) == 1: # look for the column in all tables col_name = name[0] if isinstance(col_name, pgast.Star): return [ (t, c) for t in ctx.scope.tables # Only look at the highest precedence level for # *. That is, we take everything in our local FROM # clauses but not stuff in enclosing queries, if we # are a subquery. if t.precedence == 0 for c in t.columns if not c.hidden ] for table in ctx.scope.tables: matched_columns.extend(_lookup_in_table(col_name, table)) if not matched_columns: # is it a reference to a rel var? try: tab = _lookup_table(col_name, ctx) assert tab.reference_as col = context.Column( name=tab.reference_as, kind=context.ColumnByName(reference_as=tab.reference_as), ) return [(context.Table(), col)] except errors.QueryError: pass elif len(name) >= 2: # look for the column in the specific table tab_name, col_name = name[-2:] try: table = _lookup_table(cast(str, tab_name), ctx) except errors.QueryError as e: e.set_span(column_ref.span) raise if isinstance(col_name, pgast.Star): return [(table, c) for c in table.columns if not c.hidden] else: matched_columns.extend(_lookup_in_table(col_name, table)) if not matched_columns: raise errors.QueryError( f'column {qi(col_name, force=True)} does not exist', span=column_ref.span, pgext_code=pgerror.ERROR_UNDEFINED_COLUMN, ) # apply precedence if len(matched_columns) > 1: max_precedence = max(t.precedence for t, _ in matched_columns) matched_columns = [ (t, c) for t, c in matched_columns if t.precedence == max_precedence ] # when ambiguous references have been used in USING clause, # we resolve them to first or the second column or a COALESCE of the two. if ( len(matched_columns) == 2 and matched_columns[0][1].name == matched_columns[1][1].name ): matched_name = matched_columns[0][1].name matched_tables = [t for t, _c in matched_columns] for c_name, t_left, t_right, join_type in ctx.scope.factored_columns: if matched_name != c_name: continue if not (t_left in matched_tables and t_right in matched_tables): continue c_left = next(c for c in t_left.columns if c.name == c_name) c_right = next(c for c in t_right.columns if c.name == c_name) if join_type == 'INNER' or join_type == 'LEFT': matched_columns = [(t_left, c_left)] elif join_type == 'RIGHT': matched_columns = [(t_right, c_right)] elif join_type == 'FULL': coalesce = pgast.CoalesceExpr( args=[ resolve_column_kind(t_left, c_left.kind, ctx=ctx), resolve_column_kind(t_right, c_right.kind, ctx=ctx), ] ) c_coalesce = context.Column( name=c_name, kind=context.ColumnPgExpr(expr=coalesce), ) matched_columns = [(t_left, c_coalesce)] else: raise NotImplementedError() break if len(matched_columns) > 1: potential_tables = ', '.join([t.name or '' for t, _ in matched_columns]) raise errors.QueryError( f'ambiguous column `{col_name}` could belong to ' f'following tables: {potential_tables}', span=column_ref.span, ) return matched_columns def _lookup_in_table( col_name: str, table: context.Table ) -> Iterator[tuple[context.Table, context.Column]]: for column in table.columns: if column.name == col_name: yield (table, column) def _maybe_lookup_table(tab_name: str, ctx: Context) -> context.Table | None: matched_tables: list[context.Table] = [] for t in ctx.scope.tables: t_name = t.alias or t.name if t_name == tab_name: matched_tables.append(t) if not matched_tables: return None # apply precedence if len(matched_tables) > 1: max_precedence = max(t.precedence for t in matched_tables) matched_tables = [ t for t in matched_tables if t.precedence == max_precedence ] if len(matched_tables) > 1: raise errors.QueryError(f'ambiguous table `{tab_name}`') table = matched_tables[0] return table def _lookup_table(tab_name: str, ctx: Context) -> context.Table: table = _maybe_lookup_table(tab_name, ctx=ctx) if table is None: raise errors.QueryError(f'cannot find table `{tab_name}`') return table @dispatch._resolve.register def resolve_SubLink( sub_link: pgast.SubLink, *, ctx: Context, ) -> pgast.SubLink: with ctx.child() as subctx: expr = dispatch.resolve(sub_link.expr, ctx=subctx) return pgast.SubLink( operator=sub_link.operator, expr=expr, test_expr=dispatch.resolve_opt(sub_link.test_expr, ctx=ctx), ) @dispatch._resolve.register def resolve_Expr(expr: pgast.Expr, *, ctx: Context) -> pgast.Expr: return pgast.Expr( name=expr.name, lexpr=dispatch.resolve(expr.lexpr, ctx=ctx) if expr.lexpr else None, rexpr=dispatch.resolve(expr.rexpr, ctx=ctx) if expr.rexpr else None, ) @dispatch._resolve.register def resolve_TypeCast( expr: pgast.TypeCast, *, ctx: Context, ) -> pgast.BaseExpr: pg_catalog_name = static.name_in_pg_catalog(expr.type_name.name) if pg_catalog_name == 'regclass' and not expr.type_name.array_bounds: return static.cast_to_regclass(expr.arg, ctx) return pgast.TypeCast( arg=dispatch.resolve(expr.arg, ctx=ctx), type_name=expr.type_name, ) @dispatch._resolve.register def resolve_BaseConstant( expr: pgast.BaseConstant, *, ctx: Context, ) -> pgast.BaseConstant: return expr @dispatch._resolve.register def resolve_CaseExpr( expr: pgast.CaseExpr, *, ctx: Context, ) -> pgast.CaseExpr: return pgast.CaseExpr( arg=dispatch.resolve_opt(expr.arg, ctx=ctx), args=dispatch.resolve_list(expr.args, ctx=ctx), defresult=dispatch.resolve_opt(expr.defresult, ctx=ctx), ) @dispatch._resolve.register def resolve_CaseWhen( expr: pgast.CaseWhen, *, ctx: Context, ) -> pgast.CaseWhen: return pgast.CaseWhen( expr=dispatch.resolve(expr.expr, ctx=ctx), result=dispatch.resolve(expr.result, ctx=ctx), ) @dispatch._resolve.register def resolve_SortBy( expr: pgast.SortBy, *, ctx: Context, ) -> pgast.SortBy: return pgast.SortBy( node=dispatch.resolve(expr.node, ctx=ctx), dir=expr.dir, nulls=expr.nulls, ) @dispatch._resolve.register def resolve_LockingClause( expr: pgast.LockingClause, *, ctx: Context, ) -> pgast.LockingClause: tables: list[context.Table] = [] if expr.locked_rels is not None: for rvar in expr.locked_rels: assert rvar.relation.name table = _lookup_table(rvar.relation.name, ctx=ctx) tables.append(table) else: tables.extend(ctx.scope.tables) # validate that the locking clause can be used on these tables for table in tables: if table.schema_id and not table.is_direct_relation: raise errors.QueryError( f'locking clause not supported: `{table.name or table.alias}` ' 'must not have child types or access policies', pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, ) return pgast.LockingClause( strength=expr.strength, locked_rels=[ pgast.RelRangeVar(relation=pgast.Relation(name=table.reference_as)) for table in tables ], wait_policy=expr.wait_policy, ) func_calls_remapping: dict[tuple[str, ...], tuple[str, ...]] = { ('information_schema', '_pg_truetypid'): ( common.versioned_schema('edgedbsql'), '_pg_truetypid', ), ('information_schema', '_pg_truetypmod'): ( common.versioned_schema('edgedbsql'), '_pg_truetypmod', ), ('pg_catalog', 'format_type'): ( common.versioned_schema('edgedbsql'), '_format_type', ), ('format_type',): ( common.versioned_schema('edgedbsql'), '_format_type', ), ('pg_catalog', 'pg_get_constraintdef'): ( common.versioned_schema('edgedbsql'), 'pg_get_constraintdef', ), ('pg_get_constraintdef',): ( common.versioned_schema('edgedbsql'), 'pg_get_constraintdef', ), } funcs_with_text_args: set[str] = { 'num_nulls', 'num_nonnulls', 'int8inc_any', 'int8dec_any', 'pg_typeof', 'pg_collation_for', 'concat', 'concat_ws', 'format', 'count', 'pg_column_size', 'json_build_array', 'jsonb_build_array', 'json_build_object', 'jsonb_build_object', 'json_object_agg', 'jsonb_object_agg', 'json_object_agg_strict', 'jsonb_object_agg_strict', 'json_object_agg_unique', 'jsonb_object_agg_unique', 'json_object_agg_unique_strict', 'jsonb_object_agg_unique_strict', } @dispatch._resolve.register def resolve_FuncCall( call: pgast.FuncCall, *, ctx: Context, ) -> pgast.BaseExpr: # Special case: some function calls (mostly from pg_catalog) are # intercepted and statically evaluated. if res := static.eval_FuncCall(call, ctx=ctx): return res # Remap function name and default to the original name. # Effectively, this exposes all non-remapped functions. name = func_calls_remapping.get(call.name, call.name) args = dispatch.resolve_list(call.args, ctx=ctx) # If arg is a param, add type annotations, so function can be resolved. # For example, `json_build_object($1, $2)` must be injected with annotations # See maybe_annotate_param for more info name_in_pg = static.name_in_pg_catalog(call.name) unknown_as = 'text' if name_in_pg in funcs_with_text_args else 'unknown' args = [ maybe_annotate_param(a, unknown_as=unknown_as, ctx=ctx) for a in args ] res = pgast.FuncCall( name=name, args=args, agg_order=dispatch.resolve_opt_list(call.agg_order, ctx=ctx), agg_filter=dispatch.resolve_opt(call.agg_filter, ctx=ctx), agg_star=call.agg_star, agg_distinct=call.agg_distinct, agg_within_group=call.agg_within_group, over=dispatch.resolve_opt(call.over, ctx=ctx), with_ordinality=call.with_ordinality, ) return res @dispatch._resolve.register def resolve_WindowDef( expr: pgast.WindowDef, *, ctx: Context, ) -> pgast.WindowDef: return pgast.WindowDef( partition_clause=dispatch.resolve_opt_list( expr.partition_clause, ctx=ctx ), order_clause=dispatch.resolve_opt_list(expr.order_clause, ctx=ctx), start_offset=dispatch.resolve_opt(expr.start_offset, ctx=ctx), end_offset=dispatch.resolve_opt(expr.end_offset, ctx=ctx), ) @dispatch._resolve.register def resolve_CoalesceExpr( expr: pgast.CoalesceExpr, *, ctx: Context, ) -> pgast.CoalesceExpr: return pgast.CoalesceExpr(args=dispatch.resolve_list(expr.args, ctx=ctx)) @dispatch._resolve.register def resolve_NullTest( expr: pgast.NullTest, *, ctx: Context, ) -> pgast.NullTest: return pgast.NullTest( arg=dispatch.resolve(expr.arg, ctx=ctx), negated=expr.negated ) @dispatch._resolve.register def resolve_BooleanTest( expr: pgast.BooleanTest, *, ctx: Context, ) -> pgast.BooleanTest: return pgast.BooleanTest( arg=dispatch.resolve(expr.arg, ctx=ctx), negated=expr.negated, is_true=expr.is_true, ) @dispatch._resolve.register def resolve_ImplicitRowExpr( expr: pgast.ImplicitRowExpr, *, ctx: Context, ) -> pgast.ImplicitRowExpr: return pgast.ImplicitRowExpr( args=dispatch.resolve_list(expr.args, ctx=ctx), ) @dispatch._resolve.register def resolve_RowExpr( expr: pgast.RowExpr, *, ctx: Context, ) -> pgast.RowExpr: return construct_row_expr( dispatch.resolve_list(expr.args, ctx=ctx), ctx=ctx, ) def construct_row_expr( args: Iterable[pgast.BaseExpr], *, ctx: Context ) -> pgast.RowExpr: # Constructs a ROW and maybe injects type casts for params. return pgast.RowExpr( args=[maybe_annotate_param(a, unknown_as='text', ctx=ctx) for a in args] ) def maybe_annotate_param( expr: pgast.BaseExpr, *, unknown_as: str = 'unknown', ctx: Context, ): # If the expression is a param whose type is `unknown`, we need to inject a # type cast that passes this information to Postgres. # Ideally, we could inject type `unknown`, but that would not work for some # cases, such as ROW('hello') or json_build_object('hello', TRUE). # So for special cases, where string literals are known to represent text, # we inject text and otherwise, we inject unknown. if isinstance(expr, pgast.ParamRef): param = ctx.query_params[expr.number - 1] if ( isinstance(param, dbstate.SQLParamExtractedConst) and param.type_oid == pg_parser.PgLiteralTypeOID.UNKNOWN ): return pgast.TypeCast( arg=expr, type_name=pgast.TypeName(name=(unknown_as,)) ) return expr @dispatch._resolve.register def resolve_ParamRef( expr: pgast.ParamRef, *, ctx: Context, ) -> pgast.ParamRef: # external params map one-to-one to internal params if expr.number < 1: raise errors.QueryError( f'param out of bounds: ${expr.number}', pgext_code=pgerror.ERROR_UNDEFINED_PARAMETER, hint='query parameters start with 1', ) param = ctx.query_params[expr.number - 1] param.used = True return expr @dispatch._resolve.register def resolve_ArrayExpr( expr: pgast.ArrayExpr, *, ctx: Context, ) -> pgast.ArrayExpr: return pgast.ArrayExpr( elements=dispatch.resolve_list(expr.elements, ctx=ctx) ) @dispatch._resolve.register def resolve_Indirection( expr: pgast.Indirection, *, ctx: Context, ) -> pgast.Indirection: return pgast.Indirection( arg=dispatch.resolve(expr.arg, ctx=ctx), indirection=dispatch.resolve_list(expr.indirection, ctx=ctx), ) @dispatch._resolve.register def resolve_RecordIndirectionOp( expr: pgast.RecordIndirectionOp, *, ctx: Context, ) -> pgast.RecordIndirectionOp: return expr @dispatch._resolve.register def resolve_Slice( expr: pgast.Slice, *, ctx: Context, ) -> pgast.Slice: return pgast.Slice( lidx=dispatch.resolve_opt(expr.lidx, ctx=ctx), ridx=dispatch.resolve_opt(expr.ridx, ctx=ctx), ) @dispatch._resolve.register def resolve_Index( expr: pgast.Index, *, ctx: Context, ) -> pgast.Index: return pgast.Index( idx=dispatch.resolve(expr.idx, ctx=ctx), ) @dispatch._resolve.register def resolve_SQLValueFunction( expr: pgast.SQLValueFunction, *, ctx: Context, ) -> pgast.BaseExpr: return static.eval_SQLValueFunction(expr, ctx=ctx) @dispatch._resolve.register def resolve_CollateClause( expr: pgast.CollateClause, *, ctx: Context, ) -> pgast.BaseExpr: return pgast.CollateClause( arg=dispatch.resolve(expr.arg, ctx=ctx), collname=expr.collname ) @dispatch._resolve.register def resolve_MinMaxExpr( expr: pgast.MinMaxExpr, *, ctx: Context, ) -> pgast.BaseExpr: return pgast.MinMaxExpr( op=expr.op, args=dispatch.resolve_list(expr.args, ctx=ctx), ) ================================================ FILE: edb/pgsql/resolver/range_functions.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Declarations of supported range functions""" COLUMNS = { 'json_array_elements': ['value'], 'json_array_elements_text': ['value'], 'json_each': ['key', 'value'], 'json_each_text': ['key', 'value'], 'jsonb_array_elements': ['value'], 'jsonb_array_elements_text': ['value'], 'jsonb_each': ['key', 'value'], 'jsonb_each_text': ['key', 'value'], 'pg_available_extension_versions': [ 'name', 'version', 'superuser', 'trusted', 'relocatable', 'schema', 'requires', 'comment', ], 'pg_available_extensions': ['name', 'default_version', 'comment'], 'pg_config': ['name', 'setting'], 'pg_control_checkpoint': [ 'checkpoint_lsn', 'redo_lsn', 'redo_wal_file', 'timeline_id', 'prev_timeline_id', 'full_page_writes', 'next_xid', 'next_oid', 'next_multixact_id', 'next_multi_offset', 'oldest_xid', 'oldest_commit_ts_xid', 'checkpoint_time', 'newest_commit_ts_xid', 'oldest_multi_dbid', 'oldest_multi_xid', 'oldest_active_xid', 'oldest_xid_dbid', ], 'pg_control_init': [ 'max_data_alignment', 'database_block_size', 'blocks_per_segment', 'wal_block_size', 'bytes_per_wal_segment', 'max_identifier_length', 'max_index_columns', 'max_toast_chunk_size', 'large_object_chunk_size', 'float8_pass_by_value', 'data_page_checksum_version', ], 'pg_control_recovery': [ 'min_recovery_end_lsn', 'min_recovery_end_timeline', 'backup_start_lsn', 'backup_end_lsn', 'end_of_backup_record_required', ], 'pg_control_system': [ 'pg_control_version', 'catalog_version_no', 'system_identifier', 'pg_control_last_modified', ], 'pg_copy_logical_replication_slot': [ 'slot_name', 'slot_name', 'lsn', 'slot_name', 'lsn', 'lsn', ], 'pg_copy_physical_replication_slot': [ 'slot_name', 'lsn', 'slot_name', 'lsn', ], 'pg_create_logical_replication_slot': ['slot_name', 'lsn'], 'pg_create_physical_replication_slot': ['slot_name', 'lsn'], 'pg_cursor': [ 'name', 'statement', 'is_holdable', 'is_binary', 'is_scrollable', 'creation_time', ], 'pg_event_trigger_ddl_commands': [ 'classid', 'objid', 'objsubid', 'command_tag', 'object_type', 'schema_name', 'object_identity', 'in_extension', 'command', ], 'pg_event_trigger_dropped_objects': [ 'classid', 'objid', 'objsubid', 'original', 'normal', 'is_temporary', 'object_type', 'schema_name', 'object_name', 'object_identity', 'address_names', 'address_args', ], 'pg_event_trigger_table_rewrite_oid': ['oid'], 'pg_extension_update_paths': ['source', 'target', 'path'], 'pg_get_backend_memory_contexts': [ 'name', 'ident', 'parent', 'level', 'total_bytes', 'total_nblocks', 'free_bytes', 'free_chunks', 'used_bytes', ], 'pg_get_catalog_foreign_keys': [ 'fktable', 'fkcols', 'pktable', 'pkcols', 'is_array', 'is_opt', ], 'pg_get_keywords': ['word', 'catcode', 'barelabel', 'catdesc', 'baredesc'], 'pg_get_multixact_members': ['xid', 'mode'], 'pg_get_object_address': ['classid', 'objid', 'objsubid'], 'pg_get_publication_tables': ['relid'], 'pg_get_replication_slots': [ 'slot_name', 'plugin', 'slot_type', 'datoid', 'temporary', 'active', 'active_pid', 'xmin', 'catalog_xmin', 'restart_lsn', 'confirmed_flush_lsn', 'wal_status', 'safe_wal_size', 'two_phase', ], 'pg_get_shmem_allocations': ['name', 'off', 'size', 'allocated_size'], 'pg_hba_file_rules': [ 'line_number', 'type', 'database', 'user_name', 'address', 'netmask', 'auth_method', 'options', 'error', ], 'pg_identify_object': ['type', 'schema', 'name', 'identity'], 'pg_identify_object_as_address': ['type', 'object_names', 'object_args'], 'pg_last_committed_xact': ['xid', 'timestamp', 'roident'], 'pg_lock_status': [ 'locktype', 'database', 'relation', 'page', 'tuple', 'virtualxid', 'transactionid', 'classid', 'objid', 'objsubid', 'virtualtransaction', 'pid', 'fastpath', 'waitstart', 'granted', 'mode', ], 'pg_logical_slot_get_binary_changes': ['lsn', 'xid', 'data'], 'pg_logical_slot_get_changes': ['lsn', 'xid', 'data'], 'pg_logical_slot_peek_binary_changes': ['lsn', 'xid', 'data'], 'pg_logical_slot_peek_changes': ['lsn', 'xid', 'data'], 'pg_ls_archive_statusdir': ['name', 'size', 'modification'], 'pg_ls_logdir': ['name', 'size', 'modification'], 'pg_ls_tmpdir': [ 'name', 'size', 'name', 'modification', 'size', 'modification', ], 'pg_ls_waldir': ['name', 'size', 'modification'], 'pg_mcv_list_items': [ 'index', 'values', 'nulls', 'frequency', 'base_frequency', ], 'pg_options_to_table': ['option_name', 'option_value'], 'pg_partition_ancestors': ['relid'], 'pg_partition_tree': ['relid', 'parentrelid', 'isleaf', 'level'], 'pg_prepared_statement': [ 'name', 'statement', 'prepare_time', 'parameter_types', 'from_sql', 'generic_plans', 'custom_plans', ], 'pg_prepared_xact': ['transaction', 'gid', 'prepared', 'ownerid', 'dbid'], 'pg_replication_slot_advance': ['slot_name', 'end_lsn'], 'pg_sequence_parameters': [ 'start_value', 'minimum_value', 'maximum_value', 'increment', 'cycle_option', 'cache_size', 'data_type', ], 'pg_show_all_file_settings': [ 'sourcefile', 'sourceline', 'seqno', 'name', 'setting', 'applied', 'error', ], 'pg_show_all_settings': [ 'name', 'setting', 'unit', 'category', 'short_desc', 'extra_desc', 'context', 'vartype', 'source', 'min_val', 'max_val', 'enumvals', 'boot_val', 'reset_val', 'sourcefile', 'sourceline', 'pending_restart', ], 'pg_show_replication_origin_status': [ 'local_id', 'external_id', 'remote_lsn', 'local_lsn', ], 'pg_stat_file': [ 'size', 'size', 'access', 'modification', 'access', 'modification', 'change', 'change', 'creation', 'isdir', 'creation', 'isdir', ], 'pg_stat_get_activity': [ 'datid', 'pid', 'usesysid', 'application_name', 'state', 'query', 'wait_event_type', 'wait_event', 'xact_start', 'query_start', 'ssl', 'backend_start', 'state_change', 'client_addr', 'client_hostname', 'client_port', 'backend_xid', 'backend_xmin', 'query_id', 'leader_pid', 'gss_enc', 'gss_princ', 'gss_auth', 'ssl_issuer_dn', 'ssl_client_serial', 'ssl_client_dn', 'sslbits', 'sslcipher', 'sslversion', 'backend_type', ], 'pg_stat_get_archiver': [ 'archived_count', 'last_archived_wal', 'last_archived_time', 'failed_count', 'last_failed_wal', 'last_failed_time', 'stats_reset', ], 'pg_stat_get_progress_info': [ 'pid', 'datid', 'relid', 'param1', 'param2', 'param3', 'param4', 'param5', 'param6', 'param7', 'param20', 'param18', 'param17', 'param16', 'param15', 'param14', 'param13', 'param12', 'param11', 'param10', 'param9', 'param8', 'param19', ], 'pg_stat_get_replication_slot': [ 'slot_name', 'spill_txns', 'spill_count', 'spill_bytes', 'stream_txns', 'stream_count', 'stream_bytes', 'total_txns', 'total_bytes', 'stats_reset', ], 'pg_stat_get_slru': [ 'name', 'blks_zeroed', 'blks_hit', 'blks_read', 'blks_written', 'blks_exists', 'flushes', 'truncates', 'stats_reset', ], 'pg_stat_get_subscription': [ 'subid', 'relid', 'pid', 'received_lsn', 'last_msg_send_time', 'last_msg_receipt_time', 'latest_end_lsn', 'latest_end_time', ], 'pg_stat_get_wal': [ 'wal_records', 'wal_fpi', 'wal_bytes', 'wal_buffers_full', 'wal_write', 'wal_sync', 'wal_write_time', 'wal_sync_time', 'stats_reset', ], 'pg_stat_get_wal_receiver': [ 'pid', 'status', 'receive_start_lsn', 'receive_start_tli', 'written_lsn', 'flushed_lsn', 'received_tli', 'last_msg_send_time', 'last_msg_receipt_time', 'latest_end_lsn', 'latest_end_time', 'slot_name', 'sender_host', 'sender_port', 'conninfo', ], 'pg_stat_get_wal_senders': [ 'pid', 'state', 'sent_lsn', 'write_lsn', 'flush_lsn', 'replay_lsn', 'write_lag', 'flush_lag', 'replay_lag', 'sync_priority', 'sync_state', 'reply_time', ], 'pg_stop_backup': ['lsn', 'labelfile', 'spcmapfile'], 'pg_timezone_abbrevs': ['abbrev', 'utc_offset', 'is_ds'], 'pg_timezone_names': ['name', 'abbrev', 'utc_offset', 'is_dst'], 'pg_walfile_name_offset': ['file_name', 'file_offset'], 'pg_xact_commit_timestamp_origin': ['timestamp', 'roident'], } # retrieved with r''' WITH procedures AS ( SELECT * FROM pg_proc WHERE proname NOT ILIKE 'ts\_%' AND proname NOT ILIKE '\_%' AND proname != 'unnest' AND proname != 'aclexplode' ), pro_args AS ( SELECT proname, UNNEST(proargnames) AS argname, UNNEST(proargmodes) AS argmode, GENERATE_SERIES(0, 10, 1) AS argn FROM procedures ), pro_outputs AS ( SELECT * FROM pro_args WHERE argmode = 'o' ORDER BY proname, argn ) SELECT proname, ARRAY_AGG(argname) FROM pro_outputs GROUP BY proname; ''' ================================================ FILE: edb/pgsql/resolver/range_var.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """SQL resolver that compiles public SQL to internal SQL which is executable in our internal Postgres instance.""" import functools from typing import Optional, Iterable, cast from edb import errors from edb.common.parsing import Span from edb.pgsql import ast as pgast from edb.pgsql import common as pgcommon from edb.pgsql.compiler import astutils as pgastutils from . import dispatch from . import context from . import range_functions from . import expr Context = context.ResolverContextLevel def resolve_BaseRangeVar( range_var: pgast.BaseRangeVar, *, ctx: Context ) -> pgast.BaseRangeVar: # handle join that returns multiple tables and does not use alias if isinstance(range_var, pgast.JoinExpr): return _resolve_JoinExpr(range_var, ctx=ctx) # generate internal alias internal_alias = ctx.alias_generator.get('rel') alias = pgast.Alias( aliasname=internal_alias, colnames=range_var.alias.colnames ) # general case node, table = _resolve_range_var(range_var, alias, ctx=ctx) node = node.replace(span=range_var.span) # infer public name and internal alias table.alias = range_var.alias.aliasname table.reference_as = internal_alias # pull result relation of inner scope into outer scope ctx.scope.tables.append(table) return node @functools.singledispatch def _resolve_range_var( ir: pgast.BaseRangeVar, alias: pgast.Alias, *, ctx: context.ResolverContextLevel, ) -> tuple[pgast.BaseRangeVar, context.Table]: raise ValueError(f'no SQL resolve handler for {ir.__class__}') @_resolve_range_var.register def _resolve_RelRangeVar( range_var: pgast.RelRangeVar, alias: pgast.Alias, *, ctx: Context, ) -> tuple[pgast.BaseRangeVar, context.Table]: with ctx.child() as subctx: relation: pgast.BaseRelation | pgast.CommonTableExpr if isinstance(range_var.relation, pgast.BaseRelation): relation, table = dispatch.resolve_relation( range_var.relation, include_inherited=range_var.include_inherited, ctx=subctx, ) else: relation, cte = resolve_CommonTableExpr( range_var.relation, ctx=subctx ) table = context.Table( name=cte.name, columns=cte.columns, reference_as=cte.name, ) table.columns = [ context.Column( name=alias or col.name, hidden=col.hidden, kind=( context.ColumnByName(reference_as=alias) if alias else col.kind ), ) for col, alias in _zip_column_alias( table.columns, alias, ctx=range_var.span ) ] rel: pgast.BaseRangeVar if isinstance(relation, pgast.Relation): rel = pgast.RelRangeVar(relation=relation, alias=alias) else: assert isinstance(relation, pgast.Query) rel = pgast.RangeSubselect(subquery=relation, alias=alias) return (rel, table) @_resolve_range_var.register def _resolve_RangeSubselect( range_var: pgast.RangeSubselect, alias: pgast.Alias, *, ctx: Context, ) -> tuple[pgast.BaseRangeVar, context.Table]: with ctx.lateral() if range_var.lateral else ctx.child() as subctx: subquery, subtable = dispatch.resolve_relation( range_var.subquery, ctx=subctx ) result = context.Table( name=range_var.alias.aliasname, reference_as=alias.aliasname, columns=[ context.Column( name=alias or col.name, kind=context.ColumnByName( reference_as=alias if alias else col.name ), ) for col, alias in _zip_column_alias( subtable.columns, alias, ctx=range_var.span ) ], ) alias = pgast.Alias( aliasname=alias.aliasname, colnames=[ cast(context.ColumnByName, c.kind).reference_as for c in result.columns ], ) node = pgast.RangeSubselect( subquery=cast(pgast.Query, subquery), alias=alias, lateral=range_var.lateral, ) return node, result def _resolve_JoinExpr( range_var: pgast.JoinExpr, *, ctx: Context, ) -> pgast.BaseRangeVar: larg = resolve_BaseRangeVar(range_var.larg, ctx=ctx) ltable = ctx.scope.tables[len(ctx.scope.tables) - 1] assert len(range_var.joins) == 1, ( "pg resolver should always produce non-flattened joins" ) join = range_var.joins[0] rarg = resolve_BaseRangeVar(join.rarg, ctx=ctx) rtable = ctx.scope.tables[len(ctx.scope.tables) - 1] quals: Optional[pgast.BaseExpr] = None if join.quals: quals = dispatch.resolve(join.quals, ctx=ctx) if join.using_clause: for c in join.using_clause: assert len(c.name) == 1 assert isinstance(c.name[-1], str) c_name = c.name[-1] with ctx.child() as subctx: subctx.scope.tables = [ltable] l_expr = dispatch.resolve(c, ctx=subctx) with ctx.child() as subctx: subctx.scope.tables = [rtable] r_expr = dispatch.resolve(c, ctx=subctx) ctx.scope.factored_columns.append( (c_name, ltable, rtable, join.type) ) quals = pgastutils.extend_binop( quals, pgast.Expr( name='=', lexpr=l_expr, rexpr=r_expr, ), ) return pgast.JoinExpr( larg=larg, joins=[ pgast.JoinClause( type=join.type, rarg=rarg, quals=quals, ) ], ) def resolve_CommonTableExpr( cte: pgast.CommonTableExpr, *, ctx: Context ) -> tuple[pgast.CommonTableExpr, context.CTE]: reference_as = None with ctx.child() as subctx: aliascolnames = cte.aliascolnames if isinstance(cte.query, pgast.SelectStmt): # When no explicit column names were given, we look into the actual # select to see if we can extract the column names from that # instead. This is needed for some RECURSIVE CTEs. if not aliascolnames: if isinstance(cte.query.larg, pgast.SelectStmt): if res := _infer_col_aliases(cte.query.larg): aliascolnames = res if not aliascolnames: if isinstance(cte.query.rarg, pgast.SelectStmt): if res := _infer_col_aliases(cte.query.rarg): aliascolnames = res if cte.recursive and aliascolnames: reference_as = [ subctx.alias_generator.get('col') for _ in aliascolnames ] columns = [ context.Column( name=col, kind=context.ColumnByName(reference_as=ref_as) ) for col, ref_as in zip(aliascolnames, reference_as) ] subctx.scope.ctes.append( context.CTE(name=cte.name, columns=columns) ) query, table = dispatch.resolve_relation(cte.query, ctx=subctx) result = context.CTE(name=cte.name, columns=[]) alias = pgast.Alias(aliasname=cte.name, colnames=aliascolnames) for col, al in _zip_column_alias(table.columns, alias, cte.span): result.columns.append( context.Column( name=al or col.name, kind=context.ColumnByName(reference_as=col.name), ) ) if reference_as: for col, ref_as in zip(result.columns, reference_as): col.kind = context.ColumnByName(reference_as=ref_as) node = pgast.CommonTableExpr( name=cte.name, span=cte.span, aliascolnames=reference_as, query=cast(pgast.Query, query), recursive=cte.recursive, materialized=cte.materialized, ) return node, result def _infer_col_aliases(query: pgast.SelectStmt) -> Optional[list[str]]: aliases = [expr.infer_alias(t) for t in query.target_list] if not all(aliases): return None return cast(list[str], aliases) @_resolve_range_var.register def _resolve_RangeFunction( range_var: pgast.RangeFunction, alias: pgast.Alias, *, ctx: Context, ) -> tuple[pgast.BaseRangeVar, context.Table]: with ctx.lateral() if range_var.lateral else ctx.child() as subctx: functions: list[pgast.BaseExpr] = [] col_names = [] for function in range_var.functions: match function: case pgast.FuncCall(): name = function.name[len(function.name) - 1] if name in range_functions.COLUMNS: col_names.extend(range_functions.COLUMNS[name]) elif name == 'unnest': col_names.extend('unnest' for _ in function.args) else: col_names.append(name) functions.append(dispatch.resolve(function, ctx=subctx)) case pgast.SQLValueFunction(op=op): # If SQLValueFunction gets statically evaluated, we need to # wrap it into a subquery, otherwise it is syntactically # incorrect. E.g. `SELECT * FROM current_user`, should be # compiled to `SELECT * FROM (SELECT 'admin')` val = dispatch.resolve(function, ctx=subctx) name = pgcommon.get_sql_value_function_op(op) range = pgast.RangeSubselect( subquery=pgast.SelectStmt( target_list=[pgast.ResTarget(val=val, name=name)] ), alias=pgast.Alias( aliasname=alias.aliasname, colnames=[name], ), ) column = context.Column( name=name, kind=context.ColumnByName(reference_as=name), ) table = context.Table(columns=[column]) return range, table case _: functions.append(dispatch.resolve(function, ctx=subctx)) inferred_columns = [ context.Column( name=name, kind=context.ColumnByName(reference_as='') ) for name in col_names ] if range_var.with_ordinality: inferred_columns.append( context.Column( name='ordinality', kind=context.ColumnByName(reference_as='ordinality'), ) ) table = context.Table( columns=[ context.Column( name=al or col.name, kind=context.ColumnByName( reference_as=al or ctx.alias_generator.get('col') ), ) for col, al in _zip_column_alias( inferred_columns, alias, ctx=range_var.span ) ] ) alias = pgast.Alias( aliasname=alias.aliasname, colnames=[ cast(context.ColumnByName, c.kind).reference_as for c in table.columns if not c.hidden ], ) node = pgast.RangeFunction( lateral=range_var.lateral, with_ordinality=range_var.with_ordinality, is_rowsfrom=range_var.is_rowsfrom, functions=functions, alias=alias, ) return node, table def _zip_column_alias( columns: list[context.Column], alias: pgast.Alias, ctx: Optional[Span], ) -> Iterable[tuple[context.Column, Optional[str]]]: if not alias.colnames: return map(lambda c: (c, None), columns) columns = [c for c in columns if not c.hidden] if len(columns) != len(alias.colnames): from edb.server.pgcon import errors as pgerror raise errors.QueryError( f'Table alias for `{alias.aliasname}` contains ' f'{len(alias.colnames)} columns, but the query resolves to ' f'{len(columns)} columns', span=ctx, pgext_code=pgerror.ERROR_INVALID_COLUMN_REFERENCE, ) return zip(columns, alias.colnames) ================================================ FILE: edb/pgsql/resolver/relation.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """SQL resolver that compiles public SQL to internal SQL which is executable in our internal Postgres instance.""" from typing import Optional, cast import uuid from edb import errors from edb.server.pgcon import errors as pgerror from edb.edgeql import qltypes from edb.pgsql import ast as pgast from edb.pgsql import common as pgcommon from edb.pgsql import codegen as pgcodegen from edb.pgsql import inheritance as pginheritance from edb.pgsql import types as pgtypes from edb.schema import objtypes as s_objtypes from edb.schema import links as s_links from edb.schema import properties as s_properties from edb.schema import pointers as s_pointers from edb.schema import sources as s_sources from edb.schema import name as sn from . import dispatch from . import context from . import range_var from . import expr from . import sql_introspection from . import command Context = context.ResolverContextLevel @dispatch._resolve_relation.register def resolve_SelectStmt( stmt: pgast.SelectStmt, *, include_inherited: bool, ctx: Context ) -> tuple[pgast.SelectStmt, context.Table]: # CTEs ctes: list[pgast.CommonTableExpr] = [] if stmt.ctes: for cte in stmt.ctes: cte, tab = range_var.resolve_CommonTableExpr(cte, ctx=ctx) ctes.extend(extract_ctes_from_ctx(ctx)) ctes.append(cte) ctx.scope.ctes.append(tab) # VALUES if stmt.values: values = dispatch.resolve_list(stmt.values, ctx=ctx) relation = pgast.SelectStmt( values=values, ctes=ctes + extract_ctes_from_ctx(ctx) ) first_val = values[0] assert isinstance(first_val, pgast.ImplicitRowExpr) table = context.Table( columns=[ context.Column( name=f'column{index + 1}', kind=context.ColumnByName( reference_as=f'column{index + 1}' ), ) for index, _ in enumerate(first_val.args) ] ) return relation, table # UNION if stmt.larg or stmt.rarg: assert stmt.larg and stmt.rarg with ctx.child() as subctx: larg, ltable = dispatch.resolve_relation(stmt.larg, ctx=subctx) with ctx.child() as subctx: rarg, rtable = dispatch.resolve_relation(stmt.rarg, ctx=subctx) # validate equal columns from both sides if len(ltable.columns) != len(rtable.columns): raise errors.QueryError( f'{stmt.op} requires equal number of columns in both sides', span=stmt.span, ) relation = stmt.replace( larg=cast(pgast.Query, larg), rarg=cast(pgast.Query, rarg), ctes=ctes + extract_ctes_from_ctx(ctx), ) return (relation, ltable) # FROM from_clause: list[pgast.BaseRangeVar] = [] for clause in stmt.from_clause: from_clause.append(range_var.resolve_BaseRangeVar(clause, ctx=ctx)) # WHERE where = dispatch.resolve_opt(stmt.where_clause, ctx=ctx) # GROUP BY with ctx.child() as subctx: register_projections(stmt.target_list, ctx=subctx) group_clause = dispatch.resolve_opt_list(stmt.group_clause, ctx=subctx) # HAVING having = dispatch.resolve_opt(stmt.having_clause, ctx=ctx) # SELECT projection table = context.Table() target_list: list[pgast.ResTarget] = [] names: set[str] = set() for t in stmt.target_list: targets, columns = expr.resolve_ResTarget( t, existing_names=names, ctx=ctx ) target_list.extend(targets) table.columns.extend(columns) names.update(c.name for c in columns) distinct_clause = None if stmt.distinct_clause: distinct_clause = [ (c if isinstance(c, pgast.Star) else dispatch.resolve(c, ctx=ctx)) for c in stmt.distinct_clause ] # order by can refer to columns in SELECT projection, so we need to add # table.columns into scope projected_table = context.Table( columns=[ context.Column( name=c.name, kind=context.ColumnByName(reference_as=c.name), ) for c, target in zip(table.columns, stmt.target_list) if target.name and ( not isinstance(target.val, pgast.ColumnRef) or target.val.name[-1] != target.name ) ] ) if len(projected_table.columns) > 0: ctx.scope.tables.append(projected_table) sort_clause = dispatch.resolve_opt_list(stmt.sort_clause, ctx=ctx) limit_offset = dispatch.resolve_opt(stmt.limit_offset, ctx=ctx) limit_count = dispatch.resolve_opt(stmt.limit_count, ctx=ctx) locking_clause = dispatch.resolve_opt_list(stmt.locking_clause, ctx=ctx) ctes.extend(extract_ctes_from_ctx(ctx)) res = pgast.SelectStmt( distinct_clause=distinct_clause, from_clause=from_clause, target_list=target_list, group_clause=group_clause, having_clause=having, where_clause=where, sort_clause=sort_clause, limit_offset=limit_offset, limit_count=limit_count, locking_clause=locking_clause, ctes=ctes if len(ctes) > 0 else None, ) return ( res, table, ) # If current context is top-level, return additional CTEs that need to be # injected. They were probably generated by DML. def extract_ctes_from_ctx( ctx: context.ResolverContextLevel, ) -> list[pgast.CommonTableExpr]: if ctx.subquery_depth != 0: return [] res = list(ctx.ctes_buffer) ctx.ctes_buffer.clear() return res def register_projections(target_list: list[pgast.ResTarget], *, ctx: Context): # add aliases from target_list into scope table = context.Table() for target in target_list: if not target.name: continue table.columns.append( context.Column( name=target.name, kind=context.ColumnByName(reference_as=target.name), ) ) ctx.scope.tables.append(table) PG_TOAST_TABLE: list[ tuple[sql_introspection.ColumnName, sql_introspection.ColumnType, int] ] = [ ('chunk_id', None, 13), ('chunk_seq', None, 13), ('chunk_data', None, 13), ] @dispatch._resolve_relation.register def resolve_relation( relation: pgast.Relation, *, include_inherited: bool, ctx: Context ) -> tuple[pgast.BaseRelation, context.Table]: assert relation.name rel: pgast.BaseRelation if relation.catalogname and relation.catalogname != 'postgres': raise errors.QueryError( f'queries cross databases are not supported', span=relation.span, ) # try information_schema, pg_catalog and pg_toast preset_tables = None if relation.schemaname == 'information_schema': preset_tables = ( sql_introspection.INFORMATION_SCHEMA, pgcommon.versioned_schema('edgedbsql'), ) elif not relation.schemaname or relation.schemaname == 'pg_catalog': preset_tables = ( sql_introspection.PG_CATALOG, pgcommon.versioned_schema('edgedbsql'), ) elif relation.schemaname == 'pg_toast': preset_tables = ({relation.name: PG_TOAST_TABLE}, 'pg_toast') if preset_tables and relation.name in preset_tables[0]: cols = [ context.Column(name=n, kind=context.ColumnByName(reference_as=n)) for n, _type, _ver_since in preset_tables[0][relation.name] ] cols.extend(_construct_system_columns()) table = context.Table( name=relation.name, columns=cols, is_direct_relation=True ) rel = pgast.Relation(name=relation.name, schemaname=preset_tables[1]) return rel, table schema_name = relation.schemaname # try a CTE if not schema_name or schema_name == 'public': cte = next((t for t in ctx.scope.ctes if t.name == relation.name), None) if cte: table = context.Table(name=cte.name, columns=cte.columns.copy()) return pgast.Relation(name=cte.name, schemaname=None), table def public_to_default(s: str) -> str: # make sure to match `public`, `public::blah`, but not `public_blah` if s == 'public': return 'default' if s.startswith('public::'): return 'default' + s[6:] return s # lookup the object in schema schemas = [schema_name] if schema_name else ctx.options.search_path modules = [public_to_default(s) for s in schemas] obj: Optional[s_sources.Source | s_properties.Property] = None for module in modules: if obj: break object_name = sn.QualName(module, relation.name) obj = ctx.schema.get( # type: ignore object_name, None, module_aliases={None: 'default'}, type=s_objtypes.ObjectType, ) # try pointer table for module in modules: if obj: break obj = _lookup_pointer_table(module, relation.name, ctx) if not obj: rel_name = pgcodegen.generate_source(relation) raise errors.QueryError( f'unknown table `{rel_name}`', span=relation.span, pgext_code=pgerror.ERROR_UNDEFINED_TABLE, ) # extract table name table = context.Table(schema_id=obj.id, name=relation.name) # extract table columns # when changing this, make sure to update sql information_schema columns: list[context.Column] = [] if isinstance(obj, s_sources.Source): pointers = obj.get_pointers(ctx.schema).objects(ctx.schema) for p in pointers: card = p.get_cardinality(ctx.schema) if card.is_multi(): continue columns.append(_construct_column(p, ctx)) else: for c in ['source', 'target']: columns.append( context.Column( name=c, kind=context.ColumnByName(reference_as=c) ) ) def column_order_key(c: context.Column) -> tuple[int, str]: spec = {'id': 0, 'source': 0, 'target': 1} order: int if isinstance(c.kind, context.ColumnByName): order = spec.get(c.kind.reference_as, 2) else: order = 2 return (order, c.name or '') # sort by name but put `id` first columns.sort(key=column_order_key) table.columns.extend(columns) table.columns.extend(_construct_system_columns()) if ctx.options.apply_access_policies and _has_access_policies(obj, ctx): if isinstance(obj, s_objtypes.ObjectType): rel = _compile_read_of_obj_table(obj, include_inherited, table, ctx) else: assert isinstance(obj, (s_links.Link | s_pointers.Pointer)) rel = _compile_read_of_link_table( obj, include_inherited, table, ctx ) else: if include_inherited and _has_sub_types(obj, ctx): rel = _relation_of_inheritance_cte(obj, ctx) else: rel = _relation_of_table(obj, table, ctx) return rel, table def _has_access_policies( obj: s_sources.Source | s_properties.Property, ctx: Context ): if isinstance(obj, s_pointers.Pointer): source = obj.get_source(ctx.schema) assert isinstance(source, (s_objtypes.ObjectType, s_links.Link)) if isinstance(source, s_objtypes.ObjectType): obj = source elif isinstance(source, s_links.Link): source = source.get_source(ctx.schema) assert isinstance(source, s_objtypes.ObjectType) obj = source else: return False assert isinstance(obj, s_objtypes.ObjectType) policies = obj.get_access_policies(ctx.schema) return len(policies) > 0 def _has_sub_types(obj: s_sources.Source | s_properties.Property, ctx: Context): return len(obj.children(ctx.schema)) > 0 def _relation_of_table( obj: s_sources.Source | s_properties.Property, table: context.Table, ctx: Context, ) -> pgast.Relation: schemaname, dbname = pgcommon.get_backend_name( ctx.schema, obj, aspect='table', catenate=False ) relation = pgast.Relation(name=dbname, schemaname=schemaname) table.is_direct_relation = True # When referencing actual tables, we need to statically provide __type__, # since this column does not exist in the database. for col in table.columns: if col.name == '__type__': col.kind = context.ColumnStaticVal(val=obj.id) break return relation def _relation_of_inheritance_cte( obj: s_sources.Source | s_properties.Property, ctx: Context ) -> pgast.Relation: if obj not in ctx.inheritance_ctes: cte = pgast.CommonTableExpr( name=ctx.alias_generator.get('inh'), query=pginheritance.get_inheritance_view(ctx.schema, obj), ) ctx.ctes_buffer.append(cte) ctx.inheritance_ctes[obj] = cte.name return pgast.Relation(name=ctx.inheritance_ctes[obj]) def _lookup_pointer_table( module: str, name: str, ctx: Context ) -> Optional[s_links.Link | s_properties.Property]: # Pointer tables are either: # - multi link tables # - single link tables with at least one property besides source and target # - multi property tables if '.' not in name: return None object_name, link_name = name.split('.') object_name_qual = sn.QualName(module, object_name) parent: s_objtypes.ObjectType = ctx.schema.get( # type: ignore object_name_qual, None, module_aliases={None: 'default'}, type=s_objtypes.ObjectType, ) if not parent: return None pointer = parent.maybe_get_ptr( ctx.schema, sn.UnqualName.from_string(link_name) ) if not pointer: return None if pointer.get_computable(ctx.schema) or pointer.get_internal(ctx.schema): return None match pointer: case s_links.Link(): if pointer.get_cardinality(ctx.schema).is_single(): # single links only for tables with at least one property # besides source and target l_pointers = pointer.get_pointers(ctx.schema).objects( ctx.schema ) if len(l_pointers) <= 2: return None return pointer case s_properties.Property(): if pointer.get_cardinality(ctx.schema).is_single(): return None return pointer raise NotImplementedError() def _construct_column(p: s_pointers.Pointer, ctx: Context) -> context.Column: short_name = p.get_shortname(ctx.schema) col_name: str kind: context.ColumnKind if isinstance(p, s_properties.Property): col_name = short_name.name if p.get_computable(ctx.schema): kind = context.ColumnComputable(pointer=p) elif p.is_link_source_property(ctx.schema): kind = context.ColumnByName(reference_as='source') elif p.is_link_target_property(ctx.schema): kind = context.ColumnByName(reference_as='target') elif p.is_id_pointer(ctx.schema): kind = context.ColumnByName(reference_as='id') else: _, dbname = pgcommon.get_backend_name(ctx.schema, p, catenate=False) kind = context.ColumnByName(reference_as=dbname) elif isinstance(p, s_links.Link): if p.get_computable(ctx.schema): col_name = short_name.name + '_id' kind = context.ColumnComputable(pointer=p) elif short_name.name == '__type__': col_name = '__type__' kind = context.ColumnByName(reference_as='__type__') else: col_name = short_name.name + '_id' _, dbname = pgcommon.get_backend_name(ctx.schema, p, catenate=False) kind = context.ColumnByName(reference_as=dbname) return context.Column(name=col_name, kind=kind) def _construct_system_columns() -> list[context.Column]: return [ context.Column( name=c, kind=context.ColumnByName(reference_as=c), hidden=True ) for c in ['tableoid', 'xmin', 'cmin', 'xmax', 'cmax', 'ctid'] ] def _compile_read_of_obj_table( obj: s_objtypes.ObjectType | s_links.Link, include_inherited: bool, table: context.Table, ctx: Context, ) -> pgast.Relation: from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.ir import ast as irast from edb.pgsql import compiler as pgcompiler from edb.pgsql.compiler import enums as pgce obj_name: sn.QualName = obj.get_name(ctx.schema) assert obj_name ql_stmt = qlast.SelectQuery( result=qlast.Path( steps=[qlast.ObjectRef(module=obj_name.module, name=obj_name.name)] ) ) if not include_inherited: ql_stmt.where = qlast.BinOp( left=qlast.Path( partial=True, steps=[qlast.Ptr(name='__type__'), qlast.Ptr(name='id')], ), op='=', right=qlast.TypeCast( expr=qlast.Constant.string(str(obj.id)), type=qlast.TypeName(maintype=qlast.ObjectRef(name='uuid')), ), ) ir_stmt = qlcompiler.compile_ast_to_ir( ql_stmt, ctx.schema, options=qlcompiler.CompilerOptions(apply_user_access_policies=True), ) sql_tree = pgcompiler.compile_ir_to_sql_tree( ir_stmt, output_format=pgcompiler.OutputFormat.NATIVE_INTERNAL, alias_generator=ctx.alias_generator, ) command.merge_params(sql_tree, ir_stmt, ctx) # add CTEs to resolver's CTE buffer assert isinstance(sql_tree.ast, pgast.Query) if sql_tree.ast.ctes: ctx.ctes_buffer.extend(sql_tree.ast.ctes) sql_tree.ast.ctes.clear() SYSTEM_COLS = {'tableoid', 'xmin', 'cmin', 'xmax', 'cmax', 'ctid'} # pull all expected columns out of the result rvar obj_id: irast.PathId if isinstance(obj, s_objtypes.ObjectType): obj_id = irast.PathId.from_type(ctx.schema, obj, env=None) else: obj_id = irast.PathId.from_pointer(ctx.schema, obj, env=None) for column in table.columns: if not isinstance(column.kind, context.ColumnByName): continue if column.kind.reference_as in {'id', 'source', 'target', '__type__'}: ptr = obj.getptr( ctx.schema, sn.UnqualName.from_string(column.kind.reference_as) ) ptr_id = irast.PathId.from_pointer(ctx.schema, ptr, env=None) elif column.kind.reference_as in SYSTEM_COLS: el_name = sn.QualName('__object__', column.kind.reference_as) ptr_ref = irast.SpecialPointerRef( name=el_name, shortname=el_name, out_source=obj_id.target, out_target=pgtypes.pg_oid_typeref, out_cardinality=qltypes.Cardinality.AT_MOST_ONE, ) ptr_id = obj_id.extend(ptrref=ptr_ref) else: ptr = ctx.schema.get_by_id( uuid.UUID(column.kind.reference_as), type=s_pointers.Pointer ) ptr_id = irast.PathId.from_pointer(ctx.schema, ptr, env=None) output = pgcompiler.pathctx.get_path_output( sql_tree.ast, ptr_id, aspect=pgce.PathAspect.VALUE, env=sql_tree.env, ) assert isinstance(output, pgast.ColumnRef) assert isinstance(output.name[-1], str) # override how this column will be referenced as column.kind.reference_as = output.name[-1] cte_name = ctx.alias_generator.get('tbl') ctx.ctes_buffer.append( pgast.CommonTableExpr( name=cte_name, query=sql_tree.ast, ) ) return pgast.Relation(name=cte_name) def _compile_read_of_link_table( obj: s_links.Link | s_properties.Property, include_inherited: bool, table: context.Table, ctx: Context, ) -> pgast.BaseRelation: # get CTE that will provide source relation, with access policies applied source = obj.get_source(ctx.schema) assert isinstance(source, (s_objtypes.ObjectType, s_links.Link)) source_table = context.Table( schema_id=source.id, columns=[ context.Column( name="id", kind=context.ColumnByName(reference_as="id") ) ], ) source_rel = _compile_read_of_obj_table( source, include_inherited, source_table, ctx ) source_table_id = source_table.columns[0].kind assert isinstance(source_table_id, context.ColumnByName) # get name of link table (with inheritance) if obj not in ctx.inheritance_ctes: cte = pgast.CommonTableExpr( name=ctx.alias_generator.get('inh'), query=pginheritance.get_inheritance_view(ctx.schema, obj), ) ctx.ctes_buffer.append(cte) ctx.inheritance_ctes[obj] = cte.name link_table_name = ctx.inheritance_ctes[obj] # inner join source table with the link table target_list = [] for c in table.columns: if not isinstance(c.kind, context.ColumnByName): continue target_list.append( pgast.ResTarget( val=pgast.ColumnRef(name=("l", c.kind.reference_as)) ) ) return pgast.SelectStmt( from_clause=[ pgast.JoinExpr( larg=pgast.RelRangeVar( relation=source_rel, alias=pgast.Alias(aliasname="s") ), joins=[ pgast.JoinClause( type="INNER", rarg=pgast.RelRangeVar( relation=pgast.Relation(name=link_table_name), alias=pgast.Alias(aliasname="l"), ), quals=pgast.Expr( name="=", lexpr=pgast.ColumnRef( name=("s", source_table_id.reference_as) ), rexpr=pgast.ColumnRef(name=("l", "source")), ), ), ], ), ], target_list=target_list, ) ================================================ FILE: edb/pgsql/resolver/sql_introspection.py ================================================ # AUTOGENERATED FROM _localdev postgres instance WITH # $ edb gen-sql-introspection """Declarations of information schema and pg_catalog""" ColumnName = str ColumnType = str | None INFORMATION_SCHEMA: dict[str, list[tuple[ColumnName, ColumnType, int]]] = { "administrable_role_authorizations": [ ("grantee", "sql_identifier", 13), ("role_name", "sql_identifier", 13), ("is_grantable", "yes_or_no", 13), ], "applicable_roles": [ ("grantee", "sql_identifier", 13), ("role_name", "sql_identifier", 13), ("is_grantable", "yes_or_no", 13), ], "attributes": [ ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("attribute_name", "sql_identifier", 13), ("ordinal_position", "cardinal_number", 13), ("attribute_default", "character_data", 13), ("is_nullable", "yes_or_no", 13), ("data_type", "character_data", 13), ("character_maximum_length", "cardinal_number", 13), ("character_octet_length", "cardinal_number", 13), ("character_set_catalog", "sql_identifier", 13), ("character_set_schema", "sql_identifier", 13), ("character_set_name", "sql_identifier", 13), ("collation_catalog", "sql_identifier", 13), ("collation_schema", "sql_identifier", 13), ("collation_name", "sql_identifier", 13), ("numeric_precision", "cardinal_number", 13), ("numeric_precision_radix", "cardinal_number", 13), ("numeric_scale", "cardinal_number", 13), ("datetime_precision", "cardinal_number", 13), ("interval_type", "character_data", 13), ("interval_precision", "cardinal_number", 13), ("attribute_udt_catalog", "sql_identifier", 13), ("attribute_udt_schema", "sql_identifier", 13), ("attribute_udt_name", "sql_identifier", 13), ("scope_catalog", "sql_identifier", 13), ("scope_schema", "sql_identifier", 13), ("scope_name", "sql_identifier", 13), ("maximum_cardinality", "cardinal_number", 13), ("dtd_identifier", "sql_identifier", 13), ("is_derived_reference_attribute", "yes_or_no", 13), ], "character_sets": [ ("character_set_catalog", "sql_identifier", 13), ("character_set_schema", "sql_identifier", 13), ("character_set_name", "sql_identifier", 13), ("character_repertoire", "sql_identifier", 13), ("form_of_use", "sql_identifier", 13), ("default_collate_catalog", "sql_identifier", 13), ("default_collate_schema", "sql_identifier", 13), ("default_collate_name", "sql_identifier", 13), ], "check_constraint_routine_usage": [ ("constraint_catalog", "sql_identifier", 13), ("constraint_schema", "sql_identifier", 13), ("constraint_name", "sql_identifier", 13), ("specific_catalog", "sql_identifier", 13), ("specific_schema", "sql_identifier", 13), ("specific_name", "sql_identifier", 13), ], "check_constraints": [ ("constraint_catalog", "sql_identifier", 13), ("constraint_schema", "sql_identifier", 13), ("constraint_name", "sql_identifier", 13), ("check_clause", "character_data", 13), ], "collation_character_set_applicability": [ ("collation_catalog", "sql_identifier", 13), ("collation_schema", "sql_identifier", 13), ("collation_name", "sql_identifier", 13), ("character_set_catalog", "sql_identifier", 13), ("character_set_schema", "sql_identifier", 13), ("character_set_name", "sql_identifier", 13), ], "collations": [ ("collation_catalog", "sql_identifier", 13), ("collation_schema", "sql_identifier", 13), ("collation_name", "sql_identifier", 13), ("pad_attribute", "character_data", 13), ], "column_column_usage": [ ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("column_name", "sql_identifier", 13), ("dependent_column", "sql_identifier", 13), ], "column_domain_usage": [ ("domain_catalog", "sql_identifier", 13), ("domain_schema", "sql_identifier", 13), ("domain_name", "sql_identifier", 13), ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("column_name", "sql_identifier", 13), ], "column_options": [ ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("column_name", "sql_identifier", 13), ("option_name", "sql_identifier", 13), ("option_value", "character_data", 13), ], "column_privileges": [ ("grantor", "sql_identifier", 13), ("grantee", "sql_identifier", 13), ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("column_name", "sql_identifier", 13), ("privilege_type", "character_data", 13), ("is_grantable", "yes_or_no", 13), ], "column_udt_usage": [ ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("column_name", "sql_identifier", 13), ], "columns": [ ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("column_name", "sql_identifier", 13), ("ordinal_position", "cardinal_number", 13), ("column_default", "character_data", 13), ("is_nullable", "yes_or_no", 13), ("data_type", "character_data", 13), ("character_maximum_length", "cardinal_number", 13), ("character_octet_length", "cardinal_number", 13), ("numeric_precision", "cardinal_number", 13), ("numeric_precision_radix", "cardinal_number", 13), ("numeric_scale", "cardinal_number", 13), ("datetime_precision", "cardinal_number", 13), ("interval_type", "character_data", 13), ("interval_precision", "cardinal_number", 13), ("character_set_catalog", "sql_identifier", 13), ("character_set_schema", "sql_identifier", 13), ("character_set_name", "sql_identifier", 13), ("collation_catalog", "sql_identifier", 13), ("collation_schema", "sql_identifier", 13), ("collation_name", "sql_identifier", 13), ("domain_catalog", "sql_identifier", 13), ("domain_schema", "sql_identifier", 13), ("domain_name", "sql_identifier", 13), ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("scope_catalog", "sql_identifier", 13), ("scope_schema", "sql_identifier", 13), ("scope_name", "sql_identifier", 13), ("maximum_cardinality", "cardinal_number", 13), ("dtd_identifier", "sql_identifier", 13), ("is_self_referencing", "yes_or_no", 13), ("is_identity", "yes_or_no", 13), ("identity_generation", "character_data", 13), ("identity_start", "character_data", 13), ("identity_increment", "character_data", 13), ("identity_maximum", "character_data", 13), ("identity_minimum", "character_data", 13), ("identity_cycle", "yes_or_no", 13), ("is_generated", "character_data", 13), ("generation_expression", "character_data", 13), ("is_updatable", "yes_or_no", 13), ], "constraint_column_usage": [ ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("column_name", "sql_identifier", 13), ("constraint_catalog", "sql_identifier", 13), ("constraint_schema", "sql_identifier", 13), ("constraint_name", "sql_identifier", 13), ], "constraint_table_usage": [ ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("constraint_catalog", "sql_identifier", 13), ("constraint_schema", "sql_identifier", 13), ("constraint_name", "sql_identifier", 13), ], "data_type_privileges": [ ("object_catalog", "sql_identifier", 13), ("object_schema", "sql_identifier", 13), ("object_name", "sql_identifier", 13), ("object_type", "character_data", 13), ("dtd_identifier", "sql_identifier", 13), ], "domain_constraints": [ ("constraint_catalog", "sql_identifier", 13), ("constraint_schema", "sql_identifier", 13), ("constraint_name", "sql_identifier", 13), ("domain_catalog", "sql_identifier", 13), ("domain_schema", "sql_identifier", 13), ("domain_name", "sql_identifier", 13), ("is_deferrable", "yes_or_no", 13), ("initially_deferred", "yes_or_no", 13), ], "domain_udt_usage": [ ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("domain_catalog", "sql_identifier", 13), ("domain_schema", "sql_identifier", 13), ("domain_name", "sql_identifier", 13), ], "domains": [ ("domain_catalog", "sql_identifier", 13), ("domain_schema", "sql_identifier", 13), ("domain_name", "sql_identifier", 13), ("data_type", "character_data", 13), ("character_maximum_length", "cardinal_number", 13), ("character_octet_length", "cardinal_number", 13), ("character_set_catalog", "sql_identifier", 13), ("character_set_schema", "sql_identifier", 13), ("character_set_name", "sql_identifier", 13), ("collation_catalog", "sql_identifier", 13), ("collation_schema", "sql_identifier", 13), ("collation_name", "sql_identifier", 13), ("numeric_precision", "cardinal_number", 13), ("numeric_precision_radix", "cardinal_number", 13), ("numeric_scale", "cardinal_number", 13), ("datetime_precision", "cardinal_number", 13), ("interval_type", "character_data", 13), ("interval_precision", "cardinal_number", 13), ("domain_default", "character_data", 13), ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("scope_catalog", "sql_identifier", 13), ("scope_schema", "sql_identifier", 13), ("scope_name", "sql_identifier", 13), ("maximum_cardinality", "cardinal_number", 13), ("dtd_identifier", "sql_identifier", 13), ], "element_types": [ ("object_catalog", "sql_identifier", 13), ("object_schema", "sql_identifier", 13), ("object_name", "sql_identifier", 13), ("object_type", "character_data", 13), ("collection_type_identifier", "sql_identifier", 13), ("data_type", "character_data", 13), ("character_maximum_length", "cardinal_number", 13), ("character_octet_length", "cardinal_number", 13), ("character_set_catalog", "sql_identifier", 13), ("character_set_schema", "sql_identifier", 13), ("character_set_name", "sql_identifier", 13), ("collation_catalog", "sql_identifier", 13), ("collation_schema", "sql_identifier", 13), ("collation_name", "sql_identifier", 13), ("numeric_precision", "cardinal_number", 13), ("numeric_precision_radix", "cardinal_number", 13), ("numeric_scale", "cardinal_number", 13), ("datetime_precision", "cardinal_number", 13), ("interval_type", "character_data", 13), ("interval_precision", "cardinal_number", 13), ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("scope_catalog", "sql_identifier", 13), ("scope_schema", "sql_identifier", 13), ("scope_name", "sql_identifier", 13), ("maximum_cardinality", "cardinal_number", 13), ("dtd_identifier", "sql_identifier", 13), ], "enabled_roles": [ ("role_name", "sql_identifier", 13), ], "foreign_data_wrapper_options": [ ("foreign_data_wrapper_catalog", "sql_identifier", 13), ("foreign_data_wrapper_name", "sql_identifier", 13), ("option_name", "sql_identifier", 13), ("option_value", "character_data", 13), ], "foreign_data_wrappers": [ ("foreign_data_wrapper_catalog", "sql_identifier", 13), ("foreign_data_wrapper_name", "sql_identifier", 13), ("authorization_identifier", "sql_identifier", 13), ("library_name", "character_data", 13), ("foreign_data_wrapper_language", "character_data", 13), ], "foreign_server_options": [ ("foreign_server_catalog", "sql_identifier", 13), ("foreign_server_name", "sql_identifier", 13), ("option_name", "sql_identifier", 13), ("option_value", "character_data", 13), ], "foreign_servers": [ ("foreign_server_catalog", "sql_identifier", 13), ("foreign_server_name", "sql_identifier", 13), ("foreign_data_wrapper_catalog", "sql_identifier", 13), ("foreign_data_wrapper_name", "sql_identifier", 13), ("foreign_server_type", "character_data", 13), ("foreign_server_version", "character_data", 13), ("authorization_identifier", "sql_identifier", 13), ], "foreign_table_options": [ ("foreign_table_catalog", "sql_identifier", 13), ("foreign_table_schema", "sql_identifier", 13), ("foreign_table_name", "sql_identifier", 13), ("option_name", "sql_identifier", 13), ("option_value", "character_data", 13), ], "foreign_tables": [ ("foreign_table_catalog", "sql_identifier", 13), ("foreign_table_schema", "sql_identifier", 13), ("foreign_table_name", "sql_identifier", 13), ("foreign_server_catalog", "sql_identifier", 13), ("foreign_server_name", "sql_identifier", 13), ], "information_schema_catalog_name": [ ("catalog_name", "sql_identifier", 13), ], "key_column_usage": [ ("constraint_catalog", "sql_identifier", 13), ("constraint_schema", "sql_identifier", 13), ("constraint_name", "sql_identifier", 13), ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("column_name", "sql_identifier", 13), ("ordinal_position", "cardinal_number", 13), ("position_in_unique_constraint", "cardinal_number", 13), ], "parameters": [ ("specific_catalog", "sql_identifier", 13), ("specific_schema", "sql_identifier", 13), ("specific_name", "sql_identifier", 13), ("ordinal_position", "cardinal_number", 13), ("parameter_mode", "character_data", 13), ("is_result", "yes_or_no", 13), ("as_locator", "yes_or_no", 13), ("parameter_name", "sql_identifier", 13), ("data_type", "character_data", 13), ("character_maximum_length", "cardinal_number", 13), ("character_octet_length", "cardinal_number", 13), ("character_set_catalog", "sql_identifier", 13), ("character_set_schema", "sql_identifier", 13), ("character_set_name", "sql_identifier", 13), ("collation_catalog", "sql_identifier", 13), ("collation_schema", "sql_identifier", 13), ("collation_name", "sql_identifier", 13), ("numeric_precision", "cardinal_number", 13), ("numeric_precision_radix", "cardinal_number", 13), ("numeric_scale", "cardinal_number", 13), ("datetime_precision", "cardinal_number", 13), ("interval_type", "character_data", 13), ("interval_precision", "cardinal_number", 13), ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("scope_catalog", "sql_identifier", 13), ("scope_schema", "sql_identifier", 13), ("scope_name", "sql_identifier", 13), ("maximum_cardinality", "cardinal_number", 13), ("dtd_identifier", "sql_identifier", 13), ("parameter_default", "character_data", 13), ], "referential_constraints": [ ("constraint_catalog", "sql_identifier", 13), ("constraint_schema", "sql_identifier", 13), ("constraint_name", "sql_identifier", 13), ("unique_constraint_catalog", "sql_identifier", 13), ("unique_constraint_schema", "sql_identifier", 13), ("unique_constraint_name", "sql_identifier", 13), ("match_option", "character_data", 13), ("update_rule", "character_data", 13), ("delete_rule", "character_data", 13), ], "role_column_grants": [ ("grantor", "sql_identifier", 13), ("grantee", "sql_identifier", 13), ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("column_name", "sql_identifier", 13), ("privilege_type", "character_data", 13), ("is_grantable", "yes_or_no", 13), ], "role_routine_grants": [ ("grantor", "sql_identifier", 13), ("grantee", "sql_identifier", 13), ("specific_catalog", "sql_identifier", 13), ("specific_schema", "sql_identifier", 13), ("specific_name", "sql_identifier", 13), ("routine_catalog", "sql_identifier", 13), ("routine_schema", "sql_identifier", 13), ("routine_name", "sql_identifier", 13), ("privilege_type", "character_data", 13), ("is_grantable", "yes_or_no", 13), ], "role_table_grants": [ ("grantor", "sql_identifier", 13), ("grantee", "sql_identifier", 13), ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("privilege_type", "character_data", 13), ("is_grantable", "yes_or_no", 13), ("with_hierarchy", "yes_or_no", 13), ], "role_udt_grants": [ ("grantor", "sql_identifier", 13), ("grantee", "sql_identifier", 13), ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("privilege_type", "character_data", 13), ("is_grantable", "yes_or_no", 13), ], "role_usage_grants": [ ("grantor", "sql_identifier", 13), ("grantee", "sql_identifier", 13), ("object_catalog", "sql_identifier", 13), ("object_schema", "sql_identifier", 13), ("object_name", "sql_identifier", 13), ("object_type", "character_data", 13), ("privilege_type", "character_data", 13), ("is_grantable", "yes_or_no", 13), ], "routine_column_usage": [ ("specific_catalog", "sql_identifier", 14), ("specific_schema", "sql_identifier", 14), ("specific_name", "sql_identifier", 14), ("routine_catalog", "sql_identifier", 14), ("routine_schema", "sql_identifier", 14), ("routine_name", "sql_identifier", 14), ("table_catalog", "sql_identifier", 14), ("table_schema", "sql_identifier", 14), ("table_name", "sql_identifier", 14), ("column_name", "sql_identifier", 14), ], "routine_privileges": [ ("grantor", "sql_identifier", 13), ("grantee", "sql_identifier", 13), ("specific_catalog", "sql_identifier", 13), ("specific_schema", "sql_identifier", 13), ("specific_name", "sql_identifier", 13), ("routine_catalog", "sql_identifier", 13), ("routine_schema", "sql_identifier", 13), ("routine_name", "sql_identifier", 13), ("privilege_type", "character_data", 13), ("is_grantable", "yes_or_no", 13), ], "routine_routine_usage": [ ("specific_catalog", "sql_identifier", 14), ("specific_schema", "sql_identifier", 14), ("specific_name", "sql_identifier", 14), ("routine_catalog", "sql_identifier", 14), ("routine_schema", "sql_identifier", 14), ("routine_name", "sql_identifier", 14), ], "routine_sequence_usage": [ ("specific_catalog", "sql_identifier", 14), ("specific_schema", "sql_identifier", 14), ("specific_name", "sql_identifier", 14), ("routine_catalog", "sql_identifier", 14), ("routine_schema", "sql_identifier", 14), ("routine_name", "sql_identifier", 14), ("sequence_catalog", "sql_identifier", 14), ("sequence_schema", "sql_identifier", 14), ("sequence_name", "sql_identifier", 14), ], "routine_table_usage": [ ("specific_catalog", "sql_identifier", 14), ("specific_schema", "sql_identifier", 14), ("specific_name", "sql_identifier", 14), ("routine_catalog", "sql_identifier", 14), ("routine_schema", "sql_identifier", 14), ("routine_name", "sql_identifier", 14), ("table_catalog", "sql_identifier", 14), ("table_schema", "sql_identifier", 14), ("table_name", "sql_identifier", 14), ], "routines": [ ("specific_catalog", "sql_identifier", 13), ("specific_schema", "sql_identifier", 13), ("specific_name", "sql_identifier", 13), ("routine_catalog", "sql_identifier", 13), ("routine_schema", "sql_identifier", 13), ("routine_name", "sql_identifier", 13), ("routine_type", "character_data", 13), ("module_catalog", "sql_identifier", 13), ("module_schema", "sql_identifier", 13), ("module_name", "sql_identifier", 13), ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("data_type", "character_data", 13), ("character_maximum_length", "cardinal_number", 13), ("character_octet_length", "cardinal_number", 13), ("character_set_catalog", "sql_identifier", 13), ("character_set_schema", "sql_identifier", 13), ("character_set_name", "sql_identifier", 13), ("collation_catalog", "sql_identifier", 13), ("collation_schema", "sql_identifier", 13), ("collation_name", "sql_identifier", 13), ("numeric_precision", "cardinal_number", 13), ("numeric_precision_radix", "cardinal_number", 13), ("numeric_scale", "cardinal_number", 13), ("datetime_precision", "cardinal_number", 13), ("interval_type", "character_data", 13), ("interval_precision", "cardinal_number", 13), ("type_udt_catalog", "sql_identifier", 13), ("type_udt_schema", "sql_identifier", 13), ("type_udt_name", "sql_identifier", 13), ("scope_catalog", "sql_identifier", 13), ("scope_schema", "sql_identifier", 13), ("scope_name", "sql_identifier", 13), ("maximum_cardinality", "cardinal_number", 13), ("dtd_identifier", "sql_identifier", 13), ("routine_body", "character_data", 13), ("routine_definition", "character_data", 13), ("external_name", "character_data", 13), ("external_language", "character_data", 13), ("parameter_style", "character_data", 13), ("is_deterministic", "yes_or_no", 13), ("sql_data_access", "character_data", 13), ("is_null_call", "yes_or_no", 13), ("sql_path", "character_data", 13), ("schema_level_routine", "yes_or_no", 13), ("max_dynamic_result_sets", "cardinal_number", 13), ("is_user_defined_cast", "yes_or_no", 13), ("is_implicitly_invocable", "yes_or_no", 13), ("security_type", "character_data", 13), ("to_sql_specific_catalog", "sql_identifier", 13), ("to_sql_specific_schema", "sql_identifier", 13), ("to_sql_specific_name", "sql_identifier", 13), ("as_locator", "yes_or_no", 13), ("created", "time_stamp", 13), ("last_altered", "time_stamp", 13), ("new_savepoint_level", "yes_or_no", 13), ("is_udt_dependent", "yes_or_no", 13), ("result_cast_from_data_type", "character_data", 13), ("result_cast_as_locator", "yes_or_no", 13), ("result_cast_char_max_length", "cardinal_number", 13), ("result_cast_char_octet_length", "cardinal_number", 13), ("result_cast_char_set_catalog", "sql_identifier", 13), ("result_cast_char_set_schema", "sql_identifier", 13), ("result_cast_char_set_name", "sql_identifier", 13), ("result_cast_collation_catalog", "sql_identifier", 13), ("result_cast_collation_schema", "sql_identifier", 13), ("result_cast_collation_name", "sql_identifier", 13), ("result_cast_numeric_precision", "cardinal_number", 13), ("result_cast_numeric_precision_radix", "cardinal_number", 13), ("result_cast_numeric_scale", "cardinal_number", 13), ("result_cast_datetime_precision", "cardinal_number", 13), ("result_cast_interval_type", "character_data", 13), ("result_cast_interval_precision", "cardinal_number", 13), ("result_cast_type_udt_catalog", "sql_identifier", 13), ("result_cast_type_udt_schema", "sql_identifier", 13), ("result_cast_type_udt_name", "sql_identifier", 13), ("result_cast_scope_catalog", "sql_identifier", 13), ("result_cast_scope_schema", "sql_identifier", 13), ("result_cast_scope_name", "sql_identifier", 13), ("result_cast_maximum_cardinality", "cardinal_number", 13), ("result_cast_dtd_identifier", "sql_identifier", 13), ], "schemata": [ ("catalog_name", "sql_identifier", 13), ("schema_name", "sql_identifier", 13), ("schema_owner", "sql_identifier", 13), ("default_character_set_catalog", "sql_identifier", 13), ("default_character_set_schema", "sql_identifier", 13), ("default_character_set_name", "sql_identifier", 13), ("sql_path", "character_data", 13), ], "sequences": [ ("sequence_catalog", "sql_identifier", 13), ("sequence_schema", "sql_identifier", 13), ("sequence_name", "sql_identifier", 13), ("data_type", "character_data", 13), ("numeric_precision", "cardinal_number", 13), ("numeric_precision_radix", "cardinal_number", 13), ("numeric_scale", "cardinal_number", 13), ("start_value", "character_data", 13), ("minimum_value", "character_data", 13), ("maximum_value", "character_data", 13), ("increment", "character_data", 13), ("cycle_option", "yes_or_no", 13), ], "sql_features": [ ("feature_id", "character_data", 13), ("feature_name", "character_data", 13), ("sub_feature_id", "character_data", 13), ("sub_feature_name", "character_data", 13), ("is_supported", "yes_or_no", 13), ("is_verified_by", "character_data", 13), ("comments", "character_data", 13), ], "sql_implementation_info": [ ("implementation_info_id", "character_data", 13), ("implementation_info_name", "character_data", 13), ("integer_value", "cardinal_number", 13), ("character_value", "character_data", 13), ("comments", "character_data", 13), ], "sql_parts": [ ("feature_id", "character_data", 13), ("feature_name", "character_data", 13), ("is_supported", "yes_or_no", 13), ("is_verified_by", "character_data", 13), ("comments", "character_data", 13), ], "sql_sizing": [ ("sizing_id", "cardinal_number", 13), ("sizing_name", "character_data", 13), ("supported_value", "cardinal_number", 13), ("comments", "character_data", 13), ], "table_constraints": [ ("constraint_catalog", "sql_identifier", 13), ("constraint_schema", "sql_identifier", 13), ("constraint_name", "sql_identifier", 13), ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("constraint_type", "character_data", 13), ("is_deferrable", "yes_or_no", 13), ("initially_deferred", "yes_or_no", 13), ("enforced", "yes_or_no", 13), ("nulls_distinct", "yes_or_no", 15), ], "table_privileges": [ ("grantor", "sql_identifier", 13), ("grantee", "sql_identifier", 13), ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("privilege_type", "character_data", 13), ("is_grantable", "yes_or_no", 13), ("with_hierarchy", "yes_or_no", 13), ], "tables": [ ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("table_type", "character_data", 13), ("self_referencing_column_name", "sql_identifier", 13), ("reference_generation", "character_data", 13), ("user_defined_type_catalog", "sql_identifier", 13), ("user_defined_type_schema", "sql_identifier", 13), ("user_defined_type_name", "sql_identifier", 13), ("is_insertable_into", "yes_or_no", 13), ("is_typed", "yes_or_no", 13), ("commit_action", "character_data", 13), ], "transforms": [ ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("specific_catalog", "sql_identifier", 13), ("specific_schema", "sql_identifier", 13), ("specific_name", "sql_identifier", 13), ("group_name", "sql_identifier", 13), ("transform_type", "character_data", 13), ], "triggered_update_columns": [ ("trigger_catalog", "sql_identifier", 13), ("trigger_schema", "sql_identifier", 13), ("trigger_name", "sql_identifier", 13), ("event_object_catalog", "sql_identifier", 13), ("event_object_schema", "sql_identifier", 13), ("event_object_table", "sql_identifier", 13), ("event_object_column", "sql_identifier", 13), ], "triggers": [ ("trigger_catalog", "sql_identifier", 13), ("trigger_schema", "sql_identifier", 13), ("trigger_name", "sql_identifier", 13), ("event_manipulation", "character_data", 13), ("event_object_catalog", "sql_identifier", 13), ("event_object_schema", "sql_identifier", 13), ("event_object_table", "sql_identifier", 13), ("action_order", "cardinal_number", 13), ("action_condition", "character_data", 13), ("action_statement", "character_data", 13), ("action_orientation", "character_data", 13), ("action_timing", "character_data", 13), ("action_reference_old_table", "sql_identifier", 13), ("action_reference_new_table", "sql_identifier", 13), ("action_reference_old_row", "sql_identifier", 13), ("action_reference_new_row", "sql_identifier", 13), ("created", "time_stamp", 13), ], "udt_privileges": [ ("grantor", "sql_identifier", 13), ("grantee", "sql_identifier", 13), ("udt_catalog", "sql_identifier", 13), ("udt_schema", "sql_identifier", 13), ("udt_name", "sql_identifier", 13), ("privilege_type", "character_data", 13), ("is_grantable", "yes_or_no", 13), ], "usage_privileges": [ ("grantor", "sql_identifier", 13), ("grantee", "sql_identifier", 13), ("object_catalog", "sql_identifier", 13), ("object_schema", "sql_identifier", 13), ("object_name", "sql_identifier", 13), ("object_type", "character_data", 13), ("privilege_type", "character_data", 13), ("is_grantable", "yes_or_no", 13), ], "user_defined_types": [ ("user_defined_type_catalog", "sql_identifier", 13), ("user_defined_type_schema", "sql_identifier", 13), ("user_defined_type_name", "sql_identifier", 13), ("user_defined_type_category", "character_data", 13), ("is_instantiable", "yes_or_no", 13), ("is_final", "yes_or_no", 13), ("ordering_form", "character_data", 13), ("ordering_category", "character_data", 13), ("ordering_routine_catalog", "sql_identifier", 13), ("ordering_routine_schema", "sql_identifier", 13), ("ordering_routine_name", "sql_identifier", 13), ("reference_type", "character_data", 13), ("data_type", "character_data", 13), ("character_maximum_length", "cardinal_number", 13), ("character_octet_length", "cardinal_number", 13), ("character_set_catalog", "sql_identifier", 13), ("character_set_schema", "sql_identifier", 13), ("character_set_name", "sql_identifier", 13), ("collation_catalog", "sql_identifier", 13), ("collation_schema", "sql_identifier", 13), ("collation_name", "sql_identifier", 13), ("numeric_precision", "cardinal_number", 13), ("numeric_precision_radix", "cardinal_number", 13), ("numeric_scale", "cardinal_number", 13), ("datetime_precision", "cardinal_number", 13), ("interval_type", "character_data", 13), ("interval_precision", "cardinal_number", 13), ("source_dtd_identifier", "sql_identifier", 13), ("ref_dtd_identifier", "sql_identifier", 13), ], "user_mapping_options": [ ("authorization_identifier", "sql_identifier", 13), ("foreign_server_catalog", "sql_identifier", 13), ("foreign_server_name", "sql_identifier", 13), ("option_name", "sql_identifier", 13), ("option_value", "character_data", 13), ], "user_mappings": [ ("authorization_identifier", "sql_identifier", 13), ("foreign_server_catalog", "sql_identifier", 13), ("foreign_server_name", "sql_identifier", 13), ], "view_column_usage": [ ("view_catalog", "sql_identifier", 13), ("view_schema", "sql_identifier", 13), ("view_name", "sql_identifier", 13), ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("column_name", "sql_identifier", 13), ], "view_routine_usage": [ ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("specific_catalog", "sql_identifier", 13), ("specific_schema", "sql_identifier", 13), ("specific_name", "sql_identifier", 13), ], "view_table_usage": [ ("view_catalog", "sql_identifier", 13), ("view_schema", "sql_identifier", 13), ("view_name", "sql_identifier", 13), ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ], "views": [ ("table_catalog", "sql_identifier", 13), ("table_schema", "sql_identifier", 13), ("table_name", "sql_identifier", 13), ("view_definition", "character_data", 13), ("check_option", "character_data", 13), ("is_updatable", "yes_or_no", 13), ("is_insertable_into", "yes_or_no", 13), ("is_trigger_updatable", "yes_or_no", 13), ("is_trigger_deletable", "yes_or_no", 13), ("is_trigger_insertable_into", "yes_or_no", 13), ] } PG_CATALOG: dict[str, list[tuple[ColumnName, ColumnType, int]]] = { "pg_aggregate": [ ("aggfnoid", "regproc", 13), ("aggkind", "\"char\"", 13), ("aggnumdirectargs", "smallint", 13), ("aggtransfn", "regproc", 13), ("aggfinalfn", "regproc", 13), ("aggcombinefn", "regproc", 13), ("aggserialfn", "regproc", 13), ("aggdeserialfn", "regproc", 13), ("aggmtransfn", "regproc", 13), ("aggminvtransfn", "regproc", 13), ("aggmfinalfn", "regproc", 13), ("aggfinalextra", "boolean", 13), ("aggmfinalextra", "boolean", 13), ("aggfinalmodify", "\"char\"", 13), ("aggmfinalmodify", "\"char\"", 13), ("aggsortop", "oid", 13), ("aggtranstype", "oid", 13), ("aggtransspace", "integer", 13), ("aggmtranstype", "oid", 13), ("aggmtransspace", "integer", 13), ("agginitval", "text", 13), ("aggminitval", "text", 13), ], "pg_am": [ ("oid", "oid", 13), ("amname", "name", 13), ("amhandler", "regproc", 13), ("amtype", "\"char\"", 13), ], "pg_amop": [ ("oid", "oid", 13), ("amopfamily", "oid", 13), ("amoplefttype", "oid", 13), ("amoprighttype", "oid", 13), ("amopstrategy", "smallint", 13), ("amoppurpose", "\"char\"", 13), ("amopopr", "oid", 13), ("amopmethod", "oid", 13), ("amopsortfamily", "oid", 13), ], "pg_amproc": [ ("oid", "oid", 13), ("amprocfamily", "oid", 13), ("amproclefttype", "oid", 13), ("amprocrighttype", "oid", 13), ("amprocnum", "smallint", 13), ("amproc", "regproc", 13), ], "pg_attrdef": [ ("oid", "oid", 13), ("adrelid", "oid", 13), ("adnum", "smallint", 13), ("adbin", "pg_node_tree", 13), ], "pg_attribute": [ ("attrelid", "oid", 13), ("attname", "name", 13), ("atttypid", "oid", 13), ("attlen", "smallint", 13), ("attnum", "smallint", 13), ("attcacheoff", "integer", 13), ("atttypmod", "integer", 13), ("attndims", "smallint", 13), ("attbyval", "boolean", 13), ("attalign", "\"char\"", 13), ("attstorage", "\"char\"", 13), ("attcompression", "\"char\"", 14), ("attnotnull", "boolean", 13), ("atthasdef", "boolean", 13), ("atthasmissing", "boolean", 13), ("attidentity", "\"char\"", 13), ("attgenerated", "\"char\"", 13), ("attisdropped", "boolean", 13), ("attislocal", "boolean", 13), ("attinhcount", "smallint", 13), ("attcollation", "oid", 13), ("attstattarget", "smallint", 13), ("attacl", None, 13), ("attoptions", None, 13), ("attfdwoptions", None, 13), ("attmissingval", None, 13), ], "pg_auth_members": [ ("oid", "oid", 16), ("roleid", "oid", 13), ("member", "oid", 13), ("grantor", "oid", 13), ("admin_option", "boolean", 13), ("inherit_option", "boolean", 16), ("set_option", "boolean", 16), ], "pg_authid": [ ("oid", "oid", 13), ("rolname", "name", 13), ("rolsuper", "boolean", 13), ("rolinherit", "boolean", 13), ("rolcreaterole", "boolean", 13), ("rolcreatedb", "boolean", 13), ("rolcanlogin", "boolean", 13), ("rolreplication", "boolean", 13), ("rolbypassrls", "boolean", 13), ("rolconnlimit", "integer", 13), ("rolpassword", "text", 13), ("rolvaliduntil", "timestamp with time zone", 13), ], "pg_available_extension_versions": [ ("name", "name", 13), ("version", "text", 13), ("installed", "boolean", 13), ("superuser", "boolean", 13), ("trusted", "boolean", 13), ("relocatable", "boolean", 13), ("schema", "name", 13), ("requires", None, 13), ("comment", "text", 13), ], "pg_available_extensions": [ ("name", "name", 13), ("default_version", "text", 13), ("installed_version", "text", 13), ("comment", "text", 13), ], "pg_backend_memory_contexts": [ ("name", "text", 14), ("ident", "text", 14), ("parent", "text", 14), ("level", "integer", 14), ("total_bytes", "bigint", 14), ("total_nblocks", "bigint", 14), ("free_bytes", "bigint", 14), ("free_chunks", "bigint", 14), ("used_bytes", "bigint", 14), ], "pg_cast": [ ("oid", "oid", 13), ("castsource", "oid", 13), ("casttarget", "oid", 13), ("castfunc", "oid", 13), ("castcontext", "\"char\"", 13), ("castmethod", "\"char\"", 13), ], "pg_class": [ ("oid", "oid", 13), ("relname", "name", 13), ("relnamespace", "oid", 13), ("reltype", "oid", 13), ("reloftype", "oid", 13), ("relowner", "oid", 13), ("relam", "oid", 13), ("relfilenode", "oid", 13), ("reltablespace", "oid", 13), ("relpages", "integer", 13), ("reltuples", "real", 13), ("relallvisible", "integer", 13), ("reltoastrelid", "oid", 13), ("relhasindex", "boolean", 13), ("relisshared", "boolean", 13), ("relpersistence", "\"char\"", 13), ("relkind", "\"char\"", 13), ("relnatts", "smallint", 13), ("relchecks", "smallint", 13), ("relhasrules", "boolean", 13), ("relhastriggers", "boolean", 13), ("relhassubclass", "boolean", 13), ("relrowsecurity", "boolean", 13), ("relforcerowsecurity", "boolean", 13), ("relispopulated", "boolean", 13), ("relreplident", "\"char\"", 13), ("relispartition", "boolean", 13), ("relrewrite", "oid", 13), ("relfrozenxid", "xid", 13), ("relminmxid", "xid", 13), ("relacl", None, 13), ("reloptions", None, 13), ("relpartbound", "pg_node_tree", 13), ], "pg_collation": [ ("oid", "oid", 13), ("collname", "name", 13), ("collnamespace", "oid", 13), ("collowner", "oid", 13), ("collprovider", "\"char\"", 13), ("collisdeterministic", "boolean", 13), ("collencoding", "integer", 13), ("collcollate", "text", 13), ("collctype", "text", 13), ("colllocale", "text", 17), ("collicurules", "text", 16), ("collversion", "text", 13), ], "pg_config": [ ("name", "text", 13), ("setting", "text", 13), ], "pg_constraint": [ ("oid", "oid", 13), ("conname", "name", 13), ("connamespace", "oid", 13), ("contype", "\"char\"", 13), ("condeferrable", "boolean", 13), ("condeferred", "boolean", 13), ("convalidated", "boolean", 13), ("conrelid", "oid", 13), ("contypid", "oid", 13), ("conindid", "oid", 13), ("conparentid", "oid", 13), ("confrelid", "oid", 13), ("confupdtype", "\"char\"", 13), ("confdeltype", "\"char\"", 13), ("confmatchtype", "\"char\"", 13), ("conislocal", "boolean", 13), ("coninhcount", "smallint", 13), ("connoinherit", "boolean", 13), ("conkey", None, 13), ("confkey", None, 13), ("conpfeqop", None, 13), ("conppeqop", None, 13), ("conffeqop", None, 13), ("confdelsetcols", None, 15), ("conexclop", None, 13), ("conbin", "pg_node_tree", 13), ], "pg_conversion": [ ("oid", "oid", 13), ("conname", "name", 13), ("connamespace", "oid", 13), ("conowner", "oid", 13), ("conforencoding", "integer", 13), ("contoencoding", "integer", 13), ("conproc", "regproc", 13), ("condefault", "boolean", 13), ], "pg_cursors": [ ("name", "text", 13), ("statement", "text", 13), ("is_holdable", "boolean", 13), ("is_binary", "boolean", 13), ("is_scrollable", "boolean", 13), ("creation_time", "timestamp with time zone", 13), ], "pg_database": [ ("oid", "oid", 13), ("datname", "name", 13), ("datdba", "oid", 13), ("encoding", "integer", 13), ("datlocprovider", "\"char\"", 15), ("datistemplate", "boolean", 13), ("datallowconn", "boolean", 13), ("dathasloginevt", "boolean", 17), ("datconnlimit", "integer", 13), ("datfrozenxid", "xid", 13), ("datminmxid", "xid", 13), ("dattablespace", "oid", 13), ("datcollate", "text", 13), ("datctype", "text", 13), ("datlocale", "text", 17), ("daticurules", "text", 16), ("datcollversion", "text", 15), ("datacl", None, 13), ], "pg_db_role_setting": [ ("setdatabase", "oid", 13), ("setrole", "oid", 13), ("setconfig", None, 13), ], "pg_default_acl": [ ("oid", "oid", 13), ("defaclrole", "oid", 13), ("defaclnamespace", "oid", 13), ("defaclobjtype", "\"char\"", 13), ("defaclacl", None, 13), ], "pg_depend": [ ("classid", "oid", 13), ("objid", "oid", 13), ("objsubid", "integer", 13), ("refclassid", "oid", 13), ("refobjid", "oid", 13), ("refobjsubid", "integer", 13), ("deptype", "\"char\"", 13), ], "pg_description": [ ("objoid", "oid", 13), ("classoid", "oid", 13), ("objsubid", "integer", 13), ("description", "text", 13), ], "pg_enum": [ ("oid", "oid", 13), ("enumtypid", "oid", 13), ("enumsortorder", "real", 13), ("enumlabel", "name", 13), ], "pg_event_trigger": [ ("oid", "oid", 13), ("evtname", "name", 13), ("evtevent", "name", 13), ("evtowner", "oid", 13), ("evtfoid", "oid", 13), ("evtenabled", "\"char\"", 13), ("evttags", None, 13), ], "pg_extension": [ ("oid", "oid", 13), ("extname", "name", 13), ("extowner", "oid", 13), ("extnamespace", "oid", 13), ("extrelocatable", "boolean", 13), ("extversion", "text", 13), ("extconfig", None, 13), ("extcondition", None, 13), ], "pg_file_settings": [ ("sourcefile", "text", 13), ("sourceline", "integer", 13), ("seqno", "integer", 13), ("name", "text", 13), ("setting", "text", 13), ("applied", "boolean", 13), ("error", "text", 13), ], "pg_foreign_data_wrapper": [ ("oid", "oid", 13), ("fdwname", "name", 13), ("fdwowner", "oid", 13), ("fdwhandler", "oid", 13), ("fdwvalidator", "oid", 13), ("fdwacl", None, 13), ("fdwoptions", None, 13), ], "pg_foreign_server": [ ("oid", "oid", 13), ("srvname", "name", 13), ("srvowner", "oid", 13), ("srvfdw", "oid", 13), ("srvtype", "text", 13), ("srvversion", "text", 13), ("srvacl", None, 13), ("srvoptions", None, 13), ], "pg_foreign_table": [ ("ftrelid", "oid", 13), ("ftserver", "oid", 13), ("ftoptions", None, 13), ], "pg_group": [ ("groname", "name", 13), ("grosysid", "oid", 13), ("grolist", None, 13), ], "pg_hba_file_rules": [ ("rule_number", "integer", 16), ("file_name", "text", 16), ("line_number", "integer", 13), ("type", "text", 13), ("database", None, 13), ("user_name", None, 13), ("address", "text", 13), ("netmask", "text", 13), ("auth_method", "text", 13), ("options", None, 13), ("error", "text", 13), ], "pg_ident_file_mappings": [ ("map_number", "integer", 16), ("file_name", "text", 16), ("line_number", "integer", 15), ("map_name", "text", 15), ("sys_name", "text", 15), ("pg_username", "text", 15), ("error", "text", 15), ], "pg_index": [ ("indexrelid", "oid", 13), ("indrelid", "oid", 13), ("indnatts", "smallint", 13), ("indnkeyatts", "smallint", 13), ("indisunique", "boolean", 13), ("indnullsnotdistinct", "boolean", 15), ("indisprimary", "boolean", 13), ("indisexclusion", "boolean", 13), ("indimmediate", "boolean", 13), ("indisclustered", "boolean", 13), ("indisvalid", "boolean", 13), ("indcheckxmin", "boolean", 13), ("indisready", "boolean", 13), ("indislive", "boolean", 13), ("indisreplident", "boolean", 13), ("indkey", None, 13), ("indcollation", None, 13), ("indclass", None, 13), ("indoption", None, 13), ("indexprs", "pg_node_tree", 13), ("indpred", "pg_node_tree", 13), ], "pg_indexes": [ ("schemaname", "name", 13), ("tablename", "name", 13), ("indexname", "name", 13), ("tablespace", "name", 13), ("indexdef", "text", 13), ], "pg_inherits": [ ("inhrelid", "oid", 13), ("inhparent", "oid", 13), ("inhseqno", "integer", 13), ("inhdetachpending", "boolean", 14), ], "pg_init_privs": [ ("objoid", "oid", 13), ("classoid", "oid", 13), ("objsubid", "integer", 13), ("privtype", "\"char\"", 13), ("initprivs", None, 13), ], "pg_language": [ ("oid", "oid", 13), ("lanname", "name", 13), ("lanowner", "oid", 13), ("lanispl", "boolean", 13), ("lanpltrusted", "boolean", 13), ("lanplcallfoid", "oid", 13), ("laninline", "oid", 13), ("lanvalidator", "oid", 13), ("lanacl", None, 13), ], "pg_largeobject": [ ("loid", "oid", 13), ("pageno", "integer", 13), ("data", "bytea", 13), ], "pg_largeobject_metadata": [ ("oid", "oid", 13), ("lomowner", "oid", 13), ("lomacl", None, 13), ], "pg_locks": [ ("locktype", "text", 13), ("database", "oid", 13), ("relation", "oid", 13), ("page", "integer", 13), ("tuple", "smallint", 13), ("virtualxid", "text", 13), ("transactionid", "xid", 13), ("classid", "oid", 13), ("objid", "oid", 13), ("objsubid", "smallint", 13), ("virtualtransaction", "text", 13), ("pid", "integer", 13), ("mode", "text", 13), ("granted", "boolean", 13), ("fastpath", "boolean", 13), ("waitstart", "timestamp with time zone", 14), ], "pg_matviews": [ ("schemaname", "name", 13), ("matviewname", "name", 13), ("matviewowner", "name", 13), ("tablespace", "name", 13), ("hasindexes", "boolean", 13), ("ispopulated", "boolean", 13), ("definition", "text", 13), ], "pg_namespace": [ ("oid", "oid", 13), ("nspname", "name", 13), ("nspowner", "oid", 13), ("nspacl", None, 13), ], "pg_opclass": [ ("oid", "oid", 13), ("opcmethod", "oid", 13), ("opcname", "name", 13), ("opcnamespace", "oid", 13), ("opcowner", "oid", 13), ("opcfamily", "oid", 13), ("opcintype", "oid", 13), ("opcdefault", "boolean", 13), ("opckeytype", "oid", 13), ], "pg_operator": [ ("oid", "oid", 13), ("oprname", "name", 13), ("oprnamespace", "oid", 13), ("oprowner", "oid", 13), ("oprkind", "\"char\"", 13), ("oprcanmerge", "boolean", 13), ("oprcanhash", "boolean", 13), ("oprleft", "oid", 13), ("oprright", "oid", 13), ("oprresult", "oid", 13), ("oprcom", "oid", 13), ("oprnegate", "oid", 13), ("oprcode", "regproc", 13), ("oprrest", "regproc", 13), ("oprjoin", "regproc", 13), ], "pg_opfamily": [ ("oid", "oid", 13), ("opfmethod", "oid", 13), ("opfname", "name", 13), ("opfnamespace", "oid", 13), ("opfowner", "oid", 13), ], "pg_parameter_acl": [ ("oid", "oid", 15), ("parname", "text", 15), ("paracl", None, 15), ], "pg_partitioned_table": [ ("partrelid", "oid", 13), ("partstrat", "\"char\"", 13), ("partnatts", "smallint", 13), ("partdefid", "oid", 13), ("partattrs", None, 13), ("partclass", None, 13), ("partcollation", None, 13), ("partexprs", "pg_node_tree", 13), ], "pg_policies": [ ("schemaname", "name", 13), ("tablename", "name", 13), ("policyname", "name", 13), ("permissive", "text", 13), ("roles", None, 13), ("cmd", "text", 13), ("qual", "text", 13), ("with_check", "text", 13), ], "pg_policy": [ ("oid", "oid", 13), ("polname", "name", 13), ("polrelid", "oid", 13), ("polcmd", "\"char\"", 13), ("polpermissive", "boolean", 13), ("polroles", None, 13), ("polqual", "pg_node_tree", 13), ("polwithcheck", "pg_node_tree", 13), ], "pg_prepared_statements": [ ("name", "text", 13), ("statement", "text", 13), ("prepare_time", "timestamp with time zone", 13), ("parameter_types", None, 13), ("result_types", None, 16), ("from_sql", "boolean", 13), ("generic_plans", "bigint", 14), ("custom_plans", "bigint", 14), ], "pg_prepared_xacts": [ ("transaction", "xid", 13), ("gid", "text", 13), ("prepared", "timestamp with time zone", 13), ("owner", "name", 13), ("database", "name", 13), ], "pg_proc": [ ("oid", "oid", 13), ("proname", "name", 13), ("pronamespace", "oid", 13), ("proowner", "oid", 13), ("prolang", "oid", 13), ("procost", "real", 13), ("prorows", "real", 13), ("provariadic", "oid", 13), ("prosupport", "regproc", 13), ("prokind", "\"char\"", 13), ("prosecdef", "boolean", 13), ("proleakproof", "boolean", 13), ("proisstrict", "boolean", 13), ("proretset", "boolean", 13), ("provolatile", "\"char\"", 13), ("proparallel", "\"char\"", 13), ("pronargs", "smallint", 13), ("pronargdefaults", "smallint", 13), ("prorettype", "oid", 13), ("proargtypes", None, 13), ("proallargtypes", None, 13), ("proargmodes", None, 13), ("proargnames", None, 13), ("proargdefaults", "pg_node_tree", 13), ("protrftypes", None, 13), ("prosrc", "text", 13), ("probin", "text", 13), ("prosqlbody", "pg_node_tree", 14), ("proconfig", None, 13), ("proacl", None, 13), ], "pg_publication": [ ("oid", "oid", 13), ("pubname", "name", 13), ("pubowner", "oid", 13), ("puballtables", "boolean", 13), ("pubinsert", "boolean", 13), ("pubupdate", "boolean", 13), ("pubdelete", "boolean", 13), ("pubtruncate", "boolean", 13), ("pubviaroot", "boolean", 13), ], "pg_publication_namespace": [ ("oid", "oid", 15), ("pnpubid", "oid", 15), ("pnnspid", "oid", 15), ], "pg_publication_rel": [ ("oid", "oid", 13), ("prpubid", "oid", 13), ("prrelid", "oid", 13), ("prqual", "pg_node_tree", 15), ("prattrs", None, 15), ], "pg_publication_tables": [ ("pubname", "name", 13), ("schemaname", "name", 13), ("tablename", "name", 13), ("attnames", None, 15), ("rowfilter", "text", 15), ], "pg_range": [ ("rngtypid", "oid", 13), ("rngsubtype", "oid", 13), ("rngmultitypid", "oid", 14), ("rngcollation", "oid", 13), ("rngsubopc", "oid", 13), ("rngcanonical", "regproc", 13), ("rngsubdiff", "regproc", 13), ], "pg_replication_origin": [ ("roident", "oid", 13), ("roname", "text", 13), ], "pg_replication_origin_status": [ ("local_id", "oid", 13), ("external_id", "text", 13), ("remote_lsn", "pg_lsn", 13), ("local_lsn", "pg_lsn", 13), ], "pg_replication_slots": [ ("slot_name", "name", 13), ("plugin", "name", 13), ("slot_type", "text", 13), ("datoid", "oid", 13), ("database", "name", 13), ("temporary", "boolean", 13), ("active", "boolean", 13), ("active_pid", "integer", 13), ("xmin", "xid", 13), ("catalog_xmin", "xid", 13), ("restart_lsn", "pg_lsn", 13), ("confirmed_flush_lsn", "pg_lsn", 13), ("wal_status", "text", 13), ("safe_wal_size", "bigint", 13), ("two_phase", "boolean", 14), ("inactive_since", "timestamp with time zone", 17), ("conflicting", "boolean", 16), ("invalidation_reason", "text", 17), ("failover", "boolean", 17), ("synced", "boolean", 17), ], "pg_rewrite": [ ("oid", "oid", 13), ("rulename", "name", 13), ("ev_class", "oid", 13), ("ev_type", "\"char\"", 13), ("ev_enabled", "\"char\"", 13), ("is_instead", "boolean", 13), ("ev_qual", "pg_node_tree", 13), ("ev_action", "pg_node_tree", 13), ], "pg_roles": [ ("rolname", "name", 13), ("rolsuper", "boolean", 13), ("rolinherit", "boolean", 13), ("rolcreaterole", "boolean", 13), ("rolcreatedb", "boolean", 13), ("rolcanlogin", "boolean", 13), ("rolreplication", "boolean", 13), ("rolconnlimit", "integer", 13), ("rolpassword", "text", 13), ("rolvaliduntil", "timestamp with time zone", 13), ("rolbypassrls", "boolean", 13), ("rolconfig", None, 13), ("oid", "oid", 13), ], "pg_rules": [ ("schemaname", "name", 13), ("tablename", "name", 13), ("rulename", "name", 13), ("definition", "text", 13), ], "pg_seclabel": [ ("objoid", "oid", 13), ("classoid", "oid", 13), ("objsubid", "integer", 13), ("provider", "text", 13), ("label", "text", 13), ], "pg_seclabels": [ ("objoid", "oid", 13), ("classoid", "oid", 13), ("objsubid", "integer", 13), ("objtype", "text", 13), ("objnamespace", "oid", 13), ("objname", "text", 13), ("provider", "text", 13), ("label", "text", 13), ], "pg_sequence": [ ("seqrelid", "oid", 13), ("seqtypid", "oid", 13), ("seqstart", "bigint", 13), ("seqincrement", "bigint", 13), ("seqmax", "bigint", 13), ("seqmin", "bigint", 13), ("seqcache", "bigint", 13), ("seqcycle", "boolean", 13), ], "pg_sequences": [ ("schemaname", "name", 13), ("sequencename", "name", 13), ("sequenceowner", "name", 13), ("data_type", "regtype", 13), ("start_value", "bigint", 13), ("min_value", "bigint", 13), ("max_value", "bigint", 13), ("increment_by", "bigint", 13), ("cycle", "boolean", 13), ("cache_size", "bigint", 13), ("last_value", "bigint", 13), ], "pg_settings": [ ("name", "text", 13), ("setting", "text", 13), ("unit", "text", 13), ("category", "text", 13), ("short_desc", "text", 13), ("extra_desc", "text", 13), ("context", "text", 13), ("vartype", "text", 13), ("source", "text", 13), ("min_val", "text", 13), ("max_val", "text", 13), ("enumvals", None, 13), ("boot_val", "text", 13), ("reset_val", "text", 13), ("sourcefile", "text", 13), ("sourceline", "integer", 13), ("pending_restart", "boolean", 13), ], "pg_shadow": [ ("usename", "name", 13), ("usesysid", "oid", 13), ("usecreatedb", "boolean", 13), ("usesuper", "boolean", 13), ("userepl", "boolean", 13), ("usebypassrls", "boolean", 13), ("passwd", "text", 13), ("valuntil", "timestamp with time zone", 13), ("useconfig", None, 13), ], "pg_shdepend": [ ("dbid", "oid", 13), ("classid", "oid", 13), ("objid", "oid", 13), ("objsubid", "integer", 13), ("refclassid", "oid", 13), ("refobjid", "oid", 13), ("deptype", "\"char\"", 13), ], "pg_shdescription": [ ("objoid", "oid", 13), ("classoid", "oid", 13), ("description", "text", 13), ], "pg_shmem_allocations": [ ("name", "text", 13), ("off", "bigint", 13), ("size", "bigint", 13), ("allocated_size", "bigint", 13), ], "pg_shseclabel": [ ("objoid", "oid", 13), ("classoid", "oid", 13), ("provider", "text", 13), ("label", "text", 13), ], "pg_stat_activity": [ ("datid", "oid", 13), ("datname", "name", 13), ("pid", "integer", 13), ("leader_pid", "integer", 13), ("usesysid", "oid", 13), ("usename", "name", 13), ("application_name", "text", 13), ("client_addr", "inet", 13), ("client_hostname", "text", 13), ("client_port", "integer", 13), ("backend_start", "timestamp with time zone", 13), ("xact_start", "timestamp with time zone", 13), ("query_start", "timestamp with time zone", 13), ("state_change", "timestamp with time zone", 13), ("wait_event_type", "text", 13), ("wait_event", "text", 13), ("state", "text", 13), ("backend_xid", "xid", 13), ("backend_xmin", "xid", 13), ("query_id", "bigint", 14), ("query", "text", 13), ("backend_type", "text", 13), ], "pg_stat_all_indexes": [ ("relid", "oid", 13), ("indexrelid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("indexrelname", "name", 13), ("idx_scan", "bigint", 13), ("last_idx_scan", "timestamp with time zone", 16), ("idx_tup_read", "bigint", 13), ("idx_tup_fetch", "bigint", 13), ], "pg_stat_all_tables": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("seq_scan", "bigint", 13), ("last_seq_scan", "timestamp with time zone", 16), ("seq_tup_read", "bigint", 13), ("idx_scan", "bigint", 13), ("last_idx_scan", "timestamp with time zone", 16), ("idx_tup_fetch", "bigint", 13), ("n_tup_ins", "bigint", 13), ("n_tup_upd", "bigint", 13), ("n_tup_del", "bigint", 13), ("n_tup_hot_upd", "bigint", 13), ("n_tup_newpage_upd", "bigint", 16), ("n_live_tup", "bigint", 13), ("n_dead_tup", "bigint", 13), ("n_mod_since_analyze", "bigint", 13), ("n_ins_since_vacuum", "bigint", 13), ("last_vacuum", "timestamp with time zone", 13), ("last_autovacuum", "timestamp with time zone", 13), ("last_analyze", "timestamp with time zone", 13), ("last_autoanalyze", "timestamp with time zone", 13), ("vacuum_count", "bigint", 13), ("autovacuum_count", "bigint", 13), ("analyze_count", "bigint", 13), ("autoanalyze_count", "bigint", 13), ], "pg_stat_archiver": [ ("archived_count", "bigint", 13), ("last_archived_wal", "text", 13), ("last_archived_time", "timestamp with time zone", 13), ("failed_count", "bigint", 13), ("last_failed_wal", "text", 13), ("last_failed_time", "timestamp with time zone", 13), ("stats_reset", "timestamp with time zone", 13), ], "pg_stat_bgwriter": [ ("buffers_clean", "bigint", 13), ("maxwritten_clean", "bigint", 13), ("buffers_alloc", "bigint", 13), ("stats_reset", "timestamp with time zone", 13), ], "pg_stat_checkpointer": [ ("num_timed", "bigint", 17), ("num_requested", "bigint", 17), ("restartpoints_timed", "bigint", 17), ("restartpoints_req", "bigint", 17), ("restartpoints_done", "bigint", 17), ("write_time", "double precision", 17), ("sync_time", "double precision", 17), ("buffers_written", "bigint", 17), ("stats_reset", "timestamp with time zone", 17), ], "pg_stat_database": [ ("datid", "oid", 13), ("datname", "name", 13), ("numbackends", "integer", 13), ("xact_commit", "bigint", 13), ("xact_rollback", "bigint", 13), ("blks_read", "bigint", 13), ("blks_hit", "bigint", 13), ("tup_returned", "bigint", 13), ("tup_fetched", "bigint", 13), ("tup_inserted", "bigint", 13), ("tup_updated", "bigint", 13), ("tup_deleted", "bigint", 13), ("conflicts", "bigint", 13), ("temp_files", "bigint", 13), ("temp_bytes", "bigint", 13), ("deadlocks", "bigint", 13), ("checksum_failures", "bigint", 13), ("checksum_last_failure", "timestamp with time zone", 13), ("blk_read_time", "double precision", 13), ("blk_write_time", "double precision", 13), ("session_time", "double precision", 14), ("active_time", "double precision", 14), ("idle_in_transaction_time", "double precision", 14), ("sessions", "bigint", 14), ("sessions_abandoned", "bigint", 14), ("sessions_fatal", "bigint", 14), ("sessions_killed", "bigint", 14), ("stats_reset", "timestamp with time zone", 13), ], "pg_stat_database_conflicts": [ ("datid", "oid", 13), ("datname", "name", 13), ("confl_tablespace", "bigint", 13), ("confl_lock", "bigint", 13), ("confl_snapshot", "bigint", 13), ("confl_bufferpin", "bigint", 13), ("confl_deadlock", "bigint", 13), ("confl_active_logicalslot", "bigint", 16), ], "pg_stat_gssapi": [ ("pid", "integer", 13), ("gss_authenticated", "boolean", 13), ("principal", "text", 13), ("encrypted", "boolean", 13), ("credentials_delegated", "boolean", 16), ], "pg_stat_io": [ ("backend_type", "text", 16), ("object", "text", 16), ("context", "text", 16), ("reads", "bigint", 16), ("read_time", "double precision", 16), ("writes", "bigint", 16), ("write_time", "double precision", 16), ("writebacks", "bigint", 16), ("writeback_time", "double precision", 16), ("extends", "bigint", 16), ("extend_time", "double precision", 16), ("op_bytes", "bigint", 16), ("hits", "bigint", 16), ("evictions", "bigint", 16), ("reuses", "bigint", 16), ("fsyncs", "bigint", 16), ("fsync_time", "double precision", 16), ("stats_reset", "timestamp with time zone", 16), ], "pg_stat_progress_analyze": [ ("pid", "integer", 13), ("datid", "oid", 13), ("datname", "name", 13), ("relid", "oid", 13), ("phase", "text", 13), ("sample_blks_total", "bigint", 13), ("sample_blks_scanned", "bigint", 13), ("ext_stats_total", "bigint", 13), ("ext_stats_computed", "bigint", 13), ("child_tables_total", "bigint", 13), ("child_tables_done", "bigint", 13), ("current_child_table_relid", "oid", 13), ], "pg_stat_progress_basebackup": [ ("pid", "integer", 13), ("phase", "text", 13), ("backup_total", "bigint", 13), ("backup_streamed", "bigint", 13), ("tablespaces_total", "bigint", 13), ("tablespaces_streamed", "bigint", 13), ], "pg_stat_progress_cluster": [ ("pid", "integer", 13), ("datid", "oid", 13), ("datname", "name", 13), ("relid", "oid", 13), ("command", "text", 13), ("phase", "text", 13), ("cluster_index_relid", "oid", 13), ("heap_tuples_scanned", "bigint", 13), ("heap_tuples_written", "bigint", 13), ("heap_blks_total", "bigint", 13), ("heap_blks_scanned", "bigint", 13), ("index_rebuild_count", "bigint", 13), ], "pg_stat_progress_copy": [ ("pid", "integer", 14), ("datid", "oid", 14), ("datname", "name", 14), ("relid", "oid", 14), ("command", "text", 14), ("type", "text", 14), ("bytes_processed", "bigint", 14), ("bytes_total", "bigint", 14), ("tuples_processed", "bigint", 14), ("tuples_excluded", "bigint", 14), ("tuples_skipped", "bigint", 17), ], "pg_stat_progress_create_index": [ ("pid", "integer", 13), ("datid", "oid", 13), ("datname", "name", 13), ("relid", "oid", 13), ("index_relid", "oid", 13), ("command", "text", 13), ("phase", "text", 13), ("lockers_total", "bigint", 13), ("lockers_done", "bigint", 13), ("current_locker_pid", "bigint", 13), ("blocks_total", "bigint", 13), ("blocks_done", "bigint", 13), ("tuples_total", "bigint", 13), ("tuples_done", "bigint", 13), ("partitions_total", "bigint", 13), ("partitions_done", "bigint", 13), ], "pg_stat_progress_vacuum": [ ("pid", "integer", 13), ("datid", "oid", 13), ("datname", "name", 13), ("relid", "oid", 13), ("phase", "text", 13), ("heap_blks_total", "bigint", 13), ("heap_blks_scanned", "bigint", 13), ("heap_blks_vacuumed", "bigint", 13), ("index_vacuum_count", "bigint", 13), ("max_dead_tuple_bytes", "bigint", 17), ("dead_tuple_bytes", "bigint", 17), ("num_dead_item_ids", "bigint", 17), ("indexes_total", "bigint", 17), ("indexes_processed", "bigint", 17), ], "pg_stat_recovery_prefetch": [ ("stats_reset", "timestamp with time zone", 15), ("prefetch", "bigint", 15), ("hit", "bigint", 15), ("skip_init", "bigint", 15), ("skip_new", "bigint", 15), ("skip_fpw", "bigint", 15), ("skip_rep", "bigint", 15), ("wal_distance", "integer", 15), ("block_distance", "integer", 15), ("io_depth", "integer", 15), ], "pg_stat_replication": [ ("pid", "integer", 13), ("usesysid", "oid", 13), ("usename", "name", 13), ("application_name", "text", 13), ("client_addr", "inet", 13), ("client_hostname", "text", 13), ("client_port", "integer", 13), ("backend_start", "timestamp with time zone", 13), ("backend_xmin", "xid", 13), ("state", "text", 13), ("sent_lsn", "pg_lsn", 13), ("write_lsn", "pg_lsn", 13), ("flush_lsn", "pg_lsn", 13), ("replay_lsn", "pg_lsn", 13), ("write_lag", "interval", 13), ("flush_lag", "interval", 13), ("replay_lag", "interval", 13), ("sync_priority", "integer", 13), ("sync_state", "text", 13), ("reply_time", "timestamp with time zone", 13), ], "pg_stat_replication_slots": [ ("slot_name", "text", 14), ("spill_txns", "bigint", 14), ("spill_count", "bigint", 14), ("spill_bytes", "bigint", 14), ("stream_txns", "bigint", 14), ("stream_count", "bigint", 14), ("stream_bytes", "bigint", 14), ("total_txns", "bigint", 14), ("total_bytes", "bigint", 14), ("stats_reset", "timestamp with time zone", 14), ], "pg_stat_slru": [ ("name", "text", 13), ("blks_zeroed", "bigint", 13), ("blks_hit", "bigint", 13), ("blks_read", "bigint", 13), ("blks_written", "bigint", 13), ("blks_exists", "bigint", 13), ("flushes", "bigint", 13), ("truncates", "bigint", 13), ("stats_reset", "timestamp with time zone", 13), ], "pg_stat_ssl": [ ("pid", "integer", 13), ("ssl", "boolean", 13), ("version", "text", 13), ("cipher", "text", 13), ("bits", "integer", 13), ("client_dn", "text", 13), ("client_serial", "numeric", 13), ("issuer_dn", "text", 13), ], "pg_stat_subscription": [ ("subid", "oid", 13), ("subname", "name", 13), ("worker_type", "text", 17), ("pid", "integer", 13), ("leader_pid", "integer", 16), ("relid", "oid", 13), ("received_lsn", "pg_lsn", 13), ("last_msg_send_time", "timestamp with time zone", 13), ("last_msg_receipt_time", "timestamp with time zone", 13), ("latest_end_lsn", "pg_lsn", 13), ("latest_end_time", "timestamp with time zone", 13), ], "pg_stat_subscription_stats": [ ("subid", "oid", 15), ("subname", "name", 15), ("apply_error_count", "bigint", 15), ("sync_error_count", "bigint", 15), ("stats_reset", "timestamp with time zone", 15), ], "pg_stat_sys_indexes": [ ("relid", "oid", 13), ("indexrelid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("indexrelname", "name", 13), ("idx_scan", "bigint", 13), ("last_idx_scan", "timestamp with time zone", 16), ("idx_tup_read", "bigint", 13), ("idx_tup_fetch", "bigint", 13), ], "pg_stat_sys_tables": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("seq_scan", "bigint", 13), ("last_seq_scan", "timestamp with time zone", 16), ("seq_tup_read", "bigint", 13), ("idx_scan", "bigint", 13), ("last_idx_scan", "timestamp with time zone", 16), ("idx_tup_fetch", "bigint", 13), ("n_tup_ins", "bigint", 13), ("n_tup_upd", "bigint", 13), ("n_tup_del", "bigint", 13), ("n_tup_hot_upd", "bigint", 13), ("n_tup_newpage_upd", "bigint", 16), ("n_live_tup", "bigint", 13), ("n_dead_tup", "bigint", 13), ("n_mod_since_analyze", "bigint", 13), ("n_ins_since_vacuum", "bigint", 13), ("last_vacuum", "timestamp with time zone", 13), ("last_autovacuum", "timestamp with time zone", 13), ("last_analyze", "timestamp with time zone", 13), ("last_autoanalyze", "timestamp with time zone", 13), ("vacuum_count", "bigint", 13), ("autovacuum_count", "bigint", 13), ("analyze_count", "bigint", 13), ("autoanalyze_count", "bigint", 13), ], "pg_stat_user_functions": [ ("funcid", "oid", 13), ("schemaname", "name", 13), ("funcname", "name", 13), ("calls", "bigint", 13), ("total_time", "double precision", 13), ("self_time", "double precision", 13), ], "pg_stat_user_indexes": [ ("relid", "oid", 13), ("indexrelid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("indexrelname", "name", 13), ("idx_scan", "bigint", 13), ("last_idx_scan", "timestamp with time zone", 16), ("idx_tup_read", "bigint", 13), ("idx_tup_fetch", "bigint", 13), ], "pg_stat_user_tables": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("seq_scan", "bigint", 13), ("last_seq_scan", "timestamp with time zone", 16), ("seq_tup_read", "bigint", 13), ("idx_scan", "bigint", 13), ("last_idx_scan", "timestamp with time zone", 16), ("idx_tup_fetch", "bigint", 13), ("n_tup_ins", "bigint", 13), ("n_tup_upd", "bigint", 13), ("n_tup_del", "bigint", 13), ("n_tup_hot_upd", "bigint", 13), ("n_tup_newpage_upd", "bigint", 16), ("n_live_tup", "bigint", 13), ("n_dead_tup", "bigint", 13), ("n_mod_since_analyze", "bigint", 13), ("n_ins_since_vacuum", "bigint", 13), ("last_vacuum", "timestamp with time zone", 13), ("last_autovacuum", "timestamp with time zone", 13), ("last_analyze", "timestamp with time zone", 13), ("last_autoanalyze", "timestamp with time zone", 13), ("vacuum_count", "bigint", 13), ("autovacuum_count", "bigint", 13), ("analyze_count", "bigint", 13), ("autoanalyze_count", "bigint", 13), ], "pg_stat_wal": [ ("wal_records", "bigint", 14), ("wal_fpi", "bigint", 14), ("wal_bytes", "numeric", 14), ("wal_buffers_full", "bigint", 14), ("wal_write", "bigint", 14), ("wal_sync", "bigint", 14), ("wal_write_time", "double precision", 14), ("wal_sync_time", "double precision", 14), ("stats_reset", "timestamp with time zone", 14), ], "pg_stat_wal_receiver": [ ("pid", "integer", 13), ("status", "text", 13), ("receive_start_lsn", "pg_lsn", 13), ("receive_start_tli", "integer", 13), ("written_lsn", "pg_lsn", 13), ("flushed_lsn", "pg_lsn", 13), ("received_tli", "integer", 13), ("last_msg_send_time", "timestamp with time zone", 13), ("last_msg_receipt_time", "timestamp with time zone", 13), ("latest_end_lsn", "pg_lsn", 13), ("latest_end_time", "timestamp with time zone", 13), ("slot_name", "text", 13), ("sender_host", "text", 13), ("sender_port", "integer", 13), ("conninfo", "text", 13), ], "pg_stat_xact_all_tables": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("seq_scan", "bigint", 13), ("seq_tup_read", "bigint", 13), ("idx_scan", "bigint", 13), ("idx_tup_fetch", "bigint", 13), ("n_tup_ins", "bigint", 13), ("n_tup_upd", "bigint", 13), ("n_tup_del", "bigint", 13), ("n_tup_hot_upd", "bigint", 13), ("n_tup_newpage_upd", "bigint", 16), ], "pg_stat_xact_sys_tables": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("seq_scan", "bigint", 13), ("seq_tup_read", "bigint", 13), ("idx_scan", "bigint", 13), ("idx_tup_fetch", "bigint", 13), ("n_tup_ins", "bigint", 13), ("n_tup_upd", "bigint", 13), ("n_tup_del", "bigint", 13), ("n_tup_hot_upd", "bigint", 13), ("n_tup_newpage_upd", "bigint", 16), ], "pg_stat_xact_user_functions": [ ("funcid", "oid", 13), ("schemaname", "name", 13), ("funcname", "name", 13), ("calls", "bigint", 13), ("total_time", "double precision", 13), ("self_time", "double precision", 13), ], "pg_stat_xact_user_tables": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("seq_scan", "bigint", 13), ("seq_tup_read", "bigint", 13), ("idx_scan", "bigint", 13), ("idx_tup_fetch", "bigint", 13), ("n_tup_ins", "bigint", 13), ("n_tup_upd", "bigint", 13), ("n_tup_del", "bigint", 13), ("n_tup_hot_upd", "bigint", 13), ("n_tup_newpage_upd", "bigint", 16), ], "pg_statio_all_indexes": [ ("relid", "oid", 13), ("indexrelid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("indexrelname", "name", 13), ("idx_blks_read", "bigint", 13), ("idx_blks_hit", "bigint", 13), ], "pg_statio_all_sequences": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("blks_read", "bigint", 13), ("blks_hit", "bigint", 13), ], "pg_statio_all_tables": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("heap_blks_read", "bigint", 13), ("heap_blks_hit", "bigint", 13), ("idx_blks_read", "bigint", 13), ("idx_blks_hit", "bigint", 13), ("toast_blks_read", "bigint", 13), ("toast_blks_hit", "bigint", 13), ("tidx_blks_read", "bigint", 13), ("tidx_blks_hit", "bigint", 13), ], "pg_statio_sys_indexes": [ ("relid", "oid", 13), ("indexrelid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("indexrelname", "name", 13), ("idx_blks_read", "bigint", 13), ("idx_blks_hit", "bigint", 13), ], "pg_statio_sys_sequences": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("blks_read", "bigint", 13), ("blks_hit", "bigint", 13), ], "pg_statio_sys_tables": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("heap_blks_read", "bigint", 13), ("heap_blks_hit", "bigint", 13), ("idx_blks_read", "bigint", 13), ("idx_blks_hit", "bigint", 13), ("toast_blks_read", "bigint", 13), ("toast_blks_hit", "bigint", 13), ("tidx_blks_read", "bigint", 13), ("tidx_blks_hit", "bigint", 13), ], "pg_statio_user_indexes": [ ("relid", "oid", 13), ("indexrelid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("indexrelname", "name", 13), ("idx_blks_read", "bigint", 13), ("idx_blks_hit", "bigint", 13), ], "pg_statio_user_sequences": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("blks_read", "bigint", 13), ("blks_hit", "bigint", 13), ], "pg_statio_user_tables": [ ("relid", "oid", 13), ("schemaname", "name", 13), ("relname", "name", 13), ("heap_blks_read", "bigint", 13), ("heap_blks_hit", "bigint", 13), ("idx_blks_read", "bigint", 13), ("idx_blks_hit", "bigint", 13), ("toast_blks_read", "bigint", 13), ("toast_blks_hit", "bigint", 13), ("tidx_blks_read", "bigint", 13), ("tidx_blks_hit", "bigint", 13), ], "pg_statistic": [ ("starelid", "oid", 13), ("staattnum", "smallint", 13), ("stainherit", "boolean", 13), ("stanullfrac", "real", 13), ("stawidth", "integer", 13), ("stadistinct", "real", 13), ("stakind1", "smallint", 13), ("stakind2", "smallint", 13), ("stakind3", "smallint", 13), ("stakind4", "smallint", 13), ("stakind5", "smallint", 13), ("staop1", "oid", 13), ("staop2", "oid", 13), ("staop3", "oid", 13), ("staop4", "oid", 13), ("staop5", "oid", 13), ("stacoll1", "oid", 13), ("stacoll2", "oid", 13), ("stacoll3", "oid", 13), ("stacoll4", "oid", 13), ("stacoll5", "oid", 13), ("stanumbers1", None, 13), ("stanumbers2", None, 13), ("stanumbers3", None, 13), ("stanumbers4", None, 13), ("stanumbers5", None, 13), ("stavalues1", None, 13), ("stavalues2", None, 13), ("stavalues3", None, 13), ("stavalues4", None, 13), ("stavalues5", None, 13), ], "pg_statistic_ext": [ ("oid", "oid", 13), ("stxrelid", "oid", 13), ("stxname", "name", 13), ("stxnamespace", "oid", 13), ("stxowner", "oid", 13), ("stxkeys", None, 13), ("stxstattarget", "smallint", 13), ("stxkind", None, 13), ("stxexprs", "pg_node_tree", 14), ], "pg_statistic_ext_data": [ ("stxoid", "oid", 13), ("stxdinherit", "boolean", 15), ("stxdndistinct", "pg_ndistinct", 13), ("stxddependencies", "pg_dependencies", 13), ("stxdmcv", "pg_mcv_list", 13), ("stxdexpr", None, 14), ], "pg_stats": [ ("schemaname", "name", 13), ("tablename", "name", 13), ("attname", "name", 13), ("inherited", "boolean", 13), ("null_frac", "real", 13), ("avg_width", "integer", 13), ("n_distinct", "real", 13), ("most_common_vals", None, 13), ("most_common_freqs", None, 13), ("histogram_bounds", None, 13), ("correlation", "real", 13), ("most_common_elems", None, 13), ("most_common_elem_freqs", None, 13), ("elem_count_histogram", None, 13), ("range_length_histogram", None, 17), ("range_empty_frac", "real", 17), ("range_bounds_histogram", None, 17), ], "pg_stats_ext": [ ("schemaname", "name", 13), ("tablename", "name", 13), ("statistics_schemaname", "name", 13), ("statistics_name", "name", 13), ("statistics_owner", "name", 13), ("attnames", None, 13), ("exprs", None, 14), ("kinds", None, 13), ("inherited", "boolean", 15), ("n_distinct", "pg_ndistinct", 13), ("dependencies", "pg_dependencies", 13), ("most_common_vals", None, 13), ("most_common_val_nulls", None, 13), ("most_common_freqs", None, 13), ("most_common_base_freqs", None, 13), ], "pg_stats_ext_exprs": [ ("schemaname", "name", 14), ("tablename", "name", 14), ("statistics_schemaname", "name", 14), ("statistics_name", "name", 14), ("statistics_owner", "name", 14), ("expr", "text", 14), ("inherited", "boolean", 15), ("null_frac", "real", 14), ("avg_width", "integer", 14), ("n_distinct", "real", 14), ("most_common_vals", None, 14), ("most_common_freqs", None, 14), ("histogram_bounds", None, 14), ("correlation", "real", 14), ("most_common_elems", None, 14), ("most_common_elem_freqs", None, 14), ("elem_count_histogram", None, 14), ], "pg_subscription": [ ("oid", "oid", 13), ("subdbid", "oid", 13), ("subskiplsn", "pg_lsn", 15), ("subname", "name", 13), ("subowner", "oid", 13), ("subenabled", "boolean", 13), ("subbinary", "boolean", 14), ("substream", "\"char\"", 14), ("subtwophasestate", "\"char\"", 15), ("subdisableonerr", "boolean", 15), ("subpasswordrequired", "boolean", 16), ("subrunasowner", "boolean", 16), ("subfailover", "boolean", 17), ("subconninfo", "text", 13), ("subslotname", "name", 13), ("subsynccommit", "text", 13), ("subpublications", None, 13), ("suborigin", "text", 16), ], "pg_subscription_rel": [ ("srsubid", "oid", 13), ("srrelid", "oid", 13), ("srsubstate", "\"char\"", 13), ("srsublsn", "pg_lsn", 13), ], "pg_tables": [ ("schemaname", "name", 13), ("tablename", "name", 13), ("tableowner", "name", 13), ("tablespace", "name", 13), ("hasindexes", "boolean", 13), ("hasrules", "boolean", 13), ("hastriggers", "boolean", 13), ("rowsecurity", "boolean", 13), ], "pg_tablespace": [ ("oid", "oid", 13), ("spcname", "name", 13), ("spcowner", "oid", 13), ("spcacl", None, 13), ("spcoptions", None, 13), ], "pg_timezone_abbrevs": [ ("abbrev", "text", 13), ("utc_offset", "interval", 13), ("is_dst", "boolean", 13), ], "pg_timezone_names": [ ("name", "text", 13), ("abbrev", "text", 13), ("utc_offset", "interval", 13), ("is_dst", "boolean", 13), ], "pg_transform": [ ("oid", "oid", 13), ("trftype", "oid", 13), ("trflang", "oid", 13), ("trffromsql", "regproc", 13), ("trftosql", "regproc", 13), ], "pg_trigger": [ ("oid", "oid", 13), ("tgrelid", "oid", 13), ("tgparentid", "oid", 13), ("tgname", "name", 13), ("tgfoid", "oid", 13), ("tgtype", "smallint", 13), ("tgenabled", "\"char\"", 13), ("tgisinternal", "boolean", 13), ("tgconstrrelid", "oid", 13), ("tgconstrindid", "oid", 13), ("tgconstraint", "oid", 13), ("tgdeferrable", "boolean", 13), ("tginitdeferred", "boolean", 13), ("tgnargs", "smallint", 13), ("tgattr", None, 13), ("tgargs", "bytea", 13), ("tgqual", "pg_node_tree", 13), ("tgoldtable", "name", 13), ("tgnewtable", "name", 13), ], "pg_ts_config": [ ("oid", "oid", 13), ("cfgname", "name", 13), ("cfgnamespace", "oid", 13), ("cfgowner", "oid", 13), ("cfgparser", "oid", 13), ], "pg_ts_config_map": [ ("mapcfg", "oid", 13), ("maptokentype", "integer", 13), ("mapseqno", "integer", 13), ("mapdict", "oid", 13), ], "pg_ts_dict": [ ("oid", "oid", 13), ("dictname", "name", 13), ("dictnamespace", "oid", 13), ("dictowner", "oid", 13), ("dicttemplate", "oid", 13), ("dictinitoption", "text", 13), ], "pg_ts_parser": [ ("oid", "oid", 13), ("prsname", "name", 13), ("prsnamespace", "oid", 13), ("prsstart", "regproc", 13), ("prstoken", "regproc", 13), ("prsend", "regproc", 13), ("prsheadline", "regproc", 13), ("prslextype", "regproc", 13), ], "pg_ts_template": [ ("oid", "oid", 13), ("tmplname", "name", 13), ("tmplnamespace", "oid", 13), ("tmplinit", "regproc", 13), ("tmpllexize", "regproc", 13), ], "pg_type": [ ("oid", "oid", 13), ("typname", "name", 13), ("typnamespace", "oid", 13), ("typowner", "oid", 13), ("typlen", "smallint", 13), ("typbyval", "boolean", 13), ("typtype", "\"char\"", 13), ("typcategory", "\"char\"", 13), ("typispreferred", "boolean", 13), ("typisdefined", "boolean", 13), ("typdelim", "\"char\"", 13), ("typrelid", "oid", 13), ("typsubscript", "regproc", 14), ("typelem", "oid", 13), ("typarray", "oid", 13), ("typinput", "regproc", 13), ("typoutput", "regproc", 13), ("typreceive", "regproc", 13), ("typsend", "regproc", 13), ("typmodin", "regproc", 13), ("typmodout", "regproc", 13), ("typanalyze", "regproc", 13), ("typalign", "\"char\"", 13), ("typstorage", "\"char\"", 13), ("typnotnull", "boolean", 13), ("typbasetype", "oid", 13), ("typtypmod", "integer", 13), ("typndims", "integer", 13), ("typcollation", "oid", 13), ("typdefaultbin", "pg_node_tree", 13), ("typdefault", "text", 13), ("typacl", None, 13), ], "pg_user": [ ("usename", "name", 13), ("usesysid", "oid", 13), ("usecreatedb", "boolean", 13), ("usesuper", "boolean", 13), ("userepl", "boolean", 13), ("usebypassrls", "boolean", 13), ("passwd", "text", 13), ("valuntil", "timestamp with time zone", 13), ("useconfig", None, 13), ], "pg_user_mapping": [ ("oid", "oid", 13), ("umuser", "oid", 13), ("umserver", "oid", 13), ("umoptions", None, 13), ], "pg_user_mappings": [ ("umid", "oid", 13), ("srvid", "oid", 13), ("srvname", "name", 13), ("umuser", "oid", 13), ("usename", "name", 13), ("umoptions", None, 13), ], "pg_views": [ ("schemaname", "name", 13), ("viewname", "name", 13), ("viewowner", "name", 13), ("definition", "text", 13), ], "pg_wait_events": [ ("type", "text", 17), ("name", "text", 17), ("description", "text", 17), ] } ================================================ FILE: edb/pgsql/resolver/static.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Static evaluation for SQL.""" import functools import platform from typing import Optional, Sequence from edb import errors from edb.pgsql import ast as pgast from edb.pgsql.ast import SQLValueFunctionOP as val_func_op from edb.pgsql import common from edb.server import defines from edb.server.pgcon import errors as pgerror from edb.server.compiler import enums from edb.server.compiler.sql import DisableNormalization from . import context from . import dispatch V = common.versioned_schema Context = context.ResolverContextLevel @functools.singledispatch def eval(expr: pgast.BaseExpr, *, ctx: Context) -> Optional[pgast.BaseExpr]: """ Tries to statically evaluate expr, recursing into sub-expressions. Returns None if that is not possible. """ return None def eval_list( exprs: list[pgast.BaseExpr], *, ctx: Context ) -> Optional[list[pgast.BaseExpr]]: """ Tries to statically evaluate exprs, recursing into sub-expressions. Returns None if that is not possible. Raises DisableNormalization if param refs are encountered. """ res = [] for expr in exprs: r = eval(expr, ctx=ctx) if not r: return None res.append(r) return res def name_in_pg_catalog(name: Sequence[str]) -> Optional[str]: """ Strips `pg_catalog.` schema name from an SQL ident that resides in pg_catalog. If name is unqualified, it is deemed to be in pg_catalog, because pg_catalog is always the first schema in search_path. """ if len(name) == 1 or name[0] == 'pg_catalog': return name[-1] return None @eval.register def eval_BaseConstant( expr: pgast.BaseConstant, *, ctx: Context ) -> Optional[pgast.BaseExpr]: return expr @eval.register def eval_TypeCast( expr: pgast.TypeCast, *, ctx: Context ) -> Optional[pgast.BaseExpr]: if expr.type_name.array_bounds: return None pg_catalog_name = name_in_pg_catalog(expr.type_name.name) if pg_catalog_name == 'regclass': return cast_to_regclass(expr.arg, ctx) arg = eval(expr.arg, ctx=ctx) if not arg: return None if isinstance(arg, pgast.StringConstant): type_name = name_in_pg_catalog(expr.type_name.name) if type_name == 'text': return arg if type_name == 'bool': string = arg.val.lower() if 'true'.startswith(string) or 'yes'.startswith(string): return pgast.BooleanConstant(val=True) if 'false'.startswith(string) or 'no'.startswith(string): return pgast.BooleanConstant(val=False) raise errors.QueryError('invalid cast', span=expr.arg.span) return None # Functions that are inquiring about privileges of users or schemas. # Dict from function name into number of trailing arguments that are passed # trough. PRIVILEGE_INQUIRY_FUNCTIONS_ARGS = { 'has_any_column_privilege': 2, 'has_column_privilege': 3, 'has_database_privilege': 2, 'has_foreign_data_wrapper_privilege': 2, 'has_function_privilege': 2, 'has_language_privilege': 2, 'has_parameter_privilege': 2, 'has_schema_privilege': 2, 'has_sequence_privilege': 2, 'has_server_privilege': 2, 'has_table_privilege': 2, 'has_tablespace_privilege': 2, 'has_type_privilege': 2, 'pg_has_role': 2, } # Allowed functions from pg_catalog that start with `pg_`. # By default, all such functions are forbidden by default. # To see the list of forbidden functions, use `edb ls-forbidden-functions`. ALLOWED_ADMIN_FUNCTIONS = frozenset( { 'pg_is_in_recovery', 'pg_is_wal_replay_paused', 'pg_get_wal_replay_pause_state', 'pg_column_size', 'pg_column_compression', 'pg_database_size', 'pg_indexes_size', 'pg_relation_size', 'pg_size_bytes', 'pg_size_pretty', 'pg_table_size', 'pg_tablespace_size', 'pg_total_relation_size', 'pg_relation_filenode', 'pg_relation_filepath', 'pg_filenode_relation', 'pg_char_to_encoding', 'pg_column_is_updatable', 'pg_conf_load_time', 'pg_current_xact_id', 'pg_current_xact_id_if_assigned', 'pg_describe_object', 'pg_encoding_max_length', 'pg_encoding_to_char', 'pg_get_constraintdef', 'pg_get_expr', 'pg_get_function_arg_default', 'pg_get_function_arguments', 'pg_get_function_identity_arguments', 'pg_get_function_result', 'pg_get_functiondef', 'pg_get_indexdef', 'pg_get_keywords', 'pg_get_multixact_members', 'pg_get_object_address', 'pg_get_partition_constraintdef', 'pg_get_partkeydef', 'pg_get_publication_tables', 'pg_get_replica_identity_index', 'pg_get_replication_slots', 'pg_get_ruledef', 'pg_get_serial_sequence', 'pg_get_shmem_allocations', 'pg_get_statisticsobjdef', 'pg_get_triggerdef', 'pg_get_userbyid', 'pg_get_viewdef', 'pg_options_to_table', 'pg_has_role', 'pg_function_is_visible', 'pg_opclass_is_visible', 'pg_operator_is_visible', 'pg_opfamily_is_visible', 'pg_statistics_obj_is_visible', 'pg_table_is_visible', 'pg_ts_config_is_visible', 'pg_ts_dict_is_visible', 'pg_ts_parser_is_visible', 'pg_ts_template_is_visible', 'pg_type_is_visible', 'pg_index_column_has_property', 'pg_index_has_property', 'pg_is_in_backup', 'pg_is_other_temp_schema', 'pg_jit_available', 'pg_relation_is_updatable', 'pg_sequence_last_value', 'pg_sequence_parameters', 'pg_timezone_abbrevs', 'pg_timezone_names', 'pg_typeof', 'pg_visible_in_snapshot', 'pg_xact_commit_timestamp', 'pg_xact_status', 'pg_partition_ancestors', 'pg_backend_pid', 'pg_wal_lsn_diff', 'pg_last_wal_replay_lsn', 'pg_current_wal_flush_lsn', 'pg_relation_is_publishable', 'pg_show_all_settings', } ) WRAPPED_FUNCTIONS = frozenset( { "to_regclass", 'pg_show_all_settings', } ) @eval.register def eval_FuncCall( expr: pgast.FuncCall, *, ctx: Context, ) -> Optional[pgast.BaseExpr]: if len(expr.name) >= 3: raise errors.QueryError("unknown function", span=expr.span) fn_name = name_in_pg_catalog(expr.name) if not fn_name: return None if fn_name.startswith('pg_') and fn_name not in ALLOWED_ADMIN_FUNCTIONS: raise errors.QueryError( f"forbidden function '{fn_name}'", span=expr.span, pgext_code=pgerror.ERROR_INSUFFICIENT_PRIVILEGE, ) if fn_name == 'current_schemas': return eval_current_schemas(expr, ctx=ctx) if fn_name == 'current_database': return pgast.StringConstant(val=ctx.options.current_database) if fn_name == 'current_query': return pgast.StringConstant(val=ctx.options.current_query) if fn_name == 'version': from edb import buildmeta edgedb_version = buildmeta.get_version_line() return pgast.StringConstant( val=" ".join( [ "PostgreSQL", str(defines.PGEXT_POSTGRES_VERSION), f"(Gel {edgedb_version}),", platform.architecture()[0], ] ), ) if fn_name == "set_config": # HACK: pg_dump # - set_config('search_path', '', false) # - set_config(name, 'view, foreign-table', false) # HACK: pgadmin # - set_config('bytea_output','hex',false) # HACK: asyncpg # - set_config('jit', ...) ctx.env.capabilities |= enums.Capability.SQL_SESSION_CONFIG if args := eval_list(expr.args, ctx=ctx): name, value, is_local = args if isinstance(name, pgast.StringConstant): if ( isinstance(value, pgast.StringConstant) and isinstance(is_local, pgast.BooleanConstant) ): if ( name.val == "search_path" and value.val == "" and not is_local.val ): return value if ( name.val == "bytea_output" and value.val == "hex" and not is_local.val ): return value if name.val == "jit": return value elif args := eval_list(expr.args[1:], ctx=ctx): value, is_local = args if ( isinstance(value, pgast.StringConstant) and isinstance(is_local, pgast.BooleanConstant) ): if ( value.val == "view, foreign-table" and not is_local.val ): return value raise errors.QueryError( "function set_config is not supported", span=expr.span, pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, ) if fn_name == 'current_setting': arg = require_string_param(expr, ctx) val = None if arg == 'search_path': val = ', '.join(ctx.options.search_path) if val: return pgast.StringConstant(val=val) return expr if fn_name == "pg_filenode_relation": raise errors.QueryError( f"function pg_catalog.{fn_name} is not supported", span=expr.span, pgext_code=pgerror.ERROR_FEATURE_NOT_SUPPORTED, ) if fn_name == "pg_get_serial_sequence": eval_list(expr.args, ctx=ctx) # we do not expose sequences, so any calls to this function returns NULL return pgast.NullConstant() if fn_name in WRAPPED_FUNCTIONS: args = eval_list(expr.args, ctx=ctx) return pgast.FuncCall( name=(V('edgedbsql'), fn_name), args=expr.args if args is None else args, ) cast_arg_to_regclass = { 'pg_relation_filenode', 'pg_relation_filepath', 'pg_relation_size', } if fn_name in cast_arg_to_regclass: regclass_oid = cast_to_regclass(expr.args[0], ctx=ctx) return pgast.FuncCall( name=('pg_catalog', fn_name), args=[regclass_oid] ) if num_allowed_args := PRIVILEGE_INQUIRY_FUNCTIONS_ARGS.get(fn_name, None): # For privilege inquiry functions, we strip the leading user (role), # so the inquiry refers to current user's privileges. # This is needed because the exposed username is not necessarily the # same as the user we use to connect to Postgres instance. # We do not allow creating additional users, so this should not be a # problem anyway. # See: https://www.postgresql.org/docs/15/functions-info.html # TODO: deny INSERT, UPDATE and all other unsupported functions fn_args = expr.args[-num_allowed_args:] fn_args = dispatch.resolve_list(fn_args, ctx=ctx) # schema and table names need to be remapped. This is accomplished # with wrapper functions defined in metaschema.py. has_wrapper = { 'has_database_privilege', 'has_schema_privilege', 'has_table_privilege', 'has_column_privilege', 'has_any_column_privilege', } if fn_name in has_wrapper: return pgast.FuncCall(name=(V('edgedbsql'), fn_name), args=fn_args) return pgast.FuncCall(name=('pg_catalog', fn_name), args=fn_args) if fn_name == 'pg_table_is_visible': arg_0 = dispatch.resolve(expr.args[0], ctx=ctx) # our *_is_visible functions need search_path, passed in as an array arg_1 = pgast.ArrayExpr( elements=[ pgast.StringConstant(val=v) for v in ctx.options.search_path ] ) return pgast.FuncCall( name=(V('edgedbsql'), fn_name), args=[arg_0, arg_1] ) return None def require_string_param( expr: pgast.FuncCall, ctx: Context ) -> str: args = eval_list(expr.args, ctx=ctx) arg = args[0] if args and len(args) == 1 else None if not isinstance(arg, pgast.StringConstant): raise errors.QueryError( f"function pg_catalog.{expr.name[-1]} requires a string literal", span=expr.span, pgext_code=pgerror.ERROR_UNDEFINED_FUNCTION ) return arg.val def require_bool_param( expr: pgast.FuncCall, ctx: Context ) -> bool: args = eval_list(expr.args, ctx=ctx) arg = args[0] if args and len(args) == 1 else None if not isinstance(arg, pgast.BooleanConstant): raise errors.QueryError( f"function pg_catalog.{expr.name[-1]} requires a boolean literal", span=expr.span, pgext_code=pgerror.ERROR_UNDEFINED_FUNCTION ) return arg.val def cast_to_regclass(param: pgast.BaseExpr, ctx: Context) -> pgast.BaseExpr: """ Equivalent to `::regclass` in SQL. Converts a string constant or a oid to a "registered class" (fully-qualified name of the table/index/sequence). In practice, type of resulting expression is oid. """ expr = eval(param, ctx=ctx) if expr is None: param = dispatch.resolve(param, ctx=ctx) return pgast.FuncCall( name=(V('edgedbsql'), "to_regclass"), args=[param] ) elif isinstance(expr, pgast.NullConstant): return pgast.NullConstant() elif isinstance(expr, pgast.StringConstant): return pgast.FuncCall( name=(V('edgedbsql'), "to_regclass"), args=[expr], ) elif isinstance(expr, pgast.NumericConstant): return pgast.TypeCast( arg=expr, type_name=pgast.TypeName(name=('pg_catalog', 'regclass')), ) else: return pgast.FuncCall( name=(V('edgedbsql'), "to_regclass"), args=[expr] ) def eval_current_schemas( expr: pgast.FuncCall, ctx: Context ) -> Optional[pgast.BaseExpr]: include_implicit = require_bool_param(expr, ctx) res = [] if include_implicit: # if any temporary object has been created in current session, # here we should also append res.append('pg_temp_xxx') were xxx is # a number assigned by the server. res.append('pg_catalog') res.extend(ctx.options.search_path) return pgast.ArrayExpr(elements=[pgast.StringConstant(val=r) for r in res]) VALUE_FUNC_PASS_THROUGH = frozenset({ val_func_op.CURRENT_DATE, val_func_op.CURRENT_TIME, val_func_op.CURRENT_TIME_N, val_func_op.CURRENT_TIMESTAMP, val_func_op.CURRENT_TIMESTAMP_N, val_func_op.LOCALTIME, val_func_op.LOCALTIME_N, val_func_op.LOCALTIMESTAMP, val_func_op.LOCALTIMESTAMP_N, }) VALUE_FUNC_USER = frozenset({ val_func_op.CURRENT_ROLE, val_func_op.CURRENT_USER, val_func_op.USER, val_func_op.SESSION_USER, }) def eval_current_user( *, ctx: Context, ) -> pgast.BaseExpr: from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.pgsql import compiler as pgcompiler from . import command ql_stmt = qlast.SelectQuery( result=qlast.GlobalExpr( name=qlast.ObjectRef(module='sys', name='current_role'), ) ) ir_stmt = qlcompiler.compile_ast_to_ir( ql_stmt, ctx.schema, options=qlcompiler.CompilerOptions(), ) sql_tree = pgcompiler.compile_ir_to_sql_tree( ir_stmt, output_format=pgcompiler.OutputFormat.NATIVE_INTERNAL, alias_generator=ctx.alias_generator, ) command.merge_params(sql_tree, ir_stmt, ctx) assert isinstance(sql_tree.ast, pgast.BaseExpr) return sql_tree.ast @eval.register def eval_SQLValueFunction( expr: pgast.SQLValueFunction, *, ctx: Context, ) -> pgast.BaseExpr: if expr.op in VALUE_FUNC_PASS_THROUGH: return expr if expr.op in VALUE_FUNC_USER: return eval_current_user(ctx=ctx) if expr.op == val_func_op.CURRENT_CATALOG: return pgast.StringConstant(val=ctx.options.current_database) if expr.op == val_func_op.CURRENT_SCHEMA: # note: PG also does a check that this schema exists and proceeds to # the next one in the search path return pgast.StringConstant(val=ctx.options.search_path[0]) # this should never happen raise NotImplementedError() @eval.register def eval_ParamRef( _expr: pgast.ParamRef, *, ctx: Context, ) -> Optional[pgast.BaseExpr]: if len(ctx.options.normalized_params) > 0: raise DisableNormalization() else: return None ================================================ FILE: edb/pgsql/schemamech.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, Sequence, Collection import itertools import dataclasses from edb import errors from edb.ir import ast as irast from edb.ir import astexpr as irastexpr from edb.ir import typeutils as irtyputils from edb.ir import utils as ir_utils from edb.edgeql import compiler as qlcompiler from edb.edgeql import ast as qlast from edb.edgeql import parser as ql_parser from edb.edgeql.compiler import astutils as ql_astutils from edb.schema import name as s_name from edb.schema import pointers as s_pointers from edb.schema import scalars as s_scalars from edb.schema import utils as s_utils from edb.schema import types as s_types from edb.schema import constraints as s_constraints from edb.schema import schema as s_schema from edb.schema import sources as s_sources from edb.schema import expr as s_expr from edb.schema import objects as s_obj from edb.common import ast from edb.common import parsing from . import ast as pgast from . import dbops from . import deltadbops from . import common from . import types from . import compiler from . import codegen from .common import qname as qn def _get_exclusive_refs(tree: irast.Statement) -> Sequence[irast.Base] | None: # Check if the expression is # std::_is_exclusive() [and std::_is_exclusive()...] expr = tree.expr.expr return irastexpr.get_constraint_references(expr) @dataclasses.dataclass(kw_only=True, repr=False, eq=False, slots=True) class PGConstrData: subject_db_name: Optional[tuple[str, str]] expressions: list[ExprData] relative_expressions: list[ExprData] table_type: str except_data: Optional[ExprDataSources] scope: Optional[str] = None type: Optional[str] = None @dataclasses.dataclass(kw_only=True, repr=False, eq=False, slots=True) class ExprData: exprdata: ExprDataSources is_multicol: bool is_trivial: bool subject_db_name: Optional[tuple[str, str]] = None except_data: Optional[ExprDataSources] = None @dataclasses.dataclass(kw_only=True, repr=False, eq=False, slots=True) class ExprDataSources: plain: str new: str old: str plain_chunks: Sequence[str] def _to_source(sql_expr: pgast.Base) -> str: src = codegen.generate_source(sql_expr) # ColumnRefs are the most common thing, and they should be safe to # skip parenthesizing, for deuglification purposes. anything else # we put parens around, to be sure. if not isinstance(sql_expr, pgast.ColumnRef): src = f'({src})' return src def _edgeql_tree_to_expr_data( sql_expr: pgast.Base, refs: Optional[set[pgast.ColumnRef]] = None ) -> ExprDataSources: if refs is None: refs = set( ast.find_children( sql_expr, pgast.ColumnRef, lambda n: len(n.name) == 1 ) ) plain_expr = _to_source(sql_expr) if isinstance(sql_expr, (pgast.RowExpr, pgast.ImplicitRowExpr)): chunks = [] for elem in sql_expr.args: chunks.append(_to_source(elem)) else: chunks = [plain_expr] if isinstance(sql_expr, pgast.ColumnRef): refs.add(sql_expr) for ref in refs: assert isinstance(ref.name, list) ref.name.insert(0, 'NEW') new_expr = _to_source(sql_expr) for ref in refs: assert isinstance(ref.name, list) ref.name[0] = 'OLD' old_expr = _to_source(sql_expr) return ExprDataSources( plain=plain_expr, new=new_expr, old=old_expr, plain_chunks=chunks ) def _edgeql_ref_to_pg_constr( subject: s_constraints.ConsistencySubject, origin_subject: s_types.Type | s_pointers.Pointer | None, tree: irast.Base, ) -> ExprData: sql_res = compiler.compile_ir_to_sql_tree(tree, singleton_mode=True) sql_expr: pgast.Base if isinstance(sql_res.ast, pgast.SelectStmt): # XXX: use ast pattern matcher for this from_clause = sql_res.ast.from_clause[0] assert isinstance(from_clause, pgast.RelRangeVar) assert isinstance(from_clause.relation, pgast.CommonTableExpr) sql_expr = from_clause.relation.query.target_list[0].val else: sql_expr = sql_res.ast if isinstance(tree, irast.Statement): tree = tree.expr if isinstance(tree, irast.Set) and isinstance(tree.expr, irast.SelectStmt): tree = tree.expr.result is_multicol = isinstance(sql_expr, (pgast.RowExpr, pgast.ImplicitRowExpr)) # Determine if the sequence of references are all simple refs, not # expressions. This influences the type of Postgres constraint used. # is_trivial = isinstance(sql_expr, pgast.ColumnRef) or ( isinstance(sql_expr, (pgast.RowExpr, pgast.ImplicitRowExpr)) and all(isinstance(el, pgast.ColumnRef) for el in sql_expr.args) ) # Find all field references # refs = set( ast.find_children(sql_expr, pgast.ColumnRef, lambda n: len(n.name) == 1) ) if isinstance(subject, s_scalars.ScalarType): # Domain constraint, replace with VALUE assert origin_subject subj_pgname = common.edgedb_name_to_pg_name(str(subject.id)) orgsubj_pgname = common.edgedb_name_to_pg_name(str(origin_subject.id)) for ref in refs: if ref.name != [subj_pgname] and ref.name != [orgsubj_pgname]: raise ValueError( f'unexpected node reference in ' f'ScalarType constraint: {qn(*ref.name)}' ) # work around the immutability check object.__setattr__(ref, 'name', ['VALUE']) exprdata = _edgeql_tree_to_expr_data(sql_expr, refs=refs) # Scalar constraints shouldn't ever fail on NULL if isinstance(subject, s_scalars.ScalarType): exprdata.plain = f"VALUE IS NULL OR ({exprdata.plain})" return ExprData( exprdata=exprdata, is_multicol=is_multicol, is_trivial=is_trivial ) @dataclasses.dataclass(frozen=True) class CompiledConstraintData: subject: s_types.Type | s_pointers.Pointer exclusive_expr_refs: Optional[Sequence[irast.Base]] subject_db_name: Optional[tuple[str, str]] except_data: Optional[ExprDataSources] ir: irast.Statement subject_table_type: str def _compile_constraint_data( constraint: s_constraints.Constraint, schema: s_schema.Schema, is_optional: bool, *, span: Optional[parsing.Span] = None, type_remaps: Optional[dict[s_obj.Object, s_obj.Object]] = None, ) -> CompiledConstraintData: sub = constraint.get_subject(schema) assert isinstance( sub, (s_types.Type, s_pointers.Pointer, s_scalars.ScalarType) ) subject: s_types.Type | s_pointers.Pointer = sub path_prefix_anchor = '__subject__' singletons = frozenset({(subject, is_optional)}) options = qlcompiler.CompilerOptions( anchors={'__subject__': subject}, path_prefix_anchor=path_prefix_anchor, apply_query_rewrites=False, singletons=singletons, schema_object_context=type(constraint), type_remaps=type_remaps if type_remaps is not None else {}, ) final_expr: Optional[s_expr.Expression] = constraint.get_finalexpr(schema) assert final_expr is not None and final_expr.parse() is not None ir = qlcompiler.compile_ast_to_ir( final_expr.parse(), schema, options=options, ) assert isinstance(ir, irast.Statement) except_ir: Optional[irast.Statement] = None except_data = None if except_expr := constraint.get_except_expr(schema): assert isinstance(except_expr, s_expr.Expression) except_ir = qlcompiler.compile_ast_to_ir( except_expr.parse(), schema, options=options, ) except_sql = compiler.compile_ir_to_sql_tree( except_ir, singleton_mode=True ) except_data = _edgeql_tree_to_expr_data(except_sql.ast) terminal_refs: set[irast.Set] = ( ir_utils.get_longest_paths(ir.expr.expr) ) if except_ir is not None: terminal_refs.update( ir_utils.get_longest_paths(except_ir.expr) ) ref_tables = get_ref_storage_info(ir.schema, terminal_refs) if len(ref_tables) > 1: raise errors.InvalidConstraintDefinitionError( f'Constraint {constraint.get_displayname(schema)} on ' f'{subject.get_displayname(schema)} is not supported ' f'because it would depend on multiple objects', span=span, ) elif ref_tables: subject_db_name, info = next(iter(ref_tables.items())) subject_table_type = info[0][3].table_type else: # the expression does don't have any refs: default to the subject table subject_table: Optional[s_obj.InheritingObject] | s_types.Type if isinstance(subject, s_pointers.Pointer): subject_table = subject.get_source(schema) else: subject_table = subject assert subject_table subject_db_name = common.get_backend_name( schema, subject_table, catenate=False, ) subject_table_type = 'ObjectType' exclusive_expr_refs = _get_exclusive_refs(ir) return CompiledConstraintData( subject, exclusive_expr_refs, subject_db_name, except_data, ir, subject_table_type, ) def _get_compiled_constraint_expr_data( primary_subject: s_constraints.ConsistencySubject, constraint_data: CompiledConstraintData, ) -> list[ExprData]: exprdatas: list[ExprData] = [] constraint_subject = ( constraint_data.subject if constraint_data.subject != primary_subject else None ) assert constraint_data.exclusive_expr_refs is not None for ref in constraint_data.exclusive_expr_refs: exprdata = _edgeql_ref_to_pg_constr( primary_subject, constraint_subject, ref ) exprdata.subject_db_name = constraint_data.subject_db_name exprdata.except_data = constraint_data.except_data exprdatas.append(exprdata) return exprdatas def table_constraint_requires_triggers( constraint: s_constraints.Constraint, schema: s_schema.Schema, constraint_type: str, ): subject = constraint.get_subject(schema) cname = constraint.get_shortname(schema) if ( isinstance(subject, s_pointers.Pointer) and subject.is_id_pointer(schema) and cname == s_name.QualName('std', 'exclusive') ): return False else: return constraint_type != 'check' def compile_constraint( subject: s_constraints.ConsistencySubject, constraint: s_constraints.Constraint, schema: s_schema.Schema, span: Optional[parsing.Span], ) -> SchemaDomainConstraint | SchemaTableConstraint: assert constraint.get_subject(schema) is not None assert isinstance( subject, (s_types.Type, s_pointers.Pointer, s_scalars.ScalarType) ) constraint_origins = constraint.get_constraint_origins(schema) first_subject = constraint_origins[0].get_subject(schema) is_optional = isinstance( first_subject, s_pointers.Pointer ) and not first_subject.get_required(schema) constraint_data = _compile_constraint_data( constraint, schema, is_optional, span=span, # Remap the constraint origin to the subject, so that if # we have B <: A, and the constraint references A.foo, it # gets rewritten in the subtype to B.foo. It's OK to only # look at one constraint origin, because if there were # multiple different origins, they couldn't get away with # referring to the type explicitly. type_remaps={first_subject: subject}, ) pg_constr_data = PGConstrData( subject_db_name=constraint_data.subject_db_name, expressions=[], relative_expressions=[], table_type=constraint_data.subject_table_type, except_data=constraint_data.except_data, ) if constraint_data.exclusive_expr_refs: origin_expr_datas: dict[ s_constraints.Constraint, list[ExprData] ] = {} for origin in constraint_origins: if origin == constraint: origin_data = constraint_data else: origin_data = _compile_constraint_data( origin, schema, is_optional, ) origin_expr_datas[origin] = _get_compiled_constraint_expr_data( subject, origin_data ) # Set constraint expressions expressions: list[ExprData] if constraint in origin_expr_datas: expressions = origin_expr_datas[constraint] else: expressions = _get_compiled_constraint_expr_data( subject, constraint_data ) pg_constr_data.expressions.extend(expressions) # Set relative expressions # These are only needed for constraint triggers. if ( not isinstance(constraint.get_subject(schema), s_scalars.ScalarType) and table_constraint_requires_triggers( constraint, schema, 'unique' ) ): relatives = list(set( descendant for origin in constraint_origins for descendant in itertools.chain( [origin], origin.descendants(schema) ) )) relative_expressions: list[ExprData] = [] for relative in relatives: if relative == constraint: relative_expressions.extend(expressions) elif relative in origin_expr_datas: relative_expressions.extend(origin_expr_datas[relative]) else: relative_data = _compile_constraint_data( relative, schema, is_optional, ) relative_expressions.extend( _get_compiled_constraint_expr_data( subject, relative_data ) ) pg_constr_data.relative_expressions.extend(relative_expressions) pg_constr_data.scope = 'relation' pg_constr_data.type = 'unique' else: assert len(constraint_origins) == 1 origin_data = ( _compile_constraint_data( constraint_origins[0], schema, is_optional, ) if constraint_origins[0] != constraint else constraint_data ) exprdata = _edgeql_ref_to_pg_constr( subject, origin_data.subject, constraint_data.ir ) exprdata.subject_db_name = origin_data.subject_db_name exprdata.except_data = origin_data.except_data pg_constr_data.expressions.append(exprdata) pg_constr_data.scope = 'row' pg_constr_data.type = 'check' if isinstance(constraint.get_subject(schema), s_scalars.ScalarType): return SchemaDomainConstraint( subject=subject, constraint=constraint, pg_constr_data=pg_constr_data, schema=schema, ) else: return SchemaTableConstraint( subject=subject, constraint=constraint, pg_constr_data=pg_constr_data, schema=schema, ) @dataclasses.dataclass(kw_only=True, repr=False, eq=False, slots=True) class SchemaDomainConstraint: subject: s_constraints.ConsistencySubject constraint: s_constraints.Constraint pg_constr_data: PGConstrData schema: s_schema.Schema def _domain_constraint(self, constr: SchemaConstraint): domain_name = constr.pg_constr_data.subject_db_name expressions = constr.pg_constr_data.expressions return deltadbops.SchemaConstraintDomainConstraint( domain_name, constr.constraint, expressions, schema=self.schema ) def create_ops(self): ops = dbops.CommandGroup() domconstr = self._domain_constraint(self) add_constr = dbops.AlterDomainAddConstraint( name=domconstr.get_subject_name(quote=False), constraint=domconstr) ops.add_command(add_constr) return ops def alter_ops( self, orig_constr: SchemaConstraint ): ops = dbops.CommandGroup() return ops def delete_ops(self): ops = dbops.CommandGroup() domconstr = self._domain_constraint(self) add_constr = dbops.AlterDomainDropConstraint( name=domconstr.get_subject_name(quote=False), constraint=domconstr) ops.add_command(add_constr) return ops def enforce_ops(self): ops = dbops.CommandGroup() return ops def update_trigger_ops(self) -> dbops.CommandGroup: ops = dbops.CommandGroup() return ops def fixup_trigger_ops(self) -> dbops.CommandGroup: ops = dbops.CommandGroup() return ops @dataclasses.dataclass(kw_only=True, repr=False, eq=False, slots=True) class SchemaTableConstraint: subject: s_constraints.ConsistencySubject constraint: s_constraints.Constraint pg_constr_data: PGConstrData schema: s_schema.Schema def _table_constraint( self, constr: SchemaConstraint ) -> deltadbops.SchemaConstraintTableConstraint: pg_c = constr.pg_constr_data table_name = pg_c.subject_db_name expressions = pg_c.expressions relative_expressions = pg_c.relative_expressions assert table_name return deltadbops.SchemaConstraintTableConstraint( table_name, constraint=constr.constraint, exprdata=expressions, relative_exprdata=relative_expressions, except_data=pg_c.except_data, scope=pg_c.scope, type=pg_c.type, table_type=pg_c.table_type, schema=constr.schema, ) def create_ops(self): ops = dbops.CommandGroup() tabconstr = self._table_constraint(self) add_constr = deltadbops.AlterTableAddConstraint( name=tabconstr.get_subject_name(quote=False), constraint=tabconstr, ) ops.add_command(add_constr) return ops def alter_ops( self, orig_constr: SchemaConstraint ): ops = dbops.CommandGroup() tabconstr = self._table_constraint(self) orig_tabconstr = self._table_constraint(orig_constr) alter_constr = deltadbops.AlterTableAlterConstraint( name=tabconstr.get_subject_name(quote=False), constraint=orig_tabconstr, new_constraint=tabconstr, ) ops.add_command(alter_constr) return ops def delete_ops(self): ops = dbops.CommandGroup() tabconstr = self._table_constraint(self) add_constr = deltadbops.AlterTableDropConstraint( name=tabconstr.get_subject_name(quote=False), constraint=tabconstr, ) ops.add_command(add_constr) return ops def enforce_ops(self) -> dbops.CommandGroup: ops = dbops.CommandGroup() tabconstr = self._table_constraint(self) constr_name = tabconstr.constraint_name() raw_constr_name = tabconstr.constraint_name(quote=False) for expr, relative_expr in zip( itertools.cycle(tabconstr._exprdata), tabconstr._relative_exprdata ): exprdata = expr.exprdata relative_exprdata = relative_expr.exprdata old_expr = relative_exprdata.old new_expr = exprdata.new assert relative_expr.subject_db_name schemaname, tablename = relative_expr.subject_db_name real_tablename = tabconstr.get_subject_name(quote=False) errmsg = 'duplicate key value violates unique ' \ 'constraint {constr}'.format(constr=constr_name) detail = common.quote_literal( f"Key ({relative_exprdata.plain}) already exists." ) if ( isinstance(self.subject, s_pointers.Pointer) and self.pg_constr_data.table_type == 'link' ): key = "source" else: key = "id" except_data = tabconstr._except_data relative_except_data = relative_expr.except_data if except_data: assert relative_except_data except_part = f''' AND ({relative_except_data.old} is not true) AND ({except_data.new} is not true) ''' else: except_part = '' check = dbops.Query( f''' SELECT edgedb_VER.raise( NULL::text, 'unique_violation', msg => '{errmsg}', "constraint" => '{raw_constr_name}', "table" => '{tablename}', "schema" => '{schemaname}', detail => {detail} ) FROM {common.qname(schemaname, tablename)} AS OLD CROSS JOIN {common.qname(*real_tablename)} AS NEW WHERE {old_expr} = {new_expr} and OLD.{key} != NEW.{key} {except_part} INTO _dummy_text; ''' ) ops.add_command(check) return ops def update_trigger_ops(self) -> dbops.CommandGroup: ops = dbops.CommandGroup() tabconstr = self._table_constraint(self) add_constr = deltadbops.AlterTableUpdateConstraintTrigger( name=tabconstr.get_subject_name(quote=False), constraint=tabconstr, ) ops.add_command(add_constr) return ops def fixup_trigger_ops(self) -> dbops.CommandGroup: # Pre 6.8 versions of gel created needless disabled triggers # in some cases. This path (invoked by administer # remove_pointless_triggers()) deletes them. ops = dbops.CommandGroup() tabconstr = self._table_constraint(self) add_constr = deltadbops.AlterTableUpdateConstraintTriggerFixup( name=tabconstr.get_subject_name(quote=False), constraint=tabconstr, ) ops.add_command(add_constr) return ops SchemaConstraint = SchemaDomainConstraint | SchemaTableConstraint def ptr_default_to_col_default(schema, ptr, expr): try: # NOTE: This code currently will only be invoked for scalars. # Blindly cast the default expression into the ptr target # type, validation of the expression type is not the concern # of this function. eql = ql_parser.parse_query(expr.text) eql = ql_astutils.ensure_ql_query( qlast.TypeCast( type=s_utils.typeref_to_ast( schema, ptr.get_target(schema)), expr=eql, ) ) ir = qlcompiler.compile_ast_to_ir(eql, schema) except (errors.SchemaError, errors.QueryError): # Reference errors mean that is is a non-constant default # referring to a not-yet-existing objects. return None if not ir_utils.is_const(ir): return None if ast.find_children(ir, irast.TupleIndirectionPointer): return None try: sql_res = compiler.compile_ir_to_sql_tree(ir, singleton_mode=True) except errors.UnsupportedFeatureError: return None sql_text = _to_source(sql_res.ast) return sql_text RefTables = dict[ Optional[tuple[str, str]], list[ tuple[ irast.Set, s_pointers.PointerLike, s_pointers.PointerLike | s_types.Type, types.PointerStorageInfo, ] ], ] def get_ref_storage_info( schema: s_schema.Schema, refs: Collection[irast.Set] ) -> RefTables: link_biased: dict[irast.Set, types.PointerStorageInfo] = {} objtype_biased: dict[irast.Set, types.PointerStorageInfo] = {} RefPtr = tuple[ s_pointers.PointerLike, s_types.Type | s_pointers.PointerLike ] ref_ptrs: dict[irast.Set, RefPtr] = {} refs = list(refs) for ref in refs: ptr: s_pointers.PointerLike src: s_types.Type | s_pointers.PointerLike rptr = ref.expr if isinstance(ref.expr, irast.Pointer) else None if rptr is None: source_typeref = ref.typeref if not irtyputils.is_object(source_typeref): continue if irtyputils.is_free_object(source_typeref): continue schema, t = irtyputils.ir_typeref_to_type(schema, ref.typeref) assert isinstance(t, s_sources.Source) ptr = t.getptr(schema, s_name.UnqualName('id')) else: ptrref = rptr.ptrref schema, ptr = irtyputils.ptrcls_from_ptrref(ptrref, schema=schema) source_typeref = rptr.source.typeref if ptr.is_link_property(schema): assert rptr and rptr.source assert isinstance(rptr.source.expr, irast.Pointer) srcref = rptr.source.expr.ptrref schema, src = irtyputils.ptrcls_from_ptrref( srcref, schema=schema) if src.get_is_derived(schema): # This specialized pointer was derived specifically # for the purposes of constraint expr compilation. src = src.get_bases(schema).first(schema) elif ptr.is_tuple_indirection(): assert rptr refs.append(rptr.source) # noqa: B909 continue elif ptr.is_type_intersection(): assert rptr refs.append(rptr.source) # noqa: B909 continue else: schema, src = irtyputils.ir_typeref_to_type(schema, source_typeref) ref_ptrs[ref] = (ptr, src) for ref, (ptr, src) in ref_ptrs.items(): ptr_info = types.get_pointer_storage_info( ptr, source=src, resolve_type=False, schema=schema) # type: ignore # See if any of the refs are hosted in pointer tables and others # are not... if ptr_info.table_type == 'link': link_biased[ref] = ptr_info else: objtype_biased[ref] = ptr_info if link_biased and objtype_biased: break if link_biased and objtype_biased: for ref in objtype_biased.copy(): ptr, src = ref_ptrs[ref] ptr_info = types.get_pointer_storage_info( ptr, source=src, # type: ignore resolve_type=False, link_bias=True, schema=schema, ) if ptr_info is not None and ptr_info.table_type == 'link': link_biased[ref] = ptr_info objtype_biased.pop(ref) ref_tables: RefTables = {} for ref, ptr_info in itertools.chain( objtype_biased.items(), link_biased.items()): ptr, src = ref_ptrs[ref] try: ref_tables[ptr_info.table_name].append((ref, ptr, src, ptr_info)) except KeyError: ref_tables[ptr_info.table_name] = [(ref, ptr, src, ptr_info)] return ref_tables ================================================ FILE: edb/pgsql/trampoline.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Support for namespacing and trampolining the standard library. The idea here is that all of the functions, tables, and views in edgedb/edgedbstd/edgedbsql should be moved into namespaced libraries of the form `edgedb_VER`, where VER will be substituted with some version identifier. Then, for anything that might be referenced by a function, constraint, etc, we will create a *trampoline* in the un-suffixed namespace. When doing (eventually) in-place version upgrades, we will create the new namespace and then update the trampolines to point to it. CURRENT STATUS: So far, functions and views are mostly namespaced. Standard library schema object tables aren't yet. """ from __future__ import annotations from typing import ( TYPE_CHECKING, Optional, Sequence, ) import abc import copy import dataclasses from . import common from . import dbops q = common.qname qi = common.quote_ident ql = common.quote_literal V = common.versioned_schema def fixup_query(query: str) -> str: for s in common.VERSIONED_SCHEMAS: query = query.replace(f"{s}_VER", V(s)) return query class VersionedFunction(dbops.Function): if TYPE_CHECKING: # What the volatility of the trampoline wrapper should be. # This is sometimes immutable even when the underlying is # stable, for functions that must be immutable so they can go # into indexes/constraints but might do something technically # stable (like raise an error). # # This allows the real function to be inlined while allowing # the wrapper to get used in indexes/constraints. wrapper_volatility: Optional[str] def __init__( self, name: tuple[str, ...], *, args: Optional[Sequence[dbops.FunctionArg]] = None, returns: str | tuple[str, ...], text: str, volatility: str = "volatile", language: str = "sql", has_variadic: Optional[bool] = None, strict: bool = False, parallel_safe: bool = False, set_returning: bool = False, wrapper_volatility: Optional[str] = None, ): pass else: def __init__(self, *args, wrapper_volatility=None, **kwargs): super().__init__(*args, **kwargs) self.name = ( common.maybe_versioned_schema(self.name[0]), *self.name[1:]) self.text = fixup_query(self.text) self.wrapper_volatility = wrapper_volatility if self.args: nargs = [] for arg in self.args: if isinstance(arg, tuple) and isinstance(arg[1], tuple): new_name = ( arg[1][0].replace('_VER', V('')), *arg[1][1:]) arg = (arg[0], new_name, *arg[2:]) nargs.append(arg) self.args = nargs class VersionedView(dbops.View): if not TYPE_CHECKING: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.name = ( common.maybe_versioned_schema(self.name[0]), *self.name[1:]) self.query = fixup_query(self.query) @dataclasses.dataclass class Trampoline: name: tuple[str, str] @abc.abstractmethod def make(self) -> dbops.Command: pass @abc.abstractmethod def drop(self) -> dbops.Command: pass @dataclasses.dataclass class TrampolineFunction(Trampoline): func: dbops.Function def make(self) -> dbops.Command: return dbops.CreateFunction(self.func, or_replace=True) def drop(self) -> dbops.Command: return dbops.DropFunction( self.func.name, args=self.func.args or (), has_variadic=bool(self.func.has_variadic), if_exists=True, ) @dataclasses.dataclass class TrampolineView(Trampoline): old_name: tuple[str, str] def make(self) -> dbops.Command: return dbops.Query(f''' PERFORM {V('edgedb')}._create_trampoline_view( {ql(q(*self.old_name))}, {ql(self.name[0])}, {ql(self.name[1])}) ''') def drop(self) -> dbops.Command: return dbops.DropView( self.name, conditions=[dbops.ViewExists(self.name)], ) def make_trampoline(func: dbops.Function) -> TrampolineFunction: new_func = copy.copy(func) schema, name = func.name namespace = V('') assert schema.endswith(namespace), schema new_func.name = (schema[:-len(namespace)], name) args = [] for arg in (func.args or ()): if isinstance(arg, str): args.append(arg) else: assert arg[0] args.append(arg[0]) args = [qi(arg) for arg in args] if func.has_variadic: args[-1] = f'VARIADIC {args[-1]}' new_func.text = f'select {q(*func.name)}({", ".join(args)})' new_func.language = 'sql' new_func.strict = False if isinstance(func, VersionedFunction) and func.wrapper_volatility: new_func.volatility = func.wrapper_volatility return TrampolineFunction(new_func.name, new_func) def make_table_trampoline(fullname: tuple[str, str]) -> TrampolineView: schema, name = fullname namespace = V('') assert schema.endswith(namespace), schema new_name = (schema[:-len(namespace)], name) return TrampolineView(new_name, fullname) def make_view_trampoline(view: dbops.View) -> TrampolineView: return make_table_trampoline(view.name) ================================================ FILE: edb/pgsql/types.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2010-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import dataclasses import uuid from typing import Literal, Optional, cast, overload from edb.common.typeutils import not_none from edb.common import lru from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.schema import scalars as s_scalars from edb.schema import objtypes as s_objtypes from edb.schema import name as sn from edb.schema import objects as s_obj from edb.schema import schema as s_schema from edb.schema import types as s_types from edb.schema import pointers as s_pointers from edb.schema import properties as s_properties from . import common base_type_name_map = { s_obj.get_known_type_id('std::str'): ('text',), s_obj.get_known_type_id('std::int64'): ('int8',), s_obj.get_known_type_id('std::int32'): ('int4',), s_obj.get_known_type_id('std::int16'): ('int2',), s_obj.get_known_type_id('std::decimal'): ('numeric',), s_obj.get_known_type_id('std::bigint'): ('edgedbt', 'bigint_t'), s_obj.get_known_type_id('std::bool'): ('bool',), s_obj.get_known_type_id('std::float64'): ('float8',), s_obj.get_known_type_id('std::float32'): ('float4',), s_obj.get_known_type_id('std::uuid'): ('uuid',), s_obj.get_known_type_id('std::datetime'): ('edgedbt', 'timestamptz_t'), s_obj.get_known_type_id('std::duration'): ('edgedbt', 'duration_t',), s_obj.get_known_type_id('std::bytes'): ('bytea',), s_obj.get_known_type_id('std::json'): ('jsonb',), s_obj.get_known_type_id('std::cal::local_datetime'): ('edgedbt', 'timestamp_t'), s_obj.get_known_type_id('std::cal::local_date'): ('edgedbt', 'date_t'), s_obj.get_known_type_id('std::cal::local_time'): ('time',), s_obj.get_known_type_id('std::cal::relative_duration'): ('edgedbt', 'relative_duration_t'), s_obj.get_known_type_id('std::cal::date_duration'): ('edgedbt', 'date_duration_t'), s_obj.get_known_type_id('cfg::memory'): ('edgedbt', 'memory_t'), s_obj.get_known_type_id('std::pg::json'): ('json',), s_obj.get_known_type_id('std::pg::timestamptz'): ('timestamptz',), s_obj.get_known_type_id('std::pg::timestamp'): ('timestamp',), s_obj.get_known_type_id('std::pg::date'): ('date',), s_obj.get_known_type_id('std::pg::interval'): ('interval',), } type_to_range_name_map = { ('int4',): ('int4range',), ('int8',): ('int8range',), ('numeric',): ('numrange',), ('float4',): ('edgedb', 'float32_range_t'), ('float8',): ('edgedb', 'float64_range_t'), ('edgedbt', 'timestamptz_t'): ('edgedb', 'datetime_range_t'), ('edgedbt', 'timestamp_t'): ('edgedb', 'local_datetime_range_t'), # cal::local_date uses the built-in daterange instead of a custom # one that actually uses edgedbt.date_t as its subtype. This is # because cal::local_date is discrete, and its range type should # get canonicalized. Defining a canonicalization function for a # custom range is a big hassle, and daterange already has the # correct canonicalization function ('edgedbt', 'date_t'): ('daterange',), ('timestamptz',): ('tstzrange',), ('timestamp',): ('tsrange',), ('date',): ('daterange',), } # Construct a multirange map based on type_to_range_name_map by replacing # 'range' with 'multirange' in the names. # # The multiranges are created automatically when ranges are created. They # have the same names except with "multi" in front of the "range". type_to_multirange_name_map = {} for key, val in type_to_range_name_map.items(): *pre, name = val pre.append(name.replace('range', 'multirange')) type_to_multirange_name_map[key] = tuple(pre) base_type_name_map_r = { 'character varying': sn.QualName('std', 'str'), 'character': sn.QualName('std', 'str'), 'text': sn.QualName('std', 'str'), 'numeric': sn.QualName('std', 'decimal'), 'edgedbt.bigint_t': sn.QualName('std', 'bigint'), 'bigint_t': sn.QualName('std', 'bigint'), 'int4': sn.QualName('std', 'int32'), 'integer': sn.QualName('std', 'int32'), 'bigint': sn.QualName('std', 'int64'), 'int8': sn.QualName('std', 'int64'), 'int2': sn.QualName('std', 'int16'), 'smallint': sn.QualName('std', 'int16'), 'boolean': sn.QualName('std', 'bool'), 'bool': sn.QualName('std', 'bool'), 'double precision': sn.QualName('std', 'float64'), 'float8': sn.QualName('std', 'float64'), 'real': sn.QualName('std', 'float32'), 'float4': sn.QualName('std', 'float32'), 'uuid': sn.QualName('std', 'uuid'), 'timestamp with time zone': sn.QualName('std', 'datetime'), 'edgedbt.timestamptz_t': sn.QualName('std', 'datetime'), 'timestamptz_t': sn.QualName('std', 'datetime'), 'timestamptz': sn.QualName('std', 'datetime'), 'duration_t': sn.QualName('std', 'duration'), 'edgedbt.duration_t': sn.QualName('std', 'duration'), 'interval': sn.QualName('std', 'duration'), 'bytea': sn.QualName('std', 'bytes'), 'jsonb': sn.QualName('std', 'json'), 'timestamp': sn.QualName('std::cal', 'local_datetime'), 'timestamp_t': sn.QualName('std::cal', 'local_datetime'), 'edgedbt.timestamp_t': sn.QualName('std::cal', 'local_datetime'), 'date': sn.QualName('std::cal', 'local_date'), 'date_t': sn.QualName('std::cal', 'local_date'), 'edgedbt.date_t': sn.QualName('std::cal', 'local_date'), 'time': sn.QualName('std::cal', 'local_time'), 'relative_duration_t': sn.QualName('std::cal', 'relative_duration'), 'edgedbt.relative_duration_t': sn.QualName('std::cal', 'relative_duration'), 'date_duration_t': sn.QualName('std::cal', 'date_duration'), 'edgedbt.date_duration_t': sn.QualName('std::cal', 'date_duration'), 'edgedbt.memory_t': sn.QualName('cfg', 'memory'), 'memory_t': sn.QualName('cfg', 'memory'), 'json': sn.QualName('std::pg', 'json'), } pg_tsvector_typeref = irast.TypeRef( id=uuid.UUID('44d73839-8882-419f-80e5-84f7a3402919'), name_hint=sn.QualName('pg_catalog', 'tsvector'), is_scalar=True, sql_type='pg_catalog.tsvector', ) pg_oid_typeref = irast.TypeRef( id=uuid.UUID('44d73839-8882-419f-80e5-84f7a3402920'), name_hint=sn.QualName('pg_catalog', 'oid'), is_scalar=True, sql_type='pg_catalog.oid', ) pg_langs = { 'simple', 'arabic', 'armenian', 'basque', 'catalan', 'danish', 'dutch', 'english', 'finnish', 'french', 'german', 'greek', 'hindi', 'hungarian', 'indonesian', 'irish', 'italian', 'lithuanian', 'nepali', 'norwegian', 'portuguese', 'romanian', 'russian', 'serbian', 'spanish', 'swedish', 'tamil', 'turkish', 'yiddish', } pg_langs_by_iso_639_3 = { 'ara': 'arabic', 'hye': 'armenian', 'eus': 'basque', 'cat': 'catalan', 'dan': 'danish', 'nld': 'dutch', 'eng': 'english', 'fin': 'finnish', 'fra': 'french', 'deu': 'german', 'ell': 'greek', 'hin': 'hindi', 'hun': 'hungarian', 'ind': 'indonesian', 'gle': 'irish', 'ita': 'italian', 'lit': 'lithuanian', 'npi': 'nepali', 'nor': 'norwegian', 'por': 'portuguese', 'ron': 'romanian', 'rus': 'russian', 'srp': 'serbian', 'spa': 'spanish', 'swe': 'swedish', 'tam': 'tamil', 'tur': 'turkish', 'yid': 'yiddish', } def to_regconfig(language: str) -> str: "Analogous to edgedb.fts_to_regconfig function in metaschema" language = language.lower() if language.startswith('xxx_'): return language[4:] else: return pg_langs_by_iso_639_3.get(language, language) def is_builtin_scalar( schema: s_schema.Schema, scalar: s_scalars.ScalarType ) -> bool: return scalar.id in base_type_name_map def type_has_stable_oid(typ: s_types.Type) -> bool: pg_type = base_type_name_map.get(typ.id) return pg_type is not None and len(pg_type) == 1 def get_scalar_base( schema: s_schema.Schema, scalar: s_scalars.ScalarType ) -> tuple[str, ...]: if base := base_type_name_map.get(scalar.id): return base for ancestor in scalar.get_ancestors(schema).objects(schema): if not ancestor.get_abstract(schema): # Check if base is fundamental, if not, then it is # another domain. if base := base_type_name_map.get(ancestor.id): pass elif typstr := ancestor.resolve_sql_type(schema): base = tuple(typstr.split('.')) else: base = common.get_backend_name( schema, ancestor, catenate=False) assert base return base raise ValueError(f'cannot determine backend type for scalar type ' f'{scalar.get_name(schema)}') def pg_type_from_scalar( schema: s_schema.Schema, scalar: s_scalars.ScalarType ) -> tuple[str, ...]: if scalar.is_polymorphic(schema): return ('anynonarray',) column_type = base_type_name_map.get(scalar.id) if column_type: pass elif typstr := scalar.resolve_sql_type(schema): column_type = tuple(typstr.split('.')) else: column_type = common.get_backend_name(schema, scalar, catenate=False) assert column_type return column_type def pg_type_array(tp: tuple[str, ...]) -> tuple[str, ...]: if len(tp) == 1: return (tp[0] + '[]',) else: return (tp[0], tp[1] + '[]') def pg_type_range(tp: tuple[str, ...]) -> tuple[str, ...]: return type_to_range_name_map[tp] def pg_type_multirange(tp: tuple[str, ...]) -> tuple[str, ...]: return type_to_multirange_name_map[tp] def pg_type_from_object( schema: s_schema.Schema, obj: s_obj.Object, persistent_tuples: bool = False ) -> tuple[str, ...]: if isinstance(obj, s_scalars.ScalarType): return pg_type_from_scalar(schema, obj) elif isinstance(obj, s_types.Type) and obj.is_anytuple(schema): return ('record',) elif isinstance(obj, s_types.Tuple): if persistent_tuples: return cast( tuple[str, ...], common.get_tuple_backend_name(obj.id, catenate=False), ) else: return ('record',) elif isinstance(obj, s_types.Array): if obj.is_polymorphic(schema): return ('anyarray',) else: tp = pg_type_from_object( schema, obj.get_subtypes(schema)[0], persistent_tuples=persistent_tuples) return pg_type_array(tp) elif isinstance(obj, s_types.Range): if obj.is_polymorphic(schema): return ('anyrange',) else: tp = pg_type_from_object( schema, obj.get_subtypes(schema)[0], persistent_tuples=persistent_tuples) return pg_type_range(tp) elif isinstance(obj, s_types.MultiRange): if obj.is_polymorphic(schema): return ('anymultirange',) else: tp = pg_type_from_object( schema, obj.get_subtypes(schema)[0], persistent_tuples=persistent_tuples) return pg_type_multirange(tp) elif isinstance(obj, s_objtypes.ObjectType): return ('uuid',) elif isinstance(obj, s_types.Type) and obj.is_any(schema): return ('anyelement',) else: raise ValueError(f'could not determine PG type for {obj!r}') def pg_type_from_ir_typeref( ir_typeref: irast.TypeRef, *, serialized: bool = False, persistent_tuples: bool = False, ) -> tuple[str, ...]: if irtyputils.is_array(ir_typeref): if (irtyputils.is_generic(ir_typeref) or (irtyputils.is_abstract(ir_typeref.subtypes[0]) and irtyputils.is_scalar(ir_typeref.subtypes[0]))): return ('anyarray',) elif irtyputils.is_array(ir_typeref.subtypes[0]): return ('record[]',) else: tp = pg_type_from_ir_typeref( ir_typeref.subtypes[0], serialized=serialized, persistent_tuples=persistent_tuples) return pg_type_array(tp) elif irtyputils.is_range(ir_typeref): if (irtyputils.is_generic(ir_typeref) or (irtyputils.is_abstract(ir_typeref.subtypes[0]) and irtyputils.is_scalar(ir_typeref.subtypes[0]))): return ('anyrange',) else: tp = pg_type_from_ir_typeref( ir_typeref.subtypes[0], serialized=serialized, persistent_tuples=persistent_tuples) return pg_type_range(tp) elif irtyputils.is_multirange(ir_typeref): if (irtyputils.is_generic(ir_typeref) or (irtyputils.is_abstract(ir_typeref.subtypes[0]) and irtyputils.is_scalar(ir_typeref.subtypes[0]))): return ('anymultirange',) else: tp = pg_type_from_ir_typeref( ir_typeref.subtypes[0], serialized=serialized, persistent_tuples=persistent_tuples) return pg_type_multirange(tp) elif irtyputils.is_anytuple(ir_typeref): return ('record',) elif irtyputils.is_tuple(ir_typeref): if ir_typeref.material_type: material = ir_typeref.material_type else: material = ir_typeref if persistent_tuples or material.in_schema: return cast( tuple[str, str], common.get_tuple_backend_name(material.id, catenate=False), ) else: return ('record',) elif irtyputils.is_any(ir_typeref) or irtyputils.is_anyobject(ir_typeref): return ('anyelement',) else: if ir_typeref.material_type: material = ir_typeref.material_type else: material = ir_typeref if irtyputils.is_object(material): if serialized: return ('record',) else: return ('uuid',) elif irtyputils.is_abstract(material): return ('anynonarray',) elif material.custom_sql_serialization and serialized: return tuple(material.custom_sql_serialization.split('.')) elif material.sql_type: return tuple(material.sql_type.split('.')) else: pg_type = base_type_name_map.get(material.id) if pg_type is None: real_name_hint = material.orig_name_hint or material.name_hint assert isinstance(real_name_hint, sn.QualName) # User-defined scalar type pg_type = common.get_scalar_backend_name( material.id, real_name_hint.module, catenate=False) return pg_type TableInfo = tuple[tuple[str, str], str, str] def _source_table_info( schema: s_schema.Schema, pointer: s_pointers.Pointer, versioned: bool, ) -> TableInfo: table = common.get_backend_name( schema, not_none(pointer.get_source(schema)), catenate=False, versioned=versioned, ) ptr_name = pointer.get_shortname(schema).name if ptr_name.startswith('__') or ptr_name == 'id': col_name = ptr_name else: col_name = str(pointer.id) table_type = 'ObjectType' return table, table_type, col_name def _pointer_table_info( schema: s_schema.Schema, pointer: s_pointers.Pointer, versioned: bool, ) -> TableInfo: table = common.get_backend_name( schema, pointer, catenate=False, versioned=versioned) col_name = 'target' table_type = 'link' return table, table_type, col_name def _resolve_type( schema: s_schema.Schema, pointer: s_pointers.Pointer ) -> tuple[str, ...]: column_type: tuple[str, ...] pointer_target = pointer.get_target(schema) if pointer_target is not None: if pointer_target.is_object_type(): column_type = ('uuid',) elif pointer_target.is_tuple(schema): column_type = common.get_backend_name( schema, pointer_target, catenate=False ) else: column_type = pg_type_from_object( schema, pointer_target, persistent_tuples=True ) else: # The target may not be known in circular object-to-object # linking scenarios. column_type = ('uuid',) return column_type def _pointer_storable_in_source( schema: s_schema.Schema, pointer: s_pointers.Pointer ) -> bool: return pointer.singular(schema) def _pointer_storable_in_pointer( schema: s_schema.Schema, pointer: s_pointers.Pointer ) -> bool: return not pointer.singular(schema) or pointer.has_user_defined_properties( schema ) @lru.per_job_lru_cache() def get_pointer_storage_info( pointer: s_pointers.Pointer, *, schema: s_schema.Schema, source: Optional[s_obj.InheritingObject] = None, resolve_type: bool = True, versioned: bool = True, link_bias: bool = False, ) -> PointerStorageInfo: assert not pointer.is_non_concrete( schema ), "only specialized pointers can be stored" if pointer.get_computable(schema): material_ptrcls = None else: schema, material_ptrcls = pointer.material_type(schema) if material_ptrcls is not None: pointer = material_ptrcls if source is None: source = pointer.get_source(schema) is_lprop = pointer.is_link_property(schema) if resolve_type and schema is None: msg = 'PointerStorageInfo needs a schema to resolve column_type' raise ValueError(msg) if is_lprop and pointer.issubclass( schema, schema.get('std::target', type=s_obj.SubclassableObject) ): # Normalize link@target to link assert isinstance(source, s_pointers.Pointer) pointer = source is_lprop = False if isinstance(pointer, irast.TupleIndirectionLink): table = None table_type = 'ObjectType' col_name = pointer.get_shortname(schema).name elif is_lprop: assert source table = common.get_backend_name( schema, source, catenate=False, versioned=versioned) table_type = 'link' if pointer.get_shortname(schema).name == 'source': col_name = 'source' else: col_name = str(pointer.id) else: if isinstance(source, s_scalars.ScalarType): # This is a pseudo-link on an scalar (__type__) table = None table_type = 'ObjectType' col_name = None elif _pointer_storable_in_source(schema, pointer) and not link_bias: table, table_type, col_name = _source_table_info( schema, pointer, versioned=versioned ) elif _pointer_storable_in_pointer(schema, pointer): table, table_type, col_name = _pointer_table_info( schema, pointer, versioned=versioned, ) else: return None # type: ignore if resolve_type: column_type = _resolve_type(schema, pointer) else: column_type = None return PointerStorageInfo( table_name=table, table_type=table_type, column_name=col_name, # type: ignore column_type=column_type, # type: ignore ) @dataclasses.dataclass(kw_only=True, eq=False, slots=True) class PointerStorageInfo: table_name: Optional[tuple[str, str]] table_type: str column_name: str column_type: tuple[str, str] @overload def get_ptrref_storage_info( ptrref: irast.BasePointerRef, *, resolve_type: bool = ..., link_bias: Literal[False] = False, allow_missing: Literal[False] = False, versioned: bool = True, ) -> PointerStorageInfo: ... @overload def get_ptrref_storage_info( ptrref: irast.BasePointerRef, *, resolve_type: bool = ..., link_bias: bool = ..., allow_missing: bool = ..., versioned: bool = True, ) -> Optional[PointerStorageInfo]: ... def get_ptrref_storage_info( ptrref: irast.BasePointerRef, *, resolve_type: bool = True, link_bias: bool = False, allow_missing: bool = False, # XXX versioned: bool = True, ) -> Optional[PointerStorageInfo]: # We wrap the real version because of bad mypy interactions # with lru_cache. return _get_ptrref_storage_info( ptrref, resolve_type=resolve_type, link_bias=link_bias, allow_missing=allow_missing, versioned=versioned, ) @lru.per_job_lru_cache() def _get_ptrref_storage_info( ptrref: irast.BasePointerRef, *, resolve_type: bool = True, link_bias: bool = False, allow_missing: bool = False, versioned: bool = False, ) -> Optional[PointerStorageInfo]: if ptrref.material_ptr: ptrref = ptrref.material_ptr if ptrref.out_cardinality is None: # Guard against the IR generator failure to populate the PointerRef # cardinality correctly. raise RuntimeError( f'cannot determine backend storage parameters for the ' f'{ptrref.name!r} pointer: the cardinality is not known') target = ptrref.out_target if isinstance( ptrref, (irast.TupleIndirectionPointerRef, irast.SpecialPointerRef) ): table = None table_type = 'ObjectType' col_name = ptrref.shortname.name elif ptrref.source_ptr is not None: # link property assert isinstance(ptrref, irast.PointerRef) source_ptr = ptrref.source_ptr table = common.get_pointer_backend_name( source_ptr.id, source_ptr.name.module, catenate=False, versioned=versioned, ) table_type = 'link' if ptrref.shortname.name in ('source', 'target'): col_name = ptrref.shortname.name else: col_name = str(ptrref.id) else: assert isinstance(ptrref, irast.PointerRef) source = ptrref.out_source if irtyputils.is_scalar(source): # This is a pseudo-link on an scalar (__type__) table = None table_type = 'ObjectType' col_name = None elif _ptrref_storable_in_source(ptrref) and not link_bias: assert isinstance(source.name_hint, sn.QualName) # XXX: TRAMPOLINE table = common.get_objtype_backend_name( source.id, source.name_hint.module, catenate=False, versioned=versioned, ) ptrname = ptrref.shortname.name if ptrname.startswith('__') or ptrname == 'id': col_name = ptrname else: col_name = str(ptrref.id) table_type = 'ObjectType' elif _ptrref_storable_in_pointer(ptrref): table = common.get_pointer_backend_name( ptrref.id, ptrref.name.module, catenate=False, versioned=versioned) col_name = 'target' table_type = 'link' elif not link_bias and not allow_missing: raise RuntimeError( f'cannot determine backend storage parameters for the ' f'{ptrref.name} pointer: unexpected characteristics') else: return None column_type: tuple[str, ...] | None if resolve_type: if irtyputils.is_object(target): column_type = ('uuid',) else: column_type = pg_type_from_ir_typeref( target, persistent_tuples=True) else: column_type = None return PointerStorageInfo( table_name=table, table_type=table_type, column_name=col_name, # type: ignore column_type=column_type, # type: ignore ) def _ptrref_storable_in_source(ptrref: irast.BasePointerRef) -> bool: return ptrref.out_cardinality.is_single() def _ptrref_storable_in_pointer(ptrref: irast.BasePointerRef) -> bool: if ptrref.union_components: return all( _ptrref_storable_in_pointer(c) for c in ptrref.union_components ) else: return ( ptrref.out_cardinality.is_multi() or ptrref.has_properties ) def has_table( obj: Optional[s_obj.InheritingObject], schema: s_schema.Schema ) -> bool: """Returns True for all schema objects that need a postgres table""" assert obj if isinstance(obj, s_objtypes.ObjectType): return not ( obj.is_compound_type(schema) or obj.get_is_derived(schema) or obj.is_view(schema) ) assert isinstance(obj, s_pointers.Pointer) if obj.is_pure_computable(schema) or obj.get_is_derived(schema): return False elif obj.is_non_concrete(schema): return ( not isinstance(obj, s_properties.Property) and str(obj.get_name(schema)) != 'std::link' ) elif obj.is_link_property(schema): return not obj.singular(schema) elif not has_table(obj.get_source(schema), schema): return False else: ptr_stor_info = get_pointer_storage_info( obj, resolve_type=False, schema=schema, link_bias=True) return ( ptr_stor_info is not None and ptr_stor_info.table_type == 'link' ) ================================================ FILE: edb/protocol/.gitignore ================================================ /*.c ================================================ FILE: edb/protocol/README ================================================ Protocol documentation and testing utilities. The actual protocol implementation is in the edb.server package. ================================================ FILE: edb/protocol/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import enum from . import messages from . import render_utils from .messages import * # NoQA def render( obj: type[enum.Enum] | type[messages.Struct] ) -> str: if issubclass(obj, messages.Struct): return obj.render() else: assert issubclass(obj, enum.Enum) buf = render_utils.RenderBuffer() buf.write(f'enum {obj.__name__} {{') with buf.indent(): for membername, member in obj.__members__.items(): buf.write( f'{membername.ljust(messages._PAD - 1)} = ' f'{member.value:#x};' ) buf.write('};') return str(buf) ================================================ FILE: edb/protocol/enums.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import enum class Cardinality(enum.Enum): # Cardinality isn't applicable for the query: # * the query is a command like CONFIGURE that # does not return any data; # * the query is composed of multiple queries. NO_RESULT = 0x6e # Cardinality is 1 or 0 AT_MOST_ONE = 0x6f # Cardinality is 1 ONE = 0x41 # Cardinality is >= 0 MANY = 0x6d # Cardinality is >= 1 AT_LEAST_ONE = 0x4d ================================================ FILE: edb/protocol/messages.py ================================================ # mypy: ignore-errors # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import enum import io import typing from edb.common import binwrapper from .enums import Cardinality from . import render_utils _PAD = 16 class CType: pass class Scalar(CType): cname = None def __init__( self, doc: typing.Optional[str] = None, *, default: typing.Any = None ) -> None: self.doc = doc self.default = default def validate(self, val: typing.Any) -> bool: raise NotImplementedError def parse(self, buffer: binwrapper.BinWrapper) -> any: raise NotImplementedError def dump(self, val: typing.Any, buffer: binwrapper.BinWrapper) -> None: raise NotImplementedError def render_field( self, fieldname: str, buf: render_utils.RenderBuffer ) -> None: cname = self.cname if cname is None: raise NotImplementedError if self.default and isinstance(self.default, int): buf.write( f'{cname.ljust(_PAD - 1)} {fieldname} = {self.default:#x};') elif self.default: buf.write( f'{cname.ljust(_PAD - 1)} {fieldname} = {self.default};') else: buf.write( f'{cname.ljust(_PAD - 1)} {fieldname};') class UInt8(Scalar): cname = 'uint8' def validate(self, val: typing.Any) -> bool: return isinstance(val, int) and (0 <= val <= 255) def parse(self, buffer: binwrapper.BinWrapper) -> any: return buffer.read_ui8() def dump(self, val: int, buffer: binwrapper.BinWrapper) -> None: buffer.write_ui8(val) class UInt16(Scalar): cname = 'uint16' def validate(self, val: typing.Any) -> bool: return isinstance(val, int) and (0 <= val <= 2 ** 16 - 1) def parse(self, buffer: binwrapper.BinWrapper) -> any: return buffer.read_ui16() def dump(self, val: int, buffer: binwrapper.BinWrapper) -> None: buffer.write_ui16(val) class UInt32(Scalar): cname = 'uint32' def validate(self, val: typing.Any) -> bool: return isinstance(val, int) and (0 <= val <= 2 ** 32 - 1) def parse(self, buffer: binwrapper.BinWrapper) -> any: return buffer.read_ui32() def dump(self, val: int, buffer: binwrapper.BinWrapper) -> None: buffer.write_ui32(val) class UInt64(Scalar): cname = 'uint64' def validate(self, val: typing.Any) -> bool: return isinstance(val, int) and (0 <= val <= 2 ** 64 - 1) def parse(self, buffer: binwrapper.BinWrapper) -> any: return buffer.read_ui64() def dump(self, val: int, buffer: binwrapper.BinWrapper) -> None: buffer.write_ui64(val) class Bytes(Scalar): cname = 'bytes' def validate(self, val: typing.Any) -> bool: return isinstance(val, bytes) def parse(self, buffer: binwrapper.BinWrapper) -> any: return buffer.read_len32_prefixed_bytes() def dump(self, val: bytes, buffer: binwrapper.BinWrapper) -> None: buffer.write_len32_prefixed_bytes(val) class String(Scalar): cname = 'string' def validate(self, val: typing.Any) -> bool: return isinstance(val, str) def parse(self, buffer: binwrapper.BinWrapper) -> any: return buffer.read_len32_prefixed_bytes().decode('utf-8') def dump(self, val: str, buffer: binwrapper.BinWrapper) -> None: buffer.write_len32_prefixed_bytes(val.encode('utf-8')) class UUID(Scalar): cname = 'uuid' def validate(self, val: typing.Any) -> bool: return isinstance(val, bytes) and len(val) == 16 def parse(self, buffer: binwrapper.BinWrapper) -> any: return buffer.read_bytes(16) def dump(self, val: bytes, buffer: binwrapper.BinWrapper) -> None: assert isinstance(val, bytes) and len(val) == 16 buffer.write_bytes(val) class ArrayOf(CType): def __init__( self, length_in: type[CType], element: CType | type[Struct], doc: str = None, ) -> None: self.length_in = length_in() self.element = element self.doc = doc def validate(self, val: typing.Any) -> bool: if not isinstance(val, list) or not self.length_in.validate(len(val)): return False if isinstance(self.element, CType): return all(self.element.validate(x) for x in val) else: return all(isinstance(x, self.element) for x in val) def parse(self, buffer: binwrapper.BinWrapper) -> any: length = self.length_in.parse(buffer) result = [] for _ in range(length): result.append(self.element.parse(buffer)) return result def dump(self, val: list, buffer: binwrapper.BinWrapper) -> None: self.length_in.dump(len(val), buffer) for el in val: self.element.dump(el, buffer) def render_field( self, fieldname: str, buf: render_utils.RenderBuffer ) -> None: self.length_in.render_field(f'num_{fieldname}', buf) self.element.render_field(f'{fieldname}[num_{fieldname}]', buf) class FixedArrayOf(CType): def __init__( self, length: int, element: CType | type[Struct], doc: typing.Optional[str]=None ) -> None: self.length = length self.element = element self.doc = doc def validate(self, val: typing.Any) -> bool: if not isinstance(val, list) or len(val) != self.length: return False if isinstance(self.element, CType): return all(self.element.validate(x) for x in val) else: return all(isinstance(x, self.element) for x in val) def parse(self, buffer: binwrapper.BinWrapper) -> any: result = [] for _ in range(self.length): result.append(self.element.parse(buffer)) return result def dump(self, val: list, buffer: binwrapper.BinWrapper) -> None: assert len(val) == self.length self.length_in.dump(self.length, buffer) for el in val: self.element.dump(el, buffer) def render_field( self, fieldname: str, buf: render_utils.RenderBuffer ) -> None: self.element.render_field(f'{fieldname}[{self.length}]', buf) class EnumOf(CType): def __init__( self, value_in: type[Scalar], enum: type[enum.Enum], doc: typing.Optional[str]=None, ) -> None: self.value_in = value_in() self.enum = enum self.doc = doc def validate(self, val: typing.Any) -> bool: if isinstance(val, self.enum): return True if not self.value_in.validate(val): return False try: self.enum(val) except ValueError: return False else: return True def parse(self, buffer: binwrapper.BinWrapper) -> any: result = self.value_in.parse(buffer) return self.enum(result) def dump(self, val: typing.Any, buffer: binwrapper.BinWrapper) -> None: self.value_in.dump(val.value, buffer) def render_field( self, fieldname: str, buf: render_utils.RenderBuffer ) -> None: typename = f'{self.value_in.cname}<{self.enum.__name__}>' buf.write(f'{typename.ljust(_PAD - 1)} {fieldname};') class Struct: _fields: dict[str, CType | type[Struct]] = {} def __init_subclass__(cls, *, abstract=False): if abstract: return fields = {} for name in cls.__dict__: attr = cls.__dict__[name] if name.startswith('__') or callable(attr): continue if not isinstance(attr, CType): raise TypeError( f'field {cls.__name__}.{name!r} must be a Type') else: fields[name] = attr cls._fields = fields def __init__(self, **args: typing.Any): for fieldname in ['mtype', 'message_length']: if fieldname in args: raise ValueError( f'cannot construct instance of {type(self).__name__}: ' f'{fieldname!r} field is not supposed to be passed to ' f'the constructor') for fieldname, field in type(self)._fields.items(): if fieldname in ['mtype', 'message_length']: continue try: arg = args[fieldname] except KeyError: raise ValueError( f'cannot construct instance of {type(self).__name__}: ' f'the {fieldname!r} field is missing') if ( isinstance(field, CType) and not field.validate(arg) or isinstance(field, type) and not isinstance(arg, field) ): raise ValueError( f'cannot construct instance of {type(self).__name__}: ' f'invalid value {arg!r} for the {fieldname!r} field') setattr(self, fieldname, arg) @classmethod def parse(cls, buffer: binwrapper.BinWrapper) -> Struct: kwargs: dict[str, any] = {} for fieldname, field in cls._fields.items(): if fieldname in {'mtype', 'message_length'}: continue kwargs[fieldname] = field.parse(buffer) return cls(**kwargs) @classmethod def dump(cls, val: Struct, buffer: binwrapper.BinWrapper) -> None: fields = val._fields for fieldname, field in fields.items(): if fieldname in {'mtype', 'message_length'}: continue fval = getattr(val, fieldname) field.dump(fval, buffer) def __repr__(self): res = [f'<{type(self).__name__}'] for fieldname in type(self)._fields: if fieldname in {'mtype', 'message_length'}: continue val = getattr(self, fieldname) res.append(f' {fieldname}={val!r}') res.append('>') return ''.join(res) @classmethod def render_field( cls, fieldname: str, buf: render_utils.RenderBuffer ) -> None: buf.write(f'{cls.__name__.ljust(_PAD - 1)} {fieldname};') @classmethod def render(cls) -> str: buf = render_utils.RenderBuffer() buf.write(f'struct {cls.__name__} {{') with buf.indent(): for fieldname, field in cls._fields.items(): if field.doc: buf.write_comment(field.doc) field.render_field(fieldname, buf) buf.newline() if buf.lastline() == '': buf.popline() buf.write('};') return str(buf) class KeyValue(Struct): code = UInt16('Key code (specific to the type of the Message).') value = Bytes('Value data.') class Annotation(Struct): name = String('Name of the annotation') value = String('Value of the annotation (in JSON format).') KeyValues = ArrayOf(UInt16, KeyValue, 'A set of key-value pairs.') Annotations = ArrayOf(UInt16, Annotation, 'A set of annotations.') MessageLength = UInt32('Length of message contents in bytes, including self.') MessageType = (lambda letter: UInt8(f"Message type ('{letter}').", default=ord(letter))) class Message(Struct, abstract=True): pass class ServerMessage(Message, abstract=True): index: dict[int, list[type[ServerMessage]]] = {} def __init_subclass__(cls): super().__init_subclass__() if 'mtype' not in cls._fields: raise TypeError(f'mtype field is missing for {cls}') if 'message_length' not in cls._fields: raise TypeError(f'message_length field is missing for {cls}') cls.index.setdefault(cls._fields['mtype'].default, []).append(cls) @classmethod def parse(cls, mtype: int, data: bytes) -> ServerMessage: iobuf = io.BytesIO(data) buffer = binwrapper.BinWrapper(iobuf) kwargs: dict[str, any] = {} msg_types = cls.index.get(mtype) if not msg_types: raise ValueError(f"unspecced message type {chr(mtype)!r}") if len(msg_types) > 1: raise ValueError(f"multiple specs for message type {chr(mtype)!r}") msg_type = msg_types[0] for fieldname, field in msg_type._fields.items(): if fieldname in {'mtype', 'message_length'}: continue kwargs[fieldname] = field.parse(buffer) if len(iobuf.read(1)): raise ValueError( f'buffer is not empty after parsing {chr(mtype)!r} message') return msg_type(**kwargs) class ClientMessage(Message, abstract=True): def __init_subclass__(cls): super().__init_subclass__() if 'mtype' not in cls._fields: raise TypeError(f'mtype field is missing for {cls}') if 'message_length' not in cls._fields: raise TypeError(f'message_length field is missing for {cls}') def dump(self) -> bytes: iobuf = io.BytesIO() buf = binwrapper.BinWrapper(iobuf) fields = type(self)._fields for fieldname, field in fields.items(): if fieldname in {'mtype', 'message_length'}: continue val = getattr(self, fieldname) field.dump(val, buf) dumped = iobuf.getvalue() return ( fields['mtype'].default.to_bytes(1, 'big') + (len(dumped) + 4).to_bytes(4, 'big') + dumped ) ############################################################################### # Protocol Messages Definitions ############################################################################### class InputLanguage(enum.Enum): EDGEQL = 0x45 # b'E' SQL = 0x53 # b'S' class OutputFormat(enum.Enum): BINARY = 0x62 JSON = 0x6a JSON_ELEMENTS = 0x4a NONE = 0x6e class Capability(enum.IntFlag): MODIFICATIONS = 1 << 0 # noqa SESSION_CONFIG = 1 << 1 # noqa TRANSACTION = 1 << 2 # noqa DDL = 1 << 3 # noqa PERSISTENT_CONFIG = 1 << 4 # noqa ALL = 0xFFFFFFFFFFFFFFFF # noqa class CompilationFlag(enum.IntFlag): INJECT_OUTPUT_TYPE_IDS = 1 << 0 # noqa INJECT_OUTPUT_TYPE_NAMES = 1 << 1 # noqa INJECT_OUTPUT_OBJECT_IDS = 1 << 2 # noqa class DumpFlag(enum.IntFlag): DUMP_SECRETS = 1 << 0 # noqa class ErrorSeverity(enum.Enum): ERROR = 120 FATAL = 200 PANIC = 255 class ErrorResponse(ServerMessage): mtype = MessageType('E') message_length = MessageLength severity = EnumOf(UInt8, ErrorSeverity, 'Message severity.') error_code = UInt32('Message code.') message = String('Error message.') attributes = ArrayOf(UInt16, KeyValue, 'Error attributes.') class MessageSeverity(enum.Enum): DEBUG = 20 INFO = 40 NOTICE = 60 WARNING = 80 class LogMessage(ServerMessage): mtype = MessageType('L') message_length = MessageLength severity = EnumOf(UInt8, MessageSeverity, 'Message severity.') code = UInt32('Message code.') text = String('Message text.') annotations = ArrayOf(UInt16, Annotation, 'Message annotations.') class TransactionState(enum.Enum): NOT_IN_TRANSACTION = 0x49 IN_TRANSACTION = 0x54 IN_FAILED_TRANSACTION = 0x45 class ReadyForCommand(ServerMessage): mtype = MessageType('Z') message_length = MessageLength annotations = Annotations transaction_state = EnumOf(UInt8, TransactionState, 'Transaction state.') class RestoreReady(ServerMessage): mtype = MessageType('+') message_length = MessageLength annotations = Annotations jobs = UInt16('Number of parallel jobs for restore, currently always "1"') class DataElement(Struct): data = ArrayOf(UInt32, UInt8(), 'Encoded output data.') class CommandComplete(ServerMessage): mtype = MessageType('C') message_length = MessageLength annotations = Annotations capabilities = EnumOf(UInt64, Capability, 'A bit mask of allowed capabilities.') status = String('Command status.') state_typedesc_id = UUID('State data descriptor ID.') state_data = Bytes('Encoded state data.') class CommandDataDescription(ServerMessage): mtype = MessageType('T') message_length = MessageLength annotations = Annotations capabilities = EnumOf(UInt64, Capability, 'A bit mask of allowed capabilities.') result_cardinality = EnumOf( UInt8, Cardinality, 'Actual result cardinality.') input_typedesc_id = UUID('Argument data descriptor ID.') input_typedesc = Bytes('Argument data descriptor.') output_typedesc_id = UUID('Output data descriptor ID.') output_typedesc = Bytes('Output data descriptor.') class StateDataDescription(ServerMessage): mtype = MessageType('s') message_length = MessageLength typedesc_id = UUID('Updated state data descriptor ID.') typedesc = Bytes('State data descriptor.') class Data(ServerMessage): mtype = MessageType('D') message_length = MessageLength data = ArrayOf( UInt16, DataElement, 'Encoded output data array. The array is currently always of size 1.' ) class DumpTypeInfo(Struct): type_name = String() type_class = String() type_id = UUID() class DumpObjectDesc(Struct): object_id = UUID() description = Bytes() dependencies = ArrayOf(UInt16, UUID()) class DumpHeader(ServerMessage): mtype = MessageType('@') message_length = MessageLength attributes = KeyValues major_ver = UInt16('Major version of Gel.') minor_ver = UInt16('Minor version of Gel.') schema_ddl = String('Schema.') types = ArrayOf(UInt32, DumpTypeInfo, 'Type identifiers.') descriptors = ArrayOf(UInt32, DumpObjectDesc, 'Object descriptors.') class DumpBlock(ServerMessage): mtype = MessageType('=') message_length = MessageLength attributes = KeyValues class ServerKeyData(ServerMessage): mtype = MessageType('K') message_length = MessageLength data = FixedArrayOf(32, UInt8(), 'Key data.') class ParameterStatus(ServerMessage): mtype = MessageType('S') message_length = MessageLength name = Bytes('Parameter name.') value = Bytes('Parameter value.') class ParameterStatus_SystemConfig(Struct): typedesc = ArrayOf(UInt32, UInt8(), 'Type descriptor prefixed with ' 'type descriptor uuid.') data = FixedArrayOf(1, DataElement, 'Configuration settings data.') class ProtocolExtension(Struct): name = String('Extension name.') annotations = ArrayOf(UInt16, Annotation, 'A set of extension annotaions.') class ServerHandshake(ServerMessage): mtype = MessageType('v') message_length = MessageLength major_ver = UInt16('maximum supported or client-requested ' 'protocol major version, whichever is greater.') minor_ver = UInt16('maximum supported or client-requested ' 'protocol minor version, whichever is greater.') extensions = ArrayOf( UInt16, ProtocolExtension, 'Supported protocol extensions.') class AuthenticationOK(ServerMessage): mtype = MessageType('R') message_length = MessageLength auth_status = UInt32('Specifies that this message contains ' 'a successful authentication indicator.', default=0x0) class AuthenticationRequiredSASLMessage(ServerMessage): mtype = MessageType('R') message_length = MessageLength auth_status = UInt32('Specifies that this message contains ' 'a SASL authentication request.', default=0x0A) methods = ArrayOf(UInt32, String(), 'A list of supported SASL authentication methods.') class AuthenticationSASLContinue(ServerMessage): mtype = MessageType('R') message_length = MessageLength auth_status = UInt32('Specifies that this message contains ' 'a SASL challenge.', default=0x0B) sasl_data = Bytes('Mechanism-specific SASL data.') class AuthenticationSASLFinal(ServerMessage): mtype = MessageType('R') message_length = MessageLength auth_status = UInt32('Specifies that SASL authentication ' 'has completed.', default=0x0C) sasl_data = Bytes() class Dump(ClientMessage): mtype = MessageType('>') message_length = MessageLength annotations = Annotations flags = EnumOf(UInt64, DumpFlag, 'A bit mask of dump options.') class Sync(ClientMessage): mtype = MessageType('S') message_length = MessageLength class Flush(ClientMessage): mtype = MessageType('H') message_length = MessageLength class Restore(ClientMessage): mtype = MessageType('<') message_length = MessageLength attributes = KeyValues jobs = UInt16( 'Number of parallel jobs for restore (only "1" is supported)') header_data = Bytes( 'Original DumpHeader packet data excluding mtype and message_length') class RestoreBlock(ClientMessage): mtype = MessageType('=') message_length = MessageLength block_data = Bytes( 'Original DumpBlock packet data excluding mtype and message_length') class RestoreEof(ClientMessage): mtype = MessageType('.') message_length = MessageLength class Parse(ClientMessage): mtype = MessageType('P') message_length = MessageLength annotations = Annotations allowed_capabilities = EnumOf(UInt64, Capability, 'A bit mask of allowed capabilities.') compilation_flags = EnumOf(UInt64, CompilationFlag, 'A bit mask of query options.') implicit_limit = UInt64('Implicit LIMIT clause on returned sets.') input_language = EnumOf(UInt8, InputLanguage, 'Command source language.') output_format = EnumOf(UInt8, OutputFormat, 'Data output format.') expected_cardinality = EnumOf(UInt8, Cardinality, 'Expected result cardinality.') command_text = String('Command text.') state_typedesc_id = UUID('State data descriptor ID.') state_data = Bytes('Encoded state data.') class Execute(ClientMessage): mtype = MessageType('O') message_length = MessageLength annotations = Annotations allowed_capabilities = EnumOf(UInt64, Capability, 'A bit mask of allowed capabilities.') compilation_flags = EnumOf(UInt64, CompilationFlag, 'A bit mask of query options.') implicit_limit = UInt64('Implicit LIMIT clause on returned sets.') input_language = EnumOf(UInt8, InputLanguage, 'Command source language.') output_format = EnumOf(UInt8, OutputFormat, 'Data output format.') expected_cardinality = EnumOf(UInt8, Cardinality, 'Expected result cardinality.') command_text = String('Command text.') state_typedesc_id = UUID('State data descriptor ID.') state_data = Bytes('Encoded state data.') input_typedesc_id = UUID('Argument data descriptor ID.') output_typedesc_id = UUID('Output data descriptor ID.') arguments = Bytes('Encoded argument data.') class ConnectionParam(Struct): name = String() value = String() class ClientHandshake(ClientMessage): mtype = MessageType('V') message_length = MessageLength major_ver = UInt16('Requested protocol major version.') minor_ver = UInt16('Requested protocol minor version.') params = ArrayOf(UInt16, ConnectionParam, 'Connection parameters.') extensions = ArrayOf( UInt16, ProtocolExtension, 'Requested protocol extensions.') class Terminate(ClientMessage): mtype = MessageType('X') message_length = MessageLength class AuthenticationSASLInitialResponse(ClientMessage): mtype = MessageType('p') message_length = MessageLength method = String('Name of the SASL authentication mechanism ' 'that the client selected.') sasl_data = Bytes('Mechanism-specific "Initial Response" data.') class AuthenticationSASLResponse(ClientMessage): mtype = MessageType('r') message_length = MessageLength sasl_data = Bytes('Mechanism-specific response data.') ================================================ FILE: edb/protocol/protocol.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from gel.protocol.asyncio_proto cimport AsyncIOProtocol cdef class Protocol(AsyncIOProtocol): cdef: public bytes last_state cdef parse_command_complete_message(self) cdef encode_state(self, state) cdef class Connection: cdef: object _transport readonly list inbox AsyncIOProtocol _protocol ================================================ FILE: edb/protocol/protocol.pyi ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any from . import messages class Connection: async def connect(self) -> None: ... async def execute(self, query: str, state_id: bytes, state: bytes) -> None: ... async def sync(self) -> bytes: ... async def recv(self) -> messages.ServerMessage: ... async def recv_match( self, msgcls: type[messages.ServerMessage], _ignore_msg: type[messages.ServerMessage] | None, **fields: Any, ) -> messages.ServerMessage: ... async def send(self, *msgs: messages.ClientMessage) -> None: ... async def aclose(self) -> None: ... ================================================ FILE: edb/protocol/protocol.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import asyncio import re import time from gel import con_utils from gel import enums from gel.protocol.asyncio_proto cimport AsyncIOProtocol from gel.protocol.protocol cimport ReadBuffer, WriteBuffer from . import messages cdef class Protocol(AsyncIOProtocol): cdef parse_command_complete_message(self): cdef WriteBuffer buf = WriteBuffer.new() self.ignore_headers() self.last_capabilities = enums.Capability(self.buffer.read_int64()) self.last_status = self.buffer.read_len_prefixed_bytes() state_typedesc_id = self.buffer.read_bytes(16) buf.write_len_prefixed_bytes(self.buffer.read_len_prefixed_bytes()) if state_typedesc_id != b'\x00' * 16: self.last_state = bytes(buf) self.buffer.finish_message() cdef encode_state(self, state): if self.last_state is None: return AsyncIOProtocol.encode_state(self, None) else: return self.state_type_id, self.last_state cdef class Connection: def __init__(self, pr, tr): self._protocol = pr self._transport = tr self.inbox = [] async def connect(self): await self._protocol.connect() async def execute(self, query, state_id=b'\0' * 16, state=b''): await self.send( messages.Execute( annotations=[], command_text=query, input_language=messages.InputLanguage.EDGEQL, output_format=messages.OutputFormat.NONE, expected_cardinality=messages.Cardinality.MANY, allowed_capabilities=messages.Capability.ALL, compilation_flags=messages.CompilationFlag(9), implicit_limit=0, input_typedesc_id=b'\0' * 16, output_typedesc_id=b'\0' * 16, state_typedesc_id=state_id, arguments=b'', state_data=state, ), messages.Sync(), ) await self.recv_match( messages.CommandComplete, _ignore_msg=messages.StateDataDescription, ) await self.recv_match(messages.ReadyForCommand) async def sync(self): await self.send(messages.Sync()) reply = await self.recv() if not isinstance(reply, messages.ReadyForCommand): raise AssertionError( f'invalid response for Sync request: {reply!r}') return reply.transaction_state async def recv(self): while True: await self._protocol.wait_for_message() mtype = self._protocol.buffer.get_message_type() data = self._protocol.buffer.consume_message() msg = messages.ServerMessage.parse(mtype, data) if isinstance(msg, messages.LogMessage): self.inbox.append(msg) continue return msg async def recv_match(self, msgcls, _ignore_msg=None, **fields): message = await self.recv() if _ignore_msg is not None and isinstance(message, _ignore_msg): message = await self.recv() if not isinstance(message, msgcls): raise AssertionError( f'expected for {msgcls.__name__} message, received ' f'{type(message).__name__}: {message!r}') for fieldname, expected in fields.items(): val = getattr(message, fieldname) if isinstance(expected, str): if not re.match(expected, val): raise AssertionError( f'{msgcls.__name__}.{fieldname} value {val!r} ' f'does not match expected regexp {expected!r}') else: if expected != val: raise AssertionError( f'{msgcls.__name__}.{fieldname} value {val!r} ' f'does not equal to expected {expected!r}') return message async def send(self, *msgs: messages.ClientMessage): cdef WriteBuffer buf for msg in msgs: out = msg.dump() buf = WriteBuffer.new() buf.write_bytes(out) self._protocol.write(buf) async def aclose(self): # TODO: Fix when edgedb-python implements proper cancellation asyncio.get_running_loop().call_soon(lambda: self._protocol.abort()) await self._protocol.wait_for_disconnect() async def new_connection( dsn: str = None, *, host: str = None, port: int = None, user: str = None, password: str = None, secret_key: str = None, branch: str = None, database: str = None, timeout: float = 60, tls_ca: str = None, tls_ca_file: str = None, tls_security: str = 'default', credentials: str = None, credentials_file: str = None, **kwargs ): connect_config, client_config = con_utils.parse_connect_arguments( dsn=dsn, host=host, port=port, user=user, password=password, secret_key=secret_key, database=database, branch=branch, timeout=timeout, command_timeout=None, server_settings=None, tls_ca=tls_ca, tls_ca_file=tls_ca_file, tls_security=tls_security, tls_server_name=None, wait_until_available=timeout, credentials=credentials, credentials_file=credentials_file, **kwargs ) loop = asyncio.get_running_loop() last_error = None addr = None for addr in [connect_config.address]: before = time.monotonic() try: if timeout <= 0: raise asyncio.TimeoutError protocol_factory = lambda: Protocol(connect_config, loop) if isinstance(addr, str): connector = loop.create_unix_connection( protocol_factory, addr) else: connector = loop.create_connection( protocol_factory, *addr, ssl=connect_config.ssl_ctx if tls_security else None, ) before = time.monotonic() try: tr, pr = await asyncio.wait_for(connector, timeout=timeout) finally: timeout -= time.monotonic() - before return Connection(pr, tr) except (OSError, asyncio.TimeoutError, ConnectionError) as ex: last_error = ex raise last_error ================================================ FILE: edb/protocol/render_utils.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional import contextlib import textwrap class RenderBuffer: ilevel: int buf: list[str] def __init__(self): self.ilevel = 0 self.buf = [] def write(self, line: str) -> None: self.buf.append(' ' * (self.ilevel * 2) + line) def newline(self) -> None: self.buf.append('') def lastline(self) -> Optional[str]: return self.buf[-1] if len(self.buf) else None def popline(self) -> str: return self.buf.pop() def write_comment(self, comment: str) -> None: lines = textwrap.wrap(comment, width=40) for line in lines: self.write(f'// {line}') def __str__(self): return '\n'.join(self.buf) @contextlib.contextmanager def indent(self): self.ilevel += 1 try: yield finally: self.ilevel -= 1 ================================================ FILE: edb/schema/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2013-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from .name import QualName # NOQA from .objects import Object, ObjectMeta # NOQA from .schema import Schema # NOQA from .modules import Module # NOQA ================================================ FILE: edb/schema/_types.py ================================================ # AUTOGENERATED FROM "edb/api/types.txt" WITH # $ edb gen-types from __future__ import annotations import uuid from edb.common import uuidgen from edb.schema import name as sn UUID: type[uuid.UUID] = uuidgen.UUID TYPE_IDS = { sn.name_from_string('anytype'): UUID('00000000-0000-0000-0000-000000000001'), sn.name_from_string('anytuple'): UUID('00000000-0000-0000-0000-000000000002'), sn.name_from_string('anyobject'): UUID('00000000-0000-0000-0000-000000000003'), sn.name_from_string('std'): UUID('00000000-0000-0000-0000-0000000000f0'), sn.name_from_string('empty-tuple'): UUID('00000000-0000-0000-0000-0000000000ff'), sn.name_from_string('std::uuid'): UUID('00000000-0000-0000-0000-000000000100'), sn.name_from_string('std::str'): UUID('00000000-0000-0000-0000-000000000101'), sn.name_from_string('std::bytes'): UUID('00000000-0000-0000-0000-000000000102'), sn.name_from_string('std::int16'): UUID('00000000-0000-0000-0000-000000000103'), sn.name_from_string('std::int32'): UUID('00000000-0000-0000-0000-000000000104'), sn.name_from_string('std::int64'): UUID('00000000-0000-0000-0000-000000000105'), sn.name_from_string('std::float32'): UUID('00000000-0000-0000-0000-000000000106'), sn.name_from_string('std::float64'): UUID('00000000-0000-0000-0000-000000000107'), sn.name_from_string('std::decimal'): UUID('00000000-0000-0000-0000-000000000108'), sn.name_from_string('std::bool'): UUID('00000000-0000-0000-0000-000000000109'), sn.name_from_string('std::datetime'): UUID('00000000-0000-0000-0000-00000000010a'), sn.name_from_string('std::duration'): UUID('00000000-0000-0000-0000-00000000010e'), sn.name_from_string('std::json'): UUID('00000000-0000-0000-0000-00000000010f'), sn.name_from_string('std::bigint'): UUID('00000000-0000-0000-0000-000000000110'), sn.name_from_string('std::cal::local_datetime'): UUID('00000000-0000-0000-0000-00000000010b'), sn.name_from_string('std::cal::local_date'): UUID('00000000-0000-0000-0000-00000000010c'), sn.name_from_string('std::cal::local_time'): UUID('00000000-0000-0000-0000-00000000010d'), sn.name_from_string('std::cal::relative_duration'): UUID('00000000-0000-0000-0000-000000000111'), sn.name_from_string('std::cal::date_duration'): UUID('00000000-0000-0000-0000-000000000112'), sn.name_from_string('cfg::memory'): UUID('00000000-0000-0000-0000-000000000130'), sn.name_from_string('std::pg::json'): UUID('00000000-0000-0000-0000-000001000001'), sn.name_from_string('std::pg::timestamptz'): UUID('00000000-0000-0000-0000-000001000002'), sn.name_from_string('std::pg::timestamp'): UUID('00000000-0000-0000-0000-000001000003'), sn.name_from_string('std::pg::date'): UUID('00000000-0000-0000-0000-000001000004'), sn.name_from_string('std::pg::interval'): UUID('00000000-0000-0000-0000-000001000005'), } ================================================ FILE: edb/schema/abc.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import typing class Reducible: """An interface implemented by all non-builtin objects stored in schema.""" def schema_reduce(self) -> typing.Any: """Return a primitive representation of the object. The return value must consist of primitive Python objects. """ raise NotImplementedError @classmethod def schema_restore( cls, data: typing.Any, ) -> Reducible: """Restore object from data returned by *schema_reduce*.""" raise NotImplementedError ================================================ FILE: edb/schema/annos.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( AbstractSet, Any, Callable, Optional, TypeVar, cast, TYPE_CHECKING, ) import json from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from . import delta as sd from . import name as sn from . import referencing from . import objects as so from . import utils if TYPE_CHECKING: from . import schema as s_schema class AnnotationValue( referencing.ReferencedInheritingObject, qlkind=qltypes.SchemaObjectClass.ANNOTATION, reflection=so.ReflectionMethod.AS_LINK, reflection_link='annotation', data_safe=True, ): subject = so.SchemaField( so.Object, compcoef=1.0, default=None, inheritable=False) # N.B: This is really an Annotation, and we even patch it up below # to be one, but we can't reference it here because it hasn't been # declared. (And the tricks used for this sort of thing elsewhere # don't work for reflection AS_LINK) annotation = so.SchemaField( so.Object, compcoef=0.429, ddl_identity=True) value = so.SchemaField( str, compcoef=0.909) def get_annotation(self, schema: s_schema.Schema) -> Annotation: return self.get_field_value( # type: ignore[no-any-return] schema, 'annotation') def should_propagate(self, schema: s_schema.Schema) -> bool: return self.get_annotation(schema).get_inheritable(schema) def __str__(self) -> str: return '<{}: at 0x{:x}>'.format(self.__class__.__name__, id(self)) __repr__ = __str__ @classmethod def get_schema_class_displayname(cls) -> str: return 'annotation' def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool = False ) -> str: vn = super().get_verbosename(schema) if with_parent: subject = self.get_subject(schema) assert subject is not None pvn = subject.get_verbosename(schema, with_parent=True) return f'{vn} of {pvn}' else: return vn T = TypeVar("T") class AnnotationSubject(so.Object): annotations_refs = so.RefDict( attr='annotations', ref_cls=AnnotationValue) annotations = so.SchemaField( so.ObjectIndexByShortname[AnnotationValue], inheritable=False, ephemeral=True, coerce=True, compcoef=0.909, default=so.DEFAULT_CONSTRUCTOR) def get_annotation( self, schema: s_schema.Schema, name: sn.QualName, ) -> Optional[str]: attrval = self.get_annotations(schema).get(schema, name, None) return attrval.get_value(schema) if attrval is not None else None def must_get_annotation( self, schema: s_schema.Schema, name: sn.QualName, ) -> str: annotation_text = self.get_annotation(schema, name) if annotation_text is None: vn = self.get_verbosename(schema, with_parent=True) raise errors.SchemaDefinitionError( f"annotation {name} on {vn} is not set") return annotation_text def get_json_annotation( self, schema: s_schema.Schema, name: sn.QualName, t: Callable[[Any], T], ) -> Optional[T]: annotation_text = self.get_annotation(schema, name) if annotation_text is None: return None else: try: value = json.loads(annotation_text) except Exception: vn = self.get_verbosename(schema, with_parent=True) raise errors.SchemaDefinitionError( f"annotation {name} on {vn} is not set to " f"a valid JSON value") try: return t(value) except Exception as e: vn = self.get_verbosename(schema, with_parent=True) raise errors.SchemaDefinitionError( f"annotation {name} on {vn} is not set to " f"JSON containing a valid value of type {t}: {e}" ) def must_get_json_annotation( self, schema: s_schema.Schema, name: sn.QualName, t: Callable[[Any], T], ) -> T: value = self.get_json_annotation(schema, name, t) if value is None: vn = self.get_verbosename(schema, with_parent=True) raise errors.SchemaDefinitionError( f"annotation {name} is not set on {vn}" ) else: return value class Annotation( so.QualifiedObject, so.InheritingObject, AnnotationSubject, qlkind=qltypes.SchemaObjectClass.ANNOTATION, data_safe=True, ): inheritable = so.SchemaField( bool, default=False, compcoef=0.2) def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool = False ) -> str: vn = super().get_verbosename(schema) return f"abstract {vn}" # HACK?: Fix up annotation field in AnnotationValue to have the proper type AnnotationValue.get_field('annotation').type = Annotation class AnnotationSubjectCommandContext: pass class AnnotationSubjectCommand(sd.ObjectCommand[so.Object_T]): pass class AnnotationCommandContext(sd.ObjectCommandContext[Annotation], AnnotationSubjectCommandContext): pass class AnnotationCommand(sd.QualifiedObjectCommand[Annotation], AnnotationSubjectCommand[Annotation], context_class=AnnotationCommandContext): def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if field in {'abstract', 'inheritable'}: return field else: return super().get_ast_attr_for_field(field, astnode) class CreateAnnotation(AnnotationCommand, sd.CreateObject[Annotation]): astnode = qlast.CreateAnnotation @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> CreateAnnotation: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(astnode, qlast.CreateAnnotation) cmd.set_attribute_value('inheritable', astnode.inheritable) cmd.set_attribute_value('abstract', True) assert isinstance(cmd, CreateAnnotation) return cmd class RenameAnnotation(AnnotationCommand, sd.RenameObject[Annotation]): def _canonicalize( self, schema: s_schema.Schema, context: sd.CommandContext, scls: Annotation, ) -> None: super()._canonicalize(schema, context, scls) # AnnotationValues have names derived from the abstract # annotations. We unfortunately need to go update their names. annot_vals = cast( AbstractSet[AnnotationValue], schema.get_referrers( scls, scls_type=AnnotationValue, field_name='annotation')) for ref in annot_vals: if ref.get_implicit_bases(schema): # This annotation value is inherited, and presumably # the rename in parent will propagate. continue ref_name = ref.get_name(schema) quals = list(sn.quals_from_fullname(ref_name)) new_ref_name = sn.QualName( name=sn.get_specialized_name(self.new_name, *quals), module=ref_name.module, ) self.add(self.init_rename_branch( ref, new_ref_name, schema=schema, context=context, )) class AlterAnnotation(AnnotationCommand, sd.AlterObject[Annotation]): astnode = qlast.AlterAnnotation class DeleteAnnotation(AnnotationCommand, sd.DeleteObject[Annotation]): astnode = qlast.DropAnnotation class AnnotationValueCommandContext(sd.ObjectCommandContext[AnnotationValue]): pass class AnnotationValueCommand( referencing.ReferencedInheritingObjectCommand[AnnotationValue], context_class=AnnotationValueCommandContext, referrer_context_class=AnnotationSubjectCommandContext, ): def _deparse_name( self, schema: s_schema.Schema, context: sd.CommandContext, name: sn.Name, ) -> qlast.ObjectRef: ref = super()._deparse_name(schema, context, name) # Clear `itemclass` ref.itemclass = None return ref @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> sn.QualName: parent_ctx = cls.get_referrer_context_or_die(context) assert isinstance(parent_ctx.op, sd.QualifiedObjectCommand) referrer_name = context.get_referrer_name(parent_ctx) base_ref = utils.ast_to_object_shell( astnode.name, modaliases=context.modaliases, schema=schema, metaclass=Annotation, ) base_name = base_ref.name quals = cls._classname_quals_from_ast( schema, astnode, base_name, referrer_name, context) pnn = sn.get_specialized_name(base_name, str(referrer_name), *quals) return sn.QualName(name=pnn, module=referrer_name.module) def populate_ddl_identity( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().populate_ddl_identity(schema, context) if not isinstance(self, sd.CreateObject): anno = self.scls.get_annotation(schema) else: annoname = sn.shortname_from_fullname(self.classname) anno = schema.get(annoname, type=Annotation) self.set_ddl_identity('annotation', anno) return schema class CreateAnnotationValue( AnnotationValueCommand, referencing.CreateReferencedInheritingObject[AnnotationValue], ): astnode = qlast.CreateAnnotationValue referenced_astnode = qlast.CreateAnnotationValue @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext ) -> CreateAnnotationValue: assert isinstance(astnode, qlast.CreateAnnotationValue) cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, CreateAnnotationValue) annoname = sn.shortname_from_fullname(cmd.classname) value, ir = qlcompiler.evaluate_ast_to_python_val_and_ir( astnode.value, schema=schema) if ir.stype.get_name(schema) != sn.QualName('std', 'str'): vn = ir.stype.get_verbosename(schema) raise errors.InvalidValueError( f"annotation values must be 'std::str', got {vn}", span=astnode.value.span, ) anno = utils.ast_objref_to_object_shell( utils.name_to_ast_ref(annoname), metaclass=Annotation, modaliases=context.modaliases, schema=schema, ) cmd.set_attribute_value('annotation', anno) cmd.set_attribute_value('value', value) return cmd def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) anno = self.get_ddl_identity('annotation') assert anno is not None self.set_attribute_value('internal', True) return schema def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: if op.property == 'value': assert isinstance(op.new_value, str) assert isinstance(node, ( qlast.CreateAnnotationValue, qlast.AlterAnnotationValue)) node.value = qlast.Constant.string(op.new_value) else: super()._apply_field_ast(schema, context, node, op) class AlterAnnotationValueOwned( referencing.AlterOwned[AnnotationValue], AnnotationValueCommand, field='owned', referrer_context_class=AnnotationSubjectCommandContext, ): pass class AlterAnnotationValue( AnnotationValueCommand, referencing.AlterReferencedInheritingObject[AnnotationValue], ): astnode = qlast.AlterAnnotationValue referenced_astnode = qlast.AlterAnnotationValue @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext ) -> AlterAnnotationValue: assert isinstance( astnode, (qlast.CreateAnnotationValue, qlast.AlterAnnotationValue), ) cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, AlterAnnotationValue) if astnode.value is not None: value, ir = qlcompiler.evaluate_ast_to_python_val_and_ir( astnode.value, schema=schema) if ir.stype.get_name(schema) != sn.QualName('std', 'str'): vn = ir.stype.get_verbosename(schema) raise errors.InvalidValueError( f"annotation values must be 'std::str', got {vn}", span=astnode.value.span, ) cmd.set_attribute_value( 'value', value, ) annoname = sn.shortname_from_fullname(cmd.classname) anno = utils.ast_objref_to_object_shell( utils.name_to_ast_ref(annoname), metaclass=Annotation, modaliases=context.modaliases, schema=schema, ) cmd.set_attribute_value('annotation', value=anno, orig_value=anno) return cmd def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if ( not self.has_attribute_value('value') and not self.has_attribute_value('owned') ): return None # Skip AlterObject's _get_ast, because we *don't* want to # filter out things without subcommands! return sd.ObjectCommand._get_ast( self, schema, context, parent_node=parent_node) def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty ) -> None: assert isinstance(node, qlast.AlterAnnotationValue) if op.property == 'value': assert isinstance(op.new_value, str) node.value = qlast.Constant.string(op.new_value) else: super()._apply_field_ast(schema, context, node, op) class RebaseAnnotationValue( AnnotationValueCommand, referencing.RebaseReferencedInheritingObject[AnnotationValue], ): pass class RenameAnnotationValue( AnnotationValueCommand, referencing.RenameReferencedInheritingObject[AnnotationValue], ): pass class DeleteAnnotationValue( AnnotationValueCommand, referencing.DeleteReferencedInheritingObject[AnnotationValue], ): astnode = qlast.DropAnnotationValue @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext ) -> DeleteAnnotationValue: assert isinstance(astnode, qlast.DropAnnotationValue) cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, DeleteAnnotationValue) annoname = sn.shortname_from_fullname(cmd.classname) anno = utils.ast_objref_to_object_shell( utils.name_to_ast_ref(annoname), metaclass=Annotation, modaliases=context.modaliases, schema=schema, ) cmd.set_attribute_value('annotation', value=None, orig_value=anno) return cmd def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) anno = self.get_ddl_identity('annotation') assert anno is not None self.set_attribute_value( 'annotation', value=None, orig_value=anno, ) return schema ================================================ FILE: edb/schema/casts.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Mapping, cast from edb import errors from edb.common import lru from edb.edgeql import ast as qlast from edb.edgeql import qltypes from . import annos as s_anno from . import delta as sd from . import functions as s_func from . import name as sn from . import objects as so from . import types as s_types from . import schema as s_schema from . import utils _NOT_REACHABLE = 10000000 def _is_reachable( schema: s_schema.Schema, cast_kwargs: Mapping[str, bool], source: s_types.Type, target: s_types.Type, distance: int, ) -> int: if source == target: return distance casts = schema.get_casts_to_type(target, **cast_kwargs) if not casts: return _NOT_REACHABLE sources = {c.get_from_type(schema) for c in casts} distance += 1 if source in sources: return distance else: return min( _is_reachable(schema, cast_kwargs, source, s, distance) for s in sources ) @lru.per_job_lru_cache() def get_implicit_cast_distance( schema: s_schema.Schema, source: s_types.Type, target: s_types.Type, ) -> int: dist = _is_reachable(schema, {'implicit': True}, source, target, 0) if dist == _NOT_REACHABLE: return -1 else: return dist def is_implicitly_castable( schema: s_schema.Schema, source: s_types.Type, target: s_types.Type, ) -> bool: return get_implicit_cast_distance(schema, source, target) >= 0 @lru.per_job_lru_cache() def find_common_castable_type( schema: s_schema.Schema, source: s_types.Type, target: s_types.Type, ) -> Optional[s_types.Type]: if get_implicit_cast_distance(schema, target, source) >= 0: return source if get_implicit_cast_distance(schema, source, target) >= 0: return target # Elevate target in the castability ladder, and check if # source is castable to it on each step. while True: casts = schema.get_casts_from_type(target, implicit=True) if not casts: return None targets = {c.get_to_type(schema) for c in casts} if len(targets) > 1: for t in targets: candidate = find_common_castable_type(schema, source, t) if candidate is not None: return candidate else: return None else: target = next(iter(targets)) if get_implicit_cast_distance(schema, source, target) >= 0: return target @lru.per_job_lru_cache() def is_assignment_castable( schema: s_schema.Schema, source: s_types.Type, target: s_types.Type, ) -> bool: # Implicitly castable implies assignment castable. if is_implicitly_castable(schema, source, target): return True # Assignment casts are valid only as one-hop casts. casts = schema.get_casts_to_type(target, assignment=True) if not casts: return False for c in casts: if c.get_from_type(schema) == source: return True return False @lru.per_job_lru_cache() def is_castable( schema: s_schema.Schema, source: s_types.Type, target: s_types.Type, ) -> bool: # Implicitly castable if is_implicitly_castable(schema, source, target): return True elif is_assignment_castable(schema, source, target): return True else: casts = schema.get_casts_to_type(target) if not casts: return False else: for c in casts: if c.get_from_type(schema) == source: return True else: return False def get_cast_fullname_from_names( from_type: sn.Name, to_type: sn.Name, ) -> sn.QualName: std = not ( ( isinstance(from_type, sn.QualName) and sn.UnqualName(from_type.module) not in s_schema.STD_MODULES ) or ( isinstance(to_type, sn.QualName) and sn.UnqualName(to_type.module) not in s_schema.STD_MODULES ) ) module = 'std' if std else '__ext_casts__' quals = [str(from_type), str(to_type)] shortname = sn.QualName(module, 'cast') return sn.QualName( module=shortname.module, name=sn.get_specialized_name(shortname, *quals), ) def get_cast_fullname( schema: s_schema.Schema, from_type: s_types.TypeShell[s_types.Type], to_type: s_types.TypeShell[s_types.Type], ) -> sn.QualName: return get_cast_fullname_from_names( from_type.get_name(schema), to_type.get_name(schema), ) class Cast( so.QualifiedObject, s_anno.AnnotationSubject, s_func.VolatilitySubject, qlkind=qltypes.SchemaObjectClass.CAST, data_safe=True, ): from_type = so.SchemaField( s_types.Type, compcoef=0.5) to_type = so.SchemaField( s_types.Type, compcoef=0.5) allow_implicit = so.SchemaField( bool, default=False, compcoef=0.4) allow_assignment = so.SchemaField( bool, default=False, compcoef=0.4) language = so.SchemaField( qlast.Language, default=None, compcoef=0.4, coerce=True) from_function = so.SchemaField( str, default=None, compcoef=0.4) from_expr = so.SchemaField( bool, default=False, compcoef=0.4) from_cast = so.SchemaField( bool, default=False, compcoef=0.4) code = so.SchemaField( str, default=None, compcoef=0.4) class CastCommandContext(sd.ObjectCommandContext[Cast], s_anno.AnnotationSubjectCommandContext): pass class CastCommand(sd.QualifiedObjectCommand[Cast], context_class=CastCommandContext): def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if field in {'allow_assignment', 'allow_implicit'}: return field else: return super().get_ast_attr_for_field(field, astnode) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: if not context.stdmode and not context.testmode: raise errors.UnsupportedFeatureError( 'user-defined casts are not supported', span=astnode.span ) return super()._cmd_tree_from_ast(schema, astnode, context) @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> sn.QualName: assert isinstance(astnode, qlast.CastCommand) modaliases = context.modaliases from_type = utils.ast_to_type_shell( astnode.from_type, metaclass=s_types.Type, modaliases=modaliases, schema=schema, ) to_type = utils.ast_to_type_shell( astnode.to_type, metaclass=s_types.Type, modaliases=modaliases, schema=schema, ) return get_cast_fullname(schema, from_type, to_type) def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) schema = s_types.materialize_type_in_attribute( schema, context, self, 'from_type') schema = s_types.materialize_type_in_attribute( schema, context, self, 'to_type') return schema class CreateCast(CastCommand, sd.CreateObject[Cast]): astnode = qlast.CreateCast def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: fullname = self.classname cast = schema.get(fullname, None) if cast: from_type = self.get_attribute_value('from_type') to_type = self.get_attribute_value('to_type') raise errors.DuplicateCastDefinitionError( f'a cast from {from_type.get_displayname(schema)!r} ' f'to {to_type.get_displayname(schema)!r} is already defined', span=self.span) return super()._create_begin(schema, context) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: assert isinstance(astnode, qlast.CreateCast) cmd = super()._cmd_tree_from_ast(schema, astnode, context) modaliases = context.modaliases from_type = utils.ast_to_type_shell( astnode.from_type, metaclass=s_types.Type, modaliases=modaliases, schema=schema, ) cmd.set_attribute_value('from_type', from_type) to_type = utils.ast_to_type_shell( astnode.to_type, metaclass=s_types.Type, modaliases=modaliases, schema=schema, ) cmd.set_attribute_value('to_type', to_type) cmd.set_attribute_value('allow_implicit', astnode.allow_implicit) cmd.set_attribute_value('allow_assignment', astnode.allow_assignment) if astnode.code is not None: cmd.set_attribute_value( 'language', astnode.code.language, ) if astnode.code.from_function is not None: cmd.set_attribute_value( 'from_function', astnode.code.from_function, ) if astnode.code.code is not None: cmd.set_attribute_value( 'code', astnode.code.code, ) if astnode.code.from_expr is not None: cmd.set_attribute_value( 'from_expr', astnode.code.from_expr, ) if astnode.code.from_cast is not None: cmd.set_attribute_value( 'from_cast', astnode.code.from_cast, ) return cmd def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: assert isinstance(node, qlast.CreateCast) new_value: Any = op.new_value if op.property == 'from_type': # In a cast we can only have pure types, so this is going # to be a TypeName. node.from_type = cast(qlast.TypeName, utils.typeref_to_ast(schema, new_value)) elif op.property == 'to_type': # In a cast we can only have pure types, so this is going # to be a TypeName. node.to_type = cast(qlast.TypeName, utils.typeref_to_ast(schema, new_value)) elif op.property == 'code': if node.code is None: node.code = qlast.CastCode() node.code.code = new_value elif op.property == 'language': if node.code is None: node.code = qlast.CastCode() node.code.language = new_value elif op.property == 'from_function' and new_value: if node.code is None: node.code = qlast.CastCode() node.code.from_function = new_value elif op.property == 'from_expr' and new_value: if node.code is None: node.code = qlast.CastCode() node.code.from_expr = new_value elif op.property == 'from_cast' and new_value: if node.code is None: node.code = qlast.CastCode() node.code.from_cast = new_value else: super()._apply_field_ast(schema, context, node, op) class RenameCast(CastCommand, sd.RenameObject[Cast]): pass class AlterCast(CastCommand, sd.AlterObject[Cast]): astnode = qlast.AlterCast class DeleteCast(CastCommand, sd.DeleteObject[Cast]): astnode = qlast.DropCast def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._delete_begin(schema, context) if not context.canonical: from_type = self.scls.get_from_type(schema) if op := from_type.as_type_delete_if_unused(schema): self.add_caused(op) to_type = self.scls.get_to_type(schema) if op := to_type.as_type_delete_if_unused(schema): self.add_caused(op) return schema ================================================ FILE: edb/schema/constraints.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, Mapping, cast, Iterable, TYPE_CHECKING, ) import re from edb import errors from edb.common import verutils from edb import edgeql from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes as ft from edb.edgeql import parser as qlparser from edb.edgeql import utils as qlutils from edb.edgeql import qltypes from . import annos as s_anno from . import delta as sd from . import expr as s_expr from . import functions as s_func from . import inheriting from . import name as sn from . import objects as so from . import types as s_types from . import pseudo as s_pseudo from . import referencing from . import utils if TYPE_CHECKING: from edb.common import parsing as c_parsing from edb.schema import schema as s_schema def _assert_not_none[T](value: Optional[T]) -> T: if value is None: raise TypeError("A value is expected") return value def merge_constraint_params( constraint: Constraint, supers: list[Constraint], field_name: str, *, ignore_local: bool, schema: s_schema.Schema, ) -> Any: if constraint.get_subject(schema) is None: # consistency of abstract constraint params is checked # in CreateConstraint.validate_create return constraint.get_explicit_field_value(schema, field_name, None) else: # concrete constraints cannot redefine parameters and always # inherit from super. return supers[0].get_explicit_field_value(schema, field_name, None) def constraintname_from_fullname(name: sn.Name) -> sn.QualName: assert isinstance(name, sn.QualName) # the dict key for constraints drops the first qual, which makes # it independent of where it is declared short = sn.shortname_from_fullname(name) quals = sn.quals_from_fullname(name) return sn.QualName( name=sn.get_specialized_name(short, *quals[1:]), module='__', ) def _constraint_object_key(schema: s_schema.Schema, o: so.Object) -> sn.Name: return constraintname_from_fullname(o.get_name(schema)) class ObjectIndexByConstraintName( so.ObjectIndexBase[sn.Name, so.Object_T], key=_constraint_object_key, ): @classmethod def get_key_for_name( cls, schema: s_schema.Schema, name: sn.Name, ) -> sn.Name: return constraintname_from_fullname(name) class Constraint( referencing.ReferencedInheritingObject, s_func.CallableObject, qlkind=ft.SchemaObjectClass.CONSTRAINT, data_safe=True, ): params = so.SchemaField( s_func.FuncParameterList, coerce=True, compcoef=0.4, default=so.DEFAULT_CONSTRUCTOR, inheritable=True, merge_fn=merge_constraint_params, ) expr = so.SchemaField( s_expr.Expression, default=None, compcoef=0.909, coerce=True) subjectexpr = so.SchemaField( s_expr.Expression, default=None, compcoef=0.833, coerce=True, ddl_identity=True) finalexpr = so.SchemaField( s_expr.Expression, default=None, compcoef=0.909, coerce=True) except_expr = so.SchemaField( s_expr.Expression, default=None, coerce=True, compcoef=0.909, ddl_identity=True, ) subject = so.SchemaField( so.Object, default=None, inheritable=False) args = so.SchemaField( s_expr.ExpressionList, default=None, coerce=True, inheritable=False, compcoef=0.875, ddl_identity=True) delegated = so.SchemaField( bool, default=False, inheritable=False, special_ddl_syntax=True, compcoef=0.9, ) errmessage = so.SchemaField( str, default=None, compcoef=0.971, allow_ddl_set=True, allow_interpolation=True, ) is_aggregate = so.SchemaField( bool, default=False, compcoef=0.971, allow_ddl_set=False) def get_name_impacting_ancestors( self, schema: s_schema.Schema, ) -> list[Constraint]: if self.is_non_concrete(schema): return [] else: return [self.get_nearest_generic_parent(schema)] def get_constraint_origins( self, schema: s_schema.Schema ) -> list[Constraint]: """ Origins of a constraint are the constraints that should actually perform validation on their subjects. Example: If we have `Baz <: Bar <: Foo` and `Foo` declares some exclusive constraint, this constraint will be inherited by `Bar` and then `Baz`. But this inherited constraint on `Baz` should not validate exclusivity of the property within just `Baz`, but within all `Foo` objects. That's why origin of the constraint on `Baz` and on `Bar` is the constraint on `Foo`. Determining which type is that is non-trivial because of: - multiple inheritance (constraint might originate from multiple unrelated ancestors) - delegated exclusive constraints, which are defined on a parent, but should be exclusive within each of the children. We validate constraints using triggers, and this function helps drive their generation. """ # collect origins from all ancestors origins: set[Constraint] = set() for base in self.get_bases(schema).objects(schema): # abstract bases are not an origin if base.is_non_concrete(schema): continue # delegated bases are not an origin if base.get_delegated(schema): continue # recurse origins.update(base.get_constraint_origins(schema)) # if no ancestors have an origin, I am the origin return [self] if not origins else list(origins) def is_independent(self, schema: s_schema.Schema) -> bool: return ( not self.descendants(schema) and self.get_constraint_origins(schema) == [self] ) def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool = False ) -> str: vn = super().get_verbosename(schema, with_parent=with_parent) if self.is_non_concrete(schema): return f'abstract {vn}' else: # concrete constraint must have a subject assert self.get_subject(schema) is not None return vn def is_non_concrete(self, schema: s_schema.Schema) -> bool: return self.get_subject(schema) is None def get_subject(self, schema: s_schema.Schema) -> ConsistencySubject: return cast( ConsistencySubject, self.get_field_value(schema, 'subject'), ) def format_error_message( self, schema: s_schema.Schema, ) -> str: subject = self.get_subject(schema) title_ann = subject.get_annotation(schema, sn.QualName('std', 'title')) if title_ann: subject_name = title_ann else: short_name = subject.get_shortname(schema) subject_name = short_name.name return self.format_error_text(schema, subject_name) def format_error_text( self, schema: s_schema.Schema, subject_name: str, ) -> str: text = self.get_errmessage(schema) assert text args: Optional[s_expr.ExpressionList] = self.get_args(schema) if args: args_ql: list[qlast.Base] = [ qlast.Path(steps=[qlast.ObjectRef(name=subject_name)]), ] args_ql.extend(arg.parse() for arg in args) constr_base: Constraint = schema.get( self.get_name(schema), type=type(self)) index_parameters = qlutils.index_parameters( args_ql, parameters=constr_base.get_params(schema), schema=schema, ) expr = constr_base.get_field_value(schema, 'expr') expr_ql = qlparser.parse_query(expr.text) qlutils.inline_parameters(expr_ql, index_parameters) args_map = {name: edgeql.generate_source(val, pretty=False) for name, val in index_parameters.items()} else: args_map = {'__subject__': subject_name} return interpolate_error_text(text, args_map) def as_alter_delta( self, other: Constraint, *, self_schema: s_schema.Schema, other_schema: s_schema.Schema, confidence: float, context: so.ComparisonContext, ) -> sd.ObjectCommand[Constraint]: return super().as_alter_delta( other, self_schema=self_schema, other_schema=other_schema, confidence=confidence, context=context, ) def as_delete_delta( self, *, schema: s_schema.Schema, context: so.ComparisonContext, ) -> sd.ObjectCommand[Constraint]: return super().as_delete_delta(schema=schema, context=context) def get_ddl_identity( self, schema: s_schema.Schema, ) -> Optional[dict[str, str]]: ddl_identity = super().get_ddl_identity(schema) if ( ddl_identity is not None and self.field_is_inherited(schema, 'subjectexpr') and (bases := self.get_bases(schema).objects(schema)) and ( bases[0].is_non_concrete(schema) or 'subjectexpr' not in ( bases[0].get_ddl_identity(schema) or ()) ) ): ddl_identity.pop('subjectexpr', None) return ddl_identity @classmethod def get_root_classes(cls) -> tuple[sn.QualName, ...]: return ( sn.QualName(module='std', name='constraint'), ) @classmethod def get_default_base_name(self) -> sn.QualName: return sn.QualName('std', 'constraint') class ConsistencySubject( so.QualifiedObject, so.InheritingObject, s_anno.AnnotationSubject, ): constraints_refs = so.RefDict( attr='constraints', ref_cls=Constraint) constraints = so.SchemaField( ObjectIndexByConstraintName[Constraint], inheritable=False, ephemeral=True, coerce=True, compcoef=0.887, default=so.DEFAULT_CONSTRUCTOR ) def add_constraint( self, schema: s_schema.Schema, constraint: Constraint, replace: bool = False, ) -> s_schema.Schema: return self.add_classref( schema, 'constraints', constraint, replace=replace, ) def can_accept_constraints(self, schema: s_schema.Schema) -> bool: return True class ConsistencySubjectCommandContext: # context mixin pass class ConsistencySubjectCommand( inheriting.InheritingObjectCommand[so.InheritingObjectT], ): pass class ConstraintCommandContext(sd.ObjectCommandContext[Constraint], s_anno.AnnotationSubjectCommandContext): pass class ConstraintCommand( referencing.ReferencedInheritingObjectCommand[Constraint], s_func.CallableCommand[Constraint], context_class=ConstraintCommandContext, referrer_context_class=ConsistencySubjectCommandContext, ): @classmethod def _validate_subcommands( cls, astnode: qlast.DDLOperation, ) -> None: # check that 'subject' and 'subjectexpr' are not set as annotations for command in astnode.commands: if isinstance(command, qlast.SetField): cname = command.name if cname in {'subject', 'subjectexpr'}: raise errors.InvalidConstraintDefinitionError( f'{cname} is not a valid constraint annotation', span=command.span) @classmethod def _classname_quals_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, base_name: sn.Name, referrer_name: sn.QualName, context: sd.CommandContext, ) -> tuple[str, ...]: if isinstance(astnode, qlast.CreateConstraint): return () exprs = [] args = cls._constraint_args_from_ast(schema, astnode, context) for arg in args: exprs.append(arg.text) assert isinstance(astnode, qlast.ConcreteConstraintOp) if astnode.subjectexpr: # use the normalized text directly from the expression expr = s_expr.Expression.from_ast( astnode.subjectexpr, schema, context.modaliases) exprs.append(expr.text) if astnode.except_expr: # use the normalized text directly from the expression expr = s_expr.Expression.from_ast( astnode.except_expr, schema, context.modaliases) # but mangle it a bit, so that we can distinguish between # on and except when only one is present exprs.append('!' + expr.text) return (cls._name_qual_from_exprs(schema, exprs),) @classmethod def _classname_quals_from_name(cls, name: sn.QualName) -> tuple[str, ...]: quals = sn.quals_from_fullname(name) return (quals[-1],) @classmethod def _constraint_args_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> list[s_expr.Expression]: args = [] assert isinstance(astnode, qlast.ConcreteConstraintOp) if astnode.args: for arg in astnode.args: arg_expr = s_expr.Expression.from_ast( arg, schema, context.modaliases) args.append(arg_expr) return args @classmethod def as_inherited_ref_ast( cls, schema: s_schema.Schema, context: sd.CommandContext, name: sn.Name, parent: so.Object, ) -> qlast.ObjectDDL: assert isinstance(parent, Constraint) astnode_cls = cls.referenced_astnode # type: ignore nref = cls.get_inherited_ref_name(schema, context, parent, name) args = [] parent_args = parent.get_args(schema) if parent_args: for arg_expr in parent_args: arg = edgeql.parse_fragment(arg_expr.text) args.append(arg) subj_expr = parent.get_subjectexpr(schema) if ( subj_expr is None # Don't include subjectexpr if it was inherited from an # abstract constraint. or parent.get_nearest_generic_parent( schema).get_subjectexpr(schema) is not None ): subj_expr_ql = None else: subj_expr_ql = edgeql.parse_fragment(subj_expr.text) except_expr: s_expr.Expression | None = parent.get_except_expr(schema) if except_expr: except_expr_ql = except_expr.parse() else: except_expr_ql = None astnode = astnode_cls( name=nref, args=args, subjectexpr=subj_expr_ql, except_expr=except_expr_ql) return cast(qlast.ObjectDDL, astnode) def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: from . import pointers as s_pointers base: Optional[so.Object] = None if isinstance(self, AlterConstraint): base = self.scls.get_subject(schema) else: referrer_ctx = self.get_referrer_context(context) if referrer_ctx: base = referrer_ctx.op.scls if base is not None: assert isinstance(base, (s_types.Type, s_pointers.Pointer)) # Concrete constraint if field.name == 'expr': # Concrete constraints cannot redefine the base check # expressions, and so the only way we should get here # is through field inheritance, so check that the # value is compiled and move on. if not value.is_compiled(): mcls = self.get_schema_metaclass() dn = mcls.get_schema_class_displayname() raise errors.InternalServerError( f'uncompiled expression in the {field.name!r} field of' f' {dn} {self.classname!r}' ) # HACK: Not *really* compiled, but... return value # type: ignore elif field.name in {'subjectexpr', 'finalexpr', 'except_expr'}: compiled = value.compiled( schema=schema, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, anchors={'__subject__': base}, path_prefix_anchor='__subject__', singletons=frozenset([base]), allow_generic_type_output=True, schema_object_context=self.get_schema_metaclass(), apply_query_rewrites=False, track_schema_ref_exprs=track_schema_ref_exprs, ), context=context, ) # compile the expression to sql to preempt errors downstream utils.try_compile_irast_to_sql_tree(compiled, self.span) return compiled else: return super().compile_expr_field( schema, context, field, value) elif field.name in ('expr', 'subjectexpr'): # Abstract constraint. params = self._get_params(schema, context) param_anchors = s_func.get_params_symtable( params, schema, inlined_defaults=False, ) compiled = value.compiled( schema=schema, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, anchors=param_anchors, func_params=params, allow_generic_type_output=True, schema_object_context=self.get_schema_metaclass(), apply_query_rewrites=not context.stdmode, track_schema_ref_exprs=track_schema_ref_exprs, ), context=context, ) # compile the expression to sql to preempt errors downstream utils.try_compile_irast_to_sql_tree(compiled, self.span) return compiled else: return super().compile_expr_field( schema, context, field, value, track_schema_ref_exprs) def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: if field.name in {'expr', 'subjectexpr', 'finalexpr', 'except_expr'}: return s_expr.Expression(text='false') else: raise NotImplementedError(f'unhandled field {field.name!r}') @classmethod def get_inherited_ref_name( cls, schema: s_schema.Schema, context: sd.CommandContext, parent: so.Object, name: sn.Name, ) -> qlast.ObjectRef: bn = sn.shortname_from_fullname(name) return utils.name_to_ast_ref(bn) def get_ref_implicit_base_delta( self, schema: s_schema.Schema, context: sd.CommandContext, refcls: Constraint, implicit_bases: list[Constraint], ) -> inheriting.BaseDelta_T[Constraint]: child_bases = refcls.get_bases(schema).objects(schema) return inheriting.delta_bases( [b.get_name(schema) for b in child_bases], [b.get_name(schema) for b in implicit_bases], t=Constraint, ) def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if field in ('subjectexpr', 'args', 'except_expr'): return field elif ( field == 'delegated' and astnode is qlast.CreateConcreteConstraint ): return field else: return super().get_ast_attr_for_field(field, astnode) def get_ddl_identity_fields( self, context: sd.CommandContext, ) -> tuple[so.Field[Any], ...]: id_fields = super().get_ddl_identity_fields(context) omit_fields = set() if not self.has_ddl_identity('subjectexpr'): omit_fields.add('subjectexpr') if self.get_referrer_context(context) is None: omit_fields.add('args') if omit_fields: return tuple(f for f in id_fields if f.name not in omit_fields) else: return id_fields @classmethod def localnames_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> set[str]: localnames = super().localnames_from_ast( schema, astnode, context ) # Set up the constraint parameters as part of names to be # ignored in expression normalization. if isinstance(astnode, qlast.CreateConstraint): localnames |= {param.name for param in astnode.params} elif isinstance(astnode, qlast.AlterConstraint): # ALTER ABSTRACT CONSTRAINT doesn't repeat the params, # but we can get them from the schema. objref = astnode.name # Merge the context modaliases and the command modaliases. modaliases = dict(context.modaliases) modaliases.update( cls._modaliases_from_ast(schema, astnode, context)) # Get the original constraint. constr = schema.get( utils.ast_ref_to_name(objref), module_aliases=modaliases, type=Constraint, ) localnames |= {param.get_parameter_name(schema) for param in constr.get_params(schema).objects(schema)} return localnames def inherit_fields( self, schema: s_schema.Schema, context: sd.CommandContext, bases: tuple[so.Object, ...], *, fields: Optional[Iterable[str]] = None, ignore_local: bool = False, apply: bool = True, ) -> s_schema.Schema: # Concrete constraints populate a bunch of other fields that # can be based on their abstract parents but don't come from # them. So if we are inheriting a new expr from a potentially # abstract parent, we need to actually inherit all of these # other properties that can be populated by # _populate_concrete_constraint_attrs. # # This is pretty fragile though, and I don't love it. if fields is not None: fields = set(fields) | { 'subjectexpr', 'finalexpr', 'abstract', 'args' } return super().inherit_fields( schema, context, bases, fields=fields, ignore_local=ignore_local, apply=apply, ) def _populate_concrete_constraint_attrs( self, schema: s_schema.Schema, context: sd.CommandContext, subject_obj: so.Object, *, name: sn.QualName, subjectexpr: Optional[s_expr.Expression] = None, subjectexpr_inherited: bool = False, span: Optional[c_parsing.Span] = None, args: Optional[Iterable[s_expr.Expression]] = None, **kwargs: Any ) -> None: from edb.ir import ast as irast from edb.ir import utils as ir_utils from . import pointers as s_pointers from . import links as s_links from . import objtypes as s_objtypes from . import scalars as s_scalars bases = self.get_resolved_attribute_value( 'bases', schema=schema, context=context, ) if not bases: bases = self.scls.get_bases(schema) constr_base = bases.objects(schema)[0] # If we have a concrete base, then we should inherit all of # these attrs through the normal inherit_fields() mechanisms, # and populating them ourselves will just mess up # inherited_fields. if not constr_base.is_non_concrete(schema): return attrs = dict(kwargs) inherited = dict() base_subjectexpr = constr_base.get_field_value(schema, 'subjectexpr') if subjectexpr is None: attrs['subjectexpr'] = subjectexpr inherited['subjectexpr'] = subjectexpr_inherited subjectexpr = base_subjectexpr else: if (base_subjectexpr is not None and subjectexpr.text != base_subjectexpr.text): raise errors.InvalidConstraintDefinitionError( f'subjectexpr is already defined for {name}', span=span, ) base_subjectexpr = constr_base.get_subjectexpr(schema) if base_subjectexpr is not None: attrs['subjectexpr'] = base_subjectexpr inherited['subjectexpr'] = True if (isinstance(subject_obj, s_scalars.ScalarType) and constr_base.get_is_aggregate(schema)): raise errors.InvalidConstraintDefinitionError( f'{constr_base.get_verbosename(schema)} may not ' f'be used on scalar types', span=span, ) if ( subjectexpr is None and isinstance(subject_obj, s_objtypes.ObjectType) ): raise errors.InvalidConstraintDefinitionError( "constraints on object types must have an 'on' clause", span=span, ) expr: s_expr.Expression = constr_base.get_field_value(schema, 'expr') if not expr: raise errors.InvalidConstraintDefinitionError( f'missing constraint expression in {name}') # Re-parse instead of using expr.parse, because we mutate # the AST below. expr_ql = qlparser.parse_query(expr.text) if subjectexpr is not None: # subject has been redefined subject_ql = subjectexpr.parse() assert isinstance(subject_ql, qlast.Base) qlutils.inline_anchors( expr_ql, anchors={'__subject__': subject_ql}) if not args: args = constr_base.get_field_value(schema, 'args') if args: args_ql: list[qlast.Base] = [ qlast.Path(steps=[qlast.SpecialAnchor(name='__subject__')]), ] args_ql.extend(arg.parse() for arg in args) args_map = qlutils.index_parameters( args_ql, parameters=constr_base.get_params(schema), schema=schema, ) qlutils.inline_parameters(expr_ql, args_map) attrs['args'] = args assert isinstance(subject_obj, (s_types.Type, s_pointers.Pointer)) singletons = frozenset({subject_obj}) final_expr = s_expr.Expression.from_ast(expr_ql, schema, {}).compiled( schema=schema, options=qlcompiler.CompilerOptions( anchors={'__subject__': subject_obj}, path_prefix_anchor='__subject__', singletons=singletons, apply_query_rewrites=False, schema_object_context=self.get_schema_metaclass(), ), context=context, ) bool_t = schema.get('std::bool', type=s_scalars.ScalarType) expr_type = final_expr.irast.stype expr_schema = final_expr.irast.schema if not expr_type.issubclass(expr_schema, bool_t): raise errors.InvalidConstraintDefinitionError( f'{name} constraint expression expected ' f'to return a bool value, got ' f'{expr_type.get_verbosename(expr_schema)}', span=span ) except_expr: s_expr.Expression | None = attrs.get('except_expr') if subjectexpr is not None: options = qlcompiler.CompilerOptions( anchors={'__subject__': subject_obj}, path_prefix_anchor='__subject__', singletons=singletons, apply_query_rewrites=False, schema_object_context=self.get_schema_metaclass(), ) final_subjectexpr = subjectexpr.compiled( schema=schema, options=options, context=context ) refs = ir_utils.get_longest_paths(final_expr.irast) final_except_expr: s_expr.CompiledExpression | None = None if except_expr: final_except_expr = except_expr.compiled( schema=schema, options=options, context=context ) refs |= ir_utils.get_longest_paths(final_except_expr.irast) has_any_multi = has_non_subject_multi = False for ref in refs: while isinstance(ref.expr, irast.Pointer): rptr = ref.expr if rptr.dir_cardinality.is_multi(): has_any_multi = True # We don't need to look further than the subject, # which is always valid. (And which is a singleton # in a constraint expression if it is itself a # singleton, regardless of other parts of the path.) if ( isinstance(rptr.ptrref, irast.PointerRef) and rptr.ptrref.id == subject_obj.id ): break if rptr.dir_cardinality.is_multi(): has_non_subject_multi = True if (not isinstance(rptr.ptrref, irast.TupleIndirectionPointerRef) and rptr.ptrref.source_ptr is None and isinstance(rptr.source.expr, irast.Pointer)): if isinstance(subject_obj, s_links.Link): raise errors.InvalidConstraintDefinitionError( "link constraints may not access " "the link target", span=span ) else: raise errors.InvalidConstraintDefinitionError( "constraints cannot contain paths with more " "than one hop", span=span ) ref = rptr.source if has_non_subject_multi and len(refs) > 1: raise errors.InvalidConstraintDefinitionError( "cannot reference multiple links or properties in a " "constraint where at least one link or property is MULTI", span=span ) if set_of_op := ir_utils.find_set_of_op( final_subjectexpr.irast, has_any_multi, ): label = ( 'function' if isinstance(set_of_op, irast.FunctionCall) else 'operator' ) op_name = str(set_of_op.func_shortname) raise errors.UnsupportedFeatureError( f"cannot use SET OF {label} '{op_name}' " f"in a constraint", span=set_of_op.span ) if ( final_subjectexpr.irast.volatility != qltypes.Volatility.Immutable ): raise errors.InvalidConstraintDefinitionError( f'constraint expressions must be immutable', span=final_subjectexpr.irast.span, ) if final_except_expr: if ( final_except_expr.irast.volatility != qltypes.Volatility.Immutable ): raise errors.InvalidConstraintDefinitionError( f'constraint expressions must be immutable', span=final_except_expr.irast.span, ) if final_expr.irast.volatility != qltypes.Volatility.Immutable: raise errors.InvalidConstraintDefinitionError( f'constraint expressions must be immutable', span=span, ) attrs['finalexpr'] = final_expr attrs['params'] = constr_base.get_params(schema) inherited['params'] = True attrs['abstract'] = False for k, v in attrs.items(): self.set_attribute_value(k, v, inherited=bool(inherited.get(k))) class CreateConstraint( ConstraintCommand, s_func.CreateCallableObject[Constraint], referencing.CreateReferencedInheritingObject[Constraint], ): astnode = [qlast.CreateConcreteConstraint, qlast.CreateConstraint] referenced_astnode = qlast.CreateConcreteConstraint @classmethod def _get_param_desc_from_ast( cls, schema: s_schema.Schema, modaliases: Mapping[Optional[str], str], astnode: qlast.ObjectDDL, *, param_offset: int=0 ) -> list[s_func.ParameterDesc]: if not isinstance(astnode, qlast.CallableObjectCommand): # Concrete constraint. return [] params = super()._get_param_desc_from_ast( schema, modaliases, astnode, param_offset=param_offset + 1) params.insert(0, s_func.ParameterDesc( num=param_offset, name=sn.UnqualName('__subject__'), default=None, type=s_pseudo.PseudoTypeShell(name=sn.UnqualName('anytype')), typemod=ft.TypeModifier.SingletonType, kind=ft.ParameterKind.PositionalParam, )) return params def validate_create( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_create(schema, context) if self.get_referrer_context(context) is not None: # The checks below apply only to abstract constraints. return base_params: Optional[s_func.FuncParameterList] = None base_with_params: Optional[Constraint] = None bases = self.get_resolved_attribute_value( 'bases', schema=schema, context=context, ) for base in bases.objects(schema): params = base.get_params(schema) if params and len(params) > 1: # All constraints have __subject__ parameter # auto-injected, hence the "> 1" check. if base_params is not None: raise errors.InvalidConstraintDefinitionError( f'{self.get_verbosename()} ' f'extends multiple constraints ' f'with parameters', span=self.span, ) base_params = params base_with_params = base if base_params: assert base_with_params is not None params = self._get_params(schema, context) if not params or len(params) == 1: # All constraints have __subject__ parameter # auto-injected, hence the "== 1" check. raise errors.InvalidConstraintDefinitionError( f'{self.get_verbosename()} ' f'must define parameters to reflect parameters of ' f'the {base_with_params.get_verbosename(schema)} ' f'it extends', span=self.span, ) if len(params) < len(base_params): raise errors.InvalidConstraintDefinitionError( f'{self.get_verbosename()} ' f'has fewer parameters than the ' f'{base_with_params.get_verbosename(schema)} ' f'it extends', span=self.span, ) # Skipping the __subject__ param for base_param, param in zip(base_params.objects(schema)[1:], params.objects(schema)[1:]): param_name = param.get_parameter_name(schema) base_param_name = base_param.get_parameter_name(schema) if param_name != base_param_name: raise errors.InvalidConstraintDefinitionError( f'the {param_name!r} parameter of the ' f'{self.get_verbosename()} ' f'must be renamed to {base_param_name!r} ' f'to match the signature of the base ' f'{base_with_params.get_verbosename(schema)} ', span=self.span, ) param_type = param.get_type(schema) base_param_type = base_param.get_type(schema) if ( not base_param_type.is_polymorphic(schema) and param_type.is_polymorphic(schema) ): raise errors.InvalidConstraintDefinitionError( f'the {param_name!r} parameter of the ' f'{self.get_verbosename()} cannot ' f'be of generic type because the corresponding ' f'parameter of the ' f'{base_with_params.get_verbosename(schema)} ' f'it extends has a concrete type', span=self.span, ) if ( not base_param_type.is_polymorphic(schema) and not param_type.is_polymorphic(schema) and not param_type.implicitly_castable_to( base_param_type, schema) ): raise errors.InvalidConstraintDefinitionError( f'the {param_name!r} parameter of the ' f'{self.get_verbosename()} has type of ' f'{param_type.get_displayname(schema)} that ' f'is not implicitly castable to the ' f'corresponding parameter of the ' f'{base_with_params.get_verbosename(schema)} with ' f'type {base_param_type.get_displayname(schema)}', span=self.span, ) def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: referrer_ctx = self.get_referrer_context(context) if referrer_ctx is None: schema = super()._create_begin(schema, context) return schema subject = referrer_ctx.scls assert isinstance(subject, ConsistencySubject) if not subject.can_accept_constraints(schema): raise errors.UnsupportedFeatureError( f'constraints cannot be defined on ' f'{subject.get_verbosename(schema)}', span=self.span, ) if not context.canonical: props = self.get_attributes(schema, context) props.pop('name') props.pop('subject', None) fullname = self.classname shortname = sn.shortname_from_fullname(fullname) assert isinstance(shortname, sn.QualName), \ "expected qualified name" self._populate_concrete_constraint_attrs( schema, context, subject_obj=subject, name=shortname, subjectexpr_inherited=self.is_attribute_inherited( 'subjectexpr'), span=self.span, **props, ) self.set_attribute_value('subject', subject) return super()._create_begin(schema, context) @classmethod def as_inherited_ref_cmd( cls, *, schema: s_schema.Schema, context: sd.CommandContext, astnode: qlast.ObjectDDL, bases: list[Constraint], referrer: so.Object, ) -> sd.ObjectCommand[Constraint]: cmd = super().as_inherited_ref_cmd( schema=schema, context=context, astnode=astnode, bases=bases, referrer=referrer, ) args = cls._constraint_args_from_ast(schema, astnode, context) if args: cmd.set_attribute_value('args', args) subj_expr = bases[0].get_subjectexpr(schema) if subj_expr is not None: cmd.set_attribute_value('subjectexpr', subj_expr, inherited=True) params = bases[0].get_params(schema) if params is not None: cmd.set_attribute_value('params', params, inherited=True) return cmd @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> CreateConstraint: cmd = super()._cmd_tree_from_ast(schema, astnode, context) if isinstance(astnode, qlast.CreateConcreteConstraint): if astnode.delegated: cmd.set_attribute_value('delegated', astnode.delegated) args = cls._constraint_args_from_ast(schema, astnode, context) if args: cmd.set_attribute_value('args', args) elif isinstance(astnode, qlast.CreateConstraint): params = cls._get_param_desc_from_ast( schema, context.modaliases, astnode) for param in params: if param.get_kind(schema) is ft.ParameterKind.NamedOnlyParam: raise errors.InvalidConstraintDefinitionError( 'named only parameters are not allowed ' 'in this context', span=astnode.span) if param.get_default(schema) is not None: raise errors.InvalidConstraintDefinitionError( 'constraints do not support parameters ' 'with defaults', span=astnode.span) if cmd.get_attribute_value('return_type') is None: cmd.set_attribute_value( 'return_type', schema.get('std::bool'), ) if cmd.get_attribute_value('return_typemod') is None: cmd.set_attribute_value( 'return_typemod', ft.TypeModifier.SingletonType, ) assert isinstance(astnode, (qlast.CreateConstraint, qlast.CreateConcreteConstraint)) # 'subjectexpr' can be present in either astnode type if astnode.subjectexpr: orig_text = cls.get_orig_expr_text(schema, astnode, 'subjectexpr') expr_ql: qlast.Expr if ( orig_text is not None and context.compat_ver_is_before( (1, 0, verutils.VersionStage.ALPHA, 6) ) ): # Versions prior to a6 used a different expression # normalization strategy, so we must renormalize the # expression. expr_ql = qlcompiler.renormalize_compat( astnode.subjectexpr, orig_text, schema=schema, localnames=context.localnames, ) else: expr_ql = astnode.subjectexpr subjectexpr = s_expr.Expression.from_ast( expr_ql, schema, context.modaliases, context.localnames, ) cmd.set_attribute_value( 'subjectexpr', subjectexpr, ) if ( isinstance(astnode, qlast.CreateConcreteConstraint) and astnode.except_expr ): except_expr = s_expr.Expression.from_ast( astnode.except_expr, schema, context.modaliases, context.localnames, ) cmd.set_attribute_value('except_expr', except_expr) cls._validate_subcommands(astnode) assert isinstance(cmd, CreateConstraint) return cmd def _skip_param(self, props: dict[str, Any]) -> bool: pname = s_func.Parameter.paramname_from_fullname(props['name']) return pname == '__subject__' def _get_params_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, ) -> list[tuple[int, qlast.FuncParamDecl]]: if isinstance(node, qlast.CreateConstraint): return super()._get_params_ast(schema, context, node) else: return [] def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: if ( op.property == 'args' and isinstance(node, (qlast.CreateConcreteConstraint, qlast.AlterConcreteConstraint)) ): assert isinstance(op.new_value, s_expr.ExpressionList) args = [] for arg in op.new_value: exprast = arg.parse() assert isinstance(exprast, qlast.Expr), "expected qlast.Expr" args.append(exprast) node.args = args return super()._apply_field_ast(schema, context, node, op) @classmethod def _classbases_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> list[so.ObjectShell[Constraint]]: if isinstance(astnode, qlast.CreateConcreteConstraint): classname = cls._classname_from_ast(schema, astnode, context) base_name = sn.shortname_from_fullname(classname) assert isinstance(base_name, sn.QualName), \ "expected qualified name" base = utils.ast_objref_to_object_shell( qlast.ObjectRef( module=base_name.module, name=base_name.name, ), metaclass=Constraint, schema=schema, modaliases=context.modaliases, ) return [base] else: return super()._classbases_from_ast(schema, astnode, context) class RenameConstraint( ConstraintCommand, s_func.RenameCallableObject[Constraint], referencing.RenameReferencedInheritingObject[Constraint], ): @classmethod def _classname_quals_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, base_name: sn.Name, referrer_name: sn.QualName, context: sd.CommandContext, ) -> tuple[str, ...]: parent_op = cls.get_parent_op(context) assert isinstance(parent_op.classname, sn.QualName) return cls._classname_quals_from_name(parent_op.classname) def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) if not context.canonical and self.scls.get_abstract(schema): self._propagate_ref_rename(schema, context, self.scls) return schema class AlterConstraintOwned( referencing.AlterOwned[Constraint], ConstraintCommand, field='owned', referrer_context_class=ConsistencySubjectCommandContext, ): pass class AlterConstraint( ConstraintCommand, referencing.AlterReferencedInheritingObject[Constraint], ): astnode = [qlast.AlterConcreteConstraint, qlast.AlterConstraint] referenced_astnode = qlast.AlterConcreteConstraint def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: referrer_ctx = self.get_referrer_context(context) if referrer_ctx is None: schema = super()._alter_begin(schema, context) return schema subject = referrer_ctx.scls assert isinstance(subject, ConsistencySubject) if not context.canonical: props = self.get_attributes(schema, context) props.pop('name', None) props.pop('subject', None) props.pop('expr', None) args = props.pop('args', None) if not args: args = self.scls.get_args(schema) subjectexpr = props.pop('subjectexpr', None) subjectexpr_inherited = self.is_attribute_inherited('subjectexpr') if not subjectexpr: subjectexpr_inherited = self.scls.field_is_inherited( schema, 'subjectexpr') subjectexpr = self.scls.get_subjectexpr(schema) fullname = self.classname shortname = sn.shortname_from_fullname(fullname) assert isinstance(shortname, sn.QualName), \ "expected qualified name" self._populate_concrete_constraint_attrs( schema, context, subject_obj=subject, name=shortname, subjectexpr=subjectexpr, subjectexpr_inherited=subjectexpr_inherited, args=args, span=self.span, **props, ) return super()._alter_begin(schema, context) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> AlterConstraint: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, AlterConstraint) if isinstance(astnode, (qlast.CreateConcreteConstraint, qlast.AlterConcreteConstraint)): if getattr(astnode, 'delegated', False): assert isinstance(astnode, qlast.CreateConcreteConstraint) cmd.set_attribute_value('delegated', astnode.delegated) new_name = None for op in cmd.get_subcommands(type=RenameConstraint): new_name = op.new_name if new_name is not None: cmd.set_attribute_value('name', new_name) cls._validate_subcommands(astnode) return cmd def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if self.scls.get_abstract(schema): return super()._get_ast(schema, context, parent_node=parent_node) # We need to make sure to include subjectexpr and args # in the AST, since they are really part of the name. op = self.as_inherited_ref_ast( schema, context, self.scls.get_name(schema), self.scls, ) self._apply_fields_ast(schema, context, op) if (op is not None and hasattr(op, 'commands') and not op.commands): return None return op def validate_alter( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_alter(schema, context) self_delegated = self.get_attribute_value('delegated') if not self_delegated: return concrete_bases = [ b for b in self.scls.get_bases(schema).objects(schema) if not b.is_non_concrete(schema) and not b.get_delegated(schema) ] if concrete_bases: tgt_repr = self.scls.get_verbosename(schema, with_parent=True) bases_repr = ', '.join( b.get_subject(schema).get_verbosename(schema, with_parent=True) for b in concrete_bases ) raise errors.InvalidConstraintDefinitionError( f'cannot redefine {tgt_repr} as delegated:' f' it is defined as non-delegated in {bases_repr}', span=self.span, ) def canonicalize_alter_from_external_ref( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: if ( not self.get_attribute_value('abstract') and (subjectexpr := self.get_attribute_value('subjectexpr')) is not None ): assert isinstance(subjectexpr, s_expr.Expression) # To compute the new name, we construct an AST of the # constraint, since that is the infrastructure we have for # computing the classname. name = sn.shortname_from_fullname(self.classname) assert isinstance(name, sn.QualName), "expected qualified name" ast = qlast.CreateConcreteConstraint( name=qlast.ObjectRef(name=name.name, module=name.module), subjectexpr=subjectexpr.parse(), args=[], ) quals = sn.quals_from_fullname(self.classname) new_name = self._classname_from_ast_and_referrer( schema, sn.QualName.from_string(quals[0]), ast, context) if new_name == self.classname: return rename = self.scls.init_delta_command( schema, sd.RenameObject, new_name=new_name) rename.set_attribute_value( 'name', value=new_name, orig_value=self.classname) self.add(rename) def _get_params( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_func.FuncParameterList: return self.scls.get_params(schema) class DeleteConstraint( ConstraintCommand, referencing.DeleteReferencedInheritingObject[Constraint], s_func.DeleteCallableObject[Constraint], ): astnode = [qlast.DropConcreteConstraint, qlast.DropConstraint] referenced_astnode = qlast.DropConcreteConstraint def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: if op.property == 'args': assert isinstance(op.old_value, s_expr.ExpressionList) assert isinstance(node, qlast.DropConcreteConstraint) node.args = [arg.parse() for arg in op.old_value] return super()._apply_field_ast(schema, context, node, op) class RebaseConstraint( ConstraintCommand, referencing.RebaseReferencedInheritingObject[Constraint], ): # finalexpr is fully determined by a bunch of ddl_identity fields, # so it should be inherited by compute_inherited_fields if they are EXTRA_INHERITED_FIELDS = {'finalexpr'} def _get_bases_for_ast( self, schema: s_schema.Schema, context: sd.CommandContext, bases: tuple[so.ObjectShell[Constraint], ...], ) -> tuple[so.ObjectShell[Constraint], ...]: return () def interpolate_error_text(text: str, args: dict[str, str]) -> str: """ Converts message template "hello {world}! {nope}{{world}}" and arguments {"world": "Alice", "hell": "Eve"} into "hello Alice! {world}". """ regex = r"\{\{.*\}\}|\{([A-Za-z_0-9]+)\}" formatted = "" last_start = 0 for match in re.finditer(regex, text, flags=0): formatted += text[last_start : match.start()] last_start = match.end() if match[1] is None: # escape double curly braces formatted += match[0][1:-1] elif match[1] in args: # lookup an arg formatted += args[match[1]] else: # arg not found formatted += match[0] formatted += text[last_start:] return formatted ================================================ FILE: edb/schema/database.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb import errors from edb.common import struct from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.schema import defines as s_def from . import annos as s_anno from . import delta as sd from . import objects as so from . import schema as s_schema from typing import cast class Branch( so.ExternalObject, s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.DATABASE, data_safe=False, ): pass class BranchCommandContext(sd.ObjectCommandContext[Branch]): pass class BranchCommand( sd.ExternalObjectCommand[Branch], context_class=BranchCommandContext, ): def _validate_name( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: name = self.get_attribute_value('name') if len(str(name)) > s_def.MAX_NAME_LENGTH: span = self.get_attribute_span('name') raise errors.SchemaDefinitionError( f'Branch names longer than {s_def.MAX_NAME_LENGTH} ' f'characters are not supported', span=span, ) class CreateBranch(BranchCommand, sd.CreateExternalObject[Branch]): astnode = qlast.CreateDatabase template = struct.Field(str, default=None) branch_type = struct.Field( qlast.BranchType, default=qlast.BranchType.EMPTY) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> CreateBranch: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, CreateBranch) assert isinstance(astnode, qlast.CreateDatabase) if astnode.template is not None: cmd.template = astnode.template.name if ( astnode.branch_type == qlast.BranchType.TEMPLATE and not context.testmode ): raise errors.EdgeQLSyntaxError( f'unexpected TEMPLATE', span=astnode.span, ) cmd.branch_type = astnode.branch_type return cmd def validate_create( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: # no call to super().validate_create() as we don't want to enforce # rules that hold for any other schema objects self._validate_name(schema, context) class AlterBranch(BranchCommand, sd.AlterExternalObject[Branch]): astnode = qlast.AlterDatabase def validate_alter( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_alter(schema, context) self._validate_name(schema, context) class DropBranch(BranchCommand, sd.DeleteExternalObject[Branch]): astnode = qlast.DropDatabase def _validate_legal_command( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super()._validate_legal_command(schema, context) if self.classname.name in s_def.EDGEDB_SPECIAL_DBS: raise errors.ExecutionError( f"database {self.classname.name!r} cannot be dropped" ) class RenameBranch(BranchCommand, sd.RenameObject[Branch]): # databases are ExternalObjects, so they might not be properly # present in the schema, so we can't do a proper rename. def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: scls = self.get_parent_op(context).scls self.scls = cast(Branch, scls) return schema ================================================ FILE: edb/schema/ddl.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Callable, Optional, Iterable, Mapping, cast, TYPE_CHECKING, ) from collections import defaultdict import itertools from edb import errors from edb import edgeql from edb.common import debug from edb.common import uuidgen from edb.common import verutils from edb.edgeql import ast as qlast from edb.edgeql import declarative as s_decl from . import delta as sd from . import expr as s_expr from . import extensions as s_ext from . import functions as s_func from . import migrations as s_migr from . import modules as s_mod from . import name as sn from . import objects as so from . import objtypes as s_objtypes from . import ordering as s_ordering from . import pseudo as s_pseudo from . import schema as s_schema from . import types as s_types from . import version as s_ver if TYPE_CHECKING: import uuid def delta_schemas( schema_a: Optional[s_schema.Schema], schema_b: s_schema.Schema, *, included_modules: Optional[Iterable[sn.Name]]=None, excluded_modules: Optional[Iterable[sn.Name]]=None, included_items: Optional[Iterable[sn.Name]]=None, excluded_items: Optional[Iterable[sn.Name]]=None, schema_a_filters: Iterable[ Callable[[s_schema.Schema, so.Object], bool] ] = (), schema_b_filters: Iterable[ Callable[[s_schema.Schema, so.Object], bool] ] = (), include_module_diff: bool=True, include_std_diff: bool=False, include_derived_types: bool=True, include_extensions: bool=False, linearize_delta: bool=True, descriptive_mode: bool=False, generate_prompts: bool=False, guidance: Optional[so.DeltaGuidance]=None, ) -> sd.DeltaRoot: """Return difference between *schema_a* and *schema_b*. The returned object is a delta tree that, when applied to *schema_a* results in *schema_b*. Args: schema_a: Schema to use as a starting state. If ``None``, then a schema with only standard modules is assumed, unless *include_std_diff* is ``True``, in which case an entirely empty schema is assumed as a starting point. schema_b: Schema to use as the ending state. included_modules: Optional list of modules to include in the delta. excluded_modules: Optional list of modules to exlude from the delta. Takes precedence over *included_modules*. NOTE: standard library modules are always excluded, unless *include_std_diff* is ``True``. included_items: Optional list of names of objects to include in the delta. excluded_items: Optional list of names of objects to exclude from the delta. Takes precedence over *included_items*. schema_a_filters: Optional list of additional filters to place on *schema_a*. schema_b_filters: Optional list of additional filters to place on *schema_b*. include_module_diff: Whether to include create/drop module operations in the delta diff. include_std_diff: Whether to include the standard library in the diff. include_derived_types: Whether to include derived types, like unions, in the diff. linearize_delta: Whether the resulting diff should be properly ordered using the dependencies between objects. descriptive_mode: DESCRIBE AS TEXT mode. generate_prompts: Whether to generate prompts that can be used in DESCRIBE MIGRATION. guidance: Optional explicit guidance to schema diff. Returns: A :class:`schema.delta.DeltaRoot` instances representing the delta between *schema_a* and *schema_b*. """ result = sd.DeltaRoot() schema_a_filters = list(schema_a_filters) schema_b_filters = list(schema_b_filters) context = so.ComparisonContext( generate_prompts=generate_prompts, descriptive_mode=descriptive_mode, guidance=guidance, ) if schema_a is None: if include_std_diff: schema_a = s_schema.EMPTY_SCHEMA else: schema_a = schema_b def _filter(schema: s_schema.Schema, obj: so.Object) -> bool: return ( ( isinstance(obj, so.QualifiedObject) and ( obj.get_name(schema).get_module_name() in s_schema.STD_MODULES ) ) or ( isinstance(obj, s_mod.Module) and obj.get_name(schema) in s_schema.STD_MODULES ) ) schema_a_filters.append(_filter) my_modules = { m.get_name(schema_b) for m in schema_b.get_objects( type=s_mod.Module, extra_filters=schema_b_filters, ) } other_modules = { m.get_name(schema_a) for m in schema_a.get_objects( type=s_mod.Module, extra_filters=schema_a_filters, ) } added_modules = my_modules - other_modules dropped_modules = other_modules - my_modules if included_modules is not None: included_modules = set(included_modules) added_modules &= included_modules dropped_modules &= included_modules else: included_modules = set() if excluded_modules is None: excluded_modules = set() else: excluded_modules = set(excluded_modules) if not include_std_diff: excluded_modules.update(s_schema.STD_MODULES) def _filter(schema: s_schema.Schema, obj: so.Object) -> bool: return not obj.get_builtin(schema) schema_a_filters.append(_filter) schema_b_filters.append(_filter) # In theory, __derived__ is ephemeral and should not need to be # included. In practice, unions created by computed links put # persistent things into __derived__ and need to be included in # diffs. # TODO: Fix this. if not include_derived_types: excluded_modules.add(sn.UnqualName('__derived__')) excluded_modules.add(sn.UnqualName('__ext_casts__')) excluded_modules.add(sn.UnqualName('__ext_index_matches__')) # Don't analyze the objects from extensions. if not include_extensions and isinstance(schema_b, s_schema.ChainedSchema): ext_packages = schema_b._global_schema.get_objects( type=s_ext.ExtensionPackage) ext_mods = set() for pkg in ext_packages: if not (modname := pkg.get_ext_module(schema_b)): continue if schema_a and schema_a.get_referrers(pkg): ext_mods.add(sn.UnqualName(modname)) for submod in schema_a.get_modules(): submod_name = submod.get_name(schema_a) assert isinstance(submod_name, sn.UnqualName) if submod_name.name.startswith(modname + '::'): ext_mods.add(submod_name) if schema_b.get_referrers(pkg): ext_mods.add(sn.UnqualName(modname)) for submod in schema_b.get_modules(): submod_name = submod.get_name(schema_b) assert isinstance(submod_name, sn.UnqualName) if submod_name.name.startswith(modname + '::'): ext_mods.add(submod_name) for ext_mod in ext_mods: if ext_mod not in included_modules: excluded_modules.add(ext_mod) if excluded_modules: added_modules -= excluded_modules dropped_modules -= excluded_modules if included_items is not None: included_items = set(included_items) if excluded_items is not None: excluded_items = set(excluded_items) if include_module_diff: for added_module in sorted(added_modules): if ( guidance is None or ( (s_mod.Module, added_module) not in guidance.banned_creations ) ): mod = schema_b.get_global(s_mod.Module, added_module, None) assert mod is not None create = mod.as_create_delta( schema=schema_b, context=context, ) assert isinstance(create, sd.CreateObject) create.if_not_exists = True # We currently fully assume that modules are created # or deleted and never renamed. This is fine, because module # objects are never actually referenced directly, only by # the virtue of being the leading part of a fully-qualified # name. create.set_annotation('confidence', 1.0) result.add(create) excluded_classes = ( so.GlobalObject, s_mod.Module, s_func.Parameter, s_pseudo.PseudoType, s_migr.Migration, ) schemaclasses = [ schemacls for schemacls in so.ObjectMeta.get_schema_metaclasses() if ( not issubclass(schemacls, excluded_classes) and not schemacls.is_abstract() ) ] assert not context.renames # We retry performing the diff until we stop finding new renames # and deletions. This allows us to be agnostic to the order that # we process schemaclasses. old_count = -1, -1 while old_count != (len(context.renames), len(context.deletions)): old_count = len(context.renames), len(context.deletions) objects = sd.DeltaRoot() for sclass in schemaclasses: filters: list[Callable[[s_schema.Schema, so.Object], bool]] = [] if not issubclass(sclass, so.QualifiedObject): # UnqualifiedObjects (like anonymous tuples and arrays) # should not use an included_modules filter. incl_modules = None else: if issubclass(sclass, so.DerivableObject): def _only_generic( schema: s_schema.Schema, obj: so.Object, ) -> bool: assert isinstance(obj, so.DerivableObject) return obj.is_non_concrete(schema) or ( isinstance(obj, s_types.Type) and obj.get_from_global(schema) ) filters.append(_only_generic) incl_modules = included_modules new = schema_b.get_objects( type=sclass, included_modules=incl_modules, excluded_modules=excluded_modules, included_items=included_items, excluded_items=excluded_items, extra_filters=filters + schema_b_filters, ) old = schema_a.get_objects( type=sclass, included_modules=incl_modules, excluded_modules=excluded_modules, included_items=included_items, excluded_items=excluded_items, extra_filters=filters + schema_a_filters, ) objects.add( sd.delta_objects( old, new, sclass=sclass, old_schema=schema_a, new_schema=schema_b, context=context, ) ) # We don't propertly understand the dependencies on extensions, so # instead of having s_ordering sort them, we just put all # CreateExtension commands first and all DeleteExtension commands # last. create_exts: list[s_ext.CreateExtension] = [] delete_exts = [] for cmd in list(objects.get_subcommands()): if isinstance(cmd, s_ext.CreateExtension): cmd.canonical = False objects.discard(cmd) create_exts.append(cmd) elif isinstance(cmd, s_ext.DeleteExtension): cmd.canonical = False objects.discard(cmd) delete_exts.append(cmd) if linearize_delta: objects = s_ordering.linearize_delta( objects, old_schema=schema_a, new_schema=schema_b) if include_derived_types: result.add(objects) else: for cmd in objects.get_subcommands(): if isinstance(cmd, s_objtypes.ObjectTypeCommand): if isinstance(cmd, s_objtypes.DeleteObjectType): relevant_schema = schema_a else: relevant_schema = schema_b obj = cast(s_objtypes.ObjectType, relevant_schema.get(cmd.classname)) if obj.is_union_type(relevant_schema): continue result.add(cmd) if include_module_diff: # Process dropped modules in *reverse* sorted order, so that # `foo::bar` gets dropped before `foo`. for dropped_module in reversed(sorted(dropped_modules)): if ( guidance is None or ( (s_mod.Module, dropped_module) not in guidance.banned_deletions ) ): mod = schema_a.get_global(s_mod.Module, dropped_module, None) assert mod is not None dropped = mod.as_delete_delta( schema=schema_a, context=context, ) dropped.set_annotation('confidence', 1.0) result.add(dropped) create_exts_sorted = sd.sort_by_cross_refs_key( schema_b, create_exts, key=lambda x: x.scls) delete_exts_sorted = sd.sort_by_cross_refs_key( schema_a, delete_exts, key=lambda x: x.scls) for op in create_exts_sorted: result.prepend(op) result.update(delete_exts_sorted) return result def cmd_from_ddl( stmt: qlast.DDLOperation, *, context: Optional[sd.CommandContext]=None, schema: s_schema.Schema, modaliases: Mapping[Optional[str], str], testmode: bool=False ) -> sd.Command: ddl = s_expr.imprint_expr_context(stmt, modaliases) assert isinstance(ddl, qlast.DDLOperation) if context is None: context = sd.CommandContext( schema=schema, modaliases=modaliases, testmode=testmode) res = sd.compile_ddl(schema, ddl, context=context) context.early_renames.clear() return res def apply_sdl( sdl_document: qlast.Schema, *, base_schema: s_schema.Schema, stdmode: bool = False, testmode: bool = False, ) -> tuple[s_schema.Schema, list[errors.EdgeDBError]]: # group declarations by module documents: dict[str, list[qlast.DDLCommand]] = defaultdict(list) # initialize the "default" module documents[s_mod.DEFAULT_MODULE_ALIAS] = [] extensions = {} futures = {} def collect( decl: qlast.ObjectDDL | qlast.ModuleDeclaration, module: Optional[str], ) -> None: # declarations are either in a module block or fully-qualified if isinstance(decl, qlast.ModuleDeclaration): new_mod = ( f'{module}::{decl.name.name}' if module else decl.name.name) # make sure the new one is present documents.setdefault(new_mod, []) for sdecl in decl.declarations: collect(sdecl, new_mod) elif isinstance(decl, qlast.CreateExtension): assert not module extensions[decl.name.name] = decl elif isinstance(decl, qlast.CreateFuture): assert not module futures[decl.name.name] = decl else: assert isinstance(decl, qlast.ObjectDDL) assert module or decl.name.module is not None if decl.name.module is None: assert module name = module else: name = ( f'{module}::{decl.name.name}' if module else decl.name.module) documents[name].append(decl) context = sd.CommandContext( modaliases={}, schema=base_schema, stdmode=stdmode, testmode=testmode, declarative=True, ) for decl in sdl_document.declarations: collect(decl, None) target_schema = base_schema warnings = [] def process(ddl_stmt: qlast.DDLCommand) -> None: nonlocal target_schema delta = sd.DeltaRoot() with context(sd.DeltaRootContext(schema=target_schema, op=delta)): cmd = cmd_from_ddl( ddl_stmt, schema=target_schema, modaliases={}, context=context, testmode=testmode) delta.add(cmd) target_schema = delta.apply(target_schema, context) context.schema = target_schema warnings.extend(delta.warnings) # Process all the extensions first, since sdl_to_ddl needs to be # able to see their contents. While we do so, also collect any # transitive dependency extensions and add those as well. We this # dependency resolution automatically as part of SDL processing # instead of when doing CREATE EXTENSION because I didn't want # *DROP EXTENSION* to automatically drop transitive dependencies, # and so CREATE EXTENSION shouldn't either, symmetrically. extensions_done = set() def process_ext(ddl_stmt: qlast.CreateExtension) -> None: name = ddl_stmt.name.name pkg = s_ext.get_package( sn.UnqualName(name), ( verutils.parse_version(ddl_stmt.version.value) if ddl_stmt.version else None ), base_schema, ) pkg_ver = pkg.get_version(base_schema) if (name, pkg_ver) in extensions_done: return extensions_done.add((name, pkg_ver)) if pkg: for dep in pkg.get_dependencies(base_schema): if '>=' not in dep: builtin = ( 'built-in ' if pkg.get_builtin(base_schema) else '' ) raise errors.SchemaError( f'{builtin}extension {name} missing version for {dep}') dep, dep_version = dep.split('>=') process_ext( qlast.CreateExtension( name=qlast.ObjectRef(name=dep), version=qlast.Constant.string(value=dep_version), ) ) process(ddl_stmt) ddl_stmt: qlast.DDLCommand for ddl_stmt in extensions.values(): process_ext(ddl_stmt) # Now, sort the main body of SDL and apply it. ddl_stmts = s_decl.sdl_to_ddl(target_schema, documents) if debug.flags.sdl_loading: debug.header('SDL loading script') for ddl_stmt in itertools.chain( extensions.values(), futures.values(), ddl_stmts ): ddl_stmt.dump_edgeql() for ddl_stmt in itertools.chain(futures.values(), ddl_stmts): process(ddl_stmt) return target_schema, warnings def apply_ddl_script( ddl_text: str, *, schema: s_schema.Schema, modaliases: Optional[Mapping[Optional[str], str]] = None, stdmode: bool = False, testmode: bool = False, ) -> s_schema.Schema: schema, _ = apply_ddl_script_ex( ddl_text, schema=schema, modaliases=modaliases, stdmode=stdmode, testmode=testmode, ) return schema def apply_ddl_script_ex( ddl_text: str, *, schema: s_schema.Schema, modaliases: Optional[Mapping[Optional[str], str]] = None, stdmode: bool = False, internal_schema_mode: bool = False, testmode: bool = False, store_migration_sdl: bool=False, schema_object_ids: Optional[ Mapping[tuple[sn.Name, Optional[str]], uuid.UUID] ]=None, compat_ver: Optional[verutils.Version] = None, ) -> tuple[s_schema.Schema, sd.DeltaRoot]: delta = sd.DeltaRoot() if modaliases is None: modaliases = {} for ddl_stmt in edgeql.parse_block(ddl_text): if not isinstance(ddl_stmt, qlast.DDLCommand): raise AssertionError(f'expected DDLCommand node, got {ddl_stmt!r}') schema, cmd = delta_and_schema_from_ddl( ddl_stmt, schema=schema, modaliases=modaliases, stdmode=stdmode, internal_schema_mode=internal_schema_mode, testmode=testmode, store_migration_sdl=store_migration_sdl, schema_object_ids=schema_object_ids, compat_ver=compat_ver, ) delta.add(cmd) return schema, delta def delta_from_ddl( ddl_stmt: qlast.DDLCommand, *, schema: s_schema.Schema, modaliases: Mapping[Optional[str], str], stdmode: bool=False, testmode: bool=False, store_migration_sdl: bool=False, schema_object_ids: Optional[ Mapping[tuple[sn.Name, Optional[str]], uuid.UUID] ]=None, compat_ver: Optional[verutils.Version] = None, ) -> sd.DeltaRoot: _, cmd = delta_and_schema_from_ddl( ddl_stmt, schema=schema, modaliases=modaliases, stdmode=stdmode, testmode=testmode, store_migration_sdl=store_migration_sdl, schema_object_ids=schema_object_ids, compat_ver=compat_ver, ) return cmd def delta_and_schema_from_ddl( ddl_stmt: qlast.DDLCommand, *, schema: s_schema.Schema, modaliases: Mapping[Optional[str], str], stdmode: bool=False, internal_schema_mode: bool=False, testmode: bool=False, store_migration_sdl: bool=False, schema_object_ids: Optional[ Mapping[tuple[sn.Name, Optional[str]], uuid.UUID] ]=None, compat_ver: Optional[verutils.Version] = None, ) -> tuple[s_schema.Schema, sd.DeltaRoot]: delta = sd.DeltaRoot() context = sd.CommandContext( modaliases=modaliases, schema=schema, stdmode=stdmode, internal_schema_mode=internal_schema_mode, testmode=testmode, store_migration_sdl=store_migration_sdl, schema_object_ids=schema_object_ids, compat_ver=compat_ver, ) with context(sd.DeltaRootContext(schema=schema, op=delta)): cmd = cmd_from_ddl( ddl_stmt, schema=schema, modaliases=modaliases, context=context, testmode=testmode, ) if debug.flags.delta_plan: debug.header('Delta Plan Input') debug.dump(cmd) schema = cmd.apply(schema, context) if not stdmode: if not isinstance( cmd, (sd.GlobalObjectCommand, sd.ExternalObjectCommand), ): ver = schema.get_global( s_ver.SchemaVersion, '__schema_version__') ver_cmd = ver.init_delta_command(schema, sd.AlterObject) ver_cmd.set_attribute_value('version', uuidgen.uuid1mc()) schema = ver_cmd.apply(schema, context) delta.add(ver_cmd) elif not isinstance(cmd, sd.ExternalObjectCommand): gver = schema.get_global( s_ver.GlobalSchemaVersion, '__global_schema_version__') g_ver_cmd = gver.init_delta_command(schema, sd.AlterObject) g_ver_cmd.set_attribute_value('version', uuidgen.uuid1mc()) schema = g_ver_cmd.apply(schema, context) delta.add(g_ver_cmd) delta.add(cmd) delta.canonical = True return schema, delta def ddlast_from_delta( schema_a: Optional[s_schema.Schema], schema_b: s_schema.Schema, delta: sd.DeltaRoot, *, sdlmode: bool = False, testmode: bool = False, descriptive_mode: bool = False, include_ext_version: bool = True, ) -> dict[qlast.DDLOperation, sd.Command]: context = sd.CommandContext( descriptive_mode=descriptive_mode, declarative=sdlmode, testmode=testmode, include_ext_version=include_ext_version, ) if schema_a is None: schema = schema_b update_schema = False else: schema = schema_a update_schema = True stmts = {} for command in delta.get_subcommands(): with context(sd.DeltaRootContext(schema=schema, op=delta)): # The reason we do this instead of just directly using the new # schema is to populate the renames field of the context. # We do this one part at a time to avoid referring to things # that have not been renamed yet. # XXX: Is this fine-grained enough, though? if update_schema: schema = command.apply(schema, context) ql_ast = command.get_ast(schema, context) if ql_ast: stmts[ql_ast] = command return stmts def statements_from_delta( schema_a: Optional[s_schema.Schema], schema_b: s_schema.Schema, delta: sd.DeltaRoot, *, sdlmode: bool = False, descriptive_mode: bool = False, # Used for backwards compatibility with older migration text. uppercase: bool = False, limit_ref_classes: Iterable[so.ObjectMeta] = tuple(), include_ext_version: bool = True, ) -> tuple[tuple[str, qlast.DDLOperation, sd.Command], ...]: stmts = ddlast_from_delta( schema_a, schema_b, delta, sdlmode=sdlmode, descriptive_mode=descriptive_mode, include_ext_version=include_ext_version, ) ql_classes_src = { scls.get_ql_class() for scls in limit_ref_classes } ql_classes = {q for q in ql_classes_src if q is not None} # If we're generating SDL and it includes modules, try to nest the # module contents in the actual modules. processed: list[tuple[qlast.DDLOperation, sd.Command]] = [] unqualified: list[tuple[qlast.DDLOperation, sd.Command]] = [] modules = dict() for stmt_ast, cmd in stmts.items(): if sdlmode: if isinstance(stmt_ast, qlast.CreateModule): # Record the module stubs. modules[stmt_ast.name.name] = stmt_ast stmt_ast.commands = [] processed.append((stmt_ast, cmd)) elif ( modules and not isinstance(stmt_ast, qlast.UnqualifiedObjectCommand) ): # This SDL included creation of modules, so we will try to # nest the declarations in them. assert isinstance(stmt_ast, qlast.CreateObject) assert stmt_ast.name.module is not None module = modules[stmt_ast.name.module] module.commands.append(stmt_ast) # Strip the module from the object name, since we nest # them in a module already. stmt_ast.name.module = None elif isinstance(stmt_ast, qlast.UnqualifiedObjectCommand): unqualified.append((stmt_ast, cmd)) else: processed.append((stmt_ast, cmd)) else: processed.append((stmt_ast, cmd)) text = [] for stmt_ast, cmd in itertools.chain(unqualified, processed): stmt_text = edgeql.generate_source( stmt_ast, sdlmode=sdlmode, descmode=descriptive_mode, limit_ref_classes=ql_classes, uppercase=uppercase, ) text.append((stmt_text + ';', stmt_ast, cmd)) return tuple(text) def text_from_delta( schema_a: Optional[s_schema.Schema], schema_b: s_schema.Schema, delta: sd.DeltaRoot, *, sdlmode: bool = False, descriptive_mode: bool = False, limit_ref_classes: Iterable[so.ObjectMeta] = tuple(), include_ext_version: bool = True, ) -> str: stmts = statements_from_delta( schema_a, schema_b, delta, sdlmode=sdlmode, descriptive_mode=descriptive_mode, limit_ref_classes=limit_ref_classes, include_ext_version=include_ext_version, ) return '\n'.join(text for text, _, _ in stmts) def ddl_text_from_delta( schema_a: Optional[s_schema.Schema], schema_b: s_schema.Schema, delta: sd.DeltaRoot, *, include_ext_version: bool = True, ) -> str: """Return DDL text corresponding to a delta plan. Args: schema_a: The original schema (or None if starting from empty/std) schema_b: The schema to which the *delta* has **already** been applied. delta: The delta plan. Returns: DDL text corresponding to *delta*. """ return text_from_delta( schema_a, schema_b, delta, sdlmode=False, include_ext_version=include_ext_version, ) def sdl_text_from_delta( schema_a: Optional[s_schema.Schema], schema_b: s_schema.Schema, delta: sd.DeltaRoot, ) -> str: """Return SDL text corresponding to a delta plan. Args: schema_a: The original schema (or None if starting from empty/std) schema_b: The schema to which the *delta* has **already** been applied. delta: The delta plan. Returns: SDL text corresponding to *delta*. """ return text_from_delta(schema_a, schema_b, delta, sdlmode=True) def descriptive_text_from_delta( schema_a: Optional[s_schema.Schema], schema_b: s_schema.Schema, delta: sd.DeltaRoot, *, limit_ref_classes: Iterable[so.ObjectMeta]=tuple(), ) -> str: """Return descriptive text corresponding to a delta plan. Args: schema_a: The original schema (or None if starting from empty/std) schema_b: The schema to which the *delta* has **already** been applied. delta: The delta plan. limit_ref_classes: If specified, limit the output of referenced objects to the specified classes. Returns: Descriptive text corresponding to *delta*. """ return text_from_delta( schema_a, schema_b, delta, sdlmode=True, descriptive_mode=True, limit_ref_classes=limit_ref_classes, ) def ddl_text_from_schema( schema: s_schema.Schema, *, included_modules: Optional[Iterable[sn.Name]] = None, excluded_modules: Optional[Iterable[sn.Name]] = None, included_items: Optional[Iterable[sn.Name]] = None, excluded_items: Optional[Iterable[sn.Name]] = None, included_ref_classes: Iterable[so.ObjectMeta] = tuple(), include_module_ddl: bool = True, include_std_ddl: bool = False, include_migrations: bool = False, ) -> str: diff = delta_schemas( schema_a=None, schema_b=schema, included_modules=included_modules, excluded_modules=excluded_modules, included_items=included_items, excluded_items=excluded_items, include_module_diff=include_module_ddl, include_std_diff=include_std_ddl, include_derived_types=False, ) if include_migrations: context = so.ComparisonContext() for mig in s_migr.get_ordered_migrations(schema): diff.add(mig.as_create_delta(schema, context)) return ddl_text_from_delta(None, schema, diff, include_ext_version=not include_migrations) def sdl_text_from_schema( schema: s_schema.Schema, *, included_modules: Optional[Iterable[sn.Name]] = None, excluded_modules: Optional[Iterable[sn.Name]] = None, included_items: Optional[Iterable[sn.Name]] = None, excluded_items: Optional[Iterable[sn.Name]] = None, included_ref_classes: Iterable[so.ObjectMeta] = tuple(), include_module_ddl: bool = True, include_std_ddl: bool = False, ) -> str: diff = delta_schemas( schema_a=None, schema_b=schema, included_modules=included_modules, excluded_modules=excluded_modules, included_items=included_items, excluded_items=excluded_items, include_module_diff=include_module_ddl, include_std_diff=include_std_ddl, include_derived_types=False, linearize_delta=False, ) return sdl_text_from_delta(None, schema, diff) def descriptive_text_from_schema( schema: s_schema.Schema, *, included_modules: Optional[Iterable[sn.Name]] = None, excluded_modules: Optional[Iterable[sn.Name]] = None, included_items: Optional[Iterable[sn.Name]] = None, excluded_items: Optional[Iterable[sn.Name]] = None, included_ref_classes: Iterable[so.ObjectMeta] = tuple(), include_module_ddl: bool = True, include_std_ddl: bool = False, include_derived_types: bool = False, ) -> str: diff = delta_schemas( schema_a=None, schema_b=schema, included_modules=included_modules, excluded_modules=excluded_modules, included_items=included_items, excluded_items=excluded_items, include_module_diff=include_module_ddl, include_std_diff=include_std_ddl, include_derived_types=False, linearize_delta=False, descriptive_mode=True, ) return descriptive_text_from_delta( None, schema, diff, limit_ref_classes=included_ref_classes) ================================================ FILE: edb/schema/defines.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations # Maximum length of Postgres tenant ID. MAX_TENANT_ID_LENGTH = 10 # Maximum length of names that are reflected 1:1 to Postgres: MAX_NAME_LENGTH = 63 - MAX_TENANT_ID_LENGTH - 1 - 1 # ^ ^ ^ # max Postgres name len tenant_id scheme tenant_id separator # Maximum number of arguments supported by SQL functions. MAX_FUNC_ARG_COUNT = 100 EDGEDB_SUPERUSER = 'admin' EDGEDB_OLD_SUPERUSER = 'edgedb' EDGEDB_TEMPLATE_DB = '__edgedbtpl__' EDGEDB_SYSTEM_DB = '__edgedbsys__' EDGEDB_SPECIAL_DBS = {EDGEDB_TEMPLATE_DB, EDGEDB_SYSTEM_DB} ================================================ FILE: edb/schema/delta.py ================================================ # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( AbstractSet, Any, Callable, cast, ClassVar, Generator, Generic, Hashable, Iterable, Iterator, Mapping, NoReturn, Optional, overload, Self, Sequence, TypeVar, ) import collections import collections.abc import contextlib import functools import itertools import uuid from edb import errors from edb.common import adapter from edb.common import ast from edb.common import checked from edb.common import markup from edb.common import ordered from edb.common import parsing from edb.common import struct from edb.common import topological from edb.common import typing_inspect from edb.common import verutils from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from . import expr as s_expr from . import name as sn from . import objects as so from . import schema as s_schema from . import utils def delta_objects( old_in: Iterable[so.Object_T], new_in: Iterable[so.Object_T], sclass: type[so.Object_T], *, parent_confidence: Optional[float] = None, context: so.ComparisonContext, old_schema: s_schema.Schema, new_schema: s_schema.Schema, ) -> DeltaRoot: delta = DeltaRoot() # TODO: Previously, we attempted to do an optimization based on # computing a hash_criteria of each object, in order to discard # unchanged objects. Because hash_criteria returns values that # include schema objects, the optimization didn't work; it wasn't # cheap, so I removed it. But the general idea is sound and should # be revisited. old = {o.get_name(old_schema): o for o in old_in} new = {o.get_name(new_schema): o for o in new_in} # If an object exists with the same name in both the old and the # new schemas, we don't compare those objects to anything but # each other. This makes our runtime linear in most common cases, # though the worst case remains quadratic. # # This unfortunately means that we have trouble understanding # "chain renames" (Foo -> Bar, Bar -> Baz), which may be worth # addressing in the future, but we fail to understand that for # other reasons also. # Collect all the pairs of objects with the same name in both schemas. pairs = [ (new[k], o) for k, o in old.items() if k in new ] # Then collect the cross product of all the other objects. pairs.extend( itertools.product( [o for k, o in new.items() if k not in old], [o for k, o in old.items() if k not in new], ) ) full_matrix: list[tuple[so.Object_T, so.Object_T, float]] = [] # If there are any renames that are already decided on, honor those first renames_x: set[sn.Name] = set() renames_y: set[sn.Name] = set() for y in old.values(): rename = context.renames.get((type(y), y.get_name(old_schema))) if rename: renames_x.add(rename.new_name) renames_y.add(rename.classname) if context.guidance is not None: guidance = context.guidance # In these functions, we need to look at the actual object to # figure out the type instead of just using sclass because # sclass might be an abstract parent like Pointer. def can_create(obj: so.Object_T, name: sn.Name) -> bool: return (type(obj), name) not in guidance.banned_creations def can_alter( obj: so.Object_T, old_name: sn.Name, new_name: sn.Name ) -> bool: return ( (type(obj), (old_name, new_name)) not in guidance.banned_alters) def can_delete(obj: so.Object_T, name: sn.Name) -> bool: return (type(obj), name) not in guidance.banned_deletions else: def can_create(obj: so.Object_T, name: sn.Name) -> bool: return True def can_alter( obj: so.Object_T, old_name: sn.Name, new_name: sn.Name ) -> bool: return True def can_delete(obj: so.Object_T, name: sn.Name) -> bool: return True for x, y in pairs: x_name = x.get_name(new_schema) y_name = y.get_name(old_schema) similarity = y.compare( x, our_schema=old_schema, their_schema=new_schema, context=context, ) # If similarity for an alter is 1.0, that means there is no # actual change. We keep that, since otherwise we will generate # extra drop/create pairs when we are already done. if similarity < 1.0 and not can_alter(y, y_name, x_name): similarity = 0.0 full_matrix.append((x, y, similarity)) full_matrix.sort( key=lambda v: ( 1.0 - v[2], str(v[0].get_name(new_schema)), str(v[1].get_name(old_schema)), ), ) full_matrix_x = {} full_matrix_y = {} seen_x = set() seen_y = set() x_alter_variants: dict[so.Object_T, int] = collections.defaultdict(int) y_alter_variants: dict[so.Object_T, int] = collections.defaultdict(int) comparison_map: dict[so.Object_T, tuple[float, so.Object_T]] = {} comparison_map_y: dict[so.Object_T, tuple[float, so.Object_T]] = {} # Find the top similarity pairs for x, y, similarity in full_matrix: if x not in seen_x and y not in seen_y: comparison_map[x] = (similarity, y) comparison_map_y[y] = (similarity, x) seen_x.add(x) seen_y.add(y) if x not in full_matrix_x: full_matrix_x[x] = (similarity, y) if y not in full_matrix_y: full_matrix_y[y] = (similarity, x) if ( can_alter(y, y.get_name(old_schema), x.get_name(new_schema)) and full_matrix_x[x][0] != 1.0 and full_matrix_y[y][0] != 1.0 ): x_alter_variants[x] += 1 y_alter_variants[y] += 1 alters = [] alter_pairs = [] if comparison_map: if issubclass(sclass, so.InheritingObject): # Generate the diff from the top of the inheritance # hierarchy, since changes to parent objects may inform # how the delta in child objects is treated. order_x = cast( Iterable[so.Object_T], sort_by_inheritance( new_schema, cast(Iterable[so.InheritingObject], comparison_map), ), ) else: order_x = comparison_map for x in order_x: confidence, y = comparison_map[x] x_name = x.get_name(new_schema) y_name = y.get_name(old_schema) already_has = x_name == y_name and x_name not in renames_x if ( (0.6 < confidence < 1.0 and can_alter(y, y_name, x_name)) or ( (not can_create(x, x_name) or not can_delete(y, y_name)) and can_alter(y, y_name, x_name) ) or x_name in renames_x ): alter_pairs.append((x, y)) alter = y.as_alter_delta( other=x, context=context, self_schema=old_schema, other_schema=new_schema, confidence=confidence, ) # If we are basically certain about this alter, # make the confidence 1.0, unless child steps # are not confident. if not ( (x_alter_variants[x] > 1 or ( not already_has and can_create(x, x_name))) and parent_confidence != 1.0 ): cons = [ sub.get_annotation('confidence') for sub in alter.get_subcommands(type=ObjectCommand) ] confidence = min( [1.0, *[c for c in cons if c is not None]]) alter.set_annotation('confidence', confidence) alters.append(alter) elif confidence == 1.0: alter_pairs.append((x, y)) created = ordered.OrderedSet(new.values()) - {x for x, _ in alter_pairs} for x in created: x_name = x.get_name(new_schema) if can_create(x, x_name) and x_name not in renames_x: create = x.as_create_delta(schema=new_schema, context=context) if x_alter_variants[x] > 0 and parent_confidence != 1.0: confidence = full_matrix_x[x][0] else: confidence = 1.0 create.set_annotation('confidence', confidence) delta.add(create) delta.update(alters) deleted_order: Iterable[so.Object_T] deleted = ordered.OrderedSet(old.values()) - {y for _, y in alter_pairs} if issubclass(sclass, so.InheritingObject): deleted_order = sort_by_inheritance( # type: ignore[assignment] old_schema, cast(Iterable[so.InheritingObject], deleted), ) else: deleted_order = deleted for y in deleted_order: y_name = y.get_name(old_schema) if can_delete(y, y_name) and y_name not in renames_y: delete = y.as_delete_delta(schema=old_schema, context=context) if y_alter_variants[y] > 0 and parent_confidence != 1.0: confidence = full_matrix_y[y][0] else: confidence = 1.0 delete.set_annotation('confidence', confidence) delta.add(delete) return delta def sort_by_inheritance( schema: s_schema.Schema, objs: Iterable[so.InheritingObjectT], ) -> tuple[so.InheritingObjectT, ...]: graph = {} for x in objs: graph[x] = topological.DepGraphEntry( item=x, deps=ordered.OrderedSet(x.get_ancestors(schema).objects(schema)), extra=False, ) return topological.sort(graph, allow_unresolved=True) T = TypeVar("T") def sort_by_cross_refs_key( schema: s_schema.Schema, objs: Iterable[T], *, key: Callable[[T], so.Object], ) -> tuple[T, ...]: """Sort an iterable of objects according to cross-references between them. Return a toplogical ordering of a graph of objects joined by references. It is assumed that the graph has no cycles. """ graph = {} # We want to report longer cycles before trivial self references, # since cycles with (for example) computed properties will *also* # lead to self references (because the computed property gets # inlined, essentially). self_ref = None for entry in objs: x = key(entry) referrers = schema.get_referrers(x) if x in referrers: self_ref = x graph[x] = topological.DepGraphEntry( item=entry, deps={ref for ref in referrers if not x.is_parent_ref(schema, ref) and x != ref}, extra=False, ) res = topological.sort(graph, allow_unresolved=True) if self_ref: raise topological.CycleError( f"{self_ref!r} refers to itself", item=self_ref) return res def sort_by_cross_refs[ObjectT: so.Object]( schema: s_schema.Schema, objs: Iterable[ObjectT], ) -> tuple[ObjectT, ...]: return sort_by_cross_refs_key(schema, objs, key=lambda x: x) class CommandMeta( adapter.Adapter, struct.MixedStructMeta, ): _astnode_map: dict[type[qlast.DDLOperation], type[Command]] = {} def __new__[CommandMeta_T: CommandMeta]( mcls: type[CommandMeta_T], name: str, bases: tuple[type, ...], dct: dict[str, Any], *, context_class: Optional[type[CommandContextToken[Command]]] = None, **kwargs: Any, ) -> CommandMeta_T: cls = super().__new__(mcls, name, bases, dct, **kwargs) if context_class is not None: cast(Command, cls)._context_class = context_class return cls def __init__( cls, name: str, bases: tuple[type, ...], clsdict: dict[str, Any], *, adapts: Optional[type] = None, **kwargs: Any, ) -> None: adapter.Adapter.__init__(cls, name, bases, clsdict, adapts=adapts) struct.MixedStructMeta.__init__(cls, name, bases, clsdict) astnodes = clsdict.get('astnode') if astnodes and not isinstance(astnodes, (list, tuple)): astnodes = [astnodes] if astnodes: cls.register_astnodes(astnodes) def register_astnodes( cls, astnodes: Iterable[type[qlast.DDLCommand]], ) -> None: mapping = type(cls)._astnode_map for astnode in astnodes: existing = mapping.get(astnode) if existing: msg = ('duplicate EdgeQL AST node to command mapping: ' + '{!r} is already declared for {!r}') raise TypeError(msg.format(astnode, existing)) mapping[astnode] = cast(type["Command"], cls) # We use _DummyObject for contexts where an instance of an object is # required by type signatures, and the actual reference will be quickly # replaced by a real object. _dummy_object = so.Object( _private_id=uuid.UUID('C0FFEE00-C0DE-0000-0000-000000000000'), ) Command_T = TypeVar("Command_T", bound="Command") Command_T_co = TypeVar("Command_T_co", bound="Command", covariant=True) class Command( struct.MixedStruct, markup.MarkupCapableMixin, metaclass=CommandMeta, ): span = struct.Field(parsing.Span, default=None) canonical = struct.Field(bool, default=False) _context_class: Optional[type[CommandContextToken[Command]]] = None #: An optional list of commands that are prerequisites of this #: command and must run before any of the operations in this #: command or its subcommands in ops or caused_ops. before_ops: list[Command] #: An optional list of subcommands that are considered to be #: integral part of this command. ops: list[Command] #: An optional list of commands that are _caused_ by this command, #: such as any propagation to children or any other side-effects #: that are not considered integral to this command. caused_ops: list[Command] #: AlterObjectProperty lookup table for get|set_attribute_value _attrs: dict[str, AlterObjectProperty] #: AlterSpecialObjectField lookup table _special_attrs: dict[str, AlterSpecialObjectField[so.Object]] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.ops = [] self.before_ops = [] self.caused_ops = [] self.qlast: qlast.DDLOperation self._attrs = {} self._special_attrs = {} def dump(self) -> None: markup.dump(self) def copy(self: Self) -> Self: result = super().copy() result.before_ops = [op.copy() for op in self.before_ops] result.ops = [op.copy() for op in self.ops] result.caused_ops = [op.copy() for op in self.caused_ops] return result def get_verb(self) -> str: """Return a verb representing this command in infinitive form.""" raise NotImplementedError def get_friendly_description( self, *, parent_op: Optional[Command] = None, schema: Optional[s_schema.Schema] = None, object: Any = None, object_desc: Optional[str] = None, ) -> str: """Return a friendly description of this command in imperative mood. The result is used in error messages and other user-facing renderings of the command. """ raise NotImplementedError @classmethod def adapt(cls: type[Command_T], obj: Command) -> Command_T: result = obj.copy_with_class(cls) mcls = cast(CommandMeta, type(cls)) for op in obj.get_prerequisites(): result.add_prerequisite(mcls.adapt(op)) for op in obj.get_subcommands( include_prerequisites=False, include_caused=False, ): result.add(mcls.adapt(op)) for op in obj.get_caused(): result.add_caused(mcls.adapt(op)) return result def is_data_safe(self) -> bool: return False def get_required_user_input(self) -> list[dict[str, str]]: return [] def record_diff_annotations( self, *, schema: s_schema.Schema, orig_schema: Optional[s_schema.Schema], context: so.ComparisonContext, orig_object: Optional[so.Object], object: Optional[so.Object], ) -> None: """Record extra information on a delta obtained by diffing schemas. This provides an apportunity for a delta command to annotate itself in schema diff schenarios (i.e. migrations). Args: schema: Final schema of a migration. orig_schema: Original schema of a migration. context: Schema comparison context. """ pass def resolve_obj_collection( self, value: Any, schema: s_schema.Schema, ) -> Sequence[so.Object]: sequence: Sequence[so.Object] if isinstance(value, so.ObjectCollection): sequence = value.objects(schema) else: sequence = [] for v in value: if isinstance(v, so.Shell): val = v.resolve(schema) else: val = v sequence.append(val) return sequence def _resolve_attr_value( self, value: Any, fname: str, field: so.Field[Any], schema: s_schema.Schema, ) -> Any: ftype = field.type if isinstance(value, so.Shell): value = value.resolve(schema) else: if issubclass(ftype, so.ObjectDict): if isinstance(value, so.ObjectDict): items = dict(value.items(schema)) elif isinstance(value, collections.abc.Mapping): items = {} for k, v in value.items(): if isinstance(v, so.Shell): val = v.resolve(schema) else: val = v items[k] = val value = ftype.create(schema, items) elif issubclass(ftype, so.ObjectCollection): sequence = self.resolve_obj_collection(value, schema) value = ftype.create(schema, sequence) else: value = field.coerce_value(schema, value) return value def enumerate_attributes(self) -> tuple[str, ...]: return tuple(self._attrs) def _enumerate_attribute_cmds(self) -> tuple[AlterObjectProperty, ...]: return tuple(self._attrs.values()) def has_attribute_value(self, attr_name: str) -> bool: return attr_name in self._attrs or attr_name in self._special_attrs def _get_simple_attribute_set_cmd( self, attr_name: str, ) -> Optional[AlterObjectProperty]: return self._attrs.get(attr_name) def _get_attribute_set_cmd( self, attr_name: str, ) -> Optional[AlterObjectProperty]: cmd = self._get_simple_attribute_set_cmd(attr_name) if cmd is None: special_cmd = self._special_attrs.get(attr_name) if special_cmd is not None: cmd = special_cmd._get_attribute_set_cmd(attr_name) return cmd def get_attribute_value( self, attr_name: str, ) -> Any: op = self._get_attribute_set_cmd(attr_name) if op is not None: return op.new_value else: return None def get_local_attribute_value( self, attr_name: str, ) -> Any: """Return the new value of field, if not inherited.""" op = self._get_attribute_set_cmd(attr_name) if op is not None and not op.new_inherited: return op.new_value else: return None def get_orig_attribute_value( self, attr_name: str, ) -> Any: op = self._get_attribute_set_cmd(attr_name) if op is not None: return op.old_value else: return None def is_attribute_inherited( self, attr_name: str, ) -> bool: op = self._get_attribute_set_cmd(attr_name) if op is not None: return op.new_inherited else: return False def is_attribute_computed( self, attr_name: str, ) -> bool: op = self._get_attribute_set_cmd(attr_name) if op is not None: return op.new_computed else: return False def get_attribute_span( self, attr_name: str, ) -> Optional[parsing.Span]: op = self._get_attribute_set_cmd(attr_name) if op is not None: return op.span else: return None def set_attribute_value( self, attr_name: str, value: Any, *, orig_value: Any = None, inherited: bool = False, orig_inherited: Optional[bool] = None, computed: bool = False, from_default: bool = False, orig_computed: Optional[bool] = None, span: Optional[parsing.Span] = None, ) -> Command: orig_op = op = self._get_simple_attribute_set_cmd(attr_name) if op is None: op = AlterObjectProperty(property=attr_name, new_value=value) else: op.new_value = value if orig_inherited is None: orig_inherited = inherited op.new_inherited = inherited op.old_inherited = orig_inherited if orig_computed is None: orig_computed = computed op.new_computed = computed op.old_computed = orig_computed op.from_default = from_default if span is not None: op.span = span if orig_value is not None: op.old_value = orig_value if orig_op is None: self.add(op) return op def discard_attribute(self, attr_name: str) -> None: op = self._get_attribute_set_cmd(attr_name) if op is not None: self.discard(op) def __iter__(self) -> NoReturn: raise TypeError(f'{type(self)} object is not iterable') @overload def get_subcommands( self, *, type: type[Command_T], metaclass: Optional[type[so.Object]] = None, exclude: type[Command] | tuple[type[Command], ...] | None = None, include_prerequisites: bool = True, include_caused: bool = True, ) -> tuple[Command_T, ...]: ... @overload def get_subcommands( self, *, type: None = None, metaclass: Optional[type[so.Object]] = None, exclude: type[Command] | tuple[type[Command], ...] | None = None, include_prerequisites: bool = True, include_caused: bool = True, ) -> tuple[Command, ...]: ... def get_subcommands( self, *, type: type[Command_T] | None = None, metaclass: Optional[type[so.Object]] = None, exclude: type[Command] | tuple[type[Command], ...] | None = None, include_prerequisites: bool = True, include_caused: bool = True, ) -> tuple[Command, ...]: ops: Iterable[Command] = self.ops if include_prerequisites: ops = itertools.chain(self.before_ops, ops) if include_caused: ops = itertools.chain(ops, self.caused_ops) filters = [] if type is not None: t = type filters.append(lambda i: isinstance(i, t)) if exclude is not None: ex = exclude filters.append(lambda i: not isinstance(i, ex)) if metaclass is not None: mcls = metaclass filters.append( lambda i: ( isinstance(i, ObjectCommand) and issubclass(i.get_schema_metaclass(), mcls) ) ) if filters: return tuple(filter(lambda i: all(f(i) for f in filters), ops)) else: return tuple(ops) @overload def get_prerequisites( self, *, type: type[Command_T], ) -> tuple[Command_T, ...]: ... @overload def get_prerequisites( self, *, type: None = None, ) -> tuple[Command, ...]: ... def get_prerequisites( self, *, type: type[Command_T] | None = None, ) -> tuple[Command, ...]: if type is not None: t = type return tuple(filter(lambda i: isinstance(i, t), self.before_ops)) else: return tuple(self.before_ops) @overload def get_caused( self, *, type: type[Command_T], ) -> tuple[Command_T, ...]: ... @overload def get_caused( self, *, type: None = None, ) -> tuple[Command, ...]: ... def get_caused( self, *, type: type[Command_T] | None = None, ) -> tuple[Command, ...]: if type is not None: t = type return tuple(filter(lambda i: isinstance(i, t), self.caused_ops)) else: return tuple(self.caused_ops) def has_subcommands(self) -> bool: return bool(self.ops) or bool(self.before_ops) or bool(self.caused_ops) def get_nonattr_subcommand_count(self) -> int: attr_cmds = (AlterObjectProperty, AlterSpecialObjectField) return len(self.get_subcommands(exclude=attr_cmds)) def get_nonattr_special_subcommand_count(self) -> int: attr_cmds = (AlterObjectProperty,) return len(self.get_subcommands(exclude=attr_cmds)) def prepend_prerequisite(self, command: Command) -> None: if isinstance(command, CommandGroup): for op in reversed(command.get_subcommands()): self.prepend_prerequisite(op) else: self.before_ops.insert(0, command) def add_prerequisite(self, command: Command) -> None: if isinstance(command, CommandGroup): self.before_ops.extend(command.get_subcommands()) else: self.before_ops.append(command) def prepend_caused(self, command: Command) -> None: if isinstance(command, CommandGroup): for op in reversed(command.get_subcommands()): self.prepend_caused(op) else: self.caused_ops.insert(0, command) def add_caused(self, command: Command) -> None: if isinstance(command, CommandGroup): self.caused_ops.extend(command.get_subcommands()) else: self.caused_ops.append(command) def prepend(self, command: Command) -> None: if isinstance(command, CommandGroup): for op in reversed(command.get_subcommands()): self.prepend(op) else: if isinstance(command, AlterObjectProperty): self._attrs[command.property] = command elif isinstance(command, AlterSpecialObjectField): self._special_attrs[command._field] = command self.ops.insert(0, command) def add(self, command: Command) -> None: if isinstance(command, CommandGroup): self.update(command.get_subcommands()) else: if isinstance(command, AlterObjectProperty): self._attrs[command.property] = command elif isinstance(command, AlterSpecialObjectField): self._special_attrs[command._field] = command self.ops.append(command) def update(self, commands: Iterable[Command]) -> None: for command in commands: self.add(command) def replace(self, existing: Command, new: Command) -> None: # type: ignore try: i = self.ops.index(existing) self.ops[i] = new return except ValueError: pass try: i = self.before_ops.index(existing) self.before_ops[i] = new return except ValueError: pass i = self.caused_ops.index(existing) self.caused_ops[i] = new def replace_all(self, commands: Iterable[Command]) -> None: self.ops.clear() self._attrs.clear() self._special_attrs.clear() self.update(commands) def discard(self, command: Command) -> None: try: self.ops.remove(command) except ValueError: pass try: self.before_ops.remove(command) except ValueError: pass try: self.caused_ops.remove(command) except ValueError: pass if isinstance(command, AlterObjectProperty): self._attrs.pop(command.property) elif isinstance(command, AlterSpecialObjectField): self._special_attrs.pop(command._field) def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: return schema def apply_prerequisites( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: for op in self.get_prerequisites(): schema = op.apply(schema, context) return schema def apply_subcommands( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: for op in self.get_subcommands( include_prerequisites=False, include_caused=False, ): if not isinstance(op, AlterObjectProperty): schema = op.apply(schema, context=context) return schema def apply_caused( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: for op in self.get_caused(): schema = op.apply(schema, context) return schema def get_ast( self, schema: s_schema.Schema, context: CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: context_class = type(self).get_context_class() assert context_class is not None with context(context_class(schema=schema, op=self)): return self._get_ast(schema, context, parent_node=parent_node) def _get_ast( self, schema: s_schema.Schema, context: CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: raise NotImplementedError def _log_all_renames(self, context: CommandContext) -> None: if isinstance(self, RenameObject): context.early_renames[self.classname] = self.new_name for subcmd in self.get_subcommands(): subcmd._log_all_renames(context) @classmethod def get_orig_expr_text( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, name: str, ) -> Optional[str]: orig_text_expr = qlast.get_ddl_field_value(astnode, f'orig_{name}') if orig_text_expr: orig_text = qlcompiler.evaluate_ast_to_python_val( orig_text_expr, schema=schema) else: orig_text = None return orig_text # type: ignore @classmethod def command_for_ast_node( cls, astnode: qlast.DDLOperation, schema: s_schema.Schema, context: CommandContext, ) -> type[Command]: return cls @classmethod def _modaliases_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> dict[Optional[str], str]: modaliases = {} if isinstance(astnode, qlast.DDLCommand) and astnode.aliases: for alias in astnode.aliases: if isinstance(alias, qlast.ModuleAliasDecl): modaliases[alias.alias] = alias.module return modaliases @classmethod def localnames_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> set[str]: localnames: set[str] = set() if isinstance(astnode, qlast.DDLCommand) and astnode.aliases: for alias in astnode.aliases: if isinstance(alias, qlast.AliasedExpr): localnames.add(alias.alias) return localnames @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> Command: cmd = cls._cmd_from_ast(schema, astnode, context) cmd.span = astnode.span cmd.qlast = astnode ctx = context.current() if ctx is not None and type(ctx) is cls.get_context_class(): ctx.op = cmd if astnode.commands: for subastnode in astnode.commands: subcmd = compile_ddl(schema, subastnode, context=context) if subcmd is not None: cmd.add(subcmd) return cmd @classmethod def _cmd_from_ast( cls: type[Command_T], schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> Command: return cls() @classmethod def as_markup(cls, self: Command, *, ctx: markup.Context) -> markup.Markup: node = markup.elements.lang.TreeNode(name=str(self)) def _markup(dd: Command) -> None: if isinstance(dd, AlterObjectProperty): diff = markup.elements.doc.ValueDiff( before=repr(dd.old_value), after=repr(dd.new_value)) if dd.new_inherited: diff.comment = 'inherited' elif dd.new_computed: diff.comment = 'computed' node.add_child(label=dd.property, node=diff) else: node.add_child(node=markup.serialize(dd, ctx=ctx)) prereqs = self.get_prerequisites() caused = self.get_caused() if prereqs: node.add_child( node=markup.elements.doc.Marker(text='prerequsites')) for dd in prereqs: _markup(dd) if subs := self.get_subcommands( include_prerequisites=False, include_caused=False, ): # Only label regular subcommands if there are prereqs or # caused actions, and so there is room for confusion if prereqs or caused: node.add_child(node=markup.elements.doc.Marker(text='main')) for dd in subs: _markup(dd) if caused: node.add_child(node=markup.elements.doc.Marker(text='caused')) for dd in caused: _markup(dd) return node @classmethod def get_context_class( cls: type[Command_T], ) -> Optional[type[CommandContextToken[Command_T]]]: return cls._context_class # type: ignore @classmethod def get_context_class_or_die( cls: type[Command_T], ) -> type[CommandContextToken[Command_T]]: ctxcls = cls.get_context_class() if ctxcls is None: raise RuntimeError(f'context class not defined for {cls}') return ctxcls def formatfields( self, formatter: str = 'str', ) -> Iterator[tuple[str, str]]: """Return an iterator over fields formatted using `formatter`.""" for name, field in self.__class__._fields.items(): value = getattr(self, name) default = field.default formatter_obj = field.formatters.get(formatter) if formatter_obj and value != default: yield (name, formatter_obj(value)) class Nop(Command): pass # Similarly to _dummy_object, we use _dummy_command for places where # the typing requires an object, but we don't have it just yet. _dummy_command = Command() CommandList = checked.CheckedList[Command] class CommandGroup(Command): def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: schema = self.apply_prerequisites(schema, context) schema = self.apply_subcommands(schema, context) schema = self.apply_caused(schema, context) return schema class CommandContextToken(Generic[Command_T_co]): # noqa: UP046 original_schema: s_schema.Schema op: Command_T_co modaliases: Mapping[Optional[str], str] localnames: AbstractSet[str] inheritance_merge: Optional[bool] inheritance_refdicts: Optional[AbstractSet[str]] mark_derived: Optional[bool] enable_recursion: Optional[bool] transient_derivation: Optional[bool] # Whether to skip creating @source/@target properties on links. # Typically this is set whenever transient_derivation is, # (so that it doesn't get set on transiet views, etc), # except when compiling aliases where we need to produce # fully populated links. # This is a surprisingly valuable optimization (25% on a big schema). slim_links: Optional[bool] def __init__( self, schema: s_schema.Schema, op: Command_T_co, *, modaliases: Optional[Mapping[Optional[str], str]] = None, # localnames are the names defined locally via with block or # as function parameters and should not be fully-qualified localnames: AbstractSet[str] = frozenset(), ) -> None: self.original_schema = schema self.op = op self.modaliases = modaliases if modaliases is not None else {} self.localnames = localnames self.inheritance_merge = None self.inheritance_refdicts = None self.mark_derived = None self.enable_recursion = None self.transient_derivation = None self.slim_links = None class CommandContextWrapper(Generic[Command_T_co]): # noqa: UP046 def __init__( self, context: CommandContext, token: CommandContextToken[Command_T_co], ) -> None: self.context = context self.token = token def __enter__(self) -> CommandContextToken[Command_T_co]: self.context.push(self.token) return self.token def __exit__( self, exc_type: type[Exception], exc_value: Exception, traceback: Any, ) -> None: self.context.pop() class CommandContext: def __init__( self, *, schema: Optional[s_schema.Schema] = None, modaliases: Optional[Mapping[Optional[str], str]] = None, localnames: AbstractSet[str] = frozenset(), declarative: bool = False, stdmode: bool = False, testmode: bool = False, internal_schema_mode: bool = False, disable_dep_verification: bool = False, store_migration_sdl: bool = False, descriptive_mode: bool = False, schema_object_ids: Optional[ Mapping[tuple[sn.Name, Optional[str]], uuid.UUID] ] = None, backend_runtime_params: Optional[Any] = None, compat_ver: Optional[verutils.Version] = None, include_ext_version: bool = True, ) -> None: self.stack: list[CommandContextToken[Command]] = [] self._cache: dict[Hashable, Any] = {} self._values: dict[Hashable, Any] = {} self.declarative = declarative self.schema = schema self._modaliases = modaliases if modaliases is not None else {} self._localnames = localnames self.stdmode = stdmode self.stable_ids = stdmode self.internal_schema_mode = internal_schema_mode self.testmode = testmode self.descriptive_mode = descriptive_mode self.disable_dep_verification = disable_dep_verification self.store_migration_sdl = store_migration_sdl self.renames: dict[sn.Name, sn.Name] = {} self.early_renames: dict[sn.Name, sn.Name] = {} self.renamed_objs: set[so.Object] = set() self.change_log: dict[tuple[type[so.Object], str], set[so.Object]] = ( collections.defaultdict(set)) self.schema_object_ids = schema_object_ids self.backend_runtime_params = backend_runtime_params self.affected_finalization: dict[ Command, list[tuple[Command, AlterObject[so.Object], list[str]]], ] = collections.defaultdict(list) self.compat_ver = compat_ver self.include_ext_version = include_ext_version @property def modaliases(self) -> Mapping[Optional[str], str]: maps = [t.modaliases for t in reversed(self.stack)] maps.append(self._modaliases) return collections.ChainMap(*maps) # type: ignore @property def localnames(self) -> set[str]: ign: set[str] = set() for ctx in reversed(self.stack): ign.update(ctx.localnames) ign.update(self._localnames) return ign @property def inheritance_merge(self) -> Optional[bool]: for ctx in reversed(self.stack): if ctx.inheritance_merge is not None: return ctx.inheritance_merge return None @property def mark_derived(self) -> Optional[bool]: for ctx in reversed(self.stack): if ctx.mark_derived is not None: return ctx.mark_derived return None @property def inheritance_refdicts(self) -> Optional[AbstractSet[str]]: for ctx in reversed(self.stack): if ctx.inheritance_refdicts is not None: return ctx.inheritance_refdicts return None @property def enable_recursion(self) -> bool: for ctx in reversed(self.stack): if ctx.enable_recursion is not None: return ctx.enable_recursion return True @property def transient_derivation(self) -> bool: for ctx in reversed(self.stack): if ctx.transient_derivation is not None: return ctx.transient_derivation return False @property def slim_links(self) -> bool: return any(ctx.slim_links for ctx in self.stack) @property def canonical(self) -> bool: return any(ctx.op.canonical for ctx in self.stack) def in_deletion(self, offset: int = 0) -> bool: """Return True if any object is being deleted in this context. :param offset: The offset in the context stack to start looking at. :returns: True if any object is being deleted in this context starting from *offset* in the stack. """ return any(isinstance(ctx.op, DeleteObject) for ctx in self.stack[:-offset if offset else None]) def is_deleting(self, obj: so.Object) -> bool: """Return True if *obj* is being deleted in this context. :param obj: The object in question. :returns: True if *obj* is being deleted in this context. """ return any(isinstance(ctx.op, DeleteObject) and ctx.op.scls == obj for ctx in self.stack) def is_creating(self, obj: so.Object) -> bool: """Return True if *obj* is being created in this context. :param obj: The object in question. :returns: True if *obj* is being created in this context. """ return any(isinstance(ctx.op, CreateObject) and getattr(ctx.op, 'scls', None) == obj for ctx in self.stack) def is_altering(self, obj: so.Object) -> bool: """Return True if *obj* is being altered in this context. :param obj: The object in question. :returns: True if *obj* is being altered in this context. """ return any(isinstance(ctx.op, AlterObject) and getattr(ctx.op, 'scls', None) == obj for ctx in self.stack) def push(self, token: CommandContextToken[Command]) -> None: self.stack.append(token) def pop(self) -> CommandContextToken[Command]: return self.stack.pop() def get_referrer_name( self, referrer_ctx: CommandContextToken[ObjectCommand[so.Object]], ) -> sn.QualName: referrer_name = referrer_ctx.op.classname renamed = self.early_renames.get(referrer_name) if renamed: referrer_name = renamed else: renamed = self.renames.get(referrer_name) if renamed: referrer_name = renamed assert isinstance(referrer_name, sn.QualName) return referrer_name @overload def get( self, cls: type[ObjectCommandContext[so.Object_T]], ) -> Optional[ObjectCommandContext[so.Object_T]]: ... @overload def get( self, cls: type[Command_T] | type[CommandContextToken[Command_T]], ) -> Optional[CommandContextToken[Command_T]]: ... def get( self, cls: type[Command_T] | type[CommandContextToken[Command_T]], ) -> Optional[CommandContextToken[Command_T]]: ctxcls: Any if issubclass(cls, Command): ctxcls = cls.get_context_class() assert ctxcls is not None else: ctxcls = cls for item in reversed(self.stack): if isinstance(item, ctxcls): return item # type: ignore return None def get_ancestor( self, cls: type[Command] | type[CommandContextToken[Command]], op: Optional[Command] = None, ) -> Optional[CommandContextToken[Command]]: if issubclass(cls, Command): ctxcls = cls.get_context_class() assert ctxcls is not None else: ctxcls = cls if op is not None: for item in list(reversed(self.stack)): if isinstance(item, ctxcls) and item.op is not op: return item else: for item in list(reversed(self.stack))[1:]: if isinstance(item, ctxcls): return item return None def get_topmost_ancestor( self, cls: type[Command] | type[CommandContextToken[Command]], ) -> Optional[CommandContextToken[Command]]: if issubclass(cls, Command): ctxcls = cls.get_context_class() assert ctxcls is not None else: ctxcls = cls for item in self.stack: if isinstance(item, ctxcls): return item return None def top(self) -> CommandContextToken[Command]: if self.stack: return self.stack[0] else: raise KeyError('command context stack is empty') def current(self) -> CommandContextToken[Command]: if self.stack: return self.stack[-1] else: raise KeyError('command context stack is empty') def parent(self) -> Optional[CommandContextToken[Command]]: if len(self.stack) > 1: return self.stack[-2] else: return None def copy(self) -> CommandContext: ctx = CommandContext() ctx.stack = self.stack[:] return ctx def cache_value(self, key: Hashable, value: Any) -> None: self._cache[key] = value def get_cached(self, key: Hashable) -> Any: return self._cache.get(key) def drop_cache(self, key: Hashable) -> None: self._cache.pop(key, None) def store_value(self, key: Hashable, value: Any) -> None: self._values[key] = value def get_value(self, key: Hashable) -> Any: return self._values.get(key) @contextlib.contextmanager def suspend_dep_verification(self) -> Iterator[CommandContext]: dep_ver = self.disable_dep_verification self.disable_dep_verification = True try: yield self finally: self.disable_dep_verification = dep_ver def __call__( self, token: CommandContextToken[Command_T], ) -> CommandContextWrapper[Command_T]: return CommandContextWrapper(self, token) def compat_ver_is_before( self, ver: tuple[int, int, verutils.VersionStage, int], ) -> bool: return self.compat_ver is not None and self.compat_ver < ver class ContextStack: def __init__( self, contexts: Iterable[CommandContextWrapper[Command]], ) -> None: self._contexts = list(contexts) def push(self, ctx: CommandContextWrapper[Command]) -> None: self._contexts.append(ctx) def pop(self) -> None: self._contexts.pop() @contextlib.contextmanager def __call__(self) -> Generator[None, None, None]: with contextlib.ExitStack() as stack: for ctx in self._contexts: stack.enter_context(ctx) # type: ignore yield class DeltaRootContext(CommandContextToken["DeltaRoot"]): pass class DeltaRoot(CommandGroup, context_class=DeltaRootContext): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.new_types: set[uuid.UUID] = set() self.warnings: list[errors.EdgeDBError] = [] @classmethod def from_commands(cls, *cmds: Command) -> DeltaRoot: delta = DeltaRoot() delta.update(cmds) return delta def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: with context(DeltaRootContext(schema=schema, op=self)): schema = self.apply_prerequisites(schema, context) schema = self.apply_subcommands(schema, context) schema = self.apply_caused(schema, context) return schema def is_data_safe(self) -> bool: return all( subcmd.is_data_safe() for subcmd in self.get_subcommands() ) class Query(Command): """A special delta command representing a non-DDL query. These are found in migrations. """ astnode = qlast.DDLQuery expr = struct.Field(s_expr.Expression) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> Command: assert isinstance(astnode, qlast.DDLQuery) return cls( span=astnode.span, expr=s_expr.Expression.from_ast( astnode.query, schema=schema, modaliases=context.modaliases, localnames=context.localnames, ), ) @classmethod def as_markup(cls, self: Command, *, ctx: markup.Context) -> markup.Markup: node = super().as_markup(self, ctx=ctx) assert isinstance(node, markup.elements.lang.TreeNode) assert isinstance(self, Query) qltext = self.expr.text node.add_child(node=markup.elements.lang.MultilineString(str=qltext)) return node def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) if not self.expr.is_compiled(): self.expr = self.expr.compiled( schema, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, apply_query_rewrites=False, ), context=context, ) return schema _command_registry: dict[ tuple[str, type[so.Object]], type[ObjectCommand[so.Object]] ] = {} def get_object_command_class[Command_T: Command]( cmdtype: type[Command_T], schema_metaclass: type[so.Object], ) -> Optional[type[Command_T]]: assert issubclass(cmdtype, ObjectCommand) return _command_registry.get( # type: ignore (cmdtype._delta_action, schema_metaclass), ) def get_object_command_class_or_die[Command_T: Command]( cmdtype: type[Command_T], schema_metaclass: type[so.Object], ) -> type[Command_T]: cmdcls = get_object_command_class(cmdtype, schema_metaclass) if cmdcls is None: raise TypeError(f'missing {cmdtype.__name__} implementation ' f'for {schema_metaclass.__name__}') return cmdcls class ObjectCommand[Object_T: so.Object](Command): """Base class for all Object-related commands.""" #: Full name of the object this command operates on. classname = struct.Field(sn.Name) #: An optional set of values neceessary to render the command in DDL. ddl_identity = struct.Field( dict, # type: ignore default=None, ) #: An optional dict of metadata annotations for this command. annotations = struct.Field( dict, # type: ignore default=None, ) #: Auxiliary object information that might be necessary to process #: this command, derived from object fields. aux_object_data = struct.Field( dict, # type: ignore default=None, ) #: When this command is produced by a breakup of a larger command #: subtree, *orig_cmd_type* would contain the type of the original #: command. orig_cmd_type = struct.Field( CommandMeta, default=None, ) #: Is this from an expression change being propagated. #: FIXME: Every place this is used is a hack and some are bugs. from_expr_propagation = struct.Field(bool, default=False) scls: Object_T _delta_action: ClassVar[str] _schema_metaclass: ClassVar[ # type: ignore Optional[type[Object_T]] ] = None astnode: ClassVar[type[qlast.DDLOperation] | list[type[qlast.DDLOperation]]] def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # Check if the command subclass has been parametrized with # a concrete schema object class, and if so, record the # argument to be made available via get_schema_metaclass(). super().__init_subclass__(*args, **kwargs) generic_bases = typing_inspect.get_generic_bases(cls) mcls: Optional[type[so.Object]] = None for gb in generic_bases: base_origin = typing_inspect.get_origin(gb) # Find the [Type] base, where ObjectCommand # is any ObjectCommand subclass. if ( base_origin is not None and issubclass(base_origin, ObjectCommand) ): args = typing_inspect.get_args(gb) if len(args) != 1: raise AssertionError( 'expected only one argument to ObjectCommand generic') arg_0 = args[0] if not typing_inspect.is_typevar(arg_0): assert issubclass(arg_0, so.Object) if not arg_0.is_abstract(): mcls = arg_0 break if mcls is not None: existing = getattr(cls, '_schema_metaclass', None) if existing is not None and existing is not mcls: raise TypeError( f'cannot redefine schema class of {cls.__name__} to ' f'{mcls.__name__}: a superclass has already defined it as ' f'{existing.__name__}' ) cls._schema_metaclass = mcls # If this is a command adapter rather than the actual # command, skip the command class registration. if not cls.has_adaptee(): delta_action = getattr(cls, '_delta_action', None) schema_metaclass = getattr(cls, '_schema_metaclass', None) if schema_metaclass is not None and delta_action is not None: key = delta_action, schema_metaclass cmdcls = _command_registry.get(key) if cmdcls is not None: raise TypeError( f'Action {cls._delta_action!r} for ' f'{schema_metaclass} is already claimed by {cmdcls}' ) _command_registry[key] = cls # type: ignore @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: CommandContext, ) -> sn.Name: return sn.UnqualName(astnode.name.name) @classmethod def _cmd_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> ObjectCommand[Object_T]: assert isinstance(astnode, qlast.ObjectDDL), 'expected ObjectDDL' classname = cls._classname_from_ast(schema, astnode, context) return cls(classname=classname) def is_data_safe(self) -> bool: if self.get_schema_metaclass()._data_safe: return True else: return all( subcmd.is_data_safe() for subcmd in self.get_subcommands() ) def get_required_user_input(self) -> list[dict[str, str]]: result: list[dict[str, str]] = [] if ann := self.get_annotation('required_input'): result.append(ann) for cmd in self.get_subcommands(): result.extend(cmd.get_required_user_input()) return result def get_friendly_description( self, *, parent_op: Optional[Command] = None, schema: Optional[s_schema.Schema] = None, object: Any = None, object_desc: Optional[str] = None, ) -> str: """Return a friendly description of this command in imperative mood. The result is used in error messages and other user-facing renderings of the command. """ object_desc = self.get_friendly_object_name_for_description( parent_op=parent_op, schema=schema, object=object, object_desc=object_desc, ) return f'{self.get_verb()} {object_desc}' def get_user_prompt( self, *, parent_op: Optional[Command] = None, ) -> tuple[CommandKey, str]: """Return a human-friendly prompt describing this operation.""" # The prompt is determined by the *innermost* subcommand as # long as all its parents have exactly one child. The tree # traversal stops on fragments and CreateObject commands, # since there is no point to prompt about the creation of # object innards. if ( not isinstance(self, AlterObjectFragment) and ( not isinstance(self, CreateObject) and ( self.orig_cmd_type is None or not issubclass( self.orig_cmd_type, CreateObject ) ) ) ): from . import referencing as s_referencing subcommands = self.get_subcommands( type=ObjectCommand, exclude=(AlterObjectProperty, s_referencing.AlterOwned), ) if len(subcommands) == 1: subcommand = subcommands[0] if isinstance(subcommand, AlterObjectFragment): return subcommand.get_user_prompt(parent_op=parent_op) else: return subcommand.get_user_prompt(parent_op=self) desc = self.get_friendly_description(parent_op=parent_op) prompt_text = f'did you {desc}?' prompt_id = get_object_command_key(self) assert prompt_id is not None return prompt_id, prompt_text def validate_object( self, schema: s_schema.Schema, context: CommandContext, ) -> None: pass @classmethod def get_parent_op( cls, context: CommandContext, ) -> ObjectCommand[so.Object]: parent = context.parent() if parent is None: raise AssertionError(f'{cls!r} has no parent context') op = parent.op assert isinstance(op, ObjectCommand) return op @classmethod @functools.lru_cache() def _get_special_handler( cls, field_name: str, ) -> Optional[type[AlterSpecialObjectField[so.Object]]]: if ( issubclass(cls, AlterObjectOrFragment) and not issubclass(cls, AlterSpecialObjectField) ): schema_cls = cls.get_schema_metaclass() return get_special_field_alter_handler(field_name, schema_cls) else: return None def set_attribute_value( self, attr_name: str, value: Any, *, orig_value: Any = None, inherited: bool = False, orig_inherited: Optional[bool] = None, computed: bool = False, orig_computed: Optional[bool] = None, from_default: bool = False, span: Optional[parsing.Span] = None, disallow_special: bool = False, ) -> Command: special = ( type(self)._get_special_handler(attr_name) if not disallow_special else None ) op = self._get_attribute_set_cmd(attr_name) top_op: Optional[Command] = None if orig_inherited is None: orig_inherited = inherited if orig_computed is None: orig_computed = computed if op is None: op = AlterObjectProperty( property=attr_name, new_value=value, old_value=orig_value, new_inherited=inherited, old_inherited=orig_inherited, new_computed=computed, old_computed=orig_computed, from_default=from_default, span=span, ) top_op = self._special_attrs.get(attr_name) if top_op is None and special is not None: top_op = special(classname=self.classname) self.add(top_op) if top_op: top_op.add(op) else: self.add(op) top_op = op return top_op else: op.new_value = value op.new_inherited = inherited op.old_inherited = orig_inherited op.new_computed = computed op.old_computed = orig_computed op.from_default = from_default if span is not None: op.span = span if orig_value is not None: op.old_value = orig_value return op def _fix_referencing_expr_after_rename( self, schema: s_schema.Schema, cmd: ObjectCommand[so.Object], fn: str, context: CommandContext, expr: s_expr.Expression, ) -> s_expr.Expression: if isinstance(self, RenameObject): new_name = self.new_name elif (fops := self.get_subcommands(type=RenameObject)): new_name = fops[0].new_name else: raise AssertionError("not a rename!") # Recompile the expression with reference tracking on so that we # can clean up the ast. field = cmd.get_schema_metaclass().get_field(fn) compiled = cmd.compile_expr_field( schema, context, field, expr, track_schema_ref_exprs=True) assert compiled.irast.schema_ref_exprs is not None # Now that the compilation is done, try to do the fixup. new_shortname = sn.shortname_from_fullname(new_name) old_shortname = sn.shortname_from_fullname(self.classname).name for ref in compiled.irast.schema_ref_exprs.get(self.scls, []): assert isinstance( ref, (qlast.ObjectRef, qlast.FunctionCall, qlast.Ptr) ), f"only support object refs and func calls but got {ref}" if isinstance(ref, qlast.FunctionCall): ref.func = ((new_shortname.module, new_shortname.name) if isinstance(new_shortname, sn.QualName) else new_shortname.name) elif ( isinstance(ref, (qlast.Ptr, qlast.ObjectRef)) and ref.name == old_shortname ): ref.name = new_shortname.name if ( isinstance(new_shortname, sn.QualName) and isinstance(ref, qlast.ObjectRef) and new_shortname.module != "__" ): ref.module = new_shortname.module # say as_fragment=True as a hack to avoid renormalizing it out = s_expr.Expression.from_ast( compiled.parse(), schema, modaliases={}, as_fragment=True) return out def _propagate_if_expr_refs( self, schema: s_schema.Schema, context: CommandContext, *, action: str, include_self: bool=True, include_ancestors: bool=False, extra_refs: Optional[dict[so.Object, list[str]]]=None, filter: type[so.Object] | tuple[type[so.Object], ...] | None = None, metadata_only: bool=False, ) -> s_schema.Schema: # If we are a rename or contain a rename, we need to fix up expressions if ( isinstance(self, RenameObject) or self.get_subcommands(type=RenameObject) ): fixer = self._fix_referencing_expr_after_rename else: fixer = None scls = self.scls expr_refs: dict[so.Object, list[str]] = {} if include_self: expr_refs.update(s_expr.get_expr_referrers(schema, scls)) if include_ancestors and isinstance(scls, so.InheritingObject): for anc in scls.get_ancestors(schema).objects(schema): expr_refs.update(s_expr.get_expr_referrers(schema, anc)) if extra_refs: expr_refs.update(extra_refs) if filter is not None: expr_refs = { k: v for k, v in expr_refs.items() if isinstance(k, filter)} if expr_refs: try: sorted_ref_objs = sort_by_cross_refs(schema, expr_refs.keys()) except topological.CycleError as e: assert e.item is not None item_vn = e.item.get_verbosename(schema, with_parent=True) if e.path: # Recursion involving more than one schema object. rec_vn = e.path[-1].get_verbosename( schema, with_parent=True) # Sort for output determinism vn1, vn2 = sorted([rec_vn, item_vn]) msg = ( f'definition dependency cycle between {vn1} and {vn2}' ) else: # A single schema object with a recursive definition. msg = f'{item_vn} is defined recursively' raise errors.InvalidDefinitionError(msg) from e ref_desc = [] for ref in sorted_ref_objs: cmd_drop: Command cmd_create: Command fns = expr_refs[ref] this_ref_desc = [] for fn in fns: if fn == 'expr': fdesc = 'expression' else: sfn = type(ref).get_field(fn).sname fdesc = f"{sfn.replace('_', ' ')} expression" vn = ref.get_verbosename(schema, with_parent=True) this_ref_desc.append(f'{fdesc} of {vn}') # Alter the affected entity to change the body to # a dummy version (removing the dependency) and # then reset the body to original expression. delta_drop, cmd_drop, _ = ref.init_delta_branch( schema, context, cmdtype=AlterObject) delta_create, cmd_create, ctx_stack = ref.init_delta_branch( schema, context, cmdtype=AlterObject, possible_parent=self, # type: ignore ) cmd_drop.from_expr_propagation = True cmd_create.from_expr_propagation = True # Mark it metadata_only so that if it actually gets # applied, only the metadata is changed but not # the real underlying schema. if metadata_only: cmd_drop.metadata_only = True cmd_create.metadata_only = True # Treat the drop as canonical, since we only need # to eliminate the reference, not get to a fully # consistent state, and the canonicalization can # mess up "associated" attributes. cmd_drop.canonical = True for fn, cur_ref_desc in zip(fns, this_ref_desc): value: s_expr.Expression | None = ( ref.get_explicit_field_value(schema, fn, None)) if value is None: continue try: # Compute a dummy value dummy = cmd_create.get_dummy_expr_field_value( schema, context, field=type(ref).get_field(fn), value=ref.get_field_value(schema, fn) ) except NotImplementedError: ref_desc.append(cur_ref_desc) else: # Do the switcheroos # Strip the "compiled" out of the expression value = s_expr.Expression.not_compiled(value) # We don't run the fixer on inherited fields because # they can't have changed (and because running it # on inherited constraint finalexprs breaks # the extra parens in it...) if fixer and not ref.field_is_inherited(schema, fn): with ctx_stack(): value = fixer( schema, cmd_create, fn, context, value) cmd_drop.set_attribute_value(fn, dummy) cmd_create.set_attribute_value( fn, value, inherited=ref.field_is_inherited(schema, fn), computed=ref.field_is_computed(schema, fn), ) context.affected_finalization[self].append( (delta_create, cmd_create, this_ref_desc) ) schema = delta_drop.apply(schema, context) if ref_desc: expr_s = ( 'an expression' if len(ref_desc) == 1 else 'expressions') ref_desc_s = "\n - " + "\n - ".join(ref_desc) raise errors.SchemaDefinitionError( f'cannot {action} because it is used in {expr_s}', details=( f'{scls.get_verbosename(schema)} is used in:' f'{ref_desc_s}' ) ) return schema def _finalize_affected_refs( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: # There might be dependencies between the things we need to # fix up (a computed property and a constraint on it, for # example, requires us to fix up the computed property first), # so sort by dependency order. objs_to_cmds: dict[ so.Object, list[tuple[Command, AlterObject[so.Object], list[str]]] ] = {} for delta, cmd, refdesc in context.affected_finalization.get(self, []): if schema.has_object(cmd.scls.id): cmds = objs_to_cmds.setdefault(cmd.scls, []) cmds.append((delta, cmd, refdesc)) objs = sort_by_cross_refs(schema, objs_to_cmds.keys()) for obj in reversed(objs): for delta, cmd, refdesc in objs_to_cmds[obj]: try: cmd.canonicalize_alter_from_external_ref(schema, context) schema = delta.apply(schema, context) if not context.canonical and delta: # We need to force the attributes to be resolved so # that expressions get compiled *now* under a schema # where they are correct, and not later, when more # renames may have broken them. assert isinstance(cmd, ObjectCommand) res_attrs = cmd.get_resolved_attributes(schema, context) for key, value in res_attrs.items(): cmd.set_attribute_value(key, value) # HACK: Apply constraint of pointers in innards, because # when converting a pointer to a computed pointer, # constraints need to be adjusted before the column is # dropped. We cannot drop the column later because we # need mainain the ordering of drops of any children # pointers. from . import constraints as s_constraints if isinstance(cmd, s_constraints.ConstraintCommand): self.add(delta) else: # base case self.add_caused(delta) except errors.QueryError as e: orig_schema = context.current().original_schema desc = self.get_friendly_description(schema=orig_schema) raise errors.SchemaDefinitionError( f'cannot {desc} because this affects' f' {" and ".join(refdesc)}', details=e.args[0], ) from e return schema def _get_computed_status_of_fields( self, schema: s_schema.Schema, context: CommandContext, ) -> dict[str, bool]: result = {} mcls = self.get_schema_metaclass() for op in self._enumerate_attribute_cmds(): field = mcls.get_field(op.property) if not field.ephemeral: result[op.property] = op.new_computed return result def _update_computed_fields( self, schema: s_schema.Schema, context: CommandContext, update: Mapping[str, bool], ) -> None: raise NotImplementedError def _append_subcmd_ast( self, schema: s_schema.Schema, node: qlast.DDLOperation, subcmd: Command, context: CommandContext, ) -> None: subnode = subcmd.get_ast(schema, context, parent_node=node) if subnode is not None: node.commands.append(subnode) def _get_ast_node( self, schema: s_schema.Schema, context: CommandContext, ) -> type[qlast.DDLOperation]: # TODO: how to handle the following type: ignore? # in this class, astnode is always a Type[DDLOperation], # but the current design of constraints handles it as # a List[Type[DDLOperation]] return type(self).astnode # type: ignore def _deparse_name( self, schema: s_schema.Schema, context: CommandContext, name: sn.Name, ) -> qlast.ObjectRef: qlclass = self.get_schema_metaclass().get_ql_class() if isinstance(name, sn.QualName): nname = sn.shortname_from_fullname(name) assert isinstance(nname, sn.QualName), \ "expected qualified name" ref = qlast.ObjectRef( module=nname.module, name=nname.name, itemclass=qlclass) else: ref = qlast.ObjectRef(module='', name=str(name), itemclass=qlclass) return ref def _get_ast( self, schema: s_schema.Schema, context: CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: astnode = self._get_ast_node(schema, context) if astnode.get_field('name'): # We need to be able to catch both renames of the object # itself, which might have a long name (for pointers, for # example) as well as an object being referenced by # shortname, if this is (for example) a concrete # constraint and the abstract constraint was renamed. name = context.early_renames.get(self.classname, self.classname) name = sn.shortname_from_fullname(name) if self.classname not in context.early_renames: name = context.early_renames.get(name, name) op = astnode( # type: ignore name=self._deparse_name(schema, context, name), ) else: op = astnode() self._apply_fields_ast(schema, context, op) return op def _apply_fields_ast( self, schema: s_schema.Schema, context: CommandContext, node: qlast.DDLOperation, ) -> None: mcls = self.get_schema_metaclass() if not isinstance(self, DeleteObject): fops = self.get_subcommands(type=AlterObjectProperty) for fop in sorted(fops, key=lambda f: f.property): field = mcls.get_field(fop.property) if fop.new_value is not None: new_value = fop.new_value else: new_value = field.get_default() if ( ( # Only include fields that are not inherited # and that have their value actually changed. not fop.new_inherited or context.descriptive_mode or self.ast_ignore_ownership() or self.ast_ignore_field_ownership(fop.property) ) and ( fop.old_value != new_value or fop.old_inherited != fop.new_inherited or fop.old_computed != fop.new_computed ) ): self._apply_field_ast(schema, context, node, fop) if not isinstance(self, AlterObjectFragment): for field in self.get_ddl_identity_fields(context): ast_attr = self.get_ast_attr_for_field(field.name, type(node)) if ( ast_attr is not None and not getattr(node, ast_attr, None) and ( field.required or self.has_ddl_identity(field.name) ) ): ddl_id = self.get_ddl_identity(field.name) attr_val: Any if issubclass(field.type, s_expr.Expression): assert isinstance(ddl_id, s_expr.Expression) attr_val = ddl_id.parse() elif issubclass(field.type, s_expr.ExpressionList): assert isinstance(ddl_id, s_expr.ExpressionList) attr_val = [e.parse() for e in ddl_id] elif issubclass(field.type, s_expr.ExpressionDict): assert isinstance(ddl_id, s_expr.ExpressionDict) attr_val = { name: e.parse() for name, e in ddl_id.items() } else: raise AssertionError( f'unexpected type of ddl_identity' f' field: {field.type!r}' ) setattr(node, ast_attr, attr_val) # Keep subcommands from refdicts and alter fragments (like # rename, rebase) in order when producing DDL asts refdicts = tuple(x.ref_cls for x in mcls.get_refdicts()) for op in self.get_subcommands(): if ( isinstance(op, AlterObjectFragment) or (isinstance(op, ObjectCommand) and issubclass(op.get_schema_metaclass(), refdicts)) ): self._append_subcmd_ast(schema, node, op, context) else: for op in self.get_subcommands(type=AlterObjectFragment): self._append_subcmd_ast(schema, node, op, context) if isinstance(node, qlast.DropObject): def _is_drop(ddl: qlast.DDLOperation) -> bool: return ( isinstance(ddl, (qlast.DropObject, qlast.AlterObject)) and all(_is_drop(sub) for sub in ddl.commands) ) # Deletes in the AST shouldn't have subcommands, so we # drop them. To try to make sure we aren't papering # over bugs by dropping things we dont expect, make # sure every subcommand was also a delete (or an alter # containing only deletes) assert all(_is_drop(sub) for sub in node.commands) node.commands = [] def _apply_field_ast( self, schema: s_schema.Schema, context: CommandContext, node: qlast.DDLOperation, op: AlterObjectProperty, ) -> None: if op.property != 'name': subnode = op._get_ast(schema, context, parent_node=node) if subnode is not None: node.commands.append(subnode) def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: return None def get_ddl_identity_fields( self, context: CommandContext, ) -> tuple[so.Field[Any], ...]: mcls = self.get_schema_metaclass() return tuple(f for f in mcls.get_fields().values() if f.ddl_identity) @classmethod def maybe_get_schema_metaclass(cls) -> Optional[type[Object_T]]: return cls._schema_metaclass @classmethod def get_schema_metaclass(cls) -> type[Object_T]: if cls._schema_metaclass is None: raise TypeError(f'schema metaclass not set for {cls}') return cls._schema_metaclass @classmethod def get_other_command_class[ObjectCommand_T: ObjectCommand[so.Object]]( cls, cmdtype: type[ObjectCommand_T], ) -> type[ObjectCommand_T]: mcls = cls.get_schema_metaclass() return get_object_command_class_or_die(cmdtype, mcls) def _validate_legal_command( self, schema: s_schema.Schema, context: CommandContext, ) -> None: from . import functions as s_func if (not context.stdmode and not context.testmode and not isinstance(self, s_func.ParameterCommand)): if ( isinstance(self.classname, sn.QualName) and (modroot := self.classname.get_root_module_name()) and modroot in s_schema.STD_MODULES and not ( modroot == s_schema.EXT_MODULE and context.transient_derivation ) ): raise errors.SchemaDefinitionError( f'cannot {self._delta_action} {self.get_verbosename()}: ' f'module {modroot} is read-only', span=self.span) def get_verbosename(self, parent: Optional[str] = None) -> str: mcls = self.get_schema_metaclass() return mcls.get_verbosename_static(self.classname, parent=parent) def get_displayname(self) -> str: mcls = self.get_schema_metaclass() return mcls.get_displayname_static(self.classname) def get_friendly_object_name_for_description( self, *, parent_op: Optional[Command] = None, schema: Optional[s_schema.Schema] = None, object: Optional[Object_T] = None, object_desc: Optional[str] = None, ) -> str: if object_desc is not None: return object_desc else: if object is None: object = cast( Object_T, getattr(self, 'scls', cast(Object_T, _dummy_object)), ) if object is _dummy_object or schema is None: if not isinstance(parent_op, ObjectCommand): parent_desc = None else: parent_desc = parent_op.get_verbosename() object_desc = self.get_verbosename(parent=parent_desc) else: object_desc = object.get_verbosename(schema, with_parent=True) return object_desc @overload def get_object( self, schema: s_schema.Schema, context: CommandContext, *, name: Optional[sn.Name] = None, default: Object_T | so.NoDefaultT = so.NoDefault, span: Optional[parsing.Span] = None, ) -> Object_T: ... @overload def get_object( self, schema: s_schema.Schema, context: CommandContext, *, name: Optional[sn.Name] = None, default: None = None, span: Optional[parsing.Span] = None, ) -> Optional[Object_T]: ... def get_object( self, schema: s_schema.Schema, context: CommandContext, *, name: Optional[sn.Name] = None, default: Object_T | so.NoDefaultT | None = so.NoDefault, span: Optional[parsing.Span] = None, ) -> Optional[Object_T]: metaclass = self.get_schema_metaclass() if name is None: name = self.classname rename = context.renames.get(name) if rename is not None: name = rename return schema.get_global(metaclass, name, default=default) def canonicalize_attributes( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: """Resolve, canonicalize and amend field mutations in this command. This is called just before the object described by this command is created or updated but after all prerequisite commands have been applied, so it is safe to resolve object shells and do other schema inquiries here. """ return schema def update_field_status( self, schema: s_schema.Schema, context: CommandContext, ) -> None: computed_status = self._get_computed_status_of_fields(schema, context) self._update_computed_fields(schema, context, computed_status) def populate_ddl_identity( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: return schema def get_resolved_attribute_value( self, attr_name: str, *, schema: s_schema.Schema, context: CommandContext, ) -> Any: raw_value = self.get_attribute_value(attr_name) if raw_value is None: return None value = context.get_cached((self, 'attribute', attr_name)) if value is None: value = self.resolve_attribute_value( attr_name, raw_value, schema=schema, context=context, ) context.cache_value((self, 'attribute', attr_name), value) return value def resolve_attribute_value( self, attr_name: str, raw_value: Any, *, schema: s_schema.Schema, context: CommandContext, ) -> Any: metaclass = self.get_schema_metaclass() field = metaclass.get_field(attr_name) if field is None: raise errors.SchemaDefinitionError( f'got AlterObjectProperty command for ' f'invalid field: {metaclass.__name__}.{attr_name}') value = self._resolve_attr_value( raw_value, attr_name, field, schema) if isinstance(value, s_expr.Expression): if not value.is_compiled(): value = self.compile_expr_field(schema, context, field, value) if id := self.get_attribute_value('id'): value.set_origin(id, attr_name) elif isinstance(value, s_expr.ExpressionDict): compiled = {} obj_id = self.get_attribute_value('id') for k, v in value.items(): if not v.is_compiled(): v = self.compile_expr_field(schema, context, field, v) if obj_id: v.set_origin(obj_id, attr_name) compiled[k] = v value = compiled return value def get_attributes( self, schema: s_schema.Schema, context: CommandContext, ) -> dict[str, Any]: result = {} for attr in self.enumerate_attributes(): result[attr] = self.get_attribute_value(attr) return result def get_resolved_attributes( self, schema: s_schema.Schema, context: CommandContext, ) -> dict[str, Any]: result = {} for attr in self.enumerate_attributes(): result[attr] = self.get_resolved_attribute_value( attr, schema=schema, context=context) return result def get_orig_attributes( self, schema: s_schema.Schema, context: CommandContext, ) -> dict[str, Any]: result = {} for attr in self.enumerate_attributes(): result[attr] = self.get_orig_attribute_value(attr) return result def get_specified_attribute_value( self, field: str, schema: s_schema.Schema, context: CommandContext, ) -> Optional[Any]: """Fetch the specified (not computed) value of a field. If the command is an alter, it will fall back to the value in the schema. Return None if there is no specified value or if the specified value is being reset. """ spec = self.get_attribute_value(field) is_alter = ( isinstance(self, AlterObject) or ( isinstance(self, AlterObjectFragment) and isinstance(self.get_parent_op(context), AlterObject) ) ) if ( is_alter and spec is None and not self.has_attribute_value(field) and field not in self.scls.get_computed_fields(schema) ): spec = self.scls.get_explicit_field_value( schema, field, default=None) return spec def compile_expr_field( self, schema: s_schema.Schema, context: CommandContext, field: so.Field[Any], value: Any, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: cdn = self.get_schema_metaclass().get_schema_class_displayname() raise errors.InternalServerError( f'uncompiled expression in the field {field.name!r} of ' f'{cdn} {self.classname!r}' ) def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: """Return a dummy value for an expression stored in *field*. Schema class command implementations should overload this to specify a dummy value for an expression field, which is necessary when doing dependency type and name propagation switcheroo in _propagate_if_expr_refs() / _finalize_affected_refs(). """ raise NotImplementedError def _create_begin( self, schema: s_schema.Schema, context: CommandContext ) -> s_schema.Schema: raise NotImplementedError def new_context( self: ObjectCommand[Object_T], schema: s_schema.Schema, context: CommandContext, scls: Object_T, ) -> CommandContextWrapper[ObjectCommand[Object_T]]: ctxcls = type(self).get_context_class() assert ctxcls is not None return context( ctxcls(schema=schema, op=self, scls=scls), # type: ignore ) def get_ast( self, schema: s_schema.Schema, context: CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: dummy = cast(Object_T, _dummy_object) context_class = type(self).get_context_class() if context_class is not None: with self.new_context(schema, context, dummy): return self._get_ast(schema, context, parent_node=parent_node) else: return self._get_ast(schema, context, parent_node=parent_node) def get_ddl_identity(self, aspect: str) -> Any: if self.ddl_identity is None: raise LookupError(f'{self!r} has no DDL identity information') value = self.ddl_identity.get(aspect) if value is None: raise LookupError(f'{self!r} has no {aspect!r} in DDL identity') return value def has_ddl_identity(self, aspect: str) -> bool: return ( self.ddl_identity is not None and self.ddl_identity.get(aspect) is not None ) def set_ddl_identity(self, aspect: str, value: Any) -> None: if self.ddl_identity is None: self.ddl_identity = {} self.ddl_identity[aspect] = value def maybe_get_object_aux_data(self, field: str) -> Any: if self.aux_object_data is None: return None else: value = self.aux_object_data.get(field) if value is None: return None else: return value def get_object_aux_data(self, field: str) -> Any: if self.aux_object_data is None: raise LookupError(f'{self!r} has no auxiliary object information') value = self.aux_object_data.get(field) if value is None: raise LookupError( f'{self!r} has no {field!r} in auxiliary object information') return value def has_object_aux_data(self, field: str) -> bool: return ( self.aux_object_data is not None and self.aux_object_data.get(field) is not None ) def set_object_aux_data(self, field: str, value: Any) -> None: if self.aux_object_data is None: self.aux_object_data = {} self.aux_object_data[field] = value def get_annotation(self, name: str) -> Any: if self.annotations is None: return None else: return self.annotations.get(name) def set_annotation(self, name: str, value: Any) -> None: if self.annotations is None: self.annotations = {} self.annotations[name] = value def ast_ignore_ownership(self) -> bool: """Whether to force generating an AST even though it isn't owned""" return False def ast_ignore_field_ownership(self, field: str) -> bool: """Whether to force generating an AST even though it isn't owned""" return False class ObjectCommandContext[Object_T: so.Object]( CommandContextToken[ObjectCommand[Object_T]] ): def __init__( self, schema: s_schema.Schema, op: ObjectCommand[Object_T], scls: Object_T, *, modaliases: Optional[Mapping[Optional[str], str]] = None, localnames: AbstractSet[str] = frozenset(), ) -> None: super().__init__( schema, op, modaliases=modaliases, localnames=localnames) self.scls = scls class QualifiedObjectCommand(ObjectCommand[so.QualifiedObject_T]): classname = struct.Field(sn.QualName) @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: CommandContext, ) -> sn.QualName: objref = astnode.name module = context.modaliases.get(objref.module, objref.module) if module is None: raise errors.SchemaDefinitionError( f'unqualified name and no default module set', span=objref.span, ) return sn.QualName(module=module, name=objref.name) @overload def get_object( self, schema: s_schema.Schema, context: CommandContext, *, name: Optional[sn.Name] = None, default: so.QualifiedObject_T | so.NoDefaultT = so.NoDefault, span: Optional[parsing.Span] = None, ) -> so.QualifiedObject_T: ... @overload def get_object( self, schema: s_schema.Schema, context: CommandContext, *, name: Optional[sn.Name] = None, default: None = None, span: Optional[parsing.Span] = None, ) -> Optional[so.QualifiedObject_T]: ... def get_object( self, schema: s_schema.Schema, context: CommandContext, *, name: Optional[sn.Name] = None, default: so.QualifiedObject_T | so.NoDefaultT | None = so.NoDefault, span: Optional[parsing.Span] = None, ) -> Optional[so.QualifiedObject_T]: if name is None: name = self.classname rename = context.renames.get(name) if rename is not None: name = rename metaclass = self.get_schema_metaclass() if span is None: span = self.span return schema.get( name, type=metaclass, default=default, span=span) class GlobalObjectCommand(ObjectCommand[so.GlobalObject_T]): pass class ExternalObjectCommand(ObjectCommand[so.ExternalObject_T]): pass class CreateObject[Object_T: so.Object](ObjectCommand[Object_T]): _delta_action = 'create' # If the command is conditioned with IF NOT EXISTS if_not_exists = struct.Field(bool, default=False) def is_data_safe(self) -> bool: # Creations are always data-safe. return True @classmethod def command_for_ast_node( cls, astnode: qlast.DDLOperation, schema: s_schema.Schema, context: CommandContext, ) -> type[ObjectCommand[Object_T]]: assert isinstance(astnode, qlast.CreateObject), "expected CreateObject" if astnode.sdl_alter_if_exists: modaliases = cls._modaliases_from_ast(schema, astnode, context) dummy_op = cls( classname=sn.QualName('placeholder', 'placeholder')) ctxcls = cast( type[ObjectCommandContext[Object_T]], cls.get_context_class_or_die(), ) ctx = ctxcls( schema, op=dummy_op, scls=cast(Object_T, _dummy_object), modaliases=modaliases, ) with context(ctx): classname = cls._classname_from_ast(schema, astnode, context) mcls = cls.get_schema_metaclass() if schema.get(classname, default=None) is not None: return get_object_command_class_or_die( AlterObject, mcls) return cls @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(astnode, qlast.CreateObject) assert isinstance(cmd, CreateObject) cmd.if_not_exists = astnode.create_if_not_exists cmd.set_attribute_value('name', cmd.classname) if getattr(astnode, 'abstract', False): cmd.set_attribute_value('abstract', True) return cmd def get_verb(self) -> str: return 'create' def validate_create( self, schema: s_schema.Schema, context: CommandContext, ) -> None: # Check if functions by this name exist obj_name = self.get_attribute_value('name') if obj_name is not None and not sn.is_fullname(str(obj_name)): from . import functions as s_func funcs = s_func.lookup_functions(obj_name, tuple(), schema=schema) if funcs: raise errors.SchemaError( f'{funcs[0].get_verbosename(schema)} already exists') def _create_begin( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: self._validate_legal_command(schema, context) schema = self.apply_prerequisites(schema, context) if not context.canonical: schema = self.populate_ddl_identity(schema, context) schema = self.canonicalize_attributes(schema, context) self.update_field_status(schema, context) self.validate_create(schema, context) metaclass = self.get_schema_metaclass() props = self.get_resolved_attributes(schema, context) if not props.get('id'): if context.schema_object_ids is not None: specified_id = self.get_prespecified_id(context) if specified_id is not None: props['id'] = specified_id # This takes the span of the delta command and attaches it to the schema # object. In practice, this means that span of DDL CREATE and SDL cmds # is saved to the schema. # But only to the in-memory repr of schema, since span is marked as # ephemeral. This is because spans are large and not really needed in # normal schema work, but are needed for language server. if self.span and 'span' not in props: props['span'] = self.span schema, self.scls = metaclass.create_in_schema( schema, stable_ids=context.stable_ids, **props) if not self.get_attribute_value('id'): # Record the generated ID. self.set_attribute_value('id', self.scls.id) return schema def get_prespecified_id( self, context: CommandContext, *, id_field: str = 'id', ) -> Optional[uuid.UUID]: if context.schema_object_ids is None: return None mcls = self.get_schema_metaclass() qlclass: Optional[str] if issubclass(mcls, so.QualifiedObject): qlclass = None else: qlclass = mcls.get_ql_class_or_die() objname = self.classname if context.compat_ver_is_before( (1, 0, verutils.VersionStage.ALPHA, 5) ): # Pre alpha.5 used to have a different name mangling scheme. objname = sn.compat_name_remangle(str(objname)) if id_field != 'id': qlclass = f'{qlclass}-{id_field}' key = (objname, qlclass) return context.schema_object_ids.get(key) def canonicalize_attributes( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) self.set_attribute_value('builtin', context.stdmode) if not self.has_attribute_value('internal'): self.set_attribute_value('internal', context.internal_schema_mode) return schema def _update_computed_fields( self, schema: s_schema.Schema, context: CommandContext, update: Mapping[str, bool], ) -> None: computed_fields = {n for n, v in update.items() if v} if computed_fields: self.set_attribute_value( 'computed_fields', frozenset(computed_fields)) def _get_ast( self, schema: s_schema.Schema, context: CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: node = super()._get_ast(schema, context, parent_node=parent_node) if node is not None and self.if_not_exists: assert isinstance(node, qlast.CreateObject) node.create_if_not_exists = True return node def _create_innards( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: return self.apply_subcommands(schema, context) def _create_finalize( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: if not context.canonical: # This is rarely triggered. schema = self._finalize_affected_refs(schema, context) self.validate_object(schema, context) return schema def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: with self.new_context(schema, context, _dummy_object): # type: ignore if self.if_not_exists: scls = self.get_object(schema, context, default=None) if scls is not None: parent_ctx = context.parent() if parent_ctx is not None and not self.canonical: parent_ctx.op.discard(self) self.scls = scls return schema schema = self._create_begin(schema, context) ctx = context.current() objctx = cast(ObjectCommandContext[Object_T], ctx) objctx.scls = self.scls schema = self._create_innards(schema, context) schema = self.apply_caused(schema, context) schema = self._create_finalize(schema, context) return schema class CreateExternalObject( CreateObject[so.ExternalObject_T], ExternalObjectCommand[so.ExternalObject_T], ): def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: with self.new_context(schema, context, _dummy_object): # type: ignore if self.if_not_exists: raise NotImplementedError( 'if_not_exists not implemented for external objects') schema = self._create_begin(schema, context) schema = self._create_innards(schema, context) schema = self.apply_caused(schema, context) schema = self._create_finalize(schema, context) return schema def _create_begin( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: self._validate_legal_command(schema, context) if not context.canonical: schema = self.populate_ddl_identity(schema, context) schema = self.canonicalize_attributes(schema, context) self.update_field_status(schema, context) self.validate_create(schema, context) props = self.get_resolved_attributes(schema, context) metaclass = self.get_schema_metaclass() obj_id = props.get('id') if obj_id is None: obj_id = metaclass._prepare_id(schema, context.stable_ids, props) self.set_attribute_value('id', obj_id) self.scls = metaclass._create_from_id(obj_id) return schema class AlterObjectOrFragment[Object_T: so.Object](ObjectCommand[Object_T]): def canonicalize_attributes( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) # Hydrate the ALTER fields with original field values, # if not present. for cmd in self.get_subcommands(type=AlterObjectProperty): if cmd.old_value is None: cmd.old_value = self.scls.get_explicit_field_value( schema, cmd.property, default=None) return schema def validate_alter( self, schema: s_schema.Schema, context: CommandContext, ) -> None: self._validate_legal_command(schema, context) def _alter_begin( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: schema = self.apply_prerequisites(schema, context) if not context.canonical: schema = self.populate_ddl_identity(schema, context) schema = self.canonicalize_attributes(schema, context) self.update_field_status(schema, context) self.validate_alter(schema, context) props = self.get_resolved_attributes(schema, context) return self.scls.update(schema, props) def _alter_innards( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: return self.apply_subcommands(schema, context) def _populate_link_reflection_fields( self, schema: s_schema.Schema, context: CommandContext, ) -> None: """For objects reflected with AS_LINK, populate all attributes This is kind of a hack around reflection.writer (... and edgeql) deficiencies, where reflection needs anything that is reflected as a linkprop to actually be present in the Alter, or it will be lost. """ if isinstance(self, AlterSpecialObjectField): return mcls = self.get_schema_metaclass() for name in mcls.get_schema_fields().keys(): if not self.has_attribute_value(name): try: value = self.scls.get_explicit_field_value( schema, name ) except so.FieldValueNotFoundError: continue self.set_attribute_value( name, value, orig_value=value, inherited=self.scls.field_is_inherited(schema, name), computed=self.scls.field_is_computed(schema, name), disallow_special=True, ) def _alter_finalize( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: schema = self._finalize_affected_refs(schema, context) if not context.canonical: self.validate_object(schema, context) mcls = self.get_schema_metaclass() if mcls.get_reflection_method() == so.ReflectionMethod.AS_LINK: self._populate_link_reflection_fields(schema, context) return schema def _update_computed_fields( self, schema: s_schema.Schema, context: CommandContext, update: Mapping[str, bool], ) -> None: cur_comp_fields = self.scls.get_computed_fields(schema) comp_fields = set(cur_comp_fields) for fn, computed in update.items(): if computed: comp_fields.add(fn) else: comp_fields.discard(fn) if cur_comp_fields != comp_fields: if comp_fields: self.set_attribute_value( 'computed_fields', frozenset(comp_fields), orig_value=cur_comp_fields if cur_comp_fields else None, ) else: self.set_attribute_value( 'computed_fields', None, orig_value=cur_comp_fields if cur_comp_fields else None, ) class AlterObjectFragment[Object_T: so.Object](AlterObjectOrFragment[Object_T]): def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: # AlterObjectFragment must be executed in the context # of a parent AlterObject command. scls = self.get_parent_op(context).scls self.scls = cast(Object_T, scls) schema = self._alter_begin(schema, context) schema = self._alter_innards(schema, context) schema = self.apply_caused(schema, context) schema = self._alter_finalize(schema, context) return schema @classmethod def get_parent_op( cls, context: CommandContext, ) -> ObjectCommand[so.Object]: op = context.current().op assert isinstance(op, ObjectCommand) return op class RenameObject[Object_T: so.Object](AlterObjectFragment[Object_T]): _delta_action = 'rename' astnode = qlast.Rename new_name = struct.Field(sn.Name) def is_data_safe(self) -> bool: # Renames are always data-safe. return True def get_verb(self) -> str: return 'rename' def get_friendly_description( self, *, parent_op: Optional[Command] = None, schema: Optional[s_schema.Schema] = None, object: Any = None, object_desc: Optional[str] = None, ) -> str: object_desc = self.get_friendly_object_name_for_description( parent_op=parent_op, schema=schema, object=object, object_desc=object_desc, ) mcls = self.get_schema_metaclass() new_name = mcls.get_displayname_static(self.new_name) return f"rename {object_desc} to '{new_name}'" def _alter_begin( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: scls = self.scls context.renames[self.classname] = self.new_name context.renamed_objs.add(scls) # Propagate the change, but only if it wasn't handled by the # enclosing Alter. if context.current().op not in context.affected_finalization: vn = scls.get_verbosename(schema) schema = self._propagate_if_expr_refs( schema, context, action=f'rename {vn}', metadata_only=True, ) if not context.canonical: self.set_attribute_value( 'name', value=self.new_name, orig_value=self.classname, ) return super()._alter_begin(schema, context) def _alter_innards( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: if not context.canonical: self._canonicalize(schema, context, self.scls) return super()._alter_innards(schema, context) def init_rename_branch( self, ref: so.Object, new_ref_name: sn.Name, schema: s_schema.Schema, context: CommandContext, ) -> Command: ref_root, ref_alter, _ = ref.init_delta_branch( schema, context, AlterObject) ref_alter.add( ref.init_delta_command( schema, RenameObject, new_name=new_ref_name, ), ) return ref_root def _canonicalize( self, schema: s_schema.Schema, context: CommandContext, scls: Object_T, ) -> None: mcls = self.get_schema_metaclass() for refdict in mcls.get_refdicts(): all_refs = set( scls.get_field_value(schema, refdict.attr).objects(schema) ) ref: so.Object for ref in all_refs: ref_name = ref.get_name(schema) quals = list(sn.quals_from_fullname(ref_name)) assert isinstance(self.new_name, sn.QualName) quals[0] = str(self.new_name) shortname = sn.shortname_from_fullname(ref_name) new_ref_name = sn.QualName( name=sn.get_specialized_name(shortname, *quals), module=self.new_name.module, ) self.add(self.init_rename_branch( ref, new_ref_name, schema=schema, context=context, )) def _get_ast( self, schema: s_schema.Schema, context: CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: astnode = self._get_ast_node(schema, context) ref = self._deparse_name(schema, context, self.new_name) ref.itemclass = None orig_ref = self._deparse_name(schema, context, self.classname) # Ha, ha! Do it recursively to force any renames in children! self._log_all_renames(context) if (orig_ref.module, orig_ref.name) != (ref.module, ref.name): return astnode(new_name=ref) # type: ignore else: return None @classmethod def _cmd_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> RenameObject[Object_T]: parent_ctx = context.current() parent_op = parent_ctx.op assert isinstance(parent_op, ObjectCommand) parent_class = parent_op.get_schema_metaclass() rename_class = get_object_command_class_or_die( RenameObject, parent_class) return rename_class._rename_cmd_from_ast(schema, astnode, context) @classmethod def _rename_cmd_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> RenameObject[Object_T]: assert isinstance(astnode, qlast.Rename) parent_ctx = context.current() parent_op = parent_ctx.op assert isinstance(parent_op, ObjectCommand) parent_class = parent_op.get_schema_metaclass() rename_class = get_object_command_class_or_die( RenameObject, parent_class) new_name = cls._classname_from_ast(schema, astnode, context) # Populate the early_renames map of the context as we go, since # in-flight renames will affect the generated names of later # operations. context.early_renames[parent_op.classname] = new_name return rename_class( classname=parent_op.classname, new_name=new_name, ) class AlterObject[Object_T: so.Object](AlterObjectOrFragment[Object_T]): _delta_action = 'alter' #: If True, apply the command only if the object exists. if_exists = struct.Field(bool, default=False) #: If True, only apply changes to properties, not "real" schema changes metadata_only = struct.Field(bool, default=False) def get_verb(self) -> str: return 'alter' @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, AlterObject) if getattr(astnode, 'abstract', False): cmd.set_attribute_value('abstract', True) return cmd def _get_ast( self, schema: s_schema.Schema, context: CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: node = super()._get_ast(schema, context, parent_node=parent_node) if (node is not None and hasattr(node, 'commands') and not node.commands): # Alter node without subcommands. Occurs when all # subcommands have been filtered out of DDL stream, # so filter it out as well. node = None return node def canonicalize_alter_from_external_ref( self, schema: s_schema.Schema, context: CommandContext, ) -> None: """Canonicalize an ALTER command triggered by a modification of a an object referred to by an expression in this object.""" pass def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: if not context.canonical and self.if_exists: scls = self.get_object(schema, context, default=None) if scls is None: context.current().op.discard(self) return schema else: scls = self.get_object(schema, context) self.scls = scls with self.new_context(schema, context, scls): schema = self._alter_begin(schema, context) schema = self._alter_innards(schema, context) schema = self.apply_caused(schema, context) schema = self._alter_finalize(schema, context) return schema class DeleteObject[Object_T: so.Object](ObjectCommand[Object_T]): _delta_action = 'delete' #: If True, apply the command only if the object exists. if_exists = struct.Field(bool, default=False) #: If True, apply the command only if the object has no referrers #: in the schema. if_unused = struct.Field(bool, default=False) def get_verb(self) -> str: return 'drop' def is_data_safe(self) -> bool: # Deletions are only safe if the entire object class # has been declared as data-safe. return self.get_schema_metaclass()._data_safe def _delete_begin( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: from . import ordering self._validate_legal_command(schema, context) schema = self.apply_prerequisites(schema, context) if not context.canonical: schema = self.populate_ddl_identity(schema, context) schema = self.canonicalize_attributes(schema, context) if not context.get_value(('delcanon', self)): commands = self._canonicalize(schema, context, self.scls) root = DeltaRoot() root.update(commands) root = ordering.linearize_delta(root, schema, schema) self.update(root.get_subcommands()) return schema def _canonicalize( self, schema: s_schema.Schema, context: CommandContext, scls: Object_T, ) -> list[Command]: mcls = self.get_schema_metaclass() commands: list[Command] = [] for refdict in mcls.get_refdicts(): deleted_refs = set() all_refs = set( scls.get_field_value(schema, refdict.attr).objects(schema) ) refcmds = cast( tuple[ObjectCommand[so.Object], ...], self.get_subcommands(metaclass=refdict.ref_cls), ) for op in refcmds: deleted_ref: so.Object = schema.get(op.classname) deleted_refs.add(deleted_ref) # Add implicit Delete commands for any local refs not # deleted explicitly. for ref in all_refs - deleted_refs: op = ref.init_delta_command(schema, DeleteObject) assert isinstance(op, DeleteObject) subcmds = op._canonicalize(schema, context, ref) op.update(subcmds) commands.append(op) # Record the fact that DeleteObject._canonicalize # was called on this object to guard against possible # duplicate calls. context.store_value(('delcanon', self), True) return commands def _delete_innards( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: return self.apply_subcommands(schema, context) def _delete_finalize( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: ref_strs = [] if not context.canonical and not context.disable_dep_verification: refs = schema.get_referrers(self.scls) ctx = context.current() assert ctx is not None orig_schema = ctx.original_schema if refs: for ref in refs: if (not self._is_deleting_ref(schema, context, ref) and ref.is_blocking_ref(orig_schema, self.scls)): ref_strs.append( ref.get_verbosename(orig_schema, with_parent=True)) if ref_strs: vn = self.scls.get_verbosename(orig_schema, with_parent=True) dn = self.scls.get_displayname(orig_schema) detail = '; '.join(f'{ref_str} depends on {dn}' for ref_str in ref_strs) raise errors.SchemaError( f'cannot drop {vn} because ' f'other objects in the schema depend on it', details=detail, ) schema = schema.delete(self.scls) if not context.canonical: schema = self._finalize_affected_refs(schema, context) return schema def _is_deleting_ref( self, schema: s_schema.Schema, context: CommandContext, ref: so.Object, ) -> bool: if context.is_deleting(ref): return True for op in self.get_prerequisites(): if isinstance(op, DeleteObject) and op.scls == ref: return True return False def _has_outside_references( self, schema: s_schema.Schema, context: CommandContext, ) -> bool: # Check if the subject of this command has any outside references # minus any current expiring refs and minus structural child refs # (e.g. source backref in pointers of an object type). refs = [ ref for ref in schema.get_referrers(self.scls) if not ref.is_parent_ref(schema, self.scls) and not context.is_deleting(ref) ] return bool(refs) def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: if self.if_exists: scls = self.get_object(schema, context, default=None) if scls is None: context.current().op.discard(self) return schema else: scls = self.get_object(schema, context) self.scls = scls with self.new_context(schema, context, scls): if ( self.if_unused and self._has_outside_references(schema, context) ): parent_ctx = context.parent() if parent_ctx is not None: parent_ctx.op.discard(self) return schema schema = self._delete_begin(schema, context) schema = self._delete_innards(schema, context) schema = self.apply_caused(schema, context) schema = self._delete_finalize(schema, context) return schema class AlterExternalObject[ExternalObject_T: so.ExternalObject]( AlterObject[ExternalObject_T], ExternalObjectCommand[ExternalObject_T], ): def _alter_begin( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: schema = self.apply_prerequisites(schema, context) return schema def _alter_innards( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: return self.apply_subcommands(schema, context) def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: self.scls = _dummy_object # type: ignore with self.new_context(schema, context, self.scls): schema = self._alter_begin(schema, context) schema = self._alter_innards(schema, context) schema = self.apply_caused(schema, context) schema = self._alter_finalize(schema, context) return schema class DeleteExternalObject[ExternalObject_T: so.ExternalObject]( DeleteObject[ExternalObject_T], ExternalObjectCommand[ExternalObject_T], ): def _delete_begin( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: self._validate_legal_command(schema, context) return schema def _delete_innards( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: for op in self.get_subcommands(metaclass=so.Object): schema = op.apply(schema, context=context) return schema def _delete_finalize( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: return schema def apply( self, schema: s_schema.Schema, context: CommandContext, ) -> s_schema.Schema: self.scls = _dummy_object # type: ignore with self.new_context(schema, context, self.scls): schema = self._delete_begin(schema, context) schema = self._delete_innards(schema, context) schema = self.apply_caused(schema, context) schema = self._delete_finalize(schema, context) return schema special_field_alter_handlers: dict[ str, dict[type[so.Object], type[AlterSpecialObjectField[so.Object]]], ] = {} class AlterSpecialObjectField[Object_T: so.Object]( AlterObjectFragment[Object_T] ): """Base class for AlterObjectFragment implementations for special fields. When the generic `AlterObjectProperty` handling of field value transitions is insufficient, declare a subclass of this to implement custom handling. """ _field: ClassVar[str] def __init_subclass__( cls, *, field: Optional[str] = None, **kwargs: Any, ) -> None: super().__init_subclass__(**kwargs) if field is None: if any( issubclass(b, AlterSpecialObjectField) for b in cls.__mro__[1:] ): return else: raise TypeError( "AlterSpecialObjectField.__init_subclass__() missing " "1 required keyword-only argument: 'field'" ) handlers = special_field_alter_handlers.get(field) if handlers is None: handlers = special_field_alter_handlers[field] = {} schema_metaclass = cls.get_schema_metaclass() handlers[schema_metaclass] = cls # type: ignore cls._field = field def clone(self, name: sn.Name) -> AlterSpecialObjectField[Object_T]: return struct.Struct.replace(self, classname=name) @classmethod def _cmd_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> ObjectCommand[Object_T]: this_op = context.current().op assert isinstance(this_op, ObjectCommand) return cls(classname=this_op.classname) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> Command: assert isinstance(astnode, qlast.SetField) cmd = super()._cmd_tree_from_ast(schema, astnode, context) cmd.add(AlterObjectProperty.regular_cmd_from_ast( schema, astnode, context)) return cmd def _get_ast( self, schema: s_schema.Schema, context: CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: attrs = self._enumerate_attribute_cmds() assert len(attrs) == 1, "expected one attribute command" return attrs[0]._get_ast(schema, context, parent_node=parent_node) def get_verb(self) -> str: return f'alter the {self._field} of' def get_special_field_alter_handler( field: str, schema_cls: type[so.Object], ) -> Optional[type[AlterSpecialObjectField[so.Object]]]: """Return a custom handler for the field value transition, if any. Returns a subclass of AlterSpecialObjectField, when in the context of an AlterObject operation, and a special handler has been declared. """ field_handlers = special_field_alter_handlers.get(field) if field_handlers is None: return None return field_handlers.get(schema_cls) def get_special_field_create_handler( field: str, schema_cls: type[so.Object], ) -> Optional[type[AlterSpecialObjectField[so.Object]]]: """Return a custom handler for the field value transition, if any. Returns a subclass of AlterSpecialObjectField, when in the context of an CreateObject operation, and a special handler has been declared. For now this is just a hacky special case: the 'required' field of Pointers. If that changes, we should generalize the mechanism. """ if field != 'required': return None return get_special_field_alter_handler(field, schema_cls) def get_special_field_alter_handler_for_context( field: str, context: CommandContext, ) -> Optional[type[AlterSpecialObjectField[so.Object]]]: """Return a custom handler for the field value transition, if any. Returns a subclass of AlterSpecialObjectField, when in the context of an AlterObject operation, and a special handler has been declared. """ this_op = context.current().op if ( isinstance(this_op, AlterObjectOrFragment) and not isinstance(this_op, AlterSpecialObjectField) ): mcls = this_op.get_schema_metaclass() return get_special_field_alter_handler(field, mcls) elif isinstance(this_op, CreateObject): mcls = this_op.get_schema_metaclass() return get_special_field_create_handler(field, mcls) else: return None class AlterObjectProperty(Command): astnode = qlast.SetField property = struct.Field(str) old_value = struct.Field[Any](object, default=None) new_value = struct.Field[Any](object, default=None) old_inherited = struct.Field(bool, default=False) new_inherited = struct.Field(bool, default=False) new_computed = struct.Field(bool, default=False) old_computed = struct.Field(bool, default=False) from_default = struct.Field(bool, default=False) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: CommandContext, ) -> Command: assert isinstance(astnode, qlast.SetField) handler = get_special_field_alter_handler_for_context( astnode.name, context) if handler is not None: return handler._cmd_tree_from_ast(schema, astnode, context) else: return cls.regular_cmd_from_ast(schema, astnode, context) @classmethod def regular_cmd_from_ast( cls, schema: s_schema.Schema, astnode: qlast.SetField, context: CommandContext, ) -> Command: propname = astnode.name parent_ctx = context.current() parent_op = parent_ctx.op assert isinstance(parent_op, ObjectCommand) parent_cls = parent_op.get_schema_metaclass() if ( propname.startswith('orig_') and context.compat_ver_is_before( (1, 0, verutils.VersionStage.ALPHA, 8) ) and not parent_cls.has_field(propname) ): return Nop() else: try: field = parent_cls.get_field(propname) except LookupError: raise errors.SchemaDefinitionError( f'{propname!r} is not a valid field', span=astnode.span) if not ( astnode.special_syntax or field.allow_ddl_set or context.stdmode or context.testmode ): raise errors.SchemaDefinitionError( f'{propname!r} is not a valid field', span=astnode.span) if field.name == 'id' and not isinstance(parent_op, CreateObject): raise errors.SchemaDefinitionError( f'cannot alter object id', span=astnode.span) ast_value: Optional[qlast.Expr | qlast.TypeExpr] = astnode.value if field.obj_names_as_string: inliner = NameToStringConverter() ast_value = cast( Optional[qlast.Expr | qlast.TypeExpr], inliner.visit(ast_value), ) new_value: Any if field.type is s_expr.Expression: if ast_value is None: new_value = None else: assert isinstance(ast_value, qlast.Expr) orig_text = cls.get_orig_expr_text( schema, parent_op.qlast, field.name) if ( orig_text is not None and context.compat_ver_is_before( (1, 0, verutils.VersionStage.ALPHA, 6) ) ): # Versions prior to a6 used a different expression # normalization strategy, so we must renormalize the # expression. expr_ql = qlcompiler.renormalize_compat( ast_value, orig_text, schema=schema, localnames=context.localnames, ) else: expr_ql = ast_value new_value = s_expr.Expression.from_ast( expr_ql, schema, context.modaliases, context.localnames, ) else: if ( isinstance(ast_value, qlast.Set) and not ast_value.elements ): # empty set new_value = None elif isinstance(ast_value, qlast.Tuple): new_value = tuple( qlcompiler.evaluate_ast_to_python_val( el, schema=schema) for el in ast_value.elements ) # Handle object references elif ( isinstance(ast_value, qlast.Path) and not ast_value.partial and len(ast_value.steps) == 1 and isinstance(ast_value.steps[0], qlast.ObjectRef) ): new_value = utils.ast_to_object_shell( ast_value.steps[0], metaclass=so.Object, modaliases=context.modaliases, schema=schema, ) if issubclass(field.type, so.ObjectCollection): new_value = [new_value] # ... and sets of object references # It is kind of a bummer the way this is special cased, though elif ( isinstance(ast_value, qlast.Set) and all( isinstance(v, qlast.Path) and not v.partial and len(v.steps) == 1 and isinstance(v.steps[0], qlast.ObjectRef) for v in ast_value.elements ) ): new_value = [ utils.ast_to_object_shell( v.steps[0], metaclass=so.Object, modaliases=context.modaliases, schema=schema, ) for v in ast_value.elements if isinstance(v, qlast.Path) and isinstance(v.steps[0], qlast.ObjectRef) ] elif isinstance(ast_value, qlast.TypeExpr): from . import types as s_types if not isinstance(parent_op, QualifiedObjectCommand): raise AssertionError( 'cannot determine module for derived compound type: ' 'parent operation is not a QualifiedObjectCommand' ) new_value = utils.ast_to_type_shell( ast_value, metaclass=s_types.Type, module=parent_op.classname.module, modaliases=context.modaliases, schema=schema, ) if issubclass(field.type, so.ObjectCollection): new_value = [new_value] # ... and sets of object references # It is kind of a bummer the way this is special cased, though elif ( isinstance(ast_value, qlast.Set) and all( isinstance(v, qlast.TypeExpr) for v in ast_value.elements ) ): from . import types as s_types new_value = [ utils.ast_to_type_shell( v, metaclass=s_types.Type, modaliases=context.modaliases, schema=schema, ) for v in ast_value.elements if isinstance(v, qlast.TypeExpr) ] elif ( isinstance(ast_value, qlast.StrInterp) and field.allow_interpolation ): new_value = utils.str_interpolation_to_old_style(ast_value) else: try: new_value = qlcompiler.evaluate_ast_to_python_val( ast_value, schema=schema) if ast_value else None except Exception: raise if new_value is not None: new_value = field.coerce_value(schema, new_value) return cls( property=propname, new_value=new_value, span=astnode.span, ) def is_data_safe(self) -> bool: # Field alterations on existing schema objects # generally represent semantic changes and are # reversible. Non-safe field alters are normally # represented by a dedicated subcommand, such as # SetLinkType. return True def _get_ast( self, schema: s_schema.Schema, context: CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: value = self.new_value new_value_empty = ( value is None or ( utils.is_nontrivial_container(value) is not None and not value ) ) old_value_empty = ( self.old_value is None or ( utils.is_nontrivial_container(self.old_value) is not None and not self.old_value ) ) parent_ctx = context.current() parent_op = parent_ctx.op assert isinstance(parent_op, ObjectCommand) assert parent_node is not None parent_cls = parent_op.get_schema_metaclass() field = parent_cls.get_field(self.property) if field is None: raise errors.SchemaDefinitionError( f'{self.property!r} is not a valid field', span=self.span) if self.property == 'id': return None parent_node_attr = parent_op.get_ast_attr_for_field( field.name, type(parent_node)) if ( not field.allow_ddl_set and not ( field.special_ddl_syntax and isinstance(parent_node, qlast.AlterObject) ) and self.property != 'expr' and parent_node_attr is None ): # Don't produce any AST if: # # * a field does not have the "allow_ddl_set" option, unless # it's an 'expr' field. # # 'expr' fields come from the "USING" clause and are specially # treated in parser and codegen. return None if ( ( self.new_inherited and not self.old_inherited and not old_value_empty ) or ( self.new_computed and not self.old_computed and not self.old_inherited and not old_value_empty ) ): # The field became inherited or computed, in which case we should # generate a RESET. return qlast.SetField( name=self.property, value=None, special_syntax=field.special_ddl_syntax, ) if self.new_inherited or self.new_computed: # We don't want to show inherited or computed properties unless # we are in "descriptive_mode" ... if not context.descriptive_mode: return None if not ( field.describe_visibility & so.DescribeVisibilityFlags.SHOW_IF_DERIVED ): # ... or if the field shouldn't be shown when inherited # or computed. return None if ( not ( field.describe_visibility & so.DescribeVisibilityFlags.SHOW_IF_DEFAULT ) and field.default == value ): # ... or if the field should not be shown when the value # mathdes the default. return None parentop_sn = sn.shortname_from_fullname(parent_op.classname).name if self.property == 'default' and parentop_sn == 'id': # ... or if it's 'default' for the 'id' property # (special case). return None if self.from_default: if not context.descriptive_mode: return None if not ( field.describe_visibility & so.DescribeVisibilityFlags.SHOW_IF_DEFAULT ): # ... or if the field should not be shown when the value # mathdes the default. return None if new_value_empty: if old_value_empty: return None else: value = None elif issubclass(field.type, s_expr.Expression): return self._get_expr_field_ast( schema, context, parent_op=parent_op, field=field, parent_node=parent_node, parent_node_attr=parent_node_attr, ) elif issubclass(field.type, so.ObjectCollection): value = qlast.Set(elements=[ # HACK: This is wrong, but it's good enough. cast(qlast.Expr, utils.shell_to_ast(schema, v)) for v in (value or ()) ]) elif parent_node_attr is not None: setattr(parent_node, parent_node_attr, value) return None elif (v := utils.is_nontrivial_container(value)) and v is not None: value = qlast.Tuple(elements=[ utils.const_ast_from_python(el) for el in v ]) elif isinstance(value, uuid.UUID): value = qlast.TypeCast( expr=qlast.Constant.string(str(value)), type=qlast.TypeName( maintype=qlast.ObjectRef( name='uuid', module='std', ) ) ) elif isinstance(value, so.ObjectShell): value = utils.shell_to_ast(schema, value) else: value = utils.const_ast_from_python(value) return qlast.SetField( name=self.property, value=value, special_syntax=field.special_ddl_syntax, ) def _get_expr_field_ast( self, schema: s_schema.Schema, context: CommandContext, *, parent_op: ObjectCommand[so.Object], field: so.Field[Any], parent_node: qlast.DDLOperation, parent_node_attr: Optional[str], ) -> Optional[qlast.DDLOperation]: from edb import edgeql assert isinstance( self.new_value, (s_expr.Expression, s_expr.ExpressionShell), ) expr_ql = edgeql.parse_fragment(self.new_value.text) if parent_node is not None and parent_node_attr is not None: setattr(parent_node, parent_node_attr, expr_ql) return None else: return qlast.SetField( name=self.property, value=expr_ql, special_syntax=( self.property == 'expr' or field.special_ddl_syntax ), ) def __repr__(self) -> str: return '<%s.%s "%s":"%s"->"%s">' % ( self.__class__.__module__, self.__class__.__name__, self.property, self.old_value, self.new_value) def get_friendly_description( self, *, parent_op: Optional[Command] = None, schema: Optional[s_schema.Schema] = None, object: Any = None, object_desc: Optional[str] = None, ) -> str: if parent_op is not None: assert isinstance(parent_op, ObjectCommand) object_desc = parent_op.get_friendly_object_name_for_description( schema=schema, object=object, object_desc=object_desc, ) return f'alter the {self.property} of {object_desc}' else: return f'alter the {self.property} of schema object' class NameToStringConverter(ast.NodeTransformer): def visit_Path(self, node: qlast.Path) -> qlast.Base: if ( len(node.steps) == 1 and (obj_name := node.steps[0]) and isinstance(obj_name, qlast.ObjectRef) ): if obj_name.module is None: raise errors.SchemaDefinitionError( f"Object name must be fully qualified.", span=node.span, ) return qlast.Constant.string(f"{obj_name.module}::{obj_name.name}") raise errors.SchemaDefinitionError( f"Object references are not allowed here.", span=node.span, ) def compile_ddl( schema: s_schema.Schema, astnode: qlast.DDLOperation, *, context: Optional[CommandContext]=None, ) -> Command: if context is None: context = CommandContext() astnode_type = type(astnode) primary_cmdcls = CommandMeta._astnode_map.get(astnode_type) if primary_cmdcls is None: for astnode_type_base in astnode_type.__mro__[1:]: primary_cmdcls = CommandMeta._astnode_map.get(astnode_type_base) if primary_cmdcls is not None: break else: raise AssertionError( f'no delta command class for AST node {astnode!r}') cmdcls = primary_cmdcls.command_for_ast_node(astnode, schema, context) context_class = cmdcls.get_context_class() if context_class is not None: modaliases = cmdcls._modaliases_from_ast(schema, astnode, context) localnames = cmdcls.localnames_from_ast(schema, astnode, context) ctxcls = cast( type[ObjectCommandContext[so.Object]], context_class, ) ctx = ctxcls( schema, op=cast(ObjectCommand[so.Object], _dummy_command), scls=_dummy_object, modaliases=modaliases, localnames=localnames, ) with context(ctx): cmd = cmdcls._cmd_tree_from_ast(schema, astnode, context) else: cmd = cmdcls._cmd_tree_from_ast(schema, astnode, context) return cmd def get_object_delta_command[ Object_T: so.Object, ObjectCommand_T: ObjectCommand[so.Object] ]( *, objtype: type[Object_T], cmdtype: type[ObjectCommand_T], schema: s_schema.Schema, name: sn.Name, ddl_identity: Optional[Mapping[str, Any]] = None, **kwargs: Any, ) -> ObjectCommand_T: cmdcls = cast( type[ObjectCommand_T], get_object_command_class_or_die(cmdtype, objtype), ) return cmdcls( classname=name, ddl_identity=dict(ddl_identity) if ddl_identity is not None else None, **kwargs, ) CommandKey = tuple[str, type[so.Object], sn.Name, Optional[sn.Name]] def get_object_command_key(delta: ObjectCommand[Any]) -> CommandKey: if delta.orig_cmd_type is not None: cmdtype = delta.orig_cmd_type else: cmdtype = type(delta) new_name = ( getattr(delta, 'new_name', None) or delta.get_annotation('new_name') ) mcls = delta.get_schema_metaclass() return cmdtype.__name__, mcls, delta.classname, new_name def get_object_command_id(key: CommandKey) -> str: cmdclass_name, mcls, name, new_name = key qlcls = mcls.get_ql_class_or_die() extra = ' TO ' + str(new_name) if new_name else '' return f'{cmdclass_name} {qlcls} {name}{extra}' def apply[S: s_schema.Schema]( delta: Command, *, schema: S, context: Optional[CommandContext] = None, ) -> S: if context is None: context = CommandContext() if not isinstance(delta, DeltaRoot): root = DeltaRoot() root.add(delta) else: root = delta return cast(S, root.apply(schema, context)) ================================================ FILE: edb/schema/expr.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, Optional, AbstractSet, Iterable, Mapping, Sequence, TYPE_CHECKING, ) import copy import uuid from edb.common import checked from edb.common import struct from edb.edgeql import ast as qlast_ from edb.edgeql import codegen as qlcodegen from edb.edgeql import compiler as qlcompiler from edb.edgeql import parser as qlparser from edb.edgeql import qltypes from . import objects as so from . import name as sn from . import delta as sd if TYPE_CHECKING: from edb.schema import schema as s_schema from edb.schema import types as s_types from edb.ir import ast as irast_ class Expression(struct.MixedRTStruct, so.ObjectContainer): text = struct.Field(str, frozen=True) # mypy wants an argument to the ObjectSet generic, but # that wouldn't work for struct.Field, since subscripted # generics are not types. refs = struct.Field( so.ObjectSet, # type: ignore coerce=True, default=None, frozen=True, ) # A string describing the provenance of the expression, used to # help annotate the parser contexts. We don't store it explicitly # in the database or explicitly populate it when creating # Expressions, but instead populate it in resolve_attribute_value # and when reading in the schema. origin = struct.Field(str, default=None) def __init__( self, *args: Any, _qlast: Optional[qlast_.Expr] = None, _irast: Optional[irast_.Statement] = None, **kwargs: Any ) -> None: super().__init__(*args, **kwargs) self._qlast = _qlast self._irast = _irast def __getstate__(self) -> dict[str, Any]: return { 'text': self.text, 'refs': self.refs, '_qlast': None, '_irast': None, } def __setstate__(self, state: Mapping[str, Any]) -> None: # Since `origin` is omitted from the pickled schema, it needs to be # explicitly set to `None` when loading pickles. super().__setstate__({"origin": None, **state}) def __eq__(self, rhs: object) -> bool: if not isinstance(rhs, Expression): return NotImplemented return ( self.text == rhs.text and self.refs == rhs.refs and self.origin == rhs.origin ) def parse(self) -> qlast_.Expr: """Parse the expression text into an AST. Cached.""" if self._qlast is None: self._qlast = qlparser.parse_fragment( self.text, filename=f'<{self.origin}>' if self.origin else "") return self._qlast @property def irast(self) -> Optional[irast_.Statement]: return self._irast def set_origin(self, id: uuid.UUID, field: str) -> None: """ Set the origin of the expression based on field and enclosing object. We base the origin on the id of the object, not on its name, because these strings should be useful to a client, which can't do a lookup based on the mangled internal names. """ self.origin = f'{id} {field}' def is_compiled(self) -> bool: return self.refs is not None def _refs_keys( self, schema: s_schema.Schema ) -> set[tuple[type[so.Object], sn.Name]]: return { (type(x), x.get_name(schema)) for x in (self.refs.objects(schema) if self.refs else ()) } @classmethod def compare_values( cls: type[Expression], ours: Expression, theirs: Expression, *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: so.ComparisonContext, compcoef: float, ) -> float: if not ours and not theirs: return 1.0 elif not ours or not theirs: return compcoef # If the new and old versions share a reference to an object # that is being deleted, then we must delete this object as well. our_refs = ours._refs_keys(our_schema) their_refs = theirs._refs_keys(their_schema) if (our_refs & their_refs) & context.deletions.keys(): return 0.0 if ours.text == theirs.text: return 1.0 else: return compcoef @classmethod def from_ast( cls: type[Expression], qltree: qlast_.Expr, schema: s_schema.Schema, modaliases: Optional[Mapping[Optional[str], str]] = None, localnames: AbstractSet[str] = frozenset(), *, as_fragment: bool = False, ) -> Expression: if modaliases is None: modaliases = {} if not as_fragment: qlcompiler.normalize( qltree, schema=schema, modaliases=modaliases, localnames=localnames ) norm_text = qlcodegen.generate_source(qltree, pretty=False) return Expression( text=norm_text, _qlast=qltree, ) def not_compiled(self) -> Expression: return Expression(text=self.text, origin=self.origin) def compiled( self, schema: s_schema.Schema, *, options: Optional[qlcompiler.CompilerOptions] = None, as_fragment: bool = False, detached: bool = False, find_extra_refs: Optional[ Callable[[irast_.Set], set[so.Object]] ] = None, context: Optional[sd.CommandContext], ) -> CompiledExpression: from edb.ir import ast as irast_ from edb.edgeql import ast as qlast if as_fragment: ir: irast_.Command = qlcompiler.compile_ast_fragment_to_ir( self.parse(), schema=schema, options=options, ) else: ql_expr = self.parse() if detached: ql_expr = qlast.DetachedExpr( expr=ql_expr, preserve_path_prefix=True, ) ir = qlcompiler.compile_ast_to_ir( ql_expr, schema=schema, options=options, ) assert isinstance(ir, irast_.Statement) if context and ir.warnings: delta_root = context.top().op if isinstance(delta_root, sd.DeltaRoot): delta_root.warnings.extend(ir.warnings) # XXX: ref stuff - why doesn't it go into the delta tree? - temporary?? srefs: set[so.Object] = { ref for ref in ir.schema_refs if schema.has_object(ref.id) } if find_extra_refs is not None: srefs |= find_extra_refs(ir.expr) return CompiledExpression( text=self.text, refs=so.ObjectSet.create(schema, srefs), _qlast=self.parse(), _irast=ir, origin=self.origin, ) def ensure_compiled( self, schema: s_schema.Schema, *, options: Optional[qlcompiler.CompilerOptions] = None, as_fragment: bool = False, context: Optional[sd.CommandContext], ) -> CompiledExpression: if self._irast: return self # type: ignore else: return self.compiled( schema, options=options, as_fragment=as_fragment, context=context) def assert_compiled(self) -> CompiledExpression: if self._irast: return self # type: ignore else: raise AssertionError( f"uncompiled expression {self.text!r} (origin: {self.origin})") @classmethod def from_ir( cls: type[Expression], expr: Expression, ir: irast_.Statement, schema: s_schema.Schema, ) -> CompiledExpression: return CompiledExpression( text=expr.text, refs=so.ObjectSet.create(schema, ir.schema_refs), _qlast=expr.parse(), _irast=ir, origin=expr.origin, ) def as_shell(self, schema: s_schema.Schema) -> ExpressionShell: return ExpressionShell( text=self.text, refs=( r.as_shell(schema) for r in self.refs.objects(schema) ) if self.refs is not None else None, _qlast=self._qlast, ) def schema_reduce( self, ) -> tuple[ str, tuple[ str, Optional[tuple[type, ...] | type], tuple[uuid.UUID, ...], tuple[tuple[str, Any], ...], ], Optional[str], ]: assert self.refs is not None, 'expected expression to be compiled' return ( self.text, self.refs.schema_reduce(), self.origin, ) @classmethod def schema_restore( cls, data: tuple[ str, tuple[ str, Optional[tuple[type, ...] | type], tuple[uuid.UUID, ...], tuple[tuple[str, Any], ...], ], Optional[str], ], ) -> Expression: text, refs_data, origin = data return Expression( text=text, refs=so.ObjectCollection.schema_restore(refs_data), origin=origin, ) @classmethod def schema_refs_from_data( cls, data: tuple[ str, tuple[ str, Optional[tuple[type, ...] | type], tuple[uuid.UUID, ...], tuple[tuple[str, Any], ...], ], ], ) -> frozenset[uuid.UUID]: return so.ObjectCollection.schema_refs_from_data(data[1]) @property def ir_statement(self) -> irast_.Statement: """Assert this expr is a compiled EdgeQL statement and return its IR""" from edb.ir import ast as irast_ if not self.is_compiled(): raise AssertionError('expected a compiled expression') ir = self.irast if not isinstance(ir, irast_.Statement): raise AssertionError( 'expected the result of an expression to be a Statement') return ir @property def stype(self) -> s_types.Type: return self.ir_statement.stype @property def cardinality(self) -> qltypes.Cardinality: return self.ir_statement.cardinality @property def schema(self) -> s_schema.Schema: return self.ir_statement.schema class CompiledExpression(Expression): refs = struct.Field( so.ObjectSet, # type: ignore coerce=True, frozen=True, ) def __init__( self, *args: Any, _qlast: Optional[qlast_.Expr] = None, _irast: irast_.Statement, **kwargs: Any ) -> None: super().__init__(*args, _qlast=_qlast, _irast=_irast, **kwargs) @property def irast(self) -> irast_.Statement: assert self._irast return self._irast def as_python_value(self) -> Any: return qlcompiler.evaluate_ir_statement_to_python_val(self.irast) class ExpressionShell(so.Shell): def __init__( self, *, text: str, refs: Optional[Iterable[so.ObjectShell[so.Object]]], _qlast: Optional[qlast_.Expr] = None, _irast: Optional[irast_.Statement] = None, ) -> None: self.text = text self.refs = tuple(refs) if refs is not None else None self._qlast = _qlast self._irast = _irast def resolve(self, schema: s_schema.Schema) -> Expression: cls = CompiledExpression if self._irast else Expression return cls( text=self.text, refs=so.ObjectSet.create( schema, [s.resolve(schema) for s in self.refs], ) if self.refs is not None else None, _qlast=self._qlast, _irast=self._irast, # type: ignore[arg-type] ) def parse(self) -> qlast_.Expr: if self._qlast is None: self._qlast = qlparser.parse_fragment(self.text) return self._qlast def __repr__(self) -> str: if self.refs is None: refs = 'N/A' else: refs = ', '.join(repr(obj) for obj in self.refs) return f'' class ExpressionList(checked.FrozenCheckedList[Expression]): @staticmethod def merge_values( target: so.Object, sources: Sequence[so.Object], field_name: str, *, ignore_local: bool = False, schema: s_schema.Schema, ) -> Any: if not ignore_local: result = target.get_explicit_field_value(schema, field_name, None) else: result = None for source in sources: theirs = source.get_explicit_field_value(schema, field_name, None) if theirs: if result is None: result = theirs[:] else: result.extend(theirs) return result @classmethod def compare_values( cls: type[ExpressionList], ours: Optional[ExpressionList], theirs: Optional[ExpressionList], *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: so.ComparisonContext, compcoef: float, ) -> float: """See the comment in Object.compare_values""" if not ours and not theirs: basecoef = 1.0 elif (not ours or not theirs) or (len(ours) != len(theirs)): basecoef = 0.2 else: similarity = [] for expr1, expr2 in zip(ours, theirs): similarity.append( Expression.compare_values( expr1, expr2, our_schema=our_schema, their_schema=their_schema, context=context, compcoef=compcoef)) basecoef = sum(similarity) / len(similarity) return basecoef + (1 - basecoef) * compcoef class ExpressionDict(checked.CheckedDict[str, Expression]): @staticmethod def merge_values( target: so.Object, sources: Sequence[so.Object], field_name: str, *, ignore_local: bool = False, schema: s_schema.Schema, ) -> Any: result = None # Assume that sources are given in MRO order, so we need to reverse # them to figure out the merged vaue. for source in reversed(sources): theirs = source.get_explicit_field_value(schema, field_name, None) if theirs: if result is None: result = dict(theirs) else: result.update(theirs) # Finally merge the most relevant data. if not ignore_local: ours = target.get_explicit_field_value(schema, field_name, None) if result is None: result = ours elif ours: result.update(ours) return result @classmethod def compare_values( cls: type[ExpressionDict], ours: Optional[ExpressionDict], theirs: Optional[ExpressionDict], *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: so.ComparisonContext, compcoef: float, ) -> float: """See the comment in Object.compare_values""" if not ours and not theirs: basecoef = 1.0 elif (not ours or not theirs) or (len(ours) != len(theirs)): basecoef = 0.2 elif set(ours.keys()) != set(theirs.keys()): # Same length dicts can still have different keys, which is # similar to having mismatched length. basecoef = 0.2 else: # We have the same keys, so just compare the values. similarity = [] for ((_, expr1), (_, expr2)) in zip( sorted(ours.items()), sorted(theirs.items()) ): similarity.append( Expression.compare_values( expr1, expr2, our_schema=our_schema, their_schema=their_schema, context=context, compcoef=compcoef)) basecoef = sum(similarity) / len(similarity) return basecoef + (1 - basecoef) * compcoef EXPRESSION_TYPES = ( Expression, ExpressionList, ExpressionDict ) def imprint_expr_context( qltree: qlast_.Base, modaliases: Mapping[Optional[str], str], ) -> qlast_.Base: # Imprint current module aliases as explicit # alias declarations in the expression. if (isinstance(qltree, qlast_.BaseConstant) or qltree is None or (isinstance(qltree, qlast_.Set) and not qltree.elements) or (isinstance(qltree, qlast_.Array) and all(isinstance(el, qlast_.BaseConstant) for el in qltree.elements))): # Leave constants alone. return qltree if isinstance(qltree, qlast_.Expr): qltree = qlast_.SelectQuery(result=qltree, implicit=True) else: assert isinstance(qltree, (qlast_.Command, qlast_.DDLCommand)) qltree = copy.copy(qltree) qltree.aliases = ( list(qltree.aliases) if qltree.aliases is not None else None) assert isinstance(qltree, (qlast_.Query, qlast_.Command)) existing_aliases: dict[Optional[str], str] = {} for alias in (qltree.aliases or ()): if isinstance(alias, qlast_.ModuleAliasDecl): existing_aliases[alias.alias] = alias.module aliases_to_add = set(modaliases) - set(existing_aliases) for alias_name in aliases_to_add: if qltree.aliases is None: qltree.aliases = [] qltree.aliases.append( qlast_.ModuleAliasDecl( alias=alias_name, module=modaliases[alias_name], ) ) return qltree def get_expr_referrers( schema: s_schema.Schema, obj: so.Object ) -> dict[so.Object, list[str]]: """Return schema referrers with refs in expressions.""" refs: dict[tuple[type[so.Object], str], frozenset[so.Object]] = ( schema.get_referrers_ex(obj)) result: dict[so.Object, list[str]] = {} for (mcls, fn), referrers in refs.items(): field = mcls.get_field(fn) if issubclass(field.type, (Expression, ExpressionList)): for ref in referrers: result.setdefault(ref, []).append(fn) return result ================================================ FILE: edb/schema/expraliases.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, TYPE_CHECKING from edb import errors from edb.common import parsing from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from edb.edgeql.compiler import astutils as qlastutils from . import annos as s_anno from . import expr as s_expr from . import delta as sd from . import name as sn from . import objects as so from . import types as s_types if TYPE_CHECKING: from edb.ir import ast as irast from . import schema as s_schema class Alias( so.QualifiedObject, s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.ALIAS, data_safe=True, ): expr = so.SchemaField( s_expr.Expression, default=None, coerce=True, compcoef=0.909, ) type = so.SchemaField( s_types.Type, compcoef=0.909, ) created_types = so.SchemaField( so.ObjectSet[s_types.Type], default=so.DEFAULT_CONSTRUCTOR, ) class AliasCommandContext( sd.ObjectCommandContext[so.Object], s_anno.AnnotationSubjectCommandContext ): pass class AliasLikeCommand( sd.QualifiedObjectCommand[so.QualifiedObject_T], ): """Common code for "alias-likes": that is, aliases and globals Aliases and computed globals behave extremely similarly, except for a few annoying differences that need to be handled in the subclasses with appropriate overloads: * In aliases, the type field name is 'type', while for computed globals it is 'target'. This annoying discrepency is because computed globals also share a bunch of code paths with pointers, which use 'target', and so there needed to be a mismatch on one of the sides. Handled by overloading TYPE_FIELD_NAME. * For aliases, it is the generated view type that gets the real name and the alias that gets the mangled one. For globals, the real global needs to get the real name, so that the name does not depend on whether it is computed or not. This is handled by overloading _get_alias_name, which computes the name of the alias type (and _classname_from_ast). * Also aliases *always* are alias-like, while globals only are when computed. This is handled by overloading _is_computable. Computed globals also have explicit 'required' and 'cardinality' fields, which are managed explicitly in the globals code. """ TYPE_FIELD_NAME = '' ALIAS_LIKE_EXPR_FIELDS: tuple[str, ...] = () @classmethod def _get_alias_name(cls, type_name: sn.QualName) -> sn.QualName: raise NotImplementedError @classmethod def _is_computable( cls, obj: so.QualifiedObject_T, schema: s_schema.Schema ) -> bool: raise NotImplementedError # Generic code def _delete_alias_types( self, scls: so.QualifiedObject_T, schema: s_schema.Schema, context: sd.CommandContext, *, unset_type: bool = True, ) -> sd.CommandGroup: from . import globals as s_globals from . import ordering as s_ordering assert isinstance(scls, (Alias, s_globals.Global)) types: so.ObjectSet[s_types.Type] = scls.get_created_types(schema) created = types.objects(schema) delta = sd.DeltaRoot() # Unset created_types and type/target, so the types can be dropped alter_alias = scls.init_delta_command(schema, sd.AlterObject) alter_alias.canonical = True # This would usually not be needed, as both Alias.type and # Global.target have ON TARGET DELETE DEFERRED RESTRICT. # But because we are using "if_unused", we want to delete references # to these types so they get dropped if had been the only ref. if unset_type: # (there are cases when we don't need to unset the type, such as # when a computed global has been converted to a non-computed one) alter_alias.add( sd.AlterObjectProperty( property=self.TYPE_FIELD_NAME, new_value=None ) ) alter_alias.set_attribute_value('created_types', set()) for dep_type in created: if_unused = isinstance(dep_type, s_types.Collection) drop_dep = dep_type.init_delta_command( schema, sd.DeleteObject, if_exists=True, if_unused=if_unused ) subcmds = drop_dep._canonicalize(schema, context, dep_type) drop_dep.update(subcmds) delta.add(drop_dep) delta = s_ordering.linearize_delta( delta, old_schema=schema, new_schema=schema ) delta.prepend(alter_alias) return delta @classmethod def get_type( cls, obj: so.QualifiedObject_T, schema: s_schema.Schema ) -> s_types.Type: obj = obj.get_field_value(schema, cls.TYPE_FIELD_NAME) assert isinstance(obj, s_types.Type) return obj @classmethod def _mangle_name( cls, type_name: sn.QualName, *, include_module_in_name: bool, ) -> sn.QualName: base_name = ( type_name if include_module_in_name else type_name.get_local_name() ) quals = (cls.get_schema_metaclass().get_schema_class_displayname(),) pnn = sn.get_specialized_name(base_name, str(type_name), *quals) name = sn.QualName(name=pnn, module=type_name.module) assert isinstance(name, sn.QualName) return name def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: if field.name in self.ALIAS_LIKE_EXPR_FIELDS: rt = self.get_type(self.scls, schema) return s_types.type_dummy_expr(rt, schema) else: raise NotImplementedError(f'unhandled field {field.name!r}') def _handle_alias_op( self, *, expr: s_expr.Expression, classname: sn.QualName, schema: s_schema.Schema, context: sd.CommandContext, is_alter: bool = False, span: Optional[parsing.Span] = None, ) -> tuple[ sd.Command, s_types.TypeShell[s_types.Type], s_expr.Expression, set[so.ObjectShell[s_types.Type]], ]: pschema = schema # For alters, drop the alias first, use the schema without the alias # for compilation of the new alias expr drop_old_types_cmd: Optional[sd.Command] = None if is_alter: drop_old_types_cmd = self._delete_alias_types( self.scls, schema, context) with context.suspend_dep_verification(): pschema = drop_old_types_cmd.apply(pschema, context) ir = compile_alias_expr( expr.parse(), classname, pschema, context, span=span, ) expr = s_expr.Expression.from_ir(expr, ir, schema=schema) is_global = (self.get_schema_metaclass(). get_schema_class_displayname() == 'global') cmd, type_shell, created_types = _create_alias_types( expr=expr, classname=classname, schema=schema, is_global=is_global, span=span, ) if drop_old_types_cmd: cmd.prepend(drop_old_types_cmd) return cmd, type_shell, expr, created_types class AliasCommand( AliasLikeCommand[Alias], context_class=AliasCommandContext, ): TYPE_FIELD_NAME = 'type' ALIAS_LIKE_EXPR_FIELDS = ('expr',) @classmethod def _get_alias_name(cls, type_name: sn.QualName) -> sn.QualName: alias_name = sn.shortname_from_fullname(type_name) assert isinstance(alias_name, sn.QualName), "expected qualified name" return alias_name @classmethod def _is_computable(cls, obj: Alias, schema: s_schema.Schema) -> bool: return True @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> sn.QualName: type_name = super()._classname_from_ast(schema, astnode, context) return cls._mangle_name(type_name, include_module_in_name=True) def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: assert field.name == 'expr' classname = sn.shortname_from_fullname(self.classname) assert isinstance(classname, sn.QualName), \ "expected qualified name" return value.compiled( schema=schema, options=qlcompiler.CompilerOptions( derived_target_module=classname.module, modaliases=context.modaliases, in_ddl_context_name='alias definition', schema_object_context=self.get_schema_metaclass(), track_schema_ref_exprs=track_schema_ref_exprs, ), context=context, ) class CreateAliasLike( AliasLikeCommand[so.QualifiedObject_T], sd.CreateObject[so.QualifiedObject_T], ): def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if not context.canonical and self.get_attribute_value('expr'): alias_name = self._get_alias_name(self.classname) # generated types might conflict with existing types if other_obj := schema.get(alias_name, default=None): vn = other_obj.get_verbosename(schema, with_parent=True) raise errors.SchemaError(f'{vn} already exists') type_cmd, type_shell, expr, created_types = self._handle_alias_op( expr=self.get_attribute_value('expr'), classname=alias_name, schema=schema, context=context, span=self.get_attribute_span('expr'), ) self.add_prerequisite(type_cmd) self.set_attribute_value('expr', expr) self.set_attribute_value( self.TYPE_FIELD_NAME, type_shell, computed=True) self.set_attribute_value('created_types', created_types) return super()._create_begin(schema, context) class CreateAlias( CreateAliasLike[Alias], AliasCommand, ): astnode = qlast.CreateAlias class RenameAliasLike( AliasLikeCommand[so.QualifiedObject_T], sd.RenameObject[so.QualifiedObject_T], ): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if not context.canonical and self._is_computable(self.scls, schema): assert isinstance(self.new_name, sn.QualName) new_alias_name = self._get_alias_name(self.new_name) alias_type = self.get_type(self.scls, schema) alter_cmd = alias_type.init_delta_command(schema, sd.AlterObject) rename_cmd = alias_type.init_delta_command( schema, sd.RenameObject, new_name=new_alias_name, ) alter_cmd.add(rename_cmd) self.add_prerequisite(alter_cmd) return super()._alter_begin(schema, context) class RenameAlias(RenameAliasLike[Alias], AliasCommand): pass class AlterAliasLike( AliasLikeCommand[so.QualifiedObject_T], sd.AlterObject[so.QualifiedObject_T], ): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if not context.canonical: schema = self._propagate_if_expr_refs( schema, context, action=self.get_friendly_description(schema=schema), ) expr = self.get_attribute_value('expr') is_computable = self._is_computable(self.scls, schema) if expr: alias_name = self._get_alias_name(self.classname) type_cmd, type_shell, expr, created_tys = self._handle_alias_op( expr=expr, classname=alias_name, schema=schema, context=context, is_alter=is_computable, span=self.get_attribute_span('expr'), ) self.add_prerequisite(type_cmd) self.set_attribute_value('expr', expr) self.set_attribute_value( self.TYPE_FIELD_NAME, type_shell, computed=True) self.set_attribute_value('created_types', created_tys) # Clear out the type field in the schema *now*, # before we call the parent _alter_begin, which will # run prerequisites. This prevents the type reference # from interferring with deletion. (And the deletion of # the type has to be done as a prereq, since it needs # to precede the creation of the replacement type # with the same name.) schema = schema.unset_field( self.scls, self.TYPE_FIELD_NAME) else: # there is no expr if is_computable and self.has_attribute_value('expr'): # this is a global that just had its expr unset self.add( self._delete_alias_types( self.scls, schema, context, unset_type=False ) ) return super()._alter_begin(schema, context) class AlterAlias( AlterAliasLike[Alias], AliasCommand, ): astnode = qlast.AlterAlias class DeleteAliasLike[QualifiedObject_T: so.QualifiedObject]( AliasLikeCommand[QualifiedObject_T], sd.DeleteObject[QualifiedObject_T], ): def _canonicalize( self, schema: s_schema.Schema, context: sd.CommandContext, scls: QualifiedObject_T, ) -> list[sd.Command]: ops = super()._canonicalize(schema, context, scls) if self._is_computable(scls, schema): ops.append(self._delete_alias_types(scls, schema, context)) return ops class DeleteAlias( DeleteAliasLike[Alias], AliasCommand, ): astnode = qlast.DropAlias def compile_alias_expr( expr: qlast.Expr, classname: sn.QualName, schema: s_schema.Schema, context: sd.CommandContext, span: Optional[parsing.Span] = None, ) -> irast.Statement: cached: Optional[irast.Statement] = ( context.get_cached((expr, classname))) if cached is not None: return cached expr = qlastutils.ensure_ql_query(expr) ir = qlcompiler.compile_ast_to_ir( expr, schema, options=qlcompiler.CompilerOptions( derived_target_module=classname.module, result_view_name=classname, modaliases=context.modaliases, schema_view_mode=True, schema_object_context=Alias, in_ddl_context_name='alias definition', bootstrap_mode=context.stdmode, ), ) if ir.volatility.is_volatile(): raise errors.SchemaDefinitionError( f'volatile functions are not permitted in schema-defined ' f'computed expressions', span=span, ) context.cache_value((expr, classname), ir) return ir def _create_alias_types( *, expr: s_expr.CompiledExpression, classname: sn.QualName, schema: s_schema.Schema, is_global: bool, span: Optional[parsing.Span] = None, ) -> tuple[ sd.Command, s_types.TypeShell[s_types.Type], set[so.ObjectShell[s_types.Type]], ]: from . import ordering as s_ordering from edb.ir import utils as irutils ir = expr.irast new_schema = ir.schema derived_delta = sd.DeltaRoot() created_type_shells: set[so.ObjectShell[s_types.Type]] = set() for ty_id in irutils.collect_schema_types(ir.expr): ty = new_schema.get_by_id(ty_id, type=s_types.Type) name = ty.get_name(new_schema) if schema.has_object(ty_id): # This is not a new type. # Add any existing collection types to the `create_types` of aliases # and computed globals. This adds a reference which prevents their # deletion as other types are deleted. if ( not ty.get_from_alias(schema) and isinstance(ty, s_types.Collection) ): created_type_shells.add( so.ObjectShell(name=name, schemaclass=type(ty)) ) continue if ( not isinstance(ty, s_types.Collection) and not _has_alias_name_prefix(classname, name) ): # not all created types are visible from the root, so they don't # need to be created in the schema continue # Schema views in derive an alias subtype for their expressions, which # are stored in the schema with `from_alias` set to True. Aliases will # also store any new types from their expressions into the schema. # # This is not an issue for most derived types since they are used only # in their alias expression. # # However, collections which are not expr aliases that are created in # an alias expression may be used in other places and should be # not be created as `from_alias` or other alias/global associated # fields. if ( not isinstance(ty, s_types.Collection) or isinstance(ty, s_types.CollectionExprAlias) ): new_schema = ty.update( new_schema, dict( alias_is_persistent=True, expr_type=s_types.ExprType.Select, from_alias=True, from_global=is_global, ), ) new_schema = ty.update( new_schema, dict( internal=False, builtin=False, ), ) if isinstance(ty, s_types.Collection): new_schema = ty.set_field_value(new_schema, 'is_persistent', True) derived_delta.add( ty.as_create_delta( schema=new_schema, context=so.ComparisonContext() ) ) created_type_shells.add(so.ObjectShell(name=name, schemaclass=type(ty))) derived_delta = s_ordering.linearize_delta( derived_delta, old_schema=schema, new_schema=new_schema ) type_cmd = None for op in derived_delta.get_subcommands(): assert isinstance(op, sd.ObjectCommand) if op.classname == classname: type_cmd = op break assert type_cmd type_cmd.set_attribute_value('expr', expr) result = sd.CommandGroup() result.update(derived_delta.get_subcommands()) type_shell = s_types.TypeShell( name=classname, origname=classname, schemaclass=type_cmd.get_schema_metaclass(), span=span, ) return result, type_shell, created_type_shells def _has_alias_name_prefix( alias_name: sn.QualName, name: sn.Name, ) -> bool: return alias_name == name or ( isinstance(name, sn.QualName) and name.module == alias_name.module and name.name.startswith(f'__{alias_name.name}__') ) ================================================ FILE: edb/schema/extensions.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Iterator, cast import collections import contextlib import uuid from edb import errors from edb.common import verutils from edb.common import struct from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.edgeql import parser as qlparser from edb.common import checked from . import annos as s_anno from . import casts as s_casts from . import delta as sd from . import indexes as s_indexes from . import name as sn from . import objects as so from . import schema as s_schema class ExtensionPackage( so.GlobalObject, s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.EXTENSION_PACKAGE, data_safe=False, ): # Note: !!!!!! # ExtensionPackage, like all GlobalObjects, needs to store its # data in globally stored JSON instead of via reflection schema. # When you add a field to ExtensionPackage, you must also update # CreateExtensionPackage in pgsql/delta.py and # _generate_extension_views in metaschema to store and retrieve # the data from json. version = so.SchemaField( verutils.Version, compcoef=0.9, ) script = so.SchemaField( str, compcoef=0.9, ) sql_extensions = so.SchemaField( checked.FrozenCheckedSet[str], default=so.DEFAULT_CONSTRUCTOR, coerce=True, inheritable=False, compcoef=0.9, ) sql_setup_script = so.SchemaField( str, default=None, compcoef=0.9) sql_teardown_script = so.SchemaField( str, default=None, compcoef=0.9) ext_module = so.SchemaField( str, default=None, compcoef=0.9) # It uses str instead of direct references so we can stick # versions in there eventually dependencies = so.SchemaField( checked.FrozenCheckedSet[str], default=so.DEFAULT_CONSTRUCTOR, coerce=True, inheritable=False, compcoef=0.9, ) @classmethod def get_shortname_static(cls, name: sn.Name) -> sn.UnqualName: return sn.UnqualName(sn.shortname_from_fullname(name).name) @classmethod def get_displayname_static(cls, name: sn.Name) -> str: shortname = cls.get_shortname_static(name) return shortname.name class ExtensionPackageMigration( so.GlobalObject, s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.EXTENSION_PACKAGE_MIGRATION, data_safe=False, ): # Note: !!!!!! # ExtensionPackageMigration, like all GlobalObjects, needs to store its # data in globally stored JSON instead of via reflection schema. # When you add a field to ExtensionPackageMigration, you must also update # CreateExtensionPackageMigration in pgsql/delta.py and # _generate_extension_views in metaschema to store and retrieve # the data from json. from_version = so.SchemaField( verutils.Version, compcoef=0.9, ) to_version = so.SchemaField( verutils.Version, compcoef=0.9, ) script = so.SchemaField( str, compcoef=0.9, ) sql_early_script = so.SchemaField( str, default=None, compcoef=0.9) sql_late_script = so.SchemaField( str, default=None, compcoef=0.9) @classmethod def get_shortname_static(cls, name: sn.Name) -> sn.UnqualName: return sn.UnqualName(sn.shortname_from_fullname(name).name) @classmethod def get_displayname_static(cls, name: sn.Name) -> str: shortname = cls.get_shortname_static(name) return shortname.name class Extension( so.Object, qlkind=qltypes.SchemaObjectClass.EXTENSION, data_safe=False, ): package = so.SchemaField( ExtensionPackage, compcoef=0.9, ) dependencies = so.SchemaField( so.ObjectList['Extension'], default=so.DEFAULT_CONSTRUCTOR, coerce=True, inheritable=False, compcoef=0.9, ) @classmethod def create_in_schema[Schema_T: s_schema.Schema]( cls: type[Extension], schema: Schema_T, stable_ids: bool = False, *, id: Optional[uuid.UUID] = None, **data: Any, ) -> tuple[Schema_T, Extension]: name = data['name'] pkg = data['package'] if existing_ext := schema.get_global(Extension, name, default=None): vn = existing_ext.get_verbosename(schema) existing_pkg = existing_ext.get_package(schema) raise errors.SchemaError( f'cannot install {vn} version {pkg.get_version(schema)}: ' f'version {existing_pkg.get_version(schema)} is already ' f'installed' ) return super().create_in_schema(schema, stable_ids, id=id, **data) class ExtensionPackageCommandContext( sd.ObjectCommandContext[ExtensionPackage], s_anno.AnnotationSubjectCommandContext, ): pass class ExtensionPackageCommand( sd.GlobalObjectCommand[ExtensionPackage], s_anno.AnnotationSubjectCommand[ExtensionPackage], context_class=ExtensionPackageCommandContext, ): @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext ) -> sn.UnqualName: assert isinstance(astnode, qlast.ExtensionPackageCommand) parsed_version = verutils.parse_version(astnode.version.value) quals = ['pkg', str(parsed_version)] pnn = sn.get_specialized_name(sn.UnqualName(astnode.name.name), *quals) return sn.UnqualName(pnn) def get_package( name: sn.Name, version: Optional[verutils.Version], schema: s_schema.Schema ) -> ExtensionPackage: filters = [ lambda schema, pkg: ( pkg.get_shortname(schema) == name ) ] # Version specs are always implicitly >=. if version is not None: filters.append( lambda schema, pkg: ( pkg.get_version(schema) >= version ) ) pkgs = list(schema.get_objects( type=ExtensionPackage, extra_filters=filters, )) if not pkgs: dname = str(name) if version is None: raise errors.SchemaError( f'cannot create extension {dname!r}:' f' extension package {dname!r} does' f' not exist' ) else: raise errors.SchemaError( f'cannot create extension {dname!r}:' f' extension package {dname!r} version' f' {str(version)!r} does not exist' ) pkgs.sort(key=lambda pkg: pkg.get_version(schema)) # If the exact version exists, then use it. Otherwise, take the # newest version. if pkgs[0].get_version(schema) == version: return pkgs[0] else: return pkgs[-1] def get_package_migrations( name: sn.Name, from_version: verutils.Version, to_version: verutils.Version, schema: s_schema.Schema, ) -> list[ExtensionPackageMigration]: # TODO: We need to figure out migration chains # # That will have some fiddliness, though, with when SQL extension # upgrades and SQL scripts run? filters = [ lambda schema, mig: ( mig.get_shortname(schema) == name ) ] migs = list(schema.get_objects( type=ExtensionPackageMigration, extra_filters=filters, )) # Build a graph of available migrations. We make this # complicated, just in case, but probably it will be simple. # TODO: What about missing packages? graph: dict[ verutils.Version, list[tuple[verutils.Version, ExtensionPackageMigration]], ] = {} for mig in migs: fromv = mig.get_from_version(schema) tov = mig.get_to_version(schema) graph.setdefault(fromv, []).append((tov, mig)) for tgts in graph.values(): tgts.sort() # BFS it out sources = {} todo = collections.deque([from_version]) while todo: cur_node = todo.popleft() if cur_node == to_version: break for next_ver, mig in graph.get(cur_node, []): if next_ver not in sources: sources[next_ver] = cur_node, mig todo.append(next_ver) else: dname = str(name) raise errors.SchemaError( f'cannot create migrate extension {dname!r} from ' f'{from_version} to {to_version}' ) # Trace back the path mig_path = [] while cur_node in sources: cur_node, mig = sources[cur_node] mig_path.append(mig) mig_path.reverse() return mig_path # XXX: Trying to CREATE/DROP these from within a transaction managed # to get me stuck getting "Cannot serialize global DDL" errors. # # I'm haven't fully investigated whether it is actually sensible to do # this kind of global command in a transaction, but we currently allow # it and some tests do it. class CreateExtensionPackage( ExtensionPackageCommand, sd.CreateObject[ExtensionPackage], ): astnode = qlast.CreateExtensionPackage @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> CreateExtensionPackage: if not context.stdmode and not context.testmode: raise errors.UnsupportedFeatureError( 'user-defined extension packages are not supported yet', span=astnode.span ) cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, CreateExtensionPackage) assert isinstance(astnode, qlast.CreateExtensionPackage) assert astnode.body.text is not None parsed_version = verutils.parse_version(astnode.version.value) cmd.set_attribute_value('version', parsed_version) cmd.set_attribute_value('script', astnode.body.text) cmd.set_attribute_value('builtin', context.stdmode) if not cmd.has_attribute_value('internal'): cmd.set_attribute_value('internal', False) return cmd def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: assert isinstance(node, qlast.CreateExtensionPackage) if op.property == 'script': node.body = qlast.NestedQLBlock( text=op.new_value, commands=cast( list[qlast.DDLOperation], qlparser.parse_block(op.new_value)), ) elif op.property == 'version': node.version = qlast.Constant.string( value=str(op.new_value), ) else: super()._apply_field_ast(schema, context, node, op) class DeleteExtensionPackage( ExtensionPackageCommand, sd.DeleteObject[ExtensionPackage], ): astnode = qlast.DropExtensionPackage def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if ( not context.stdmode and not context.testmode and self.scls.get_builtin(schema) ): name = self.scls.get_shortname(schema) raise errors.UnsupportedFeatureError( f"cannot drop builtin extension package '{name}'", span=self.span, ) return super()._delete_begin(schema, context) class ExtensionPackageMigrationCommandContext( sd.ObjectCommandContext[ExtensionPackageMigration], s_anno.AnnotationSubjectCommandContext, ): pass class ExtensionPackageMigrationCommand( sd.GlobalObjectCommand[ExtensionPackageMigration], s_anno.AnnotationSubjectCommand[ExtensionPackageMigration], context_class=ExtensionPackageCommandContext, ): @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext ) -> sn.UnqualName: assert isinstance(astnode, ( qlast.CreateExtensionPackageMigration, qlast.DropExtensionPackageMigration, )) from_version = verutils.parse_version(astnode.from_version.value) to_version = verutils.parse_version(astnode.to_version.value) quals = ['pkg-migration', str(from_version), str(to_version)] pnn = sn.get_specialized_name(sn.UnqualName(astnode.name.name), *quals) return sn.UnqualName(pnn) class CreateExtensionPackageMigration( ExtensionPackageMigrationCommand, sd.CreateObject[ExtensionPackageMigration], ): astnode = qlast.CreateExtensionPackageMigration @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> CreateExtensionPackageMigration: if not context.stdmode and not context.testmode: raise errors.UnsupportedFeatureError( 'user-defined extension packages are not supported yet', span=astnode.span ) cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, CreateExtensionPackageMigration) assert isinstance(astnode, qlast.CreateExtensionPackageMigration) assert astnode.body.text is not None from_version = verutils.parse_version(astnode.from_version.value) cmd.set_attribute_value('from_version', from_version) to_version = verutils.parse_version(astnode.to_version.value) cmd.set_attribute_value('to_version', to_version) cmd.set_attribute_value('script', astnode.body.text) cmd.set_attribute_value('builtin', context.stdmode) if not cmd.has_attribute_value('internal'): cmd.set_attribute_value('internal', False) return cmd def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: assert isinstance(node, qlast.CreateExtensionPackageMigration) if op.property == 'script': node.body = qlast.NestedQLBlock( text=op.new_value, commands=cast( list[qlast.DDLOperation], qlparser.parse_block(op.new_value)), ) elif op.property == 'from_version': node.from_version = qlast.Constant.string( value=str(op.new_value), ) elif op.property == 'to_version': node.to_version = qlast.Constant.string( value=str(op.new_value), ) else: super()._apply_field_ast(schema, context, node, op) class DeleteExtensionPackageMigration( ExtensionPackageMigrationCommand, sd.DeleteObject[ExtensionPackageMigration], ): astnode = qlast.DropExtensionPackageMigration def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if ( not context.stdmode and not context.testmode and self.scls.get_builtin(schema) ): name = self.scls.get_shortname(schema) raise errors.UnsupportedFeatureError( f"cannot drop builtin extension package migration '{name}'", span=self.span, ) return super()._delete_begin(schema, context) class ExtensionCommandContext( sd.ObjectCommandContext[Extension], ): pass class ExtensionCommand( sd.ObjectCommand[Extension], context_class=ExtensionCommandContext, ): def _get_dependencies( self, pkg: ExtensionPackage, schema: s_schema.Schema, ) -> list[Extension]: deps = [] for dep_name in pkg.get_dependencies(schema): if '>=' not in dep_name: builtin = 'built-in ' if pkg.get_builtin(schema) else '' raise errors.SchemaError( f'{builtin}extension {self.classname} missing ' f'version for {dep_name}') dep_name, dep_version_s = dep_name.split('>=') dep = schema.get_global(Extension, dep_name, default=None) if not dep: raise errors.SchemaError( f'cannot create extension {self.get_displayname()!r} ' f'version {str(pkg.get_version(schema))!r}: ' f'it depends on extension {dep_name} which has not been' f' created' ) dep_version = verutils.parse_version(dep_version_s) real_version = dep.get_package(schema).get_version(schema) if dep_version > real_version: raise errors.SchemaError( f'cannot create extension {self.get_displayname()!r} ' f'version {str(pkg.get_version(schema))!r}: ' f'it depends on extension {dep_name}, but the wrong ' f'version is installed: {real_version} is present but ' f'{dep_version} is required' ) deps.append(dep) return deps @contextlib.contextmanager def _extension_mode(context: sd.CommandContext) -> Iterator[None]: # TODO: We'll want to be a bit more discriminating once we support # user extensions, and not set stable_ids then? stable_ids = context.stable_ids testmode = context.testmode declarative = context.declarative context.stable_ids = True context.testmode = True context.declarative = False try: yield finally: context.stable_ids = stable_ids context.testmode = testmode context.declarative = declarative class CreateExtension( ExtensionCommand, sd.CreateObject[Extension], ): astnode = qlast.CreateExtension def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: with _extension_mode(context): return super().apply(schema, context) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext ) -> CreateExtension: assert isinstance(astnode, qlast.CreateExtension) cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, CreateExtension) if astnode.version is not None: parsed_version = verutils.parse_version(astnode.version.value) cmd.set_attribute_value('version', parsed_version) return cmd def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._create_begin(schema, context) if not context.canonical: package = self.scls.get_package(schema) module = package.get_ext_module(schema) if module: module_name = sn.UnqualName(module) if module_name.get_root_module_name() != s_schema.EXT_MODULE: builtin = 'built-in ' if package.get_builtin(schema) else '' raise errors.SchemaError( f'{builtin}extension {self.classname} has invalid ' f'module "{module}": ' f'extension modules must begin with "ext::"' ) script = package.get_script(schema) if script: block, _ = qlparser.parse_extension_package_body_block(script) for subastnode in block.commands: subcmd = sd.compile_ddl( schema, subastnode, context=context) if subcmd is not None: self.add(subcmd) return schema def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) pkg: ExtensionPackage if pkg_attr := self.get_attribute_value('package'): pkg = pkg_attr.resolve(schema) else: # If we're restoring a dump ignore the extension package version # as the current EdgeDB might have a different version available # and we don't have a way to select specific versions yet. # # Use `compat_ver` as a way to detect that we're working with a # dump rather than some other operation. if context.compat_ver is not None: version = None else: version = self.get_attribute_value('version') pkg = get_package(self.classname, version, schema) self.discard_attribute('version') self.set_attribute_value('package', pkg) deps = self._get_dependencies(pkg, schema) self.set_attribute_value('dependencies', deps) return schema # XXX: I think this is wrong, but it might not matter ever. def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: node = super()._get_ast(schema, context, parent_node=parent_node) assert isinstance(node, qlast.CreateExtension) pkg = self.get_resolved_attribute_value( 'package', schema=schema, context=context) # When performing dumps we don't want to include the extension version # as we're not guaranteed that the same version will be avaialble when # restoring the dump. We also have no mechanism of installing a specific # extension version, yet. if context.include_ext_version: node.version = qlast.Constant.string( value=str(pkg.get_version(schema)) ) return node class AlterExtension( ExtensionCommand, sd.AlterObject[Extension], ): astnode = qlast.AlterExtension to_version = struct.Field(verutils.Version, default=None) migration = struct.Field(ExtensionPackageMigration, default=None) def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: with _extension_mode(context): return super().apply(schema, context) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext ) -> AlterExtension: assert isinstance(astnode, qlast.AlterExtension) cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, AlterExtension) cmd.to_version = verutils.parse_version(astnode.to_version.value) return cmd def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: # HACK: AlterObject insists on filtering out any ALTERs # without subcommands, but we don't have any, so skip # AlterObject. return super(sd.AlterObject, self)._get_ast( schema, context, parent_node=parent_node ) def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: if op.property == 'package': assert isinstance(node, qlast.AlterExtension) package = op.new_value.resolve(schema) node.to_version = qlast.Constant.string( str(package.get_version(schema)) ) else: super()._apply_field_ast(schema, context, node, op) def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) pkg: ExtensionPackage if pkg_attr := self.get_attribute_value('package'): pkg = pkg_attr.resolve(schema) else: assert self.to_version pkg = get_package(self.classname, self.to_version, schema) self.set_attribute_value('package', pkg) deps = self._get_dependencies(pkg, schema) self.set_attribute_value('dependencies', deps) return schema def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: from_version = self.scls.get_package(schema).get_version(schema) schema = super()._alter_begin(schema, context) if not context.canonical: assert self.to_version if not self.migration: migrations = get_package_migrations( self.classname, from_version, self.to_version, schema ) else: migrations = [self.migration] if len(migrations) == 1: self.migration = migrations[0] script = self.migration.get_script(schema) if script: block, _ = qlparser.parse_extension_package_body_block( script) for subastnode in block.commands: subcmd = sd.compile_ddl( schema, subastnode, context=context) if subcmd is not None: self.add(subcmd) else: for migration in migrations: self.add(AlterExtension( classname=self.classname, to_version=migration.get_to_version(schema), migration=migration, )) return schema class DeleteExtension( ExtensionCommand, sd.DeleteObject[Extension], ): astnode = qlast.DropExtension def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: module = self.scls.get_package(schema).get_ext_module(schema) schema = super()._delete_begin(schema, context) if context.canonical or not module: return schema # If the extension included a module, delete everything in it. from . import ddl as s_ddl module_name = sn.UnqualName(module) def _name_in_mod(name: sn.Name) -> bool: return ( ( isinstance(name, sn.QualName) and ( name.module == module or name.module.startswith(module + '::') ) ) or ( isinstance(name, sn.UnqualName) and ( name == module_name or name.name.startswith(module + '::') ) ) ) # Clean up the casts separately because we can't keep them in # our own module, so we keep them in __ext_casts__. (Cast # names are derived solely from the names of their from and to # types, which means that if we have a cast between ext::a::T # and ext::b::S, we wouldn't have a way to distinguish which # is should be.) casts_cleanup: list[sd.Command] = [] for obj in schema.get_objects( included_modules=(sn.UnqualName('__ext_casts__'),), type=s_casts.Cast, ): if ( _name_in_mod(obj.get_from_type(schema).get_name(schema)) or _name_in_mod(obj.get_to_type(schema).get_name(schema)) ): drop = obj.init_delta_command( schema, sd.DeleteObject, ) casts_cleanup.append(drop) # Similarly, index matches are kept in __ext_index_matches__. We can # remove them first since nothing else depends on them. for im in schema.get_objects( included_modules=(sn.UnqualName('__ext_index_matches__'),), type=s_indexes.IndexMatch, ): if ( _name_in_mod(im.get_valid_type(schema).get_name(schema)) or _name_in_mod(im.get_index(schema).get_name(schema)) ): self.add(im.init_delta_command(schema, sd.DeleteObject)) def filt(schema: s_schema.Schema, obj: so.Object) -> bool: return not _name_in_mod(obj.get_name(schema)) or obj == self.scls # We handle deleting the module contents in a heavy-handed way: # do a schema diff. module_names = [ m.get_name(schema) for m in schema.get_modules() if _name_in_mod(m.get_name(schema)) ] delta = s_ddl.delta_schemas( schema, schema, included_modules=module_names, schema_b_filters=[filt], include_extensions=True, linearize_delta=True, ) # The output of delta_schemas is really just intended to be # dumped as an AST. So, sigh, just do that, and then read it # back. # # This is horrific, but it does actually work and is built # around codepaths that are heavily tested. from . import ddl for subast in ddl.ddlast_from_delta(None, schema, delta): # We want to clean the casts right before we're cleaning the # scalar types. Cleaning casts earlier may cause issues with # functions that use casts in their signatures as part of the # default expression. if casts_cleanup and isinstance(subast, qlast.DropScalarType): self.update(casts_cleanup) casts_cleanup.clear() self.add(sd.compile_ddl(schema, subast, context=context)) return schema def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: with _extension_mode(context): return super().apply(schema, context) ================================================ FILE: edb/schema/functions.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import abc import types import uuid import builtins from typing import ( Any, Optional, TypeVar, Iterable, Mapping, Sequence, cast, TYPE_CHECKING, ) from edb import errors from edb.common import ast from edb.common import parsing from edb.common import struct from edb.common import verutils from edb.common import lru from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes as ft from edb.edgeql import parser as qlparser from edb.edgeql import qltypes from edb.common import uuidgen from . import annos as s_anno from . import delta as sd from . import expr as s_expr from . import globals as s_globals from . import name as sn from . import objects as so from . import permissions as s_permissions from . import referencing from . import types as s_types from . import utils from . import schema as s_schema if TYPE_CHECKING: from edb.edgeql.compiler import context as qlcontext from edb.ir import ast as irast FUNC_NAMESPACE = uuidgen.UUID('80cd3b19-bb51-4659-952d-6bb03e3347d7') def param_as_str( schema: s_schema.Schema, param: ParameterDesc | Parameter, ) -> str: ret = [] kind = param.get_kind(schema) typemod = param.get_typemod(schema) default = param.get_default(schema) if kind is not ft.ParameterKind.PositionalParam: ret.append(kind.to_edgeql()) ret.append(' ') ret.append(f'{param.get_parameter_name(schema)}: ') if typemod is not ft.TypeModifier.SingletonType: ret.append(typemod.to_edgeql()) ret.append(' ') paramt: s_types.Type | s_types.TypeShell[s_types.Type] if isinstance(param, ParameterDesc): paramt = param.get_type_shell(schema) else: paramt = param.get_type(schema) ret.append(paramt.get_displayname(schema)) if default is not None: ret.append(f'={default.text}') return ''.join(ret) def canonical_param_sort[ParameterLike_T: "ParameterLike"]( schema: s_schema.Schema, params: Iterable[ParameterLike_T], ) -> tuple[ParameterLike_T, ...]: canonical_order = [] named = [] variadic = None for param in params: param_kind = param.get_kind(schema) if param_kind is ft.ParameterKind.PositionalParam: canonical_order.append(param) elif param_kind is ft.ParameterKind.NamedOnlyParam: named.append(param) else: variadic = param if variadic is not None: canonical_order.append(variadic) if named: named.sort(key=lambda p: p.get_name(schema)) named.extend(canonical_order) canonical_order = named return tuple(canonical_order) def param_is_inherited( schema: s_schema.Schema, func: CallableObject, param: ParameterLike, ) -> bool: qualname = sn.get_specialized_name( sn.UnqualName(param.get_parameter_name(schema)), str(func.get_name(schema)), ) param_name = param.get_name(schema) assert isinstance(param_name, sn.QualName) return qualname != param_name.name class ParameterLike: def get_parameter_name(self, schema: s_schema.Schema) -> str: raise NotImplementedError def get_name(self, schema: s_schema.Schema) -> sn.Name: raise NotImplementedError def get_kind(self, _: s_schema.Schema) -> ft.ParameterKind: raise NotImplementedError def get_default(self, _: s_schema.Schema) -> Optional[s_expr.Expression]: raise NotImplementedError def get_type(self, _: s_schema.Schema) -> s_types.Type: raise NotImplementedError def get_typemod(self, _: s_schema.Schema) -> ft.TypeModifier: raise NotImplementedError def as_str(self, schema: s_schema.Schema) -> str: raise NotImplementedError # Non-schema description of a parameter. class ParameterDesc(ParameterLike): num: int name: sn.Name default: Optional[s_expr.Expression] type: s_types.TypeShell[s_types.Type] typemod: ft.TypeModifier kind: ft.ParameterKind def __init__( self, *, num: int, name: sn.Name, default: Optional[s_expr.Expression], type: s_types.TypeShell[s_types.Type], typemod: ft.TypeModifier, kind: ft.ParameterKind, ) -> None: self.num = num self.name = name self.default = default self.type = type self.typemod = typemod self.kind = kind @classmethod def from_ast( cls, schema: s_schema.Schema, modaliases: Mapping[Optional[str], str], num: int, astnode: qlast.FuncParamDecl, ) -> ParameterDesc: paramd = None if astnode.default is not None: paramd = s_expr.Expression.from_ast( astnode.default, schema, modaliases, as_fragment=True) paramt_ast = astnode.type if astnode.kind is ft.ParameterKind.VariadicParam: paramt_ast = qlast.TypeName( maintype=qlast.ObjectRef( name='array', ), subtypes=[paramt_ast], ) paramt = utils.ast_to_type_shell( paramt_ast, metaclass=s_types.Type, modaliases=modaliases, schema=schema, ) return cls( num=num, name=sn.UnqualName(astnode.name), type=paramt, typemod=astnode.typemod, kind=astnode.kind, default=paramd ) def get_parameter_name(self, schema: s_schema.Schema) -> str: return str(self.name) def get_name(self, schema: s_schema.Schema) -> sn.Name: return self.name def get_kind(self, _: s_schema.Schema) -> ft.ParameterKind: return self.kind def get_default(self, _: s_schema.Schema) -> Optional[s_expr.Expression]: return self.default def get_type(self, schema: s_schema.Schema) -> s_types.Type: return self.type.resolve(schema) def get_type_shell( self, schema: s_schema.Schema, ) -> s_types.TypeShell[s_types.Type]: return self.type def get_typemod(self, _: s_schema.Schema) -> ft.TypeModifier: return self.typemod def as_str(self, schema: s_schema.Schema) -> str: return param_as_str(schema, self) @classmethod def from_create_delta( cls, schema: s_schema.Schema, context: sd.CommandContext, cmd: CreateParameter, ) -> tuple[s_schema.Schema, ParameterDesc]: props = cmd.get_attributes(schema, context) props['name'] = Parameter.paramname_from_fullname(props['name']) if not isinstance(props['type'], s_types.TypeShell): paramt = props['type'].as_shell(schema) else: paramt = props['type'] return schema, cls( num=props['num'], name=props['name'], type=paramt, typemod=props['typemod'], kind=props['kind'], default=props.get('default'), ) def get_fqname( self, schema: s_schema.Schema, func_fqname: sn.QualName, ) -> sn.QualName: return sn.QualName( func_fqname.module, sn.get_specialized_name(self.get_name(schema), str(func_fqname)) ) def as_create_delta( self, schema: s_schema.Schema, func_fqname: sn.QualName, *, context: sd.CommandContext, ) -> sd.CreateObject[Parameter]: CreateParameter = sd.get_object_command_class_or_die( sd.CreateObject, Parameter) param_name = self.get_fqname(schema, func_fqname) cmd = CreateParameter(classname=param_name) cmd.set_attribute_value('name', param_name) cmd.set_attribute_value('type', self.type) for attr in ('num', 'typemod', 'kind', 'default'): cmd.set_attribute_value(attr, getattr(self, attr)) return cmd def _params_are_all_required_singletons( params: Sequence[ParameterLike], schema: s_schema.Schema, ) -> bool: return all( param.get_kind(schema) is not ft.ParameterKind.VariadicParam and param.get_typemod(schema) is ft.TypeModifier.SingletonType and param.get_default(schema) is None for param in params ) def make_func_param( *, name: str, type: qlast.TypeExpr, typemod: qltypes.TypeModifier = qltypes.TypeModifier.SingletonType, kind: qltypes.ParameterKind, default: Optional[qlast.Expr] = None, ) -> qlast.FuncParamDecl: # If the param is variadic, strip the array from the type in the schema if kind is ft.ParameterKind.VariadicParam: assert ( isinstance(type, qlast.TypeName) and isinstance(type.maintype, qlast.ObjectRef) and type.maintype.name == 'array' and type.subtypes ) type = type.subtypes[0] return qlast.FuncParamDecl( name=name, type=type, typemod=typemod, kind=kind, default=default, ) class Parameter( so.ObjectFragment, so.Object, # Help reflection figure out the right db MRO ParameterLike, qlkind=ft.SchemaObjectClass.PARAMETER, data_safe=True, ): num = so.SchemaField( int, compcoef=0.4) default = so.SchemaField( s_expr.Expression, default=None, compcoef=0.4) type = so.SchemaField( s_types.Type, compcoef=0.4) typemod = so.SchemaField( ft.TypeModifier, default=ft.TypeModifier.SingletonType, coerce=True, compcoef=0.4) kind = so.SchemaField( ft.ParameterKind, coerce=True, compcoef=0.4) @classmethod def paramname_from_fullname(cls, fullname: sn.Name) -> str: parts = str(fullname.name).split('@', 1) if len(parts) == 2: return sn.unmangle_name(parts[0]) else: return fullname.name def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool = False, ) -> str: vn = super().get_verbosename(schema) if with_parent: pfns = [r for r in schema.get_referrers(self) if isinstance(r, CallableObject)] if pfns: pvn = pfns[0].get_verbosename(schema, with_parent=True) return f'{vn} of {pvn}' else: return vn else: return vn @classmethod def get_shortname_static(cls, name: sn.Name) -> sn.QualName: assert isinstance(name, sn.QualName) return sn.QualName( module='__', name=cls.paramname_from_fullname(name), ) @classmethod def get_displayname_static(cls, name: sn.Name) -> str: shortname = cls.get_shortname_static(name) return shortname.name def get_parameter_name(self, schema: s_schema.Schema) -> str: fullname = self.get_name(schema) return self.paramname_from_fullname(fullname) def get_ir_default( self, *, schema: s_schema.Schema, context: sd.CommandContext, ) -> irast.Statement: from edb.ir import utils as irutils defexpr = self.get_default(schema) assert defexpr is not None defexpr = defexpr.compiled( as_fragment=True, schema=schema, context=context, ) ir = defexpr.irast if not irutils.is_const(ir.expr): raise ValueError('expression not constant') return ir def as_str(self, schema: s_schema.Schema) -> str: return param_as_str(schema, self) @classmethod def compare_field_value[T]( cls, field: so.Field[builtins.type[T]], our_value: T, their_value: T, *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: so.ComparisonContext, ) -> float: # Only compare the actual param name, not the full name. if field.name == 'name': assert isinstance(our_value, sn.Name) assert isinstance(their_value, sn.Name) if ( cls.paramname_from_fullname(our_value) == cls.paramname_from_fullname(their_value) ): return 1.0 return super().compare_field_value( field, our_value, their_value, our_schema=our_schema, their_schema=their_schema, context=context, ) def get_ast(self, schema: s_schema.Schema) -> qlast.FuncParamDecl: default = self.get_default(schema) kind = self.get_kind(schema) return make_func_param( name=self.get_parameter_name(schema), type=utils.typeref_to_ast(schema, self.get_type(schema)), typemod=self.get_typemod(schema), kind=kind, default=default.parse() if default else None, ) class CallableCommandContext(sd.ObjectCommandContext['CallableObject'], s_anno.AnnotationSubjectCommandContext): pass class ParameterCommandContext(sd.ObjectCommandContext[Parameter]): pass # type ignore below, because making Parameter # a referencing.ReferencedObject breaks the code class ParameterCommand( referencing.ReferencedObjectCommandBase[Parameter], # type: ignore context_class=ParameterCommandContext, referrer_context_class=CallableCommandContext ): is_strong_ref = struct.Field(bool, default=True) def get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: # ParameterCommand cannot have its own AST because it is a # side-effect of a FunctionCommand. return None def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) return s_types.materialize_type_in_attribute( schema, context, self, 'type') def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: if field.name == 'default': return value.compiled( schema=schema, as_fragment=True, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, schema_object_context=self.get_schema_metaclass(), apply_query_rewrites=not context.stable_ids, track_schema_ref_exprs=track_schema_ref_exprs, ), context=context, ) else: return super().compile_expr_field( schema, context, field, value, track_schema_ref_exprs) def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: if field.name == 'default': type = self.scls.get_type(schema) return s_types.type_dummy_expr(type, schema) else: raise NotImplementedError(f'unhandled field {field.name!r}') class CreateParameter(ParameterCommand, sd.CreateObject[Parameter]): @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) for sub in cmd.get_subcommands(type=sd.AlterObjectProperty): if sub.property == 'default': sub.new_value = [sub.new_value] return cmd class DeleteParameter(ParameterCommand, sd.DeleteObject[Parameter]): def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._delete_begin(schema, context) if not context.canonical: typ = self.scls.get_type(schema) if op := typ.as_type_delete_if_unused(schema): self.add_caused(op) return schema class RenameParameter(ParameterCommand, sd.RenameObject[Parameter]): pass class AlterParameter(ParameterCommand, sd.AlterObject[Parameter]): pass class ParameterLikeList(abc.ABC): @abc.abstractmethod def get_by_name( self, schema: s_schema.Schema, name: str, ) -> Optional[ParameterLike]: raise NotImplementedError @abc.abstractmethod def as_str(self, schema: s_schema.Schema) -> str: raise NotImplementedError @abc.abstractmethod def has_polymorphic(self, schema: s_schema.Schema) -> bool: raise NotImplementedError @abc.abstractmethod def has_set_of(self, schema: s_schema.Schema) -> bool: raise NotImplementedError @abc.abstractmethod def has_objects(self, schema: s_schema.Schema) -> bool: raise NotImplementedError @abc.abstractmethod def find_named_only( self, schema: s_schema.Schema, ) -> Mapping[str, ParameterLike]: raise NotImplementedError @abc.abstractmethod def find_variadic( self, schema: s_schema.Schema, ) -> Optional[ParameterLike]: raise NotImplementedError @abc.abstractmethod def has_required_params( self, schema: s_schema.Schema, ) -> bool: raise NotImplementedError @abc.abstractmethod def objects( self, schema: s_schema.Schema, ) -> tuple[ParameterLike, ...]: raise NotImplementedError @abc.abstractmethod def get_in_canonical_order( self, schema: s_schema.Schema, ) -> tuple[ParameterLike, ...]: raise NotImplementedError class FuncParameterList(so.ObjectList[Parameter], ParameterLikeList): def get_by_name( self, schema: s_schema.Schema, name: str, ) -> Optional[Parameter]: for param in self.objects(schema): if param.get_parameter_name(schema) == name: return param return None def as_str(self, schema: s_schema.Schema) -> str: ret = [] for param in self.objects(schema): ret.append(param.as_str(schema)) return '(' + ', '.join(ret) + ')' def has_polymorphic(self, schema: s_schema.Schema) -> bool: return any( p.get_type(schema).is_polymorphic(schema) for p in self.objects(schema) ) def has_type_mod( self, schema: s_schema.Schema, mod: ft.TypeModifier ) -> bool: return any(p.get_typemod(schema) is mod for p in self.objects(schema)) def has_set_of(self, schema: s_schema.Schema) -> bool: return self.has_type_mod(schema, ft.TypeModifier.SetOfType) def has_objects(self, schema: s_schema.Schema) -> bool: return any( p.get_type(schema).is_object_type() for p in self.objects(schema) ) def find_named_only( self, schema: s_schema.Schema, ) -> Mapping[str, Parameter]: named = {} for param in self.objects(schema): if param.get_kind(schema) is ft.ParameterKind.NamedOnlyParam: named[param.get_parameter_name(schema)] = param return types.MappingProxyType(named) def find_variadic(self, schema: s_schema.Schema) -> Optional[Parameter]: for param in self.objects(schema): if param.get_kind(schema) is ft.ParameterKind.VariadicParam: return param return None def has_required_params(self, schema: s_schema.Schema) -> bool: return any( param.get_kind(schema) is not ft.ParameterKind.VariadicParam and param.get_default(schema) is None for param in self.objects(schema) ) def get_in_canonical_order( self, schema: s_schema.Schema, ) -> tuple[Parameter, ...]: return canonical_param_sort(schema, self.objects(schema)) def get_ast(self, schema: s_schema.Schema) -> list[qlast.FuncParamDecl]: result = [] for param in self.objects(schema): result.append(param.get_ast(schema)) return result @classmethod def compare_values( cls, ours_o: so.ObjectCollection[Parameter], theirs_o: so.ObjectCollection[Parameter], *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: so.ComparisonContext, compcoef: float, ) -> float: ours = list(ours_o.objects(our_schema)) theirs = list(theirs_o.objects(their_schema)) # Because parameter lists can't currently be ALTERed, any # changes are catastrophic, so return compcoef on any mismatch # at all. if len(ours) != len(theirs): return compcoef for param1, param2 in zip(ours, theirs): coef = param1.compare( param2, our_schema=our_schema, their_schema=their_schema, context=context) if coef != 1.0: return compcoef return 1.0 class VolatilitySubject(so.Object): volatility = so.SchemaField( ft.Volatility, default=ft.Volatility.Volatile, compcoef=0.4, coerce=True, allow_ddl_set=True) class CallableLike: """A minimal callable object interface required by multidispatch.""" def has_inlined_defaults(self, schema: s_schema.Schema) -> bool: raise NotImplementedError def get_params(self, schema: s_schema.Schema) -> ParameterLikeList: raise NotImplementedError def get_return_type(self, schema: s_schema.Schema) -> s_types.Type: raise NotImplementedError def get_return_typemod(self, schema: s_schema.Schema) -> ft.TypeModifier: raise NotImplementedError def get_signature_as_str(self, schema: s_schema.Schema) -> str: raise NotImplementedError def get_verbosename(self, schema: s_schema.Schema) -> str: raise NotImplementedError def get_abstract(self, schema: s_schema.Schema) -> bool: raise NotImplementedError CallableObjectT = TypeVar('CallableObjectT', bound='CallableObject') class CallableObject( so.QualifiedObject, s_anno.AnnotationSubject, CallableLike, ): params = so.SchemaField( FuncParameterList, coerce=True, compcoef=0.4, default=so.DEFAULT_CONSTRUCTOR, inheritable=False, simpledelta=False) return_type = so.SchemaField( s_types.Type, compcoef=0.2) return_typemod = so.SchemaField( ft.TypeModifier, compcoef=0.4, coerce=True) abstract = so.SchemaField( bool, default=False, inheritable=False, compcoef=0.909) impl_is_strict = so.SchemaField( bool, default=True, compcoef=0.4) # Kind of a hack: indicates that when possible we should pass arguments # to this function as a subquery-as-an-expression. This is important for # functions that see use in ORDER BY clauses that need indexes. # The compilation strategy this asks for /should/ work in general, # but I didn't want to make a major codegen change in an rc3. # We should consider doing this a different way. prefer_subquery_args = so.SchemaField( bool, default=False, compcoef=0.9) # Some set of calls are allowed in singleton expressions is_singleton_set_of = so.SchemaField( bool, default=False, compcoef=0.4) def as_create_delta( self: CallableObjectT, schema: s_schema.Schema, context: so.ComparisonContext, ) -> sd.CreateObject[CallableObjectT]: delta = super().as_create_delta(schema, context) new_params = self.get_params(schema).objects(schema) for p in new_params: if not param_is_inherited(schema, self, p): delta.add_prerequisite( p.as_create_delta(schema=schema, context=context), ) return delta def as_alter_delta( self: CallableObjectT, other: CallableObjectT, *, self_schema: s_schema.Schema, other_schema: s_schema.Schema, confidence: float, context: so.ComparisonContext, ) -> sd.ObjectCommand[CallableObjectT]: delta = super().as_alter_delta( other, self_schema=self_schema, other_schema=other_schema, confidence=confidence, context=context, ) old_params = self.get_params(self_schema).objects(self_schema) oldcoll = [ p for p in old_params if not param_is_inherited(self_schema, self, p) ] new_params = other.get_params(other_schema).objects(other_schema) newcoll = [ p for p in new_params if not param_is_inherited(other_schema, other, p) ] delta.add_prerequisite( sd.delta_objects( oldcoll, newcoll, sclass=Parameter, context=context, old_schema=self_schema, new_schema=other_schema, ), ) return delta def as_delete_delta( self: CallableObjectT, *, schema: s_schema.Schema, context: so.ComparisonContext, ) -> sd.ObjectCommand[CallableObjectT]: delta = super().as_delete_delta(schema=schema, context=context) old_params = self.get_params(schema).objects(schema) for p in old_params: if not param_is_inherited(schema, self, p): delta.add(p.as_delete_delta(schema=schema, context=context)) return delta @classmethod def _get_fqname_quals( cls, schema: s_schema.Schema, params: list[ParameterDesc], ) -> tuple[str, ...]: quals: list[str] = [] canonical_order = canonical_param_sort(schema, params) for param in canonical_order: pt = param.get_type_shell(schema) pt_id = str(pt.get_name(schema)) quals.append(pt_id) pk = param.get_kind(schema) if pk is ft.ParameterKind.NamedOnlyParam: quals.append(f'$NO-{param.get_name(schema)}-{pt_id}$') elif pk is ft.ParameterKind.VariadicParam: quals.append(f'$V$') return tuple(quals) @classmethod def get_fqname( cls, schema: s_schema.Schema, shortname: sn.QualName, params: list[ParameterDesc], *extra_quals: str, ) -> sn.QualName: quals = cls._get_fqname_quals(schema, params) return sn.QualName( module=shortname.module, name=sn.get_specialized_name(shortname, *(quals + extra_quals))) def has_inlined_defaults(self, schema: s_schema.Schema) -> bool: return False def is_blocking_ref( self, schema: s_schema.Schema, reference: so.Object, ) -> bool: # Parameters cannot be deleted via DDL syntax, # so the only possible scenario is the deletion of # the host function. return not isinstance(reference, Parameter) class ParametrizedCommand(sd.ObjectCommand[so.Object_T]): def _get_params( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> FuncParameterList: params = self.get_attribute_value('params') result: Any if params is None: param_list = [] for cr_param in self.get_subcommands(type=ParameterCommand): param = schema.get(cr_param.classname, type=Parameter) param_list.append(param) result = FuncParameterList.create(schema, param_list) elif isinstance(params, so.ObjectCollectionShell): result = params.resolve(schema) else: result = params assert isinstance(result, FuncParameterList) return result @classmethod def _get_param_desc_from_params_ast( cls, schema: s_schema.Schema, modaliases: Mapping[Optional[str], str], params: list[qlast.FuncParamDecl], *, param_offset: int=0, ) -> list[ParameterDesc]: return [ ParameterDesc.from_ast(schema, modaliases, num, param) for num, param in enumerate(params, param_offset) ] @classmethod def _get_param_desc_from_ast( cls, schema: s_schema.Schema, modaliases: Mapping[Optional[str], str], astnode: qlast.ObjectDDL, *, param_offset: int=0, ) -> list[ParameterDesc]: if not hasattr(astnode, 'params'): # Some Callables, like the concrete constraints, # have no params in their AST. return [] assert isinstance(astnode, qlast.CallableObjectCommand) return cls._get_param_desc_from_params_ast( schema, modaliases, astnode.params, param_offset=param_offset) @classmethod def _get_param_desc_from_delta( cls, schema: s_schema.Schema, context: sd.CommandContext, cmd: sd.Command, ) -> tuple[s_schema.Schema, list[ParameterDesc]]: params = [] for subcmd in cmd.get_subcommands(type=CreateParameter): schema, param = ParameterDesc.from_create_delta( schema, context, subcmd) params.append(param) return schema, params class CallableCommand(sd.QualifiedObjectCommand[CallableObjectT], ParametrizedCommand[CallableObjectT]): def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) return s_types.materialize_type_in_attribute( schema, context, self, 'return_type') class RenameCallableObject( CallableCommand[CallableObjectT], sd.RenameObject[CallableObjectT], ): def _canonicalize( self, schema: s_schema.Schema, context: sd.CommandContext, scls: CallableObjectT, ) -> None: super()._canonicalize(schema, context, scls) # Don't do anything for concrete constraints if not isinstance(scls, Function) and not scls.get_abstract(schema): return # params don't get picked up by the base _canonicalize because # they aren't RefDicts (and use a different mangling scheme to # boot), so we need to do it ourselves. param_list = scls.get_params(schema) params = CallableCommand._get_param_desc_from_params_ast( schema, context.modaliases, param_list.get_ast(schema)) assert isinstance(self.new_name, sn.QualName) for dparam, oparam in zip(params, param_list.objects(schema)): self.add(self.init_rename_branch( oparam, dparam.get_fqname(schema, self.new_name), schema=schema, context=context, )) class AlterCallableObject( CallableCommand[CallableObjectT], sd.AlterObject[CallableObjectT], ): def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.CallableObjectCommand]: node = cast( Optional[qlast.CallableObjectCommand], # Skip AlterObject's _get_ast, since we don't want to # filter things without subcommands. (Since updating # nativecode isn't a subcommand in the AST.) super(sd.AlterObject, self)._get_ast( schema, context, parent_node=parent_node) ) if not node: return None scls = self.get_object(schema, context) node.params = scls.get_params(schema).get_ast(schema) return node def _alter_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_innards(schema, context) for op in self.get_subcommands(metaclass=Parameter): schema = op.apply(schema, context=context) return schema class CreateCallableObject( CallableCommand[CallableObjectT], sd.CreateObject[CallableObjectT], ): @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(astnode, qlast.CreateObject) assert isinstance(cmd, CreateCallableObject) params = cls._get_param_desc_from_ast( schema, context.modaliases, astnode) for param in params: # as_create_delta requires the specific type cmd.add_prerequisite(param.as_create_delta( schema, cmd.classname, context=context)) if hasattr(astnode, 'returning'): assert isinstance(astnode, (qlast.CreateOperator, qlast.CreateFunction)) modaliases = context.modaliases return_type = utils.ast_to_type_shell( astnode.returning, metaclass=s_types.Type, modaliases=modaliases, module=cmd.classname.module, schema=schema, ) cmd.set_attribute_value( 'return_type', return_type) cmd.set_attribute_value( 'return_typemod', astnode.returning_typemod) return cmd def get_resolved_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> dict[str, Any]: params = self._get_params(schema, context) props = super().get_resolved_attributes(schema, context) props['params'] = params return props def _skip_param(self, props: dict[str, Any]) -> bool: return False def _get_params_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, ) -> list[tuple[int, qlast.FuncParamDecl]]: params: list[tuple[int, qlast.FuncParamDecl]] = [] for op in self.get_subcommands(type=ParameterCommand): props = op.get_resolved_attributes(schema, context) if self._skip_param(props): continue num: int = props['num'] default: Optional[s_expr.Expression] = props.get('default') param = make_func_param( name=Parameter.paramname_from_fullname(props['name']), type=utils.typeref_to_ast(schema, props['type']), typemod=props['typemod'], kind=props['kind'], default=default.parse() if default is not None else None, ) params.append((num, param)) params.sort(key=lambda e: e[0]) return params node.params = [p[1] for p in params] def _apply_fields_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, ) -> None: super()._apply_fields_ast(schema, context, node) params = self._get_params_ast(schema, context, node) if isinstance(node, qlast.CallableObjectCommand): node.params = [p[1] for p in params] class DeleteCallableObject( CallableCommand[CallableObjectT], sd.DeleteObject[CallableObjectT], ): def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._delete_begin(schema, context) scls = self.scls if ( not context.canonical # Don't do anything for concrete constraints and (isinstance(scls, Function) or scls.get_abstract(schema)) ): for param in scls.get_params(schema).objects(schema): self.add(param.init_delta_command(schema, sd.DeleteObject)) return_type = scls.get_return_type(schema) if op := return_type.as_type_delete_if_unused(schema): self.add_caused(op) return schema class Function( CallableObject, VolatilitySubject, qlkind=ft.SchemaObjectClass.FUNCTION, data_safe=True, ): used_globals = so.SchemaField( so.ObjectSet[s_globals.Global], coerce=True, default=so.DEFAULT_CONSTRUCTOR, inheritable=False ) used_permissions = so.SchemaField( so.ObjectSet[s_permissions.Permission], coerce=True, default=so.DEFAULT_CONSTRUCTOR, inheritable=False, ) required_permissions = so.SchemaField( so.ObjectSet[s_permissions.Permission], coerce=True, default=so.DEFAULT_CONSTRUCTOR, inheritable=False, allow_ddl_set=True, compcoef=0.8, ) # A backend_name that is shared between all overloads of the same # function, to make them independent from the actual name. backend_name = so.SchemaField( uuid.UUID, default=None, ) code = so.SchemaField( str, default=None, compcoef=0.4) # Function body, when language is EdgeQL nativecode = so.SchemaField( s_expr.Expression, default=None, compcoef=0.9, reflection_name='body') language = so.SchemaField( qlast.Language, default=None, compcoef=0.4, coerce=True, reflection_name='language_real') reflected_language = so.SchemaField( str, reflection_name='language') from_function = so.SchemaField( str, default=None, compcoef=0.4) from_expr = so.SchemaField( bool, default=False, compcoef=0.4) force_return_cast = so.SchemaField( bool, default=False, compcoef=0.9) sql_func_has_out_params = so.SchemaField( bool, default=False, compcoef=0.9) error_on_null_result = so.SchemaField( str, default=None, compcoef=0.9) #: For a generic function, if True, indicates that the #: optionality of the result set should be the same as #: of the generic argument. (See std::assert_single). preserves_optionality = so.SchemaField( bool, default=False, compcoef=0.99) #: For a generic function, if True, indicates that the #: upper cardinality of the result set should be the same as #: of the generic argument. (See std::assert_exists). preserves_upper_cardinality = so.SchemaField( bool, default=False, compcoef=0.99) initial_value = so.SchemaField( s_expr.Expression, default=None, compcoef=0.4, coerce=True) # This flag indicates that this function is intended to be used as # a generic fallback implementation for a particular polymorphic # function. The fallback implementation is exempted from the # limitation that all polymorphic functions have to map to the # same function in Postgres. There can only be at most one # fallback implementation for any given polymorphic function. # # The flag is intended for internal use for standard library # functions. fallback = so.SchemaField( bool, default=False, inheritable=False, compcoef=0.909, ) is_inlined = so.SchemaField(bool, default=False) # A json string which describes any server param conversions to apply. # # The data should take the form: dict[str, str | list[str]] # # The key should be the names of the converted params. # The value should be either: the conversion name, or a list of strings # where the first item is the name of the conversion. # # If the value is a list, the additional items act as parameters to the # conversion. server_param_conversions = so.SchemaField( str, default=None, compcoef=0.0, # HACK: We don't actually allow users to set this in DDL, but # we want to do it in one of our test suite schemas, and # unless we set allow_ddl_set, it won't get DESCRIBEd # correctly, which breaks patch and upgrade testing. # So we do this check explicitly. allow_ddl_set=True, ) def has_inlined_defaults(self, schema: s_schema.Schema) -> bool: # This can be relaxed to just `language is EdgeQL` when we # support non-constant defaults. return bool(self.get_language(schema) is qlast.Language.EdgeQL and self.get_params(schema).find_named_only(schema)) def get_signature_as_str( self, schema: s_schema.Schema, ) -> str: params = self.get_params(schema) sn = self.get_shortname(schema) return f"{sn}{params.as_str(schema)}" def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool=False, ) -> str: return f"function '{self.get_signature_as_str(schema)}'" def find_object_param_overloads( self, schema: s_schema.Schema, *, span: Optional[parsing.Span] = None, ) -> Optional[tuple[list[Function], int]]: """Find if this function overloads another in object parameter. If so, check the following rules: - in the signatures of functions, only the overloaded object parameter must differ, the number and the types of other parameters must be the same across all object-overloaded functions; - the names of arguments in object-overloaded functions must match. If there are object overloads, return a tuple containing the list of overloaded functions and the position of the overloaded parameter. """ params = self.get_params(schema) if not params.has_objects(schema): return None new_params = params.objects(schema) new_pt = tuple(p.get_type(schema) for p in new_params) diff_param = -1 overloads = [] sn = self.get_shortname(schema) for f in lookup_functions(sn, schema=schema): if f == self: continue f_params = f.get_params(schema) if not f_params.has_objects(schema): continue ext_params = f_params.objects(schema) ext_pt = (p.get_type(schema) for p in ext_params) this_diff_param = -1 non_obj_param_diff = False multi_overload = False for i, (new_t, ext_t) in enumerate(zip(new_pt, ext_pt)): if new_t != ext_t: if new_t.is_object_type() and ext_t.is_object_type(): if ( this_diff_param != -1 or ( this_diff_param != -1 and diff_param != -1 and diff_param != this_diff_param ) or non_obj_param_diff ): multi_overload = True break else: this_diff_param = i else: non_obj_param_diff = True if this_diff_param != -1: multi_overload = True break if this_diff_param != -1: if not multi_overload: multi_overload = len(new_params) != len(ext_params) if multi_overload: # Multiple dispatch of object-taking functions is # not supported. my_sig = self.get_signature_as_str(schema) other_sig = f.get_signature_as_str(schema) raise errors.UnsupportedFeatureError( f'cannot create the `{my_sig}` function: ' f'overloading an object type-receiving ' f'function with differences in the remaining ' f'parameters is not supported', span=span, details=( f"Other function is defined as `{other_sig}`" ) ) if not all( new_p.get_parameter_name(schema) == ext_p.get_parameter_name(schema) for new_p, ext_p in zip(new_params, ext_params) ): # And also _all_ parameter names must match due to # current implementation constraints. my_sig = self.get_signature_as_str(schema) other_sig = f.get_signature_as_str(schema) raise errors.UnsupportedFeatureError( f'cannot create the `{my_sig}` ' f'function: overloading an object type-receiving ' f'function with differences in the names of ' f'parameters is not supported', span=span, details=( f"Other function is defined as `{other_sig}`" ) ) if not all( new_p.get_typemod(schema) == ext_p.get_typemod(schema) for new_p, ext_p in zip(new_params, ext_params) ): # And also _all_ parameter names must match due to # current implementation constraints. my_sig = self.get_signature_as_str(schema) other_sig = f.get_signature_as_str(schema) raise errors.UnsupportedFeatureError( f'cannot create the `{my_sig}` ' f'function: overloading an object type-receiving ' f'function with differences in the type modifiers of ' f'parameters is not supported', span=span, details=( f"Other function is defined as `{other_sig}`" ) ) if ( new_params[this_diff_param].get_typemod(schema) != ft.TypeModifier.SingletonType ): my_sig = self.get_signature_as_str(schema) raise errors.UnsupportedFeatureError( f'cannot create the `{my_sig}` function: ' f'object type-receiving ' f'functions may not be overloaded on an OPTIONAL ' f'parameter', span=span, ) diff_param = this_diff_param overloads.append(f) if diff_param == -1: return None else: return (overloads, diff_param) class FunctionCommandContext(CallableCommandContext): pass class FunctionCommand( CallableCommand[Function], context_class=FunctionCommandContext, ): @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> sn.QualName: # _classname_from_ast signature expects qlast.ObjectDDL, # but _get_param_desc_from_ast expects a ObjectDDL, # which is more specific assert isinstance(astnode, qlast.ObjectDDL) name = super()._classname_from_ast(schema, astnode, context) params = cls._get_param_desc_from_ast( schema, context.modaliases, astnode) return cls.get_schema_metaclass().get_fqname(schema, name, params) def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if field == 'nativecode': return 'nativecode' else: return super().get_ast_attr_for_field(field, astnode) def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: if not context.stdmode and not context.testmode: if self.scls.get_server_param_conversions(schema): raise errors.InvalidFunctionDefinitionError( f'setting server_param_conversions is not supported in ' f'user-defined functions', span=self.span) def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: if field.name == 'initial_value': return value.compiled( schema=schema, options=qlcompiler.CompilerOptions( allow_generic_type_output=True, schema_object_context=self.get_schema_metaclass(), apply_query_rewrites=not context.stdmode, track_schema_ref_exprs=track_schema_ref_exprs, ), context=context, ) elif field.name == 'nativecode': return self.compile_this_function( schema, context, value, track_schema_ref_exprs, ) else: return super().compile_expr_field( schema, context, field, value, track_schema_ref_exprs) def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: if field.name == 'nativecode': func = schema.get(self.classname, type=Function) rt = func.get_return_type(schema) return s_types.type_dummy_expr(rt, schema) else: raise NotImplementedError(f'unhandled field {field.name!r}') def _get_attribute_value( self, schema: s_schema.Schema, context: sd.CommandContext, name: str, ) -> Any: val = self.get_resolved_attribute_value( name, schema=schema, context=context, ) mcls = self.get_schema_metaclass() if val is None: field = mcls.get_field(name) assert isinstance(field, so.SchemaField) val = field.default if val is None: raise AssertionError( f'missing required {name} for {mcls.__name__}' ) return val def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) # When volatility is altered, we need to force a # reconsideration of nativecode if it exists in order to check # it against the new volatility or compute the volatility on a # RESET. This is kind of unfortunate. if ( isinstance(self, sd.AlterObject) and self.has_attribute_value('volatility') and not self.has_attribute_value('nativecode') and (nativecode := self.scls.get_nativecode(schema)) is not None ): self.set_attribute_value( 'nativecode', nativecode.not_compiled() ) # Resolving 'nativecode' has side effects on has_dml and # volatility, so force it to happen as part of # canonicalization of attributes. super().get_resolved_attribute_value( 'nativecode', schema=schema, context=context) return schema def compile_this_function( self, schema: s_schema.Schema, context: sd.CommandContext, body: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: params = self._get_params(schema, context) language = self._get_attribute_value(schema, context, 'language') return_type = self._get_attribute_value(schema, context, 'return_type') return_typemod = self._get_attribute_value( schema, context, 'return_typemod') expr = compile_function( schema, context, body=body, func_name=self.classname, params=params, language=language, return_type=return_type, return_typemod=return_typemod, track_schema_ref_exprs=track_schema_ref_exprs, ) ir = expr.irast spec_volatility: Optional[ft.Volatility] = ( self.get_specified_attribute_value('volatility', schema, context)) if spec_volatility is None: self.set_attribute_value('volatility', ir.volatility, computed=True) # If a volatility is specified, it can be more volatile than the # inferred volatility but not less. if spec_volatility is not None and spec_volatility < ir.volatility: # When restoring from old versions, just ignore the problem # and use the inferred volatility if context.compat_ver_is_before( (1, 0, verutils.VersionStage.ALPHA, 8) ): self.set_attribute_value('volatility', ir.volatility) else: raise errors.InvalidFunctionDefinitionError( f'volatility mismatch in function declared as ' f'{str(spec_volatility).lower()}', details=f'Actual volatility is ' f'{str(ir.volatility).lower()}', span=body.parse().span, ) globs = { schema.get(glob.global_name, type=s_globals.Global) for glob in ir.globals if not glob.is_permission } self.set_attribute_value('used_globals', globs) permissions = { schema.get(glob.global_name, type=s_permissions.Permission) for glob in ir.globals if glob.is_permission } self.set_attribute_value('used_permissions', permissions) return expr @classmethod def localnames_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> set[str]: localnames = super().localnames_from_ast( schema, astnode, context ) if isinstance(astnode, (qlast.CreateFunction, qlast.AlterFunction)): localnames |= {param.name for param in astnode.params} return localnames class CreateFunction(CreateCallableObject[Function], FunctionCommand): astnode = qlast.CreateFunction def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: from edb.ir import utils as irutils fullname = self.classname shortname = sn.shortname_from_fullname(fullname) schema, cp = self._get_param_desc_from_delta(schema, context, self) signature = f'{shortname}({", ".join(p.as_str(schema) for p in cp)})' if func := schema.get(fullname, None): raise errors.DuplicateFunctionDefinitionError( f'cannot create the `{signature}` function: ' f'a function with the same signature ' f'is already defined', span=self.span) if not context.canonical: fullname = self.classname shortname = sn.shortname_from_fullname(fullname) if backend_name := self.get_prespecified_id( context, id_field='backend_name'): pass elif others := lookup_functions( sn.QualName(fullname.module, shortname.name), (), schema=schema ): backend_name = others[0].get_backend_name(schema) elif context.stdmode: backend_name = uuidgen.uuid5(FUNC_NAMESPACE, str(fullname)) else: backend_name = uuidgen.uuid1mc() if not self.has_attribute_value('backend_name'): self.set_attribute_value('backend_name', backend_name) if ( self.has_attribute_value("code") or self.has_attribute_value("nativecode") ) and not self.has_attribute_value('impl_is_strict'): self.set_attribute_value( 'impl_is_strict', _params_are_all_required_singletons(cp, schema), ) # Check if other schema objects with the same name (ignoring # signature, of course) exist. if other := schema.get( sn.QualName(fullname.module, shortname.name), None): raise errors.SchemaError( f'{other.get_verbosename(schema)} already exists') schema = super()._create_begin(schema, context) params: FuncParameterList = self.scls.get_params(schema) language = self.scls.get_language(schema) return_type = self.scls.get_return_type(schema) return_typemod = self.scls.get_return_typemod(schema) from_function = self.scls.get_from_function(schema) has_polymorphic = params.has_polymorphic(schema) has_set_of = params.has_set_of(schema) has_objects = params.has_objects(schema) polymorphic_return_type = return_type.is_polymorphic(schema) named_only = params.find_named_only(schema) fallback = self.scls.get_fallback(schema) preserves_opt = self.scls.get_preserves_optionality(schema) preserves_upper_card = self.scls.get_preserves_upper_cardinality( schema) if preserves_opt and not has_set_of: raise errors.InvalidFunctionDefinitionError( f'cannot create `{signature}` function: ' f'"preserves_optionality" makes no sense ' f'in a non-aggregate function', span=self.span) if preserves_upper_card and not has_set_of: raise errors.InvalidFunctionDefinitionError( f'cannot create `{signature}` function: ' f'"preserves_upper_cardinality" makes no sense ' f'in a non-aggregate function', span=self.span) if preserves_upper_card and ( return_typemod is not ft.TypeModifier.SetOfType ): raise errors.InvalidFunctionDefinitionError( f'cannot create `{signature}` function: ' f'"preserves_upper_cardinality" makes no sense ' f'in a function not returning SET OF', span=self.span) # Certain syntax is only allowed in "EdgeDB developer" mode, # i.e. when populating std library, etc. if not context.stdmode and not context.testmode: if has_polymorphic or polymorphic_return_type: raise errors.InvalidFunctionDefinitionError( f'cannot create `{signature}` function: ' f'generic types are not supported in ' f'user-defined functions', span=self.span) elif from_function: raise errors.InvalidFunctionDefinitionError( f'cannot create `{signature}` function: ' f'"USING SQL FUNCTION" is not supported in ' f'user-defined functions', span=self.span) elif language != qlast.Language.EdgeQL: raise errors.InvalidFunctionDefinitionError( f'cannot create `{signature}` function: ' f'"USING {language}" is not supported in ' f'user-defined functions', span=self.span) if polymorphic_return_type and not has_polymorphic: raise errors.InvalidFunctionDefinitionError( f'cannot create `{signature}` function: ' f'function returns a generic type but has no ' f'generic parameters', span=self.span) overloaded_funcs = lookup_functions(shortname, (), schema=schema) has_from_function = from_function for func in overloaded_funcs: func_params = func.get_params(schema) func_named_only = func_params.find_named_only(schema) func_from_function = func.get_from_function(schema) func_preserves_opt = func.get_preserves_optionality(schema) func_preserves_upper_card = func.get_preserves_upper_cardinality( schema) if func_named_only.keys() != named_only.keys(): raise errors.InvalidFunctionDefinitionError( f'cannot create `{signature}` function: ' f'overloading another function with different ' f'named only parameters: ' f'"{func.get_signature_as_str(schema)}"', span=self.span) if ((has_polymorphic or func_params.has_polymorphic(schema)) and ( func.get_return_typemod(schema) != return_typemod)): func_return_typemod = func.get_return_typemod(schema) raise errors.InvalidFunctionDefinitionError( f'cannot create the polymorphic `{signature} -> ' f'{return_typemod.to_edgeql()} ' f'{return_type.get_displayname(schema)}` ' f'function: overloading another function with different ' f'return type {func_return_typemod.to_edgeql()} ' f'{func.get_return_type(schema).get_displayname(schema)}', span=self.span) if fallback and func.get_fallback(schema) and self.scls != func: raise errors.InvalidFunctionDefinitionError( f'cannot create the polymorphic `{signature} -> ' f'{return_typemod.to_edgeql()} ' f'{return_type.get_displayname(schema)}` ' f'function: only one generic fallback per polymorphic ' f'function is allowed', span=self.span) if func_from_function: has_from_function = func_from_function if func_preserves_opt != preserves_opt: raise errors.InvalidFunctionDefinitionError( f'cannot create `{signature}` function: ' f'overloading another function with different ' f'"preserves_optionality" attribute: ' f'`{func.get_signature_as_str(schema)}`', span=self.span) if func_preserves_upper_card != preserves_upper_card: raise errors.InvalidFunctionDefinitionError( f'cannot create `{signature}` function: ' f'overloading another function with different ' f'"preserves_upper_cardinality" attribute: ' f'`{func.get_signature_as_str(schema)}`', span=self.span) if has_objects: self.scls.find_object_param_overloads( schema, span=self.span) if has_from_function: # Ignore the generic fallback when considering # from_function for polymorphic functions. if (not fallback and from_function != has_from_function or any(not f.get_fallback(schema) and f.get_from_function(schema) != has_from_function for f in overloaded_funcs)): raise errors.InvalidFunctionDefinitionError( f'cannot create the `{signature}` function: ' f'overloading "USING SQL FUNCTION" functions is ' f'allowed only when all functions point to the same ' f'SQL function', span=self.span) if (language == qlast.Language.EdgeQL and any(p.get_typemod(schema) is ft.TypeModifier.SetOfType for p in params.objects(schema))): raise errors.UnsupportedFeatureError( f'cannot create the `{signature}` function: ' f'SET OF parameters in user-defined EdgeQL functions are ' f'not supported', span=self.span) # check that params of type 'anytype' don't have defaults for p in params.objects(schema): p_default = p.get_default(schema) if p_default is None: continue p_type = p.get_type(schema) try: ir_default = p.get_ir_default(schema=schema, context=context) except Exception as ex: raise errors.InvalidFunctionDefinitionError( f'cannot create the `{signature}` function: ' f'invalid default value {p_default.text!r} of parameter ' f'{p.get_displayname(schema)!r}: {ex}', span=self.span) check_default_type = True if p_type.is_polymorphic(schema): if irutils.is_empty(ir_default.expr): check_default_type = False else: raise errors.InvalidFunctionDefinitionError( f'cannot create the `{signature}` function: ' f'polymorphic parameter of type ' f'{p_type.get_displayname(schema)} cannot ' f'have a non-empty default value', span=self.span) elif (p.get_typemod(schema) is ft.TypeModifier.OptionalType and irutils.is_empty(ir_default.expr)): check_default_type = False if check_default_type: default_type = ir_default.stype if not default_type.assignment_castable_to( p_type, ir_default.schema ): raise errors.InvalidFunctionDefinitionError( f'cannot create the `{signature}` function: ' f'invalid declaration of parameter ' f'{p.get_displayname(schema)!r}: ' f'unexpected type of the default expression: ' f'{default_type.get_displayname(ir_default.schema)}, ' f'expected ' f'{p_type.get_displayname(schema)}', span=self.span) # Make sure variadic parameters do not contain optional types in # user-defined functions if language == qlast.Language.EdgeQL: if variadic := params.find_variadic(schema): typemod = variadic.get_typemod(schema) if typemod is ft.TypeModifier.OptionalType: raise errors.InvalidFunctionDefinitionError( f'cannot create the `{signature}` function: ' f'variadic argument ' f'`{variadic.get_displayname(schema)}` ' f'illegally declared with optional type in ' f'user-defined function', span=self.span) return schema @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, CreateFunction) reflected_language = 'builtin' assert isinstance(astnode, qlast.CreateFunction) if astnode.code is not None: cmd.set_attribute_value( 'language', astnode.code.language, ) if astnode.code.language is qlast.Language.EdgeQL: reflected_language = 'EdgeQL' nativecode_expr: qlast.Base if astnode.nativecode is not None: nativecode_expr = astnode.nativecode else: assert astnode.code.code is not None nativecode_expr = qlparser.parse_query(astnode.code.code) nativecode = s_expr.Expression.from_ast( nativecode_expr, schema, context.modaliases, context.localnames, ) cmd.set_attribute_value( 'nativecode', nativecode, ) elif astnode.code.from_function is not None: cmd.set_attribute_value( 'from_function', astnode.code.from_function ) elif ( astnode.code.from_expr is not None and astnode.code.code is None ): cmd.set_attribute_value( 'from_expr', astnode.code.from_expr, ) else: cmd.set_attribute_value( 'code', astnode.code.code, ) cmd.set_attribute_value('reflected_language', reflected_language) return cmd def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: assert isinstance(node, qlast.CreateFunction) new_value: Any = op.new_value if op.property == 'return_type': node.returning = utils.typeref_to_ast(schema, new_value) elif op.property == 'return_typemod': node.returning_typemod = new_value elif op.property == 'code': if node.code is None: node.code = qlast.FunctionCode() node.code.code = new_value elif op.property == 'language': if node.code is None: node.code = qlast.FunctionCode() node.code.language = new_value elif op.property == 'from_function' and new_value: if node.code is None: node.code = qlast.FunctionCode() node.code.from_function = new_value elif op.property == 'from_expr' and new_value: if node.code is None: node.code = qlast.FunctionCode() node.code.from_expr = new_value else: super()._apply_field_ast(schema, context, node, op) class RenameFunction(RenameCallableObject[Function], FunctionCommand): @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> sn.QualName: ctx = context.current() assert isinstance(ctx.op, AlterFunction) name = sd.QualifiedObjectCommand._classname_from_ast( schema, astnode, context) quals = list(sn.quals_from_fullname(ctx.op.classname)) out = sn.QualName( name=sn.get_specialized_name(name, *quals), module=name.module ) return out def validate_alter( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: cur_shortname = sn.shortname_from_fullname(self.classname) cur_name = sn.QualName(self.classname.module, cur_shortname.name) new_shortname = sn.shortname_from_fullname(self.new_name) assert isinstance(self.new_name, sn.QualName) new_name = sn.QualName(self.new_name.module, new_shortname.name) if cur_name == new_name: return existing = lookup_functions(cur_name, schema=schema) if len(existing) > 1: raise errors.SchemaError( 'renaming an overloaded function is not allowed', span=self.span) target = lookup_functions(new_name, (), schema=schema) if target: raise errors.SchemaError( f"can not rename function to '{new_name!s}' because " f"a function with the same name already exists, and " f"renaming into an overload is not supported", span=self.span) class AlterFunction(AlterCallableObject[Function], FunctionCommand): astnode = qlast.AlterFunction def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) scls = self.scls if context.canonical: return schema if self.has_attribute_value("fallback"): overloaded_funcs = schema._get_by_shortname( Function, self.scls.get_shortname(schema) ) or () if len([func for func in overloaded_funcs if func.get_fallback(schema)]) > 1: raise errors.InvalidFunctionDefinitionError( f'cannot alter the polymorphic ' f'{self.scls.get_verbosename(schema)}: ' f'only one generic fallback per polymorphic ' f'function is allowed', span=self.span) # If volatility or nativecode changed, propagate that to # referring exprs if not ( self.has_attribute_value("volatility") or self.has_attribute_value("nativecode") ): return schema # We also need to propagate changes to "parent" # overloads. This is mainly so they can get the proper global # variables updated. extra_refs: Optional[dict[so.Object, list[str]]] = None if (overloaded := scls.find_object_param_overloads(schema)): ov_funcs, ov_idx = overloaded cur_type = ( scls.get_params(schema).objects(schema)[ov_idx]. get_type(schema) ) extra_refs = { f: ['nativecode'] for f in ov_funcs if (f_type := f.get_params(schema).objects(schema)[ov_idx]. get_type(schema)) and f_type != cur_type and cur_type.issubclass(schema, f_type) } vn = scls.get_verbosename(schema, with_parent=True) schema = self._propagate_if_expr_refs( schema, context, extra_refs=extra_refs, action=f'alter the definition of {vn}') return schema @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(astnode, qlast.AlterFunction) if astnode.code is not None: if ( astnode.code.from_function is not None or astnode.code.from_expr ): raise errors.EdgeQLSyntaxError( 'altering function code is only supported for ' 'pure EdgeQL functions', span=astnode.span ) nativecode_expr: Optional[qlast.Expr] = None if astnode.nativecode is not None: nativecode_expr = astnode.nativecode elif ( astnode.code.language is qlast.Language.EdgeQL and astnode.code.code is not None ): nativecode_expr = qlparser.parse_query(astnode.code.code) else: cmd.set_attribute_value( 'code', astnode.code.code, ) if nativecode_expr is not None: nativecode = s_expr.Expression.from_ast( nativecode_expr, schema, context.modaliases, context.localnames, ) cmd.set_attribute_value( 'nativecode', nativecode, ) return cmd def _get_attribute_value( self, schema: s_schema.Schema, context: sd.CommandContext, name: str, ) -> Any: val = self.get_resolved_attribute_value( name, schema=schema, context=context, ) if val is None: val = self.scls.get_field_value(schema, name) if val is None: mcls = self.get_schema_metaclass() raise AssertionError( f'missing required {name} for {mcls.__name__}' ) return val def _get_params( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> FuncParameterList: return self.scls.get_params(schema) def canonicalize_alter_from_external_ref( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: # Produce a param desc list which we use to find a new name. param_list = self.scls.get_params(schema) params = CallableCommand._get_param_desc_from_params_ast( schema, context.modaliases, param_list.get_ast(schema)) name = sn.shortname_from_fullname(self.classname) assert isinstance(name, sn.QualName), "expected qualified name" new_fname = CallableObject.get_fqname(schema, name, params) if new_fname == self.classname: return # Do the rename rename = self.scls.init_delta_command( schema, sd.RenameObject, new_name=new_fname) rename.set_attribute_value( 'name', value=new_fname, orig_value=self.classname) self.add(rename) class DeleteFunction(DeleteCallableObject[Function], FunctionCommand): astnode = qlast.DropFunction def _apply_fields_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, ) -> None: super()._apply_fields_ast(schema, context, node) params = [] for op in self.get_subcommands(type=ParameterCommand): props = op.get_orig_attributes(schema, context) num: int = props['num'] param = make_func_param( name=Parameter.paramname_from_fullname(props['name']), type=utils.typeref_to_ast(schema, props['type']), typemod=props['typemod'], kind=props['kind'], ) params.append((num, param)) params.sort(key=lambda e: e[0]) assert isinstance(node, qlast.CallableObjectCommand) node.params = [p[1] for p in params] def get_params_symtable( params: FuncParameterList, schema: s_schema.Schema, *, inlined_defaults: bool, ) -> dict[str, qlast.Expr]: anchors: dict[str, qlast.Expr] = {} defaults_mask = qlast.TypeCast( expr=qlast.FunctionParameter(name='__defaults_mask__'), type=qlast.TypeName( maintype=qlast.ObjectRef( module='std', name='bytes', ), ), ) for pi, p in enumerate(params.get_in_canonical_order(schema)): p_shortname = p.get_parameter_name(schema) p_is_optional = ( p.get_typemod(schema) is not ft.TypeModifier.SingletonType ) anchors[p_shortname] = qlast.TypeCast( expr=qlast.FunctionParameter(name=p_shortname), cardinality_mod=( qlast.CardinalityModifier.Optional if p_is_optional else None ), type=utils.typeref_to_ast(schema, p.get_type(schema)), ) p_default = p.get_default(schema) if p_default is None: continue if not inlined_defaults: continue anchors[p_shortname] = qlast.IfElse( condition=qlast.BinOp( left=qlast.FunctionCall( func=('std', 'bytes_get_bit'), args=[ defaults_mask, qlast.Constant.integer(pi), ]), op='=', right=qlast.Constant.integer(0), ), if_expr=anchors[p_shortname], else_expr=qlast.OptionalExpr(expr=p_default.parse()), ) return anchors def compile_function( schema: s_schema.Schema, context: sd.CommandContext, *, body: s_expr.Expression, func_name: sn.QualName, params: FuncParameterList, language: qlast.Language, return_type: s_types.Type, return_typemod: ft.TypeModifier, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: assert language is qlast.Language.EdgeQL compiled = body.compiled( schema, options=get_compiler_options( schema, context, func_name=func_name, params=params, track_schema_ref_exprs=track_schema_ref_exprs, ), context=context, ) ir = compiled.irast schema = ir.schema if (not ir.stype.issubclass(schema, return_type) and not ir.stype.implicitly_castable_to(return_type, schema)): raise errors.InvalidFunctionDefinitionError( f'return type mismatch in function declared to return ' f'{return_type.get_verbosename(schema)}', details=f'Actual return type is ' f'{ir.stype.get_verbosename(schema)}', span=body.parse().span, ) if (return_typemod is not ft.TypeModifier.SetOfType and ir.cardinality.is_multi()): raise errors.InvalidFunctionDefinitionError( f'return cardinality mismatch in function declared to return ' f'a singleton', details=( f'Function may return a set with more than one element.' ), span=body.parse().span, ) elif (return_typemod is ft.TypeModifier.SingletonType and ir.cardinality.can_be_zero()): raise errors.InvalidFunctionDefinitionError( f'return cardinality mismatch in function declared to return ' f'exactly one value', details=( f'Function may return an empty set.' ), span=body.parse().span, ) return compiled def compile_function_inline( schema: s_schema.Schema, context: sd.CommandContext, *, body: s_expr.Expression, func_name: sn.QualName, params: FuncParameterList, language: qlast.Language, return_type: s_types.Type, return_typemod: ft.TypeModifier, track_schema_ref_exprs: bool=False, inlining_context: qlcontext.ContextLevel, ) -> irast.Set: """Compile a function body to be inlined.""" assert language is qlast.Language.EdgeQL from edb.edgeql.compiler import dispatch from edb.edgeql.compiler import pathctx from edb.edgeql.compiler import setgen from edb.edgeql.compiler import stmtctx ctx = stmtctx.init_context( schema=schema, options=get_compiler_options( schema, context, func_name=func_name, params=params, track_schema_ref_exprs=track_schema_ref_exprs, inlining_context=inlining_context, ), inlining_context=inlining_context, ) ql_expr = body.parse() # Wrap argument paths param_names: set[str] = { param.get_parameter_name(inlining_context.env.schema) for param in params.objects(inlining_context.env.schema) } argument_path_wrapper = ArgumentPathWrapper(param_names) ql_expr = argument_path_wrapper.visit(ql_expr) # Add implicit limit if present if ctx.implicit_limit: ql_expr = qlast.SelectQuery(result=ql_expr, implicit=True) ql_expr.limit = qlast.Constant.integer(ctx.implicit_limit) ir_set: irast.Set = dispatch.compile(ql_expr, ctx=ctx) # Copy schema back to inlining context if inlining_context: inlining_context.env.schema = ctx.env.schema # Create scoped set if necessary if pathctx.get_set_scope(ir_set, ctx=ctx) is None: ir_set = setgen.scoped_set(ir_set, ctx=ctx) return ir_set class ArgumentPathWrapper(ast.NodeTransformer): # Wrap paths based on the inlined arguments which are arguments to other # inlined functions. # # Given the functions: # function inner(x: int64) -> int64 using (x); # function outer(x: int64) -> int64 using (inner(x)); # # Before inlining the outer function, the irast may look like this: # FunctionCall outer # CallArg: Set expr~1: Parameter x # Body # Set: FunctionCall inner # CallArg: Set expr~2: Parameter x # Body # SelectStmt: Set expr~2: InlinedParameterExpr # # The outer function will then inline, `Parameter x`: # FunctionCall outer # CallArg: Set expr~1: Parameter x # Body # Set: FunctionCall inner # CallArg: Set expr~1: InlinedParameterExpr # Body # SelectStmt: Set expr~2: InlinedParameterExpr # # And the definition of `Set expr~2` will be removed. # # To ensure outer function inlines `Parameter x` while keeping the path id # of the inner function, wrap the parameter with a Select: # FunctionCall outer # CallArg: Set expr~1: Parameter x # Body # Set: FunctionCall inner # CallArg: Set expr~2: SelectStmt: Set expr~1: Parameter x # Body # SelectStmt: Set expr~2: InlinedParameterExpr def __init__( self, param_names: set[str], ) -> None: super().__init__() self.param_names = param_names def visit_FunctionCall(self, node: qlast.FunctionCall) -> qlast.Base: has_direct_args = False new_args: list[qlast.Expr] = [] new_kwargs: dict[str, qlast.Expr] = {} for arg in node.args: if ( isinstance(arg, qlast.Path) and isinstance(arg.steps[0], qlast.ObjectRef) and arg.steps[0].name in self.param_names ): has_direct_args = True new_args.append(qlast.SelectQuery(result=arg)) else: new_args.append(arg) for arg_name, arg in node.kwargs.items(): if ( isinstance(arg, qlast.Path) and isinstance(arg.steps[0], qlast.ObjectRef) and arg.steps[0].name in self.param_names ): has_direct_args = True new_kwargs[arg_name] = qlast.SelectQuery(result=arg) else: new_kwargs[arg_name] = arg if has_direct_args: node = node.replace(args=new_args, kwargs=new_kwargs) return cast(qlast.Base, self.generic_visit(node)) def get_compiler_options( schema: s_schema.Schema, context: sd.CommandContext, *, func_name: sn.QualName, params: FuncParameterList, track_schema_ref_exprs: bool, inlining_context: Optional[qlcontext.ContextLevel] = None, ) -> qlcompiler.CompilerOptions: has_inlined_defaults = ( bool(params.find_named_only(schema)) and inlining_context is None ) param_anchors = get_params_symtable( params, schema, inlined_defaults=has_inlined_defaults, ) return qlcompiler.CompilerOptions( anchors=param_anchors, func_name=( inlining_context.env.options.func_name if inlining_context is not None else func_name ), func_params=( inlining_context.env.options.func_params if inlining_context is not None else params ), json_parameters=( inlining_context.env.options.json_parameters if inlining_context is not None else False ), apply_query_rewrites=not context.stdmode, track_schema_ref_exprs=track_schema_ref_exprs, ) def lookup_functions( name: str | sn.Name, default: tuple[Function, ...] | so.NoDefaultT = so.NoDefault, *, module_aliases: Optional[Mapping[Optional[str], str]] = None, schema: s_schema.Schema, ) -> tuple[Function, ...]: funcs = s_schema.lookup( schema, name, getter=_get_functions, module_aliases=module_aliases, default=default, ) if funcs is not so.NoDefault: return funcs else: return s_schema.Schema.raise_bad_reference( name=name, module_aliases=module_aliases, type=Function, ) @lru.per_job_lru_cache() def _get_functions( schema: s_schema.Schema, name: sn.Name, ) -> tuple[Function, ...] | None: return schema._get_by_shortname(Function, name) ================================================ FILE: edb/schema/futures.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Callable, cast from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import qltypes from . import delta as sd from . import name as sn from . import objects as so from . import schema as s_schema class FutureBehavior( so.Object, qlkind=qltypes.SchemaObjectClass.FUTURE, data_safe=False, ): name = so.SchemaField( sn.Name, inheritable=False, compcoef=0.0, # can't rename ) class FutureBehaviorCommandContext( sd.ObjectCommandContext[FutureBehavior], ): pass # Unlike extensions, futures are *explicitly* built into the # language. Enabling or disabling a futures might require making # other changes (recompiling functions that depend on it, for # example), so each future is mapped to a handler function that can # generate a command. _FutureBehaviorHandler = Callable[ ['FutureBehaviorCommand', s_schema.Schema, sd.CommandContext, bool], tuple[s_schema.Schema, sd.Command], ] FUTURE_HANDLERS: dict[str, _FutureBehaviorHandler] = {} def register_handler( name: str, ) -> Callable[[_FutureBehaviorHandler], _FutureBehaviorHandler]: def func(f: _FutureBehaviorHandler) -> _FutureBehaviorHandler: FUTURE_HANDLERS[name] = f return f return func def future_enabled(schema: s_schema.Schema, feat: str) -> bool: return bool(schema.get_global(FutureBehavior, feat, default=None)) class FutureBehaviorCommand( sd.ObjectCommand[FutureBehavior], context_class=FutureBehaviorCommandContext, ): # A command that gets run after adjusting the future value. # It needs to run *after* the delete, for a 'drop future', # and so it can't use any of the existing varieties of subcommands. # # If anything else ends up needing to do this, we can add another # variety of subcommand. future_cmd: sd.Command | None = None def copy(self: FutureBehaviorCommand) -> FutureBehaviorCommand: result = super().copy() if self.future_cmd: result.future_cmd = self.future_cmd.copy() return result @classmethod def adapt( cls: type[FutureBehaviorCommand], obj: sd.Command ) -> FutureBehaviorCommand: result = super(FutureBehaviorCommand, cls).adapt(obj) assert isinstance(obj, FutureBehaviorCommand) mcls = cast(sd.CommandMeta, type(cls)) if obj.future_cmd: result.future_cmd = mcls.adapt(obj.future_cmd) return result def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) if not context.canonical and not isinstance(self, sd.AlterObject): key = str(self.classname) if key not in FUTURE_HANDLERS: raise errors.QueryError( f"Unknown future '{str(key)}'" ) schema, cmd = FUTURE_HANDLERS[key]( self, schema, context, isinstance(self, sd.CreateObject)) self.future_cmd = cmd if self.future_cmd: schema = self.future_cmd.apply(schema, context) return schema class CreateFutureBehavior( FutureBehaviorCommand, sd.CreateObject[FutureBehavior], ): astnode = qlast.CreateFuture class DeleteFutureBehavior( FutureBehaviorCommand, sd.DeleteObject[FutureBehavior], ): astnode = qlast.DropFuture class AlterFutureBehavior( FutureBehaviorCommand, sd.AlterObject[FutureBehavior], ): pass # These are registered here because they aren't directly related to # any schema elements. # They are all dummys now, too. @register_handler('simple_scoping') @register_handler('warn_old_scoping') @register_handler('_scoping_noop_test') def toggle_scoping_future( cmd: FutureBehaviorCommand, schema: s_schema.Schema, context: sd.CommandContext, on: bool, ) -> tuple[s_schema.Schema, sd.Command]: return schema, sd.CommandGroup() ================================================ FILE: edb/schema/globals.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, TYPE_CHECKING from edb import errors from edb.common import struct from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from . import annos as s_anno from . import delta as sd from . import expr as s_expr from . import expraliases as s_expraliases from . import name as sn from . import objects as so from . import types as s_types from . import utils if TYPE_CHECKING: from edb.schema import schema as s_schema class Global( so.QualifiedObject, s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.GLOBAL, data_safe=True, ): target = so.SchemaField( s_types.Type, compcoef=0.85, special_ddl_syntax=True, ) required = so.SchemaField( bool, default=False, compcoef=0.909, special_ddl_syntax=True, describe_visibility=( so.DescribeVisibilityPolicy.SHOW_IF_EXPLICIT_OR_DERIVED ), ) cardinality = so.SchemaField( qltypes.SchemaCardinality, default=qltypes.SchemaCardinality.One, compcoef=0.833, coerce=True, special_ddl_syntax=True, describe_visibility=( so.DescribeVisibilityPolicy.SHOW_IF_EXPLICIT_OR_DERIVED ), ) # Computable globals have this set to an expression # defining them. expr = so.SchemaField( s_expr.Expression, default=None, coerce=True, compcoef=0.909, special_ddl_syntax=True, ) default = so.SchemaField( s_expr.Expression, allow_ddl_set=True, default=None, coerce=True, compcoef=0.909, ) created_types = so.SchemaField( so.ObjectSet[s_types.Type], default=so.DEFAULT_CONSTRUCTOR, ) def is_computable(self, schema: s_schema.Schema) -> bool: return bool(self.get_expr(schema)) def needs_present_arg(self, schema: s_schema.Schema) -> bool: return bool(self.get_default(schema)) and not self.get_required(schema) class GlobalCommandContext( sd.ObjectCommandContext[so.Object], s_anno.AnnotationSubjectCommandContext ): pass class GlobalCommand( s_expraliases.AliasLikeCommand[Global], context_class=GlobalCommandContext, ): TYPE_FIELD_NAME = 'target' ALIAS_LIKE_EXPR_FIELDS = ('expr', 'default') @classmethod def _get_alias_name(cls, type_name: sn.QualName) -> sn.QualName: return cls._mangle_name(type_name, include_module_in_name=False) @classmethod def _is_computable(cls, obj: Global, schema: s_schema.Schema) -> bool: return obj.is_computable(schema) def _check_expr( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: expression = self.get_attribute_value('expr') assert isinstance(expression, s_expr.Expression) # If it's not compiled, don't worry about it. This should just # be a dummy expression. if not expression.irast: return schema required, card = expression.irast.cardinality.to_schema_value() spec_required: Optional[bool] = ( self.get_specified_attribute_value('required', schema, context)) spec_card: Optional[qltypes.SchemaCardinality] = ( self.get_specified_attribute_value('cardinality', schema, context)) glob_name = self.get_verbosename() if spec_required and not required: span = self.get_attribute_span('target') raise errors.SchemaDefinitionError( f'possibly an empty set returned by an ' f'expression for the computed ' f'{glob_name} ' f"explicitly declared as 'required'", span=span ) if ( spec_card is qltypes.SchemaCardinality.One and card is not qltypes.SchemaCardinality.One ): span = self.get_attribute_span('target') raise errors.SchemaDefinitionError( f'possibly more than one element returned by an ' f'expression for the computed ' f'{glob_name} ' f"explicitly declared as 'single'", span=span ) if spec_card is None: self.set_attribute_value('cardinality', card, computed=True) if spec_required is None: self.set_attribute_value('required', required, computed=True) return schema def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) if self.get_attribute_value('expr'): schema = self._check_expr(schema, context) schema = s_types.materialize_type_in_attribute( schema, context, self, 'target') return schema def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: scls = self.scls is_computable = scls.is_computable(schema) target = scls.get_target(schema) if not is_computable: if ( scls.get_required(schema) and not scls.get_default(schema) ): raise errors.SchemaDefinitionError( "required globals must have a default", span=self.span, ) if scls.get_cardinality(schema) == qltypes.SchemaCardinality.Many: raise errors.SchemaDefinitionError( "non-computed globals may not be multi", span=self.span, ) if target.contains_object(schema): raise errors.SchemaDefinitionError( "non-computed globals may not have have object type", span=self.span, ) default_expr = scls.get_default(schema) if default_expr is not None: default_expr = default_expr.ensure_compiled(schema, context=context) default_schema = default_expr.irast.schema default_type = default_expr.irast.stype span = self.get_attribute_span('default') if is_computable: raise errors.SchemaDefinitionError( f'computed globals may not have default values', span=span, ) if not default_type.assignment_castable_to(target, default_schema): raise errors.SchemaDefinitionError( f'default expression is of invalid type: ' f'{default_type.get_displayname(default_schema)}, ' f'expected {target.get_displayname(schema)}', span=span, ) ptr_cardinality = scls.get_cardinality(schema) default_required, default_cardinality = \ default_expr.irast.cardinality.to_schema_value() if (ptr_cardinality is qltypes.SchemaCardinality.One and default_cardinality != ptr_cardinality): raise errors.SchemaDefinitionError( f'possibly more than one element returned by ' f'the default expression for ' f'{scls.get_verbosename(schema)} declared as ' f"'single'", span=span, ) if scls.get_required(schema) and not default_required: raise errors.SchemaDefinitionError( f'possibly no elements returned by ' f'the default expression for ' f'{scls.get_verbosename(schema)} declared as ' f"'required'", span=span, ) if default_expr.irast.volatility.is_volatile(): raise errors.SchemaDefinitionError( f'{scls.get_verbosename(schema)} has a volatile ' f'default expression, which is not allowed', span=span, ) def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: if field.name in {'default', 'expr'}: ptr_name = self.get_verbosename() in_ddl_context_name = None if field.name == 'expr': in_ddl_context_name = f'computed {ptr_name}' return value.compiled( schema=schema, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, schema_object_context=self.get_schema_metaclass(), apply_query_rewrites=not context.stdmode, track_schema_ref_exprs=track_schema_ref_exprs, in_ddl_context_name=in_ddl_context_name, ), context=context, ) else: return super().compile_expr_field( schema, context, field, value, track_schema_ref_exprs) class CreateGlobal( s_expraliases.CreateAliasLike[Global], GlobalCommand, ): astnode = qlast.CreateGlobal def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if ( field == 'required' and issubclass(astnode, qlast.CreateGlobal) ): return 'is_required' elif ( field == 'cardinality' and issubclass(astnode, qlast.CreateGlobal) ): return 'cardinality' else: return super().get_ast_attr_for_field(field, astnode) def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: assert isinstance(node, qlast.CreateGlobal) if op.property == 'target': if not node.target: expr: Optional[s_expr.Expression] = ( self.get_attribute_value('expr') ) if expr is not None: node.target = expr.parse() else: t = op.new_value assert isinstance(t, (so.Object, so.ObjectShell)) node.target = utils.typeref_to_ast(schema, t) else: super()._apply_field_ast(schema, context, node, op) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(astnode, qlast.CreateGlobal) assert isinstance(cmd, GlobalCommand) if astnode.is_required is not None: cmd.set_attribute_value( 'required', astnode.is_required, span=astnode.span, ) if astnode.cardinality is not None: cmd.set_attribute_value( 'cardinality', astnode.cardinality, span=astnode.span, ) assert astnode.target is not None if isinstance(astnode.target, qlast.TypeExpr): type_ref = utils.ast_to_type_shell( astnode.target, metaclass=s_types.Type, modaliases=context.modaliases, schema=schema, ) cmd.set_attribute_value( 'target', type_ref, span=astnode.target.span, ) else: # computable qlcompiler.normalize( astnode.target, schema=schema, modaliases=context.modaliases ) cmd.set_attribute_value( 'expr', s_expr.Expression.from_ast( astnode.target, schema, context.modaliases, context.localnames, ), ) if ( cmd.has_attribute_value('expr') and cmd.has_attribute_value('target') ): raise errors.UnsupportedFeatureError( "cannot specify a type and an expression for a global", span=astnode.span, ) return cmd class RenameGlobal( s_expraliases.RenameAliasLike[Global], GlobalCommand, ): pass class AlterGlobal( s_expraliases.AlterAliasLike[Global], GlobalCommand, ): astnode = qlast.AlterGlobal def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if not context.canonical: old_expr = self.scls.get_expr(schema) has_expr = self.has_attribute_value('expr') clears_expr = has_expr and not self.get_attribute_value('expr') # Force reconsideration of the expression if cardinality # or required is changed. if ( ( self.has_attribute_value('cardinality') or self.has_attribute_value('required') ) and not has_expr and old_expr ): self.set_attribute_value( 'expr', s_expr.Expression.not_compiled(old_expr), ) # Produce an error when setting a type on something with # an expression if ( self.get_attribute_value('target') and ( (self.scls.get_expr(schema) or has_expr) and not clears_expr ) ): raise errors.UnsupportedFeatureError( "cannot specify a type and an expression for a global", span=self.span, ) if clears_expr and old_expr: # If the expression was explicitly set to None, # that means that `RESET EXPRESSION` was executed # and this is no longer a computable. computed_fields = self.scls.get_computed_fields(schema) if ( 'required' in computed_fields and not self.has_attribute_value('required') ): self.set_attribute_value('required', None) if ( 'cardinality' in computed_fields and not self.has_attribute_value('cardinality') ): self.set_attribute_value('cardinality', None) if not old_expr and (old_target := self.scls.get_target(schema)): if op := old_target.as_type_delete_if_unused(schema): self.add_caused(op) return super()._alter_begin(schema, context) class SetGlobalType( sd.AlterSpecialObjectField[Global], field='target', ): cast_expr = struct.Field(s_expr.Expression, default=None) reset_value = struct.Field(bool, default=False) def get_verb(self) -> str: return 'alter the type of' def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super()._alter_begin(schema, context) scls = self.scls orig_target = scls.get_explicit_field_value( orig_schema, 'target', None) new_target = scls.get_target(schema) if not orig_target or orig_target == new_target: return schema if not context.canonical: if self.cast_expr: raise errors.UnsupportedFeatureError( f'USING casts for SET TYPE on globals are not supported', hint='Use RESET TO DEFAULT instead', span=self.span, ) if not self.reset_value: raise errors.SchemaDefinitionError( f"SET TYPE on global must explicitly reset the " f"global's value", hint='Use RESET TO DEFAULT after the type', span=self.span, ) return schema @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, SetGlobalType) if ( isinstance(astnode, qlast.SetGlobalType) and astnode.cast_expr is not None ): cmd.cast_expr = s_expr.Expression.from_ast( astnode.cast_expr, schema, context.modaliases, context.localnames, ) if isinstance(astnode, qlast.SetGlobalType): cmd.reset_value = astnode.reset_value return cmd def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: set_field = super()._get_ast(schema, context, parent_node=parent_node) if set_field is None or self.is_attribute_computed('target'): return None else: assert isinstance(set_field, qlast.SetField) assert not isinstance(set_field.value, qlast.Expr) case_expr = None if self.cast_expr: assert isinstance(self.cast_expr, s_expr.Expression) case_expr = self.cast_expr.parse() return qlast.SetGlobalType( value=set_field.value, cast_expr=case_expr, reset_value=self.reset_value, ) def record_diff_annotations( self, *, schema: s_schema.Schema, orig_schema: Optional[s_schema.Schema], context: so.ComparisonContext, object: Optional[so.Object], orig_object: Optional[so.Object], ) -> None: super().record_diff_annotations( schema=schema, orig_schema=orig_schema, context=context, orig_object=orig_object, object=object, ) if orig_schema is None: return if ( not self.get_orig_attribute_value('expr') and not self.get_attribute_value('expr') ): self.reset_value = True class DeleteGlobal( s_expraliases.DeleteAliasLike[Global], GlobalCommand, ): astnode = qlast.DropGlobal def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._delete_begin(schema, context) scls = self.scls if not self._is_computable(scls, schema): target = scls.get_target(schema) if op := target.as_type_delete_if_unused(schema): self.add_caused(op) return schema ================================================ FILE: edb/schema/indexes.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, TypeVar, Mapping, Sequence, cast, overload, TYPE_CHECKING, ) from edb import edgeql from edb import errors from edb.common import parsing from edb.common import verutils from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from . import annos as s_anno from . import delta as sd from . import expr as s_expr from . import functions as s_func from . import inheriting from . import name as sn from . import pointers as s_pointers from . import objects as so from . import referencing from . import scalars as s_scalars from . import types as s_types from . import schema as s_schema from . import utils if TYPE_CHECKING: from . import objtypes as s_objtypes # The name used for default concrete indexes DEFAULT_INDEX = sn.QualName(module='__', name='idx') def is_index_valid_for_type( index: Index, expr_type: s_types.Type, schema: s_schema.Schema, context: sd.CommandContext, ) -> bool: index_allows_tuples = is_index_supporting_tuples(index, schema) for index_match in schema.get_referrers( index, scls_type=IndexMatch, field_name='index', ): valid_type = index_match.get_valid_type(schema) if index_allows_tuples: if is_subclass_or_tuple(expr_type, valid_type, schema): return True elif expr_type.issubclass(schema, valid_type): return True if context.testmode and str(index.get_name(schema)) == 'default::test': # For functional tests of abstract indexes. return expr_type.issubclass( schema, schema.get('std::str', type=s_scalars.ScalarType), ) return False def is_index_supporting_tuples( index: Index, schema: s_schema.Schema, ) -> bool: index_name = str(index.get_name(schema)) return index_name in { "std::fts::index", "ext::pg_trgm::gin", "ext::pg_trgm::gist", "std::pg::gist", "std::pg::gin", "std::pg::brin", } def is_subclass_or_tuple( ty: s_types.Type, parent: s_types.Type, schema: s_schema.Schema ) -> bool: if isinstance(ty, s_types.Tuple): for (_, st) in ty.iter_subtypes(schema): if not st.issubclass(schema, parent): return False return True else: return ty.issubclass(schema, parent) def _merge_deferrability( a: qltypes.IndexDeferrability, b: qltypes.IndexDeferrability, ) -> qltypes.IndexDeferrability: if a is b: return a else: if a is qltypes.IndexDeferrability.Prohibited: raise ValueError(f"{a} and {b} are incompatible") elif a is qltypes.IndexDeferrability.Permitted: return b else: return a def merge_deferrability( idx: Index, bases: list[Index], field_name: str, *, ignore_local: bool = False, schema: s_schema.Schema, ) -> Optional[qltypes.IndexDeferrability]: """Merge function for abstract index deferrability.""" return utils.merge_reduce( idx, bases, field_name=field_name, ignore_local=ignore_local, schema=schema, f=_merge_deferrability, type=qltypes.IndexDeferrability, ) def merge_deferred( idx: Index, bases: list[Index], field_name: str, *, ignore_local: bool = False, schema: s_schema.Schema, ) -> Optional[bool]: """Merge function for the DEFERRED qualifier on indexes.""" if idx.is_non_concrete(schema): return None if bases: deferrability = next(iter(bases)).get_deferrability(schema) else: deferrability = qltypes.IndexDeferrability.Prohibited local_deferred = idx.get_explicit_local_field_value( schema, field_name, None) idx_repr = idx.get_verbosename(schema, with_parent=True) if not idx.is_defined_here(schema): ignore_local = True if ignore_local: return deferrability is qltypes.IndexDeferrability.Required elif local_deferred is None: # No explicit local declaration, derive from abstract index # deferrability. if deferrability is qltypes.IndexDeferrability.Required: raise errors.SchemaDefinitionError( f"{idx_repr} must be declared as deferred" ) else: return False else: if ( local_deferred and deferrability is qltypes.IndexDeferrability.Prohibited ): raise errors.SchemaDefinitionError( f"{idx_repr} cannot be declared as deferred" ) elif ( not local_deferred and deferrability is qltypes.IndexDeferrability.Required ): raise errors.SchemaDefinitionError( f"{idx_repr} must be declared as deferred" ) return local_deferred # type: ignore def get_index_match_fullname_from_names( valid_type: sn.Name, index: sn.Name, ) -> sn.QualName: std = not ( ( isinstance(valid_type, sn.QualName) and sn.UnqualName(valid_type.module) not in s_schema.STD_MODULES ) or ( isinstance(index, sn.QualName) and sn.UnqualName(index.module) not in s_schema.STD_MODULES ) ) module = 'std' if std else '__ext_index_matches__' quals = [str(valid_type), str(index)] shortname = sn.QualName(module, 'index_match') return sn.QualName( module=shortname.module, name=sn.get_specialized_name(shortname, *quals), ) def get_index_match_fullname( schema: s_schema.Schema, valid_type: s_types.TypeShell[s_types.Type], index: so.ObjectShell[Index], ) -> sn.QualName: return get_index_match_fullname_from_names( valid_type.get_name(schema), index.get_name(schema), ) class Index( referencing.ReferencedInheritingObject, so.InheritingObject, # Help reflection figure out the right db MRO s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.INDEX, data_safe=True, ): # redefine, so we can change compcoef bases = so.SchemaField( so.ObjectList['Index'], # type: ignore type_is_generic_self=True, default=so.DEFAULT_CONSTRUCTOR, coerce=True, inheritable=False, compcoef=0.0, # can't rebase ) subject = so.SchemaField( so.Object, default=None, compcoef=None, inheritable=False, ) # These can only appear in base abstract index definitions. These # determine how indexes can be configured. params = so.SchemaField( s_func.FuncParameterList, coerce=True, compcoef=0.4, default=so.DEFAULT_CONSTRUCTOR, inheritable=False, ) # Appears in base abstract index definitions and defines how the index # is represented in postgres. code = so.SchemaField( str, default=None, compcoef=None, inheritable=False, allow_ddl_set=True, ) # These can appear in abstract indexes extending an existing one in order # to override exisitng parameters. Also they can appear in concrete # indexes. kwargs = so.SchemaField( s_expr.ExpressionDict, coerce=True, compcoef=0, default=so.DEFAULT_CONSTRUCTOR, inheritable=False, ddl_identity=True, ) type_args = so.SchemaField( so.ObjectList[so.Object], coerce=True, compcoef=0, default=so.DEFAULT_CONSTRUCTOR, inheritable=False, ) expr = so.SchemaField( s_expr.Expression, default=None, coerce=True, compcoef=0.0, ddl_identity=True, ) except_expr = so.SchemaField( s_expr.Expression, default=None, coerce=True, compcoef=0.0, ddl_identity=True, ) deferrability = so.SchemaField( qltypes.IndexDeferrability, default=qltypes.IndexDeferrability.Prohibited, coerce=True, compcoef=0.909, merge_fn=merge_deferrability, allow_ddl_set=True, ) deferred = so.SchemaField( bool, default=False, compcoef=0.909, special_ddl_syntax=True, describe_visibility=( so.DescribeVisibilityPolicy.SHOW_IF_EXPLICIT_OR_DERIVED ), merge_fn=merge_deferred, ) # Whether the index is created and populated in pg. Relevant if # build_concurrently is true? active = so.SchemaField( bool, default=True, ) # XXX: I am not sure this is what I want to do. build_concurrently = so.SchemaField( bool, default=False, compcoef=0.803, allow_ddl_set=True, ) def __repr__(self) -> str: cls = self.__class__ return '<{}.{} {!r} at 0x{:x}>'.format( cls.__module__, cls.__name__, self.id, id(self)) __str__ = __repr__ def as_delete_delta( self, *, schema: s_schema.Schema, context: so.ComparisonContext, ) -> sd.ObjectCommand[Index]: delta = super().as_delete_delta(schema=schema, context=context) old_params = self.get_params(schema).objects(schema) for p in old_params: delta.add(p.as_delete_delta(schema=schema, context=context)) return delta def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool = False ) -> str: # baseline name for indexes vn = self.get_displayname(schema) if self.get_abstract(schema): return f"abstract index '{vn}'" else: # concrete index must have a subject assert self.get_subject(schema) is not None # add kwargs (if any) to the concrete name kwargs = self.get_kwargs(schema) if kwargs: kw = [] for key, val in kwargs.items(): kw.append(f'{key}:={val.text}') vn = f'{vn}({", ".join(kw)})' vn = f"index {vn!r}" if with_parent: return self.add_parent_name(vn, schema) return vn def add_parent_name( self, base_name: str, schema: s_schema.Schema, ) -> str: # Remove the placeholder name of the generic index. if base_name == f"index '{DEFAULT_INDEX}'": base_name = 'index' return super().add_parent_name(base_name, schema) def is_non_concrete(self, schema: s_schema.Schema) -> bool: return self.get_subject(schema) is None @classmethod def get_shortname_static(cls, name: sn.Name) -> sn.QualName: name = sn.shortname_from_fullname(name) assert isinstance(name, sn.QualName) return name def get_all_kwargs( self, schema: s_schema.Schema, ) -> s_expr.ExpressionDict: kwargs = s_expr.ExpressionDict() all_kw = type(self).get_field('kwargs').merge_fn( self, self.get_ancestors(schema).objects(schema), 'kwargs', schema=schema, ) if all_kw: kwargs.update(all_kw) return kwargs def get_ddl_identity( self, schema: s_schema.Schema, ) -> Optional[dict[str, Any]]: v = super().get_ddl_identity(schema) or {} v['kwargs'] = self.get_all_kwargs(schema) return v def get_root( self, schema: s_schema.Schema, ) -> Index: if not self.get_abstract(schema): name = sn.shortname_from_fullname(self.get_name(schema)) index = schema.get(name, type=Index) else: index = self if index.get_bases(schema): return index.get_ancestors(schema).objects(schema)[-1] else: return index def get_concrete_kwargs( self, schema: s_schema.Schema, ) -> s_expr.ExpressionDict: assert not self.get_abstract(schema) root = self.get_root(schema) kwargs = self.get_all_kwargs(schema) for param in root.get_params(schema).objects(schema): kwname = param.get_parameter_name(schema) if ( kwname not in kwargs and (val := param.get_default(schema)) is not None ): kwargs[kwname] = val for k, v in kwargs.items(): kwargs[k] = v.ensure_compiled( schema, as_fragment=True, options=qlcompiler.CompilerOptions( schema_object_context=s_func.Parameter, ), context=None, ) return kwargs def get_concrete_kwargs_as_values( self, schema: s_schema.Schema, ) -> dict[str, Any]: kwargs = self.get_concrete_kwargs(schema) return { k: v.assert_compiled().as_python_value() for k, v in kwargs.items() } def is_defined_here( self, schema: s_schema.Schema, ) -> bool: """ Returns True iff the index has not been inherited from a parent subject, and was originally defined on the subject. """ return all( base.get_abstract(schema) for base in self.get_bases(schema).objects(schema) ) IndexableSubject_T = TypeVar('IndexableSubject_T', bound='IndexableSubject') class IndexableSubject(so.InheritingObject): indexes_refs = so.RefDict( attr='indexes', ref_cls=Index) indexes = so.SchemaField( so.ObjectIndexByFullname[Index], inheritable=False, ephemeral=True, coerce=True, compcoef=0.909, default=so.DEFAULT_CONSTRUCTOR) def add_index( self, schema: s_schema.Schema, index: Index, ) -> s_schema.Schema: return self.add_classref(schema, 'indexes', index) class IndexMatch( so.QualifiedObject, s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.INDEX_MATCH, data_safe=True, abstract=False, ): valid_type = so.SchemaField( s_types.Type, compcoef=0.5) index = so.SchemaField( Index, compcoef=0.5) class IndexSourceCommandContext: pass class IndexSourceCommand( inheriting.InheritingObjectCommand[IndexableSubject_T], ): pass class IndexCommandContext(sd.ObjectCommandContext[Index], s_anno.AnnotationSubjectCommandContext): pass class IndexMatchCommandContext(sd.ObjectCommandContext[IndexMatch], s_anno.AnnotationSubjectCommandContext): pass class IndexCommand( referencing.ReferencedInheritingObjectCommand[Index], s_func.ParametrizedCommand[Index], context_class=IndexCommandContext, referrer_context_class=IndexSourceCommandContext, ): @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> sn.QualName: # We actually want to override how ReferencedObjectCommand determines # the classname # # We need to resolve the name so that we get fully # canonicalized names for things like fts::index, which are # properly std::fts::index. # (We have to do that ourselves here because we are skipping # ReferencedObjectCommand, which would otherwise handle it.) shortname = utils.resolve_name( utils.ast_ref_to_name(astnode.name), modaliases=context.modaliases, schema=schema, metaclass=cls.get_schema_metaclass(), ) referrer_ctx = cls.get_referrer_context(context) if referrer_ctx is not None: referrer_name = referrer_ctx.op.classname assert isinstance(referrer_name, sn.QualName) quals = cls._classname_quals_from_ast( schema, astnode, shortname, referrer_name, context) name = sn.QualName( module=referrer_name.module, name=sn.get_specialized_name( shortname, str(referrer_name), *quals, ), ) else: name = super()._classname_from_ast(schema, astnode, context) return name @classmethod def _classname_quals_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, base_name: sn.Name, referrer_name: sn.QualName, context: sd.CommandContext, ) -> tuple[str, ...]: assert isinstance(astnode, qlast.ConcreteIndexCommand) exprs = [] kwargs = cls._index_kwargs_from_ast(schema, astnode, context) for key, val in kwargs.items(): exprs.append(f'{key}:={val.text}') # use the normalized text directly from the expression expr = s_expr.Expression.from_ast( astnode.expr, schema, context.modaliases) expr_text = expr.text assert expr_text is not None exprs.append(expr_text) if astnode.except_expr: expr = s_expr.Expression.from_ast( astnode.except_expr, schema, context.modaliases) exprs.append('!' + expr.text) return (cls._name_qual_from_exprs(schema, exprs),) @classmethod def _classname_quals_from_name(cls, name: sn.QualName) -> tuple[str, ...]: quals = sn.quals_from_fullname(name) return tuple(quals[-1:]) @classmethod def _index_kwargs_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> dict[str, s_expr.Expression]: kwargs = dict() # Some abstract indexes and all concrete index commands have kwargs. assert isinstance(astnode, (qlast.CreateIndex, qlast.ConcreteIndexCommand)) for key, val in astnode.kwargs.items(): kwargs[key] = s_expr.Expression.from_ast( val, schema, context.modaliases, as_fragment=True) return kwargs @overload def get_object( self, schema: s_schema.Schema, context: sd.CommandContext, *, name: Optional[sn.Name] = None, default: Index | so.NoDefaultT = so.NoDefault, span: Optional[parsing.Span] = None, ) -> Index: ... @overload def get_object( self, schema: s_schema.Schema, context: sd.CommandContext, *, name: Optional[sn.Name] = None, default: None = None, span: Optional[parsing.Span] = None, ) -> Optional[Index]: ... def get_object( self, schema: s_schema.Schema, context: sd.CommandContext, *, name: Optional[sn.Name] = None, default: Index | so.NoDefaultT | None = so.NoDefault, span: Optional[parsing.Span] = None, ) -> Optional[Index]: try: return super().get_object( schema, context, name=name, default=default, span=span, ) except errors.InvalidReferenceError: referrer_ctx = self.get_referrer_context_or_die(context) referrer = referrer_ctx.scls expr = self.get_ddl_identity('expr') raise errors.InvalidReferenceError( f"index on ({expr.text}) does not exist on " f"{referrer.get_verbosename(schema)}" ) from None @classmethod def _cmd_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.ObjectCommand[Index]: cmd = super()._cmd_from_ast(schema, astnode, context) if isinstance(astnode, qlast.ConcreteIndexCommand): cmd.set_ddl_identity( 'expr', s_expr.Expression.from_ast( astnode.expr, schema, context.modaliases, ), ) return cmd def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: astnode = super()._get_ast(schema, context, parent_node=parent_node) kwargs: Optional[Mapping[str, s_expr.Expression]] = ( self.get_resolved_attribute_value( 'kwargs', schema=schema, context=context, ) ) if kwargs and astnode: assert isinstance(astnode, (qlast.CreateIndex, qlast.ConcreteIndexCommand)) astnode.kwargs = { name: expr.parse() for name, expr in kwargs.items() } return astnode def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if field in ('kwargs', 'expr', 'except_expr'): return field elif ( field == 'deferred' and astnode is qlast.CreateConcreteIndex ): return field else: return super().get_ast_attr_for_field(field, astnode) def get_ddl_identity_fields( self, context: sd.CommandContext, ) -> tuple[so.Field[Any], ...]: id_fields = super().get_ddl_identity_fields(context) omit_fields = set() if ( self.get_attribute_value('abstract') and not self.get_attribute_value('bases') ): # Base abstract indexes don't have kwargs at all. omit_fields.add('kwargs') if omit_fields: return tuple(f for f in id_fields if f.name not in omit_fields) else: return id_fields def get_friendly_object_name_for_description( self, *, parent_op: Optional[sd.Command] = None, schema: Optional[s_schema.Schema] = None, object: Optional[so.Object_T] = None, object_desc: Optional[str] = None, ) -> str: friendly_name: str = 'index' expr: Optional[s_expr.Expression] = None if ( self.has_ddl_identity('expr') and (expr := self.get_ddl_identity('expr')) ): expr_text = expr.text if expr_text[0] != '(' or expr_text[-1] != ')': expr_text = '(' + expr_text + ')' friendly_name = f"index on {expr_text}" if not isinstance(parent_op, sd.ObjectCommand): return f"{friendly_name}" else: return f"{friendly_name} of {parent_op.get_verbosename()}" def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: from edb.ir import ast as irast from edb.ir import utils as irutils if field.name in {'expr', 'except_expr'}: # type ignore below, for the class is used as mixin parent_ctx = context.get_ancestor( IndexSourceCommandContext, # type: ignore self ) assert parent_ctx is not None assert isinstance(parent_ctx.op, sd.ObjectCommand) subject = parent_ctx.op.get_object(schema, context) expr = value.compiled( schema=schema, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, schema_object_context=self.get_schema_metaclass(), anchors={'__subject__': subject}, path_prefix_anchor='__subject__', singletons=frozenset([subject]), apply_query_rewrites=False, track_schema_ref_exprs=track_schema_ref_exprs, detached=True, ), context=context, ) # Check that the inferred cardinality is no more than 1 if expr.irast.cardinality.is_multi(): raise errors.SchemaDefinitionError( f'possibly more than one element returned by ' f'the index expression where only singletons ' f'are allowed', span=value.parse().span, ) if expr.irast.volatility != qltypes.Volatility.Immutable: raise errors.SchemaDefinitionError( f'index expressions must be immutable', span=value.parse().span, ) refs = irutils.get_longest_paths(expr.irast) has_multi = False for ref in refs: assert subject # Subject is a singleton in an index expression if it is itself # a singleton, regardless of other parts of the path. if irutils.ref_contains_multi(ref, subject.id): has_multi = True break if set_of_op := irutils.find_set_of_op( expr.irast, has_multi, ): label = ( 'function' if isinstance(set_of_op, irast.FunctionCall) else 'operator' ) op_name = str(set_of_op.func_shortname) raise errors.SchemaDefinitionError( f"cannot use SET OF {label} '{op_name}' " f"in an index expression", span=set_of_op.span ) # compile the expression to sql to preempt errors downstream utils.try_compile_irast_to_sql_tree(expr, self.span) return expr elif field.name == "kwargs": parent_ctx = context.get_ancestor( IndexSourceCommandContext, # type: ignore self ) if parent_ctx is not None: assert isinstance(parent_ctx.op, sd.ObjectCommand) subject = parent_ctx.op.get_object(schema, context) subject_vname = subject.get_verbosename(schema) idx_name = self.get_verbosename(parent=subject_vname) else: idx_name = self.get_verbosename() return type(value).compiled( value, schema=schema, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, schema_object_context=self.get_schema_metaclass(), apply_query_rewrites=not context.stdmode, track_schema_ref_exprs=track_schema_ref_exprs, in_ddl_context_name=idx_name, detached=True, ), context=context, ) else: return super().compile_expr_field( schema, context, field, value, track_schema_ref_exprs) def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: if field.name == 'expr': return s_expr.Expression(text='0') else: raise NotImplementedError(f'unhandled field {field.name!r}') def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) referrer_ctx = self.get_referrer_context(context) if referrer_ctx is not None: # Concrete index deferrability = self.get_attribute_value("deferrability") if deferrability is not None: raise errors.SchemaDefinitionError( "deferrability can only be specified on abstract indexes", span=self.get_attribute_span("deferrability"), ) return schema def ast_ignore_field_ownership(self, field: str) -> bool: """Whether to force generating an AST even though field isn't owned""" return field == "deferred" def _append_subcmd_ast( self, schema: s_schema.Schema, node: qlast.DDLOperation, subcmd: sd.Command, context: sd.CommandContext, ) -> None: if isinstance(subcmd, s_anno.AnnotationValueCommand): pname = sn.shortname_from_fullname(subcmd.classname) assert isinstance(pname, sn.QualName) # Skip injected annotations if pname.module == "ext::ai": return super()._append_subcmd_ast(schema, node, subcmd, context) class CreateIndex( IndexCommand, referencing.CreateReferencedInheritingObject[Index], ): astnode = [qlast.CreateConcreteIndex, qlast.CreateIndex] referenced_astnode = qlast.CreateConcreteIndex @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, IndexCommand) assert isinstance(astnode, (qlast.CreateConcreteIndex, qlast.CreateIndex)) if isinstance(astnode, qlast.CreateIndex): cmd.set_attribute_value('abstract', True) params = cls._get_param_desc_from_ast( schema, context.modaliases, astnode) for param in params: # as_create_delta requires the specific type cmd.add_prerequisite(param.as_create_delta( schema, cmd.classname, context=context)) # There are several possibilities for abstract indexes: # 1) base abstract index # 2) an abstract index extending another one # 3) an abstract index listing index fallback alternatives if astnode.bases is None: if astnode.index_types is None: # This actually defines a new index (1). pass else: # This is for index fallback alternatives (3). raise NotImplementedError("Index fallback not implemented") else: # Extending existing indexes for composition (2). kwargs = cls._index_kwargs_from_ast(schema, astnode, context) if kwargs: cmd.set_attribute_value('kwargs', kwargs) elif isinstance(astnode, qlast.CreateConcreteIndex): orig_text = cls.get_orig_expr_text(schema, astnode, 'expr') if ( orig_text is not None and context.compat_ver_is_before( (1, 0, verutils.VersionStage.ALPHA, 6) ) ): # Versions prior to a6 used a different expression # normalization strategy, so we must renormalize the # expression. expr_ql = qlcompiler.renormalize_compat( astnode.expr, orig_text, schema=schema, localnames=context.localnames, ) else: expr_ql = astnode.expr kwargs = cls._index_kwargs_from_ast(schema, astnode, context) if kwargs: cmd.set_attribute_value('kwargs', kwargs) cmd.set_attribute_value( 'expr', s_expr.Expression.from_ast( expr_ql, schema, context.modaliases, ), ) if astnode.except_expr: cmd.set_attribute_value( 'except_expr', s_expr.Expression.from_ast( astnode.except_expr, schema, context.modaliases, ), ) if astnode.deferred: cmd.set_attribute_value( 'deferred', astnode.deferred, span=astnode.span, ) if cmd.get_attribute_span('build_concurrently'): cmd.set_attribute_value( 'active', False, span=astnode.span, ) return cmd @classmethod def as_inherited_ref_cmd( cls, *, schema: s_schema.Schema, context: sd.CommandContext, astnode: qlast.ObjectDDL, bases: list[Index], referrer: so.Object, ) -> sd.ObjectCommand[Index]: cmd = super().as_inherited_ref_cmd( schema=schema, context=context, astnode=astnode, bases=bases, referrer=referrer, ) assert isinstance(astnode, qlast.ConcreteIndexCommand), astnode if astnode.kwargs: cmd.set_attribute_value( 'kwargs', cls._index_kwargs_from_ast(schema, astnode, context), ) return cmd @classmethod def as_inherited_ref_ast( cls, schema: s_schema.Schema, context: sd.CommandContext, name: sn.Name, parent: referencing.ReferencedObject, ) -> qlast.ObjectDDL: assert isinstance(parent, Index) astnode_cls = cls.referenced_astnode expr = parent.get_expr(schema) assert expr is not None expr_ql = edgeql.parse_fragment(expr.text) except_expr: s_expr.Expression | None = parent.get_except_expr(schema) if except_expr: except_expr_ql = except_expr.parse() else: except_expr_ql = None qlkwargs = { key: val.parse() for key, val in parent.get_kwargs(schema).items() } return astnode_cls( name=cls.get_inherited_ref_name(schema, context, parent, name), kwargs=qlkwargs, expr=expr_ql, except_expr=except_expr_ql, deferred=parent.get_deferred(schema), ) @classmethod def get_inherited_ref_name( cls, schema: s_schema.Schema, context: sd.CommandContext, parent: so.Object, name: sn.Name, ) -> qlast.ObjectRef: bn = sn.shortname_from_fullname(name) return utils.name_to_ast_ref(bn) def _validate_kwargs( self, schema: s_schema.Schema, params: s_func.FuncParameterList, kwargs: s_expr.ExpressionDict, ancestor_name: str, ) -> None: if not kwargs: return if not params: raise errors.SchemaDefinitionError( f'the {ancestor_name} does not support any parameters', span=self.span ) # Make sure that the kwargs are valid. for key in kwargs: expr = kwargs[key] param = params.get_by_name(schema, key) if param is None: raise errors.SchemaDefinitionError( f'the {ancestor_name} does not have a parameter {key!r}', span=self.span ) param_type = param.get_type(schema) comp_expr = s_expr.Expression.compiled( expr, schema=schema, context=None) expr_type = comp_expr.irast.stype if ( not param_type.is_polymorphic(schema) and not expr_type.is_polymorphic(schema) and not expr_type.implicitly_castable_to( param_type, schema) ): raise errors.SchemaDefinitionError( f'the {key!r} parameter of the ' f'{self.get_verbosename()} has type of ' f'{expr_type.get_displayname(schema)} that ' f'is not implicitly castable to the ' f'corresponding parameter of the ' f'{ancestor_name} with type ' f'{param_type.get_displayname(schema)}', span=self.span, ) def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_object(schema, context) referrer_ctx = self.get_referrer_context(context) # Get kwargs if any, so that we can process them later. kwargs = self.get_resolved_attribute_value( 'kwargs', schema=schema, context=context, ) if referrer_ctx is None: # Make sure that all bases are ultimately inherited from the same # root base class. bases = self.get_resolved_attribute_value( 'bases', schema=schema, context=context, ) if bases: # Users can extend abstract indexes. root = None for base in bases.objects(schema): lineage = [base] + list( base.get_ancestors(schema).objects(schema)) if root is None: root = lineage[-1] elif root != lineage[-1]: raise errors.SchemaDefinitionError( f'cannot create {self.get_verbosename()} ' f'because it extends incompatible abstract indxes', span=self.span ) # We should have found a root because we have bases. assert root is not None # Make sure that the kwargs are valid. self._validate_kwargs( schema, root.get_params(schema), kwargs, root.get_verbosename(schema), ) else: # Creating new abstract indexes is only allowed in "EdgeDB # developer" mode, i.e. when populating std library, etc. if not context.stdmode and not context.testmode: raise errors.SchemaDefinitionError( f'cannot create {self.get_verbosename()} ' f'because user-defined abstract indexes are not ' f'supported', span=self.span ) return # The checks below apply only to concrete indexes. subject = referrer_ctx.scls assert isinstance(subject, (s_types.Type, s_pointers.Pointer)) assert isinstance(subject, IndexableSubject) if ( is_object_scope_index(schema, self.scls) and isinstance(subject, s_pointers.Pointer) ): dn = self.scls.get_displayname(schema) raise errors.SchemaDefinitionError( f"{dn} cannot be declared on links", span=self.span, ) # Ensure that the name of the index (if given) matches an existing # abstract index. name = sn.shortname_from_fullname( self.get_resolved_attribute_value( 'name', schema=schema, context=context, ) ) # HACK: the old concrete indexes all have names in form __::idx, but # this should be the actual name provided. Also the index without name # defaults to '__::idx'. if name != DEFAULT_INDEX and ( abs_index := schema.get(name, type=Index) ): # only abstract indexes should have unmangled names assert abs_index.get_abstract(schema) root = abs_index.get_root(schema) # For indexes that can only appear once per object, call # get_effective_object_index for its side-effect of # checking the error. if is_exclusive_object_scope_index(schema, self.scls): effective, others = get_effective_object_index( schema, subject, root.get_name(schema), span=self.span) if effective == self.scls and others: other = others[0] if ( other.get_concrete_kwargs_as_values(schema) != self.scls.get_concrete_kwargs_as_values(schema) ): subject_name = subject.get_verbosename(schema) other_subject = other.get_subject(schema) assert other_subject other_name = other_subject.get_verbosename(schema) raise errors.InvalidDefinitionError( f"{root.get_name(schema)} indexes defined for " f"{subject_name} with different " f"parameters than on base type {other_name}", span=self.span, ) # Make sure that kwargs and parameters match in name and type. # Also make sure that all parameters have values at this point # (either default or provided in kwargs). params = root.get_params(schema) inh_kwargs = self.scls.get_all_kwargs(schema) self._validate_kwargs(schema, params, kwargs, abs_index.get_verbosename(schema)) unused_names = {p.get_parameter_name(schema) for p in params.objects(schema)} if kwargs: unused_names -= set(kwargs) if inh_kwargs: unused_names -= set(inh_kwargs) if unused_names: # Check that all of these parameters have defaults. for pname in list(unused_names): param = params.get_by_name(schema, pname) if param and param.get_default(schema) is not None: unused_names.discard(pname) if unused_names: names = ', '.join(repr(n) for n in sorted(unused_names)) raise errors.SchemaDefinitionError( f'cannot create {self.get_verbosename()} ' f'because the following parameters are still undefined: ' f'{names}.', span=self.span ) # Make sure that the concrete index expression type matches the # abstract index type. expr = self.get_resolved_attribute_value( 'expr', schema=schema, context=context, ) options = qlcompiler.CompilerOptions( anchors={'__subject__': subject}, path_prefix_anchor='__subject__', singletons=frozenset([subject]), apply_query_rewrites=False, schema_object_context=self.get_schema_metaclass(), ) comp_expr = s_expr.Expression.compiled( expr, schema=schema, options=options, context=context ) expr_type = comp_expr.irast.stype if not is_index_valid_for_type( root, expr_type, comp_expr.schema, context, ): hint = None if str(name) == 'std::fts::index': hint = ( 'std::fts::document can be constructed with ' 'std::fts::with_options(str, ...)' ) raise errors.SchemaDefinitionError( f'index expression ({expr.text}) ' f'is not of a valid type for the ' f'{self.scls.get_verbosename(comp_expr.schema)}', span=self.span, details=hint, ) def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._create_begin(schema, context) referrer_ctx = self.get_referrer_context(context) if ( referrer_ctx is not None and not context.canonical and is_ext_ai_index(schema, self.scls) ): schema = self._inject_ext_ai_model_dependency(schema, context) return schema def _create_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: referrer_ctx = self.get_referrer_context(context) if ( referrer_ctx is not None and not context.canonical and is_ext_ai_index(schema, self.scls) ): self._copy_ext_ai_model_annotations(schema, context) return super()._create_innards(schema, context) def get_resolved_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> dict[str, Any]: params = self._get_params(schema, context) props = super().get_resolved_attributes(schema, context) props['params'] = params return props @classmethod def _classbases_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> list[so.ObjectShell[Index]]: if ( isinstance(astnode, qlast.CreateConcreteIndex) and astnode.name and astnode.name.module != DEFAULT_INDEX.module and astnode.name.name != DEFAULT_INDEX.name ): base = utils.ast_objref_to_object_shell( astnode.name, metaclass=Index, schema=schema, modaliases=context.modaliases, ) return [base] else: return super()._classbases_from_ast(schema, astnode, context) def _inject_ext_ai_model_dependency( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: model_stype = self._get_referenced_embedding_model(schema, context) type_args = so.ObjectList.create( schema, [model_stype], ) self.set_attribute_value( "type_args", type_args.as_shell(schema), ) return self.scls.update(schema, {"type_args": type_args}) def _copy_ext_ai_model_annotations( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: # Copy ext::ai:: annotations declared on the model specified # by the `embedding_model` kwarg. This is necessary to avoid # expensive lookups later where the index is used. model_stype = self._get_referenced_embedding_model(schema, context) model_stype_vn = model_stype.get_verbosename(schema) model_annos = model_stype.get_annotations(schema) my_name = self.scls.get_name(schema) idx_defined_here = self.scls.is_defined_here(schema) for model_anno in model_annos.objects(schema): anno_name = model_anno.get_shortname(schema) if anno_name.module != "ext::ai": continue value = model_anno.get_value(schema) if value is None or value == "": raise errors.SchemaDefinitionError( f"{model_stype_vn} is missing a value for the " f"'{anno_name}' annotation" ) anno_sname = sn.get_specialized_name( anno_name, str(my_name), ) anno_fqname = sn.QualName(my_name.module, anno_sname) schema1 = model_anno.update( schema, { "name": anno_fqname, "subject": self.scls, }, ) anno_copy = schema1.get( anno_fqname, type=s_anno.AnnotationValue, ) anno_cmd: sd.ObjectCommand[s_anno.AnnotationValue] if idx_defined_here: anno_cmd = anno_copy.as_create_delta( schema1, so.ComparisonContext()) anno_cmd.discard_attribute("bases") anno_cmd.discard_attribute("ancestors") else: anno_cmd = sd.get_object_delta_command( objtype=s_anno.AnnotationValue, cmdtype=sd.AlterObject, schema=schema, name=anno_fqname, ) anno_cmd.set_attribute_value("owned", True) self.add(anno_cmd) model_dimensions = model_stype.must_get_json_annotation( schema, sn.QualName("ext::ai", "embedding_model_max_output_dimensions"), int, ) supports_shortening = model_stype.must_get_json_annotation( schema, sn.QualName("ext::ai", "embedding_model_supports_shortening"), bool, ) kwargs = self.scls.get_concrete_kwargs_as_values(schema) specified_dimensions = kwargs["dimensions"] MAX_DIM = 2000 # pgvector limit if specified_dimensions is None: if model_dimensions > MAX_DIM: if not supports_shortening: raise errors.SchemaDefinitionError( f"{model_stype_vn} returns embeddings with over " f"{MAX_DIM} dimensions, does not support embedding " f"shortening, and thus cannot be used with " f"this index", span=self.span, ) else: dimensions = MAX_DIM else: dimensions = model_dimensions else: if specified_dimensions > MAX_DIM: raise errors.SchemaDefinitionError( f"cannot use more than {MAX_DIM} dimensions with " f"this index", span=self.span, ) elif specified_dimensions > model_dimensions: raise errors.SchemaDefinitionError( f"{model_stype_vn} does not support more than " f"{model_dimensions} dimensions, " f"got {specified_dimensions}", span=self.span, ) elif ( specified_dimensions != model_dimensions and not supports_shortening ): raise errors.SchemaDefinitionError( f"{model_stype_vn} returns embeddings with over " f"{model_dimensions} dimensions, and does not support " f"embedding shortening, and thus {specified_dimensions} " f"cannot be used for this index", span=self.span, ) else: dimensions = specified_dimensions dims_anno_sname = sn.get_specialized_name( sn.QualName("ext::ai", "embedding_dimensions"), str(my_name), ) alter_anno = sd.get_object_delta_command( objtype=s_anno.AnnotationValue, cmdtype=sd.AlterObject, schema=schema, name=sn.QualName(my_name.module, dims_anno_sname), ) alter_anno.set_attribute_value("value", str(dimensions)) alter_anno.set_attribute_value("owned", True) self.add(alter_anno) def _get_referenced_embedding_model( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_objtypes.ObjectType: # Copy ext::ai:: annotations declared on the model specified # by the `embedding_model` kwarg. This is necessary to avoid # expensive lookups later where the index is used. kwargs = self.scls.get_concrete_kwargs_as_values(schema) model_name = kwargs["embedding_model"] models = get_defined_ext_ai_embedding_models(schema, model_name) if len(models) == 0: raise errors.SchemaDefinitionError( f'undefined embedding model: no subtype of ' f'ext::ai::EmbeddingModel is annotated as {model_name!r}', span=self.span, ) elif len(models) > 1: models_dn = [ model.get_displayname(schema) for model in models.values() ] raise errors.SchemaDefinitionError( f'expecting only one embedding model to be annotated ' f'with ext::ai::model_name={model_name!r}: got multiple: ' f'{", ".join(models_dn)}', span=self.span, ) return next(iter(models.values())) class RenameIndex( IndexCommand, referencing.RenameReferencedInheritingObject[Index], ): @classmethod def _cmd_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> RenameIndex: return cast( RenameIndex, super()._cmd_from_ast(schema, astnode, context), ) class AlterIndexOwned( IndexCommand, referencing.AlterOwned[Index], field='owned', ): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) referrer_ctx = self.get_referrer_context(context) if ( referrer_ctx is not None and not context.canonical and is_ext_ai_index(schema, self.scls) ): schema = self._fixup_ext_ai_model_annotations(schema, context) return schema def _fixup_ext_ai_model_annotations( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: # Fixup the special ext::ai annotations that got copied to an # ai index. They are always owned, even if the index is not, # and so we have some hackiness to keep that true when DROP OWNED # is run on the index. # TODO: Can this be rationalized more? for ref in self.scls.get_annotations(schema).objects(schema): anno_name = ref.get_shortname(schema) if anno_name.module != "ext::ai": continue alter = ref.init_delta_command(schema, sd.AlterObject) alter.set_attribute_value('owned', True) if anno_name.name == 'embedding_dimensions': alter.set_attribute_value( 'value', ref.get_value(schema), inherited=False) schema = alter.apply(schema, context) self.add(alter) return schema class AlterIndex( IndexCommand, referencing.AlterReferencedInheritingObject[Index], ): astnode = [qlast.AlterConcreteIndex, qlast.AlterIndex] referenced_astnode = qlast.AlterConcreteIndex def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_object(schema, context) vn = self.scls.get_verbosename(schema, with_parent=True) if ( not self.scls.get_build_concurrently(schema) and not self.scls.get_active(schema) ): raise errors.SchemaDefinitionError( f'{vn} is not active, so build_concurrently may ' f'not be cleared', span=self.span, ) def canonicalize_alter_from_external_ref( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: if ( not self.get_attribute_value('abstract') and (indexexpr := self.get_attribute_value('expr')) is not None ): assert isinstance(indexexpr, s_expr.Expression) # To compute the new name, we construct an AST of the # index, since that is the infrastructure we have for # computing the classname. name = sn.shortname_from_fullname(self.classname) assert isinstance(name, sn.QualName), "expected qualified name" ast = qlast.CreateConcreteIndex( name=qlast.ObjectRef(name=name.name, module=name.module), expr=indexexpr.parse(), ) quals = sn.quals_from_fullname(self.classname) new_name = self._classname_from_ast_and_referrer( schema, sn.QualName.from_string(quals[0]), ast, context) if new_name == self.classname: return rename = self.scls.init_delta_command( schema, sd.RenameObject, new_name=new_name) rename.set_attribute_value( 'name', value=new_name, orig_value=self.classname) self.add(rename) class DeleteIndex( IndexCommand, referencing.DeleteReferencedInheritingObject[Index], ): astnode = [qlast.DropConcreteIndex, qlast.DropIndex] referenced_astnode = qlast.DropConcreteIndex def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._delete_begin(schema, context) if not context.canonical: for param in self.scls.get_params(schema).objects(schema): self.add(param.init_delta_command(schema, sd.DeleteObject)) return schema @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) if isinstance(astnode, qlast.ConcreteIndexCommand): cmd.set_attribute_value( 'expr', s_expr.Expression.from_ast( astnode.expr, schema, context.modaliases), ) return cmd class RebaseIndex( IndexCommand, referencing.RebaseReferencedInheritingObject[Index], ): pass def get_effective_object_index( schema: s_schema.Schema, subject: IndexableSubject, base_idx_name: sn.QualName, span: Optional[parsing.Span] = None, ) -> tuple[Optional[Index], Sequence[Index]]: """ Returns the effective index of a subject and any overridden fs indexes """ indexes: so.ObjectIndexByFullname[Index] = subject.get_indexes(schema) base = schema.get(base_idx_name, type=Index, default=None) if base is None: # Abstract base index does not exist. return (None, ()) object_indexes = [ ind for ind in indexes.objects(schema) if ind.issubclass(schema, base) ] if len(object_indexes) == 0: return (None, ()) object_indexes_defined_here = [ ind for ind in object_indexes if ind.is_defined_here(schema) ] if len(object_indexes_defined_here) > 0: # indexes defined here have priority if len(object_indexes_defined_here) > 1: subject_name = subject.get_displayname(schema) raise errors.InvalidDefinitionError( f'multiple {base_idx_name} indexes defined for {subject_name}', span=span, ) effective = object_indexes_defined_here[0] overridden = [ i for i in object_indexes if i != effective ] else: # there are no object-scoped indexes defined on the subject # the inherited indexes take effect if len(object_indexes) > 1: subject_name = subject.get_displayname(schema) raise errors.InvalidDefinitionError( f'multiple {base_idx_name} indexes ' f'inherited for {subject_name}', span=span, ) effective = object_indexes[0] overridden = [] return (effective, overridden) class IndexMatchCommand(sd.QualifiedObjectCommand[IndexMatch], context_class=IndexMatchCommandContext): @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: if not context.stdmode and not context.testmode: raise errors.UnsupportedFeatureError( 'user-defined index matches are not supported', span=astnode.span ) return super()._cmd_tree_from_ast(schema, astnode, context) @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> sn.QualName: assert isinstance(astnode, qlast.IndexMatchCommand) modaliases = context.modaliases valid_type = utils.ast_to_type_shell( astnode.valid_type, metaclass=s_types.Type, modaliases=modaliases, schema=schema, ) index = utils.ast_objref_to_object_shell( astnode.name, metaclass=Index, modaliases=context.modaliases, schema=schema, ) return get_index_match_fullname(schema, valid_type, index) def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) schema = s_types.materialize_type_in_attribute( schema, context, self, 'valid_type') return schema class CreateIndexMatch(IndexMatchCommand, sd.CreateObject[IndexMatch]): astnode = qlast.CreateIndexMatch def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: fullname = self.classname index_match = schema.get(fullname, None) if index_match: valid_type = self.get_attribute_value('valid_type') index = self.get_attribute_value('index') raise errors.DuplicateDefinitionError( f'an index match for {valid_type.get_displayname(schema)!r} ' f'using {index.get_displayname(schema)!r} is already defined', span=self.span) return super()._create_begin(schema, context) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(astnode, qlast.CreateIndexMatch) modaliases = context.modaliases valid_type = utils.ast_to_type_shell( astnode.valid_type, metaclass=s_types.Type, modaliases=modaliases, schema=schema, ) cmd.set_attribute_value('valid_type', valid_type) index = utils.ast_objref_to_object_shell( qlast.ObjectRef( module=astnode.name.module, name=astnode.name.name, ), metaclass=Index, modaliases=context.modaliases, schema=schema, ) cmd.set_attribute_value('index', index) return cmd def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: assert isinstance(node, qlast.CreateIndexMatch) new_value: Any = op.new_value if op.property == 'valid_type': # In an index match we can only have pure types, so this is going # to be a TypeName. node.valid_type = cast(qlast.TypeName, utils.typeref_to_ast(schema, new_value)) else: super()._apply_field_ast(schema, context, node, op) class DeleteIndexMatch(IndexMatchCommand, sd.DeleteObject[IndexMatch]): astnode = qlast.DropIndexMatch def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._delete_begin(schema, context) if not context.canonical: valid_type = self.scls.get_valid_type(schema) if op := valid_type.as_type_delete_if_unused(schema): self.add_caused(op) return schema # XXX: the below hardcode should be replaced by an index scope # field instead. def is_object_scope_index( schema: s_schema.Schema, index: Index, ) -> bool: return ( is_fts_index(schema, index) or is_ext_ai_index(schema, index) ) def is_exclusive_object_scope_index( schema: s_schema.Schema, index: Index, ) -> bool: return is_object_scope_index(schema, index) def is_fts_index( schema: s_schema.Schema, index: Index, ) -> bool: fts_index = schema.get(sn.QualName("std::fts", "index"), type=Index) return index.issubclass(schema, fts_index) def get_ai_index_id( schema: s_schema.Schema, index: Index, ) -> str: # TODO: Use the model name? return f'base' def is_ext_ai_index( schema: s_schema.Schema, index: Index, ) -> bool: ai_index = schema.get( sn.QualName("ext::ai", "index"), type=Index, default=None, ) if ai_index is None: return False else: return index.issubclass(schema, ai_index) _embedding_model = sn.QualName("ext::ai", "EmbeddingModel") _model_name = sn.QualName("ext::ai", "model_name") def get_defined_ext_ai_embedding_models( schema: s_schema.Schema, model_name: Optional[str] = None, ) -> dict[str, s_objtypes.ObjectType]: from . import objtypes as s_objtypes base_embedding_model = schema.get( _embedding_model, type=s_objtypes.ObjectType, ) def _flt( schema: s_schema.Schema, anno: s_anno.AnnotationValue, ) -> bool: if anno.get_shortname(schema) != _model_name: return False subject = anno.get_subject(schema) value = anno.get_value(schema) return ( value is not None and value != "" and (model_name is None or anno.get_value(schema) == model_name) and isinstance(subject, s_objtypes.ObjectType) and subject.issubclass(schema, base_embedding_model) ) annos = schema.get_objects( type=s_anno.AnnotationValue, extra_filters=(_flt,), ) result = {} for anno in annos: subject = anno.get_subject(schema) assert isinstance(subject, s_objtypes.ObjectType) result[anno.get_value(schema)] = subject return result ================================================ FILE: edb/schema/inheriting.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, AbstractSet, Iterable, Mapping, Sequence, cast, TYPE_CHECKING, ) from edb import errors from edb.common import span as edb_span from edb.common import struct from edb.edgeql import ast as qlast from edb.schema import schema as s_schema from . import delta as sd from . import expr as s_expr from . import name as sn from . import objects as so from . import utils if TYPE_CHECKING: from edb.schema import referencing as s_referencing class InheritingObjectCommand[InheritingObjectT: so.InheritingObject]( sd.ObjectCommand[InheritingObjectT] ): def _update_inherited_fields( self, schema: s_schema.Schema, context: sd.CommandContext, update: Mapping[str, bool], ) -> None: raise NotImplementedError def update_field_status( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().update_field_status(schema, context) inherited_status = self.compute_inherited_fields(schema, context) self._update_inherited_fields(schema, context, inherited_status) def compute_inherited_fields( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> dict[str, bool]: result = {} mcls = self.get_schema_metaclass() for op in self.get_subcommands(type=sd.AlterObjectProperty): field = mcls.get_field(op.property) if field.inheritable and not field.ephemeral: result[op.property] = op.new_inherited return result def inherit_fields( self, schema: s_schema.Schema, context: sd.CommandContext, bases: tuple[so.Object, ...], *, fields: Optional[Iterable[str]] = None, ignore_local: bool = False, apply: bool = True, ) -> s_schema.Schema: from . import referencing as s_referencing # HACK: Don't inherit fields if the command comes from # expression change propagation. It shouldn't be necessary, # and can cause a knock-on bug: when aliases directly refer to # another alias, they *incorrectly* have 'expr' marked as an # inherited_field, which causes trouble here. # Fixing this in 3.x/4.x would require a schema repair, though. if self.from_expr_propagation: return schema mcls = self.get_schema_metaclass() scls = self.scls is_owned = ( isinstance(scls, s_referencing.ReferencedObject) and scls.get_owned(schema) ) field_names: Iterable[str] if fields is not None: field_names = set(scls.inheritable_fields()) & set(fields) else: field_names = set(scls.inheritable_fields()) inherited_fields = scls.get_inherited_fields(schema) inherited_fields_update = {} deferred_complex_ops = [] # Iterate over mcls.get_schema_fields() instead of field_names for # determinism reasons, and so earlier declared fields get # processed first. for field_name, field in mcls.get_schema_fields().items(): if field_name not in field_names: continue was_inherited = field_name in inherited_fields ignore_local_field = ignore_local or was_inherited try: result = field.merge_fn( scls, bases, field_name, ignore_local=ignore_local_field, schema=schema, ) except (errors.SchemaDefinitionError, errors.SchemaError) as e: if (span := self.get_attribute_span(field_name)): e.set_span(span) raise if not ignore_local_field: ours = scls.get_explicit_field_value(schema, field_name, None) else: ours = None inherited = result is not None and ours is None inherited_fields_update[field_name] = inherited if ( ( result != ours or inherited or (was_inherited and not is_owned) ) or ( result is None and ours is None and ignore_local ) ): if ( inherited and not context.transient_derivation ): if isinstance(result, s_expr.Expression): result = self.compile_expr_field( schema, context, field=field, value=result) elif isinstance(result, s_expr.ExpressionDict): compiled = {} for k, v in result.items(): if not v.is_compiled(): v = self.compile_expr_field( schema, context, field, v) compiled[k] = v result = compiled sav = self.set_attribute_value( field_name, result, inherited=inherited) if isinstance(sav, sd.AlterObjectProperty): schema = self.scls.set_field_value( schema, field_name, result) else: # If this isn't a simple AlterObjectProperty, postpone # its application to _after_ _update_inherited_fields # so that the inherited_fields computation is correct, # as each non-trivial AlterSpecialObjectField operation # updates inherited_fields. deferred_complex_ops.append(sav) self._update_inherited_fields( schema, context, inherited_fields_update) if self.has_attribute_value("inherited_fields"): schema = self.scls.set_field_value( schema, "inherited_fields", self.get_attribute_value("inherited_fields"), ) # In some cases, self will be applied later if apply: for op in deferred_complex_ops: schema = op.apply(schema, context) return schema def get_inherited_ref_layout( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict ) -> dict[ sn.QualName, tuple[ type[ s_referencing.CreateReferencedInheritingObject[ s_referencing.ReferencedInheritingObject ] ], qlast.ObjectDDL, list[s_referencing.ReferencedInheritingObject], ], ]: from . import referencing as s_referencing attr = refdict.attr bases = self.scls.get_bases(schema) refs: dict[ sn.QualName, tuple[ type[ s_referencing.CreateReferencedInheritingObject[ s_referencing.ReferencedInheritingObject ] ], qlast.ObjectDDL, list[s_referencing.ReferencedInheritingObject], ], ] = {} ancestors = set(self.scls.get_ancestors(schema).objects(schema)) for base in bases.objects(schema) + (self.scls,): base_refs: dict[ sn.Name, s_referencing.ReferencedInheritingObject, ] = dict(base.get_field_value(schema, attr).items(schema)) # Pointers can reference each other if they are computed, # and if they are processed in the wrong order, # recompiling expressions in inherit_field can break, so # we need to sort them by cross refs. # Since inherit_fields doesn't recompile expressions # in transient derivations, we skip the sorting there. if not context.transient_derivation: rev_refs = {v: k for k, v in base_refs.items()} base_refs = { rev_refs[v]: v for v in sd.sort_by_cross_refs(schema, base_refs.values()) } # HACK: Because of issue #5661, we previously did not always # properly discover dependencies on __type__ in computeds. # This was fixed, but it may persist in existing databases. # Currently, expr refs are not compared when diffing schemas, # so a schema repair can't fix this. Thus, in addition to # actually fixing the bug, we hack around it by forcing # __type__ to sort to the front. # TODO: Drop this after cherry-picking. if (tname := sn.UnqualName('__type__')) in base_refs: base_refs[tname] = base_refs.pop(tname) for k, v in reversed(base_refs.items()): if not v.should_propagate(schema): continue if base == self.scls and not v.get_owned(schema): continue mcls = type(v) create_cmd = sd.get_object_command_class_or_die( sd.CreateObject, mcls) assert issubclass( create_cmd, s_referencing.CreateReferencedInheritingObject, ) astnode = create_cmd.as_inherited_ref_ast( schema, context, k, v) fqname = create_cmd._classname_from_ast( schema, astnode, context) if fqname not in refs: refs[fqname] = (create_cmd, astnode, []) objs = refs[fqname][2] if base != self.scls: objs.append(v) elif not objs: # If we are looking at refs in the base object # itself, look at the bases of the ref. Any bases # that we haven't seen already while looking in # our object bases must be refs to into objects # that have been dropped from our bases. # # To find which bases to keep, we traverse the # base graph looking for objects with referrers in # our new ancestor set. work = list(reversed(v.get_bases(schema).objects(schema))) while work: vbase = work.pop() subj = vbase.get_referrer(schema) if vbase in objs: continue elif subj is None or subj in ancestors: objs.append(vbase) else: work.extend( reversed( vbase.get_bases(schema).objects(schema))) return refs def get_no_longer_inherited_ref_layout( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, present_refs: AbstractSet[sn.QualName], ) -> dict[sn.Name, type[sd.ObjectCommand[so.Object]]]: from . import referencing as s_referencing local_refs = self.scls.get_field_value(schema, refdict.attr) dropped_refs: dict[sn.Name, type[sd.ObjectCommand[so.Object]]] = {} for k, v in local_refs.items(schema): if not v.get_owned(schema): mcls = type(v) create_cmd = sd.get_object_command_class_or_die( sd.CreateObject, mcls) assert issubclass( create_cmd, s_referencing.CreateReferencedObject, ) astnode = create_cmd.as_inherited_ref_ast( schema, context, k, v) fqname = create_cmd._classname_from_ast( schema, astnode, context) if fqname not in present_refs: delete_cmd = sd.get_object_command_class_or_die( sd.DeleteObject, mcls) dropped_refs[fqname] = delete_cmd return dropped_refs def _fixup_inheritance_refdicts( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: # HACK?: Derived object types and pointers are created # with inheritance_refdicts={'pointers'}, and typically # don't get persisted. However, for globals and aliases, # they *do* get persisted, and will be altered if a parent # is modified. Make sure those alters are also executed # with a restricted inheritance_refdicts or else whether # things like constraints are created on derived views # will be ordering dependent. # TODO: Clean this up--maybe make it driven explicitly by # is_derived, always? if self.scls.get_is_derived(schema): context.current().inheritance_refdicts = {'pointers'} context.current().inheritance_merge = True def _recompute_inheritance( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: from . import ordering scls = self.scls mcls = type(scls) orig_rec = context.current().enable_recursion context.current().enable_recursion = False new_ancestors = so.ObjectList[InheritingObjectT].create( schema, so.compute_ancestors(schema, scls), ) self.set_attribute_value( 'ancestors', new_ancestors, orig_value=scls.get_ancestors(schema), ) schema = scls.set_field_value(schema, 'ancestors', new_ancestors) bases = scls.get_bases(schema).objects(schema) schema = self.inherit_fields(schema, context, bases) deleted_refs = {} for refdict in mcls.get_refdicts(): if _needs_refdict(refdict, context): schema, deleted = self._reinherit_classref_dict( schema, context, refdict) deleted_refs.update(deleted) # Finalize the deletes. We need to linearize them, since they might # have dependencies between them. root = sd.DeltaRoot() for fqname, delete_cmd_cls in deleted_refs.items(): root.add(delete_cmd_cls(classname=fqname)) root = ordering.linearize_delta(root, schema, schema) schema = root.apply(schema, context) self.update(root.get_subcommands()) context.current().enable_recursion = orig_rec return schema def _reinherit_classref_dict( self: InheritingObjectCommand[InheritingObjectT], schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, ) -> tuple[s_schema.Schema, dict[sn.Name, type[sd.ObjectCommand[so.Object]]]]: from edb.schema import referencing as s_referencing scls = self.scls refs = self.get_inherited_ref_layout(schema, context, refdict) refnames = set(refs) obj_op: InheritingObjectCommand[InheritingObjectT] if isinstance(self, sd.AlterObjectFragment): obj_op = cast(InheritingObjectCommand[InheritingObjectT], self.get_parent_op(context)) else: obj_op = self for refalter in obj_op.get_subcommands(metaclass=refdict.ref_cls): if refalter.get_attribute_value('owned'): assert isinstance(refalter, sd.QualifiedObjectCommand) refnames.add(refalter.classname) deleted_refs = self.get_no_longer_inherited_ref_layout( schema, context, refdict, refnames) group = sd.CommandGroup() for create_cmd, astnode, bases in refs.values(): cmd = create_cmd.as_inherited_ref_cmd( schema=schema, context=context, astnode=astnode, bases=bases, referrer=scls, ) obj = schema.get(cmd.classname, default=None) if obj is None: cmd.set_attribute_value(refdict.backref_attr, scls) group.add(cmd) schema = cmd.apply(schema, context) else: assert isinstance(obj, s_referencing.ReferencedInheritingObject) existing_bases = obj.get_implicit_bases(schema) schema, cmd2 = self._rebase_ref( schema, context, obj, tuple(existing_bases), tuple(bases)) group.add(cmd2) self.add(group) return schema, deleted_refs def _rebase_ref_cmd( self, schema: s_schema.Schema, context: sd.CommandContext, scls: s_referencing.ReferencedInheritingObject, old_bases: Sequence[so.InheritingObject], new_bases: Sequence[so.InheritingObject], ) -> tuple[sd.Command, Optional[sd.Command]]: from . import referencing as s_referencing old_base_names = [b.get_name(schema) for b in old_bases] new_base_names = [b.get_name(schema) for b in new_bases] removed, added = delta_bases( old_base_names, new_base_names, t=type(scls), ) rebase = sd.get_object_command_class( RebaseInheritingObject, type(scls)) alter_cmd_root, alter_cmd, _ = ( scls.init_delta_branch(schema, context, sd.AlterObject)) assert isinstance(alter_cmd, AlterInheritingObject) new_bases_coll = so.ObjectList.create(schema, new_bases) schema = scls.set_field_value(schema, 'bases', new_bases_coll) ancestors = so.compute_ancestors(schema, scls) ancestors_coll = so.ObjectList[ s_referencing.ReferencedInheritingObject].create(schema, ancestors) if rebase is not None: rebase_cmd = rebase( classname=scls.get_name(schema), removed_bases=removed, added_bases=added, ) rebase_cmd.set_attribute_value( 'bases', new_bases_coll, ) rebase_cmd.set_attribute_value( 'ancestors', ancestors_coll, ) alter_cmd.add(rebase_cmd) alter_cmd.set_attribute_value( 'bases', new_bases_coll, ) alter_cmd.set_attribute_value( 'ancestors', ancestors_coll, ) return alter_cmd_root, rebase_cmd def _rebase_ref( self, schema: s_schema.Schema, context: sd.CommandContext, scls: s_referencing.ReferencedInheritingObject, old_bases: Sequence[so.InheritingObject], new_bases: Sequence[so.InheritingObject], ) -> tuple[s_schema.Schema, sd.Command]: alter_cmd_root, _ = self._rebase_ref_cmd( schema, context, scls, old_bases, new_bases) schema = alter_cmd_root.apply(schema, context) return schema, alter_cmd_root @classmethod def _classbases_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> list[so.ObjectShell[InheritingObjectT]]: modaliases = context.modaliases base_refs = [] for b in getattr(astnode, 'bases', None) or []: obj = utils.ast_to_object_shell( b, modaliases=modaliases, schema=schema, metaclass=cls.get_schema_metaclass(), ) base_refs.append(obj) classname = cls._classname_from_ast(schema, astnode, context) mcls = cls.get_schema_metaclass() if not base_refs and classname not in mcls.get_root_classes(): default_base = mcls.get_default_base_name() if default_base is not None and classname != default_base: base_refs.append( utils.ast_objref_to_object_shell( utils.name_to_ast_ref(default_base), metaclass=cls.get_schema_metaclass(), schema=schema, modaliases=modaliases, ) ) return base_refs def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if ( field in {'abstract'} and issubclass(astnode, qlast.CreateObject) ): return field else: return super().get_ast_attr_for_field(field, astnode) BaseDeltaItem_T = tuple[ list[so.ObjectShell[so.InheritingObjectT]], str | tuple[str, so.ObjectShell[so.InheritingObjectT]], ] BaseDelta_T = tuple[ tuple[so.ObjectShell[so.InheritingObjectT], ...], tuple[BaseDeltaItem_T[so.InheritingObjectT], ...], ] def delta_bases[InheritingObjectT: so.InheritingObject]( old_bases: Iterable[sn.Name], new_bases: Iterable[sn.Name], t: type[InheritingObjectT], ) -> BaseDelta_T[InheritingObjectT]: dropped = frozenset(old_bases) - frozenset(new_bases) removed_bases = [so.ObjectShell(name=b, schemaclass=t) for b in dropped] common_bases = [b for b in old_bases if b not in dropped] added_bases: list[BaseDeltaItem_T[InheritingObjectT]] = [] j = 0 added_set = set() added_base_refs: list[so.ObjectShell[InheritingObjectT]] = [] if common_bases: for base in new_bases: if common_bases[j] == base: # Found common base, insert the accumulated # list of new bases and continue if added_base_refs: ref = so.ObjectShell(name=common_bases[j], schemaclass=t) added_bases.append((added_base_refs, ('BEFORE', ref))) added_base_refs = [] j += 1 if j >= len(common_bases): break else: continue # Base has been inserted at position j added_base_refs.append(so.ObjectShell(name=base, schemaclass=t)) added_set.add(base) # Finally, add all remaining bases to the end of the list tail_bases = added_base_refs + [ so.ObjectShell(name=b, schemaclass=t) for b in new_bases if b not in added_set and b not in common_bases ] if tail_bases: added_bases.append((tail_bases, 'LAST')) return tuple(removed_bases), tuple(added_bases) class AlterInherit[InheritingObjectT: so.InheritingObject](sd.Command): astnode = qlast.AlterAddInherit, qlast.AlterDropInherit # We temporarily record information about inheritance alterations # here, before converting these into Rebases in AlterObject. The # goal here is to encode the information in the subcommand stream, # so the positioning is maintained. added_bases = struct.Field(list[tuple[ list[so.ObjectShell[InheritingObjectT]], Optional[str | tuple[str, so.ObjectShell[InheritingObjectT]]], ]]) dropped_bases = struct.Field(list[so.ObjectShell[InheritingObjectT]]) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astcmd: qlast.DDLOperation, context: sd.CommandContext, ) -> Any: added_bases = [] dropped_bases: list[so.ObjectShell[InheritingObjectT]] = [] parent_op = context.current().op assert isinstance(parent_op, sd.ObjectCommand) parent_mcls = parent_op.get_schema_metaclass() if isinstance(astcmd, qlast.AlterDropInherit): dropped_bases.extend( utils.ast_to_object_shell( b, metaclass=parent_mcls, modaliases=context.modaliases, schema=schema, ) for b in astcmd.bases ) elif isinstance(astcmd, qlast.AlterAddInherit): bases = [ utils.ast_to_object_shell( b, metaclass=parent_mcls, modaliases=context.modaliases, schema=schema, ) for b in astcmd.bases ] pos_node = astcmd.position pos: Optional[ str | tuple[str, so.ObjectShell[InheritingObjectT]] ] if pos_node is not None: if pos_node.ref is not None: ref = so.ObjectShell( name=utils.ast_ref_to_name(pos_node.ref), schemaclass=parent_mcls, ) pos = (pos_node.position, ref) else: pos = pos_node.position else: pos = None added_bases.append((bases, pos)) # AlterInheritingObject will turn sequences of AlterInherit # into proper RebaseWhatever commands. return AlterInherit( added_bases=added_bases, dropped_bases=dropped_bases) class CreateInheritingObject[InheritingObjectT: so.InheritingObject]( InheritingObjectCommand[InheritingObjectT], sd.CreateObject[InheritingObjectT], ): def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) bases_coll = self.get_resolved_attribute_value( 'bases', schema=schema, context=context) bases = () if bases_coll is None else bases_coll.objects(schema) ancestors = so.compute_lineage(schema, bases, self.get_verbosename()) ancestors_coll = so.ObjectList[InheritingObjectT].create( schema, ancestors) self.set_attribute_value('ancestors', ancestors_coll.as_shell(schema)) if context.mark_derived: self.set_attribute_value('is_derived', True) return schema def _update_inherited_fields( self, schema: s_schema.Schema, context: sd.CommandContext, update: Mapping[str, bool], ) -> None: inherited_fields = {n for n, v in update.items() if v} if inherited_fields: self.set_attribute_value( 'inherited_fields', frozenset(inherited_fields)) def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext ) -> s_schema.Schema: schema = super()._create_begin(schema, context) if not context.canonical: if context.inheritance_merge is None or context.inheritance_merge: bases_coll = self.get_resolved_attribute_value( 'bases', schema=schema, context=context) if bases_coll is not None: bases = bases_coll.objects(schema) else: bases = () schema = self.inherit_fields(schema, context, bases) return schema def _create_innards( self, schema: s_schema.Schema, context: sd.CommandContext ) -> s_schema.Schema: if not context.canonical: cmd = sd.CommandGroup() mcls = self.get_schema_metaclass() for refdict in mcls.get_refdicts(): if _needs_refdict(refdict, context): cmd.add(self.inherit_classref_dict( schema, context, refdict)) self.prepend(cmd) return super()._create_innards(schema, context) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(astnode, qlast.ObjectDDL) bases = cls._classbases_from_ast(schema, astnode, context) spans = [b.span for b in bases if b.span is not None] if spans: span = edb_span.merge_spans(spans) else: span = None cmd.set_attribute_value( 'bases', so.ObjectCollectionShell(bases, collection_type=so.ObjectList), span=span, ) return cmd def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: if op.property == 'bases': explicit_bases = self.get_explicit_bases( schema, context, op.new_value) if explicit_bases: if isinstance(node, qlast.CreateObject): if isinstance(node, qlast.BasedOn): node.bases = [ qlast.TypeName(maintype=utils.name_to_ast_ref(b)) for b in explicit_bases ] else: node.commands.append( qlast.AlterAddInherit( bases=[ qlast.TypeName( maintype=utils.name_to_ast_ref(b), ) for b in explicit_bases ], ) ) else: if isinstance(node, qlast.CreateObject): if isinstance(node, qlast.BasedOn): node.bases = [] else: super()._apply_field_ast(schema, context, node, op) def get_explicit_bases( self, schema: s_schema.Schema, context: sd.CommandContext, bases: Any, ) -> list[sn.Name]: mcls = self.get_schema_metaclass() default_base = mcls.get_default_base_name() base_names: list[sn.Name] if isinstance(bases, so.ObjectCollectionShell): base_names = [] for b in bases.items: assert b.name is not None base_names.append(b.name) else: assert isinstance(bases, so.ObjectList) base_names = list(bases.names(schema)) # Filter out implicit bases explicit_bases = [ b for b in base_names if ( b != default_base and ( not isinstance(b, sn.QualName) or sn.shortname_from_fullname(b) == b ) ) ] return explicit_bases def inherit_classref_dict( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, ) -> sd.CommandGroup: scls = self.scls refs = self.get_inherited_ref_layout(schema, context, refdict) group = sd.CommandGroup() for create_cmd, astnode, bases in refs.values(): cmd = create_cmd.as_inherited_ref_cmd( schema=schema, context=context, astnode=astnode, bases=bases, referrer=scls, ) cmd.set_attribute_value(refdict.backref_attr, scls) group.add(cmd) return group class AlterInheritingObjectOrFragment[InheritingObjectT: so.InheritingObject]( InheritingObjectCommand[InheritingObjectT], sd.AlterObjectOrFragment[InheritingObjectT], ): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) scls = self.scls if not context.canonical: self._fixup_inheritance_refdicts(schema, context) props = self.enumerate_attributes() if props: bases = scls.get_bases(schema).objects(schema) schema = self.inherit_fields( schema, context, bases, fields=props, ) if context.enable_recursion: self._propagate_field_alter(schema, context, scls, props) return schema def _propagate_field_alter( self, schema: s_schema.Schema, context: sd.CommandContext, scls: so.InheritingObject, props: tuple[str, ...], ) -> None: if _has_implicit_propagation(context): return descendant_names = [ d.get_name(schema) for d in scls.ordered_descendants(schema) ] for descendant_name in descendant_names: descendant = schema.get( descendant_name, type=so.InheritingObject, default=None ) assert descendant, '.inherit_fields caused a drop of a descendant?' d_root_cmd, d_alter_cmd, ctx_stack = descendant.init_delta_branch( schema, context, sd.AlterObject) d_bases = descendant.get_bases(schema).objects(schema) # Copy any special updates over if isinstance(self, sd.AlterSpecialObjectField): d_alter_cmd.add(self.clone(d_alter_cmd.classname)) with ctx_stack(): d_alter_cmd.set_annotation('implicit_propagation', True) assert isinstance(d_alter_cmd, InheritingObjectCommand) schema = d_alter_cmd.inherit_fields( schema, context, d_bases, fields=props, apply=False ) self.add_caused(d_root_cmd) def _update_inherited_fields( self, schema: s_schema.Schema, context: sd.CommandContext, update: Mapping[str, bool], ) -> None: cur_inh_fields = self.scls.get_inherited_fields(schema) inh_fields = set(cur_inh_fields) for fn, inherited in update.items(): if inherited: inh_fields.add(fn) else: inh_fields.discard(fn) if cur_inh_fields != inh_fields: if inh_fields: self.set_attribute_value( 'inherited_fields', frozenset(inh_fields), orig_value=cur_inh_fields, ) else: self.set_attribute_value( 'inherited_fields', None, orig_value=cur_inh_fields, ) class AlterInheritingObject[InheritingObjectT: so.InheritingObject]( AlterInheritingObjectOrFragment[InheritingObjectT], sd.AlterObject[InheritingObjectT], ): @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, AlterInheritingObject) assert isinstance(astnode, qlast.ObjectDDL) # Collect sequences of AlterInherit commands and transform them # into real RebaseWhatever commands. added_bases = [] dropped_bases = [] subcmds = cmd.get_subcommands() for i, sub in enumerate(subcmds): if not isinstance(sub, AlterInherit): continue dropped_bases.extend(sub.dropped_bases) added_bases.extend(sub.added_bases) if ( i + 1 < len(subcmds) and isinstance(subcmds[i + 1], AlterInherit) ): cmd.discard(sub) continue # The next command is not an AlterInherit, so it's time to # combine what we've seen and turn it into a rebase. parent_class = cmd.get_schema_metaclass() rebase_class = sd.get_object_command_class_or_die( RebaseInheritingObject, parent_class) cmd.replace( sub, rebase_class( classname=cmd.classname, removed_bases=tuple(dropped_bases), added_bases=tuple(added_bases) ) ) added_bases.clear() dropped_bases.clear() # XXX: I am not totally sure when this will come up? if getattr(astnode, 'bases', None): bases = cls._classbases_from_ast(schema, astnode, context) if bases is not None: _, added = delta_bases( [], [b.get_name(schema) for b in bases], t=cmd.get_schema_metaclass(), ) rebase = sd.get_object_command_class_or_die( RebaseInheritingObject, cmd.get_schema_metaclass()) rebase_cmd = rebase( classname=cmd.classname, removed_bases=tuple(), added_bases=added, ) cmd.add(rebase_cmd) return cmd class AlterInheritingObjectFragment[T: so.InheritingObject]( AlterInheritingObjectOrFragment[T], sd.AlterObjectFragment[T], ): pass class RenameInheritingObject[T: so.InheritingObject]( AlterInheritingObjectFragment[T], sd.RenameObject[T], ): pass class DeleteInheritingObject[T: so.InheritingObject]( InheritingObjectCommand[T], sd.DeleteObject[T], ): pass class RebaseInheritingObject[InheritingObjectT: so.InheritingObject]( AlterInheritingObjectFragment[InheritingObjectT], ): _delta_action = 'rebase' removed_bases = struct.Field(tuple) # type: ignore added_bases = struct.Field(tuple) # type: ignore EXTRA_INHERITED_FIELDS: set[str] = set() def __repr__(self) -> str: return '<%s.%s "%s">' % (self.__class__.__module__, self.__class__.__name__, self.classname) def get_verb(self) -> str: # FIXME: We just say 'alter' because it is currently somewhat # inconsistent whether an object rebase on its own will get # placed in its own alter command or whether it will share one # with all the associated rebases of pointers. Ideally we'd # say 'alter base types of', but with the current machinery it # would still usually say 'alter', so just always do that. return 'alter' def _alter_finalize( self, schema: s_schema.Schema, context: sd.CommandContext ) -> s_schema.Schema: schema = super()._alter_finalize(schema, context) if not context.canonical: schema = self._recompute_inheritance(schema, context) if ( context.enable_recursion and not _has_implicit_propagation(context) ): for descendant in self.scls.ordered_descendants(schema): d_root_cmd, d_alter_cmd, ctx_stack = ( descendant.init_delta_branch( schema, context, sd.AlterObject)) assert isinstance(d_alter_cmd, InheritingObjectCommand) d_alter_cmd.set_annotation('implicit_propagation', True) with ctx_stack(): d_alter_cmd._fixup_inheritance_refdicts( schema, context) schema = d_alter_cmd._recompute_inheritance( schema, context) self.add_caused(d_root_cmd) return schema def compute_inherited_fields( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> dict[str, bool]: result = super().compute_inherited_fields(schema, context) # When things like indexes and constraints that use # ddl_identity to define their identity are inherited, the # child should inherit all of those fields, even if the object # is owned in the child. # Make this happen when rebasing. mcls = self.get_schema_metaclass() new_bases = self.get_attribute_value('bases').objects(schema) inherit = new_bases and not new_bases[0].get_abstract(schema) fields = { field.name for field in mcls.get_fields().values() if field.ddl_identity and field.inheritable } fields.update(self.EXTRA_INHERITED_FIELDS) for field in fields: if ( inherit and field not in result and bool( new_bases[0].get_explicit_field_value(schema, field, None) ) ): result[field] = True return result def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) orig_bases = self.scls.get_bases(schema) new_bases = self._compute_new_bases(schema, context, orig_bases) self.set_attribute_value( 'bases', so.ObjectList[InheritingObjectT].create(schema, new_bases), orig_value=orig_bases, ) return schema def _compute_new_bases( self, schema: s_schema.Schema, context: sd.CommandContext, orig_bases: so.ObjectList[InheritingObjectT], ) -> list[InheritingObjectT]: mcls = self.get_schema_metaclass() default_base_name = mcls.get_default_base_name() ori_bases = list(orig_bases.objects(schema)) if default_base_name: default_base: Optional[InheritingObjectT] = self.get_object( schema, context, name=default_base_name) if ori_bases == [default_base]: ori_bases = [] else: default_base = None removed_bases = {b.name for b in self.removed_bases} bases = [ b for b in ori_bases if b.get_name(schema) not in removed_bases ] existing_bases = { b.get_name(schema) for b in bases } index = {b.get_name(schema): i for i, b in enumerate(bases)} for new_bases, pos in self.added_bases: if isinstance(pos, tuple): pos, ref = pos if not pos or pos == 'LAST': idx = len(bases) elif pos == 'FIRST': idx = 0 else: idx = index[ref.name] bases[idx:idx] = [ self.get_object( schema, context, name=b.name, span=b.span) for b in new_bases if b.name not in existing_bases ] index = {b.get_name(schema): i for i, b in enumerate(bases)} if not bases and default_base: bases = [default_base] return bases def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: assert parent_node is not None dropped = self._get_bases_for_ast(schema, context, self.removed_bases) if dropped: parent_node.commands.append( qlast.AlterDropInherit( bases=[ cast(qlast.TypeName, utils.typeref_to_ast(schema, b)) for b in dropped ], ) ) for bases, pos in self.added_bases: bases = self._get_bases_for_ast(schema, context, bases) if not bases: continue if isinstance(pos, tuple): typ = utils.typeref_to_ast(schema, pos[1]) assert isinstance(typ, qlast.TypeName) assert isinstance(typ.maintype, qlast.ObjectRef) pos_node = qlast.Position( position=pos[0], ref=typ.maintype, ) else: pos_node = qlast.Position(position=pos) parent_node.commands.append( qlast.AlterAddInherit( bases=[ cast(qlast.TypeName, utils.typeref_to_ast(schema, b)) for b in bases ], position=pos_node, ) ) return None def _get_bases_for_ast( self, schema: s_schema.Schema, context: sd.CommandContext, bases: tuple[so.ObjectShell[InheritingObjectT], ...], ) -> tuple[so.ObjectShell[InheritingObjectT], ...]: mcls = self.get_schema_metaclass() roots = set(mcls.get_root_classes()) return tuple(b for b in bases if b.name not in roots) def _needs_refdict(refdict: so.RefDict, context: sd.CommandContext) -> bool: inheritance_refdicts = context.inheritance_refdicts return ( inheritance_refdicts is None or refdict.attr in inheritance_refdicts ) and (context.inheritance_merge is None or context.inheritance_merge) def _has_implicit_propagation(context: sd.CommandContext) -> bool: for ctx in reversed(context.stack): if ( isinstance(ctx.op, sd.ObjectCommand) and ctx.op.get_annotation('implicit_propagation') ): return True return False ================================================ FILE: edb/schema/links.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, TYPE_CHECKING from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb import errors from . import constraints from . import delta as sd from . import inheriting from . import properties from . import name as sn from . import objects as so from . import pointers from . import referencing from . import rewrites as s_rewrites from . import sources as s_sources from . import types as s_types from . import unknown_pointers from . import utils from . import expr as s_expr if TYPE_CHECKING: from . import objtypes as s_objtypes from . import schema as s_schema LinkTargetDeleteAction = qltypes.LinkTargetDeleteAction LinkSourceDeleteAction = qltypes.LinkSourceDeleteAction def merge_actions( target: so.InheritingObject, sources: list[so.Object], field_name: str, *, ignore_local: bool = False, schema: s_schema.Schema, ) -> Any: if not ignore_local: ours = target.get_explicit_local_field_value(schema, field_name, None) else: ours = None if ours is None: current = None current_from = None for source in sources: theirs = source.get_explicit_field_value(schema, field_name, None) if theirs is not None: if current is None: current = theirs current_from = source elif current != theirs: target_source = target.get_source(schema) current_from_source = current_from.get_source(schema) source_source = source.get_source(schema) tgt_repr = ( f'{target_source.get_displayname(schema)}.' f'{target.get_displayname(schema)}' ) cf_repr = ( f'{current_from_source.get_displayname(schema)}.' f'{current_from.get_displayname(schema)}' ) other_repr = ( f'{source_source.get_displayname(schema)}.' f'{source.get_displayname(schema)}' ) raise errors.SchemaError( f'cannot implicitly resolve the ' f'`on target delete` action for ' f'{tgt_repr!r}: it is defined as {current} in ' f'{cf_repr!r} and as {theirs} in {other_repr!r}; ' f'to resolve, declare `on target delete` ' f'explicitly on {tgt_repr!r}' ) return current else: return ours class Link( s_sources.Source, pointers.Pointer, qlkind=qltypes.SchemaObjectClass.LINK, data_safe=False, ): on_target_delete = so.SchemaField( LinkTargetDeleteAction, default=LinkTargetDeleteAction.Restrict, coerce=True, compcoef=0.9, merge_fn=merge_actions) on_source_delete = so.SchemaField( LinkSourceDeleteAction, default=LinkSourceDeleteAction.Allow, coerce=True, compcoef=0.9, merge_fn=merge_actions) def get_target(self, schema: s_schema.Schema) -> s_objtypes.ObjectType: return self.get_field_value( # type: ignore[no-any-return] schema, 'target') def is_link_property(self, schema: s_schema.Schema) -> bool: return False def has_user_defined_properties(self, schema: s_schema.Schema) -> bool: return bool([p for p in self.get_pointers(schema).objects(schema) if not p.is_special_pointer(schema) and not p.is_pure_computable(schema)]) def get_source( self, schema: s_schema.Schema ) -> Optional[s_objtypes.ObjectType]: return self.get_field_value( # type: ignore[no-any-return] schema, 'source') def get_source_type(self, schema: s_schema.Schema) -> s_objtypes.ObjectType: source = self.get_source(schema) assert source return source def compare( self, other: so.Object, *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: so.ComparisonContext, ) -> float: if not isinstance(other, Link): if isinstance(other, pointers.Pointer): return 0.0 else: raise NotImplementedError() return super().compare( other, our_schema=our_schema, their_schema=their_schema, context=context) def set_target( self, schema: s_schema.Schema, target: s_types.Type, ) -> s_schema.Schema: schema = super().set_target(schema, target) tgt_prop = self.maybe_get_ptr(schema, sn.UnqualName('target')) if tgt_prop: schema = tgt_prop.set_target(schema, target) return schema @classmethod def get_root_classes(cls) -> tuple[sn.QualName, ...]: return ( sn.QualName(module='std', name='link'), sn.QualName(module='schema', name='__type__'), ) @classmethod def get_default_base_name(self) -> sn.QualName: return sn.QualName('std', 'link') class LinkSourceCommandContext[Source_T: s_sources.Source]( s_sources.SourceCommandContext[Source_T] ): pass class LinkSourceCommand[Source_T: s_sources.Source]( inheriting.InheritingObjectCommand[Source_T] ): pass class LinkCommandContext( pointers.PointerCommandContext, constraints.ConsistencySubjectCommandContext, properties.PropertySourceContext[Link], unknown_pointers.UnknownPointerSourceContext[Link], s_sources.SourceCommandContext[Link], s_rewrites.RewriteSubjectCommandContext, ): pass class LinkCommand( properties.PropertySourceCommand[Link], pointers.PointerCommand[Link], context_class=LinkCommandContext, referrer_context_class=LinkSourceCommandContext, ): def _append_subcmd_ast( self, schema: s_schema.Schema, node: qlast.DDLOperation, subcmd: sd.Command, context: sd.CommandContext, ) -> None: if ( isinstance(subcmd, pointers.PointerCommand) and subcmd.classname != self.classname ): pname = sn.shortname_from_fullname(subcmd.classname) if pname.name in {'source', 'target'}: return super()._append_subcmd_ast(schema, node, subcmd, context) def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: """Check that link definition is sound.""" super().validate_object(schema, context) scls = self.scls assert isinstance(scls, Link) if not scls.get_owned(schema): return target = scls.get_target(schema) assert target is not None if not target.is_object_type(): span = self.get_attribute_span('target') if isinstance(target, s_types.Array): # Custom error message for link -> array<...> link_dn = scls.get_displayname(schema) el_dn = target.get_subtypes(schema)[0].get_displayname(schema) hint = f"did you mean 'multi link {link_dn} -> {el_dn}'?" else: hint = None raise errors.InvalidLinkTargetError( f'invalid link target type, expected object type, got ' f'{target.get_verbosename(schema)}', span=span, hint=hint, ) if target.is_free_object_type(schema): span = self.get_attribute_span('target') raise errors.InvalidLinkTargetError( f'{target.get_verbosename(schema)} is not a valid link target', span=span, ) if ( not scls.is_pure_computable(schema) and not scls.get_from_alias(schema) and target.is_view(schema) ): span = self.get_attribute_span('target') raise errors.InvalidLinkTargetError( f'invalid link type: {target.get_displayname(schema)!r}' f' is an expression alias, not a proper object type', span=span, ) if ( scls.get_required(schema) and scls.get_on_target_delete(schema) == qltypes.LinkTargetDeleteAction.DeferredRestrict ): raise errors.InvalidLinkTargetError( 'required links may not use `on target delete ' 'deferred restrict`', span=self.span, ) def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: node = super()._get_ast(schema, context, parent_node=parent_node) # __type__ link is special, and while it exists on every object # it does not have a defined default in the schema (and therefore # it isn't marked as required.) We intervene here to mark all # __type__ links required when rendering for SDL/TEXT. if context.declarative and node is not None: assert isinstance(node, (qlast.CreateConcreteLink, qlast.CreateLink)) if node.name.name == '__type__': assert isinstance(node, qlast.CreateConcretePointer) node.is_required = True return node def _reinherit_classref_dict( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, ) -> tuple[s_schema.Schema, dict[sn.Name, type[sd.ObjectCommand[so.Object]]]]: if self.scls.get_computable(schema) and refdict.attr != 'pointers': # If the link is a computable, the inheritance would only # happen in the case of aliasing, and in that case we only # need to inherit the link properties and nothing else. return schema, {} return super()._reinherit_classref_dict(schema, context, refdict) class CreateLink( pointers.CreatePointer[Link], LinkCommand, ): astnode = [qlast.CreateConcreteLink, qlast.CreateLink] referenced_astnode = qlast.CreateConcreteLink @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) if isinstance(astnode, qlast.CreateConcreteLink): assert isinstance(cmd, pointers.PointerCommand) cmd._process_create_or_alter_ast(schema, astnode, context) assert isinstance(cmd, sd.Command) return cmd def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if ( field == 'required' and issubclass(astnode, qlast.CreateConcreteLink) ): return 'is_required' elif ( field == 'cardinality' and issubclass(astnode, qlast.CreateConcreteLink) ): return 'cardinality' else: return super().get_ast_attr_for_field(field, astnode) def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: objtype = self.get_referrer_context(context) if op.property == 'target' and objtype: # Due to how SDL is processed the underlying AST may be an # AlterConcreteLink, which requires different handling. if isinstance(node, qlast.CreateConcreteLink): if not node.target: expr: Optional[s_expr.Expression] = ( self.get_attribute_value('expr') ) if expr is not None: node.target = expr.parse() else: t = op.new_value assert isinstance(t, (so.Object, so.ObjectShell)) node.target = utils.typeref_to_ast(schema, t) else: old_type = pointers.merge_target( self.scls, list(self.scls.get_bases(schema).objects(schema)), 'target', ignore_local=True, schema=schema, ) assert isinstance(op.new_value, (so.Object, so.ObjectShell)) new_type = ( op.new_value.resolve(schema) if isinstance(op.new_value, so.ObjectShell) else op.new_value) assert isinstance(new_type, s_types.Type) new_type_ast = utils.typeref_to_ast(schema, op.new_value) cast_expr = None # If the type isn't assignment castable, generate a # USING with a nonsense cast. It shouldn't matter, # since there should be no data to cast, but the DDL side # of things doesn't know that since the command is split up. if old_type and not old_type.assignment_castable_to( new_type, schema): cast_expr = qlast.TypeCast( type=new_type_ast, expr=qlast.Set(elements=[]), ) node.commands.append( qlast.SetPointerType( value=new_type_ast, cast_expr=cast_expr, ) ) elif op.property == 'on_target_delete': node.commands.append(qlast.OnTargetDelete(cascade=op.new_value)) elif op.property == 'on_source_delete': node.commands.append(qlast.OnSourceDelete(cascade=op.new_value)) else: super()._apply_field_ast(schema, context, node, op) def inherit_classref_dict( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, ) -> sd.CommandGroup: if self.scls.get_computable(schema) and refdict.attr != 'pointers': # If the link is a computable, the inheritance would only # happen in the case of aliasing, and in that case we only # need to inherit the link properties and nothing else. return sd.CommandGroup() cmd = super().inherit_classref_dict(schema, context, refdict) if refdict.attr != 'pointers': return cmd parent_ctx = self.get_referrer_context(context) if parent_ctx is None: return cmd # Skip source and target when compiling stuff that won't ever # go into a real schema. if context.slim_links: return cmd base_prop_name = sn.QualName('std', 'source') s_name = sn.get_specialized_name( sn.QualName('__', 'source'), str(self.classname)) src_prop_name = sn.QualName( name=s_name, module=self.classname.module) src_prop = properties.CreateProperty( classname=src_prop_name, is_strong_ref=True, ) src_prop.set_attribute_value('name', src_prop_name) src_prop.set_attribute_value( 'bases', so.ObjectList.create(schema, [schema.get(base_prop_name)]), ) src_prop.set_attribute_value( 'source', self.scls, ) src_prop.set_attribute_value( 'target', parent_ctx.op.scls, ) src_prop.set_attribute_value('required', True) src_prop.set_attribute_value('readonly', True) src_prop.set_attribute_value('owned', True) src_prop.set_attribute_value('from_alias', self.scls.get_from_alias(schema)) src_prop.set_attribute_value('cardinality', qltypes.SchemaCardinality.One) cmd.prepend(src_prop) base_prop_name = sn.QualName('std', 'target') s_name = sn.get_specialized_name( sn.QualName('__', 'target'), str(self.classname)) tgt_prop_name = sn.QualName( name=s_name, module=self.classname.module) tgt_prop = properties.CreateProperty( classname=tgt_prop_name, is_strong_ref=True, ) tgt_prop.set_attribute_value('name', tgt_prop_name) tgt_prop.set_attribute_value( 'bases', so.ObjectList.create(schema, [schema.get(base_prop_name)]), ) tgt_prop.set_attribute_value( 'source', self.scls, ) tgt_prop.set_attribute_value( 'target', self.get_attribute_value('target'), ) tgt_prop.set_attribute_value('required', False) tgt_prop.set_attribute_value('readonly', True) tgt_prop.set_attribute_value('owned', True) tgt_prop.set_attribute_value('from_alias', self.scls.get_from_alias(schema)) tgt_prop.set_attribute_value('cardinality', qltypes.SchemaCardinality.One) cmd.prepend(tgt_prop) return cmd class RenameLink( LinkCommand, referencing.RenameReferencedInheritingObject[Link], ): pass class RebaseLink( LinkCommand, referencing.RebaseReferencedInheritingObject[Link], ): pass class SetLinkType( pointers.SetPointerType[Link], referrer_context_class=LinkSourceCommandContext, field='target', ): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) scls = self.scls new_target = scls.get_target(schema) if not context.canonical: # We need to update the target link prop as well tgt_prop = scls.maybe_get_ptr(schema, sn.UnqualName('target')) if tgt_prop: tgt_prop_alter = tgt_prop.init_delta_command( schema, sd.AlterObject) tgt_prop_alter.set_attribute_value('target', new_target) self.add(tgt_prop_alter) return schema class AlterLinkUpperCardinality( pointers.AlterPointerUpperCardinality[Link], referrer_context_class=LinkSourceCommandContext, field='cardinality', ): pass class AlterLinkLowerCardinality( pointers.AlterPointerLowerCardinality[Link], referrer_context_class=LinkSourceCommandContext, field='required', ): pass class AlterLinkOwned( referencing.AlterOwned[Link], pointers.PointerCommandOrFragment[Link], referrer_context_class=LinkSourceCommandContext, field='owned', ): pass class SetTargetDeletePolicy(sd.Command): astnode = qlast.OnTargetDelete @classmethod def _cmd_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.AlterObjectProperty: return sd.AlterObjectProperty( property='on_target_delete' ) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: assert isinstance(astnode, qlast.OnTargetDelete) cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, sd.AlterObjectProperty) cmd.new_value = astnode.cascade return cmd class SetSourceDeletePolicy(sd.Command): astnode = qlast.OnSourceDelete @classmethod def _cmd_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.AlterObjectProperty: return sd.AlterObjectProperty( property='on_source_delete' ) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: assert isinstance(astnode, qlast.OnSourceDelete) cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, sd.AlterObjectProperty) cmd.new_value = astnode.cascade return cmd class AlterLink( LinkCommand, pointers.AlterPointer[Link], ): astnode = [qlast.AlterConcreteLink, qlast.AlterLink] referenced_astnode = qlast.AlterConcreteLink @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> AlterLink: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, AlterLink) if isinstance(astnode, qlast.CreateConcreteLink): cmd._process_create_or_alter_ast(schema, astnode, context) else: cmd._process_alter_ast(schema, astnode, context) return cmd def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: if op.property == 'target': if op.new_value: assert isinstance(op.new_value, so.ObjectShell) node.commands.append( qlast.SetPointerType( value=utils.typeref_to_ast(schema, op.new_value), ), ) elif op.property == 'computable': if not op.new_value: node.commands.append( qlast.SetField( name='expr', value=None, special_syntax=True, ), ) elif op.property == 'on_target_delete': node.commands.append(qlast.OnTargetDelete(cascade=op.new_value)) elif op.property == 'on_source_delete': node.commands.append(qlast.OnSourceDelete(cascade=op.new_value)) else: super()._apply_field_ast(schema, context, node, op) class DeleteLink( LinkCommand, pointers.DeletePointer[Link], ): astnode = [qlast.DropConcreteLink, qlast.DropLink] referenced_astnode = qlast.DropConcreteLink def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if self.get_orig_attribute_value('from_alias'): # This is an alias type, appropriate DDL would be generated # from the corresponding Alter/DeleteAlias node. return None else: return super()._get_ast(schema, context, parent_node=parent_node) ================================================ FILE: edb/schema/migrations.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Implementation of MIGRATION objects.""" from __future__ import annotations from typing import Optional, TYPE_CHECKING from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import codegen as qlcodegen from edb.edgeql import qltypes from edb.edgeql import parser as qlparser import edb._edgeql_parser as ql_parser from . import delta as sd from . import name as sn from . import objects as so from . import utils as s_utils if TYPE_CHECKING: from . import schema as s_schema class Migration( so.Object, qlkind=qltypes.SchemaObjectClass.MIGRATION, data_safe=False, ): parents = so.SchemaField( so.ObjectList["Migration"], default=so.DEFAULT_CONSTRUCTOR, coerce=True, ) message = so.SchemaField( str, default=None, allow_ddl_set=True, ) generated_by = so.SchemaField( str, default=None, allow_ddl_set=True, ) script = so.SchemaField( str, ) sdl = so.SchemaField( str, ) class MigrationCommandContext(sd.ObjectCommandContext[Migration]): pass class MigrationCommand( sd.ObjectCommand[Migration], context_class=MigrationCommandContext, ): pass class CreateMigration(MigrationCommand, sd.CreateObject[Migration]): astnode = qlast.CreateMigration @classmethod def _cmd_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> CreateMigration: assert isinstance(astnode, qlast.CreateMigration) if astnode.name is not None: specified_name = astnode.name.name else: specified_name = None parent_migration = schema.get_last_migration() parent: Optional[so.ObjectShell[Migration]] if parent_migration is not None: parent = parent_migration.as_shell(schema) parent_name = str(parent.name) else: parent = None parent_name = 'initial' if astnode.parent is not None: parent_name = astnode.parent.name hasher = ql_parser.Hasher.start_migration(parent_name) if astnode.body.text is not None: # This is an explicitly specified CREATE MIGRATION ddl_text = astnode.body.text elif astnode.body.commands: # An implicit CREATE MIGRATION produced by START MIGRATION ddl_text = ';\n'.join( qlcodegen.generate_source(stmt, uppercase=True) for stmt in [*astnode.commands, *astnode.body.commands] ) + ';' else: ddl_text = '' hasher.add_source(ddl_text) name = hasher.make_migration_id() sdl_text: Optional[str] = astnode.target_sdl if specified_name is not None and name != specified_name: raise errors.SchemaDefinitionError( f'specified migration name does not match the name derived ' f'from the migration contents: {specified_name!r}, expected ' f'{name!r}', span=astnode.name.span, ) if specified_name is not None and schema.has_migration(specified_name): # Note: it's not possible to have duplicate migration without # `specified_name`. Because new one will be based onto the new # parent (and you can't specify parent without a name). raise errors.DuplicateMigrationError( f'migration {name!r} is already applied', span=astnode.name.span, ) if astnode.parent is not None: if parent_migration is None: if astnode.parent.name.lower() != 'initial': raise errors.SchemaDefinitionError( f'specified migration parent does not exist', span=astnode.parent.span, ) else: astnode_parent = s_utils.ast_objref_to_object_shell( astnode.parent, metaclass=Migration, schema=schema, modaliases={}, ) actual_parent_name = parent_migration.get_name(schema) if astnode_parent.name != actual_parent_name: raise errors.SchemaDefinitionError( f'specified migration parent is not the most recent ' f'migration, expected {str(actual_parent_name)!r}', span=astnode.parent.span, ) cmd = cls(classname=sn.UnqualName(name)) cmd.set_attribute_value('script', ddl_text) cmd.set_attribute_value('sdl', sdl_text) cmd.set_attribute_value('builtin', False) cmd.set_attribute_value('internal', False) if parent is not None: cmd.set_attribute_value('parents', [parent]) return cmd @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> CreateMigration: assert isinstance(astnode, qlast.CreateMigration) cmd = super()._cmd_tree_from_ast(schema, astnode, context) if astnode.body.commands and not astnode.metadata_only: for subastnode in astnode.body.commands: subcmd = sd.compile_ddl(schema, subastnode, context=context) if subcmd is not None: cmd.add(subcmd) assert isinstance(cmd, CreateMigration) return cmd def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: from . import ddl as s_ddl new_schema = super().apply(schema, context) if ( context.store_migration_sdl and not self.get_attribute_value('sdl') ): # If target sdl was not known in advance, compute it now. new_sdl: str = s_ddl.sdl_text_from_schema(new_schema) new_schema = self.scls.set_field_value(new_schema, 'sdl', new_sdl) self.set_attribute_value('sdl', new_sdl) return new_schema def apply_subcommands( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: assert not self.get_prerequisites() and not self.get_caused() # Renames shouldn't persist between commands in a migration script. context.renames.clear() for op in self.get_subcommands( include_prerequisites=False, include_caused=False, ): if not isinstance(op, sd.AlterObjectProperty): schema = op.apply(schema, context=context) context.renames.clear() return schema def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: node = super()._get_ast(schema, context, parent_node=parent_node) assert isinstance(node, qlast.CreateMigration) node.metadata_only = True return node def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: assert isinstance(node, qlast.CreateMigration) if op.property == 'script': block, _ = qlparser.parse_migration_body_block(op.new_value) node.body = qlast.NestedQLBlock( commands=block.commands, text=op.new_value, ) elif op.property == 'parents': if op.new_value and (items := op.new_value.items): assert len(items) == 1 parent = next(iter(items)) node.parent = s_utils.name_to_ast_ref(parent.get_name(schema)) else: super()._apply_field_ast(schema, context, node, op) class AlterMigration(MigrationCommand, sd.AlterObject[Migration]): astnode = qlast.AlterMigration class DeleteMigration(MigrationCommand, sd.DeleteObject[Migration]): astnode = qlast.DropMigration def get_ordered_migrations( schema: s_schema.Schema, ) -> list[Migration]: '''Get all the migrations, in order. It would be nice if our toposort could do this for us, but toposort is implemented recursively, and it would be a pain to change that. ''' output = [] mig = schema.get_last_migration() while mig: output.append(mig) parents = mig.get_parents(schema).objects(schema) assert len(parents) <= 1, "only one parent supported currently" mig = parents[0] if parents else None output.reverse() return output ================================================ FILE: edb/schema/modules.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import qltypes from . import annos as s_anno from . import delta as sd from . import name as sn from . import objects as so from . import schema as s_schema RESERVED_MODULE_NAMES = { 'super', } DEFAULT_MODULE_ALIAS = 'default' class Module( s_anno.AnnotationSubject, so.Object, # Help reflection figure out the right db MRO qlkind=qltypes.SchemaObjectClass.MODULE, data_safe=False, ): # N.B: Modules are not "qualified" objects, even though they can # be nested (because they might *not* be nested) and we arrange # for their names to always be represented with an UnqualName. pass class ModuleCommandContext(sd.ObjectCommandContext[Module]): pass class ModuleCommand( sd.ObjectCommand[Module], context_class=ModuleCommandContext, ): def _validate_legal_command( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super()._validate_legal_command(schema, context) last = str(self.classname) first = last enclosing = None if '::' in str(self.classname): first, _, _ = str(self.classname).partition('::') enclosing, _, last = str(self.classname).rpartition('::') if not schema.has_module(enclosing): raise errors.UnknownModuleError( f'module {enclosing!r} is not in this schema') if last in RESERVED_MODULE_NAMES: raise errors.SchemaDefinitionError( f"module {last!r} is a reserved module name") if ( not context.stdmode and not context.testmode and sn.UnqualName(first) in s_schema.STD_MODULES ): raise errors.SchemaDefinitionError( f'cannot {self._delta_action} {self.get_verbosename()}: ' f'module {first} is read-only', span=self.span) class CreateModule(ModuleCommand, sd.CreateObject[Module]): astnode = qlast.CreateModule class AlterModule(ModuleCommand, sd.AlterObject[Module]): astnode = qlast.AlterModule class RenameModule(ModuleCommand, sd.RenameObject[Module]): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: raise errors.SchemaError( f'renaming modules is not supported', span=self.span, ) class DeleteModule(ModuleCommand, sd.DeleteObject[Module]): astnode = qlast.DropModule def _validate_legal_command( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super()._validate_legal_command(schema, context) # For now, we disallow deleting non-empty modules. # Modules aren't actually stored with any direct linkage # to the objects in them, so explicitly search for objects # in the module (excluding the module itself). has_objects = bool(any(schema.get_objects( included_modules=[self.classname], excluded_items=[self.classname], ))) if has_objects: vn = self.scls.get_verbosename(schema) raise errors.SchemaError( f'cannot drop {vn} because it is not empty' ) ================================================ FILE: edb/schema/name.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, TypeVar, NamedTuple, TYPE_CHECKING import abc import functools import re from edb import errors from edb.common import markup NameT = TypeVar("NameT", bound="Name") QualNameT = TypeVar("QualNameT", bound="QualName") UnqualNameT = TypeVar("UnqualNameT", bound="UnqualName") # Unfortunately, there is no way to convince mypy that QualName # and UnqualName are implementations of the Name ABC: # NamedTuple doesn't support multiple inheritance, and ABCMeta.register # is not supported either. And so, we must resort to stubbing. if TYPE_CHECKING: class Name: __match_args__ = ('name',) name: str @classmethod def from_string(cls: type[NameT], name: str) -> NameT: ... def get_local_name(self) -> UnqualName: ... def get_root_module_name(self) -> UnqualName: ... def __lt__(self, other: Any) -> bool: ... def __le__(self, other: Any) -> bool: ... def __gt__(self, other: Any) -> bool: ... def __ge__(self, other: Any) -> bool: ... def __str__(self) -> str: ... def __repr__(self) -> str: ... def __hash__(self) -> int: ... class QualName(Name): __match_args__ = ('module', 'name') module: str name: str @classmethod def from_string( cls: type[QualNameT], name: str, ) -> QualNameT: ... def __init__(self, module: str, name: str) -> None: ... def get_local_name(self) -> UnqualName: ... def get_module_name(self) -> Name: ... class UnqualName(Name): __slots__ = ('name',) name: str @classmethod def from_string( cls: type[UnqualNameT], name: str, ) -> UnqualNameT: ... def __init__(self, name: str) -> None: ... def get_local_name(self) -> UnqualName: ... else: class Name(abc.ABC): # noqa: B024 pass class QualName(NamedTuple): module: str name: str @classmethod def from_string( cls: type[QualNameT], name: str, ) -> QualNameT: module, _, nqname = name.rpartition('::') if not module: err = ( f'improperly formed name {name!r}: ' f'module is not specified' ) raise errors.InvalidReferenceError(err) return cls( module=module, name=nqname, ) def get_local_name(self) -> UnqualName: return UnqualName(self.name) def get_module_name(self) -> Name: return UnqualName(self.module) def get_root_module_name(self) -> UnqualName: return UnqualName(self.module.partition('::')[0]) def __str__(self) -> str: return f'{self.module}::{self.name}' def __repr__(self) -> str: return f'' class UnqualName(NamedTuple): name: str @classmethod def from_string( cls: type[UnqualNameT], name: str, ) -> UnqualNameT: return cls(name) def get_local_name(self) -> UnqualName: return self def get_root_module_name(self) -> UnqualName: return UnqualName(self.name.partition('::')[0]) def __str__(self) -> str: return self.name def __repr__(self) -> str: return f'' Name.register(QualName) Name.register(UnqualName) def is_qualified(name: str) -> bool: return '::' in name def name_from_string(name: str) -> Name: if is_qualified(name): return QualName.from_string(name) else: return UnqualName.from_string(name) def mangle_name(name: str) -> str: return ( name .replace('|', '||') .replace('&', '&&') .replace('::', '|') .replace('@', '&') ) mangle_re_1 = re.compile(r'(? str: name = mangle_re_1.sub('::', name) name = mangle_re_2.sub('@', name) return name.replace('||', '|').replace('&&', '&') @functools.lru_cache(10240) def shortname_from_fullname(fullname: Name) -> Name: name = fullname.name parts = name.split('@', 1) if len(parts) == 2: return name_from_string(unmangle_name(parts[0])) else: return fullname unmangle_re_1 = re.compile(r'\|+') def recursively_unmangle_shortname(name: str) -> str: # Any number of pipes becomes a single ::. return unmangle_re_1.sub('::', name) @functools.lru_cache(4096) def quals_from_fullname(fullname: QualName) -> list[str]: _, _, mangled_quals = fullname.name.partition('@') return ( [unmangle_name(p) for p in mangled_quals.split('@')] if mangled_quals else [] ) def get_specialized_name(basename: Name, *qualifiers: str) -> str: mangled_quals = '@'.join(mangle_name(qual) for qual in qualifiers if qual) return f'{mangle_name(str(basename))}@{mangled_quals}' def is_fullname(name: str) -> bool: return is_qualified(name) and '@' in name def compat_get_specialized_name(basename: str, *qualifiers: str) -> str: mangled_quals = '@'.join( compat_mangle_name(qual) for qual in qualifiers if qual ) return f'{compat_mangle_name(basename)}@@{mangled_quals}' def compat_mangle_name(name: str) -> str: return name.replace('::', '|') def compat_name_remangle(name: str) -> Name: if is_fullname(name): qname = QualName.from_string(name) sn = shortname_from_fullname(qname) quals = list(quals_from_fullname(qname)) if quals and is_fullname(quals[0]): quals[0] = str(compat_name_remangle(quals[0])) compat_sn = compat_get_specialized_name(str(sn), *quals) return QualName(name=compat_sn, module=qname.module) else: return name_from_string(name) @markup.serializer.no_ref_detect @markup.serializer.serializer.register(Name) def _serialize_to_markup(obj: Name, *, ctx: markup.Context) -> markup.Markup: return markup.elements.lang.Object( id=id(obj), class_module=type(obj).__module__, classname=type(obj).__name__, repr=str(obj)) ================================================ FILE: edb/schema/objects.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, ClassVar, Final, Generic, Optional, Protocol, TypeVar, Iterable, Iterator, Mapping, Collection, NamedTuple, cast, Self, TYPE_CHECKING, ) import builtins import collections import collections.abc import copy import enum import re import uuid from edb import errors from edb.edgeql import qltypes from edb.common.typeutils import not_none from edb.common import checked from edb.common import lru from edb.common import markup from edb.common import ordered from edb.common import parametric from edb.common import parsing from edb.common import struct from edb.common import topological from edb.common import uuidgen from . import abc as s_abc from . import name as sn from . import _types if TYPE_CHECKING: from edb.schema import delta as sd from edb.schema import schema as s_schema CovT = TypeVar("CovT", covariant=True) class MergeFunction(Protocol): def __call__( self, # not actually part of the signature target: InheritingObject, sources: Iterable[Object], field_name: str, *, ignore_local: bool = False, schema: s_schema.Schema, ) -> Any: ... class CollectionFactory(Collection[CovT], Protocol): """An unknown collection that can be instantiated from an iterable.""" def __init__( self, from_iter: Optional[Iterable[CovT]] = None ) -> None: ... class NoDefaultT(enum.Enum): """Used as a sentinel indicating that a named argument wasn't passed. Trick from https://github.com/python/mypy/issues/7642. """ NoDefault = 0 NoDefault: Final = NoDefaultT.NoDefault class DefaultConstructorT(enum.Enum): DefaultConstructor = 0 DEFAULT_CONSTRUCTOR: Final = DefaultConstructorT.DefaultConstructor ObjectContainer_T = TypeVar('ObjectContainer_T', bound='ObjectContainer') Object_T = TypeVar("Object_T", bound="Object") Object_T_co = TypeVar("Object_T_co", bound="Object", covariant=True) ObjectCollection_T = TypeVar( "ObjectCollection_T", bound="ObjectCollection[Object]", ) HashCriterion = type["Object"] | tuple[str, Any] TYPE_ID_NAMESPACE = uuidgen.UUID('00e50276-2502-11e7-97f2-27fe51238dbd') class ReflectionMethod(enum.Enum): """Annotation on schema classes telling how to reflect in metaschema.""" #: Straight 1:1 reflection (the default) REGULAR = enum.auto() #: Object type for schema class is elided and its properties #: are reflected as link properties. This is used for certain #: Referenced classes, like AnnotationValue. AS_LINK = enum.auto() #: No metaschema reflection at all. NONE = enum.auto() def default_field_merge( target: InheritingObject, sources: Iterable[Object], field_name: str, *, ignore_local: bool = False, schema: s_schema.Schema, ) -> Any: """The default `MergeFunction`.""" if not ignore_local: ours = target.get_explicit_local_field_value(schema, field_name, None) if ours is not None: return ours for source in sources: theirs = source.get_explicit_field_value(schema, field_name, None) if theirs is not None: return theirs return None def get_known_type_id( typename: str | sn.Name, default: uuid.UUID | NoDefaultT = NoDefault ) -> uuid.UUID: if isinstance(typename, str): typename = sn.name_from_string(typename) try: return _types.TYPE_IDS[typename] except KeyError: pass if default is NoDefault: raise errors.SchemaError( f'failed to lookup named type id for {typename!r}') return default class DeltaGuidance(NamedTuple): banned_creations: frozenset[tuple[type[Object], sn.Name]] = frozenset() banned_deletions: frozenset[tuple[type[Object], sn.Name]] = frozenset() banned_alters: frozenset[ tuple[type[Object], tuple[sn.Name, sn.Name]] ] = frozenset() class DescribeVisibilityFlags(enum.IntFlag): #: Show the field if it is set explicitly, i.e. not inherited or computed. SHOW_IF_EXPLICIT = 1 << 0 #: Show the field if it is inherited or computed. SHOW_IF_DERIVED = 1 << 1 #: Show if the field value matches the default. SHOW_IF_DEFAULT = 1 << 2 class DescribeVisibilityPolicy(enum.IntEnum): SHOW_IF_EXPLICIT = ( DescribeVisibilityFlags.SHOW_IF_EXPLICIT ) SHOW_IF_EXPLICIT_OR_DERIVED = ( DescribeVisibilityFlags.SHOW_IF_EXPLICIT | DescribeVisibilityFlags.SHOW_IF_DERIVED | DescribeVisibilityFlags.SHOW_IF_DEFAULT ) SHOW_IF_EXPLICIT_OR_DERIVED_NOT_DEFAULT = ( DescribeVisibilityFlags.SHOW_IF_EXPLICIT | DescribeVisibilityFlags.SHOW_IF_DERIVED ) class ComparisonContext: renames: dict[tuple[type[Object], sn.Name], sd.RenameObject[Object]] deletions: dict[tuple[type[Object], sn.Name], sd.DeleteObject[Object]] guidance: Optional[DeltaGuidance] parent_ops: list[sd.ObjectCommand[Any]] def __init__( self, *, generate_prompts: bool = False, descriptive_mode: bool = False, guidance: Optional[DeltaGuidance] = None, ) -> None: self.generate_prompts = generate_prompts self.descriptive_mode = descriptive_mode self.guidance = guidance self.renames = {} self.deletions = {} self.placeholder_ctr: dict[str, int] = collections.Counter() self.parent_ops = [] def is_deleting(self, schema: s_schema.Schema, obj: Object) -> bool: return (type(obj), obj.get_name(schema)) in self.deletions def record_rename( self, op: sd.RenameObject[Object], ) -> None: self.renames[op.get_schema_metaclass(), op.classname] = op def is_renaming(self, schema: s_schema.Schema, obj: Object) -> bool: return (type(obj), obj.get_name(schema)) in self.renames def get_obj_name(self, schema: s_schema.Schema, obj: Object) -> sn.Name: obj_name = obj.get_name(schema) rename_op = self.renames.get((type(obj), obj_name)) if rename_op is not None: return rename_op.new_name else: return obj_name def get_placeholder(self, prefix: str) -> str: ctr = self.placeholder_ctr[prefix] self.placeholder_ctr[prefix] += 1 if ctr == 0: return f'{prefix}' else: return f'{prefix}_{ctr}' # derived from ProtoField for validation class Field[T](struct.ProtoField): __slots__ = ( 'name', 'sname', 'type', 'type_is_generic_self', 'coerce', 'compcoef', 'inheritable', 'simpledelta', 'ephemeral', 'allow_ddl_set', 'ddl_identity', 'aux_cmd_data', 'special_ddl_syntax', 'describe_visibility', 'weak_ref', 'merge_fn', 'reflection_method', 'reflection_proxy', 'is_reducible', 'patch_level', 'obj_names_as_string', ) #: Name of the field on the target class; assigned by ObjectMeta name: str #: The name to use when reflecting the field into the schema. #: The same as name by default, but can be overridden. sname: str #: The type of the value stored in the field type: type[T] #: Specifies if *type* is a generic type of the host object #: this field is defined on. type_is_generic_self: bool #: Whether the field is allowed to automatically coerce #: the input value to the declared type of the field. coerce: bool #: The diffing coefficient to use when comparing field #: values in objects from 0 to 1. compcoef: Optional[float] #: Whether the field value can be inherited. inheritable: bool #: Wheter the field uses the generic AlterObjectProperty #: delta op, or a custom delta command. simpledelta: bool #: If true, the value of the field is not persisted in the #: database. ephemeral: bool #: Whether the field can be set directly using the `SET` #: command in DDL. allow_ddl_set: bool #: Whether the field is used to identify the object #: in DDL operations and schema reflection when object #: name is insufficient. ddl_identity: bool #: Whether the value of this field should be included in the #: aux_object_data for delta commands of objects containing the field. aux_cmd_data: bool #: Whether this field is set using special DDL syntax or a generic #: SET command. special_ddl_syntax: bool #: Determines when this field is shown in #: DESCRIBE AS TEXT [VERBOSE]. describe_visibility: DescribeVisibilityPolicy #: Used for fields holding references to objects. If True, #: the reference is considered "weak", i.e. not essential for #: object definition. The schema and delta linearization #: rely on this to break object reference cycles. weak_ref: bool #: A callable used to merge the value of the field from #: multiple objects. Most oftenly used by inheritance. merge_fn: MergeFunction #: Defines how the field is reflected into the backend schema storage. reflection_method: ReflectionMethod #: In cases when the value of the field cannot be reflected as a #: direct link (for example, if the value is a non-distinct set), #: this specifies a (ProxyType, linkname) pair of a proxy object type #: and the name of the link within that proxy type. reflection_proxy: Optional[tuple[str, str]] #: Which edgeql+schema patch for the current major version this #: field was introduced in. Ensures that the data tuples always #: get extended strictly at the end and filters out the field when #: applying earlier patches. patch_level: int #: Interpret any assigned object names as strings. obj_names_as_string: bool def __init__( self, type_: builtins.type[T], *, type_is_generic_self: bool = False, coerce: bool = False, compcoef: Optional[float] = None, inheritable: bool = True, simpledelta: bool = True, merge_fn: MergeFunction = default_field_merge, ephemeral: bool = False, weak_ref: bool = False, allow_ddl_set: bool = False, describe_visibility: DescribeVisibilityPolicy = ( DescribeVisibilityPolicy.SHOW_IF_EXPLICIT), ddl_identity: bool = False, aux_cmd_data: bool = False, special_ddl_syntax: bool = False, reflection_method: ReflectionMethod = ReflectionMethod.REGULAR, reflection_proxy: Optional[tuple[str, str]] = None, name: Optional[str] = None, reflection_name: Optional[str] = None, patch_level: int = -1, obj_names_as_string: bool = False, **kwargs: Any, ) -> None: """Schema item core attribute definition. """ if not isinstance(type_, type): raise ValueError(f'{type_!r} is not a type') self.type = type_ self.type_is_generic_self = type_is_generic_self self.coerce = coerce self.allow_ddl_set = allow_ddl_set self.ddl_identity = ddl_identity self.aux_cmd_data = aux_cmd_data self.special_ddl_syntax = special_ddl_syntax self.describe_visibility = describe_visibility self.compcoef = compcoef self.inheritable = inheritable self.simpledelta = simpledelta self.weak_ref = weak_ref self.reflection_method = reflection_method self.reflection_proxy = reflection_proxy self.is_reducible = issubclass(type_, s_abc.Reducible) self.patch_level = patch_level self.obj_names_as_string = obj_names_as_string if name is not None: self.name = name if reflection_name is not None: self.sname = reflection_name if ( merge_fn is default_field_merge and callable( type_merge_fn := getattr(self.type, 'merge_values', None) ) ): self.merge_fn = type_merge_fn else: self.merge_fn = merge_fn self.ephemeral = ephemeral def coerce_value( self, schema: s_schema.Schema, value: Any, ) -> Optional[T]: ftype = self.type if value is None or isinstance(value, ftype): return value if not self.coerce: raise TypeError( f'{self.name} field: expected {ftype} but got {value!r}') if issubclass(ftype, (checked.CheckedList, checked.CheckedSet, checked.FrozenCheckedList, checked.FrozenCheckedSet)): casted_list = [] # Mypy complains about ambiguity and generics in class vars here, # although the generic in SingleParameter is clearly a type. valtype = ftype.type # type: ignore # When creating a checked collection field, we may receive either # collections or single items. # If the value is a collection, cast each item separately. If the # value is a single item, cast it directly. if ( isinstance(value, Collection) and not isinstance(value, (str, bytes, bytearray)) ): for v in value: if v is not None and not isinstance(v, valtype): v = valtype(v) casted_list.append(v) else: casted_list.append(valtype(value)) value = casted_list elif issubclass(ftype, checked.CheckedDict): casted_dict = {} for k, v in value.items(): if k is not None and not isinstance(k, ftype.keytype): k = ftype.keytype(k) if v is not None and not isinstance(v, ftype.valuetype): v = ftype.valuetype(v) casted_dict[k] = v value = casted_dict elif issubclass(ftype, ObjectCollection): # Type ignore below because mypy narrowed ftype to # Type[ObjectCollection] and lost track that it's actually # Type[T] return ftype.create(schema, value) # type: ignore elif issubclass(ftype, sn.QualName): return ftype.from_string(value) # type: ignore try: # Type ignore below because Mypy doesn't trust we can instantiate # the type using the value. We don't trust that either but this # is why there's the try-except block. return ftype(value) # type: ignore except Exception: raise TypeError( f'cannot coerce {self.name!r} value {value!r} to {ftype}') @property def required(self) -> bool: return True @property def is_schema_field(self) -> bool: return False def get_default(self) -> Any: raise ValueError(f'field {self.name!r} is required and has no default') def __get__( self, instance: Optional[Object], owner: builtins.type[Object], ) -> Optional[T]: if instance is not None: return None else: raise AttributeError( f"type object {owner.__name__!r} " f"has no attribute {self.name!r}" ) def __repr__(self) -> str: return ( f'<{type(self).__name__} name={self.name!r} ' f'type={self.type} {id(self):#x}>' ) class SchemaField[Type_T: type](Field[Type_T]): __slots__ = ('default', 'hashable', 'allow_ddl_set', 'allow_interpolation', 'index', 'get_default_specialized') #: The default value to use for the field. default: Any #: Whether the field participates in object hash. hashable: bool #: Whether it's possible to set the field in DDL. allow_ddl_set: bool #: Whether, when setting the field in DDL, we allow \(expr) #: string interpolation (and transform it into {field} #: style interpolation). allow_interpolation: bool #: Field index within the object data tuple index: int #: Specialized get_default function get_default_specialized: Callable[[], Any] def __init__( self, type: Type_T, *, default: Any = NoDefault, hashable: bool = True, allow_ddl_set: bool = False, allow_interpolation: bool = False, **kwargs: Any, ) -> None: super().__init__(type, **kwargs) self.default = default self.hashable = hashable self.allow_ddl_set = allow_ddl_set self.allow_interpolation = allow_interpolation self.index = -1 # Use this instead of get_default if you can; get_default has to # be a method because it comes from the parent self.get_default_specialized = self._make_get_default() @property def required(self) -> bool: return self.default is NoDefault @property def is_schema_field(self) -> bool: return True def get_default(self) -> Any: return self.get_default_specialized() def _make_get_default(self) -> Callable[[], Any]: if self.default is NoDefault: def _get_error() -> Any: raise ValueError( f'field {self.name!r} is required and has no default') return _get_error elif self.default is DEFAULT_CONSTRUCTOR: # ObjectCollection might not be defined yet when we first need # to call this, so we hack it around a bit. if getattr(self.type, 'is_object_collection', False): return self.type.create_empty # type: ignore else: return self.type else: def _get_simple(value: Any=self.default) -> Any: return value return _get_simple def __get__( self, instance: Optional[Object], owner: type[Object], ) -> Optional[Type_T]: if instance is not None: raise FieldValueNotFoundError(self.name) else: raise AttributeError( f"type object {owner.__name__!r} " f"has no attribute {self.name!r}" ) class RefDict(struct.RTStruct): attr = struct.Field( str, frozen=True) backref_attr = struct.Field( str, default='subject', frozen=True) requires_explicit_overloaded = struct.Field( bool, default=False, frozen=True) ref_cls: type[Object] = struct.Field( type, frozen=True) class ObjectContainer(s_abc.Reducible): @classmethod def schema_refs_from_data( cls, data: Any, ) -> frozenset[uuid.UUID]: raise NotImplementedError class ObjectMeta(type): _all_types: ClassVar[dict[str, type[Object]]] = {} _schema_types: ClassVar[set[ObjectMeta]] = set() _ql_map: ClassVar[dict[qltypes.SchemaObjectClass, ObjectMeta]] = {} _refdicts_to: ClassVar[ dict[ObjectMeta, list[tuple[RefDict, ObjectMeta]]] ] = {} # Instance fields (i.e. class fields on types built with ObjectMeta) _displayname: str _fields: dict[str, Field[Any]] _schema_fields: dict[str, SchemaField[Any]] _hashable_fields: set[Field[Any]] # if f.is_schema_field and f.hashable _sorted_fields: collections.OrderedDict[str, Field[Any]] #: Fields that contain references to objects either directly or #: indirectly. _objref_fields: frozenset[SchemaField[Any]] _reducible_fields: frozenset[SchemaField[Any]] _aux_cmd_data_fields: frozenset[SchemaField[Any]] # if f.aux_cmd_data _refdicts: collections.OrderedDict[str, RefDict] _refdicts_by_refclass: dict[type, RefDict] _refdicts_by_field: dict[str, RefDict] # key is rd.attr _ql_class: Optional[qltypes.SchemaObjectClass] _reflection_method: ReflectionMethod _reflection_link: Optional[str] #: Indicates that ALL changes to this object class are safe from the #: standpoint of persistent data. In other words, changes to the #: object are fully reversible without possible data loss. _data_safe: bool #: Which edgeql+schema patch for the current major version this #: object was introduced in. Ensures that the data tuples always #: get extended strictly at the end and filters out the field when #: applying earlier patches. _patch_level: int #: Whether the type should be abstract in EdgeDB schema. #: This only applies if the type wasn't specified in schema.edgeql. _abstract: Optional[bool] def __new__( mcls, name: str, bases: tuple[type, ...], clsdict: dict[str, Any], *, qlkind: Optional[qltypes.SchemaObjectClass] = None, reflection: ReflectionMethod = ReflectionMethod.REGULAR, reflection_link: Optional[str] = None, data_safe: bool = False, abstract: Optional[bool] = None, patch_level: int = -1, **kwargs: Any, ) -> ObjectMeta: refdicts: collections.OrderedDict[str, RefDict] fields = {} myfields = {} refdicts = collections.OrderedDict() mydicts = {} if name in mcls._all_types: raise TypeError( f'duplicate name for schema class: {name}, already defined' f' as {mcls._all_types[name]!r}' ) for k, field in tuple(clsdict.items()): if isinstance(field, RefDict): mydicts[k] = field continue if not isinstance(field, struct.ProtoField): continue if not isinstance(field, Field): raise TypeError( f'cannot create {name} class: schema.objects.Field ' f'expected, got {type(field)}') field.name = k if not hasattr(field, 'sname'): field.sname = k myfields[k] = field del clsdict[k] try: cls = super().__new__(mcls, name, bases, clsdict, **kwargs) except TypeError as ex: raise TypeError( f'Object metaclass has failed to create class {name}: {ex}') for parent in reversed(cls.__mro__): if parent is cls: fields.update(myfields) refdicts.update(mydicts) elif isinstance(parent, ObjectMeta): fields.update({ fn: copy.copy(f) for fn, f in parent.get_ownfields().items() }) refdicts.update({ k: d.copy() for k, d in parent.get_own_refdicts().items() }) cls._displayname = re.sub( r'([a-z])([A-Z])', r'\1 \2', cls.__name__ ).lower() cls._data_safe = data_safe cls._abstract = abstract cls._patch_level = patch_level cls._fields = fields cls._schema_fields = { fn: f for fn, f in sorted(fields.items(), key=lambda f: f[1].patch_level) if isinstance(f, SchemaField) } cls._hashable_fields = { f for f in cls._schema_fields.values() if f.hashable } cls._aux_cmd_data_fields = frozenset( f for f in cls._schema_fields.values() if f.aux_cmd_data ) cls._sorted_fields = collections.OrderedDict( sorted(fields.items(), key=lambda e: e[0])) cls._objref_fields = frozenset( f for f in cls._schema_fields.values() if issubclass(f.type, ObjectContainer) ) cls._reducible_fields = frozenset( f for f in cls._schema_fields.values() if issubclass(f.type, s_abc.Reducible) ) fa = '{}.{}_fields'.format(cls.__module__, cls.__name__) setattr(cls, fa, myfields) for findex, field in enumerate(cls._schema_fields.values()): field.index = findex getter_name = f'get_{field.name}' if getter_name in clsdict: # The getter was defined explicitly, move on. continue ftype = field.type # The field getters are hot code as they're essentially # attribute access, so be mindful about what you are adding # into the callables below. if issubclass(ftype, s_abc.Reducible): def reducible_getter( self: Any, schema: s_schema.Schema, *, _fn: str = field.name, _fi: int = findex, _sr: Callable[[Any], s_abc.Reducible] = ( ftype.schema_restore ), _fd: Callable[[], Any] = field.get_default, ) -> Any: v = schema.get_field_raw(self, _fi) if v is not None: return _sr(v) else: try: return _fd() except ValueError: pass raise FieldValueNotFoundError( f'{self!r} object has no value ' f'for field {_fn!r}' ) setattr(cls, getter_name, reducible_getter) elif ( field.default is not NoDefault and field.default is not DEFAULT_CONSTRUCTOR ): def regular_default_getter( self: Any, schema: s_schema.Schema, *, _fi: int = findex, _fd: Any = field.default, ) -> Any: v = schema.get_field_raw(self, _fi) if v is not None: return v else: return _fd setattr(cls, getter_name, regular_default_getter) else: def regular_getter( self: Any, schema: s_schema.Schema, *, _fn: str = field.name, _fi: int = findex, _fd: Callable[[], Any] = field.get_default, ) -> Any: v = schema.get_field_raw(self, _fi) if v is not None: return v else: try: return _fd() except ValueError: pass raise FieldValueNotFoundError( f'{self!r} object has no value ' f'for field {_fn!r}' ) setattr(cls, getter_name, regular_getter) non_schema_fields = {field.name for field in fields.values() if not field.is_schema_field} if non_schema_fields == {'id'} and len(fields) > 1: mcls._schema_types.add(cls) if qlkind is not None: mcls._ql_map[qlkind] = cls cls._refdicts_by_refclass = {} for dct in refdicts.values(): if dct.attr not in cls._fields: raise RuntimeError( f'object {name} has no refdict field {dct.attr}') if cls._fields[dct.attr].inheritable: raise RuntimeError( f'{name}.{dct.attr} field must not be inheritable') if not cls._fields[dct.attr].ephemeral: raise RuntimeError( f'{name}.{dct.attr} field must be ephemeral') if not cls._fields[dct.attr].coerce: raise RuntimeError( f'{name}.{dct.attr} field must be coerced') other_dct = cls._refdicts_by_refclass.get(dct.ref_cls) if other_dct is not None: raise TypeError( 'multiple reference dicts for {!r} in ' '{!r}: {!r} and {!r}'.format(dct.ref_cls, cls, dct.attr, other_dct.attr)) cls._refdicts_by_refclass[dct.ref_cls] = dct try: refdicts_to = mcls._refdicts_to[dct.ref_cls] except KeyError: refdicts_to = mcls._refdicts_to[dct.ref_cls] = [] refdicts_to.append((dct, cls)) # Refdicts need to be reversed here to respect the __mro__, # as we have iterated over it in reverse above. cls._refdicts = collections.OrderedDict(reversed(refdicts.items())) cls._refdicts_by_field = {rd.attr: rd for rd in cls._refdicts.values()} setattr(cls, '{}.{}_refdicts'.format(cls.__module__, cls.__name__), mydicts) for f in myfields.values(): if (issubclass(f.type, parametric.ParametricType) and not f.type.is_fully_resolved()): f.type.resolve_types({cls.__name__: cls}) cls._ql_class = qlkind cls._reflection_method = reflection if reflection is ReflectionMethod.AS_LINK: if reflection_link is None: raise TypeError( 'reflection AS_LINK requires reflection_link to be passed' ' also' ) cls._reflection_link = reflection_link mcls._all_types[name] = cast(type['Object'], cls) return cls def get_object_reference_fields(cls) -> frozenset[SchemaField[Any]]: return cls._objref_fields def get_reducible_fields(cls) -> frozenset[SchemaField[Any]]: return cls._reducible_fields def get_aux_cmd_data_fields(cls) -> frozenset[SchemaField[Any]]: return cls._aux_cmd_data_fields def has_field(cls, name: str) -> bool: return name in cls._fields def get_field(cls, name: str) -> Field[Any]: field = cls._fields.get(name) if field is None: raise LookupError( f'schema class {cls.__name__!r} has no field {name!r}' ) return field def get_fields(cls, sorted: bool = False) -> Mapping[str, Field[Any]]: return cls._sorted_fields if sorted else cls._fields def get_schema_field(cls, name: str) -> SchemaField[Any]: field = cls._schema_fields.get(name) if field is None: raise LookupError( f'schema class {cls.__name__!r} has no schema field {name!r}' ) return field def get_schema_fields(cls) -> Mapping[str, SchemaField[Any]]: return cls._schema_fields def get_ownfields(cls) -> Mapping[str, Field[Any]]: return getattr( # type: ignore cls, f'{cls.__module__}.{cls.__name__}_fields', ) def get_own_refdicts(cls) -> Mapping[str, RefDict]: return getattr( # type: ignore cls, f'{cls.__module__}.{cls.__name__}_refdicts', ) def get_refdicts(cls) -> Iterator[RefDict]: return iter(cls._refdicts.values()) def has_refdict(cls, name: str) -> bool: return name in cls._refdicts_by_field def get_refdict(cls, name: str) -> RefDict: refdict = cls._refdicts_by_field.get(name) if refdict is None: raise LookupError( f'schema class {cls.__name__!r} has no refdict {name!r}' ) return refdict def get_refdict_for_class(cls, refcls: type) -> RefDict: for rcls in refcls.__mro__: try: return cls._refdicts_by_refclass[rcls] except KeyError: pass else: raise KeyError(f'{cls} has no refdict for {refcls}') def get_referring_classes(cls) -> frozenset[tuple[RefDict, ObjectMeta]]: try: refdicts_to = type(cls)._refdicts_to[cls] except KeyError: return frozenset() else: return frozenset(refdicts_to) @property def is_schema_object(cls) -> bool: return cls in ObjectMeta._schema_types @classmethod def get_schema_metaclasses(mcls) -> Iterator[type[Object]]: return iter(mcls._all_types.values()) @classmethod def get_schema_class(mcls, name: str) -> type[Object]: return mcls._all_types[name] @classmethod def maybe_get_schema_class(mcls, name: str) -> Optional[type[Object]]: return mcls._all_types.get(name) @classmethod def get_schema_metaclass_for_ql_class( mcls, qlkind: qltypes.SchemaObjectClass ) -> type[Object]: cls = mcls._ql_map.get(qlkind) if cls is None: raise LookupError(f'no schema metaclass for {qlkind}') return cast(type[Object], cls) def get_ql_class(cls) -> Optional[qltypes.SchemaObjectClass]: return cls._ql_class def get_ql_class_or_die(cls) -> qltypes.SchemaObjectClass: if cls._ql_class is not None: return cls._ql_class else: raise LookupError(f'{cls} has no edgeql class string assigned') def get_reflection_method(cls) -> ReflectionMethod: return cls._reflection_method def get_reflection_link(cls) -> Optional[str]: return cls._reflection_link class FieldValueNotFoundError(Exception): pass class Object(ObjectContainer, metaclass=ObjectMeta): """Base schema item class.""" __slots__ = ('id',) is_global_object = False # Unique ID for this schema item. id = Field( uuid.UUID, inheritable=False, simpledelta=False, allow_ddl_set=True, ) internal = SchemaField( bool, inheritable=False, ) # Span of source text that contained definition of this object. # This field is ephemeral, which means it not seriliazed and saved # persistently. This is ok, because we only need it for language server. span = SchemaField( parsing.Span, default=None, compcoef=None, hashable=False, ephemeral=True, ) name = SchemaField( sn.Name, inheritable=False, compcoef=0.670, ) builtin = SchemaField( bool, default=False, compcoef=0.01, inheritable=False, ) # Fields that have been computed by the system as opposed to # set explicitly or inherited. computed_fields = SchemaField( checked.FrozenCheckedSet[str], default=DEFAULT_CONSTRUCTOR, coerce=True, inheritable=False, compcoef=0.999, ) _fields: dict[str, SchemaField[Any]] def schema_reduce(self) -> tuple[str, uuid.UUID]: return type(self).__name__, self.id @staticmethod @lru.per_job_lru_cache(maxsize=10240) def raw_schema_restore( sclass_name: str, obj_id: uuid.UUID, ) -> Object: sclass = ObjectMeta.get_schema_class(sclass_name) return sclass(_private_id=obj_id) @staticmethod def schema_restore( data: tuple[str, uuid.UUID], ) -> Object: sclass_name, obj_id = data return Object.raw_schema_restore(sclass_name, obj_id) @classmethod def schema_refs_from_data( cls, data: tuple[str, uuid.UUID], ) -> frozenset[uuid.UUID]: return frozenset((data[1],)) def get_id(self, schema: s_schema.Schema) -> uuid.UUID: return self.id @classmethod def get_schema_class_displayname(cls) -> str: return cls._displayname @classmethod def get_shortname_static(cls, name: sn.Name) -> sn.Name: return name @classmethod def get_local_name_static(cls, name: sn.Name) -> sn.UnqualName: return cls.get_shortname_static(name).get_local_name() @classmethod def get_displayname_static(cls, name: sn.Name) -> str: return str(cls.get_shortname_static(name)) @classmethod def get_verbosename_static( cls, name: sn.Name, *, parent: Optional[str] = None, ) -> str: clsname = cls.get_schema_class_displayname() dname = cls.get_displayname_static(name) if parent is not None: return f"{clsname} '{dname}' of {parent}" else: return f"{clsname} '{dname}'" @classmethod def is_abstract(cls) -> bool: """Return True if this type does NOT represent a concrete schema class. """ return cls.get_ql_class() is None def get_shortname(self, schema: s_schema.Schema) -> sn.Name: return type(self).get_shortname_static(self.get_name(schema)) def get_local_name(self, schema: s_schema.Schema) -> sn.UnqualName: return type(self).get_local_name_static(self.get_name(schema)) def get_displayname(self, schema: s_schema.Schema) -> str: return type(self).get_displayname_static(self.get_name(schema)) def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool = False ) -> str: clsname = self.get_schema_class_displayname() dname = self.get_displayname(schema) return f"{clsname} '{dname}'" def __init__(self, *, _private_id: uuid.UUID) -> None: self.id = _private_id def __eq__(self, other: Any) -> bool: try: return self.id == other.id # type: ignore except AttributeError: return NotImplemented def __hash__(self) -> int: return hash(self.id) @classmethod def _prepare_id( cls, schema: s_schema.Schema, stable_ids: bool, data: dict[str, Any], ) -> uuid.UUID: name = data.get('name') assert isinstance(name, (str, sn.Name)) try: return get_known_type_id(name) except errors.SchemaError: if stable_ids: # When compiling the standard library, we generate # stable ids based on the internal name and the type's # name. This keeps std schemas compatible across # minor versions at least. return uuidgen.uuid5( TYPE_ID_NAMESPACE, f'{name}-{cls.__name__}') else: return uuidgen.uuid1mc() @classmethod def _create_from_id(cls: type[Self], id: uuid.UUID) -> Self: assert id is not None return cls(_private_id=id) @classmethod def create_in_schema[Schema_T: s_schema.Schema]( cls: type[Self], schema: Schema_T, stable_ids: bool = False, *, id: Optional[uuid.UUID] = None, **data: Any, ) -> tuple[Schema_T, Self]: if not cls.is_schema_object: raise TypeError(f'{cls.__name__} type cannot be created in schema') if not data.get('name'): raise RuntimeError(f'cannot create {cls} without a name') all_fields = cls.get_schema_fields() obj_data = [None] * len(all_fields) for field_name, value in data.items(): field = cls.get_schema_field(field_name) value = field.coerce_value(schema, value) obj_data[field.index] = value if id is None: id = cls._prepare_id(schema, stable_ids, data) scls = cls._create_from_id(id) schema = schema.add(id, cls, tuple(obj_data)) return schema, scls # XXX sadly, in the methods below, statically we don't know any better than # "Any" since providing the field name as a `str` is the equivalent of # getattr() on a regular class. def get_field_value( self, schema: s_schema.Schema, field_name: str, ) -> Any: field = type(self).get_field(field_name) if isinstance(field, SchemaField): val = schema.get_field_raw(self, field.index) if val is not None: if field.is_reducible: return field.type.schema_restore(val) else: return val else: try: return field.get_default() except ValueError: pass else: try: return object.__getattribute__(self, field_name) except AttributeError: pass raise FieldValueNotFoundError( f'{self!r} object has no value for field {field_name!r}') def get_explicit_field_value( self, schema: s_schema.Schema, field_name: str, default: Any = NoDefault, ) -> Any: field = type(self).get_field(field_name) if isinstance(field, SchemaField): val = schema.get_field_raw(self, field.index) if val is not None: if field.is_reducible: return field.type.schema_restore(val) else: return val elif default is not NoDefault: return default else: try: return object.__getattribute__(self, field_name) except AttributeError: if default is not NoDefault: return default raise FieldValueNotFoundError( f'{self!r} object has no value for field {field_name!r}') def set_field_value( self, schema: s_schema.Schema, name: str, value: Any, ) -> s_schema.Schema: field = type(self)._fields[name] assert field.is_schema_field if value is None: return schema.unset_field(self, name) else: value = field.coerce_value(schema, value) return schema.set_field(self, name, value) def update( self, schema: s_schema.Schema, updates: dict[str, Any] ) -> s_schema.Schema: fields = type(self)._fields updates = updates.copy() for field_name in updates: field = fields[field_name] assert field.is_schema_field new_val = updates[field_name] if new_val is not None: new_val = field.coerce_value(schema, new_val) updates[field_name] = new_val return schema.update_obj(self, updates) def hash_criteria( self: Self, schema: s_schema.Schema ) -> frozenset[HashCriterion]: cls = type(self) sig: list[type[Self] | tuple[str, Any]] = [cls] for f in cls._hashable_fields: fn = f.name val = self.get_explicit_field_value(schema, fn, default=None) if val is None: continue elif isinstance(val, collections.abc.MutableSequence): # Turn the list into something hashable so it can be # put in a set. val = tuple(val) elif isinstance(val, collections.abc.MutableMapping): # Turn the dict into something hashable so it can be # put in a set. val = tuple((k, v) for k, v in val.items()) sig.append((fn, val)) return frozenset(sig) def compare( self, other: Object, *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: ComparisonContext, ) -> float: if (not isinstance(other, self.__class__) and not isinstance(self, other.__class__)): raise NotImplementedError( f'class {self.__class__.__name__!r} and ' f'class {other.__class__.__name__!r} are not comparable' ) cls = type(self) similarity = 1.0 fields = cls.get_fields(sorted=True) for field in fields.values(): if field.compcoef is None: continue fcoef = cls.compare_obj_field_value( field, self, other, our_schema=our_schema, their_schema=their_schema, context=context, ) similarity *= fcoef return similarity def is_blocking_ref( self, schema: s_schema.Schema, reference: Object ) -> bool: return True def is_parent_ref( self, schema: s_schema.Schema, reference: Object, ) -> bool: """Return True if *reference* is a structural ancestor of self.""" return False def is_generated(self, schema: s_schema.Schema) -> bool: return False @classmethod def compare_field_value[T]( cls, field: Field[type[T]], our_value: T, their_value: T, *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: ComparisonContext, ) -> float: if ( our_value is not None and their_value is not None and type(our_value) is type(their_value) ): comparator = getattr(type(our_value), 'compare_values', None) else: comparator = getattr(field.type, 'compare_values', None) assert field.compcoef is not None if callable(comparator): result = comparator( our_value, their_value, context=context, our_schema=our_schema, their_schema=their_schema, compcoef=field.compcoef, ) assert isinstance(result, (float, int)) return result if our_value != their_value: return field.compcoef else: return 1.0 @classmethod def compare_obj_field_value[T]( cls: type[Self], field: Field[type[T]], ours: Self, theirs: Self, *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: ComparisonContext, explicit: bool = False, ) -> float: fname = field.name # If a field is not inheritable (and thus cannot be affected # by other objects) and the value is missing, it is exactly # equivalent to that field having the default value instead, # so we should use the default for comparisons. This means # that we perform the comparison as if explicit = False. # # E.g. 'owned' being None and False is semantically # identical and should not be considered a change. if (isinstance(field, SchemaField) and not field.inheritable): explicit = False if explicit: our_value = ours.get_explicit_field_value( our_schema, fname, None) their_value = theirs.get_explicit_field_value( their_schema, fname, None) else: our_value = ours.get_field_value(our_schema, fname) their_value = theirs.get_field_value(their_schema, fname) similarity = cls.compare_field_value( field, our_value, their_value, our_schema=our_schema, their_schema=their_schema, context=context, ) # Check to see if this field's computed status has changed. our_cfs = ours.get_computed_fields(our_schema) their_cfs = theirs.get_computed_fields(their_schema) fname = field.name if (fname in our_cfs) != (fname in their_cfs): # The change in computed status decreases the similarity. similarity *= 0.95 return similarity @classmethod def compare_values( cls: type[Self], ours: Optional[Object_T], theirs: Optional[Object_T], *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: ComparisonContext, compcoef: float, ) -> float: """Compare two values and return a coefficient of similarity. This is a common callback that is used when we do schema comparisons. *ours* and *theirs* are instances of this class, and *our_schema* and *their_schema* are the corresponding schemas in which the values are defined. *compcoef* is whatever was specified for the field. The method returns a coefficient of similarity of the values, from ``0`` to ``1``. """ similarity = 1.0 if ours is not None and theirs is not None: if type(ours) is not type(theirs): similarity /= 1.4 else: our_name = context.get_obj_name(our_schema, ours) their_name = theirs.get_name(their_schema) if our_name != their_name: similarity /= 1.2 else: # If the new and old versions share a reference to # an object that is being deleted, then we must # delete this object as well. if (type(ours), our_name) in context.deletions: return 0.0 elif ours is not None or theirs is not None: # one is None but not both similarity /= 1.2 if similarity < 1.0: return compcoef else: return 1.0 def refresh_classref( self, schema: s_schema.Schema, collection: str, ) -> s_schema.Schema: refdict = type(self).get_refdict(collection) attr = refdict.attr colltype = type(self).get_field(attr).type coll = self.get_explicit_field_value(schema, attr, None) if coll is not None: all_coll = colltype.create(schema, coll.objects(schema)) schema = self.set_field_value(schema, attr, all_coll) return schema def add_classref( self, schema: s_schema.Schema, collection: str, obj: Object, replace: bool = False, ) -> s_schema.Schema: refdict = type(self).get_refdict(collection) attr = refdict.attr colltype = type(self).get_field(attr).type coll = self.get_explicit_field_value(schema, attr, None) if coll is not None: schema, all_coll = coll.update(schema, [obj]) else: all_coll = colltype.create(schema, [obj]) schema = self.set_field_value(schema, attr, all_coll) return schema def field_is_computed( self, schema: s_schema.Schema, field_name: str, ) -> bool: return field_name in self.get_computed_fields(schema) def field_is_inherited( self, schema: s_schema.Schema, field_name: str, ) -> bool: return False def del_classref( self, schema: s_schema.Schema, collection: str, key: str, ) -> s_schema.Schema: refdict = type(self).get_refdict(collection) attr = refdict.attr coll = self.get_field_value(schema, attr) if coll and coll.has(schema, key): schema, coll = coll.delete(schema, [key]) schema = self.set_field_value(schema, attr, coll) return schema def as_shell( self: Self, schema: s_schema.Schema, ) -> ObjectShell[Self]: return ObjectShell( name=self.get_name(schema), displayname=self.get_displayname(schema), schemaclass=type(self), ) def get_ddl_identity( self, schema: s_schema.Schema, ) -> Optional[dict[str, Any]]: ddl_id_fields = [ fn for fn, f in type(self).get_fields().items() if f.ddl_identity ] ddl_identity: Optional[dict[str, Any]] if ddl_id_fields: ddl_identity = {} for fn in ddl_id_fields: v = self.get_field_value(schema, fn) if v is not None: ddl_identity[fn] = v else: ddl_identity = None return ddl_identity def init_delta_command[ ObjectCommand_T: sd.ObjectCommand[Object] ]( self, schema: s_schema.Schema, cmdtype: type[ObjectCommand_T], *, classname: Optional[sn.Name] = None, **kwargs: Any, ) -> ObjectCommand_T: from . import delta as sd cls = type(self) cmd = sd.get_object_delta_command( objtype=cls, cmdtype=cmdtype, schema=schema, name=classname or self.get_name(schema), ddl_identity=self.get_ddl_identity(schema), **kwargs, ) cmd.scls = self self.record_cmd_object_aux_data(schema, cmd) return cmd def record_cmd_object_aux_data( self: Self, schema: s_schema.Schema, cmd: sd.ObjectCommand[Any], ) -> None: for field in type(self).get_aux_cmd_data_fields(): cmd.set_object_aux_data( field.name, self.get_field_value(schema, field.name), ) def init_parent_delta_branch( self: Self, schema: s_schema.Schema, context: sd.CommandContext, *, referrer: Optional[Object] = None, ) -> tuple[sd.CommandGroup, sd.Command, sd.ContextStack]: """Prepare a parent portion of a command tree for this object. This returns a tuple containing: - the root (as a ``CommandGroup``) of a nested ``AlterObject`` tree with nodes for each enclosing referrer object; - direct reference to the innermost command in the above tree (may be root if there are no referring objects); - a ``ContextStack`` instance representing the nested CommandContext corresponding to the returned command tree. """ from . import delta as sd root = sd.CommandGroup() return root, root, sd.ContextStack(()) def init_delta_branch[ObjectCommand_T: sd.ObjectCommand[Object]]( self, schema: s_schema.Schema, context: sd.CommandContext, cmdtype: type[ObjectCommand_T], *, classname: Optional[sn.Name] = None, referrer: Optional[Object] = None, possible_parent: Optional[sd.ObjectCommand[Object]] = None, **kwargs: Any, ) -> tuple[sd.Command, ObjectCommand_T, sd.ContextStack]: """Make a command subtree for this object. This returns a tuple containing: - the root (as a ``CommandGroup``) of a nested ``AlterObject`` tree with nodes for each enclosing referrer object and an instance of *cmdtype* as the innermost command; - direct reference to the innermost command in the above tree; - a ``ContextStack`` instance representing the nested CommandContext corresponding to the returned command tree. """ root_cmd: sd.Command root_cmd, parent_cmd, ctx_stack = self.init_parent_delta_branch( schema=schema, context=context, referrer=referrer, ) self_cmd = self.init_delta_command( schema, cmdtype=cmdtype, classname=classname, **kwargs, ) from . import delta as sd # possible_parent allows the caller to tell us what *they* are, # so we can reuse that Alter if we can. The big advantage here is # that it saves needing to do validate_object on the intermediate # objects. if ( isinstance(possible_parent, sd.AlterObject) and isinstance(parent_cmd, sd.ObjectCommand) and possible_parent.classname == parent_cmd.classname ): root_cmd = parent_cmd = self_cmd else: parent_cmd.add(self_cmd) ctx_stack.push(self_cmd.new_context(schema, context, self)) return root_cmd, self_cmd, ctx_stack def as_create_delta( self: Self, schema: s_schema.Schema, context: ComparisonContext, ) -> sd.CreateObject[Self]: from . import delta as sd cls = type(self) delta = self.init_delta_command( schema, sd.CreateObject, canonical=True, ) if context.generate_prompts: delta.set_annotation('orig_cmdclass', type(delta)) ff = cls.get_fields(sorted=True).items() fields = {fn: f for fn, f in ff if f.simpledelta and not f.ephemeral} for fn, f in fields.items(): value = self.get_explicit_field_value(schema, fn, None) if ( value is None and context.descriptive_mode and ( f.describe_visibility & DescribeVisibilityFlags.SHOW_IF_DERIVED ) ): value = self.get_field_value(schema, fn) value_from_default = True else: value_from_default = False if f.aux_cmd_data: delta.set_object_aux_data(fn, value) if value is not None: v: Any if issubclass(f.type, ObjectContainer): v = value.as_shell(schema) else: v = value self.record_field_create_delta( schema, delta, context=context, fname=fn, value=v, from_default=value_from_default, ) for refdict in cls.get_refdicts(): refcoll: ObjectCollection[Object] = ( self.get_field_value(schema, refdict.attr)) sorted_refcoll = sorted( refcoll.objects(schema), key=lambda o: o.get_name(schema), ) for ref in sorted_refcoll: delta.add(ref.as_create_delta(schema, context)) return delta def as_alter_delta( self: Self, other: Self, *, self_schema: s_schema.Schema, other_schema: s_schema.Schema, confidence: float, context: ComparisonContext, ) -> sd.ObjectCommand[Self]: from . import delta as sd cls = type(self) delta = self.init_delta_command( self_schema, sd.AlterObject, canonical=True, ) delta.set_annotation('confidence', confidence) if context.generate_prompts: other_name = other.get_name(other_schema) if self.get_name(self_schema) != other_name: delta.set_annotation('new_name', other_name) delta.set_annotation('orig_cmdclass', type(delta)) ff = cls.get_fields(sorted=True).items() fields = {fn: f for fn, f in ff if f.simpledelta and not f.ephemeral} for fn, f in fields.items(): oldattr_v = self.get_explicit_field_value(self_schema, fn, None) newattr_v = other.get_explicit_field_value(other_schema, fn, None) if f.aux_cmd_data: delta.set_object_aux_data(fn, newattr_v) old_v: Any new_v: Any if issubclass(f.type, ObjectContainer): if oldattr_v is not None: old_v = oldattr_v.as_shell(self_schema) else: old_v = None if newattr_v is not None: new_v = newattr_v.as_shell(other_schema) else: new_v = None else: old_v = oldattr_v new_v = newattr_v if f.compcoef is not None: fcoef = cls.compare_obj_field_value( f, self, other, our_schema=self_schema, their_schema=other_schema, context=context, explicit=True, ) if fcoef != 1.0: other.record_field_alter_delta( other_schema, delta, context, fname=fn, value=new_v, orig_value=old_v, orig_schema=self_schema, orig_object=self, confidence=confidence, ) for refdict in cls.get_refdicts(): oldcoll: ObjectCollection[Object] = ( self.get_field_value(self_schema, refdict.attr)) oldcoll_idx = sorted( oldcoll.objects(self_schema), key=lambda o: o.get_name(self_schema) ) newcoll: ObjectCollection[Object] = ( other.get_field_value(other_schema, refdict.attr)) newcoll_idx = sorted( newcoll.objects(other_schema), key=lambda o: o.get_name(other_schema), ) context.parent_ops.append(delta) delta.add( sd.delta_objects( oldcoll_idx, newcoll_idx, sclass=refdict.ref_cls, parent_confidence=confidence, context=context, old_schema=self_schema, new_schema=other_schema, ), ) context.parent_ops.pop() return delta def as_delete_delta( self: Self, *, schema: s_schema.Schema, context: ComparisonContext, ) -> sd.ObjectCommand[Self]: from . import delta as sd cls = type(self) delta = self.init_delta_command( schema, sd.DeleteObject, canonical=True, ) if context.generate_prompts: delta.set_annotation('orig_cmdclass', type(delta)) context.deletions[type(self), delta.classname] = delta ff = cls.get_fields(sorted=True).items() fields = {fn: f for fn, f in ff if f.simpledelta and not f.ephemeral} for fn, f in fields.items(): value = self.get_explicit_field_value(schema, fn, None) if f.aux_cmd_data: delta.set_object_aux_data(fn, value) if value is not None: if issubclass(f.type, ObjectContainer): v = value.as_shell(schema) else: v = value self.record_field_delete_delta( schema, delta, context, fn, orig_value=v, ) for refdict in cls.get_refdicts(): refcoll = self.get_field_value(schema, refdict.attr) for ref in refcoll.objects(schema): delta.add(ref.as_delete_delta(schema=schema, context=context)) return delta def record_simple_field_delta( self: Self, schema: s_schema.Schema, delta: sd.ObjectCommand[Self], context: ComparisonContext, *, fname: str, value: Any, orig_value: Any, orig_schema: Optional[s_schema.Schema], orig_object: Optional[Self], from_default: bool = False, ) -> None: computed_fields = self.get_computed_fields(schema) is_computed = fname in computed_fields if orig_schema is not None and orig_object is not None: orig_computed_fields = ( orig_object.get_computed_fields(orig_schema)) orig_is_computed = fname in orig_computed_fields else: orig_is_computed = is_computed cmd = delta.set_attribute_value( fname, value, orig_value=orig_value, computed=is_computed, orig_computed=orig_is_computed, from_default=from_default, ) context.parent_ops.append(delta) cmd.record_diff_annotations( schema=schema, orig_schema=orig_schema, context=context, object=self, orig_object=orig_object, ) context.parent_ops.pop() def record_field_create_delta( self: Self, schema: s_schema.Schema, delta: sd.ObjectCommand[Self], context: ComparisonContext, *, fname: str, value: Any, from_default: bool, ) -> None: self.record_simple_field_delta( schema, delta, context, fname=fname, value=value, orig_value=None, orig_schema=None, orig_object=None, from_default=from_default, ) def record_field_alter_delta( self: Self, schema: s_schema.Schema, delta: sd.ObjectCommand[Self], context: ComparisonContext, *, fname: str, value: Any, orig_value: Any, orig_schema: s_schema.Schema, orig_object: Self, confidence: float, ) -> None: from . import delta as sd if fname == 'name': rename_op = orig_object.init_delta_command( orig_schema, sd.RenameObject, new_name=value, ) rename_op.set_annotation('confidence', confidence) self.record_simple_field_delta( schema, rename_op, context, fname=fname, value=value, orig_value=orig_value, orig_schema=orig_schema, orig_object=orig_object, ) delta.add(rename_op) context.record_rename(rename_op) else: self.record_simple_field_delta( schema, delta, context, fname=fname, value=value, orig_value=orig_value, orig_schema=orig_schema, orig_object=orig_object, ) def record_field_delete_delta( self: Self, schema: s_schema.Schema, delta: sd.ObjectCommand[Self], context: ComparisonContext, fname: str, orig_value: Any, ) -> None: self.record_simple_field_delta( schema, delta, context, fname=fname, value=None, orig_value=orig_value, orig_schema=None, orig_object=None, ) def dump(self, schema: s_schema.Schema) -> str: return ( f'<{type(self).__name__} name={self.get_name(schema)!r} ' f'at {id(self):#x}>' ) def __repr__(self) -> str: return f'<{type(self).__name__} {self.id} at 0x{id(self):#x}>' class InternalObject(Object): """A schema object that is used by the system internally. Instances of InternalObject should not appear in schema dumps. """ @classmethod def is_abstract(cls) -> bool: """Return True if this type does NOT represent a concrete schema class. """ return cls is InternalObject class QualifiedObject(Object): name = SchemaField( # ignore below because Mypy doesn't understand fields which are not # inheritable. sn.QualName, # type: ignore inheritable=False, compcoef=0.670, ) @classmethod def get_shortname_static(cls, name: sn.Name) -> sn.QualName: result = sn.shortname_from_fullname(name) if not isinstance(result, sn.QualName): assert isinstance(name, sn.QualName) result = sn.QualName(module=name.module, name=result.name) return result def get_shortname(self, schema: s_schema.Schema) -> sn.QualName: return type(self).get_shortname_static(self.get_name(schema)) QualifiedObject_T = TypeVar('QualifiedObject_T', bound='QualifiedObject') class ObjectFragment(QualifiedObject): """A part of another object that cannot exist independently.""" class GlobalObject(Object): is_global_object = True GlobalObject_T = TypeVar('GlobalObject_T', bound='GlobalObject') class ExternalObject(GlobalObject): """An object that is not tracked in a schema, but some external state.""" pass ExternalObject_T = TypeVar('ExternalObject_T', bound='ExternalObject') class DerivableObject(QualifiedObject): def derive_name( self, schema: s_schema.Schema, source: QualifiedObject, *qualifiers: str, derived_name_base: Optional[sn.Name] = None, module: Optional[str] = None, ) -> sn.QualName: source_name = source.get_name(schema) if module is None: module = source_name.module qualifiers = (str(source_name),) + qualifiers return derive_name( schema, *qualifiers, module=module, parent=self, derived_name_base=derived_name_base, ) def is_non_concrete(self, schema: s_schema.Schema) -> bool: return self.get_shortname(schema) == self.get_name(schema) def get_derived_name_base(self, schema: s_schema.Schema) -> sn.Name: return self.get_shortname(schema) def get_derived_name( self, schema: s_schema.Schema, source: QualifiedObject, *qualifiers: str, mark_derived: bool = False, derived_name_base: Optional[sn.Name] = None, module: Optional[str] = None, ) -> sn.QualName: return self.derive_name( schema, source, *qualifiers, derived_name_base=derived_name_base, module=module) class Shell: """ Shells mimic objects, but are not part of the schema. They are construced from AST and are used to hold object data before it is commited to the schema. """ def resolve(self, schema: s_schema.Schema) -> Any: raise NotImplementedError class ObjectShell(Shell, Generic[Object_T_co]): # noqa: UP046 def __init__( self, *, name: sn.Name, schemaclass: type[Object_T_co], displayname: Optional[str] = None, origname: Optional[sn.Name] = None, span: Optional[parsing.Span] = None, ) -> None: self.name = name self.origname = origname self.displayname = displayname self.schemaclass = schemaclass self.span = span def get_id(self, schema: s_schema.Schema) -> uuid.UUID: return self.resolve(schema).get_id(schema) def resolve(self, schema: s_schema.Schema) -> Object_T_co: if self.name is None: raise TypeError( 'cannot resolve anonymous ObjectShell' ) if isinstance(self.name, sn.QualName): return schema.get( self.name, type=self.schemaclass, span=self.span, ) else: return schema.get_global(self.schemaclass, self.name) def get_refname(self, schema: s_schema.Schema) -> sn.Name: if self.origname is not None: return self.origname else: # XXX: change get_displayname to return Name return sn.name_from_string(self.get_displayname(schema)) def get_name(self, schema: s_schema.Schema) -> sn.Name: # this function is needed for polymorphism of Object and ObjectShell return self.name def get_displayname(self, schema: s_schema.Schema) -> str: return self.displayname or str(self.name) def get_schema_class_displayname(self) -> str: return self.schemaclass.get_schema_class_displayname() def __repr__(self) -> str: if self.schemaclass is not None: dn = self.schemaclass.__name__ else: dn = 'Object' n = self.name or '' return f'<{type(self).__name__} {dn}({n!r}) at 0x{id(self):x}>' class ObjectCollectionDuplicateNameError(Exception): pass # A set of scalars that should be reflected as a multi prop, not as an # array. class MultiPropSet[T]( checked.FrozenCheckedSet[T], ): pass class ObjectCollection[Object_T: "Object"]( ObjectContainer, parametric.SingleParametricType[Object_T], ): __slots__ = ('_ids',) is_object_collection = True # Even though Object_T would be a correct annotation below, # we want the type to default to base `Object` for cases # when a TypeVar is passed as Object_T. This is a hack, # of course, because, ideally we'd want to at least default # to the bounds or constraints of the TypeVar, or, even better, # pass the actual type at the call site, but there seems to be # no easy solution to do that. type: ClassVar[type[Object]] = Object # type: ignore _registry: ClassVar[dict[str, builtins.type[ObjectCollection[Object]]]] = {} _container: ClassVar[builtins.type[CollectionFactory[Any]]] def __init_subclass__( cls, *, container: Optional[builtins.type[CollectionFactory[Any]]] = None, ) -> None: super().__init_subclass__() if container is not None: cls._container = container if not cls.is_anon_parametrized(): name = cls.__name__ if name in cls._registry: raise TypeError( f'duplicate name for schema collection class: {name},' f'already defined as {cls._registry[name]!r}' ) else: cls._registry[name] = cls @classmethod def get_subclass(cls, name: str) -> builtins.type[ObjectCollection[Object]]: return cls._registry[name] def __init__( self, _ids: Collection[uuid.UUID], *, _private_init: bool, ) -> None: if not self.is_fully_resolved(): raise TypeError( f"{type(self)!r} unresolved type parameters" ) self._ids = _ids def __len__(self) -> int: return len(self._ids) def __eq__(self, other: Any) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._ids == other._ids def __hash__(self) -> int: return hash(self._ids) def schema_reduce( self, ) -> tuple[ str, Optional[tuple[builtins.type, ...] | builtins.type], tuple[uuid.UUID, ...], tuple[tuple[str, Any], ...], ]: cls = type(self) _, (typeargs, ids, attrs) = self.__reduce__() if cls.is_anon_parametrized(): clsname = cls.__bases__[0].__name__ else: clsname = cls.__name__ return (clsname, typeargs, ids, tuple(attrs.items())) @staticmethod @lru.per_job_lru_cache(maxsize=10240) def schema_restore( data: tuple[ str, Optional[tuple[builtins.type, ...] | builtins.type], tuple[uuid.UUID, ...], tuple[tuple[str, Any], ...], ], ) -> ObjectCollection[Object]: clsname, typeargs, ids, attrs = data scoll_class = ObjectCollection.get_subclass(clsname) return scoll_class.__restore__(typeargs, ids, dict(attrs)) @classmethod def schema_refs_from_data( cls, data: tuple[ str, Optional[tuple[builtins.type, ...] | builtins.type], tuple[uuid.UUID, ...], tuple[tuple[str, Any], ...], ], ) -> frozenset[uuid.UUID]: return frozenset(data[2]) def __reduce__(self) -> tuple[ Callable[..., ObjectCollection[Any]], tuple[ Optional[tuple[builtins.type, ...] | builtins.type], tuple[uuid.UUID, ...], dict[str, Any], ], ]: assert type(self).is_fully_resolved(), \ f'{type(self)} parameters are not resolved' cls: type[ObjectCollection[Object_T]] = self.__class__ types: Optional[tuple[type, ...]] = self.orig_args if types is None or not cls.is_anon_parametrized(): typeargs = None else: typeargs = types[0] if len(types) == 1 else types attrs = {k: getattr(self, k) for k in self.__slots__ if k != '_ids'} return ( cls.__restore__, (typeargs, tuple(self._ids), attrs) ) @classmethod def __restore__( cls, typeargs: Optional[tuple[builtins.type, ...] | builtins.type], ids: tuple[uuid.UUID, ...], attrs: dict[str, Any], ) -> ObjectCollection[Object_T]: if typeargs is None or cls.is_anon_parametrized(): obj = cls(_ids=ids, **attrs, _private_init=True) else: obj = cls[typeargs]( # type: ignore _ids=ids, **attrs, _private_init=True) return obj def dump(self, schema: s_schema.Schema) -> str: return ( f'<{type(self).__name__} objects=' f'[{", ".join(o.dump(schema) for o in self.objects(schema))}] ' f'at {id(self):#x}>' ) @classmethod def create( cls: builtins.type[ObjectCollection[Object_T]], schema: s_schema.Schema, data: Collection[Object_T] | ObjectCollection[Object_T], **kwargs: Any, ) -> ObjectCollection[Object_T]: ids: list[uuid.UUID] = [] if isinstance(data, ObjectCollection): ids.extend(data._ids) elif data: for v in data: ids.append(cls._validate_value(schema, v)) container: Collection[uuid.UUID] = cls._container(ids) return cls(container, **kwargs, _private_init=True) @classmethod def create_empty(cls) -> ObjectCollection[Object_T]: return cls(cls._container(), _private_init=True) @classmethod def _validate_value(cls, schema: s_schema.Schema, v: Object) -> uuid.UUID: if not isinstance(v, cls.type): raise TypeError( f'invalid input data for ObjectIndexByShortname: ' f'expected {cls.type} values, got {type(v)}') if v.id is not None: return v.id else: raise TypeError(f'object {v!r} has no ID!') def ids(self) -> tuple[uuid.UUID, ...]: return tuple(self._ids) def names(self, schema: s_schema.Schema) -> Collection[sn.Name]: result = [] for item_id in self._ids: obj = schema.get_by_id(item_id) result.append(obj.get_name(schema)) return type(self)._container(result) def objects(self, schema: s_schema.Schema) -> tuple[Object_T, ...]: # Calling tuple on a list produced by a comprehension instead # of on a generator comprehension is tragically a slight # performance improvement, and this is a hot path. return tuple([ schema.get_by_id(iid) for iid in self._ids # type: ignore ]) def _object_keys( self, schema: s_schema.Schema ) -> set[tuple[builtins.type[Object], sn.Name]]: return {(type(x), x.get_name(schema)) for x in (self.objects(schema))} @classmethod def compare_values( cls, ours: ObjectCollection[Object_T], theirs: ObjectCollection[Object_T], *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: ComparisonContext, compcoef: float, ) -> float: # If the new and old versions share a reference to an object # that is being deleted, then we must delete this object as well. our_keys = set(ours._object_keys(our_schema) if ours else {}) their_keys = set(theirs._object_keys(their_schema) if theirs else {}) if (our_keys & their_keys) & context.deletions.keys(): return 0.0 if ours is not None: our_names = cls._container( context.get_obj_name(our_schema, obj) for obj in ours.objects(our_schema) ) else: our_names = cls._container() if theirs is not None: their_names = theirs.names(their_schema) else: their_names = cls._container() if our_names != their_names: return compcoef else: return 1.0 def as_shell( self, schema: s_schema.Schema, ) -> ObjectCollectionShell[Object_T]: return ObjectCollectionShell[Object_T]( items=[o.as_shell(schema) for o in self.objects(schema)], collection_type=type(self), ) class ObjectCollectionShell[Object_T: "Object"](Shell): def __init__( self, items: Iterable[ObjectShell[Object_T]], collection_type: type[ObjectCollection[Object_T]], ) -> None: self.items = items self.collection_type = collection_type def __iter__(self) -> Iterator[ObjectShell[Object_T]]: return iter(self.items) def __bool__(self) -> bool: return bool(self.items) def resolve(self, schema: s_schema.Schema) -> ObjectCollection[Object_T]: return self.collection_type.create( schema, [s.resolve(schema) for s in self.items], ) def __repr__(self) -> str: tn = self.__class__.__name__ cn = self.collection_type.__name__ items = ', '.join(str(e.name) or '' for e in self.items) return f'<{tn} {cn}({items}) at 0x{id(self):x}>' class ObjectIndexBase[Key_T, Object_T: Object]( ObjectCollection[Object_T], container=tuple, ): # The keys here are derived, but caching them is a big performance # win (-14% when it was implemented). # # N.B: Because the keys are cached, if the value of the key is # changed, a new collection must be created! # Currently, ObjectIndexBase is used only for refdicts, and so # this update is done in RenameReferencedInheritingObject._alter_begin(). __slots__ = ('_ids', '_keys') _key: Callable[["s_schema.Schema", Object_T], Key_T] def __init_subclass__( cls, *, key: Optional[Callable[["s_schema.Schema", Object_T], Key_T]] = None, ) -> None: super().__init_subclass__() if key is not None: cls._key = key @classmethod def get_key_for(cls, schema: s_schema.Schema, obj: Object) -> Key_T: return cls._key(schema, obj) # type: ignore # mypy bug? @classmethod def get_key_for_name( cls, schema: s_schema.Schema, name: sn.Name, ) -> Key_T: raise NotImplementedError @classmethod def create( cls: type[ObjectIndexBase[Key_T, Object_T]], schema: s_schema.Schema, data: Collection[Object_T] | ObjectCollection[Object_T], **kwargs: Any, ) -> ObjectIndexBase[Key_T, Object_T]: if isinstance(data, ObjectIndexBase): keys = data._keys elif isinstance(data, ObjectCollection): _k = cls._key keys = tuple([_k(schema, x) for x in data.objects(schema)]) else: _k = cls._key keys = tuple([_k(schema, x) for x in data]) coll = cast( ObjectIndexBase[Key_T, Object_T], super().create(schema, data, _keys=keys, **kwargs) ) coll._check_duplicates(schema) return coll def __init__( self, _ids: Collection[uuid.UUID], _keys: Optional[tuple[Key_T, ...]] = None, *, _private_init: bool, ) -> None: super().__init__(_ids, _private_init=_private_init) self._keys = _keys def _check_duplicates(self, schema: s_schema.Schema) -> None: uniq = set(self.keys(schema)) if len(uniq) != len(self._ids): counts = collections.Counter(self.keys(schema)) duplicates = [v for v, count in counts.items() if count > 1] raise ObjectCollectionDuplicateNameError( 'object index contains duplicate key(s): ' + ', '.join(repr(d) for d in duplicates)) @classmethod def compare_values( cls, ours: ObjectCollection[Object_T], theirs: ObjectCollection[Object_T], *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: ComparisonContext, compcoef: float, ) -> float: if not ours and not theirs: basecoef = 1.0 elif not ours or not theirs: basecoef = 0.2 else: assert isinstance(ours, ObjectIndexBase) assert isinstance(theirs, ObjectIndexBase) similarity: list[float] = [] for k, v in ours.items(our_schema): try: theirsv = theirs.get(their_schema, k) except KeyError: # key only in ours similarity.append(0.2) else: similarity.append( v.compare(theirsv, our_schema=our_schema, their_schema=their_schema, context=context)) diff = ( set(theirs.keys(their_schema)) - set(ours.keys(our_schema)) ) similarity.extend(0.2 for k in diff) basecoef = sum(similarity) / len(similarity) return basecoef + (1 - basecoef) * compcoef def add( self, schema: s_schema.Schema, item: Object_T ) -> tuple[s_schema.Schema, Self]: """Return a copy of this collection containing the given item. If the item is already present in the collection, an ``ObjectIndexDuplicateNameError`` is raised. """ key = type(self)._key(schema, item) if self.has(schema, key): raise ObjectCollectionDuplicateNameError( f'object index already contains the {key!r} key') return self.update(schema, [item]) def update( self, schema: s_schema.Schema, reps: Iterable[Object_T] ) -> tuple[s_schema.Schema, Self]: items = dict(self.items(schema)) keyfunc = type(self)._key for obj in reps: items[keyfunc(schema, obj)] = obj return ( schema, cast(Self, type(self).create(schema, items.values())), ) def delete( self, schema: s_schema.Schema, names: Iterable[Key_T], ) -> tuple[s_schema.Schema, Self]: items = dict(self.items(schema)) for name in names: items.pop(name) return ( schema, cast(Self, type(self).create(schema, items.values())), ) def items( self, schema: s_schema.Schema, ) -> tuple[tuple[Key_T, Object_T], ...]: result = [] for key, item_id in zip(self.keys(schema), self._ids): obj = schema.get_by_id(item_id) result.append((key, obj)) return tuple(result) # type: ignore def keys(self, schema: s_schema.Schema) -> tuple[Key_T, ...]: # To support existing pickled schemas that don't have _keys in them, # lazily compute them if they are missing. # FUTURE: Can drop in 5.0. if self._keys is None: _k = type(self)._key self._keys = tuple([_k(schema, x) for x in self.objects(schema)]) return self._keys def has(self, schema: s_schema.Schema, name: Key_T) -> bool: return name in self.keys(schema) def get( self, schema: s_schema.Schema, name: Key_T, default: Optional[Object_T | NoDefaultT] = NoDefault, ) -> Optional[Object_T]: # TODO: Should we store an actual dict? for key, item_id in zip(self.keys(schema), self._ids): if name == key: return schema.get_by_id(item_id) # type: ignore if default is NoDefault: raise KeyError(name) else: return default def _fullname_object_key(schema: s_schema.Schema, o: Object) -> sn.Name: return o.get_name(schema) class ObjectIndexByFullname( ObjectIndexBase[sn.Name, Object_T], key=_fullname_object_key, ): @classmethod def get_key_for_name( cls, schema: s_schema.Schema, name: sn.Name, ) -> sn.Name: return name def _shortname_object_key(schema: s_schema.Schema, o: Object) -> sn.Name: return o.get_shortname(schema) class ObjectIndexByShortname( ObjectIndexBase[sn.Name, Object_T], key=_shortname_object_key, ): @classmethod def get_key_for_name( cls, schema: s_schema.Schema, name: sn.Name, ) -> sn.Name: return sn.shortname_from_fullname(name) def _unqualified_object_key( schema: s_schema.Schema, o: QualifiedObject, ) -> sn.UnqualName: return sn.UnqualName(o.get_shortname(schema).name) class ObjectIndexByUnqualifiedName( ObjectIndexBase[sn.UnqualName, QualifiedObject_T], key=_unqualified_object_key, ): @classmethod def get_key_for_name( cls, schema: s_schema.Schema, name: sn.Name, ) -> sn.UnqualName: return sn.UnqualName(sn.shortname_from_fullname(name).name) class ObjectDict[Key_T, Object_T: Object]( ObjectCollection[Object_T], container=tuple, ): __slots__ = ('_ids', '_keys') # Breaking the Liskov Substitution Principle @classmethod def create( # type: ignore cls, schema: s_schema.Schema, data: Mapping[Key_T, Object_T] | ObjectDict[Key_T, Object_T], **kwargs: Any, ) -> ObjectDict[Key_T, Object_T]: if isinstance(data, ObjectDict): return super().create( schema, data, _keys=data._keys, ) # type: ignore else: return super().create( schema, data.values(), _keys=tuple(data.keys()), ) # type: ignore @classmethod def create_empty(cls) -> ObjectDict[Key_T, Object_T]: return cls(cls._container(), _private_init=True, _keys=tuple()) @classmethod def compare_values( cls, ours: ObjectCollection[Object_T], theirs: ObjectCollection[Object_T], *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: ComparisonContext, compcoef: float, ) -> float: assert isinstance(ours, ObjectDict) assert isinstance(theirs, ObjectDict) if ours.keys(our_schema) != theirs.keys(their_schema): return compcoef return super().compare_values( ours, theirs, our_schema=our_schema, their_schema=their_schema, context=context, compcoef=compcoef) def __init__( self, _ids: Collection[uuid.UUID], _keys: tuple[Key_T, ...], *, _private_init: bool, ) -> None: super().__init__(_ids, _private_init=_private_init) self._keys = _keys def __eq__(self, other: Any) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._ids == other._ids and self._keys == other._keys def __hash__(self) -> int: return hash((self._ids, self._keys)) def dump(self, schema: s_schema.Schema) -> str: objs = ", ".join(f"{self._keys[i]}: {o.dump(schema)}" for i, o in enumerate(self.objects(schema))) return f'<{type(self).__name__} objects={objs} at {id(self):#x}>' def __repr__(self) -> str: items = [f"{self._keys[i]}: {id}" for i, id in enumerate(self._ids)] return f'{{{", ".join(items)}}}' def keys(self, schema: s_schema.Schema) -> tuple[Key_T, ...]: return self._keys def values(self, schema: s_schema.Schema) -> tuple[Object_T, ...]: return self.objects(schema) def items( self, schema: s_schema.Schema, ) -> tuple[tuple[Key_T, Object_T], ...]: return tuple(zip(self._keys, self.objects(schema))) def as_shell( self, schema: s_schema.Schema, ) -> ObjectDictShell[Key_T, Object_T]: return ObjectDictShell( items={k: o.as_shell(schema) for k, o in self.items(schema)}, collection_type=type(self), ) class ObjectDictShell[Key_T, Object_T: "Object"]( ObjectCollectionShell[Object_T], ): items: Mapping[Any, ObjectShell[Object_T]] collection_type: type[ObjectDict[Key_T, Object_T]] def __init__( self, items: Mapping[Any, ObjectShell[Object_T]], collection_type: type[ObjectDict[Key_T, Object_T]], ) -> None: self.items = items self.collection_type = collection_type def __repr__(self) -> str: tn = self.__class__.__name__ cn = self.collection_type.__name__ items = ', '.join(f'{k}: {v.name}' for k, v in self.items.items()) return f'<{tn} {cn}({items}) at 0x{id(self):x}>' def resolve(self, schema: s_schema.Schema) -> ObjectDict[Key_T, Object_T]: return self.collection_type.create( schema, {k: s.resolve(schema) for k, s in self.items.items()}, ) class ObjectSet[Object_T: Object]( ObjectCollection[Object_T], container=frozenset, ): def __repr__(self) -> str: return f'{{{", ".join(str(id) for id in self._ids)}}}' @classmethod def merge_values( cls: type[ObjectSet[Object_T]], target: Object, sources: Iterable[Object], field_name: str, *, ignore_local: bool = False, schema: s_schema.Schema, ) -> ObjectSet[Object_T]: if not ignore_local: result = target.get_explicit_field_value(schema, field_name, None) else: result = None for source in sources: if source.__class__.get_field(field_name) is None: continue theirs = source.get_explicit_field_value(schema, field_name, None) if theirs: if result is None: result = theirs else: result._ids |= theirs._ids return result # type: ignore class ObjectList[Object_T: Object]( ObjectCollection[Object_T], container=tuple, ): def __repr__(self) -> str: return f'ObjectList([{", ".join(str(id) for id in self._ids)}])' def first(self, schema: s_schema.Schema, default: Any = NoDefault) -> Any: # The `Any` return type is so that using methods on Object subclasses # doesn't cause Mypy to complain. try: return next(iter(self.objects(schema))) except StopIteration: pass if default is NoDefault: raise IndexError('ObjectList is empty') else: return default # Unfortunately, mypy does not support self generics over types with # typevars, so we have to resort to method redifinition. @classmethod def create( cls, schema: s_schema.Schema, data: Iterable[Object_T] | ObjectCollection[Object_T], **kwargs: Any, ) -> ObjectList[Object_T]: return super().create(schema, data, **kwargs) # type: ignore class SubclassableObject(Object): abstract = SchemaField( bool, default=False, inheritable=False, special_ddl_syntax=True, compcoef=0.909, ) def _issubclass( self, schema: s_schema.Schema, parent: SubclassableObject ) -> bool: return parent == self def issubclass( self, schema: s_schema.Schema, parent: SubclassableObject | tuple[SubclassableObject, ...], ) -> bool: from . import types as s_types if isinstance(parent, tuple): return any(self.issubclass(schema, p) for p in parent) if ( isinstance(parent, s_types.Type) and parent.is_anyobject(schema) and isinstance(self, s_types.Type) and self.is_object_type() ): return True if isinstance(parent, s_types.Type) and parent.is_any(schema): return True return self._issubclass(schema, parent) InheritingObjectT = TypeVar('InheritingObjectT', bound='InheritingObject') class InheritingObject(SubclassableObject): bases = SchemaField( ObjectList['InheritingObject'], type_is_generic_self=True, default=DEFAULT_CONSTRUCTOR, coerce=True, inheritable=False, compcoef=0.900, ) ancestors = SchemaField( ObjectList['InheritingObject'], type_is_generic_self=True, default=DEFAULT_CONSTRUCTOR, coerce=True, inheritable=False, compcoef=0.999, ) # Fields that have been inherited as opposed to set explicitly. inherited_fields = SchemaField( checked.FrozenCheckedSet[str], default=DEFAULT_CONSTRUCTOR, coerce=True, inheritable=False, compcoef=0.999, ) is_derived = SchemaField( bool, default=False, compcoef=0.909) def inheritable_fields(self) -> Iterable[str]: for fn, f in self.__class__.get_fields().items(): if f.inheritable and not f.ephemeral: yield fn @classmethod def get_default_base_name(self) -> Optional[sn.Name]: return None # Redefining bases and ancestors accessors to make them generic def get_bases( self: InheritingObjectT, schema: s_schema.Schema, ) -> ObjectList[InheritingObjectT]: return self.get_field_value(schema, 'bases') # type: ignore def get_ancestors( self: InheritingObjectT, schema: s_schema.Schema, ) -> ObjectList[InheritingObjectT]: return self.get_field_value(schema, 'ancestors') # type: ignore def get_base_names(self, schema: s_schema.Schema) -> Collection[sn.Name]: return self.get_bases(schema).names(schema) def maybe_get_topmost_concrete_base( self: InheritingObjectT, schema: s_schema.Schema ) -> Optional[InheritingObjectT]: """Get the topmost non-abstract base.""" lineage = self.get_ancestors(schema).objects(schema) for ancestor in reversed(lineage): if not ancestor.get_abstract(schema): return ancestor if not self.get_abstract(schema): return self return None def get_topmost_concrete_base( self: InheritingObjectT, schema: s_schema.Schema ) -> InheritingObjectT: """Get the topmost non-abstract base.""" base = self.maybe_get_topmost_concrete_base(schema) if not base: raise errors.SchemaError( f'{self.get_verbosename(schema)} has no non-abstract ancestors' ) return base def get_base_for_cast(self, schema: s_schema.Schema) -> Object: return self.get_topmost_concrete_base(schema) @classmethod def get_root_classes(cls) -> tuple[sn.QualName, ...]: return tuple() def _issubclass( self, schema: s_schema.Schema, parent: SubclassableObject, ) -> bool: if parent == self: return True lineage = self.get_ancestors(schema).objects(schema) return parent in lineage def descendants( self: InheritingObjectT, schema: s_schema.Schema ) -> frozenset[InheritingObjectT]: return schema.get_descendants(self) def ordered_descendants( self: InheritingObjectT, schema: s_schema.Schema ) -> list[InheritingObjectT]: """Return class descendants in ancestral order.""" graph = {} for descendant in self.descendants(schema): graph[descendant] = topological.DepGraphEntry( item=descendant, deps=ordered.OrderedSet( descendant.get_bases(schema).objects(schema), ), extra=False, ) return list(topological.sort(graph, allow_unresolved=True)) def children( self: InheritingObjectT, schema: s_schema.Schema, ) -> frozenset[InheritingObjectT]: return schema.get_children(self) def field_is_inherited( self, schema: s_schema.Schema, field_name: str, ) -> bool: inherited_fields = self.get_inherited_fields(schema) return field_name in inherited_fields def get_explicit_local_field_value( self, schema: s_schema.Schema, field_name: str, default: Any = NoDefault, ) -> Any: inherited_fields = self.get_inherited_fields(schema) if field_name not in inherited_fields: return self.get_explicit_field_value(schema, field_name, default) elif default is not NoDefault: return default else: raise FieldValueNotFoundError( f'{self!r} object has no non-inherited value for ' f'field {field_name!r}' ) def allow_ref_propagation( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: RefDict, ) -> bool: return True def as_alter_delta( self: InheritingObjectT, other: InheritingObjectT, *, self_schema: s_schema.Schema, other_schema: s_schema.Schema, confidence: float, context: ComparisonContext, ) -> sd.ObjectCommand[InheritingObjectT]: from . import delta as sd from . import inheriting as s_inh delta = super().as_alter_delta( other, self_schema=self_schema, other_schema=other_schema, confidence=confidence, context=context, ) rebase = sd.get_object_command_class( s_inh.RebaseInheritingObject, type(self)) old_base_names = tuple( context.get_obj_name(self_schema, base) for base in self.get_bases(self_schema).objects(self_schema) ) new_base_names = other.get_bases(other_schema).names(other_schema) if old_base_names != new_base_names and rebase is not None: removed, added = s_inh.delta_bases( old_base_names, new_base_names, t=type(self), ) rebase_cmd = rebase( classname=other.get_name(other_schema), removed_bases=removed, added_bases=added, ) rebase_cmd.set_attribute_value( 'bases', other.get_bases(other_schema).as_shell(other_schema), ) rebase_cmd.set_attribute_value( 'ancestors', other.get_ancestors(other_schema).as_shell(other_schema), ) # Trim these from the base alter since they are redundant # and clog up debug output. delta.discard(not_none(delta._get_attribute_set_cmd('bases'))) # ancestors might not be in the delta, if it didn't change if anc := delta._get_attribute_set_cmd('ancestors'): delta.discard(anc) delta.add(rebase_cmd) return delta def record_simple_field_delta( self: InheritingObjectT, schema: s_schema.Schema, delta: sd.ObjectCommand[InheritingObjectT], context: ComparisonContext, *, fname: str, value: Any, orig_value: Any, orig_schema: Optional[s_schema.Schema], orig_object: Optional[InheritingObjectT], from_default: bool = False, ) -> None: inherited_fields = self.get_inherited_fields(schema) is_inherited = fname in inherited_fields if orig_schema is not None and orig_object is not None: orig_inherited_fields = ( orig_object.get_inherited_fields(orig_schema)) orig_is_inherited = fname in orig_inherited_fields else: orig_is_inherited = is_inherited computed_fields = self.get_computed_fields(schema) is_computed = fname in computed_fields if orig_schema is not None and orig_object is not None: orig_computed_fields = ( orig_object.get_computed_fields(orig_schema)) orig_is_computed = fname in orig_computed_fields else: orig_is_computed = is_computed cmd = delta.set_attribute_value( fname, value=value, orig_value=orig_value, inherited=is_inherited, orig_inherited=orig_is_inherited, computed=is_computed, orig_computed=orig_is_computed, from_default=from_default, ) context.parent_ops.append(delta) cmd.record_diff_annotations( schema=schema, orig_schema=orig_schema, context=context, object=self, orig_object=orig_object, ) context.parent_ops.pop() def get_field_create_delta( self: InheritingObjectT, schema: s_schema.Schema, delta: sd.ObjectCommand[InheritingObjectT], fname: str, value: Any, ) -> None: inherited_fields = self.get_inherited_fields(schema) delta.set_attribute_value( fname, value=value, inherited=fname in inherited_fields, ) def get_field_alter_delta( self: Self, old_schema: s_schema.Schema, new_schema: s_schema.Schema, delta: sd.ObjectCommand[InheritingObjectT], fname: str, value: Any, orig_value: Any, ) -> None: inherited_fields = self.get_inherited_fields(new_schema) delta.set_attribute_value( fname, value, orig_value=orig_value, inherited=fname in inherited_fields, ) def get_field_delete_delta( self: Self, schema: s_schema.Schema, delta: sd.ObjectCommand[InheritingObjectT], fname: str, orig_value: Any, ) -> None: inherited_fields = self.get_inherited_fields(schema) delta.set_attribute_value( fname, value=None, orig_value=orig_value, inherited=fname in inherited_fields, ) @classmethod def compare_obj_field_value[T]( cls: type[Self], field: Field[type[T]], ours: Self, theirs: Self, *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: ComparisonContext, explicit: bool = False, ) -> float: similarity = super().compare_obj_field_value( field, ours, theirs, our_schema=our_schema, their_schema=their_schema, context=context, explicit=explicit, ) # Check to see if this field's inherited status has changed. # If so, this is definitely a change. our_ifs = ours.get_inherited_fields(our_schema) their_ifs = theirs.get_inherited_fields(their_schema) fname = field.name if (fname in our_ifs) != (fname in their_ifs): # The change in inherited status decreases the similarity. similarity *= 0.95 return similarity DerivableInheritingObjectT = TypeVar( 'DerivableInheritingObjectT', bound='DerivableInheritingObject', ) class DerivableInheritingObject(DerivableObject, InheritingObject): def get_nearest_non_derived_parent( self: DerivableInheritingObjectT, schema: s_schema.Schema, ) -> DerivableInheritingObjectT: obj = self while obj.get_is_derived(schema): obj = obj.get_bases(schema).first(schema) return obj def get_nearest_generic_parent( self: DerivableInheritingObjectT, schema: s_schema.Schema, ) -> DerivableInheritingObjectT: obj = self while not obj.is_non_concrete(schema): obj = obj.get_bases(schema).first(schema) return obj @markup.serializer.serializer.register(Object) @markup.serializer.serializer.register(ObjectCollection) def _serialize_to_markup(o: Object, *, ctx: markup.Context) -> markup.Markup: if 'schema' not in ctx.kwargs: orepr = repr(o) else: orepr = o.dump(ctx.kwargs['schema']) return markup.elements.lang.Object( id=id(o), class_module=type(o).__module__, classname=type(o).__name__, repr=orepr, ) def _merge_lineage[InheritingObjectT: 'InheritingObject']( lineage: Iterable[list[InheritingObjectT]], subject_name: str, ) -> list[InheritingObjectT]: result: list[Any] = [] while True: nonempty = [line for line in lineage if line] if not nonempty: return result for line in nonempty: candidate = line[0] tails = [m for m in nonempty if candidate in m[1:]] if not tails: break else: raise errors.SchemaError( f"could not find consistent ancestor order for {subject_name}" ) result.append(candidate) for line in nonempty: if line[0] == candidate: del line[0] def _compute_lineage[InheritingObjectT: 'InheritingObject']( schema: s_schema.Schema, obj: InheritingObjectT, subject_name: str, ) -> list[InheritingObjectT]: bases = tuple(obj.get_bases(schema).objects(schema)) lineage = [[obj]] for base in bases: lineage.append(_compute_lineage(schema, base, subject_name)) lineage.append(list(bases)) return _merge_lineage(lineage, subject_name) def compute_lineage[InheritingObjectT: 'InheritingObject']( schema: s_schema.Schema, bases: Iterable[InheritingObjectT], subject_name: str, ) -> list[InheritingObjectT]: lineage = [] for base in bases: lineage.append(_compute_lineage(schema, base, subject_name)) lineage.append(list(bases)) try: return _merge_lineage(lineage, subject_name) except errors.SchemaError as e: sbases = ', '.join(str(base.get_name(schema)) for base in bases) details = f'type has specified bases: {sbases}' e.set_hint_and_details(hint=e.hint, details=details) raise def compute_ancestors[InheritingObjectT: 'InheritingObject']( schema: s_schema.Schema, obj: InheritingObjectT, ) -> list[InheritingObjectT]: return compute_lineage( schema, obj.get_bases(schema).objects(schema), obj.get_verbosename(schema), ) def derive_name( schema: s_schema.Schema, *qualifiers: str, module: str, parent: Optional[DerivableObject] = None, derived_name_base: Optional[sn.Name] = None, ) -> sn.QualName: if derived_name_base is None: assert parent is not None derived_name_base = parent.get_derived_name_base(schema) name = sn.get_specialized_name(derived_name_base, *qualifiers) return sn.QualName(name=name, module=module) ================================================ FILE: edb/schema/objtypes.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, Iterable, cast import collections from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import qltypes from . import annos as s_anno from . import constraints from . import delta as sd from . import functions as s_func from . import inheriting from . import links from . import properties from . import name as sn from . import objects as so from . import pointers from . import policies from . import schema as s_schema from . import sources from . import triggers from . import types as s_types from . import unknown_pointers from . import utils class ObjectTypeRefMixin(so.Object): # We stick access policies and triggers in their own class as a # hack, to allow us to ensure that access_policies comes later in # the refdicts list than pointers does, so that pointers are # always created before access policies when creating an inherited # type. access_policies_refs = so.RefDict( attr='access_policies', requires_explicit_overloaded=True, backref_attr='subject', ref_cls=policies.AccessPolicy) access_policies = so.SchemaField( so.ObjectIndexByUnqualifiedName[policies.AccessPolicy], inheritable=False, ephemeral=True, coerce=True, compcoef=0.857, default=so.DEFAULT_CONSTRUCTOR) triggers_refs = so.RefDict( attr='triggers', requires_explicit_overloaded=True, backref_attr='subject', ref_cls=triggers.Trigger) triggers = so.SchemaField( so.ObjectIndexByUnqualifiedName[triggers.Trigger], inheritable=False, ephemeral=True, coerce=True, compcoef=0.857, default=so.DEFAULT_CONSTRUCTOR) class ObjectType( sources.Source, constraints.ConsistencySubject, s_types.InheritingType, so.InheritingObject, # Help reflection figure out the right db MRO s_types.Type, # Help reflection figure out the right db MRO s_anno.AnnotationSubject, # Help reflection figure out the right db MRO ObjectTypeRefMixin, qlkind=qltypes.SchemaObjectClass.TYPE, data_safe=False, ): union_of = so.SchemaField( so.ObjectSet["ObjectType"], default=so.DEFAULT_CONSTRUCTOR, coerce=True, type_is_generic_self=True, compcoef=0.0, ) intersection_of = so.SchemaField( so.ObjectSet["ObjectType"], default=so.DEFAULT_CONSTRUCTOR, coerce=True, type_is_generic_self=True, ) is_opaque_union = so.SchemaField( bool, default=False, ) def is_object_type(self) -> bool: return True def is_free_object_type(self, schema: s_schema.Schema) -> bool: if self.get_name(schema) == sn.QualName('std', 'FreeObject'): return True FreeObject = schema.get( 'std::FreeObject', type=ObjectType, default=None) if FreeObject is None: # Possible in bootstrap before FreeObject is declared return False else: return self.issubclass(schema, FreeObject) def is_fake_object_type(self, schema: s_schema.Schema) -> bool: return self.is_free_object_type(schema) def is_material_object_type(self, schema: s_schema.Schema) -> bool: return not ( self.is_fake_object_type(schema) or self.is_compound_type(schema) or self.is_view(schema) ) def is_union_type(self, schema: s_schema.Schema) -> bool: return bool(self.get_union_of(schema)) def is_intersection_type(self, schema: s_schema.Schema) -> bool: return bool(self.get_intersection_of(schema)) def is_compound_type(self, schema: s_schema.Schema) -> bool: return self.is_union_type(schema) or self.is_intersection_type(schema) def get_displayname(self, schema: s_schema.Schema) -> str: if self.is_view(schema) and not self.get_alias_is_persistent(schema): schema, mtype = self.material_type(schema) else: mtype = self union_of = mtype.get_union_of(schema) if union_of: if self.get_is_opaque_union(schema): std_obj = schema.get('std::BaseObject', type=ObjectType) return std_obj.get_displayname(schema) else: comp_dns = sorted( (c.get_displayname(schema) for c in union_of.objects(schema))) return '(' + ' | '.join(comp_dns) + ')' else: intersection_of = mtype.get_intersection_of(schema) if intersection_of: comp_dns = sorted( (c.get_displayname(schema) for c in intersection_of.objects(schema))) # Elide BaseObject from display, because `& BaseObject` # is a nop. return '(' + ' & '.join( dn for dn in comp_dns if dn != 'std::BaseObject' ) + ')' elif mtype == self: return super().get_displayname(schema) else: return mtype.get_displayname(schema) def getrptrs( self, schema: s_schema.Schema, name: str, *, sources: Iterable[so.Object] = () ) -> set[links.Link]: if sn.is_qualified(name): raise ValueError( 'references to concrete pointers must not be qualified') ptrs: set[links.Link] = set() ancestor_objects = self.get_ancestors(schema).objects(schema) for obj in (self,) + ancestor_objects: ptrs.update( lnk for lnk in schema.get_referrers( obj, scls_type=links.Link, field_name='target') if ( lnk.get_shortname(schema).name == name and lnk.get_source_type(schema).is_material_object_type( schema) # Only grab the "base" pointers and all( b.is_non_concrete(schema) for b in lnk.get_bases(schema).objects(schema) ) and (not sources or lnk.get_source_type(schema) in sources) ) ) for intersection in self.get_intersection_of(schema).objects(schema): ptrs.update(intersection.getrptrs(schema, name, sources=sources)) unions = schema.get_referrers( self, scls_type=ObjectType, field_name='union_of') for union in unions: ptrs.update(union.getrptrs(schema, name, sources=sources)) return ptrs def get_relevant_triggers( self, kind: qltypes.TriggerKind, schema: s_schema.Schema ) -> list[triggers.Trigger]: return [ t for t in self.get_triggers(schema).objects(schema) if kind in t.get_kinds(schema) ] def implicitly_castable_to( self, other: s_types.Type, schema: s_schema.Schema ) -> bool: return self.issubclass(schema, other) def find_common_implicitly_castable_type( self, other: s_types.Type, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Optional[ObjectType]]: if not isinstance(other, ObjectType): return schema, None nearest_common_ancestors = utils.get_class_nearest_common_ancestors( schema, [self, other] ) # We arbitrarily select the first nearest common ancestor nearest_common_ancestor = ( nearest_common_ancestors[0] if nearest_common_ancestors else None) if nearest_common_ancestor is not None: assert isinstance(nearest_common_ancestor, ObjectType) return ( schema, nearest_common_ancestor, ) @classmethod def get_root_classes(cls) -> tuple[sn.QualName, ...]: return ( sn.QualName(module='std', name='BaseObject'), sn.QualName(module='std', name='Object'), sn.QualName(module='std', name='FreeObject'), ) @classmethod def get_default_base_name(cls) -> sn.QualName: return sn.QualName(module='std', name='Object') def _issubclass( self, schema: s_schema.Schema, parent: so.SubclassableObject ) -> bool: if self == parent: return True if ( (my_union := self.get_union_of(schema)) and not self.get_is_opaque_union(schema) ): # A union is considered a subclass of a type, if # ALL its components are subclasses of that type. return all( t._issubclass(schema, parent) for t in my_union.objects(schema) ) if my_intersection := self.get_intersection_of(schema): # An intersection is considered a subclass of a type, if # ANY of its components are subclasses of that type. return any( t._issubclass(schema, parent) for t in my_intersection.objects(schema) ) lineage = self.get_ancestors(schema).objects(schema) if parent in lineage: return True elif isinstance(parent, ObjectType): if ( (parent_union := parent.get_union_of(schema)) and not parent.get_is_opaque_union(schema) ): # A type is considered a subclass of a union type, # if it is a subclass of ANY of the union components. return ( parent.get_is_opaque_union(schema) or any( self._issubclass(schema, t) for t in parent_union.objects(schema) ) ) if parent_intersection := parent.get_intersection_of(schema): # A type is considered a subclass of an intersection type, # if it is a subclass of ALL of the intersection components. return all( self._issubclass(schema, t) for t in parent_intersection.objects(schema) ) return False def allow_ref_propagation( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, ) -> bool: return not self.is_view(schema) or refdict.attr == 'pointers' def as_type_delete_if_unused( self, schema: s_schema.Schema, ) -> Optional[sd.DeleteObject[ObjectType]]: if not self._is_deletable(schema): return None # References to aliases can only occur inside other aliases, # so when they go, we need to delete the reference also. # Compound types also need to be deleted when their last # referrer goes. if ( self.is_view(schema) and self.get_alias_is_persistent(schema) ) or self.is_compound_type(schema): return self.init_delta_command( schema, sd.DeleteObject, if_unused=True, if_exists=True, ) else: return None def _test_polymorphic( self, schema: s_schema.Schema, other: s_types.Type ) -> bool: if other.is_anyobject(schema): return True return False def get_or_create_union_type( schema: s_schema.Schema, components: Iterable[ObjectType], *, transient: bool = False, opaque: bool = False, module: Optional[str] = None, ) -> tuple[s_schema.Schema, ObjectType, bool]: name = s_types.get_union_type_name( (c.get_name(schema) for c in components), opaque=opaque, module=module, ) objtype = schema.get(name, default=None, type=ObjectType) created = objtype is None if objtype is None: components = list(components) std_object = schema.get('std::BaseObject', type=ObjectType) schema, objtype = std_object.derive_subtype( schema, name=name, attrs=dict( union_of=so.ObjectSet.create(schema, components), is_opaque_union=opaque, abstract=True, ), transient=transient, ) if not opaque: schema = sources.populate_pointer_set_for_source_union( schema, cast(list[sources.Source], components), objtype, modname=module, ) return schema, objtype, created def get_or_create_intersection_type( schema: s_schema.Schema, components: Iterable[ObjectType], *, module: Optional[str] = None, transient: bool = False, ) -> tuple[s_schema.Schema, ObjectType, bool]: name = s_types.get_intersection_type_name( (c.get_name(schema) for c in components), module=module, ) objtype = schema.get(name, default=None, type=ObjectType) created = objtype is None if objtype is None: components = list(components) std_object = schema.get('std::BaseObject', type=ObjectType) schema, objtype = std_object.derive_subtype( schema, name=name, attrs=dict( intersection_of=so.ObjectSet.create(schema, components), abstract=True, ), transient=transient, ) ptrs_dict = collections.defaultdict(list) for component in components: for pn, ptr in component.get_pointers(schema).items(schema): ptrs_dict[pn].append(ptr) intersection_pointers = {} for pn, ptrs in ptrs_dict.items(): if len(ptrs) > 1: # The pointer is present in more than one component. schema, ptr = pointers.get_or_create_intersection_pointer( schema, ptrname=pn, source=objtype, components=set(ptrs), transient=transient, ) else: ptr = ptrs[0] intersection_pointers[pn] = ptr for pn, ptr in intersection_pointers.items(): if objtype.maybe_get_ptr(schema, pn) is None: schema = objtype.add_pointer(schema, ptr) assert isinstance(objtype, ObjectType) return schema, objtype, created class ObjectTypeCommandContext( links.LinkSourceCommandContext[ObjectType], properties.PropertySourceContext[ObjectType], unknown_pointers.UnknownPointerSourceContext[ObjectType], policies.AccessPolicySourceCommandContext[ObjectType], triggers.TriggerSourceCommandContext[ObjectType], sd.ObjectCommandContext[ObjectType], constraints.ConsistencySubjectCommandContext, s_anno.AnnotationSubjectCommandContext, ): pass class ObjectTypeCommand( s_types.InheritingTypeCommand[ObjectType], constraints.ConsistencySubjectCommand[ObjectType], sources.SourceCommand[ObjectType], links.LinkSourceCommand[ObjectType], context_class=ObjectTypeCommandContext, ): def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext ) -> None: if ( not context.stdmode and not context.testmode and self.scls.is_material_object_type(schema) ): for base in self.scls.get_bases(schema).objects(schema): name = base.get_name(schema) if ( sn.UnqualName(name.module) in s_schema.STD_MODULES and name not in ( sn.QualName('std', 'BaseObject'), sn.QualName('std', 'Object'), ) ): raise errors.SchemaDefinitionError( f"cannot extend system type '{name}'", span=self.span, ) # Internal consistency check: our stdlib and extension types # shouldn't extend std::Object, which is reserved for user # types. if ( self.scls.is_material_object_type(schema) and self.classname.get_root_module_name() in s_schema.STD_MODULES ): for base in self.scls.get_bases(schema).objects(schema): name = base.get_name(schema) if name == sn.QualName('std', 'Object'): raise errors.SchemaDefinitionError( f"standard lib/extension type '{self.classname}' " f"cannot extend std::Object", hint="try BaseObject", ) class CreateObjectType( ObjectTypeCommand, s_types.CreateInheritingType[ObjectType], ): astnode = qlast.CreateObjectType def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if (self.get_attribute_value('expr_type') and not self.get_attribute_value('expr')): # This is a nested view type, e.g # __FooAlias_bar produced by FooAlias := (SELECT Foo { bar: ... }) # and should obviously not appear as a top level definition. return None else: return super()._get_ast(schema, context, parent_node=parent_node) def _get_ast_node( self, schema: s_schema.Schema, context: sd.CommandContext ) -> type[qlast.DDLOperation]: if self.get_attribute_value('expr_type'): return qlast.CreateAlias else: return super()._get_ast_node(schema, context) def _create_finalize( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if ( not context.canonical and self.scls.is_material_object_type(schema) ): # Propagate changes to any functions that depend on # ancestor types in order to recompute the inheritance # situation. schema = self._propagate_if_expr_refs( schema, context, action='creating an object type', include_ancestors=True, filter=s_func.Function, ) return super()._create_finalize(schema, context) class RenameObjectType( ObjectTypeCommand, s_types.RenameInheritingType[ObjectType], ): pass class RebaseObjectType( ObjectTypeCommand, s_types.RebaseInheritingType[ObjectType], ): pass class AlterObjectType( ObjectTypeCommand, s_types.AlterType[ObjectType], inheriting.AlterInheritingObject[ObjectType], ): astnode = qlast.AlterObjectType def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) if ( not context.canonical and bool(self.get_subcommands(type=policies.AccessPolicyCommand)) ): from . import functions # If we have any policy commands, we need to propagate to update # functions. We also need to propagate to anything that updates # an ancestor. # # Note that the ancestor search does not generate # quadratically many updates in the case that this change # was propagated from an ancestor, since the # _propagate_if_expr_refs call in the ancestor temporarily # eliminates the ref! schema = self._propagate_if_expr_refs( schema, context, action=self.get_friendly_description(schema=schema), include_ancestors=True, filter=functions.Function, ) return schema def _alter_finalize( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if not context.canonical: # If this type is contained in any unions, we need to # update them with any additions or alterations made to # this type. (Deletions are already handled in DeletePointer.) unions = schema.get_referrers( self.scls, scls_type=ObjectType, field_name='union_of') orig_disable = context.disable_dep_verification for union in unions: if union.get_is_opaque_union(schema): continue delete = union.init_delta_command(schema, sd.DeleteObject) context.disable_dep_verification = True delete.apply(schema, context) context.disable_dep_verification = orig_disable # We run the delete to populate the tree, but then instead # of actually deleting the object, we just remove the names. # This is because the pointers in the types we are looking # at might themselves reference the union, so we need # them in the schema to produce the correct as_alter_delta. nschema = _delete_to_delist(delete, schema) nschema, nunion, _ = utils.ensure_union_type( nschema, types=union.get_union_of(schema).objects(schema), opaque=union.get_is_opaque_union(schema), module=union.get_name(schema).module, ) assert isinstance(nunion, ObjectType) diff = union.as_alter_delta( other=nunion, self_schema=schema, other_schema=nschema, confidence=1.0, context=so.ComparisonContext(), ) schema = diff.apply(schema, context) self.add(diff) return super()._alter_finalize(schema, context) def _delete_to_delist( delete: sd.DeleteObject[so.Object], schema: s_schema.Schema, ) -> s_schema.Schema: """Delist all of the objects mentioned in a delete tree. This removes their names from the schema but preserves the actual objects. """ schema = schema.delist(delete.classname) for sub in delete.get_subcommands(type=sd.DeleteObject): schema = _delete_to_delist(sub, schema) return schema class DeleteObjectType( ObjectTypeCommand, s_types.DeleteType[ObjectType], inheriting.DeleteInheritingObject[ObjectType], ): astnode = qlast.DropObjectType def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if self.get_orig_attribute_value('expr_type'): # This is an alias type, appropriate DDL would be generated # from the corresponding DeleteAlias node. return None else: return super()._get_ast(schema, context, parent_node=parent_node) def _delete_finalize( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if ( not context.canonical and self.scls.is_material_object_type(schema) ): # Propagate changes to any functions that depend on # ancestor types in order to recompute the inheritance # situation. schema = self._propagate_if_expr_refs( schema, context, action='deleting an object type', include_self=False, include_ancestors=True, filter=s_func.Function, ) return super()._delete_finalize(schema, context) ================================================ FILE: edb/schema/operators.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Mapping from edb import errors from edb.common import checked from edb.edgeql import ast as qlast from edb.edgeql import qltypes as ft from . import delta as sd from . import functions as s_func from . import name as sn from . import objects as so from . import schema as s_schema from . import utils class Operator( s_func.CallableObject, s_func.VolatilitySubject, qlkind=ft.SchemaObjectClass.OPERATOR, data_safe=True, ): operator_kind = so.SchemaField( ft.OperatorKind, coerce=True, compcoef=0.4) language = so.SchemaField( qlast.Language, default=None, compcoef=0.4, coerce=True) from_operator = so.SchemaField( checked.CheckedList[str], coerce=True, default=None, compcoef=0.4) from_function = so.SchemaField( checked.CheckedList[str], coerce=True, default=None, compcoef=0.4) from_expr = so.SchemaField( bool, default=False, compcoef=0.4) force_return_cast = so.SchemaField( bool, default=False, compcoef=0.9) code = so.SchemaField( str, default=None, compcoef=0.4) # An unused dummy field. We have this here to make it easier to # test the *removal* of internal schema fields during in-place # upgrades. _dummy_field = so.SchemaField( str, default=None) # If this is a derivative operator, *derivative_of* would # contain the name of the origin operator. # For example, the `std::IN` operator has `std::=` # as its origin. derivative_of = so.SchemaField( sn.QualName, coerce=True, default=None, compcoef=0.4) commutator = so.SchemaField( sn.QualName, coerce=True, default=None, compcoef=0.99) negator = so.SchemaField( sn.QualName, coerce=True, default=None, compcoef=0.99) recursive = so.SchemaField( bool, default=False, compcoef=0.4) def get_display_signature(self, schema: s_schema.Schema) -> str: params = [ p.get_type(schema).get_displayname(schema) for p in self.get_params(schema).objects(schema) ] name = self.get_shortname(schema).name kind = self.get_operator_kind(schema) if kind is ft.OperatorKind.Infix: return f'{params[0]} {name} {params[1]}' elif kind is ft.OperatorKind.Postfix: return f'{params[0]} {name}' elif kind is ft.OperatorKind.Prefix: return f'{name} {params[0]}' elif kind is ft.OperatorKind.Ternary: return f'{name} ({", ".join(params)})' else: raise ValueError('unexpected operator kind') def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool = False ) -> str: return f'operator "{self.get_display_signature(schema)}"' class OperatorCommandContext(s_func.CallableCommandContext): pass class OperatorCommand( s_func.CallableCommand[Operator], context_class=OperatorCommandContext, ): def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if field == 'abstract': return field elif field == 'operator_kind': return 'kind' else: return super().get_ast_attr_for_field(field, astnode) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: if not context.stdmode and not context.testmode: raise errors.UnsupportedFeatureError( 'user-defined operators are not supported', span=astnode.span ) return super()._cmd_tree_from_ast(schema, astnode, context) @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> sn.QualName: assert isinstance(astnode, qlast.OperatorCommand) assert isinstance(astnode, qlast.ObjectDDL) name = super()._classname_from_ast(schema, astnode, context) params = cls._get_param_desc_from_ast( schema, context.modaliases, astnode) fqname = cls.get_schema_metaclass().get_fqname( schema, name, params, astnode.kind) assert isinstance(fqname, sn.QualName) return fqname class CreateOperator( s_func.CreateCallableObject[Operator], OperatorCommand, ): astnode = qlast.CreateOperator def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: fullname = self.classname shortname = sn.shortname_from_fullname(fullname) schema, cp = self._get_param_desc_from_delta(schema, context, self) signature = f'{shortname}({", ".join(p.as_str(schema) for p in cp)})' func = schema.get(fullname, None) if func: raise errors.InvalidOperatorDefinitionError( f'cannot create the `{signature}` operator: ' f'an operator with the same signature ' f'is already defined', span=self.span) schema = super()._create_begin(schema, context) params: s_func.FuncParameterList = self.scls.get_params(schema) fullname = self.scls.get_name(schema) shortname = sn.shortname_from_fullname(fullname) return_typemod = self.scls.get_return_typemod(schema) assert isinstance(self.scls, Operator) recursive = self.scls.get_recursive(schema) derivative_of = self.scls.get_derivative_of(schema) # an operator must have operands if len(params) == 0: raise errors.InvalidOperatorDefinitionError( f'cannot create the `{signature}` operator: ' f'an operator must have operands', span=self.span) # We'll need to make sure that there's no mix of recursive and # non-recursive operators being overloaded. all_arrays = all_tuples = all_ranges = True for param in params.objects(schema): ptype = param.get_type(schema) all_arrays = all_arrays and ptype.is_array() all_tuples = all_tuples and ptype.is_tuple(schema) all_ranges = all_ranges and (ptype.is_range() or ptype.is_multirange()) # It's illegal to declare an operator as recursive unless all # of its operands are the same basic type of collection. if recursive and not any([all_arrays, all_tuples, all_ranges]): raise errors.InvalidOperatorDefinitionError( f'cannot create the `{signature}` operator: ' f'operands of a recursive operator must either be ' f'all arrays or all tuples', span=self.span) for oper in lookup_operators(shortname, (), schema=schema): if oper == self.scls: continue oper_return_typemod = oper.get_return_typemod(schema) if oper_return_typemod != return_typemod: raise errors.DuplicateOperatorDefinitionError( f'cannot create the `{signature}` ' f'operator: overloading another operator with different ' f'return type {oper_return_typemod.to_edgeql()} ' f'{oper.get_return_type(schema).name}', span=self.span) oper_derivative_of = oper.get_derivative_of(schema) if oper_derivative_of: raise errors.DuplicateOperatorDefinitionError( f'cannot create the `{signature}` ' f'operator: there exists a derivative operator of the ' f'same name', span=self.span) elif derivative_of: raise errors.DuplicateOperatorDefinitionError( f'cannot create `{signature}` ' f'as a derivative operator: there already exists an ' f'operator of the same name', span=self.span) # Check if there is a recursive/non-recursive operator # overloading. oper_recursive = oper.get_recursive(schema) if recursive != oper_recursive: oper_signature = oper.get_display_signature(schema) oper_all_arrays = oper_all_tuples = oper_all_ranges = True for param in oper.get_params(schema).objects(schema): ptype = param.get_type(schema) oper_all_arrays = oper_all_arrays and ptype.is_array() oper_all_tuples = ( oper_all_tuples and ptype.is_tuple(schema) ) oper_all_ranges = oper_all_ranges and ( ptype.is_range() or ptype.is_multirange() ) if (all_arrays == oper_all_arrays and all_tuples == oper_all_tuples and all_ranges == oper_all_ranges): new_rec = 'recursive' if recursive else 'non-recursive' oper_rec = \ 'recursive' if oper_recursive else 'non-recursive' raise errors.InvalidOperatorDefinitionError( f'cannot create the {new_rec} `{signature}` operator: ' f'overloading a {oper_rec} operator ' f'`{oper_signature}` with a {new_rec} one ' f'is not allowed', span=self.span) return schema @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: assert isinstance(astnode, qlast.CreateOperator) cmd = super()._cmd_tree_from_ast(schema, astnode, context) cmd.set_attribute_value( 'operator_kind', astnode.kind, ) if astnode.code is not None: cmd.set_attribute_value( 'language', astnode.code.language, ) if astnode.code.from_operator is not None: cmd.set_attribute_value( 'from_operator', astnode.code.from_operator, ) if astnode.code.from_function is not None: cmd.set_attribute_value( 'from_function', astnode.code.from_function, ) if astnode.code.code is not None: # TODO: Make operators from code strict when we can? cmd.set_attribute_value( 'impl_is_strict', False ) cmd.set_attribute_value( 'code', astnode.code.code, ) if astnode.code.from_expr is not None: cmd.set_attribute_value( 'from_expr', astnode.code.from_expr, ) return cmd def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: assert isinstance(node, qlast.CreateOperator) new_value: Any = op.new_value if op.property == 'return_type': node.returning = utils.typeref_to_ast(schema, new_value) elif op.property == 'return_typemod': node.returning_typemod = new_value elif op.property == 'code': if node.code is None: node.code = qlast.OperatorCode() node.code.code = new_value elif op.property == 'language': if node.code is None: node.code = qlast.OperatorCode() node.code.language = new_value elif op.property == 'from_function' and new_value: if node.code is None: node.code = qlast.OperatorCode() node.code.from_function = new_value elif op.property == 'from_expr' and new_value: if node.code is None: node.code = qlast.OperatorCode() node.code.from_expr = new_value elif op.property == 'from_operator' and new_value: if node.code is None: node.code = qlast.OperatorCode() node.code.from_operator = tuple(new_value) else: super()._apply_field_ast(schema, context, node, op) class RenameOperator(sd.RenameObject[Operator], OperatorCommand): pass class AlterOperator(s_func.AlterCallableObject[Operator], OperatorCommand): astnode = qlast.AlterOperator class DeleteOperator(s_func.DeleteCallableObject[Operator], OperatorCommand): astnode = qlast.DropOperator def lookup_operators( name: sn.Name | str, default: tuple[Operator, ...] | so.NoDefaultT = so.NoDefault, *, module_aliases: Optional[Mapping[Optional[str], str]] = None, schema: s_schema.Schema, ) -> tuple[Operator, ...]: funcs: tuple[Operator, ...] | so.NoDefaultT = s_schema.lookup( schema, name, getter=s_schema._get_operators, module_aliases=module_aliases, default=default, ) if funcs is not so.NoDefault: return funcs else: return s_schema.Schema.raise_bad_reference( name=name, module_aliases=module_aliases, type=Operator, ) ================================================ FILE: edb/schema/ordering.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, Sequence, NamedTuple, TYPE_CHECKING, ) import collections from edb import errors from edb.common import ordered from edb.common import topological from . import delta as sd from . import expraliases as s_expraliases from . import functions as s_func from . import indexes as s_indexes from . import inheriting from . import name as sn from . import objects as so from . import objtypes as s_objtypes from . import pointers as s_pointers from . import constraints as s_constraints from . import referencing from . import types as s_types if TYPE_CHECKING: from . import schema as s_schema class DepGraphEntryExtra(NamedTuple): implicit_ancestors: list[sn.Name] DepGraphKey = tuple[str, str] DepGraphEntry = topological.DepGraphEntry[ DepGraphKey, tuple[sd.Command, ...], DepGraphEntryExtra, ] DepGraph = dict[DepGraphKey, DepGraphEntry] def linearize_delta( delta: sd.DeltaRoot, old_schema: Optional[s_schema.Schema], new_schema: s_schema.Schema, ) -> sd.DeltaRoot: """Reorder the *delta* tree in-place to satisfy command dependency order. Args: delta: Input delta command tree. old_schema: Schema used to resolve original object state. new_schema: Schema used to resolve final schema state. Returns: Input delta tree reordered according to topological ordering of commands. """ # We take the scatter-sort-gather approach here, where the original # tree is broken up into linear branches, which are then sorted # and reassembled back into a tree. # A map of commands to root->command paths through the tree. # Nodes are duplicated so the interior nodes of the path are # distinct. opmap: dict[sd.Command, list[sd.Command]] = {} strongrefs: dict[sn.Name, sn.Name] = {} for op in _get_sorted_subcommands(delta): _break_down(opmap, strongrefs, [delta, op]) depgraph: DepGraph = {} renames: dict[sn.Name, sn.Name] = {} renames_r: dict[sn.Name, sn.Name] = {} deletions: set[sn.Name] = set() for op in opmap: if isinstance(op, sd.RenameObject): renames[op.classname] = op.new_name renames_r[op.new_name] = op.classname elif isinstance(op, sd.DeleteObject): deletions.add(op.classname) for op, opbranch in opmap.items(): if isinstance(op, sd.AlterObject) and not op.get_subcommands(): continue _trace_op(op, opbranch, depgraph, renames, renames_r, strongrefs, old_schema, new_schema) depgraph = dict(filter(lambda i: i[1].item != (), depgraph.items())) everything = set(depgraph) for item in depgraph.values(): item.deps &= everything item.weak_deps &= everything try: sortedlist = [i[1] for i in topological.sort_ex(depgraph)] except topological.CycleError as ex: cycle = [depgraph[k].item for k in (ex.item,) + ex.path + (ex.item,)] messages = [ ' ' + nodes[-1].get_friendly_description(parent_op=nodes[-2]) for nodes in cycle ] raise errors.SchemaDefinitionError( 'cannot produce migration because of a dependency cycle:\n' + ' depends on\n'.join(messages) ) from None reconstructed = reconstruct_tree(sortedlist, depgraph) delta.replace_all(reconstructed.get_subcommands()) return delta def reconstruct_tree( sortedlist: list[DepGraphEntry], depgraph: DepGraph, ) -> sd.DeltaRoot: result = sd.DeltaRoot() # Child to parent mapping. parents: dict[sd.Command, sd.Command] = {} # A mapping of commands to their dependencies. dependencies: dict[sd.Command, set[sd.Command]] = ( collections.defaultdict(set)) # Current address of command within a tree in the form of # a tuple of indexes where each index represents relative # position within the tree rank. offsets: dict[sd.Command, tuple[int, ...]] = {} # Object commands indexed by command type and object name and # implicitness, where each entry represents the latest seen # command of the type for a particular object. Implicit commands # are included, but can only be attached to by other implicit # commands. opindex: dict[ tuple[type[sd.ObjectCommand[so.Object]], sn.Name, bool], sd.ObjectCommand[so.Object] ] = {} def ok_to_attach_to( op_to_attach: sd.Command, op_to_attach_to: sd.ObjectCommand[so.Object], only_if_confident: bool = False, ) -> bool: """Determine if a given command can be attached to another. Returns True, if *op_to_attach* can be attached to *op_to_attach_to* without violating the dependency order. """ if only_if_confident and isinstance(op_to_attach, sd.ObjectCommand): # Avoid reattaching the subcommand if confidence is below 100%, # so that granular prompts can be generated. confidence = op_to_attach.get_annotation('confidence') if confidence is not None and confidence < 1.0: return False tgt_offset = offsets[op_to_attach_to] tgt_offset_len = len(tgt_offset) deps = dependencies[op_to_attach] return all(offsets[dep][:tgt_offset_len] <= tgt_offset for dep in deps) def attach( opbranch: tuple[sd.Command, ...], new_parent: sd.Command, slice_start: int = 1, as_implicit: bool = False, ) -> None: """Attach a portion of a given command branch to another parent. Args: opbranch: Command branch to attach to *new_parent*. new_parent: Command node to attach the specified portion of *opbranch* to. slice_start: Offset into *opbranch* that determines which commands get attached. as_implicit: If True, the command branch is considered to be implicit, i.e. it is not recorded in the command index. """ parent = opbranch[slice_start] op = opbranch[-1] offset_within_parent = new_parent.get_nonattr_subcommand_count() if not isinstance(new_parent, sd.DeltaRoot): parent_offset = offsets[new_parent] + (offset_within_parent,) else: parent_offset = (offset_within_parent,) old_parent = parents[parent] old_parent.discard(parent) new_parent.add_caused(parent) parents[parent] = new_parent for i in range(slice_start, len(opbranch)): op = opbranch[i] if isinstance(op, sd.ObjectCommand): ancestor_key = (type(op), op.classname, as_implicit) opindex[ancestor_key] = op if op in offsets: op_offset = offsets[op][slice_start:] else: op_offset = (0,) * (i - slice_start) offsets[op] = parent_offset + op_offset def maybe_replace_preceding( op: sd.Command, ) -> bool: """Possibly merge and replace an earlier command with *op*. If *op* is a DELETE command, or an ALTER command that has no subcommands, and there is an earlier ALTER command operating on the same object as *op*, merge that command into *op* and replace it with *op*. Returns: True if merge and replace happened, False otherwise. """ if not ( isinstance(op, sd.DeleteObject) or ( isinstance(op, sd.AlterObject) and op.get_nonattr_subcommand_count() == 0 ) ): return False alter_cmd_cls = sd.get_object_command_class( sd.AlterObject, op.get_schema_metaclass()) if alter_cmd_cls is None: # ALTER isn't even defined for this object class, bail. return False alter_key = ((alter_cmd_cls), op.classname, False) alter_op = opindex.get(alter_key) if alter_op is None: # No preceding ALTER, bail. return False if ( not ok_to_attach_to(op, alter_op) or ( isinstance(parents[op], sd.DeltaRoot) != isinstance(parents[alter_op], sd.DeltaRoot) ) or bool(alter_op.get_subcommands(type=sd.RenameObject)) ): return False for alter_sub in reversed(alter_op.get_prerequisites()): op.prepend_prerequisite(alter_sub) parents[alter_sub] = op for alter_sub in reversed( alter_op.get_subcommands(include_prerequisites=False) ): op.prepend(alter_sub) parents[alter_sub] = op attached_root = parents[alter_op] attached_root.replace(alter_op, op) opindex[alter_key] = op opindex[type(op), op.classname, False] = op offsets[op] = offsets[alter_op] parents[op] = attached_root return True def maybe_attach_to_preceding( opbranch: tuple[sd.Command, ...], parent_candidates: list[sn.Name], allowed_op_types: list[type[sd.ObjectCommand[so.Object]]], as_implicit: bool = False, slice_start: int = 1, ) -> bool: """Find a parent and attach a given portion of command branch to it. Args: opbranch: Command branch to consider. parent_candidates: A list of parent object names to consider when looking for a parent command. allowed_op_types: A list of command types to consider when looking for a parent command. as_implicit: If True, the command branch is considered to be implicit, i.e. it is not recorded in the command index. slice_start: Offset into *opbranch* that determines which commands get attached. """ for candidate in parent_candidates: for op_type in allowed_op_types: parent_op = opindex.get((op_type, candidate, False)) # implicit ops are allowed to attach to other implicit # ops. (Since we want them to chain properly in # inheritance order.) if parent_op is None and as_implicit: parent_op = opindex.get((op_type, candidate, True)) if ( parent_op is not None and ok_to_attach_to( op, parent_op, only_if_confident=not as_implicit, ) ): attach( opbranch, parent_op, as_implicit=as_implicit, slice_start=slice_start, ) return True return False # First, build parents and dependencies maps. for info in sortedlist: opbranch = info.item op = opbranch[-1] for j, pop in enumerate(opbranch[1:]): parents[pop] = opbranch[j] for dep in info.deps: dep_item = depgraph[dep] dep_stack = dep_item.item dep_op = dep_stack[-1] dependencies[op].add(dep_op) for info in sortedlist: opbranch = info.item op = opbranch[-1] # Elide empty ALTER statements from output. if isinstance(op, sd.AlterObject) and not op.get_subcommands(): continue # If applicable, replace a preceding command with this op. if maybe_replace_preceding(op): continue if ( isinstance(op, sd.ObjectCommand) and not isinstance(op, sd.CreateObject) and info.extra is not None and info.extra.implicit_ancestors ): # This command is deemed to be an implicit effect of another # command, such as when alteration is propagated through the # inheritance chain. If so, find a command that operates on # a parent object and attach this branch to it. allowed_ops = [type(op)] if isinstance(op, sd.DeleteObject): allowed_ops.append(op.get_other_command_class(sd.DeleteObject)) if maybe_attach_to_preceding( opbranch, info.extra.implicit_ancestors, allowed_ops, as_implicit=True, ): continue # Walking the branch toward root, see if there's a matching # branch prefix we could attach to. for depth, ancestor_op in enumerate(reversed(opbranch[1:-1])): assert isinstance(ancestor_op, sd.ObjectCommand) allowed_ops = [] create_cmd_t = ancestor_op.get_other_command_class(sd.CreateObject) if type(ancestor_op) is not create_cmd_t: allowed_ops.append(create_cmd_t) allowed_ops.append(type(ancestor_op)) if maybe_attach_to_preceding( opbranch, [ancestor_op.classname], allowed_ops, slice_start=len(opbranch) - (depth + 1), ): break else: # No branches to attach to, so attach to root. attach(opbranch, result) return result def _command_key(cmd: sd.Command) -> Any: if isinstance(cmd, sd.ObjectCommand): return (cmd.get_schema_metaclass().__name__, cmd.classname) elif isinstance(cmd, sd.AlterObjectProperty): return ('.field', cmd.property) else: return ('_generic', type(cmd).__name__) def _get_sorted_subcommands(cmd: sd.Command) -> list[sd.Command]: subcommands = list(cmd.get_subcommands()) subcommands.sort(key=_command_key) return subcommands def _break_down( opmap: dict[sd.Command, list[sd.Command]], strongrefs: dict[sn.Name, sn.Name], opbranch: list[sd.Command], ) -> None: if len(opbranch) > 2: new_opbranch = _extract_op(opbranch) else: new_opbranch = opbranch op = new_opbranch[-1] breakable_commands = ( referencing.ReferencedObjectCommand, sd.RenameObject, inheriting.RebaseInheritingObject, ) for sub_op in _get_sorted_subcommands(op): if ( isinstance(sub_op, sd.AlterObjectProperty) and not isinstance(op, sd.DeleteObject) ): assert isinstance(op, sd.ObjectCommand) mcls = op.get_schema_metaclass() field = mcls.get_field(sub_op.property) # Break a possible reference cycle # (i.e. Type.rptr <-> Pointer.target) if ( field.weak_ref or ( isinstance(op, sd.AlterObject) and issubclass(field.type, so.Object) ) ): _break_down(opmap, strongrefs, new_opbranch + [sub_op]) elif ( isinstance(sub_op, sd.AlterSpecialObjectField) and not isinstance( sub_op, ( referencing.AlterOwned, s_pointers.SetPointerType, ) ) ): pass elif ( isinstance(sub_op, referencing.ReferencedObjectCommandBase) and sub_op.is_strong_ref ): assert isinstance(op, sd.ObjectCommand) strongrefs[sub_op.classname] = op.classname elif isinstance(sub_op, breakable_commands): _break_down(opmap, strongrefs, new_opbranch + [sub_op]) # For SET TYPE and friends, we need to make sure that an alter # (with children) makes it into the opmap so it is processed. if ( isinstance(op, sd.AlterSpecialObjectField) and not isinstance(op, referencing.AlterOwned) ): opmap[new_opbranch[-2]] = new_opbranch[:-1] opmap[op] = new_opbranch def _trace_op( op: sd.Command, opbranch: list[sd.Command], depgraph: DepGraph, renames: dict[sn.Name, sn.Name], renames_r: dict[sn.Name, sn.Name], strongrefs: dict[sn.Name, sn.Name], old_schema: Optional[s_schema.Schema], new_schema: s_schema.Schema, ) -> None: def get_deps(key: DepGraphKey) -> DepGraphEntry: try: item = depgraph[key] except KeyError: item = depgraph[key] = DepGraphEntry( item=(), deps=ordered.OrderedSet(), weak_deps=ordered.OrderedSet(), ) return item def record_field_deps( op: sd.AlterObjectProperty, parent_op: sd.ObjectCommand[so.Object], ) -> str: nvn = None if isinstance(op.new_value, (so.Object, so.ObjectShell)): obj = op.new_value nvn = obj.get_name(new_schema) if nvn is not None: deps.add(('create', str(nvn))) deps.add(('alter', str(nvn))) if nvn in renames_r: deps.add(('rename', str(renames_r[nvn]))) if isinstance(obj, so.ObjectShell): obj = obj.resolve(new_schema) # For SET TYPE, we want to finish any rebasing into the # target type before we change the type. if isinstance(obj, so.InheritingObject): for desc in obj.descendants(new_schema): deps.add(('rebase', str(desc.get_name(new_schema)))) graph_key = f'{parent_op.classname}%%{op.property}' deps.add(('create', str(parent_op.classname))) deps.add(('alter', str(parent_op.classname))) if isinstance(op.old_value, (so.Object, so.ObjectShell)): assert old_schema is not None ovn = op.old_value.get_name(old_schema) if ovn != nvn: ov_item = get_deps(('delete', str(ovn))) ov_item.deps.add((tag, graph_key)) return graph_key def write_dep_matrix( dependent: str, dependent_tags: tuple[str, ...], dependency: str, dependency_tags: tuple[str, ...], *, as_weak: bool = False, ) -> None: for dependent_tag in dependent_tags: item = get_deps((dependent_tag, dependent)) for dependency_tag in dependency_tags: if as_weak: item.weak_deps.add((dependency_tag, dependency)) else: item.deps.add((dependency_tag, dependency)) def write_ref_deps( ref: so.Object, obj: so.Object, this_name_str: str, ) -> None: ref_name = ref.get_name(new_schema) if ref_name in renames_r: ref_name = renames_r[ref_name] ref_name_str = str(ref_name) if ((isinstance(ref, referencing.ReferencedObject) and ref.get_referrer(new_schema) == obj) or (isinstance(obj, referencing.ReferencedObject) and obj.get_referrer(new_schema) == ref)): # Mostly ignore refs generated by refdict backref, but # make create/alter depend on renames of the backref. # This makes sure that a rename is done before the innards are # modified. DDL doesn't actually require this but some of the # internals for producing the DDL do (since otherwise we can # generate references to the renamed type in our delta before # it is renamed). if tag in ('create', 'alter'): deps.add(('rename', ref_name_str)) return write_dep_matrix( dependent=ref_name_str, dependent_tags=('create', 'alter', 'rebase'), dependency=this_name_str, dependency_tags=('create', 'alter', 'rename'), ) item = get_deps(('rename', ref_name_str)) item.deps.add(('create', this_name_str)) item.deps.add(('alter', this_name_str)) item.deps.add(('rename', this_name_str)) if isinstance(ref, s_pointers.Pointer): # The current item is a type referred to by # a link or property in another type. Set the referring # type and its descendants as weak dependents of the current # item to reduce the number of unnecessary ALTERs in the # final delta, especially ones that might result in SET TYPE # commands being generated. ref_src = ref.get_source(new_schema) if isinstance(ref_src, s_pointers.Pointer): ref_src_src = ref_src.get_source(new_schema) if ref_src_src is not None: ref_src = ref_src_src if ref_src is not None: for desc in ref_src.descendants(new_schema) | {ref_src}: desc_name = str(desc.get_name(new_schema)) write_dep_matrix( dependent=desc_name, dependent_tags=('create', 'alter'), dependency=this_name_str, dependency_tags=('create', 'alter', 'rename'), as_weak=True, ) deps: ordered.OrderedSet[tuple[str, str]] = ordered.OrderedSet() graph_key: str implicit_ancestors: list[sn.Name] = [] if isinstance(op, sd.CreateObject): tag = 'create' elif isinstance(op, sd.AlterObject): tag = 'alter' elif isinstance(op, sd.RenameObject): tag = 'rename' elif isinstance(op, inheriting.RebaseInheritingObject): tag = 'rebase' elif isinstance(op, sd.DeleteObject): tag = 'delete' elif isinstance(op, referencing.AlterOwned): if op.get_attribute_value('owned'): tag = 'setowned' else: tag = 'dropowned' elif isinstance(op, (sd.AlterObjectProperty, sd.AlterSpecialObjectField)): tag = 'field' else: raise RuntimeError( f'unexpected delta command type at top level: {op!r}' ) if isinstance(op, (sd.DeleteObject, referencing.AlterOwned)): assert old_schema is not None try: obj = get_object(old_schema, op) except errors.InvalidReferenceError: if isinstance(op, sd.DeleteObject) and op.if_exists: # If this is conditional deletion and the object isn't there, # then don't bother with analysis, since this command wouldn't # get executed. return else: raise refs = _get_referrers(old_schema, obj, strongrefs) for ref in refs: ref_name_str = str(ref.get_name(old_schema)) if ( ( isinstance(obj, referencing.ReferencedObject) and obj.get_referrer(old_schema) == ref ) ): # If the referrer is enclosing the object # (i.e. the reference is a refdict reference), # we sort the enclosed operation first. ref_item = get_deps(('delete', ref_name_str)) ref_item.deps.add((tag, str(op.classname))) elif ( isinstance(ref, referencing.ReferencedInheritingObject) and ( op.classname in { b.get_name(old_schema) for b in ref.get_implicit_ancestors(old_schema) } ) and ( not isinstance(ref, s_pointers.Pointer) or not ref.get_from_alias(old_schema) ) ): # If the ref is an implicit descendant (i.e. an inherited ref), # we also sort it _after_ the parent, because we'll pull # it as a child of the parent op at the time of tree # reassembly. ref_item = get_deps(('delete', ref_name_str)) ref_item.deps.add((tag, str(op.classname))) elif ( isinstance(ref, referencing.ReferencedObject) and ref.get_referrer(old_schema) == obj ): # Skip refdict.backref_attr to avoid dependency cycles. continue else: # Otherwise, things must be deleted _after_ their referrers # have been deleted or altered. deps.add(('delete', ref_name_str)) # (except for aliases, which in the collection case # specifically need the old target deleted before the # new one is created) if not isinstance(ref, s_expraliases.Alias): deps.add(('alter', ref_name_str)) if type(ref) is type(obj): deps.add(('rebase', ref_name_str)) # The deletion of any implicit ancestors needs to come after # the deletion of any referrers also. if isinstance(obj, referencing.ReferencedInheritingObject): for ancestor in obj.get_implicit_ancestors(old_schema): ancestor_name = ancestor.get_name(old_schema) anc_item = get_deps(('delete', str(ancestor_name))) anc_item.deps.add(('delete', ref_name_str)) if isinstance(obj, referencing.ReferencedObject): if tag == 'delete': # If the object is being deleted and then recreated # via inheritance, that deletion needs to come before # an ancestor gets created (since that will cause our # recreation.) try: new_obj = get_object(new_schema, op) except errors.InvalidReferenceError: new_obj = None if isinstance(new_obj, referencing.ReferencedInheritingObject): for ancestor in new_obj.get_implicit_ancestors(new_schema): rep_item = get_deps( ('create', str(ancestor.get_name(new_schema)))) rep_item.deps.add((tag, str(op.classname))) referrer = obj.get_referrer(old_schema) if referrer is not None: assert isinstance(referrer, so.QualifiedObject) referrer_name: sn.Name = referrer.get_name(old_schema) if referrer_name in renames_r: referrer_name = renames_r[referrer_name] # A drop needs to come *before* drop owned on the referrer # which will do it itself. if tag == 'delete': ref_item = get_deps(('dropowned', str(referrer_name))) ref_item.deps.add(('delete', str(op.classname))) # For SET OWNED, we need any rebase of the enclosing # object to come *after*, because otherwise obj could # get dropped before the SET OWNED takes effect. # DROP, also. if tag in ('setowned', 'delete'): ref_item = get_deps(('rebase', str(referrer_name))) ref_item.deps.add((tag, str(op.classname))) else: deps.add(('rebase', str(referrer_name))) if ( isinstance(obj, referencing.ReferencedInheritingObject) and ( not isinstance(obj, s_pointers.Pointer) or not obj.get_from_alias(old_schema) ) ): for ancestor in obj.get_implicit_ancestors(old_schema): ancestor_name = ancestor.get_name(old_schema) implicit_ancestors.append(ancestor_name) if isinstance(op, referencing.AlterOwned): anc_item = get_deps(('delete', str(ancestor_name))) anc_item.deps.add((tag, str(op.classname))) if tag == 'setowned': # SET OWNED must come before ancestor rebases too anc_item = get_deps(('rebase', str(ancestor_name))) anc_item.deps.add(('setowned', str(op.classname))) if tag == 'dropowned': deps.add(('alter', str(op.classname))) graph_key = str(op.classname) elif isinstance(op, sd.AlterObjectProperty): parent_op = opbranch[-2] assert isinstance(parent_op, sd.ObjectCommand) graph_key = record_field_deps(op, parent_op) elif isinstance(op, sd.AlterSpecialObjectField): parent_op = opbranch[-2] assert isinstance(parent_op, sd.ObjectCommand) field_op = op._get_attribute_set_cmd(op._field) assert field_op is not None graph_key = record_field_deps(field_op, parent_op) elif isinstance(op, sd.ObjectCommand): # If the object was renamed, use the new name, else use regular. name = renames.get(op.classname, op.classname) obj = get_object(new_schema, op, name) this_name_str = str(op.classname) if tag == 'rename': # On renames, we want to delete any references before we # do the rename. This is because for functions and # constraints we implicitly rename the object when # something it references is renamed, and this implicit # rename can interfere with a CREATE/DELETE pair. So we # make sure to put the DELETE before the RENAME of a # referenced object. (An improvement would be to elide a # CREATE/DELETE pair when it could be implicitly handled # by a rename). assert old_schema old_obj = get_object(old_schema, op, op.classname) for ref in _get_referrers(old_schema, old_obj, strongrefs): deps.add(('delete', str(ref.get_name(old_schema)))) refs = _get_referrers(new_schema, obj, strongrefs) for ref in refs: write_ref_deps(ref, obj, this_name_str) if tag == 'create': # In a delete/create cycle, deletion must obviously # happen first. deps.add(('delete', this_name_str)) # Renaming also deps.add(('rename', this_name_str)) if isinstance(obj, s_func.Function) and old_schema is not None: old_funcs = old_schema._get_by_shortname( s_func.Function, sn.shortname_from_fullname(op.classname), ) or () for old_func in old_funcs: deps.add(('delete', str(old_func.get_name(old_schema)))) # Some index types only allow one per object type. Make # sure we drop the old one before creating the new. if ( isinstance(obj, s_indexes.Index) and s_indexes.is_exclusive_object_scope_index(new_schema, obj) and old_schema is not None and (subject := obj.get_subject(new_schema)) and (old_subject := old_schema.get( subject.get_name(new_schema), type=s_objtypes.ObjectType, default=None )) and (eff_index := s_indexes.get_effective_object_index( old_schema, old_subject, obj.get_root(new_schema).get_name(new_schema), )[0]) ): deps.add(('delete', str(eff_index.get_name(old_schema)))) if tag == 'alter': # Alteration must happen after creation, if any. deps.add(('create', this_name_str)) deps.add(('rename', this_name_str)) deps.add(('rebase', this_name_str)) if isinstance(obj, referencing.ReferencedObject): referrer = obj.get_referrer(new_schema) if referrer is not None: assert isinstance(referrer, so.QualifiedObject) referrer_name = referrer.get_name(new_schema) if referrer_name in renames_r: referrer_name = renames_r[referrer_name] ref_name_str = str(referrer_name) deps.add(('create', ref_name_str)) if op.ast_ignore_ownership() or tag == 'rename': ref_item = get_deps(('rebase', ref_name_str)) ref_item.deps.add((tag, this_name_str)) else: deps.add(('rebase', ref_name_str)) # Addition and removal of constraints can cause # changes to the cardinality of expressions that refer # to them. Add the appropriate dependencies in. if ( isinstance(obj, s_constraints.Constraint) and isinstance(referrer, s_pointers.Pointer) ): refs = _get_referrers(new_schema, referrer, strongrefs) for ref in refs: write_ref_deps(ref, referrer, this_name_str) if ( isinstance(obj, referencing.ReferencedInheritingObject) # Changes to owned objects can't necessarily be merged # in with parents, so we make sure not to. and not obj.get_owned(new_schema) ): implicit_ancestors = [ b.get_name(new_schema) for b in obj.get_implicit_ancestors(new_schema) ] if not isinstance(op, sd.CreateObject): assert old_schema is not None name = renames_r.get(op.classname, op.classname) old_obj = get_object(old_schema, op, name) assert isinstance( old_obj, referencing.ReferencedInheritingObject, ) implicit_ancestors += [ b.get_name(old_schema) for b in old_obj.get_implicit_ancestors(old_schema) ] graph_key = this_name_str else: raise AssertionError(f'unexpected op type: {op!r}') item = get_deps((tag, graph_key)) item.item = tuple(opbranch) item.deps |= deps item.extra = DepGraphEntryExtra( implicit_ancestors=[renames_r.get(a, a) for a in implicit_ancestors], ) def get_object( schema: s_schema.Schema, op: sd.ObjectCommand[so.Object], name: Optional[sn.Name] = None, ) -> so.Object: metaclass = op.get_schema_metaclass() if name is None: name = op.classname if issubclass(metaclass, s_types.Collection): if isinstance(name, sn.QualName): return schema.get(name) else: return schema.get_global(metaclass, name) elif not issubclass(metaclass, so.QualifiedObject): obj = schema.get_global(metaclass, name) assert isinstance(obj, so.Object) return obj else: return schema.get(name) def _get_referrers( schema: s_schema.Schema, obj: so.Object, strongrefs: dict[sn.Name, sn.Name], ) -> list[so.Object]: refs = schema.get_referrers(obj) result: set[so.Object] = set() for ref in refs: if not ref.is_blocking_ref(schema, obj): continue referrer: so.Object | None = None parent_ref = strongrefs.get(ref.get_name(schema)) if parent_ref is not None: referrer = schema.get(parent_ref, default=None) if not referrer or obj == referrer: referrer = ref result.add(referrer) return list(sorted( result, key=lambda o: (type(o).__name__, o.get_name(schema)), )) def _extract_op(stack: Sequence[sd.Command]) -> list[sd.Command]: parent_op = stack[0] new_stack = [parent_op] for stack_op in stack[1:-1]: assert isinstance(stack_op, sd.ObjectCommand) alter_class = stack_op.get_other_command_class(sd.AlterObject) alter_delta = alter_class( classname=stack_op.classname, ddl_identity=stack_op.ddl_identity, aux_object_data=stack_op.aux_object_data, annotations=stack_op.annotations, canonical=stack_op.canonical, orig_cmd_type=type(stack_op), ) parent_op.add(alter_delta) parent_op = alter_delta new_stack.append(parent_op) stack[-2].discard(stack[-1]) parent_op.add(stack[-1]) new_stack.append(stack[-1]) return new_stack ================================================ FILE: edb/schema/permissions.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb.edgeql import ast as qlast from edb.edgeql import qltypes from . import annos as s_anno from . import delta as sd from . import objects as so class Permission( so.QualifiedObject, s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.PERMISSION, data_safe=True, ): pass class PermissionCommandContext( sd.ObjectCommandContext[Permission], s_anno.AnnotationSubjectCommandContext, ): pass class PermissionCommand( sd.QualifiedObjectCommand[Permission], s_anno.AnnotationSubjectCommand[Permission], context_class=PermissionCommandContext, ): pass class CreatePermission( PermissionCommand, sd.CreateObject[Permission], ): astnode = qlast.CreatePermission class AlterPermission( PermissionCommand, sd.AlterObject[Permission], ): astnode = qlast.AlterPermission class DeletePermission( PermissionCommand, sd.DeleteObject[Permission], ): astnode = qlast.DropPermission class RenamePermission( PermissionCommand, sd.RenameObject[Permission], ): pass ================================================ FILE: edb/schema/pointers.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, cast, Iterable, Optional, Self, Sequence, TYPE_CHECKING, ) import abc import collections.abc import enum import json import operator import dataclasses from edb import errors from edb.common import enum as s_enum from edb.common import struct from edb.common import parsing from edb.common import ast from edb.common.typeutils import not_none from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from edb.edgeql import quote as qlquote from . import annos as s_anno from . import constraints from . import delta as sd from . import expr as s_expr from . import expraliases as s_expraliases from . import futures as s_futures from . import inheriting from . import name as sn from . import objects as so from . import referencing from . import rewrites as s_rewrites from . import schema as s_schema from . import types as s_types from . import scalars as s_scalars from . import utils if TYPE_CHECKING: from . import objtypes as s_objtypes from . import sources as s_sources from edb.ir import ast as irast class PointerDirection(s_enum.StrEnum): Outbound = '>' Inbound = '<' class LineageStatus(enum.Enum): VALID = 0 MULTIPLE_COMPUTABLES = 1 MIXED = 2 def merge_cardinality( ptr: Pointer, bases: list[Pointer], field_name: str, *, ignore_local: bool, schema: s_schema.Schema, ) -> Any: current: Optional[qltypes.SchemaCardinality] = None current_from = None if not ignore_local: current = ptr.get_explicit_field_value(schema, field_name, None) if current is not None: current_from = ptr for base in bases: # ignore abstract pointers if base.is_non_concrete(schema): continue nextval: Optional[qltypes.SchemaCardinality] = ( base.get_field_value(schema, field_name)) if nextval is None: continue if current is None: current = nextval current_from = base elif not current.is_known() and nextval is not None: current = nextval current_from = base elif current is not nextval: tgt_repr = ptr.get_verbosename(schema, with_parent=True) assert current_from is not None cf_repr = current_from.get_verbosename(schema, with_parent=True) other_repr = base.get_verbosename(schema, with_parent=True) if current.is_known(): current_qual = f'defined as {current.as_ptr_qual()!r}' else: current_qual = 'unknown' if nextval.is_known(): nextval_qual = f'defined as {nextval.as_ptr_qual()!r}' else: nextval_qual = 'unknown' raise errors.SchemaDefinitionError( f'cannot redefine the cardinality of ' f'{tgt_repr}: it is {current_qual} in {cf_repr} and ' f'is {nextval_qual} in {other_repr}.' ) return current def merge_readonly( target: Pointer, sources: list[Pointer], field_name: str, *, ignore_local: bool, schema: s_schema.Schema, ) -> Any: current = None current_from = None # The target field value is only relevant if it is explicit, # otherwise it should be based on the inherited value. if not ignore_local: current = target.get_explicit_field_value(schema, field_name, None) if current is not None: current_from = target for source in list(sources): # ignore abstract pointers if source.is_non_concrete(schema): continue # We want the field value including the default, not just # explicit value. nextval = source.get_field_value(schema, field_name) if nextval is not None: if current is None: current = nextval current_from = source elif current is not nextval: assert current_from is not None tgt_repr = target.get_verbosename( schema, with_parent=True) cf_repr = current_from.get_verbosename( schema, with_parent=True) other_repr = source.get_verbosename( schema, with_parent=True) raise errors.SchemaDefinitionError( f'cannot redefine the readonly flag of ' f'{tgt_repr}: it is defined ' f'as {current} in {cf_repr} and ' f'as {nextval} in {other_repr}.' ) return current def merge_required( ptr: Pointer, bases: list[Pointer], field_name: str, *, ignore_local: bool = False, schema: s_schema.Schema, ) -> Optional[bool]: """Merge function for the REQUIRED qualifier on links and properties.""" local_required = ptr.get_explicit_local_field_value( schema, field_name, None) if ignore_local or local_required is None: # No explicit local declaration, so True if any of the bases # have it as required, and False otherwise. return utils.merge_reduce( ptr, bases, field_name=field_name, ignore_local=ignore_local, schema=schema, f=operator.or_, type=bool, ) elif local_required: # If set locally and True, just use that. assert isinstance(local_required, bool) return local_required else: # Explicitly set locally as False, check if any of the bases # are REQUIRED, and if so, raise. for base in bases: base_required = base.get_field_value(schema, field_name) if base_required: ptr_repr = ptr.get_verbosename(schema, with_parent=True) base_repr = base.get_verbosename(schema, with_parent=True) raise errors.SchemaDefinitionError( f'cannot make {ptr_repr} optional: its parent {base_repr} ' f'is defined as required' ) return False def merge_target( ptr: Pointer, bases: list[Pointer], field_name: str, *, ignore_local: bool = False, schema: s_schema.Schema, ) -> Optional[s_types.Type]: target = None current_source = None for base in bases: base_target = base.get_target(schema) if base_target is None: continue if target is None: target = base_target current_source = base.get_source(schema) else: assert current_source is not None source = base.get_source(schema) assert source is not None schema, target = _merge_types( schema, ptr, target, base_target, t1_source=current_source, t2_source=source, allow_contravariant=True, ) if not ignore_local: local_target = ptr.get_target(schema) if target is None: target = local_target elif local_target is not None: assert current_source is not None schema, target = _merge_types( schema, ptr, target, local_target, t1_source=current_source, t2_source=None, ) return target def _merge_types( schema: s_schema.Schema, ptr: Pointer, t1: s_types.Type, t2: s_types.Type, *, t1_source: so.Object, t2_source: Optional[so.Object], allow_contravariant: bool = False, ) -> tuple[s_schema.Schema, Optional[s_types.Type]]: if t1 == t2: return schema, t1 # When two pointers are merged, check target compatibility # and return a target that satisfies both specified targets. elif isinstance(t1, s_scalars.ScalarType) != isinstance( t2, s_scalars.ScalarType ): # Mixing a property with a link. vnp = ptr.get_verbosename(schema, with_parent=True) vn = ptr.get_verbosename(schema) t1_vn = t1.get_verbosename(schema) t2_vn = t2.get_verbosename(schema) t1_cls = 'property' if isinstance(t1, s_scalars.ScalarType) else 'link' t2_cls = 'property' if isinstance(t2, s_scalars.ScalarType) else 'link' t1_source_vn = t1_source.get_verbosename(schema, with_parent=True) if t2_source is None: raise errors.SchemaError( f'cannot redefine {vnp} as {t2_vn}', details=( f'{vn} is defined as a {t1_cls} to {t1_vn} in' f' parent {t1_source_vn}' ), ) else: t2_source_vn = t2_source.get_verbosename(schema, with_parent=True) raise errors.SchemaError( f'inherited {vnp} has a type conflict', details=( f'{vn} is defined as a {t1_cls} to {t1_vn} in' f' parent {t1_source_vn} and as {t2_cls} in' f' parent {t2_source_vn}' ), ) else: assert isinstance(t1, so.SubclassableObject) assert isinstance(t2, so.SubclassableObject) if t2.issubclass(schema, t1): # The new target is a subclass of the current target, so # it is a more specific requirement. current_target = t2 elif allow_contravariant and t1.issubclass(schema, t2): current_target = t1 else: # The new target is not a subclass, of the previously seen # targets, which creates an unresolvable target requirement # conflict. vnp = ptr.get_verbosename(schema, with_parent=True) vn = ptr.get_verbosename(schema) t1_vn = t1.get_verbosename(schema) t2_vn = t2.get_verbosename(schema) t1_source_vn = t1_source.get_verbosename(schema, with_parent=True) if t2_source is None: raise errors.SchemaError( f'cannot redefine {vnp} as {t2_vn}', details=( f'{vn} is defined as {t1_vn} in' f' parent {t1_source_vn}' ), ) else: t2_source_vn = t2_source.get_verbosename( schema, with_parent=True) raise errors.SchemaError( f'inherited {vnp} has a type conflict', details=( f'{vn} is defined as {t1_vn} in' f' parent {t1_source_vn} and as {t2_vn} in' f' parent {t2_source_vn}' ), ) return schema, current_target def get_root_source( obj: Optional[so.Object], schema: s_schema.Schema ) -> Optional[so.Object]: while isinstance(obj, Pointer): obj = obj.get_source(schema) return obj def is_view_source( source: Optional[so.Object], schema: s_schema.Schema ) -> bool: source = get_root_source(source, schema) return isinstance(source, s_types.Type) and source.is_view(schema) def _get_target_name_in_diff( *, schema: s_schema.Schema, orig_schema: Optional[s_schema.Schema], object: Optional[so.Object], orig_object: Optional[so.Object], ) -> sn.Name: """Compute the target type name for a fill/conv expr Called from record_diff_annotations to produce annotation information for migrations. The trickiness here is that this information is generated when producing the diff, where we have somewhat limited information. """ # Prefer getting the target type from the original object instead # of the new one, for a cheesy reason: if we change both # required/cardinality and target type, we do the cardinality # change before the cast, for reasons of alphabetical order. if isinstance(orig_object, Pointer): assert orig_schema target = orig_object.get_target(orig_schema) return not_none(target).get_name(orig_schema) else: assert isinstance(object, Pointer) target = object.get_target(schema) return not_none(target).get_name(schema) class Pointer( referencing.NamedReferencedInheritingObject, constraints.ConsistencySubject, s_anno.AnnotationSubject, ): source = so.SchemaField( so.InheritingObject, default=None, compcoef=None, inheritable=False) target = so.SchemaField( s_types.Type, merge_fn=merge_target, default=None, compcoef=0.85, special_ddl_syntax=True, ) required = so.SchemaField( bool, default=False, compcoef=0.909, special_ddl_syntax=True, describe_visibility=( so.DescribeVisibilityPolicy.SHOW_IF_EXPLICIT_OR_DERIVED ), merge_fn=merge_required, ) readonly = so.SchemaField( bool, allow_ddl_set=True, describe_visibility=( so.DescribeVisibilityPolicy.SHOW_IF_EXPLICIT_OR_DERIVED_NOT_DEFAULT ), default=False, compcoef=0.909, merge_fn=merge_readonly, ) splat_strategy = so.SchemaField( qltypes.SplatStrategy, allow_ddl_set=True, describe_visibility=( so.DescribeVisibilityPolicy.SHOW_IF_EXPLICIT_OR_DERIVED_NOT_DEFAULT ), coerce=True, default=qltypes.SplatStrategy.Default, compcoef=0.909, ) secret = so.SchemaField( bool, default=False, compcoef=0.909, ) protected = so.SchemaField( bool, default=False, compcoef=0.909, ) linkful = so.SchemaField( bool, default=False, compcoef=0.99, inheritable=False, ) # For non-derived pointers this is strongly correlated with # "expr" below. Derived pointers might have "computable" set, # but expr=None. computable = so.SchemaField( bool, default=False, compcoef=0.99, ) # True, if this pointer is defined in an Alias. from_alias = so.SchemaField( bool, default=None, compcoef=0.99, # This value needs to be recorded in the delta commands # to signal that we don't want to render this command in DDL. aux_cmd_data=True, ) # Is this pointer a "definition site" of some kind or just a # trivial inheritor. Used to determine whether to use this pointer # or a parent when computing path ids. defined_here = so.SchemaField( bool, inheritable=False, ephemeral=True, default=False) # Computable pointers have this set to an expression # defining them. expr = so.SchemaField( s_expr.Expression, default=None, coerce=True, compcoef=0.909, special_ddl_syntax=True, ) default = so.SchemaField( s_expr.Expression, allow_ddl_set=True, describe_visibility=( so.DescribeVisibilityPolicy.SHOW_IF_EXPLICIT_OR_DERIVED ), default=None, coerce=True, compcoef=0.909, ) cardinality = so.SchemaField( qltypes.SchemaCardinality, default=qltypes.SchemaCardinality.One, compcoef=0.833, coerce=True, special_ddl_syntax=True, describe_visibility=( so.DescribeVisibilityPolicy.SHOW_IF_EXPLICIT_OR_DERIVED ), merge_fn=merge_cardinality, ) union_of = so.SchemaField( so.ObjectSet['Pointer'], default=None, coerce=True, type_is_generic_self=True, ) intersection_of = so.SchemaField( so.ObjectSet['Pointer'], default=None, coerce=True, type_is_generic_self=True, ) computed_link_alias_is_backward = so.SchemaField( bool, default=None, compcoef=0.99, ) computed_link_alias = so.SchemaField( so.Object, default=None, compcoef=0.99, ) rewrites_refs = so.RefDict( attr="rewrites", requires_explicit_overloaded=True, backref_attr="subject", ref_cls=s_rewrites.Rewrite, ) rewrites = so.SchemaField( so.ObjectIndexByUnqualifiedName[s_rewrites.Rewrite], inheritable=False, ephemeral=True, coerce=True, compcoef=0.857, default=so.DEFAULT_CONSTRUCTOR, ) def is_tuple_indirection(self) -> bool: return False def is_type_intersection(self) -> bool: return False def is_generated(self, schema: s_schema.Schema) -> bool: return bool(self.get_from_alias(schema)) def get_subject(self, schema: s_schema.Schema) -> Optional[so.Object]: # Required by ReferencedObject return self.get_source(schema) @classmethod def get_displayname_static(cls, name: sn.Name) -> str: sn = cls.get_shortname_static(name) if sn.module == '__': return sn.name else: return str(sn) def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool=False, ) -> str: vn = super().get_verbosename(schema) if self.is_non_concrete(schema): return f'abstract {vn}' else: if with_parent: source = self.get_source(schema) assert source is not None pvn = source.get_verbosename( schema, with_parent=True) return f'{vn} of {pvn}' else: return vn def is_scalar(self) -> bool: return False def material_type( self, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Pointer]: non_derived_parent = self.get_nearest_non_derived_parent(schema) source = non_derived_parent.get_source(schema) if source is None: return schema, self else: return schema, non_derived_parent def get_nearest_defined(self, schema: s_schema.Schema) -> Pointer: """ Find the pointer definition site. For view pointers, find the place where the pointer is "really" defined that is, either its schema definition site or where it last had a expression defining it. """ ptrcls = self while ( ptrcls.get_is_derived(schema) and not ptrcls.get_defined_here(schema) # schema defined computeds don't have the ephemeral defined_here # set, but they do have expr set, so we check that also. and not ptrcls.get_expr(schema) and (bases := ptrcls.get_bases(schema).objects(schema)) and len(bases) == 1 and bases[0].get_source(schema) ): ptrcls = bases[0] return ptrcls def get_near_endpoint( self, schema: s_schema.Schema, direction: PointerDirection, ) -> Optional[so.Object]: if direction == PointerDirection.Outbound: return self.get_source(schema) else: return self.get_target(schema) def get_far_endpoint( self, schema: s_schema.Schema, direction: PointerDirection, ) -> Optional[so.Object]: if direction == PointerDirection.Outbound: return self.get_target(schema) else: return self.get_source(schema) def set_target( self, schema: s_schema.Schema, target: s_types.Type, ) -> s_schema.Schema: return self.set_field_value(schema, 'target', target) def get_derived( self: Self, schema: s_schema.Schema, source: s_sources.Source, target: s_types.Type, *, derived_name_base: Optional[sn.Name] = None, **kwargs: Any ) -> tuple[s_schema.Schema, Self]: fqname = self.derive_name( schema, source, derived_name_base=derived_name_base) ptr = schema.get(fqname, default=None) if ptr is None: fqname = self.derive_name( schema, source, str(target.get_name(schema)), derived_name_base=derived_name_base, ) ptr = schema.get(fqname, default=None) if ptr is None: schema, ptr = self.derive_ref( schema, source, target=target, derived_name_base=derived_name_base, **kwargs) return schema, ptr # type: ignore def get_derived_name_base( self, schema: s_schema.Schema, ) -> sn.QualName: shortname = self.get_shortname(schema) return sn.QualName(module='__', name=shortname.name) def derive_ref( self, schema: s_schema.Schema, referrer: so.QualifiedObject, *qualifiers: str, target: Optional[s_types.Type] = None, mark_derived: bool = False, attrs: Optional[dict[str, Any]] = None, dctx: Optional[sd.CommandContext] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, Pointer]: if target is None: if attrs and 'target' in attrs: target = attrs['target'] else: target = self.get_target(schema) if attrs is None: attrs = {} attrs['source'] = referrer attrs['target'] = target return super().derive_ref( schema, referrer, mark_derived=mark_derived, dctx=dctx, attrs=attrs, **kwargs) def is_pure_computable(self, schema: s_schema.Schema) -> bool: return bool(self.get_expr(schema)) or bool(self.get_computable(schema)) def is_id_pointer(self, schema: s_schema.Schema) -> bool: local_name = self.get_local_name(schema) if local_name.name != 'id': return False from edb.schema import sources as s_sources std_base = schema.get('std::BaseObject', type=s_sources.Source) std_id = std_base.getptr(schema, sn.UnqualName('id')) assert isinstance(std_id, so.SubclassableObject) return self.issubclass(schema, std_id) def is_link_source_property(self, schema: s_schema.Schema) -> bool: std_source = schema.get('std::source', type=so.SubclassableObject) return self.issubclass(schema, std_source) def is_link_target_property(self, schema: s_schema.Schema) -> bool: std_target = schema.get('std::target', type=so.SubclassableObject) return self.issubclass(schema, std_target) def is_endpoint_pointer(self, schema: s_schema.Schema) -> bool: std_source = schema.get('std::source', type=so.SubclassableObject) std_target = schema.get('std::target', type=so.SubclassableObject) return self.issubclass(schema, (std_source, std_target)) def is_special_pointer(self, schema: s_schema.Schema) -> bool: return self.get_shortname(schema).name in { 'source', 'target', 'id' } and (self.is_id_pointer(schema) or self.is_endpoint_pointer(schema)) @classmethod def is_property(cls) -> bool: # Property overloads return False def is_link_property(self, schema: s_schema.Schema) -> bool: raise NotImplementedError def is_dumpable(self, schema: s_schema.Schema) -> bool: return ( not self.is_pure_computable(schema) and not self.get_shortname(schema).name == '__type__' ) def is_non_concrete(self, schema: s_schema.Schema) -> bool: return self.get_source(schema) is None def get_referrer(self, schema: s_schema.Schema) -> Optional[so.Object]: return self.get_source(schema) def get_exclusive_constraints( self, schema: s_schema.Schema ) -> Sequence[constraints.Constraint]: if self.is_non_concrete(schema): raise ValueError(f'{self!r} is not a concrete pointer') exclusive = schema.get('std::exclusive', type=constraints.Constraint) ptr = self.get_nearest_non_derived_parent(schema) constrs = [] for constr in ptr.get_constraints(schema).objects(schema): if ( constr.issubclass(schema, exclusive) and not constr.get_subjectexpr(schema) and not constr.get_delegated(schema) ): assert not constr.get_except_expr(schema) constrs.append(constr) return constrs def is_exclusive(self, schema: s_schema.Schema) -> bool: return bool(self.get_exclusive_constraints(schema)) def singular( self, schema: s_schema.Schema, direction: PointerDirection = PointerDirection.Outbound, ) -> bool: # Determine the cardinality of a given endpoint set. if direction == PointerDirection.Outbound: cardinality = self.get_cardinality(schema) if cardinality is None or not cardinality.is_known(): vn = self.get_verbosename(schema, with_parent=True) raise AssertionError(f'cardinality of {vn} is unknown') return cardinality.is_single() else: return self.is_exclusive(schema) def get_implicit_bases(self, schema: s_schema.Schema) -> list[Pointer]: bases = super().get_implicit_bases(schema) # True implicit bases for pointers will have the same name my_name = self.get_shortname(schema) return [ b for b in bases if b.get_shortname(schema) == my_name ] def get_implicit_ancestors(self, schema: s_schema.Schema) -> list[Pointer]: ancestors = super().get_implicit_ancestors(schema) # True implicit ancestors for pointers will have the same name my_name = self.get_shortname(schema) return [ b for b in ancestors if b.get_shortname(schema) == my_name ] def has_user_defined_properties(self, schema: s_schema.Schema) -> bool: return False def allow_ref_propagation( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, ) -> bool: object_type = self.get_source(schema) if isinstance(object_type, s_types.Type): return ( not object_type.is_view(schema) or refdict.attr == 'pointers') else: return True def get_schema_reflection_default( self, schema: s_schema.Schema, ) -> Optional[str]: """Return the default expression if this is a reflection of a schema class field and the field has a defined default value. """ ptr = self.get_nearest_non_derived_parent(schema) src = ptr.get_source(schema) if src is None: # This is an abstract pointer return None ptr_name = ptr.get_name(schema) if ptr_name.module not in {'schema', 'sys', 'cfg'}: # This isn't a reflection type return None if isinstance(src, Pointer): # This is a link property tgt = src.get_target(schema) assert tgt is not None schema_objtype = tgt else: assert isinstance(src, s_types.Type) schema_objtype = src assert isinstance(schema_objtype, so.QualifiedObject) src_name = schema_objtype.get_name(schema) mcls = so.ObjectMeta.maybe_get_schema_class(src_name.name) if mcls is None: # This schema class is not (publicly) reflected. return None fname = ptr.get_shortname(schema).name if not mcls.has_field(fname): # This pointer is not a schema field. return None field = mcls.get_field(fname) if not isinstance(field, so.SchemaField): # Not a schema field, no default possible. return None f_default = field.default if ( f_default is None or f_default is so.NoDefault ): # No explicit default value. return None tgt = ptr.get_target(schema) assert tgt is not None if f_default is so.DEFAULT_CONSTRUCTOR: if ( issubclass( field.type, (collections.abc.Set, collections.abc.Sequence), ) and not issubclass(field.type, (str, bytes)) ): return f'<{tgt.get_displayname(schema)}>[]' else: return None default = qlquote.quote_literal(json.dumps(f_default)) if tgt.is_enum(schema): return f'<{tgt.get_displayname(schema)}>to_json({default})' else: return f'<{tgt.get_displayname(schema)}>to_json({default})' def as_create_delta( self, schema: s_schema.Schema, context: so.ComparisonContext, ) -> sd.CreateObject[Pointer]: delta = super().as_create_delta(schema, context) # When we are creating a new required property on an existing type, # we need to generate a AlterPointerLowerCardinality so that we can # attach a USING to it. if ( context.parent_ops and isinstance(context.parent_ops[-1], sd.AlterObject) and self.get_required(schema) and not self.get_default(schema) and not self.get_computable(schema) and not self.is_link_property(schema) and (required := delta._get_attribute_set_cmd('required')) ): special = sd.get_special_field_alter_handler( 'required', type(self)) assert special top_op = special(classname=delta.classname) delta.replace(required, top_op) top_op.add(required) context.parent_ops.append(delta) top_op.record_diff_annotations( schema=schema, orig_schema=None, object=self, orig_object=None, context=context, ) context.parent_ops.pop() return delta def get_local_rewrite( self, schema: s_schema.Schema, kind: qltypes.RewriteKind ) -> Optional[s_rewrites.Rewrite]: rewrites = self.get_rewrites(schema) if rewrites: for rewrite in rewrites.objects(schema): if rewrite.get_kind(schema) == kind: return rewrite return None def get_rewrite( self, schema: s_schema.Schema, kind: qltypes.RewriteKind ) -> Optional[s_rewrites.Rewrite]: if rw := self.get_local_rewrite(schema, kind): return rw for anc in self.get_ancestors(schema).objects(schema): if rw := anc.get_local_rewrite(schema, kind): return rw return None class PseudoPointer(abc.ABC): # An abstract base class for pointer-like objects, i.e. # pseudo-links used by the compiler to represent things like # tuple and type intersection. def is_tuple_indirection(self) -> bool: return False def is_type_intersection(self) -> bool: return False def get_bases(self, schema: s_schema.Schema) -> so.ObjectList[Pointer]: return so.ObjectList.create(schema, []) def get_ancestors(self, schema: s_schema.Schema) -> so.ObjectList[Pointer]: return so.ObjectList.create(schema, []) @abc.abstractmethod def get_name(self, schema: s_schema.Schema) -> sn.QualName: raise NotImplementedError def get_shortname(self, schema: s_schema.Schema) -> sn.QualName: return self.get_name(schema) def get_displayname(self, schema: s_schema.Schema) -> str: return str(self.get_name(schema)) def has_user_defined_properties(self, schema: s_schema.Schema) -> bool: return False def get_required(self, schema: s_schema.Schema) -> bool: return True @abc.abstractmethod def get_cardinality( self, schema: s_schema.Schema ) -> qltypes.SchemaCardinality: raise NotImplementedError def get_path_id_name(self, schema: s_schema.Schema) -> sn.QualName: return self.get_name(schema) def get_is_derived(self, schema: s_schema.Schema) -> bool: return False def get_owned(self, schema: s_schema.Schema) -> bool: return True def get_union_of( self, schema: s_schema.Schema, ) -> None: return None def get_intersection_of( self, schema: s_schema.Schema, ) -> None: return None def get_default( self, schema: s_schema.Schema, ) -> Optional[s_expr.Expression]: return None def get_expr(self, schema: s_schema.Schema) -> Optional[s_expr.Expression]: return None @abc.abstractmethod def get_source(self, schema: s_schema.Schema) -> so.Object: raise NotImplementedError @abc.abstractmethod def get_target(self, schema: s_schema.Schema) -> s_types.Type: raise NotImplementedError def get_near_endpoint( self, schema: s_schema.Schema, direction: PointerDirection, ) -> so.Object: if direction is PointerDirection.Outbound: return self.get_source(schema) else: raise AssertionError( f'inbound direction is not valid for {type(self)}' ) def get_far_endpoint( self, schema: s_schema.Schema, direction: PointerDirection, ) -> so.Object: if direction is PointerDirection.Outbound: return self.get_target(schema) else: raise AssertionError( f'inbound direction is not valid for {type(self)}' ) def is_link_property(self, schema: s_schema.Schema) -> bool: return False def is_non_concrete(self, schema: s_schema.Schema) -> bool: return False @abc.abstractmethod def singular( self, schema: s_schema.Schema, direction: PointerDirection = PointerDirection.Outbound, ) -> bool: raise NotImplementedError def material_type( self, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, PseudoPointer]: return schema, self def is_pure_computable(self, schema: s_schema.Schema) -> bool: return False def is_exclusive(self, schema: s_schema.Schema) -> bool: return False def get_schema_reflection_default( self, schema: s_schema.Schema, ) -> Optional[str]: return None PointerLike = Pointer | PseudoPointer @dataclasses.dataclass(repr=False, eq=False) class ComputableRef: """A shell for a computed target type.""" expr: qlast.Expr specified_type: Optional[s_types.TypeShell[s_types.Type]] = ( dataclasses.field(default=None) ) class PointerCommandContext( sd.ObjectCommandContext[Pointer], s_anno.AnnotationSubjectCommandContext, s_rewrites.RewriteSubjectCommandContext, ): pass class PointerCommandOrFragment[Pointer_T: Pointer]( referencing.ReferencedObjectCommandBase[Pointer_T] ): def is_property_command(self) -> bool: return self.get_schema_metaclass().is_property() def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) target_ref = self.get_local_attribute_value('target') inf_target_ref: Optional[s_types.TypeShell[s_types.Type]] # When cardinality/required is altered, we need to force a # reconsideration of expr if it exists in order to check # it against the new specifier or compute them on a # RESET. This is kind of unfortunate. if ( isinstance(self, sd.AlterObject) and ( ( self.has_attribute_value('cardinality') and not self.is_attribute_inherited('cardinality') ) or ( self.has_attribute_value('required') and not self.is_attribute_inherited('required') ) ) and not self.has_attribute_value('expr') and (expr := self.scls.get_expr(schema)) is not None ): self.set_attribute_value( 'expr', s_expr.Expression.not_compiled(expr) ) if isinstance(target_ref, ComputableRef): schema, inf_target_ref = self._parse_computable( target_ref.expr, schema, context) elif (expr := self.get_local_attribute_value('expr')) is not None: assert isinstance(expr, s_expr.Expression) schema = s_types.materialize_type_in_attribute( schema, context, self, 'target') schema, inf_target_ref = self._parse_computable( expr.parse(), schema, context) else: inf_target_ref = None if ( isinstance(self, sd.CreateObject) and not self.is_property_command() ): self.set_attribute_value('linkful', True) if inf_target_ref is not None: if inf_target_ref.has_intersection(): raise errors.UnsupportedFeatureError( ( f'unsupported type intersection in schema ' f'{inf_target_ref.get_name(schema).name}' ), hint=( f'Type intersections are currently ' f'unsupported as valid link targets.' ), span=self.span, ) span = self.get_attribute_span('target') self.set_attribute_value( 'target', inf_target_ref, span=span, computed=True, ) schema = s_types.materialize_type_in_attribute( schema, context, self, 'target') expr = self.get_local_attribute_value('expr') if expr is not None: # There is an expression, therefore it is a computable. self.set_attribute_value('computable', True) return schema def _parse_computable( self, expr: qlast.Expr, schema: s_schema.Schema, context: sd.CommandContext, ) -> tuple[ s_schema.Schema, s_types.TypeShell[s_types.Type], ]: from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.ir import utils as irutils # "source" attribute is set automatically as a refdict back-attr parent_ctx = self.get_referrer_context(context) assert parent_ctx is not None source_name = context.get_referrer_name(parent_ctx) assert isinstance(source_name, sn.QualName) source = schema.get(source_name) parent_vname = source.get_verbosename(schema) ptr_name = self.get_verbosename(parent=parent_vname) expression = self.compile_expr_field( schema, context, field=Pointer.get_field('expr'), value=s_expr.Expression.from_ast(expr, schema, context.modaliases), ) target = expression.irast.stype target_shell = target.as_shell(expression.irast.schema) if ( isinstance(target_shell, s_types.UnionTypeShell) and target_shell.opaque ): target = schema.get('std::BaseObject', type=s_types.Type) target_shell = target.as_shell(schema) orig_expr = expression.irast.expr if isinstance(orig_expr, irast.Set): orig_expr = irutils.unwrap_set(orig_expr) result_expr = orig_expr if isinstance(result_expr, irast.Set): if isinstance(result_expr.expr, irast.Pointer): result_expr, _ = irutils.collapse_type_intersection( result_expr) if self.is_property_command(): self.set_attribute_value('linkful', irutils.is_linkful(orig_expr)) # Process a computable pointer which potentially could be an # aliased link that should inherit link properties. computed_link_alias = None computed_link_alias_is_backward = None if ( isinstance(result_expr, irast.Set) and isinstance(result_expr.expr, irast.Pointer) and (expr_rptr := result_expr.expr) and expr_rptr.direction is PointerDirection.Outbound and not isinstance(expr_rptr.source.expr, irast.Pointer) and isinstance(expr_rptr.ptrref, irast.PointerRef) and schema.has_object(expr_rptr.ptrref.id) ): new_schema, aliased_ptr = irtyputils.ptrcls_from_ptrref( expr_rptr.ptrref, schema=schema ) # Only pointers coming from the same source as the # alias should be "inherited" (in order to preserve # link props). Random paths coming from other sources # get treated same as any other arbitrary expression # in a computable. if ( aliased_ptr.get_source(new_schema) == source and isinstance(aliased_ptr, self.get_schema_metaclass()) ): schema = new_schema computed_link_alias = aliased_ptr computed_link_alias_is_backward = False # Do similar logic, but in reverse, to see if the computed pointer # is a computed backlink that we need to keep track of. if ( computed_link_alias is None and isinstance(orig_expr, irast.Set) and isinstance(orig_expr.expr, irast.Pointer) and isinstance( orig_expr.expr.ptrref, irast.TypeIntersectionPointerRef) and len(orig_expr.expr.ptrref.rptr_specialization) == 1 and expr_rptr and expr_rptr.direction is not PointerDirection.Outbound ): ptrref = list(orig_expr.expr.ptrref.rptr_specialization)[0] new_schema, aliased_ptr = irtyputils.ptrcls_from_ptrref( ptrref, schema=schema ) if ( aliased_ptr.get_target(new_schema) == source and not ptrref.out_source.is_opaque_union and isinstance(aliased_ptr, self.get_schema_metaclass()) ): computed_link_alias_is_backward = True computed_link_alias = aliased_ptr schema = new_schema self.set_attribute_value('computed_link_alias', computed_link_alias) self.set_attribute_value( 'computed_link_alias_is_backward', computed_link_alias_is_backward) self.set_attribute_value('expr', expression) required, card = expression.irast.cardinality.to_schema_value() # Disallow referring to aliases from computed pointers. # We will support this eventually but it is pretty broken now # and best to consistently give an understandable error. for schema_ref in expression.irast.schema_refs: if isinstance(schema_ref, s_expraliases.Alias): span = self.get_attribute_span('target') an = schema_ref.get_verbosename(expression.irast.schema) raise errors.UnsupportedFeatureError( f'referring to {an} from computed {ptr_name} ' f'is unsupported', span=span, ) if ( not isinstance(source, Pointer) and not source.is_view(schema) # type: ignore and target.is_view(expression.irast.schema) ): raise errors.UnsupportedFeatureError( f'including a shape on schema-defined computed links ' f'is not yet supported', span=self.span, ) spec_target: Optional[ s_types.TypeShell[s_types.Type] | s_types.Type | ComputableRef ] = ( self.get_specified_attribute_value('target', schema, context)) spec_required: Optional[bool] = ( self.get_specified_attribute_value('required', schema, context)) spec_card: Optional[qltypes.SchemaCardinality] = ( self.get_specified_attribute_value('cardinality', schema, context)) if ( spec_target is not None and ( not isinstance(spec_target, ComputableRef) or (spec_target := spec_target.specified_type) is not None ) ): if isinstance(spec_target, s_types.TypeShell): spec_target_type = spec_target.resolve(schema) else: spec_target_type = spec_target mschema, inferred_target_type = target.material_type( expression.irast.schema) if spec_target_type != inferred_target_type: span = self.get_attribute_span('target') raise errors.SchemaDefinitionError( f'the type inferred from the expression ' f'of the computed {ptr_name} ' f'is {inferred_target_type.get_verbosename(mschema)}, ' f'which does not match the explicitly specified ' f'{spec_target_type.get_verbosename(schema)}', span=span ) if spec_required and not required: span = self.get_attribute_span('target') raise errors.SchemaDefinitionError( f'possibly an empty set returned by an ' f'expression for the computed ' f'{ptr_name} ' f"explicitly declared as 'required'", span=span ) if ( spec_card is qltypes.SchemaCardinality.One and card is not qltypes.SchemaCardinality.One ): span = self.get_attribute_span('target') raise errors.SchemaDefinitionError( f'possibly more than one element returned by an ' f'expression for the computed ' f'{ptr_name} ' f"explicitly declared as 'single'", span=span ) if spec_card is None: self.set_attribute_value('cardinality', card, computed=True) if spec_required is None: self.set_attribute_value('required', required, computed=True) if ( not is_view_source(source, schema) and expression.irast.volatility.is_volatile() ): span = self.get_attribute_span('target') raise errors.SchemaDefinitionError( f'volatile functions are not permitted in schema-defined ' f'computed expressions', span=span ) self.set_attribute_value('computable', True) return schema, target_shell def _compile_expr( self, schema: s_schema.Schema, context: sd.CommandContext, expr: s_expr.Expression, *, in_ddl_context_name: Optional[str] = None, track_schema_ref_exprs: bool = False, singleton_result_expected: bool = False, target_as_singleton: bool = False, expr_description: Optional[str] = None, no_query_rewrites: bool = False, make_globals_empty: bool = False, span: Optional[parsing.Span] = None, detached: bool = False, should_set_path_prefix_anchor: bool = True ) -> s_expr.CompiledExpression: singletons: list[s_types.Type | Pointer] = [] parent_ctx = self.get_referrer_context_or_die(context) source = parent_ctx.op.get_object(schema, context) if ( isinstance(source, Pointer) and not source.get_source(schema) ): # If the source is an abstract link, we need to # make up an object and graft the link onto it, # because the compiler really does not know what # to make of a link without a source or target. from edb.schema import objtypes as s_objtypes base_obj = schema.get( s_objtypes.ObjectType.get_default_base_name(), type=s_objtypes.ObjectType ) schema, view = base_obj.derive_subtype( schema, name=sn.QualName("__derived__", "FakeAbstractLinkBase"), mark_derived=True, transient=True, ) schema, source = source.derive_ref( schema, view, target=view, mark_derived=True, transient=True, ) assert isinstance(source, (s_types.Type, Pointer)) singletons = [source] if target_as_singleton: src = self.scls.get_source(schema) if isinstance(src, Pointer): # linkprop singletons.append(src) else: singletons.append(self.scls) with errors.ensure_span(span or expr.parse().span): options = qlcompiler.CompilerOptions( modaliases=context.modaliases, schema_object_context=self.get_schema_metaclass(), anchors={'__source__': source}, path_prefix_anchor=( '__source__' if should_set_path_prefix_anchor else None), singletons=singletons, apply_query_rewrites=( not context.stdmode and not no_query_rewrites ), make_globals_empty=make_globals_empty, track_schema_ref_exprs=track_schema_ref_exprs, in_ddl_context_name=in_ddl_context_name, ) compiled = expr.compiled( schema=schema, options=options, detached=detached, context=context, ) if singleton_result_expected and compiled.cardinality.is_multi(): if expr_description is None: expr_description = 'an expression' raise errors.SchemaError( f'possibly more than one element returned by ' f'{expr_description}, while a singleton is expected' ) return compiled def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: if field.name in {'default', 'expr'}: if field.name == 'expr': parent_ctx = self.get_referrer_context_or_die(context) source = parent_ctx.op.get_object(schema, context) parent_vname = source.get_verbosename(schema) ptr_name = self.get_verbosename(parent=parent_vname) in_ddl_context_name = f'computed {ptr_name}' detached = False else: in_ddl_context_name = None detached = True # If we are in a link property's default field # do not set path prefix anchor, because link properties # cannot have defaults that reference the object being inserted should_set_path_prefix_anchor = True if field.name == 'default': # We are checking if the parent context is a pointer # (i.e. a link or a property). # If so, do not set the path prefix anchor. parent_ctx = self.get_referrer_context_or_die(context) source = parent_ctx.op.get_object(schema, context) if isinstance(source, Pointer): should_set_path_prefix_anchor = False return self._compile_expr( schema, context, value, in_ddl_context_name=in_ddl_context_name, track_schema_ref_exprs=track_schema_ref_exprs, detached=detached, should_set_path_prefix_anchor=should_set_path_prefix_anchor, ) else: return super().compile_expr_field( schema, context, field, value, track_schema_ref_exprs) def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: if field.name == 'expr': return None elif field.name == 'default': return None else: raise NotImplementedError(f'unhandled field {field.name!r}') class PointerCommand[Pointer_T: Pointer]( referencing.NamedReferencedInheritingObjectCommand[Pointer_T], constraints.ConsistencySubjectCommand[Pointer_T], s_anno.AnnotationSubjectCommand[Pointer_T], PointerCommandOrFragment[Pointer_T], ): def _validate_computables( self, schema: s_schema.Schema, context: sd.CommandContext ) -> None: scls = self.scls if scls.get_from_alias(schema): return is_computable = scls.is_pure_computable(schema) is_owned = scls.get_owned(schema) if is_computable: if any( b.is_non_concrete(schema) and str(b.get_name(schema)) not in ( 'std::link', 'std::property') for b in scls.get_bases(schema).objects(schema) ): raise errors.SchemaDefinitionError( f'it is illegal for the computed ' f'{scls.get_verbosename(schema, with_parent=True)} ' f'to extend an abstract ' f'{scls.get_schema_class_displayname()}', span=self.span, ) # Get the non-generic, explicitly declared ancestors as the # limitations on computables apply to explicitly declared # pointers, not just a long chain of inherited ones. # # Because this is potentially nested inside a command to # delete a property some ancestors may not be present in the # schema anymore, so we will only consider the ones that still # are (which should still be valid). lineage: list[Pointer_T] = [] for iid in scls.get_ancestors(schema)._ids: try: p = cast(Pointer_T, schema.get_by_id(iid)) if not p.is_non_concrete(schema) and p.get_owned(schema): lineage.append(p) except errors.InvalidReferenceError: pass if is_owned: # If the current pointer is explicitly declared, add it at # the end of the lineage. lineage.insert(0, scls) status = self._validate_lineage(schema, lineage) if status is LineageStatus.VALID: return if is_computable and is_owned: # Overloading with a computable raise errors.SchemaDefinitionError( f'it is illegal for the computed ' f'{scls.get_verbosename(schema, with_parent=True)} ' f'to overload an existing ' f'{scls.get_schema_class_displayname()}', span=self.span, ) else: if status is LineageStatus.MIXED: raise errors.SchemaDefinitionError( f'it is illegal for the ' f'{scls.get_verbosename(schema, with_parent=True)} ' f'to extend both a computed and a non-computed ' f'{scls.get_schema_class_displayname()}', span=self.span, ) elif status is LineageStatus.MULTIPLE_COMPUTABLES: raise errors.SchemaDefinitionError( f'it is illegal for the ' f'{scls.get_verbosename(schema, with_parent=True)} ' f'to extend more than one computed ' f'{scls.get_schema_class_displayname()}', span=self.span, ) def _validate_lineage( self, schema: s_schema.Schema, lineage: list[Pointer_T], ) -> LineageStatus: if len(lineage) <= 1: # Having at most 1 item in the lineage is always valid. return LineageStatus.VALID head, *rest = lineage if not head.is_pure_computable(schema): # The rest of the lineage must all be regular if any(b.is_pure_computable(schema) for b in rest): return LineageStatus.MIXED else: return LineageStatus.VALID else: # We have a computable with some non-empty lineage. Which # could be valid only if this is some aliasing followed by # regular pointers only. prev_shortname = head.get_shortname(schema) prev_is_comp = True for b in rest: cur_is_comp = b.is_pure_computable(schema) cur_shortname = b.get_shortname(schema) if prev_is_comp: # Computables cannot overload, but they can alias # other pointers, however aliases cannot have # matching shortnames. if cur_shortname == prev_shortname: # Names match, so this is illegal. if cur_is_comp: return LineageStatus.MULTIPLE_COMPUTABLES else: return LineageStatus.MIXED else: # Only regular pointers expected from here on. if cur_is_comp: return LineageStatus.MULTIPLE_COMPUTABLES prev_shortname = cur_shortname prev_is_comp = cur_is_comp # Did not find anything wrong with the computable lineage. return LineageStatus.VALID def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: """Check that pointer definition is sound.""" from edb.ir import ast as irast referrer_ctx = self.get_referrer_context(context) if referrer_ctx is None: return self._validate_computables(schema, context) scls: Pointer = self.scls if not scls.get_owned(schema): return default_expr: Optional[s_expr.Expression] = scls.get_default(schema) if default_expr is not None: if not default_expr.irast: default_expr = self._compile_expr( schema, context, default_expr, detached=True, ) assert default_expr.irast if scls.is_id_pointer(schema): self._check_id_default( schema, context, default_expr.irast.expr) span = self.get_attribute_span('default') ir = default_expr.irast default_schema = ir.schema default_type = ir.stype assert default_type is not None ptr_target = scls.get_target(schema) assert ptr_target is not None if ( default_type.is_view(default_schema) # Using an alias/global always creates a new subtype view, # but we want to allow those here, so check whether there # is a shape more directly. and not ( len(shape := ir.view_shapes.get(default_type, [])) == 1 and shape[0].is_id_pointer(default_schema) ) ): raise errors.SchemaDefinitionError( f'default expression may not include a shape', span=span, ) if not default_type.assignment_castable_to( ptr_target, default_schema): raise errors.SchemaDefinitionError( f'default expression is of invalid type: ' f'{default_type.get_displayname(default_schema)}, ' f'expected {ptr_target.get_displayname(schema)}', span=span, ) # "required" status of defaults should not be enforced # because it's impossible to actually guarantee that any # SELECT involving a path is non-empty ptr_cardinality = scls.get_cardinality(schema) _default_required, default_cardinality = \ default_expr.irast.cardinality.to_schema_value() if (ptr_cardinality is qltypes.SchemaCardinality.One and default_cardinality != ptr_cardinality): raise errors.SchemaDefinitionError( f'possibly more than one element returned by ' f'the default expression for ' f'{scls.get_verbosename(schema)} declared as ' f"'single'", span=span, ) # prevent references to local links, only properties pointers = ast.find_children(default_expr.irast, irast.Pointer) scls_source = scls.get_source(schema) assert scls_source for pointer in pointers: if pointer.source.typeref.id != scls_source.id: continue if not isinstance(pointer.ptrref, irast.PointerRef): continue s_pointer = schema.get_by_id(pointer.ptrref.id, type=Pointer) card = s_pointer.get_cardinality(schema) if s_pointer.is_property() and card.is_multi(): raise errors.SchemaDefinitionError( f"default expression cannot refer to multi properties " "of inserted object", span=span, hint="this is a temporary implementation restriction", ) if not s_pointer.is_property(): raise errors.SchemaDefinitionError( f"default expression cannot refer to links " "of inserted object", span=span, hint='this is a temporary implementation restriction' ) if ( self.scls.get_rewrite(schema, qltypes.RewriteKind.Update) or self.scls.get_rewrite(schema, qltypes.RewriteKind.Insert) ): if self.scls.get_cardinality(schema).is_multi(): raise errors.SchemaDefinitionError( f"cannot specify a rewrite for " f"{scls.get_verbosename(schema, with_parent=True)} " f"because it is multi", span=self.span, hint='this is a temporary implementation restriction' ) if self.scls.has_user_defined_properties(schema): raise errors.SchemaDefinitionError( f"cannot specify a rewrite for " f"{scls.get_verbosename(schema, with_parent=True)} " f"because it has link properties", span=self.span, hint='this is a temporary implementation restriction' ) def _check_id_default( self, schema: s_schema.Schema, context: sd.CommandContext, expr: irast.Base, ) -> None: """If default is being set on id, check it against a whitelist""" from edb.ir import ast as irast from edb.ir import utils as irutils # If we add more, we probably want a better mechanism ID_ALLOWLIST = ( 'std::uuid_generate_v1mc', 'std::uuid_generate_v4', ) while ( isinstance(expr, irast.Set) and expr.expr and irutils.is_trivial_select(expr.expr) ): expr = expr.expr.result if not ( isinstance(expr, irast.Set) and isinstance(expr.expr, irast.FunctionCall) and str(expr.expr.func_shortname) in ID_ALLOWLIST ): span = self.get_attribute_span('default') options = ', '.join(ID_ALLOWLIST) raise errors.SchemaDefinitionError( "invalid default value for 'id' property", hint=f'default must be a call to one of: {options}', span=span, ) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, PointerCommand) referrer_ctx = cls.get_referrer_context(context) if referrer_ctx is not None: if getattr(astnode, 'declared_overloaded', False): cmd.set_attribute_value('declared_overloaded', True) else: # This is an abstract property/link if cmd.get_attribute_value('default') is not None: typ = cls.get_schema_metaclass().get_schema_class_displayname() raise errors.SchemaDefinitionError( f"'default' is not a valid field for an abstract {typ}", span=astnode.span) return cmd def _process_create_or_alter_ast( self, schema: s_schema.Schema, astnode: qlast.CreateConcretePointer, context: sd.CommandContext, ) -> None: """Handle the CREATE {PROPERTY|LINK} AST node. This may be called in the context of either Create or Alter. """ from edb.schema import sources as s_sources if astnode.is_required is not None: self.set_attribute_value( 'required', astnode.is_required, span=astnode.span, ) if astnode.cardinality is not None: if isinstance(self, sd.CreateObject): self.set_attribute_value( 'cardinality', astnode.cardinality, span=astnode.span, ) else: handler = sd.get_special_field_alter_handler_for_context( 'cardinality', context) assert handler is not None set_field = qlast.SetField( name='cardinality', value=qlast.Constant.string( str(astnode.cardinality), ), special_syntax=True, span=astnode.span, ) apc = handler._cmd_tree_from_ast(schema, set_field, context) self.add(apc) parent_ctx = self.get_referrer_context_or_die(context) source_name = context.get_referrer_name(parent_ctx) self.set_attribute_value( 'source', so.ObjectShell(name=source_name, schemaclass=s_sources.Source), ) target_ref: None | s_types.TypeShell[s_types.Type] | ComputableRef if astnode.target: if isinstance(astnode.target, qlast.TypeExpr): target_ref = utils.ast_to_type_shell( astnode.target, metaclass=s_types.Type, modaliases=context.modaliases, module=source_name.module, schema=schema, ) else: # computable qlcompiler.normalize( astnode.target, schema=schema, modaliases=context.modaliases ) target_ref = ComputableRef(astnode.target) else: # Target is inherited. target_ref = None if isinstance(self, sd.CreateObject): assert astnode.target is not None self.set_attribute_value( 'target', target_ref, span=astnode.target.span, ) elif target_ref is not None: assert astnode.target is not None self.set_attribute_value( 'target', target_ref, span=astnode.target.span, ) def _process_alter_ast( self, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> None: """Handle the ALTER {PROPERTY|LINK} AST node.""" expr_cmd = qlast.get_ddl_field_command(astnode, 'expr') if expr_cmd is not None: expr = expr_cmd.value if expr is not None: assert isinstance(expr, qlast.Expr) qlcompiler.normalize( expr, schema=schema, modaliases=context.modaliases ) target_ref = ComputableRef( expr, specified_type=self.get_attribute_value('target'), ) self.set_attribute_value( 'target', target_ref, span=expr.span, ) self.discard_attribute('expr') class CreatePointer[Pointer_T: Pointer]( referencing.CreateReferencedInheritingObject[Pointer_T], PointerCommand[Pointer_T], ): def ast_ignore_ownership(self) -> bool: # If we have a SET REQUIRED with a fill_expr, we need to force # this operation to appear in the AST in a useful position, # even if it normally would be skipped. subs = list(self.get_subcommands(type=AlterPointerLowerCardinality)) return len(subs) == 1 and bool(subs[0].fill_expr) @classmethod def as_inherited_ref_cmd( cls, *, schema: s_schema.Schema, context: sd.CommandContext, astnode: qlast.ObjectDDL, bases: list[Pointer_T], referrer: so.Object, ) -> sd.ObjectCommand[Pointer_T]: cmd = super().as_inherited_ref_cmd( schema=schema, context=context, astnode=astnode, bases=bases, referrer=referrer, ) if ( ( isinstance(referrer, s_types.Type) and referrer.is_view(schema) ) or ( isinstance(referrer, Pointer) and referrer.get_from_alias(schema) ) ): cmd.set_attribute_value('from_alias', True) cmd.set_object_aux_data('from_alias', True) return cmd class AlterPointer[Pointer_T: Pointer]( referencing.AlterReferencedInheritingObject[Pointer_T], PointerCommand[Pointer_T], ): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) if not context.canonical and ( self.get_attribute_value('expr') is not None or self.get_orig_attribute_value('expr') is not None or bool(self.get_subcommands(type=constraints.ConstraintCommand)) or ( self.get_attribute_value('default') is not None and self.scls.is_link_property(schema) ) ): extras: dict[so.Object, list[str]] = {} if ( self.get_attribute_value('expr') is not None or self.get_orig_attribute_value('expr') is not None ): for constr in ( self.scls.get_constraints(schema).objects(schema) ): extras[constr] = ['finalexpr'] # If the expression gets changed, we need to propagate # this change to other expressions referring to this one, # in case there are any cycles caused by this change. # # Also, if constraints are modified, that can affect # cardinality of other expressions using backlinks. # # Also when setting a default on a link property, since # access policies need to be prevented from accessing them. # (Ugh.) # # FIXME: sometimes this can cause a constraint to get # altered because we've created another constraint, which # could change inference schema = self._propagate_if_expr_refs( schema, context, action=self.get_friendly_description(schema=schema), extra_refs=extras, ) return schema @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> referencing.AlterReferencedInheritingObject[Any]: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, PointerCommand) if isinstance(astnode, qlast.CreateConcreteLink): cmd._process_create_or_alter_ast(schema, astnode, context) else: expr_cmd = qlast.get_ddl_field_command(astnode, 'expr') if expr_cmd is not None: expr = expr_cmd.value if expr is None: # `RESET EXPRESSION` detected aop = sd.AlterObjectProperty( property='expr', new_value=None, span=astnode.span, ) cmd.add(aop) assert isinstance(cmd, referencing.AlterReferencedInheritingObject) return cmd def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) # Handle `RESET EXPRESSION` here if ( self.has_attribute_value('expr') and not self.is_attribute_inherited('expr') and self.get_attribute_value('expr') is None ): self.set_attribute_value('linkful', not self.is_property_command()) old_expr = self.get_orig_attribute_value('expr') pointer = schema.get(self.classname, type=Pointer) if old_expr is None: # Get the old value from the schema if the old_expr # attribute isn't set. old_expr = pointer.get_expr(schema) if old_expr is not None: # If the expression was explicitly set to None, # that means that `RESET EXPRESSION` was executed # and this is no longer a computable. self.set_attribute_value('computable', None) computed_fields = pointer.get_computed_fields(schema) if ( 'required' in computed_fields and not self.has_attribute_value('required') ): self.set_attribute_value('required', None) if ( 'cardinality' in computed_fields and not self.has_attribute_value('cardinality') ): self.set_attribute_value('cardinality', None) self.set_attribute_value( 'computed_link_alias_is_backward', None) self.set_attribute_value('computed_link_alias', None) # Clear the placeholder value for 'expr'. self.set_attribute_value('expr', None) return schema def canonicalize_alter_from_external_ref( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: # if the delta involves re-setting a computable # expression, then we also need to change the type to the # new expression type expr = self.get_attribute_value('expr') if expr is None: # This shouldn't happen, but asserting here doesn't seem quite # right either. return assert isinstance(expr, s_expr.Expression) pointer = schema.get(self.classname, type=Pointer) source = cast(s_types.Type, pointer.get_source(schema)) expression = expr.compiled( schema=schema, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, anchors={'__source__': source}, path_prefix_anchor='__source__', singletons=frozenset([source]), apply_query_rewrites=not context.stdmode, ), context=context, ) target = expression.irast.stype self.set_attribute_value( 'target', target, inherited=pointer.field_is_inherited(schema, 'target'), computed=pointer.field_is_computed(schema, 'target'), ) def is_data_safe(self) -> bool: # HACK: expr ought to be managed by AlterSpecialObjectField # the way that target/required/cardinality are. return super().is_data_safe() and not ( self.get_attribute_value('expr') is not None and self.get_orig_attribute_value('expr') is None ) class DeletePointer[Pointer_T: Pointer]( referencing.DeleteReferencedInheritingObject[Pointer_T], PointerCommand[Pointer_T], ): def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._delete_begin(schema, context) if ( not context.canonical and (target := self.scls.get_target(schema)) is not None and not self.scls.is_endpoint_pointer(schema) and (del_cmd := target.as_type_delete_if_unused(schema)) is not None ): self.add_caused(del_cmd) if not context.canonical: # We need to do a propagate here, too, since there could # be backrefs to this pointer that technically reference # us but will be fine if it is deleted. schema = self._propagate_if_expr_refs( schema, context, action=self.get_friendly_description(schema=schema), ) return schema def _canonicalize( self, schema: s_schema.Schema, context: sd.CommandContext, scls: Pointer_T, ) -> list[sd.Command]: commands = super()._canonicalize(schema, context, scls) # Any union type that references this field needs to have it # deleted. unions = schema.get_referrers( self.scls, scls_type=Pointer, field_name='union_of') for union in unions: group, op, _ = union.init_delta_branch( schema, context, sd.DeleteObject) op.update(op._canonicalize(schema, context, union)) commands.append(group) return commands class SetPointerType[Pointer_T: Pointer]( referencing.ReferencedInheritingObjectCommand[Pointer_T], inheriting.AlterInheritingObjectFragment[Pointer_T], sd.AlterSpecialObjectField[Pointer_T], PointerCommandOrFragment[Pointer_T], ): cast_expr = struct.Field(s_expr.Expression, default=None) def get_verb(self) -> str: return 'alter the type of' def is_data_safe(self) -> bool: # A computed target means this must be an inferred computed # property, so it is data safe. return self.is_attribute_computed('target') def record_diff_annotations( self, *, schema: s_schema.Schema, orig_schema: Optional[s_schema.Schema], context: so.ComparisonContext, object: Optional[so.Object], orig_object: Optional[so.Object], ) -> None: super().record_diff_annotations( schema=schema, orig_schema=orig_schema, context=context, orig_object=orig_object, object=object, ) if orig_schema is None: return if not context.generate_prompts: return old_type_shell = self.get_orig_attribute_value('target') new_type_shell = self.get_attribute_value('target') assert isinstance(old_type_shell, s_types.TypeShell) assert isinstance(new_type_shell, s_types.TypeShell) old_type: Optional[s_types.Type] = None try: old_type = old_type_shell.resolve(schema) except errors.InvalidReferenceError: # The original type does not exist in the new schema, # which means either of the two things: # 1) the original type is a collection, in which case we can # attempt to temporarily recreate it in the new schema to # check castability; # 2) the original type is not a collection, and was removed # in the new schema; there is no way for us to infer # castability and we assume a cast expression is needed. if isinstance(old_type_shell, s_types.CollectionTypeShell): try: create = old_type_shell.as_create_delta(schema) schema = sd.apply(create, schema=schema) except errors.InvalidReferenceError: # A removed type is part of the collection, # can't do anything about that. pass else: old_type = old_type_shell.resolve(schema) new_type = new_type_shell.resolve(schema) assert len(context.parent_ops) > 1 ptr_op = context.parent_ops[-1] src_op = context.parent_ops[-2] is_computable = bool(ptr_op.get_attribute_value('expr')) needs_cast = ( old_type is None or self._needs_cast_expr( schema=schema, ptr_op=ptr_op, src_op=src_op, old_type=old_type, new_type=new_type, is_computable=is_computable, ) ) if needs_cast: placeholder_name = context.get_placeholder('cast_expr') desc = self.get_friendly_description(schema=schema) prompt = f'Please specify a conversion expression to {desc}' self.set_annotation('required_input', dict( placeholder=placeholder_name, prompt=prompt, old_type=str(old_type.get_name(schema)) if old_type else None, old_type_is_object=old_type and old_type.is_object_type(), new_type=str(new_type.get_name(schema)), new_type_is_object=new_type.is_object_type(), pointer_name=self.get_displayname(), )) self.cast_expr = s_expr.Expression.from_ast( qlast.Placeholder(name=placeholder_name), schema, ) def _is_endpoint_property(self) -> bool: mcls = self.get_schema_metaclass() shortname = mcls.get_shortname_static(self.classname) quals = sn.quals_from_fullname(self.classname) if not quals: return False else: source = quals[0] return ( sn.is_fullname(source) and str(shortname) in {'__::source', '__::target'} ) def _needs_cast_expr( self, *, schema: s_schema.Schema, ptr_op: sd.ObjectCommand[so.Object], src_op: sd.ObjectCommand[so.Object], old_type: s_types.Type, new_type: s_types.Type, is_computable: bool, ) -> bool: return ( not old_type.assignment_castable_to(new_type, schema) and not is_computable and not ptr_op.maybe_get_object_aux_data('from_alias') and self.cast_expr is None and not self._is_endpoint_property() and not ( ptr_op.get_attribute_value('declared_overloaded') or isinstance(src_op, sd.CreateObject) ) ) def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: from edb.ir import utils as irutils orig_schema = schema orig_rec = context.current().enable_recursion context.current().enable_recursion = False schema = super()._alter_begin(schema, context) context.current().enable_recursion = orig_rec scls = self.scls vn = scls.get_verbosename(schema, with_parent=True) orig_target = scls.get_target(orig_schema) new_target = scls.get_target(schema) if new_target is None: # This will happen if `RESET TYPE` was called # on a non-inherited type. raise errors.SchemaError( f'cannot RESET TYPE of {vn} because it is not inherited', span=self.span, ) if not context.canonical and orig_target != new_target: assert orig_target is not None assert new_target is not None ptr_op = self.get_parent_op(context) src_op = self.get_referrer_context_or_die(context).op if self._needs_cast_expr( schema=schema, ptr_op=ptr_op, src_op=src_op, old_type=orig_target, new_type=new_target, is_computable=self.scls.is_pure_computable(schema), ): vn = scls.get_verbosename(schema, with_parent=True) ot = orig_target.get_verbosename(schema) nt = new_target.get_verbosename(schema) raise errors.SchemaError( f'{vn} cannot be cast automatically from ' f'{ot} to {nt}', hint=( 'You might need to specify a conversion ' 'expression in a USING clause' ), span=self.span, ) if self.cast_expr is not None: vn = scls.get_verbosename(schema, with_parent=True) self.cast_expr = self._compile_expr( schema=orig_schema, context=context, expr=self.cast_expr, target_as_singleton=True, singleton_result_expected=True, no_query_rewrites=True, expr_description=( f'the USING clause for the alteration of {vn}' ), ) using_type = self.cast_expr.stype if not using_type.assignment_castable_to( new_target, self.cast_expr.schema, ): ot = using_type.get_verbosename(self.cast_expr.schema) nt = new_target.get_verbosename(schema) raise errors.SchemaError( f'result of USING clause for the alteration of ' f'{vn} cannot be cast automatically from ' f'{ot} to {nt} ', hint='You might need to add an explicit cast.', span=self.span, ) if using_type.is_view(self.cast_expr.schema): raise errors.SchemaError( f'result of USING clause for the alteration of ' f'{vn} may not include a shape', span=self.span, ) if irutils.contains_dml(self.cast_expr.ir_statement): raise errors.SchemaError( f'USING clause for the alteration of type of {vn} ' f'cannot include mutating statements', span=self.span, ) schema = self._propagate_if_expr_refs( schema, context, action=self.get_friendly_description(schema=schema), ) if orig_target is not None and scls.is_property(): if cleanup_op := orig_target.as_type_delete_if_unused(schema): parent_op = self.get_parent_op(context) parent_op.add_caused(cleanup_op) if not context.canonical: if context.enable_recursion: self._propagate_ref_field_alter_in_inheritance( schema, context, field_name='target', ) return schema @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, SetPointerType) if ( isinstance(astnode, qlast.SetPointerType) and astnode.cast_expr is not None ): cmd.cast_expr = s_expr.Expression.from_ast( astnode.cast_expr, schema, context.modaliases, context.localnames, ) return cmd def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: set_field = super()._get_ast(schema, context, parent_node=parent_node) if set_field is None or self.is_attribute_computed('target'): return None else: assert isinstance(set_field, qlast.SetField) assert not isinstance(set_field.value, qlast.Expr) return qlast.SetPointerType( value=set_field.value, cast_expr=( self.cast_expr.parse() if self.cast_expr is not None else None ) ) class AlterPointerUpperCardinality[Pointer_T: Pointer]( referencing.ReferencedInheritingObjectCommand[Pointer_T], inheriting.AlterInheritingObjectFragment[Pointer_T], sd.AlterSpecialObjectField[Pointer_T], PointerCommandOrFragment[Pointer_T], ): """Handler for the "cardinality" field changes.""" conv_expr = struct.Field(s_expr.Expression, default=None) def get_friendly_description( self, *, parent_op: Optional[sd.Command] = None, schema: Optional[s_schema.Schema] = None, object: Any = None, object_desc: Optional[str] = None, ) -> str: object_desc = self.get_friendly_object_name_for_description( parent_op=parent_op, schema=schema, object=object, object_desc=object_desc, ) new_card = self.get_attribute_value('cardinality') if new_card is None: # RESET CARDINALITY (to default) new_card = qltypes.SchemaCardinality.One return ( f"convert {object_desc} to" f" {new_card.as_ptr_qual()!r} cardinality" ) def is_data_safe(self) -> bool: # A computed target means this must be an inferred computed # property, so it is data safe. if self.is_attribute_computed('cardinality'): return True old_val = self.get_orig_attribute_value('cardinality') new_val = self.get_attribute_value('cardinality') if ( old_val is qltypes.SchemaCardinality.Many and new_val is qltypes.SchemaCardinality.One ): return False else: return True def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super()._alter_begin(schema, context) scls = self.scls orig_card = scls.get_cardinality(orig_schema) new_card = scls.get_cardinality(schema) is_computed = 'cardinality' in scls.get_computed_fields(schema) if orig_card == new_card or is_computed: # The actual value hasn't changed, nothing to do here. return schema if not context.canonical: vn = scls.get_verbosename(schema, with_parent=True) desc = self.get_friendly_description(schema=schema) ptr_op = self.get_parent_op(context) src_op = self.get_referrer_context_or_die(context).op if self._needs_conv_expr( schema=schema, ptr_op=ptr_op, src_op=src_op, ): vn = scls.get_verbosename(schema, with_parent=True) raise errors.SchemaError( f'cannot automatically {desc}', hint=( 'You need to specify a conversion ' 'expression in a USING clause' ), span=self.span, ) if self.conv_expr is not None: self.conv_expr = self._compile_expr( schema=orig_schema, context=context, expr=self.conv_expr, target_as_singleton=False, singleton_result_expected=True, no_query_rewrites=True, expr_description=( f'the USING clause for the alteration of {vn}' ), ) using_type = self.conv_expr.stype ptr_type = scls.get_target(schema) assert ptr_type is not None if not using_type.assignment_castable_to( ptr_type, self.conv_expr.schema, ): ot = using_type.get_verbosename(self.conv_expr.schema) nt = ptr_type.get_verbosename(schema) raise errors.SchemaError( f'result of USING clause for the alteration of ' f'{vn} cannot be cast automatically from ' f'{ot} to {nt} ', hint='You might need to add an explicit cast.', span=self.span, ) if using_type.is_view(self.conv_expr.schema): raise errors.SchemaError( f'result of USING clause for the alteration of ' f'{vn} may not include a shape', span=self.span, ) schema = self._propagate_if_expr_refs(schema, context, action=desc) self._propagate_ref_field_alter_in_inheritance( schema, context, field_name='cardinality', ) return schema def record_diff_annotations( self, *, schema: s_schema.Schema, orig_schema: Optional[s_schema.Schema], context: so.ComparisonContext, object: Optional[so.Object], orig_object: Optional[so.Object], ) -> None: super().record_diff_annotations( schema=schema, orig_schema=orig_schema, context=context, orig_object=orig_object, object=object, ) if orig_schema is None: return if not context.generate_prompts: return assert len(context.parent_ops) > 1 ptr_op = context.parent_ops[-1] src_op = context.parent_ops[-2] needs_conv_expr = self._needs_conv_expr( schema=schema, ptr_op=ptr_op, src_op=src_op, ) if needs_conv_expr: placeholder_name = context.get_placeholder('conv_expr') desc = self.get_friendly_description( schema=schema, parent_op=src_op) prompt = ( f'Please specify an expression in order to {desc}' ) type_name = _get_target_name_in_diff( schema=schema, orig_schema=orig_schema, object=object, orig_object=orig_object, ) self.set_annotation('required_input', dict( placeholder=placeholder_name, prompt=prompt, type=str(type_name), pointer_name=self.get_displayname(), )) self.conv_expr = s_expr.Expression.from_ast( qlast.Placeholder(name=placeholder_name), schema, ) def _needs_conv_expr( self, *, schema: s_schema.Schema, ptr_op: sd.ObjectCommand[so.Object], src_op: sd.ObjectCommand[so.Object], ) -> bool: old_card = ( self.get_orig_attribute_value('cardinality') or qltypes.SchemaCardinality.One ) new_card = ( self.get_attribute_value('cardinality') or qltypes.SchemaCardinality.One ) return ( old_card is qltypes.SchemaCardinality.Many and new_card is qltypes.SchemaCardinality.One and not self.is_attribute_computed('cardinality') and not self.is_attribute_inherited('cardinality') and not ptr_op.maybe_get_object_aux_data('from_alias') and self.conv_expr is None and not ( ptr_op.get_attribute_value('expr') or ptr_op.get_orig_attribute_value('expr') ) and not ( ptr_op.get_attribute_value('declared_overloaded') or isinstance(src_op, sd.CreateObject) ) ) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, AlterPointerUpperCardinality) if ( isinstance(astnode, qlast.SetPointerCardinality) and astnode.conv_expr is not None ): cmd.conv_expr = s_expr.Expression.from_ast( astnode.conv_expr, schema, context.modaliases, context.localnames, ) return cmd def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: set_field = super()._get_ast(schema, context, parent_node=parent_node) if set_field is None: return None else: assert isinstance(set_field, qlast.SetField) return qlast.SetPointerCardinality( value=set_field.value, conv_expr=( self.conv_expr.parse() if self.conv_expr is not None else None ) ) class AlterPointerLowerCardinality[Pointer_T: Pointer]( referencing.ReferencedInheritingObjectCommand[Pointer_T], inheriting.AlterInheritingObjectFragment[Pointer_T], sd.AlterSpecialObjectField[Pointer_T], PointerCommandOrFragment[Pointer_T], ): """Handler for the "required" field changes.""" fill_expr = struct.Field(s_expr.Expression, default=None) def get_friendly_description( self, *, parent_op: Optional[sd.Command] = None, schema: Optional[s_schema.Schema] = None, object: Any = None, object_desc: Optional[str] = None, ) -> str: object_desc = self.get_friendly_object_name_for_description( parent_op=parent_op, schema=schema, object=object, object_desc=object_desc, ) required = self.get_attribute_value('required') return f"make {object_desc} {'required' if required else 'optional'}" def is_data_safe(self) -> bool: return True def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: from edb.ir import utils as irutils orig_schema = schema schema = super()._alter_begin(schema, context) scls = self.scls orig_required = scls.get_required(orig_schema) new_required = scls.get_required(schema) new_card = scls.get_cardinality(schema) is_computed = 'required' in scls.get_computed_fields(schema) if orig_required == new_required or is_computed: # The actual value hasn't changed, nothing to do here. return schema if not context.canonical: vn = scls.get_verbosename(schema, with_parent=True) if self.fill_expr is not None: self.fill_expr = self._compile_expr( schema=orig_schema, context=context, expr=self.fill_expr, target_as_singleton=True, singleton_result_expected=new_card.is_single(), no_query_rewrites=True, expr_description=( f'the USING clause for the alteration of {vn}' ), ) using_type = self.fill_expr.stype ptr_type = scls.get_target(schema) assert ptr_type is not None if not using_type.assignment_castable_to( ptr_type, self.fill_expr.schema, ): ot = using_type.get_verbosename(self.fill_expr.schema) nt = ptr_type.get_verbosename(schema) raise errors.SchemaError( f'result of USING clause for the alteration of ' f'{vn} cannot be cast automatically from ' f'{ot} to {nt} ', hint='You might need to add an explicit cast.', span=self.span, ) if using_type.is_view(self.fill_expr.schema): raise errors.SchemaError( f'result of USING clause for the alteration of ' f'{vn} may not include a shape', span=self.span, ) if ( self.scls.is_link_property(schema) and irutils.contains_dml(self.fill_expr.ir_statement) ): raise errors.UnsupportedFeatureError( f'USING clause for the alteration of optionality of ' f'{vn} cannot include mutating statements, ' 'because it is a link property', span=self.span, ) schema = self._propagate_if_expr_refs( schema, context, action=( f'make {vn} {"required" if new_required else "optional"}' ), ) return schema def record_diff_annotations( self, *, schema: s_schema.Schema, orig_schema: Optional[s_schema.Schema], context: so.ComparisonContext, object: Optional[so.Object], orig_object: Optional[so.Object], ) -> None: super().record_diff_annotations( schema=schema, orig_schema=orig_schema, context=context, orig_object=orig_object, object=object, ) if not context.generate_prompts: return if len(context.parent_ops) <= 1: return ptr_op = context.parent_ops[-1] src_op = context.parent_ops[-2] needs_fill_expr = self._needs_fill_expr( schema=schema, ptr_op=ptr_op, src_op=src_op, ) if needs_fill_expr: placeholder_name = context.get_placeholder('fill_expr') desc = self.get_friendly_description( schema=schema, parent_op=src_op) prompt = ( f'Please specify an expression to populate existing objects ' f'in order to {desc}' ) type_name = _get_target_name_in_diff( schema=schema, orig_schema=orig_schema, object=object, orig_object=orig_object, ) self.set_annotation('required_input', dict( placeholder=placeholder_name, prompt=prompt, type=str(type_name), pointer_name=self.get_displayname(), )) self.fill_expr = s_expr.Expression.from_ast( qlast.Placeholder(name=placeholder_name), schema, ) def _needs_fill_expr( self, *, schema: s_schema.Schema, ptr_op: sd.ObjectCommand[so.Object], src_op: sd.ObjectCommand[so.Object], ) -> bool: old_required = self.get_orig_attribute_value('required') or False new_required = self.get_attribute_value('required') or False return ( not old_required and new_required and not self.is_attribute_computed('required') and not ptr_op.maybe_get_object_aux_data('from_alias') and self.fill_expr is None and not ( ptr_op.get_attribute_value('declared_overloaded') or isinstance(src_op, sd.CreateObject) ) ) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, AlterPointerLowerCardinality) if ( isinstance(astnode, qlast.SetPointerOptionality) and astnode.fill_expr is not None ): cmd.fill_expr = s_expr.Expression.from_ast( astnode.fill_expr, schema, context.modaliases, context.localnames, ) return cmd def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: set_field = super()._get_ast(schema, context, parent_node=parent_node) if set_field is None and not self.fill_expr: return None else: if set_field is not None: assert isinstance(set_field, qlast.SetField) value = set_field.value else: req = self.get_attribute_value('required') value = (utils.const_ast_from_python(req) if req is not None else None) return qlast.SetPointerOptionality( value=value, fill_expr=( self.fill_expr.parse() if self.fill_expr is not None else None ) ) def get_or_create_union_pointer( schema: s_schema.Schema, ptrname: sn.UnqualName, source: s_sources.Source, direction: PointerDirection, components: Iterable[Pointer], *, transient: bool = False, opaque: bool = False, modname: Optional[str] = None, ) -> tuple[s_schema.Schema, Pointer]: from . import sources as s_sources components = list(components) if len(components) == 1 and direction is PointerDirection.Outbound: return schema, components[0] # We want to transform all the computables in the list of the # components to their respective owned computables. This is to # ensure that mixing multiple inherited copies of the same # computable is actually allowed. comp_set = set() for c in components: if c.is_pure_computable(schema): comp_set.add(_get_nearest_owned(schema, c)) else: comp_set.add(c) components = list(comp_set) if ( any(p.is_pure_computable(schema) for p in components) and len(components) > 1 and ptrname.name not in ('__tname__', '__tid__') ): p = components[0] raise errors.SchemaError( f'it is illegal to create a type union that causes ' f'a computed {p.get_verbosename(schema)} to mix ' f'with other versions of the same {p.get_verbosename(schema)}', ) if len(components) == 1 and direction is PointerDirection.Outbound: return schema, components[0] far_endpoints = [ p.get_far_endpoint(schema, direction) for p in components ] targets: Sequence[s_types.Type] = [ p for p in far_endpoints if isinstance(p, s_types.Type) ] targets = utils.simplify_union_types(schema, targets) target: s_types.Type schema, target, _ = utils.ensure_union_type( schema, targets, opaque=opaque, module=modname, transient=transient) cardinality = qltypes.SchemaCardinality.One for component in components: if component.get_cardinality(schema) is qltypes.SchemaCardinality.Many: cardinality = qltypes.SchemaCardinality.Many break required = all(component.get_required(schema) for component in components) metacls = type(components[0]) default_base_name = metacls.get_default_base_name() assert default_base_name is not None genptr = schema.get(default_base_name, type=Pointer) if direction is PointerDirection.Inbound: # type ignore below, because the types "Type" and "Source" # could only be swapped by their common ancestor so.Object, # and here we are considering them both as more specific objects source, target = target, source # type: ignore schema, result = genptr.get_derived( schema, source, target, derived_name_base=sn.QualName(module='__', name=ptrname.name), attrs={ 'union_of': so.ObjectSet.create(schema, components), 'cardinality': cardinality, 'required': required, }, transient=transient, ) if isinstance(result, s_sources.Source) and not opaque: # cast below, because in this case the list of Pointer # is also a list of Source (links.Link) schema = s_sources.populate_pointer_set_for_source_union( schema, cast(list[s_sources.Source], components), result, modname=modname, ) return schema, result def _get_nearest_owned( schema: s_schema.Schema, pointer: Pointer, ) -> Pointer: if pointer.get_owned(schema): return pointer for p in pointer.get_ancestors(schema).objects(schema): if p.get_owned(schema): return p return pointer def get_or_create_intersection_pointer( schema: s_schema.Schema, ptrname: sn.UnqualName, source: s_objtypes.ObjectType, components: Iterable[Pointer], *, modname: Optional[str] = None, transient: bool = False, ) -> tuple[s_schema.Schema, Pointer]: components = list(components) if len(components) == 1: return schema, components[0] required = any(component.get_required(schema) for component in components) targets: Sequence[s_types.Type] targets = list(filter(None, [p.get_target(schema) for p in components])) targets = utils.simplify_intersection_types(schema, targets) schema, target, _ = utils.ensure_intersection_type( schema, targets, module=modname) cardinality = qltypes.SchemaCardinality.Many for component in components: if component.get_cardinality(schema) is qltypes.SchemaCardinality.One: cardinality = qltypes.SchemaCardinality.One break metacls = type(components[0]) default_base_name = metacls.get_default_base_name() assert default_base_name is not None genptr = schema.get(default_base_name, type=Pointer) schema, result = genptr.get_derived( schema, source, target, derived_name_base=sn.QualName(module='__', name=ptrname.name), attrs={ 'intersection_of': so.ObjectSet.create(schema, components), 'cardinality': cardinality, 'required': required, }, transient=transient, ) # We want to transform all the computables in the list of the # components to their respective owned computables. This is to # ensure that mixing multiple inherited copies of the same # computable is actually allowed. comp_set = set() for c in components: if c.is_pure_computable(schema): comp_set.add(_get_nearest_owned(schema, c)) else: comp_set.add(c) components = list(comp_set) if ( any(p.is_pure_computable(schema) for p in components) and len(components) > 1 and ptrname.name not in ('__tname__', '__tid__') ): p = components[0] raise errors.SchemaError( f'it is illegal to create a type intersection that causes ' f'a computed {p.get_verbosename(schema)} to mix ' f'with other versions of the same {p.get_verbosename(schema)}', ) if len({p.get_cardinality(schema) for p in components}) > 1: p = components[0] raise errors.SchemaError( f'it is illegal to create a type intersection that causes ' f'a {p.get_verbosename(schema)} to mix ' f'with other versions of {p.get_verbosename(schema)} ' f'which have a different cardinality', ) return schema, result @s_futures.register_handler('no_linkful_computed_splats') def toggle_no_linkful_computed_splats( cmd: s_futures.FutureBehaviorCommand, schema: s_schema.Schema, context: sd.CommandContext, on: bool, ) -> tuple[s_schema.Schema, sd.Command]: # Nothing to do because splats can't appear in functions group = sd.CommandGroup() return schema, group ================================================ FILE: edb/schema/policies.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, TYPE_CHECKING from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from . import annos as s_anno from . import delta as sd from . import expr as s_expr from . import futures as s_futures from . import name as sn from . import objects as so from . import properties as s_props from . import referencing from . import schema as s_schema from . import sources as s_sources from . import types as s_types if TYPE_CHECKING: from . import objtypes as s_objtypes class AccessPolicy( referencing.NamedReferencedInheritingObject, so.InheritingObject, # Help reflection figure out the right db MRO s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.ACCESS_POLICY, data_safe=True, ): condition = so.SchemaField( s_expr.Expression, default=None, coerce=True, compcoef=0.909, special_ddl_syntax=True, ) expr = so.SchemaField( s_expr.Expression, default=None, compcoef=0.909, special_ddl_syntax=True, ) action = so.SchemaField( qltypes.AccessPolicyAction, coerce=True, compcoef=0.85, special_ddl_syntax=True, ) access_kinds = so.SchemaField( so.MultiPropSet[qltypes.AccessKind], coerce=True, compcoef=0.85, special_ddl_syntax=True, ) subject = so.SchemaField( so.InheritingObject, compcoef=None, inheritable=False) errmessage = so.SchemaField( str, default=None, compcoef=0.971, allow_ddl_set=True ) # We don't support SET/DROP OWNED owned on policies so we set its # compcoef to 0.0 owned = so.SchemaField( bool, default=False, inheritable=False, compcoef=0.0, reflection_method=so.ReflectionMethod.AS_LINK, special_ddl_syntax=True, ) def get_expr_refs(self, schema: s_schema.Schema) -> list[so.Object]: objs: list[so.Object] = [] if (condition := self.get_condition(schema)) and condition.refs: objs.extend(condition.refs.objects(schema)) if (expr := self.get_expr(schema)) and expr.refs: objs.extend(expr.refs.objects(schema)) return objs def get_subject(self, schema: s_schema.Schema) -> s_objtypes.ObjectType: subj: s_objtypes.ObjectType = self.get_field_value(schema, 'subject') return subj def get_original_subject( self, schema: s_schema.Schema ) -> s_objtypes.ObjectType: ancs = (self,) + self.get_ancestors(schema).objects(schema) return ancs[-1].get_subject(schema) class AccessPolicyCommandContext( sd.ObjectCommandContext[AccessPolicy], s_anno.AnnotationSubjectCommandContext, ): pass class AccessPolicySourceCommandContext[Source_T: s_sources.Source]( s_sources.SourceCommandContext[Source_T] ): pass class AccessPolicyCommand( referencing.NamedReferencedInheritingObjectCommand[AccessPolicy], s_anno.AnnotationSubjectCommand[AccessPolicy], context_class=AccessPolicyCommandContext, referrer_context_class=AccessPolicySourceCommandContext, ): def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) parent_ctx = self.get_referrer_context_or_die(context) source = parent_ctx.op.scls pol_name = self.get_verbosename(parent=source.get_verbosename(schema)) for field in ('expr', 'condition'): if (expr := self.get_local_attribute_value(field)) is None: continue vname = 'when' if field == 'condition' else 'using' expression = self.compile_expr_field( schema, context, field=AccessPolicy.get_field(field), value=expr, ) span = self.get_attribute_span(field) if expression.irast.cardinality.can_be_zero(): raise errors.SchemaDefinitionError( f'possibly an empty set returned by {vname} ' f'expression for the {pol_name} ', span=span ) if expression.irast.cardinality.is_multi(): raise errors.SchemaDefinitionError( f'possibly more than one element returned by {vname} ' f'expression for the {pol_name} ', span=span ) if expression.irast.volatility.is_volatile(): raise errors.SchemaDefinitionError( f'{pol_name} has a volatile {vname} expression, ' f'which is not allowed', span=span ) target = schema.get(sn.QualName('std', 'bool'), type=s_types.Type) expr_type = expression.irast.stype if not expr_type.issubclass(expression.irast.schema, target): span = self.get_attribute_span(field) raise errors.SchemaDefinitionError( f'{vname} expression for {pol_name} is of invalid type: ' f'{expr_type.get_displayname(expression.irast.schema)}, ' f'expected {target.get_displayname(schema)}', span=self.span, ) return schema def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: if field.name in {'expr', 'condition'}: parent_ctx = self.get_referrer_context_or_die(context) source = parent_ctx.op.get_object(schema, context) parent_vname = source.get_verbosename(schema) pol_name = self.get_verbosename(parent=parent_vname) in_ddl_context_name = pol_name assert isinstance(source, s_types.Type) return type(value).compiled( value, schema=schema, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, schema_object_context=self.get_schema_metaclass(), anchors={'__subject__': source}, path_prefix_anchor='__subject__', singletons=frozenset({source}), apply_query_rewrites=not context.stdmode, apply_user_access_policies=False, track_schema_ref_exprs=track_schema_ref_exprs, in_ddl_context_name=in_ddl_context_name, detached=True, ), context=context, ) else: return super().compile_expr_field( schema, context, field, value, track_schema_ref_exprs) def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: if field.name in {'expr', 'condition'}: return s_expr.Expression(text='false') else: raise NotImplementedError(f'unhandled field {field.name!r}') def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: subject = self.scls.get_subject(schema) for obj in self.scls.get_expr_refs(schema): if isinstance(obj, s_props.Property): # Disable access of link properties with default # values from all the post-check DML changes. This is # because linkprop default values currently come from # the postgres side, so we don't have access to them # before actually doing the link table inserts. # TODO: Fix this. if ( obj.get_source(schema) and obj.is_link_property(schema) and obj.get_default(schema) and any( kind.is_data_check() for kind in self.scls.get_access_kinds(schema) ) ): pol_name = self.get_verbosename( parent=subject.get_verbosename(schema)) obj_name = obj.get_verbosename(schema, with_parent=True) raise errors.UnsupportedFeatureError( f'insert and update write access policies may not ' f'refer to link properties with default values: ' f'{pol_name} refers to {obj_name}', span=self.span, ) class CreateAccessPolicy( AccessPolicyCommand, referencing.CreateReferencedInheritingObject[AccessPolicy], ): referenced_astnode = astnode = qlast.CreateAccessPolicy def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if ( field in ('expr', 'condition', 'action', 'access_kinds') and issubclass(astnode, qlast.CreateAccessPolicy) ): return field else: return super().get_ast_attr_for_field(field, astnode) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(astnode, qlast.CreateAccessPolicy) assert isinstance(cmd, AccessPolicyCommand) if astnode.condition is not None: cmd.set_attribute_value( 'condition', s_expr.Expression.from_ast( astnode.condition, schema, context.modaliases, context.localnames, ), span=astnode.condition.span, ) if astnode.expr: cmd.set_attribute_value( 'expr', s_expr.Expression.from_ast( astnode.expr, schema, context.modaliases, context.localnames, ), span=astnode.expr.span, ) cmd.set_attribute_value('action', astnode.action) cmd.set_attribute_value('access_kinds', astnode.access_kinds) return cmd class RenameAccessPolicy( AccessPolicyCommand, referencing.RenameReferencedInheritingObject[AccessPolicy], ): pass class RebaseAccessPolicy( AccessPolicyCommand, referencing.RebaseReferencedInheritingObject[AccessPolicy], ): pass class AlterAccessPolicy( AccessPolicyCommand, referencing.AlterReferencedInheritingObject[AccessPolicy], ): referenced_astnode = astnode = qlast.AlterAccessPolicy def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) # If either action or access_kinds appears, make sure the # other one does as well, so that _apply_field_ast has # a canonical setup to work with. if ( self.has_attribute_value('action') and not self.has_attribute_value('access_kinds') ): self.set_attribute_value( 'access_kinds', self.scls.get_access_kinds(schema)) elif ( self.has_attribute_value('access_kinds') and not self.has_attribute_value('action') ): self.set_attribute_value('action', self.scls.get_action(schema)) # TODO: We may wish to support this in the future but it will # take some thought. if ( self.get_attribute_value('owned') and not self.get_orig_attribute_value('owned') ): raise errors.SchemaDefinitionError( f'cannot alter the definition of inherited access policy ' f'{self.scls.get_displayname(schema)}', span=self.span ) return schema def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: if op.property == 'action': pass elif op.property == 'access_kinds': node.commands.append( qlast.SetAccessPerms( action=self.get_attribute_value('action'), access_kinds=op.new_value, ) ) else: super()._apply_field_ast(schema, context, node, op) # This is kind of a hack: we never actually instantiate this class, we # just use its _cmd_tree_from_ast to produce a command group with two # property sets. class AlterAccessPolicyPerms( referencing.ReferencedInheritingObjectCommand[AccessPolicy], referrer_context_class=AccessPolicyCommandContext, ): referenced_astnode = astnode = qlast.SetAccessPerms @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: assert isinstance(astnode, qlast.SetAccessPerms) cmd = sd.CommandGroup() cmd.add( sd.AlterObjectProperty( property='action', new_value=astnode.action, span=astnode.span, ) ) cmd.add( sd.AlterObjectProperty( property='access_kinds', new_value=astnode.access_kinds, span=astnode.span, ) ) return cmd class DeleteAccessPolicy( AccessPolicyCommand, referencing.DeleteReferencedInheritingObject[AccessPolicy], ): referenced_astnode = astnode = qlast.DropAccessPolicy @s_futures.register_handler('nonrecursive_access_policies') def toggle_nonrecursive_access_policies( cmd: s_futures.FutureBehaviorCommand, schema: s_schema.Schema, context: sd.CommandContext, on: bool, ) -> tuple[s_schema.Schema, sd.Command]: # Nothing to do anymore group = sd.CommandGroup() return schema, group ================================================ FILE: edb/schema/properties.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, TYPE_CHECKING from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb import errors from . import constraints from . import delta as sd from . import inheriting from . import name as sn from . import objects as so from . import pointers from . import referencing from . import rewrites as s_rewrites from . import sources as s_sources from . import types as s_types from . import utils from . import expr as s_expr if TYPE_CHECKING: from . import schema as s_schema class Property( pointers.Pointer, qlkind=qltypes.SchemaObjectClass.PROPERTY, data_safe=False, ): def derive_ref( self, schema: s_schema.Schema, referrer: so.QualifiedObject, *qualifiers: str, target: Optional[s_types.Type] = None, attrs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, Property]: from . import links as s_links if target is None: target = self.get_target(schema) schema, ptr = super().derive_ref( schema, referrer, target=target, attrs=attrs, **kwargs) ptr_sn = str(ptr.get_shortname(schema)) if ptr_sn == 'std::source': assert isinstance(referrer, s_links.Link) schema = ptr.set_field_value( schema, 'target', referrer.get_source(schema)) elif ptr_sn == 'std::target': schema = ptr.set_field_value( schema, 'target', referrer.get_field_value(schema, 'target')) assert isinstance(ptr, Property) return schema, ptr def compare( self, other: so.Object, *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: so.ComparisonContext, ) -> float: if not isinstance(other, Property): if isinstance(other, pointers.Pointer): return 0.0 else: raise NotImplementedError similarity = super().compare( other, our_schema=our_schema, their_schema=their_schema, context=context) if ( not self.is_non_concrete(our_schema) and not other.is_non_concrete(their_schema) and self.issubclass( our_schema, our_schema.get('std::source', type=Property) ) and other.issubclass( their_schema, their_schema.get('std::source', type=Property) ) ): # Make std::source link property ignore differences in its target. # This is consistent with skipping the comparison on Pointer.source # in general. field = self.__class__.get_field('target') target_coef = field.type.compare_values( self.get_target(our_schema), other.get_target(their_schema), our_schema=our_schema, their_schema=their_schema, context=context, compcoef=field.compcoef) if target_coef < 1: similarity *= target_coef return similarity def should_propagate(self, schema: s_schema.Schema) -> bool: # @source and @target link props don't propagate to children # because we create new properties with distinct types. return not self.is_endpoint_pointer(schema) @classmethod def is_property(cls, schema: Optional[s_schema.Schema]=None) -> bool: return True def has_user_defined_properties(self, schema: s_schema.Schema) -> bool: return False def is_link_property(self, schema: s_schema.Schema) -> bool: source = self.get_source(schema) if source is None: return False return isinstance(source, pointers.Pointer) def allow_ref_propagation( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, ) -> bool: source = self.get_source(schema) if isinstance(source, pointers.Pointer): if source.is_non_concrete(schema): return True else: source = source.get_source(schema) assert isinstance(source, s_types.Type) return not source.is_view(schema) else: assert isinstance(source, s_types.Type) return not source.is_view(schema) @classmethod def get_root_classes(cls) -> tuple[sn.QualName, ...]: return ( sn.QualName(module='std', name='property'), ) @classmethod def get_default_base_name(self) -> sn.QualName: return sn.QualName('std', 'property') def is_blocking_ref( self, schema: s_schema.Schema, reference: so.Object, ) -> bool: return not self.is_endpoint_pointer(schema) def init_delta_command[ ObjectCommand_T: sd.ObjectCommand[so.Object] ]( self, schema: s_schema.Schema, cmdtype: type[ObjectCommand_T], *, classname: Optional[sn.Name] = None, **kwargs: Any, ) -> ObjectCommand_T: delta = super().init_delta_command( schema=schema, cmdtype=cmdtype, classname=classname, **kwargs, ) assert isinstance(delta, referencing.ReferencedObjectCommandBase) delta.is_strong_ref = self.is_special_pointer(schema) return delta # type: ignore class PropertySourceContext[Source_T: s_sources.Source]( s_sources.SourceCommandContext[Source_T] ): pass class PropertySourceCommand[Source_T: s_sources.Source]( inheriting.InheritingObjectCommand[Source_T], ): pass class PropertyCommandContext( pointers.PointerCommandContext, constraints.ConsistencySubjectCommandContext, s_rewrites.RewriteCommandContext, ): pass class PropertyCommand( pointers.PointerCommand[Property], context_class=PropertyCommandContext, referrer_context_class=PropertySourceContext, ): def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: """Check that property definition is sound.""" super().validate_object(schema, context) scls = self.scls if not scls.get_owned(schema): return if scls.is_special_pointer(schema): return if ( scls.is_link_property(schema) and not scls.is_pure_computable(schema) ): # link properties cannot be multi if (self.get_attribute_value('cardinality') is qltypes.SchemaCardinality.Many): raise errors.InvalidPropertyDefinitionError( "multi properties aren't supported for links", span=self.span, ) target_type = scls.get_target(schema) if target_type is None: raise TypeError(f'missing target type in scls {scls}') if target_type.is_polymorphic(schema): span = self.get_attribute_span('target') raise errors.InvalidPropertyTargetError( f'invalid property type: ' f'{target_type.get_verbosename(schema)} ' f'is a generic type', span=span, ) if (target_type.is_object_type() or (isinstance(target_type, s_types.Collection) and target_type.contains_object(schema))): span = self.get_attribute_span('target') raise errors.InvalidPropertyTargetError( f'invalid property type: expected a scalar type, ' f'or a scalar collection, got ' f'{target_type.get_verbosename(schema)}', span=span, ) def _check_field_errors(self, node: qlast.DDLOperation) -> None: for sub in node.commands: # do not allow link property on properties if isinstance(sub, qlast.CreateConcretePointer): raise errors.InvalidDefinitionError( f'cannot place a link property on a property', span=node.span, hint=( 'Link properties can only be placed on links, whose ' 'target types are object types.' ), ) # do not allow on source/target delete on properties if isinstance(sub, (qlast.OnSourceDelete, qlast.OnTargetDelete)): raise errors.InvalidDefinitionError( f'cannot place a deletion policy on a property', span=node.span, hint=( 'Deletion policies can only be placed on links, whose ' 'target types are object types.' ), ) class CreateProperty( PropertyCommand, pointers.CreatePointer[Property], ): astnode = [qlast.CreateConcreteProperty, qlast.CreateProperty] referenced_astnode = qlast.CreateConcreteProperty @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) if isinstance(astnode, qlast.CreateConcreteProperty): assert isinstance(cmd, PropertyCommand) cmd._process_create_or_alter_ast(schema, astnode, context) cmd._check_field_errors(astnode) return cmd def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if ( field == 'required' and issubclass(astnode, qlast.CreateConcreteProperty) ): return 'is_required' elif ( field == 'cardinality' and issubclass(astnode, qlast.CreateConcreteProperty) ): return 'cardinality' else: return super().get_ast_attr_for_field(field, astnode) def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: link = context.get(PropertySourceContext) if op.property == 'target' and link: if isinstance(node, qlast.CreateConcreteProperty): expr: Optional[s_expr.Expression] = ( self.get_attribute_value('expr') ) if expr is not None: node.target = expr.parse() else: ref = op.new_value assert isinstance(ref, (so.Object, so.ObjectShell)) node.target = utils.typeref_to_ast(schema, ref) else: ref = op.new_value assert isinstance(ref, (so.Object, so.ObjectShell)) node.commands.append( qlast.SetPointerType( value=utils.typeref_to_ast(schema, ref) ) ) else: super()._apply_field_ast(schema, context, node, op) class RenameProperty( PropertyCommand, referencing.RenameReferencedInheritingObject[Property], ): pass class RebaseProperty( PropertyCommand, referencing.RebaseReferencedInheritingObject[Property], ): pass class SetPropertyType( pointers.SetPointerType[Property], referrer_context_class=PropertySourceContext, field='target', ): pass class AlterPropertyUpperCardinality( pointers.AlterPointerUpperCardinality[Property], referrer_context_class=PropertySourceContext, field='cardinality', ): pass class AlterPropertyLowerCardinality( pointers.AlterPointerLowerCardinality[Property], referrer_context_class=PropertySourceContext, field='required', ): pass class AlterPropertyOwned( referencing.AlterOwned[Property], pointers.PointerCommandOrFragment[Property], referrer_context_class=PropertySourceContext, field='owned', ): pass class AlterProperty( PropertyCommand, pointers.AlterPointer[Property], ): astnode = [qlast.AlterConcreteProperty, qlast.AlterProperty] referenced_astnode = qlast.AlterConcreteProperty @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> AlterProperty: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(cmd, AlterProperty) if isinstance(astnode, qlast.CreateConcreteProperty): cmd._process_create_or_alter_ast(schema, astnode, context) else: cmd._process_alter_ast(schema, astnode, context) cmd._check_field_errors(astnode) return cmd def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: if op.property == 'target': if op.new_value: assert isinstance(op.new_value, so.ObjectShell) node.commands.append( qlast.SetPointerType( value=utils.typeref_to_ast(schema, op.new_value), ), ) else: super()._apply_field_ast(schema, context, node, op) def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if self.maybe_get_object_aux_data('from_alias'): # This is an alias type, appropriate DDL would be generated # from the corresponding Alter/DeleteAlias node. return None else: return super()._get_ast(schema, context, parent_node=parent_node) class DeleteProperty( PropertyCommand, pointers.DeletePointer[Property], ): astnode = [qlast.DropConcreteProperty, qlast.DropProperty] referenced_astnode = qlast.DropConcreteProperty def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if self.maybe_get_object_aux_data('from_alias'): # This is an alias type, appropriate DDL would be generated # from the corresponding Alter/DeleteAlias node. return None else: return super()._get_ast(schema, context, parent_node=parent_node) ================================================ FILE: edb/schema/pseudo.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, TypeVar, TYPE_CHECKING from edb import errors from edb.common import parsing from edb.edgeql import ast as qlast from edb.edgeql import qltypes from . import delta as sd from . import name as sn from . import objects as so from . import types as s_types if TYPE_CHECKING: from . import schema as s_schema PseudoType_T = TypeVar("PseudoType_T", bound="PseudoType") class PseudoType( so.InheritingObject, s_types.Type, qlkind=qltypes.SchemaObjectClass.PSEUDO_TYPE, ): @classmethod def get( cls, schema: s_schema.Schema, name: str | sn.Name, ) -> PseudoType: return schema.get_global(PseudoType, name) def as_shell(self, schema: s_schema.Schema) -> PseudoTypeShell: return PseudoTypeShell(name=self.get_name(schema)) def get_bases( self, schema: s_schema.Schema, ) -> so.ObjectList[PseudoType]: return so.ObjectList[PseudoType].create_empty() # type: ignore def get_ancestors( self, schema: s_schema.Schema, ) -> so.ObjectList[PseudoType]: return so.ObjectList[PseudoType].create_empty() # type: ignore def get_abstract(self, schema: s_schema.Schema) -> bool: return True def is_polymorphic(self, schema: s_schema.Schema) -> bool: return True def material_type( self, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, PseudoType]: return schema, self def is_any(self, schema: s_schema.Schema) -> bool: return str(self.get_name(schema)) == 'anytype' def is_anytuple(self, schema: s_schema.Schema) -> bool: return str(self.get_name(schema)) == 'anytuple' def is_anyobject(self, schema: s_schema.Schema) -> bool: return str(self.get_name(schema)) == 'anyobject' def is_tuple(self, schema: s_schema.Schema) -> bool: return self.is_anytuple(schema) def implicitly_castable_to( self, other: s_types.Type, schema: s_schema.Schema ) -> bool: return self == other def find_common_implicitly_castable_type( self, other: s_types.Type, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Optional[PseudoType]]: if self == other: return schema, self else: return schema, None def get_common_parent_type_distance( self, other: s_types.Type, schema: s_schema.Schema ) -> int: if self == other: return 0 else: return s_types.MAX_TYPE_DISTANCE def _test_polymorphic( self, schema: s_schema.Schema, other: s_types.Type ) -> bool: return self == other def _to_nonpolymorphic( self, schema: s_schema.Schema, concrete_type: s_types.Type ) -> tuple[s_schema.Schema, s_types.Type]: return schema, concrete_type def _resolve_polymorphic( self, schema: s_schema.Schema, concrete_type: s_types.Type ) -> Optional[s_types.Type]: if self.is_any(schema): return concrete_type if self.is_anyobject(schema): if ( not concrete_type.is_object_type() or concrete_type.is_polymorphic(schema) ): return None else: return concrete_type elif self.is_anytuple(schema): if (not concrete_type.is_tuple(schema) or concrete_type.is_polymorphic(schema)): return None else: return concrete_type else: raise ValueError( f'unexpected pseudo type: {self.get_name(schema)}') class PseudoTypeShell(s_types.TypeShell[PseudoType]): def __init__( self, *, name: sn.Name, span: Optional[parsing.Span] = None, ) -> None: super().__init__( name=name, schemaclass=PseudoType, span=span ) def is_polymorphic(self, schema: s_schema.Schema) -> bool: return True def resolve(self, schema: s_schema.Schema) -> PseudoType: return PseudoType.get(schema, self.name) class PseudoTypeCommandContext(sd.ObjectCommandContext[PseudoType]): pass class PseudoTypeCommand( s_types.TypeCommand[PseudoType], context_class=PseudoTypeCommandContext, ): pass class CreatePseudoType(PseudoTypeCommand, sd.CreateObject[PseudoType]): astnode = qlast.CreatePseudoType @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: if not context.stdmode and not context.testmode: raise errors.UnsupportedFeatureError( 'user-defined pseudo types are not supported', span=astnode.span ) return super()._cmd_tree_from_ast(schema, astnode, context) ================================================ FILE: edb/schema/referencing.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, ClassVar, Optional, TypeVar, AbstractSet, Iterable, cast, ) import hashlib from edb import errors from edb.common import struct from edb.edgeql import ast as qlast from . import delta as sd from . import inheriting from . import objects as so from . import schema as s_schema from . import name as sn from . import utils ReferencedT = TypeVar('ReferencedT', bound='ReferencedObject') ReferencedInheritingObjectT = TypeVar('ReferencedInheritingObjectT', bound='ReferencedInheritingObject') # Q: There are no ReferencedObjects that aren't ReferencedInheritingObject; # should we merge them? class ReferencedObject(so.DerivableObject): #: True if the object has an explicit definition and is not #: purely inherited. owned = so.SchemaField( bool, default=False, inheritable=False, compcoef=0.909, reflection_method=so.ReflectionMethod.AS_LINK, special_ddl_syntax=True, ) @classmethod def get_verbosename_static( cls, name: sn.Name, *, parent: Optional[str] = None, ) -> str: clsname = cls.get_schema_class_displayname() dname = cls.get_displayname_static(name) sn = cls.get_shortname_static(name) if sn == name: clsname = f'abstract {clsname}' if parent is not None: return f"{clsname} '{dname}' of {parent}" else: return f"{clsname} '{dname}'" def get_subject(self, schema: s_schema.Schema) -> Optional[so.Object]: # NB: classes that inherit ReferencedObject define a `get_subject` # method dynamically, with `subject = SchemaField` raise NotImplementedError def get_referrer(self, schema: s_schema.Schema) -> Optional[so.Object]: return self.get_subject(schema) def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool = False, ) -> str: vn = super().get_verbosename(schema) if with_parent: return self.add_parent_name(vn, schema) return vn def add_parent_name( self, base_name: str, schema: s_schema.Schema, ) -> str: subject = self.get_subject(schema) if subject is not None: pn = subject.get_verbosename(schema, with_parent=True) return f'{base_name} of {pn}' return base_name def init_parent_delta_branch( self: ReferencedT, schema: s_schema.Schema, context: sd.CommandContext, *, referrer: Optional[so.Object] = None, ) -> tuple[ sd.CommandGroup, sd.Command, sd.ContextStack, ]: root, parent, ctx_stack = super().init_parent_delta_branch( schema, context, referrer=referrer) if referrer is None: referrer = self.get_referrer(schema) if referrer is None: return root, parent, ctx_stack obj: Optional[so.Object] = referrer object_stack: list[so.Object] = [referrer] while obj is not None: if isinstance(obj, ReferencedObject): obj = obj.get_referrer(schema) if obj is not None: object_stack.append(obj) else: obj = None cmd: sd.Command = parent for obj in reversed(object_stack): alter_cmd = obj.init_delta_command(schema, sd.AlterObject) ctx_stack.push(alter_cmd.new_context(schema, context, obj)) cmd.add(alter_cmd) cmd = alter_cmd return root, cmd, ctx_stack def is_parent_ref( self, schema: s_schema.Schema, reference: so.Object, ) -> bool: """Return True if *reference* is a structural ancestor of self""" obj = self.get_referrer(schema) while obj is not None: if obj == reference: return True elif isinstance(obj, ReferencedObject): obj = obj.get_referrer(schema) else: break return False class ReferencedInheritingObject( so.DerivableInheritingObject, ReferencedObject, ): # Indicates that the object has been declared as # explicitly inherited. declared_overloaded = so.SchemaField( bool, default=False, compcoef=None, inheritable=False, ephemeral=True, ) def should_propagate(self, schema: s_schema.Schema) -> bool: """Whether this object should be propagated to subtypes of the owner""" return True def get_implicit_bases( self: ReferencedInheritingObjectT, schema: s_schema.Schema, ) -> list[ReferencedInheritingObjectT]: return [ b for b in self.get_bases(schema).objects(schema) if not b.is_non_concrete(schema) ] def get_implicit_ancestors( self: ReferencedInheritingObjectT, schema: s_schema.Schema, ) -> list[ReferencedInheritingObjectT]: return [ b for b in self.get_ancestors(schema).objects(schema) if not b.is_non_concrete(schema) ] def get_name_impacting_ancestors( self: ReferencedInheritingObjectT, schema: s_schema.Schema, ) -> list[ReferencedInheritingObjectT]: """Return ancestors that have an impact on the name of this object. For most types this is the same as implicit ancestors. (For constraints it is not.) """ return self.get_implicit_ancestors(schema) def is_endpoint_pointer(self, schema: s_schema.Schema) -> bool: # overloaded by Pointer return False def as_delete_delta( self: ReferencedInheritingObjectT, *, schema: s_schema.Schema, context: so.ComparisonContext, ) -> sd.ObjectCommand[ReferencedInheritingObjectT]: del_op = super().as_delete_delta(schema=schema, context=context) if ( self.get_owned(schema) and not self.is_generated(schema) and any( context.is_deleting(schema, ancestor) for ancestor in self.get_implicit_ancestors(schema) ) ): owned_op = self.init_delta_command(schema, AlterOwned) owned_op.set_attribute_value('owned', False, orig_value=True) del_op.add(owned_op) return del_op def record_field_alter_delta( self: ReferencedInheritingObjectT, schema: s_schema.Schema, delta: sd.ObjectCommand[ReferencedInheritingObjectT], context: so.ComparisonContext, *, fname: str, value: Any, orig_value: Any, orig_schema: s_schema.Schema, orig_object: ReferencedInheritingObjectT, confidence: float, ) -> None: super().record_field_alter_delta( schema, delta, context, fname=fname, value=value, orig_value=orig_value, orig_schema=orig_schema, orig_object=orig_object, confidence=confidence, ) if fname == 'name': if any( context.is_renaming(orig_schema, ancestor) for ancestor in orig_object.get_name_impacting_ancestors( orig_schema) ): renames = delta.get_subcommands(type=sd.RenameObject) assert len(renames) == 1 rename = renames[0] rename.set_annotation('implicit_propagation', True) def derive_ref( self: ReferencedInheritingObjectT, schema: s_schema.Schema, referrer: so.QualifiedObject, *qualifiers: str, mark_derived: bool = False, attrs: Optional[dict[str, Any]] = None, dctx: Optional[sd.CommandContext] = None, derived_name_base: Optional[sn.Name] = None, inheritance_merge: bool = True, inheritance_refdicts: Optional[AbstractSet[str]] = None, transient: bool = False, preserve_endpoint_ptrs: bool = False, name: Optional[sn.QualName] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, ReferencedInheritingObjectT]: if name is None: derived_name = self.get_derived_name( schema, referrer, *qualifiers, mark_derived=mark_derived, derived_name_base=derived_name_base, ) else: derived_name = name if self.get_name(schema) == derived_name: raise errors.SchemaError( f'cannot derive {self!r}({derived_name}) from itself') derived_attrs: dict[str, object] = {} if attrs is not None: derived_attrs.update(attrs) derived_attrs['name'] = derived_name derived_attrs['bases'] = so.ObjectList.create(schema, [self]) mcls = type(self) referrer_class = type(referrer) refdict = referrer_class.get_refdict_for_class(mcls) reftype = referrer_class.get_field(refdict.attr).type refname = reftype.get_key_for_name(schema, derived_name) refcoll = referrer.get_field_value(schema, refdict.attr) existing = refcoll.get(schema, refname, default=None) cmdcls = sd.AlterObject if existing is not None else sd.CreateObject cmd: sd.ObjectCommand[ReferencedInheritingObjectT] = ( sd.get_object_delta_command( # type: ignore[type-var, assignment] objtype=type(self), cmdtype=cmdcls, schema=schema, name=derived_name, ) ) for k, v in derived_attrs.items(): cmd.set_attribute_value(k, v) if existing is not None: new_bases = derived_attrs['bases'] old_bases = existing.get_bases(schema) if new_bases != old_bases: assert isinstance(new_bases, so.ObjectList) removed_bases, added_bases = inheriting.delta_bases( [b.get_name(schema) for b in old_bases.objects(schema)], [b.get_name(schema) for b in new_bases.objects(schema)], t=type(self), ) rebase_cmd = sd.get_object_delta_command( objtype=type(self), cmdtype=inheriting.RebaseInheritingObject, schema=schema, name=derived_name, added_bases=added_bases, removed_bases=removed_bases, ) cmd.add(rebase_cmd) context = sd.CommandContext(modaliases={}, schema=schema) delta, parent_cmd, _ = self.init_parent_delta_branch( schema, context, referrer=referrer) root = sd.DeltaRoot() root.add(delta) with context(sd.DeltaRootContext(schema=schema, op=root)): if not inheritance_merge: context.current().inheritance_merge = False if inheritance_refdicts is not None: context.current().inheritance_refdicts = ( inheritance_refdicts) if mark_derived: context.current().mark_derived = True if transient: context.current().transient_derivation = True if not preserve_endpoint_ptrs: context.current().slim_links = True parent_cmd.add(cmd) schema = delta.apply(schema, context) derived = schema.get(derived_name, type=type(self)) return schema, derived class ReferencedObjectCommandBase(sd.QualifiedObjectCommand[ReferencedT]): _referrer_context_class: ClassVar[Optional[ type[sd.ObjectCommandContext[so.Object]] ]] = None #: Whether the referenced command represents a "strong" reference, #: i.e. the one that must not be broken out of the enclosing parent #: command when doing dependency reorderings. is_strong_ref = struct.Field(bool, default=False) def __init_subclass__( cls, *, referrer_context_class: Optional[ type[sd.ObjectCommandContext[so.Object]] ] = None, **kwargs: Any, ) -> None: super().__init_subclass__(**kwargs) if referrer_context_class is not None: cls._referrer_context_class = referrer_context_class @classmethod def get_referrer_context_class( cls, ) -> type[sd.ObjectCommandContext[so.Object]]: if cls._referrer_context_class is None: raise TypeError( f'referrer_context_class is not defined for {cls}') return cls._referrer_context_class @classmethod def get_referrer_context( cls, context: sd.CommandContext, ) -> Optional[sd.ObjectCommandContext[so.Object]]: """Get the context of the command for the referring object, if any. E.g. for a `create/alter/etc concrete link` command this would be the context of the `create/alter/etc type` command. """ ctxcls = cls.get_referrer_context_class() return context.get(ctxcls) @classmethod def get_referrer_context_or_die( cls, context: sd.CommandContext, ) -> sd.ObjectCommandContext[so.Object]: ctx = cls.get_referrer_context(context) if ctx is None: raise RuntimeError(f'no referrer context for {cls}') return ctx def get_top_referrer_op( self, context: sd.CommandContext, ) -> Optional[sd.ObjectCommand[so.Object]]: op: sd.ObjectCommand[so.Object] = self # type: ignore while True: if not isinstance(op, ReferencedObjectCommandBase): break ctx = op.get_referrer_context(context) if ctx is None: break op = ctx.op return op class ReferencedObjectCommand(ReferencedObjectCommandBase[ReferencedT]): @classmethod def _classname_from_ast_and_referrer( cls, schema: s_schema.Schema, referrer_name: sn.QualName, astnode: qlast.ObjectDDL, context: sd.CommandContext ) -> sn.QualName: base_ref = utils.ast_to_object_shell( astnode.name, modaliases=context.modaliases, schema=schema, metaclass=cls.get_schema_metaclass(), ) base_name = sn.shortname_from_fullname(base_ref.name) quals = cls._classname_quals_from_ast( schema, astnode, base_name, referrer_name, context) pnn = sn.get_specialized_name(base_name, str(referrer_name), *quals) return sn.QualName(name=pnn, module=referrer_name.module) @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> sn.QualName: parent_ctx = cls.get_referrer_context(context) if parent_ctx is not None: assert isinstance(parent_ctx.op, sd.QualifiedObjectCommand) referrer_name = context.get_referrer_name(parent_ctx) name = cls._classname_from_ast_and_referrer( schema, referrer_name, astnode, context ) else: name = super()._classname_from_ast(schema, astnode, context) assert isinstance(name, sn.QualName) return name @classmethod def _classname_from_name( cls, name: sn.QualName, referrer_name: sn.QualName, ) -> sn.QualName: base_name = sn.shortname_from_fullname(name) quals = cls._classname_quals_from_name(name) pnn = sn.get_specialized_name(base_name, str(referrer_name), *quals) return sn.QualName(name=pnn, module=referrer_name.module) @classmethod def _classname_quals_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, base_name: sn.Name, referrer_name: sn.QualName, context: sd.CommandContext, ) -> tuple[str, ...]: return () @classmethod def _classname_quals_from_name( cls, name: sn.QualName, ) -> tuple[str, ...]: return () @classmethod def _name_qual_from_exprs( cls, schema: s_schema.Schema, exprs: Iterable[str] ) -> str: m = hashlib.sha1() for expr in exprs: m.update(expr.encode()) return m.hexdigest() def _get_ast_node( self, schema: s_schema.Schema, context: sd.CommandContext ) -> type[qlast.DDLOperation]: subject_ctx = self.get_referrer_context(context) ref_astnode: Optional[type[qlast.DDLOperation]] = ( getattr(self, 'referenced_astnode', None)) if subject_ctx is not None and ref_astnode is not None: return ref_astnode else: if isinstance(self.astnode, (list, tuple)): return self.astnode[1] else: return self.astnode class CreateReferencedObject( ReferencedObjectCommand[ReferencedT], sd.CreateObject[ReferencedT], ): referenced_astnode: ClassVar[type[qlast.ObjectDDL]] @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) if isinstance(astnode, cls.referenced_astnode): objcls = cls.get_schema_metaclass() referrer_ctx = cls.get_referrer_context_or_die(context) referrer_class = referrer_ctx.op.get_schema_metaclass() referrer_name = context.get_referrer_name(referrer_ctx) refdict = referrer_class.get_refdict_for_class(objcls) cmd.set_attribute_value( refdict.backref_attr, so.ObjectShell( name=referrer_name, schemaclass=referrer_class, ), ) cmd.set_attribute_value('owned', True) if getattr(astnode, 'abstract', None): cmd.set_attribute_value('abstract', True) return cmd def _get_ast_node( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> type[qlast.DDLOperation]: # Render CREATE as ALTER in DDL if the created referenced object is # implicitly inherited from parents. scls = self.get_object(schema, context) assert isinstance(scls, ReferencedInheritingObject) implicit_bases = scls.get_implicit_bases(schema) if ( implicit_bases and not context.declarative and not self.ast_ignore_ownership() ): alter = scls.init_delta_command(schema, sd.AlterObject) return alter._get_ast_node(schema, context) else: return super()._get_ast_node(schema, context) @classmethod def as_inherited_ref_cmd( cls, *, schema: s_schema.Schema, context: sd.CommandContext, astnode: qlast.ObjectDDL, bases: list[ReferencedT], referrer: so.Object, ) -> sd.ObjectCommand[ReferencedT]: cmd = cls(classname=cls._classname_from_ast(schema, astnode, context)) cmd.set_attribute_value('name', cmd.classname) cmd.set_attribute_value( 'bases', so.ObjectList.create(schema, bases).as_shell(schema)) return cmd @classmethod def as_inherited_ref_ast( cls, schema: s_schema.Schema, context: sd.CommandContext, refname: sn.Name, parent: ReferencedObject, ) -> qlast.ObjectDDL: # N.B: If this is overloaded, then as_inherited_ref_cmd # probably needs to be overloaded too. # In particular, any fields that are inherited=False in the schema # but actually need to be inherited (like arguments for constraints # and indexes) likely should be handled there. nref = cls.get_inherited_ref_name(schema, context, parent, refname) astnode_cls = cls.referenced_astnode astnode = astnode_cls(name=nref) assert isinstance(astnode, qlast.ObjectDDL) return astnode @classmethod def get_inherited_ref_name( cls, schema: s_schema.Schema, context: sd.CommandContext, parent: ReferencedObject, refname: sn.Name, ) -> qlast.ObjectRef: ref = utils.name_to_ast_ref(refname) if ref.module is None: ref.module = parent.get_shortname(schema).module return ref def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._create_begin(schema, context) referrer_ctx = self.get_referrer_context(context) if referrer_ctx is not None: referrer = referrer_ctx.scls referrer_cls = type(referrer) mcls = type(self.scls) refdict = referrer_cls.get_refdict_for_class(mcls) schema = referrer.add_classref(schema, refdict.attr, self.scls) return schema class DeleteReferencedObjectCommand( ReferencedObjectCommand[ReferencedT], sd.DeleteObject[ReferencedT], ): def _delete_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._delete_innards(schema, context) referrer_ctx = self.get_referrer_context(context) if referrer_ctx is not None: referrer = referrer_ctx.scls schema = self._delete_ref(schema, context, referrer) return schema def _delete_ref( self, schema: s_schema.Schema, context: sd.CommandContext, referrer: so.Object, ) -> s_schema.Schema: scls = self.scls referrer_class = type(referrer) mcls = type(scls) refdict = referrer_class.get_refdict_for_class(mcls) reftype = referrer_class.get_field(refdict.attr).type refname = reftype.get_key_for(schema, self.scls) return referrer.del_classref(schema, refdict.attr, refname) class ReferencedInheritingObjectCommand( ReferencedObjectCommand[ReferencedInheritingObjectT], inheriting.InheritingObjectCommand[ReferencedInheritingObjectT], ): def _get_implicit_ref_bases( self, schema: s_schema.Schema, context: sd.CommandContext, referrer: so.InheritingObject, referrer_field: str, fq_name: sn.QualName, ) -> list[ReferencedInheritingObjectT]: ref_field_type = type(referrer).get_field(referrer_field).type assert isinstance(referrer, so.QualifiedObject) child_referrer_bases = referrer.get_bases(schema).objects(schema) implicit_bases = [] for ref_base in child_referrer_bases: fq_name_in_child = self._classname_from_name( fq_name, ref_base.get_name(schema)) refname = ref_field_type.get_key_for_name(schema, fq_name_in_child) parent_coll = ref_base.get_field_value(schema, referrer_field) parent_item = parent_coll.get(schema, refname, default=None) if ( parent_item is not None and parent_item.should_propagate(schema) and not context.is_deleting(parent_item) ): implicit_bases.append(parent_item) return implicit_bases def get_ref_implicit_base_delta( self, schema: s_schema.Schema, context: sd.CommandContext, refcls: ReferencedInheritingObjectT, implicit_bases: list[ReferencedInheritingObjectT], ) -> inheriting.BaseDelta_T[ReferencedInheritingObjectT]: child_bases = refcls.get_bases(schema).objects(schema) default_base = refcls.get_default_base_name() explicit_bases = [ b for b in child_bases if b.is_non_concrete(schema) and b.get_name(schema) != default_base ] new_bases = implicit_bases + explicit_bases return inheriting.delta_bases( [b.get_name(schema) for b in child_bases], [b.get_name(schema) for b in new_bases], t=type(refcls), ) def _validate( self, schema: s_schema.Schema, context: sd.CommandContext ) -> None: scls = self.scls implicit_bases = [ b for b in scls.get_bases(schema).objects(schema) if not b.is_non_concrete(schema) ] referrer_ctx = self.get_referrer_context_or_die(context) objcls = self.get_schema_metaclass() referrer_class = referrer_ctx.op.get_schema_metaclass() refdict = referrer_class.get_refdict_for_class(objcls) if context.declarative and scls.get_owned(schema): if (implicit_bases and refdict.requires_explicit_overloaded and not self.get_attribute_value('declared_overloaded')): ancestry = [] for obj in implicit_bases: bref = obj.get_referrer(schema) assert bref is not None ancestry.append(bref) alist = ", ".join( str(a.get_shortname(schema)) for a in ancestry ) raise errors.SchemaDefinitionError( f'{self.scls.get_verbosename(schema, with_parent=True)} ' f'must be declared using the `overloaded` keyword because ' f'it is defined in the following ancestor(s): {alist}', span=self.span, ) elif (not implicit_bases and self.get_attribute_value('declared_overloaded')): raise errors.SchemaDefinitionError( f'{self.scls.get_verbosename(schema, with_parent=True)}: ' f'cannot be declared `overloaded` as there are no ' f'ancestors defining it.', span=self.span, ) def get_implicit_bases( self, schema: s_schema.Schema, context: sd.CommandContext, bases: Any, ) -> list[sn.QualName]: mcls = self.get_schema_metaclass() default_base = mcls.get_default_base_name() if isinstance(bases, so.ObjectCollectionShell): base_names = [b.get_name(schema) for b in bases.items] elif isinstance(bases, so.ObjectList): base_names = list(bases.names(schema)) else: # assume regular iterable of shells base_names = [b.get_name(schema) for b in bases] # Filter out explicit bases implicit_bases = [ b for b in base_names if ( b != default_base and isinstance(b, sn.QualName) and sn.shortname_from_fullname(b) != b ) ] return implicit_bases def _propagate_ref_op( self, schema: s_schema.Schema, context: sd.CommandContext, scls: ReferencedInheritingObject, cb: Callable[[sd.ObjectCommand[so.Object], sn.Name], None] ) -> None: if inheriting._has_implicit_propagation(context): return referrer_ctx = self.get_referrer_context(context) if referrer_ctx: referrer = referrer_ctx.scls referrer_class = type(referrer) mcls = type(scls) refdict = referrer_class.get_refdict_for_class(mcls) reftype = referrer_class.get_field(refdict.attr).type refname = reftype.get_key_for(schema, self.scls) else: refname = self.scls.get_name(schema) for descendant in scls.ordered_descendants(schema): d_alter_root, d_alter_cmd, ctx_stack = ( descendant.init_delta_branch(schema, context, sd.AlterObject)) d_alter_cmd.set_annotation('implicit_propagation', True) with ctx_stack(): cb(d_alter_cmd, refname) self.add_caused(d_alter_root) def _propagate_ref_field_alter_in_inheritance( self, schema: s_schema.Schema, context: sd.CommandContext, field_name: str, require_inheritance_consistency: bool = True, ) -> None: """Validate and propagate a field alteration to children. This method also performs consistency checks against base objects to ensure that the new value matches that of the parents. """ scls = self.scls currently_altered = context.change_log[type(scls), field_name] currently_altered.add(scls) if require_inheritance_consistency: implicit_bases = scls.get_implicit_bases(schema) non_altered_bases = [] value = scls.get_field_value(schema, field_name) for base in { x for x in implicit_bases if x not in currently_altered}: base_value = base.get_field_value(schema, field_name) if isinstance(value, so.SubclassableObject): if not value.issubclass(schema, base_value): non_altered_bases.append(base) else: if value != base_value: non_altered_bases.append(base) # This object is inherited from one or more ancestors that # are not altered in the same op, and this is an error. if non_altered_bases: bases_str = ', '.join( b.get_verbosename(schema, with_parent=True) for b in non_altered_bases ) vn = scls.get_verbosename(schema, with_parent=True) desc = self.get_friendly_description( schema=schema, object_desc=f'inherited {vn}', ) raise errors.SchemaDefinitionError( f'cannot {desc}', details=( f'{vn} is inherited from ' f'{bases_str}' ), span=self.span, ) value = self.get_attribute_value(field_name) def _propagate( alter_cmd: sd.ObjectCommand[so.Object], refname: sn.Name, ) -> None: assert isinstance(alter_cmd, sd.QualifiedObjectCommand) s_t: sd.ObjectCommand[ReferencedInheritingObjectT] if isinstance(self, sd.AlterSpecialObjectField): s_t = self.clone(alter_cmd.classname) else: s_t = type(self)(classname=alter_cmd.classname) orig_value = scls.get_explicit_field_value( schema, field_name, default=None) s_t.set_attribute_value( field_name, value, orig_value=orig_value, inherited=True, ) alter_cmd.add(s_t) self._propagate_ref_op(schema, context, scls, cb=_propagate) def _drop_owned_refs( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, ) -> s_schema.Schema: scls = self.scls refs = scls.get_field_value(schema, refdict.attr) ref: ReferencedInheritingObject for ref in refs.objects(schema): inherited = ref.get_implicit_bases(schema) if inherited and ref.get_owned(schema): alter = ref.init_delta_command(schema, sd.AlterObject) alter.set_attribute_value('owned', False, orig_value=True) schema = alter.apply(schema, context) self.add(alter) elif ( # drop things that aren't owned and aren't inherited not inherited # endpoint pointers are special because they aren't marked as # inherited even though they basically are and not ref.is_endpoint_pointer(schema) ): drop_ref = ref.init_delta_command(schema, sd.DeleteObject) self.add(drop_ref) return schema class CreateReferencedInheritingObject( CreateReferencedObject[ReferencedInheritingObjectT], inheriting.CreateInheritingObject[ReferencedInheritingObjectT], ReferencedInheritingObjectCommand[ReferencedInheritingObjectT], ): def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: refctx = type(self).get_referrer_context(context) if refctx is not None: if self.get_attribute_value('from_alias'): return None elif ( not self.get_attribute_value('owned') and not self.ast_ignore_ownership() ): if context.descriptive_mode: astnode = super()._get_ast( schema, context, parent_node=parent_node, ) assert astnode is not None inherited_from = [ sn.quals_from_fullname(b)[0] for b in self.get_implicit_bases( schema, context, self.get_attribute_value('bases'), ) ] astnode.system_comment = ( f'inherited from {", ".join(inherited_from)}' ) return astnode else: return None else: astnode = super()._get_ast( schema, context, parent_node=parent_node) if context.declarative: scls = self.get_object(schema, context) assert isinstance(scls, ReferencedInheritingObject) implicit_bases = scls.get_implicit_bases(schema) objcls = self.get_schema_metaclass() referrer_class = refctx.op.get_schema_metaclass() refdict = referrer_class.get_refdict_for_class(objcls) if refdict.requires_explicit_overloaded and implicit_bases: assert isinstance(astnode, qlast.CreateConcretePointer) astnode.declared_overloaded = True return astnode else: return super()._get_ast(schema, context, parent_node=parent_node) def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: referrer_ctx = self.get_referrer_context(context) implicit_bases = None if referrer_ctx is not None and not context.canonical: objcls = self.get_schema_metaclass() referrer = referrer_ctx.scls if isinstance(referrer, so.InheritingObject): referrer_class = referrer_ctx.op.get_schema_metaclass() refdict = referrer_class.get_refdict_for_class(objcls) implicit_bases = self._get_implicit_ref_bases( schema, context, referrer, refdict.attr, self.classname) if implicit_bases: bases = self.get_attribute_value('bases') if bases: res_bases = cast( list[ReferencedInheritingObjectT], self.resolve_obj_collection(bases, schema)) bases = so.ObjectList.create( schema, implicit_bases + [ b for b in res_bases if b not in implicit_bases ], ) else: bases = so.ObjectList.create( schema, implicit_bases, ) self.set_attribute_value('bases', bases.as_shell(schema)) if referrer.get_is_derived(schema): self.set_attribute_value('is_derived', True) return super()._create_begin(schema, context) def _create_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if ( not context.canonical and context.enable_recursion and (referrer_ctx := self.get_referrer_context(context)) and isinstance(referrer := referrer_ctx.scls, so.InheritingObject) and self.scls.should_propagate(schema) ): # Propagate the creation of a new ref to # descendants of our referrer. self._propagate_ref_creation(schema, context, referrer) return super()._create_innards(schema, context) def _create_finalize( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._create_finalize(schema, context) if not context.canonical: referrer_ctx = self.get_referrer_context(context) if referrer_ctx is not None: self._validate(schema, context) return schema def _propagate_ref_creation( self, schema: s_schema.Schema, context: sd.CommandContext, referrer: so.InheritingObject, ) -> None: get_cmd = sd.get_object_command_class_or_die mcls = type(self.scls) referrer_cls = type(referrer) ref_create_cmd = get_cmd(sd.CreateObject, mcls) ref_alter_cmd = get_cmd(sd.AlterObject, mcls) ref_rebase_cmd = get_cmd(inheriting.RebaseInheritingObject, mcls) assert issubclass(ref_create_cmd, CreateReferencedInheritingObject) assert issubclass(ref_rebase_cmd, RebaseReferencedInheritingObject) refdict = referrer_cls.get_refdict_for_class(mcls) parent_fq_refname = self.scls.get_name(schema) for child in referrer.children(schema): if not child.allow_ref_propagation(schema, context, refdict): continue alter_root, alter, ctx_stack = child.init_delta_branch( schema, context, sd.AlterObject) with ctx_stack(): # This is needed to get the correct inherited name which will # either be created or rebased. ref_field_type = type(child).get_field(refdict.attr).type refname = ref_field_type.get_key_for_name( schema, parent_fq_refname) astnode = ref_create_cmd.as_inherited_ref_ast( schema, context, refname, self.scls) fq_name = self._classname_from_ast(schema, astnode, context) # We cannot check for ref existence in this child at this # time, because it might get created in a sibling branch # of the delta tree. Instead, generate a command group # containing Alter(if_exists) and Create(if_not_exists) # to postpone that check until the application time. ref_create = ref_create_cmd.as_inherited_ref_cmd( schema=schema, context=context, astnode=astnode, bases=[self.scls], referrer=child, ) assert isinstance(ref_create, sd.CreateObject) ref_create.if_not_exists = True # Copy any special updates over for special in self.get_subcommands( type=sd.AlterSpecialObjectField): ref_create.add(special.clone(ref_create.classname)) ref_create.set_attribute_value(refdict.backref_attr, child) if child.get_is_derived(schema): # All references in a derived object must # also be marked as derived, to be consistent # with derive_subtype(). ref_create.set_attribute_value('is_derived', True) ref_alter = ref_alter_cmd(classname=fq_name, if_exists=True) ref_alter.add(ref_rebase_cmd( classname=fq_name, implicit=True, added_bases=(), removed_bases=(), )) alter.add(ref_alter) alter.add(ref_create) self.add_caused(alter_root) class AlterReferencedInheritingObject( ReferencedInheritingObjectCommand[ReferencedInheritingObjectT], inheriting.AlterInheritingObject[ReferencedInheritingObjectT], ): def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if self.get_attribute_value('from_alias'): return None else: return super()._get_ast(schema, context, parent_node=parent_node) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) refctx = cls.get_referrer_context(context) # When a referenced object is altered it becomes "owned" # by the referrer, _except_ when either an explicit # SET OWNED/DROP OWNED subcommand is present, or # _all_ subcommands are `RESET` subcommands. if ( refctx is not None and qlast.get_ddl_field_command(astnode, 'owned') is None and ( not cmd.get_subcommands() or not all( ( isinstance(scmd, sd.AlterObjectProperty) and scmd.new_value is None ) for scmd in cmd.get_subcommands() ) ) ): cmd.set_attribute_value('owned', True) assert isinstance(cmd, AlterReferencedInheritingObject) return cmd def _alter_finalize( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_finalize(schema, context) scls = self.scls was_owned = scls.get_owned(context.current().original_schema) now_owned = scls.get_owned(schema) if not was_owned and now_owned: self._validate(schema, context) return schema class RebaseReferencedInheritingObject( ReferencedInheritingObjectCommand[ReferencedInheritingObjectT], inheriting.RebaseInheritingObject[ReferencedInheritingObjectT], ): implicit = struct.Field(bool, default=False) def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if not context.canonical and self.implicit: mcls = self.get_schema_metaclass() refctx = self.get_referrer_context_or_die(context) referrer = refctx.scls assert isinstance(referrer, so.InheritingObject) refdict = type(referrer).get_refdict_for_class(mcls) implicit_bases = self._get_implicit_ref_bases( schema, context, referrer=referrer, referrer_field=refdict.attr, fq_name=self.classname, ) scls = self.get_object(schema, context) removed_bases, added_bases = self.get_ref_implicit_base_delta( schema, context, scls, implicit_bases=implicit_bases, ) self.added_bases = added_bases self.removed_bases = removed_bases return super().apply(schema, context) def _get_bases_for_ast( self, schema: s_schema.Schema, context: sd.CommandContext, bases: tuple[so.ObjectShell[ReferencedInheritingObjectT], ...], ) -> tuple[so.ObjectShell[ReferencedInheritingObjectT], ...]: bases = super()._get_bases_for_ast(schema, context, bases) implicit_bases = set(self.get_implicit_bases(schema, context, bases)) return tuple(b for b in bases if b.name not in implicit_bases) class RenameReferencedInheritingObject( ReferencedInheritingObjectCommand[ReferencedInheritingObjectT], inheriting.RenameInheritingObject[ReferencedInheritingObjectT], ): def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super()._alter_begin(schema, context) scls = self.scls referrer_ctx = self.get_referrer_context(context) if referrer_ctx: mcls = self.get_schema_metaclass() referrer_class = referrer_ctx.op.get_schema_metaclass() refdict = referrer_class.get_refdict_for_class(mcls) reftype = referrer_class.get_field(refdict.attr).type # Force a refresh of the refdict, since the rename may # have invalidated its cache of names. referrer = referrer_ctx.scls schema = referrer.refresh_classref(schema, refdict.attr) if not context.canonical and not scls.is_non_concrete(schema): assert referrer_ctx orig_ref_fqname = scls.get_name(orig_schema) orig_ref_lname = reftype.get_key_for_name(schema, orig_ref_fqname) new_ref_fqname = scls.get_name(schema) new_ref_lname = reftype.get_key_for_name(schema, new_ref_fqname) # Distinguish between actual local name change and fully-qualified # name change due to structural parent rename. if orig_ref_lname != new_ref_lname: implicit_bases = scls.get_implicit_bases(orig_schema) non_renamed_bases = { x for x in implicit_bases if x not in context.renamed_objs} # This object is inherited from one or more ancestors that # are not renamed in the same op, and this is an error. if non_renamed_bases: bases_str = ', '.join( b.get_verbosename(schema, with_parent=True) for b in non_renamed_bases ) verb = 'are' if len(non_renamed_bases) > 1 else 'is' vn = scls.get_verbosename(orig_schema, with_parent=True) raise errors.SchemaDefinitionError( f'cannot rename inherited {vn}', details=( f'{vn} is inherited from ' f'{bases_str}, which {verb} not being renamed' ), span=self.span, ) self._propagate_ref_rename(schema, context, scls) return schema def _propagate_ref_rename( self, schema: s_schema.Schema, context: sd.CommandContext, scls: ReferencedInheritingObject ) -> None: rename_cmdcls = sd.get_object_command_class_or_die( sd.RenameObject, type(scls)) def _ref_rename(alter_cmd: sd.Command, refname: sn.Name) -> None: astnode = rename_cmdcls.astnode( # type: ignore new_name=utils.name_to_ast_ref(refname), ) rename_cmd = rename_cmdcls._rename_cmd_from_ast( schema, astnode, context) alter_cmd.add(rename_cmd) self._propagate_ref_op(schema, context, scls, cb=_ref_rename) def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if self.get_annotation('implicit_propagation'): return None else: return super()._get_ast(schema, context, parent_node=parent_node) class DeleteReferencedInheritingObject( DeleteReferencedObjectCommand[ReferencedInheritingObjectT], inheriting.DeleteInheritingObject[ReferencedInheritingObjectT], ReferencedInheritingObjectCommand[ReferencedInheritingObjectT], ): def _delete_innards( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if ( not context.canonical and (referrer_ctx := self.get_referrer_context(context)) and isinstance(referrer := referrer_ctx.scls, so.InheritingObject) and self.scls.should_propagate(schema) ): self._propagate_ref_deletion(schema, context, referrer) return super()._delete_innards(schema, context) def _propagate_ref_deletion( self, schema: s_schema.Schema, context: sd.CommandContext, referrer: so.InheritingObject, ) -> None: scls = self.scls self_name = scls.get_name(schema) referrer_class = type(referrer) mcls = type(scls) refdict = referrer_class.get_refdict_for_class(mcls) reftype = referrer_class.get_field(refdict.attr).type if ( not context.in_deletion(offset=1) and not context.disable_dep_verification ): implicit_bases = set(self._get_implicit_ref_bases( schema, context, referrer, refdict.attr, self_name)) if implicit_bases: # Cannot remove inherited objects. vn = scls.get_verbosename(schema, with_parent=True) parents = [ b.get_field_value(schema, refdict.backref_attr) for b in implicit_bases ] pnames = '\n- '.join( p.get_verbosename(schema, with_parent=True) for p in parents ) raise errors.SchemaError( f'cannot drop inherited {vn}', span=self.span, details=f'{vn} is inherited from:\n- {pnames}' ) # Sort the children by reverse inheritance order amongst them. # So if we are T and have children A and B and A <: B, we want to # process A first, since we need to rebase it away from T, and then # dropping A will also drop B. for child in reversed( sd.sort_by_inheritance(schema, referrer.children(schema)) ): assert isinstance(child, so.QualifiedObject) child_coll = child.get_field_value(schema, refdict.attr) fq_refname_in_child = self._classname_from_name( self_name, child.get_name(schema), ) child_refname = reftype.get_key_for_name( schema, fq_refname_in_child) existing = child_coll.get(schema, child_refname, None) if existing is not None: alter_root, alter_leaf, ctx_stack = ( existing.init_parent_delta_branch( schema, context, referrer=child)) with ctx_stack(): cmd = self._propagate_child_ref_deletion( schema, context, refdict, child, existing) alter_leaf.add(cmd) self.add_caused(alter_root) def _propagate_child_ref_deletion( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, child: so.InheritingObject, child_ref: ReferencedInheritingObjectT, ) -> sd.Command: name = child_ref.get_name(schema) implicit_bases = self._get_implicit_ref_bases( schema, context, child, refdict.attr, name) cmd: sd.Command if child_ref.get_owned(schema) or implicit_bases: # Child is either defined locally or is inherited # from another parent, so we need to do a rebase. removed_bases, added_bases = self.get_ref_implicit_base_delta( schema, context, child_ref, implicit_bases) rebase_cmd = child_ref.init_delta_command( schema, inheriting.RebaseInheritingObject, added_bases=added_bases, removed_bases=removed_bases, ) cmd = child_ref.init_delta_command(schema, sd.AlterObject) cmd.add(rebase_cmd) else: # The ref in child should no longer exist. cmd = child_ref.init_delta_command(schema, sd.DeleteObject) return cmd def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: refctx = type(self).get_referrer_context(context) if ( refctx is not None and not self.get_orig_attribute_value('owned') ): return None else: return super()._get_ast(schema, context, parent_node=parent_node) class AlterOwned( ReferencedInheritingObjectCommand[ReferencedInheritingObjectT], inheriting.AlterInheritingObjectFragment[ReferencedInheritingObjectT], sd.AlterSpecialObjectField[ReferencedInheritingObjectT], ): _delta_action = 'alterowned' def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: orig_schema = schema schema = super()._alter_begin(schema, context) scls = self.scls orig_owned = scls.get_owned(orig_schema) owned = scls.get_owned(schema) if ( orig_owned != owned and not owned and not context.canonical ): implicit_bases = scls.get_implicit_bases(schema) if not implicit_bases: # ref isn't actually inherited, so cannot be un-owned vn = scls.get_verbosename(schema, with_parent=True) sn = type(scls).get_schema_class_displayname().upper() raise errors.InvalidDefinitionError( f'cannot drop owned {vn}, as it is not inherited, ' f'use DROP {sn} instead', span=self.span, ) # DROP OWNED requires special handling: the object in question # must revert all modification made on top of inherited attributes. bases = scls.get_bases(schema).objects(schema) schema = self.inherit_fields( schema, context, bases, ignore_local=True, ) for refdict in type(scls).get_refdicts(): schema = self._drop_owned_refs(schema, context, refdict) return schema class NamedReferencedInheritingObject( ReferencedInheritingObject, ): """A referenced inheriting object that has an explicit local name. That is, things like pointers, access policies, and triggers, which are referenced by another object but have an explicitly specified name. """ @classmethod def get_displayname_static(cls, name: sn.Name) -> str: sn = cls.get_shortname_static(name) if sn.module == '__': return sn.name else: return str(sn) def get_derived_name_base( self, schema: s_schema.Schema, ) -> sn.QualName: shortname = self.get_shortname(schema) return sn.QualName(module='__', name=shortname.name) class NamedReferencedInheritingObjectCommand( ReferencedInheritingObjectCommand[ReferencedInheritingObjectT], ): # XXX: Do we want different namespaces for different kinds of objects? @classmethod def _classname_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> sn.QualName: referrer_ctx = cls.get_referrer_context(context) if referrer_ctx is not None: referrer_name = context.get_referrer_name(referrer_ctx) shortname = sn.QualName(module='__', name=astnode.name.name) name = sn.QualName( module=referrer_name.module, name=sn.get_specialized_name(shortname, str(referrer_name)), ) else: name = super()._classname_from_ast(schema, astnode, context) return name def _deparse_name( self, schema: s_schema.Schema, context: sd.CommandContext, name: sn.Name, ) -> qlast.ObjectRef: ref = super()._deparse_name(schema, context, name) referrer_ctx = self.get_referrer_context(context) if referrer_ctx is None: return ref else: ref.module = '' return ref ================================================ FILE: edb/schema/reflection/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from .reader import parse_schema, SchemaClassLayout from .structure import generate_structure from .structure import SchemaTypeLayout, SchemaReflectionParts from .writer import generate_metadata_write_edgeql __all__ = ( 'generate_structure', 'generate_metadata_write_edgeql', 'parse_schema', 'SchemaTypeLayout', 'SchemaClassLayout', 'SchemaReflectionParts' ) ================================================ FILE: edb/schema/reflection/reader.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Callable import collections import functools import json import uuid import immutables from edb.common import checked from edb.common import verutils from edb.common import uuidgen from edb.schema import abc as s_abc from edb.schema import expr as s_expr from edb.schema import functions as s_func from edb.schema import name as s_name from edb.schema import objects as s_obj from edb.schema import operators as s_oper from edb.schema import schema as s_schema from edb.schema import version as s_ver from . import structure as sr_struct SchemaClassLayout = dict[type[s_obj.Object], sr_struct.SchemaTypeLayout] def parse_schema( base_schema: s_schema.Schema, data: str | bytes, schema_class_layout: SchemaClassLayout, ) -> s_schema.FlatSchema: """Parse JSON-encoded schema objects and populate the schema with them. Args: schema: A schema instance to use as a starting point. data: A JSON-encoded schema object data as returned by an introspection query. schema_class_layout: A mapping describing schema class layout in the reflection, as returned from :func:`schema.reflection.structure.generate_structure`. Returns: A schema instance including objects encoded in the provided JSON sequence. """ id_to_type = {} id_to_data = {} name_to_id = {} shortname_to_id = collections.defaultdict(set) globalname_to_id = {} dict_of_dicts: Callable[ [], dict[tuple[type[s_obj.Object], str], dict[uuid.UUID, None]], ] = functools.partial(collections.defaultdict, dict) refs_to: dict[ uuid.UUID, dict[tuple[type[s_obj.Object], str], dict[uuid.UUID, None]] ] = collections.defaultdict(dict_of_dicts) objects: dict[uuid.UUID, tuple[s_obj.Object, dict[str, Any]]] = {} objid: uuid.UUID for entry in json.loads(data): _, _, clsname = entry['_tname'].rpartition('::') mcls = s_obj.ObjectMeta.maybe_get_schema_class(clsname) if mcls is None: raise ValueError( f'unexpected type in schema reflection: {clsname}') objid = uuidgen.UUID(entry['id']) objects[objid] = (mcls._create_from_id(objid), entry) refdict_updates = {} for objid, (obj, entry) in objects.items(): mcls = type(obj) name = s_name.name_from_string(entry['name__internal']) layout = schema_class_layout[mcls] if ( base_schema.has_object(objid) and not isinstance(obj, s_ver.BaseSchemaVersion) ): continue if isinstance(obj, s_obj.QualifiedObject): name_to_id[name] = objid else: name = s_name.UnqualName(str(name)) globalname_to_id[mcls, name] = objid if isinstance(obj, (s_func.Function, s_oper.Operator)): shortname = mcls.get_shortname_static(name) shortname_to_id[mcls, shortname].add(objid) id_to_type[objid] = type(obj).__name__ all_fields = mcls.get_schema_fields() objdata: list[Any] = [None] * len(all_fields) val: Any refid: uuid.UUID for k, v in entry.items(): desc = layout.get(k) if desc is None: continue fn = desc.fieldname field = all_fields.get(fn) if field is None: continue findex = field.index if desc.storage is not None: if v is None: pass elif desc.storage.ptrkind == 'link': refid = uuidgen.UUID(v['id']) newobj = objects.get(refid) if newobj is not None: val = newobj[0] else: val = base_schema.get_by_id(refid) objdata[findex] = val.schema_reduce() refs_to[val.id][mcls, fn][objid] = None elif desc.storage.ptrkind == 'multi link': ftype = mcls.get_field(fn).type if issubclass(ftype, s_obj.ObjectDict): refids = ftype._container( uuidgen.UUID(e['value']) for e in v) refkeys = tuple(e['name'] for e in v) val = ftype(refids, refkeys, _private_init=True) else: refids = ftype._container( uuidgen.UUID(e['id']) for e in v) val = ftype(refids, _private_init=True) objdata[findex] = val.schema_reduce() for refid in refids: refs_to[refid][mcls, fn][objid] = None elif desc.storage.shadow_ptrkind: val = entry[f'{k}__internal'] ftype = mcls.get_field(fn).type if val is not None and type(val) is not ftype: if issubclass(ftype, s_expr.Expression): val = _parse_expression(val, objid, k) for refid in val.refs.ids(): refs_to[refid][mcls, fn][objid] = None elif issubclass(ftype, s_expr.ExpressionList): exprs = [] for e_dict in val: e = _parse_expression(e_dict, objid, k) assert e.refs is not None for refid in e.refs.ids(): refs_to[refid][mcls, fn][objid] = None exprs.append(e) val = ftype(exprs) elif issubclass(ftype, s_expr.ExpressionDict): expr_dict = dict() for e_dict in val: e = _parse_expression( e_dict['expr'], objid, k) assert e.refs is not None for refid in e.refs.ids(): refs_to[refid][mcls, fn][objid] = None expr_dict[e_dict['name']] = e val = ftype(expr_dict) elif issubclass(ftype, s_obj.Object): val = val.id elif issubclass(ftype, s_name.Name): if isinstance(obj, s_obj.QualifiedObject): val = s_name.name_from_string(val) else: val = s_name.UnqualName(val) else: val = ftype(val) if issubclass(ftype, s_abc.Reducible): val = val.schema_reduce() objdata[findex] = val else: ftype = mcls.get_field(fn).type if type(v) is not ftype: if issubclass(ftype, verutils.Version): objdata[findex] = _parse_version(v) elif issubclass(ftype, s_name.Name): objdata[findex] = s_name.name_from_string(v) elif ( issubclass(ftype, checked.ParametricContainer) and ftype.types and len(ftype.types) == 1 ): # Coerce the elements in a parametric container # type. # XXX: Or should we do it in the container? subtyp = ftype.types[0] objdata[findex] = ftype( subtyp(x) for x in v) # type: ignore else: objdata[findex] = ftype(v) else: objdata[findex] = v elif desc.is_refdict: ftype = mcls.get_field(fn).type refids = ftype._container(uuidgen.UUID(e['id']) for e in v) for refid in refids: refs_to[refid][mcls, fn][objid] = None val = ftype(refids, _private_init=True) objdata[findex] = val.schema_reduce() if desc.properties: for e_dict in v: refdict_updates[uuidgen.UUID(e_dict['id'])] = { p: pv for p in desc.properties if (pv := e_dict[f'@{p}']) is not None } id_to_data[objid] = tuple(objdata) for objid, updates in refdict_updates.items(): if updates: sclass = s_obj.ObjectMeta.get_schema_class(id_to_type[objid]) updated_data = list(id_to_data[objid]) for fn, v in updates.items(): field = sclass.get_schema_field(fn) updated_data[field.index] = v id_to_data[objid] = tuple(updated_data) refs_to_im = {} for referred_id, ref_data in refs_to.items(): refs_to_im[referred_id] = immutables.Map(( (k, immutables.Map(r)) for k, r in ref_data.items() )) return s_schema.FlatSchema()._replace( id_to_type=immutables.Map(id_to_type), id_to_data=immutables.Map(id_to_data), name_to_id=immutables.Map(name_to_id), shortname_to_id=immutables.Map({ (k, frozenset(v)) for k, v in shortname_to_id.items() }), globalname_to_id=immutables.Map(globalname_to_id), refs_to=immutables.Map(refs_to_im), ) def _parse_expression( val: dict[str, Any], id: uuid.UUID, field: str ) -> s_expr.Expression: refids = frozenset( uuidgen.UUID(r) for r in val['refs'] ) expr = s_expr.Expression( text=val['text'], refs=s_obj.ObjectSet( refids, _private_init=True, ), ) expr.set_origin(id, field) return expr def _parse_version(val: dict[str, Any]) -> verutils.Version: return verutils.Version( major=val['major'], minor=val['minor'], stage=getattr(verutils.VersionStage, val['stage'].upper()), stage_no=val['stage_no'], local=tuple(val['local']), ) ================================================ FILE: edb/schema/reflection/structure.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Sequence, NamedTuple import collections import uuid from edb.common import adapter from edb.common import checked from edb.common import enum from edb.common import verutils from edb.edgeql import qltypes from edb.schema import ddl as s_ddl from edb.schema import delta as sd from edb.schema import expr as s_expr from edb.schema import inheriting as s_inh from edb.schema import links as s_links from edb.schema import name as sn from edb.schema import objects as s_obj from edb.schema import objtypes as s_objtypes from edb.schema import schema as s_schema from edb.schema import types as s_types class FieldType(enum.StrEnum): """Field type tag for fields requiring special handling.""" #: An Expression field. EXPR = 'EXPR' #: An ExpressionList field. EXPR_LIST = 'EXPR_LIST' #: An ExpressionDict field. EXPR_DICT = 'EXPR_DICT' #: An ObjectDict field. OBJ_DICT = 'OBJ_DICT' #: All other field types. OTHER = 'OTHER' class FieldStorage(NamedTuple): """Schema object field storage descriptor.""" #: Field type specifying special handling, if necessary. fieldtype: FieldType #: Pointer kind (property or link) and cardinality (single or multi) ptrkind: str #: Fully-qualified pointer target type. ptrtype: str #: Shadow pointer kind, if any. shadow_ptrkind: Optional[str] = None #: Shadow pointer type, if any. shadow_ptrtype: Optional[str] = None class SchemaFieldDesc(NamedTuple): """Schema object field descriptor.""" type: s_types.Type cardinality: qltypes.SchemaCardinality properties: dict[str, tuple[s_types.Type, FieldType]] fieldname: str schema_fieldname: str is_ordered: bool = False reflection_proxy: Optional[tuple[str, str]] = None storage: Optional[FieldStorage] = None is_refdict: bool = False # N.B: Indexed by schema_fieldname SchemaTypeLayout = dict[str, SchemaFieldDesc] class SchemaReflectionParts(NamedTuple): intro_schema_delta: sd.Command class_layout: dict[type[s_obj.Object], SchemaTypeLayout] local_intro_parts: list[str] global_intro_parts: list[str] def _run_ddl( ddl_text: str, *, schema: s_schema.Schema, delta: sd.Command, ) -> s_schema.Schema: schema, cmd = s_ddl.apply_ddl_script_ex( ddl_text, schema=schema, stdmode=True, internal_schema_mode=True, ) delta.update(cmd.get_subcommands()) return schema def _classify_object_field(field: s_obj.Field[Any]) -> FieldStorage: """Determine FieldStorage for a given schema class field.""" ftype = field.type shadow_ptr_kind = None shadow_ptr_type = None fieldtype = FieldType.OTHER is_array = is_multiprop = False if issubclass(ftype, s_obj.MultiPropSet): is_multiprop = True ftype = ftype.type elif ( issubclass( ftype, (checked.CheckedList, checked.FrozenCheckedList, checked.CheckedSet, checked.FrozenCheckedSet)) and not issubclass(ftype, s_expr.ExpressionList) ): is_array = True ftype = ftype.type # type: ignore if issubclass(ftype, s_obj.ObjectCollection): ptr_kind = 'multi link' ptr_type = 'schema::Object' if issubclass(ftype, s_obj.ObjectDict): fieldtype = FieldType.OBJ_DICT elif issubclass(ftype, s_obj.Object): ptr_kind = 'link' ptr_type = f'schema::{ftype.__name__}' elif issubclass(ftype, s_expr.Expression): shadow_ptr_kind = 'property' shadow_ptr_type = 'tuple>' ptr_kind = 'property' ptr_type = 'str' fieldtype = FieldType.EXPR elif issubclass(ftype, s_expr.ExpressionList): shadow_ptr_kind = 'property' shadow_ptr_type = ( 'array>>' ) ptr_kind = 'property' ptr_type = 'array' fieldtype = FieldType.EXPR_LIST elif issubclass(ftype, s_expr.ExpressionDict): shadow_ptr_kind = 'property' shadow_ptr_type = '''array> >>''' ptr_kind = 'property' ptr_type = 'array>' fieldtype = FieldType.EXPR_DICT elif issubclass(ftype, collections.abc.Mapping): ptr_kind = 'property' ptr_type = 'json' elif issubclass(ftype, (str, sn.Name)): ptr_kind = 'property' ptr_type = 'str' if field.name == 'name': # TODO: consider shadow-reflecting names as tuples shadow_ptr_kind = 'property' shadow_ptr_type = 'str' elif issubclass(ftype, bool): ptr_kind = 'property' ptr_type = 'bool' elif issubclass(ftype, int): ptr_kind = 'property' ptr_type = 'int64' elif issubclass(ftype, uuid.UUID): ptr_kind = 'property' ptr_type = 'uuid' elif issubclass(ftype, verutils.Version): ptr_kind = 'property' ptr_type = ''' tuple< major: std::int64, minor: std::int64, stage: sys::VersionStage, stage_no: std::int64, local: array, > ''' else: raise RuntimeError( f'no metaschema reflection for field {field.name} of type {ftype}' ) if is_multiprop: ptr_kind = 'multi property' if is_array: ptr_type = f'array<{ptr_type}>' return FieldStorage( fieldtype=fieldtype, ptrkind=ptr_kind, ptrtype=ptr_type, shadow_ptrkind=shadow_ptr_kind, shadow_ptrtype=shadow_ptr_type, ) def get_schema_name_for_pycls(py_cls: type[s_obj.Object]) -> sn.Name: py_cls_name = py_cls.__name__ if issubclass(py_cls, s_obj.GlobalObject): # Global objects, like Role and Database live in the sys:: module return sn.QualName(module='sys', name=py_cls_name) else: return sn.QualName(module='schema', name=py_cls_name) def get_default_base_for_pycls(py_cls: type[s_obj.Object]) -> sn.Name: if issubclass(py_cls, s_obj.GlobalObject): # Global objects, like Role and Database live in the sys:: module return sn.QualName(module='sys', name='SystemObject') else: return sn.QualName(module='schema', name='Object') def generate_structure( schema: s_schema.Schema, *, make_funcs: bool=True, patch_level: int=2**30, ) -> SchemaReflectionParts: """Generate schema reflection structure from Python schema classes. If specified, patch_level is the "patch level" of the currently patch being applied during minor version upgrading. All schema objects with patch levels that are higher will be skipped, to avoid adding things created by later patches prematurely. Returns: A quadruple (as a SchemaReflectionParts instance) containing: - Delta, which, when applied to stdlib, yields an enhanced version of the `schema` module that contains all types and properties, not just those that are publicly exposed for introspection. - A mapping, containing type layout description for all schema classes. - A sequence of EdgeQL queries necessary to introspect a database schema. - A sequence of EdgeQL queries necessary to introspect global objects, such as roles and databases. """ delta = sd.DeltaRoot() classlayout: dict[ type[s_obj.Object], SchemaTypeLayout, ] = {} ordered_link = schema.get('schema::ordered', type=s_links.Link) if make_funcs: schema = _run_ddl( ''' CREATE FUNCTION sys::_get_pg_type_for_edgedb_type( typeid: std::uuid, kind: std::str, elemid: OPTIONAL std::uuid, sql_type: OPTIONAL std::str, ) -> std::int64 { USING SQL FUNCTION 'edgedb.get_pg_type_for_edgedb_type'; SET volatility := 'STABLE'; SET impl_is_strict := false; }; CREATE FUNCTION sys::_expr_from_json( data: json ) -> OPTIONAL tuple> { USING SQL $$ SELECT "data"->>'text' AS text, coalesce(r.refs, ARRAY[]::uuid[]) AS refs FROM (SELECT array_agg(v::uuid) AS refs FROM jsonb_array_elements_text("data"->'refs') AS v ) AS r WHERE jsonb_typeof("data") != 'null' $$; SET volatility := 'IMMUTABLE'; }; # A strictly-internal get config function that bypasses # the redaction of secrets in the public-facing one. CREATE FUNCTION cfg::_get_config_json_internal( NAMED ONLY sources: OPTIONAL array = {}, NAMED ONLY max_source: OPTIONAL std::str = {} ) -> std::json { USING SQL $$ SELECT coalesce(jsonb_object_agg(cfg.name, cfg), '{}'::jsonb) FROM edgedb_VER._read_sys_config( sources::edgedb._sys_config_source_t[], max_source::edgedb._sys_config_source_t ) AS cfg $$; }; ''', schema=schema, delta=delta, ) py_classes = [] for py_cls in s_obj.ObjectMeta.get_schema_metaclasses(): if isinstance(py_cls, adapter.Adapter): continue if py_cls is s_obj.GlobalObject: continue if py_cls._patch_level > patch_level: continue py_classes.append(py_cls) read_sets: dict[type[s_obj.Object], list[str]] = {} for py_cls in py_classes: rschema_name = get_schema_name_for_pycls(py_cls) schema_objtype = schema.get( rschema_name, type=s_objtypes.ObjectType, default=None, ) bases = [] for base in py_cls.__bases__: if base in py_classes: bases.append(get_schema_name_for_pycls(base)) default_base = get_default_base_for_pycls(py_cls) if not bases and rschema_name != default_base: bases.append(default_base) reflection = py_cls.get_reflection_method() is_simple_wrapper = issubclass(py_cls, s_types.CollectionExprAlias) if schema_objtype is None: as_abstract = ( reflection is s_obj.ReflectionMethod.REGULAR and not is_simple_wrapper and ( py_cls is s_obj.InternalObject or not issubclass(py_cls, s_obj.InternalObject) ) and py_cls._abstract is not False ) schema = _run_ddl( f''' CREATE {'ABSTRACT' if as_abstract else ''} TYPE {rschema_name} EXTENDING {', '.join(str(b) for b in bases)}; ''', schema=schema, delta=delta, ) schema_objtype = schema.get( rschema_name, type=s_objtypes.ObjectType) else: ex_bases = schema_objtype.get_bases(schema).names(schema) _, added_bases = s_inh.delta_bases( ex_bases, bases, t=type(schema_objtype), ) if added_bases: for subset, position in added_bases: # XXX: Don't generate changes for just moving around the # order of types when the mismatch between python and # the schema, since it doesn't work anyway and causes mass # grief when trying to patch the schema. subset = [x for x in subset if x.name not in ex_bases] if not subset: continue if isinstance(position, tuple): position_clause = ( f'{position[0]} {position[1].name}' ) else: position_clause = position bases_expr = ', '.join(str(t.name) for t in subset) stmt = f''' ALTER TYPE {rschema_name} {{ EXTENDING {bases_expr} {position_clause} }} ''' schema = _run_ddl( stmt, schema=schema, delta=delta, ) if reflection is s_obj.ReflectionMethod.NONE: continue referrers = py_cls.get_referring_classes() if reflection is s_obj.ReflectionMethod.AS_LINK: if not referrers: raise RuntimeError( f'schema class {py_cls.__name__} is declared with AS_LINK ' f'reflection method but is not referenced in any RefDict' ) is_concrete = not schema_objtype.get_abstract(schema) if ( is_concrete and not is_simple_wrapper and any( not b.get_abstract(schema) for b in schema_objtype.get_ancestors(schema).objects(schema) ) ): raise RuntimeError( f'non-abstract {schema_objtype.get_verbosename(schema)} has ' f'non-abstract ancestors' ) read_shape = read_sets[py_cls] = [] if is_concrete: read_shape.append( '_tname := .__type__[IS schema::ObjectType].name' ) classlayout[py_cls] = {} ownfields = py_cls.get_ownfields() for fn, field in py_cls.get_fields().items(): if field.patch_level > patch_level: continue sfn = field.sname if ( field.ephemeral or ( field.reflection_method is not s_obj.ReflectionMethod.REGULAR ) ): continue storage = _classify_object_field(field) ptr = schema_objtype.maybe_get_ptr(schema, sn.UnqualName(sfn)) if fn in ownfields: qual = "REQUIRED" if field.required else "OPTIONAL" otd = " { ON TARGET DELETE ALLOW }" if field.weak_ref else "" if ptr is None: schema = _run_ddl( f''' ALTER TYPE {rschema_name} {{ CREATE {qual} {storage.ptrkind} {sfn} -> {storage.ptrtype} {otd}; }} ''', schema=schema, delta=delta, ) ptr = schema_objtype.getptr(schema, sn.UnqualName(fn)) if storage.shadow_ptrkind is not None: pn = f'{sfn}__internal' internal_ptr = schema_objtype.maybe_get_ptr( schema, sn.UnqualName(pn)) if internal_ptr is None: ptrkind = storage.shadow_ptrkind ptrtype = storage.shadow_ptrtype schema = _run_ddl( f''' ALTER TYPE {rschema_name} {{ CREATE {qual} {ptrkind} {pn} -> {ptrtype}; }} ''', schema=schema, delta=delta, ) else: assert ptr is not None if is_concrete: read_ptr = sfn if field.type_is_generic_self: read_ptr = f'{read_ptr}[IS {rschema_name}]' if field.reflection_proxy: _proxy_type, proxy_link = field.reflection_proxy read_ptr = ( f'{read_ptr}: {{name, value := .{proxy_link}.id}}' ) if ptr.issubclass(schema, ordered_link): read_ptr = f'{read_ptr} ORDER BY @index' read_shape.append(read_ptr) if storage.shadow_ptrkind is not None: read_shape.append(f'{sfn}__internal') if field.reflection_proxy: proxy_type_name, proxy_link_name = field.reflection_proxy proxy_obj = schema.get( proxy_type_name, type=s_objtypes.ObjectType) proxy_link_obj = proxy_obj.getptr( schema, sn.UnqualName(proxy_link_name)) tgt = proxy_link_obj.get_target(schema) else: tgt = ptr.get_target(schema) assert tgt is not None cardinality = ptr.get_cardinality(schema) assert cardinality is not None classlayout[py_cls][sfn] = SchemaFieldDesc( fieldname=fn, schema_fieldname=sfn, type=tgt, cardinality=cardinality, properties={}, storage=storage, is_ordered=ptr.issubclass(schema, ordered_link), reflection_proxy=field.reflection_proxy, ) # Second pass: deal with RefDicts, which are reflected as links. for py_cls in py_classes: rschema_name = get_schema_name_for_pycls(py_cls) schema_cls = schema.get(rschema_name, type=s_objtypes.ObjectType) for refdict in py_cls.get_own_refdicts().values(): if py_cls.get_field(refdict.attr).patch_level > patch_level: continue ref_ptr = schema_cls.maybe_get_ptr( schema, sn.UnqualName(refdict.attr)) ref_cls = refdict.ref_cls assert issubclass(ref_cls, s_obj.Object) shadow_ref_ptr = None reflect_as_link = ( ref_cls.get_reflection_method() is s_obj.ReflectionMethod.AS_LINK ) if reflect_as_link: reflection_link = ref_cls.get_reflection_link() assert reflection_link is not None target_field = ref_cls.get_field(reflection_link) target_cls = target_field.type shadow_pn = f'{refdict.attr}__internal' shadow_ref_ptr = schema_cls.maybe_get_ptr( schema, sn.UnqualName(shadow_pn)) if reflect_as_link and not shadow_ref_ptr: schema = _run_ddl( f''' ALTER TYPE {rschema_name} {{ CREATE OPTIONAL MULTI LINK {shadow_pn} EXTENDING schema::reference -> {get_schema_name_for_pycls(ref_cls)} {{ ON TARGET DELETE ALLOW; }}; }} ''', schema=schema, delta=delta, ) shadow_ref_ptr = schema_cls.getptr( schema, sn.UnqualName(shadow_pn)) else: target_cls = ref_cls if ref_ptr is None: ptr_type = get_schema_name_for_pycls(target_cls) schema = _run_ddl( f''' ALTER TYPE {rschema_name} {{ CREATE OPTIONAL MULTI LINK {refdict.attr} EXTENDING schema::reference -> {ptr_type} {{ ON TARGET DELETE ALLOW; }}; }} ''', schema=schema, delta=delta, ) ref_ptr = schema_cls.getptr( schema, sn.UnqualName(refdict.attr)) assert isinstance(ref_ptr, s_links.Link) if py_cls not in classlayout: classlayout[py_cls] = {} # First, fields declared to be reflected as link properties. props = _get_reflected_link_props( ref_ptr=ref_ptr, target_cls=ref_cls, schema=schema, ) if reflect_as_link: # Then, because it's a passthrough reflection, all scalar # fields of the proxy object. fields_as_props = [ f for f in ref_cls.get_ownfields().values() if ( not f.ephemeral and ( f.reflection_method is not s_obj.ReflectionMethod.AS_LINK ) and f.name != refdict.backref_attr and f.name != ref_cls.get_reflection_link() ) ] extra_props = _classify_scalar_object_fields(fields_as_props) for field, storage in {**props, **extra_props}.items(): sfn = field.sname prop_ptr = ref_ptr.maybe_get_ptr(schema, sn.UnqualName(sfn)) if prop_ptr is None: pty = storage.ptrtype schema = _run_ddl( f''' ALTER TYPE {rschema_name} {{ ALTER LINK {refdict.attr} {{ CREATE OPTIONAL PROPERTY {sfn} -> {pty}; }} }} ''', schema=schema, delta=delta, ) if shadow_ref_ptr is not None: assert isinstance(shadow_ref_ptr, s_links.Link) shadow_pn = shadow_ref_ptr.get_shortname(schema).name for field, storage in props.items(): sfn = field.sname prop_ptr = shadow_ref_ptr.maybe_get_ptr( schema, sn.UnqualName(sfn)) if prop_ptr is None: pty = storage.ptrtype schema = _run_ddl( f''' ALTER TYPE {rschema_name} {{ ALTER LINK {shadow_pn} {{ CREATE OPTIONAL PROPERTY {sfn} -> {pty}; }} }} ''', schema=schema, delta=delta, ) for py_cls in py_classes: rschema_name = get_schema_name_for_pycls(py_cls) schema_cls = schema.get(rschema_name, type=s_objtypes.ObjectType) is_concrete = not schema_cls.get_abstract(schema) read_shape = read_sets[py_cls] for refdict in py_cls.get_refdicts(): if py_cls.get_field(refdict.attr).patch_level > patch_level: continue if py_cls not in classlayout: classlayout[py_cls] = {} ref_ptr = schema_cls.getptr( schema, sn.UnqualName(refdict.attr), type=s_links.Link) assert ref_ptr tgt = ref_ptr.get_target(schema) assert tgt is not None cardinality = ref_ptr.get_cardinality(schema) assert cardinality is not None classlayout[py_cls][refdict.attr] = SchemaFieldDesc( fieldname=refdict.attr, schema_fieldname=refdict.attr, type=tgt, cardinality=cardinality, properties={}, is_ordered=ref_ptr.issubclass(schema, ordered_link), reflection_proxy=None, is_refdict=True, ) target_cls = refdict.ref_cls props = _get_reflected_link_props( ref_ptr=ref_ptr, target_cls=target_cls, schema=schema, ) reflect_as_link = ( target_cls.get_reflection_method() is s_obj.ReflectionMethod.AS_LINK ) prop_layout = {} extra_prop_layout = {} for field, storage in props.items(): prop_ptr = ref_ptr.getptr(schema, sn.UnqualName(field.sname)) prop_tgt = prop_ptr.get_target(schema) assert prop_tgt is not None prop_layout[field.name] = (prop_tgt, storage.fieldtype) if reflect_as_link: # Then, because it's a passthrough reflection, all scalar # fields of the proxy object. fields_as_props = [ f for f in target_cls.get_ownfields().values() if ( not f.ephemeral and ( f.reflection_method is not s_obj.ReflectionMethod.AS_LINK ) and f.name != refdict.backref_attr and f.name != target_cls.get_reflection_link() ) ] extra_props = _classify_scalar_object_fields(fields_as_props) for field, storage in extra_props.items(): prop_ptr = ref_ptr.getptr( schema, sn.UnqualName(field.sname)) prop_tgt = prop_ptr.get_target(schema) assert prop_tgt is not None extra_prop_layout[field.name] = ( prop_tgt, storage.fieldtype) else: extra_prop_layout = {} classlayout[py_cls][refdict.attr].properties.update({ **prop_layout, **extra_prop_layout, }) if reflect_as_link: shadow_tgt = schema.get( get_schema_name_for_pycls(ref_cls), type=s_objtypes.ObjectType, ) iname = f'{refdict.attr}__internal' classlayout[py_cls][iname] = ( SchemaFieldDesc( fieldname=refdict.attr, schema_fieldname=iname, type=shadow_tgt, cardinality=qltypes.SchemaCardinality.Many, properties=prop_layout, is_refdict=True, ) ) if is_concrete: read_ptr = refdict.attr prop_shape_els = [] if reflect_as_link: read_ptr = f'{read_ptr}__internal' ref_ptr = schema_cls.getptr( schema, sn.UnqualName(f'{refdict.attr}__internal'), ) for field in props: sfn = field.sname prop_shape_els.append(f'@{sfn}') if prop_shape_els: prop_shape = ',\n'.join(prop_shape_els) read_ptr = f'{read_ptr}: {{id, {prop_shape}}}' if ref_ptr.issubclass(schema, ordered_link): read_ptr = f'{read_ptr} ORDER BY @index' read_shape.append(read_ptr) local_parts = [] global_parts = [] for py_cls, shape_els in read_sets.items(): if ( not shape_els # The CollectionExprAlias family needs to be excluded # because TupleExprAlias and ArrayExprAlias inherit from # concrete classes and so are picked up from those. or issubclass(py_cls, s_types.CollectionExprAlias) ): continue rschema_name = get_schema_name_for_pycls(py_cls) shape = ',\n'.join(shape_els) qry = f''' SELECT {rschema_name} {{ {shape} }} ''' if not issubclass(py_cls, (s_types.Collection, s_obj.GlobalObject)): qry += ' FILTER NOT .builtin' if issubclass(py_cls, s_obj.GlobalObject): global_parts.append(qry) else: local_parts.append(qry) delta.canonical = True return SchemaReflectionParts( intro_schema_delta=delta, class_layout=classlayout, local_intro_parts=local_parts, global_intro_parts=global_parts, ) def _get_reflected_link_props( *, ref_ptr: s_links.Link, target_cls: type[s_obj.Object], schema: s_schema.Schema, ) -> dict[s_obj.Field[Any], FieldStorage]: fields = [ f for f in target_cls.get_fields().values() if ( not f.ephemeral and ( f.reflection_method is s_obj.ReflectionMethod.AS_LINK ) ) ] return _classify_scalar_object_fields(fields) def _classify_scalar_object_fields( fields: Sequence[s_obj.Field[Any]], ) -> dict[s_obj.Field[Any], FieldStorage]: props = {} for field in fields: fn = field.name storage = _classify_object_field(field) if storage.ptrkind != 'property' and fn != 'id': continue props[field] = storage return props ================================================ FILE: edb/schema/reflection/writer.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Schema reflection helpers.""" from __future__ import annotations from typing import ( Any, Callable, Optional, Collection, cast, ) import functools import json import numbers import textwrap from edb.edgeql import qltypes from edb.schema import constraints as s_constr from edb.schema import delta as sd from edb.schema import extensions as s_ext from edb.schema import objects as so from edb.schema import objtypes as s_objtypes from edb.schema import referencing as s_ref from edb.schema import scalars as s_scalars from edb.schema import schema as s_schema from edb.schema import types as s_types from edb.schema.reflection import structure as sr_struct @functools.singledispatch def generate_metadata_write_edgeql( cmd: sd.Command, *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, context: sd.CommandContext, blocks: list[tuple[str, dict[str, Any]]], internal_schema_mode: bool, stdmode: bool, ) -> None: _hoist_if_unused_deletes(cmd) return write_meta( cmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, internal_schema_mode=internal_schema_mode, stdmode=stdmode) def _hoist_if_unused_deletes( cmd: sd.Command, target: Optional[sd.DeleteObject[so.Object]] = None, ) -> None: """Hoist up if_unused deletes higher in the tree. if_unused deletes for things like union and collection types need to be done *after* the referring object that triggered the deletion is deleted. There is special handling in the write_meta case for DeleteObject to support this, but we need to also handle the case where the delete of the union/collection happens down in a nested delete of a child. Work around this by hoisting up if_unused to the outermost enclosing delete. (We can't just hoist to the actual toplevel, because that might move the command after something that needs to go *after*, like a delete of one of the union components.) Don't hoist the if_unused all the way *outside* an extension. We want the effects of deleting an extension to be contained in the DeleteExtension command. If there are union/collection types used outside this extension, they won't be deleted. If the union/collection types are used only by this extension, there is a chance that they also rely on the types *from* the extension. This means that it will be impossible to delete the base types if we defer deleting the union/collection types until all extension content is removed. FIXME: Could we instead *generate* the deletions at the outermost point? """ new_target = target if ( not new_target and isinstance(cmd, sd.DeleteObject) and not isinstance(cmd, s_ext.DeleteExtension) ): new_target = cmd for sub in cmd.get_subcommands(): if ( isinstance(sub, sd.DeleteObject) and target and sub.if_unused ): cmd.discard(sub) target.add_caused(sub) else: _hoist_if_unused_deletes(sub, new_target) @functools.singledispatch def write_meta( cmd: sd.Command, *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, context: sd.CommandContext, blocks: list[tuple[str, dict[str, Any]]], internal_schema_mode: bool, stdmode: bool, ) -> None: """Generate EdgeQL statements populating schema metadata. Args: cmd: Delta command tree for which EdgeQL DML must be generated. classlayout: Schema class layout as returned from :func:`schema.reflection.structure.generate_structure`. schema: A schema instance. context: Delta context corresponding to *cmd*. blocks: A list where a sequence of (edgeql, args) tuples will be appended to. internal_schema_mode: If True, *cmd* represents internal `schema` modifications. stdmode: If True, *cmd* represents a standard library bootstrap DDL. """ raise NotImplementedError(f"cannot handle {cmd!r}") def _descend( cmd: sd.Command, *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, context: sd.CommandContext, blocks: list[tuple[str, dict[str, Any]]], internal_schema_mode: bool, stdmode: bool, prerequisites: bool = False, cmd_filter: Optional[Callable[[sd.Command], bool]] = None, ) -> None: if prerequisites: commands = cmd.get_prerequisites() else: commands = cmd.get_subcommands(include_prerequisites=False) if cmd_filter: commands = tuple(filter(cmd_filter, commands)) def _write_subcommands(commands: Collection[sd.Command]) -> None: for subcmd in commands: if not isinstance(subcmd, sd.AlterObjectProperty): write_meta( subcmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, internal_schema_mode=internal_schema_mode, stdmode=stdmode ) ctxcls = cmd.get_context_class() if ctxcls is not None: if ( issubclass(ctxcls, sd.ObjectCommandContext) and isinstance(cmd, sd.ObjectCommand) ): objctxcls = cast( type[sd.ObjectCommandContext[so.Object]], ctxcls, ) ctx = objctxcls(schema=schema, op=cmd, scls=sd._dummy_object) else: # I could not find a way to convince mypy here. ctx = ctxcls(schema=schema, op=cmd) # type: ignore with context(ctx): _write_subcommands(commands) else: _write_subcommands(commands) @write_meta.register def write_meta_delta_root( cmd: sd.DeltaRoot, *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, context: sd.CommandContext, blocks: list[tuple[str, dict[str, Any]]], internal_schema_mode: bool, stdmode: bool, ) -> None: _descend( cmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, internal_schema_mode=internal_schema_mode, stdmode=stdmode, ) def _build_object_mutation_shape( cmd: sd.ObjectCommand[so.Object], *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], lprop_fields: Optional[ dict[str, tuple[s_types.Type, sr_struct.FieldType]] ] = None, lprops_only: bool = False, internal_schema_mode: bool, stdmode: bool, var_prefix: str = '', schema: s_schema.Schema, context: sd.CommandContext, ) -> tuple[str, dict[str, Any]]: props = cmd.get_resolved_attributes(schema, context) mcls = cmd.get_schema_metaclass() layout = classlayout[mcls] if lprop_fields is None: lprop_fields = {} # XXX: This is a hack around the fact that _update_lprops works by # removing all the links and recreating them. Since that will lose # data in situations where not every lprop attribute is specified, # merge AlterOwned props up into the enclosing command. (This avoids # trouble with annotations, which is the main place where we have # multiple interesting lprops at once.) if isinstance(cmd, s_ref.AlterOwned): return '', {} for sub in cmd.get_subcommands(type=s_ref.AlterOwned): props.update(sub.get_resolved_attributes(schema, context)) assignments = [] variables: dict[str, str] = {} if isinstance(cmd, sd.CreateObject): empties = { v.fieldname: None for f, v in layout.items() if ( f != 'backend_id' and v.storage is not None and v.storage.ptrkind != 'link' and v.storage.ptrkind != 'multi link' ) } all_props = {**empties, **props} else: all_props = props for n, v in sorted(all_props.items(), key=lambda i: i[0]): ns = mcls.get_field(n).sname lprop_target = lprop_fields.get(n) if lprop_target is not None: target, ftype = lprop_target cardinality = qltypes.SchemaCardinality.One is_ordered = False reflection_proxy = None elif lprops_only: continue else: layout_entry = layout.get(ns) if layout_entry is None: # The field is ephemeral, skip it. continue else: target = layout_entry.type cardinality = layout_entry.cardinality is_ordered = layout_entry.is_ordered reflection_proxy = layout_entry.reflection_proxy assert layout_entry.storage is not None ftype = layout_entry.storage.fieldtype target_value: Any var_n = f'__{var_prefix}{n}' if ( issubclass(mcls, s_constr.Constraint) and n == 'params' and isinstance(cmd, s_ref.ReferencedObjectCommand) and cmd.get_referrer_context(context) is not None ): # Constraint args are represented as a `@value` link property # on the `params` link. # TODO: replace this hack by a generic implementation of # an ObjectKeyDict collection that allow associating objects # with arbitrary values (a transposed ObjectDict). target_expr = f"""assert_distinct(( FOR v IN {{ enumerate(json_array_unpack(${var_n})) }} UNION ( SELECT {target.get_name(schema)} {{ @index := v.0, @value := v.1[1], }} FILTER .id = v.1[0] ) ))""" args = props.get('args', []) target_value = [] if v is not None: for i, param in enumerate(v.objects(schema)): if i == 0: # skip the implicit __subject__ parameter arg_expr = '' else: try: arg = args[i - 1] except IndexError: arg_expr = '' else: pkind = param.get_kind(schema) if pkind is qltypes.ParameterKind.VariadicParam: rest = [arg.text for arg in args[i - 1:]] arg_expr = f'[{",".join(rest)}]' else: arg_expr = arg.text target_value.append((str(param.id), arg_expr)) elif n == 'name': target_expr = f'${var_n}' assignments.append(f'{ns}__internal := ${var_n}__internal') if v is not None: target_value = mcls.get_displayname_static(v) variables[f'{var_n}__internal'] = json.dumps(str(v)) else: target_value = None variables[f'{var_n}__internal'] = json.dumps(None) elif isinstance(target, s_objtypes.ObjectType): if cardinality is qltypes.SchemaCardinality.Many: if ftype is sr_struct.FieldType.OBJ_DICT: target_expr, target_value = _reflect_object_dict_value( schema=schema, value=v, is_ordered=is_ordered, value_var_name=var_n, target=target, reflection_proxy=reflection_proxy, ) elif is_ordered: target_expr = f'''( FOR v IN {{ enumerate(assert_distinct( json_array_unpack(${var_n}) )) }} UNION ( SELECT (DETACHED {target.get_name(schema)}) {{ @index := v.0, }} FILTER .id = v.1 ) )''' if v is not None: target_value = [str(i) for i in v.ids()] else: target_value = [] else: target_expr = f'''( SELECT (DETACHED {target.get_name(schema)}) FILTER .id IN json_array_unpack(${var_n}) )''' if v is not None: target_value = [str(i) for i in v.ids()] else: target_value = [] else: target_expr = f'''( SELECT (DETACHED {target.get_name(schema)}) FILTER .id = ${var_n} )''' if v is not None: target_value = str(v.id) else: target_value = None elif ftype is sr_struct.FieldType.EXPR: target_expr = f'${var_n}' if v is not None: target_value = v.text else: target_value = None shadow_target_expr = ( f'sys::_expr_from_json(${var_n}_expr)' ) assignments.append(f'{ns}__internal := {shadow_target_expr}') if v is not None: ids = [str(i) for i in v.refs.ids()] variables[f'{var_n}_expr'] = json.dumps( {'text': v.text, 'refs': ids} ) else: variables[f'{var_n}_expr'] = json.dumps(None) elif ftype is sr_struct.FieldType.EXPR_LIST: target_expr = f''' array_agg(json_array_unpack(${var_n})["text"]) ''' if v is not None: target_value = [ { 'text': ex.text, 'refs': ( [str(i) for i in ex.refs.ids()] if ex.refs else [] ) } for ex in v ] else: target_value = [] shadow_target_expr = f''' (SELECT array_agg( sys::_expr_from_json( json_array_unpack(${var_n}) ) ) ) ''' assignments.append(f'{ns}__internal := {shadow_target_expr}') elif ftype is sr_struct.FieldType.EXPR_DICT: target_expr = f''' ( WITH orig_json := json_array_unpack(${var_n}) SELECT array_agg(( for orig_json in orig_json union ( name := orig_json['name'], expr := orig_json['expr']['text'], ) )) ) ''' if v is not None: target_value = [ { 'name': key, 'expr': { 'text': ex.text, 'refs': ( [str(i) for i in ex.refs.ids()] if ex.refs else [] ) } } for key, ex in v.items() ] else: target_value = [] shadow_target_expr = f''' ( WITH orig_json := json_array_unpack(${var_n}) SELECT array_agg(( for orig_json in orig_json union ( name := orig_json['name'], expr := sys::_expr_from_json( orig_json['expr'] ) ) )) ) ''' assignments.append(f'{ns}__internal := {shadow_target_expr}') elif isinstance(target, s_types.Array): eltype = target.get_element_type(schema) target_expr = f''' array_agg(<{eltype.get_name(schema)}> json_array_unpack(${var_n})) IF json_typeof(${var_n}) != 'null' ELSE >{{}} ''' if v is not None: target_value = list(v) else: target_value = None else: target_expr = f'${var_n}' if cardinality and cardinality.is_multi(): target_expr = f'json_array_unpack({target_expr})' if target.is_enum(schema): target_expr = f'{target_expr}' target_expr = f'<{target.get_name(schema)}>{target_expr}' if v is not None and cardinality.is_multi(): target_value = list(v) elif v is None or isinstance(v, numbers.Number): target_value = v else: target_value = str(v) if lprop_target is not None: assignments.append(f'@{ns} := {target_expr}') else: assignments.append(f'{ns} := {target_expr}') variables[var_n] = json.dumps(target_value) object_actually_exists = schema.has_object(cmd.scls.id) if ( isinstance(cmd, sd.CreateObject) and object_actually_exists and issubclass(mcls, (s_scalars.ScalarType, s_types.Collection)) and not issubclass(mcls, s_types.CollectionExprAlias) and not cmd.get_attribute_value('abstract') and not cmd.get_attribute_value('transient') and not cmd.has_attribute_value('backend_id') ): kind = f'"schema::{mcls.__name__}"' if issubclass(mcls, (s_types.Array, s_types.Range, s_types.MultiRange)): assignments.append( f'backend_id := sys::_get_pg_type_for_edgedb_type(' f'$__{var_prefix}id, ' f'{kind}, ' f'$__{var_prefix}element_type, ' f'$__{var_prefix}sql_type2), ' ) else: assignments.append( f'backend_id := sys::_get_pg_type_for_edgedb_type(' f'$__{var_prefix}id, {kind}, {{}}, ' f'$__{var_prefix}sql_type2), ' ) sql_type = None if isinstance(cmd.scls, s_scalars.ScalarType): sql_type, _ = cmd.scls.resolve_sql_type_scheme(schema) variables[f'__{var_prefix}id'] = json.dumps( str(cmd.get_attribute_value('id')) ) variables[f'__{var_prefix}sql_type2'] = json.dumps(sql_type) shape = ',\n'.join(assignments) return shape, variables def _reflect_object_dict_value( *, schema: s_schema.Schema, value: Optional[so.ObjectDict[str, so.Object]], is_ordered: bool, value_var_name: str, target: s_types.Type, reflection_proxy: Optional[tuple[str, str]], ) -> tuple[str, Any]: if reflection_proxy is not None: # Non-unique ObjectDict, reflecting via a proxy object proxy_type, proxy_link = reflection_proxy if is_ordered: target_expr = f'''( FOR v IN {{ enumerate( json_array_unpack(${value_var_name}) ) }} UNION ( INSERT {proxy_type} {{ {proxy_link} := ( SELECT (DETACHED {target.get_name(schema)}) FILTER .id = v.1[1] ), name := v.1[0], @index := v.0, }} ) )''' else: target_expr = f'''( FOR v IN {{ json_array_unpack(${value_var_name}) }} UNION ( INSERT {proxy_type} {{ {proxy_link} := ( SELECT (DETACHED {target.get_name(schema)}) FILTER .id = v[1] ), name := v[0], }} ) )''' else: if is_ordered: target_expr = f'''( FOR v IN {{ enumerate( json_array_unpack(${value_var_name}) ) }} UNION ( SELECT (DETACHED {target.get_name(schema)}) {{ name := v.1[0], @index := v.0, }} FILTER .id = v.1[1] ) )''' else: target_expr = f'''( FOR v IN {{ json_array_unpack(${value_var_name}) }} UNION ( SELECT (DETACHED {target.get_name(schema)}) {{ @key := v[0], }} FILTER .id = v[1] ) )''' if value is None: target_value = [] else: target_value = [(n, str(i.id)) for n, i in value.items(schema)] return target_expr, target_value # type ignore below because mypy's wishes of generic parametrization # clash with the expectations of singledispatch receiving an actual type. @write_meta.register def write_meta_create_object( cmd: sd.CreateObject, # type: ignore *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, context: sd.CommandContext, blocks: list[tuple[str, dict[str, Any]]], internal_schema_mode: bool, stdmode: bool, ) -> None: _descend( cmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, prerequisites=True, internal_schema_mode=internal_schema_mode, stdmode=stdmode, ) mcls = cmd.maybe_get_schema_metaclass() if mcls is not None and not issubclass(mcls, so.GlobalObject): if isinstance(cmd, s_ref.ReferencedObjectCommand): refctx = cmd.get_referrer_context(context) else: refctx = None if refctx is None: shape, variables = _build_object_mutation_shape( cmd, classlayout=classlayout, internal_schema_mode=internal_schema_mode, stdmode=stdmode, schema=schema, context=context, ) insert_query = f''' INSERT schema::{mcls.__name__} {{ {shape} }} ''' blocks.append((insert_query, variables)) else: refop = refctx.op refcls = refop.get_schema_metaclass() refdict = refcls.get_refdict_for_class(mcls) layout = classlayout[refcls][refdict.attr] lprops = layout.properties reflect_as_link = ( mcls.get_reflection_method() is so.ReflectionMethod.AS_LINK ) shape, variables = _build_object_mutation_shape( cmd, classlayout=classlayout, lprop_fields=lprops, lprops_only=reflect_as_link, internal_schema_mode=internal_schema_mode, stdmode=stdmode, schema=schema, context=context, ) assignments = [] if reflect_as_link: target_link = mcls.get_reflection_link() assert target_link is not None target_field = mcls.get_field(target_link) target = cmd.get_attribute_value(target_link) append_query = f''' SELECT DETACHED schema::{target_field.type.__name__} {{ {shape} }} FILTER .name__internal = $__{target_link} ''' variables[f'__{target_link}'] = ( json.dumps(str(target.get_name(schema))) ) shadow_clslayout = classlayout[refcls] shadow_link_layout = ( shadow_clslayout[f'{refdict.attr}__internal']) shadow_shape, shadow_variables = _build_object_mutation_shape( cmd, classlayout=classlayout, internal_schema_mode=internal_schema_mode, lprop_fields=shadow_link_layout.properties, stdmode=stdmode, var_prefix='shadow_', schema=schema, context=context, ) variables.update(shadow_variables) shadow_append_query = f''' INSERT schema::{mcls.__name__} {{ {shadow_shape} }} ''' assignments.append(f''' {refdict.attr}__internal += ( {shadow_append_query} ) ''') else: append_query = f''' INSERT schema::{mcls.__name__} {{ {shape} }} ''' assignments.append(f''' {refdict.attr} += ( {append_query} ) ''') update_shape = ',\n'.join(assignments) parent_update_query = f''' UPDATE schema::{refcls.__name__} FILTER .name__internal = $__parent_classname SET {{ {update_shape} }} ''' ref_name = context.get_referrer_name(refctx) variables['__parent_classname'] = json.dumps(str(ref_name)) blocks.append((parent_update_query, variables)) _descend( cmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, internal_schema_mode=internal_schema_mode, stdmode=stdmode, ) @write_meta.register def write_meta_alter_object( cmd: sd.ObjectCommand, # type: ignore *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, context: sd.CommandContext, blocks: list[tuple[str, dict[str, Any]]], internal_schema_mode: bool, stdmode: bool, ) -> None: _descend( cmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, prerequisites=True, internal_schema_mode=internal_schema_mode, stdmode=stdmode, ) mcls = cmd.maybe_get_schema_metaclass() if mcls is not None and not issubclass(mcls, so.GlobalObject): shape, variables = _build_object_mutation_shape( cmd, classlayout=classlayout, internal_schema_mode=internal_schema_mode, stdmode=stdmode, schema=schema, context=context, ) if shape: query = f''' UPDATE schema::{mcls.__name__} FILTER .name__internal = $__classname SET {{ {shape} }}; ''' variables['__classname'] = json.dumps(str(cmd.classname)) blocks.append((query, variables)) if isinstance(cmd, s_ref.ReferencedObjectCommand): refctx = cmd.get_referrer_context(context) if refctx is not None: _update_lprops( cmd, classlayout=classlayout, schema=schema, blocks=blocks, context=context, internal_schema_mode=internal_schema_mode, stdmode=stdmode, ) _descend( cmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, internal_schema_mode=internal_schema_mode, stdmode=stdmode, ) def _update_lprops( cmd: s_ref.ReferencedObjectCommand, # type: ignore *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, blocks: list[tuple[str, dict[str, Any]]], context: sd.CommandContext, internal_schema_mode: bool, stdmode: bool, ) -> None: mcls = cmd.get_schema_metaclass() refctx = cmd.get_referrer_context_or_die(context) refop = refctx.op refcls = refop.get_schema_metaclass() refdict = refcls.get_refdict_for_class(mcls) layout = classlayout[refcls][refdict.attr] lprops = layout.properties if not lprops: return reflect_as_link = ( mcls.get_reflection_method() is so.ReflectionMethod.AS_LINK ) # N.B: For reflect_as_link AlterObjects, we depend on all of the # relevant fields having been populated in the command, which is # done by _populate_link_reflection_fields. if reflect_as_link: target_link = mcls.get_reflection_link() assert target_link is not None target_field = mcls.get_field(target_link) target_obj = cmd.get_ddl_identity(target_link) if target_obj is None: raise AssertionError( f'cannot find link target in ddl_identity of a command for ' f'schema class reflected as link: {cmd!r}' ) target_clsname = target_field.type.__name__ else: referrer_cls = refop.get_schema_metaclass() target_field = referrer_cls.get_field(refdict.attr) if issubclass(target_field.type, so.ObjectCollection): target_type = target_field.type.type else: target_type = target_field.type target_clsname = target_type.__name__ target_link = refdict.attr target_obj = cmd.scls shape, append_variables = _build_object_mutation_shape( cmd, classlayout=classlayout, lprop_fields=lprops, lprops_only=True, internal_schema_mode=internal_schema_mode, stdmode=stdmode, schema=schema, context=context, ) if shape: parent_variables = {} parent_variables[f'__{target_link}'] = json.dumps(str(target_obj.id)) ref_name = context.get_referrer_name(refctx) parent_variables['__parent_classname'] = json.dumps(str(ref_name)) # XXX: we have to do a -= followed by a += because # support for filtered nested link property updates # is currently broken. # This is fragile! If not all of the lprops are specified, # we will drop them. assignments = [] assignments.append(textwrap.dedent( f'''\ {refdict.attr} -= ( SELECT DETACHED (schema::{target_clsname}) FILTER .id = $__{target_link} )''' )) if reflect_as_link: parent_variables[f'__{target_link}_shadow'] = ( json.dumps(str(cmd.classname))) assignments.append(textwrap.dedent( f'''\ {refdict.attr}__internal -= ( SELECT DETACHED (schema::{mcls.__name__}) FILTER .name__internal = $__{target_link}_shadow )''' )) update_shape = textwrap.indent( '\n' + ',\n'.join(assignments), ' ' * 4) parent_update_query = textwrap.dedent(f'''\ UPDATE schema::{refcls.__name__} FILTER .name__internal = $__parent_classname SET {{{update_shape} }} ''') blocks.append((parent_update_query, parent_variables)) assignments = [] shape = textwrap.indent(f'\n{shape}', ' ' * 5) assignments.append(textwrap.dedent( f'''\ {refdict.attr} += ( SELECT DETACHED schema::{target_clsname} {{{shape} }} FILTER .id = $__{target_link} )''' )) if reflect_as_link: shadow_clslayout = classlayout[refcls] shadow_link_layout = shadow_clslayout[f'{refdict.attr}__internal'] shadow_shape, shadow_variables = _build_object_mutation_shape( cmd, classlayout=classlayout, internal_schema_mode=internal_schema_mode, lprop_fields=shadow_link_layout.properties, lprops_only=True, stdmode=stdmode, var_prefix='shadow_', schema=schema, context=context, ) shadow_shape = textwrap.indent(f'\n{shadow_shape}', ' ' * 6) assignments.append(textwrap.dedent( f'''\ {refdict.attr}__internal += ( SELECT DETACHED schema::{mcls.__name__} {{{shadow_shape} }} FILTER .name__internal = $__{target_link}_shadow )''' )) parent_variables.update(shadow_variables) update_shape = textwrap.indent( '\n' + ',\n'.join(assignments), ' ' * 4) parent_update_query = textwrap.dedent(f''' UPDATE schema::{refcls.__name__} FILTER .name__internal = $__parent_classname SET {{{update_shape} }} ''') parent_variables.update(append_variables) blocks.append((parent_update_query, parent_variables)) @write_meta.register def write_meta_delete_object( cmd: sd.DeleteObject, # type: ignore *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, context: sd.CommandContext, blocks: list[tuple[str, dict[str, Any]]], internal_schema_mode: bool, stdmode: bool, ) -> None: _descend( cmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, prerequisites=True, internal_schema_mode=internal_schema_mode, stdmode=stdmode, ) defer_filter = ( lambda cmd: isinstance(cmd, sd.DeleteObject) and cmd.if_unused ) _descend( cmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, internal_schema_mode=internal_schema_mode, stdmode=stdmode, cmd_filter=lambda cmd: not defer_filter(cmd), ) mcls = cmd.maybe_get_schema_metaclass() if mcls is not None and not issubclass(mcls, so.GlobalObject): if isinstance(cmd, s_ref.ReferencedObjectCommand): refctx = cmd.get_referrer_context(context) else: refctx = None if ( refctx is not None and mcls.get_reflection_method() is so.ReflectionMethod.AS_LINK ): refop = refctx.op refcls = refop.get_schema_metaclass() refdict = refcls.get_refdict_for_class(mcls) target_link = mcls.get_reflection_link() assert target_link is not None target_field = mcls.get_field(target_link) target = cmd.get_orig_attribute_value(target_link) parent_variables = {} # N.B: In some cases, like repair, where the delta came # directly from diffing, we might have an ObjectShell # instead of an Object, and so no .id. # In that case, just deal with it, and use name instead. # (XXX: We can't always use name because the target object might # be deleted in some cases?) # # An alternate approach would be to try to always force resolve # the fields in these cases, but the straightforward approaches # seemed like they'd hit more cases than we wanted. if isinstance(target, so.ObjectShell): parent_variables[f'__{target_link}'] = ( json.dumps(str(target.get_name(schema))) ) parent_update_query = f''' UPDATE schema::{refcls.__name__} FILTER .name__internal = $__parent_classname SET {{ {refdict.attr} -= ( SELECT DETACHED (schema::{target_field.type.__name__}) FILTER .name__internal = $__{target_link} ) }} ''' else: parent_variables[f'__{target_link}'] = ( json.dumps(str(target.id)) ) parent_update_query = f''' UPDATE schema::{refcls.__name__} FILTER .name__internal = $__parent_classname SET {{ {refdict.attr} -= ( SELECT DETACHED (schema::{target_field.type.__name__}) FILTER .id = $__{target_link} ) }} ''' ref_name = context.get_referrer_name(refctx) parent_variables['__parent_classname'] = ( json.dumps(str(ref_name)) ) blocks.append((parent_update_query, parent_variables)) # We need to delete any links created via reflection_proxy layout = classlayout[mcls] proxy_links = [ link for link, layout_entry in layout.items() if layout_entry.reflection_proxy ] to_delete = ['D'] + [f'D.{link}' for link in proxy_links] operations = [f'(DELETE {x})' for x in to_delete] query = f''' WITH D := (SELECT schema::{mcls.__name__} FILTER .name__internal = $__classname), SELECT {{{", ".join(operations)}}}; ''' variables = {'__classname': json.dumps(str(cmd.classname))} blocks.append((query, variables)) _descend( cmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, internal_schema_mode=internal_schema_mode, stdmode=stdmode, cmd_filter=defer_filter, ) @write_meta.register def write_meta_rename_object( cmd: sd.RenameObject, # type: ignore *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, context: sd.CommandContext, blocks: list[tuple[str, dict[str, Any]]], internal_schema_mode: bool, stdmode: bool, ) -> None: # Delegate to the more general function, and then record the rename. write_meta_alter_object( cmd, classlayout=classlayout, schema=schema, context=context, blocks=blocks, internal_schema_mode=internal_schema_mode, stdmode=stdmode, ) context.early_renames[cmd.classname] = cmd.new_name @write_meta.register def write_meta_nop( cmd: sd.Nop, *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, context: sd.CommandContext, blocks: list[tuple[str, dict[str, Any]]], internal_schema_mode: bool, stdmode: bool, ) -> None: pass @write_meta.register def write_meta_query( cmd: sd.Query, *, classlayout: dict[type[so.Object], sr_struct.SchemaTypeLayout], schema: s_schema.Schema, context: sd.CommandContext, blocks: list[tuple[str, dict[str, Any]]], internal_schema_mode: bool, stdmode: bool, ) -> None: pass ================================================ FILE: edb/schema/rewrites.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, cast, TYPE_CHECKING from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from . import annos as s_anno from . import delta as sd from . import expr as s_expr from . import name as sn from . import inheriting as s_inheriting from . import objects as so from . import referencing from . import schema as s_schema from . import types as s_types if TYPE_CHECKING: from . import pointers as s_pointers class Rewrite( referencing.NamedReferencedInheritingObject, so.InheritingObject, # Help reflection figure out the right db MRO s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.REWRITE, data_safe=True, ): kind = so.SchemaField( qltypes.RewriteKind, coerce=True, compcoef=0.0, special_ddl_syntax=True, ) # 0.0 because we don't support ALTER yet expr = so.SchemaField( s_expr.Expression, compcoef=0.0, special_ddl_syntax=True, ) subject = so.SchemaField( so.InheritingObject, compcoef=None, inheritable=False ) def should_propagate(self, schema: s_schema.Schema) -> bool: # Rewrites should override rewrites on properties of an extended object # type. But overriding *objects* would be hard, so we just disable # inheritance for rewrites, and do lookups into parent object types # when retrieving them. return False def get_ptr_target(self, schema: s_schema.Schema) -> s_types.Type: pointer: s_pointers.Pointer = cast( 's_pointers.Pointer', self.get_subject(schema)) ptr_target = pointer.get_target(schema) assert ptr_target return ptr_target class RewriteCommandContext( sd.ObjectCommandContext[Rewrite], s_anno.AnnotationSubjectCommandContext, ): pass class RewriteSubjectCommandContext: pass class RewriteSubjectCommand( s_inheriting.InheritingObjectCommand[so.InheritingObjectT], ): pass class RewriteCommand( referencing.NamedReferencedInheritingObjectCommand[Rewrite], s_anno.AnnotationSubjectCommand[Rewrite], context_class=RewriteCommandContext, referrer_context_class=RewriteSubjectCommandContext, ): def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) for field in ('expr',): if (expr := self.get_local_attribute_value(field)) is None: continue self.compile_expr_field( schema, context, field=Rewrite.get_field(field), value=expr, ) return schema def _get_kind( self, schema: s_schema.Schema, ) -> qltypes.RewriteKind: return self.get_attribute_value('kind') or self.scls.get_kind(schema) def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool = False, ) -> s_expr.CompiledExpression: if field.name == 'expr': from edb.common import ast from edb.ir import ast as irast from edb.ir import pathid from . import pointers as s_pointers from . import objtypes as s_objtypes from . import links as s_links parent_ctx = self.get_referrer_context_or_die(context) pointer = parent_ctx.op.scls assert isinstance(pointer, s_pointers.Pointer) source = pointer.get_source(schema) if isinstance(source, s_objtypes.ObjectType): subject = source elif isinstance(source, s_links.Link): subject = source.get_target(schema) assert subject span = self.get_attribute_span('expr') raise errors.SchemaDefinitionError( 'rewrites on link properties are not supported', span=span, ) else: raise NotImplementedError('unsupported rewrite source') # XXX: in_ddl_context_name is disabled for now because # it causes the compiler to reject DML; we might actually # want it for something, though, so we might need to # improve that restriction. # parent_vname = source.get_verbosename(schema) # pol_name = self.get_verbosename(parent=parent_vname) # in_ddl_context_name = pol_name kind = self._get_kind(schema) anchors: dict[str, s_types.Type | pathid.PathId] = {} # __subject__ anchors["__subject__"] = pathid.PathId.from_type( schema, subject, typename=sn.QualName(module="__derived__", name="__subject__"), env=None, ) # __specified__ bool_type = schema.get("std::bool", type=s_types.Type) schema, specified_type = s_types.Tuple.create( schema, named=True, element_types={ pn.name: bool_type for pn in subject.get_pointers(schema).keys(schema) }, ) anchors['__specified__'] = specified_type # __old__ if qltypes.RewriteKind.Update == kind: anchors['__old__'] = pathid.PathId.from_type( schema, subject, typename=sn.QualName(module='__derived__', name='__old__'), env=None, ) singletons = frozenset(anchors.values()) # If the `__specified__` anchor is used, create references to the # matching pointers. # # These references are necessary in order to compute the dependency # and ordering of Rewrite commands when producing DDL. # # If creating Type T with two properties, A and B, such that # A has a Rewrite containing `__specified__.B`. # # Without the references, the DDL may look like: # - Create Type T # - Create Property A # - Create Rewrite using __specified__.B # - Create Property B # # This will cause an issue when compiling the Rewrite. At that # point, the schema will not know about B and so the tuple will not # have element `.B`. # # The reference will cause the reordering of commands and the DDL # may instead look like: # - Create Object O # - Create Property A # - Create Property B # - Alter Property A # - Create Rewrite using __specified__.B # # With Create Rewrite ordered after Property B, the tuple for # `__specified__` will correctly have element `.B`. def find_extra_refs(ir_expr: irast.Set) -> set[so.Object]: def find_specified(node: irast.TupleIndirectionPointer) -> bool: return node.source.anchor == '__specified__' ref_ptr_names: set[str] = set() for tuple_node in ast.find_children( ir_expr, irast.TupleIndirectionPointer, test_func=find_specified, ): ref_ptr_names.add(tuple_node.ptrref.name.name) ref_ptrs: set[so.Object] = set( pointer for pointer in subject.get_pointers(schema).objects(schema) if pointer.get_shortname(schema).name in ref_ptr_names ) return ref_ptrs return type(value).compiled( value, schema=schema, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, schema_object_context=self.get_schema_metaclass(), path_prefix_anchor="__subject__", anchors=anchors, singletons=singletons, apply_query_rewrites=not context.stdmode, track_schema_ref_exprs=track_schema_ref_exprs, # in_ddl_context_name=in_ddl_context_name, detached=True, ), find_extra_refs=find_extra_refs, context=context, ) else: return super().compile_expr_field( schema, context, field, value, track_schema_ref_exprs ) def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: if field.name == 'expr': return s_types.type_dummy_expr( self.scls.get_ptr_target(schema), schema) else: raise NotImplementedError(f'unhandled field {field.name!r}') def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: expr: s_expr.Expression = self.scls.get_expr(schema) if not expr.irast: expr = self.compile_expr_field( schema, context, Rewrite.get_field('expr'), expr ) assert expr.irast ir = expr.irast compiled_schema = ir.schema typ: s_types.Type = ir.stype if ( typ.is_view(compiled_schema) # Using an alias/global always creates a new subtype view, # but we want to allow those here, so check whether there # is a shape more directly. and not ( len(shape := ir.view_shapes.get(typ, [])) == 1 and shape[0].is_id_pointer(compiled_schema) ) ): span = self.get_attribute_span('expr') raise errors.SchemaDefinitionError( f'rewrite expression may not include a shape', span=span, ) ptr_target = self.scls.get_ptr_target(compiled_schema) if not typ.assignment_castable_to(ptr_target, compiled_schema): span = self.get_attribute_span('expr') raise errors.SchemaDefinitionError( f'rewrite expression is of invalid type: ' f'{typ.get_displayname(compiled_schema)}, ' f'expected {ptr_target.get_displayname(compiled_schema)}', span=span, ) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: """ Converts a single `qlast.RewriteCommand` into multiple `schema.RewriteCommand`s, one for each kind. """ group = sd.CommandGroup() assert isinstance(astnode, qlast.RewriteCommand) for kind in astnode.kinds: # use kind for the name newnode = astnode.replace( name=qlast.ObjectRef(module='__', name=str(kind)), kinds=kind, ) cmd = super()._cmd_tree_from_ast(schema, newnode, context) assert isinstance(cmd, RewriteCommand) cmd.set_attribute_value('kind', kind) group.add(cmd) return group class CreateRewrite( RewriteCommand, referencing.CreateReferencedInheritingObject[Rewrite], ): referenced_astnode = astnode = qlast.CreateRewrite def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if field in ('kind', 'expr') and issubclass( astnode, qlast.CreateRewrite ): return field else: return super().get_ast_attr_for_field(field, astnode) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: group = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(group, sd.CommandGroup) assert isinstance(astnode, qlast.CreateRewrite) for cmd in group.ops: assert isinstance(cmd, CreateRewrite) cmd.set_attribute_value( 'expr', s_expr.Expression.from_ast( astnode.expr, schema, context.modaliases, context.localnames, ), span=astnode.expr.span, ) return group def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: if op.property == 'kind': assert isinstance(node, qlast.CreateRewrite) node.kinds = [self.get_attribute_value('kind')] else: super()._apply_field_ast(schema, context, node, op) class RebaseRewrite( RewriteCommand, referencing.RebaseReferencedInheritingObject[Rewrite], ): pass class RenameRewrite( RewriteCommand, referencing.RenameReferencedInheritingObject[Rewrite], ): pass class AlterRewrite( RewriteCommand, referencing.AlterReferencedInheritingObject[Rewrite], ): referenced_astnode = astnode = qlast.AlterRewrite def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) # TODO: We may wish to support this in the future but it will # take some thought. if self.get_attribute_value( 'owned' ) and not self.get_orig_attribute_value('owned'): raise errors.SchemaDefinitionError( f'cannot alter the definition of inherited trigger ' f'{self.scls.get_displayname(schema)}', span=self.span, ) return schema class DeleteRewrite( RewriteCommand, referencing.DeleteReferencedInheritingObject[Rewrite], ): referenced_astnode = astnode = qlast.DropRewrite def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: node = super()._get_ast(schema, context, parent_node=parent_node) assert isinstance(node, qlast.DropRewrite) skind = sn.shortname_from_fullname(self.classname).name node.kinds = [qltypes.RewriteKind(skind)] return node ================================================ FILE: edb/schema/roles.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, overload, TYPE_CHECKING from edgedb import scram from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.schema import defines as s_def from . import annos as s_anno from . import delta as sd from . import inheriting from . import name as sn from . import objects as so from . import utils if TYPE_CHECKING: from edb.schema import schema as s_schema class Role( so.GlobalObject, so.InheritingObject, s_anno.AnnotationSubject, qlkind=qltypes.SchemaObjectClass.ROLE, data_safe=True, ): superuser = so.SchemaField( bool, default=False, inheritable=False) password = so.SchemaField( str, default=None, allow_ddl_set=True, inheritable=False) password_hash = so.SchemaField( str, default=None, allow_ddl_set=True, ephemeral=True, inheritable=False) permissions = so.SchemaField( so.MultiPropSet[str], default=None, coerce=True, allow_ddl_set=True, obj_names_as_string=True, inheritable=False, ) branches = so.SchemaField( so.MultiPropSet[str], # default=so.MultiPropSet[str]('*'), # default=('*',), coerce=True, allow_ddl_set=True, inheritable=False, ) apply_access_policies_pg_default = so.SchemaField( bool, default=None, allow_ddl_set=True, inheritable=True, ) class RoleCommandContext( sd.ObjectCommandContext[Role], s_anno.AnnotationSubjectCommandContext): pass class RoleCommand( sd.GlobalObjectCommand[Role], inheriting.InheritingObjectCommand[Role], s_anno.AnnotationSubjectCommand[Role], context_class=RoleCommandContext, ): @classmethod def _process_role_body( cls, cmd: sd.Command, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> None: password = cmd.get_attribute_value('password') if password is not None: if cmd.get_attribute_value('password_hash') is not None: raise errors.EdgeQLSyntaxError( 'cannot specify both `password` and `password_hash` in' ' the same statement', span=astnode.span, ) salted_password = scram.build_verifier(password) cmd.set_attribute_value('password', salted_password) password_hash = cmd.get_attribute_value('password_hash') if password_hash is not None: try: scram.parse_verifier(password_hash) except ValueError as e: raise errors.InvalidValueError( e.args[0], span=astnode.span) cmd.set_attribute_value('password', password_hash) @classmethod def _classbases_from_ast( cls, schema: s_schema.Schema, astnode: qlast.ObjectDDL, context: sd.CommandContext, ) -> list[so.ObjectShell[Role]]: result = [] for b in getattr(astnode, 'bases', None) or []: result.append(utils.ast_objref_to_object_shell( b.maintype, metaclass=Role, schema=schema, modaliases=context.modaliases, )) return result def _validate_name( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: name = self.get_attribute_value('name') if len(str(name)) > s_def.MAX_NAME_LENGTH: span = self.get_attribute_span('name') raise errors.SchemaDefinitionError( f'Role names longer than {s_def.MAX_NAME_LENGTH} ' f'characters are not supported', span=span, ) def _validate_permissions( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: if ( self.has_attribute_value('permissions') and (permissions := self.get_attribute_value('permissions')) ): if 'sys::perm::superuser' in permissions: span = self.get_attribute_span('permissions') raise errors.SchemaDefinitionError( f'Permission "sys::perm::superuser" ' f'cannot be explicitly granted.', span=span, ) class CreateRole(RoleCommand, inheriting.CreateInheritingObject[Role]): astnode = qlast.CreateRole @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: assert isinstance(astnode, qlast.CreateRole) cmd = super()._cmd_tree_from_ast(schema, astnode, context) cmd.set_attribute_value('superuser', astnode.superuser) cls._process_role_body(cmd, schema, astnode, context) if not cmd.has_attribute_value('branches'): cmd.set_attribute_value('branches', frozenset(['*'])) return cmd def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if ( field == 'superuser' and issubclass(astnode, qlast.CreateRole) ): return 'superuser' else: return super().get_ast_attr_for_field(field, astnode) def validate_create( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_create(schema, context) self._validate_name(schema, context) self._validate_permissions(schema, context) class RebaseRole(RoleCommand, inheriting.RebaseInheritingObject[Role]): pass class RenameRole(RoleCommand, sd.RenameObject[Role]): pass class AlterRole(RoleCommand, inheriting.AlterInheritingObject[Role]): astnode = qlast.AlterRole @overload def get_object( self, schema: s_schema.Schema, context: sd.CommandContext, *, name: Optional[sn.Name] = None, default: Role | so.NoDefaultT = so.NoDefault, span: Optional[qlast.Span] = None, ) -> Role: ... @overload def get_object( self, schema: s_schema.Schema, context: sd.CommandContext, *, name: Optional[sn.Name] = None, default: None = None, span: Optional[qlast.Span] = None, ) -> Optional[Role]: ... def get_object( self, schema: s_schema.Schema, context: sd.CommandContext, *, name: Optional[sn.Name] = None, default: Role | so.NoDefaultT | None = so.NoDefault, span: Optional[qlast.Span] = None, ) -> Optional[Role]: # On an ALTER ROLE edgedb, if 'edgedb' doesn't exist, fall # back to 'admin'. This mirrors what we do for login and # avoids breaking setup scripts. if name is None and str(self.classname) == 'edgedb': try: return super().get_object( schema, context, span=span, ) except errors.InvalidReferenceError: name = sn.UnqualName('admin') return super().get_object( schema, context, name=name, default=default, span=span, ) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) cls._process_role_body(cmd, schema, astnode, context) return cmd def validate_alter( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_alter(schema, context) self._validate_name(schema, context) self._validate_permissions(schema, context) class DeleteRole(RoleCommand, inheriting.DeleteInheritingObject[Role]): astnode = qlast.DropRole def _validate_legal_command( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super()._validate_legal_command(schema, context) if self.classname.name == s_def.EDGEDB_SUPERUSER: raise errors.ExecutionError( f"role {self.classname.name!r} cannot be dropped" ) ================================================ FILE: edb/schema/scalars.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, Iterable, Sequence, cast from edb import errors from edb.common import checked from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.common.typeutils import downcast from . import annos as s_anno from . import casts as s_casts from . import constraints from . import delta as sd from . import expr as s_expr from . import inheriting from . import name as s_name from . import objects as so from . import schema as s_schema from . import types as s_types from . import utils as s_utils class ScalarType( s_types.InheritingType, constraints.ConsistencySubject, qlkind=qltypes.SchemaObjectClass.SCALAR_TYPE, data_safe=True, ): default = so.SchemaField( s_expr.Expression, default=None, coerce=True, compcoef=0.909, ) enum_values = so.SchemaField( checked.FrozenCheckedList[str], default=None, coerce=True, compcoef=0.8, ) sql_type = so.SchemaField( str, default=None, inheritable=False, compcoef=0.0) # A type scheme for supporting type mods in scalar types. # If present, describes what the sql_type of children scalars # should be, such as 'varchar({__arg_0__})'. sql_type_scheme = so.SchemaField( str, default=None, inheritable=False, compcoef=0.0) # The number of parameters that the type takes. Currently all parameters # must be integer literals. # This is an internal API and might change. num_params = so.SchemaField( int, default=None, inheritable=False, compcoef=0.0, ) # Arguments to fill in a parent type's parameterized type scheme. arg_values = so.SchemaField( checked.FrozenCheckedList[str], default=None, inheritable=False, coerce=True, compcoef=0.0, ) custom_sql_serialization = so.SchemaField( str, default=None, inheritable=False, compcoef=0.0) def is_scalar(self) -> bool: return True def is_concrete_enum(self, schema: s_schema.Schema) -> bool: return any( str(base.get_name(schema)) == 'std::anyenum' for base in self.get_bases(schema).objects(schema) ) def is_base_type( self, schema: s_schema.Schema, ) -> bool: """Returns true of the type has only abstract bases""" bases: Sequence[s_types.Type] = self.get_bases(schema).objects(schema) return all(b.get_abstract(schema) for b in bases) def is_enum(self, schema: s_schema.Schema) -> bool: return bool(self.get_enum_values(schema)) def is_sequence(self, schema: s_schema.Schema) -> bool: seq = schema.get('std::sequence', type=ScalarType) return self.issubclass(schema, seq) def is_polymorphic(self, schema: s_schema.Schema) -> bool: return self.get_abstract(schema) def is_json(self, schema: s_schema.Schema) -> bool: return self.issubclass( schema, schema.get(s_name.QualName('std', 'json'), type=ScalarType), ) def can_accept_constraints(self, schema: s_schema.Schema) -> bool: return not self.is_enum(schema) def _resolve_polymorphic( self, schema: s_schema.Schema, concrete_type: s_types.Type, ) -> Optional[s_types.Type]: if (self.is_polymorphic(schema) and concrete_type.is_scalar() and not concrete_type.is_polymorphic(schema)): return concrete_type return None def _to_nonpolymorphic( self, schema: s_schema.Schema, concrete_type: s_types.Type, ) -> tuple[s_schema.Schema, s_types.Type]: if (not concrete_type.is_polymorphic(schema) and concrete_type.issubclass(schema, self)): return schema, concrete_type raise TypeError( f'cannot interpret {concrete_type.get_name(schema)} ' f'as {self.get_name(schema)}') def _test_polymorphic( self, schema: s_schema.Schema, other: s_types.Type, ) -> bool: if other.is_any(schema): return True else: return self.issubclass(schema, other) def assignment_castable_to( self, other: s_types.Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, ScalarType): return False if self.is_polymorphic(schema) or other.is_polymorphic(schema): return False left = self.get_base_for_cast(schema) right = other.get_base_for_cast(schema) assert isinstance(left, s_types.Type) assert isinstance(right, s_types.Type) return s_casts.is_assignment_castable(schema, left, right) def implicitly_castable_to( self, other: s_types.Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, ScalarType): return False if self.is_polymorphic(schema) or other.is_polymorphic(schema): return False left = self.get_topmost_concrete_base(schema) right = other.get_topmost_concrete_base(schema) assert isinstance(left, s_types.Type) assert isinstance(right, s_types.Type) return s_casts.is_implicitly_castable(schema, left, right) def castable_to( self, other: s_types.Type, schema: s_schema.Schema, ) -> bool: """Determine if any cast exists between self and *other*.""" if not isinstance(other, ScalarType): return False if self.is_polymorphic(schema) or other.is_polymorphic(schema): return False left = self.get_topmost_concrete_base(schema) right = other.get_topmost_concrete_base(schema) assert isinstance(left, s_types.Type) assert isinstance(right, s_types.Type) return s_casts.is_castable(schema, left, right) def get_implicit_cast_distance( self, other: s_types.Type, schema: s_schema.Schema, ) -> int: if not isinstance(other, ScalarType): return -1 if self.is_polymorphic(schema) or other.is_polymorphic(schema): return -1 left = self.get_topmost_concrete_base(schema) right = other.get_topmost_concrete_base(schema) return s_casts.get_implicit_cast_distance(schema, left, right) def find_common_implicitly_castable_type( self, other: s_types.Type, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Optional[ScalarType]]: if not isinstance(other, ScalarType): return schema, None if self.is_polymorphic(schema) and other.is_polymorphic(schema): return schema, self left = self.get_topmost_concrete_base(schema) right = other.get_topmost_concrete_base(schema) if left == right: return schema, left else: return ( schema, cast( Optional[ScalarType], s_casts.find_common_castable_type(schema, left, right), ) ) def get_base_for_cast(self, schema: s_schema.Schema) -> so.Object: if self.is_enum(schema): # all enums have to use std::anyenum as base type for casts return schema.get('std::anyenum') else: return super().get_base_for_cast(schema) def get_verbosename( self, schema: s_schema.Schema, *, with_parent: bool = False ) -> str: if self.is_enum(schema): clsname = 'enumerated type' else: clsname = self.get_schema_class_displayname() dname = self.get_displayname(schema) return f"{clsname} '{dname}'" def resolve_sql_type_scheme( self, schema: s_schema.Schema, ) -> tuple[Optional[str], Optional[str]]: if sql := self.get_sql_type(schema): return sql, None if self.get_arg_values(schema) is None: return None, None bases = self.get_bases(schema).objects(schema) if len(bases) != 1: return None, None if scheme := bases[0].get_sql_type_scheme(schema): base_sql_type = bases[0].get_sql_type(schema) assert base_sql_type is not None return base_sql_type, scheme return None, None def resolve_sql_type( self, schema: s_schema.Schema, ) -> Optional[str]: type, scheme = self.resolve_sql_type_scheme(schema) if scheme: return constraints.interpolate_error_text( scheme, { f'__arg_{i}__': v for i, v in enumerate(self.get_arg_values(schema) or ()) }, ) else: return type def as_alter_delta( self, other: ScalarType, *, self_schema: s_schema.Schema, other_schema: s_schema.Schema, confidence: float, context: so.ComparisonContext, ) -> sd.ObjectCommand[ScalarType]: alter = super().as_alter_delta( other, self_schema=self_schema, other_schema=other_schema, confidence=confidence, context=context, ) # If this is an enum and enum_values changed, we need to # generate a rebase. old_enum_values = self.get_enum_values(self_schema) enum_values = alter.get_local_attribute_value('enum_values') if old_enum_values and enum_values: assert isinstance(alter.classname, s_name.QualName) rebase = RebaseScalarType( classname=alter.classname, removed_bases=(), added_bases=( ([AnonymousEnumTypeShell(elements=enum_values)], ''), ), ) alter.add(rebase) # Changing enum_values is the respoinsiblity of the rebase command. # Either it's in the one we synthesized above, or, the rebase is doomed # to throw. When we run the ddl directly, the ALTER will not have a # enum_values set, so discard here for symmetry. alter.discard_attribute('enum_values') return alter class AnonymousEnumTypeShell(s_types.TypeShell[ScalarType]): elements: Sequence[str] def __init__( self, *, name: Optional[s_name.Name] = None, elements: Iterable[str], ) -> None: name = name or s_name.QualName(module='std', name='anyenum') super().__init__(name=name, schemaclass=ScalarType) self.elements = list(elements) def resolve(self, schema: s_schema.Schema) -> ScalarType: raise errors.InvalidPropertyDefinitionError( 'this type cannot be anonymous', details=( 'you may want define this enum first:\n\n' ' scalar type MyEnum extending enum<...>;' ), ) class ScalarTypeCommandContext(sd.ObjectCommandContext[ScalarType], s_anno.AnnotationSubjectCommandContext, constraints.ConsistencySubjectCommandContext): pass class ScalarTypeCommand( s_types.InheritingTypeCommand[ScalarType], constraints.ConsistencySubjectCommand[ScalarType], s_anno.AnnotationSubjectCommand[ScalarType], context_class=ScalarTypeCommandContext, ): def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext ) -> None: if ( self.scls.resolve_sql_type_scheme(schema)[0] ): if len(self.scls.get_constraints(schema)): raise errors.SchemaError( f'parameterized scalar types may not have constraints', span=self.span, ) if args := self.scls.get_arg_values(schema): base = self.scls.get_bases(schema).objects(schema)[0] num_params = base.get_num_params(schema) if not num_params: raise errors.SchemaDefinitionError( f'base type {base.get_name(schema)} does not ' f'accept parameters', span=self.span, ) if num_params != len(args): raise errors.SchemaDefinitionError( f'incorrect number of arguments provided to base type ' f'{base.get_name(schema)}: expected {num_params} ' f'but got {len(args)}', span=self.span, ) def validate_scalar_ancestors( self, ancestors: Sequence[so.SubclassableObject], schema: s_schema.Schema, context: sd.CommandContext, ) -> None: real_concrete_ancestors = { ancestor for ancestor in ancestors if not ancestor.get_abstract(schema) } # Filter out anything that has a subclass relation with # every other concrete ancestor. This lets us allow chains # of concrete scalar types while prohibiting diamonds (for # example if X <: A, B <: int64 where A, B are concrete). # (If we wanted to allow diamonds, we could instead filter out # anything that has concrete bases.) concrete_ancestors = { c1 for c1 in real_concrete_ancestors if not all(c1 == c2 or c1.issubclass(schema, c2) or c2.issubclass(schema, c1) for c2 in real_concrete_ancestors) } if len(concrete_ancestors) > 1: raise errors.SchemaError( f'scalar type may not have more than ' f'one concrete base type', span=self.span, ) abstract = self.get_attribute_value('abstract') enum = self.get_attribute_value('enum_values') if ( len(real_concrete_ancestors) < 1 and not context.stdmode and not abstract and not enum and not self.get_attribute_value('sql_type') ): if not ancestors: hint = ( f'\nFor example: scalar type {self.classname.name} ' f'extending str' ) else: hint = 'Bases were specified but no concrete bases were found' raise errors.SchemaError( f'scalar type must have a concrete base type', span=self.span, hint=hint, ) def validate_scalar_bases( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: bases = self.get_resolved_attribute_value( 'bases', schema=schema, context=context) if bases is not None: ancestors = [] for base in bases.objects(schema): ancestors.append(base) ancestors.extend(base.get_ancestors(schema).objects(schema)) self.validate_scalar_ancestors(ancestors, schema, context) class CreateScalarType( ScalarTypeCommand, s_types.CreateInheritingType[ScalarType], ): astnode = qlast.CreateScalarType @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast( schema, astnode.replace(bases=[]), context) if isinstance(cmd, sd.CommandGroup): for subcmd in cmd.get_subcommands(): if isinstance(subcmd, cls): create_cmd: sd.Command = subcmd break else: raise errors.InternalServerError( 'scalar alias definition did not return CreateScalarType' ) else: create_cmd = cmd if isinstance(astnode, qlast.CreateScalarType): bases = [ s_utils.ast_to_type_shell( b, metaclass=ScalarType, modaliases=context.modaliases, schema=schema, allow_generalized_bases=True, ) for b in (astnode.bases or []) ] is_enum = any( isinstance(br, AnonymousEnumTypeShell) for br in bases) for ab, b in zip(astnode.bases, bases): if isinstance(b, s_types.CollectionTypeShell): raise errors.SchemaError( f'scalar type may not have a collection base type', span=ab.span, ) # We don't support FINAL, but old dumps and migrations specify # it on enum CREATE SCALAR TYPEs, so we need to permit it in those # cases. if not is_enum and astnode.final: raise errors.UnsupportedFeatureError( f'FINAL is not supported', span=astnode.span, ) if is_enum: # This is an enumerated type. if len(bases) > 1: assert isinstance(astnode, qlast.BasedOn) raise errors.SchemaError( f'invalid scalar type definition, enumeration must be' f' the only supertype specified', span=astnode.bases[0].span, ) if create_cmd.has_attribute_value('default'): raise errors.UnsupportedFeatureError( f'enumerated types do not support defaults', span=( create_cmd.get_attribute_span('default') ), ) shell = bases[0] assert isinstance(shell, AnonymousEnumTypeShell) if len(set(shell.elements)) != len(shell.elements): raise errors.SchemaDefinitionError( f'enums cannot contain duplicate values', span=astnode.bases[0].span, ) create_cmd.set_attribute_value('enum_values', shell.elements) create_cmd.set_attribute_value( 'bases', so.ObjectCollectionShell( [s_utils.ast_objref_to_object_shell( s_utils.name_to_ast_ref( s_name.QualName('std', 'anyenum'), ), schema=schema, metaclass=ScalarType, modaliases={}, )], collection_type=so.ObjectList, ) ) else: if any(b.extra_args for b in bases): if len(bases) > 1: raise errors.SchemaDefinitionError( 'scalars with parameterized bases may ' 'only have one', span=astnode.bases[0].span, ) base = bases[0] args = [] for x in (base.extra_args or ()): if ( not isinstance(x, qlast.TypeExprLiteral) or not isinstance(x.val, qlast.Constant) or x.val.kind != qlast.ConstantKind.INTEGER ): raise errors.SchemaDefinitionError( 'invalid scalar type argument', span=x.span, ) args.append(x.val.value) cmd.set_attribute_value('arg_values', args) cmd.set_attribute_value( 'bases', so.ObjectCollectionShell( bases, collection_type=so.ObjectList ), ) return cmd def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._create_begin(schema, context) if ( not context.canonical and not self.scls.get_abstract(schema) and not self.scls.get_transient(schema) ): # Create an array type for this scalar eagerly. # We mostly do this so that we know the `backend_id` # of the array type when running translation of SQL # involving arrays of scalars. schema2, arr_t = s_types.Array.from_subtypes(schema, [self.scls]) self.add_caused(arr_t.as_shell(schema2).as_create_delta(schema2)) return schema def validate_create( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_create(schema, context) self.validate_scalar_bases(schema, context) def _get_ast_node( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> type[qlast.DDLOperation]: if self.get_attribute_value('expr'): return qlast.CreateAlias else: return super()._get_ast_node(schema, context) def _apply_field_ast( self, schema: s_schema.Schema, context: sd.CommandContext, node: qlast.DDLOperation, op: sd.AlterObjectProperty, ) -> None: if op.property == 'default': if op.new_value: assert isinstance(op.new_value, list) op.new_value = op.new_value[0] super()._apply_field_ast(schema, context, node, op) elif op.property == 'bases': enum_values = self.get_local_attribute_value('enum_values') if enum_values: assert isinstance(node, qlast.BasedOn) node.bases = [ qlast.TypeName( maintype=qlast.ObjectRef(name='enum'), subtypes=[ qlast.TypeName(maintype=qlast.ObjectRef(name=v)) for v in enum_values ] ) ] else: super()._apply_field_ast(schema, context, node, op) if arg_values := self.get_local_attribute_value('arg_values'): frags = [ s_expr.Expression(text=x).parse() for x in arg_values] assert isinstance(node, qlast.BasedOn) node.bases[0].subtypes = [ qlast.TypeExprLiteral( val=downcast(qlast.Constant, frag) ) for frag in frags ] else: super()._apply_field_ast(schema, context, node, op) class RenameScalarType( ScalarTypeCommand, s_types.RenameInheritingType[ScalarType], ): pass class RebaseScalarType( ScalarTypeCommand, inheriting.RebaseInheritingObject[ScalarType], ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: scls = self.get_object(schema, context) self.scls = scls assert isinstance(scls, ScalarType) if self.scls.is_concrete_enum(schema): if self.removed_bases and not self.added_bases: raise errors.SchemaError( f'cannot DROP EXTENDING enum') if self.added_bases: first_bases = self.added_bases[0] new_bases, pos = first_bases if len(self.added_bases) > 1 or len(new_bases) > 1: dn = self.scls.get_displayname(schema) raise errors.SchemaError( f'enum {dn} may not have multiple supertypes') new_base = new_bases[0] if isinstance(new_base, AnonymousEnumTypeShell): new_name = _prettyprint_enum(new_base.elements) else: if isinstance(new_base, so.ObjectShell): new_base = new_base.resolve(schema) assert isinstance(new_base, s_types.Type) new_name = new_base.get_verbosename(schema) if self.removed_bases and not scls.is_view(schema): # enum to enum rebases come without removed_bases assert not new_base.is_enum(schema) raise errors.SchemaError( f'cannot change the base of enum type ' f'{scls.get_displayname(schema)} to {new_name}') if pos: raise errors.SchemaError( f'cannot add supertype {new_name} ' f'to enum type {scls.get_displayname(schema)}') assert isinstance(new_base, AnonymousEnumTypeShell) schema = self._validate_enum_change( scls, new_base.elements, schema) schema = super().apply(schema, context) self.validate_scalar_bases(schema, context) else: old_concrete = self.scls.maybe_get_topmost_concrete_base(schema) for b in [b for bs, _ in self.added_bases for b in bs]: if isinstance(b, s_types.CollectionTypeShell): raise errors.SchemaError( f'scalar type may not have a collection base type', span=self.span, ) schema = super().apply(schema, context) self.validate_scalar_bases(schema, context) new_concrete = self.scls.maybe_get_topmost_concrete_base(schema) if old_concrete != new_concrete and not scls.is_view(schema): old_name = (old_concrete.get_displayname(schema) if old_concrete else 'None') if self.scls.is_concrete_enum(schema): values = self.scls.get_enum_values(schema) assert values is not None new_name = _prettyprint_enum(values) elif new_concrete: new_name = new_concrete.get_displayname(schema) else: new_name = 'None' raise errors.SchemaError( f'cannot change concrete base of scalar type ' f'{scls.get_displayname(schema)} from ' f'{old_name} to {new_name}') return schema def validate_scalar_bases( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_scalar_bases(schema, context) bases = self.get_resolved_attribute_value( 'bases', schema=schema, context=context) if bases: obj = self.scls # For each descendant, compute its new ancestors and check # that they are valid for a scalar type. new_schema = obj.set_field_value(schema, 'bases', bases) for desc in obj.descendants(schema): ancestors = so.compute_ancestors(new_schema, desc) self.validate_scalar_ancestors(ancestors, schema, context) def _validate_enum_change( self, stype: s_types.Type, new_labels: Sequence[str], schema: s_schema.Schema, ) -> s_schema.Schema: new_set = set(new_labels) if len(new_set) != len(new_labels): raise errors.SchemaError( f'enums cannot contain duplicate values') self.set_attribute_value('enum_values', new_labels) schema = stype.set_field_value(schema, 'enum_values', new_labels) return schema def _prettyprint_enum(elements: Iterable[str]) -> str: return f"enum<{', '.join(elements)}>" class AlterScalarType( ScalarTypeCommand, s_types.AlterType[ScalarType], inheriting.AlterInheritingObject[ScalarType], ): astnode = qlast.AlterScalarType class DeleteScalarType( ScalarTypeCommand, s_types.DeleteInheritingType[ScalarType], ): astnode = qlast.DropScalarType def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if self.get_orig_attribute_value('expr_type'): # This is an alias type, appropriate DDL would be generated # from the corresponding DeleteAlias node. return None else: return super()._get_ast(schema, context, parent_node=parent_node) def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: if not context.canonical: schema2, arr_typ = s_types.Array.from_subtypes(schema, [self.scls]) arr_op = arr_typ.init_delta_command( schema2, sd.DeleteObject, if_exists=True, ) self.add_prerequisite(arr_op) return super()._delete_begin(schema, context) ================================================ FILE: edb/schema/schema.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, Iterable, Iterator, Mapping, NoReturn, Optional, overload, Self, TYPE_CHECKING, ) import abc import collections import itertools import immutables as immu from edb import errors from edb.common import adapter from edb.common import english from edb.common import lru from . import casts as s_casts from . import functions as s_func from . import migrations as s_migrations from . import modules as s_mod from . import name as sn from . import objects as so from . import operators as s_oper from . import pseudo as s_pseudo from . import types as s_types if TYPE_CHECKING: import uuid from edb.common import parsing Refs_T = immu.Map[ uuid.UUID, immu.Map[ tuple[type[so.Object], str], immu.Map[uuid.UUID, None], ], ] EXT_MODULE = sn.UnqualName('ext') STD_MODULES = ( sn.UnqualName('std'), sn.UnqualName('schema'), sn.UnqualName('std::math'), sn.UnqualName('sys'), sn.UnqualName('sys::perm'), sn.UnqualName('cfg'), sn.UnqualName('cfg::perm'), sn.UnqualName('std::cal'), sn.UnqualName('std::net'), sn.UnqualName('std::net::http'), sn.UnqualName('std::net::perm'), sn.UnqualName('std::pg'), sn.UnqualName('std::_test'), sn.UnqualName('std::fts'), sn.UnqualName('std::lang'), sn.UnqualName('std::lang::go'), sn.UnqualName('std::lang::js'), sn.UnqualName('std::lang::py'), sn.UnqualName('std::lang::rs'), EXT_MODULE, sn.UnqualName('std::enc'), ) SPECIAL_MODULES = ( sn.UnqualName('__derived__'), sn.UnqualName('__ext_casts__'), sn.UnqualName('__ext_index_matches__'), ) # Specifies the order of processing of files and directories in lib/ STD_SOURCES = ( sn.UnqualName('std'), sn.UnqualName('schema'), sn.UnqualName('math'), sn.UnqualName('sys'), sn.UnqualName('cfg'), sn.UnqualName('cal'), sn.UnqualName('ext'), sn.UnqualName('enc'), sn.UnqualName('pg'), sn.UnqualName('fts'), sn.UnqualName('net'), ) TESTMODE_SOURCES = ( sn.UnqualName('_testmode'), ) # Deep optimization: avoid lookups into so.Object _raw_schema_restore = so.Object.raw_schema_restore class Schema(abc.ABC): ''' Data store for objects and their data. Objects have: - a class (also called mcls for meta class or scls for schema class), - an id (of type UUID), - data (a tuple of python values), - name (of type sn.Name). Objects can be retrieved by: - id, - name (fully qualified), - shortname (only for function and operator class), - references. ''' @abc.abstractmethod def _get_by_name( self, name: sn.Name, ) -> Optional[so.Object]: raise NotImplementedError @abc.abstractmethod def _get_by_shortname[T: s_func.Function | s_oper.Operator]( self, mcls: type[T], shortname: sn.Name, ) -> Optional[tuple[T, ...]]: raise NotImplementedError @abc.abstractmethod def _get_by_globalname[T: so.Object]( self, mcls: type[T], name: sn.Name, ) -> Optional[T]: raise NotImplementedError @abc.abstractmethod def add( self, id: uuid.UUID, sclass: type[so.Object], data: tuple[Any, ...], ) -> Self: raise NotImplementedError def discard(self: Self, obj: so.Object) -> Self: if self.has_object(obj.id): return self.delete(obj) else: return self @abc.abstractmethod def delete(self: Self, obj: so.Object) -> Self: raise NotImplementedError @abc.abstractmethod def delist(self: Self, name: sn.Name) -> Self: raise NotImplementedError @abc.abstractmethod def update_obj( self: Self, obj: so.Object, updates: Mapping[str, Any], ) -> Self: raise NotImplementedError @abc.abstractmethod def get_data_raw( self, obj: so.Object, ) -> Optional[tuple[Any, ...]]: raise NotImplementedError @abc.abstractmethod def get_field_raw( self, obj: so.Object, field_index: int, ) -> Optional[Any]: raise NotImplementedError @abc.abstractmethod def set_field( self: Self, obj: so.Object, field: str, value: Any, ) -> Self: raise NotImplementedError @abc.abstractmethod def unset_field( self: Self, obj: so.Object, field: str, ) -> Self: raise NotImplementedError @abc.abstractmethod def has_object(self, object_id: uuid.UUID) -> bool: raise NotImplementedError def has_module(self, name: str) -> bool: return self.get_global(s_mod.Module, name, None) is not None def has_migration(self, name: str) -> bool: return self.get_global(s_migrations.Migration, name, None) is not None @overload def get_by_id( self, obj_id: uuid.UUID, default: so.Object | so.NoDefaultT = so.NoDefault, *, type: None = None, ) -> so.Object: ... @overload def get_by_id( self, obj_id: uuid.UUID, default: so.Object_T | so.NoDefaultT = so.NoDefault, *, type: Optional[type[so.Object_T]] = None, ) -> so.Object_T: ... @overload def get_by_id( self, obj_id: uuid.UUID, default: None = None, *, type: Optional[type[so.Object_T]] = None, ) -> Optional[so.Object_T]: ... def get_by_id( self, obj_id: uuid.UUID, default: so.Object_T | so.NoDefaultT | None = so.NoDefault, *, type: Optional[type[so.Object_T]] = None, ) -> Optional[so.Object_T]: return self._get_by_id(obj_id, default, type=type) @abc.abstractmethod def _get_by_id( self, obj_id: uuid.UUID, default: so.Object_T | so.NoDefaultT | None = so.NoDefault, *, type: Optional[type[so.Object_T]] = None, ) -> Optional[so.Object_T]: raise NotImplementedError @overload def get_by_name[T: so.Object]( self, name: sn.Name | str, default: T | so.NoDefaultT = so.NoDefault, type: Optional[type[T]] = None, span: Optional[parsing.Span] = None ) -> T: ... @overload def get_by_name[T: so.Object]( self, name: sn.Name | str, default: None = None, type: Optional[type[T]] = None, span: Optional[parsing.Span] = None ) -> Optional[T]: ... def get_by_name[T: so.Object]( self, name: sn.Name | str, default: T | so.NoDefaultT | None = so.NoDefault, type: Optional[type[T]] = None, span: Optional[parsing.Span] = None ) -> Optional[T]: """Retrieve object by name (not global name or short name)""" if isinstance(name, str): name = sn.QualName.from_string(name) obj = self._get_by_name(name) if obj is not None: if type is not None: if not isinstance(obj, type): Schema.raise_wrong_type(name, obj.__class__, type, span) return obj # type: ignore elif default is not so.NoDefault: return default else: Schema.raise_bad_reference(name, type=type) def get_by_shortname[T: s_func.Function | s_oper.Operator]( self, mcls: type[T], shortname: str | sn.Name, span: Optional[parsing.Span] = None ) -> tuple[T, ...]: """Retrieve object by shortname""" if isinstance(shortname, str): shortname = sn.QualName.from_string(shortname) objs = self._get_by_shortname(mcls, shortname) if objs is not None: return objs else: Schema.raise_bad_reference(shortname, type=mcls) # TODO: rename to get_by_globalname @overload def get_global[T: so.Object]( self, mcls: type[T], name: str | sn.Name, default: T | so.NoDefaultT = so.NoDefault, ) -> T: ... # TODO: rename to get_by_globalname @overload def get_global[T: so.Object]( self, mcls: type[T], name: str | sn.Name, default: None = None, ) -> Optional[T]: ... # TODO: rename to get_by_globalname def get_global[T: so.Object]( self, mcls: type[T], name: str | sn.Name, default: T | so.NoDefaultT | None = so.NoDefault, ) -> Optional[T]: if isinstance(name, str): name = sn.UnqualName(name) obj = self._get_by_globalname(mcls, name) if obj is not None: return obj elif default is not so.NoDefault: return default else: Schema.raise_bad_reference(name, type=mcls) @overload def get( self, name: str | sn.Name, default: so.Object | so.NoDefaultT = so.NoDefault, *, module_aliases: Optional[Mapping[Optional[str], str]] = None, condition: Optional[Callable[[so.Object], bool]] = None, label: Optional[str] = None, span: Optional[parsing.Span] = None, ) -> so.Object: ... @overload def get( self, name: str | sn.Name, default: None, *, module_aliases: Optional[Mapping[Optional[str], str]] = None, condition: Optional[Callable[[so.Object], bool]] = None, label: Optional[str] = None, span: Optional[parsing.Span] = None, ) -> Optional[so.Object]: ... @overload def get[T: so.Object]( self, name: str | sn.Name, default: T | so.NoDefaultT = so.NoDefault, *, module_aliases: Optional[Mapping[Optional[str], str]] = None, type: type[T], condition: Optional[Callable[[so.Object], bool]] = None, label: Optional[str] = None, span: Optional[parsing.Span] = None, ) -> T: ... @overload def get[T: so.Object]( self, name: str | sn.Name, default: None, *, module_aliases: Optional[Mapping[Optional[str], str]] = None, type: type[T], condition: Optional[Callable[[so.Object], bool]] = None, label: Optional[str] = None, span: Optional[parsing.Span] = None, ) -> Optional[T]: ... @overload def get( self, name: str | sn.Name, default: so.Object | so.NoDefaultT | None = so.NoDefault, *, module_aliases: Optional[Mapping[Optional[str], str]] = None, type: Optional[type[so.Object]] = None, condition: Optional[Callable[[so.Object], bool]] = None, label: Optional[str] = None, span: Optional[parsing.Span] = None, ) -> Optional[so.Object]: ... def get( self, name: str | sn.Name, default: so.Object | so.NoDefaultT | None = so.NoDefault, *, module_aliases: Optional[Mapping[Optional[str], str]] = None, type: Optional[type[so.Object]] = None, condition: Optional[Callable[[so.Object], bool]] = None, label: Optional[str] = None, span: Optional[parsing.Span] = None, ) -> Optional[so.Object]: def getter(schema: Schema, name: sn.Name) -> Optional[so.Object]: obj = schema._get_by_name(name) if obj is not None and condition is not None: if not condition(obj): obj = None return obj obj = lookup( self, name, getter=getter, default=default, module_aliases=module_aliases, ) if obj is not so.NoDefault: # We do our own type check, instead of using get_by_id's, so # we can produce a user-facing error message. if obj and type is not None and not isinstance(obj, type): Schema.raise_wrong_type(name, obj.__class__, type, span) return obj else: Schema.raise_bad_reference( name=name, label=label, module_aliases=module_aliases, span=span, type=type, ) @abc.abstractmethod def _get_object_ids(self) -> Iterable[uuid.UUID]: raise NotImplementedError @abc.abstractmethod def _get_global_name_ids( self ) -> Iterable[tuple[type[so.Object], uuid.UUID]]: raise NotImplementedError def get_children( self, scls: so.Object_T, ) -> frozenset[so.Object_T]: # Ideally get_referrers needs to be made generic via # an overload on scls_type, but mypy crashes on that. return self.get_referrers( scls, scls_type=type(scls), field_name='bases', ) def get_descendants( self, scls: so.Object_T, ) -> frozenset[so.Object_T]: return self.get_referrers( scls, scls_type=type(scls), field_name='ancestors') def get_objects[Object_T: so.Object]( self, *, exclude_stdlib: bool = False, exclude_global: bool = False, exclude_extensions: bool = False, exclude_internal: bool = True, included_modules: Optional[Iterable[sn.Name]] = None, excluded_modules: Optional[Iterable[sn.Name]] = None, included_items: Optional[Iterable[sn.Name]] = None, excluded_items: Optional[Iterable[sn.Name]] = None, type: Optional[type[Object_T]] = None, extra_filters: Iterable[Callable[[Schema, Object_T], bool]] = (), ) -> SchemaIterator[Object_T]: return SchemaIterator[Object_T]( self, self._get_object_ids(), exclude_global=exclude_global, exclude_stdlib=exclude_stdlib, exclude_extensions=exclude_extensions, exclude_internal=exclude_internal, included_modules=included_modules, excluded_modules=excluded_modules, included_items=included_items, excluded_items=excluded_items, type=type, extra_filters=extra_filters, ) def get_modules(self) -> tuple[s_mod.Module, ...]: modules = [] for mcls, id in self._get_global_name_ids(): if mcls is s_mod.Module: modules.append(mcls(_private_id=id)) return tuple(modules) # type: ignore def get_last_migration(self) -> Optional[s_migrations.Migration]: return _get_last_migration(self) def get_casts_to_type( self, to_type: s_types.Type, *, implicit: bool = False, assignment: bool = False, ) -> frozenset[s_casts.Cast]: return self._get_casts( to_type, disposition='to_type', implicit=implicit, assignment=assignment, ) def get_casts_from_type( self, from_type: s_types.Type, *, implicit: bool = False, assignment: bool = False, ) -> frozenset[s_casts.Cast]: return self._get_casts( from_type, disposition='from_type', implicit=implicit, assignment=assignment, ) @lru.lru_method_cache() def _get_casts( self, stype: s_types.Type, *, disposition: str, implicit: bool = False, assignment: bool = False, ) -> frozenset[s_casts.Cast]: all_casts = self.get_referrers( stype, scls_type=s_casts.Cast, field_name=disposition ) casts = [] for castobj in all_casts: if implicit and not castobj.get_allow_implicit(self): continue if assignment and not castobj.get_allow_assignment(self): continue casts.append(castobj) return frozenset(casts) @overload def get_referrers[T: so.Object]( self, scls: so.Object, *, scls_type: type[T], field_name: Optional[str] = None, ) -> frozenset[T]: ... @overload def get_referrers( self, scls: so.Object, *, scls_type: None = None, field_name: Optional[str] = None, ) -> frozenset[so.Object]: ... @abc.abstractmethod def get_referrers( self, scls: so.Object, *, scls_type: Optional[type[so.Object]] = None, field_name: Optional[str] = None, ) -> frozenset[so.Object]: raise NotImplementedError @abc.abstractmethod def get_referrers_ex( self, scls: so.Object, *, scls_type: Optional[type[so.Object_T]] = None, ) -> dict[ tuple[type[so.Object_T], str], frozenset[so.Object_T], ]: raise NotImplementedError @staticmethod def raise_wrong_type( name: str | sn.Name, actual_type: type[so.Object_T], expected_type: type[so.Object_T], span: Optional[parsing.Span], ) -> NoReturn: refname = str(name) actual_type_name = actual_type.get_schema_class_displayname() expected_type_name = expected_type.get_schema_class_displayname() raise errors.InvalidReferenceError( f'{refname!r} exists, but is {english.add_a(actual_type_name)}, ' f'not {english.add_a(expected_type_name)}', span=span, ) @staticmethod def raise_bad_reference( name: str | sn.Name, *, label: Optional[str] = None, module_aliases: Optional[Mapping[Optional[str], str]] = None, span: Optional[parsing.Span] = None, type: Optional[type[so.Object]] = None, ) -> NoReturn: refname = str(name) if label is None: if type is not None: label = type.get_schema_class_displayname() else: label = 'schema item' if type is not None: if issubclass(type, so.QualifiedObject): if not sn.is_qualified(refname): if module_aliases is not None: default_module = module_aliases.get(None) if default_module is not None: refname = type.get_displayname_static( sn.QualName(default_module, refname), ) else: refname = type.get_displayname_static( sn.QualName.from_string(refname)) else: refname = type.get_displayname_static( sn.UnqualName.from_string(refname)) raise errors.InvalidReferenceError( f'{label} {refname!r} does not exist', span=span, ) class FlatSchema(Schema): _id_to_data: immu.Map[uuid.UUID, tuple[Any, ...]] _id_to_type: immu.Map[uuid.UUID, str] _name_to_id: immu.Map[sn.Name, uuid.UUID] _shortname_to_id: immu.Map[ tuple[type[so.Object], sn.Name], frozenset[uuid.UUID], ] _globalname_to_id: immu.Map[ tuple[type[so.Object], sn.Name], uuid.UUID, ] _refs_to: Refs_T _generation: int def __init__(self) -> None: self._id_to_data = immu.Map() self._id_to_type = immu.Map() self._shortname_to_id = immu.Map() self._name_to_id = immu.Map() self._globalname_to_id = immu.Map() self._refs_to = immu.Map() self._generation = 0 def _get_object_ids(self) -> Iterable[uuid.UUID]: return self._id_to_type.keys() def _get_global_name_ids( self ) -> Iterable[tuple[type[so.Object], uuid.UUID]]: return ( (mcls, id) for (mcls, _name), id in self._globalname_to_id.items() ) def _replace( self, *, id_to_data: Optional[immu.Map[uuid.UUID, tuple[Any, ...]]] = None, id_to_type: Optional[immu.Map[uuid.UUID, str]] = None, name_to_id: Optional[immu.Map[sn.Name, uuid.UUID]] = None, shortname_to_id: Optional[ immu.Map[ tuple[type[so.Object], sn.Name], frozenset[uuid.UUID] ] ] = None, globalname_to_id: Optional[ immu.Map[tuple[type[so.Object], sn.Name], uuid.UUID] ] = None, refs_to: Optional[Refs_T] = None, ) -> FlatSchema: new = FlatSchema.__new__(FlatSchema) if id_to_data is None: new._id_to_data = self._id_to_data else: new._id_to_data = id_to_data if id_to_type is None: new._id_to_type = self._id_to_type else: new._id_to_type = id_to_type if name_to_id is None: new._name_to_id = self._name_to_id else: new._name_to_id = name_to_id if shortname_to_id is None: new._shortname_to_id = self._shortname_to_id else: new._shortname_to_id = shortname_to_id if globalname_to_id is None: new._globalname_to_id = self._globalname_to_id else: new._globalname_to_id = globalname_to_id if refs_to is None: new._refs_to = self._refs_to else: new._refs_to = refs_to new._generation = self._generation + 1 return new def _update_obj_name( self, obj_id: uuid.UUID, sclass: type[so.Object], old_name: Optional[sn.Name], new_name: Optional[sn.Name], ) -> tuple[ immu.Map[sn.Name, uuid.UUID], immu.Map[tuple[type[so.Object], sn.Name], frozenset[uuid.UUID]], immu.Map[tuple[type[so.Object], sn.Name], uuid.UUID], ]: name_to_id = self._name_to_id shortname_to_id = self._shortname_to_id globalname_to_id = self._globalname_to_id is_global = not issubclass(sclass, so.QualifiedObject) has_sn_cache = issubclass(sclass, (s_func.Function, s_oper.Operator)) if old_name is not None: if is_global: globalname_to_id = globalname_to_id.delete((sclass, old_name)) else: name_to_id = name_to_id.delete(old_name) if has_sn_cache: old_shortname = sn.shortname_from_fullname(old_name) sn_key = (sclass, old_shortname) new_ids = shortname_to_id[sn_key] - {obj_id} if new_ids: shortname_to_id = shortname_to_id.set(sn_key, new_ids) else: shortname_to_id = shortname_to_id.delete(sn_key) if new_name is not None: if is_global: key = (sclass, new_name) if key in globalname_to_id: other_obj = self.get_by_id( globalname_to_id[key], type=so.Object) vn = other_obj.get_verbosename(self, with_parent=True) raise errors.SchemaError( f'{vn} already exists') globalname_to_id = globalname_to_id.set(key, obj_id) else: assert isinstance(new_name, sn.QualName) if ( not self.has_module(new_name.module) and new_name.get_module_name() not in SPECIAL_MODULES ): raise errors.UnknownModuleError( f'module {new_name.module!r} is not in this schema') if new_name in name_to_id: other_obj = self.get_by_id( name_to_id[new_name], type=so.Object) vn = other_obj.get_verbosename(self, with_parent=True) raise errors.SchemaError( f'{vn} already exists') name_to_id = name_to_id.set(new_name, obj_id) if has_sn_cache: new_shortname = sn.shortname_from_fullname(new_name) sn_key = (sclass, new_shortname) try: ids = shortname_to_id[sn_key] except KeyError: ids = frozenset() shortname_to_id = shortname_to_id.set(sn_key, ids | {obj_id}) return name_to_id, shortname_to_id, globalname_to_id def update_obj( self, obj: so.Object, updates: Mapping[str, Any], ) -> FlatSchema: if not updates: return self obj_id = obj.id sclass = type(obj) all_fields = sclass.get_schema_fields() object_ref_fields = sclass.get_object_reference_fields() reducible_fields = sclass.get_reducible_fields() try: data = list(self._id_to_data[obj_id]) except KeyError: data = [None] * len(all_fields) name_to_id = None shortname_to_id = None globalname_to_id = None orig_refs = {} new_refs = {} for fieldname, value in updates.items(): field = all_fields[fieldname] findex = field.index if fieldname == 'name': name_to_id, shortname_to_id, globalname_to_id = ( self._update_obj_name( obj_id, sclass, data[findex], value ) ) if value is None: if field in reducible_fields and field in object_ref_fields: orig_value = data[findex] if orig_value is not None: orig_refs[fieldname] = ( field.type.schema_refs_from_data(orig_value)) else: if field in reducible_fields: value = value.schema_reduce() if field in object_ref_fields: new_refs[fieldname] = ( field.type.schema_refs_from_data(value)) orig_value = data[findex] if orig_value is not None: orig_refs[fieldname] = ( field.type.schema_refs_from_data(orig_value)) data[findex] = value id_to_data = self._id_to_data.set(obj_id, tuple(data)) refs_to = self._update_refs_to(obj_id, sclass, orig_refs, new_refs) return self._replace(name_to_id=name_to_id, shortname_to_id=shortname_to_id, globalname_to_id=globalname_to_id, id_to_data=id_to_data, refs_to=refs_to) def get_data_raw( self, obj: so.Object, ) -> Optional[tuple[Any, ...]]: return self._id_to_data.get(obj.id) def get_field_raw( self, obj: so.Object, field_index: int ) -> Optional[Any]: data = self._id_to_data.get(obj.id) assert data, ( f'cannot get item data: item {str(obj.id)!r} ' f'is not present in the schema {self!r}' ) return data[field_index] def set_field( self, obj: so.Object, fieldname: str, value: Any, ) -> FlatSchema: obj_id = obj.id try: data = self._id_to_data[obj_id] except KeyError: err = (f'cannot set {fieldname!r} value: item {str(obj_id)!r} ' f'is not present in the schema {self!r}') raise errors.SchemaError(err) from None sclass = so.ObjectMeta.get_schema_class(self._id_to_type[obj_id]) field = sclass.get_schema_field(fieldname) findex = field.index is_object_ref = field in sclass.get_object_reference_fields() if field in sclass.get_reducible_fields(): value = value.schema_reduce() name_to_id = None shortname_to_id = None globalname_to_id = None if fieldname == 'name': old_name = data[findex] name_to_id, shortname_to_id, globalname_to_id = ( self._update_obj_name(obj_id, sclass, old_name, value) ) data_list = list(data) data_list[findex] = value new_data = tuple(data_list) id_to_data = self._id_to_data.set(obj_id, new_data) if not is_object_ref: refs_to = None else: orig_value = data[findex] if orig_value is not None: orig_refs = { fieldname: field.type.schema_refs_from_data(orig_value), } else: orig_refs = {} new_refs = {fieldname: field.type.schema_refs_from_data(value)} refs_to = self._update_refs_to(obj_id, sclass, orig_refs, new_refs) return self._replace( name_to_id=name_to_id, shortname_to_id=shortname_to_id, globalname_to_id=globalname_to_id, id_to_data=id_to_data, refs_to=refs_to, ) def unset_field( self, obj: so.Object, fieldname: str, ) -> FlatSchema: obj_id = obj.id try: data = self._id_to_data[obj.id] except KeyError: return self sclass = so.ObjectMeta.get_schema_class(self._id_to_type[obj.id]) field = sclass.get_schema_field(fieldname) findex = field.index name_to_id = None shortname_to_id = None globalname_to_id = None orig_value = data[findex] if orig_value is None: return self if fieldname == 'name': name_to_id, shortname_to_id, globalname_to_id = ( self._update_obj_name( obj_id, sclass, orig_value, None ) ) data_list = list(data) data_list[findex] = None new_data = tuple(data_list) id_to_data = self._id_to_data.set(obj_id, new_data) is_object_ref = field in sclass.get_object_reference_fields() if not is_object_ref: refs_to = None else: orig_refs = { fieldname: field.type.schema_refs_from_data(orig_value), } refs_to = self._update_refs_to(obj_id, sclass, orig_refs, None) return self._replace( name_to_id=name_to_id, shortname_to_id=shortname_to_id, globalname_to_id=globalname_to_id, id_to_data=id_to_data, refs_to=refs_to, ) def _update_refs_to( self, object_id: uuid.UUID, sclass: type[so.Object], orig_refs: Optional[Mapping[str, frozenset[uuid.UUID]]], new_refs: Optional[Mapping[str, frozenset[uuid.UUID]]], ) -> Refs_T: objfields = sclass.get_object_reference_fields() if not objfields: return self._refs_to with self._refs_to.mutate() as mm: for field in objfields: if not new_refs: ids = None else: ids = new_refs.get(field.name) if not orig_refs: orig_ids = None else: orig_ids = orig_refs.get(field.name) if not ids and not orig_ids: continue old_ids: Optional[frozenset[uuid.UUID]] new_ids: Optional[frozenset[uuid.UUID]] key = (sclass, field.name) if ids and orig_ids: new_ids = ids - orig_ids old_ids = orig_ids - ids elif ids: new_ids = ids old_ids = None else: new_ids = None old_ids = orig_ids if new_ids: for ref_id in new_ids: try: refs = mm[ref_id] except KeyError: mm[ref_id] = immu.Map(( (key, immu.Map(((object_id, None),))), )) else: try: field_refs = refs[key] except KeyError: field_refs = immu.Map(((object_id, None),)) else: field_refs = field_refs.set(object_id, None) mm[ref_id] = refs.set(key, field_refs) if old_ids: for ref_id in old_ids: refs = mm[ref_id] field_refs = refs[key].delete(object_id) if not field_refs: mm[ref_id] = refs.delete(key) else: mm[ref_id] = refs.set(key, field_refs) result = mm.finish() return result def add( self, id: uuid.UUID, sclass: type[so.Object], data: tuple[Any, ...], ) -> FlatSchema: reducible_fields = sclass.get_reducible_fields() if reducible_fields: data_list = list(data) for field in reducible_fields: val = data[field.index] if val is not None: data_list[field.index] = val.schema_reduce() data = tuple(data_list) name_field = sclass.get_schema_field('name') name = data[name_field.index] if name in self._name_to_id: other_obj = self.get_by_id( self._name_to_id[name], type=so.Object) vn = other_obj.get_verbosename(self, with_parent=True) raise errors.SchemaError(f'{vn} already exists') if id in self._id_to_data: raise errors.SchemaError( f'{sclass.__name__} ({str(id)!r}) is already present ' f'in the schema {self!r}') object_ref_fields = sclass.get_object_reference_fields() if not object_ref_fields: refs_to = None else: new_refs = {} for field in object_ref_fields: ref_data = data[field.index] if ref_data is not None: ref = field.type.schema_refs_from_data(ref_data) new_refs[field.name] = ref refs_to = self._update_refs_to(id, sclass, None, new_refs) name_to_id, shortname_to_id, globalname_to_id = self._update_obj_name( id, sclass, None, name) updates = dict( id_to_data=self._id_to_data.set(id, data), id_to_type=self._id_to_type.set(id, sclass.__name__), name_to_id=name_to_id, shortname_to_id=shortname_to_id, globalname_to_id=globalname_to_id, refs_to=refs_to, ) if ( issubclass(sclass, so.QualifiedObject) and not self.has_module(name.module) and name.get_module_name() not in SPECIAL_MODULES ): raise errors.UnknownModuleError( f'module {name.module!r} is not in this schema') return self._replace(**updates) # type: ignore def delist(self, name: sn.Name) -> FlatSchema: name_to_id = self._name_to_id.delete(name) return self._replace( name_to_id=name_to_id, shortname_to_id=self._shortname_to_id, globalname_to_id=self._globalname_to_id, ) def delete(self, obj: so.Object) -> FlatSchema: data = self._id_to_data.get(obj.id) if data is None: raise errors.InvalidReferenceError( f'cannot delete {obj!r}: not in this schema') sclass = type(obj) name_field = sclass.get_schema_field('name') name = data[name_field.index] updates = {} name_to_id, shortname_to_id, globalname_to_id = self._update_obj_name( obj.id, sclass, name, None) object_ref_fields = sclass.get_object_reference_fields() if not object_ref_fields: refs_to = None else: values = self._id_to_data[obj.id] orig_refs = {} for field in object_ref_fields: ref = values[field.index] if ref is not None: ref = field.type.schema_refs_from_data(ref) orig_refs[field.name] = ref refs_to = self._update_refs_to(obj.id, sclass, orig_refs, None) updates.update(dict( name_to_id=name_to_id, shortname_to_id=shortname_to_id, globalname_to_id=globalname_to_id, id_to_data=self._id_to_data.delete(obj.id), id_to_type=self._id_to_type.delete(obj.id), refs_to=refs_to, )) return self._replace(**updates) # type: ignore def get_referrers( self, scls: so.Object, *, scls_type: Optional[type[so.Object_T]] = None, field_name: Optional[str] = None, ) -> frozenset[so.Object_T]: return self._get_referrers( scls, scls_type=scls_type, field_name=field_name ) @lru.lru_method_cache() def _get_referrers( self, scls: so.Object, *, scls_type: Optional[type[so.Object_T]] = None, field_name: Optional[str] = None, ) -> frozenset[so.Object_T]: try: refs = self._refs_to[scls.id] except KeyError: return frozenset() else: referrers: set[so.Object] = set() if scls_type is not None: if field_name is not None: for (st, fn), ids in refs.items(): if issubclass(st, scls_type) and fn == field_name: referrers.update( self.get_by_id(objid) for objid in ids) else: for (st, _), ids in refs.items(): if issubclass(st, scls_type): referrers.update( self.get_by_id(objid) for objid in ids) elif field_name is not None: for (_, fn), ids in refs.items(): if fn == field_name: referrers.update( self.get_by_id(objid) for objid in ids) else: refids = itertools.chain.from_iterable(refs.values()) referrers.update(self.get_by_id(objid) for objid in refids) return frozenset(referrers) # type: ignore @lru.lru_method_cache() def get_referrers_ex( self, scls: so.Object, *, scls_type: Optional[type[so.Object_T]] = None, ) -> dict[ tuple[type[so.Object_T], str], frozenset[so.Object_T], ]: try: refs = self._refs_to[scls.id] except KeyError: return {} else: result = {} if scls_type is not None: for (st, fn), ids in refs.items(): if issubclass(st, scls_type): result[st, fn] = frozenset( self.get_by_id(objid) for objid in ids) else: for (st, fn), ids in refs.items(): result[st, fn] = frozenset( # type: ignore self.get_by_id(objid) for objid in ids) return result # type: ignore def _get_by_id( self, obj_id: uuid.UUID, default: so.Object_T | so.NoDefaultT | None = so.NoDefault, *, type: Optional[type[so.Object_T]] = None, ) -> Optional[so.Object_T]: try: sclass_name = self._id_to_type[obj_id] except KeyError: if default is so.NoDefault: raise LookupError( f'reference to a non-existent schema item {obj_id}' f' in schema {self!r}' ) from None else: return default else: obj = _raw_schema_restore(sclass_name, obj_id) if type is not None and not isinstance(obj, type): raise TypeError( f'schema object {obj_id!r} exists, but is a ' f'{obj.__class__.get_schema_class_displayname()!r}, ' f'not a {type.get_schema_class_displayname()!r}' ) # Avoid the overhead of cast(Object_T) below return obj # type: ignore # Important micro-optimization if not TYPE_CHECKING: get_by_id = _get_by_id def _get_by_globalname[T: so.Object]( self, mcls: type[T], name: sn.Name, ) -> Optional[T]: if isinstance(name, str): name = sn.UnqualName(name) obj_id = self._globalname_to_id.get((mcls, name)) if obj_id is None: return None return _raw_schema_restore(mcls.__name__, obj_id) # type: ignore def _get_by_shortname[T: s_func.Function | s_oper.Operator]( self, mcls: type[T], shortname: sn.Name, ) -> Optional[tuple[T, ...]]: obj_ids = self._shortname_to_id.get((mcls, shortname)) if obj_ids is None: return None return tuple( _raw_schema_restore(mcls.__name__, i) # type: ignore for i in obj_ids ) def _get_by_name( self, name: sn.Name, ) -> Optional[so.Object]: obj_id = self._name_to_id.get(name) if obj_id is None: return None return self.get_by_id(obj_id) def has_object(self, object_id: uuid.UUID) -> bool: return object_id in self._id_to_type def get_objects( self, *, exclude_stdlib: bool = False, exclude_global: bool = False, exclude_extensions: bool = False, exclude_internal: bool = True, included_modules: Optional[Iterable[sn.Name]] = None, excluded_modules: Optional[Iterable[sn.Name]] = None, included_items: Optional[Iterable[sn.Name]] = None, excluded_items: Optional[Iterable[sn.Name]] = None, type: Optional[type[so.Object_T]] = None, extra_filters: Iterable[Callable[[Schema, so.Object_T], bool]] = (), ) -> SchemaIterator[so.Object_T]: return SchemaIterator[so.Object_T]( self, self._id_to_type, exclude_stdlib=exclude_stdlib, exclude_global=exclude_global, exclude_extensions=exclude_extensions, exclude_internal=exclude_internal, included_modules=included_modules, excluded_modules=excluded_modules, included_items=included_items, excluded_items=excluded_items, type=type, extra_filters=extra_filters, ) def __repr__(self) -> str: return ( f'<{type(self).__name__} gen:{self._generation} at {id(self):#x}>') def lookup[T]( schema: Schema, name: sn.Name | str, *, getter: Callable[[Schema, sn.QualName], T | None], default: T | so.NoDefaultT = so.NoDefault, module_aliases: Optional[Mapping[Optional[str], str]], ) -> T | so.NoDefaultT: """ Find something in the schema with a given name. This function mostly mirrors edgeql.tracer.resolve_name except: - When searching in std, disallow some modules (often the base modules) - If no result found, return default """ if isinstance(name, str): name = sn.name_from_string(name) obj_name = name.name module = name.module if isinstance(name, sn.QualName) else None orig_module = module if module == '__std__': fqname = sn.QualName('std', obj_name) result = getter(schema, fqname) if result is not None: return result else: return default # Apply module aliases module = apply_module_aliases( module, module_aliases ) # Check if something matches the name if module is not None: fqname = sn.QualName(module, obj_name) result = getter(schema, fqname) if result is not None: return result # For unqualified names, fallback to std::{obj_name} if orig_module is None: fqname = sn.QualName('std', obj_name) result = getter(schema, fqname) if result is not None: return result # For qualified names, fallback to std::{module}::{obj_name} # This is allowed only when there is no top-level module with the same name. if module and not schema.has_module(module.split('::')[0]): fqname = sn.QualName(f'std::{module}', obj_name) result = getter(schema, fqname) if result is not None: return result return default def apply_module_aliases( module: str | None, module_aliases: Optional[Mapping[Optional[str], str]], ) -> str | None: if module_aliases is not None: # Apply modalias first: Optional[str] if module: first, sep, rest = module.partition('::') else: first, sep, rest = module, '', '' fq_module = module_aliases.get(first) if fq_module is not None: module = fq_module + sep + rest return module EMPTY_SCHEMA: Schema = FlatSchema() def upgrade_schema(schema: Schema) -> Schema: """Repair a schema object serialized by an older patch version When an edgeql+schema patch adds fields to schema types, old serialized schemas will be broken, since their tuples are missing the fields. In this situation, we run through all the data tuples and fill them out. The upgraded version will then be cached. """ if isinstance(schema, ChainedSchema): return ChainedSchema( base_schema=upgrade_schema(schema._base_schema), top_schema=upgrade_schema(schema._top_schema), global_schema=upgrade_schema(schema._global_schema), ) assert isinstance(schema, FlatSchema) cls_fields = {} for py_cls in so.ObjectMeta.get_schema_metaclasses(): if isinstance(py_cls, adapter.Adapter): continue fields = py_cls._schema_fields.values() cls_fields[py_cls] = sorted(fields, key=lambda f: f.index) id_to_data = schema._id_to_data fixes = {} for id, typ_name in schema._id_to_type.items(): data = id_to_data[id] obj = so.Object.schema_restore((typ_name, id)) typ = type(obj) tfields = cls_fields[typ] exp_len = len(tfields) if len(data) < exp_len: ldata = list(data) for _ in range(len(ldata), exp_len): ldata.append(None) fixes[id] = tuple(ldata) return schema._replace(id_to_data=id_to_data.update(fixes)) class SchemaIterator[Object_T: so.Object]: def __init__( self, schema: Schema, object_ids: Iterable[uuid.UUID], *, exclude_stdlib: bool = False, exclude_global: bool = False, exclude_extensions: bool = False, exclude_internal: bool = True, included_modules: Optional[Iterable[sn.Name]], excluded_modules: Optional[Iterable[sn.Name]], included_items: Optional[Iterable[sn.Name]] = None, excluded_items: Optional[Iterable[sn.Name]] = None, type: Optional[type[Object_T]] = None, extra_filters: Iterable[Callable[[Schema, Object_T], bool]] = (), ) -> None: filters = [] if type is not None: t = type filters.append(lambda schema, obj: isinstance(obj, t)) if included_modules: modules = frozenset(included_modules) filters.append( lambda schema, obj: isinstance(obj, so.QualifiedObject) and obj.get_name(schema).get_module_name() in modules) if excluded_modules or exclude_stdlib: excmod: set[sn.Name] = set() if excluded_modules: excmod.update(excluded_modules) if exclude_stdlib: excmod.update(STD_MODULES) filters.append( lambda schema, obj: ( not isinstance(obj, so.QualifiedObject) or obj.get_name(schema).get_module_name() not in excmod ) ) if included_items: objs = frozenset(included_items) filters.append( lambda schema, obj: obj.get_name(schema) in objs) if excluded_items: objs = frozenset(excluded_items) filters.append( lambda schema, obj: obj.get_name(schema) not in objs) if exclude_stdlib: filters.append( lambda schema, obj: not isinstance(obj, s_pseudo.PseudoType) ) if exclude_extensions: filters.append( lambda schema, obj: obj.get_name(schema).get_root_module_name() != EXT_MODULE ) if exclude_global: filters.append( lambda schema, obj: not isinstance(obj, so.GlobalObject) ) if exclude_internal: filters.append( lambda schema, obj: not isinstance(obj, so.InternalObject) ) # Extra filters are last, because they might depend on type. filters.extend(extra_filters) self._filters = filters self._schema = schema self._object_ids = object_ids def __iter__(self) -> Iterator[Object_T]: filters = self._filters schema = self._schema get_by_id = schema.get_by_id for obj_id in self._object_ids: obj = get_by_id(obj_id) if all(f(self._schema, obj) for f in filters): yield obj # type: ignore class ChainedSchema(Schema): __slots__ = ('_base_schema', '_top_schema', '_global_schema') def __init__( self, base_schema: Schema, top_schema: Schema, global_schema: Schema ) -> None: self._base_schema = base_schema self._top_schema = top_schema self._global_schema = global_schema def _get_object_ids(self) -> Iterable[uuid.UUID]: return itertools.chain( self._base_schema._get_object_ids(), self._top_schema._get_object_ids(), self._global_schema._get_object_ids(), ) def _get_global_name_ids( self ) -> Iterable[tuple[type[so.Object], uuid.UUID]]: return itertools.chain( self._base_schema._get_global_name_ids(), self._top_schema._get_global_name_ids(), self._global_schema._get_global_name_ids(), ) def get_top_schema(self) -> Schema: return self._top_schema def get_global_schema(self) -> Schema: return self._global_schema def add( self, id: uuid.UUID, sclass: type[so.Object], data: tuple[Any, ...], ) -> ChainedSchema: if issubclass(sclass, so.GlobalObject): return ChainedSchema( self._base_schema, self._top_schema, self._global_schema.add(id, sclass, data), ) else: return ChainedSchema( self._base_schema, self._top_schema.add(id, sclass, data), self._global_schema, ) def delete(self, obj: so.Object) -> ChainedSchema: if isinstance(obj, so.GlobalObject): return ChainedSchema( self._base_schema, self._top_schema, self._global_schema.delete(obj), ) else: return ChainedSchema( self._base_schema, self._top_schema.delete(obj), self._global_schema, ) def delist( self, name: sn.Name, ) -> ChainedSchema: return ChainedSchema( self._base_schema, self._top_schema.delist(name), self._global_schema, ) def update_obj( self, obj: so.Object, updates: Mapping[str, Any], ) -> ChainedSchema: if isinstance(obj, so.GlobalObject): return ChainedSchema( self._base_schema, self._top_schema, self._global_schema.update_obj(obj, updates), ) else: obj_id = obj.id base_obj = self._base_schema.get_by_id(obj_id, default=None) if ( base_obj is not None and not self._top_schema.has_object(obj_id) ): top_schema = self._top_schema.add( obj_id, type(base_obj), self._base_schema.get_data_raw(base_obj), ) else: top_schema = self._top_schema return ChainedSchema( self._base_schema, top_schema.update_obj(obj, updates), self._global_schema, ) def get_data_raw( self, obj: so.Object, ) -> Optional[tuple[Any, ...]]: data = self._top_schema.get_data_raw(obj) if data is not None: return data data = self._base_schema.get_data_raw(obj) if data is not None: return data return self._global_schema.get_data_raw(obj) def get_field_raw( self, obj: so.Object, field_index: int, ) -> Optional[Any]: if self._top_schema.has_object(obj.id): return self._top_schema.get_field_raw(obj, field_index) if self._base_schema.has_object(obj.id): return self._base_schema.get_field_raw(obj, field_index) if self._global_schema.has_object(obj.id): return self._global_schema.get_field_raw(obj, field_index) raise AssertionError( f'cannot get item data: item {str(obj.id)!r} ' f'is not present in the schema {self!r}' ) def set_field( self, obj: so.Object, fieldname: str, value: Any, ) -> ChainedSchema: if isinstance(obj, so.GlobalObject): return ChainedSchema( self._base_schema, self._top_schema, self._global_schema.set_field(obj, fieldname, value), ) else: return ChainedSchema( self._base_schema, self._top_schema.set_field(obj, fieldname, value), self._global_schema, ) def unset_field( self, obj: so.Object, field: str, ) -> ChainedSchema: if isinstance(obj, so.GlobalObject): return ChainedSchema( self._base_schema, self._top_schema, self._global_schema.unset_field(obj, field), ) else: return ChainedSchema( self._base_schema, self._top_schema.unset_field(obj, field), self._global_schema, ) def get_referrers( self, scls: so.Object, *, scls_type: Optional[type[so.Object_T]] = None, field_name: Optional[str] = None, ) -> frozenset[so.Object_T]: return ( self._base_schema.get_referrers( # type: ignore [return-value] scls, scls_type=scls_type, field_name=field_name, ) | self._top_schema.get_referrers( scls, scls_type=scls_type, field_name=field_name, ) | self._global_schema.get_referrers( # type: ignore [operator] scls, scls_type=scls_type, field_name=field_name, ) ) def get_referrers_ex( self, scls: so.Object, *, scls_type: Optional[type[so.Object_T]] = None, ) -> dict[ tuple[type[so.Object_T], str], frozenset[so.Object_T], ]: base = self._base_schema.get_referrers_ex(scls, scls_type=scls_type) top = self._top_schema.get_referrers_ex(scls, scls_type=scls_type) gl = self._global_schema.get_referrers_ex(scls, scls_type=scls_type) return { k: ( base.get(k, frozenset()) | top.get(k, frozenset()) | gl.get(k, frozenset()) ) for k in itertools.chain(base, top) } def _get_by_id( self, obj_id: uuid.UUID, default: so.Object_T | so.NoDefaultT | None = so.NoDefault, *, type: Optional[type[so.Object_T]] = None, ) -> Optional[so.Object_T]: obj = self._top_schema.get_by_id(obj_id, type=type, default=None) if obj is None: obj = self._base_schema.get_by_id( obj_id, default=None, type=type) if obj is None: obj = self._global_schema.get_by_id( obj_id, default=default, type=type) return obj # Important micro-optimization if not TYPE_CHECKING: get_by_id = _get_by_id def _get_by_globalname[T: so.Object]( self, mcls: type[T], name: sn.Name, ) -> Optional[T]: if issubclass(mcls, so.GlobalObject): if o := self._global_schema._get_by_globalname( mcls, name ): return o # type: ignore if obj := self._top_schema._get_by_globalname(mcls, name): return obj return self._base_schema._get_by_globalname(mcls, name) def _get_by_shortname[T: s_func.Function | s_oper.Operator]( self, mcls: type[T], shortname: sn.Name, ) -> Optional[tuple[T, ...]]: objs = self._base_schema._get_by_shortname(mcls, shortname) if objs is not None: return objs return self._top_schema._get_by_shortname(mcls, shortname) def _get_by_name( self, name: sn.Name, ) -> Optional[so.Object]: objs = self._base_schema._get_by_name(name) if objs is not None: return objs return self._top_schema._get_by_name(name) def has_object(self, object_id: uuid.UUID) -> bool: return ( self._base_schema.has_object(object_id) or self._top_schema.has_object(object_id) or self._global_schema.has_object(object_id) ) @lru.per_job_lru_cache() def _get_operators( schema: Schema, name: sn.Name, ) -> tuple[s_oper.Operator, ...] | None: return schema._get_by_shortname(s_oper.Operator, name) @lru.per_job_lru_cache() def _get_last_migration( schema: Schema, ) -> Optional[s_migrations.Migration]: migrations: list[s_migrations.Migration] = [ mcls(_private_id=id) # type: ignore for mcls, id in schema._get_global_name_ids() if mcls is s_migrations.Migration ] if not migrations: return None migration_map = collections.defaultdict(list) root = None for m in migrations: parents = m.get_parents(schema).objects(schema) if not parents: if root is not None: raise errors.InternalServerError( 'multiple migration roots found') root = m for parent in parents: migration_map[parent].append(m) if root is None: raise errors.InternalServerError('cannot find migration root') latest = root while children := migration_map[latest]: if len(children) > 1: raise errors.InternalServerError( 'nonlinear migration history detected') latest = children[0] return latest ================================================ FILE: edb/schema/sources.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Optional, Iterable, Sequence, overload, TYPE_CHECKING, ) from edb import errors from . import delta as sd from . import indexes from . import name as sn from . import objects as so from . import pointers as s_pointers if TYPE_CHECKING: from . import links from . import schema as s_schema class SourceCommandContext[Source_T: Source]( sd.ObjectCommandContext[Source_T], indexes.IndexSourceCommandContext, ): # context mixin pass class SourceCommand[Source_T: Source]( indexes.IndexSourceCommand[Source_T] ): pass class Source( so.QualifiedObject, indexes.IndexableSubject, so.Object, # Help reflection figure out the right db MRO ): pointers_refs = so.RefDict( attr='pointers', requires_explicit_overloaded=True, backref_attr='source', ref_cls=s_pointers.Pointer) pointers = so.SchemaField( so.ObjectIndexByUnqualifiedName[s_pointers.Pointer], inheritable=False, ephemeral=True, coerce=True, compcoef=0.857, default=so.DEFAULT_CONSTRUCTOR) @overload def maybe_get_ptr[Pointer_T: s_pointers.Pointer]( self, schema: s_schema.Schema, name: sn.UnqualName, *, type: type[Pointer_T], ) -> Optional[Pointer_T]: ... @overload def maybe_get_ptr( self, schema: s_schema.Schema, name: sn.UnqualName, *, type: Optional[type[s_pointers.Pointer]] = None, ) -> Optional[s_pointers.Pointer]: ... def maybe_get_ptr( self, schema: s_schema.Schema, name: sn.UnqualName, *, type: Optional[type[s_pointers.Pointer]] = None, ) -> Optional[s_pointers.Pointer]: ptr = self.get_pointers(schema).get(schema, name, None) if ptr is not None and type is not None and not isinstance(ptr, type): raise AssertionError( f'{self.get_verbosename(schema)} has a the ' f' {str(name)!r} pointer, but it is not a' f' {type.get_schema_class_displayname()}' ) return ptr @overload def getptr[Pointer_T: s_pointers.Pointer]( self, schema: s_schema.Schema, name: sn.UnqualName, *, type: type[Pointer_T], ) -> Pointer_T: ... @overload def getptr( self, schema: s_schema.Schema, name: sn.UnqualName, *, type: Optional[type[s_pointers.Pointer]] = None, ) -> s_pointers.Pointer: ... def getptr( self, schema: s_schema.Schema, name: sn.UnqualName, *, type: Optional[type[s_pointers.Pointer]] = None, ) -> s_pointers.Pointer: ptr = self.maybe_get_ptr(schema, name, type=type) if ptr is None: raise AssertionError( f'{self.get_verbosename(schema)} has no' f' link or property {str(name)!r}' ) return ptr def getrptrs( self, schema: s_schema.Schema, name: str, *, sources: Iterable[so.Object] = () ) -> set[links.Link]: return set() def add_pointer( self, schema: s_schema.Schema, pointer: s_pointers.Pointer, *, replace: bool = False ) -> s_schema.Schema: schema = self.add_classref( schema, 'pointers', pointer, replace=replace) return schema def get_addon_columns( self, schema: s_schema.Schema ) -> Sequence[tuple[str, str, tuple[str, str]]]: """ Returns a list of columns that are present in the backing table of this source, apart from the columns for pointers. """ res = [] from edb.common import debug if not debug.flags.zombodb: fts_index, _ = indexes.get_effective_object_index( schema, self, sn.QualName("std::fts", "index") ) if fts_index: res.append( ( '__fts_document__', '__fts_document__', ( 'pg_catalog', 'tsvector', ), ) ) ext_ai_index, _ = indexes.get_effective_object_index( schema, self, sn.QualName("ext::ai", "index") ) if ext_ai_index: idx_id = indexes.get_ai_index_id(schema, ext_ai_index) dimensions = ext_ai_index.must_get_json_annotation( schema, sn.QualName( "ext::ai", "embedding_dimensions"), int, ) res.append( ( f'__ext_ai_{idx_id}_embedding__', f'__ext_ai_{idx_id}_embedding__', ( 'edgedb', f'vector({dimensions})', ), ) ) return res def populate_pointer_set_for_source_union( schema: s_schema.Schema, components: list[Source], union: Source, *, modname: Optional[str] = None, ) -> s_schema.Schema: if modname is None: modname = '__derived__' union_pointers = {} for pn, ptr in components[0].get_pointers(schema).items(schema): ptrs = [ptr] for component in components[1:]: other_ptr = component.get_pointers(schema).get( schema, pn, None) if other_ptr is None: break ptrs.append(other_ptr) if len(ptrs) == len(components): # The pointer is present in all components. if len(ptrs) == 1: ptr = ptrs[0] else: try: schema, ptr = s_pointers.get_or_create_union_pointer( schema, ptrname=pn, source=union, direction=s_pointers.PointerDirection.Outbound, components=set(ptrs), modname=modname, ) except errors.SchemaError as e: # ptrs may have different verbose names # ensure the same one is always chosen vn = sorted(p.get_verbosename(schema) for p in ptrs)[0] e.args = ( (f'with {vn} {e.args[0]}',) + e.args[1:] ) raise e union_pointers[pn] = ptr if union_pointers: for pn, ptr in union_pointers.items(): if union.maybe_get_ptr(schema, pn) is None: schema = union.add_pointer(schema, ptr) return schema ================================================ FILE: edb/schema/std.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import pathlib from edb import lib as stdlib from edb import errors from edb.common import uuidgen from edb import schema from edb.schema import delta as sd from edb.schema import version as s_ver from edb.edgeql import compiler as qlcompiler from edb.edgeql import parser as qlparser from . import ddl as s_ddl from . import name as sn from . import schema as s_schema SCHEMA_ROOT = pathlib.Path(schema.__path__[0]) LIB_ROOT = pathlib.Path(stdlib.__path__[0]) QL_COMPILER_ROOT = pathlib.Path(qlcompiler.__path__[0]) QL_PARSER_ROOT = pathlib.Path(qlparser.__path__[0]) CACHE_SRC_DIRS = ( (SCHEMA_ROOT, '.py'), (QL_COMPILER_ROOT, '.py'), (QL_PARSER_ROOT, '.py'), (LIB_ROOT, '.edgeql'), ) def get_std_module_text(modname: sn.Name) -> str: module_eql = '' module_path = LIB_ROOT / str(modname) module_files = [] if module_path.is_dir(): for entry in module_path.iterdir(): if entry.is_file() and entry.suffix == '.edgeql': module_files.append(entry) else: module_path = module_path.with_suffix('.edgeql') if not module_path.exists(): raise errors.SchemaError(f'std module not found: {modname}') module_files.append(module_path) module_files.sort(key=lambda p: p.name) for module_file in module_files: with open(module_file) as f: module_eql += '\n' + f.read() return module_eql def load_std_module( schema: s_schema.Schema, modname: sn.Name, ) -> s_schema.Schema: return s_ddl.apply_ddl_script( get_std_module_text(modname), schema=schema, modaliases={}, stdmode=True, ) BASE_VERSION = uuidgen.UUID('013d1e23-51ce-11ee-a29d-e1f01853d332') GLOBAL_BASE_VERSION = uuidgen.UUID('013d235b-51ce-11ee-be76-bf15d10edfe5') def make_schema_version( schema: s_schema.Schema, ) -> tuple[s_schema.Schema, s_ver.CreateSchemaVersion]: context = sd.CommandContext(stdmode=True) sv = sn.UnqualName('__schema_version__') schema_version = s_ver.CreateSchemaVersion(classname=sv) schema_version.set_attribute_value('name', sv) schema_version.set_attribute_value('version', BASE_VERSION) schema_version.set_attribute_value('internal', True) schema = sd.apply(schema_version, schema=schema, context=context) return schema, schema_version def make_global_schema_version( schema: s_schema.Schema, ) -> tuple[s_schema.Schema, s_ver.CreateGlobalSchemaVersion]: context = sd.CommandContext(stdmode=True) sv = sn.UnqualName('__global_schema_version__') schema_version = s_ver.CreateGlobalSchemaVersion(classname=sv) schema_version.set_attribute_value('name', sv) schema_version.set_attribute_value('version', GLOBAL_BASE_VERSION) schema_version.set_attribute_value('internal', True) schema = sd.apply(schema_version, schema=schema, context=context) return schema, schema_version ================================================ FILE: edb/schema/triggers.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, AbstractSet, TYPE_CHECKING from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from . import annos as s_anno from . import delta as sd from . import expr as s_expr from . import name as sn from . import objects as so from . import referencing from . import schema as s_schema from . import sources as s_sources from . import types as s_types if TYPE_CHECKING: from . import objtypes as s_objtypes class Trigger( referencing.NamedReferencedInheritingObject, so.InheritingObject, # Help reflection figure out the right db MRO qlkind=qltypes.SchemaObjectClass.TRIGGER, data_safe=True, ): # XXX: compcoef is zero since we don't have syntax yet timing = so.SchemaField( qltypes.TriggerTiming, coerce=True, compcoef=0.0, special_ddl_syntax=True, ) kinds = so.SchemaField( so.MultiPropSet[qltypes.TriggerKind], coerce=True, compcoef=0.0, special_ddl_syntax=True, ) scope = so.SchemaField( qltypes.TriggerScope, coerce=True, compcoef=0.0, special_ddl_syntax=True, ) expr = so.SchemaField( s_expr.Expression, compcoef=0.909, special_ddl_syntax=True, ) condition = so.SchemaField( s_expr.Expression, default=None, coerce=True, compcoef=0.909, special_ddl_syntax=True, ) subject = so.SchemaField( so.InheritingObject, compcoef=None, inheritable=False) # We don't support SET/DROP OWNED owned on triggers so we set its # compcoef to 0.0 owned = so.SchemaField( bool, default=False, inheritable=False, compcoef=0.0, reflection_method=so.ReflectionMethod.AS_LINK, special_ddl_syntax=True, ) def get_subject(self, schema: s_schema.Schema) -> s_objtypes.ObjectType: subj: s_objtypes.ObjectType = self.get_field_value(schema, 'subject') return subj class TriggerCommandContext( sd.ObjectCommandContext[Trigger], s_anno.AnnotationSubjectCommandContext, ): pass class TriggerSourceCommandContext[Source_T: s_sources.Source]( s_sources.SourceCommandContext[Source_T] ): pass class TriggerCommand( referencing.NamedReferencedInheritingObjectCommand[Trigger], s_anno.AnnotationSubjectCommand[Trigger], context_class=TriggerCommandContext, referrer_context_class=TriggerSourceCommandContext, ): def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super().canonicalize_attributes(schema, context) parent_ctx = self.get_referrer_context_or_die(context) source = parent_ctx.op.scls trig_name = self.get_verbosename(parent=source.get_verbosename(schema)) for field in ('expr', 'condition'): if (expr := self.get_local_attribute_value(field)) is None: continue vname = 'when' if field == 'condition' else 'using' expression = self.compile_expr_field( schema, context, field=Trigger.get_field(field), value=expr, ) if field == 'condition': target = schema.get( sn.QualName('std', 'bool'), type=s_types.Type) expr_type = expression.irast.stype if not expr_type.issubclass(expression.irast.schema, target): span = self.get_attribute_span(field) raise errors.SchemaDefinitionError( f'{vname} expression for {trig_name} is of invalid ' f'type: ' f'{expr_type.get_displayname(expression.irast.schema)}' f', expected {target.get_displayname(schema)}', span=span, ) if expression.irast.dml_exprs: raise errors.SchemaDefinitionError( 'data-modifying statements are not allowed in trigger ' 'when clauses', span=expression.irast.dml_exprs[0].span, ) return schema def _get_scope( self, schema: s_schema.Schema, ) -> qltypes.TriggerScope: return self.get_attribute_value('scope') or self.scls.get_scope(schema) def _get_kinds( self, schema: s_schema.Schema, ) -> AbstractSet[qltypes.TriggerKind]: return self.get_attribute_value('kinds') or self.scls.get_kinds(schema) def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: if field.name in {'expr', 'condition'}: from edb.ir import pathid parent_ctx = self.get_referrer_context_or_die(context) source = parent_ctx.op.get_object(schema, context) assert isinstance(source, s_types.Type) # XXX: in_ddl_context_name is disabled for now because # it causes the compiler to reject DML; we might actually # want it for something, though, so we might need to # improve that restriction. # parent_vname = source.get_verbosename(schema) # pol_name = self.get_verbosename(parent=parent_vname) # in_ddl_context_name = pol_name scope = self._get_scope(schema) kinds = self._get_kinds(schema) anchors: dict[str, pathid.PathId] = {} if qltypes.TriggerKind.Insert not in kinds: anchors['__old__'] = pathid.PathId.from_type( schema, source, typename=sn.QualName(module='__derived__', name='__old__'), env=None, ) if qltypes.TriggerKind.Delete not in kinds: anchors['__new__'] = pathid.PathId.from_type( schema, source, typename=sn.QualName(module='__derived__', name='__new__'), env=None, ) singletons = ( frozenset(anchors.values()) if scope == qltypes.TriggerScope.Each else frozenset() ) assert isinstance(source, s_types.Type) try: return type(value).compiled( value, schema=schema, options=qlcompiler.CompilerOptions( modaliases=context.modaliases, schema_object_context=self.get_schema_metaclass(), anchors=anchors, singletons=singletons, apply_query_rewrites=not context.stdmode, track_schema_ref_exprs=track_schema_ref_exprs, # in_ddl_context_name=in_ddl_context_name, detached=True, trigger_type=source, trigger_kinds=kinds, ), context=context, ) except errors.QueryError as e: if not e.has_span(): e.set_span( self.get_attribute_span(field.name) ) raise else: return super().compile_expr_field( schema, context, field, value, track_schema_ref_exprs) def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: if field.name in {'expr', 'condition'}: return s_expr.Expression(text='false') else: raise NotImplementedError(f'unhandled field {field.name!r}') def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: # XXX: verify we don't have the same bug as access policies # where linkprop defaults are broken. # (I think we won't need to, since we'll operate after # the *real* operations) pass class CreateTrigger( TriggerCommand, referencing.CreateReferencedInheritingObject[Trigger], ): referenced_astnode = astnode = qlast.CreateTrigger def get_ast_attr_for_field( self, field: str, astnode: type[qlast.DDLOperation], ) -> Optional[str]: if ( field in ('timing', 'condition', 'kinds', 'scope', 'expr') and issubclass(astnode, qlast.CreateTrigger) ): return field else: return super().get_ast_attr_for_field(field, astnode) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: cmd = super()._cmd_tree_from_ast(schema, astnode, context) assert isinstance(astnode, qlast.CreateTrigger) assert isinstance(cmd, TriggerCommand) if astnode.expr: cmd.set_attribute_value( 'expr', s_expr.Expression.from_ast( astnode.expr, schema, context.modaliases, context.localnames, ), span=astnode.expr.span, ) if astnode.condition is not None: cmd.set_attribute_value( 'condition', s_expr.Expression.from_ast( astnode.condition, schema, context.modaliases, context.localnames, ), span=astnode.condition.span, ) cmd.set_attribute_value('timing', astnode.timing) cmd.set_attribute_value('kinds', astnode.kinds) cmd.set_attribute_value('scope', astnode.scope) return cmd class RenameTrigger( TriggerCommand, referencing.RenameReferencedInheritingObject[Trigger], ): pass class RebaseTrigger( TriggerCommand, referencing.RebaseReferencedInheritingObject[Trigger], ): pass class AlterTrigger( TriggerCommand, referencing.AlterReferencedInheritingObject[Trigger], ): referenced_astnode = astnode = qlast.AlterTrigger def _alter_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._alter_begin(schema, context) # TODO: We may wish to support this in the future but it will # take some thought. if ( self.get_attribute_value('owned') and not self.get_orig_attribute_value('owned') ): raise errors.SchemaDefinitionError( f'cannot alter the definition of inherited trigger ' f'{self.scls.get_displayname(schema)}', span=self.span ) return schema class DeleteTrigger( TriggerCommand, referencing.DeleteReferencedInheritingObject[Trigger], ): referenced_astnode = astnode = qlast.DropTrigger ================================================ FILE: edb/schema/types.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import collections import collections.abc import enum import typing from typing import Self import uuid from edb import errors from edb.common import checked from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.edgeql import compiler as qlcompiler from . import annos as s_anno from . import casts as s_casts from . import delta as sd from . import expr as s_expr from . import inheriting from . import name as s_name from . import objects as so from . import schema as s_schema from . import utils if typing.TYPE_CHECKING: from typing import Any, Iterable, Iterator, Mapping, Optional from typing import AbstractSet, Sequence, Callable from edb.common import parsing MAX_TYPE_DISTANCE = 1_000_000_000 class ExprType(enum.IntEnum): """Enumeration to identify the type of an expression in aliases.""" Select = enum.auto() Insert = enum.auto() Update = enum.auto() Delete = enum.auto() Group = enum.auto() def is_update(self) -> bool: return self == ExprType.Update def is_insert(self) -> bool: return self == ExprType.Insert def is_mutation(self) -> bool: return self != ExprType.Select and self != ExprType.Group TypeT_co = typing.TypeVar('TypeT_co', bound='Type', covariant=True) InheritingTypeT = typing.TypeVar('InheritingTypeT', bound='InheritingType') CollectionTypeT = typing.TypeVar('CollectionTypeT', bound='Collection') CollectionTypeT_co = typing.TypeVar( 'CollectionTypeT_co', bound='Collection', covariant=True) CollectionExprAliasT = typing.TypeVar( 'CollectionExprAliasT', bound='CollectionExprAlias' ) class Type( so.SubclassableObject, s_anno.AnnotationSubject, ): """A schema item that is a valid *type*.""" # If this type is an alias, expr will contain an expression that # defines it. expr = so.SchemaField( s_expr.Expression, default=None, coerce=True, compcoef=0.909) # For a type representing an expression alias, this would contain the # expression type. Non-alias types have None here. expr_type = so.SchemaField( ExprType, default=None, compcoef=0.909) # True for views. This should always match the value of # `bool(expr_type)`, but can be exported in the introspection # schema without revealing weird internals. from_alias = so.SchemaField( bool, default=False, # cannot alter type from being produced by an alias into an actual type compcoef=0.0, ) # True when from a global. The purpose of this is to ensure that # the types from globals and aliases can't be migrated between # each other. from_global = so.SchemaField( bool, default=False, compcoef=0.2) # True for aliases defined by CREATE ALIAS, false for local # aliases in queries. alias_is_persistent = so.SchemaField( bool, default=False, compcoef=None) # If this type is a view defined by a nested shape expression, # and the nested shape contains references to link properties, # rptr will contain the inbound pointer class. rptr = so.SchemaField( so.Object, weak_ref=True, default=None, compcoef=0.909) # The OID by which the backend refers to the type. backend_id = so.SchemaField( int, default=None, inheritable=False) # True for types that cannot be persistently stored. # See std::fts::document for an example. transient = so.SchemaField(bool, default=False) def compare( self, other: so.Object, *, our_schema: s_schema.Schema, their_schema: s_schema.Schema, context: so.ComparisonContext, ) -> float: # We need to be able to compare objects and scalars, in some places if ( isinstance(other, Type) and not isinstance(other, self.__class__) and not isinstance(self, other.__class__) ): return 0.0 return super().compare( other, our_schema=our_schema, their_schema=their_schema, context=context) def is_blocking_ref( self, schema: s_schema.Schema, reference: so.Object ) -> bool: return reference != self.get_rptr(schema) def derive_subtype( self: Self, schema: s_schema.Schema, *, name: s_name.QualName, mark_derived: bool = False, attrs: Optional[Mapping[str, Any]] = None, inheritance_merge: bool = True, transient: bool = False, preserve_endpoint_ptrs: bool = False, inheritance_refdicts: Optional[AbstractSet[str]] = None, stdmode: bool = False, **kwargs: Any, ) -> tuple[s_schema.Schema, Self]: if self.get_name(schema) == name: raise errors.SchemaError( f'cannot derive {self!r}({name}) from itself') derived_attrs: dict[str, object] = {} if attrs is not None: derived_attrs.update(attrs) derived_attrs['name'] = name derived_attrs['bases'] = so.ObjectList.create(schema, [self]) derived_attrs['from_alias'] = bool(derived_attrs.get('expr_type')) cmd = sd.get_object_delta_command( objtype=type(self), cmdtype=sd.CreateObject, schema=schema, name=name, ) for k, v in derived_attrs.items(): cmd.set_attribute_value(k, v) context = sd.CommandContext( modaliases={}, schema=schema, stdmode=stdmode, ) delta = sd.DeltaRoot() with context(sd.DeltaRootContext(schema=schema, op=delta)): if not inheritance_merge: context.current().inheritance_merge = False if inheritance_refdicts is not None: context.current().inheritance_refdicts = inheritance_refdicts if mark_derived: context.current().mark_derived = True if transient: context.current().transient_derivation = True if not preserve_endpoint_ptrs: context.current().slim_links = True delta.add(cmd) schema = delta.apply(schema, context) derived = typing.cast(Self, schema.get(name)) return schema, derived def is_object_type(self) -> bool: return False def is_free_object_type(self, schema: s_schema.Schema) -> bool: return False def is_union_type(self, schema: s_schema.Schema) -> bool: return False def is_intersection_type(self, schema: s_schema.Schema) -> bool: return False def is_compound_type(self, schema: s_schema.Schema) -> bool: return False def is_polymorphic(self, schema: s_schema.Schema) -> bool: return False def is_any(self, schema: s_schema.Schema) -> bool: return False def is_anytuple(self, schema: s_schema.Schema) -> bool: return False def is_anyobject(self, schema: s_schema.Schema) -> bool: return False def is_scalar(self) -> bool: return False def is_collection(self) -> bool: return False def is_array(self) -> bool: return False def is_json(self, schema: s_schema.Schema) -> bool: return False def is_tuple(self, schema: s_schema.Schema) -> bool: return False def is_range(self) -> bool: return False def is_multirange(self) -> bool: return False def is_enum(self, schema: s_schema.Schema) -> bool: return False def is_sequence(self, schema: s_schema.Schema) -> bool: return False def is_array_of_arrays(self, schema: s_schema.Schema) -> bool: return False def is_array_of_tuples(self, schema: s_schema.Schema) -> bool: return False def find_predicate( self, pred: Callable[[Type], bool], schema: s_schema.Schema, ) -> Optional[Type]: if pred(self): return self else: return None def contains_predicate( self, pred: Callable[[Type], bool], schema: s_schema.Schema, ) -> bool: return bool(self.find_predicate(pred, schema)) def find_generic(self, schema: s_schema.Schema) -> Optional[Type]: return self.find_predicate( lambda x: x.is_any(schema) or x.is_anyobject(schema), schema ) def contains_object(self, schema: s_schema.Schema) -> bool: return self.contains_predicate(lambda x: x.is_object_type(), schema) def contains_json(self, schema: s_schema.Schema) -> bool: return self.contains_predicate(lambda x: x.is_json(schema), schema) def find_array(self, schema: s_schema.Schema) -> Optional[Type]: return self.find_predicate(lambda x: x.is_array(), schema) def contains_array_of_array(self, schema: s_schema.Schema) -> bool: return self.contains_predicate( lambda x: x.is_array_of_arrays(schema), schema) def contains_array_of_tuples(self, schema: s_schema.Schema) -> bool: return self.contains_predicate( lambda x: x.is_array_of_tuples(schema), schema) def test_polymorphic(self, schema: s_schema.Schema, poly: Type) -> bool: """Check if this type can be matched by a polymorphic type. Examples: - `array`.test_polymorphic(`array`) -> True - `array`.test_polymorphic(`array`) -> True - `array`.test_polymorphic(`anyscalar`) -> False - `float32`.test_polymorphic(`anyint`) -> False - `int32`.test_polymorphic(`anyint`) -> True """ if not poly.is_polymorphic(schema): raise TypeError('expected a polymorphic type as a second argument') if poly.is_any(schema): return True if poly.is_anyobject(schema) and self.is_object_type(): return True return self._test_polymorphic(schema, poly) def resolve_polymorphic( self, schema: s_schema.Schema, other: Type ) -> Optional[Type]: """Resolve the polymorphic type component. Examples: - `array`.resolve_polymorphic(`array`) -> `int` - `array`.resolve_polymorphic(`tuple`) -> None """ if not self.is_polymorphic(schema): return None return self._resolve_polymorphic(schema, other) def to_nonpolymorphic( self: Self, schema: s_schema.Schema, concrete_type: Type ) -> tuple[s_schema.Schema, Type]: """Produce an non-polymorphic version of self. Example: `array`.to_nonpolymorphic(`int`) -> `array` `tuple`.to_nonpolymorphic(`str`) -> `tuple` """ if not self.is_polymorphic(schema): raise TypeError('non-polymorphic type') return self._to_nonpolymorphic(schema, concrete_type) def _test_polymorphic(self, schema: s_schema.Schema, other: Type) -> bool: return False def _resolve_polymorphic( self, schema: s_schema.Schema, concrete_type: Type, ) -> Optional[Type]: raise NotImplementedError( f'{type(self)} does not support resolve_polymorphic()') def _to_nonpolymorphic( self: Self, schema: s_schema.Schema, concrete_type: Type, ) -> tuple[s_schema.Schema, Type]: raise NotImplementedError( f'{type(self)} does not support to_nonpolymorphic()') def is_view(self, schema: s_schema.Schema) -> bool: return self.get_from_alias(schema) def castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: if self.implicitly_castable_to(other, schema): return True elif self.assignment_castable_to(other, schema): return True else: return False def assignment_castable_to( self, other: Type, schema: s_schema.Schema ) -> bool: return self.implicitly_castable_to(other, schema) def implicitly_castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: return False def get_implicit_cast_distance( self, other: Type, schema: s_schema.Schema ) -> int: return -1 def find_common_implicitly_castable_type( self, other: Type, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Optional[Type]]: return schema, None def get_union_of( self: Self, schema: s_schema.Schema, ) -> Optional[so.ObjectSet[Self]]: return None def get_is_opaque_union(self, schema: s_schema.Schema) -> bool: return False def get_intersection_of( self: Self, schema: s_schema.Schema, ) -> Optional[so.ObjectSet[Self]]: return None def material_type( self: Self, schema: s_schema.Schema ) -> tuple[s_schema.Schema, Self]: return schema, self def peel_view(self, schema: s_schema.Schema) -> Type: return self def get_common_parent_type_distance( self, other: Type, schema: s_schema.Schema, ) -> int: raise NotImplementedError def allow_ref_propagation( self, schema: s_schema.Schema, context: sd.CommandContext, refdict: so.RefDict, ) -> bool: return not self.is_view(schema) def as_shell( self: Self, schema: s_schema.Schema, ) -> TypeShell[Self]: name = typing.cast(s_name.QualName, self.get_name(schema)) if ( (union_of := self.get_union_of(schema)) and not self.is_view(schema) ): assert isinstance(self, so.QualifiedObject) return UnionTypeShell( module=name.module, components=[ o.as_shell(schema) for o in union_of.objects(schema) ], opaque=self.get_is_opaque_union(schema), schemaclass=type(self), ) elif ( (intersection_of := self.get_intersection_of(schema)) and not self.is_view(schema) ): assert isinstance(self, so.QualifiedObject) return IntersectionTypeShell( module=name.module, components=[ o.as_shell(schema) for o in intersection_of.objects(schema) ], schemaclass=type(self), ) else: return TypeShell( name=name, schemaclass=type(self), ) def record_cmd_object_aux_data( self, schema: s_schema.Schema, cmd: sd.ObjectCommand[Type], ) -> None: super().record_cmd_object_aux_data(schema, cmd) if self.is_compound_type(schema): cmd.set_object_aux_data('is_compound_type', True) def as_type_delete_if_unused( self: Self, schema: s_schema.Schema, ) -> Optional[sd.DeleteObject[Self]]: """If this is type is owned by other objects, delete it if unused. For types that get created behind the scenes as part of another object, such as collection types and union types, this should generate an appropriate deletion. Otherwise, it should return None. """ return None def _is_deletable( self, schema: s_schema.Schema, ) -> bool: # this type was already deleted by some other op # (probably alias types cleanup) return schema.get_by_id(self.id, default=None) is not None class QualifiedType(so.QualifiedObject, Type): pass class InheritingType(so.DerivableInheritingObject, QualifiedType): def material_type[ InheritingTypeT: InheritingType, Schema_T: s_schema.Schema ]( self: InheritingTypeT, schema: Schema_T, ) -> tuple[Schema_T, InheritingTypeT]: return schema, self.get_nearest_non_derived_parent(schema) def peel_view(self, schema: s_schema.Schema) -> Type: # When self is a view, this returns the class the view # is derived from (which may be another view). If no # parent class is available, returns self. if self.is_view(schema): return typing.cast(Type, self.get_bases(schema).first(schema)) else: return self def get_common_parent_type_distance( self, other: Type, schema: s_schema.Schema, ) -> int: if other.is_any(schema) or self.is_any(schema): return MAX_TYPE_DISTANCE if not isinstance(other, type(self)): return -1 if self == other: return 0 ancestors = utils.get_class_nearest_common_ancestors( schema, [self, other]) if not ancestors: return -1 elif self in ancestors: return 0 else: all_ancestors = list(self.get_ancestors(schema).objects(schema)) return min( all_ancestors.index(ancestor) + 1 for ancestor in ancestors) class TypeShell(so.ObjectShell[TypeT_co]): schemaclass: type[TypeT_co] extra_args: tuple[qlast.Expr | qlast.TypeExpr, ...] | None def __init__( self, *, name: s_name.Name, origname: Optional[s_name.Name] = None, displayname: Optional[str] = None, expr: Optional[str] = None, schemaclass: type[TypeT_co], span: Optional[parsing.Span] = None, extra_args: tuple[qlast.Expr] | None = None, ) -> None: super().__init__( name=name, origname=origname, displayname=displayname, schemaclass=schemaclass, span=span, ) self.expr = expr self.extra_args = extra_args def is_polymorphic(self, schema: s_schema.Schema) -> bool: return self.resolve(schema).is_polymorphic(schema) def as_create_delta( self, schema: s_schema.Schema, *, view_name: Optional[s_name.QualName] = None, attrs: Optional[dict[str, Any]] = None, ) -> sd.Command: raise NotImplementedError('unsupported typeshell') def has_intersection(self) -> bool: return False class TypeExprShell(TypeShell[TypeT_co]): components: tuple[TypeShell[TypeT_co], ...] def __init__( self, *, name: s_name.Name, components: Iterable[TypeShell[TypeT_co]], schemaclass: type[TypeT_co], span: Optional[parsing.Span] = None, ) -> None: super().__init__( name=name, schemaclass=schemaclass, span=span, ) self.components = tuple(components) def resolve_components( self, schema: s_schema.Schema, ) -> tuple[TypeT_co, ...]: return tuple(c.resolve(schema) for c in self.components) def get_components( self, schema: s_schema.Schema, ) -> tuple[TypeShell[TypeT_co], ...]: return self.components def has_intersection(self) -> bool: return any( c.has_intersection() for c in self.components ) class UnionTypeShell(TypeExprShell[TypeT_co]): def __init__( self, *, module: str, components: Iterable[TypeShell[TypeT_co]], opaque: bool = False, schemaclass: type[TypeT_co], span: Optional[parsing.Span] = None, ) -> None: name = get_union_type_name( (c.name for c in components), opaque=opaque, module=module, ) super().__init__( name=name, components=components, schemaclass=schemaclass, span=span, ) self.opaque = opaque def as_create_delta( self, schema: s_schema.Schema, *, view_name: Optional[s_name.QualName] = None, attrs: Optional[dict[str, Any]] = None, ) -> sd.Command: assert isinstance(self.name, s_name.QualName) cmd = CreateUnionType(classname=self.name) for component in self.components: if isinstance(component, TypeExprShell): cmd.add_prerequisite(component.as_create_delta(schema)) cmd.set_attribute_value('name', self.name) cmd.set_attribute_value('components', tuple(self.components)) cmd.set_attribute_value('is_opaque_union', self.opaque) cmd.set_attribute_value('span', self.span) return cmd def __repr__(self) -> str: dn = 'UnionType' comps = ' | '.join(repr(c) for c in self.components) return f'<{type(self).__name__} {dn}({comps}) at 0x{id(self):x}>' class AlterType[TypeT: Type](sd.AlterObject[TypeT]): def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if hasattr(self, 'scls') and self.scls.get_from_alias(schema): # This is a nested view type, e.g # __FooAlias_bar produced by FooAlias := (SELECT Foo { bar: ... }) # and should obviously not appear as a top level definition. return None else: return super()._get_ast(schema, context, parent_node=parent_node) class RenameType[TypeT: Type](sd.RenameObject[TypeT]): def _canonicalize( self, schema: s_schema.Schema, context: sd.CommandContext, scls: TypeT, ) -> None: super()._canonicalize(schema, context, scls) # Now, see if there are any compound or collection types using # this type as a component. We must rename them, as they derive # their names from the names of their component types. # We must be careful about the order in which we consider the # referrers, because they may reference each other as well, and # so we must proceed with renames starting from the simplest type. referrers = collections.defaultdict(set) referrer_map = schema.get_referrers_ex(scls, scls_type=Type) for (_, field_name), objs in referrer_map.items(): for obj in objs: referrers[obj].add(field_name) ref_order = sd.sort_by_cross_refs(schema, referrers) for ref_type in ref_order: field_names = referrers[ref_type] for field_name in field_names: if field_name == 'union_of' or field_name == 'intersection_of': orig_ref_type_name = ref_type.get_name(schema) assert isinstance(orig_ref_type_name, s_name.QualName) components = ref_type.get_field_value( schema, field_name) assert components is not None component_names = set(components.names(schema)) component_names.discard(self.classname) component_names.add(self.new_name) if field_name == 'union_of': new_ref_type_name = get_union_type_name( component_names, module=orig_ref_type_name.module, opaque=ref_type.get_is_opaque_union(schema), ) else: new_ref_type_name = get_intersection_type_name( component_names, module=orig_ref_type_name.module, ) self.add(self.init_rename_branch( ref_type, new_ref_type_name, schema=schema, context=context, )) elif ( isinstance(ref_type, Tuple) and field_name == 'element_types' ): subtypes = { k: st.get_name(schema) for k, st in ( ref_type.get_element_types(schema).items(schema) ) } new_tup_type_name = Tuple.generate_name( subtypes, named=ref_type.is_named(schema), ) self.add(self.init_rename_branch( ref_type, new_tup_type_name, schema=schema, context=context, )) elif ( isinstance(ref_type, Array) and field_name == 'element_type' ): new_arr_type_name = Array.generate_name( ref_type.get_element_type(schema).get_name(schema) ) self.add(self.init_rename_branch( ref_type, new_arr_type_name, schema=schema, context=context, )) def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if ( self.maybe_get_object_aux_data('is_compound_type') or self.scls.is_view(schema) ): return None else: return super()._get_ast(schema, context, parent_node=parent_node) class DeleteType[TypeT: Type](sd.DeleteObject[TypeT]): def _get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if self.maybe_get_object_aux_data('is_compound_type'): return None else: return super()._get_ast(schema, context, parent_node=parent_node) class RenameInheritingType( RenameType[InheritingTypeT], inheriting.RenameInheritingObject[InheritingTypeT], ): pass class DeleteInheritingType( DeleteType[InheritingTypeT], inheriting.DeleteInheritingObject[InheritingTypeT], ): pass class CompoundTypeCommandContext(sd.ObjectCommandContext[InheritingType]): pass class CompoundTypeCommand( sd.QualifiedObjectCommand[InheritingType], context_class=CompoundTypeCommandContext, ): pass class CreateUnionType(sd.CreateObject[InheritingType], CompoundTypeCommand): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: from edb.schema import types as s_types for cmd in self.get_prerequisites(): schema = cmd.apply(schema, context) if not context.canonical: components: Sequence[s_types.Type] = [ c.resolve(schema) for c in self.get_attribute_value('components') ] try: new_schema, union_type, created = utils.ensure_union_type( schema, components, opaque=self.get_attribute_value('is_opaque_union') or False, module=self.classname.module, ) except errors.SchemaError as e: union_name = ( '(' + ' | '.join(sorted( c.get_displayname(schema) for c in components )) + ')' ) e.args = ( (f'cannot create union {union_name} {e.args[0]}',) + e.args[1:] ) e.set_span(self.get_attribute_value('span')) raise e if created: delta = union_type.as_create_delta( schema=new_schema, context=so.ComparisonContext(), ) self.add(delta) for cmd in self.get_subcommands(include_prerequisites=False): schema = cmd.apply(schema, context) return schema class CreateIntersectionType( sd.CreateObject[InheritingType], CompoundTypeCommand ): def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: from edb.schema import types as s_types for cmd in self.get_prerequisites(): schema = cmd.apply(schema, context) if not context.canonical: components: Sequence[s_types.Type] = [ c.resolve(schema) for c in self.get_attribute_value('components') ] try: new_schema, intersection_type, created = ( utils.ensure_intersection_type( schema, components, module=self.classname.module, ) ) except errors.SchemaError as e: intersection_name = ( '(' + ' | '.join(sorted( c.get_displayname(schema) for c in components )) + ')' ) e.args = ( ( f'cannot create intersection ' f'{intersection_name} {e.args[0]}', ) + e.args[1:] ) e.set_span(self.get_attribute_value('span')) raise e if created: delta = intersection_type.as_create_delta( schema=new_schema, context=so.ComparisonContext(), ) self.add(delta) for cmd in self.get_subcommands(include_prerequisites=False): schema = cmd.apply(schema, context) return schema class IntersectionTypeShell(TypeExprShell[TypeT_co]): def __init__( self, *, module: str, components: Iterable[TypeShell[TypeT_co]], schemaclass: type[TypeT_co], span: parsing.Span | None = None, ) -> None: name = get_intersection_type_name( (c.name for c in components), module=module, ) super().__init__( name=name, components=components, schemaclass=schemaclass, span=span ) def as_create_delta( self, schema: s_schema.Schema, *, view_name: Optional[s_name.QualName] = None, attrs: Optional[dict[str, Any]] = None, ) -> sd.Command: assert isinstance(self.name, s_name.QualName) cmd = CreateIntersectionType(classname=self.name) for component in self.components: if isinstance(component, TypeExprShell): cmd.add_prerequisite(component.as_create_delta(schema)) cmd.set_attribute_value('name', self.name) cmd.set_attribute_value('components', tuple(self.components)) cmd.set_attribute_value('span', self.span) return cmd def has_intersection(self) -> bool: return True _collection_impls: dict[str, type[Collection]] = {} class Collection(Type): _schema_name: typing.ClassVar[typing.Optional[str]] = None #: True for collection types that are stored in schema persistently is_persistent = so.SchemaField( bool, default=False, compcoef=None, ) def __init_subclass__( cls, *, schema_name: typing.Optional[str] = None, ) -> None: super().__init_subclass__() if schema_name is not None: if existing := _collection_impls.get(schema_name): raise TypeError( f"{schema_name} is already implemented by {existing}") _collection_impls[schema_name] = cls cls._schema_name = schema_name def as_create_delta( self: CollectionTypeT, schema: s_schema.Schema, context: so.ComparisonContext, ) -> sd.CreateObject[CollectionTypeT]: delta = super().as_create_delta(schema=schema, context=context) assert isinstance(delta, sd.CreateObject) if not isinstance(self, CollectionExprAlias): delta.if_not_exists = True return delta def as_delete_delta( self: CollectionTypeT, *, schema: s_schema.Schema, context: so.ComparisonContext, ) -> sd.ObjectCommand[CollectionTypeT]: delta = super().as_delete_delta(schema=schema, context=context) assert isinstance(delta, sd.DeleteObject) if not isinstance(self, CollectionExprAlias): delta.if_exists = True if not ( isinstance(self, Array) and self.get_element_type(schema).is_scalar() ): # Arrays of scalars are special, because we create them # implicitly and overload reference checks to never # delete them unless the scalar is also deleted. delta.if_unused = True return delta @classmethod def get_displayname_static(cls, name: s_name.Name) -> str: if isinstance(name, s_name.QualName): # FIXME: Globals and alias names do mangling but *don't* # duplicate the module name, which most places do. return str(name).split('@', 1)[0] else: return s_name.recursively_unmangle_shortname(str(name)) @classmethod def get_schema_name(cls) -> str: if cls._schema_name is None: raise TypeError( f"{cls.get_schema_class_displayname()} is not " f"a concrete collection type" ) return cls._schema_name def get_generated_name(self, schema: s_schema.Schema) -> s_name.UnqualName: """Return collection type name generated from element types. Unlike get_name(), which might return a custom name, this will always return a name derived from the names of the collection element type(s). """ raise NotImplementedError def is_polymorphic(self, schema: s_schema.Schema) -> bool: return any(st.is_polymorphic(schema) for st in self.get_subtypes(schema)) def find_predicate( self, pred: Callable[[Type], bool], schema: s_schema.Schema, ) -> Optional[Type]: if pred(self): return self for st in self.get_subtypes(schema): res = st.find_predicate(pred, schema) if res is not None: return res return None def is_collection(self) -> bool: return True def get_common_parent_type_distance( self, other: Type, schema: s_schema.Schema ) -> int: if other.is_any(schema): return 1 if other.__class__ is not self.__class__: return -1 other = typing.cast(Collection, other) other_types = other.get_subtypes(schema) my_types = self.get_subtypes(schema) type_dist = 0 for ot, my in zip(other_types, my_types): el_dist = my.get_common_parent_type_distance(ot, schema) if el_dist < 0: return -1 else: type_dist += el_dist return type_dist def _issubclass( self, schema: s_schema.Schema, parent: so.SubclassableObject ) -> bool: if isinstance(parent, Type) and parent.is_any(schema): return True if isinstance(parent, Type) and parent.is_anyobject(schema): if isinstance(self, Type) and self.is_object_type(): return True if parent.__class__ is not self.__class__: return False # The cast below should not be necessary but Mypy does not believe # that a.__class__ == b.__class__ is enough. parent_types = typing.cast(Collection, parent).get_subtypes(schema) my_types = self.get_subtypes(schema) for pt, my in zip(parent_types, my_types): if not pt.is_any(schema) and not my.issubclass(schema, pt): return False return True def issubclass( self, schema: s_schema.Schema, parent: so.SubclassableObject | tuple[so.SubclassableObject, ...], ) -> bool: if isinstance(parent, tuple): return any(self.issubclass(schema, p) for p in parent) if isinstance(parent, Type) and parent.is_any(schema): return True return self._issubclass(schema, parent) def get_subtypes(self, schema: s_schema.Schema) -> tuple[Type, ...]: raise NotImplementedError def get_typemods(self, schema: s_schema.Schema) -> Any: return () @classmethod def get_class(cls, schema_name: str) -> type[Collection]: coll_type = _collection_impls.get(schema_name) if coll_type: return coll_type else: raise errors.SchemaError( 'unknown collection type: {!r}'.format(schema_name)) @classmethod def from_subtypes( cls, schema: s_schema.Schema, subtypes: Any, typemods: Any = None, ) -> tuple[s_schema.Schema, Collection]: raise NotImplementedError def __repr__(self) -> str: return ( f'<{self.__class__.__name__} ' f'{self.id} at 0x{id(self):x}>' ) def dump(self, schema: s_schema.Schema) -> str: return repr(self) # We define this specifically to override children @classmethod def get_schema_class_displayname(cls) -> str: return 'collection' def as_type_delete_if_unused( self: CollectionTypeT, schema: s_schema.Schema, ) -> Optional[sd.DeleteObject[CollectionTypeT]]: if not self._is_deletable(schema): return None return self.init_delta_command( schema, sd.DeleteObject, if_unused=True, if_exists=True, ) Dimensions = checked.FrozenCheckedList[int] Array_T = typing.TypeVar("Array_T", bound="Array") Array_T_co = typing.TypeVar("Array_T_co", bound="Array", covariant=True) class CollectionTypeShell(TypeShell[CollectionTypeT_co]): def get_subtypes( self, schema: s_schema.Schema, ) -> tuple[TypeShell[Type], ...]: raise NotImplementedError def is_polymorphic(self, schema: s_schema.Schema) -> bool: return any( st.is_polymorphic(schema) for st in self.get_subtypes(schema) ) class CollectionExprAlias(QualifiedType, Collection): @classmethod def get_schema_class_displayname(cls) -> str: return 'expression alias' @classmethod def get_underlying_schema_class(cls) -> type[Collection]: """Return the concrete collection class for this ExprAlias class.""" raise NotImplementedError def as_underlying_type_delete_if_unused( self, schema: s_schema.Schema, ) -> sd.DeleteObject[Type]: """Return a conditional deletion command for the underlying type object """ return sd.get_object_delta_command( objtype=type(self).get_underlying_schema_class(), cmdtype=sd.DeleteObject, schema=schema, name=self.get_generated_name(schema), if_unused=True, if_exists=True, ) def as_type_delete_if_unused( self: CollectionExprAliasT, schema: s_schema.Schema, ) -> Optional[sd.DeleteObject[CollectionExprAliasT]]: if not self._is_deletable(schema): return None cmd = self.init_delta_command(schema, sd.DeleteObject, if_exists=True) cmd.add_prerequisite(self.as_underlying_type_delete_if_unused(schema)) return cmd class Array( Collection, qlkind=qltypes.SchemaObjectClass.ARRAY_TYPE, schema_name='array', ): element_type = so.SchemaField( Type, # We want a low compcoef so that array types are *never* altered. compcoef=0, ) dimensions = so.SchemaField( Dimensions, coerce=True, # We want a low compcoef so that array types are *never* altered. compcoef=0, ) @classmethod def generate_name( cls, element_name: s_name.Name, ) -> s_name.UnqualName: return s_name.UnqualName( f'array<{s_name.mangle_name(str(element_name))}>', ) @classmethod def create( cls: type[Array_T], schema: s_schema.Schema, *, name: Optional[s_name.Name] = None, id: Optional[uuid.UUID] = None, dimensions: Sequence[int] = (), element_type: Any, **kwargs: Any, ) -> tuple[s_schema.Schema, Array_T]: if not dimensions: dimensions = [-1] if dimensions != [-1]: raise errors.UnsupportedFeatureError( f'multi-dimensional arrays are not supported') if name is None: name = cls.generate_name(element_type.get_name(schema)) if isinstance(name, s_name.QualName): result = schema.get(name, type=cls, default=None) else: result = schema.get_global(cls, name, default=None) if result is None: schema, result = super().create_in_schema( schema, id=id, name=name, element_type=element_type, dimensions=dimensions, **kwargs, ) # Compute material type so that we can retrieve it safely later schema, _ = result.material_type(schema) return schema, result def get_generated_name(self, schema: s_schema.Schema) -> s_name.UnqualName: return type(self).generate_name( self.get_element_type(schema).get_name(schema), ) def is_array_of_arrays(self, schema: s_schema.Schema) -> bool: return self.get_element_type(schema).is_array() def is_array_of_tuples(self, schema: s_schema.Schema) -> bool: return self.get_element_type(schema).is_tuple(schema) def get_displayname(self, schema: s_schema.Schema) -> str: return ( f'array<{self.get_element_type(schema).get_displayname(schema)}>') def is_array(self) -> bool: return True def derive_subtype( self, schema: s_schema.Schema, *, name: s_name.QualName, attrs: Optional[Mapping[str, Any]] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, ArrayExprAlias]: assert not kwargs return ArrayExprAlias.from_subtypes( schema, [self.get_element_type(schema)], self.get_typemods(schema), name=name, **(attrs or {}), ) def get_subtypes(self, schema: s_schema.Schema) -> tuple[Type, ...]: return (self.get_element_type(schema),) def get_typemods(self, schema: s_schema.Schema) -> tuple[Any, ...]: return (self.get_dimensions(schema),) def implicitly_castable_to( self, other: Type, schema: s_schema.Schema ) -> bool: if not isinstance(other, Array): return False return self.get_element_type(schema).implicitly_castable_to( other.get_element_type(schema), schema) def get_implicit_cast_distance( self, other: Type, schema: s_schema.Schema ) -> int: if not isinstance(other, Array): return -1 return self.get_element_type(schema).get_implicit_cast_distance( other.get_element_type(schema), schema) def assignment_castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, Array): from . import scalars as s_scalars if not isinstance(other, s_scalars.ScalarType): return False if other.is_polymorphic(schema): return False right = other.get_base_for_cast(schema) assert isinstance(right, Type) return s_casts.is_assignment_castable(schema, self, right) return self.get_element_type(schema).assignment_castable_to( other.get_element_type(schema), schema) def castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, Array): from . import scalars as s_scalars if not isinstance(other, s_scalars.ScalarType): return False if other.is_polymorphic(schema): return False right = other.get_base_for_cast(schema) assert isinstance(right, Type) return s_casts.is_assignment_castable(schema, self, right) return self.get_element_type(schema).castable_to( other.get_element_type(schema), schema) def find_common_implicitly_castable_type( self, other: Type, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Optional[Array]]: if not isinstance(other, Array): return schema, None if self == other: return schema, self my_el = self.get_element_type(schema) schema, subtype = my_el.find_common_implicitly_castable_type( other.get_element_type(schema), schema) if subtype is None: return schema, None return Array.from_subtypes(schema, [subtype]) def _resolve_polymorphic( self, schema: s_schema.Schema, concrete_type: Type, ) -> Optional[Type]: if not isinstance(concrete_type, Array): return None return self.get_element_type(schema).resolve_polymorphic( schema, concrete_type.get_element_type(schema)) def _to_nonpolymorphic( self, schema: s_schema.Schema, concrete_type: Type, ) -> tuple[s_schema.Schema, Array]: st = self.get_subtypes(schema=schema)[0] # TODO: maybe we should have a generic nested polymorphic algo? if isinstance(st, (Range, MultiRange)): schema, newst = st.to_nonpolymorphic(schema, concrete_type) else: newst = concrete_type return Array.from_subtypes(schema, (newst,)) def _test_polymorphic(self, schema: s_schema.Schema, other: Type) -> bool: if other.is_any(schema): return True if not isinstance(other, Array): return False return self.get_element_type(schema).test_polymorphic( schema, other.get_element_type(schema)) @classmethod def from_subtypes( cls: type[Array_T], schema: s_schema.Schema, subtypes: Sequence[Type], typemods: Any = None, *, name: Optional[s_name.QualName] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, Array_T]: if len(subtypes) != 1: raise errors.SchemaError( f'unexpected number of subtypes, expecting 1: {subtypes!r}') stype = subtypes[0] # One-dimensional unbounded array. dimensions = [-1] schema, ty = cls.create( schema, element_type=stype, dimensions=dimensions, name=name, **kwargs, ) return schema, ty @classmethod def create_shell( cls: type[Self], schema: s_schema.Schema, *, subtypes: Sequence[TypeShell[Type]], typemods: Any = None, name: Optional[s_name.Name] = None, expr: Optional[str] = None, ) -> ArrayTypeShell[Self]: if not typemods: typemods = ([-1],) st = next(iter(subtypes)) return ArrayTypeShell( subtype=st, typemods=typemods, name=name, expr=expr, schemaclass=cls, ) def as_shell( self: Self, schema: s_schema.Schema, ) -> ArrayTypeShell[Self]: expr = self.get_expr(schema) expr_text = expr.text if expr is not None else None return type(self).create_shell( schema, subtypes=[st.as_shell(schema) for st in self.get_subtypes(schema)], typemods=self.get_typemods(schema), name=self.get_name(schema), expr=expr_text, ) def material_type( self, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Array]: # We need to resolve material types based on the subtype recursively. st = self.get_element_type(schema) schema, stm = st.material_type(schema) if stm != st or isinstance(self, ArrayExprAlias): return Array.from_subtypes( schema, [stm], typemods=self.get_typemods(schema), ) else: return (schema, self) class ArrayTypeShell(CollectionTypeShell[Array_T_co]): schemaclass: type[Array_T_co] def __init__( self, *, name: Optional[s_name.Name], expr: Optional[str] = None, subtype: TypeShell[Type], typemods: tuple[typing.Any, ...], schemaclass: type[Array_T_co], ) -> None: if name is None: name = schemaclass.generate_name(subtype.name) super().__init__(name=name, schemaclass=schemaclass, expr=expr) self.subtype = subtype self.typemods = typemods def get_subtypes( self, schema: s_schema.Schema, ) -> tuple[TypeShell[Type], ...]: return (self.subtype,) def get_displayname(self, schema: s_schema.Schema) -> str: return f'array<{self.subtype.get_displayname(schema)}>' def as_create_delta( self, schema: s_schema.Schema, *, view_name: Optional[s_name.QualName] = None, attrs: Optional[dict[str, Any]] = None, ) -> sd.CommandGroup: ca: CreateArray | CreateArrayExprAlias cmd = sd.CommandGroup() if view_name is None: ca = CreateArray( classname=self.get_name(schema), if_not_exists=True, ) else: ca = CreateArrayExprAlias( classname=view_name, ) el = self.subtype if isinstance(el, CollectionTypeShell): cmd.add(el.as_create_delta(schema)) ca.set_attribute_value('name', ca.classname) ca.set_attribute_value('element_type', el) ca.set_attribute_value('is_persistent', True) ca.set_attribute_value('abstract', self.is_polymorphic(schema)) ca.set_attribute_value('dimensions', self.typemods[0]) if attrs: for k, v in attrs.items(): ca.set_attribute_value(k, v) cmd.add(ca) return cmd class ArrayExprAlias( CollectionExprAlias, Array, qlkind=qltypes.SchemaObjectClass.ALIAS, ): # N.B: Don't add any SchemaFields to this class, they won't be # reflected properly (since this inherits from the concrete Array). @classmethod def get_underlying_schema_class(cls) -> type[Collection]: return Array Tuple_T = typing.TypeVar('Tuple_T', bound='Tuple') Tuple_T_co = typing.TypeVar('Tuple_T_co', bound='Tuple', covariant=True) class Tuple( Collection, qlkind=qltypes.SchemaObjectClass.TUPLE_TYPE, schema_name='tuple', ): named = so.SchemaField( bool, # We want a low compcoef so that tuples are *never* altered. compcoef=0.01, ) element_types = so.SchemaField( so.ObjectDict[str, Type], coerce=True, # We want a low compcoef so that tuples are *never* altered. compcoef=0.01, # Tuple element types cannot be represented by a direct link, # because the element types may be duplicate, so we need a # proxy object. reflection_proxy=('schema::TupleElement', 'type'), ) @classmethod def generate_name( cls, element_names: Mapping[str, s_name.Name], named: bool = False, ) -> s_name.UnqualName: if named: st_names = ', '.join( f'{n}:{st}' for n, st in element_names.items() ) else: st_names = ', '.join(str(st) for st in element_names.values()) return s_name.UnqualName(f'tuple<{s_name.mangle_name(st_names)}>') @classmethod def create( cls: type[Tuple_T], schema: s_schema.Schema, *, name: Optional[s_name.Name] = None, id: Optional[uuid.UUID] = None, element_types: Mapping[str, Type], named: bool = False, **kwargs: Any, ) -> tuple[s_schema.Schema, Tuple_T]: el_types = so.ObjectDict[str, Type].create(schema, element_types) if name is None: name = cls.generate_name( {n: el.get_name(schema) for n, el in element_types.items()}, named, ) if isinstance(name, s_name.QualName): result = schema.get(name, type=cls, default=None) else: result = schema.get_global(cls, name, default=None) if result is None: schema, result = super().create_in_schema( schema, id=id, name=name, named=named, element_types=el_types, **kwargs, ) # Compute material type so that we can retrieve it safely later schema, _ = result.material_type(schema) return schema, result def get_generated_name(self, schema: s_schema.Schema) -> s_name.UnqualName: els = {n: st.get_name(schema) for n, st in self.iter_subtypes(schema)} return type(self).generate_name(els, self.is_named(schema)) def get_displayname(self, schema: s_schema.Schema) -> str: if self.is_named(schema): st_names = ', '.join( f'{name}: {st.get_displayname(schema)}' for name, st in self.get_element_types(schema).items(schema) ) else: st_names = ', '.join(st.get_displayname(schema) for st in self.get_subtypes(schema)) return f'tuple<{st_names}>' def is_tuple(self, schema: s_schema.Schema) -> bool: return True def is_named(self, schema: s_schema.Schema) -> bool: return self.get_named(schema) def get_element_names(self, schema: s_schema.Schema) -> Sequence[str]: return tuple(self.get_element_types(schema).keys(schema)) def iter_subtypes( self, schema: s_schema.Schema ) -> Iterator[tuple[str, Type]]: yield from self.get_element_types(schema).items(schema) def get_subtypes(self, schema: s_schema.Schema) -> tuple[Type, ...]: return self.get_element_types(schema).values(schema) def normalize_index(self, schema: s_schema.Schema, field: str) -> str: if self.is_named(schema) and field.isdecimal(): idx = int(field) el_names = self.get_element_names(schema) if idx >= 0 and idx < len(el_names): return el_names[idx] else: raise errors.InvalidReferenceError( f'{field} is not a member of ' f'{self.get_displayname(schema)}') return field def index_of(self, schema: s_schema.Schema, field: str) -> int: if field.isdecimal(): idx = int(field) el_names = self.get_element_names(schema) if idx >= 0 and idx < len(el_names): if self.is_named(schema): return el_names.index(field) else: return idx elif self.is_named(schema): el_names = self.get_element_names(schema) try: return el_names.index(field) except ValueError: pass raise errors.InvalidReferenceError( f'{field} is not a member of {self.get_displayname(schema)}') def get_subtype(self, schema: s_schema.Schema, field: str) -> Type: # index can be a name or a position if field.isdecimal(): idx = int(field) subtypes_l = list(self.get_subtypes(schema)) if idx >= 0 and idx < len(subtypes_l): return subtypes_l[idx] elif self.is_named(schema): subtypes_d = dict(self.iter_subtypes(schema)) if field in subtypes_d: return subtypes_d[field] raise errors.InvalidReferenceError( f'{field} is not a member of {self.get_displayname(schema)}') def derive_subtype( self, schema: s_schema.Schema, *, name: s_name.QualName, attrs: Optional[Mapping[str, Any]] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, TupleExprAlias]: assert not kwargs return TupleExprAlias.from_subtypes( schema, dict(self.iter_subtypes(schema)), self.get_typemods(schema), name=name, **(attrs or {}), ) @classmethod def from_subtypes( cls: type[Tuple_T], schema: s_schema.Schema, subtypes: Iterable[Type] | Mapping[str, Type], typemods: Any = None, *, name: Optional[s_name.QualName] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, Tuple_T]: named = False if typemods is not None: named = typemods.get('named', False) types: Mapping[str, Type] if isinstance(subtypes, collections.abc.Mapping): types = subtypes else: types = {str(i): type for i, type in enumerate(subtypes)} schema, ty = cls.create( schema, element_types=types, named=named, name=name, **kwargs) return schema, ty @classmethod def create_shell( cls: type[Tuple_T], schema: s_schema.Schema, *, subtypes: Mapping[str, TypeShell[Type]], typemods: Any = None, name: Optional[s_name.Name] = None, ) -> TupleTypeShell[Tuple_T]: return TupleTypeShell( subtypes=subtypes, typemods=typemods, name=name, schemaclass=cls, ) def as_shell( self: Self, schema: s_schema.Schema, ) -> TupleTypeShell[Self]: stshells: dict[str, TypeShell[Type]] = {} for n, st in self.iter_subtypes(schema): stshells[n] = st.as_shell(schema) return type(self).create_shell( schema, subtypes=stshells, typemods=self.get_typemods(schema), name=self.get_name(schema), ) def implicitly_castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, Tuple): return False self_subtypes = self.get_subtypes(schema) other_subtypes = other.get_subtypes(schema) if len(self_subtypes) != len(other_subtypes): return False if ( self.is_named(schema) and other.is_named(schema) and (self.get_element_names(schema) != other.get_element_names(schema)) ): return False for st, ot in zip(self_subtypes, other_subtypes): if not st.implicitly_castable_to(ot, schema): return False return True def get_implicit_cast_distance( self, other: Type, schema: s_schema.Schema, ) -> int: if not isinstance(other, Tuple): return -1 self_subtypes = self.get_subtypes(schema) other_subtypes = other.get_subtypes(schema) if len(self_subtypes) != len(other_subtypes): return -1 if ( self.is_named(schema) and other.is_named(schema) and (self.get_element_names(schema) != other.get_element_names(schema)) ): return -1 total_dist = 0 for st, ot in zip(self_subtypes, other_subtypes): dist = st.get_implicit_cast_distance(ot, schema) if dist < 0: return -1 total_dist += dist return total_dist def assignment_castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, Tuple): return False self_subtypes = self.get_subtypes(schema) other_subtypes = other.get_subtypes(schema) if len(self_subtypes) != len(other_subtypes): return False if ( self.is_named(schema) and other.is_named(schema) and (self.get_element_names(schema) != other.get_element_names(schema)) ): return False for st, ot in zip(self_subtypes, other_subtypes): if not st.assignment_castable_to(ot, schema): return False return True def castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, Tuple): return False self_subtypes = self.get_subtypes(schema) other_subtypes = other.get_subtypes(schema) if len(self_subtypes) != len(other_subtypes): return False for st, ot in zip(self_subtypes, other_subtypes): if not st.castable_to(ot, schema): return False return True def find_common_implicitly_castable_type( self, other: Type, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Optional[Tuple]]: if not isinstance(other, Tuple): return schema, None if self == other: return schema, self subs = self.get_subtypes(schema) other_subs = other.get_subtypes(schema) if len(subs) != len(other_subs): return schema, None new_types: list[Type] = [] for st, ot in zip(subs, other_subs): schema, nt = st.find_common_implicitly_castable_type(ot, schema) if nt is None: return schema, None new_types.append(nt) if self.is_named(schema) and other.is_named(schema): my_names = self.get_element_names(schema) other_names = other.get_element_names(schema) if my_names == other_names: return Tuple.from_subtypes( schema, dict(zip(my_names, new_types)), {"named": True} ) return Tuple.from_subtypes(schema, new_types) def get_typemods(self, schema: s_schema.Schema) -> dict[str, bool]: return {'named': self.is_named(schema)} def _resolve_polymorphic( self, schema: s_schema.Schema, concrete_type: Type, ) -> Optional[Type]: if not isinstance(concrete_type, Tuple): return None self_subtypes = self.get_subtypes(schema) other_subtypes = concrete_type.get_subtypes(schema) if len(self_subtypes) != len(other_subtypes): return None for source, target in zip(self_subtypes, other_subtypes): if source.is_polymorphic(schema): return source.resolve_polymorphic(schema, target) return None def _to_nonpolymorphic( self: Self, schema: s_schema.Schema, concrete_type: Type, ) -> tuple[s_schema.Schema, Self]: new_types: list[Type] = [] for st in self.get_subtypes(schema): if st.is_polymorphic(schema): schema, nst = st.to_nonpolymorphic(schema, concrete_type) else: nst = st new_types.append(nst) if self.is_named(schema): return type(self).from_subtypes( schema, dict(zip(self.get_element_names(schema), new_types)), {"named": True}, ) return type(self).from_subtypes(schema, new_types) def _test_polymorphic(self, schema: s_schema.Schema, other: Type) -> bool: if other.is_any(schema) or other.is_anytuple(schema): return True if not isinstance(other, Tuple): return False self_subtypes = self.get_subtypes(schema) other_subtypes = other.get_subtypes(schema) if len(self_subtypes) != len(other_subtypes): return False return all(st.test_polymorphic(schema, ot) for st, ot in zip(self_subtypes, other_subtypes)) def material_type( self, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Tuple]: # We need to resolve material types of all the subtypes recursively. new_material_type = False subtypes = {} for st_name, st in self.iter_subtypes(schema): schema, stm = st.material_type(schema) if stm != st: new_material_type = True subtypes[st_name] = stm if new_material_type or isinstance(self, TupleExprAlias): return Tuple.from_subtypes( schema, subtypes, typemods=self.get_typemods(schema)) else: return schema, self class TupleTypeShell(CollectionTypeShell[Tuple_T_co]): schemaclass: type[Tuple_T_co] def __init__( self, *, name: Optional[s_name.Name], subtypes: Mapping[str, TypeShell[Type]], typemods: Any = None, schemaclass: type[Tuple_T_co], ) -> None: if name is None: named = typemods is not None and typemods.get('named', False) name = schemaclass.generate_name( {n: st.name for n, st in subtypes.items()}, named, ) super().__init__(name=name, schemaclass=schemaclass) self.subtypes = subtypes self.typemods = typemods def get_displayname(self, schema: s_schema.Schema) -> str: st_names = ', '.join(st.get_displayname(schema) for st in self.get_subtypes(schema)) return f'tuple<{st_names}>' def get_subtypes( self, schema: s_schema.Schema, ) -> tuple[TypeShell[Type], ...]: return tuple(self.subtypes.values()) def iter_subtypes( self, schema: s_schema.Schema, ) -> Iterator[tuple[str, TypeShell[Type]]]: return iter(self.subtypes.items()) def is_named(self) -> bool: return self.typemods is not None and self.typemods.get('named', False) def as_create_delta( self, schema: s_schema.Schema, *, view_name: Optional[s_name.QualName] = None, attrs: Optional[dict[str, Any]] = None, ) -> CreateTuple | CreateTupleExprAlias: ct: CreateTuple | CreateTupleExprAlias plain_tuple = self._as_plain_create_delta(schema) if view_name is None: ct = plain_tuple else: ct = CreateTupleExprAlias(classname=view_name) self._populate_create_delta(schema, ct, attrs=attrs) for el in self.subtypes.values(): if isinstance(el, CollectionTypeShell): ct.add_prerequisite(el.as_create_delta(schema)) if view_name is not None: ct.add_prerequisite(plain_tuple) return ct def _as_plain_create_delta( self, schema: s_schema.Schema, ) -> CreateTuple: name = self.schemaclass.generate_name( {n: st.get_name(schema) for n, st in self.subtypes.items()}, self.is_named(), ) ct = CreateTuple(classname=name, if_not_exists=True) self._populate_create_delta(schema, ct) return ct def _populate_create_delta( self, schema: s_schema.Schema, ct: CreateTuple | CreateTupleExprAlias, *, attrs: Optional[dict[str, Any]] = None, ) -> None: named = self.is_named() ct.set_attribute_value('name', ct.classname) ct.set_attribute_value('named', named) ct.set_attribute_value('abstract', self.is_polymorphic(schema)) ct.set_attribute_value('is_persistent', True) ct.set_attribute_value('element_types', self.subtypes) if attrs: for k, v in attrs.items(): ct.set_attribute_value(k, v) class TupleExprAlias( CollectionExprAlias, Tuple, qlkind=qltypes.SchemaObjectClass.ALIAS, ): # N.B: Don't add any SchemaFields to this class, they won't be # reflected properly (since this inherits from the concrete Tuple). @classmethod def get_underlying_schema_class(cls) -> type[Collection]: return Tuple Range_T = typing.TypeVar('Range_T', bound='Range') Range_T_co = typing.TypeVar('Range_T_co', bound='Range', covariant=True) class Range( Collection, qlkind=qltypes.SchemaObjectClass.RANGE_TYPE, schema_name='range', ): element_type = so.SchemaField( Type, # We want a low compcoef so that range types are *never* altered. compcoef=0, ) @classmethod def generate_name( cls, element_name: s_name.Name, ) -> s_name.UnqualName: return s_name.UnqualName( f'range<{s_name.mangle_name(str(element_name))}>', ) @classmethod def create( cls: type[Range_T], schema: s_schema.Schema, *, name: Optional[s_name.Name] = None, id: Optional[uuid.UUID] = None, element_type: Any, **kwargs: Any, ) -> tuple[s_schema.Schema, Range_T]: if name is None: name = cls.generate_name(element_type.get_name(schema)) if isinstance(name, s_name.QualName): result = schema.get(name, type=cls, default=None) else: result = schema.get_global(cls, name, default=None) if result is None: schema, result = super().create_in_schema( schema, id=id, name=name, element_type=element_type, **kwargs, ) return schema, result def get_generated_name(self, schema: s_schema.Schema) -> s_name.UnqualName: return type(self).generate_name( self.get_element_type(schema).get_name(schema), ) def get_displayname(self, schema: s_schema.Schema) -> str: return ( f'range<{self.get_element_type(schema).get_displayname(schema)}>') def is_range(self) -> bool: return True def derive_subtype( self, schema: s_schema.Schema, *, name: s_name.QualName, attrs: Optional[Mapping[str, Any]] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, RangeExprAlias]: assert not kwargs return RangeExprAlias.from_subtypes( schema, [self.get_element_type(schema)], self.get_typemods(schema), name=name, **(attrs or {}), ) def get_subtypes(self, schema: s_schema.Schema) -> tuple[Type, ...]: return (self.get_element_type(schema),) def implicitly_castable_to( self, other: Type, schema: s_schema.Schema ) -> bool: if not isinstance(other, (Range, MultiRange)): return False my_el = self.get_element_type(schema) other_el = other.get_element_type(schema) if isinstance(other, MultiRange): # Only valid implicit cast to multirange is the one that preserves # the element type. return my_el.issubclass(schema, other_el) return my_el.implicitly_castable_to(other_el, schema) def get_implicit_cast_distance( self, other: Type, schema: s_schema.Schema ) -> int: if not isinstance(other, (Range, MultiRange)): return -1 extra = 1 if isinstance(other, MultiRange) else 0 return self.get_element_type(schema).get_implicit_cast_distance( other.get_element_type(schema), schema) + extra def assignment_castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, (Range, MultiRange)): return False return self.get_element_type(schema).assignment_castable_to( other.get_element_type(schema), schema) def castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, (Range, MultiRange)): return False return self.get_element_type(schema).castable_to( other.get_element_type(schema), schema) def find_common_implicitly_castable_type( self, other: Type, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Optional[RangeLike]]: if not isinstance(other, (Range, MultiRange)): return schema, None if self == other: return schema, self my_el = self.get_element_type(schema) other_el = other.get_element_type(schema) if ( isinstance(other, MultiRange) and not my_el.issubclass(schema, other_el) ): # Only valid implicit cast to multirange is the one that preserves # the element type. return schema, None schema, subtype = my_el.find_common_implicitly_castable_type( other_el, schema) if subtype is None: return schema, None # Implicitly castable target is based on the `other` subtype because # it may be Range or MultiRange. We also need to account for # CollectionExprAlias. if isinstance(other, CollectionExprAlias): other_t = other.get_underlying_schema_class() # Keeps mypy happy, even though these have to be exactly one of # those two types and not merely subclasses. assert issubclass(other_t, (Range, MultiRange)) else: other_t = type(other) # mypy is not happy even if I try issubclass or a cast for the result # of get_underlying_schema_class, so I'm casting the return here # return typing.cast(typing.Tuple[s_schema.Schema, RangeLike], # other_t.from_subtypes(schema, [subtype])) return other_t.from_subtypes(schema, [subtype]) def _resolve_polymorphic( self, schema: s_schema.Schema, concrete_type: Type, ) -> Optional[Type]: if not isinstance(concrete_type, Range): return None return self.get_element_type(schema).resolve_polymorphic( schema, concrete_type.get_element_type(schema)) def _to_nonpolymorphic( self, schema: s_schema.Schema, concrete_type: Type, ) -> tuple[s_schema.Schema, Range]: return Range.from_subtypes(schema, (concrete_type,)) def _test_polymorphic(self, schema: s_schema.Schema, other: Type) -> bool: if other.is_any(schema): return True if not isinstance(other, (Range, MultiRange)): return False return self.get_element_type(schema).test_polymorphic( schema, other.get_element_type(schema)) @classmethod def from_subtypes( cls: type[Range_T], schema: s_schema.Schema, subtypes: Sequence[Type], typemods: Any = None, *, name: Optional[s_name.QualName] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, Range_T]: if len(subtypes) != 1: raise errors.SchemaError( f'unexpected number of subtypes, expecting 1: {subtypes!r}') stype = subtypes[0] anypoint = schema.get('std::anypoint', type=Type) if not stype.issubclass(schema, anypoint): raise errors.UnsupportedFeatureError( f'unsupported range subtype: {stype.get_displayname(schema)}' ) return cls.create( schema, element_type=stype, name=name, **kwargs, ) @classmethod def create_shell( cls: type[Range_T], schema: s_schema.Schema, *, subtypes: Sequence[TypeShell[Type]], typemods: Any = None, name: Optional[s_name.Name] = None, ) -> RangeTypeShell[Range_T]: st = next(iter(subtypes)) return RangeTypeShell( subtype=st, typemods=typemods, name=name, schemaclass=cls, ) def as_shell( self: Self, schema: s_schema.Schema, ) -> RangeTypeShell[Self]: return type(self).create_shell( schema, subtypes=[st.as_shell(schema) for st in self.get_subtypes(schema)], typemods=self.get_typemods(schema), name=self.get_name(schema), ) def material_type( self, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Range]: # We need to resolve material types based on the subtype recursively. st = self.get_element_type(schema) schema, stm = st.material_type(schema) if stm != st or isinstance(self, RangeExprAlias): return Range.from_subtypes( schema, [stm], typemods=self.get_typemods(schema), ) else: return (schema, self) class RangeTypeShell(CollectionTypeShell[Range_T_co]): schemaclass: type[Range_T_co] def __init__( self, *, name: Optional[s_name.Name], subtype: TypeShell[Type], typemods: tuple[typing.Any, ...], schemaclass: type[Range_T_co], ) -> None: if name is None: name = schemaclass.generate_name(subtype.name) super().__init__(name=name, schemaclass=schemaclass) self.subtype = subtype self.typemods = typemods def get_subtypes( self, schema: s_schema.Schema, ) -> tuple[TypeShell[Type], ...]: return (self.subtype,) def get_displayname(self, schema: s_schema.Schema) -> str: return f'range<{self.subtype.get_displayname(schema)}>' def as_create_delta( self, schema: s_schema.Schema, *, view_name: Optional[s_name.QualName] = None, attrs: Optional[dict[str, Any]] = None, ) -> sd.CommandGroup: ca: CreateRange | CreateRangeExprAlias cmd = sd.CommandGroup() if view_name is None: ca = CreateRange( classname=self.get_name(schema), if_not_exists=True, ) else: ca = CreateRangeExprAlias( classname=view_name, ) el = self.subtype if isinstance(el, CollectionTypeShell): cmd.add(el.as_create_delta(schema)) ca.set_attribute_value('name', ca.classname) ca.set_attribute_value('element_type', el) ca.set_attribute_value('is_persistent', True) ca.set_attribute_value('abstract', self.is_polymorphic(schema)) if attrs: for k, v in attrs.items(): ca.set_attribute_value(k, v) cmd.add(ca) return cmd class RangeExprAlias( CollectionExprAlias, Range, qlkind=qltypes.SchemaObjectClass.ALIAS, ): # N.B: Don't add any SchemaFields to this class, they won't be # reflected properly (since this inherits from the concrete Range). @classmethod def get_underlying_schema_class(cls) -> type[Collection]: return Range MultiRange_T = typing.TypeVar('MultiRange_T', bound='MultiRange') MultiRange_T_co = typing.TypeVar( 'MultiRange_T_co', bound='MultiRange', covariant=True) class MultiRange( Collection, qlkind=qltypes.SchemaObjectClass.MULTIRANGE_TYPE, schema_name='multirange', ): element_type = so.SchemaField( Type, # We want a low compcoef so that multirange types are *never* altered. compcoef=0, ) @classmethod def generate_name( cls, element_name: s_name.Name, ) -> s_name.UnqualName: return s_name.UnqualName( f'multirange<{s_name.mangle_name(str(element_name))}>', ) @classmethod def create( cls: type[MultiRange_T], schema: s_schema.Schema, *, name: Optional[s_name.Name] = None, id: Optional[uuid.UUID] = None, element_type: Any, **kwargs: Any, ) -> tuple[s_schema.Schema, MultiRange_T]: if name is None: name = cls.generate_name(element_type.get_name(schema)) if isinstance(name, s_name.QualName): result = schema.get(name, type=cls, default=None) else: result = schema.get_global(cls, name, default=None) if result is None: schema, result = super().create_in_schema( schema, id=id, name=name, element_type=element_type, **kwargs, ) return schema, result def get_generated_name(self, schema: s_schema.Schema) -> s_name.UnqualName: return type(self).generate_name( self.get_element_type(schema).get_name(schema), ) def get_displayname(self, schema: s_schema.Schema) -> str: return f'''multirange<{self.get_element_type(schema) .get_displayname(schema)}>''' def is_multirange(self) -> bool: return True def derive_subtype( self, schema: s_schema.Schema, *, name: s_name.QualName, attrs: Optional[Mapping[str, Any]] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, MultiRangeExprAlias]: assert not kwargs return MultiRangeExprAlias.from_subtypes( schema, [self.get_element_type(schema)], self.get_typemods(schema), name=name, **(attrs or {}), ) def get_subtypes(self, schema: s_schema.Schema) -> tuple[Type, ...]: return (self.get_element_type(schema),) def implicitly_castable_to( self, other: Type, schema: s_schema.Schema ) -> bool: if not isinstance(other, MultiRange): return False return self.get_element_type(schema).implicitly_castable_to( other.get_element_type(schema), schema) def get_implicit_cast_distance( self, other: Type, schema: s_schema.Schema ) -> int: if not isinstance(other, MultiRange): return -1 return self.get_element_type(schema).get_implicit_cast_distance( other.get_element_type(schema), schema) def assignment_castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, MultiRange): return False return self.get_element_type(schema).assignment_castable_to( other.get_element_type(schema), schema) def castable_to( self, other: Type, schema: s_schema.Schema, ) -> bool: if not isinstance(other, MultiRange): return False return self.get_element_type(schema).castable_to( other.get_element_type(schema), schema) def find_common_implicitly_castable_type( self: MultiRange, other: Type, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, Optional[MultiRange]]: if not isinstance(other, MultiRange): return schema, None if self == other: return schema, self my_el = self.get_element_type(schema) schema, subtype = my_el.find_common_implicitly_castable_type( other.get_element_type(schema), schema) if subtype is None: return schema, None return MultiRange.from_subtypes(schema, [subtype]) def _resolve_polymorphic( self, schema: s_schema.Schema, concrete_type: Type, ) -> Optional[Type]: # polymorphic multiranges can resolve using concrete multiranges and # ranges because ranges are implicitly castable into multiranges. if not isinstance(concrete_type, (Range, MultiRange)): return None return self.get_element_type(schema).resolve_polymorphic( schema, concrete_type.get_element_type(schema)) def _to_nonpolymorphic( self, schema: s_schema.Schema, concrete_type: Type, ) -> tuple[s_schema.Schema, MultiRange]: return MultiRange.from_subtypes(schema, (concrete_type,)) def _test_polymorphic(self, schema: s_schema.Schema, other: Type) -> bool: if other.is_any(schema): return True if not isinstance(other, MultiRange): return False return self.get_element_type(schema).test_polymorphic( schema, other.get_element_type(schema)) @classmethod def from_subtypes( cls: type[MultiRange_T], schema: s_schema.Schema, subtypes: Sequence[Type], typemods: Any = None, *, name: Optional[s_name.QualName] = None, **kwargs: Any, ) -> tuple[s_schema.Schema, MultiRange_T]: if len(subtypes) != 1: raise errors.SchemaError( f'unexpected number of subtypes, expecting 1: {subtypes!r}') stype = subtypes[0] anypoint = schema.get('std::anypoint', type=Type) if not stype.issubclass(schema, anypoint): raise errors.UnsupportedFeatureError( f'unsupported range subtype: {stype.get_displayname(schema)}' ) return cls.create( schema, element_type=stype, name=name, **kwargs, ) @classmethod def create_shell( cls: type[MultiRange_T], schema: s_schema.Schema, *, subtypes: Sequence[TypeShell[Type]], typemods: Any = None, name: Optional[s_name.Name] = None, ) -> MultiRangeTypeShell[MultiRange_T]: st = next(iter(subtypes)) if name is None: name = cls.generate_name( st.get_name(schema), ) return MultiRangeTypeShell( subtype=st, typemods=typemods, name=name, schemaclass=cls, ) def as_shell( self: Self, schema: s_schema.Schema, ) -> MultiRangeTypeShell[Self]: return type(self).create_shell( schema, subtypes=[st.as_shell(schema) for st in self.get_subtypes(schema)], typemods=self.get_typemods(schema), name=self.get_name(schema), ) def material_type( self, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, MultiRange]: # We need to resolve material types based on the subtype recursively. st = self.get_element_type(schema) schema, stm = st.material_type(schema) if stm != st or isinstance(self, MultiRangeExprAlias): return MultiRange.from_subtypes( schema, [stm], typemods=self.get_typemods(schema), ) else: return (schema, self) class MultiRangeTypeShell(CollectionTypeShell[MultiRange_T_co]): schemaclass: type[MultiRange_T_co] def __init__( self, *, name: s_name.Name, subtype: TypeShell[Type], typemods: tuple[typing.Any, ...], schemaclass: type[MultiRange_T_co], ) -> None: super().__init__(name=name, schemaclass=schemaclass) self.subtype = subtype self.typemods = typemods def get_subtypes( self, schema: s_schema.Schema, ) -> tuple[TypeShell[Type], ...]: return (self.subtype,) def get_displayname(self, schema: s_schema.Schema) -> str: return f'multirange<{self.subtype.get_displayname(schema)}>' def as_create_delta( self, schema: s_schema.Schema, *, view_name: Optional[s_name.QualName] = None, attrs: Optional[dict[str, Any]] = None, ) -> sd.CommandGroup: ca: CreateMultiRange | CreateMultiRangeExprAlias cmd = sd.CommandGroup() if view_name is None: ca = CreateMultiRange( classname=self.get_name(schema), if_not_exists=True, ) else: ca = CreateMultiRangeExprAlias( classname=view_name, ) el = self.subtype if isinstance(el, CollectionTypeShell): cmd.add(el.as_create_delta(schema)) ca.set_attribute_value('name', ca.classname) ca.set_attribute_value('element_type', el) ca.set_attribute_value('is_persistent', True) ca.set_attribute_value('abstract', self.is_polymorphic(schema)) if attrs: for k, v in attrs.items(): ca.set_attribute_value(k, v) cmd.add(ca) return cmd class MultiRangeExprAlias( CollectionExprAlias, MultiRange, qlkind=qltypes.SchemaObjectClass.ALIAS, ): # N.B: Don't add any SchemaFields to this class, they won't be # reflected properly (since this inherits from the concrete MultiRange). @classmethod def get_underlying_schema_class(cls) -> type[Collection]: return MultiRange RangeLike = Range | MultiRange def get_union_type_name( component_names: typing.Iterable[s_name.Name], *, opaque: bool = False, module: typing.Optional[str] = None, ) -> s_name.QualName: sorted_name_list = sorted( str(name).replace('::', ':') for name in component_names) if opaque: nqname = f"(opaque: {' | '.join(sorted_name_list)})" else: nqname = f"({' | '.join(sorted_name_list)})" return s_name.QualName(name=nqname, module=module or '__derived__') def get_intersection_type_name( component_names: typing.Iterable[s_name.Name], *, module: typing.Optional[str] = None, ) -> s_name.QualName: sorted_name_list = sorted( str(name).replace('::', ':') for name in component_names) nqname = f"({' & '.join(sorted_name_list)})" return s_name.QualName(name=nqname, module=module or '__derived__') def ensure_schema_type_expr_type( schema: s_schema.Schema, type_shell: TypeExprShell[Type], parent_cmd: sd.Command, *, span: typing.Optional[parsing.Span] = None, context: sd.CommandContext, ) -> Optional[sd.Command]: name = type_shell.get_name(schema) texpr_type = schema.get(name, default=None, type=Type) cmd = None if texpr_type is None: cmd = type_shell.as_create_delta(schema) if cmd is not None: parent_cmd.add_prerequisite(cmd) return cmd def type_dummy_expr( typ: Type, schema: s_schema.Schema, ) -> Optional[s_expr.Expression]: if isinstance(typ, so.DerivableInheritingObject): typ = typ.get_nearest_non_derived_parent(schema) q = qlast.FunctionCall( func=('__std__', 'assert_exists'), args=[ qlast.TypeCast( type=utils.typeref_to_ast(schema, typ), expr=qlast.Set(elements=[]), ) ], ) return s_expr.Expression.from_ast(q, schema) class TypeCommand[TypeT: Type](sd.ObjectCommand[TypeT]): @classmethod def _get_alias_expr(cls, astnode: qlast.CreateAlias) -> qlast.Expr: expr = qlast.get_ddl_field_value(astnode, 'expr') if expr is None: raise errors.InvalidAliasDefinitionError( f'missing required view expression', span=astnode.span ) assert isinstance(expr, qlast.Expr) return expr def get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: if self.get_attribute_value('expr'): return None elif ( (union_of := self.get_attribute_value('union_of')) is not None and union_of.items ): return None elif ( (int_of := self.get_attribute_value('intersection_of')) is not None and int_of.items ): return None else: return super().get_ast(schema, context, parent_node=parent_node) def compile_expr_field( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: s_expr.Expression, track_schema_ref_exprs: bool=False, ) -> s_expr.CompiledExpression: # XXX: This seems like pointless duplication of work from # globals/aliases... why is expr even here? # (... because we export it in the introspection schema) assert field.name == 'expr' return value.compiled( schema=schema, options=qlcompiler.CompilerOptions( schema_object_context=self.get_schema_metaclass(), modaliases=context.modaliases, in_ddl_context_name='type definition', track_schema_ref_exprs=track_schema_ref_exprs, ), context=context, ) def get_dummy_expr_field_value( self, schema: s_schema.Schema, context: sd.CommandContext, field: so.Field[Any], value: Any, ) -> Optional[s_expr.Expression]: if field.name == 'expr': return type_dummy_expr(self.scls, schema) else: raise NotImplementedError(f'unhandled field {field.name!r}') def _create_begin( self, schema: s_schema.Schema, context: sd.CommandContext ) -> s_schema.Schema: schema = super()._create_begin(schema, context) assert isinstance(self.scls, Type) if not self.scls.is_view(schema): delta_root = context.top().op assert isinstance(delta_root, sd.DeltaRoot) delta_root.new_types.add(self.scls.id) return schema class InheritingTypeCommand( sd.QualifiedObjectCommand[InheritingTypeT], TypeCommand[InheritingTypeT], inheriting.InheritingObjectCommand[InheritingTypeT], ): def _validate_bases( self, schema: s_schema.Schema, context: sd.CommandContext, bases: so.ObjectList[InheritingTypeT], shells: Mapping[s_name.QualName, TypeShell[InheritingTypeT]], is_derived: bool, ) -> None: for base in bases.objects(schema): if base.find_generic(schema) is not None or ( base.is_free_object_type(schema) and not is_derived ): base_type_name = base.get_displayname(schema) shell = shells.get(base.get_name(schema)) raise errors.SchemaError( f"{base_type_name!r} cannot be a parent type", span=shell.span if shell is not None else None, ) class CreateInheritingType( InheritingTypeCommand[InheritingTypeT], inheriting.CreateInheritingObject[InheritingTypeT], ): def validate_create( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_create(schema, context) shells = self.get_attribute_value('bases') if isinstance(shells, so.ObjectList): # XXX: fix set_attribute_value shell hygiene shells = shells.as_shell(schema) shell_map = {s.get_name(schema): s for s in shells} bases = self.get_resolved_attribute_value( 'bases', schema=schema, context=context, ) self._validate_bases( schema, context, bases, shell_map, is_derived=self.get_attribute_value('is_derived') or False, ) class RebaseInheritingType( InheritingTypeCommand[InheritingTypeT], inheriting.RebaseInheritingObject[InheritingTypeT], ): def validate_alter( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> None: super().validate_alter(schema, context) shell_map = {} for base_shells, _ in self.added_bases: shell_map.update({s.get_name(schema): s for s in base_shells}) bases = self.get_resolved_attribute_value( 'bases', schema=schema, context=context, ) self._validate_bases( schema, context, bases, shell_map, is_derived=self.scls.get_is_derived(schema), ) class CollectionTypeCommandContext(sd.ObjectCommandContext[Collection]): pass class CollectionTypeCommand(TypeCommand[CollectionTypeT], context_class=CollectionTypeCommandContext): def get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: # CollectionTypeCommand cannot have its own AST because it is a # side-effect of some other command. return None class CollectionExprAliasCommand( sd.QualifiedObjectCommand[CollectionExprAliasT], TypeCommand[CollectionExprAliasT], context_class=CollectionTypeCommandContext, ): def get_ast( self, schema: s_schema.Schema, context: sd.CommandContext, *, parent_node: Optional[qlast.DDLOperation] = None, ) -> Optional[qlast.DDLOperation]: # CollectionTypeCommand cannot have its own AST because it is a # side-effect of some other command. return None class CreateCollectionType( CollectionTypeCommand[CollectionTypeT], sd.CreateObject[CollectionTypeT], ): def canonicalize_attributes( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: # Even if we create a collection while setting up internal # things, we don't mark it internal, since something visible # might use it later. self.set_attribute_value('internal', False) return super().canonicalize_attributes(schema, context) def validate_object( self, schema: s_schema.Schema, context: sd.CommandContext ) -> None: super().validate_object(schema, context) if isinstance(self.scls, (Range, MultiRange)): from . import scalars as s_scalars from edb.pgsql import types as pgtypes st = self.scls.get_subtypes(schema)[0] # general rule of what's supported supported = ( isinstance(st, s_scalars.ScalarType) and st.is_base_type(schema) ) if supported: # actually test that it's supported try: pgtypes.pg_type_from_object(schema, self.scls) except Exception: supported = False if not supported: raise errors.UnsupportedFeatureError( f'unsupported range subtype: {st.get_displayname(schema)}' ) class AlterCollectionType( CollectionTypeCommand[CollectionTypeT], AlterType[CollectionTypeT], sd.AlterObject[CollectionTypeT], ): pass class RenameCollectionType( CollectionTypeCommand[CollectionTypeT], RenameType[CollectionTypeT], ): pass class DeleteCollectionType( CollectionTypeCommand[CollectionTypeT], sd.DeleteObject[CollectionTypeT], ): def _delete_begin( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: schema = super()._delete_begin(schema, context) if not context.canonical: for el in self.scls.get_subtypes(schema): if op := el.as_type_delete_if_unused(schema): self.add_caused(op) return schema class CreateCollectionExprAlias( CollectionExprAliasCommand[CollectionExprAliasT], sd.CreateObject[CollectionExprAliasT], ): pass class DeleteCollectionExprAlias( CollectionExprAliasCommand[CollectionExprAliasT], DeleteCollectionType[CollectionExprAliasT], ): def _canonicalize( self, schema: s_schema.Schema, context: sd.CommandContext, scls: CollectionExprAliasT, ) -> list[sd.Command]: ops = super()._canonicalize(schema, context, scls) ops.append(scls.as_underlying_type_delete_if_unused(schema)) return ops class CreateTuple(CreateCollectionType[Tuple]): pass class AlterTuple(AlterCollectionType[Tuple]): pass class RenameTuple(RenameCollectionType[Tuple]): pass class CreateTupleExprAlias(CreateCollectionExprAlias[TupleExprAlias]): def _get_ast_node( self, schema: s_schema.Schema, context: sd.CommandContext ) -> type[qlast.CreateAlias]: # Can't just use class-level astnode because that creates a # duplicate in ast -> command mapping. return qlast.CreateAlias class RenameTupleExprAlias( CollectionExprAliasCommand[TupleExprAlias], sd.RenameObject[TupleExprAlias], ): pass class AlterTupleExprAlias( CollectionExprAliasCommand[TupleExprAlias], sd.AlterObject[TupleExprAlias], ): pass class CreateArray(CreateCollectionType[Array]): pass class AlterArray(AlterCollectionType[Array]): pass class RenameArray(RenameCollectionType[Array]): pass class CreateArrayExprAlias(CreateCollectionExprAlias[ArrayExprAlias]): def _get_ast_node( self, schema: s_schema.Schema, context: sd.CommandContext ) -> type[qlast.CreateAlias]: # Can't just use class-level astnode because that creates a # duplicate in ast -> command mapping. return qlast.CreateAlias class RenameArrayExprAlias( CollectionExprAliasCommand[ArrayExprAlias], sd.RenameObject[ArrayExprAlias], ): pass class AlterArrayExprAlias( CollectionExprAliasCommand[ArrayExprAlias], sd.AlterObject[ArrayExprAlias], ): pass class CreateRange(CreateCollectionType[Range]): pass class AlterRange(AlterCollectionType[Range]): pass class RenameRange(RenameCollectionType[Range]): pass class CreateRangeExprAlias(CreateCollectionExprAlias[RangeExprAlias]): def _get_ast_node( self, schema: s_schema.Schema, context: sd.CommandContext ) -> type[qlast.CreateAlias]: # Can't just use class-level astnode because that creates a # duplicate in ast -> command mapping. return qlast.CreateAlias class RenameRangeExprAlias( CollectionExprAliasCommand[RangeExprAlias], sd.RenameObject[RangeExprAlias], ): pass class AlterRangeExprAlias( CollectionExprAliasCommand[RangeExprAlias], sd.AlterObject[RangeExprAlias], ): pass class CreateMultiRange(CreateCollectionType[MultiRange]): pass class AlterMultiRange(AlterCollectionType[MultiRange]): pass class RenameMultiRange(RenameCollectionType[MultiRange]): pass class CreateMultiRangeExprAlias(CreateCollectionExprAlias[MultiRangeExprAlias]): def _get_ast_node( self, schema: s_schema.Schema, context: sd.CommandContext ) -> type[qlast.CreateAlias]: # Can't just use class-level astnode because that creates a # duplicate in ast -> command mapping. return qlast.CreateAlias class RenameMultiRangeExprAlias( CollectionExprAliasCommand[MultiRangeExprAlias], sd.RenameObject[MultiRangeExprAlias], ): pass class AlterMultiRangeExprAlias( CollectionExprAliasCommand[MultiRangeExprAlias], sd.AlterObject[MultiRangeExprAlias], ): pass class DeleteTuple(DeleteCollectionType[Tuple]): pass class DeleteTupleExprAlias(DeleteCollectionExprAlias[TupleExprAlias]): pass class DeleteArray(DeleteCollectionType[Array]): # Prevent array types from getting deleted unless the element # type is being deleted too. def _has_outside_references( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> bool: if super()._has_outside_references(schema, context): return True el_type = self.scls.get_element_type(schema) if el_type.is_scalar() and not context.is_deleting(el_type): return True return False class DeleteArrayExprAlias(DeleteCollectionExprAlias[ArrayExprAlias]): pass class DeleteRange(DeleteCollectionType[Range]): pass class DeleteRangeExprAlias(DeleteCollectionExprAlias[RangeExprAlias]): pass class DeleteMultiRange(DeleteCollectionType[MultiRange]): pass class DeleteMultiRangeExprAlias(DeleteCollectionExprAlias[MultiRangeExprAlias]): pass def materialize_type_in_attribute( schema: s_schema.Schema, context: sd.CommandContext, cmd: sd.Command, attrname: str, ) -> s_schema.Schema: assert isinstance(cmd, sd.ObjectCommand) type_ref = cmd.get_local_attribute_value(attrname) if type_ref is None: return schema span = cmd.get_attribute_span('target') if isinstance(type_ref, TypeExprShell): cc_cmd = ensure_schema_type_expr_type( schema, type_ref, parent_cmd=cmd, span=span, context=context, ) if cc_cmd is not None: schema = cc_cmd.apply(schema, context) if isinstance(type_ref, CollectionTypeShell): # If the current command is a fragment, we want the collection # creation to live in the parent operation, in order for the # logic to skip it if the object already exists to work. op = (cmd.get_parent_op(context) if isinstance(cmd, sd.AlterObjectFragment) else cmd) make_coll = type_ref.as_create_delta(schema) op.add_prerequisite(make_coll) schema = make_coll.apply(schema, context) if isinstance(type_ref, TypeShell): try: type_ref.resolve(schema) except errors.InvalidReferenceError as e: refname = type_ref.get_refname(schema) if refname is not None: utils.enrich_schema_lookup_error( e, refname, modaliases=context.modaliases, schema=schema, item_type=Type, span=span, ) raise except errors.InvalidPropertyDefinitionError as e: e.set_span(span) raise elif not isinstance(type_ref, Type): raise AssertionError( f'unexpected value in type attribute {attrname!r} of ' f'{cmd.get_verbosename()}: {type_ref!r}' ) return schema def is_type_compatible( type_a: Type, type_b: Type, *, schema: s_schema.Schema, ) -> bool: """Check whether two types have compatible SQL representations. EdgeQL implicit casts need to be turned into explicit casts in some places, since the semantics differ from SQL's. """ schema, material_type_a = type_a.material_type(schema) schema, material_type_b = type_b.material_type(schema) def labels_compatible(t_a: Type, t_b: Type) -> bool: if t_a == t_b: return True if isinstance(t_a, Tuple) and isinstance(t_b, Tuple): if t_a.get_is_persistent(schema) and t_b.get_is_persistent(schema): return False # For tuples, we also (recursively) check that the element # names match return all( name_a == name_b and labels_compatible(st_a, st_b) for (name_a, st_a), (name_b, st_b) in zip(t_a.iter_subtypes(schema), t_b.iter_subtypes(schema)) ) elif isinstance(t_a, Array) and isinstance(t_b, Array): t_as = t_a.get_element_type(schema) t_bs = t_b.get_element_type(schema) return ( not isinstance(t_as, Tuple) and labels_compatible(t_as, t_bs) ) elif isinstance(t_a, Range) and isinstance(t_b, Range): t_as = t_a.get_element_type(schema) t_bs = t_b.get_element_type(schema) return labels_compatible(t_as, t_bs) elif isinstance(t_a, MultiRange) and isinstance(t_b, MultiRange): t_as = t_a.get_element_type(schema) t_bs = t_b.get_element_type(schema) return labels_compatible(t_as, t_bs) else: return True return ( material_type_b.issubclass(schema, material_type_a) and labels_compatible(material_type_a, material_type_b) ) ================================================ FILE: edb/schema/unknown_pointers.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Machinery for handling pointers with an unspecified kind. Most of the DDL/delta machinery really requires that we know whether we are operating on a link or a property, but our SDL syntax allows omitting the specifier. Because the pointer might be computed, it's not possible to resolve this ahead of time, so we build just enough machinery for compiling unknown pointer operations to make ddl.apply_sdl work. """ from __future__ import annotations from edb.common import struct from edb.edgeql import ast as qlast from edb.edgeql import parser as qlparser from . import delta as sd from . import objects as so from . import properties as s_props from . import pointers from . import sources as s_sources from . import schema as s_schema class UnknownPointerSourceContext[Source_T: s_sources.Source]( s_sources.SourceCommandContext[Source_T] ): pass class UnknownPointerCommand( pointers.PointerCommand[pointers.Pointer], context_class=pointers.PointerCommandContext, referrer_context_class=UnknownPointerSourceContext, ): _schema_metaclass = pointers.Pointer def _propagate_ref_creation( self, schema: s_schema.Schema, context: sd.CommandContext, referrer: so.InheritingObject, ) -> None: pass class CreateUnknownPointer( UnknownPointerCommand, pointers.CreatePointer[pointers.Pointer], ): astnode = qlast.CreateConcreteUnknownPointer referenced_astnode = qlast.CreateConcreteUnknownPointer # We stash the original AST node here, so we can reuse it in apply # after we've figured out the type. node = struct.Field(qlast.CreateConcreteUnknownPointer, default=None) @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> sd.Command: assert isinstance(astnode, qlast.CreateConcreteUnknownPointer) # We don't need any of the subcommands in order to figure out # the kind, and we avoid needing to get the contexts right if # we skip them. fakenode = astnode.replace(commands=[]) cmd = super()._cmd_tree_from_ast(schema, fakenode, context) assert isinstance(cmd, CreateUnknownPointer) cmd._process_create_or_alter_ast(schema, fakenode, context) if context.modaliases: astnode = astnode.replace() qlparser.append_module_aliases(astnode, context.modaliases) cmd.node = astnode return cmd def apply( self, schema: s_schema.Schema, context: sd.CommandContext, ) -> s_schema.Schema: # We don't know what the real type of this pointer is, so this # is a two step process: # 1. Apply it using purely generic Pointer code. This doesn't produce # a fully legitimate result, but will resolve the target. # 2. Check whether the target is an object, and construct a new # create AST node specialized to pointer or link. Then compile # that to a delta tree and apply it. nschema = super().apply(schema, context) source = self.scls.get_source(nschema) target = self.scls.get_target(nschema) assert source and target astnode = self.node assert astnode astcls = ( qlast.CreateConcreteLink # It's a link if the target is an object and so is the source. # If the source isn't, it's a link property, which will fail. if target.is_object_type() and not isinstance(source, pointers.Pointer) else qlast.CreateConcreteProperty ) astnode = astnode.replace(__class__=astcls) ncmd = sd.compile_ddl(schema, astnode, context=context) assert isinstance(ncmd, pointers.CreatePointer) rschema = ncmd.apply(schema, context) return rschema class AlterUnknownPointer( UnknownPointerCommand, pointers.AlterPointer[pointers.Pointer], ): astnode = qlast.AlterConcreteUnknownPointer referenced_astnode = qlast.AlterConcreteUnknownPointer @classmethod def _cmd_tree_from_ast( cls, schema: s_schema.Schema, astnode: qlast.DDLOperation, context: sd.CommandContext, ) -> pointers.AlterPointer[pointers.Pointer]: # For alters that get run as part of apply_sdl, the relevant # object should exist in the schema when _cmd_tree_from_ast is # called, so we can resolve whether it is a link or a property # right away and never need to return an AlterUnknownPointer # object. # We don't need any of the subcommands in order to figure out # the kind, and we avoid needing to get the contexts right if # we skip them. fakenode = astnode.replace(commands=[]) cmd = super()._cmd_tree_from_ast(schema, fakenode, context) obj = cmd.get_object(schema, context) source = obj.get_source(schema) is_prop = ( isinstance(obj, s_props.Property) or isinstance(source, pointers.Pointer) ) astcls = ( qlast.AlterConcreteProperty if is_prop else qlast.AlterConcreteLink ) if isinstance(astnode, qlast.AlterObject) else ( qlast.CreateConcreteProperty if is_prop else qlast.CreateConcreteLink ) astnode = astnode.replace(__class__=astcls) assert isinstance(astnode, qlast.DDLCommand) qlparser.append_module_aliases(astnode, context.modaliases) res = sd.compile_ddl(schema, astnode, context=context) assert isinstance(res, pointers.AlterPointer) return res ================================================ FILE: edb/schema/utils.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, Optional, TypeVar, Iterable, Mapping, Sequence, cast, TYPE_CHECKING, ) import collections import decimal import itertools from edb import errors from edb.common import levenshtein from edb.edgeql import ast as qlast from edb.ir import statypes from . import name as sn from . import objects as so from . import expr as s_expr if TYPE_CHECKING: from . import objtypes as s_objtypes from . import schema as s_schema from . import types as s_types from edb.common import parsing T = TypeVar('T') def name_to_ast_ref(name: sn.Name) -> qlast.ObjectRef: if isinstance(name, sn.QualName): return qlast.ObjectRef( module=name.module, name=name.name, ) else: return qlast.ObjectRef( name=name.name, ) def ast_ref_to_name(ref: qlast.ObjectRef) -> sn.Name: if ref.module: return sn.QualName(name=ref.name, module=ref.module) else: return sn.UnqualName(name=ref.name) def ast_ref_to_unqualname(ref: qlast.ObjectRef) -> sn.UnqualName: if ref.module: raise errors.InternalServerError( f'unexpected fully-qualified name: {ast_ref_to_name(ref)}', span=ref.span, ) else: return sn.UnqualName(name=ref.name) def resolve_name( lname: sn.Name, *, metaclass: Optional[type[so.Object]] = None, span: Optional[parsing.Span] = None, modaliases: Mapping[Optional[str], str], schema: s_schema.Schema, ) -> sn.Name: obj = schema.get( lname, type=metaclass, module_aliases=modaliases, default=None, span=span, ) if obj is not None: name = obj.get_name(schema) elif isinstance(lname, sn.QualName): name = sn.QualName( module=modaliases.get(lname.module, lname.module), name=lname.name, ) elif metaclass is not None and issubclass(metaclass, so.QualifiedObject): actual_module = modaliases.get(None) if actual_module is None: raise errors.InvalidReferenceError( 'unqualified name and no default module alias set') name = sn.QualName(module=actual_module, name=lname.name) else: # Do not assume the name is fully-qualified unless asked # explicitly. name = lname return name def ast_objref_to_object_shell( ref: qlast.ObjectRef, *, metaclass: type[so.Object_T], modaliases: Mapping[Optional[str], str], schema: s_schema.Schema, ) -> so.ObjectShell[so.Object_T]: lname = ast_ref_to_name(ref) name = resolve_name( lname, metaclass=metaclass, modaliases=modaliases, schema=schema, span=ref.span, ) return so.ObjectShell( name=name, origname=lname, schemaclass=metaclass, span=ref.span, ) def ast_objref_to_type_shell[TypeT: s_types.Type]( ref: qlast.ObjectRef, *, metaclass: type[TypeT], modaliases: Mapping[Optional[str], str], schema: s_schema.Schema, ) -> s_types.TypeShell[TypeT]: from . import types as s_types if metaclass is not s_types.Type: mcls = metaclass else: mcls = s_types.QualifiedType # type: ignore lname = ast_ref_to_name(ref) name = resolve_name( lname, metaclass=mcls, modaliases=modaliases, schema=schema, span=ref.span, ) return s_types.TypeShell( name=name, origname=lname, schemaclass=mcls, span=ref.span, ) def ast_to_type_shell( node: qlast.TypeExpr, *, metaclass: type[s_types.TypeT_co], module: Optional[str] = None, modaliases: Mapping[Optional[str], str], schema: s_schema.Schema, allow_generalized_bases: bool = False, ) -> s_types.TypeShell[s_types.TypeT_co]: if isinstance(node, qlast.TypeOp): return type_op_ast_to_type_shell( node, metaclass=metaclass, module=module, modaliases=modaliases, schema=schema, ) assert isinstance(node, qlast.TypeName) if (node.subtypes is not None and isinstance(node.maintype, qlast.ObjectRef) and node.maintype.name == 'enum'): from . import scalars as s_scalars from edb.pgsql import common as pg_common assert node.subtypes elements: list[str] = [] element_spans: list[Optional[parsing.Span]] = [] if isinstance(node.subtypes[0], qlast.TypeExprLiteral): # handling enums as literals # eg. enum<'A','B','C'> for subtype_expr_literal in cast( list[qlast.TypeExprLiteral], node.subtypes ): elements.append(subtype_expr_literal.val.value) element_spans.append(subtype_expr_literal.val.span) else: # handling enums as typenames # eg. enum for subtype_type_name in cast( list[qlast.TypeName], node.subtypes ): if ( not isinstance(subtype_type_name, qlast.TypeName) or not isinstance( subtype_type_name.maintype, qlast.ObjectRef ) ): raise errors.EdgeQLSyntaxError( f'enums do not support mapped values', span=subtype_type_name.span, ) elements.append(subtype_type_name.maintype.name) element_spans.append(subtype_type_name.maintype.span) for element, element_span in zip(elements, element_spans): if len(element) > pg_common.MAX_ENUM_LABEL_LENGTH: raise errors.SchemaDefinitionError( f'enum labels cannot exceed ' f'{pg_common.MAX_ENUM_LABEL_LENGTH} characters', span=element_span, ) return s_scalars.AnonymousEnumTypeShell( # type: ignore elements=elements ) elif node.subtypes is not None: from . import types as s_types assert isinstance(node.maintype, qlast.ObjectRef) coll = None try: coll = s_types.Collection.get_class(node.maintype.name) except errors.SchemaError: if not allow_generalized_bases: raise subtypes_list: list[s_types.TypeShell[s_types.Type]] = [] if coll is None: assert allow_generalized_bases res = ast_objref_to_type_shell( node.maintype, modaliases=modaliases, metaclass=metaclass, schema=schema, ) res.extra_args = tuple(node.subtypes) return res elif issubclass(coll, s_types.Tuple): # Note: if we used abc Tuple here, then we would need anyway # to assert it is an instance of s_types.Tuple to make mypy happy # (rightly so, because later we use from_subtypes method) subtypes: dict[str, s_types.TypeShell[s_types.Type]] = {} # tuple declaration must either be named or unnamed, but not both names = set() named = None unnamed = None for si, st in enumerate(node.subtypes): if st.name: named = True type_name = st.name if type_name in names: raise errors.SchemaError( f"named tuple has duplicate field '{type_name}'", span=st.span) names.add(type_name) else: unnamed = True type_name = str(si) if named is not None and unnamed is not None: raise errors.EdgeQLSyntaxError( f'mixing named and unnamed tuple declaration ' f'is not supported', span=node.subtypes[0].span, ) subtypes[type_name] = ast_to_type_shell( cast(qlast.TypeName, st), modaliases=modaliases, metaclass=metaclass, schema=schema, ) try: return coll.create_shell( # type: ignore schema, subtypes=subtypes, typemods={'named': bool(named)}, ) except errors.SchemaError as e: # all errors raised inside are pertaining to subtypes, so # the context should point to the first subtype e.set_span(node.subtypes[0].span) raise e elif issubclass(coll, s_types.Array): for st in node.subtypes: subtypes_list.append( ast_to_type_shell( cast(qlast.TypeName, st), modaliases=modaliases, metaclass=metaclass, schema=schema, ) ) if len(subtypes_list) != 1: raise errors.SchemaError( f'unexpected number of subtypes,' f' expecting 1, got {len(subtypes_list)}', span=node.span, ) if isinstance(subtypes_list[0], s_types.ArrayTypeShell): raise errors.UnsupportedFeatureError( 'nested arrays are not supported', span=node.subtypes[0].span, ) try: return coll.create_shell( # type: ignore schema, subtypes=subtypes_list, ) except errors.SchemaError as e: e.set_span(node.span) raise e elif issubclass(coll, (s_types.Range, s_types.MultiRange)): for st in node.subtypes: subtypes_list.append( ast_to_type_shell( cast(qlast.TypeName, st), modaliases=modaliases, metaclass=metaclass, schema=schema, ) ) if len(subtypes_list) != 1: raise errors.SchemaError( f'unexpected number of subtypes,' f' expecting 1, got {len(subtypes_list)}', span=node.span, ) # FIXME: need to check that subtypes are only anypoint try: return coll.create_shell( # type: ignore schema, subtypes=subtypes_list, ) except errors.SchemaError as e: e.set_span(node.span) raise e elif isinstance(node.maintype, qlast.PseudoObjectRef): from . import pseudo as s_pseudo return s_pseudo.PseudoTypeShell( name=sn.UnqualName(node.maintype.name), span=node.maintype.span, ) # type: ignore assert isinstance(node.maintype, qlast.ObjectRef) return ast_objref_to_type_shell( node.maintype, modaliases=modaliases, metaclass=metaclass, schema=schema, ) def type_op_ast_to_type_shell[TypeT: s_types.Type]( node: qlast.TypeOp, *, metaclass: type[TypeT], module: Optional[str] = None, modaliases: Mapping[Optional[str], str], schema: s_schema.Schema, ) -> s_types.TypeExprShell[TypeT]: from . import types as s_types if node.op not in [qlast.TypeOpName.OR, qlast.TypeOpName.AND]: raise errors.UnsupportedFeatureError( f'unsupported type expression operator: {node.op}', span=node.span, ) if module is None: module = modaliases.get(None) if module is None: raise errors.InternalServerError( 'cannot determine module for derived compound type', span=node.span, ) left = ast_to_type_shell( node.left, metaclass=metaclass, module=module, modaliases=modaliases, schema=schema, ) right = ast_to_type_shell( node.right, metaclass=metaclass, module=module, modaliases=modaliases, schema=schema, ) CompositeTypeShell = ( s_types.UnionTypeShell if node.op == qlast.TypeOpName.OR else s_types.IntersectionTypeShell ) # Doubled check for s_types.TypeExprShell to reassure mypy if ( isinstance(left, CompositeTypeShell) and isinstance(left, s_types.TypeExprShell) ): if ( isinstance(right, CompositeTypeShell) and isinstance(right, s_types.TypeExprShell) ): return CompositeTypeShell( components=left.components + right.components, module=module, schemaclass=metaclass, span=node.span, ) else: return CompositeTypeShell( components=left.components + (right,), module=module, schemaclass=metaclass, span=node.span, ) else: if ( isinstance(right, CompositeTypeShell) and isinstance(right, s_types.TypeExprShell) ): return CompositeTypeShell( components=(left,) + right.components, schemaclass=metaclass, module=module, span=node.span, ) else: return CompositeTypeShell( components=(left, right), module=module, schemaclass=metaclass, span=node.span, ) def ast_to_object_shell( node: qlast.ObjectRef | qlast.TypeName, *, metaclass: type[so.Object_T], module: Optional[str] = None, modaliases: Mapping[Optional[str], str], schema: s_schema.Schema, ) -> so.ObjectShell[so.Object_T]: from . import types as s_types if isinstance(node, qlast.TypeName): if issubclass(metaclass, s_types.Type): return ast_to_type_shell( # type: ignore node, metaclass=metaclass, module=module, modaliases=modaliases, schema=schema, ) else: objref = node.maintype if node.subtypes: raise AssertionError( 'must pass s_types.Type subclass as type when ' 'creating a type shell from type AST' ) assert isinstance(objref, qlast.ObjectRef) return ast_objref_to_object_shell( objref, modaliases=modaliases, metaclass=metaclass, schema=schema, ) else: return ast_objref_to_object_shell( node, modaliases=modaliases, metaclass=metaclass, schema=schema, ) def typeref_to_ast( schema: s_schema.Schema, ref: so.Object_T | so.ObjectShell[so.Object_T], *, _name: Optional[str] = None, disambiguate_std: bool=False, ) -> qlast.TypeExpr: from . import types as s_types if isinstance(ref, so.ObjectShell): return type_shell_to_ast(schema, ref) else: t = ref result: qlast.TypeExpr if isinstance(t, s_types.Type) and t.is_any(schema): result = qlast.TypeName( name=_name, maintype=qlast.PseudoObjectRef(name='anytype') ) elif isinstance(t, s_types.Type) and t.is_anytuple(schema): result = qlast.TypeName( name=_name, maintype=qlast.PseudoObjectRef(name='anytuple') ) elif isinstance(t, s_types.Type) and t.is_anyobject(schema): result = qlast.TypeName( name=_name, maintype=qlast.PseudoObjectRef(name='anyobject') ) elif isinstance(t, s_types.Tuple) and t.is_named(schema): result = qlast.TypeName( name=_name, maintype=qlast.ObjectRef( name=t.get_schema_name() ), subtypes=[ typeref_to_ast(schema, st, _name=sn, disambiguate_std=disambiguate_std) for sn, st in t.iter_subtypes(schema) ] ) elif isinstance(t, (s_types.Array, s_types.Tuple, s_types.Range, s_types.MultiRange)): # Here the concrete type Array is used because t.get_schema_name() # is used, which is not defined for more generic collections and abcs result = qlast.TypeName( name=_name, maintype=qlast.ObjectRef( name=t.get_schema_name() ), subtypes=[ typeref_to_ast(schema, st, disambiguate_std=disambiguate_std) for st in t.get_subtypes(schema) ] ) elif ( isinstance(t, s_types.Type) and (t.is_union_type(schema) or t.is_intersection_type(schema)) ): object_set = ( t.get_union_of(schema) if t.is_union_type(schema) else t.get_intersection_of(schema) ) assert object_set is not None component_objects = tuple(object_set.objects(schema)) result = typeref_to_ast( schema, component_objects[0], disambiguate_std=disambiguate_std ) for component_object in component_objects[1:]: result = qlast.TypeOp( left=result, op=( qlast.TypeOpName.OR if t.is_union_type(schema) else qlast.TypeOpName.AND ), right=typeref_to_ast( schema, component_object, disambiguate_std=disambiguate_std ), ) elif isinstance(t, so.QualifiedObject): t_name = t.get_name(schema) module = t_name.module if disambiguate_std and module == 'std': # If the type is defined in 'std::', replace the module to # '__std__' to handle cases where 'std' name is aliased to # another module. module = '__std__' result = qlast.TypeName( name=_name, maintype=qlast.ObjectRef( module=module, name=t_name.name ) ) else: raise NotImplementedError(f'cannot represent {t!r} as a shell') return result def shell_to_ast( schema: s_schema.Schema, t: so.ObjectShell[so.Object], *, _name: Optional[str] = None, ) -> qlast.TypeExpr | qlast.Expr: from . import types as s_types if isinstance(t, s_types.TypeShell): return type_shell_to_ast(schema, t, _name=_name) elif isinstance(t, so.ObjectShell): name = t.name if isinstance(name, sn.QualName): qlref = qlast.ObjectRef( module=name.module, name=name.name, ) else: qlref = qlast.ObjectRef( module='', name=name.name, ) return qlast.Path(steps=[qlref]) else: raise NotImplementedError(f'cannot represent {t!r} as a type shell') def type_shell_to_ast( schema: s_schema.Schema, t: so.ObjectShell[so.Object], *, _name: Optional[str] = None, ) -> qlast.TypeExpr: from . import pseudo as s_pseudo from . import types as s_types from . import scalars as s_scalars result: qlast.TypeExpr qlref: qlast.BaseObjectRef if isinstance(t, s_pseudo.PseudoTypeShell): if t.name.name not in {'anytype', 'anytuple', 'anyobject'}: raise AssertionError(f'unexpected pseudo type shell: {t.name!r}') result = qlast.TypeName( name=_name, maintype=qlast.PseudoObjectRef(name=t.name.name) ) elif isinstance(t, s_types.TupleTypeShell): if t.is_named(): result = qlast.TypeName( name=_name, maintype=qlast.ObjectRef( name='tuple', ), subtypes=[ type_shell_to_ast(schema, st, _name=sn) for sn, st in t.iter_subtypes(schema) ] ) else: result = qlast.TypeName( name=_name, maintype=qlast.ObjectRef( name='tuple', ), subtypes=[ type_shell_to_ast(schema, st) for st in t.get_subtypes(schema) ] ) elif isinstance(t, s_types.ArrayTypeShell): result = qlast.TypeName( name=_name, maintype=qlast.ObjectRef( name='array', ), subtypes=[ type_shell_to_ast(schema, st) for st in t.get_subtypes(schema) ] ) elif isinstance(t, s_types.RangeTypeShell): result = qlast.TypeName( name=_name, maintype=qlast.ObjectRef( name='range', ), subtypes=[ type_shell_to_ast(schema, st) for st in t.get_subtypes(schema) ] ) elif isinstance(t, s_types.MultiRangeTypeShell): result = qlast.TypeName( name=_name, maintype=qlast.ObjectRef( name='multirange', ), subtypes=[ type_shell_to_ast(schema, st) for st in t.get_subtypes(schema) ] ) elif isinstance(t, s_types.UnionTypeShell): components = t.get_components(schema) result = typeref_to_ast(schema, components[0]) for component in components[1:]: result = qlast.TypeOp( left=result, op=qlast.TypeOpName.OR, right=typeref_to_ast(schema, component), ) elif isinstance(t, s_types.IntersectionTypeShell): components = t.get_components(schema) result = typeref_to_ast(schema, components[0]) for component in components[1:]: result = qlast.TypeOp( left=result, op=qlast.TypeOpName.AND, right=typeref_to_ast(schema, component), ) elif isinstance(t, s_scalars.AnonymousEnumTypeShell): result = qlast.TypeName( name=_name, maintype=qlast.ObjectRef( name='enum', ), subtypes=[ qlast.TypeName(maintype=qlast.ObjectRef(name=x)) for x in t.elements ] ) elif isinstance(t, so.ObjectShell): name = t.name if isinstance(name, sn.QualName): qlref = qlast.ObjectRef( module=name.module, name=name.name, ) else: qlref = qlast.ObjectRef( module='', name=name.name, ) result = qlast.TypeName( name=_name, maintype=qlref, ) else: raise NotImplementedError(f'cannot represent {t!r} as a type shell') return result def is_nontrivial_container(value: Any) -> Optional[Iterable[Any]]: trivial_classes = (str, bytes, bytearray, memoryview) if (isinstance(value, collections.abc.Iterable) and not isinstance(value, trivial_classes)): return value else: return None def get_class_nearest_common_ancestors( schema: s_schema.Schema, classes: Iterable[so.InheritingObjectT] ) -> list[so.InheritingObjectT]: # First, find the intersection of parents classes = list(classes) first = [classes[0]] first.extend(classes[0].get_ancestors(schema).objects(schema)) common = set(first).intersection( *[set(c.get_ancestors(schema).objects(schema)) | {c} for c in classes[1:]]) common_list = sorted(common, key=lambda i: first.index(i)) nearests: list[so.InheritingObjectT] = [] # Then find the common ancestors that don't have any subclasses that # are also nearest common ancestors. for anc in common_list: if not any(x.issubclass(schema, anc) for x in nearests): nearests.append(anc) return nearests def minimize_class_set_by_most_generic( schema: s_schema.Schema, classes: Iterable[so.InheritingObjectT] ) -> list[so.InheritingObjectT]: """Minimize the given set of objects by filtering out all subclasses.""" classes = list(classes) mros = [set(p.get_ancestors(schema).objects(schema)) for p in classes] count = len(classes) smap = itertools.starmap # Return only those entries that do not have other entries in their mro result = [ scls for i, scls in enumerate(classes) if not any(smap(set.__contains__, ((mros[i], classes[j]) for j in range(count) if j != i))) ] return result def minimize_class_set_by_least_generic( schema: s_schema.Schema, classes: Iterable[so.InheritingObjectT] ) -> list[so.InheritingObjectT]: """Minimize the given set of objects by filtering out all superclasses.""" classes = list(classes) mros = [set(p.get_ancestors(schema).objects(schema)) | {p} for p in classes] count = len(classes) smap = itertools.starmap # Return only those entries that are not present in other entries' mro result = [ scls for i, scls in enumerate(classes) if not any(smap(set.__contains__, ((mros[j], classes[i]) for j in range(count) if j != i))) ] return result def merge_reduce( target: so.InheritingObjectT, sources: Iterable[so.InheritingObjectT], field_name: str, *, ignore_local: bool, schema: s_schema.Schema, f: Callable[[T, T], T], type: type[T], ) -> Optional[T]: values: list[tuple[T, str]] = [] if not ignore_local: ours = target.get_explicit_local_field_value(schema, field_name, None) if ours is not None: vn = target.get_verbosename(schema, with_parent=True) values.append((ours, vn)) for source in sources: theirs = source.get_explicit_field_value(schema, field_name, None) if theirs is not None: vn = source.get_verbosename(schema, with_parent=True) values.append((theirs, vn)) if values: val = values[0][0] desc = values[0][1] cdn = target.get_schema_class_displayname() for other_val, other_desc in values[1:]: try: val = f(val, other_val) except Exception: raise errors.SchemaDefinitionError( f'invalid {cdn} definition: {field_name} is defined ' f'as {val} in {desc}, but is defined as {other_val} ' f'in {other_desc}, which is incompatible' ) return val else: return None def get_nq_name(schema: s_schema.Schema, item: so.Object) -> str: shortname = item.get_shortname(schema) if isinstance(shortname, sn.QualName): return shortname.name else: return str(shortname) def find_item_suggestions( name: sn.Name, modaliases: Mapping[Optional[str], str], schema: s_schema.Schema, *, item_type: Optional[so.ObjectMeta] = None, condition: Optional[Callable[[so.Object], bool]] = None, ) -> Iterable[tuple[so.Object, str]]: from . import functions as s_func from . import properties as s_prop from . import links as s_link from . import modules as s_mod orig_modname = name.module if isinstance(name, sn.QualName) else None suggestions: list[so.Object] = [] if modname := modaliases.get(orig_modname, None): if schema.get_global(s_mod.Module, modname, None): suggestions.extend( schema.get_objects( included_modules=[sn.UnqualName(modname)], ), ) modname = f'std::{modname}' if schema.get_global(s_mod.Module, modname, None): suggestions.extend( schema.get_objects( included_modules=[sn.UnqualName(modname)], ), ) if orig_modname: if schema.get_global(s_mod.Module, orig_modname, None): suggestions.extend( schema.get_objects( included_modules=[sn.UnqualName(orig_modname)], ), ) modname = f'std::{orig_modname}' if schema.get_global(s_mod.Module, modname, None): suggestions.extend( schema.get_objects( included_modules=[sn.UnqualName(modname)], ), ) else: suggestions.extend( schema.get_objects( included_modules=[sn.UnqualName("std")], ), ) filters = [] # links and properties are suggested by find_fields_suggestions filters.append( lambda s: not isinstance(s, s_prop.Property) and not isinstance(s, s_link.Link) ) if condition is not None: filters.append(condition) if item_type is not None: it = item_type filters.append(lambda s: isinstance(s, it)) else: # When schema class is not specified, only suggest generic objects. filters.append(lambda s: not sn.is_fullname(str(s.get_name(schema)))) filters.append(lambda s: not isinstance(s, s_func.CallableObject)) # Never suggest object fragments. filters.append(lambda s: not isinstance(s, so.ObjectFragment)) filtered = filter(lambda s: all(f(s) for f in filters), suggestions) # Add display names cur_module_name = modaliases.get(None) def get_display_name(suggestion: so.Object) -> str: if isinstance(suggestion, so.QualifiedObject): mod = suggestion.get_name(schema).module if mod == "std" or mod == cur_module_name: return get_nq_name(schema, suggestion) return suggestion.get_displayname(schema) return ((s, get_display_name(s)) for s in filtered) def find_pointer_suggestions( schema: s_schema.Schema, item_type: Optional[so.ObjectMeta], parent: Optional[so.Object], ) -> Iterable[tuple[so.Object, str]]: from . import pointers as s_pointers """ Suggests pointers (properties or links) from parent object type. If pointer type is not expected, use .name notation. """ from . import sources as s_sources if not isinstance(parent, s_sources.Source): return () pointers_with_names = parent.get_pointers(schema).items(schema) pointers = (pointer for _, pointer in pointers_with_names) suggestions = ((s, s.get_displayname(schema)) for s in pointers) if item_type is not s_pointers.Pointer: # Prefix with . suggestions = ((s, "." + n) for s, n in suggestions) return suggestions def pick_closest_suggestions( name: sn.Name, schema: s_schema.Schema, suggestions: Iterable[tuple[so.Object, str]], limit: int, ) -> list[tuple[so.Object, str]]: local_name = name.name # Compute Levenshtein distance for each suggestion. with_distance: list[tuple[so.Object, str, int]] = [ (s, name, levenshtein.distance(local_name, get_nq_name(schema, s))) for s, name in suggestions ] # Filter out suggestions that are too dissimilar. max_distance = 3 closest = list(filter(lambda s: s[2] < max_distance, with_distance)) # Sort by proximity, then by whether the suggestion is contains # the source string at the beginning, then by suggestion name. closest.sort( key=lambda s: ( s[2], not get_nq_name(schema, s[0]).startswith(local_name), s[1], ) ) return [(s[0], s[1]) for s in closest[:limit]] def enrich_schema_lookup_error( error: errors.EdgeDBError, item_name: sn.Name, modaliases: Mapping[Optional[str], str], schema: s_schema.Schema, *, item_type: Optional[so.ObjectMeta] = None, suggestion_limit: int = 3, condition: Optional[Callable[[so.Object], bool]] = None, span: Optional[parsing.Span] = None, pointer_parent: Optional[so.Object] = None, hint_text: str = 'did you mean' ) -> None: all_suggestions = itertools.chain( find_item_suggestions( item_name, modaliases, schema, item_type=item_type, condition=condition, ), find_pointer_suggestions(schema, item_type, pointer_parent), ) suggestions = pick_closest_suggestions( item_name, schema, all_suggestions, suggestion_limit ) if suggestions: names = [name for _, name in suggestions] if len(names) > 1: hint = f'{hint_text} one of these: {", ".join(names)}?' else: hint = f'{hint_text} {names[0]!r}?' error.set_hint_and_details(hint=hint) if span is not None: error.set_span(span) def ensure_union_type( schema: s_schema.Schema, types: Sequence[s_types.Type], *, opaque: bool = False, module: Optional[str] = None, transient: bool = False, ) -> tuple[s_schema.Schema, s_types.Type, bool]: from edb.schema import objtypes as s_objtypes if len(types) == 1 and not opaque: return schema, next(iter(types)), False seen_scalars = False seen_objtypes = False created = False for t in types: if isinstance(t, s_objtypes.ObjectType): if seen_scalars: raise _union_error(schema, types) seen_objtypes = True else: if seen_objtypes: raise _union_error(schema, types) seen_scalars = True if seen_scalars: uniontype: s_types.Type = types[0] for t1 in types[1:]: schema, common_type = ( uniontype.find_common_implicitly_castable_type(t1, schema) ) if common_type is None: raise _union_error(schema, types) else: uniontype = common_type else: objtypes = cast( Sequence[s_objtypes.ObjectType], types, ) schema, uniontype, created = s_objtypes.get_or_create_union_type( schema, components=objtypes, opaque=opaque, module=module, transient=transient, ) return schema, uniontype, created def simplify_union_types( schema: s_schema.Schema, types: Sequence[s_types.Type], ) -> Sequence[s_types.Type]: """Minimize the types used to create a union of types. Any unions types are unwrapped. Then, any unnecessary subclasses are removed. """ from edb.schema import types as s_types components: set[s_types.Type] = set() for t in types: union_of = t.get_union_of(schema) if union_of: components.update(union_of.objects(schema)) else: components.add(t) if all(isinstance(c, s_types.InheritingType) for c in components): return list(minimize_class_set_by_most_generic( schema, cast(set[s_types.InheritingType], components), )) else: return list(components) def simplify_union_types_preserve_derived( schema: s_schema.Schema, types: Sequence[s_types.Type], ) -> Sequence[s_types.Type]: """Minimize the types used to create a union of types. Any unions types are unwrapped. Then, any unnecessary subclasses are removed. Derived types are always preserved for 'std::UNION', 'std::IF', and 'std::??'. """ from edb.schema import types as s_types components: set[s_types.Type] = set() for t in types: union_of = t.get_union_of(schema) if union_of: components.update(union_of.objects(schema)) else: components.add(t) derived = set( t for t in components if ( isinstance(t, s_types.InheritingType) and t.get_is_derived(schema) ) ) nonderived: Sequence[s_types.Type] = [ t for t in components if t not in derived ] nonderived = minimize_class_set_by_most_generic( schema, cast(set[s_types.InheritingType], nonderived), ) return list(nonderived) + list(derived) def get_non_overlapping_union( schema: s_schema.Schema, objects: Iterable[so.InheritingObjectT], ) -> tuple[frozenset[so.InheritingObjectT], bool]: all_objects: set[so.InheritingObjectT] = set(objects) non_unique_count = 0 for obj in objects: descendants = obj.descendants(schema) non_unique_count += len(descendants) + 1 all_objects.update(descendants) if non_unique_count == len(all_objects): # The input object set is already non-overlapping return frozenset(objects), False else: return frozenset(all_objects), True def get_type_expr_non_overlapping_union( type: s_types.Type, schema: s_schema.Schema, ) -> tuple[frozenset[s_types.Type], bool]: """Get a non-overlapping set of the type's descendants""" from edb.schema import types as s_types expanded_types = expand_type_expr_descendants(type, schema) # filter out subclasses expanded_types = { type for type in expanded_types if not any( type is not other and type.issubclass(schema, other) for other in expanded_types ) } non_overlapping, union_is_exhaustive = get_non_overlapping_union( schema, cast(set[so.InheritingObject], expanded_types) ) return cast(frozenset[s_types.Type], non_overlapping), union_is_exhaustive def expand_type_expr_descendants( type: s_types.Type, schema: s_schema.Schema, *, expand_opaque_union: bool = True, ) -> set[s_types.Type]: """Expand types and type expressions to get descendants""" from edb.schema import types as s_types if sub_union := type.get_union_of(schema): # Expanding a union # Get the union of the component descendants return set.union(*( expand_type_expr_descendants( component, schema, ) for component in sub_union.objects(schema) )) elif sub_intersection := type.get_intersection_of(schema): # Expanding an intersection # Get the intersection of component descendants return set.intersection(*( expand_type_expr_descendants( component, schema ) for component in sub_intersection.objects(schema) )) elif type.is_view(schema): # When expanding a view, simply unpeel the view. return expand_type_expr_descendants( type.peel_view(schema), schema ) # Return simple type and all its descendants. # Some types (eg. BaseObject) have non-simple descendants, filter them out. return {type} | { c for c in cast( set[s_types.Type], set(cast(so.InheritingObject, type).descendants(schema)) ) if ( not c.is_union_type(schema) and not c.is_intersection_type(schema) and not c.is_view(schema) ) } def _union_error( schema: s_schema.Schema, components: Iterable[s_types.Type] ) -> errors.SchemaError: names = ', '.join(sorted(c.get_displayname(schema) for c in components)) return errors.SchemaError(f'using incompatible types {names}') def ensure_intersection_type( schema: s_schema.Schema, types: Sequence[s_types.Type], *, transient: bool = False, module: Optional[str] = None, ) -> tuple[s_schema.Schema, s_types.Type, bool]: from edb.schema import objtypes as s_objtypes if len(types) == 1: return schema, next(iter(types)), False seen_scalars = False seen_objtypes = False for t in types: if t.is_object_type(): if seen_scalars: raise _intersection_error(schema, types) seen_objtypes = True else: if seen_objtypes: raise _intersection_error(schema, types) seen_scalars = True if seen_scalars: # Non-related scalars and collections cannot for intersection types. raise _intersection_error(schema, types) else: return s_objtypes.get_or_create_intersection_type( schema, components=cast(Iterable[s_objtypes.ObjectType], types), module=module, transient=transient, ) def simplify_intersection_types( schema: s_schema.Schema, types: Sequence[s_types.Type], ) -> Sequence[s_types.Type]: """Minimize the types used to create an intersection of types. Any intersection types are unwrapped. Then, any unnecessary superclasses are removed. """ from edb.schema import types as s_types components: set[s_types.Type] = set() for t in types: intersection_of = t.get_intersection_of(schema) if intersection_of: components.update(intersection_of.objects(schema)) else: components.add(t) if all(isinstance(c, s_types.InheritingType) for c in components): return minimize_class_set_by_least_generic( schema, cast(set[s_types.InheritingType], components), ) else: return list(components) def _intersection_error( schema: s_schema.Schema, components: Iterable[s_types.Type] ) -> errors.SchemaError: names = ', '.join(sorted(c.get_displayname(schema) for c in components)) return errors.SchemaError(f'cannot create an intersection of {names}') MAX_INT64 = 2 ** 63 - 1 MIN_INT64 = -2 ** 63 def const_ast_from_python(val: Any, with_secrets: bool=False) -> qlast.Expr: if isinstance(val, str): return qlast.Constant.string(val) elif isinstance(val, bool): return qlast.Constant.boolean(val) elif isinstance(val, int): if MIN_INT64 <= val <= MAX_INT64: return qlast.Constant.integer(val) else: raise ValueError(f'int64 value out of range: {val}') elif isinstance(val, decimal.Decimal): return qlast.Constant(value=f'{val}n', kind=qlast.ConstantKind.DECIMAL) elif isinstance(val, float): return qlast.Constant(value=str(val), kind=qlast.ConstantKind.FLOAT) elif isinstance(val, bytes): return qlast.BytesConstant(value=val) elif isinstance(val, statypes.Duration): return qlast.TypeCast( type=qlast.TypeName( maintype=qlast.ObjectRef(module='__std__', name='duration'), ), expr=qlast.Constant.string(value=val.to_iso8601()), ) elif isinstance(val, statypes.EnumScalarType): qltype = val.get_edgeql_type() return qlast.TypeCast( type=qlast.TypeName( maintype=qlast.ObjectRef( module=qltype.module, name=qltype.name), ), expr=qlast.Constant.string(value=val.to_str()), ) elif isinstance(val, statypes.CompositeType): return qlast.InsertQuery( subject=name_to_ast_ref(sn.name_from_string(val._tspec.name)), shape=[ qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=ptr)]), compexpr=const_ast_from_python( getattr(val, ptr), with_secrets=with_secrets ), ) for ptr, typ in val._tspec.fields.items() if not (typ.secret and not with_secrets) and not typ.protected ], ) elif isinstance(val, (set, frozenset)): return qlast.Set(elements=[ const_ast_from_python(x, with_secrets=with_secrets) for x in val ]) elif val is None: return qlast.Set(elements=[]) else: raise ValueError(f'unexpected constant type: {type(val)!r}') def get_config_type_shape( schema: s_schema.Schema, stype: s_objtypes.ObjectType, path: list[qlast.PathElement], ) -> list[qlast.ShapeElement]: from . import objtypes as s_objtypes shape = [ qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name='_tname')], ), compexpr=qlast.Path( steps=path + [ qlast.Ptr(name='__type__'), qlast.Ptr(name='name'), ], ), ), ] seen: set[str] = set() stypes = [stype] + list(stype.ordered_descendants(schema)) for t in stypes: t_name = t.get_name(schema) for unqual_pn, p in t.get_pointers(schema).items(schema): pn = str(unqual_pn) if pn in ('id', '__type__') or pn in seen: continue elem_path: list[qlast.PathElement] = [] if t != stype: elem_path.append( qlast.TypeIntersection( type=qlast.TypeName( maintype=qlast.ObjectRef( module=t_name.module, name=t_name.name, ), ), ), ) elem_path.append(qlast.Ptr(name=pn)) ptype = p.get_target(schema) assert ptype is not None if str(ptype.get_name(schema)) == 'cfg::AbstractConfig': continue if isinstance(ptype, s_objtypes.ObjectType): subshape = get_config_type_shape( schema, ptype, path + elem_path) else: subshape = [] shape.append( qlast.ShapeElement( expr=qlast.Path(steps=elem_path), elements=subshape, ), ) seen.add(pn) return shape def type_shell_multi_substitute( mapping: dict[sn.Name, s_types.TypeShell[s_types.TypeT_co]], typ: s_types.TypeShell[s_types.TypeT_co], schema: s_schema.Schema, ) -> s_types.TypeShell[s_types.TypeT_co]: for name, new in mapping.items(): typ = type_shell_substitute(name, new, typ, schema) return typ def type_shell_substitute( name: sn.Name, new: s_types.TypeShell[s_types.TypeT_co], typ: s_types.TypeShell[s_types.TypeT_co], schema: s_schema.Schema, ) -> s_types.TypeShell[s_types.TypeT_co]: from . import types as s_types # arguably this would be better done with a method on the types if typ.name == name: return new if isinstance(typ, s_types.UnionTypeShell): assert isinstance(typ.name, sn.QualName) return s_types.UnionTypeShell( module=typ.name.module, schemaclass=typ.schemaclass, opaque=typ.opaque, components=[ type_shell_substitute(name, new, c, schema) for c in typ.components ], ) elif isinstance(typ, s_types.IntersectionTypeShell): assert isinstance(typ.name, sn.QualName) return s_types.IntersectionTypeShell( module=typ.name.module, schemaclass=typ.schemaclass, components=[ type_shell_substitute(name, new, c, schema) for c in typ.components ], ) elif isinstance(typ, s_types.ArrayTypeShell): return s_types.ArrayTypeShell( name=None, expr=typ.expr, typemods=typ.typemods, schemaclass=typ.schemaclass, subtype=type_shell_substitute(name, new, typ.subtype, schema), ) elif isinstance(typ, s_types.TupleTypeShell): return s_types.TupleTypeShell( name=None, typemods=typ.typemods, schemaclass=typ.schemaclass, subtypes={ k: type_shell_substitute(name, new, v, schema) for k, v in typ.subtypes.items() }, ) elif isinstance(typ, s_types.RangeTypeShell): return s_types.RangeTypeShell( name=None, typemods=typ.typemods, schemaclass=typ.schemaclass, subtype=type_shell_substitute(name, new, typ.subtype, schema), ) else: return typ def try_compile_irast_to_sql_tree( compiled_expr: s_expr.CompiledExpression, span: Optional[parsing.Span] ) -> None: # compile the expression to sql to preempt errors downstream from edb.pgsql import compiler as pg_compiler try: pg_compiler.compile_ir_to_sql_tree( compiled_expr.irast, output_format=pg_compiler.OutputFormat.NATIVE, singleton_mode=True, ) except errors.EdgeDBError as exception: exception.set_span(span) raise exception except: raise def str_interpolation_to_old_style(interp: qlast.StrInterp) -> str: r"""Convert a \(name) string interpolation to {name} style for schema use. The {name} style of string interpolation is used (with special handling) for errmessage in constraints, and we want to also support \(name) style. We (somewhat unfortunately) implement this by converting the \(name) style *into* the older {name} style. It would be somewhat nicer to store the string using \(name) style, but doing that loses the ability to prevent interpolation by doing \\(name). (Since the lexed version is stored in that case, we wouldn't be able to distinguish between the cases.) """ res = interp.prefix for frag in interp.interpolations: match frag.expr: case qlast.Path( partial=False, steps=[ qlast.Anchor(name=name) | qlast.ObjectRef(name=name, module=None) ] ): res += '{' + name + '}' case _: raise errors.SchemaDefinitionError( "only variables are allowed in simple schema " "interpolations", span=frag.expr.span, ) res += frag.suffix return res ================================================ FILE: edb/schema/version.py ================================================ # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import uuid from . import delta as sd from . import objects as so class BaseSchemaVersion(so.Object): version = so.SchemaField(uuid.UUID) class SchemaVersion(BaseSchemaVersion, so.InternalObject): pass class SchemaVersionCommandContext(sd.ObjectCommandContext[SchemaVersion]): pass class SchemaVersionCommand( sd.ObjectCommand[SchemaVersion], context_class=SchemaVersionCommandContext, ): pass class CreateSchemaVersion( SchemaVersionCommand, sd.CreateObject[SchemaVersion], ): pass class AlterSchemaVersion( SchemaVersionCommand, sd.AlterObject[SchemaVersion], ): pass class GlobalSchemaVersion( BaseSchemaVersion, so.InternalObject, so.GlobalObject ): pass class GlobalSchemaVersionCommandContext( sd.ObjectCommandContext[GlobalSchemaVersion], ): pass class GlobalSchemaVersionCommand( sd.ObjectCommand[GlobalSchemaVersion], context_class=GlobalSchemaVersionCommandContext, ): pass class CreateGlobalSchemaVersion( GlobalSchemaVersionCommand, sd.CreateObject[GlobalSchemaVersion], ): pass class AlterGlobalSchemaVersion( GlobalSchemaVersionCommand, sd.AlterObject[GlobalSchemaVersion], ): pass ================================================ FILE: edb/server/.gitignore ================================================ *.c *.html ================================================ FILE: edb/server/__init__.py ================================================ ## # Copyright (c) 2008-present MagicStack Inc. # All rights reserved. # # See LICENSE for details. ## from __future__ import annotations ================================================ FILE: edb/server/_rust_native/Cargo.toml ================================================ [package] name = "rust_native" version = "0.1.0" license = "MIT/Apache-2.0" authors = ["MagicStack Inc. "] edition = "2021" [lints] workspace = true [features] python_extension = ["pyo3/extension-module", "pyo3/serde"] [dependencies] pyo3 = { workspace = true } pyo3_util.workspace = true conn_pool = { workspace = true, features = [ "python_extension" ] } pgrust = { workspace = true, features = [ "python_extension" ] } gel-http = { workspace = true, features = [ "python_extension" ] } gel-jwt = { workspace = true, features = [ "python_extension" ] } [lib] crate-type = ["lib", "cdylib"] path = "src/lib.rs" ================================================ FILE: edb/server/_rust_native/src/lib.rs ================================================ use pyo3::{ pymodule, types::{PyAnyMethods, PyModule, PyModuleMethods}, Bound, PyResult, Python, }; use pyo3_util::logging::{get_python_logger_level, initialize_logging_in_thread}; const MODULE_PREFIX: &str = "edb.server._rust_native"; fn add_child_module( py: Python, parent: &Bound, name: &str, init_fn: fn(Python, &Bound) -> PyResult<()>, ) -> PyResult<()> { let full_name = format!("{MODULE_PREFIX}.{name}"); let child_module = PyModule::new(py, &full_name)?; init_fn(py, &child_module)?; parent.add(name, &child_module)?; // Add the child module to the sys.modules dictionary so it can be imported // by name. let sys_modules = py.import("sys")?.getattr("modules")?; sys_modules.set_item(full_name, child_module)?; Ok(()) } #[pymodule] fn _rust_native(py: Python, m: &Bound) -> PyResult<()> { // Initialize any logging in this thread to route to "edb.server" let level = get_python_logger_level(py, "edb.server")?; initialize_logging_in_thread("edb.server", level); add_child_module(py, m, "_conn_pool", conn_pool::python::_conn_pool)?; add_child_module(py, m, "_pg_rust", pgrust::python::_pg_rust)?; add_child_module(py, m, "_http", gel_http::python::_gel_http)?; add_child_module(py, m, "_jwt", gel_jwt::python::_jwt)?; Ok(()) } ================================================ FILE: edb/server/args.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, Optional, ItemsView, Mapping, NamedTuple, NoReturn, Sequence, ) import logging import os import pathlib import re import warnings import tempfile import click import psutil from edb import buildmeta from edb.common import devmode from edb.common import enum from edb.common import typeutils from edb.schema import defines as schema_defines from edb.pgsql import params as pgsql_params from . import defines MIB = 1024 * 1024 RAM_MIB_PER_CONN = 100 TLS_CERT_FILE_NAME = "edbtlscert.pem" TLS_KEY_FILE_NAME = "edbprivkey.pem" JWS_KEY_FILE_NAME = "edbjwskeys.pem" logger = logging.getLogger('edb.server') class InvalidUsageError(Exception): def __init__(self, msg: str, exit_code: int = 2) -> None: super().__init__(msg, exit_code) def abort(msg: str, *, exit_code: int = 2) -> NoReturn: raise InvalidUsageError(msg, exit_code) class StartupScript(NamedTuple): text: str database: str user: str class ServerSecurityMode(enum.StrEnum): Strict = "strict" InsecureDevMode = "insecure_dev_mode" class ServerEndpointSecurityMode(enum.StrEnum): Tls = "tls" Optional = "optional" class ServerTlsCertMode(enum.StrEnum): RequireFile = "require_file" SelfSigned = "generate_self_signed" class JOSEKeyMode(enum.StrEnum): RequireFile = "require_file" Generate = "generate" class ReadinessState(enum.StrEnum): Default = "default" """Default state: serving normally""" NotReady = "not_ready" """/server/status/ready returns an error, but clients can still connect.""" ReadOnly = "read_only" """Only read-only queries are allowed.""" Offline = "offline" """Any existing connections are gracefully terminated and no new connections are accepted.""" Blocked = "blocked" """Any existing connections are gracefully terminated and all new connections are accepted but are immediately terminated with a ServerBlockedError.""" class ServerAuthMethod(enum.StrEnum): Auto = "auto" Trust = "Trust" Scram = "SCRAM" JWT = "JWT" Password = "Password" mTLS = "mTLS" class ServerConnTransport(enum.StrEnum): HTTP = "HTTP" TCP = "TCP" TCP_PG = "TCP_PG" SIMPLE_HTTP = "SIMPLE_HTTP" HTTP_METRICS = "HTTP_METRICS" HTTP_HEALTH = "HTTP_HEALTH" class ReloadTrigger(enum.StrEnum): """ Configure what triggers the reload of the following config files: 1. TLS certificate and key (server config) 2. JWS key (server config) 3. Multi-tenant config file (server config) 4. Readiness state (server or tenant config) 5. JWT sub allowlist and revocation list (server or tenant config) 6. The TOML config file (server or tenant config) """ Default = "default" """By default, reload on both SIGHUP and fsevent.""" Never = "never" """Disable the reload function.""" Signal = "signal" """Only reload on SIGHUP.""" FileSystemEvent = "fsevent" """Watch the files for changes and reload when it happens.""" class NetWorkerMode(enum.StrEnum): Default = "default" Disabled = "disabled" class ServerAuthMethods: def __init__( self, methods: Mapping[ServerConnTransport, list[ServerAuthMethod]], ) -> None: self._methods = dict(methods) def get(self, transport: ServerConnTransport) -> list[ServerAuthMethod]: return self._methods[transport] def items(self) -> ItemsView[ServerConnTransport, list[ServerAuthMethod]]: return self._methods.items() def __str__(self): return ','.join( f'{t.lower()}:{'/'.join(m.lower() for m in mm)}' for t, mm in self._methods.items() ) DEFAULT_AUTH_METHODS = ServerAuthMethods({ ServerConnTransport.TCP: [ServerAuthMethod.Scram], ServerConnTransport.TCP_PG: [ServerAuthMethod.Scram], ServerConnTransport.HTTP: [ServerAuthMethod.JWT], ServerConnTransport.SIMPLE_HTTP: [ ServerAuthMethod.Password, ServerAuthMethod.JWT], ServerConnTransport.HTTP_METRICS: [ServerAuthMethod.Auto], ServerConnTransport.HTTP_HEALTH: [ServerAuthMethod.Auto], }) class BackendCapabilitySets(NamedTuple): must_be_present: list[pgsql_params.BackendCapabilities] must_be_absent: list[pgsql_params.BackendCapabilities] class CompilerPoolMode(enum.StrEnum): Default = "default" Fixed = "fixed" OnDemand = "on_demand" Remote = "remote" MultiTenant = "fixed_multi_tenant" def __init__(self, name): self.pool_class = None def assign_implementation(self, cls): # decorator function to link this enum with the actual implementation self.pool_class = cls return cls class ServerConfig(NamedTuple): data_dir: pathlib.Path backend_dsn: str backend_adaptive_ha: bool tenant_id: Optional[str] ignore_other_tenants: bool multitenant_config_file: Optional[pathlib.Path] log_level: str log_to: str bootstrap_only: bool inplace_upgrade_prepare: Optional[pathlib.Path] inplace_upgrade_finalize: bool inplace_upgrade_rollback: bool bootstrap_command: str bootstrap_command_file: pathlib.Path default_branch: Optional[str] default_database: Optional[str] default_database_user: Optional[str] devmode: bool testmode: bool bind_addresses: list[str] port: int background: bool pidfile_dir: pathlib.Path daemon_user: str daemon_group: str runstate_dir: pathlib.Path extensions_dir: tuple[pathlib.Path, ...] max_backend_connections: Optional[int] compiler_pool_size: int compiler_worker_branch_limit: int compiler_pool_mode: CompilerPoolMode compiler_pool_addr: tuple[str, int] compiler_pool_tenant_cache_size: int compiler_worker_max_rss: Optional[int] echo_runtime_info: bool emit_server_status: str temp_dir: bool auto_shutdown_after: float readiness_state_file: Optional[pathlib.Path] disable_dynamic_system_config: bool reload_config_files: ReloadTrigger net_worker_mode: NetWorkerMode config_file: Optional[pathlib.Path] startup_script: Optional[StartupScript] status_sinks: list[Callable[[str], None]] tls_cert_file: pathlib.Path tls_key_file: pathlib.Path tls_cert_mode: ServerTlsCertMode tls_client_ca_file: Optional[pathlib.Path] jws_key_file: pathlib.Path jose_key_mode: JOSEKeyMode jwt_sub_allowlist_file: Optional[pathlib.Path] jwt_revocation_list_file: Optional[pathlib.Path] default_auth_method: ServerAuthMethods security: ServerSecurityMode binary_endpoint_security: ServerEndpointSecurityMode http_endpoint_security: ServerEndpointSecurityMode instance_name: str backend_capability_sets: BackendCapabilitySets admin_ui: bool cors_always_allowed_origins: Optional[str] class PathPath(click.Path): name = 'path' def convert(self, value, param, ctx): return pathlib.Path(super().convert(value, param, ctx)).absolute() class PortType(click.ParamType): name = 'port' def convert(self, value, param, ctx): if value == 'auto': return 0 try: return int(value, 10) except TypeError: self.fail( "expected string for int() conversion, got " f"{value!r} of type {type(value).__name__}", param, ctx, ) except ValueError: self.fail(f"{value!r} is not a valid integer", param, ctx) class BackendCapabilitySet(click.ParamType): name = 'capability' def __init__(self): self.choices = { cap.name: cap for cap in pgsql_params.BackendCapabilities if cap.name != 'NONE' } def get_metavar(self, param): return " ".join(f'[[~]{cap}]' for cap in self.choices) def convert(self, value, param, ctx): must_be_present = [] must_be_absent = [] visited = set() for cap_str in value.split(): try: if cap_str.startswith("~"): cap = self.choices[cap_str[1:].upper()] must_be_absent.append(cap) else: cap = self.choices[cap_str.upper()] must_be_present.append(cap) if cap in visited: self.fail(f"duplicate capability: {cap_str}", param, ctx) else: visited.add(cap) except KeyError: self.fail( f"invalid capability: {cap_str}. " f"(choose from {', '.join(self.choices)})", param, ctx, ) return BackendCapabilitySets( must_be_present=must_be_present, must_be_absent=must_be_absent, ) class CompilerPoolModeChoice(click.Choice): def __init__(self): super().__init__( list(sorted( set(CompilerPoolMode.__members__.values()) - {CompilerPoolMode.Remote} )), ) def convert(self, value, param, ctx): if value == "remote": return CompilerPoolMode.Remote else: return super().convert(value, param, ctx) def _get_runstate_dir_default() -> str: runstate_dir: Optional[str] try: runstate_dir = buildmeta.get_build_metadata_value("RUNSTATE_DIR") except buildmeta.MetadataError: runstate_dir = None if runstate_dir is None: runstate_dir = '' return runstate_dir def _validate_max_backend_connections(ctx, param, value): if value is not None and value < defines.BACKEND_CONNECTIONS_MIN: raise click.BadParameter( f'the minimum number of backend connections ' f'is {defines.BACKEND_CONNECTIONS_MIN}') return value def compute_default_max_backend_connections() -> int: total_mem = psutil.virtual_memory().total total_mem_mb = total_mem // MIB if total_mem_mb <= 1024: return defines.BACKEND_CONNECTIONS_MIN else: return max( total_mem_mb // RAM_MIB_PER_CONN, defines.BACKEND_CONNECTIONS_MIN, ) def adjust_testmode_max_connections(max_conns): # Some test cases will start a second EdgeDB server (default # max_backend_connections=5), so we should reserve some backend # connections for that. This is ideally calculated upon the edb test -j # option, but that also depends on the total available memory. We are # hard-coding 15 reserved connections here for simplicity. return max(1, max_conns // 2, max_conns - 30) def _validate_compiler_pool_size(ctx, param, value): if value is not None and value < defines.BACKEND_COMPILER_POOL_SIZE_MIN: raise click.BadParameter( f'the minimum value for the compiler pool size option ' f'is {defines.BACKEND_COMPILER_POOL_SIZE_MIN}') return value def _validate_compiler_pool_host_port(ctx, param, value): if value is None: return None address = value.split(":", 1) if len(address) == 1: return address[0], defines.EDGEDB_REMOTE_COMPILER_PORT else: try: return address[0], int(address[1]) except ValueError: raise click.BadParameter(f'port must be int: {address[1]}') def compute_default_compiler_pool_size() -> int: total_mem = psutil.virtual_memory().total total_mem_mb = total_mem // MIB if total_mem_mb <= 1024: return defines.BACKEND_COMPILER_POOL_SIZE_MIN else: return max( psutil.cpu_count(logical=False) or 0, defines.BACKEND_COMPILER_POOL_SIZE_MIN, ) def _validate_tenant_id(ctx, param, value): if value is not None: if len(value) > schema_defines.MAX_TENANT_ID_LENGTH: raise click.BadParameter( f'cannot be longer than' f' {schema_defines.MAX_TENANT_ID_LENGTH} characters') if not value.isalnum() or not value.isascii(): raise click.BadParameter( f'contains invalid characters') return value def _status_sink_file(path: str) -> Callable[[str], None]: def _writer(status: str) -> None: try: with open(path, 'a') as f: print(status, file=f, flush=True) except OSError as e: logger.warning( f'could not write server status to {path!r}: {e.strerror}') except Exception as e: logger.warning( f'could not write server status to {path!r}: {e}') return _writer def _status_sink_fd(fileno: int) -> Callable[[str], None]: def _writer(status: str) -> None: try: with open(fileno, mode='a', closefd=False) as f: print(status, file=f, flush=True) except OSError as e: logger.warning( f'could not write server status to fd://{fileno!r}: ' f'{e.strerror}') except Exception as e: logger.warning( f'could not write server status to fd://{fileno!r}: {e}') return _writer def _validate_default_auth_method( ctx: click.Context, param: click.Option | click.Parameter, value: Any, ) -> ServerAuthMethods | None: if value is None: return None methods = dict(DEFAULT_AUTH_METHODS.items()) names = {v.lower(): v for v in ServerAuthMethod.__members__.values()} method = names.get(value.lower()) if method is not None: # Single auth method value. # # HTTP does not support SCRAM, but for backward compatibility # if SCRAM is passed explicitly, default HTTP to JWT. if method in {ServerAuthMethod.Auto, ServerAuthMethod.Scram}: pass else: for t in methods: # HTTP_METRICS and HTTP_HEALTH support only mTLS, but for # backward compatibility, default them to `auto` if unsupported # method is passed explicitly. if t in ( ServerConnTransport.HTTP_METRICS, ServerConnTransport.HTTP_HEALTH, ): if method not in ( ServerAuthMethod.Trust, ServerAuthMethod.mTLS, ): continue methods[t] = [method] elif "," not in value and ":" not in value: raise click.BadParameter( f"invalid authentication method: {value}, " f"supported values are: {', '.join(names)})" ) else: # Per-transport configuration. transport_specs = value.split(",") transport_names = { v.lower(): v for v in ServerConnTransport.__members__.values() } for transport_spec in transport_specs: transport_spec = transport_spec.strip() transport_name, _, method_names = transport_spec.partition(':') if not method_names: raise click.BadParameter( "format is :[/method...][,...]") transport = transport_names.get(transport_name.lower()) if not transport: raise click.BadParameter( f"invalid connection transport: {transport_name}, " f"supported values are: {', '.join(transport_names)})" ) transport_methods = [] for method_name in method_names.split('/'): method = names.get(method_name) if not method: raise click.BadParameter( f"invalid authentication method: {method_name}, " f"supported values are: {', '.join(names)})" ) transport_methods.append(method) methods[transport] = transport_methods return ServerAuthMethods(methods) def oxford_comma(els: Sequence[str]) -> str: '''Who gives a fuck?''' assert els if len(els) == 1: return els[0] elif len(els) == 2: return f'{els[0]} and {els[1]}' else: return f'{", ".join(els[:-1])}, and {els[-1]}' class EnvvarResolver(click.Option): def resolve_envvar_value(self, ctx: click.Context): if self.envvar is None: return None if not isinstance(self.envvar, str): raise click.BadParameter( "expected a single envvar value but got multiple") file_var = f'{self.envvar}_FILE' alt_var = f'{self.envvar}_ENV' old_envvar = self.envvar.replace('GEL_', 'EDGEDB_') old_file_var = f'{old_envvar}_FILE' old_alt_var = f'{old_envvar}_ENV' vars_set = [] for var, old_var in [ (self.envvar, old_envvar), (file_var, old_file_var), (alt_var, old_alt_var), ]: if var in os.environ and old_var in os.environ: print( f"Warning: both {var} and {old_var} are specified. " f"{var} will take precedence." ) if var in os.environ: vars_set.append(var) elif old_var in os.environ: vars_set.append(old_var) if len(vars_set) > 1: amt = "both" if len(vars_set) == 2 else "all" raise click.BadParameter( f'{oxford_comma(vars_set)} are exclusive, ' f'but {amt} are set.' ) var_val = os.environ.get(self.envvar) or os.environ.get(old_envvar) alt_var_val = os.environ.get(alt_var) or os.environ.get(old_alt_var) file_var_val = os.environ.get(file_var) or os.environ.get(old_file_var) if alt_var_val: var_val = os.environ.get(alt_var_val) if var_val: return var_val if file_var_val: try: with open(file_var_val, 'rt') as f: return f.read() except Exception as e: raise click.BadParameter( f'could not read the file specified by ' f'{file_var} ({file_var_val!r})') from e return None server_options = typeutils.chain_decorators([ click.option( '-D', '--data-dir', type=PathPath(), envvar="GEL_SERVER_DATADIR", cls=EnvvarResolver, help='database cluster directory'), click.option( '--postgres-dsn', type=str, hidden=True, help='[DEPRECATED] DSN of a remote Postgres cluster, if using one'), click.option( '--backend-dsn', type=str, envvar="GEL_SERVER_BACKEND_DSN", cls=EnvvarResolver, help='DSN of a remote backend cluster, if using one. ' 'Also supports HA clusters, for example: stolon+consul+http://' 'localhost:8500/test_cluster'), click.option( '--enable-backend-adaptive-ha', 'backend_adaptive_ha', is_flag=True, help='If backend adaptive HA is enabled, the Gel server will ' 'monitor the health of the backend cluster and shutdown all ' 'backend connections if threshold is reached, until reconnected ' 'again using the same DSN (HA should have updated the DNS ' 'value). Default is disabled.'), click.option( '--tenant-id', type=str, callback=_validate_tenant_id, envvar="GEL_SERVER_TENANT_ID", cls=EnvvarResolver, help='Specifies the tenant ID of this server when hosting' ' multiple Gel instances on one Postgres cluster.' ' Must be an alphanumeric ASCII string, maximum' f' {schema_defines.MAX_TENANT_ID_LENGTH} characters long.', ), click.option( '--ignore-other-tenants', is_flag=True, help='If set, the server will ignore the presence of another tenant ' 'in the database instance in single-tenant mode instead of ' 'exiting with a catalog incompatibility error.' ), click.option( '--multitenant-config-file', type=PathPath(), metavar="PATH", envvar="GEL_SERVER_MULTITENANT_CONFIG_FILE", cls=EnvvarResolver, hidden=True, help='Start the server in multi-tenant mode, with reloadable tenants ' 'configured in the given file. Each tenant must have a unique ' 'SNI name as the key to route the traffic correctly, as well as ' 'a dedicated backend DSN to host the tenant data. See edb/server/' 'multitenant.py for config file format. All tenants share the ' 'same compiler pool, thus the same stdlib. So if any of the ' 'backends contains test-mode schema, the server should be ' 'started with --testmode to handle them properly.', ), click.option( '-l', '--log-level', envvar="GEL_SERVER_LOG_LEVEL", cls=EnvvarResolver, default='i', type=click.Choice( ['debug', 'd', 'info', 'i', 'warn', 'w', 'error', 'e', 'silent', 's'], case_sensitive=False, ), help=( 'Logging level. Possible values: (d)ebug, (i)nfo, (w)arn, ' '(e)rror, (s)ilent' )), click.option( '--log-to', help=('send logs to DEST, where DEST can be a file name, "syslog", ' 'or "stderr"'), type=str, metavar='DEST', default='stderr'), click.option( '--bootstrap', is_flag=True, hidden=True, help='[DEPRECATED] bootstrap the database cluster and exit'), click.option( '--bootstrap-only', is_flag=True, envvar="GEL_SERVER_BOOTSTRAP_ONLY", cls=EnvvarResolver, help='bootstrap the database cluster and exit'), click.option( '--inplace-upgrade-prepare', type=PathPath(), envvar="GEL_SERVER_INPLACE_UPGRADE_PREPARE", cls=EnvvarResolver, help='try to do an in-place upgrade with the specified dump file'), click.option( '--inplace-upgrade-rollback', type=bool, is_flag=True, envvar="GEL_SERVER_INPLACE_UPGRADE_ROLLBACK", cls=EnvvarResolver, help='rollback a prepared upgrade'), click.option( '--inplace-upgrade-finalize', type=bool, is_flag=True, envvar="GEL_SERVER_INPLACE_UPGRADE_FINALIZE", cls=EnvvarResolver, help='finalize an in-place upgrade'), click.option( '--default-branch', type=str, help='the name of the default branch to create'), click.option( '--default-database', type=str, hidden=True, help='[DEPRECATED] the name of the default database to create'), click.option( '--default-database-user', type=str, hidden=True, help='[DEPRECATED] the name of the default database owner'), click.option( '--bootstrap-command', metavar="QUERIES", envvar="GEL_SERVER_BOOTSTRAP_COMMAND", cls=EnvvarResolver, help='run the commands when initializing the database. ' 'Queries are executed by default user within default ' 'database. May be used with or without `--bootstrap-only`.'), click.option( '--bootstrap-command-file', type=PathPath(), metavar="PATH", help='run the script when initializing the database. ' 'Script run by default user within default database. ' 'May be used with or without `--bootstrap-only`.'), click.option( '--bootstrap-script', type=PathPath(), help='[DEPRECATED] use --bootstrap-command-file instead.'), click.option( '--devmode/--no-devmode', help='enable or disable the development mode', default=None), click.option( '--testmode/--no-testmode', help='enable or disable the test mode', default=False), click.option( '-I', '--bind-address', type=str, multiple=True, envvar="GEL_SERVER_BIND_ADDRESS", cls=EnvvarResolver, help='IP addresses to listen on, specify multiple times for more than ' 'one address to listen on'), click.option( '-P', '--port', type=PortType(), default=None, envvar="GEL_SERVER_PORT", cls=EnvvarResolver, help='port to listen on'), click.option( '-b', '--background', is_flag=True, help='daemonize'), click.option( '--pidfile-dir', type=PathPath(), default=None, help='path to PID file directory, defaults to --runstate-dir'), click.option( '--daemon-user', type=int), click.option( '--daemon-group', type=int), click.option( '--runstate-dir', type=PathPath(), default=None, envvar="GEL_SERVER_RUNSTATE_DIR", cls=EnvvarResolver, help=f'directory where UNIX sockets and other temporary ' f'runtime files will be placed ({_get_runstate_dir_default()} ' f'by default)'), click.option( '--extensions-dir', type=PathPath(), default=(), multiple=True, envvar="GEL_SERVER_EXTENSIONS_DIR", cls=EnvvarResolver, help=f'directory where third-party extension packages are loaded from'), click.option( '--max-backend-connections', type=int, metavar='NUM', envvar="GEL_SERVER_MAX_BACKEND_CONNECTIONS", cls=EnvvarResolver, help=f'The maximum NUM of connections this Gel instance could make ' f'to the backend PostgreSQL cluster. If not set, Gel will ' f'detect and calculate the NUM: RAM/100MiB=' f'{compute_default_max_backend_connections()} for local ' f'Postgres or pg_settings.max_connections for remote Postgres, ' f'minus the NUM of --reserved-pg-connections.', callback=_validate_max_backend_connections), click.option( '--compiler-pool-size', type=int, metavar='NUM', envvar="GEL_SERVER_COMPILER_POOL_SIZE", cls=EnvvarResolver, callback=_validate_compiler_pool_size, help='Size of the compiler pool. When --compiler-pool-mode=fixed or ' 'fixed_multi_tenant, it is the NUM of compiler worker processes, ' f"defaults to {compute_default_compiler_pool_size()} (you'll see " '1 extra template process); for on_demand, it is the maximum NUM ' 'of workers the pool could scale up to, with the same default; ' 'for remote, it is the maximum NUM of concurrent requests to the ' 'remote compiler server, defaults to 2.' ), click.option( '--compiler-worker-branch-limit', type=int, metavar='NUM', default=5, envvar="GEL_SERVER_COMPILER_WORKER_BRANCH_LIMIT", cls=EnvvarResolver, help='The maximum NUM of branches each compiler worker could cache up ' 'to, default is 5. If the worker serves multiple tenants (as in ' '--compiler-pool-mode=fixed_multi_tenant or remote), this tenant ' 'on that worker will be able to cache up to NUM branches.' ), click.option( '--compiler-pool-mode', type=CompilerPoolModeChoice(), default=CompilerPoolMode.Default.value, envvar="GEL_SERVER_COMPILER_POOL_MODE", cls=EnvvarResolver, help='Choose a mode for the compiler pool to scale. "fixed" means the ' 'pool will not scale and sticks to --compiler-pool-size, while ' '"on_demand" means the pool will maintain at least 1 worker and ' 'automatically scale up (to --compiler-pool-size workers ) and ' 'down to the demand. Defaults to "fixed" in production mode and ' '"on_demand" in development mode.', ), click.option( '--compiler-pool-addr', hidden=True, callback=_validate_compiler_pool_host_port, envvar="GEL_SERVER_COMPILER_POOL_ADDR", cls=EnvvarResolver, help=f'Specify the host[:port] of the compiler pool to connect to, ' f'only used if --compiler-pool-mode=remote. Default host is ' f'localhost, port is {defines.EDGEDB_REMOTE_COMPILER_PORT}', ), click.option( "--compiler-pool-tenant-cache-size", hidden=True, type=int, default=20, envvar="GEL_SERVER_COMPILER_POOL_TENANT_CACHE_SIZE", cls=EnvvarResolver, help="Maximum number of tenants for which each compiler worker can " "cache their schemas, " "only used when --compiler-pool-mode=fixed_multi_tenant" ), click.option( '--echo-runtime-info', type=bool, default=False, is_flag=True, help='[DEPREATED, use --emit-server-status] ' 'echo runtime info to stdout; the format is JSON, prefixed by ' '"EDGEDB_SERVER_DATA:", ended with a new line'), click.option( '--emit-server-status', type=str, default=None, metavar='DEST', multiple=True, help='Instruct the server to emit changes in status to DEST, ' 'where DEST is a URI specifying a file (file://), ' 'or a file descriptor (fd://). If the URI scheme ' 'is not specified, file:// is assumed.'), click.option( '--temp-dir', type=bool, default=False, is_flag=True, help='create a temporary database cluster directory ' 'that will be automatically purged on server shutdown'), click.option( '--auto-shutdown', type=bool, default=False, is_flag=True, hidden=True, help='shutdown the server after the last ' + 'connection is closed'), click.option( '--auto-shutdown-after', type=float, default=-1.0, metavar='N', help='shutdown the server if no client connections were made in the ' 'last N seconds, if N = 0, shut down after the last client has ' 'disconnected, N < 0 (default) means no auto shutdown'), click.option( '--tls-cert-file', type=PathPath(), envvar="GEL_SERVER_TLS_CERT_FILE", cls=EnvvarResolver, help='Specifies a path to a file containing a server TLS certificate ' 'in PEM format, as well as possibly any number of CA ' 'certificates needed to establish the certificate ' 'authenticity. If the file does not exist and the ' '--tls-cert-mode option is set to "generate_self_signed", a ' 'self-signed certificate will be automatically created in ' 'the specified path.'), click.option( '--tls-key-file', type=PathPath(), envvar="GEL_SERVER_TLS_KEY_FILE", cls=EnvvarResolver, help='Specifies a path to a file containing the private key in PEM ' 'format. If the file does not exist and the --tls-cert-mode ' 'option is set to "generate_self_signed", the private key will ' 'be automatically created in the specified path.'), click.option( '--tls-cert-mode', envvar="GEL_SERVER_TLS_CERT_MODE", cls=EnvvarResolver, type=click.Choice( ['default'] + list(ServerTlsCertMode.__members__.values()), case_sensitive=True, ), default='default', help='Specifies what to do when the TLS certificate and key are ' 'either not specified or are missing. When set to ' '"require_file", the TLS certificate and key must be specified ' 'in the --tls-cert-file and --tls-key-file options and both must ' 'exist. When set to "generate_self_signed" a new self-signed ' 'certificate and private key will be generated and placed in the ' 'path specified by --tls-cert-file/--tls-key-file, if those are ' 'set, otherwise the generated certificate and key are stored as ' f'`{TLS_CERT_FILE_NAME}` and `{TLS_KEY_FILE_NAME}` in the data ' 'directory, or, if the server is running with --backend-dsn, ' 'in a subdirectory of --runstate-dir.\n\nThe default is ' '"require_file" when the --security option is set to "strict", ' 'and "generate_self_signed" when the --security option is set to ' '"insecure_dev_mode"'), click.option( '--tls-client-ca-file', type=PathPath(), envvar='EDGEDB_SERVER_TLS_CLIENT_CA_FILE', cls=EnvvarResolver, help='Specifies a path to a file containing a TLS CA certificate to ' 'verify client certificates on demand. When set, the default ' 'authentication method of HTTP_METRICS(/metrics) and HTTP_HEALTH' '(/server/*) will also become "mTLS", unless explicitly set in ' '--default-auth-method. Note, the protection of such HTTP ' 'endpoints is only complete if --http-endpoint-security is also ' 'set to `tls`, or they are still accessible in plaintext HTTP.' ), click.option( '--generate-self-signed-cert', type=bool, default=False, is_flag=True, help='DEPRECATED.\n\n' 'Use --tls-cert-mode=generate_self_signed instead.'), click.option( '--binary-endpoint-security', envvar="GEL_SERVER_BINARY_ENDPOINT_SECURITY", cls=EnvvarResolver, type=click.Choice( ['default', 'tls', 'optional'], case_sensitive=True, ), default='default', help='Specifies the security mode of server binary endpoint. ' 'When set to `optional`, non-TLS connections are allowed. ' 'The default is `tls`.', ), click.option( '--http-endpoint-security', envvar="GEL_SERVER_HTTP_ENDPOINT_SECURITY", cls=EnvvarResolver, type=click.Choice( ['default', 'tls', 'optional'], case_sensitive=True, ), default='default', help='Specifies the security mode of server HTTP endpoint. ' 'When set to `optional`, non-TLS connections are allowed. ' 'The default is `tls`.', ), click.option( '--security', envvar="GEL_SERVER_SECURITY", cls=EnvvarResolver, type=click.Choice( ['default', 'strict', 'insecure_dev_mode'], case_sensitive=True, ), default='default', help=( 'When set to `insecure_dev_mode`, sets the default ' 'authentication method to `Trust`, enables non-TLS ' 'client HTTP connections, and implies ' '`--tls-cert-mode=generate_self_signed`. The default is `strict`.' ), ), click.option( '--jws-key-file', type=PathPath(), envvar="GEL_SERVER_JWS_KEY_FILE", cls=EnvvarResolver, hidden=True, help='Specifies a path to a file containing a public key in PEM ' 'or JSON JWK format used to verify JWT signatures. The file may ' 'also contain a private key to sign JWT tokens for ' 'SCRAM-over-HTTP.'), click.option( '--jwe-key-file', type=PathPath(), hidden=True, help='Deprecated: no longer in use.'), click.option( '--jose-key-mode', envvar="GEL_SERVER_JOSE_KEY_MODE", cls=EnvvarResolver, type=click.Choice( ['default'] + list(JOSEKeyMode.__members__.values()), case_sensitive=True, ), hidden=True, default='default', help='Specifies what to do when the JOSE keys are either not ' 'specified or are missing. When set to "require_file", the JOSE ' 'keys must be specified in the --jws-key-file and the file must ' 'exist. When set to "generate", a new key pair will be ' 'generated and placed in the path specified by --jws-key-file, ' 'if those are set, otherwise the generated key pairs are stored ' f'as `{JWS_KEY_FILE_NAME}` in the data directory, or, if the ' 'server is running with --backend-dsn, in a subdirectory of ' '--runstate-dir.\n\nThe default is "require_file" when the ' '--security option is set to "strict", and "generate" when the ' '--security option is set to "insecure_dev_mode"'), click.option( '--jwt-sub-allowlist-file', type=PathPath(), envvar="GEL_SERVER_JWT_SUB_ALLOWLIST_FILE", cls=EnvvarResolver, hidden=True, help='A file where the server can obtain a list of all JWT subjects ' 'that are allowed to access this instance. ' 'The file must contain one JWT "sub" claim value per line. ' 'Applies only to the JWT authentication method.' ), click.option( '--jwt-revocation-list-file', type=PathPath(), envvar="GEL_SERVER_JWT_REVOCATION_LIST_FILE", cls=EnvvarResolver, hidden=True, help='A file where the server can obtain a list of all JWT ids ' 'that are allowed to access this instance. ' 'The file must contain one JWT "jti" claim value per line. ' 'Applies only to the JWT authentication method.' ), click.option( "--default-auth-method", envvar="GEL_SERVER_DEFAULT_AUTH_METHOD", cls=EnvvarResolver, callback=_validate_default_auth_method, type=str, help=( "The default authentication method to use when none is " "explicitly configured. Defaults to 'auto', which means " "the SCRAM authentication method for TCP connections and " "the JWT authentication method for HTTP-tunneled connections." ), ), click.option( "--readiness-state-file", envvar="GEL_SERVER_READINESS_STATE_FILE", cls=EnvvarResolver, type=PathPath(), help=( "Path to a file containing the value for server readiness state. " "When it contains 'not_ready' (without quotes), the server will " "refuse connections and the '/server/status/ready' check will " "return a 503 status. Every other value, including absense of " "file indicates that the server is in the 'ready' state and " "can server connections. The file can be modified when the " "server is running." ), ), click.option( '--instance-name', envvar="GEL_SERVER_INSTANCE_NAME", cls=EnvvarResolver, type=str, default=None, hidden=True, help='Server instance name.'), click.option( '--backend-capabilities', envvar="GEL_SERVER_BACKEND_CAPABILITIES", cls=EnvvarResolver, type=BackendCapabilitySet(), help="A space-separated set of backend capabilities, which are " "required to be present, or absent if prefixed with ~. Gel " "will only start if the actual backend capabilities match the " "specified set. However if the backend was never bootstrapped, " "the capabilities prefixed with ~ will be *disabled permanently* " "in Gel as if the backend never had them." ), click.option( '--version', is_flag=True, help='Show the version and exit.'), click.option( '--admin-ui', envvar="GEL_SERVER_ADMIN_UI", cls=EnvvarResolver, type=click.Choice( ['default', 'enabled', 'disabled'], case_sensitive=True, ), default='default', help='Enable admin UI.'), click.option( '--cors-always-allowed-origins', envvar="GEL_SERVER_CORS_ALWAYS_ALLOWED_ORIGINS", cls=EnvvarResolver, hidden=True, help='A comma separated list of origins to always allow CORS requests ' 'from regardless of the `cors_allow_orgin` config. The `*` ' 'character can be used as a wildcard. Intended for use by cloud ' 'to always allow the cloud UI to make requests to the instance.' ), click.option( '--disable-dynamic-system-config', is_flag=True, envvar="GEL_SERVER_DISABLE_DYNAMIC_SYSTEM_CONFIG", cls=EnvvarResolver, help="Disable dynamic configuration of system config values", ), click.option( "--reload-config-files", envvar="GEL_SERVER_RELOAD_CONFIG_FILES", cls=EnvvarResolver, type=click.Choice( list(ReloadTrigger.__members__.values()), case_sensitive=True ), hidden=True, default='default', help='Specifies when to reload the config files. See the docstring of ' 'ReloadTrigger for more information.', ), click.option( "--net-worker-mode", envvar="GEL_SERVER_NET_WORKER_MODE", cls=EnvvarResolver, type=click.Choice( list(NetWorkerMode.__members__.values()), case_sensitive=True ), hidden=True, default='default', help='Controls how the std::net workers work.', ), click.option( "--config-file", type=PathPath(), metavar="PATH", envvar="GEL_SERVER_CONFIG_FILE", cls=EnvvarResolver, help='Path to a TOML file to configure the server.', hidden=True, ), click.option( '--compiler-worker-max-rss', type=int, envvar="GEL_SERVER_COMPILER_WORKER_MAX_RSS", cls=EnvvarResolver, help='Maximum allowed RSS (in bytes) per compiler worker process. Any ' 'worker exceeding this limit will be terminated and recreated. ' 'Each worker is free from this limit in its first 20-30 hours ' 'after spawn to avoid infinite restarts or a thundering herd.', ), ]) compiler_options = typeutils.chain_decorators([ click.option( "--pool-size", type=int, envvar="GEL_COMPILER_POOL_SIZE", cls=EnvvarResolver, callback=_validate_compiler_pool_size, default=compute_default_compiler_pool_size(), help=f"Number of compiler worker processes. Defaults to " f"{compute_default_compiler_pool_size()}.", ), click.option( "--client-schema-cache-size", type=int, envvar="GEL_COMPILER_POOL_TENANT_CACHE_SIZE", cls=EnvvarResolver, default=20, help="Maximum number of clients for which each worker can cache their " "schemas, The compiler server is not affected by this setting, " "it keeps pickled copies of schemas from all active clients " "(each capped by --compiler-worker-branch-limit of the client)." ), click.option( '-I', '--listen-addresses', type=str, multiple=True, envvar="GEL_COMPILER_BIND_ADDRESS", cls=EnvvarResolver, default=('localhost',), help='IP addresses to listen on, specify multiple times for more than ' 'one address to listen on. Default: localhost', ), click.option( '-P', '--listen-port', type=PortType(), envvar="GEL_COMPILER_SERVER_PORT", cls=EnvvarResolver, help=f'Port to listen on. ' f'Default: {defines.EDGEDB_REMOTE_COMPILER_PORT}', ), click.option( '--runstate-dir', type=PathPath(), default=None, envvar="GEL_COMPILER_RUNSTATE_DIR", cls=EnvvarResolver, help="Directory to store UNIX domain socket file for IPC, a temporary " "directory will be used if not specified.", ), click.option( '--metrics-port', type=PortType(), envvar="GEL_COMPILER_METRICS_PORT", cls=EnvvarResolver, help=f'Port to listen on for metrics HTTP API.', ), click.option( '--worker-max-rss', type=int, envvar="GEL_COMPILER_WORKER_MAX_RSS", cls=EnvvarResolver, help='Maximum allowed RSS (in bytes) per worker process. Any worker ' 'exceeding this limit will be terminated and recreated. ' 'Each worker is free from this limit in its first 20-30 hours ' 'after spawn to avoid infinite restarts or a thundering herd.', ), ]) def parse_args(**kwargs: Any): kwargs['bind_addresses'] = kwargs.pop('bind_address') if kwargs['echo_runtime_info']: warnings.warn( "The `--echo-runtime-info` option is deprecated, use " "`--emit-server-status` instead.", DeprecationWarning, stacklevel=2, ) if kwargs['bootstrap']: warnings.warn( "Option `--bootstrap` is deprecated, use `--bootstrap-only`", DeprecationWarning, stacklevel=2, ) kwargs['bootstrap_only'] = True kwargs.pop('bootstrap', False) if kwargs['default_database_user']: if kwargs['default_database_user'] == 'edgedb': warnings.warn( "Option `--default-database-user` is deprecated." " Role `edgedb` is always created and" " no role named after unix user is created any more.", DeprecationWarning, stacklevel=2, ) else: warnings.warn( "Option `--default-database-user` is deprecated." " Please create the role explicitly.", DeprecationWarning, stacklevel=2, ) if kwargs['default_database']: if kwargs['default_database'] == 'edgedb': warnings.warn( "Option `--default-database` is deprecated." " Database `edgedb` is always created and" " no database named after unix user is created any more.", DeprecationWarning, stacklevel=2, ) else: warnings.warn( "Option `--default-database` is deprecated." " Please create the database explicitly.", DeprecationWarning, stacklevel=2, ) if kwargs['auto_shutdown']: warnings.warn( "The `--auto-shutdown` option is deprecated, use " "`--auto-shutdown-after` instead.", DeprecationWarning, stacklevel=2, ) if kwargs['auto_shutdown_after'] < 0: kwargs['auto_shutdown_after'] = 0 del kwargs['auto_shutdown'] if kwargs['postgres_dsn']: warnings.warn( "The `--postgres-dsn` option is deprecated, use " "`--backend-dsn` instead.", DeprecationWarning, stacklevel=2, ) if not kwargs['backend_dsn']: kwargs['backend_dsn'] = kwargs['postgres_dsn'] del kwargs['postgres_dsn'] if kwargs['generate_self_signed_cert']: warnings.warn( "The `--generate-self-signed-cert` option is deprecated, use " "`--tls-cert-mode=generate_self_signed` instead.", DeprecationWarning, stacklevel=2, ) if kwargs['tls_cert_mode'] == 'default': kwargs['tls_cert_mode'] = 'generate_self_signed' del kwargs['generate_self_signed_cert'] if os.environ.get('EDGEDB_SERVER_ALLOW_INSECURE_BINARY_CLIENTS') == "1": if kwargs['binary_endpoint_security'] == "tls": abort( "The value of deprecated " "EDGEDB_SERVER_ALLOW_INSECURE_BINARY_CLIENTS environment " "variable disagrees with --binary-endpoint-security" ) else: if kwargs['binary_endpoint_security'] == "default": warnings.warn( "EDGEDB_SERVER_ALLOW_INSECURE_BINARY_CLIENTS is " "deprecated. Use EDGEDB_SERVER_BINARY_ENDPOINT_SECURITY " "instead.", DeprecationWarning, stacklevel=2, ) kwargs['binary_endpoint_security'] = 'optional' if os.environ.get('EDGEDB_SERVER_ALLOW_INSECURE_HTTP_CLIENTS') == "1": if kwargs['http_endpoint_security'] == "tls": abort( "The value of deprecated " "EDGEDB_SERVER_ALLOW_INSECURE_HTTP_CLIENTS environment " "variable disagrees with --http-endpoint-security" ) else: if kwargs['http_endpoint_security'] == "default": warnings.warn( "EDGEDB_SERVER_ALLOW_INSECURE_BINARY_CLIENTS is " "deprecated. Use EDGEDB_SERVER_BINARY_ENDPOINT_SECURITY " "instead.", DeprecationWarning, stacklevel=2, ) kwargs['http_endpoint_security'] = 'optional' if kwargs['security'] == 'default': if devmode.is_in_dev_mode(): kwargs['security'] = 'insecure_dev_mode' else: kwargs['security'] = 'strict' if kwargs['security'] == 'insecure_dev_mode': if kwargs['http_endpoint_security'] == 'default': kwargs['http_endpoint_security'] = 'optional' if not kwargs['default_auth_method']: kwargs['default_auth_method'] = ServerAuthMethods({ t: [ServerAuthMethod.Trust] for t in ServerConnTransport.__members__.values() }) if kwargs['tls_cert_mode'] == 'default': kwargs['tls_cert_mode'] = 'generate_self_signed' elif not kwargs['default_auth_method']: kwargs['default_auth_method'] = DEFAULT_AUTH_METHODS transport_methods = dict(kwargs['default_auth_method'].items()) for transport in ServerConnTransport.__members__.values(): methods = transport_methods[transport] if ServerAuthMethod.Auto in methods: pos = methods.index(ServerAuthMethod.Auto) if transport in ( ServerConnTransport.HTTP_METRICS, ServerConnTransport.HTTP_HEALTH, ): if kwargs['tls_client_ca_file'] is None: method = ServerAuthMethod.Trust else: method = ServerAuthMethod.mTLS methods[pos] = method else: methods = ( methods[:pos] + DEFAULT_AUTH_METHODS.get(transport) + methods[pos + 1:] ) transport_methods[transport] = [method] elif transport in ( ServerConnTransport.HTTP_METRICS, ServerConnTransport.HTTP_HEALTH, ): if ServerAuthMethod.mTLS in methods: if kwargs['tls_client_ca_file'] is None: abort('--tls-client-ca-file is required ' 'for mTLS authentication') if not all( m is ServerAuthMethod.Trust or m is ServerAuthMethod.mTLS for m in methods ): abort( f'--default-auth-method of {transport} can only be one ' f'of: {ServerAuthMethod.Trust}, {ServerAuthMethod.mTLS} ' f'or {ServerAuthMethod.Auto}' ) kwargs['default_auth_method'] = ServerAuthMethods(transport_methods) if kwargs['binary_endpoint_security'] == 'default': kwargs['binary_endpoint_security'] = 'tls' if kwargs['http_endpoint_security'] == 'default': kwargs['http_endpoint_security'] = 'tls' if kwargs['tls_cert_mode'] == 'default': kwargs['tls_cert_mode'] = 'require_file' if kwargs['jose_key_mode'] == 'default': kwargs['jose_key_mode'] = 'generate' kwargs['security'] = ServerSecurityMode(kwargs['security']) kwargs['binary_endpoint_security'] = ServerEndpointSecurityMode( kwargs['binary_endpoint_security']) kwargs['http_endpoint_security'] = ServerEndpointSecurityMode( kwargs['http_endpoint_security']) kwargs['tls_cert_mode'] = ServerTlsCertMode(kwargs['tls_cert_mode']) kwargs['jose_key_mode'] = JOSEKeyMode(kwargs['jose_key_mode']) if kwargs['compiler_pool_mode'] == 'default': if kwargs['multitenant_config_file']: kwargs['compiler_pool_mode'] = 'fixed_multi_tenant' elif devmode.is_in_dev_mode(): kwargs['compiler_pool_mode'] = 'on_demand' else: kwargs['compiler_pool_mode'] = 'fixed' kwargs['compiler_pool_mode'] = CompilerPoolMode( kwargs['compiler_pool_mode'] ) if kwargs['compiler_pool_size'] is None: if kwargs['compiler_pool_mode'] == CompilerPoolMode.Remote: # this reflects to a local semaphore to control concurrency, # 2 means this is a small EdgeDB instance that could only issue # at max 2 concurrent compile requests at a time. kwargs['compiler_pool_size'] = 2 else: kwargs['compiler_pool_size'] = compute_default_compiler_pool_size() if kwargs['compiler_pool_mode'] == CompilerPoolMode.Remote: if kwargs['compiler_pool_addr'] is None: kwargs['compiler_pool_addr'] = ( "localhost", defines.EDGEDB_REMOTE_COMPILER_PORT ) if kwargs['compiler_worker_max_rss'] is not None: abort('cannot set --compiler-worker-max-rss when using ' '--compiler-pool-mode=remote') elif kwargs['compiler_pool_addr'] is not None: abort('--compiler-pool-addr is only meaningful ' 'under --compiler-pool-mode=remote') if kwargs['temp_dir']: if kwargs['data_dir']: abort('--temp-dir is incompatible with --data-dir/-D') if kwargs['runstate_dir']: abort('--temp-dir is incompatible with --runstate-dir') if kwargs['backend_dsn']: abort('--temp-dir is incompatible with --backend-dsn') if kwargs['multitenant_config_file']: abort('--temp-dir is incompatible with --multitenant-config-file') kwargs['data_dir'] = kwargs['runstate_dir'] = pathlib.Path( tempfile.mkdtemp()) else: if not kwargs['data_dir']: if kwargs['backend_dsn'] or kwargs['multitenant_config_file']: pass elif devmode.is_in_dev_mode(): data_dir = devmode.get_dev_mode_data_dir() if not data_dir.parent.exists(): data_dir.parent.mkdir(exist_ok=True, parents=True) kwargs["data_dir"] = data_dir else: abort('Please specify the instance data directory ' 'using the -D argument or the address of a remote ' 'backend cluster using the --backend-dsn argument') elif kwargs['backend_dsn']: abort('The -D and --backend-dsn options are mutually exclusive.') elif kwargs['multitenant_config_file']: abort('The -D and --multitenant-config-file options ' 'are mutually exclusive.') if kwargs['tls_key_file'] and not kwargs['tls_cert_file']: abort('When --tls-key-file is set, --tls-cert-file must also be set.') if kwargs['tls_cert_file'] and not kwargs['tls_key_file']: abort('When --tls-cert-file is set, --tls-key-file must also be set.') self_signing = kwargs['tls_cert_mode'] is ServerTlsCertMode.SelfSigned if not kwargs['tls_cert_file']: if kwargs['data_dir']: tls_cert_file = kwargs['data_dir'] / TLS_CERT_FILE_NAME tls_key_file = kwargs['data_dir'] / TLS_KEY_FILE_NAME elif self_signing: tls_cert_file = pathlib.Path('') / TLS_CERT_FILE_NAME tls_key_file = pathlib.Path('') / TLS_KEY_FILE_NAME else: abort( "no TLS certificate specified and certificate auto-generation" " has not been requested; see help for --tls-cert-mode", exit_code=10, ) kwargs['tls_cert_file'] = tls_cert_file kwargs['tls_key_file'] = tls_key_file if not kwargs['bootstrap_only'] and not self_signing: if not kwargs['tls_cert_file'].exists(): abort( f"TLS certificate file \"{kwargs['tls_cert_file']}\"" " does not exist and certificate auto-generation has not been" " requested; see help for --tls-cert-mode", exit_code=10, ) if ( kwargs['tls_cert_file'].exists() and not kwargs['tls_cert_file'].is_file() ): abort( f"TLS certificate file \"{kwargs['tls_cert_file']}\"" " is not a regular file" ) if ( kwargs['tls_key_file'].exists() and not kwargs['tls_key_file'].is_file() ): abort( f"TLS private key file \"{kwargs['tls_key_file']}\"" " is not a regular file" ) generate_jose = kwargs['jose_key_mode'] is JOSEKeyMode.Generate if not kwargs['jws_key_file']: if kwargs['data_dir']: jws_key_file = kwargs['data_dir'] / JWS_KEY_FILE_NAME elif generate_jose: jws_key_file = pathlib.Path('') / JWS_KEY_FILE_NAME else: abort( "no JWS key specified and JOSE keys auto-generation" " has not been requested; see help for --jose-key-mode", exit_code=11, ) kwargs['jws_key_file'] = jws_key_file del kwargs['jwe_key_file'] if not kwargs['bootstrap_only'] and not generate_jose: if not kwargs['jws_key_file'].exists(): abort( f"JWS key file \"{kwargs['jws_key_file']}\" does not exist" ) if ( kwargs['jws_key_file'].exists() and not kwargs['jws_key_file'].is_file() ): abort( f"JWT key file \"{kwargs['jws_key_file']}\"" " is not a regular file" ) if kwargs['log_level']: kwargs['log_level'] = kwargs['log_level'].lower()[0] if kwargs['bootstrap_script']: if not kwargs['bootstrap_command_file']: warnings.warn( "The `--bootstrap-script` option is deprecated, use " "`--bootstrap-command-file` instead.", DeprecationWarning, stacklevel=2, ) kwargs['bootstrap_command_file'] = kwargs['bootstrap_script'] else: warnings.warn( "Both `--bootstrap-command-file` and `--bootstrap-script` " "were specified, but are mutually exclusive. " "Ignoring the deprecated `--bootstrap-script` option.", DeprecationWarning, stacklevel=2, ) del kwargs['bootstrap_script'] if kwargs['multitenant_config_file']: for name in ( "tenant_id", "backend_dsn", "backend_adaptive_ha", "bootstrap_only", "inplace_upgrade", "bootstrap_command", "bootstrap_command_file", "instance_name", "max_backend_connections", "readiness_state_file", "jwt_sub_allowlist_file", "jwt_revocation_list_file", "config_file", ): if kwargs.get(name): opt = "--" + name.replace("_", "-") abort(f"The {opt} and --multitenant-config-file options " f"are mutually exclusive.") if kwargs['compiler_pool_mode'] is not CompilerPoolMode.MultiTenant: abort("must use --compiler-pool-mode=fixed_multi_tenant " "in multi-tenant mode") bootstrap_script_text: Optional[str] if kwargs['bootstrap_command_file']: with open(kwargs['bootstrap_command_file']) as f: bootstrap_script_text = f.read() elif kwargs['bootstrap_command']: bootstrap_script_text = kwargs['bootstrap_command'] else: bootstrap_script_text = None if bootstrap_script_text is None: startup_script = None else: startup_script = StartupScript( text=bootstrap_script_text, database=( kwargs['default_branch'] or kwargs['default_database'] or defines.EDGEDB_SUPERUSER_DB ), user=( kwargs['default_database_user'] or defines.EDGEDB_SUPERUSER ), ) status_sinks = [] if status_sink_addrs := kwargs['emit_server_status']: for status_sink_addr in status_sink_addrs: if status_sink_addr.startswith('file://'): status_sink = _status_sink_file( status_sink_addr[len('file://'):]) elif status_sink_addr.startswith('fd://'): fileno_str = status_sink_addr[len('fd://'):] try: fileno = int(fileno_str) except ValueError: abort( f'invalid file descriptor number in ' f'--emit-server-status: {fileno_str!r}' ) status_sink = _status_sink_fd(fileno) elif m := re.match(r'^(\w+)://', status_sink_addr): abort( f'unsupported destination scheme in --emit-server-status: ' f'{m.group(1)}' ) else: # Assume it's a file. status_sink = _status_sink_file(status_sink_addr) status_sinks.append(status_sink) kwargs['backend_capability_sets'] = ( kwargs.pop('backend_capabilities') or BackendCapabilitySets([], []) ) if kwargs['admin_ui'] == 'default': if devmode.is_in_dev_mode(): kwargs['admin_ui'] = 'enabled' else: kwargs['admin_ui'] = 'disabled' kwargs['admin_ui'] = kwargs['admin_ui'] == 'enabled' if not kwargs['instance_name']: if devmode.is_in_dev_mode(): kwargs['instance_name'] = '_localdev' else: kwargs['instance_name'] = '_unknown' kwargs['reload_config_files'] = ReloadTrigger( kwargs['reload_config_files'] ) kwargs['net_worker_mode'] = NetWorkerMode(kwargs['net_worker_mode']) for disallowed, replacement in ( ( 'EDGEDB_SERVER_CONFIG_cfg::listen_addresses', 'GEL_SERVER_BIND_ADDRESS', ), ( 'EDGEDB_SERVER_CONFIG_cfg::listen_port', 'GEL_SERVER_PORT', ), ( 'GEL_SERVER_CONFIG_cfg::listen_addresses', 'GEL_SERVER_BIND_ADDRESS', ), ( 'GEL_SERVER_CONFIG_cfg::listen_port', 'GEL_SERVER_PORT', ), ): if disallowed in os.environ: abort(f"{disallowed} is disallowed; use {replacement} instead") return ServerConfig( startup_script=startup_script, status_sinks=status_sinks, **kwargs, ) ================================================ FILE: edb/server/auth.py ================================================ import datetime import pathlib from typing import TYPE_CHECKING, Iterable, Optional, Any if TYPE_CHECKING: class SigningCtx: def __init__(self) -> None: ... def set_issuer(self, issuer: str) -> None: ... def set_audience(self, audience: str) -> None: ... def set_expiry(self, expiry: int) -> None: ... def set_not_before(self, not_before: int) -> None: ... class ValidationCtx: def __init__(self) -> None: ... def allow( self, claim: str, values: list[str] | Iterable[str], ) -> None: ... def deny( self, claim: str, values: list[str] | Iterable[str], ) -> None: ... def require(self, claim: str) -> None: ... def reject(self, claim: str) -> None: ... def ignore(self, claim: str) -> None: ... def require_expiry(self) -> None: ... def ignore_expiry(self) -> None: ... class JWKSet: @staticmethod def from_hs256_key(key: bytes) -> "JWKSet": ... def __init__(self) -> None: ... def generate(self, *, kid: Optional[str], kty: str) -> None: ... def add(self, **kwargs: Any) -> None: ... def load(self, keys: str) -> int: ... def load_json(self, keys: str) -> int: ... def export_pem(self, *, private_keys: bool=True) -> bytes: ... def export_json(self, *, private_keys: bool=True) -> bytes: ... def can_sign(self) -> bool: ... def can_validate(self) -> bool: ... def has_public_keys(self) -> bool: ... def has_private_keys(self) -> bool: ... def has_symmetric_keys(self) -> bool: ... def sign( self, claims: dict[str, Any], *, ctx: Optional[SigningCtx] = None ) -> str: ... def validate( self, token: str, *, ctx: Optional[ValidationCtx] = None ) -> dict[str, Any]: ... @property def default_signing_context(self) -> SigningCtx: ... @property def default_validation_context(self) -> ValidationCtx: ... class JWKSetCache: def __init__(self, expiry_seconds: int) -> None: ... # Returns a tuple of (is_fresh, registry) def get(self, key: str) -> tuple[bool, Optional[JWKSet]]: ... def set(self, key: str, registry: JWKSet) -> None: ... def generate_gel_token( registry: JWKSet, *, instances: Optional[list[str] | Iterable[str]] = None, roles: Optional[list[str] | Iterable[str]] = None, databases: Optional[list[str] | Iterable[str]] = None, **kwargs: Any, ) -> str: ... def validate_gel_token( registry: JWKSet, token: str, user: str, dbname: str, instance_name: str, ) -> str | None: ... else: from edb.server._rust_native._jwt import ( JWKSet, JWKSetCache, generate_gel_token, validate_gel_token, SigningCtx, ValidationCtx # noqa ) def load_secret_key(key_file: pathlib.Path) -> JWKSet: try: with open(key_file, 'rb') as kf: jws_key = JWKSet() jws_key.load(kf.read().decode('ascii')) except Exception as e: raise SecretKeyReadError(f"cannot load JWS key {key_file}: {e}") from e if not jws_key.can_validate(): raise SecretKeyReadError( f"the cluster JWS key file {key_file} does not " f"contain a valid key for token validation (RSA, EC or " f"HMAC-SHA256)") # TODO: We should also add a default issuer and add that to the allow-list. # Default to one day expiry for tokens -- we will probably tighten this up jws_key.default_signing_context.set_expiry(86400) # 60 second leeway for not before jws_key.default_signing_context.set_not_before(60) return jws_key def generate_jwk(keys_file: pathlib.Path) -> None: key = JWKSet() # kid is yyyymmdd kid = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d") key.generate(kid=kid, kty='ES256') if keys_file.name.endswith(".pem"): with keys_file.open("wb") as f: f.write(key.export_pem()) elif keys_file.name.endswith(".json"): with keys_file.open("wb") as f: f.write(key.export_json()) else: raise ValueError(f"Unsupported key file extension {keys_file.suffix}. " "Use .pem or .json extension when generating a key.") keys_file.chmod(0o600) class SecretKeyReadError(Exception): pass ================================================ FILE: edb/server/bootstrap.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, Optional, Iterable, Mapping, Awaitable, NamedTuple, TYPE_CHECKING, cast, ) import dataclasses import enum import json import logging import os import pathlib import pickle import re import struct import textwrap from edb import buildmeta from edb import errors from edb import edgeql from edb.ir import typeutils as irtyputils from edb.edgeql import ast as qlast from edb.edgeql import codegen as qlcodegen from edb.edgeql import qltypes from edb.common import debug from edb.common import devmode from edb.common import retryloop from edb.common import uuidgen from edb.schema import ddl as s_ddl from edb.schema import delta as sd from edb.schema import extensions as s_exts from edb.schema import functions as s_func from edb.schema import modules as s_mod from edb.schema import name as sn from edb.schema import objects as s_obj from edb.schema import properties as s_props from edb.schema import reflection as s_refl from edb.schema import schema as s_schema from edb.schema import std as s_std from edb.schema import types as s_types from edb.schema import utils as s_utils from edb.schema import version as s_ver from edb.server import args as edbargs from edb.server import config from edb.server import compiler as edbcompiler from edb.server.compiler import dbstate from edb.server import defines as edbdef from edb.server import pgcluster from edb.server import pgcon from edb.pgsql import common as pg_common from edb.pgsql import dbops from edb.pgsql import delta as delta_cmds from edb.pgsql import metaschema from edb.pgsql import params from edb.pgsql import patches from edb.pgsql import trampoline from edb.pgsql.common import quote_ident as qi from edgedb import scram if TYPE_CHECKING: import uuid logger = logging.getLogger('edb.server') STDLIB_CACHE_FILE_NAME = 'backend-stdlib.pickle' class ClusterMode(enum.IntEnum): pristine = 0 regular = 1 single_role = 2 single_database = 3 # A simple connection proxy that reconnects and retries queries # on connection errors. Helps defeat flaky connections and/or # flaky Postgres servers (Digital Ocean managed instances are # one example that has a weird setup that crashes a helper # process when we bootstrap, breaking other connections). class PGConnectionProxy: def __init__( self, cluster: pgcluster.BaseCluster, *, source_description: str, dbname: Optional[str] = None, log_listener: Optional[Callable[[str, str], None]] = None, ): self._conn: Optional[pgcon.PGConnection] = None self._cluster = cluster self._dbname = dbname self._log_listener = log_listener or _pg_log_listener self._source_description = source_description async def connect(self) -> None: if self._conn is not None: self._conn.terminate() if self._dbname: self._conn = await self._cluster.connect( source_description=self._source_description, database=self._dbname ) else: self._conn = await self._cluster.connect( source_description=self._source_description ) if self._log_listener is not None: self._conn.add_log_listener(self._log_listener) def _on_retry(self, exc: Optional[BaseException]) -> None: logger.warning( f'Retrying bootstrap SQL query due to connection error: ' f'{type(exc)}({exc})', ) self.terminate() async def _retry_conn_errors[T]( self, task: Callable[[], Awaitable[T]], ) -> T: rloop = retryloop.RetryLoop( backoff=retryloop.exp_backoff(), timeout=5.0, ignore=( ConnectionError, pgcon.BackendConnectionError, ), retry_cb=self._on_retry, ) async for iteration in rloop: async with iteration: if self._conn is None: await self.connect() result = await task() return result async def sql_execute(self, sql: bytes) -> None: async def _task() -> None: assert self._conn is not None await self._conn.sql_execute(sql) return await self._retry_conn_errors(_task) async def sql_fetch( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> list[tuple[bytes, ...]]: async def _task() -> list[tuple[bytes, ...]]: assert self._conn is not None return await self._conn.sql_fetch(sql, args=args) return await self._retry_conn_errors(_task) async def sql_fetch_val( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> bytes: async def _task() -> bytes: assert self._conn is not None return await self._conn.sql_fetch_val(sql, args=args) return await self._retry_conn_errors(_task) async def sql_fetch_col( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), ) -> list[bytes]: async def _task() -> list[bytes]: assert self._conn is not None return await self._conn.sql_fetch_col(sql, args=args) return await self._retry_conn_errors(_task) def terminate(self) -> None: if self._conn is not None: self._conn.terminate() self._conn = None @dataclasses.dataclass class BootstrapContext: cluster: pgcluster.BaseCluster conn: PGConnectionProxy | pgcon.PGConnection args: edbargs.ServerConfig mode: Optional[ClusterMode] = None async def _execute(conn, query): return await metaschema.execute_sql_script(conn, query) async def _execute_block(conn, block: dbops.SQLBlock) -> None: if not block.is_transactional(): stmts = block.get_statements() else: stmts = [block.to_string()] for stmt in stmts: await _execute(conn, stmt) def _execute_edgeql_ddl[Schema_T: s_schema.Schema]( schema: Schema_T, ddltext: str, stdmode: bool = True, ) -> Schema_T: context = sd.CommandContext(stdmode=stdmode) for ddl_cmd in edgeql.parse_block(ddltext): assert isinstance(ddl_cmd, qlast.DDLCommand) delta_command = s_ddl.delta_from_ddl( ddl_cmd, modaliases={}, schema=schema, stdmode=stdmode) schema = delta_command.apply(schema, context) # type: ignore return schema async def _ensure_edgedb_supergroup( ctx: BootstrapContext, role_name: str, *, member_of: Iterable[str] = (), members: Iterable[str] = (), ) -> None: member_of = set(member_of) backend_params = ctx.cluster.get_runtime_params() superuser_role = backend_params.instance_params.base_superuser if superuser_role: # If the cluster is exposing an explicit superuser role, # become a member of that instead of creating a superuser # role directly. member_of.add(superuser_role) pg_role_name = ctx.cluster.get_role_name(role_name) role = dbops.Role( name=pg_role_name, superuser=backend_params.has_superuser_access, allow_login=False, allow_createdb=True, allow_createrole=True, membership=member_of, members=members, ) create_role = dbops.CreateRole( role, neg_conditions=[dbops.RoleExists(pg_role_name)], ) block = dbops.PLTopBlock() create_role.generate(block) await _execute_block(ctx.conn, block) async def _ensure_edgedb_role( ctx: BootstrapContext, role_name: str, *, superuser: bool = False, builtin: bool = False, objid: Optional[uuid.UUID] = None, ) -> uuid.UUID: member_of = set() if superuser: member_of.add(edbdef.EDGEDB_SUPERGROUP) if objid is None: objid = uuidgen.uuid1mc() members = set() login_role = ctx.cluster.get_connection_params().user assert login_role is not None sup_role = ctx.cluster.get_role_name(edbdef.EDGEDB_SUPERUSER) if login_role != sup_role: members.add(login_role) backend_params = ctx.cluster.get_runtime_params() pg_role_name = ctx.cluster.get_role_name(role_name) role = dbops.Role( name=pg_role_name, superuser=superuser and backend_params.has_superuser_access, allow_login=True, allow_createdb=True, allow_createrole=True, membership=[ctx.cluster.get_role_name(m) for m in member_of], members=members, metadata=dict( id=str(objid), name=role_name, tenant_id=backend_params.tenant_id, builtin=builtin, branches=['*'], ), ) create_role = dbops.CreateRole( role, neg_conditions=[dbops.RoleExists(pg_role_name)], ) block = dbops.PLTopBlock() create_role.generate(block) await _execute_block(ctx.conn, block) return objid async def _get_cluster_mode(ctx: BootstrapContext) -> ClusterMode: backend_params = ctx.cluster.get_runtime_params() tenant_id = backend_params.tenant_id # First, check the existence of EDGEDB_SUPERGROUP - the role which is # usually created at the beginning of bootstrap. is_default_tenant = tenant_id == buildmeta.get_default_tenant_id() ignore_others = is_default_tenant and ctx.args.ignore_other_tenants if is_default_tenant: result = await ctx.conn.sql_fetch_col( b""" SELECT r.rolname FROM pg_catalog.pg_roles AS r WHERE r.rolname LIKE ('%' || $1) """, args=[ edbdef.EDGEDB_SUPERGROUP.encode("utf-8"), ], ) else: result = await ctx.conn.sql_fetch_col( b""" SELECT r.rolname FROM pg_catalog.pg_roles AS r WHERE r.rolname = $1 """, args=[ ctx.cluster.get_role_name( edbdef.EDGEDB_SUPERGROUP).encode("utf-8"), ], ) if result: if not ignore_others: # Either our tenant slot is occupied, or there is # a default tenant present. return ClusterMode.regular # We were explicitly asked to ignore the other default tenant, # so check specifically if our tenant slot is occupied and ignore # the others. # This mode is used for in-place upgrade. for rolname in result: other_tenant_id = rolname[: -(len(edbdef.EDGEDB_SUPERGROUP) + 1)] if other_tenant_id == tenant_id.encode("utf-8"): return ClusterMode.regular # Then, check if the current database was bootstrapped in single-db mode. has_instdata = await ctx.conn.sql_fetch_val( trampoline.fixup_query(''' SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname = 'edgedbinstdata_VER' AND tablename = 'instdata' ''').encode('utf-8'), ) if has_instdata: return ClusterMode.single_database # At last, check for single-role-bootstrapped instance by trying to find # the Gel System DB with the assumption that we are not running in # single-db mode. If not found, this is a pristine backend cluster. if is_default_tenant: result = await ctx.conn.sql_fetch_col( b''' SELECT datname FROM pg_database WHERE datname LIKE '%' || $1 ''', args=( edbdef.EDGEDB_SYSTEM_DB.encode("utf-8"), ), ) else: result = await ctx.conn.sql_fetch_col( b''' SELECT datname FROM pg_database WHERE datname = $1 ''', args=( ctx.cluster.get_db_name( edbdef.EDGEDB_SYSTEM_DB).encode("utf-8"), ), ) if result: if not ignore_others: # Either our tenant slot is occupied, or there is # a default tenant present. return ClusterMode.single_role # We were explicitly asked to ignore the other default tenant, # so check specifically if our tenant slot is occupied and ignore # the others. # This mode is used for in-place upgrade. for dbname in result: other_tenant_id = dbname[: -(len(edbdef.EDGEDB_SYSTEM_DB) + 1)] if other_tenant_id == tenant_id.encode("utf-8"): return ClusterMode.single_role return ClusterMode.pristine async def _create_edgedb_template_database( ctx: BootstrapContext, ) -> uuid.UUID: backend_params = ctx.cluster.get_runtime_params() have_c_utf8 = backend_params.has_c_utf8_locale logger.info('Creating template database...') block = dbops.SQLBlock() dbid = uuidgen.uuid1mc() db = dbops.Database( ctx.cluster.get_db_name(edbdef.EDGEDB_TEMPLATE_DB), owner=ctx.cluster.get_role_name(edbdef.EDGEDB_SUPERUSER), is_template=True, lc_collate='C', lc_ctype='C.UTF-8' if have_c_utf8 else 'en_US.UTF-8', encoding='UTF8', metadata=dict( id=str(dbid), tenant_id=backend_params.tenant_id, name=edbdef.EDGEDB_TEMPLATE_DB, builtin=True, ), ) dbops.CreateDatabase(db, template='template0').generate(block) await _execute_block(ctx.conn, block) return dbid async def _store_static_bin_cache_conn( conn: metaschema.PGConnection, key: str, data: bytes, ) -> None: text = trampoline.fixup_query(f"""\ INSERT INTO edgedbinstdata_VER.instdata (key, bin) VALUES( {pg_common.quote_literal(key)}, {pg_common.quote_bytea_literal(data)} ) """) await _execute(conn, text) async def _store_static_bin_cache( ctx: BootstrapContext, key: str, data: bytes, ) -> None: await _store_static_bin_cache_conn(ctx.conn, key, data) async def _store_static_text_cache( ctx: BootstrapContext, key: str, data: str, ) -> None: text = trampoline.fixup_query(f"""\ INSERT INTO edgedbinstdata_VER.instdata (key, text) VALUES( {pg_common.quote_literal(key)}, {pg_common.quote_literal(data)}::text ) """) await _execute(ctx.conn, text) async def _store_static_json_cache( ctx: BootstrapContext, key: str, data: str, ) -> None: text = trampoline.fixup_query(f"""\ INSERT INTO edgedbinstdata_VER.instdata (key, json) VALUES( {pg_common.quote_literal(key)}, {pg_common.quote_literal(data)}::jsonb ) """) await _execute(ctx.conn, text) def _process_delta_params[Schema_T: s_schema.Schema]( delta: sd.Command, schema: Schema_T, params: params.BackendRuntimeParams, stdmode: bool=True, **kwargs, ) -> tuple[ Schema_T, delta_cmds.MetaCommand, delta_cmds.CreateTrampolines, ]: """Adapt and process the delta command.""" if debug.flags.delta_plan: debug.header('Delta Plan') debug.dump(delta, schema=schema) context = sd.CommandContext(stdmode=True) if not delta.canonical: # Canonicalize sd.apply(delta, schema=schema) delta_pg: delta_cmds.MetaCommand = delta_cmds.CommandMeta.adapt(delta) # type: ignore context = sd.CommandContext( stdmode=stdmode, backend_runtime_params=params, **kwargs, ) schema = sd.apply(delta_pg, schema=schema, context=context) if debug.flags.delta_pgsql_plan: debug.header('PgSQL Delta Plan') debug.dump(delta_pg, schema=schema) if isinstance(delta_pg, delta_cmds.DeltaRoot): out = delta_pg.create_trampolines else: out = delta_cmds.CreateTrampolines() return schema, delta_pg, out def _process_delta[Schema_T: s_schema.Schema]( ctx: BootstrapContext, delta: sd.Command, schema: Schema_T, ) -> tuple[ Schema_T, delta_cmds.MetaCommand, delta_cmds.CreateTrampolines, ]: """Adapt and process the delta command.""" return _process_delta_params( delta, schema, ctx.cluster.get_runtime_params() ) def compile_bootstrap_script( compiler: edbcompiler.Compiler, schema: s_schema.Schema, eql: str, *, bootstrap_mode: bool = True, expected_cardinality_one: bool = False, output_format: edbcompiler.OutputFormat = edbcompiler.OutputFormat.JSON, ) -> tuple[s_schema.Schema, str]: ctx = edbcompiler.new_compiler_context( compiler_state=compiler.state, user_schema=schema, expected_cardinality_one=expected_cardinality_one, json_parameters=True, output_format=output_format, bootstrap_mode=bootstrap_mode, log_ddl_as_migrations=False, ) return edbcompiler.compile_edgeql_script(ctx, eql) def compile_single_query( eql: str, compilerctx: edbcompiler.CompileContext, ) -> str: ql_source = edgeql.Source.from_string(eql) units = edbcompiler.compile(ctx=compilerctx, source=ql_source).units assert len(units) == 1 return units[0].sql.decode() def _get_all_subcommands( cmd: sd.Command, type: Optional[type[sd.Command]] = None ) -> list[sd.Command]: cmds = [] def go(cmd): if not type or isinstance(cmd, type): cmds.append(cmd) for sub in cmd.get_subcommands(): go(sub) go(cmd) return cmds def _get_schema_object_ids( delta: sd.Command, ) -> Mapping[tuple[sn.Name, Optional[str]], uuid.UUID]: schema_object_ids = {} for cmd in _get_all_subcommands(delta, sd.CreateObject): assert isinstance(cmd, sd.CreateObject) mcls = cmd.get_schema_metaclass() if issubclass(mcls, s_obj.QualifiedObject): qlclass = None else: qlclass = mcls.get_ql_class_or_die() id = cmd.get_attribute_value('id') schema_object_ids[cmd.classname, qlclass] = id # backend_name in callables is a lot *like* an id, in that it gets # randomly generated and needs to match between things. if isinstance(cmd, s_func.CreateCallableObject): backend_name = cmd.get_attribute_value('backend_name') if backend_name: schema_object_ids[ cmd.classname, f'{qlclass}-backend_name'] = backend_name return schema_object_ids def prepare_repair_patch( stdschema: s_schema.Schema, reflschema: s_schema.Schema, userschema: s_schema.Schema, globalschema: s_schema.Schema, schema_class_layout: s_refl.SchemaClassLayout, backend_params: params.BackendRuntimeParams, ) -> str: compiler = edbcompiler.new_compiler( std_schema=stdschema, reflection_schema=reflschema, schema_class_layout=schema_class_layout ) compilerctx = edbcompiler.new_compiler_context( compiler_state=compiler.state, global_schema=globalschema, user_schema=userschema, ) res = edbcompiler.repair_schema(compilerctx) if not res: return "" sql, _ = res return sql.decode('utf-8') PatchEntry = tuple[tuple[str, ...], tuple[str, ...], dict[str, Any]] async def get_existing_view_columns( conn: pgcon.PGConnection | PGConnectionProxy, ) -> dict[str, list[str]]: # Find all the config views (they are pg_classes where # there is also a table with the same name but "_dummy" # at the end) and collect all their columns in order. schema = pg_common.versioned_schema("edgedbstd") return json.loads(await conn.sql_fetch_val(f'''\ select json_object_agg(v.relname, ( select json_agg(a.attname order by a.attnum) from pg_catalog.pg_attribute as a where v.oid = a.attrelid )) from pg_catalog.pg_class as v inner join pg_catalog.pg_tables as t on v.relname || '_dummy' = t.tablename -- Filter for just our namespace! inner join pg_catalog.pg_namespace as ns on v.relnamespace = ns.oid where ns.nspname = '{schema}' OR ns.nspname = 'edgedbpub' '''.encode('utf-8'))) async def gather_patch_info( num: int, kind: str, patch: str, conn: pgcon.PGConnection | PGConnectionProxy, ) -> Optional[dict[str, list[str]]]: """Fetch info for a patch that needs to use the connection. Currently, the only thing we need is, for config updates, the order that columns appear in the config views in SQL. We need this because we need to preserve that order when we update the view. """ if '+config' in kind: return await get_existing_view_columns(conn) else: return None def prepare_patch( num: int, kind: str, patch: str, schema: s_schema.Schema, reflschema: s_schema.Schema, schema_class_layout: s_refl.SchemaClassLayout, backend_params: params.BackendRuntimeParams, patch_info: Optional[dict[str, list[str]]], user_schema: Optional[s_schema.Schema]=None, global_schema: Optional[s_schema.Schema]=None, *, dbname: Optional[str]=None, ) -> PatchEntry: val = f'{pg_common.quote_literal(json.dumps(num + 1))}::jsonb' # TODO: This is an INSERT because 2.0 shipped without num_patches. # We can just make this an UPDATE for 3.0 update = trampoline.fixup_query(f"""\ INSERT INTO edgedbinstdata_VER.instdata (key, json) VALUES('num_patches', {val}) ON CONFLICT (key) DO UPDATE SET json = {val}; """) existing_view_columns = patch_info if '+testmode' in kind: if schema.get('cfg::TestSessionConfig', default=None): kind = kind.replace('+testmode', '') else: return (update,), (), {} # Pure SQL patches are simple if kind == 'sql': return (patch, update), (), {} # metaschema-sql: just recreate a function from metaschema if kind == 'metaschema-sql': func = getattr(metaschema, patch) create = dbops.CreateFunction(func(), or_replace=True) block = dbops.PLTopBlock() create.generate(block) return (block.to_string(), update), (), {} if kind.startswith('repair'): assert not patch if not user_schema: return (update,), (), dict(is_user_update=True) assert global_schema if kind.startswith('repair+user_ext'): # Only run a userext update if the extension we are trying to # update is installed. extension_name = kind.split('|')[-1] extension = user_schema.get_global( s_exts.Extension, extension_name, default=None) if not extension: return (update,), (), {} # TODO: Implement the last-repair-only optimization? try: logger.info("repairing database '%s'", dbname) sql = prepare_repair_patch( schema, reflschema, user_schema, global_schema, schema_class_layout, backend_params ) except errors.EdgeDBError as e: if isinstance(e, errors.InternalServerError): raise raise errors.SchemaError( f'Could not repair schema inconsistencies in ' f'database branch "{dbname}". Probably the schema is ' f'no longer valid due to a bug fix.\n' f'Downgrade to the last working version, fix ' f'the schema issue, and try again.' ) from e return (update, sql), (), {} # EdgeQL and reflection schema patches need to be compiled. current_block = dbops.PLTopBlock() preblock = current_block.add_block() subblock = current_block.add_block() std_plans = [] updates: dict[str, Any] = {} global_schema_update = kind == 'ext-pkg' sys_update_only = global_schema_update or kind.endswith('+globalonly') if kind == 'ext-pkg': # N.B: We process this without actually having the global # schema present, so we don't do any check for if it already # exists. The backend code will overwrite an older version's # JSON in the global metadata if it was already present. patch = s_std.get_std_module_text(sn.UnqualName(f'ext/{patch}')) if ( kind == 'edgeql' or kind == 'ext-pkg' or kind.startswith('edgeql+schema') ): assert '+user_ext' not in kind for ddl_cmd in edgeql.parse_block(patch): if not isinstance(ddl_cmd, qlast.DDLCommand): assert isinstance(ddl_cmd, qlast.Query) ddl_cmd = qlast.DDLQuery(query=ddl_cmd) # First apply it to the regular schema, just so we can update # stdschema delta_command = s_ddl.delta_from_ddl( ddl_cmd, modaliases={}, schema=schema, stdmode=True) schema, _, _ = _process_delta_params( delta_command, schema, backend_params) # We need to extract all ids of new objects created when # applying it to the regular schema, so that we can make sure # to use the same ids in the reflschema. schema_object_ids = _get_schema_object_ids(delta_command) # Then apply it to the reflschema, which we will use to drive # the actual table updating. delta_command = s_ddl.delta_from_ddl( ddl_cmd, modaliases={}, schema=reflschema, schema_object_ids=schema_object_ids, stdmode=True) reflschema, plan, tplan = _process_delta_params( delta_command, reflschema, backend_params) std_plans.append(delta_command) plan.generate(subblock) tplan.generate(subblock) metadata_user_schema = reflschema elif kind.startswith('edgeql+user') or kind.startswith('edgeql+user_ext'): assert '+schema' not in kind # There isn't anything to do on the system database for # user updates. if user_schema is None: return (update,), (), dict(is_user_update=True) if kind.startswith('edgeql+user_ext'): # Only run a userext update if the extension we are trying to # update is installed. extension_name = kind.split('|')[-1] extension = user_schema.get_global( s_exts.Extension, extension_name, default=None) if not extension: return (update,), (), {} assert global_schema cschema = s_schema.ChainedSchema( schema, user_schema, global_schema, ) for ddl_cmd in edgeql.parse_block(patch): if not isinstance(ddl_cmd, qlast.DDLCommand): assert isinstance(ddl_cmd, qlast.Query) ddl_cmd = qlast.DDLQuery(query=ddl_cmd) delta_command = s_ddl.delta_from_ddl( ddl_cmd, modaliases={}, schema=cschema, stdmode=False, testmode=True, ) # Prune any AlterSchemaVersion commands, because they # won't work, since we defer all the # compile_schema_storage_in_delta calls to the end. for sub in delta_command.get_subcommands( type=s_ver.AlterSchemaVersion ): delta_command.discard(sub) cschema, plan, tplan = _process_delta_params( delta_command, cschema, backend_params) std_plans.append(delta_command) plan.generate(subblock) tplan.generate(subblock) if '+config' in kind: views = metaschema.get_config_views(cschema, existing_view_columns) views.generate(subblock) metadata_user_schema = cschema.get_top_schema() elif kind == 'sql-introspection': support_view_commands = dbops.CommandGroup() support_view_commands.add_commands( metaschema._generate_sql_information_schema( backend_params.instance_params.version ) ) support_view_commands.generate(subblock) metaschema.generate_drop_views(list(support_view_commands), preblock) metadata_user_schema = reflschema else: raise AssertionError(f'unknown patch type {kind}') if kind.startswith('edgeql+schema'): # If we are modifying the schema layout, we need to rerun # generate_structure to collect schema changes not reflected # in the public schema and to discover the new introspection # query. reflection = s_refl.generate_structure( reflschema, make_funcs=False, patch_level=patches.get_patch_level(num), ) reflschema, plan, tplan = _process_delta_params( reflection.intro_schema_delta, reflschema, backend_params) plan.generate(subblock) tplan.generate(subblock) compiler = edbcompiler.new_compiler( std_schema=schema, reflection_schema=reflschema, schema_class_layout=schema_class_layout ) local_intro_sql, global_intro_sql = compile_intro_queries_stdlib( compiler=compiler, user_schema=reflschema, reflection=reflection, ) updates.update(dict( classlayout=reflection.class_layout, local_intro_query=local_intro_sql.encode('utf-8'), global_intro_query=global_intro_sql.encode('utf-8'), )) # This part is wildly hinky # We need to delete all the support views and recreate them at the end support_view_commands = dbops.CommandGroup() support_view_commands.add_commands([ dbops.CreateView(view) for view in metaschema._generate_schema_alias_views( reflschema, sn.UnqualName('schema') ) + metaschema._generate_schema_alias_views( reflschema, sn.UnqualName('sys') ) ]) support_view_commands.add_commands( metaschema._generate_sql_information_schema( backend_params.instance_params.version ) ) wrapper_views = metaschema._get_wrapper_views() support_view_commands.add_commands(list(wrapper_views)) trampolines = metaschema.trampoline_command(wrapper_views) metaschema.generate_drop_views( tuple(support_view_commands) + tuple(trampolines), preblock, ) # Now add the trampolines to support_view_commands support_view_commands.add_commands([t.make() for t in trampolines]) # We want to limit how much unconditional work we do, so only recreate # extension views if requested. if '+exts' in kind: for extview in metaschema._generate_extension_views(reflschema): support_view_commands.add_command( dbops.CreateView(extview, or_replace=True)) # Though we always update the instdata for the config system, # because it is currently the most convenient way to make sure # all the versioned fields get updated. config_spec = config.load_spec_from_schema(schema) # Similarly, only do config system updates if requested. if '+config' in kind: support_view_commands.add_command( metaschema.get_config_views(schema, existing_view_columns)) support_view_commands.add_command( metaschema._get_regenerated_config_support_functions( config_spec ) ) ( sysqueries, report_configs_typedesc_1_0, report_configs_typedesc_2_0, ) = compile_sys_queries( reflschema, compiler, config_spec, ) updates.update(dict( sysqueries=json.dumps(sysqueries).encode('utf-8'), report_configs_typedesc_1_0=report_configs_typedesc_1_0, report_configs_typedesc_2_0=report_configs_typedesc_2_0, configspec=config.spec_to_json(config_spec).encode('utf-8'), )) support_view_commands.generate(subblock) compiler = edbcompiler.new_compiler( std_schema=schema, reflection_schema=reflschema, schema_class_layout=schema_class_layout ) compilerctx = edbcompiler.new_compiler_context( compiler_state=compiler.state, user_schema=metadata_user_schema, bootstrap_mode=user_schema is None, ) for std_plan in std_plans: edbcompiler.compile_schema_storage_in_delta( ctx=compilerctx, delta=std_plan, block=subblock, ) patch = current_block.to_string() if debug.flags.delta_execute: debug.header('Patch Script') debug.dump_code(patch, lexer='sql') if not global_schema_update: updates.update(dict( std_and_reflection_schema=(schema, reflschema), )) bins = ( 'std_and_reflection_schema', 'global_schema', 'classlayout', 'report_configs_typedesc_1_0', 'report_configs_typedesc_2_0', ) rawbin = ( 'report_configs_typedesc_1_0', 'report_configs_typedesc_2_0', ) jsons = ( 'sysqueries', 'configspec', ) # This is unversioned because it is consumed by a function in metaschema. # (And only by a function in metaschema.) unversioned = ( 'configspec', ) # Just for the system database, we need to update the cached pickle # of everything. version_key = patches.get_version_key(num + 1) sys_updates: tuple[str, ...] = () spatches: tuple[str, ...] = (patch,) for k, v in updates.items(): key = f"'{k}{version_key}'" if k not in unversioned else f"'{k}'" if k in bins: if k not in rawbin: v = pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) val = f'{pg_common.quote_bytea_literal(v)}' sys_updates += (trampoline.fixup_query(f''' INSERT INTO edgedbinstdata_VER.instdata (key, bin) VALUES({key}, {val}) ON CONFLICT (key) DO UPDATE SET bin = {val}; '''),) else: typ, col = ('jsonb', 'json') if k in jsons else ('text', 'text') val = f'{pg_common.quote_literal(v.decode("utf-8"))}::{typ}' sys_updates += (trampoline.fixup_query(f''' INSERT INTO edgedbinstdata_VER.instdata (key, {col}) VALUES({key}, {val}) ON CONFLICT (key) DO UPDATE SET {col} = {val}; '''),) if k in unversioned: spatches += (sys_updates[-1],) # If we're updating the global schema (for extension packages, # perhaps), only run the script once, on the system connection. # Since the state is global, we only should update it once. regular_updates: tuple[str, ...] if sys_update_only: regular_updates = (update,) sys_updates = (patch,) + sys_updates else: regular_updates = spatches + (update,) # FIXME: This is a hack to make the is_user_update cases # work (by ensuring we can always read their current state), # but this is actually a pretty dumb approach and we can do # better. regular_updates += sys_updates return regular_updates, sys_updates, updates async def create_branch( cluster: pgcluster.BaseCluster, schema: s_schema.Schema, conn: metaschema.PGConnection, src_dbname: str, tgt_dbname: str, mode: str, backend_id_fixup_sql: bytes, ) -> None: """Create a new database (branch) based on an existing one.""" # Dump the edgedbpub schema that holds user data and any # extensions. Also dump edgedbext, which can unfortunately # include some tables/views for the AI extension. (And some # extensions, which get created with IF NOT EXISTS, so that is # fine.) schema_dump = await cluster.dump_database( src_dbname, include_schemas=('edgedbpub', 'edgedbext'), include_extensions=('*',), schema_only=True, ) # Tuples types are always kept in edgedbpub, but some already # exist from the std schema, so we need to skip those. We also # need to skip recreating the schema. This requires doing some # annoying postprocessing. to_skip = [ str(obj.id) for obj in schema.get_objects(type=s_types.Tuple) ] old_lines = schema_dump.decode('utf-8').split('\n') new_lines = [] skipping = False for line in old_lines: if line == ');' and skipping: skipping = False continue elif line.startswith('CREATE SCHEMA'): continue elif line.startswith('CREATE TYPE'): if any(skip in line for skip in to_skip): skipping = True elif line == 'SET transaction_timeout = 0;': continue if skipping: continue new_lines.append(line) s_schema_dump = '\n'.join(new_lines) await conn.sql_execute(s_schema_dump.encode('utf-8')) # Copy database config variables over directly copy_cfg_query = f''' select edgedb._copy_database_configs( {pg_common.quote_literal(src_dbname)}) '''.encode('utf-8') await conn.sql_execute(copy_cfg_query) # HACK: Empty out all schema multi property tables. This is # because the original template has the stdschema in it, and so we # use --on-conflict-do-nothing to avoid conflicts since the dump # will have that in it too. That works, except for multi properties # where it won't conflict, and modules, which might have a different # 'default' module on each side. (Since it isn't in the stdschema, # and could have an old id persisted from an in-place upgrade.) to_delete: set[s_obj.Object] = { prop for prop in schema.get_objects(type=s_props.Property) if prop.get_cardinality(schema).is_multi() and prop.get_name(schema).module not in irtyputils.VIEW_MODULES } to_delete.add(schema.get('schema::Module')) for target in to_delete: name = pg_common.get_backend_name(schema, target, catenate=True) await conn.sql_execute(f'delete from {name}'.encode('utf-8')) await conn.sql_execute(trampoline.fixup_query(f''' delete from edgedbinstdata_VER.instdata where key = 'configspec_ext' ''').encode('utf-8')) # Do the dump/restore for the data. We always need to copy over # edgedbstd, since it has the reflected schema. We copy over # edgedbpub when it is a data branch. data_arg = ['--table=edgedbpub.*'] if mode == qlast.BranchType.DATA else [] dump_args = [ '--data-only', '--table=edgedbstd.*', f'--table={pg_common.versioned_schema("edgedbstd")}.*', '--table=edgedb._db_config', f'--table={pg_common.versioned_schema("edgedbinstdata")}.instdata', *data_arg, '--disable-triggers', # We need to use --inserts so that we can use --on-conflict-do-nothing. # (See above, in discussion of the HACK.) '--inserts', '--rows-per-insert=100', '--on-conflict-do-nothing', ] await cluster._copy_database( src_dbname, tgt_dbname, dump_args, [], ) # Restore the search_path as the dump might have altered it. await conn.sql_execute( b"SELECT pg_catalog.set_config('search_path', 'edgedb', false)") # Fixup the backend ids in the schema to match what is actually in pg. await conn.sql_execute(backend_id_fixup_sql) class StdlibBits(NamedTuple): #: User-visible std. stdschema: s_schema.Schema #: Shadow extended schema for reflection.. reflschema: s_schema.Schema #: Standard portion of the global schema global_schema: s_schema.Schema #: SQL text of the procedure to initialize `std` in Postgres. sqltext: str #: SQL text of the procedure to create all `std` scalars for inplace #: upgrades inplace_upgrade_scalar_text: str #: SQL text of the procedure to recreate all extension packages #: for inplace upgrades. inplace_upgrade_extension_packages_text: str #: Descriptors of all the needed trampolines trampolines: list[trampoline.Trampoline] #: A set of ids of all types in std. types: set[uuid.UUID] #: Schema class reflection layout. classlayout: dict[type[s_obj.Object], s_refl.SchemaTypeLayout] #: Schema introspection SQL query. local_intro_query: str #: Global object introspection SQL query. global_intro_query: str #: Number of patches already baked into the stdlib. num_patches: int # TODO: All of sysqueries ought to go here, right? # It would speed up instance creation a bit at little cost. # (Oh, maybe testmode screws this idea up?) def _make_stdlib( ctx: BootstrapContext, testmode: bool, global_ids: Mapping[str, uuid.UUID], ) -> StdlibBits: schema: s_schema.ChainedSchema = s_schema.ChainedSchema( s_schema.EMPTY_SCHEMA, s_schema.EMPTY_SCHEMA, s_schema.EMPTY_SCHEMA, ) for special_mod in s_schema.SPECIAL_MODULES: schema, _ = s_mod.Module.create_in_schema( schema, name=special_mod, stable_ids=True, ) current_block = dbops.PLTopBlock() trampolines = [] std_texts = [] for modname in s_schema.STD_SOURCES: std_texts.append(s_std.get_std_module_text(modname)) if testmode: for modname in s_schema.TESTMODE_SOURCES: std_texts.append(s_std.get_std_module_text(modname)) ddl_text = '\n'.join(std_texts) types: set[uuid.UUID] = set() std_plans: list[sd.Command] = [] specials = [] def _collect_special(cmd): if isinstance( cmd, (dbops.CreateEnum, delta_cmds.CreateExtensionPackage), ): specials.append(cmd) elif isinstance(cmd, dbops.CommandGroup): for sub in cmd.commands: _collect_special(sub) elif isinstance(cmd, delta_cmds.MetaCommand): for sub in cmd.pgops: _collect_special(sub) for ddl_cmd in edgeql.parse_block(ddl_text): assert isinstance(ddl_cmd, qlast.DDLCommand) delta_command = s_ddl.delta_from_ddl( ddl_cmd, modaliases={}, schema=schema, stdmode=True) # Apply and adapt delta, build native delta plan, which # will also update the schema. schema, plan, tplan = _process_delta(ctx, delta_command, schema) assert isinstance(plan, delta_cmds.DeltaRoot) std_plans.append(delta_command) _collect_special(plan) types.update(plan.new_types) plan.generate(current_block) trampolines.extend(tplan.trampolines) _, schema_version = s_std.make_schema_version(schema) schema, plan, tplan = _process_delta(ctx, schema_version, schema) std_plans.append(schema_version) plan.generate(current_block) trampolines.extend(tplan.trampolines) stdglobals = '\n'.join([ f'''CREATE SUPERUSER ROLE {edbdef.EDGEDB_SUPERUSER} {{ SET id := '{global_ids[edbdef.EDGEDB_SUPERUSER]}' }};''', ]) schema = _execute_edgeql_ddl(schema, stdglobals) _, global_schema_version = s_std.make_global_schema_version(schema) schema, plan, tplan = _process_delta(ctx, global_schema_version, schema) std_plans.append(global_schema_version) plan.generate(current_block) trampolines.extend(tplan.trampolines) reflection = s_refl.generate_structure(schema) reflschema, reflplan, treflplan = _process_delta( ctx, reflection.intro_schema_delta, schema) # Any collection types that made it into reflschema need to get # to get pulled back into the stdschema, or else they will be in # an inconsistent state. for obj in reflschema.get_objects(type=s_types.Collection): if not schema.has_object(obj.id): delta = sd.DeltaRoot() delta.add(obj.as_shell(reflschema).as_create_delta(reflschema)) schema = cast( s_schema.ChainedSchema, delta.apply(schema, sd.CommandContext(stdmode=True)) ) assert isinstance(schema, s_schema.ChainedSchema) assert current_block is not None reflplan.generate(current_block) trampolines.extend(treflplan.trampolines) subblock = current_block.add_block() compiler = edbcompiler.new_compiler( std_schema=schema.get_top_schema(), reflection_schema=reflschema.get_top_schema(), schema_class_layout=reflection.class_layout, # type: ignore ) backend_runtime_params = ctx.cluster.get_runtime_params() compilerctx = edbcompiler.new_compiler_context( compiler_state=compiler.state, user_schema=reflschema.get_top_schema(), global_schema=schema.get_global_schema(), bootstrap_mode=True, backend_runtime_params=backend_runtime_params, ) for std_plan in std_plans: edbcompiler.compile_schema_storage_in_delta( ctx=compilerctx, delta=std_plan, block=subblock, ) compilerctx = edbcompiler.new_compiler_context( compiler_state=compiler.state, user_schema=reflschema.get_top_schema(), global_schema=schema.get_global_schema(), bootstrap_mode=True, internal_schema_mode=True, backend_runtime_params=backend_runtime_params, ) edbcompiler.compile_schema_storage_in_delta( ctx=compilerctx, delta=reflection.intro_schema_delta, block=subblock, ) sqltext = current_block.to_string() local_intro_sql, global_intro_sql = compile_intro_queries_stdlib( compiler=compiler, user_schema=reflschema.get_top_schema(), global_schema=schema.get_global_schema(), reflection=reflection, ) # Sigh, we need to be able to create all std scalar types that # might get added. # # TODO: Also collect tuple types, and generalize enum handling to # be able to *add* enum fields. (Which will involve some # pl/pgsql??) scalar_block = dbops.PLTopBlock() extension_package_block = dbops.PLTopBlock() for cmd in specials: if isinstance(cmd, dbops.CreateEnum): ncmd = dbops.CreateEnum( dbops.Enum(cmd.name, cmd.values), neg_conditions=[dbops.EnumExists(cmd.name)], ) ncmd.generate(scalar_block) elif isinstance(cmd, delta_cmds.CreateExtensionPackage): cmd.generate(extension_package_block) # Got it! return StdlibBits( stdschema=schema.get_top_schema(), reflschema=reflschema.get_top_schema(), global_schema=schema.get_global_schema(), sqltext=sqltext, inplace_upgrade_scalar_text=scalar_block.to_string(), inplace_upgrade_extension_packages_text=( extension_package_block.to_string()), trampolines=trampolines, types=types, classlayout=reflection.class_layout, local_intro_query=local_intro_sql, global_intro_query=global_intro_sql, num_patches=len(patches.PATCHES), ) async def _amend_stdlib( ctx: BootstrapContext, ddl_text: str, stdlib: StdlibBits, ) -> tuple[StdlibBits, str, list[trampoline.Trampoline]]: schema = s_schema.ChainedSchema( s_schema.EMPTY_SCHEMA, stdlib.stdschema, stdlib.global_schema, ) reflschema = stdlib.reflschema topblock = dbops.PLTopBlock() plans = [] trampolines = [] context = sd.CommandContext(stdmode=True) for ddl_cmd in edgeql.parse_block(ddl_text): assert isinstance(ddl_cmd, qlast.DDLCommand) delta_command = s_ddl.delta_from_ddl( ddl_cmd, modaliases={}, schema=schema, stdmode=True) # Apply and adapt delta, build native delta plan, which # will also update the schema. schema, plan, tplan = _process_delta(ctx, delta_command, schema) reflschema = delta_command.apply(reflschema, context) plan.generate(topblock) plans.append(plan) trampolines.extend(tplan.trampolines) compiler = edbcompiler.new_compiler( std_schema=schema.get_top_schema(), reflection_schema=reflschema, schema_class_layout=stdlib.classlayout, # type: ignore ) compilerctx = edbcompiler.new_compiler_context( compiler_state=compiler.state, user_schema=schema.get_top_schema(), global_schema=schema.get_global_schema(), ) for plan in plans: edbcompiler.compile_schema_storage_in_delta( ctx=compilerctx, delta=plan, block=topblock, ) sqltext = topblock.to_string() return stdlib._replace( stdschema=schema.get_top_schema(), global_schema=schema.get_global_schema(), reflschema=reflschema, ), sqltext, trampolines def compile_intro_queries_stdlib( *, compiler: edbcompiler.Compiler, user_schema: s_schema.Schema, global_schema: s_schema.Schema=s_schema.EMPTY_SCHEMA, reflection: s_refl.SchemaReflectionParts, ) -> tuple[str, str]: compilerctx = edbcompiler.new_compiler_context( compiler_state=compiler.state, user_schema=user_schema, global_schema=global_schema, schema_reflection_mode=True, output_format=edbcompiler.OutputFormat.JSON_ELEMENTS, ) # The introspection query bits are returned in chunks # because it's a large UNION and we currently generate SQL # that is much harder for Postgres to plan as opposed to a # straight flat UNION. sql_intro_local_parts = [] sql_intro_global_parts = [] for intropart in reflection.local_intro_parts: sql_intro_local_parts.append( compile_single_query( intropart, compilerctx=compilerctx, ), ) for intropart in reflection.global_intro_parts: sql_intro_global_parts.append( compile_single_query( intropart, compilerctx=compilerctx, ), ) local_intro_sql = ' UNION ALL '.join( f'({x})' for x in sql_intro_local_parts) local_intro_sql = f''' WITH intro(c) AS ({local_intro_sql}) SELECT json_agg(intro.c) FROM intro ''' global_intro_sql = ' UNION ALL '.join( f'({x})' for x in sql_intro_global_parts) global_intro_sql = f''' WITH intro(c) AS ({global_intro_sql}) SELECT json_agg(intro.c) FROM intro ''' return local_intro_sql, global_intro_sql def _calculate_src_hash() -> bytes: return buildmeta.hash_dirs( buildmeta.get_cache_src_dirs(), extra_files=[ __file__, pathlib.Path(__file__).parent.parent / 'buildmeta.py', ], ) def _get_cache_dir() -> pathlib.Path | None: if specified_cache_dir := os.environ.get('_EDGEDB_WRITE_DATA_CACHE_TO'): return pathlib.Path(specified_cache_dir) else: return None def read_data_cache( file_name: str, pickled: bool, *, src_hash: bytes | None = None, cache_dir: pathlib.Path | None = None, ) -> Any: if src_hash is None: src_hash = _calculate_src_hash() if cache_dir is None: cache_dir = _get_cache_dir() return buildmeta.read_data_cache( src_hash, file_name, source_dir=cache_dir, pickled=pickled) def cleanup_tpldbdump(tpldbdump: bytes) -> bytes: # Excluding the "edgedbext" schema above apparently # doesn't apply to extensions created in that schema, # so we have to resort to commenting out extension # statements in the dump. tpldbdump = re.sub( rb'^(CREATE|COMMENT ON) EXTENSION.*$', rb'-- \g<0>', tpldbdump, flags=re.MULTILINE, ) # PostgreSQL 14 emits multirange_type_name in RANGE definitions, # elide these to preserve compatibility with earlier servers. tpldbdump = re.sub( rb',\s*multirange_type_name\s*=[^,\n]+', rb'', tpldbdump, flags=re.MULTILINE, ) # PostgreSQL 17 adds a transaction_timeout config setting that # didn't exist on earlier versions. tpldbdump = re.sub( rb'^SET transaction_timeout = 0;$', rb'', tpldbdump, flags=re.MULTILINE, ) tpldbdump = re.sub( rb'^CREATE SCHEMA ', rb'CREATE SCHEMA IF NOT EXISTS ', tpldbdump, flags=re.MULTILINE, ) return tpldbdump async def _init_stdlib( ctx: BootstrapContext, testmode: bool, global_ids: Mapping[str, uuid.UUID], ) -> tuple[ StdlibBits, config.Spec, edbcompiler.Compiler, ]: in_dev_mode = devmode.is_in_dev_mode() conn = ctx.conn cluster = ctx.cluster args = ctx.args tpldbdump_cache = 'backend-tpldbdump.sql' src_hash = _calculate_src_hash() cache_dir = _get_cache_dir() stdlib: Optional[StdlibBits] = read_data_cache( STDLIB_CACHE_FILE_NAME, pickled=True, src_hash=src_hash, cache_dir=cache_dir, ) tpldbdump_package = read_data_cache( tpldbdump_cache, pickled=True, src_hash=src_hash, cache_dir=cache_dir, ) tpldbdump, tpldbdump_inplace = None, None if tpldbdump_package: tpldbdump, tpldbdump_inplace = tpldbdump_package else: assert not args.inplace_upgrade_prepare, ( "Gel must have a valid bootstrap cache to use inplace upgrade" ) stdlib_was_none = stdlib is None if stdlib is None: logger.info('Compiling the standard library...') stdlib = _make_stdlib( ctx, in_dev_mode or testmode, global_ids) config_spec = config.load_spec_from_schema(stdlib.stdschema) # If we recompiled the stdlib or need to generate a tpldbdump, we # need to generate bootstrap commands and trampolines, and update # the stdlib's trampolines if we compiled it. bootstrap_commands = None if stdlib_was_none or tpldbdump is None: bootstrap_commands, bootstrap_trampolines = ( metaschema.get_bootstrap_commands(config_spec) ) if stdlib_was_none: stdlib = stdlib._replace( trampolines=bootstrap_trampolines + stdlib.trampolines ) trampolines = [] # We need to set this up early, since later code depends on the # backend_instance_params of the instdata table. But it also # obviously can't go into the tpldbdump, since it is dynamic. trampolines.extend(await metaschema.generate_instdata_table( conn, )) await _populate_misc_instance_data(ctx) backend_params = cluster.get_runtime_params() if not args.inplace_upgrade_prepare: logger.info('Creating the necessary PostgreSQL extensions...') await metaschema.create_pg_extensions(conn, backend_params) trampolines.extend(stdlib.trampolines) eff_tpldbdump = ( tpldbdump_inplace if args.inplace_upgrade_prepare else tpldbdump) if eff_tpldbdump is None: logger.info('Populating internal SQL structures...') assert bootstrap_commands is not None block = dbops.PLTopBlock() fixed_bootstrap_commands = metaschema.get_fixed_bootstrap_commands() fixed_bootstrap_commands.generate(block) bootstrap_commands.generate(block) await _execute_block(conn, block) logger.info('Executing the standard library...') await _execute(conn, stdlib.sqltext) if in_dev_mode or cache_dir: tpl_db_name = edbdef.EDGEDB_TEMPLATE_DB tpl_pg_db_name = cluster.get_db_name(tpl_db_name) tpldbdump = await cluster.dump_database( tpl_pg_db_name, exclude_schemas=[ pg_common.versioned_schema('edgedbinstdata'), 'edgedbext', backend_params.instance_params.ext_schema, ], dump_object_owners=False, ) tpldbdump = cleanup_tpldbdump(tpldbdump) # The instance metadata doesn't go in the dump, so collect # it ourselves. global_metadata = await conn.sql_fetch_val( trampoline.fixup_query( "SELECT edgedb_VER.get_database_metadata($1)::json" ).encode("utf-8"), args=[tpl_db_name.encode("utf-8")], ) global_metadata = json.loads(global_metadata) pl_block = dbops.PLTopBlock() set_metadata_text = dbops.SetMetadata( dbops.DatabaseWithTenant(name=tpl_db_name), global_metadata, ).code_with_block(pl_block) set_single_db_metadata_text = dbops.SetSingleDBMetadata( edbdef.EDGEDB_TEMPLATE_DB, global_metadata ).code_with_block(pl_block) pl_block.add_command(textwrap.dedent(trampoline.fixup_query(f"""\ IF (edgedb_VER.get_backend_capabilities() & {int(params.BackendCapabilities.CREATE_DATABASE)}) != 0 THEN {textwrap.indent(set_metadata_text, ' ')} ELSE {textwrap.indent(set_single_db_metadata_text, ' ')} END IF """))) text = pl_block.to_string() tpldbdump += b'\n' + text.encode('utf-8') tpldbdump_inplace = await cluster.dump_database( tpl_pg_db_name, include_schemas=[ pg_common.versioned_schema('edgedb'), pg_common.versioned_schema('edgedbstd'), pg_common.versioned_schema('edgedbsql'), ], dump_object_owners=False, ) tpldbdump_inplace = ( stdlib.inplace_upgrade_scalar_text.encode('utf-8') + cleanup_tpldbdump(tpldbdump_inplace) ) buildmeta.write_data_cache( (tpldbdump, tpldbdump_inplace), src_hash, tpldbdump_cache, pickled=True, target_dir=cache_dir, ) buildmeta.write_data_cache( stdlib, src_hash, STDLIB_CACHE_FILE_NAME, target_dir=cache_dir, ) else: logger.info('Initializing the standard library...') await _execute(conn, eff_tpldbdump.decode('utf-8')) # Restore the search_path as the dump might have altered it. await conn.sql_execute( b"SELECT pg_catalog.set_config('search_path', 'edgedb', false)") if not in_dev_mode and testmode: # Running tests on a production build. for modname in s_schema.TESTMODE_SOURCES: stdlib, testmode_sql, new_trampolines = await _amend_stdlib( ctx, s_std.get_std_module_text(modname), stdlib, ) await conn.sql_execute(testmode_sql.encode("utf-8")) trampolines.extend(new_trampolines) # _testmode includes extra config settings, so make sure # those are picked up... config_spec = config.load_spec_from_schema(stdlib.stdschema) # ...and that config functions dependent on it are regenerated await metaschema.regenerate_config_support_functions(conn, config_spec) logger.info('Finalizing database setup...') # Make sure that schema backend_id properties are in sync with # the database. # XXX: is ScalarType sufficient here? compiler = edbcompiler.new_compiler( std_schema=stdlib.stdschema, reflection_schema=stdlib.reflschema, schema_class_layout=stdlib.classlayout, global_intro_query=stdlib.global_intro_query, local_intro_query=stdlib.local_intro_query, ) _, sql = compile_bootstrap_script( compiler, stdlib.reflschema, ''' SELECT schema::ScalarType { id, backend_id, } FILTER .builtin AND NOT (.abstract ?? False); ''', expected_cardinality_one=False, ) schema = stdlib.stdschema typemap = await conn.sql_fetch_val(sql.encode("utf-8")) for entry in json.loads(typemap): t = schema.get_by_id(uuidgen.UUID(entry['id'])) schema = t.set_field_value( schema, 'backend_id', entry['backend_id']) # Patch functions referring to extensions, because # some backends require extensions to be hosted in # hardcoded schemas (e.g. Heroku) await metaschema.patch_pg_extensions(conn, backend_params) stdlib = stdlib._replace(stdschema=schema) version_key = patches.get_version_key(stdlib.num_patches) # stdschema and reflschema are combined in one pickle to preserve sharing await _store_static_bin_cache( ctx, f'std_and_reflection_schema{version_key}', pickle.dumps( (schema, stdlib.reflschema), protocol=pickle.HIGHEST_PROTOCOL, ), ) await _store_static_bin_cache( ctx, f'global_schema{version_key}', pickle.dumps(stdlib.global_schema, protocol=pickle.HIGHEST_PROTOCOL), ) await _store_static_bin_cache( ctx, f'classlayout{version_key}', pickle.dumps(stdlib.classlayout, protocol=pickle.HIGHEST_PROTOCOL), ) await _store_static_text_cache( ctx, f'local_intro_query{version_key}', stdlib.local_intro_query, ) await _store_static_text_cache( ctx, f'global_intro_query{version_key}', stdlib.global_intro_query, ) trampolines.extend(await metaschema.generate_support_views( conn, stdlib.reflschema, cluster.get_runtime_params() )) trampolines.extend( await metaschema.generate_support_functions(conn, stdlib.reflschema) ) compiler = edbcompiler.new_compiler( std_schema=schema, reflection_schema=stdlib.reflschema, schema_class_layout=stdlib.classlayout, global_intro_query=stdlib.global_intro_query, local_intro_query=stdlib.local_intro_query, ) trampolines.extend( await metaschema.generate_more_support_functions( conn, compiler, stdlib.reflschema, testmode ) ) await _store_static_json_cache( ctx, 'configspec', config.spec_to_json(config_spec), ) await _store_static_json_cache( ctx, 'configspec_ext', json.dumps({}), ) # Create all the trampolines tramps = dbops.CommandGroup() tramps.add_commands([t.make() for t in trampolines]) block = dbops.PLTopBlock() tramps.generate(block) if args.inplace_upgrade_prepare: trampoline_text = block.to_string() await _store_static_text_cache( ctx, f'trampoline_pivot_query', trampoline_text, ) await _store_static_text_cache( ctx, f'global_schema_update_query', stdlib.inplace_upgrade_extension_packages_text, ) else: await _execute_block(conn, block) return stdlib, config_spec, compiler async def _init_defaults(schema, compiler, conn): script = ''' CREATE MODULE default; ''' schema, sql = compile_bootstrap_script( compiler, schema, script, bootstrap_mode=False ) await _execute(conn, sql) return schema async def _configure( ctx: BootstrapContext, config_spec: config.Spec, schema: s_schema.Schema, compiler: edbcompiler.Compiler, ) -> None: settings: Mapping[str, config.SettingValue] = {} config_json = config.to_json(config_spec, settings, include_source=False) block = dbops.PLTopBlock() metadata = {'sysconfig': json.loads(config_json)} if ctx.cluster.get_runtime_params().has_create_database: dbops.UpdateMetadata( dbops.DatabaseWithTenant(name=edbdef.EDGEDB_SYSTEM_DB), metadata, ).generate(block) else: dbops.UpdateSingleDBMetadata( edbdef.EDGEDB_SYSTEM_DB, metadata, ).generate(block) await _execute_block(ctx.conn, block) backend_params = ctx.cluster.get_runtime_params() for setname in config_spec: setting = config_spec[setname] if ( setting.backend_setting and setting.default is not None and ( # Do not attempt to run CONFIGURE INSTANCE on # backends that don't support it. # TODO: this should be replaced by instance-wide # emulation at backend connection time. backend_params.has_configfile_access ) ): script = qlcodegen.generate_source( qlast.ConfigSet( name=qlast.ObjectRef(name=setting.name), scope=qltypes.ConfigScope.INSTANCE, expr=s_utils.const_ast_from_python(setting.default), ) ) schema, sql = compile_bootstrap_script(compiler, schema, script) await _execute(ctx.conn, sql) def compile_sys_queries( schema: s_schema.Schema, compiler: edbcompiler.Compiler, config_spec: config.Spec, ) -> tuple[dict[str, str], bytes, bytes]: queries = {} _, sql = compile_bootstrap_script( compiler, schema, 'SELECT cfg::_get_config_json_internal()', expected_cardinality_one=True, ) queries['config'] = sql _, sql = compile_bootstrap_script( compiler, schema, "SELECT cfg::_get_config_json_internal(sources := ['database'])", expected_cardinality_one=True, ) queries['dbconfig'] = sql _, sql = compile_bootstrap_script( compiler, schema, """ SELECT cfg::_get_config_json_internal(max_source := 'system override') """, expected_cardinality_one=True, ) queries['sysconfig'] = sql _, sql = compile_bootstrap_script( compiler, schema, """ SELECT cfg::_get_config_json_internal(max_source := 'postgres client') """, expected_cardinality_one=True, ) queries['sysconfig_default'] = sql _, sql = compile_bootstrap_script( compiler, schema, f"""SELECT ( SELECT sys::Branch FILTER .name != "{edbdef.EDGEDB_TEMPLATE_DB}" ).name""", expected_cardinality_one=False, ) queries['listdbs'] = sql role_query = ''' SELECT sys::Role { name, superuser, password, branches, all_permissions, apply_access_policies_pg_default, }; ''' _, sql = compile_bootstrap_script( compiler, schema, role_query, expected_cardinality_one=False, ) queries['roles'] = sql tids_query = ''' SELECT schema::ScalarType { id, backend_id, } FILTER .id IN json_array_unpack($ids); ''' _, sql = compile_bootstrap_script( compiler, schema, tids_query, expected_cardinality_one=False, ) queries['backend_tids'] = sql # When we restore a database from a dump, OIDs for non-system # Postgres types might get skewed as they are not part of the dump. # A good example of that is `std::bigint` which is implemented as # a custom domain type. The OIDs are stored under # `schema::Object.backend_id` property and are injected into # array query arguments. # # The code below re-syncs backend_id properties of Gel builtin # types with the actual OIDs in the DB. backend_id_fixup_edgeql = ''' UPDATE schema::ScalarType FILTER NOT (.abstract ?? False) AND NOT (.transient ?? False) SET { backend_id := sys::_get_pg_type_for_edgedb_type( .id, .__type__.name, {}, [is schema::ScalarType].sql_type ?? ( select [is schema::ScalarType] .bases[is schema::ScalarType] limit 1 ).sql_type, ) }; UPDATE schema::Tuple FILTER NOT (.abstract ?? False) AND NOT (.transient ?? False) SET { backend_id := sys::_get_pg_type_for_edgedb_type( .id, .__type__.name, {}, [is schema::ScalarType].sql_type ?? ( select [is schema::ScalarType] .bases[is schema::ScalarType] limit 1 ).sql_type, ) }; UPDATE {schema::Range, schema::MultiRange} FILTER NOT (.abstract ?? False) AND NOT (.transient ?? False) SET { backend_id := sys::_get_pg_type_for_edgedb_type( .id, .__type__.name, .element_type.id, {}, ) }; UPDATE schema::Array FILTER NOT (.abstract ?? False) AND NOT (.transient ?? False) AND NOT .element_type IS schema::Array SET { backend_id := sys::_get_pg_type_for_edgedb_type( .id, .__type__.name, .element_type.id, {}, ) }; ''' _, sql = compile_bootstrap_script( compiler, schema, backend_id_fixup_edgeql, ) queries['backend_id_fixup'] = sql report_settings: list[str] = [] for setname in config_spec: setting = config_spec[setname] if setting.report: report_settings.append(setname) report_configs_query = f''' SELECT assert_single(cfg::Config {{ {', '.join(report_settings)} }}); ''' units = edbcompiler.compile( ctx=edbcompiler.new_compiler_context( compiler_state=compiler.state, user_schema=schema, expected_cardinality_one=True, json_parameters=False, output_format=edbcompiler.OutputFormat.BINARY, bootstrap_mode=True, ), source=edgeql.Source.from_string(report_configs_query), ).units assert len(units) == 1 report_configs_typedesc_2_0 = units[0].out_type_id + units[0].out_type_data queries['report_configs'] = units[0].sql.decode() units = edbcompiler.compile( ctx=edbcompiler.new_compiler_context( compiler_state=compiler.state, user_schema=schema, expected_cardinality_one=True, json_parameters=False, output_format=edbcompiler.OutputFormat.BINARY, bootstrap_mode=True, protocol_version=(1, 0), ), source=edgeql.Source.from_string(report_configs_query), ).units assert len(units) == 1 report_configs_typedesc_1_0 = units[0].out_type_id + units[0].out_type_data return ( queries, report_configs_typedesc_1_0, report_configs_typedesc_2_0, ) async def _populate_misc_instance_data( ctx: BootstrapContext, ) -> dict[str, Any]: mock_auth_nonce = scram.generate_nonce() json_instance_data = { 'version': dict(buildmeta.get_version_dict()), 'catver': edbdef.EDGEDB_CATALOG_VERSION, 'mock_auth_nonce': mock_auth_nonce, } await _store_static_json_cache( ctx, 'instancedata', json.dumps(json_instance_data), ) backend_params = ctx.cluster.get_runtime_params() instance_params = backend_params.instance_params await _store_static_json_cache( ctx, 'backend_instance_params', json.dumps(instance_params._asdict()), ) if not backend_params.has_create_role: json_single_role_metadata = { 'id': str(uuidgen.uuid1mc()), 'name': edbdef.EDGEDB_SUPERUSER, 'tenant_id': backend_params.tenant_id, 'builtin': False, } await _store_static_json_cache( ctx, 'single_role_metadata', json.dumps(json_single_role_metadata), ) if not backend_params.has_create_database: await _store_static_json_cache( ctx, f'{edbdef.EDGEDB_TEMPLATE_DB}metadata', json.dumps({}), ) await _store_static_json_cache( ctx, f'{edbdef.EDGEDB_SYSTEM_DB}metadata', json.dumps({}), ) await _store_static_json_cache( ctx, 'sql_default_fe_settings', json.dumps( [ {"name": key, "value": pg_common.setting_to_sql(key, val)} for key, val in dbstate.DEFAULT_SQL_FE_SETTINGS.items() ] ) ) return json_instance_data async def _create_edgedb_database( ctx: BootstrapContext, database: str, owner: str, *, builtin: bool = False, objid: Optional[uuid.UUID] = None, ) -> uuid.UUID: logger.info(f'Creating database: {database}') block = dbops.SQLBlock() if objid is None: objid = uuidgen.uuid1mc() instance_params = ctx.cluster.get_runtime_params().instance_params db = dbops.Database( ctx.cluster.get_db_name(database), owner=ctx.cluster.get_role_name(owner), metadata=dict( id=str(objid), tenant_id=instance_params.tenant_id, name=database, builtin=builtin, ), ) tpl_db = ctx.cluster.get_db_name(edbdef.EDGEDB_TEMPLATE_DB) dbops.CreateDatabase(db, template=tpl_db).generate(block) # Background tasks on some hosted provides like DO seem to sometimes make # their own connections to the template DB, so do a retry loop on it. rloop = retryloop.RetryLoop( backoff=retryloop.exp_backoff(), timeout=10.0, ignore=pgcon.errors.BackendError, ) async for iteration in rloop: async with iteration: await _execute_block(ctx.conn, block) return objid async def _set_edgedb_database_metadata( ctx: BootstrapContext, database: str, *, objid: Optional[uuid.UUID] = None, ) -> uuid.UUID: logger.info(f'Configuring database: {database}') block = dbops.SQLBlock() if objid is None: objid = uuidgen.uuid1mc() instance_params = ctx.cluster.get_runtime_params().instance_params db = dbops.Database(ctx.cluster.get_db_name(database)) metadata = dict( id=str(objid), tenant_id=instance_params.tenant_id, name=database, builtin=False, ) dbops.SetMetadata(db, metadata).generate(block) await _execute_block(ctx.conn, block) return objid def _pg_log_listener(severity, message): if severity == 'WARNING': level = logging.WARNING else: level = logging.DEBUG logger.log(level, message) async def _get_instance_data( conn: metaschema.PGConnection, *, versioned: bool=True, ) -> dict[str, Any]: schema = 'edgedbinstdata_VER' if versioned else 'edgedbinstdata' data = await conn.sql_fetch_val( trampoline.fixup_query(f""" SELECT json::json FROM {schema}.instdata WHERE key = 'instancedata' """).encode('utf-8'), ) return json.loads(data) async def _check_catalog_compatibility( ctx: BootstrapContext, ) -> PGConnectionProxy: tenant_id = ctx.cluster.get_runtime_params().tenant_id if ctx.mode == ClusterMode.single_database: sys_db = await ctx.conn.sql_fetch_val( trampoline.fixup_query(""" SELECT current_database() FROM edgedbinstdata_VER.instdata WHERE key = $1 AND json->>'tenant_id' = $2 """).encode('utf-8'), args=[ f"{edbdef.EDGEDB_TEMPLATE_DB}metadata".encode("utf-8"), tenant_id.encode("utf-8"), ], ) else: is_default_tenant = tenant_id == buildmeta.get_default_tenant_id() if is_default_tenant: sys_db = await ctx.conn.sql_fetch_val( b""" SELECT datname FROM pg_database WHERE datname LIKE '%' || $1 ORDER BY datname = $1, datname DESC LIMIT 1 """, args=[ edbdef.EDGEDB_SYSTEM_DB.encode("utf-8"), ], ) else: sys_db = await ctx.conn.sql_fetch_val( b""" SELECT datname FROM pg_database WHERE datname = $1 """, args=[ ctx.cluster.get_db_name( edbdef.EDGEDB_SYSTEM_DB).encode("utf-8"), ], ) if not sys_db: raise errors.ConfigurationError( 'database instance is corrupt', details=( f'The database instance does not appear to have been fully ' f'initialized or has been corrupted.' ) ) conn = PGConnectionProxy( ctx.cluster, source_description="_check_catalog_compatibility", dbname=sys_db.decode("utf-8") ) try: # versioned=False so we can properly fail on version/catalog mismatches. instancedata = await _get_instance_data(conn, versioned=False) datadir_version = instancedata.get('version') if datadir_version: datadir_major = datadir_version.get('major') expected_ver = buildmeta.get_version() datadir_catver = instancedata.get('catver') expected_catver = edbdef.EDGEDB_CATALOG_VERSION status = dict( data_catalog_version=datadir_catver, expected_catalog_version=expected_catver, ) if datadir_major != expected_ver.major: for status_sink in ctx.args.status_sinks: status_sink(f'INCOMPATIBLE={json.dumps(status)}') raise errors.ConfigurationError( 'database instance incompatible with this version of Gel', details=( f'The database instance was initialized with ' f'Gel version {datadir_major}, ' f'which is incompatible with this version ' f'{expected_ver.major}' ), hint=( f'You need to either recreate the instance and upgrade ' f'using dump/restore, or do an inplace upgrade.' ) ) if datadir_catver != expected_catver: for status_sink in ctx.args.status_sinks: status_sink(f'INCOMPATIBLE={json.dumps(status)}') raise errors.ConfigurationError( 'database instance incompatible with this version of Gel', details=( f'The database instance was initialized with ' f'Gel format version {datadir_catver}, ' f'but this version of the server expects ' f'format version {expected_catver}' ), hint=( f'You need to either recreate the instance and upgrade ' f'using dump/restore, or do an inplace upgrade.' ) ) except Exception: conn.terminate() raise return conn def _check_capabilities(ctx: BootstrapContext) -> None: caps = ctx.cluster.get_runtime_params().instance_params.capabilities for cap in ctx.args.backend_capability_sets.must_be_present: if not caps & cap: raise errors.ConfigurationError( f"the backend doesn't have necessary capability: " f"{cap.name}" ) for cap in ctx.args.backend_capability_sets.must_be_absent: if caps & cap: raise errors.ConfigurationError( f"the backend was already bootstrapped with capability: " f"{cap.name}" ) async def _pg_ensure_database_not_connected( conn: metaschema.PGConnection, dbname: str, ) -> None: conns = await conn.sql_fetch_col( b""" SELECT pid FROM pg_stat_activity WHERE datname = $1 """, args=[dbname.encode("utf-8")], ) if conns: raise errors.ExecutionError( f'database {dbname!r} is being accessed by other users') async def _start(ctx: BootstrapContext) -> edbcompiler.Compiler: conn = await _check_catalog_compatibility(ctx) try: caps = await conn.sql_fetch_val( b"SELECT edgedb.get_backend_capabilities()") ctx.cluster.overwrite_capabilities(struct.Struct('!Q').unpack(caps)[0]) _check_capabilities(ctx) return await edbcompiler.new_compiler_from_pg(conn) finally: conn.terminate() async def _bootstrap_edgedb_super_roles(ctx: BootstrapContext) -> uuid.UUID: await _ensure_edgedb_supergroup( ctx, edbdef.EDGEDB_SUPERGROUP, ) superuser_uid = await _ensure_edgedb_role( ctx, edbdef.EDGEDB_SUPERUSER, superuser=True, builtin=True, ) superuser = ctx.cluster.get_role_name(edbdef.EDGEDB_SUPERUSER) await _execute(ctx.conn, f'SET ROLE {qi(superuser)}') return superuser_uid async def _bootstrap( ctx: BootstrapContext, no_template: bool=False, ) -> tuple[ StdlibBits, edbcompiler.Compiler ]: args = ctx.args cluster = ctx.cluster backend_params = cluster.get_runtime_params() if backend_params.instance_params.version < edbdef.MIN_POSTGRES_VERSION: min_ver = '.'.join(str(v) for v in edbdef.MIN_POSTGRES_VERSION) raise errors.ConfigurationError( 'unsupported backend', details=( f'Gel requires PostgreSQL version {min_ver} or later, ' f'while the specified backend reports itself as ' f'{backend_params.instance_params.version.string}.' ) ) _check_capabilities(ctx) if backend_params.has_create_role: superuser_uid = await _bootstrap_edgedb_super_roles(ctx) else: superuser_uid = uuidgen.uuid1mc() using_template = backend_params.has_create_database and not no_template if using_template: if not args.inplace_upgrade_prepare: new_template_db_id = await _create_edgedb_template_database(ctx) # XXX: THIS IS WRONG, RIGHT? else: new_template_db_id = uuidgen.uuid1mc() tpl_db = cluster.get_db_name(edbdef.EDGEDB_TEMPLATE_DB) conn = PGConnectionProxy( cluster, source_description="_bootstrap", dbname=tpl_db ) tpl_ctx = dataclasses.replace(ctx, conn=conn) else: new_template_db_id = uuidgen.uuid1mc() tpl_ctx = ctx in_dev_mode = devmode.is_in_dev_mode() # Protect against multiple Gel tenants from trying to bootstrap # on the same cluster in devmode, as that is both a waste of resources # and might result in broken stdlib cache. if in_dev_mode: await tpl_ctx.conn.sql_execute(b"SELECT pg_advisory_lock(3987734529)") try: # Some of the views need access to the _edgecon_state table and the # _dml_dummy table, so set it up. tmp_table_query = ( pgcon.SETUP_TEMP_TABLE_SCRIPT + pgcon.SETUP_DML_DUMMY_TABLE_SCRIPT ) await _execute(tpl_ctx.conn, tmp_table_query) stdlib, config_spec, compiler = await _init_stdlib( tpl_ctx, testmode=args.testmode, global_ids={ edbdef.EDGEDB_SUPERUSER: superuser_uid, edbdef.EDGEDB_TEMPLATE_DB: new_template_db_id, } ) ( sysqueries, report_configs_typedesc_1_0, report_configs_typedesc_2_0, ) = compile_sys_queries( stdlib.reflschema, compiler, config_spec, ) # Update schema backend_ids to match the reality after await tpl_ctx.conn.sql_execute( sysqueries['backend_id_fixup'].encode('utf-8') ) version_key = patches.get_version_key(stdlib.num_patches) await _store_static_json_cache( tpl_ctx, f'sysqueries{version_key}', json.dumps(sysqueries), ) await _store_static_bin_cache( tpl_ctx, f'report_configs_typedesc_1_0{version_key}', report_configs_typedesc_1_0, ) await _store_static_bin_cache( tpl_ctx, f'report_configs_typedesc_2_0{version_key}', report_configs_typedesc_2_0, ) await _store_static_json_cache( tpl_ctx, 'num_patches', json.dumps(stdlib.num_patches), ) default_branch = args.default_branch or edbdef.EDGEDB_SUPERUSER_DB await _store_static_text_cache( tpl_ctx, 'default_branch', default_branch, ) schema = s_schema.EMPTY_SCHEMA if not no_template: schema = await _init_defaults(schema, compiler, tpl_ctx.conn) # Run analyze on the template database, so that new dbs start # with up-to-date statistics. await tpl_ctx.conn.sql_execute(b"ANALYZE") finally: if in_dev_mode: await tpl_ctx.conn.sql_execute( b"SELECT pg_advisory_unlock(3987734529)", ) if using_template: # Close the connection to the template database # and wait until it goes away on the server side # so that we can safely use the template for new # databases. # # The timeout is set weirdly high, because we were getting # frequent timeouts in macos-x86_64 release builds when it # was set to 10s. conn.terminate() rloop = retryloop.RetryLoop( backoff=retryloop.exp_backoff(), timeout=60.0, ignore=errors.ExecutionError, ) async for iteration in rloop: async with iteration: await _pg_ensure_database_not_connected(ctx.conn, tpl_db) if args.inplace_upgrade_prepare: pass elif backend_params.has_create_database: await _create_edgedb_database( ctx, edbdef.EDGEDB_SYSTEM_DB, edbdef.EDGEDB_SUPERUSER, builtin=True, ) sys_conn = PGConnectionProxy( cluster, source_description="_bootstrap", dbname=cluster.get_db_name(edbdef.EDGEDB_SYSTEM_DB), ) try: await _configure( dataclasses.replace(ctx, conn=sys_conn), config_spec=config_spec, schema=schema, compiler=compiler, ) finally: sys_conn.terminate() else: await _configure( ctx, config_spec=config_spec, schema=schema, compiler=compiler, ) if args.inplace_upgrade_prepare: pass elif backend_params.has_create_database: await _create_edgedb_database( ctx, default_branch, args.default_database_user or edbdef.EDGEDB_SUPERUSER, ) else: await _set_edgedb_database_metadata( ctx, default_branch, ) if ( backend_params.has_create_role and args.default_database_user and args.default_database_user != edbdef.EDGEDB_SUPERUSER ): await _ensure_edgedb_role( ctx, args.default_database_user, superuser=True, ) def_role = ctx.cluster.get_role_name(args.default_database_user) await _execute(ctx.conn, f"SET ROLE {qi(def_role)}") if ( backend_params.has_create_database and args.default_database and args.default_database != default_branch ): await _create_edgedb_database( ctx, args.default_database, args.default_database_user or edbdef.EDGEDB_SUPERUSER, ) return stdlib, compiler async def ensure_bootstrapped( cluster: pgcluster.BaseCluster, args: edbargs.ServerConfig, ) -> tuple[bool, edbcompiler.Compiler]: """Bootstraps Gel instance if it hasn't been bootstrapped already. Returns True if bootstrap happened and False if the instance was already bootstrapped, along with the bootstrap compiler state. """ pgconn = PGConnectionProxy( cluster, source_description="ensure_bootstrapped" ) ctx = BootstrapContext(cluster=cluster, conn=pgconn, args=args) try: mode = await _get_cluster_mode(ctx) ctx = dataclasses.replace(ctx, mode=mode) if mode == ClusterMode.pristine: _, compiler = await _bootstrap(ctx) return True, compiler else: compiler = await _start(ctx) return False, compiler finally: pgconn.terminate() ================================================ FILE: edb/server/cache/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from .stmt_cache import StatementsCache __all__ = ('StatementsCache',) ================================================ FILE: edb/server/cache/stmt_cache.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # cdef class StatementsCache: cdef: object _dict int _maxsize object _dict_move_to_end object _dict_get cpdef get(self, key, default) cpdef needs_cleanup(self) cpdef cleanup_one(self) cpdef resize(self, int maxsize) ================================================ FILE: edb/server/cache/stmt_cache.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import collections cdef object _LRU_MARKER = object() cdef class StatementsCache: # We use an OrderedDict for LRU implementation. Operations: # # * We use a simple `__setitem__` to push a new entry: # `entries[key] = new_entry` # That will push `new_entry` to the *end* of the entries dict. # # * When we have a cache hit, we call # `entries.move_to_end(key, last=True)` # to move the entry to the *end* of the entries dict. # # * When we need to remove entries to maintain `max_size`, we call # `entries.popitem(last=False)` # to remove an entry from the *beginning* of the entries dict. # # So new entries and hits are always promoted to the end of the # entries dict, whereas the unused one will group in the # beginning of it. def __init__(self, *, maxsize): self.resize(maxsize) self._dict = collections.OrderedDict() self._dict_move_to_end = self._dict.move_to_end self._dict_get = self._dict.get cpdef get(self, key, default): o = self._dict_get(key, _LRU_MARKER) if o is _LRU_MARKER: return default self._dict_move_to_end(key) # last=True return o cpdef needs_cleanup(self): return len(self._dict) > self._maxsize cpdef cleanup_one(self): return self._dict.popitem(last=False) cpdef resize(self, int maxsize): if maxsize <= 0: raise ValueError( f'maxsize is expected to be greater than 0, got {maxsize}') self._maxsize = maxsize def items(self): return self._dict.items() def clear(self): self._dict.clear() def pop(self, key, default=_LRU_MARKER): if default is _LRU_MARKER: return self._dict.pop(key) else: return self._dict.pop(key, default) def __getitem__(self, key): o = self._dict[key] self._dict_move_to_end(key) # last=True return o def __setitem__(self, key, o): if key in self._dict: self._dict[key] = o self._dict_move_to_end(key) # last=True else: self._dict[key] = o def __delitem__(self, key): del self._dict[key] def __contains__(self, key): return key in self._dict def __len__(self): return len(self._dict) def __iter__(self): return iter(self._dict) ================================================ FILE: edb/server/compiler/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for tbhe specific language governing permissions and # limitations under the License. # from __future__ import annotations from .compiler import Compiler, CompilerState from .compiler import CompileContext, CompilerDatabaseState from .compiler import compile_edgeql_script from .compiler import new_compiler, new_compiler_from_pg, new_compiler_context from .compiler import compile, compile_schema_storage_in_delta from .compiler import maybe_force_database_error from .dbstate import QueryUnit, QueryUnitGroup from .enums import Capability, Cardinality from .enums import InputFormat, OutputFormat, InputLanguage from .explain import analyze_explain_output from .ddl import repair_schema from .rpc import CompilationRequest __all__ = ( 'Cardinality', 'CompilationRequest', 'Compiler', 'CompilerState', 'CompileContext', 'CompilerDatabaseState', 'QueryUnit', 'QueryUnitGroup', 'Capability', 'InputFormat', 'InputLanguage', 'OutputFormat', 'analyze_explain_output', 'compile_edgeql_script', 'maybe_force_database_error', 'new_compiler', 'new_compiler_from_pg', 'new_compiler_context', 'compile', 'compile_schema_storage_in_delta', 'repair_schema', ) ================================================ FILE: edb/server/compiler/compiler.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, AbstractSet, Iterable, Mapping, MutableMapping, Sequence, NamedTuple, cast, TYPE_CHECKING, ) import dataclasses import functools import json import hashlib import pickle import textwrap import time import uuid import immutables from edb import buildmeta from edb import errors from edb.common.typeutils import not_none from edb.server import config from edb.server import defines from edb.server import instdata from edb import edgeql from edb.common import debug from edb import graphql from edb.common import turbo_uuid from edb.common import verutils from edb.common import uuidgen from edb.edgeql import ast as qlast from edb.edgeql import codegen as qlcodegen from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from edb.ir import staeval as ireval from edb.ir import statypes from edb.ir import ast as irast from edb.schema import ddl as s_ddl from edb.schema import delta as s_delta from edb.schema import extensions as s_ext from edb.schema import functions as s_func from edb.schema import links as s_links from edb.schema import properties as s_props from edb.schema import modules as s_mod from edb.schema import name as s_name from edb.schema import objects as s_obj from edb.schema import objtypes as s_objtypes from edb.schema import pointers as s_pointers from edb.schema import reflection as s_refl from edb.schema import roles as s_role from edb.schema import schema as s_schema from edb.schema import types as s_types from edb.schema import version as s_ver from edb.pgsql import ast as pgast from edb.pgsql import compiler as pg_compiler from edb.pgsql import codegen as pg_codegen from edb.pgsql import common as pg_common from edb.pgsql import debug as pg_debug from edb.pgsql import dbops as pg_dbops from edb.pgsql import params as pg_params from edb.pgsql import parser as pg_parser from edb.pgsql import patches as pg_patches from edb.pgsql import types as pg_types from edb.pgsql import delta as pg_delta from . import config as config_compiler from . import dbstate from . import enums from . import explain from . import sertypes from . import status from . import ddl from . import rpc from . import sql if TYPE_CHECKING: from edb.pgsql import metaschema SQLDescriptors = list[ tuple[ tuple[bytes, bytes, list[dbstate.Param], int], tuple[bytes, bytes] ] ] EMPTY_MAP: immutables.Map[Any, Any] = immutables.Map() @dataclasses.dataclass(frozen=True) class CompilerDatabaseState: user_schema: s_schema.Schema global_schema: s_schema.Schema cached_reflection: immutables.Map[str, tuple[str, ...]] @dataclasses.dataclass(frozen=True, kw_only=True) class CompileContext: compiler_state: CompilerState state: dbstate.CompilerConnectionState output_format: enums.OutputFormat expected_cardinality_one: bool protocol_version: defines.ProtocolVersion expect_rollback: bool = False json_parameters: bool = False schema_reflection_mode: bool = False force_testmode: bool = False implicit_limit: int = 0 inline_typeids: bool = False inline_typenames: bool = False inline_objectids: bool = True schema_object_ids: Optional[ Mapping[tuple[s_name.Name, Optional[str]], uuid.UUID]] = None source: Optional[edgeql.Source | graphql.Source | pg_parser.Source] = None backend_runtime_params: pg_params.BackendRuntimeParams = dataclasses.field( default_factory=pg_params.get_default_runtime_params ) compat_ver: Optional[verutils.Version] = None bootstrap_mode: bool = False internal_schema_mode: bool = False log_ddl_as_migrations: bool = True dump_restore_mode: bool = False notebook: bool = False branch_name: Optional[str] = None role_name: Optional[str] = None cache_key: Optional[uuid.UUID] = None def get_cache_mode(self) -> config.QueryCacheMode: return config.QueryCacheMode.effective( _get_config_val(self, 'query_cache_mode') ) def _assert_not_in_migration_block(self, ql: qlast.Base) -> None: """Check that a START MIGRATION block is *not* active.""" current_tx = self.state.current_tx() mstate = current_tx.get_migration_state() if mstate is not None: stmt = status.get_status(ql).decode() raise errors.QueryError( f'cannot execute {stmt} in a migration block', span=ql.span, ) def _assert_in_migration_block( self, ql: qlast.Base ) -> dbstate.MigrationState: """Check that a START MIGRATION block *is* active.""" current_tx = self.state.current_tx() mstate = current_tx.get_migration_state() if mstate is None: stmt = status.get_status(ql).decode() raise errors.QueryError( f'cannot execute {stmt} outside of a migration block', span=ql.span, ) return mstate def _assert_not_in_migration_rewrite_block(self, ql: qlast.Base) -> None: """Check that a START MIGRATION REWRITE block is *not* active.""" current_tx = self.state.current_tx() mstate = current_tx.get_migration_rewrite_state() if mstate is not None: stmt = status.get_status(ql).decode() raise errors.QueryError( f'cannot execute {stmt} in a migration rewrite block', span=ql.span, ) def _assert_in_migration_rewrite_block( self, ql: qlast.Base ) -> dbstate.MigrationRewriteState: """Check that a START MIGRATION REWRITE block *is* active.""" current_tx = self.state.current_tx() mstate = current_tx.get_migration_rewrite_state() if mstate is None: stmt = status.get_status(ql).decode() raise errors.QueryError( f'cannot execute {stmt} outside of a migration rewrite block', span=ql.span, ) return mstate def is_testmode(self) -> bool: return ( self.force_testmode or _get_config_val(self, '__internal_testmode') ) DEFAULT_MODULE_ALIASES_MAP: immutables.Map[Optional[str], str] = ( immutables.Map({None: s_mod.DEFAULT_MODULE_ALIAS})) def compile_edgeql_script( ctx: CompileContext, eql: str, ) -> tuple[s_schema.Schema, str]: sql = _compile_ql_script(ctx, eql) new_schema = ctx.state.current_tx().get_schema( ctx.compiler_state.std_schema) assert isinstance(new_schema, s_schema.ChainedSchema) return new_schema.get_top_schema(), sql def new_compiler( std_schema: s_schema.Schema, reflection_schema: s_schema.Schema, schema_class_layout: s_refl.SchemaClassLayout, *, backend_runtime_params: Optional[pg_params.BackendRuntimeParams] = None, local_intro_query: Optional[str] = None, global_intro_query: Optional[str] = None, config_spec: Optional[config.Spec] = None, ) -> Compiler: """Create and return a compiler instance.""" if not backend_runtime_params: backend_runtime_params = pg_params.get_default_runtime_params() if not config_spec: config_spec = config.load_spec_from_schema(std_schema) return Compiler(CompilerState( std_schema=std_schema, refl_schema=reflection_schema, schema_class_layout=schema_class_layout, backend_runtime_params=backend_runtime_params, config_spec=config_spec, local_intro_query=local_intro_query, global_intro_query=global_intro_query, )) async def new_compiler_from_pg(con: metaschema.PGConnection) -> Compiler: num_patches = await get_patch_count(con) std_schema, reflection_schema = await load_std_and_reflection_schema( con, num_patches) return new_compiler( std_schema=std_schema, reflection_schema=reflection_schema, schema_class_layout=await load_schema_class_layout( con, num_patches ), local_intro_query=await load_schema_intro_query( con, num_patches, 'local_intro_query' ), global_intro_query=await load_schema_intro_query( con, num_patches, 'global_intro_query' ), config_spec=None, ) def new_compiler_context( *, compiler_state: CompilerState, user_schema: s_schema.Schema, global_schema: s_schema.Schema=s_schema.EMPTY_SCHEMA, modaliases: Optional[Mapping[Optional[str], str]] = None, expected_cardinality_one: bool = False, json_parameters: bool = False, schema_reflection_mode: bool = False, output_format: enums.OutputFormat = enums.OutputFormat.BINARY, bootstrap_mode: bool = False, internal_schema_mode: bool = False, force_testmode: bool = False, protocol_version: defines.ProtocolVersion = defines.CURRENT_PROTOCOL, backend_runtime_params: Optional[pg_params.BackendRuntimeParams] = None, log_ddl_as_migrations: bool = True, ) -> CompileContext: """Create and return an ad-hoc compiler context.""" state = dbstate.CompilerConnectionState( user_schema=user_schema, global_schema=global_schema, modaliases=immutables.Map(modaliases) if modaliases else EMPTY_MAP, session_config=EMPTY_MAP, database_config=EMPTY_MAP, system_config=EMPTY_MAP, cached_reflection=EMPTY_MAP, ) ctx = CompileContext( compiler_state=compiler_state, state=state, output_format=output_format, expected_cardinality_one=expected_cardinality_one, json_parameters=json_parameters, schema_reflection_mode=schema_reflection_mode, bootstrap_mode=bootstrap_mode, internal_schema_mode=internal_schema_mode, force_testmode=force_testmode, protocol_version=protocol_version, backend_runtime_params=( backend_runtime_params or pg_params.get_default_runtime_params() ), log_ddl_as_migrations=log_ddl_as_migrations, ) return ctx async def get_patch_count(backend_conn: metaschema.PGConnection) -> int: """Get the number of applied patches.""" num_patches = await instdata.get_instdata( backend_conn, 'num_patches', 'json') res: int = json.loads(num_patches) if num_patches else 0 return res async def load_std_and_reflection_schema( backend_conn: metaschema.PGConnection, patches: int, ) -> tuple[s_schema.Schema, s_schema.Schema]: vkey = pg_patches.get_version_key(patches) # stdschema and reflschema are combined in one pickle to preserve sharing. key = f"std_and_reflection_schema{vkey}" data = await instdata.get_instdata(backend_conn, key, 'bin') try: std_schema: s_schema.Schema refl_schema: s_schema.Schema std_schema, refl_schema = pickle.loads(data) if vkey != pg_patches.get_version_key(len(pg_patches.PATCHES)): std_schema = s_schema.upgrade_schema(std_schema) refl_schema = s_schema.upgrade_schema(refl_schema) return (std_schema, refl_schema) except Exception as e: raise RuntimeError( 'could not load std schema pickle') from e async def load_schema_intro_query( backend_conn: metaschema.PGConnection, patches: int, kind: str, ) -> str: kind += pg_patches.get_version_key(patches) return ( await instdata.get_instdata(backend_conn, kind, 'text') ).decode('utf-8') async def load_schema_class_layout( backend_conn: metaschema.PGConnection, patches: int, ) -> s_refl.SchemaClassLayout: key = f'classlayout{pg_patches.get_version_key(patches)}' data = await instdata.get_instdata(backend_conn, key, 'bin') try: return cast(s_refl.SchemaClassLayout, pickle.loads(data)) except Exception as e: raise RuntimeError( 'could not load schema class layout pickle') from e @dataclasses.dataclass(frozen=True, kw_only=True) class CompilerState: std_schema: s_schema.Schema refl_schema: s_schema.Schema schema_class_layout: s_refl.SchemaClassLayout backend_runtime_params: pg_params.BackendRuntimeParams config_spec: config.Spec local_intro_query: Optional[str] global_intro_query: Optional[str] @functools.cached_property def state_serializer_factory(self) -> sertypes.StateSerializerFactory: # TODO: This factory will probably need to become per-db once # config spec differs between databases. See also #5836. return sertypes.StateSerializerFactory( self.std_schema, self.config_spec ) @functools.cached_property def compilation_config_serializer( self ) -> sertypes.CompilationConfigSerializer: return ( self.state_serializer_factory.make_compilation_config_serializer() ) class Compiler: state: CompilerState def __init__(self, state: CompilerState): self.state = state @staticmethod def _try_compile_rollback( eql: edgeql.Source | bytes ) -> tuple[dbstate.QueryUnitGroup, int]: source: str | edgeql.Source if isinstance(eql, edgeql.Source): source = eql else: source = eql.decode() statements = edgeql.parse_block(source) stmt = statements[0] unit = None if isinstance(stmt, qlast.RollbackTransaction): sql = b'ROLLBACK;' unit = dbstate.QueryUnit( status=b'ROLLBACK', sql=sql, tx_rollback=True, cacheable=False) elif isinstance(stmt, qlast.RollbackToSavepoint): sql = f'ROLLBACK TO {pg_common.quote_ident(stmt.name)};'.encode() unit = dbstate.QueryUnit( status=b'ROLLBACK TO SAVEPOINT', sql=sql, tx_savepoint_rollback=True, sp_name=stmt.name, cacheable=False) if unit is not None: rv = dbstate.QueryUnitGroup() rv.append(unit) return rv, len(statements) - 1 raise errors.TransactionError( 'expected a ROLLBACK or ROLLBACK TO SAVEPOINT command' ) # pragma: no cover def compile_notebook( self, user_schema: s_schema.Schema, global_schema: s_schema.Schema, reflection_cache: immutables.Map[str, tuple[str, ...]], database_config: immutables.Map[str, config.SettingValue], system_config: immutables.Map[str, config.SettingValue], queries: list[str], protocol_version: defines.ProtocolVersion, implicit_limit: int = 0, ) -> list[ tuple[ bool, dbstate.QueryUnit | tuple[str, str, dict[int, str]] ] ]: state = dbstate.CompilerConnectionState( user_schema=user_schema, global_schema=global_schema, modaliases=DEFAULT_MODULE_ALIASES_MAP, session_config=EMPTY_MAP, database_config=database_config, system_config=system_config, cached_reflection=reflection_cache, ) state.start_tx() result: list[ tuple[ bool, dbstate.QueryUnit | tuple[str, str, dict[int, str]] ] ] = [] for query in queries: try: source = edgeql.Source.from_string(query) ctx = CompileContext( compiler_state=self.state, state=state, output_format=enums.OutputFormat.BINARY, expected_cardinality_one=False, implicit_limit=implicit_limit, inline_typeids=False, inline_typenames=True, json_parameters=False, source=source, protocol_version=protocol_version, notebook=True, ) result.append( (False, compile(ctx=ctx, source=source)[0])) except Exception as ex: fields = {} typename = 'Error' if (isinstance(ex, errors.EdgeDBError) and type(ex) is not errors.EdgeDBError): fields = ex._attrs typename = type(ex).__name__ result.append( (True, (typename, str(ex), fields))) break return result def compile_sql( self, user_schema: s_schema.Schema, global_schema: s_schema.Schema, reflection_cache: immutables.Map[str, tuple[str, ...]], database_config: immutables.Map[str, config.SettingValue], system_config: immutables.Map[str, config.SettingValue], source: pg_parser.Source, tx_state: dbstate.SQLTransactionState, prepared_stmt_map: Mapping[str, str], current_database: str, current_user: str, ) -> list[dbstate.SQLQueryUnit]: state = dbstate.CompilerConnectionState( user_schema=user_schema, global_schema=global_schema, modaliases=DEFAULT_MODULE_ALIASES_MAP, session_config=EMPTY_MAP, database_config=database_config, system_config=system_config, cached_reflection=reflection_cache, ) schema = state.current_tx().get_schema(self.state.std_schema) setting = database_config.get('allow_user_specified_id', None) allow_user_specified_id = None if setting and setting.value: allow_user_specified_id = sql.is_setting_truthy(setting.value) setting = database_config.get('apply_access_policies_pg', None) apply_access_policies_pg = None if setting is not None: apply_access_policies_pg = sql.is_setting_truthy(setting.value) return sql.compile_sql( source, schema=schema, tx_state=tx_state, prepared_stmt_map=prepared_stmt_map, current_database=current_database, allow_user_specified_id=allow_user_specified_id, apply_access_policies=apply_access_policies_pg, disambiguate_column_names=False, backend_runtime_params=self.state.backend_runtime_params, protocol_version=defines.POSTGRES_PROTOCOL, )[0] def compile_serialized_request( self, user_schema: s_schema.Schema, global_schema: s_schema.Schema, reflection_cache: immutables.Map[str, tuple[str, ...]], database_config: Optional[immutables.Map[str, config.SettingValue]], system_config: Optional[immutables.Map[str, config.SettingValue]], serialized_request: bytes, original_query: str, ) -> tuple[ dbstate.QueryUnitGroup | SQLDescriptors, Optional[dbstate.CompilerConnectionState] ]: request = rpc.CompilationRequest.deserialize( serialized_request, original_query, self.state.compilation_config_serializer, ) return self.compile( user_schema=user_schema, global_schema=global_schema, reflection_cache=reflection_cache, database_config=database_config, system_config=system_config, request=request, ) def compile( self, *, user_schema: s_schema.Schema, global_schema: s_schema.Schema, reflection_cache: immutables.Map[str, tuple[str, ...]], database_config: Optional[immutables.Map[str, config.SettingValue]], system_config: Optional[immutables.Map[str, config.SettingValue]], request: rpc.CompilationRequest, ) -> tuple[dbstate.QueryUnitGroup | SQLDescriptors, Optional[dbstate.CompilerConnectionState]]: if request.input_language is enums.InputLanguage.SQL_PARAMS: assert isinstance(request.source, rpc.SQLParamsSource) return ( self.compile_sql_descriptors( user_schema, global_schema, request.protocol_version, request.source.types_in_out, ), # state is None -- we know we're not # in a transaction and compilation of params # couldn't have started it. None, ) sess_config = request.session_config if sess_config is None: sess_config = EMPTY_MAP if database_config is None: database_config = EMPTY_MAP if system_config is None: system_config = EMPTY_MAP sess_modaliases = request.modaliases if sess_modaliases is None: sess_modaliases = DEFAULT_MODULE_ALIASES_MAP state = dbstate.CompilerConnectionState( user_schema=user_schema, global_schema=global_schema, modaliases=sess_modaliases, session_config=sess_config, database_config=database_config, system_config=system_config, cached_reflection=reflection_cache, ) ctx = CompileContext( compiler_state=self.state, state=state, output_format=request.output_format, expected_cardinality_one=request.expect_one, implicit_limit=request.implicit_limit, inline_typeids=request.inline_typeids, inline_typenames=request.inline_typenames, inline_objectids=request.inline_objectids, json_parameters=request.input_format is enums.InputFormat.JSON, source=request.source, protocol_version=request.protocol_version, role_name=request.role_name, branch_name=request.branch_name, cache_key=request.get_cache_key(), ) match request.input_language: case enums.InputLanguage.EDGEQL: assert isinstance(request.source, edgeql.Source) unit_group = compile(ctx=ctx, source=request.source) case enums.InputLanguage.GRAPHQL: assert isinstance(request.source, graphql.Source) unit_group = compile_graphql( ctx=ctx, source=request.source, variables=request.key_params, ) case enums.InputLanguage.SQL: assert isinstance(request.source, pg_parser.Source) unit_group = compile_sql_as_unit_group( ctx=ctx, source=request.source) case _: raise NotImplementedError( f"unnsupported input language: {request.input_language}") tx_started = False for unit in unit_group: if unit.tx_id: tx_started = True break if tx_started: return unit_group, ctx.state else: return unit_group, None def compile_serialized_request_in_tx( self, state: dbstate.CompilerConnectionState, txid: int, serialized_request: bytes, original_query: str, expect_rollback: bool = False, ) -> tuple[ dbstate.QueryUnitGroup | SQLDescriptors, Optional[dbstate.CompilerConnectionState] ]: request = rpc.CompilationRequest.deserialize( serialized_request, original_query, self.state.compilation_config_serializer, ) return self.compile_in_tx( state=state, txid=txid, request=request, expect_rollback=expect_rollback, ) def compile_in_tx( self, *, state: dbstate.CompilerConnectionState, txid: int, request: rpc.CompilationRequest, expect_rollback: bool = False, ) -> tuple[ dbstate.QueryUnitGroup | SQLDescriptors, Optional[dbstate.CompilerConnectionState] ]: if request.input_language is enums.InputLanguage.SQL_PARAMS: tx = state.current_tx() assert isinstance(request.source, rpc.SQLParamsSource) return ( self.compile_sql_descriptors( tx.get_user_schema(), tx.get_global_schema(), request.protocol_version, request.source.types_in_out, ), # state is the same. state, ) # Apply session differences if any if ( request.modaliases is not None and state.current_tx().get_modaliases() != request.modaliases ): state.current_tx().update_modaliases(request.modaliases) if ( (session_config := request.session_config) is not None and state.current_tx().get_session_config() != session_config ): state.current_tx().update_session_config(session_config) if ( expect_rollback and state.current_tx().id != txid and not state.can_sync_to_savepoint(txid) ): # This is a special case when COMMIT MIGRATION fails, the compiler # doesn't have the right transaction state, so we just roll back. assert isinstance(request.source, edgeql.Source) return self._try_compile_rollback(request.source)[0], state else: state.sync_tx(txid) ctx = CompileContext( compiler_state=self.state, state=state, output_format=request.output_format, expected_cardinality_one=request.expect_one, implicit_limit=request.implicit_limit, inline_typeids=request.inline_typeids, inline_typenames=request.inline_typenames, inline_objectids=request.inline_objectids, source=request.source, protocol_version=request.protocol_version, json_parameters=request.input_format is enums.InputFormat.JSON, expect_rollback=expect_rollback, cache_key=request.get_cache_key(), ) match request.input_language: case enums.InputLanguage.EDGEQL: assert isinstance(request.source, edgeql.Source) unit_group = compile(ctx=ctx, source=request.source) case enums.InputLanguage.GRAPHQL: assert isinstance(request.source, graphql.Source) unit_group = compile_graphql( ctx=ctx, source=request.source, variables=request.key_params, ) case enums.InputLanguage.SQL: assert isinstance(request.source, pg_parser.Source) unit_group = compile_sql_as_unit_group( ctx=ctx, source=request.source) case _: raise NotImplementedError( f"unnsupported input language: {request.input_language}") return unit_group, ctx.state def compile_sql_descriptors( self, user_schema: s_schema.Schema, global_schema: s_schema.Schema, protocol_version: defines.ProtocolVersion, types_in_out: list[tuple[list[str], list[tuple[str, str]]]], ) -> SQLDescriptors: schema = s_schema.ChainedSchema( self.state.std_schema, user_schema, global_schema, ) result = [] for in_out in types_in_out: assert isinstance(in_out, tuple) and len(in_out) == 2 t_in = [] params = [] for idx, id in enumerate(in_out[0]): param_name = str(idx + 1) param_type = schema.get_by_id(turbo_uuid.UUID(id)) assert isinstance(param_type, s_types.Type) param_required = False # SQL arguments can always be NULL if isinstance(param_type, s_types.Array): array_type_id = param_type.get_element_type(schema).id else: array_type_id = None t_in.append( ( param_name, param_type, param_required, ) ) params.append( dbstate.Param( name=param_name, required=param_required, array_type_id=array_type_id, outer_idx=None, # no script support for SQL sub_params=None, # no tuple args support for SQL typename=str(param_type.get_name(schema)), ) ) input_desc, input_desc_id = sertypes.describe_params( schema=schema, params=t_in, protocol_version=protocol_version, ) t_out = { name: cast(s_types.Type, schema.get_by_id(turbo_uuid.UUID(id))) for name, id in in_out[1] } assert all(isinstance(t, s_types.Type) for t in t_out.values()) output_desc, output_desc_id = sertypes.describe_sql_result( schema=schema, row=t_out, protocol_version=protocol_version, ) result.append(( (input_desc, input_desc_id.bytes, params, len(params)), (output_desc, output_desc_id.bytes) )) return result def interpret_backend_error( self, user_schema: bytes, global_schema: bytes, error_fields: dict[str, str], from_graphql: bool, ) -> errors.EdgeDBError: from . import errormech schema = s_schema.ChainedSchema( self.state.std_schema, pickle.loads(user_schema), pickle.loads(global_schema), ) rv: errors.EdgeDBError = errormech.interpret_backend_error( schema, error_fields, from_graphql=from_graphql ) return rv def parse_json_schema( self, schema_json: bytes, base_schema: s_schema.Schema | None, ) -> s_schema.Schema: if base_schema is None: base_schema = self.state.std_schema else: base_schema = s_schema.ChainedSchema( self.state.std_schema, s_schema.EMPTY_SCHEMA, base_schema, ) return s_refl.parse_schema( base_schema=base_schema, data=schema_json, schema_class_layout=self.state.schema_class_layout, ) def parse_db_config( self, db_config_json: bytes, user_schema: s_schema.Schema ) -> immutables.Map[str, config.SettingValue]: spec = config.ChainedSpec( self.state.config_spec, config.load_ext_spec_from_schema( user_schema, self.state.std_schema, ), ) return config.from_json(spec, db_config_json) def parse_global_schema(self, global_schema_json: bytes) -> bytes: global_schema = self.parse_json_schema(global_schema_json, None) return pickle.dumps(global_schema, -1) def parse_user_schema_db_config( self, user_schema_json: bytes, db_config_json: bytes, global_schema_pickle: bytes, ) -> dbstate.ParsedDatabase: global_schema = pickle.loads(global_schema_pickle) user_schema = self.parse_json_schema(user_schema_json, global_schema) db_config = self.parse_db_config(db_config_json, user_schema) ext_config_settings = config.load_ext_settings_from_schema( s_schema.ChainedSchema( self.state.std_schema, user_schema, s_schema.EMPTY_SCHEMA, ) ) state_serializer = self.state.state_serializer_factory.make( user_schema, global_schema, defines.CURRENT_PROTOCOL, ) return dbstate.ParsedDatabase( user_schema_pickle=pickle.dumps(user_schema, -1), schema_version=_get_schema_version(user_schema), database_config=db_config, ext_config_settings=ext_config_settings, protocol_version=defines.CURRENT_PROTOCOL, state_serializer=state_serializer, feature_used_metrics=ddl.produce_feature_used_metrics( self.state, user_schema ), ) def make_state_serializer( self, protocol_version: defines.ProtocolVersion, user_schema_pickle: bytes, global_schema_pickle: bytes, ) -> sertypes.StateSerializer: user_schema = pickle.loads(user_schema_pickle) global_schema = pickle.loads(global_schema_pickle) return self.state.state_serializer_factory.make( user_schema, global_schema, protocol_version, ) def make_compilation_config_serializer( self, ) -> sertypes.CompilationConfigSerializer: return self.state.compilation_config_serializer def describe_database_dump( self, user_schema_json: bytes, global_schema_json: bytes, db_config_json: bytes, protocol_version: defines.ProtocolVersion, with_secrets: bool, ) -> DumpDescriptor: global_schema = self.parse_json_schema(global_schema_json, None) user_schema = self.parse_json_schema(user_schema_json, global_schema) database_config = self.parse_db_config(db_config_json, user_schema) schema = s_schema.ChainedSchema( self.state.std_schema, user_schema, global_schema ) sys_config_ddl = config.to_edgeql( self.state.config_spec, database_config, with_secrets=with_secrets, ) # We need to put extension DDL configs *after* we have # reloaded the schema user_config_ddl = config.to_edgeql( config.load_ext_spec_from_schema( user_schema, self.state.std_schema), database_config, with_secrets=with_secrets, ) schema_ddl = s_ddl.ddl_text_from_schema( schema, include_migrations=True) ids, sequences = get_obj_ids(schema) raw_ids = [(name, cls, id.bytes) for name, cls, id in ids] objtypes = schema.get_objects( type=s_objtypes.ObjectType, exclude_stdlib=True, ) descriptors = [] cfg_object = schema.get('cfg::ConfigObject', type=s_objtypes.ObjectType) for objtype in objtypes: if objtype.is_union_type(schema) or objtype.is_view(schema): continue if objtype.issubclass(schema, cfg_object): continue descriptors.extend(_describe_object(schema, objtype, protocol_version)) dynamic_ddl = [] if sequences: seq_ids = ', '.join( pg_common.quote_literal(str(seq_id)) for seq_id in sequences ) dynamic_ddl.append( f'SELECT edgedb._dump_sequences(ARRAY[{seq_ids}]::uuid[])' ) return DumpDescriptor( schema_ddl='\n'.join([sys_config_ddl, schema_ddl, user_config_ddl]), schema_dynamic_ddl=tuple(dynamic_ddl), schema_ids=raw_ids, blocks=descriptors, ) def _reprocess_restore_config( self, stmts: Iterable[qlast.Base], ) -> list[qlast.Base]: '''Do any rewrites to the restore script needed. This is intended to patch over certain backwards incompatible changes to config. We try not to do that too much, but when we do, dumps still need to work. ''' new_stmts = [] smtp_config = {} for stmt in stmts: # ext::auth::SMTPConfig got removed and moved into a cfg # object, so intercept those and rewrite them. if ( isinstance(stmt, qlast.ConfigSet) and stmt.name.module == 'ext::auth::SMTPConfig' ): smtp_config[stmt.name.name] = stmt.expr else: new_stmts.append(stmt) if smtp_config: # Do the rewrite of SMTPConfig smtp_config['name'] = qlast.Constant.string('_default') new_stmts.append( qlast.ConfigInsert( scope=qltypes.ConfigScope.DATABASE, name=qlast.ObjectRef( module='cfg', name='SMTPProviderConfig' ), shape=[ qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=name)]), compexpr=expr, ) for name, expr in smtp_config.items() ], ) ) new_stmts.append( qlast.ConfigSet( scope=qltypes.ConfigScope.DATABASE, name=qlast.ObjectRef( name='current_email_provider_name' ), expr=qlast.Constant.string('_default'), ) ) return new_stmts def describe_database_restore( self, user_schema_pickle: bytes, global_schema_pickle: bytes, dump_server_ver_str: str, dump_catalog_version: Optional[int], schema_ddl: bytes, schema_ids: list[tuple[str, str, bytes]], blocks: list[tuple[bytes, bytes]], # type_id, typespec protocol_version: defines.ProtocolVersion, ) -> RestoreDescriptor: schema_object_ids = { ( s_name.name_from_string(name), qltype if qltype else None ): uuidgen.from_bytes(objid) for name, qltype, objid in schema_ids } dump_server_ver = verutils.parse_version(dump_server_ver_str) # catalog_version didn't exist until late in the 3.0 cycle, # but we can just treat that as being version 0 dump_catalog_version = dump_catalog_version or 0 state = dbstate.CompilerConnectionState( user_schema=pickle.loads(user_schema_pickle), global_schema=pickle.loads(global_schema_pickle), modaliases=DEFAULT_MODULE_ALIASES_MAP, session_config=EMPTY_MAP, database_config=EMPTY_MAP, system_config=EMPTY_MAP, cached_reflection=EMPTY_MAP, ) ctx = CompileContext( compiler_state=self.state, state=state, output_format=enums.OutputFormat.BINARY, expected_cardinality_one=False, compat_ver=dump_server_ver, schema_object_ids=schema_object_ids, log_ddl_as_migrations=False, protocol_version=protocol_version, dump_restore_mode=True, ) ctx.state.start_tx() dump_with_extraneous_computables = ( dump_server_ver < (1, 0, verutils.VersionStage.ALPHA, 8) ) dump_with_ptr_item_id = dump_with_extraneous_computables allow_dml_in_functions = ( dump_server_ver < (1, 0, verutils.VersionStage.BETA, 1) ) # This change came late in the 3.0 dev cycle, and with it we # switched to using catalog versions for this, so that nightly # dumps might work. dump_with_dunder_type = ( dump_catalog_version < 2023_02_16_00_00 ) schema_ddl_text = schema_ddl.decode('utf-8') if allow_dml_in_functions: schema_ddl_text = ( 'CONFIGURE CURRENT DATABASE ' 'SET allow_dml_in_functions := true;\n' + schema_ddl_text ) ddl_source = edgeql.Source.from_string(schema_ddl_text) # The state serializer generated below is somehow inappropriate, # so it's simply ignored here and the I/O process will do it on its own commands = edgeql.parse_block(ddl_source) statements = self._reprocess_restore_config(commands) units = _try_compile_ast( ctx=ctx, source=ddl_source, statements=statements ).units _check_force_database_error(ctx, scope='restore') schema = ctx.state.current_tx().get_schema( ctx.compiler_state.std_schema) # The AI extension needs to run some code before restoring data. # TODO: Generalize this mechanism. if schema.get_global(s_ext.Extension, 'ai', default=None): from edb.pgsql import delta_ext_ai ddl_source = edgeql.Source.from_string( delta_ext_ai.get_ext_ai_pre_restore_script(schema)) units += compile(ctx=ctx, source=ddl_source).units restore_blocks = [] tables = [] repopulate_units = [] for schema_object_id_bytes, typedesc in blocks: schema_object_id = uuidgen.from_bytes(schema_object_id_bytes) obj = schema.get_by_id(schema_object_id) desc = sertypes.parse(typedesc, protocol_version) elided_col_set = set() mending_desc: list[Optional[DataMendingDescriptor]] = [] if isinstance(obj, s_props.Property): assert isinstance(desc, sertypes.NamedTupleDesc) desc_ptrs = list(desc.fields.keys()) cols = { 'source': 'source', 'target': 'target', } mending_desc.append(None) mending_desc.append(_get_ptr_mending_desc(schema, obj)) if dump_with_ptr_item_id: elided_col_set.add('ptr_item_id') mending_desc.append(None) elif isinstance(obj, s_links.Link): assert isinstance(desc, sertypes.NamedTupleDesc) desc_ptrs = list(desc.fields.keys()) cols = {} ptrs = dict(obj.get_pointers(schema).items(schema)) for ptr_name in desc_ptrs: if dump_with_ptr_item_id and ptr_name == 'ptr_item_id': elided_col_set.add(ptr_name) cols[ptr_name] = ptr_name mending_desc.append(None) else: ptr = ptrs[s_name.UnqualName(ptr_name)] if ( dump_with_extraneous_computables and ptr.is_pure_computable(schema) ): elided_col_set.add(ptr_name) mending_desc.append(None) if not ptr.is_dumpable(schema): continue stor_info = pg_types.get_pointer_storage_info( ptr, schema=schema, source=obj, link_bias=True, ) cols[ptr_name] = stor_info.column_name mending_desc.append( _get_ptr_mending_desc(schema, ptr)) elif isinstance(obj, s_objtypes.ObjectType): assert isinstance(desc, sertypes.ShapeDesc) desc_ptrs = list(desc.fields.keys()) cols = {} ptrs = dict(obj.get_pointers(schema).items(schema)) addons = { name: (col, type) for name, col, type in obj.get_addon_columns(schema) } for ptr_name in desc_ptrs: # If the pointer was one of our "addon columns" # (fts and ai shadow index columns), restore it # directly. # # N.B: This will need to become more sophisticated # if (when) we change the naming of any of our # addons. if ptr_name in addons: col, _type = addons[ptr_name] cols[ptr_name] = col mending_desc.append(None) continue ptr = ptrs[s_name.UnqualName(ptr_name)] if ( dump_with_extraneous_computables and ptr.is_pure_computable(schema) ) or ( dump_with_dunder_type and ptr_name == '__type__' ): elided_col_set.add(ptr_name) mending_desc.append(None) if not ptr.is_dumpable(schema): continue stor_info = pg_types.get_pointer_storage_info( ptr, schema=schema, source=obj, ) if stor_info.table_type == 'ObjectType': ptr_name = ptr.get_shortname(schema).name cols[ptr_name] = stor_info.column_name mending_desc.append( _get_ptr_mending_desc(schema, ptr)) cmd = pg_delta.get_reindex_sql(obj, desc, schema) if cmd: repopulate_units.append(cmd) else: raise AssertionError( f'unexpected object type in restore ' f'type descriptor: {obj!r}' ) _check_dump_layout( frozenset(desc_ptrs), frozenset(cols), elided_col_set, label=obj.get_verbosename(schema, with_parent=True), ) table_name = pg_common.get_backend_name( schema, obj, catenate=True) elided_cols = tuple(i for i, pn in enumerate(desc_ptrs) if pn in elided_col_set) col_list = ( pg_common.quote_ident(cols[pn]) for pn in desc_ptrs if pn not in elided_col_set ) stmt = ( f'COPY {table_name} ' f'({", ".join(col_list)})' f'FROM STDIN WITH (FORMAT binary, FREEZE true)' ).encode() restore_blocks.append( RestoreBlockDescriptor( schema_object_id=schema_object_id, sql_copy_stmt=stmt, compat_elided_cols=elided_cols, data_mending_desc=tuple(mending_desc), ) ) tables.append(table_name) return RestoreDescriptor( units=units, blocks=restore_blocks, tables=tables, repopulate_units=repopulate_units, ) def analyze_explain_output( self, query_asts_pickled: bytes, data: list[list[bytes]], ) -> bytes: return explain.analyze_explain_output( query_asts_pickled, data, self.state.std_schema) def validate_schema_equivalence( self, schema_a: bytes, schema_b: bytes, global_schema: bytes, conn_state_pickle: Optional[bytes], ) -> None: if conn_state_pickle: conn_state = pickle.loads(conn_state_pickle) if ( conn_state and ( conn_state.current_tx().get_migration_state() or conn_state.current_tx().get_migration_rewrite_state() ) ): return ddl.validate_schema_equivalence( self.state, pickle.loads(schema_a), pickle.loads(schema_b), pickle.loads(global_schema), ) def compile_structured_config( self, objects: Mapping[str, config_compiler.ConfigObject], source: str | None = None, allow_nested: bool = False, ) -> dict[str, immutables.Map[str, config.SettingValue]]: # XXX: only config in the stdlib is supported currently, so the only # key allowed in objects is "cfg::Config". API for future compatibility if list(objects) != ["cfg::Config"]: difference = set(objects) - {"cfg::Config"} raise NotImplementedError( f"unsupported config: {', '.join(difference)}" ) return config_compiler.compile_structured_config( objects, spec=self.state.config_spec, schema=self.state.std_schema, source=source, allow_nested=allow_nested, ) def compile_schema_storage_in_delta( ctx: CompileContext, delta: s_delta.Command, block: pg_dbops.SQLBlock, context: Optional[s_delta.CommandContext] = None, ) -> None: current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) funcblock = block.add_block() cmdblock = block.add_block() meta_blocks: list[tuple[str, dict[str, Any]]] = [] # Use a provided context if one was passed in, which lets us # used the cached values for resolved properties. (Which is # important, since if there were renames we won't necessarily # be able to resolve them just using the new schema.) if not context: context = s_delta.CommandContext() else: context.renames.clear() context.early_renames.clear() s_refl.generate_metadata_write_edgeql( delta, classlayout=ctx.compiler_state.schema_class_layout, schema=schema, context=context, blocks=meta_blocks, internal_schema_mode=ctx.internal_schema_mode, stdmode=ctx.bootstrap_mode, ) cache = current_tx.get_cached_reflection() with cache.mutate() as cache_mm: for eql, args in meta_blocks: eql_hash = hashlib.sha1(eql.encode()).hexdigest() fname = (pg_common.versioned_schema('edgedb'), f'__rh_{eql_hash}') if eql_hash in cache_mm: argnames = cache_mm[eql_hash] else: sql, argmap = _compile_schema_storage_stmt(ctx, eql) argnames = tuple(arg.name for arg in argmap) func = pg_dbops.Function( name=fname, args=[(argname, 'json') for argname in argnames], returns='json', text=sql, ) # We drop first instead of using or_replace, in case # something about the arguments changed. df = pg_dbops.DropFunction( name=func.name, args=func.args or (), # Use a condition instead of if_exists ot reduce annoying # debug spew from postgres. conditions=[pg_dbops.FunctionExists( name=func.name, args=func.args or (), )], ) df.generate(funcblock) cf = pg_dbops.CreateFunction(func) cf.generate(funcblock) cache_mm[eql_hash] = argnames argvals = [] for argname in argnames: argvals.append(pg_common.quote_literal(args[argname])) cmdblock.add_command(textwrap.dedent(f'''\ PERFORM {pg_common.qname(*fname)}({", ".join(argvals)}); ''')) ctx.state.current_tx().update_cached_reflection(cache_mm.finish()) def _compile_schema_storage_stmt( ctx: CompileContext, eql: str, output_format: enums.OutputFormat = enums.OutputFormat.JSON, ) -> tuple[str, Sequence[dbstate.Param]]: schema = ctx.state.current_tx().get_schema(ctx.compiler_state.std_schema) try: # Switch to the shadow introspection/reflection schema. ctx.state.current_tx().update_schema( # Trick dbstate to set the effective schema # to refl_schema. s_schema.ChainedSchema( ctx.compiler_state.std_schema, ctx.compiler_state.refl_schema, s_schema.EMPTY_SCHEMA ) ) newctx = CompileContext( compiler_state=ctx.compiler_state, state=ctx.state, json_parameters=True, schema_reflection_mode=True, output_format=output_format, expected_cardinality_one=False, bootstrap_mode=ctx.bootstrap_mode, protocol_version=ctx.protocol_version, backend_runtime_params=ctx.backend_runtime_params, ) source = edgeql.Source.from_string(eql) unit_group = compile(ctx=newctx, source=source) sql_stmts = [] for u in unit_group: stmt = u.sql.strip() if not stmt.endswith(b';'): stmt += b';' sql_stmts.append(stmt) if len(sql_stmts) > 1: raise errors.InternalServerError( 'compilation of schema update statement' ' yielded more than one SQL statement' ) sql = sql_stmts[0].strip(b';').decode() argmap: Optional[Sequence[dbstate.Param]] = unit_group[0].in_type_args if argmap is None: argmap = () return sql, argmap finally: # Restore the regular schema. ctx.state.current_tx().update_schema(schema) def _get_schema_version(user_schema: s_schema.Schema) -> uuid.UUID: ver = user_schema.get_global(s_ver.SchemaVersion, "__schema_version__") return ver.get_version(user_schema) def _compile_ql_script( ctx: CompileContext, eql: str, ) -> str: source = edgeql.Source.from_string(eql) unit_group = compile(ctx=ctx, source=source) sql_stmts = [] for u in unit_group: stmt = u.sql.strip() if not stmt.endswith(b';'): stmt += b';' sql_stmts.append(stmt) return b'\n'.join(sql_stmts).decode() def _get_compile_options( ctx: CompileContext, *, is_explain: bool = False, no_implicit_fields: bool = False, ) -> qlcompiler.CompilerOptions: can_have_implicit_fields = not no_implicit_fields and ( ctx.output_format is enums.OutputFormat.BINARY ) return qlcompiler.CompilerOptions( modaliases=ctx.state.current_tx().get_modaliases(), implicit_tid_in_shapes=( can_have_implicit_fields and ctx.inline_typeids ), implicit_tname_in_shapes=( can_have_implicit_fields and ctx.inline_typenames ), implicit_id_in_shapes=( can_have_implicit_fields and ctx.inline_objectids ), json_parameters=ctx.json_parameters, implicit_limit=ctx.implicit_limit, bootstrap_mode=ctx.bootstrap_mode, dump_restore_mode=ctx.dump_restore_mode, apply_query_rewrites=( not ctx.bootstrap_mode and not ctx.schema_reflection_mode and not ctx.dump_restore_mode # HMMM and not bool( _get_config_val(ctx, '__internal_no_apply_query_rewrites')) ), apply_user_access_policies=_get_config_val( ctx, 'apply_access_policies'), allow_user_specified_id=_get_config_val( ctx, 'allow_user_specified_id') or ctx.schema_reflection_mode, is_explain=is_explain, testmode=ctx.is_testmode(), schema_reflection_mode=( ctx.schema_reflection_mode or _get_config_val(ctx, '__internal_query_reflschema') ), ) # Types and default values for EXPLAIN parameters EXPLAIN_PARAMS = dict( buffers=('std::bool', False), execute=('std::bool', True), ) def _compile_ql_explain( ctx: CompileContext, ql: qlast.ExplainStmt, *, script_info: Optional[irast.ScriptInfo] = None, ) -> dbstate.BaseQuery: args = {k: v for k, (_, v) in EXPLAIN_PARAMS.items()} current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) # Evaluate and typecheck arguments if ql.args: for el in ql.args.elements: name = el.name.name if name not in EXPLAIN_PARAMS: raise errors.QueryError( f"unknown ANALYZE argument '{name}'", span=el.span, ) arg_ir = qlcompiler.compile_ast_to_ir( el.val, schema=schema, options=qlcompiler.CompilerOptions( modaliases=current_tx.get_modaliases(), ), ) exp_typ = schema.get(EXPLAIN_PARAMS[name][0], type=s_types.Type) if not arg_ir.stype.issubclass(schema, exp_typ): raise errors.QueryError( f"incorrect type for ANALYZE argument '{name}': " f"expected '{exp_typ.get_name(schema)}', " f"got '{arg_ir.stype.get_name(schema)}'", span=el.span, ) args[name] = ireval.evaluate_to_python_val(arg_ir.expr, schema) analyze = 'ANALYZE true, ' if args['execute'] else '' buffers = 'BUFFERS, ' if args['buffers'] else '' exp_command = f'EXPLAIN ({analyze}{buffers}FORMAT JSON, VERBOSE true)' ctx = dataclasses.replace( ctx, inline_typeids=False, inline_typenames=False, implicit_limit=0, output_format=enums.OutputFormat.BINARY, ) config_vals = _get_compilation_config_vals(ctx) modaliases = ctx.state.current_tx().get_modaliases() explain_data = (config_vals, args, modaliases) query = _compile_ql_query( ctx, ql.query, script_info=script_info, explain_data=explain_data, cacheable=False) if isinstance(query, dbstate.NullQuery): raise errors.QueryError( f"cannot ANALYZE inside of a migration", span=ql.span, ) assert query.sql out_type_data, out_type_id = sertypes.describe( schema, schema.get("std::str", type=s_types.Type), protocol_version=ctx.protocol_version, ) sql_bytes = exp_command.encode('utf-8') + query.sql sql_hash = _hash_sql( sql_bytes, mode=str(ctx.output_format).encode(), intype=query.in_type_id, outtype=out_type_id.bytes) return dataclasses.replace( query, is_explain=True, run_and_rollback=args['execute'], cacheable=False, sql=sql_bytes, sql_hash=sql_hash, cardinality=enums.Cardinality.ONE, out_type_data=out_type_data, out_type_id=out_type_id.bytes, ) def _compile_ql_administer( ctx: CompileContext, ql: qlast.AdministerStmt, *, script_info: Optional[irast.ScriptInfo] = None, ) -> dbstate.BaseQuery: if ql.expr.func == 'statistics_update': res = ddl.administer_statistics_update(ctx, ql) elif ql.expr.func == 'schema_repair': res = ddl.administer_repair_schema(ctx, ql) elif ql.expr.func == 'reindex': res = ddl.administer_reindex(ctx, ql) elif ql.expr.func == 'vacuum': res = ddl.administer_vacuum(ctx, ql) elif ql.expr.func == 'prepare_upgrade': res = ddl.administer_prepare_upgrade(ctx, ql) elif ql.expr.func == '_remove_pointless_triggers': res = ddl.administer_remove_pointless_triggers(ctx, ql) elif ql.expr.func == 'concurrent_index_build': res = ddl.administer_concurrent_index_build(ctx, ql) elif ql.expr.func == 'fixup_backend_upgrade': res = ddl.administer_fixup_backend_upgrade(ctx, ql) else: raise errors.QueryError( 'Unknown ADMINISTER function', span=ql.expr.span, ) if debug.flags.delta_execute or debug.flags.delta_execute_ddl: debug.header('ADMINISTER script') debug.dump_code(res.sql, lexer='sql') return res def _compile_ql_query( ctx: CompileContext, ql: qlast.Query | qlast.Command, *, script_info: Optional[irast.ScriptInfo] = None, source: Optional[edgeql.Source] = None, cacheable: bool = True, migration_block_query: bool = False, explain_data: object = None, ) -> dbstate.Query | dbstate.NullQuery: is_explain = explain_data is not None current_tx = ctx.state.current_tx() sql_info: dict[str, Any] = {} if ( not ctx.bootstrap_mode and ctx.backend_runtime_params.has_stat_statements and not ctx.schema_reflection_mode ): spec = ctx.compiler_state.config_spec cconfig = config.to_json_obj( spec, { **current_tx.get_system_config(), **current_tx.get_database_config(), **current_tx.get_session_config(), }, setting_filter=lambda v: v.name in spec and spec[v.name].affects_compilation, include_source=False, ) extras: dict[str, Any] = { 'cc': dict(sorted(cconfig.items())), # compilation_config 'pv': ctx.protocol_version, # protocol_version 'of': ctx.output_format, # output_format 'e1': ctx.expected_cardinality_one, # expect_one 'il': ctx.implicit_limit, # implicit_limit 'ii': ctx.inline_typeids, # inline_typeids 'in': ctx.inline_typenames, # inline_typenames 'io': ctx.inline_objectids, # inline_objectids } modaliases = dict(current_tx.get_modaliases()) # dn: default_namespace extras['dn'] = modaliases.pop(None, defines.DEFAULT_MODULE_ALIAS) if modaliases: # na: namespace_aliases extras['na'] = dict(sorted(modaliases.items())) sql_info.update({ 'query': qlcodegen.generate_source(ql), 'type': defines.QueryType.EdgeQL, 'extras': json.dumps(extras), }) id_hash = hashlib.blake2b(digest_size=16) id_hash.update( json.dumps(sql_info).encode(defines.EDGEDB_ENCODING) ) sql_info['id'] = str(uuidgen.from_bytes(id_hash.digest())) base_schema = ( ctx.compiler_state.std_schema if not _get_config_val(ctx, '__internal_query_reflschema') else ctx.compiler_state.refl_schema ) schema = current_tx.get_schema(base_schema) options = _get_compile_options(ctx, is_explain=is_explain) ir = qlcompiler.compile_ast_to_ir( ql, schema=schema, script_info=script_info, options=options, ) result_cardinality = enums.cardinality_from_ir_value(ir.cardinality) # This low-hanging-fruit is temporary; persistent cache should cover all # cacheable cases properly in future changes. use_persistent_cache = ( cacheable and not ctx.bootstrap_mode and script_info is None and ctx.cache_key is not None ) cache_mode = ctx.get_cache_mode() sql_res = pg_compiler.compile_ir_to_sql_tree( ir, expected_cardinality_one=ctx.expected_cardinality_one, output_format=_convert_format(ctx.output_format), json_parameters=options.json_parameters, backend_runtime_params=ctx.backend_runtime_params, is_explain=options.is_explain, cache_as_function=(use_persistent_cache and cache_mode is config.QueryCacheMode.PgFunc), versioned_stdlib=True, ) sql_text = pg_codegen.generate_source(sql_res.ast) func_call_sql = None pg_debug.dump_ast_and_query(sql_res.ast, ir) if use_persistent_cache and cache_mode is config.QueryCacheMode.PgFunc: cache_sql, func_call_ast = _build_cache_function(ctx, ir, sql_res) func_call_sql = pg_codegen.generate_source(func_call_ast) elif ( use_persistent_cache and cache_mode is config.QueryCacheMode.RegInline ): cache_sql = (b"", b"") else: cache_sql = None if ( (mstate := current_tx.get_migration_state()) and not migration_block_query ): if isinstance(ql, qlast.Query): mstate = mstate._replace( accepted_cmds=( mstate.accepted_cmds + (qlast.DDLQuery(query=ql),) ) ) current_tx.update_migration_state(mstate) return dbstate.NullQuery() # If requested, embed the EdgeQL text in the SQL. if debug.flags.edgeql_text_in_sql and source: sql_info['edgeql'] = source.text() if sql_info: sql_info_prefix = '-- ' + json.dumps(sql_info) + '\n' else: sql_info_prefix = '' globals = None permissions = None json_permissions = None if ir.globals: globals = [ (str(glob.global_name), glob.has_present_arg) for glob in ir.globals if not glob.is_permission ] permissions = [ str(glob.global_name) for glob in ir.globals if glob.is_permission ] if options.json_parameters: # In JSON parameters mode, keep only the synthetic globals, # and report the permissions as needing to be injected into # the JSON. if globals: globals = [g for g in globals if g[0].startswith('__::')] json_permissions, permissions = permissions, [] required_permissions = None if ir.required_permissions: required_permissions = [ str(perm.get_name(schema)) for perm in ir.required_permissions ] out_type_id: uuid.UUID if ctx.output_format is enums.OutputFormat.NONE: out_type_id = sertypes.NULL_TYPE_ID out_type_data = sertypes.NULL_TYPE_DESC result_cardinality = enums.Cardinality.NO_RESULT elif ctx.output_format is enums.OutputFormat.BINARY: out_type_data, out_type_id = sertypes.describe( ir.schema, ir.stype, ir.view_shapes, ir.view_shapes_metadata, inline_typenames=ctx.inline_typenames, protocol_version=ctx.protocol_version) else: out_type_data, out_type_id = sertypes.describe( ir.schema, ir.schema.get("std::str", type=s_types.Type), protocol_version=ctx.protocol_version, ) in_type_args, in_type_data, in_type_id = describe_params( ctx, ir, sql_res.argmap, script_info ) server_param_conversions: Optional[ list[dbstate.ServerParamConversion] ] = None if isinstance(ir, irast.Statement) and ir.server_param_conversions: # The irast.ServerParamConversion we get from the ql compiler contains # either a script_param_index or a constant value. # # A script_param_index can refer to either an actual query param or a # constant that was normalized out of a query. # # A constant value is used as is by the server. server_param_conversions = [ dbstate.ServerParamConversion( param_name=p.param_name, conversion_name=p.conversion_name, additional_info=p.additional_info, script_param_index=p.script_param_index, constant_value=p.constant_value ) for p in ir.server_param_conversions ] sql_hash = _hash_sql( sql_text.encode(defines.EDGEDB_ENCODING), mode=str(ctx.output_format).encode(), intype=in_type_id.bytes, outtype=out_type_id.bytes) cache_func_call = None if func_call_sql is not None: func_call_sql_hash = _hash_sql( func_call_sql.encode(defines.EDGEDB_ENCODING), mode=str(ctx.output_format).encode(), intype=in_type_id.bytes, outtype=out_type_id.bytes, ) cache_func_call = ( (sql_info_prefix + func_call_sql).encode(defines.EDGEDB_ENCODING), func_call_sql_hash, ) if is_explain: if isinstance(ir.schema, s_schema.ChainedSchema): # Strip the std schema out ir.schema = s_schema.ChainedSchema( top_schema=ir.schema._top_schema, global_schema=ir.schema._global_schema, base_schema=s_schema.EMPTY_SCHEMA, ) query_asts = pickle.dumps((ql, ir, sql_res.ast, explain_data)) else: query_asts = None return dbstate.Query( sql=(sql_info_prefix + sql_text).encode(defines.EDGEDB_ENCODING), sql_hash=sql_hash, cache_sql=cache_sql, cache_func_call=cache_func_call, cardinality=result_cardinality, globals=globals, permissions=permissions, required_permissions=required_permissions, json_permissions=json_permissions, in_type_id=in_type_id.bytes, in_type_data=in_type_data, in_type_args=in_type_args, out_type_id=out_type_id.bytes, out_type_data=out_type_data, server_param_conversions=server_param_conversions, cacheable=cacheable, has_dml=bool(ir.dml_exprs), query_asts=query_asts, warnings=ir.warnings, unsafe_isolation_dangers=ir.unsafe_isolation_dangers, ) def _build_cache_function( ctx: CompileContext, ir: irast.Statement, sql_res: pg_compiler.CompileResult, ) -> tuple[tuple[bytes, bytes], pgast.Base]: sql_ast = sql_res.ast assert ctx.cache_key is not None key = ctx.cache_key.hex returns_record = False set_returning = True return_type: tuple[str, ...] = ("unknown",) match ctx.output_format: case enums.OutputFormat.NONE: # CONFIGURE commands are never cached; other queries are actually # wrapped with a count() call in top_output_as_value(), so set the # return_type to reflect that fact. This was set to `void`, leading # to issues that certain exceptions are not raised as expected when # wrapped with a function returning (setof) void - reproducible # with test_edgeql_casts_json_12() and EDGEDB_TEST_REPEATS=1. return_type = ("int",) if ir.stype.is_object_type() or ir.stype.is_tuple(ir.schema): returns_record = True case enums.OutputFormat.BINARY: if ir.stype.is_object_type(): return_type = ("record",) returns_record = True else: return_type = pg_types.pg_type_from_ir_typeref( ir.expr.typeref.base_type or ir.expr.typeref, serialized=True, ) if ir.stype.is_tuple(ir.schema): returns_record = return_type == ('record',) case enums.OutputFormat.JSON: return_type = ("json",) case enums.OutputFormat.JSON_ELEMENTS: return_type = ("json",) if returns_record: assert isinstance(sql_ast, pgast.ReturningQuery) sql_ast.target_list.append( pgast.ResTarget( name="sentinel", val=pgast.BooleanConstant(val=True), ), ) # XXX: we need to put the version in the key fname = (pg_common.versioned_schema("edgedb"), f"__qh_{key}") func = pg_dbops.Function( name=fname, args=[(None, arg) for arg in sql_res.cached_params or []], returns=return_type, set_returning=set_returning, text=pg_codegen.generate_source(sql_ast), ) if not ir.dml_exprs: func.volatility = "stable" cf = pg_dbops.SQLBlock() pg_dbops.CreateFunction(func).generate(cf) df = pg_dbops.SQLBlock() pg_dbops.DropFunction( name=func.name, args=func.args or (), # Use a condition instead of if_exists ot reduce annoying # debug spew from postgres. conditions=[pg_dbops.FunctionExists( name=func.name, args=func.args or (), )], ).generate(df) func_call = pgast.FuncCall( name=fname, args=[ pgast.TypeCast( arg=pgast.ParamRef(number=i), type_name=pgast.TypeName(name=arg), ) for i, arg in enumerate(sql_res.cached_params or [], 1) ], coldeflist=[], ) if returns_record: func_call.coldeflist.extend( [ pgast.ColumnDef( name="result", typename=pgast.TypeName(name=("record",)), ), pgast.ColumnDef( name="sentinel", typename=pgast.TypeName(name=("bool",)), ), ] ) sql_ast = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef(name=("result",))), ], from_clause=[ pgast.RangeFunction( functions=[func_call], is_rowsfrom=True, ), ], ) else: sql_ast = pgast.SelectStmt( target_list=[pgast.ResTarget(val=func_call)], ) cache_sql = ( cf.to_string().encode(defines.EDGEDB_ENCODING), df.to_string().encode(defines.EDGEDB_ENCODING), ) return cache_sql, sql_ast def describe_params( ctx: CompileContext, ir: irast.Statement | irast.ConfigCommand, argmap: dict[str, pgast.Param], script_info: Optional[irast.ScriptInfo], ) -> tuple[Optional[list[dbstate.Param]], bytes, uuid.UUID]: in_type_args = None params: list[tuple[str, s_types.Type, bool]] = [] assert ir.schema if ir.params: params, in_type_args = _extract_params( ir.params, argmap=argmap, script_info=script_info, schema=ir.schema, ctx=ctx, ) in_type_data, in_type_id = sertypes.describe_params( schema=ir.schema, params=params, protocol_version=ctx.protocol_version, ) return in_type_args, in_type_data, in_type_id def _compile_ql_transaction( ctx: CompileContext, ql: qlast.Transaction ) -> dbstate.TxControlQuery: cacheable = True modaliases = None final_user_schema: Optional[s_schema.Schema] = None final_cached_reflection = None final_global_schema: Optional[s_schema.Schema] = None sp_name = None sp_id = None iso = None if ctx.expect_rollback and not isinstance( ql, (qlast.RollbackTransaction, qlast.RollbackToSavepoint) ): raise errors.TransactionError( 'expected a ROLLBACK or ROLLBACK TO SAVEPOINT command' ) if isinstance(ql, qlast.StartTransaction): ctx._assert_not_in_migration_block(ql) ctx.state.start_tx() # Compute the effective isolation level iso_config: statypes.TransactionIsolation = _get_config_val( ctx, "default_transaction_isolation" ) default_iso = iso_config.to_qltypes() iso = ql.isolation if iso is None: iso = default_iso # Compute the effective access mode access = ql.access if access is None: access_mode: statypes.TransactionAccessMode = _get_config_val( ctx, "default_transaction_access_mode" ) access = access_mode.to_qltypes() sqls = f'START TRANSACTION ISOLATION LEVEL {iso.value} {access.value}' if ql.deferrable is not None: sqls += f' {ql.deferrable.value}' sqls += ';' sql = sqls.encode() action = dbstate.TxAction.START cacheable = False elif isinstance(ql, qlast.CommitTransaction): ctx._assert_not_in_migration_block(ql) cur_tx = ctx.state.current_tx() final_user_schema = cur_tx.get_user_schema_if_updated() final_cached_reflection = cur_tx.get_cached_reflection_if_updated() final_global_schema = cur_tx.get_global_schema_if_updated() new_state = ctx.state.commit_tx() modaliases = new_state.modaliases sql = b'COMMIT' cacheable = False action = dbstate.TxAction.COMMIT elif isinstance(ql, qlast.RollbackTransaction): new_state = ctx.state.rollback_tx() modaliases = new_state.modaliases sql = b'ROLLBACK' cacheable = False action = dbstate.TxAction.ROLLBACK elif isinstance(ql, qlast.DeclareSavepoint): tx = ctx.state.current_tx() sp_id = tx.declare_savepoint(ql.name) pgname = pg_common.quote_ident(ql.name) sql = f'SAVEPOINT {pgname}'.encode() cacheable = False action = dbstate.TxAction.DECLARE_SAVEPOINT sp_name = ql.name elif isinstance(ql, qlast.ReleaseSavepoint): ctx.state.current_tx().release_savepoint(ql.name) pgname = pg_common.quote_ident(ql.name) sql = f'RELEASE SAVEPOINT {pgname}'.encode() action = dbstate.TxAction.RELEASE_SAVEPOINT elif isinstance(ql, qlast.RollbackToSavepoint): tx = ctx.state.current_tx() new_state = tx.rollback_to_savepoint(ql.name) modaliases = new_state.modaliases pgname = pg_common.quote_ident(ql.name) sql = f'ROLLBACK TO SAVEPOINT {pgname};'.encode() cacheable = False action = dbstate.TxAction.ROLLBACK_TO_SAVEPOINT sp_name = ql.name else: # pragma: no cover raise ValueError(f'expected a transaction AST node, got {ql!r}') return dbstate.TxControlQuery( sql=sql, action=action, cacheable=cacheable, modaliases=modaliases, user_schema=final_user_schema, cached_reflection=final_cached_reflection, global_schema=final_global_schema, sp_name=sp_name, sp_id=sp_id, isolation_level=iso, feature_used_metrics=( ddl.produce_feature_used_metrics( ctx.compiler_state, final_user_schema ) if final_user_schema else None ), ) def _compile_ql_sess_state( ctx: CompileContext, ql: qlast.SessionCommand ) -> dbstate.SessionStateQuery: current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) aliases = ctx.state.current_tx().get_modaliases() if isinstance(ql, qlast.SessionSetAliasDecl): try: schema.get_global(s_mod.Module, ql.decl.module) except errors.InvalidReferenceError: raise errors.UnknownModuleError( f'module {ql.decl.module!r} does not exist' ) from None aliases = aliases.set(ql.decl.alias, ql.decl.module) elif isinstance(ql, qlast.SessionResetModule): aliases = aliases.set(None, s_mod.DEFAULT_MODULE_ALIAS) elif isinstance(ql, qlast.SessionResetAllAliases): aliases = DEFAULT_MODULE_ALIASES_MAP elif isinstance(ql, qlast.SessionResetAliasDecl): aliases = aliases.delete(ql.alias) else: # pragma: no cover raise errors.InternalServerError( f'unsupported SET command type {type(ql)!r}') ctx.state.current_tx().update_modaliases(aliases) return dbstate.SessionStateQuery() def _get_config_spec( ctx: CompileContext, config_op: config.Operation ) -> config.Spec: config_spec = ctx.compiler_state.config_spec if config_op.setting_name not in config_spec: # We don't typically bother tracking the user config spec in # the compiler workers (to avoid needing to bother with # transmitting, caching, or computing it). If we hit a config # op that needs it, load the spec. config_spec = config.ChainedSpec( config_spec, config.load_ext_spec_from_schema( ctx.state.current_tx().get_user_schema(), ctx.compiler_state.std_schema, ), ) return config_spec def _inject_config_cache_clear(sql_ast: pgast.Base) -> pgast.Base: """Inject a call to clear the config cache into a config op. The trickiness here is that we can't just do the delete in a statement before the config op, since RESET config ops query the views and so might populate the cache, and we can't do it in a statement directly after (unless we rework the server), since then the query won't return anything. So we instead fiddle around with the query to inject a call. """ assert isinstance(sql_ast, pgast.Query) ctes = sql_ast.ctes or [] sql_ast.ctes = None ctes.append(pgast.CommonTableExpr( name="_conv_rel", query=sql_ast, )) clear_qry = pgast.SelectStmt( target_list=[ pgast.ResTarget( name="_dummy", val=pgast.FuncCall( name=('edgedb', '_clear_sys_config_cache'), args=[], ), ), ], ) ctes.append(pgast.CommonTableExpr( name="_clear_cache", query=clear_qry, materialized=True, )) force_qry = pgast.UpdateStmt( targets=[pgast.UpdateTarget( name='flag', val=pgast.BooleanConstant(val=True) )], relation=pgast.RelRangeVar(relation=pgast.Relation( name='_dml_dummy')), where_clause=pgast.Expr( name="=", lexpr=pgast.ColumnRef(name=["id"]), rexpr=pgast.SelectStmt( from_clause=[pgast.RelRangeVar(relation=ctes[-1])], target_list=[ pgast.ResTarget( val=pgast.FuncCall( name=('count',), args=[pgast.Star()]), ) ], ), ) ) if ( not isinstance(sql_ast, pgast.DMLQuery) or sql_ast.returning_list ): ctes.append(pgast.CommonTableExpr( name="_force_clear", query=force_qry, materialized=True, )) sql_ast = pgast.SelectStmt( target_list=[ pgast.ResTarget(val=pgast.ColumnRef( name=["_conv_rel", pgast.Star()])), ], ctes=ctes, from_clause=[ pgast.RelRangeVar(relation=ctes[-3]), ], ) else: sql_ast = force_qry force_qry.ctes = ctes return sql_ast def _compile_ql_config_op( ctx: CompileContext, ql: qlast.ConfigOp ) -> dbstate.SessionStateQuery: current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) session_config = current_tx.get_session_config() database_config = current_tx.get_database_config() if ql.scope is not qltypes.ConfigScope.SESSION: ctx._assert_not_in_migration_block(ql) if ( ql.scope is qltypes.ConfigScope.INSTANCE and not current_tx.is_implicit() ): raise errors.QueryError( 'CONFIGURE INSTANCE cannot be executed in a transaction block') options = _get_compile_options(ctx, no_implicit_fields=True) options.in_server_config_op = True ir = qlcompiler.compile_ast_to_ir( ql, schema=schema, options=options, ) globals = None if ir.globals: globals = [ (str(glob.global_name), glob.has_present_arg) for glob in ir.globals if not glob.is_permission ] if isinstance(ir, irast.Statement): cfg_ir = ir.expr.expr else: cfg_ir = ir is_backend_setting = bool(getattr(cfg_ir, 'backend_setting', None)) requires_restart = bool(getattr(cfg_ir, 'requires_restart', False)) is_system_config = bool(getattr(cfg_ir, 'is_system_config', False)) sql_res = pg_compiler.compile_ir_to_sql_tree( ir, backend_runtime_params=ctx.backend_runtime_params, ) sql_ast = sql_res.ast if not ctx.bootstrap_mode and ql.scope in ( qltypes.ConfigScope.DATABASE, qltypes.ConfigScope.SESSION, ): sql_ast = _inject_config_cache_clear(sql_ast) pretty = bool( debug.flags.edgeql_compile or debug.flags.edgeql_compile_sql_text) sql_text = pg_codegen.generate_source( sql_ast, pretty=pretty, ) if pretty: debug.dump_code(sql_text, lexer='sql') sql = sql_text.encode() in_type_args, in_type_data, in_type_id = describe_params( ctx, ir, sql_res.argmap, None ) if ql.scope is qltypes.ConfigScope.SESSION: config_op = ireval.evaluate_to_config_op(ir, schema=schema) session_config = config_op.apply( _get_config_spec(ctx, config_op), session_config, ) current_tx.update_session_config(session_config) elif ql.scope is qltypes.ConfigScope.DATABASE: try: config_op = ireval.evaluate_to_config_op(ir, schema=schema) except ireval.UnsupportedExpressionError: # This is a complex config object operation, the # op will be produced by the compiler as json. config_op = None else: database_config = config_op.apply( _get_config_spec(ctx, config_op), database_config, ) current_tx.update_database_config(database_config) elif ql.scope in ( qltypes.ConfigScope.INSTANCE, qltypes.ConfigScope.GLOBAL): try: config_op = ireval.evaluate_to_config_op(ir, schema=schema) except ireval.UnsupportedExpressionError: # This is a complex config object operation, the # op will be produced by the compiler as json. config_op = None else: raise AssertionError(f'unexpected configuration scope: {ql.scope}') return dbstate.SessionStateQuery( sql=sql, is_backend_setting=is_backend_setting, is_system_config=is_system_config, config_scope=ql.scope, requires_restart=requires_restart, config_op=config_op, globals=globals, in_type_args=in_type_args, in_type_data=in_type_data, in_type_id=in_type_id.bytes, ) def _compile_dispatch_ql( ctx: CompileContext, ql: qlast.Base, source: Optional[edgeql.Source] = None, *, in_script: bool=False, script_info: Optional[irast.ScriptInfo] = None, ) -> tuple[dbstate.BaseQuery, enums.Capability]: if isinstance(ql, qlast.MigrationCommand): query = ddl.compile_dispatch_ql_migration( ctx, ql, in_script=in_script ) if isinstance(query, dbstate.MigrationControlQuery): capability = enums.Capability.DDL if query.tx_action: capability |= enums.Capability.TRANSACTION return query, capability elif isinstance(query, dbstate.DDLQuery): return query, enums.Capability.DDL else: # DESCRIBE CURRENT MIGRATION return query, enums.Capability(0) elif isinstance(ql, qlast.DDLCommand): query = ddl.compile_and_apply_ddl_stmt(ctx, ql, source=source) capability = enums.Capability.DDL if isinstance(ql, (qlast.GlobalObjectCommand)): capability |= enums.Capability.GLOBAL_DDL return (query, capability) elif isinstance(ql, qlast.Transaction): return ( _compile_ql_transaction(ctx, ql), enums.Capability.TRANSACTION, ) elif isinstance(ql, qlast.SessionCommand): return ( _compile_ql_sess_state(ctx, ql), enums.Capability.SESSION_CONFIG, ) elif isinstance(ql, qlast.ConfigOp): if ql.scope is qltypes.ConfigScope.SESSION: capability = enums.Capability.SESSION_CONFIG elif ql.scope is qltypes.ConfigScope.GLOBAL: # We want the notebook protocol to be able to SET # GLOBAL but not CONFIGURE SESSION, but they are # merged in the capabilities header. Splitting them # out introduces compatability headaches, so for now # we keep them merged and hack around it for the notebook. if ctx.notebook: capability = enums.Capability(0) else: capability = enums.Capability.SESSION_CONFIG elif ql.scope is qltypes.ConfigScope.DATABASE: capability = ( enums.Capability.PERSISTENT_CONFIG | enums.Capability.BRANCH_CONFIG ) else: capability = ( enums.Capability.PERSISTENT_CONFIG | enums.Capability.INSTANCE_CONFIG ) return ( _compile_ql_config_op(ctx, ql), capability, ) elif isinstance(ql, qlast.ExplainStmt): query = _compile_ql_explain(ctx, ql, script_info=script_info) caps = enums.Capability.ANALYZE if ( isinstance(query, (dbstate.Query, dbstate.SimpleQuery)) and query.has_dml ): caps |= enums.Capability.MODIFICATIONS return (query, caps) elif isinstance(ql, qlast.AdministerStmt): query = _compile_ql_administer(ctx, ql, script_info=script_info) caps = enums.Capability.ADMINISTER return (query, caps) else: assert isinstance(ql, (qlast.Query, qlast.Command)) query = _compile_ql_query( ctx, ql, source=source, script_info=script_info) caps = enums.Capability(0) if isinstance(ql, qlast.DescribeStmt): caps |= enums.Capability.DESCRIBE if ( isinstance(query, (dbstate.Query, dbstate.SimpleQuery)) and query.has_dml ): caps |= enums.Capability.MODIFICATIONS return (query, caps) def compile_graphql( *, ctx: CompileContext, source: graphql.Source, variables: Optional[Mapping[str, object]], ) -> dbstate.QueryUnitGroup: current_tx = ctx.state.current_tx() gql_op = graphql.compile_graphql( ctx.compiler_state.std_schema, current_tx.get_user_schema(), current_tx.get_global_schema(), current_tx.get_database_config(), current_tx.get_system_config(), source.text(), tokens=source.tokens(), substitutions=source.substitutions(), extracted_variables=source.variables(), variables=variables, native_input=True, ) eql_source = edgeql.Source.from_string( edgeql.generate_source(gql_op.edgeql_ast, pretty=True), ) qug = compile(ctx=ctx, source=eql_source) if gql_op.cache_deps_vars: qug.graphql_key_variables = sorted(gql_op.cache_deps_vars) # No warnings in graphql, yet for qu in qug: qu.warnings = () qug.warnings = None return qug def compile( *, ctx: CompileContext, source: edgeql.Source, ) -> dbstate.QueryUnitGroup: current_tx = ctx.state.current_tx() if current_tx.get_migration_state() is not None: original = edgeql.Source.from_string(source.text()) ctx = dataclasses.replace( ctx, source=original, implicit_limit=0, ) return _try_compile(ctx=ctx, source=original) try: return _try_compile(ctx=ctx, source=source) except errors.EdgeQLSyntaxError as original_err: if isinstance(source, edgeql.NormalizedSource): # try non-normalized source try: original = edgeql.Source.from_string(source.text()) ctx = dataclasses.replace(ctx, source=original) _try_compile(ctx=ctx, source=original) except errors.EdgeQLSyntaxError as denormalized_err: raise denormalized_err except Exception: raise original_err else: raise AssertionError( "Normalized query is broken while original is valid") else: raise original_err def compile_sql_as_unit_group( *, ctx: CompileContext, source: edgeql.Source, ) -> dbstate.QueryUnitGroup: setting = _get_config_val(ctx, 'allow_user_specified_id') allow_user_specified_id = None if setting: allow_user_specified_id = sql.is_setting_truthy(setting) # Note that unlike SQL over PostgreSQL protocol we use # the general access policy toggle, not the SQL-specific one. apply_access_policies = None setting = _get_config_val(ctx, 'apply_access_policies') if setting is not None: apply_access_policies = sql.is_setting_truthy(setting) tx_state = ctx.state.current_tx() schema = tx_state.get_schema(ctx.compiler_state.std_schema) settings = dbstate.DEFAULT_SQL_FE_SETTINGS sql_tx_state = dbstate.SQLTransactionState( in_tx=not tx_state.is_implicit(), settings=settings, in_tx_settings=settings, in_tx_local_settings=settings, savepoints=[ (not_none(tx.name), settings, settings) for tx in tx_state._savepoints.values() ], ) sql_units, force_non_normalized = sql.compile_sql( source, schema=schema, tx_state=sql_tx_state, prepared_stmt_map={}, current_database=ctx.branch_name or "", allow_user_specified_id=allow_user_specified_id, apply_access_policies=apply_access_policies, include_edgeql_io_format_alternative=True, allow_prepared_statements=False, disambiguate_column_names=True, backend_runtime_params=ctx.backend_runtime_params, protocol_version=ctx.protocol_version, implicit_limit=ctx.implicit_limit, ) qug = dbstate.QueryUnitGroup( cardinality=sql_units[-1].cardinality, cacheable=True, force_non_normalized=force_non_normalized, ) for sql_unit in sql_units: if sql_unit.eql_format_query is not None: value_sql = sql_unit.eql_format_query.encode("utf-8") intro_sql = sql_unit.query.encode("utf-8") else: value_sql = sql_unit.query.encode("utf-8") intro_sql = None if isinstance(sql_unit.command_complete_tag, dbstate.TagPlain): status = sql_unit.command_complete_tag.tag elif isinstance( sql_unit.command_complete_tag, (dbstate.TagCountMessages, dbstate.TagUnpackRow), ): status = sql_unit.command_complete_tag.prefix.encode("utf-8") elif sql_unit.command_complete_tag is None: status = b"SELECT" # XXX else: raise AssertionError( f"unexpected SQLQueryUnit.command_complete_tag type: " f"{sql_unit.command_complete_tag}" ) globals = [] permissions = [] for sp in sql_unit.params or (): if not isinstance(sp, dbstate.SQLParamGlobal): continue if not sp.is_permission: globals.append((str(sp.global_name), False)) else: permissions.append(str(sp.global_name)) unit = dbstate.QueryUnit( sql=value_sql, introspection_sql=intro_sql, status=status, cardinality=( enums.Cardinality.NO_RESULT if ctx.output_format is enums.OutputFormat.NONE else sql_unit.cardinality ), capabilities=sql_unit.capabilities, globals=globals, permissions=permissions, output_format=( enums.OutputFormat.NONE if ( ctx.output_format is enums.OutputFormat.NONE or sql_unit.cardinality is enums.Cardinality.NO_RESULT ) else enums.OutputFormat.BINARY ), source_map=sql_unit.source_map, sql_prefix_len=sql_unit.prefix_len, ) match sql_unit.tx_action: case dbstate.TxAction.START: ctx.state.start_tx() tx_state = ctx.state.current_tx() unit.tx_id = tx_state.id case dbstate.TxAction.COMMIT: ctx.state.commit_tx() unit.tx_commit = True case dbstate.TxAction.ROLLBACK: ctx.state.rollback_tx() unit.tx_rollback = True case dbstate.TxAction.DECLARE_SAVEPOINT: assert sql_unit.sp_name is not None unit.tx_savepoint_declare = True unit.sp_id = tx_state.declare_savepoint(sql_unit.sp_name) unit.sp_name = sql_unit.sp_name case dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: assert sql_unit.sp_name is not None tx_state.rollback_to_savepoint(sql_unit.sp_name) unit.tx_savepoint_rollback = True unit.sp_name = sql_unit.sp_name case dbstate.TxAction.RELEASE_SAVEPOINT: assert sql_unit.sp_name is not None tx_state.release_savepoint(sql_unit.sp_name) unit.sp_name = sql_unit.sp_name case None: unit.cacheable = sql_unit.cacheable case _: raise AssertionError( f"unexpected SQLQueryUnit.tx_action: {sql_unit.tx_action}" ) qug.append(unit) return qug def _try_compile( *, ctx: CompileContext, source: edgeql.Source, ) -> dbstate.QueryUnitGroup: if ctx.is_testmode(): # This is a bad but simple way to emulate a slow compilation for tests. # Ideally, we should have a testmode function that is hooked to sleep # as `simple_special_case`, or wait for a notification from the test. sentinel = "# EDGEDB_TEST_COMPILER_SLEEP = " text = source.text() if text.startswith(sentinel): time.sleep(float(text[len(sentinel):text.index("\n")])) statements = edgeql.parse_block(source) return _try_compile_ast(statements=statements, source=source, ctx=ctx) def _try_compile_ast( *, ctx: CompileContext, statements: Sequence[qlast.Base], source: edgeql.Source, ) -> dbstate.QueryUnitGroup: if ctx.is_testmode(): # This is a bad but simple way to emulate a slow compilation for tests. # Ideally, we should have a testmode function that is hooked to sleep # as `simple_special_case`, or wait for a notification from the test. sentinel = "# EDGEDB_TEST_COMPILER_SLEEP = " text = source.text() if text.startswith(sentinel): time.sleep(float(text[len(sentinel):text.index("\n")])) statements_len = len(statements) if not len(statements): # pragma: no cover raise errors.ProtocolError('nothing to compile') rv = dbstate.QueryUnitGroup() is_script = statements_len > 1 script_info = None if is_script: if ctx.expect_rollback: # We are in a failed transaction expecting a rollback, while a # script cannot be a rollback raise errors.TransactionError( 'expected a ROLLBACK or ROLLBACK TO SAVEPOINT command' ) script_info = qlcompiler.preprocess_script( statements, schema=ctx.state.current_tx().get_schema( ctx.compiler_state.std_schema), options=_get_compile_options(ctx) ) non_trailing_ctx = dataclasses.replace( ctx, output_format=enums.OutputFormat.NONE) final_user_schema: Optional[s_schema.Schema] = None for i, stmt in enumerate(statements): is_trailing_stmt = i == statements_len - 1 stmt_ctx = ctx if is_trailing_stmt else non_trailing_ctx _check_force_database_error(stmt_ctx, stmt) comp, capabilities = _compile_dispatch_ql( stmt_ctx, stmt, source=source if not is_script else None, script_info=script_info, in_script=is_script, ) unit, user_schema = _make_query_unit( ctx=ctx, stmt_ctx=stmt_ctx, stmt=stmt, is_script=is_script, is_trailing_stmt=is_trailing_stmt, comp=comp, capabilities=capabilities, ) rv.append(unit) if user_schema is not None: final_user_schema = user_schema if script_info: if ctx.state.current_tx().is_implicit(): if ctx.state.current_tx().get_migration_state(): raise errors.QueryError( "Cannot leave an incomplete migration in scripts" ) if ctx.state.current_tx().get_migration_rewrite_state(): raise errors.QueryError( "Cannot leave an incomplete migration rewrite " "in scripts" ) params, in_type_args = _extract_params( list(script_info.params.values()), argmap=None, script_info=None, schema=script_info.schema, ctx=ctx) in_type_data, in_type_id = sertypes.describe_params( schema=script_info.schema, params=params, protocol_version=ctx.protocol_version, ) rv.in_type_id = in_type_id.bytes rv.in_type_args = in_type_args rv.in_type_data = in_type_data if final_user_schema is not None: rv.state_serializer = ctx.compiler_state.state_serializer_factory.make( final_user_schema, ctx.state.current_tx().get_global_schema(), ctx.protocol_version, ) # Sanity checks for unit in rv: # pragma: no cover na_cardinality = ( unit.cardinality is enums.Cardinality.NO_RESULT ) if unit.cacheable and ( unit.config_ops or unit.modaliases or unit.user_schema or unit.cached_reflection ): raise errors.InternalServerError( f'QueryUnit {unit!r} is cacheable but has config/aliases') if not na_cardinality and ( unit.tx_commit or unit.tx_rollback or unit.tx_savepoint_rollback or unit.out_type_id is sertypes.NULL_TYPE_ID or unit.system_config or unit.config_ops or unit.modaliases or unit.has_set or unit.has_ddl or not unit.sql_hash): raise errors.InternalServerError( f'unit has invalid "cardinality": {unit!r}') multi_card = rv.cardinality in ( enums.Cardinality.MANY, enums.Cardinality.AT_LEAST_ONE, ) if multi_card and ctx.expected_cardinality_one: raise errors.ResultCardinalityMismatchError( f'the query has cardinality {unit.cardinality.name} ' f'which does not match the expected cardinality ONE') return rv def _make_query_unit( *, ctx: CompileContext, stmt_ctx: CompileContext, stmt: qlast.Base, is_script: bool, is_trailing_stmt: bool, comp: dbstate.BaseQuery, capabilities: enums.Capability, ) -> tuple[dbstate.QueryUnit, Optional[s_schema.Schema]]: # Initialize user_schema_version with the version this query is # going to be compiled upon. This can be overwritten later by DDLs. try: schema_version = _get_schema_version( stmt_ctx.state.current_tx().get_user_schema() ) except errors.InvalidReferenceError: schema_version = None unit = dbstate.QueryUnit( sql=b"", status=status.get_status(stmt), cardinality=enums.Cardinality.NO_RESULT, capabilities=capabilities, output_format=stmt_ctx.output_format, cache_key=ctx.cache_key, user_schema_version=schema_version, warnings=comp.warnings, unsafe_isolation_dangers=comp.unsafe_isolation_dangers, ) if not comp.is_transactional: if is_script: raise errors.QueryError( f'cannot execute {status.get_status(stmt).decode()} ' f'with other commands in one block', span=stmt.span, ) if not ctx.state.current_tx().is_implicit(): raise errors.QueryError( f'cannot execute {status.get_status(stmt).decode()} ' f'in a transaction', span=stmt.span, ) unit.is_transactional = False final_user_schema: Optional[s_schema.Schema] = None if isinstance(comp, dbstate.Query): unit.sql = comp.sql unit.cache_sql = comp.cache_sql unit.cache_func_call = comp.cache_func_call unit.globals = comp.globals unit.permissions = comp.permissions unit.json_permissions = comp.json_permissions unit.required_permissions = comp.required_permissions unit.in_type_args = comp.in_type_args unit.sql_hash = comp.sql_hash unit.out_type_data = comp.out_type_data unit.out_type_id = comp.out_type_id unit.in_type_data = comp.in_type_data unit.in_type_id = comp.in_type_id unit.server_param_conversions = comp.server_param_conversions unit.cacheable = comp.cacheable if comp.is_explain: unit.is_explain = True unit.query_asts = comp.query_asts if comp.run_and_rollback: unit.run_and_rollback = True if is_trailing_stmt: unit.cardinality = comp.cardinality elif isinstance(comp, dbstate.SimpleQuery): unit.sql = comp.sql unit.in_type_args = comp.in_type_args elif isinstance(comp, dbstate.DDLQuery): unit.sql = comp.sql unit.db_op_trailer = comp.db_op_trailer unit.create_db = comp.create_db unit.drop_db = comp.drop_db unit.drop_db_reset_connections = comp.drop_db_reset_connections unit.create_db_template = comp.create_db_template unit.create_db_mode = comp.create_db_mode unit.ddl_stmt_id = comp.ddl_stmt_id unit.early_non_tx_sql = comp.early_non_tx_sql if not ctx.dump_restore_mode: if comp.user_schema is not None: final_user_schema = comp.user_schema unit.user_schema = pickle.dumps(comp.user_schema, -1) unit.user_schema_version = ( _get_schema_version(comp.user_schema) ) unit.extensions, unit.ext_config_settings = ( _extract_extensions(ctx, comp.user_schema) ) unit.feature_used_metrics = comp.feature_used_metrics if comp.cached_reflection is not None: unit.cached_reflection = \ pickle.dumps(comp.cached_reflection, -1) if comp.global_schema is not None: unit.global_schema = pickle.dumps(comp.global_schema, -1) unit.roles = _extract_roles(comp.global_schema) unit.config_ops.extend(comp.config_ops) elif isinstance(comp, dbstate.TxControlQuery): if is_script: raise errors.QueryError( "Explicit transaction control commands cannot be executed " "in an implicit transaction block" ) unit.sql = comp.sql unit.cacheable = comp.cacheable unit.tx_isolation_level = comp.isolation_level if not ctx.dump_restore_mode: if comp.user_schema is not None: final_user_schema = comp.user_schema unit.user_schema = pickle.dumps(comp.user_schema, -1) unit.user_schema_version = ( _get_schema_version(comp.user_schema) ) unit.extensions, unit.ext_config_settings = ( _extract_extensions(ctx, comp.user_schema) ) unit.feature_used_metrics = comp.feature_used_metrics if comp.cached_reflection is not None: unit.cached_reflection = \ pickle.dumps(comp.cached_reflection, -1) if comp.global_schema is not None: unit.global_schema = pickle.dumps(comp.global_schema, -1) unit.roles = _extract_roles(comp.global_schema) if comp.modaliases is not None: unit.modaliases = comp.modaliases if comp.action == dbstate.TxAction.START: if unit.tx_id is not None: raise errors.InternalServerError( 'already in transaction') unit.tx_id = ctx.state.current_tx().id elif comp.action == dbstate.TxAction.COMMIT: unit.tx_commit = True elif comp.action == dbstate.TxAction.ROLLBACK: unit.tx_rollback = True elif comp.action is dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: unit.tx_savepoint_rollback = True unit.sp_name = comp.sp_name elif comp.action is dbstate.TxAction.DECLARE_SAVEPOINT: unit.tx_savepoint_declare = True unit.sp_name = comp.sp_name unit.sp_id = comp.sp_id elif isinstance(comp, dbstate.MigrationControlQuery): unit.sql = comp.sql unit.cacheable = comp.cacheable if not ctx.dump_restore_mode: if comp.user_schema is not None: final_user_schema = comp.user_schema unit.user_schema = pickle.dumps(comp.user_schema, -1) unit.user_schema_version = ( _get_schema_version(comp.user_schema) ) unit.extensions, unit.ext_config_settings = ( _extract_extensions(ctx, comp.user_schema) ) if comp.cached_reflection is not None: unit.cached_reflection = \ pickle.dumps(comp.cached_reflection, -1) unit.ddl_stmt_id = comp.ddl_stmt_id if comp.modaliases is not None: unit.modaliases = comp.modaliases if comp.tx_action == dbstate.TxAction.START: if unit.tx_id is not None: raise errors.InternalServerError( 'already in transaction') unit.tx_id = ctx.state.current_tx().id elif comp.tx_action == dbstate.TxAction.COMMIT: unit.tx_commit = True unit.append_tx_op = True elif comp.tx_action == dbstate.TxAction.ROLLBACK: unit.tx_rollback = True unit.append_tx_op = True elif comp.action == dbstate.MigrationAction.ABORT: unit.tx_abort_migration = True elif isinstance(comp, dbstate.SessionStateQuery): unit.sql = comp.sql unit.globals = comp.globals if comp.config_scope is qltypes.ConfigScope.INSTANCE: if not ctx.state.current_tx().is_implicit() or is_script: raise errors.QueryError( 'CONFIGURE INSTANCE cannot be executed in a ' 'transaction block') unit.system_config = True elif comp.config_scope is qltypes.ConfigScope.GLOBAL: unit.needs_readback = True elif comp.config_scope is qltypes.ConfigScope.DATABASE: unit.database_config = True unit.needs_readback = True if comp.is_backend_setting: unit.backend_config = True if comp.requires_restart: unit.config_requires_restart = True if comp.is_system_config: unit.is_system_config = True unit.modaliases = ctx.state.current_tx().get_modaliases() if comp.config_op is not None: unit.config_ops.append(comp.config_op) if comp.in_type_args: unit.in_type_args = comp.in_type_args if comp.in_type_data: unit.in_type_data = comp.in_type_data if comp.in_type_id: unit.in_type_id = comp.in_type_id unit.has_set = True unit.output_format = enums.OutputFormat.NONE elif isinstance(comp, dbstate.MaintenanceQuery): unit.sql = comp.sql elif isinstance(comp, dbstate.NullQuery): pass else: # pragma: no cover raise errors.InternalServerError('unknown compile state') if unit.in_type_args: unit.in_type_args_real_count = sum( len(p.sub_params[0]) if p.sub_params else 1 for p in unit.in_type_args ) if unit.warnings: for warning in unit.warnings: warning.__traceback__ = None if unit.unsafe_isolation_dangers: for warning in unit.unsafe_isolation_dangers: warning.__traceback__ = None return unit, final_user_schema def _extract_params( params: list[irast.Param], *, schema: s_schema.Schema, argmap: Optional[dict[str, pgast.Param]], script_info: Optional[irast.ScriptInfo], ctx: CompileContext, ) -> tuple[list[tuple[str, s_types.Type, bool]], list[dbstate.Param]]: first_param = next(iter(params)) if params else None has_named_params = first_param and not first_param.name.isdecimal() if (src := ctx.source) is not None: first_extra = src.first_extra() else: first_extra = None all_params = script_info.params.values() if script_info else params total_params = len([p for p in all_params if not p.is_sub_param]) user_params = first_extra if first_extra is not None else total_params if script_info is not None: outer_mapping = {n: i for i, n in enumerate(script_info.params)} # Count however many of *our* arguments are user_params user_params = sum( outer_mapping[n.name] < user_params for n in params if not n.is_sub_param) else: outer_mapping = None oparams: list[Optional[tuple[str, s_obj.Object, bool]]] = ( [None] * user_params) in_type_args: list[Optional[dbstate.Param]] = [None] * user_params for idx, param in enumerate(params): if param.is_sub_param: continue if argmap is not None: sql_param = argmap[param.name] idx = sql_param.logical_index - 1 if idx >= user_params: continue if ctx.json_parameters: schema_type = schema.get('std::json') else: schema_type = param.schema_type array_tid = None if isinstance(schema_type, s_types.Array): el_type = schema_type.get_element_type(schema) array_tid = el_type.id # NB: We'll need to turn this off for script args if ( not script_info and not has_named_params and str(idx) != param.name ): raise RuntimeError( 'positional argument name disagrees ' 'with its actual position') oparams[idx] = ( param.name, schema_type, param.required, ) if param.sub_params: array_tids: list[Optional[uuid.UUID]] = [] for p in param.sub_params.params: if isinstance(p.schema_type, s_types.Array): el_type = p.schema_type.get_element_type(schema) array_tids.append(el_type.id) else: array_tids.append(None) sub_params = ( array_tids, param.sub_params.trans_type.flatten()) else: sub_params = None in_type_args[idx] = dbstate.Param( name=param.name, required=param.required, array_type_id=array_tid, outer_idx=outer_mapping[param.name] if outer_mapping else None, sub_params=sub_params, typename=str(schema_type.get_name(schema)), ) return oparams, in_type_args # type: ignore[return-value] def get_obj_ids( schema: s_schema.Schema, *, include_extras: bool=False, ) -> tuple[list[tuple[str, str, uuid.UUID]], list[uuid.UUID]]: all_objects: Iterable[s_obj.Object] = schema.get_objects( exclude_stdlib=True, exclude_global=True, ) ids = [] sequences = [] for obj in all_objects: if isinstance(obj, s_obj.QualifiedObject): ql_class = '' else: ql_class = str(type(obj).get_ql_class_or_die()) name = str(obj.get_name(schema)) ids.append(( name, ql_class, obj.id, )) if isinstance(obj, s_types.Type) and obj.is_sequence(schema): sequences.append(obj.id) if include_extras and isinstance(obj, s_func.Function): backend_name = obj.get_backend_name(schema) if backend_name: ids.append(( name, f'{ql_class or None}-backend_name', backend_name, )) return ids, sequences def _describe_object( schema: s_schema.Schema, source: s_obj.Object, protocol_version: defines.ProtocolVersion, ) -> list[DumpBlockDescriptor]: cols = [] shape = [] ptrdesc: list[DumpBlockDescriptor] = [] if isinstance(source, s_props.Property): schema, prop_tuple = s_types.Tuple.from_subtypes( schema, { 'source': schema.get('std::uuid', type=s_types.Type), 'target': not_none(source.get_target(schema)), }, {'named': True}, ) type_data, type_id = sertypes.describe( schema, prop_tuple, follow_links=False, protocol_version=protocol_version, ) cols.extend([ 'source', 'target', ]) elif isinstance(source, s_links.Link): props = {} for ptr in source.get_pointers(schema).objects(schema): if not ptr.is_dumpable(schema): continue stor_info = pg_types.get_pointer_storage_info( ptr, schema=schema, source=source, link_bias=True, ) cols.append(stor_info.column_name) props[ptr.get_shortname(schema).name] = not_none( ptr.get_target(schema)) schema, link_tuple = s_types.Tuple.from_subtypes( schema, props, {'named': True}, ) type_data, type_id = sertypes.describe( schema, link_tuple, follow_links=False, protocol_version=protocol_version, ) else: assert isinstance(source, s_objtypes.ObjectType) for ptr in source.get_pointers(schema).objects(schema): if not ptr.is_dumpable(schema): continue stor_info = pg_types.get_pointer_storage_info( ptr, schema=schema, source=source, ) if stor_info.table_type == 'ObjectType': cols.append(stor_info.column_name) shape.append(ptr) link_stor_info = pg_types.get_pointer_storage_info( ptr, schema=schema, source=source, link_bias=True, ) if link_stor_info is not None: ptrdesc.extend(_describe_object(schema, ptr, protocol_version)) # For any addon columns (currently fts and ai shadow index # columns), generate a fake pointer to put in the descriptor # and include them in the dump. nschema = schema for (name, col, _type) in source.get_addon_columns(schema): nschema, fake_ptr = _add_fake_property(source, name, nschema) cols.append(col) shape.append(fake_ptr) type_data, type_id = sertypes.describe( nschema, source, view_shapes={source: shape}, follow_links=False, protocol_version=protocol_version, ) table_name = pg_common.get_backend_name( schema, source, catenate=True ) stmt = ( f'COPY {table_name} ' f'({", ".join(pg_common.quote_ident(c) for c in cols)}) ' f'TO STDOUT WITH BINARY' ).encode() return [DumpBlockDescriptor( schema_object_id=source.id, schema_object_class=type(source).get_ql_class_or_die(), schema_deps=tuple(p.schema_object_id for p in ptrdesc), type_desc_id=type_id, type_desc=type_data, sql_copy_stmt=stmt, )] + ptrdesc def _check_dump_layout( dump_els: AbstractSet[str], schema_els: AbstractSet[str], elided_els: AbstractSet[str], label: str, ) -> None: extra_els = dump_els - (schema_els | elided_els) if extra_els: raise RuntimeError( f'dump data tuple of {label} has extraneous elements: ' f'{", ".join(extra_els)}' ) missing_els = schema_els - dump_els if missing_els: raise RuntimeError( f'dump data tuple of {label} has missing elements: ' f'{", ".join(missing_els)}' ) def _get_ptr_mending_desc( schema: s_schema.Schema, ptr: s_pointers.Pointer, ) -> Optional[DataMendingDescriptor]: ptr_type = ptr.get_target(schema) if isinstance(ptr_type, (s_types.Array, s_types.Tuple)): return _get_data_mending_desc(schema, ptr_type) else: return None def _get_data_mending_desc( schema: s_schema.Schema, typ: s_types.Type, ) -> Optional[DataMendingDescriptor]: if isinstance(typ, (s_types.Tuple, s_types.Array)): elements = tuple( _get_data_mending_desc(schema, element) for element in typ.get_subtypes(schema) ) else: elements = tuple() if pg_types.type_has_stable_oid(typ): return None else: return DataMendingDescriptor( schema_type_id=typ.id, schema_object_class=type(typ).get_ql_class_or_die(), elements=elements, needs_mending=bool( isinstance(typ, (s_types.Tuple, s_types.Array)) and any(elements) ) ) def _add_fake_property( source: s_objtypes.ObjectType, name: str, schema: s_schema.Schema, ) -> tuple[s_schema.Schema, s_props.Property]: base = schema.get( s_name.QualName('std', 'property'), type=s_props.Property, ) derived_name = s_obj.derive_name( schema, str(source.get_name(schema)), module='__derived__', derived_name_base=s_name.UnqualName(name), parent=base, ) return base.derive_ref( schema, source, name=derived_name, target=schema.get('std::bytes', type=s_types.Type), ) def maybe_force_database_error( val: Optional[str], *, scope: str, ) -> None: # Check the string directly for false to skip a deserialization if val is None or val == 'false': return try: err = json.loads(val) if not err: return scopes = err.get('_scopes', ['query']) if scope not in scopes: return versions = err.get('_versions') if versions and buildmeta.get_version_string() not in versions: return errcls = errors.EdgeDBError.get_error_class_from_name(err['type']) if context := err.get('context'): filename = context.get('filename') position = tuple( context.get(k) for k in ('line', 'col', 'start', 'end') ) else: filename = None position = None errval = errcls( msg=err.get('message'), hint=err.get('hint'), details=err.get('details'), filename=filename, position=position, ) except Exception: raise errors.ConfigurationError( "invalid 'force_database_error' value'") raise errval def _check_force_database_error( ctx: CompileContext, ql: Optional[qlast.Base]=None, *, scope: str='query', ) -> None: if isinstance(ql, qlast.ConfigOp): return val = _get_config_val(ctx, 'force_database_error') if isinstance(ql, qlast.DDLCommand): maybe_force_database_error(val, scope='ddl') maybe_force_database_error(val, scope=scope) def _get_config_val( ctx: CompileContext, name: str, ) -> Any: current_tx = ctx.state.current_tx() return config.lookup( name, current_tx.get_session_config(), current_tx.get_database_config(), current_tx.get_system_config(), spec=ctx.compiler_state.config_spec, allow_unrecognized=True, ) def _get_compilation_config_vals(ctx: CompileContext) -> Any: assert ctx.compiler_state.config_spec is not None return { k: _get_config_val(ctx, k) for k in ctx.compiler_state.config_spec if ctx.compiler_state.config_spec[k].affects_compilation } _OUTPUT_FORMAT_MAP = { enums.OutputFormat.BINARY: pg_compiler.OutputFormat.NATIVE, enums.OutputFormat.JSON: pg_compiler.OutputFormat.JSON, enums.OutputFormat.JSON_ELEMENTS: pg_compiler.OutputFormat.JSON_ELEMENTS, enums.OutputFormat.NONE: pg_compiler.OutputFormat.NONE, } def _convert_format(inp: enums.OutputFormat) -> pg_compiler.OutputFormat: try: return _OUTPUT_FORMAT_MAP[inp] except KeyError: raise RuntimeError(f"Output format {inp!r} is not supported") def _hash_sql(sql: bytes, **kwargs: bytes) -> bytes: h = hashlib.sha1(sql) for param, val in kwargs.items(): h.update(param.encode('latin1')) h.update(val) return h.hexdigest().encode('latin1') def _extract_extensions( ctx: CompileContext, user_schema: s_schema.Schema ) -> tuple[set[str], list[config.Setting]]: # XXX: Do we need to return None if extensions/config_spec didn't change? names = { ext.get_name(user_schema).name for ext in user_schema.get_objects(type=s_ext.Extension) } if names: schema = s_schema.ChainedSchema( ctx.compiler_state.std_schema, user_schema, s_schema.EMPTY_SCHEMA ) settings = config.load_ext_settings_from_schema(schema) else: settings = [] return names, settings def _extract_roles( global_schema: s_schema.Schema, ) -> immutables.Map[str, immutables.Map[str, Any]]: extracted_roles = {} schema_roles = global_schema.get_objects(type=s_role.Role) for role in schema_roles: role_name = str(role.get_name(global_schema)) extracted_roles[role_name] = dict( name=role_name, superuser=role.get_superuser(global_schema), password=role.get_password(global_schema), branches=list(sorted(role.get_branches(global_schema))), apply_access_policies_pg_default=( role.get_apply_access_policies_pg_default(global_schema) ), ) # To populate all_permissions, combine the permissions of each role # and its ancestors. role_memberships: MutableMapping[s_role.Role, list[s_role.Role]] = {} role_permissions: MutableMapping[s_role.Role, Sequence[str]] = {} for role in schema_roles: role_memberships[role] = list( role.get_ancestors(global_schema).objects(global_schema) ) role_permissions[role] = list(sorted( role.get_permissions(global_schema) or () )) for role in schema_roles: role_name = str(role.get_name(global_schema)) extracted_roles[role_name]['all_permissions'] = tuple(set( p for m in [role] + role_memberships.get(role, []) for p in role_permissions[m] )) # Convert everything into immutable maps return immutables.Map({ name: immutables.Map(role) for name, role in extracted_roles.items() }) class DumpDescriptor(NamedTuple): schema_ddl: str schema_dynamic_ddl: tuple[str, ...] schema_ids: list[tuple[str, str, bytes]] blocks: Sequence[DumpBlockDescriptor] class DumpBlockDescriptor(NamedTuple): schema_object_id: uuid.UUID schema_object_class: qltypes.SchemaObjectClass schema_deps: tuple[uuid.UUID, ...] type_desc_id: uuid.UUID type_desc: bytes sql_copy_stmt: bytes class RestoreDescriptor(NamedTuple): units: Sequence[dbstate.QueryUnit] blocks: Sequence[RestoreBlockDescriptor] tables: Sequence[str] repopulate_units: Sequence[str] class DataMendingDescriptor(NamedTuple): #: The identifier of the EdgeDB type schema_type_id: uuid.UUID #: The kind of a type we are dealing with schema_object_class: qltypes.SchemaObjectClass #: If type is a collection, mending descriptors of element types elements: tuple[Optional[DataMendingDescriptor], ...] = tuple() #: Whether a datum represented by this descriptor will need mending needs_mending: bool = False class RestoreBlockDescriptor(NamedTuple): #: The identifier of the schema object this data is for. schema_object_id: uuid.UUID #: The COPY SQL statement for this block. sql_copy_stmt: bytes #: For compatibility with old dumps, a list of column indexes #: that should be ignored in the COPY stream. compat_elided_cols: tuple[int, ...] #: If the tuple requires mending of unstable Postgres OIDs in data, #: this will contain the recursive descriptor on which parts of #: each datum need mending. data_mending_desc: tuple[Optional[DataMendingDescriptor], ...] ================================================ FILE: edb/server/compiler/config.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Iterable, Mapping, Sequence import dataclasses import datetime import functools import immutables from edb import errors from edb.common import typeutils from edb.edgeql import ast as qlast from edb.edgeql import parser as qlparser from edb.edgeql import qltypes from edb.edgeql import compiler as qlcompiler from edb.ir import ast as irast from edb.ir import staeval as ireval from edb.server import config from edb.schema import name as sn from edb.schema import objtypes as s_objtypes from edb.schema import pointers as s_pointers from edb.schema import schema as s_schema from edb.schema import types as s_types from edb.schema import utils as s_utils ConfigInput = ( str | int | float | bool | datetime.datetime | datetime.date | datetime.time | Sequence["ConfigInput"] | Mapping[str, "ConfigInput"] | None ) ConfigObject = Mapping[str, ConfigInput] @dataclasses.dataclass class Context: schema: s_schema.Schema obj_type: s_objtypes.ObjectType qual_name: str options: qlcompiler.CompilerOptions def get_ptr(self, name: str) -> s_pointers.Pointer: un = sn.UnqualName(name) schema = self.schema ty = self.obj_type ancestors = ty.get_ancestors(schema).objects(schema) for t in (ty,) + ancestors: if (rv := t.maybe_get_ptr(schema, un)) is not None: return rv raise errors.ConfigurationError( f"{ty.get_shortname(schema)!s} does not have field: {name!r}" ) def get_full_name(self, ptr: s_pointers.Pointer) -> str: return f"{self.qual_name}::{ptr.get_local_name(self.schema)}" def is_multi(self, ptr: s_pointers.Pointer) -> bool: return ptr.get_cardinality(self.schema).is_multi() def get_type[TypeT: s_types.Type]( self, ptr: s_pointers.Pointer, *, type: type[TypeT] ) -> TypeT: rv = ptr.get_target(self.schema) if not isinstance(rv, type): raise TypeError(f"{ptr!r}.target is not {type:r}") return rv def get_ref(self, ptr: s_pointers.Pointer) -> qlast.ObjectRef: ty = self.get_type(ptr, type=s_types.QualifiedType) ty_name = ty.get_shortname(self.schema) return qlast.ObjectRef(name=ty_name.name, module=ty_name.module) def cast( self, expr: qlast.Expr, *, ptr: s_pointers.Pointer ) -> qlast.TypeCast: return qlast.TypeCast( expr=expr, type=qlast.TypeName(maintype=self.get_ref(ptr)), ) @functools.singledispatch def compile_input_to_ast( value: ConfigInput, *, ptr: s_pointers.Pointer, ctx: Context ) -> qlast.Expr: raise errors.ConfigurationError( f"unsupported input type {type(value)!r} for {ctx.get_full_name(ptr)}" ) @compile_input_to_ast.register def compile_input_str( value: str, *, ptr: s_pointers.Pointer, ctx: Context ) -> qlast.Expr: if value.startswith("{{") and value.endswith("}}"): return qlparser.parse_fragment(value[2:-2]) ty = ctx.get_type(ptr, type=s_types.QualifiedType) if ty.is_enum(ctx.schema): ty_name = ty.get_shortname(ctx.schema) return qlast.Path( steps=[ qlast.ObjectRef(name=ty_name.name, module=ty_name.module), qlast.Ptr(name=value), ] ) else: return ctx.cast(qlast.Constant.string(value), ptr=ptr) @compile_input_to_ast.register def compile_input_scalar( value: int | float | bool, *, ptr: s_pointers.Pointer, ctx: Context ) -> qlast.Expr: return ctx.cast(s_utils.const_ast_from_python(value), ptr=ptr) @compile_input_to_ast.register(dict) @compile_input_to_ast.register(immutables.Map) def compile_input_mapping( value: Mapping[str, ConfigInput], *, ptr: s_pointers.Pointer, ctx: Context, ) -> qlast.Expr: if "_tname" in value: tname = value["_tname"] if not isinstance(tname, str): raise errors.ConfigurationError( f"type of `_tname` must be str, got: {type(tname)!r}" ) obj_type = ctx.schema.get(tname, type=s_objtypes.ObjectType) else: try: obj_type = ctx.get_type(ptr, type=s_objtypes.ObjectType) except TypeError: raise errors.ConfigurationError( f"unsupported input type {type(value)!r} " f"for {ctx.get_full_name(ptr)}" ) obj_name = obj_type.get_shortname(ctx.schema) new_ctx = Context( schema=ctx.schema, obj_type=obj_type, qual_name=ctx.get_full_name(ptr), options=ctx.options, ) return qlast.InsertQuery( subject=qlast.ObjectRef(name=obj_name.name, module=obj_name.module), shape=list(compile_dict_to_shape(value, ctx=new_ctx).values()), ) def compile_dict_to_shape( values: Mapping[str, ConfigInput], *, ctx: Context ) -> dict[str, qlast.ShapeElement]: rv = {} for name, value in values.items(): if name == "_tname": continue ptr = ctx.get_ptr(name) expr: qlast.Expr if ctx.is_multi(ptr) and not isinstance(value, str): if not typeutils.is_container(value) or isinstance(value, Mapping): raise errors.ConfigurationError( f"{ctx.get_full_name(ptr)} must be a sequence, " f"got type: {type(value)!r}" ) assert isinstance(value, Iterable) expr = qlast.Set( elements=[ compile_input_to_ast(v, ptr=ptr, ctx=ctx) for v in value ] ) else: expr = compile_input_to_ast(value, ptr=ptr, ctx=ctx) rv[name] = qlast.ShapeElement( expr=qlast.Path(steps=[qlast.Ptr(name=name)]), compexpr=expr ) return rv def compile_ast_to_operation( obj_name: str, field_name: str, expr: qlast.Expr, *, schema: s_schema.Schema, options: qlcompiler.CompilerOptions, allow_nested: bool = True, ) -> config.Operation: cmd: qlast.ConfigOp if isinstance(expr, qlast.InsertQuery): if not allow_nested: raise errors.ConfigurationError( "nested config object is not allowed" ) cmd = qlast.ConfigInsert( name=expr.subject, scope=qltypes.ConfigScope.INSTANCE, shape=expr.shape, ) else: field_name_ref = qlast.ObjectRef(name=field_name) if obj_name != "cfg::Config": field_name_ref.module = obj_name cmd = qlast.ConfigSet( name=field_name_ref, scope=qltypes.ConfigScope.INSTANCE, expr=expr, ) ir = qlcompiler.compile_ast_to_ir(cmd, schema=schema, options=options) if ( isinstance(ir, irast.ConfigSet) or isinstance(ir, irast.Statement) and isinstance((ir := ir.expr.expr), irast.ConfigInsert) ): return ireval.evaluate_to_config_op(ir, schema=schema) raise errors.InternalServerError(f"unrecognized IR: {type(ir)!r}") def compile_structured_config( objects: Mapping[str, ConfigObject], *, spec: config.Spec, schema: s_schema.Schema, source: str | None = None, allow_nested: bool = True, ) -> dict[str, immutables.Map[str, config.SettingValue]]: options = qlcompiler.CompilerOptions( modaliases={None: "cfg"}, in_server_config_op=True, ) rv = {} for obj_name, input_values in objects.items(): storage: immutables.Map[str, config.SettingValue] = immutables.Map() ctx = Context( schema=schema, obj_type=schema.get(obj_name, type=s_objtypes.ObjectType), qual_name=obj_name, options=options, ) shape = compile_dict_to_shape(input_values, ctx=ctx) for field_name, shape_el in shape.items(): if isinstance(shape_el.compexpr, qlast.Set): elements = shape_el.compexpr.elements if not elements: continue if isinstance(elements[0], qlast.InsertQuery): for ast in shape_el.compexpr.elements: op = compile_ast_to_operation( obj_name, field_name, ast, schema=schema, options=options, allow_nested=allow_nested, ) storage = op.apply(spec, storage, source=source) continue assert shape_el.compexpr is not None op = compile_ast_to_operation( obj_name, field_name, shape_el.compexpr, schema=schema, options=options, allow_nested=allow_nested, ) storage = op.apply(spec, storage, source=source) rv[obj_name] = storage return rv ================================================ FILE: edb/server/compiler/dbstate.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, Iterator, NamedTuple, Self, cast, ) import dataclasses import enum import io import pickle import time import uuid import immutables from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import qltypes from edb.schema import delta as s_delta from edb.schema import migrations as s_migrations from edb.schema import objects as s_obj from edb.schema import schema as s_schema from edb.schema import name as s_name from edb.server import config from edb.server import defines from edb.pgsql import codegen as pgcodegen from . import enums from . import sertypes class TxAction(enum.IntEnum): START = 1 COMMIT = 2 ROLLBACK = 3 DECLARE_SAVEPOINT = 4 RELEASE_SAVEPOINT = 5 ROLLBACK_TO_SAVEPOINT = 6 class MigrationAction(enum.IntEnum): START = 1 POPULATE = 2 DESCRIBE = 3 ABORT = 4 COMMIT = 5 REJECT_PROPOSED = 6 @dataclasses.dataclass(frozen=True, kw_only=True) class BaseQuery: sql: bytes is_transactional: bool = True has_dml: bool = False cache_sql: Optional[tuple[bytes, bytes]] = dataclasses.field( kw_only=True, default=None ) # (persist, evict) cache_func_call: Optional[tuple[bytes, bytes]] = dataclasses.field( kw_only=True, default=None ) warnings: tuple[errors.EdgeDBError, ...] = dataclasses.field( kw_only=True, default=() ) unsafe_isolation_dangers: tuple[errors.UnsafeIsolationLevelError, ...] = ( dataclasses.field(kw_only=True, default=()) ) @dataclasses.dataclass(frozen=True, kw_only=True) class NullQuery(BaseQuery): sql: bytes = b"" @dataclasses.dataclass(frozen=True, kw_only=True) class ServerParamConversion: param_name: str conversion_name: str additional_info: tuple[str, ...] # If the parameter is a query parameter, track its bind_args index. script_param_index: Optional[int] = None # If the parameter is a constant value, pass to directly to the server. constant_value: Optional[Any] = None @dataclasses.dataclass(frozen=True, kw_only=True) class Query(BaseQuery): sql_hash: bytes cardinality: enums.Cardinality out_type_data: bytes out_type_id: bytes in_type_data: bytes in_type_id: bytes in_type_args: Optional[list[Param]] = None globals: Optional[list[tuple[str, bool]]] = None permissions: Optional[list[str]] = None json_permissions: Optional[list[str]] = None required_permissions: Optional[list[str]] = None server_param_conversions: Optional[list[ServerParamConversion]] = None cacheable: bool = True is_explain: bool = False query_asts: Any = None run_and_rollback: bool = False @dataclasses.dataclass(frozen=True, kw_only=True) class SimpleQuery(BaseQuery): # XXX: Temporary hack, since SimpleQuery will die in_type_args: Optional[list[Param]] = None @dataclasses.dataclass(frozen=True, kw_only=True) class SessionStateQuery(BaseQuery): sql: bytes = b"" config_scope: Optional[qltypes.ConfigScope] = None is_backend_setting: bool = False requires_restart: bool = False is_system_config: bool = False config_op: Optional[config.Operation] = None is_transactional: bool = True globals: Optional[list[tuple[str, bool]]] = None in_type_data: Optional[bytes] = None in_type_id: Optional[bytes] = None in_type_args: Optional[list[Param]] = None @dataclasses.dataclass(frozen=True, kw_only=True) class DDLQuery(BaseQuery): user_schema: Optional[s_schema.Schema] feature_used_metrics: Optional[dict[str, float]] global_schema: Optional[s_schema.Schema] = None cached_reflection: Any = None is_transactional: bool = True create_db: Optional[str] = None drop_db: Optional[str] = None drop_db_reset_connections: bool = False create_db_template: Optional[str] = None create_db_mode: Optional[qlast.BranchType] = None db_op_trailer: tuple[bytes, ...] = () ddl_stmt_id: Optional[str] = None config_ops: list[config.Operation] = dataclasses.field(default_factory=list) early_non_tx_sql: Optional[tuple[bytes, ...]] = None @dataclasses.dataclass(frozen=True, kw_only=True) class TxControlQuery(BaseQuery): action: TxAction cacheable: bool modaliases: Optional[immutables.Map[Optional[str], str]] isolation_level: Optional[qltypes.TransactionIsolationLevel] = None user_schema: Optional[s_schema.Schema] = None global_schema: Optional[s_schema.Schema] = None cached_reflection: Any = None feature_used_metrics: Optional[dict[str, float]] = None sp_name: Optional[str] = None sp_id: Optional[int] = None @dataclasses.dataclass(frozen=True, kw_only=True) class MigrationControlQuery(BaseQuery): action: MigrationAction tx_action: Optional[TxAction] cacheable: bool modaliases: Optional[immutables.Map[Optional[str], str]] user_schema: Optional[s_schema.Schema] = None cached_reflection: Any = None ddl_stmt_id: Optional[str] = None @dataclasses.dataclass(frozen=True, kw_only=True) class MaintenanceQuery(BaseQuery): pass @dataclasses.dataclass(frozen=True) class Param: name: str required: bool array_type_id: Optional[uuid.UUID] outer_idx: Optional[int] sub_params: Optional[tuple[list[Optional[uuid.UUID]], tuple[Any, ...]]] typename: str ############################# @dataclasses.dataclass(kw_only=True) class QueryUnit: sql: bytes introspection_sql: Optional[bytes] = None # Status-line for the compiled command; returned to front-end # in a CommandComplete protocol message if the command is # executed successfully. When a QueryUnit contains multiple # EdgeQL queries, the status reflects the last query in the unit. status: bytes cache_key: Optional[uuid.UUID] = None cache_sql: Optional[tuple[bytes, bytes]] = None # (persist, evict) cache_func_call: Optional[tuple[bytes, bytes]] = None # (sql, hash) # Output format of this query unit output_format: enums.OutputFormat = enums.OutputFormat.NONE # Set only for units that contain queries that can be cached # as prepared statements in Postgres. sql_hash: bytes = b"" # True if all statements in *sql* can be executed inside a transaction. # If False, they will be executed separately. is_transactional: bool = True # SQL to run *before* the main command, non transactionally early_non_tx_sql: Optional[tuple[bytes, ...]] = None # Capabilities used in this query capabilities: enums.Capability = enums.Capability(0) # True if this unit contains SET commands. has_set: bool = False # If tx_id is set, it means that the unit # starts a new transaction. tx_id: Optional[int] = None # If this is the start of the transaction, the isolation level of it. tx_isolation_level: Optional[qltypes.TransactionIsolationLevel] = None # True if this unit is single 'COMMIT' command. # 'COMMIT' is always compiled to a separate QueryUnit. tx_commit: bool = False # True if this unit is single 'ROLLBACK' command. # 'ROLLBACK' is always compiled to a separate QueryUnit. tx_rollback: bool = False # True if this unit is single 'ROLLBACK TO SAVEPOINT' command. # 'ROLLBACK TO SAVEPOINT' is always compiled to a separate QueryUnit. tx_savepoint_rollback: bool = False tx_savepoint_declare: bool = False # True if this unit is `ABORT MIGRATION` command within a transaction, # that means abort_migration and tx_rollback cannot be both True tx_abort_migration: bool = False # For SAVEPOINT commands, the name and sp_id sp_name: Optional[str] = None sp_id: Optional[int] = None # True if it is safe to cache this unit. cacheable: bool = False # If non-None, contains a name of the DB that is about to be # created/deleted. If it's the former, the IO process needs to # introspect the new db. If it's the later, the server should # close all inactive unused pooled connections to it. create_db: Optional[str] = None drop_db: Optional[str] = None drop_db_reset_connections: bool = False # If non-None, contains a name of the DB that will be used as # a template database to create the database. The server should # close all inactive unused pooled connections to the template db. create_db_template: Optional[str] = None create_db_mode: Optional[str] = None # If a branch command needs extra SQL commands to be performed, # those would end up here. db_op_trailer: tuple[bytes, ...] = () # If non-None, the DDL statement will emit data packets marked # with the indicated ID. ddl_stmt_id: Optional[str] = None # Cardinality of the result set. Set to NO_RESULT if the # unit represents multiple queries compiled as one script. cardinality: enums.Cardinality = enums.Cardinality.NO_RESULT out_type_data: bytes = sertypes.NULL_TYPE_DESC out_type_id: bytes = sertypes.NULL_TYPE_ID.bytes in_type_data: bytes = sertypes.NULL_TYPE_DESC in_type_id: bytes = sertypes.NULL_TYPE_ID.bytes in_type_args: Optional[list[Param]] = None in_type_args_real_count: int = 0 globals: Optional[list[tuple[str, bool]]] = None permissions: Optional[list[str]] = None json_permissions: Optional[list[str]] = None required_permissions: Optional[list[str]] = None server_param_conversions: Optional[list[ServerParamConversion]] = None warnings: tuple[errors.EdgeDBError, ...] = () unsafe_isolation_dangers: tuple[errors.UnsafeIsolationLevelError, ...] = () # Set only when this unit contains a CONFIGURE INSTANCE command. system_config: bool = False # Set only when this unit contains a CONFIGURE DATABASE command. database_config: bool = False # Set only when this unit contains an operation that needs to have # its results read back in the middle of the script. # (SET GLOBAL, CONFIGURE DATABASE) needs_readback: bool = False # Whether any configuration change requires a server restart config_requires_restart: bool = False # Set only when this unit contains a CONFIGURE command which # alters a backend configuration setting. backend_config: bool = False # Set only when this unit contains a CONFIGURE command which # alters a system configuration setting. is_system_config: bool = False config_ops: list[config.Operation] = dataclasses.field(default_factory=list) modaliases: Optional[immutables.Map[Optional[str], str]] = None # If present, represents the future schema state after # the command is run. The schema is pickled. user_schema: Optional[bytes] = None # If present, represents updated metrics about feature use induced # by the new user_schema. feature_used_metrics: Optional[dict[str, float]] = None # Unlike user_schema, user_schema_version usually exist, pointing to the # latest user schema, which is self.user_schema if changed, or the user # schema this QueryUnit was compiled upon. user_schema_version: uuid.UUID | None = None cached_reflection: Optional[bytes] = None extensions: Optional[set[str]] = None ext_config_settings: Optional[list[config.Setting]] = None # If present, represents the future global schema state # after the command is run. The schema is pickled. global_schema: Optional[bytes] = None roles: immutables.Map[str, immutables.Map[str, Any]] | None = None is_explain: bool = False query_asts: Any = None run_and_rollback: bool = False append_tx_op: bool = False # Translation source map. source_map: Optional[pgcodegen.SourceMap] = None # For SQL queries, the length of the query prefix applied # after translation. sql_prefix_len: int = 0 @property def has_ddl(self) -> bool: return bool(self.capabilities & enums.Capability.DDL) @property def tx_control(self) -> bool: return ( bool(self.tx_id) or self.tx_rollback or self.tx_commit or self.tx_savepoint_declare or self.tx_savepoint_rollback ) def serialize(self) -> bytes: rv = io.BytesIO() rv.write(b"\x01") # 1 byte of version number pickle.dump(self, rv, -1) return rv.getvalue() @classmethod def deserialize(cls, data: bytes) -> Self: buf = memoryview(data) match buf[0]: case 0x00 | 0x01: return pickle.loads(buf[1:]) # type: ignore[no-any-return] raise ValueError(f"Bad version number: {buf[0]}") def maybe_use_func_cache(self) -> None: if self.cache_func_call is not None: sql, sql_hash = self.cache_func_call self.sql = sql self.sql_hash = sql_hash @dataclasses.dataclass class QueryUnitGroup: # All capabilities used by any query units in this group capabilities: enums.Capability = enums.Capability(0) # True if it is safe to cache this unit. cacheable: bool = True # True if any query unit has transaction control commands, like COMMIT, # ROLLBACK, START TRANSACTION or SAVEPOINT-related commands tx_control: bool = False # Cardinality of the result set. Set to NO_RESULT if the # unit group is not expected or desired to return data. cardinality: enums.Cardinality = enums.Cardinality.NO_RESULT out_type_data: bytes = sertypes.NULL_TYPE_DESC out_type_id: bytes = sertypes.NULL_TYPE_ID.bytes in_type_data: bytes = sertypes.NULL_TYPE_DESC in_type_id: bytes = sertypes.NULL_TYPE_ID.bytes in_type_args: Optional[list[Param]] = None in_type_args_real_count: int = 0 globals: Optional[list[tuple[str, bool]]] = None permissions: Optional[list[str]] = None json_permissions: Optional[list[str]] = None required_permissions: Optional[list[str]] = None server_param_conversions: Optional[list[ServerParamConversion]] = None unit_converted_param_indexes: Optional[dict[int, list[int]]] = None warnings: Optional[list[errors.EdgeDBError]] = None unsafe_isolation_dangers: ( Optional[list[errors.UnsafeIsolationLevelError]] ) = None # Cacheable QueryUnit is serialized in the compiler, so that the I/O server # doesn't need to serialize it again for persistence. _units: list[QueryUnit | bytes] = dataclasses.field(default_factory=list) # This is a I/O server-only cache for unpacked QueryUnits _unpacked_units: list[QueryUnit] | None = None state_serializer: Optional[sertypes.StateSerializer] = None cache_state: int = 0 tx_seq_id: int = 0 force_non_normalized: bool = False graphql_key_variables: Optional[list[str]] = None @property def units(self) -> list[QueryUnit]: if self._unpacked_units is None: self._unpacked_units = [ QueryUnit.deserialize(unit) if isinstance(unit, bytes) else unit for unit in self._units ] return self._unpacked_units def __iter__(self) -> Iterator[QueryUnit]: return iter(self.units) def __len__(self) -> int: return len(self._units) def __getitem__(self, item: int) -> QueryUnit: return self.units[item] def maybe_get_serialized(self, item: int) -> bytes | None: unit = self._units[item] if isinstance(unit, bytes): return unit return None def append( self, query_unit: QueryUnit, serialize: bool = True, ) -> None: self.capabilities |= query_unit.capabilities if not query_unit.cacheable: self.cacheable = False if query_unit.tx_control: self.tx_control = True self.cardinality = query_unit.cardinality self.out_type_data = query_unit.out_type_data self.out_type_id = query_unit.out_type_id self.in_type_data = query_unit.in_type_data self.in_type_id = query_unit.in_type_id self.in_type_args = query_unit.in_type_args self.in_type_args_real_count = query_unit.in_type_args_real_count if query_unit.globals is not None: if self.globals is None: self.globals = [] self.globals.extend(query_unit.globals) if query_unit.permissions is not None: if self.permissions is None: self.permissions = [] self.permissions.extend(query_unit.permissions) if query_unit.json_permissions is not None: if self.json_permissions is None: self.json_permissions = [] self.json_permissions.extend(query_unit.json_permissions) if query_unit.required_permissions is not None: if self.required_permissions is None: self.required_permissions = [] for perm in query_unit.required_permissions: if perm not in self.required_permissions: self.required_permissions.append(perm) if query_unit.server_param_conversions is not None: if self.server_param_conversions is None: self.server_param_conversions = [] if self.unit_converted_param_indexes is None: self.unit_converted_param_indexes = {} # De-duplicate param conversions and store information about which # units access which converted params. # If two units request the same conversion on the same parameter, # we should assume the conversion is stable and only do it once. unit_index = len(self._units) converted_param_indexes: list[int] = [] for spc in query_unit.server_param_conversions: if spc in self.server_param_conversions: converted_param_indexes.append( self.server_param_conversions.index(spc) ) else: converted_param_indexes.append( len(self.server_param_conversions) ) self.server_param_conversions.append(spc) self.unit_converted_param_indexes[unit_index] = ( converted_param_indexes ) if query_unit.warnings is not None: if self.warnings is None: self.warnings = [] self.warnings.extend(query_unit.warnings) if query_unit.unsafe_isolation_dangers is not None: if self.unsafe_isolation_dangers is None: self.unsafe_isolation_dangers = [] self.unsafe_isolation_dangers.extend( query_unit.unsafe_isolation_dangers) if not serialize or query_unit.cache_sql is None: self._units.append(query_unit) else: self._units.append(query_unit.serialize()) @dataclasses.dataclass(frozen=True, kw_only=True) class PreparedStmtOpData: """Common prepared statement metadata""" stmt_name: str """Original statement name as passed by the frontend""" be_stmt_name: bytes = b"" """Computed statement name as passed to the backend""" @dataclasses.dataclass(frozen=True, kw_only=True) class PrepareData(PreparedStmtOpData): """PREPARE statement data""" query: str """Translated query string""" source_map: Optional[pgcodegen.SourceMap] = None """Translation source map""" @dataclasses.dataclass(frozen=True, kw_only=True) class ExecuteData(PreparedStmtOpData): """EXECUTE statement data""" pass @dataclasses.dataclass(frozen=True, kw_only=True) class DeallocateData(PreparedStmtOpData): """DEALLOCATE statement data""" pass @dataclasses.dataclass(kw_only=True) class SQLQueryUnit: query: str = dataclasses.field(repr=False) """Translated query text.""" prefix_len: int = 0 source_map: Optional[pgcodegen.SourceMap] = None """Translation source map.""" eql_format_query: Optional[str] = dataclasses.field( repr=False, default=None) """Translated query text returning data in single-column format.""" orig_query: str = dataclasses.field(repr=False) """Original query text before translation.""" # True if it is safe to cache this unit. cacheable: bool = True cardinality: enums.Cardinality = enums.Cardinality.NO_RESULT capabilities: enums.Capability = enums.Capability.NONE fe_settings: SQLSettings """Frontend-only settings effective during translation of this unit.""" tx_action: Optional[TxAction] = None tx_chain: bool = False sp_name: Optional[str] = None prepare: Optional[PrepareData] = None execute: Optional[ExecuteData] = None deallocate: Optional[DeallocateData] = None set_vars: Optional[dict[Optional[str], Optional[SQLSetting]]] = None is_local: bool = False stmt_name: bytes = b"" """Computed prepared statement name for this query.""" frontend_only: bool = False """Whether the query is completely emulated outside of backend and so the response should be synthesized also.""" command_complete_tag: Optional[CommandCompleteTag] = None """When set, CommandComplete for this query will be overridden. This is useful, for example, for setting the tag of DML statements, which return the number of modified rows.""" params: Optional[list[SQLParam]] = None class CommandCompleteTag: """Dictates the tag of CommandComplete message that concludes this query.""" @dataclasses.dataclass(kw_only=True) class TagPlain(CommandCompleteTag): """Set the tag verbatim""" tag: bytes @dataclasses.dataclass(kw_only=True) class TagCountMessages(CommandCompleteTag): """Count DataRow messages in the response and set the tag to f'{prefix} {count_of_messages}'.""" prefix: str @dataclasses.dataclass(kw_only=True) class TagUnpackRow(CommandCompleteTag): """Intercept a single DataRow message with a single column which represents the number of modified rows. Sets the CommandComplete tag to f'{prefix} {modified_rows}'.""" prefix: str class SQLParam: # Internal query param. Represents params in the compiled SQL, so the params # that are sent to PostgreSQL. # True for params that are actually used in the compiled query. used: bool = False @dataclasses.dataclass(kw_only=True, eq=False, slots=True, repr=False) class SQLParamExternal(SQLParam): # An internal query param whose value is provided by an external param. # So a user-visible param. # External params share the index with internal params pass @dataclasses.dataclass(kw_only=True, eq=False, slots=True, repr=False) class SQLParamExtractedConst(SQLParam): # An internal query param whose value is a constant that this param has # replaced during query normalization. type_oid: int @dataclasses.dataclass(kw_only=True, eq=False, slots=True, repr=False) class SQLParamGlobal(SQLParam): # An internal query param whose value is provided by a global. global_name: s_name.QualName pg_type: tuple[str, ...] is_permission: bool internal_index: int @dataclasses.dataclass class ParsedDatabase: user_schema_pickle: bytes schema_version: uuid.UUID database_config: immutables.Map[str, config.SettingValue] ext_config_settings: list[config.Setting] feature_used_metrics: dict[str, float] protocol_version: defines.ProtocolVersion state_serializer: sertypes.StateSerializer SQLSetting = tuple[str | int | float, ...] SQLSettings = immutables.Map[Optional[str], Optional[SQLSetting]] DEFAULT_SQL_SETTINGS: SQLSettings = immutables.Map() DEFAULT_SQL_FE_SETTINGS: SQLSettings = immutables.Map( { "search_path": ("public",), "server_version": cast(SQLSetting, (defines.PGEXT_POSTGRES_VERSION,)), "server_version_num": cast( SQLSetting, (defines.PGEXT_POSTGRES_VERSION_NUM,) ), } ) @dataclasses.dataclass class SQLTransactionState: in_tx: bool settings: SQLSettings in_tx_settings: Optional[SQLSettings] in_tx_local_settings: Optional[SQLSettings] savepoints: list[tuple[str, SQLSettings, SQLSettings]] def current_fe_settings(self) -> SQLSettings: if self.in_tx: return self.in_tx_local_settings or DEFAULT_SQL_FE_SETTINGS else: return self.settings or DEFAULT_SQL_FE_SETTINGS def get(self, name: str) -> Optional[SQLSetting]: if self.in_tx: # For easier access, in_tx_local_settings is always a superset of # in_tx_settings; in_tx_settings only keeps track of non-local # settings, so that the local settings don't go across tx bounds assert self.in_tx_local_settings return self.in_tx_local_settings[name] else: return self.settings[name] def apply(self, query_unit: SQLQueryUnit) -> None: if query_unit.tx_action == TxAction.COMMIT: self.in_tx = False self.settings = self.in_tx_settings # type: ignore self.in_tx_settings = None self.in_tx_local_settings = None self.savepoints.clear() elif query_unit.tx_action == TxAction.ROLLBACK: self.in_tx = False self.in_tx_settings = None self.in_tx_local_settings = None self.savepoints.clear() elif query_unit.tx_action == TxAction.DECLARE_SAVEPOINT: assert query_unit.sp_name is not None assert self.in_tx_settings is not None assert self.in_tx_local_settings is not None self.savepoints.append( ( query_unit.sp_name, self.in_tx_settings, self.in_tx_local_settings, ) ) elif query_unit.tx_action == TxAction.ROLLBACK_TO_SAVEPOINT: while self.savepoints: sp_name, settings, local_settings = self.savepoints[-1] if query_unit.sp_name == sp_name: self.in_tx_settings = settings self.in_tx_local_settings = local_settings break else: self.savepoints.pop(0) else: raise errors.TransactionError( f'savepoint "{query_unit.sp_name}" does not exist' ) if not self.in_tx: # Always start an implicit transaction here, because in the # compiler, multiple apply() calls only happen in simple query, # and any query would start an implicit transaction. For example, # we need to support a single ROLLBACK without a matching BEGIN # rolling back an implicit transaction. self.in_tx = True self.in_tx_settings = self.settings self.in_tx_local_settings = self.settings if query_unit.frontend_only and query_unit.set_vars: for name, value in query_unit.set_vars.items(): self.set(name, value, query_unit.is_local) def set( self, name: Optional[str], value: Optional[SQLSetting], is_local: bool ) -> None: def _set(attr_name: str) -> None: settings = getattr(self, attr_name) if value is None: if name in settings: settings = settings.delete(name) else: settings = settings.set(name, value) setattr(self, attr_name, settings) if self.in_tx: _set("in_tx_local_settings") if not is_local: _set("in_tx_settings") elif not is_local: _set("settings") ############################# class ProposedMigrationStep(NamedTuple): statements: tuple[str, ...] confidence: float prompt: str prompt_id: str data_safe: bool required_user_input: tuple[dict[str, str], ...] # This isn't part of the output data, but is used to figure out # what to prohibit when something is rejected. operation_key: s_delta.CommandKey def to_json(self) -> dict[str, Any]: return { "statements": [{"text": stmt} for stmt in self.statements], "confidence": self.confidence, "prompt": self.prompt, "prompt_id": self.prompt_id, "data_safe": self.data_safe, "required_user_input": list(self.required_user_input), } class MigrationState(NamedTuple): parent_migration: Optional[s_migrations.Migration] initial_schema: s_schema.Schema initial_savepoint: Optional[str] target_schema: s_schema.Schema guidance: s_obj.DeltaGuidance accepted_cmds: tuple[qlast.Base, ...] last_proposed: Optional[tuple[ProposedMigrationStep, ...]] class MigrationRewriteState(NamedTuple): initial_savepoint: Optional[str] target_schema: s_schema.Schema accepted_migrations: tuple[qlast.CreateMigration, ...] class TransactionState(NamedTuple): id: int name: Optional[str] local_user_schema: s_schema.Schema | None global_schema: s_schema.Schema modaliases: immutables.Map[Optional[str], str] session_config: immutables.Map[str, config.SettingValue] database_config: immutables.Map[str, config.SettingValue] system_config: immutables.Map[str, config.SettingValue] cached_reflection: immutables.Map[str, tuple[str, ...]] tx: Transaction migration_state: Optional[MigrationState] = None migration_rewrite_state: Optional[MigrationRewriteState] = None @property def user_schema(self) -> s_schema.Schema: if self.local_user_schema is None: return self.tx.root_user_schema else: return self.local_user_schema class Transaction: # Fields that affects the state key are listed here. The key is used # to determine if we can reuse a previously-pickled state, so remember # to update get_state_key() below when adding new fields affecting the # state key. See also edb/server/compiler_pool/worker.py _id: int _savepoints: dict[int, TransactionState] _current: TransactionState # backref to the owning state object _constate: CompilerConnectionState def __init__( self, constate: CompilerConnectionState, *, user_schema: s_schema.Schema, global_schema: s_schema.Schema, modaliases: immutables.Map[Optional[str], str], session_config: immutables.Map[str, config.SettingValue], database_config: immutables.Map[str, config.SettingValue], system_config: immutables.Map[str, config.SettingValue], cached_reflection: immutables.Map[str, tuple[str, ...]], implicit: bool = True, ) -> None: assert not isinstance(user_schema, s_schema.ChainedSchema) self._constate = constate self._id = constate._new_txid() self._implicit = implicit self._current = TransactionState( id=self._id, name=None, local_user_schema=( None if user_schema is self.root_user_schema else user_schema ), global_schema=global_schema, modaliases=modaliases, session_config=session_config, database_config=database_config, system_config=system_config, cached_reflection=cached_reflection, tx=self, ) self._state0 = self._current self._savepoints = {} def get_state_key(self) -> tuple[int, tuple[int, ...], TransactionState]: return ( self._id, tuple(self._savepoints.keys()), self._current, # TransactionState is immutable ) @property def id(self) -> int: return self._id @property def root_user_schema(self) -> s_schema.Schema: return self._constate.root_user_schema def is_implicit(self) -> bool: return self._implicit def make_explicit(self) -> None: if self._implicit: self._implicit = False else: raise errors.TransactionError("already in explicit transaction") def declare_savepoint(self, name: str) -> int: if self.is_implicit(): raise errors.TransactionError( "savepoints can only be used in transaction blocks" ) return self._declare_savepoint(name) def start_migration(self) -> str: name = str(uuid.uuid4()) self._declare_savepoint(name) return name def _declare_savepoint(self, name: str) -> int: sp_id = self._constate._new_txid() sp_state = self._current._replace(id=sp_id, name=name) self._savepoints[sp_id] = sp_state self._constate._savepoints_log[sp_id] = sp_state return sp_id def rollback_to_savepoint(self, name: str) -> TransactionState: if self.is_implicit(): raise errors.TransactionError( "savepoints can only be used in transaction blocks" ) return self._rollback_to_savepoint(name) def abort_migration(self, name: str) -> None: self._rollback_to_savepoint(name) def _rollback_to_savepoint(self, name: str) -> TransactionState: sp_ids_to_erase = [] for sp in reversed(self._savepoints.values()): if sp.name == name: self._current = sp break sp_ids_to_erase.append(sp.id) else: raise errors.TransactionError(f"there is no {name!r} savepoint") for sp_id in sp_ids_to_erase: self._savepoints.pop(sp_id) return sp def release_savepoint(self, name: str) -> None: if self.is_implicit(): raise errors.TransactionError( "savepoints can only be used in transaction blocks" ) self._release_savepoint(name) def commit_migration(self, name: str) -> None: self._release_savepoint(name) def _release_savepoint(self, name: str) -> None: sp_ids_to_erase = [] for sp in reversed(self._savepoints.values()): sp_ids_to_erase.append(sp.id) if sp.name == name: break else: raise errors.TransactionError(f"there is no {name!r} savepoint") for sp_id in sp_ids_to_erase: self._savepoints.pop(sp_id) def get_schema(self, std_schema: s_schema.Schema) -> s_schema.Schema: return s_schema.ChainedSchema( std_schema, self._current.user_schema, self._current.global_schema, ) def get_user_schema(self) -> s_schema.Schema: return self._current.user_schema def get_user_schema_if_updated(self) -> Optional[s_schema.Schema]: if self._current.user_schema is self._state0.user_schema: return None else: return self._current.user_schema def get_global_schema(self) -> s_schema.Schema: return self._current.global_schema def get_global_schema_if_updated(self) -> Optional[s_schema.Schema]: if self._current.global_schema is self._state0.global_schema: return None else: return self._current.global_schema def get_modaliases(self) -> immutables.Map[Optional[str], str]: return self._current.modaliases def get_session_config(self) -> immutables.Map[str, config.SettingValue]: return self._current.session_config def get_database_config(self) -> immutables.Map[str, config.SettingValue]: return self._current.database_config def get_system_config(self) -> immutables.Map[str, config.SettingValue]: return self._current.system_config def get_cached_reflection_if_updated( self, ) -> Optional[immutables.Map[str, tuple[str, ...]]]: if self._current.cached_reflection == self._state0.cached_reflection: return None else: return self._current.cached_reflection def get_cached_reflection(self) -> immutables.Map[str, tuple[str, ...]]: return self._current.cached_reflection def get_migration_state(self) -> Optional[MigrationState]: return self._current.migration_state def get_migration_rewrite_state(self) -> Optional[MigrationRewriteState]: return self._current.migration_rewrite_state def update_schema(self, new_schema: s_schema.Schema) -> None: assert isinstance(new_schema, s_schema.ChainedSchema) user_schema = new_schema.get_top_schema() assert isinstance(user_schema, s_schema.Schema) global_schema = new_schema.get_global_schema() assert isinstance(global_schema, s_schema.Schema) self._current = self._current._replace( local_user_schema=user_schema, global_schema=global_schema, ) def update_modaliases( self, new_modaliases: immutables.Map[Optional[str], str] ) -> None: self._current = self._current._replace(modaliases=new_modaliases) def update_session_config( self, new_config: immutables.Map[str, config.SettingValue] ) -> None: self._current = self._current._replace(session_config=new_config) def update_database_config( self, new_config: immutables.Map[str, config.SettingValue] ) -> None: self._current = self._current._replace(database_config=new_config) def update_cached_reflection( self, new: immutables.Map[str, tuple[str, ...]], ) -> None: self._current = self._current._replace(cached_reflection=new) def update_migration_state(self, mstate: Optional[MigrationState]) -> None: self._current = self._current._replace(migration_state=mstate) def update_migration_rewrite_state( self, mrstate: Optional[MigrationRewriteState] ) -> None: self._current = self._current._replace(migration_rewrite_state=mrstate) CStateStateType = tuple[dict[int, TransactionState], Transaction, int] class CompilerConnectionState: __slots__ = ("_savepoints_log", "_current_tx", "_tx_count", "_user_schema") # Fields that affects the state key are listed here. The key is used # to determine if we can reuse a previously-pickled state, so remember # to update get_state_key() below when adding new fields affecting the # state key. See also edb/server/compiler_pool/worker.py _tx_count: int _savepoints_log: dict[int, TransactionState] _current_tx: Transaction _user_schema: Optional[s_schema.Schema] def __init__( self, *, user_schema: s_schema.Schema, global_schema: s_schema.Schema, modaliases: immutables.Map[Optional[str], str], session_config: immutables.Map[str, config.SettingValue], database_config: immutables.Map[str, config.SettingValue], system_config: immutables.Map[str, config.SettingValue], cached_reflection: immutables.Map[str, tuple[str, ...]], ): self._user_schema = user_schema self._tx_count = time.monotonic_ns() self._init_current_tx( user_schema=user_schema, global_schema=global_schema, modaliases=modaliases, session_config=session_config, database_config=database_config, system_config=system_config, cached_reflection=cached_reflection, ) self._savepoints_log = {} def get_state_key(self) -> tuple[tuple[int, ...], int, tuple[Any, ...]]: # This would be much more efficient if CompilerConnectionState # and TransactionState objects were immutable. But they are not, # so we have return ( tuple(self._savepoints_log.keys()), self._tx_count, self._current_tx.get_state_key(), ) def __getstate__(self) -> CStateStateType: return self._savepoints_log, self._current_tx, self._tx_count def __setstate__(self, state: CStateStateType) -> None: self._savepoints_log, self._current_tx, self._tx_count = state self._user_schema = None @property def root_user_schema(self) -> s_schema.Schema: assert self._user_schema is not None return self._user_schema def set_root_user_schema(self, user_schema: s_schema.Schema) -> None: self._user_schema = user_schema def _new_txid(self) -> int: self._tx_count += 1 return self._tx_count def _init_current_tx( self, *, user_schema: s_schema.Schema, global_schema: s_schema.Schema, modaliases: immutables.Map[Optional[str], str], session_config: immutables.Map[str, config.SettingValue], database_config: immutables.Map[str, config.SettingValue], system_config: immutables.Map[str, config.SettingValue], cached_reflection: immutables.Map[str, tuple[str, ...]], ) -> None: self._current_tx = Transaction( self, user_schema=user_schema, global_schema=global_schema, modaliases=modaliases, session_config=session_config, database_config=database_config, system_config=system_config, cached_reflection=cached_reflection, ) def can_sync_to_savepoint(self, spid: int) -> bool: return spid in self._savepoints_log def sync_to_savepoint(self, spid: int) -> None: """Synchronize the compiler state with the current DB state.""" if not self.can_sync_to_savepoint(spid): raise RuntimeError(f"failed to lookup savepoint with id={spid}") sp = self._savepoints_log[spid] self._current_tx = sp.tx self._current_tx._current = sp self._current_tx._id = spid # Cleanup all savepoints declared after the one we rolled back to # in the transaction we have now set as current. for id in tuple(self._current_tx._savepoints): if id > spid: self._current_tx._savepoints.pop(id) # Cleanup all savepoints declared after the one we rolled back to # in the global savepoints log. for id in tuple(self._savepoints_log): if id > spid: self._savepoints_log.pop(id) def current_tx(self) -> Transaction: return self._current_tx def start_tx(self) -> None: if self._current_tx.is_implicit(): self._current_tx.make_explicit() else: raise errors.TransactionError("already in transaction") def rollback_tx(self) -> TransactionState: # Note that we might not be in a transaction as we allow # ROLLBACKs outside of transaction blocks (just like Postgres). prior_state = self._current_tx._state0 self._init_current_tx( user_schema=prior_state.user_schema, global_schema=prior_state.global_schema, modaliases=prior_state.modaliases, session_config=prior_state.session_config, database_config=prior_state.database_config, system_config=prior_state.system_config, cached_reflection=prior_state.cached_reflection, ) return prior_state def commit_tx(self) -> TransactionState: if self._current_tx.is_implicit(): raise errors.TransactionError("cannot commit: not in transaction") latest_state = self._current_tx._current self._init_current_tx( user_schema=latest_state.user_schema, global_schema=latest_state.global_schema, modaliases=latest_state.modaliases, session_config=latest_state.session_config, database_config=latest_state.database_config, system_config=latest_state.system_config, cached_reflection=latest_state.cached_reflection, ) return latest_state def sync_tx(self, txid: int) -> None: if self._current_tx.id == txid: return if self.can_sync_to_savepoint(txid): self.sync_to_savepoint(txid) return raise errors.InternalServerError( f"failed to lookup transaction or savepoint with id={txid}" ) # pragma: no cover ================================================ FILE: edb/server/compiler/ddl.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional import dataclasses import json import textwrap import uuid from edb import errors from edb import edgeql from edb.common import debug from edb.common import ast from edb.common import uuidgen from edb.edgeql import ast as qlast from edb.edgeql import codegen as qlcodegen from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes from edb.edgeql import quote as qlquote from edb.schema import annos as s_annos from edb.schema import constraints as s_constraints from edb.schema import database as s_db from edb.schema import ddl as s_ddl from edb.schema import delta as s_delta from edb.schema import expraliases as s_expraliases from edb.schema import futures as s_futures from edb.schema import functions as s_func from edb.schema import globals as s_globals from edb.schema import indexes as s_indexes from edb.schema import links as s_links from edb.schema import migrations as s_migrations from edb.schema import objects as s_obj from edb.schema import objtypes as s_objtypes from edb.schema import policies as s_policies from edb.schema import pointers as s_pointers from edb.schema import properties as s_properties from edb.schema import rewrites as s_rewrites from edb.schema import scalars as s_scalars from edb.schema import schema as s_schema from edb.schema import triggers as s_triggers from edb.schema import utils as s_utils from edb.schema import version as s_ver from edb.pgsql import common as pg_common from edb.pgsql import delta as pg_delta from edb.pgsql import dbops as pg_dbops from . import dbstate from . import compiler NIL_QUERY = b"SELECT LIMIT 0" def compile_and_apply_ddl_stmt( ctx: compiler.CompileContext, stmt: qlast.DDLCommand, source: Optional[edgeql.Source] = None, ) -> dbstate.DDLQuery: query, _ = _compile_and_apply_ddl_stmt(ctx, stmt, source) return query def _compile_and_apply_ddl_stmt( ctx: compiler.CompileContext, stmt: qlast.DDLCommand, source: Optional[edgeql.Source] = None, ) -> tuple[dbstate.DDLQuery, Optional[pg_dbops.SQLBlock]]: if isinstance(stmt, qlast.GlobalObjectCommand): ctx._assert_not_in_migration_block(stmt) current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) mstate = current_tx.get_migration_state() if ( mstate is None and not ctx.bootstrap_mode and ctx.log_ddl_as_migrations and not isinstance( stmt, ( qlast.CreateMigration, qlast.GlobalObjectCommand, qlast.DropMigration, ), ) ): allow_bare_ddl = compiler._get_config_val(ctx, 'allow_bare_ddl') if allow_bare_ddl != "AlwaysAllow": raise errors.QueryError( "bare DDL statements are not allowed on this database branch", hint="Use the migration commands instead.", details=( f"The `allow_bare_ddl` configuration variable " f"is set to {str(allow_bare_ddl)!r}. The " f"`edgedb migrate` command normally sets this " f"to avoid accidental schema changes outside of " f"the migration flow." ), span=stmt.span, ) cm = qlast.CreateMigration( # type: ignore body=qlast.NestedQLBlock( commands=[stmt], ), commands=[ qlast.SetField( name='generated_by', value=qlast.Path( steps=[ qlast.ObjectRef( name='MigrationGeneratedBy', module='schema' ), qlast.Ptr(name='DDLStatement'), ] ), ) ], ) return _compile_and_apply_ddl_stmt(ctx, cm) assert isinstance(stmt, qlast.DDLCommand) new_schema, delta = s_ddl.delta_and_schema_from_ddl( stmt, schema=schema, modaliases=current_tx.get_modaliases(), **_get_delta_context_args(ctx), ) if debug.flags.delta_plan: debug.header('Canonical Delta Plan') debug.dump(delta, schema=schema) if mstate := current_tx.get_migration_state(): mstate = mstate._replace( accepted_cmds=mstate.accepted_cmds + (stmt,), ) last_proposed = mstate.last_proposed if last_proposed: if last_proposed[0].required_user_input or last_proposed[ 0 ].prompt_id.startswith("Rename"): # Cannot auto-apply the proposed DDL # if user input is required. # Also skip auto-applying for renames, since # renames often force a bunch of rethinking. mstate = mstate._replace(last_proposed=None) else: proposed_stmts = last_proposed[0].statements ddl_script = '\n'.join(proposed_stmts) if source and source.text() == ddl_script: # The client has confirmed the proposed migration step, # advance the proposed script. mstate = mstate._replace( last_proposed=last_proposed[1:], ) else: # The client replied with a statement that does not # match what was proposed, reset the proposed script # to force script regeneration on next DESCRIBE. mstate = mstate._replace(last_proposed=None) current_tx.update_migration_state(mstate) current_tx.update_schema(new_schema) query = dbstate.DDLQuery( sql=NIL_QUERY, user_schema=current_tx.get_user_schema(), is_transactional=True, warnings=tuple(delta.warnings), feature_used_metrics=None, ) return query, None store_migration_sdl = compiler._get_config_val(ctx, 'store_migration_sdl') if ( isinstance(stmt, qlast.CreateMigration) and store_migration_sdl == 'AlwaysStore' ): stmt.target_sdl = s_ddl.sdl_text_from_schema(new_schema) # If we are in a migration rewrite, we also don't actually # apply the DDL, just record it. (The DDL also needs to be a # CreateMigration.) if mrstate := current_tx.get_migration_rewrite_state(): if not isinstance(stmt, qlast.CreateMigration): # This will always fail, and gives us the error we need ctx._assert_not_in_migration_rewrite_block(stmt) # Tell this to the type checker raise AssertionError() mrstate = mrstate._replace( accepted_migrations=(mrstate.accepted_migrations + (stmt,)) ) current_tx.update_migration_rewrite_state(mrstate) current_tx.update_schema(new_schema) query = dbstate.DDLQuery( sql=NIL_QUERY, user_schema=current_tx.get_user_schema(), is_transactional=True, warnings=tuple(delta.warnings), feature_used_metrics=None, ) return query, None # Apply and adapt delta, build native delta plan, which # will also update the schema. block, new_types, config_ops = _process_delta(ctx, delta) ddl_stmt_id: Optional[str] = None is_transactional = block.is_transactional() if not is_transactional: if not isinstance(stmt, qlast.DatabaseCommand): raise AssertionError( f"unexpected non-transaction DDL command type: {stmt}") sql_stmts = block.get_statements() sql = sql_stmts[0].encode("utf-8") db_op_trailer = tuple(stmt.encode("utf-8") for stmt in sql_stmts[1:]) else: if new_types: # Inject a query returning backend OIDs for the newly # created types. ddl_stmt_id = str(uuidgen.uuid1mc()) new_type_ids = [ f'{pg_common.quote_literal(tid)}::uuid' for tid in new_types ] # Return newly-added type id mapping via the indirect # return channel (see PGConnection.last_indirect_return) new_types_sql = textwrap.dedent(f"""\ PERFORM edgedb.indirect_return( json_build_object( 'ddl_stmt_id', {pg_common.quote_literal(ddl_stmt_id)}, 'new_types', (SELECT json_object_agg( "id"::text, json_build_array("backend_id", "name") ) FROM edgedb_VER."_SchemaType" WHERE "id" = any(ARRAY[ {', '.join(new_type_ids)} ]) ) )::text )""" ) block.add_command(pg_dbops.Query(text=new_types_sql).code()) sql = block.to_string().encode('utf-8') db_op_trailer = () create_db = None drop_db = None drop_db_reset_connections = False create_db_template = None create_db_mode = None if isinstance(stmt, qlast.DropDatabase): drop_db = stmt.name.name drop_db_reset_connections = stmt.force elif isinstance(stmt, qlast.CreateDatabase): create_db = stmt.name.name create_db_template = stmt.template.name if stmt.template else None create_db_mode = stmt.branch_type elif isinstance(stmt, qlast.AlterDatabase): for cmd in stmt.commands: if isinstance(cmd, qlast.Rename): drop_db = stmt.name.name create_db = cmd.new_name.name drop_db_reset_connections = stmt.force if debug.flags.delta_execute_ddl: debug.header('Delta Script (DDL Only)') # The schema updates are always the last statement, so grab # everything but code = '\n\n'.join(block.get_statements()[:-1]) debug.dump_code(code, lexer='sql') if debug.flags.delta_execute: debug.header('Delta Script') debug.dump_code(sql + b"\n".join(db_op_trailer), lexer='sql') new_user_schema = current_tx.get_user_schema_if_updated() query = dbstate.DDLQuery( sql=sql, is_transactional=is_transactional, create_db=create_db, drop_db=drop_db, drop_db_reset_connections=drop_db_reset_connections, create_db_template=create_db_template, create_db_mode=create_db_mode, db_op_trailer=db_op_trailer, ddl_stmt_id=ddl_stmt_id, user_schema=new_user_schema, cached_reflection=current_tx.get_cached_reflection_if_updated(), global_schema=current_tx.get_global_schema_if_updated(), config_ops=config_ops, warnings=tuple(delta.warnings), feature_used_metrics=( produce_feature_used_metrics(ctx.compiler_state, new_user_schema) if new_user_schema else None ), ) return query, block def _new_delta_context( ctx: compiler.CompileContext, args: Any = None ) -> s_delta.CommandContext: return s_delta.CommandContext( backend_runtime_params=ctx.compiler_state.backend_runtime_params, internal_schema_mode=ctx.internal_schema_mode, **(_get_delta_context_args(ctx) if args is None else args), ) def _get_delta_context_args(ctx: compiler.CompileContext) -> dict[str, Any]: """Get the args needed for delta_and_schema_from_ddl""" return dict( stdmode=ctx.bootstrap_mode, testmode=ctx.is_testmode(), store_migration_sdl=( compiler._get_config_val(ctx, 'store_migration_sdl') ) == 'AlwaysStore', schema_object_ids=ctx.schema_object_ids, compat_ver=ctx.compat_ver, ) def _process_delta( ctx: compiler.CompileContext, delta: s_delta.DeltaRoot, context_args: Any = None, ) -> tuple[pg_dbops.SQLBlock, frozenset[str], Any]: """Adapt and process the delta command.""" current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) pgdelta = pg_delta.CommandMeta.adapt(delta) assert isinstance(pgdelta, pg_delta.DeltaRoot) context = _new_delta_context(ctx, context_args) schema = pgdelta.apply(schema, context) current_tx.update_schema(schema) if debug.flags.delta_pgsql_plan: debug.header('PgSQL Delta Plan') debug.dump(pgdelta, schema=schema) db_cmd = any( isinstance(c, s_db.BranchCommand) for c in pgdelta.get_subcommands() ) if db_cmd: block = pg_dbops.SQLBlock() new_types: frozenset[str] = frozenset() else: block = pg_dbops.PLTopBlock() new_types = frozenset(str(tid) for tid in pgdelta.new_types) # Generate SQL DDL for the delta. pgdelta.generate(block) # type: ignore # XXX: We would prefer for there to not be trampolines ever after bootstrap pgdelta.create_trampolines.generate(block) # type: ignore # Generate schema storage SQL (DML into schema storage tables). subblock = block.add_block() compiler.compile_schema_storage_in_delta( ctx, pgdelta, subblock, context=context ) # Performance hack; we really want trivial migration commands # (that only mutate the migration log) to not trigger a pg_catalog # view refresh, since many get issued as part of MIGRATION # REWRITEs. all_migration_tweaks = all( isinstance( cmd, (s_ver.AlterSchemaVersion, s_migrations.MigrationCommand) ) and not cmd.get_subcommands(type=s_delta.ObjectCommand) for cmd in delta.get_subcommands() ) if not ctx.bootstrap_mode and not all_migration_tweaks: from edb.pgsql import metaschema refresh = metaschema.generate_sql_information_schema_refresh( ctx.compiler_state.backend_runtime_params.instance_params.version ) refresh.generate(subblock) return block, new_types, pgdelta.config_ops def compile_dispatch_ql_migration( ctx: compiler.CompileContext, ql: qlast.MigrationCommand, *, in_script: bool, ) -> dbstate.BaseQuery: if ctx.expect_rollback and not isinstance( ql, (qlast.AbortMigration, qlast.AbortMigrationRewrite) ): # Only allow ABORT MIGRATION to pass when expecting a rollback if ctx.state.current_tx().get_migration_state() is None: raise errors.TransactionError( 'expected a ROLLBACK or ROLLBACK TO SAVEPOINT command' ) else: raise errors.TransactionError( 'expected a ROLLBACK or ABORT MIGRATION command' ) match ql: case qlast.CreateMigration(): ctx._assert_not_in_migration_block(ql) return compile_and_apply_ddl_stmt(ctx, ql) case qlast.StartMigration(): return _start_migration(ctx, ql, in_script) case qlast.PopulateMigration(): return _populate_migration(ctx, ql) case qlast.DescribeCurrentMigration(): return _describe_current_migration(ctx, ql) case qlast.AlterCurrentMigrationRejectProposed(): return _alter_current_migration_reject_proposed(ctx, ql) case qlast.CommitMigration(): return _commit_migration(ctx, ql) case qlast.AbortMigration(): return _abort_migration(ctx, ql) case qlast.DropMigration(): ctx._assert_not_in_migration_block(ql) return compile_and_apply_ddl_stmt(ctx, ql) case qlast.StartMigrationRewrite(): return _start_migration_rewrite(ctx, ql, in_script) case qlast.CommitMigrationRewrite(): return _commit_migration_rewrite(ctx, ql) case qlast.AbortMigrationRewrite(): return _abort_migration_rewrite(ctx, ql) case qlast.ResetSchema(): return _reset_schema(ctx, ql) case _: raise AssertionError(f'unexpected migration command: {ql}') def _start_migration( ctx: compiler.CompileContext, ql: qlast.StartMigration, in_script: bool, ) -> dbstate.BaseQuery: ctx._assert_not_in_migration_block(ql) current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) if current_tx.is_implicit() and not in_script: savepoint_name = None tx_cmd = qlast.StartTransaction() tx_query = compiler._compile_ql_transaction(ctx, tx_cmd) query = dbstate.MigrationControlQuery( sql=tx_query.sql, action=dbstate.MigrationAction.START, tx_action=tx_query.action, cacheable=False, modaliases=None, ) else: savepoint_name = current_tx.start_migration() query = dbstate.MigrationControlQuery( sql=NIL_QUERY, action=dbstate.MigrationAction.START, tx_action=None, cacheable=False, modaliases=None, ) if isinstance(ql.target, qlast.CommittedSchema): mrstate = ctx._assert_in_migration_rewrite_block(ql) target_schema = mrstate.target_schema else: assert ctx.compiler_state.std_schema is not None base_schema = s_schema.ChainedSchema( ctx.compiler_state.std_schema, s_schema.EMPTY_SCHEMA, current_tx.get_global_schema(), ) target_schema, warnings = s_ddl.apply_sdl( ql.target, base_schema=base_schema, testmode=ctx.is_testmode(), ) if not ( s_futures.future_enabled(target_schema, 'simple_scoping') or s_futures.future_enabled(target_schema, 'warn_old_scoping') ): warnings += ( errors.DeprecatedScopingError( f"\nSchema does not have 'using future simple_scoping'.\n" f"Non-simple_scoping will be removed in Gel 8.0.\n" f"See https://docs.geldata.com/reference/edgeql/" f"path_resolution\n" ), ) query = dataclasses.replace(query, warnings=tuple(warnings)) current_tx.update_migration_state( dbstate.MigrationState( parent_migration=schema.get_last_migration(), initial_schema=schema, initial_savepoint=savepoint_name, guidance=s_obj.DeltaGuidance(), target_schema=target_schema, accepted_cmds=tuple(), last_proposed=None, ), ) return query def _populate_migration( ctx: compiler.CompileContext, ql: qlast.PopulateMigration, ) -> dbstate.BaseQuery: mstate = ctx._assert_in_migration_block(ql) current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) diff = s_ddl.delta_schemas( schema, mstate.target_schema, guidance=mstate.guidance, ) if debug.flags.delta_plan: debug.header('Populate Migration Diff') debug.dump(diff, schema=schema) new_ddl: tuple[qlast.DDLCommand, ...] = tuple( s_ddl.ddlast_from_delta( # type: ignore schema, mstate.target_schema, diff, testmode=ctx.is_testmode(), ), ) all_ddl = mstate.accepted_cmds + new_ddl mstate = mstate._replace( accepted_cmds=all_ddl, last_proposed=None, ) if debug.flags.delta_plan: debug.header('Populate Migration DDL AST') text = [] for cmd in new_ddl: debug.dump(cmd) text.append(qlcodegen.generate_source(cmd, pretty=True)) debug.header('Populate Migration DDL Text') debug.dump_code(';\n'.join(text) + ';') current_tx.update_migration_state(mstate) delta_context = _new_delta_context(ctx) # We want to make *certain* that the DDL we generate # produces the correct schema when applied, so we reload # the diff from the AST instead of just relying on the # delta tree. We do this check because it is *very # important* that we not emit DDL that moves the schema # into the wrong state. # # The actual check for whether the schema matches is done # by DESCRIBE CURRENT MIGRATION AS JSON, to populate the # 'complete' flag. if debug.flags.delta_plan: debug.header('Populate Migration Applied Diff') for cmd in new_ddl: reloaded_diff = s_ddl.delta_from_ddl( cmd, schema=schema, modaliases=current_tx.get_modaliases(), **_get_delta_context_args(ctx), ) schema = reloaded_diff.apply(schema, delta_context) if debug.flags.delta_plan: debug.dump(reloaded_diff, schema=schema) current_tx.update_schema(schema) return dbstate.MigrationControlQuery( sql=NIL_QUERY, tx_action=None, action=dbstate.MigrationAction.POPULATE, cacheable=False, modaliases=None, ) def _describe_current_migration( ctx: compiler.CompileContext, ql: qlast.DescribeCurrentMigration, ) -> dbstate.BaseQuery: mstate = ctx._assert_in_migration_block(ql) current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) if ql.language is qltypes.DescribeLanguage.DDL: text = [] for stmt in mstate.accepted_cmds: # Generate uppercase DDL commands for backwards # compatibility with older migration text. text.append( qlcodegen.generate_source(stmt, pretty=True, uppercase=True) ) if text: description = ';\n'.join(text) + ';' else: description = '' desc_ql = edgeql.parse_query( f'SELECT {qlquote.quote_literal(description)}') return compiler._compile_ql_query( ctx, desc_ql, cacheable=False, migration_block_query=True, ) if ql.language is qltypes.DescribeLanguage.JSON: confirmed = [] for stmt in mstate.accepted_cmds: confirmed.append( # Add a terminating semicolon to match # "proposed", which is created by # s_ddl.statements_from_delta. # # Also generate uppercase DDL commands for # backwards compatibility with older migration # text. qlcodegen.generate_source(stmt, pretty=True, uppercase=True) + ';', ) if mstate.last_proposed is None: guided_diff = s_ddl.delta_schemas( schema, mstate.target_schema, generate_prompts=True, guidance=mstate.guidance, ) if debug.flags.delta_plan: debug.header('DESCRIBE CURRENT MIGRATION AS JSON delta') debug.dump(guided_diff) proposed_ddl = s_ddl.statements_from_delta( schema, mstate.target_schema, guided_diff, uppercase=True ) proposed_steps = [] if proposed_ddl: for ddl_text, ddl_ast, top_op in proposed_ddl: assert isinstance(top_op, s_delta.ObjectCommand) # get_ast has a lot of logic for figuring # out when an op is implicit in a parent # op. get_user_prompt does not have any of # that sort of logic, which makes it # susceptible to producing overly broad # messages. To avoid duplicating that sort # of logic, we recreate the delta from the # AST, and extract a user prompt from # *that*. # This is stupid, and it is slow. top_op2 = s_ddl.cmd_from_ddl( ddl_ast, schema=schema, modaliases=current_tx.get_modaliases(), ) assert isinstance(top_op2, s_delta.ObjectCommand) prompt_key2, prompt_text = top_op2.get_user_prompt() # Similarly, some placeholders may not have made # it into the actual query, so filter them out. used_placeholders = { p.name for p in ast.find_children(ddl_ast, qlast.Placeholder) } required_user_input = tuple( inp for inp in top_op.get_required_user_input() if inp['placeholder'] in used_placeholders ) # The prompt_id still needs to come from # the original op, though, since # orig_cmd_class is lost in ddl. prompt_key, _ = top_op.get_user_prompt() prompt_id = s_delta.get_object_command_id(prompt_key) confidence = top_op.get_annotation('confidence') assert confidence is not None step = dbstate.ProposedMigrationStep( statements=(ddl_text,), confidence=confidence, prompt=prompt_text, prompt_id=prompt_id, data_safe=top_op.is_data_safe(), required_user_input=required_user_input, operation_key=prompt_key2, ) proposed_steps.append(step) proposed_desc = proposed_steps[0].to_json() else: proposed_desc = None mstate = mstate._replace( last_proposed=tuple(proposed_steps), ) current_tx.update_migration_state(mstate) else: if mstate.last_proposed: proposed_desc = mstate.last_proposed[0].to_json() else: proposed_desc = None extra = {} complete = False if proposed_desc is None: diff = s_ddl.delta_schemas(schema, mstate.target_schema) complete = not bool(diff.get_subcommands()) if debug.flags.delta_plan and not complete: debug.header('DESCRIBE CURRENT MIGRATION AS JSON mismatch') debug.dump(diff) if not complete: extra['debug_diff'] = debug.dumps(diff) desc = ( json.dumps( { 'parent': ( str(mstate.parent_migration.get_name(schema)) if mstate.parent_migration is not None else 'initial' ), 'complete': complete, 'confirmed': confirmed, 'proposed': proposed_desc, **extra, } ) ) desc_ql = edgeql.parse_query( f'SELECT to_json({qlquote.quote_literal(desc)})' ) return compiler._compile_ql_query( ctx, desc_ql, cacheable=False, migration_block_query=True, ) raise AssertionError( f'DESCRIBE CURRENT MIGRATION AS {ql.language}' f' is not implemented' ) def _alter_current_migration_reject_proposed( ctx: compiler.CompileContext, ql: qlast.AlterCurrentMigrationRejectProposed, ) -> dbstate.BaseQuery: mstate = ctx._assert_in_migration_block(ql) current_tx = ctx.state.current_tx() if not mstate.last_proposed: # XXX: Or should we compute what the proposal would be? new_guidance = mstate.guidance else: last = mstate.last_proposed[0] cmdclass_name, mcls, classname, new_name = last.operation_key if new_name is None: new_name = classname if cmdclass_name.startswith('Create'): new_guidance = mstate.guidance._replace( banned_creations=mstate.guidance.banned_creations | { (mcls, classname), } ) elif cmdclass_name.startswith('Delete'): new_guidance = mstate.guidance._replace( banned_deletions=mstate.guidance.banned_deletions | { (mcls, classname), } ) else: new_guidance = mstate.guidance._replace( banned_alters=mstate.guidance.banned_alters | { (mcls, (classname, new_name)), } ) mstate = mstate._replace( guidance=new_guidance, last_proposed=None, ) current_tx.update_migration_state(mstate) return dbstate.MigrationControlQuery( sql=NIL_QUERY, tx_action=None, action=dbstate.MigrationAction.REJECT_PROPOSED, cacheable=False, modaliases=None, ) def _commit_migration( ctx: compiler.CompileContext, ql: qlast.CommitMigration, ) -> dbstate.BaseQuery: mstate = ctx._assert_in_migration_block(ql) current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) diff = s_ddl.delta_schemas(schema, mstate.target_schema) if list(diff.get_subcommands()): raise errors.QueryError( 'cannot commit incomplete migration', hint=( 'Please finish the migration by specifying the' ' remaining DDL operations or run POPULATE MIGRATION' ' to let the system populate the outstanding DDL' ' automatically.' ), span=ql.span, ) if debug.flags.delta_plan: debug.header('Commit Migration DDL AST') text = [] for cmd in mstate.accepted_cmds: debug.dump(cmd) text.append(qlcodegen.generate_source(cmd, pretty=True)) debug.header('Commit Migration DDL Text') debug.dump_code(';\n'.join(text) + ';') last_migration = schema.get_last_migration() if last_migration: last_migration_ref = s_utils.name_to_ast_ref( last_migration.get_name(schema), ) else: last_migration_ref = None target_sdl: Optional[str] = None store_migration_sdl = compiler._get_config_val(ctx, 'store_migration_sdl') if store_migration_sdl == 'AlwaysStore': target_sdl = s_ddl.sdl_text_from_schema(schema) create_migration = qlast.CreateMigration( # type: ignore body=qlast.NestedQLBlock( commands=mstate.accepted_cmds # type: ignore ), parent=last_migration_ref, target_sdl=target_sdl, ) current_tx.update_schema(mstate.initial_schema) current_tx.update_migration_state(None) # If we are in a migration rewrite, don't actually apply # the change, just record it. if mrstate := current_tx.get_migration_rewrite_state(): current_tx.update_schema(mstate.target_schema) mrstate = mrstate._replace( accepted_migrations=( mrstate.accepted_migrations + (create_migration,) ) ) current_tx.update_migration_rewrite_state(mrstate) return dbstate.MigrationControlQuery( sql=NIL_QUERY, action=dbstate.MigrationAction.COMMIT, tx_action=None, cacheable=False, modaliases=None, ) current_tx.update_schema(mstate.initial_schema) current_tx.update_migration_state(None) ddl_query = compile_and_apply_ddl_stmt( ctx, create_migration, ) if mstate.initial_savepoint: current_tx.commit_migration(mstate.initial_savepoint) tx_action = None else: tx_action = dbstate.TxAction.COMMIT return dbstate.MigrationControlQuery( sql=ddl_query.sql, ddl_stmt_id=ddl_query.ddl_stmt_id, action=dbstate.MigrationAction.COMMIT, tx_action=tx_action, cacheable=False, modaliases=None, user_schema=ctx.state.current_tx().get_user_schema(), cached_reflection=(current_tx.get_cached_reflection_if_updated()), ) def _abort_migration( ctx: compiler.CompileContext, ql: qlast.AbortMigration, ) -> dbstate.BaseQuery: mstate = ctx._assert_in_migration_block(ql) current_tx = ctx.state.current_tx() if mstate.initial_savepoint: current_tx.abort_migration(mstate.initial_savepoint) sql = NIL_QUERY tx_action = None else: tx_cmd = qlast.RollbackTransaction() tx_query = compiler._compile_ql_transaction(ctx, tx_cmd) sql = tx_query.sql tx_action = tx_query.action current_tx.update_migration_state(None) return dbstate.MigrationControlQuery( sql=sql, action=dbstate.MigrationAction.ABORT, tx_action=tx_action, cacheable=False, modaliases=None, ) def _start_migration_rewrite( ctx: compiler.CompileContext, ql: qlast.StartMigrationRewrite, in_script: bool, ) -> dbstate.BaseQuery: ctx._assert_not_in_migration_block(ql) ctx._assert_not_in_migration_rewrite_block(ql) current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) # Start a transaction if we aren't in one already if current_tx.is_implicit() and not in_script: savepoint_name = None tx_cmd = qlast.StartTransaction() tx_query = compiler._compile_ql_transaction(ctx, tx_cmd) query = dbstate.MigrationControlQuery( sql=tx_query.sql, action=dbstate.MigrationAction.START, tx_action=tx_query.action, cacheable=False, modaliases=None, ) else: savepoint_name = current_tx.start_migration() query = dbstate.MigrationControlQuery( sql=NIL_QUERY, action=dbstate.MigrationAction.START, tx_action=None, cacheable=False, modaliases=None, ) # Start from an empty schema except for `module default` base_schema = s_schema.ChainedSchema( ctx.compiler_state.std_schema, s_schema.EMPTY_SCHEMA, current_tx.get_global_schema(), ) new_base_schema, _ = s_ddl.apply_sdl( qlast.Schema( declarations=[ qlast.ModuleDeclaration( name=qlast.ObjectRef(name='default'), declarations=[], ) ] ), base_schema=base_schema, ) # Set our current schema to be the empty one current_tx.update_schema(new_base_schema) current_tx.update_migration_rewrite_state( dbstate.MigrationRewriteState( target_schema=schema, initial_savepoint=savepoint_name, accepted_migrations=tuple(), ), ) return query def _commit_migration_rewrite( ctx: compiler.CompileContext, ql: qlast.CommitMigrationRewrite, ) -> dbstate.BaseQuery: ctx._assert_not_in_migration_block(ql) mrstate = ctx._assert_in_migration_rewrite_block(ql) current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) diff = s_ddl.delta_schemas(schema, mrstate.target_schema) if list(diff.get_subcommands()): if debug.flags.delta_plan: debug.header("COMMIT MIGRATION REWRITE mismatch") diff.dump() raise errors.QueryError( 'cannot commit migration rewrite: schema resulting ' 'from rewrite does not match committed schema', span=ql.span, ) schema = mrstate.target_schema current_tx.update_schema(schema) current_tx.update_migration_rewrite_state(None) cmds: list[qlast.DDLCommand] = [] # Now we find all the migrations... migrations = s_migrations.get_ordered_migrations(schema) for mig in reversed(migrations): cmds.append( qlast.DropMigration( name=qlast.ObjectRef(name=mig.get_name(schema).name) ) ) for acc_cmd in mrstate.accepted_migrations: acc_cmd.metadata_only = True cmds.append(acc_cmd) if debug.flags.delta_plan: debug.header('COMMIT MIGRATION REWRITE DDL text') for cm in cmds: cm.dump_edgeql() block = pg_dbops.PLTopBlock() for cmd in cmds: _, ddl_block = _compile_and_apply_ddl_stmt(ctx, cmd) assert isinstance(ddl_block, pg_dbops.PLBlock) # We know nothing serious can be in that query # except for the SQL, so it's fine to just discard # it all. for stmt in ddl_block.get_statements(): block.add_command(stmt) if mrstate.initial_savepoint: current_tx.commit_migration(mrstate.initial_savepoint) tx_action = None else: tx_action = dbstate.TxAction.COMMIT return dbstate.MigrationControlQuery( sql=block.to_string().encode("utf-8"), action=dbstate.MigrationAction.COMMIT, tx_action=tx_action, cacheable=False, modaliases=None, user_schema=ctx.state.current_tx().get_user_schema(), cached_reflection=(current_tx.get_cached_reflection_if_updated()), ) def _abort_migration_rewrite( ctx: compiler.CompileContext, ql: qlast.AbortMigrationRewrite, ) -> dbstate.BaseQuery: mrstate = ctx._assert_in_migration_rewrite_block(ql) current_tx = ctx.state.current_tx() if mrstate.initial_savepoint: current_tx.abort_migration(mrstate.initial_savepoint) sql = NIL_QUERY tx_action = None else: tx_cmd = qlast.RollbackTransaction() tx_query = compiler._compile_ql_transaction(ctx, tx_cmd) sql = tx_query.sql tx_action = tx_query.action current_tx.update_migration_state(None) current_tx.update_migration_rewrite_state(None) query = dbstate.MigrationControlQuery( sql=sql, action=dbstate.MigrationAction.ABORT, tx_action=tx_action, cacheable=False, modaliases=None, ) return query def _reset_schema( ctx: compiler.CompileContext, ql: qlast.ResetSchema, ) -> dbstate.BaseQuery: ctx._assert_not_in_migration_block(ql) ctx._assert_not_in_migration_rewrite_block(ql) if ql.target.name != 'initial': raise errors.QueryError( f'Unknown schema version "{ql.target.name}". ' 'Currently, only revision supported is "initial"', span=ql.target.span, ) current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) empty_schema = s_schema.ChainedSchema( ctx.compiler_state.std_schema, s_schema.EMPTY_SCHEMA, current_tx.get_global_schema(), ) empty_schema, _ = s_ddl.apply_sdl( # type: ignore qlast.Schema( declarations=[ qlast.ModuleDeclaration( name=qlast.ObjectRef(name='default'), declarations=[], ) ] ), base_schema=empty_schema, ) # diff and create migration that drops all objects diff = s_ddl.delta_schemas(schema, empty_schema) new_ddl: tuple[qlast.DDLCommand, ...] = tuple( s_ddl.ddlast_from_delta(schema, empty_schema, diff), # type: ignore ) create_mig = qlast.CreateMigration( # type: ignore body=qlast.NestedQLBlock(commands=tuple(new_ddl)), # type: ignore ) ddl_query, ddl_block = _compile_and_apply_ddl_stmt(ctx, create_mig) assert ddl_block is not None # delete all migrations schema = current_tx.get_schema(ctx.compiler_state.std_schema) migrations = s_delta.sort_by_cross_refs( schema, schema.get_objects(type=s_migrations.Migration), ) for mig in migrations: drop_mig = qlast.DropMigration( name=qlast.ObjectRef(name=mig.get_name(schema).name), ) _, mig_block = _compile_and_apply_ddl_stmt(ctx, drop_mig) assert isinstance(mig_block, pg_dbops.PLBlock) for stmt in mig_block.get_statements(): ddl_block.add_command(stmt) return dbstate.MigrationControlQuery( sql=ddl_block.to_string().encode("utf-8"), ddl_stmt_id=ddl_query.ddl_stmt_id, action=dbstate.MigrationAction.COMMIT, tx_action=None, cacheable=False, modaliases=None, user_schema=current_tx.get_user_schema(), cached_reflection=(current_tx.get_cached_reflection_if_updated()), ) _FEATURE_NAMES: dict[type[s_obj.Object], str] = { s_annos.AnnotationValue: 'annotation', s_policies.AccessPolicy: 'policy', s_triggers.Trigger: 'trigger', s_rewrites.Rewrite: 'rewrite', s_globals.Global: 'global', s_expraliases.Alias: 'alias', s_func.Function: 'function', s_indexes.Index: 'index', s_scalars.ScalarType: 'scalar', s_migrations.Migration: 'migration', } def produce_feature_used_metrics( compiler_state: compiler.CompilerState, user_schema: s_schema.Schema, ) -> dict[str, float]: schema = s_schema.ChainedSchema( compiler_state.std_schema, user_schema, # Skipping global schema is a little dodgy but not that bad s_schema.EMPTY_SCHEMA, ) features: dict[str, float] = {} def _track(key: str) -> None: features[key] = features.get(key, 0) + 1 # TODO(perf): Should we optimize peeking into the innards directly # so we can skip creating the proxies? for obj in user_schema.get_objects( type=s_obj.Object, exclude_extensions=True, ): typ = type(obj) if (key := _FEATURE_NAMES.get(typ)): _track(key) if isinstance(obj, s_globals.Global) and obj.get_expr(user_schema): _track('computed_global') elif ( isinstance(obj, s_properties.Property) ): if obj.get_expr(user_schema): _track('computed_property') elif obj.get_cardinality(schema).is_multi(): _track('multi_property') if ( obj.is_link_property(schema) and not obj.is_special_pointer(schema) ): _track('link_property') elif ( isinstance(obj, s_links.Link) and obj.get_expr(user_schema) ): _track('computed_link') elif ( isinstance(obj, s_indexes.Index) and s_indexes.is_fts_index(schema, obj) ): _track('fts') elif ( isinstance(obj, s_constraints.Constraint) and not ( (subject := obj.get_subject(schema)) and isinstance(subject, s_properties.Property) and subject.is_special_pointer(schema) ) ): _track('constraint') exclusive_constr = schema.get( 'std::exclusive', type=s_constraints.Constraint ) if not obj.issubclass(schema, exclusive_constr): _track('constraint_expr') elif ( isinstance(obj, s_objtypes.ObjectType) and len(obj.get_bases(schema).objects(schema)) > 1 ): _track('multiple_inheritance') elif ( isinstance(obj, s_objtypes.ObjectType) and obj.is_material_object_type(schema) ): _track('object_type') elif ( isinstance(obj, s_scalars.ScalarType) and obj.is_enum(schema) ): _track('enum') return features def repair_schema( ctx: compiler.CompileContext, ) -> Optional[tuple[bytes, Any]]: """Repair inconsistencies in the schema caused by bug fixes Works by comparing the actual current schema to the schema we get from reloading the DDL description of the schema and then directly applying the diff. """ current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) empty_schema = s_schema.ChainedSchema( ctx.compiler_state.std_schema, s_schema.EMPTY_SCHEMA, current_tx.get_global_schema(), ) context_args = _get_delta_context_args(ctx) context_args.update(dict( testmode=True, )) text = s_ddl.ddl_text_from_schema(schema) reloaded_schema, _ = s_ddl.apply_ddl_script_ex( text, schema=empty_schema, **context_args, ) delta = s_ddl.delta_schemas( schema, reloaded_schema, ) mismatch = bool(delta.get_subcommands()) if not mismatch: return None if debug.flags.delta_plan: debug.header('Repair Delta') debug.dump(delta) if not delta.is_data_safe(): raise AssertionError( 'Repair script for version upgrade is not data safe' ) # Update the schema version also context = _new_delta_context(ctx, context_args) ver = schema.get_global( s_ver.SchemaVersion, '__schema_version__') reloaded_schema = ver.set_field_value( reloaded_schema, 'version', ver.get_version(schema)) ver_cmd = ver.init_delta_command(schema, s_delta.AlterObject) ver_cmd.set_attribute_value('version', uuidgen.uuid1mc()) reloaded_schema = ver_cmd.apply(reloaded_schema, context) delta.add(ver_cmd) # Apply and adapt delta, build native delta plan, which # will also update the schema. block, new_types, config_ops = _process_delta(ctx, delta, context_args) is_transactional = block.is_transactional() assert not new_types assert is_transactional sql = block.to_string().encode('utf-8') if debug.flags.delta_execute: debug.header('Repair Delta Script') debug.dump_code(sql, lexer='sql') return sql, config_ops def administer_repair_schema( ctx: compiler.CompileContext, ql: qlast.AdministerStmt, ) -> dbstate.BaseQuery: if ql.expr.args or ql.expr.kwargs: raise errors.QueryError( 'repair_schema() does not take arguments', span=ql.expr.span, ) current_tx = ctx.state.current_tx() res = repair_schema(ctx) if not res: return dbstate.MaintenanceQuery(sql=b"") sql, config_ops = res return dbstate.DDLQuery( sql=sql, user_schema=current_tx.get_user_schema_if_updated(), global_schema=current_tx.get_global_schema_if_updated(), config_ops=config_ops, feature_used_metrics=None, ) def administer_fixup_backend_upgrade( ctx: compiler.CompileContext, ql: qlast.AdministerStmt, ) -> dbstate.BaseQuery: if ql.expr.args or ql.expr.kwargs: raise errors.QueryError( 'fixup_backend_upgrade() does not take arguments', span=ql.expr.span, ) from edb.pgsql import metaschema block = pg_dbops.PLTopBlock() cmds = metaschema._generate_sql_information_schema( ctx.compiler_state.backend_runtime_params.instance_params.version ) metaschema.generate_drop_views(cmds, block) cmd_group = pg_dbops.CommandGroup() cmd_group.add_commands(cmds) cmd_group.generate(block) assert block.is_transactional() return dbstate.DDLQuery( sql=block.to_string().encode('utf-8'), user_schema=None, feature_used_metrics=None, ) def remove_pointless_triggers( schema: s_schema.Schema, ) -> pg_dbops.CommandGroup: from edb.pgsql import schemamech constraints = schema.get_objects( exclude_stdlib=True, type=s_constraints.Constraint, ) cmds = pg_dbops.CommandGroup() for constraint in constraints: if not pg_delta.ConstraintCommand.constraint_is_effective( schema, constraint ): continue subject = constraint.get_subject(schema) bconstr = schemamech.compile_constraint( subject, constraint, schema, None ) # Q: we could also use update_trigger_ops, which would # generate more useless code but avoid the need for an extra # code path? cmds.add_command(bconstr.fixup_trigger_ops()) return cmds def administer_remove_pointless_triggers( ctx: compiler.CompileContext, ql: qlast.AdministerStmt, ) -> dbstate.BaseQuery: if ql.expr.args or ql.expr.kwargs: raise errors.QueryError( '_remove_pointless_triggers() does not take arguments', span=ql.expr.span, ) if not ctx.is_testmode(): raise errors.QueryError( '_remove_pointless_triggers() is for testmode only', span=ql.expr.span, ) current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) block = pg_dbops.PLTopBlock() remove_pointless_triggers(schema).generate(block) src = block.to_string() if debug.flags.delta_execute_ddl or debug.flags.delta_execute: debug.header('remove_pointless_triggers') debug.dump_code(src, lexer='sql') return dbstate.DDLQuery( sql=src.encode('utf-8'), user_schema=ctx.state.current_tx().get_user_schema(), feature_used_metrics=None, ) def administer_reindex( ctx: compiler.CompileContext, ql: qlast.AdministerStmt, ) -> dbstate.BaseQuery: from edb.ir import ast as irast from edb.ir import typeutils as irtypeutils from edb.ir import utils as irutils from edb.schema import objtypes as s_objtypes from edb.schema import constraints as s_constraints from edb.schema import indexes as s_indexes if len(ql.expr.args) != 1 or ql.expr.kwargs: raise errors.QueryError( 'reindex() takes exactly one position argument', span=ql.expr.span, ) arg = ql.expr.args[0] match arg: case qlast.Path( steps=[qlast.ObjectRef()], partial=False, ): ptr = False case qlast.Path( steps=[qlast.ObjectRef(), qlast.Ptr()], partial=False, ): ptr = True case _: raise errors.QueryError( 'argument to reindex() must be an object type', span=arg.span, ) current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) modaliases = current_tx.get_modaliases() ir: irast.Statement = qlcompiler.compile_ast_to_ir( arg, schema=schema, options=qlcompiler.CompilerOptions( modaliases=modaliases ), ) expr = irutils.unwrap_set(ir.expr) if ptr: if ( not expr.expr or not isinstance(expr.expr, irast.Pointer) ): raise errors.QueryError( 'invalid pointer argument to reindex()', span=arg.span, ) rptr = expr.expr source = rptr.source else: rptr = None source = expr schema, obj = irtypeutils.ir_typeref_to_type(schema, source.typeref) if ( not isinstance(obj, s_objtypes.ObjectType) or not obj.is_material_object_type(schema) ): raise errors.QueryError( 'argument to reindex() must be a regular object type', span=arg.span, ) tables: set[s_pointers.Pointer | s_objtypes.ObjectType] = set() pindexes: set[ s_constraints.Constraint | s_indexes.Index | s_pointers.Pointer ] = set() commands = [] if not rptr: # On a type, we just reindex the type and its descendants tables.update({obj} | { desc for desc in obj.descendants(schema) if desc.is_material_object_type(schema) }) else: # On a pointer, we reindex any indexes and constraints, as well as # any link indexes (which might be table indexes on a link table) if not isinstance(rptr.ptrref, irast.PointerRef): raise errors.QueryError( 'invalid pointer argument to reindex()', span=arg.span, ) schema, ptrcls = irtypeutils.ptrcls_from_ptrref( rptr.ptrref, schema=schema) indexes = set(schema.get_referrers(ptrcls, scls_type=s_indexes.Index)) exclusive = schema.get('std::exclusive', type=s_constraints.Constraint) constrs = { c for c in schema.get_referrers(ptrcls, scls_type=s_constraints.Constraint) if c.issubclass(schema, exclusive) } pindexes.update(indexes | constrs) pindexes.update({ desc for pindex in pindexes for desc in pindex.descendants(schema) }) # For links, collect any single link indexes and any link table indexes if not ptrcls.is_property(): ptrclses = {ptrcls} | { desc for desc in ptrcls.descendants(schema) if isinstance( (src := desc.get_source(schema)), s_objtypes.ObjectType) and src.is_material_object_type(schema) } card = ptrcls.get_cardinality(schema) if card.is_single(): pindexes.update(ptrclses) if card.is_multi() or ptrcls.has_user_defined_properties(schema): tables.update(ptrclses) commands = [ f'REINDEX TABLE ' f'{pg_common.get_backend_name(schema, table)};' for table in tables ] + [ f'REINDEX INDEX ' f'{pg_common.get_backend_name(schema, pindex, aspect="index")};' for pindex in pindexes ] block = pg_dbops.PLTopBlock() for command in commands: block.add_command(command) return dbstate.MaintenanceQuery(sql=block.to_string().encode("utf-8")) def _identify_administer_tables_and_cols( ctx: compiler.CompileContext, call: qlast.FunctionCall, ) -> list[str]: from edb.ir import ast as irast from edb.ir import typeutils as irtypeutils from edb.schema import objtypes as s_objtypes args: list[tuple[irast.Pointer | None, s_objtypes.ObjectType]] = [] current_tx = ctx.state.current_tx() schema = current_tx.get_schema(ctx.compiler_state.std_schema) modaliases = current_tx.get_modaliases() for arg in call.args: match arg: case qlast.Path( steps=[qlast.ObjectRef()], partial=False, ): ptr = False case qlast.Path( steps=[qlast.ObjectRef(), qlast.Ptr()], partial=False, ): ptr = True case _: raise errors.QueryError( 'argument to vacuum() must be an object type ' 'or a link or property reference', span=arg.span, ) ir: irast.Statement = qlcompiler.compile_ast_to_ir( arg, schema=schema, options=qlcompiler.CompilerOptions( modaliases=modaliases ), ) expr = ir.expr if ptr: if ( not expr.expr or not isinstance(expr.expr, irast.SelectStmt) or not isinstance(expr.expr.result.expr, irast.Pointer) ): raise errors.QueryError( 'invalid pointer argument to vacuum()', span=arg.span, ) rptr = expr.expr.result.expr source = rptr.source else: rptr = None source = expr schema, obj = irtypeutils.ir_typeref_to_type(schema, source.typeref) if ( not isinstance(obj, s_objtypes.ObjectType) or not obj.is_material_object_type(schema) ): raise errors.QueryError( 'argument to vacuum() must be an object type ' 'or a link or property reference', span=arg.span, ) args.append((rptr, obj)) tables: set[s_pointers.Pointer | s_objtypes.ObjectType] = set() for arg, (rptr, obj) in zip(call.args, args): if not rptr: # On a type, we just vacuum the type and its descendants tables.update({obj} | { desc for desc in obj.descendants(schema) if desc.is_material_object_type(schema) }) else: # On a pointer, we must go over the pointer and its descendants # so that we may retrieve any link talbes if necessary. if not isinstance(rptr.ptrref, irast.PointerRef): raise errors.QueryError( 'invalid pointer argument to vacuum()', span=arg.span, ) schema, ptrcls = irtypeutils.ptrcls_from_ptrref( rptr.ptrref, schema=schema) card = ptrcls.get_cardinality(schema) if not ( card.is_multi() or ptrcls.has_user_defined_properties(schema) ): vn = ptrcls.get_verbosename(schema, with_parent=True) if ptrcls.is_property(): raise errors.QueryError( f'{vn} is not a valid argument to vacuum() ' f'because it is not a multi property', span=arg.span, ) else: raise errors.QueryError( f'{vn} is not a valid argument to vacuum() ' f'because it is neither a multi link nor ' f'does it have link properties', span=arg.span, ) ptrclses = {ptrcls} | { desc for desc in ptrcls.descendants(schema) if isinstance( (src := desc.get_source(schema)), s_objtypes.ObjectType) and src.is_material_object_type(schema) } tables.update(ptrclses) return [ pg_common.get_backend_name(schema, table) for table in tables ] def administer_vacuum( ctx: compiler.CompileContext, ql: qlast.AdministerStmt, ) -> dbstate.BaseQuery: # check that the kwargs are valid kwargs: dict[str, str] = {} for name, val in ql.expr.kwargs.items(): if name not in ('statistics_update', 'full'): raise errors.QueryError( f'unrecognized keyword argument {name!r} for vacuum()', span=val.span, ) elif ( not isinstance(val, qlast.Constant) or val.kind != qlast.ConstantKind.BOOLEAN ): raise errors.QueryError( f'argument {name!r} for vacuum() must be a boolean literal', span=val.span, ) kwargs[name] = val.value option_map = { "statistics_update": "ANALYZE", "full": "FULL", } command = "VACUUM" options = ",".join( f"{option_map[k.lower()]} {v.upper()}" for k, v in kwargs.items() ) if options: command += f" ({options})" command += " " + ", ".join( _identify_administer_tables_and_cols(ctx, ql.expr), ) return dbstate.MaintenanceQuery( sql=command.encode('utf-8'), is_transactional=False, ) def administer_statistics_update( ctx: compiler.CompileContext, ql: qlast.AdministerStmt, ) -> dbstate.BaseQuery: for name, val in ql.expr.kwargs.items(): raise errors.QueryError( f'unrecognized keyword argument {name!r} for statistics_update()', span=val.span, ) command = "ANALYZE " + ", ".join( _identify_administer_tables_and_cols(ctx, ql.expr), ) return dbstate.MaintenanceQuery( sql=command.encode('utf-8'), is_transactional=True, ) def administer_prepare_upgrade( ctx: compiler.CompileContext, ql: qlast.AdministerStmt, ) -> dbstate.BaseQuery: user_schema = ctx.state.current_tx().get_user_schema() global_schema = ctx.state.current_tx().get_global_schema() schema = s_schema.ChainedSchema( ctx.compiler_state.std_schema, user_schema, global_schema ) schema_ddl = s_ddl.ddl_text_from_schema( schema, include_migrations=True) ids, _ = compiler.get_obj_ids(schema, include_extras=True) json_ids = [(name, cls, str(id)) for name, cls, id in ids] obj = dict( ddl=schema_ddl, ids=json_ids ) desc_ql = edgeql.parse_query( f'SELECT to_json({qlquote.quote_literal(json.dumps(obj))})' ) return compiler._compile_ql_query( ctx, desc_ql, cacheable=False, migration_block_query=True, ) def _get_index( ctx: compiler.CompileContext, ql: qlast.AdministerStmt, ) -> tuple[s_indexes.Index, s_schema.Schema]: if len(ql.expr.args) != 1 or ql.expr.kwargs: raise errors.QueryError( f'{ql.expr.func}() takes exactly one positional argument', span=ql.expr.span, ) # This is janky, and we shouldn't do it. arg = ql.expr.args[0] match arg: case qlast.TypeCast( type=qlast.TypeName( maintype=qlast.ObjectRef( name='uuid', module='std' | None, ), subtypes=None, ), expr=qlast.Constant( kind=qlast.ConstantKind.STRING, value=id_string, ) ): pass case _: raise errors.QueryError( f'argument to {ql.expr.func}() must be a uuid literal', span=arg.span, ) try: id = uuid.UUID(id_string) except ValueError: raise errors.QueryError("Invalid index id") user_schema = ctx.state.current_tx().get_user_schema() global_schema = ctx.state.current_tx().get_global_schema() schema = s_schema.ChainedSchema( ctx.compiler_state.std_schema, user_schema, global_schema ) index = schema.get_by_id(id, type=s_indexes.Index) return index, schema def administer_concurrent_index_build( ctx: compiler.CompileContext, ql: qlast.AdministerStmt, ) -> dbstate.BaseQuery: index, schema = _get_index(ctx, ql) if not index.get_build_concurrently(schema): raise errors.QueryError("Index was not created concurrently") if index.get_active(schema): raise errors.QueryError("Index is already active") delta_context = _new_delta_context(ctx) create_index = pg_delta.CreateIndex.create_index( index, schema, delta_context ) block = pg_dbops.SQLBlock() block.set_non_transactional() create_index.generate(block) # HACK: Separate out the real index command and the comments # (where do the comments even get done??) assert isinstance(block.commands[0], pg_dbops.PLBlock) statements = block.commands[0].get_statements() index_command, comments = statements # Update the schema::Index to set `active = true` block = pg_dbops.PLTopBlock() context = s_delta.CommandContext() delta_root = s_delta.DeltaRoot() root, alter_index, _ = index.init_delta_branch( schema, context, s_delta.AlterObject ) alter_index.set_attribute_value('active', True) delta_root.add(root) nschema = delta_root.apply(schema, context) # Construct the command compiler.compile_schema_storage_in_delta(ctx, delta_root, block, context) block.add_command(comments) sql = block.to_string().encode('utf-8') if debug.flags.delta_execute_ddl: debug.header('ADMINISTER concurrent_index_build(...)') debug.dump_code(index_command, lexer='sql') debug.dump_code(sql, lexer='sql') assert isinstance(nschema, s_schema.ChainedSchema) return dbstate.DDLQuery( early_non_tx_sql=(index_command.encode('utf-8'),), sql=sql, is_transactional=False, feature_used_metrics=None, user_schema=nschema.get_top_schema(), ) def validate_schema_equivalence( state: compiler.CompilerState, schema_a: s_schema.Schema, schema_b: s_schema.Schema, global_schema: s_schema.Schema, ) -> None: schema_a_full = s_schema.ChainedSchema( state.std_schema, schema_a, global_schema, ) schema_b_full = s_schema.ChainedSchema( state.std_schema, schema_b, global_schema, ) diff = s_ddl.delta_schemas(schema_a_full, schema_b_full) complete = not bool(diff.get_subcommands()) if not complete: if debug.flags.delta_plan: debug.header('COMPARE SCHEMAS MISMATCH') debug.dump(diff) raise AssertionError( f'schemas did not match after introspection:\n{debug.dumps(diff)}' ) ================================================ FILE: edb/server/compiler/enums.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Callable, TypeVar, TYPE_CHECKING import enum from edb.common import enum as strenum from edb.edgeql import qltypes as ir from edb.protocol.enums import Cardinality as Cardinality if TYPE_CHECKING: Error_T = TypeVar('Error_T') TypeTag = ir.TypeTag class Capability(enum.IntFlag): # Capability flags that are part of the protocol. # Can be picked up with the PROTO_CAPS mask. MODIFICATIONS = 1 << 0 # noqa SESSION_CONFIG = 1 << 1 # noqa TRANSACTION = 1 << 2 # noqa DDL = 1 << 3 # noqa PERSISTENT_CONFIG = 1 << 4 # noqa # Internal only capability flags. GLOBAL_DDL = 1 << 57 # noqa SQL_SESSION_CONFIG= 1 << 58 # noqa BRANCH_CONFIG = 1 << 59 # noqa INSTANCE_CONFIG = 1 << 60 # noqa DESCRIBE = 1 << 61 # noqa ANALYZE = 1 << 62 # noqa ADMINISTER = 1 << 63 # noqa PROTO_CAPS = (1 << 32) - 1 # noqa ALL = (1 << 64) - 1 # noqa WRITE = (MODIFICATIONS | DDL | PERSISTENT_CONFIG) # noqa NONE = 0 # noqa def make_error( self, allowed: Capability, error_constructor: Callable[[str], Error_T], reason: str, ) -> Error_T: for item in Capability: if item & allowed: continue if self & item: return error_constructor( f"cannot execute {CAPABILITY_TITLES[item]}: {reason}") raise AssertionError( f"extra capability not found in" f" {self} allowed {allowed}" ) CAPABILITY_TITLES = { Capability.MODIFICATIONS: 'data modification queries', Capability.SESSION_CONFIG: 'session configuration queries', Capability.TRANSACTION: 'transaction control commands', Capability.DDL: 'DDL commands', Capability.PERSISTENT_CONFIG: 'configuration commands', Capability.ADMINISTER: 'ADMINISTER commands', Capability.DESCRIBE: 'DESCRIBE commands', Capability.ANALYZE: 'ANALYZE commands', Capability.INSTANCE_CONFIG: 'instance configuration commands', Capability.BRANCH_CONFIG: 'database branch configuration commands', Capability.SQL_SESSION_CONFIG: 'sql session configuration commands', Capability.GLOBAL_DDL: 'instance-wide DDL commands', } class OutputFormat(strenum.StrEnum): BINARY = 'BINARY' JSON = 'JSON' JSON_ELEMENTS = 'JSON_ELEMENTS' NONE = 'NONE' class InputFormat(strenum.StrEnum): BINARY = 'BINARY' JSON = 'JSON' class InputLanguage(strenum.StrEnum): EDGEQL = 'EDGEQL' SQL = 'SQL' SQL_PARAMS = 'SQL_PARAMS' GRAPHQL = 'GRAPHQL' def cardinality_from_ir_value(card: ir.Cardinality) -> Cardinality: if card is ir.Cardinality.AT_MOST_ONE: return Cardinality.AT_MOST_ONE elif card is ir.Cardinality.ONE: return Cardinality.ONE elif card is ir.Cardinality.MANY: return Cardinality.MANY elif card is ir.Cardinality.AT_LEAST_ONE: return Cardinality.AT_LEAST_ONE else: raise ValueError( f"Cardinality.from_ir_value() got an invalid input: {card}" ) ================================================ FILE: edb/server/compiler/errormech.py ================================================ # mypy: allow-untyped-defs, allow-incomplete-defs # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, NamedTuple import json import re from edb import errors from edb.common import value_dispatch from edb.common import uuidgen from edb.graphql import types as gql_types from edb.pgsql.parser import exceptions as parser_errors from edb.schema import name as sn from edb.schema import objtypes as s_objtypes from edb.schema import pointers as s_pointers from edb.schema import schema as s_schema from edb.schema import constraints as s_constraints from edb.pgsql import common from edb.pgsql import types from edb.server.pgcon import errors as pgerrors class SchemaRequired: '''A sentinel used to signal that a particular error requires a schema.''' # Error codes that always require the schema to be resolved. There are # other error codes that only require the schema under certain # circumstances. SCHEMA_CODES = frozenset({ pgerrors.ERROR_INVALID_TEXT_REPRESENTATION, pgerrors.ERROR_NUMERIC_VALUE_OUT_OF_RANGE, pgerrors.ERROR_INVALID_DATETIME_FORMAT, pgerrors.ERROR_DATETIME_FIELD_OVERFLOW, }) class ErrorDetails(NamedTuple): message: str detail: Optional[str] = None detail_json: Optional[dict[str, Any]] = None code: Optional[str] = None schema_name: Optional[str] = None table_name: Optional[str] = None column_name: Optional[str] = None constraint_name: Optional[str] = None errcls: Optional[type[errors.EdgeDBError]] = None constraint_errors = frozenset({ pgerrors.ERROR_INTEGRITY_CONSTRAINT_VIOLATION, pgerrors.ERROR_RESTRICT_VIOLATION, pgerrors.ERROR_NOT_NULL_VIOLATION, pgerrors.ERROR_FOREIGN_KEY_VIOLATION, pgerrors.ERROR_UNIQUE_VIOLATION, pgerrors.ERROR_CHECK_VIOLATION, pgerrors.ERROR_EXCLUSION_VIOLATION, }) branch_errors = { pgerrors.ERROR_INVALID_CATALOG_NAME: errors.UnknownDatabaseError, pgerrors.ERROR_DUPLICATE_DATABASE: errors.DuplicateDatabaseDefinitionError, } directly_mappable: dict[str, type | tuple[type, str]] = { pgerrors.ERROR_DIVISION_BY_ZERO: errors.DivisionByZeroError, pgerrors.ERROR_INTERVAL_FIELD_OVERFLOW: errors.NumericOutOfRangeError, pgerrors.ERROR_READ_ONLY_SQL_TRANSACTION: ( errors.TransactionError, "Modifications not allowed in a read-only transaction" ), pgerrors.ERROR_SERIALIZATION_FAILURE: errors.TransactionSerializationError, pgerrors.ERROR_DEADLOCK_DETECTED: errors.TransactionDeadlockError, pgerrors.ERROR_OBJECT_IN_USE: errors.ExecutionError, pgerrors.ERROR_IDLE_IN_TRANSACTION_TIMEOUT: errors.IdleTransactionTimeoutError, pgerrors.ERROR_QUERY_CANCELLED: errors.QueryTimeoutError, pgerrors.ERROR_INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: errors.InvalidValueError, pgerrors.ERROR_INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: ( errors.InvalidValueError), pgerrors.ERROR_INVALID_REGULAR_EXPRESSION: errors.InvalidValueError, pgerrors.ERROR_INVALID_LOGARITHM_ARGUMENT: errors.InvalidValueError, pgerrors.ERROR_INVALID_POWER_ARGUMENT: errors.InvalidValueError, pgerrors.ERROR_INSUFFICIENT_PRIVILEGE: errors.AccessPolicyError, pgerrors.ERROR_PROGRAM_LIMIT_EXCEEDED: errors.InvalidValueError, pgerrors.ERROR_DATA_EXCEPTION: errors.InvalidValueError, pgerrors.ERROR_CHARACTER_NOT_IN_REPERTOIRE: errors.InvalidValueError, } constraint_res = { 'cardinality': re.compile(r'^.*".*_cardinality_idx".*$'), 'link_target': re.compile(r'^.*link target constraint$'), 'constraint': re.compile(r'^.*;schemaconstr(?:#\d+)?".*$'), 'idconstraint': re.compile(r'^.*".*_pkey".*$'), 'newconstraint': re.compile(r'^.*violate the new constraint.*$'), 'id': re.compile(r'^.*"(?:\w+)_data_pkey".*$'), 'link_target_del': re.compile(r'^.*link target policy$'), 'scalar': re.compile( r'^value for domain ([\w\.]+) violates check constraint "(.+)"' ), } range_constraints = frozenset({ 'timestamptz_t_check', 'timestamp_t_check', 'date_t_check', }) pgtype_re = re.compile( '|'.join(fr'\b{key}\b' for key in types.base_type_name_map_r)) enum_re = re.compile( r'(?P

enum) (?Pedgedb([\w-]+)."(?P[\w-]+)_domain")') cache_function_re = re.compile( r'^function edgedb_.*\.__qh_.* does not exist$') type_in_access_policy_re = re.compile(r'(\w+|`.+?`)::(\w+|`.+?`)') def gql_translate_pgtype_inner(schema, msg): """Try to replace any internal pg type name with a GraphQL type name""" # Mapping base types def base_type_map(name: str) -> str: result = gql_types.EDB_TO_GQL_SCALARS_MAP.get( str(types.base_type_name_map_r.get(name)) ) if result is None: return name else: return result.name translated = pgtype_re.sub( lambda r: base_type_map(r.group(0)), msg, ) if translated != msg: return translated def replace(r): type_id = uuidgen.UUID(r.group('id')) stype = schema.get_by_id(type_id, None) gql_name = gql_types.GQLCoreSchema.get_gql_name( stype.get_name(schema)) if stype: return f'{r.group("p")} {gql_name!r}' else: return f'{r.group("p")} {r.group("v")}' translated = enum_re.sub(replace, msg) return translated def gql_replace_type_names_in_text(msg): return type_in_access_policy_re.sub( lambda m: gql_types.GQLCoreSchema.get_gql_name( sn.QualName.from_string(m.group(0))), msg, ) def eql_translate_pgtype_inner(schema, msg): """Try to replace any internal pg type name with an edgedb type name""" translated = pgtype_re.sub( lambda r: str(types.base_type_name_map_r.get(r.group(0), r.group(0))), msg, ) if translated != msg: return translated def replace(r): type_id = uuidgen.UUID(r.group('id')) stype = schema.get_by_id(type_id, None) if stype: return f'{r.group("p")} {stype.get_displayname(schema)!r}' else: return f'{r.group("p")} {r.group("v")}' translated = enum_re.sub(replace, msg) return translated def translate_pgtype(schema, msg, from_graphql=False): """Try to translate a message that might refer to internal pg types. We *want* to replace internal pg type names with edgedb names, but only when they actually refer to types. The messages aren't really structured well enough to support this properly, so we approximate it by only doing the replacement *before* the first colon in the message, so if a user does `"bigint"`, and we get the message 'invalid input syntax for type bigint: "bigint"', we do the right thing. """ leading, *rest = msg.split(':') if from_graphql: leading_translated = gql_translate_pgtype_inner(schema, leading) else: leading_translated = eql_translate_pgtype_inner(schema, leading) return ':'.join([leading_translated, *rest]) def get_error_details(fields): # See https://www.postgresql.org/docs/current/protocol-error-fields.html # for the full list of PostgreSQL error message fields. message = fields.get('M') detail = fields.get('D') detail_json = None if detail and detail.startswith('{'): detail_json = json.loads(detail) detail = None if detail_json: errcode = detail_json.get('code') if errcode: try: errcls = type(errors.EdgeDBError).get_error_class_from_code( errcode) except LookupError: pass else: return ErrorDetails( errcls=errcls, message=message, detail_json=detail_json) code = fields['C'] schema_name = fields.get('s') table_name = fields.get('t') column_name = fields.get('c') constraint_name = fields.get('n') return ErrorDetails( message=message, detail=detail, detail_json=detail_json, code=code, schema_name=schema_name, table_name=table_name, column_name=column_name, constraint_name=constraint_name ) def get_generic_exception_from_err_details(err_details): err = None if err_details.errcls is not None: err = err_details.errcls(err_details.message) if err_details.errcls is not errors.InternalServerError: err.set_linecol( err_details.detail_json.get('line', -1), err_details.detail_json.get('column', -1)) return err ######################################################################### # Static errors interpretation ######################################################################### def static_interpret_backend_error(fields, from_graphql=False): err_details = get_error_details(fields) # handle some generic errors if possible err = get_generic_exception_from_err_details(err_details) if err is not None: return err return static_interpret_by_code( err_details.code, err_details, from_graphql=from_graphql) @value_dispatch.value_dispatch def static_interpret_by_code( _code: str, err_details: ErrorDetails, from_graphql: bool = False, ): return errors.InternalServerError(err_details.message) @static_interpret_by_code.register_for_all(branch_errors.keys()) def _static_interpret_branch_errors( code: str, err_details: ErrorDetails, from_graphql: bool = False, ): errcls = branch_errors[code] msg = err_details.message.replace('database', 'branch', 1) return errcls(msg) @static_interpret_by_code.register_for_all(directly_mappable.keys()) def _static_interpret_directly_mappable( code: str, err_details: ErrorDetails, from_graphql: bool = False, ): mapped = directly_mappable[code] if isinstance(mapped, type): errcls = mapped err_message = err_details.message else: errcls, err_message = mapped if from_graphql: msg = gql_replace_type_names_in_text(err_message) else: msg = err_message return errcls(msg) @static_interpret_by_code.register_for_all(constraint_errors) def _static_interpret_constraint_errors( code: str, err_details: ErrorDetails, from_graphql: bool = False, ): if code == pgerrors.ERROR_NOT_NULL_VIOLATION: if err_details.table_name or err_details.column_name: return SchemaRequired else: return errors.InternalServerError(err_details.message) for errtype, ere in constraint_res.items(): m = ere.match(err_details.message) if m: error_type = errtype break else: return errors.InternalServerError(err_details.message) if error_type == 'cardinality': return errors.CardinalityViolationError('cardinality violation') elif error_type == 'link_target': if err_details.detail_json: srcname = err_details.detail_json.get('source') ptrname = err_details.detail_json.get('pointer') target = err_details.detail_json.get('target') expected = err_details.detail_json.get('expected') if srcname and ptrname: srcname = sn.QualName.from_string(srcname) ptrname = sn.QualName.from_string(ptrname) lname = '{}.{}'.format(srcname, ptrname.name) else: lname = '' msg = ( f'invalid target for link {lname!r}: {target!r} ' f'(expecting {expected!r})' ) else: msg = 'invalid target for link' return errors.UnknownLinkError(msg) elif error_type == 'link_target_del': if from_graphql: msg = gql_replace_type_names_in_text(err_details.message) else: msg = err_details.message return errors.ConstraintViolationError( msg, details=err_details.detail) elif error_type == 'constraint': if err_details.constraint_name is None: return errors.InternalServerError(err_details.message) constraint_id, _, _ = err_details.constraint_name.rpartition(';') try: uuidgen.UUID(constraint_id) except ValueError: return errors.InternalServerError(err_details.message) return SchemaRequired elif error_type == 'idconstraint': if err_details.constraint_name is None: return errors.InternalServerError(err_details.message) constraint_id, _, _ = err_details.constraint_name.rpartition('_') try: uuidgen.UUID(constraint_id) except ValueError: return errors.InternalServerError(err_details.message) return SchemaRequired elif error_type == 'newconstraint': # We can reconstruct what went wrong from the schema_name, # table_name, and column_name. But we don't expect # constraint_name to be present (because the constraint is # not yet present in the schema?). if (err_details.schema_name and err_details.table_name and err_details.column_name): return SchemaRequired else: return errors.InternalServerError(err_details.message) elif error_type == 'scalar': return SchemaRequired elif error_type == 'id': return errors.ConstraintViolationError( 'unique link constraint violation') @static_interpret_by_code.register_for_all(SCHEMA_CODES) def _static_interpret_schema_errors( code: str, err_details: ErrorDetails, from_graphql: bool = False, ): if code == pgerrors.ERROR_INVALID_DATETIME_FORMAT: hint = None if err_details.detail_json: hint = err_details.detail_json.get('hint') if err_details.message.startswith('missing required time zone'): return errors.InvalidValueError(err_details.message, hint=hint) elif err_details.message.startswith('unexpected time zone'): return errors.InvalidValueError(err_details.message, hint=hint) return SchemaRequired @static_interpret_by_code.register(pgerrors.ERROR_UNDEFINED_FUNCTION) def _static_interpret_undefined_function( _code: str, err_details: ErrorDetails, from_graphql: bool = False, ): if cache_function_re.match(err_details.message): return errors.QueryCacheInvalidationError( 'Cache invalidation caused query failure; retry the query' ) return errors.InternalServerError(err_details.message) @static_interpret_by_code.register(pgerrors.ERROR_INVALID_PARAMETER_VALUE) def _static_interpret_invalid_param_value( _code: str, err_details: ErrorDetails, from_graphql: bool = False, ): error_message_context = '' if err_details.detail_json: error_message_context = ( err_details.detail_json.get('error_message_context', '') ) return errors.InvalidValueError( error_message_context + err_details.message, details=err_details.detail if err_details.detail else None, ) @static_interpret_by_code.register(pgerrors.ERROR_WRONG_OBJECT_TYPE) def _static_interpret_wrong_object_type( _code: str, err_details: ErrorDetails, from_graphql: bool = False, ): if err_details.column_name: return SchemaRequired hint = None error_message_context = '' if err_details.detail_json: hint = err_details.detail_json.get('hint') error_message_context = ( err_details.detail_json.get('error_message_context', '') ) return errors.InvalidValueError( error_message_context + err_details.message, details=err_details.detail if err_details.detail else None, hint=hint, ) @static_interpret_by_code.register(pgerrors.ERROR_CARDINALITY_VIOLATION) def _static_interpret_cardinality_violation( _code: str, err_details: ErrorDetails, from_graphql: bool = False, ): if (err_details.constraint_name == 'std::assert_single' or err_details.constraint_name == 'std::assert_exists'): return errors.CardinalityViolationError(err_details.message) elif err_details.constraint_name == 'std::assert_distinct': return errors.ConstraintViolationError(err_details.message) elif err_details.constraint_name == 'std::assert': return errors.QueryAssertionError(err_details.message) elif err_details.constraint_name == 'set abstract': return errors.ConstraintViolationError(err_details.message) return errors.InternalServerError(err_details.message) @static_interpret_by_code.register(pgerrors.ERROR_FEATURE_NOT_SUPPORTED) def _static_interpret_feature_not_supported( _code: str, err_details: ErrorDetails, from_graphql: bool = False, ): return errors.UnsupportedBackendFeatureError(err_details.message) ######################################################################### # Errors interpretation that requires a schema ######################################################################### def interpret_backend_error(schema, fields, from_graphql=False): # all generic errors are static and have been handled by this point err_details = get_error_details(fields) hint = None if err_details.detail_json: hint = err_details.detail_json.get('hint') return interpret_by_code(err_details.code, schema, err_details, hint, from_graphql=from_graphql) @value_dispatch.value_dispatch def interpret_by_code(code, schema, err_details, hint, from_graphql=False): return errors.InternalServerError(err_details.message) @interpret_by_code.register_for_all(constraint_errors) def _interpret_constraint_errors( code: str, schema: s_schema.Schema, err_details: ErrorDetails, hint: Optional[str], from_graphql: bool = False, ): details = None if code == pgerrors.ERROR_NOT_NULL_VIOLATION: colname = err_details.column_name if colname: if colname.startswith('??'): ptr_id, *_ = colname[2:].partition('_') else: ptr_id = colname if ptr_id == 'id': assert err_details.table_name obj_type: s_objtypes.ObjectType = schema.get_by_id( uuidgen.UUID(err_details.table_name), type=s_objtypes.ObjectType, ) pointer = obj_type.getptr(schema, sn.UnqualName('id')) else: pointer = common.get_object_from_backend_name( schema, s_pointers.Pointer, ptr_id ) pname = pointer.get_verbosename(schema, with_parent=True) else: pname = None if pname is not None: if err_details.detail_json: object_id = err_details.detail_json.get('object_id') if object_id is not None: details = f'Failing object id is {str(object_id)!r}.' if from_graphql: pname = gql_replace_type_names_in_text(pname) return errors.MissingRequiredError( f'missing value for required {pname}', details=details, hint=hint, ) else: return errors.InternalServerError(err_details.message) error_type = None match = None for errtype, ere in constraint_res.items(): m = ere.match(err_details.message) if m: error_type = errtype match = m break # no need for else clause since it would have been handled by # the static version if error_type == 'constraint' or error_type == 'idconstraint': assert err_details.constraint_name # similarly, if we're here it's because we have a constraint_id if error_type == 'constraint': constraint_id_s, _, _ = err_details.constraint_name.rpartition(';') assert err_details.constraint_name constraint_id = uuidgen.UUID(constraint_id_s) constraint = schema.get_by_id( constraint_id, type=s_constraints.Constraint ) else: # Primary key violations give us the table name, so # look through that for the constraint obj_id, _, _ = err_details.constraint_name.rpartition('_') obj_type = schema.get_by_id( uuidgen.UUID(obj_id), type=s_objtypes.ObjectType ) obj_ptr = obj_type.getptr(schema, sn.UnqualName('id')) constraint = obj_ptr.get_exclusive_constraints(schema)[0] # msg is for the "end user" that should not mention pointers and object # type it is also affected by setting `errmessage` in user schema. msg = constraint.format_error_message(schema) # details is for the "developer" that must explain what's going on # under the hood. It contains verbose descriptions of object involved. subject = constraint.get_subject(schema) subject_description = subject.get_verbosename(schema, with_parent=True) constraint_description = constraint.get_verbosename(schema) details = f'violated {constraint_description} on {subject_description}' if from_graphql: msg = gql_replace_type_names_in_text(msg) details = gql_replace_type_names_in_text(details) return errors.ConstraintViolationError(msg, details=details) elif error_type == 'newconstraint': # If we're here, it means that we already validated that # schema_name, table_name and column_name all exist. # # NOTE: this should never occur in GraphQL mode. tabname = (err_details.schema_name, err_details.table_name) source = common.get_object_from_backend_name( schema, s_objtypes.ObjectType, tabname) source_name = source.get_displayname(schema) pointer = common.get_object_from_backend_name( schema, s_pointers.Pointer, err_details.column_name) pointer_name = pointer.get_shortname(schema).name return errors.ConstraintViolationError( f'Existing {source_name}.{pointer_name} ' f'values violate the new constraint') elif error_type == 'scalar': assert match domain_name = match.group(1) # NOTE: We don't attempt to change the name of the scalar type even if # the error comes from GraphQL because we map custom scalars onto # their underlying base types. stype_name = types.base_type_name_map_r.get(domain_name) if stype_name: if match.group(2) in range_constraints: msg = f'{str(stype_name)!r} value out of range' else: msg = f'invalid value for scalar type {str(stype_name)!r}' else: msg = translate_pgtype(schema, err_details.message) return errors.InvalidValueError(msg) return errors.InternalServerError(err_details.message) @interpret_by_code.register(pgerrors.ERROR_INVALID_TEXT_REPRESENTATION) def _interpret_invalid_text_repr( code: str, schema: s_schema.Schema, err_details: ErrorDetails, hint: Optional[str], from_graphql: bool = False, ): return errors.InvalidValueError( translate_pgtype(schema, err_details.message, from_graphql=from_graphql) ) @interpret_by_code.register(pgerrors.ERROR_NUMERIC_VALUE_OUT_OF_RANGE) def _interpret_numeric_out_of_range( code: str, schema: s_schema.Schema, err_details: ErrorDetails, hint: Optional[str], from_graphql: bool = False, ): return errors.NumericOutOfRangeError( translate_pgtype(schema, err_details.message, from_graphql=from_graphql) ) @interpret_by_code.register(pgerrors.ERROR_INVALID_DATETIME_FORMAT) @interpret_by_code.register(pgerrors.ERROR_DATETIME_FIELD_OVERFLOW) def _interpret_invalid_datetime( code: str, schema: s_schema.Schema, err_details: ErrorDetails, hint: Optional[str], from_graphql: bool = False, ): return errors.InvalidValueError( translate_pgtype(schema, err_details.message, from_graphql=from_graphql), hint=hint, ) @interpret_by_code.register(pgerrors.ERROR_WRONG_OBJECT_TYPE) def _interpret_wrong_object_type( code: str, schema: s_schema.Schema, err_details: ErrorDetails, hint: Optional[str], from_graphql: bool = False, ): # NOTE: this should never occur in GraphQL mode due to schema/query # validation. if ( err_details.message == 'covariance error' and err_details.column_name is not None and err_details.table_name is not None ): ptr = schema.get_by_id(uuidgen.UUID(err_details.column_name)) wrong_obj = schema.get_by_id(uuidgen.UUID(err_details.table_name)) assert isinstance(ptr, (s_pointers.Pointer, s_pointers.PseudoPointer)) target = ptr.get_target(schema) assert target is not None vn = ptr.get_verbosename(schema, with_parent=True) return errors.InvalidLinkTargetError( f"invalid target for {vn}: '{wrong_obj.get_name(schema)}'" f" (expecting '{target.get_name(schema)}')" ) return errors.InternalServerError(err_details.message) def static_interpret_psql_parse_error( exc: parser_errors.PSqlParseError ) -> errors.EdgeDBError: res: errors.EdgeDBError if isinstance(exc, parser_errors.PSqlSyntaxError): res = errors.EdgeQLSyntaxError(str(exc)) res.set_position(exc.cursor_pos - 1, None) res.compute_line_col(exc.query_source) elif isinstance(exc, parser_errors.PSqlUnsupportedError): res = errors.UnsupportedFeatureError(str(exc)) if exc.location is not None: res.set_position(exc.location, None) else: res = errors.InternalServerError(str(exc)) return res ================================================ FILE: edb/server/compiler/explain/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional import dataclasses import json import logging import pickle import immutables from edb import buildmeta from edb.common import debug from edb.edgeql import ast as qlast from edb.ir import ast as irast from edb.pgsql import ast as pgast from edb.schema import schema as s_schema from . import coarse_grained from . import fine_grained from . import ir_analyze from . import pg_tree from . import to_json log = logging.getLogger(__name__) # "affects_compilation" config vals that we don't actually want to report out. # This turns out to be a majority of them OMITTED_CONFIG_VALS = { "allow_dml_in_functions", "allow_bare_ddl", "force_database_error", } @dataclasses.dataclass class Arguments(to_json.ToJson): execute: bool buffers: bool @dataclasses.dataclass(frozen=True) class AnalyzeContext: schema: s_schema.Schema modaliases: immutables.Map[Optional[str], str] reverse_mod_aliases: dict[str, Optional[str]] def analyze_explain_output( query_asts_pickled: bytes, data: list[list[bytes]], std_schema: s_schema.Schema, ) -> bytes: if debug.flags.edgeql_explain: debug.header('Explain') ql: qlast.Base ir: irast.Statement pg: pgast.Base ql, ir, pg, explain_data = pickle.loads(query_asts_pickled) config_vals, args, modaliases = explain_data args = Arguments(**args) schema = ir.schema # We omit the std schema when serializing, so put it back if isinstance(schema, s_schema.ChainedSchema): schema = s_schema.ChainedSchema( top_schema=schema._top_schema, global_schema=schema._global_schema, base_schema=std_schema ) assert len(data) == 1 and len(data[0]) == 1 plan = json.loads(data[0][0]) assert len(plan) == 1 plan = debug_tree = plan[0]['Plan'] info = None fg_tree = None cg_tree = None try: ctx = AnalyzeContext( schema=schema, modaliases=modaliases, # This has last alias wins strategy. Do we need reverse? reverse_mod_aliases={v: k for k, v in modaliases.items()}, ) info = ir_analyze.analyze_queries(ql, ir, pg, ctx) debug_tree = pg_tree.Plan.from_json(plan, ctx) fg_tree, index = fine_grained.build(debug_tree, info, args) if debug.flags.edgeql_explain: debug.dump(fg_tree) debug.dump(info) cg_tree = coarse_grained.build(fg_tree, info, index) except Exception as e: log.exception("Error building explain model", exc_info=e) config_vals = { k: v for k, v in config_vals.items() if k not in OMITTED_CONFIG_VALS } globals_used = sorted([ str(k) for k in ir.globals if not k.is_permission ]) if info: buffers = info.buffers elif ql.span: buffers = [ql.span.buffer] else: buffers = [] # should never happen output = { 'config_vals': config_vals, 'globals_used': globals_used, 'module_aliases': dict(modaliases), 'arguments': args, 'version': buildmeta.get_version_string(), 'buffers': buffers, 'debug_info': { 'full_plan': debug_tree, 'analysis_info': info, }, 'fine_grained': fg_tree, 'coarse_grained': cg_tree, } return json.dumps(output, default=to_json.json_hook).encode('utf-8') ================================================ FILE: edb/server/compiler/explain/casefold.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import re # This matches spaces, minus or an empty string that comes before capital # letter (and not at the start of the string). # And is used to replace that word boundary for the underscore. # It handles cases like this: # * `Foo Bar` -- title case -- matches space # * `FooBar` -- CamelCase -- matches empty string before `Bar` # * `Some-word` -- words with dash -- matches dash word_boundary_re = re.compile(r'(? str: # note this only covers cases we have not all possible cases of # case conversion return word_boundary_re.sub('_', name).lower() def to_camel_case(name: str) -> str: # note this only covers cases we have not all possible cases of # case conversion return word_boundary_re.sub('', name) ================================================ FILE: edb/server/compiler/explain/coarse_grained.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, Iterator import dataclasses import enum import uuid from edb.server.compiler.explain import to_json from edb.server.compiler.explain import ir_analyze from edb.server.compiler.explain import pg_tree from edb.server.compiler.explain import fine_grained COST_KEYS = frozenset(( 'plan_rows', 'plan_width', 'self_cost', 'total_cost', 'startup_cost', )) class _Index: by_id: dict[int, _PlanInfo] by_alias: dict[str, _PlanInfo] def __init__(self, plan: fine_grained.Plan, idx: fine_grained.Index): by_id = {} ancestors: list[fine_grained.Plan] = [] def index(node: fine_grained.Plan) -> None: pinfo = _PlanInfo( plan=node, ancestors=list(reversed(ancestors)), ) by_id[id(node)] = pinfo ancestors.append(node) try: for sub in node.subplans: index(sub) finally: ancestors.pop() index(plan) self.by_id = by_id self.by_alias = {a: by_id[id(p)] for a, p in idx.by_alias.items()} @dataclasses.dataclass class _PlanInfo: plan: fine_grained.Plan ancestors: list[fine_grained.Plan] shape_mark: Optional[str] = None @property def id(self) -> uuid.UUID: return self.plan.pipeline[-1].plan_id def self_and_ancestors(self, index: _Index) -> Iterator[_PlanInfo]: yield self for node in self.ancestors: yield index.by_id[id(node)] @dataclasses.dataclass class Node(to_json.ToJson, pg_tree.CostMixin): plan_id: uuid.UUID relations: frozenset[str] contexts: Optional[list[ir_analyze.ContextDesc]] children: list[Child] # Note: clients should consider this open-ended list class ChildKind(enum.Enum): POINTER = "pointer" # TODO(tailhook) property/link ? FILTER = "filter" @dataclasses.dataclass class Child(to_json.ToJson): kind: ChildKind name: Optional[str] # currently set only for POINTER node: Node def _scan_relations( path: str, plan: fine_grained.Plan, index: _Index ) -> Iterator[pg_tree.Relation]: info = index.by_id[id(plan)] if info.shape_mark == path or info.shape_mark is None: for stage in plan.pipeline: if relation := getattr(stage, 'relation_name', None): yield relation for node in plan.subplans: yield from _scan_relations(path, node, index) def _build_shape( path: str, plan: fine_grained.Plan, shape: ir_analyze.ShapeInfo, contexts: Optional[list[ir_analyze.ContextDesc]], index: _Index, ) -> Node: # Coarse-grained tree is built like this: # # 1. Scan IR we find all the shapes, and mark aliases that belong to # them or their pointers (done in ir_analyze module) # 2. For each shape and property we try to find the node of fine-grained # tree that represents them (by using alias and walking up). # 3. And we output tree containing only those nodes marked in step (2) _shape_mark(path, shape, index) pointers = {} for name, pointer in shape.pointers.items(): subpath = f"{path}.{name}" if ( pointer.main_alias is not None and (c_info := index.by_alias.get(pointer.main_alias)) is not None ): info = c_info else: for alias in pointer.aliases: if c_info := index.by_alias.get(alias): info = c_info break else: continue start = info last_context = info.plan.contexts for plan_info in info.self_and_ancestors(index): mark = plan_info.shape_mark if mark is not None and mark != subpath: break start = plan_info if start.plan.contexts: last_context = start.plan.contexts pointers[name] = _build_shape( f"{path}.{name}", start.plan, pointer, last_context, index, ) relations = frozenset(_scan_relations(path, plan, index)) # sometimes context can be in inner node, hoist it if ( not contexts and (main_alias := shape.main_alias) and (main_info := index.by_alias.get(main_alias)) ): alias = main_alias contexts = main_info.plan.contexts top = plan.pipeline[0] return Node( plan_id=plan.pipeline[0].plan_id, relations=relations, children=[Child(kind=ChildKind.POINTER, name=name, node=node) for name, node in pointers.items()], contexts=contexts, # cost vars startup_cost=top.startup_cost, total_cost=top.total_cost, plan_rows=top.plan_rows, plan_width=top.plan_width, actual_startup_time=top.actual_startup_time, actual_total_time=top.actual_total_time, actual_rows=top.actual_rows, actual_loops=top.actual_loops, shared_hit_blocks=top.shared_hit_blocks, shared_read_blocks=top.shared_read_blocks, shared_dirtied_blocks=top.shared_dirtied_blocks, shared_written_blocks=top.shared_written_blocks, local_hit_blocks=top.local_hit_blocks, local_read_blocks=top.local_read_blocks, local_dirtied_blocks=top.local_dirtied_blocks, local_written_blocks=top.local_written_blocks, temp_read_blocks=top.temp_read_blocks, temp_written_blocks=top.temp_written_blocks, ) def _shape_mark(path: str, shape: ir_analyze.ShapeInfo, index: _Index) -> None: path_prefix = path + "." for alias in shape.all_aliases: info = index.by_alias.get(alias) if not info: continue for plan_info in info.self_and_ancestors(index): if plan_info.shape_mark: break plan_info.shape_mark = path for name, _subshape in shape.pointers.items(): subpath = f"{path}.{name}" for alias in shape.all_aliases: info = index.by_alias.get(alias) if not info: continue for plan_info in info.self_and_ancestors(index): cur_mark = plan_info.shape_mark if cur_mark is None: plan_info.shape_mark = subpath elif cur_mark == path: break elif cur_mark == subpath: break elif cur_mark.startswith(path_prefix): # Two pointers met together, this means it's a # branching point. We need to cleanup all pointers # from the ancestors now (just continue loop and it'll # do the job) plan_info.shape_mark = path def build( plan: fine_grained.Plan, info: ir_analyze.AnalysisInfo, index: fine_grained.Index ) -> Node: idx = _Index(plan, index) return _build_shape('🌳', plan, info.shape_tree, plan.contexts, idx) ================================================ FILE: edb/server/compiler/explain/fine_grained.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Iterable import uuid import dataclasses from edb.server.compiler.explain import to_json from edb.server.compiler.explain import pg_tree from edb.server.compiler.explain import ir_analyze from edb.server.compiler import explain PropValue = str | int | float | list[str | int | float] @dataclasses.dataclass class Prop(to_json.ToJson): title: str value: PropValue type: Optional[pg_tree.PropType] important: bool @property def attribute_name(self) -> str: return self.title class Properties(to_json.ToJson): def __init__(self, props: Iterable[Prop]): self._props = {p.attribute_name: p for p in props} def to_json(self) -> Any: return list(self._props.values()) def __repr__(self) -> str: return repr({k: v.value for k, v in self._props.items()}) @dataclasses.dataclass(kw_only=True) class Stage(to_json.ToJson, pg_tree.CostMixin): plan_type: str plan_id: uuid.UUID properties: Properties def __getattr__(self, name: str) -> PropValue: try: return self.properties._props[name].value except KeyError: raise AttributeError(name) from None @dataclasses.dataclass class Plan(to_json.ToJson): contexts: Optional[list[ir_analyze.ContextDesc]] pipeline: list[Stage] subplans: list[Plan] alias: Optional[str] = None @dataclasses.dataclass class Index: by_id: dict[uuid.UUID, Plan] by_alias: dict[str, Plan] def context_diff( left: Optional[list[ir_analyze.ContextDesc]], right: Optional[list[ir_analyze.ContextDesc]], ) -> list[ir_analyze.ContextDesc]: if not left: return [] if not right: return left result = [ctx for ctx in left if ctx not in right] return result def context_intersect( left: Optional[list[ir_analyze.ContextDesc]], right: Optional[list[ir_analyze.ContextDesc]], ) -> list[ir_analyze.ContextDesc]: if not left: return [] if not right: return [] return [ctx for ctx in left if ctx in right] def context_optimize( items: Optional[list[ir_analyze.ContextDesc]], ) -> Optional[list[ir_analyze.ContextDesc]]: if not items: return None # We assume that context are ordered: # 1. In single location (alias): from the most specific to the broadest # 2. Location that belong to single buffer or alias are subsequent # # Postgres marks by alias the most specific thing (i.e. table scan mostly) # But since we try to hoist context to nearest node having no context, that # usually matches broadest context. Although, this is just a heuristic. # # So we only keep the last context from each group (alias/buffer) by # squashing contexts that are inside of each other result: list[ir_analyze.ContextDesc] = [] for ctx in reversed(items): for maybe_parent in result: if ctx.is_subcontext_of(maybe_parent): break else: result.append(ctx) result.reverse() return result class TreeBuilder: alias_info: dict[str, ir_analyze.AliasInfo] by_id: dict[uuid.UUID, Plan] by_alias: dict[str, Plan] def __init__(self, info: ir_analyze.AnalysisInfo): self.alias_info = info.alias_info self.by_alias = {} self.by_id = {} def build(self, plan: pg_tree.Plan, args: explain.Arguments) -> Plan: # For fine-grained tree (this one will be displayed in \verbose mode or # whatever we name it) we do three things: # 1. Remove cheap scalar Result nodes. In my examples, they are: # variable in LIMIT clause, or scalar expressions, like string # concatenation. We ensure that eliminated nodes are less than 1 # percent of parent node cost/time, # 2. Squash nested nodes having one child into pipeline list. This # should allow less nested presentation of the tree. # # 3. For contexts: # a) Hoist them through the tree of one-child node # b) If contexts of all children are equal we move context to higher # level # c) If contexts of children are partly equal, we move equal # contexts to parent removing them from children # d) Eliminate overlapping contexts after that # 3c, works for things like x := count(.a) + count(.b). There are two # nodes, one starting from .a and one from .b and both of them have # contexts up to the whole expression starting from x :=. pipeline = [] aliases = set() pipeline.append(self._make_stage(plan)) alias = getattr(plan, 'alias', None) if alias: aliases.add(alias) plans = _filter_plans(plan, args) while len(plans) == 1 and not alias: node = plans[0] pipeline.append(self._make_stage(node)) plans = _filter_plans(node, args) alias = getattr(node, 'alias', None) if alias: aliases.add(alias) subplans = [self.build(subplan, args) for subplan in plans] alias_info = self.alias_info.get(alias) if alias else None contexts = alias_info.contexts if alias_info else None if not contexts and subplans and (contexts := subplans[0].contexts): # hoist contexts that are common in child branches for ch_plan in subplans[1:]: if inner_contexts := ch_plan.contexts: contexts = context_intersect(contexts, inner_contexts) if contexts: # some contexts are hoisted for (sub, node) in zip(subplans, plans): sub.contexts = context_diff(sub.contexts, contexts) if ( not sub.contexts and (subalias := getattr(node, 'alias', None)) ): aliases.add(subalias) # optimize after hoisting for sub in subplans: sub.contexts = context_optimize(sub.contexts) result = Plan( contexts=contexts, pipeline=pipeline, subplans=subplans, ) for stage in pipeline: self.by_id[stage.plan_id] = result # Note: this overwrites children with this alias by this node # when contexts are hoisted, which is a good thing for alias in aliases: self.by_alias[alias] = result return result def _get_contexts( self, plan: pg_tree.Plan, ) -> Optional[list[ir_analyze.ContextDesc]]: if not (alias := getattr(plan, 'alias', None)): return None if not (ainfo := self.alias_info.get(alias)): return None return ainfo.contexts def _make_stage(self, plan: pg_tree.Plan) -> Stage: properties = [] for name, prop in plan.get_props().items(): if (value := getattr(plan, name, None)) is not None: properties.append(Prop( title=name, value=value, type=prop.enum_type, important=prop.important, )) return Stage( plan_type=type(plan).__name__, plan_id=plan.plan_id, properties=Properties(properties), # cost vars startup_cost=plan.startup_cost, total_cost=plan.total_cost, plan_rows=plan.plan_rows, plan_width=plan.plan_width, actual_startup_time=plan.actual_startup_time, actual_total_time=plan.actual_total_time, actual_rows=plan.actual_rows, actual_loops=plan.actual_loops, shared_hit_blocks=plan.shared_hit_blocks, shared_read_blocks=plan.shared_read_blocks, shared_dirtied_blocks=plan.shared_dirtied_blocks, shared_written_blocks=plan.shared_written_blocks, local_hit_blocks=plan.local_hit_blocks, local_read_blocks=plan.local_read_blocks, local_dirtied_blocks=plan.local_dirtied_blocks, local_written_blocks=plan.local_written_blocks, temp_read_blocks=plan.temp_read_blocks, temp_written_blocks=plan.temp_written_blocks, ) def _filter_plans( node: pg_tree.Plan, args: explain.Arguments ) -> list[pg_tree.Plan]: min_cost = node.total_cost * 0.01 # TODO(tailhook) maybe we should scan inner plans to figure out that # there are no inner contexts in the children plans = [ p for p in node.plans if not isinstance(p, pg_tree.Result) or p.total_cost > min_cost or p.plan_rows > 1 ] return plans def build( plan: pg_tree.Plan, info: ir_analyze.AnalysisInfo, args: explain.Arguments, ) -> tuple[Plan, Index]: tree = TreeBuilder(info) result = tree.build(plan, args) result.contexts = context_optimize(result.contexts) index = Index(by_id=tree.by_id, by_alias=tree.by_alias) return result, index ================================================ FILE: edb/server/compiler/explain/ir_analyze.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Iterator, cast import dataclasses from edb.common import ast from edb.common import debug from edb.edgeql import ast as qlast from edb.ir import ast as irast from edb.pgsql import ast as pgast from edb.pgsql.compiler import astutils from edb.server.compiler import explain from edb.server.compiler.explain import to_json @dataclasses.dataclass(eq=True, frozen=True) class ContextDesc(to_json.ToJson): start: int end: int buffer_idx: int text: str def is_subcontext_of(self, other: ContextDesc) -> bool: return ( self.buffer_idx == other.buffer_idx and self.start >= other.start and self.end <= other.end ) @dataclasses.dataclass class AliasInfo(to_json.ToJson): contexts: list[ContextDesc] @dataclasses.dataclass class ShapeInfo(to_json.ToJson): aliases: set[str] pointers: dict[str, ShapeInfo] main_alias: Optional[str] = None @property def all_aliases(self) -> Iterator[str]: if self.main_alias: yield self.main_alias yield from self.aliases @dataclasses.dataclass class AnalysisInfo(to_json.ToJson): alias_info: dict[str, AliasInfo] buffers: list[str] shape_tree: ShapeInfo class VisitShapes(ast.NodeVisitor): ir_node_to_alias: dict[irast.Set, str] = {} skip_hidden = True extra_skips = frozenset(('shape', 'source', 'target')) def __init__(self, ir_node_to_alias: dict[irast.Set, str], **kwargs: Any): self.ir_node_to_alias = ir_node_to_alias self.current_shape = ShapeInfo(aliases=set(), pointers={}) super().__init__(**kwargs) def visit_Set(self, node: irast.Set) -> Any: alias = self.ir_node_to_alias.get(node) if not alias: return self.generic_visit(node) if not node.shape: self.current_shape.aliases.add(alias) return self.generic_visit(node) parent_shape = self.current_shape parent_shape.main_alias = alias parent_shape.aliases.discard(alias) for (item, _oper) in node.shape: if not (rptr_name := item.path_id.rptr_name()): continue name = str(rptr_name.name) self.current_shape = self.current_shape.pointers.setdefault( name, ShapeInfo(aliases=set(), pointers={}), ) try: self.generic_visit(item) finally: self.current_shape = parent_shape # Simple scalar expressions have the same alias for some reason # so we have to discard them for sub in parent_shape.pointers.values(): sub.aliases.discard(parent_shape.main_alias) sub.aliases.difference_update(parent_shape.aliases) return self.generic_visit(node) # this skips node.shape # Do a bunch of analysis of the queries. Currently we produce more # info than we actually consume, since we are still in a somewhat # exploratory phase. def analyze_queries( ql: qlast.Base, ir: irast.Statement, pg: pgast.Base, ctx: explain.AnalyzeContext, ) -> AnalysisInfo: debug_spew = debug.flags.edgeql_explain assert ql.span contexts = {(ql.span.buffer, ql.span.filename): 0} def get_context(node: irast.Set) -> ContextDesc: assert node.span, node span = node.span key = span.buffer, span.filename if (idx := contexts.get(key)) is None: idx = len(contexts) contexts[key] = idx text = span.buffer[span.start:span.end] return ContextDesc( start=span.start, end=span.end, buffer_idx=idx, text=text, ) rvars = ast.find_children(pg, pgast.BaseRangeVar) queries = ast.find_children(pg, pgast.Query) # Map subqueries back to their rvars subq_to_rvar: dict[pgast.Query, pgast.RangeSubselect] = {} for rvar in rvars: if isinstance(rvar, pgast.RangeSubselect): assert rvar.subquery not in subq_to_rvar for subq in astutils.each_query_in_set(rvar.subquery): subq_to_rvar[subq] = rvar # Find all *references* to an rvar in path_rvar_maps # Maps rvars to the queries that join them reverse_path_rvar_map: dict[ pgast.BaseRangeVar, list[pgast.Query], ] = {} for qry in queries: qrvars = [] if isinstance(qry, (pgast.SelectStmt, pgast.UpdateStmt)): qrvars.extend(qry.from_clause) if isinstance(qry, pgast.DeleteStmt): qrvars.extend(qry.using_clause) for orvar in qrvars: for rvar in astutils.each_base_rvar(orvar): reverse_path_rvar_map.setdefault(rvar, []).append(qry) # Map aliases to rvars and then to path ids aliases = { rvar.alias.aliasname: rvar for rvar in rvars if rvar.alias.aliasname } alias_contexts: dict[str, list[ContextDesc]] = {} ir_node_to_alias: dict[irast.Set, str] = {} # Try to produce good contexts # KEY FACT: We often duplicate code for with bindings. This means # we want to expose that through the contexts we include. for alias, rvar in aliases.items(): # Run up the tree looking both for contexts to associate with # and the next node in the tree to go up to asets = [] while True: ns = cast(list[irast.Set], rvar.ir_origins or []) if len(ns) >= 1 and ns[0].span: if ns[0] not in asets: asets.append(ns[0]) for node in ns: ir_node_to_alias[node] = alias break # Find the enclosing sources = reverse_path_rvar_map.get(rvar, ()) if debug_spew: print(f'SOURCES for {alias} 1/{len(ns)}', sources) if sources: source = sources[0] if source not in subq_to_rvar: break else: break rvar = subq_to_rvar[source] spans = [get_context(x) for x in asets if x.span] if debug_spew: print(alias, asets) for x in asets: debug.dump(x.span) # Using the first set of contexts found alias_contexts.setdefault(alias, spans) alias_info = { alias: AliasInfo( contexts=alias_contexts.pop(alias, []), ) for alias in aliases } visitor = VisitShapes(ir_node_to_alias=ir_node_to_alias) visitor.visit(ir) shape_tree = visitor.current_shape return AnalysisInfo( alias_info=alias_info, buffers=[text for text, _filename in contexts.keys()], shape_tree=shape_tree, ) ================================================ FILE: edb/server/compiler/explain/pg_tree.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # The types here modeled closely after postgres explain output. # See postgres/src/backend/commands/explain.c # from __future__ import annotations from typing import ( Annotated, Any, ClassVar, Optional, TypeVar, Union, Sequence, get_args, get_origin, get_type_hints, NewType, Text, ) import dataclasses import enum import re import uuid from edb.common import ast from edb.schema import constraints as s_constr from edb.schema import indexes as s_indexes from edb.schema import name as sn from edb.schema import objects as so from edb.schema import pointers as s_pointers from edb.server.compiler import explain from edb.server.compiler.explain import casefold from edb.server.compiler.explain import to_json uuid_core = '[a-f0-9]{8}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{4}-?[a-f0-9]{12}' uuid_re = re.compile( rf'(\.?"?({uuid_core})"?)', re.I ) T = TypeVar('T') FromJsonT = TypeVar('FromJsonT', bound='FromJson') class FromJson(ast.AST, to_json.ToJson): @classmethod def from_json( cls: type[FromJsonT], json: dict[str, Any], ctx: explain.AnalyzeContext, ) -> FromJsonT: annotations = get_type_hints(cls) result = cls() for name, value in json.items(): name = casefold.to_snake_case(name) if not (prop := annotations.get(name)): # extra values are okay setattr(result, name, value) continue if get_origin(prop) is Annotated: prop = get_args(prop)[0] if get_origin(prop) is Union: # actually an option prop = get_args(prop)[0] if value is None: setattr(result, name, value) continue if prop is Index: setattr(result, name, _translate_index(value, ctx)) elif prop is Relation: setattr(result, name, _translate_relation(value, ctx)) elif get_origin(prop) is list: inner = get_args(prop)[0] if type(inner) is type and issubclass(inner, FromJson): setattr(result, name, [inner.from_json(v, ctx) for v in value]) else: setattr(result, name, value) elif type(prop) is type and issubclass(prop, FromJson): setattr(result, name, prop.from_json(value, ctx)) else: setattr(result, name, value) # lists are always there for convenience for name, prop in annotations.items(): name = casefold.to_snake_case(name) if ( get_origin(prop) is list and getattr(result, name, None) is None ): setattr(result, name, []) return result def to_json(self) -> Any: dic = super().to_json() dic['node_type'] = self.__class__.__name__ return dic def _obj_to_name( sobj: so.Object, ctx: explain.AnalyzeContext, dotted: bool=False, ) -> str: if isinstance(sobj, s_pointers.Pointer): # If a pointer is on the RHS of a dot, just use # the short name. But otherwise, grab the source # and link it up s = str(sobj.get_shortname(ctx.schema).name) if sobj.is_link_property(ctx.schema): s = f'@{s}' if not dotted and (src := sobj.get_source(ctx.schema)): src_name = _translate_name( src.get_name(ctx.schema), ctx.reverse_mod_aliases, ) s = f'{src_name}.{s}' elif isinstance(sobj, s_constr.Constraint): s = sobj.get_verbosename(ctx.schema, with_parent=True) elif isinstance(sobj, s_indexes.Index): s = sobj.get_verbosename(ctx.schema, with_parent=True) if expr := sobj.get_expr(ctx.schema): s += f' on ({expr.text})' else: s = _translate_name( sobj.get_name(ctx.schema), ctx.reverse_mod_aliases, ) if dotted: s = '.' + s return s def _translate_index(name: str, ctx: explain.AnalyzeContext) -> Index: # Try to replace all ids with textual names had_index = False for (full, m) in uuid_re.findall(name): uid = uuid.UUID(m) sobj = ctx.schema.get_by_id(uid, default=None) if sobj: had_index |= isinstance(sobj, s_indexes.Index) dotted = full[0] == '.' s = _obj_to_name(sobj, ctx, dotted=dotted) name = uuid_re.sub(s, name, count=1) name = name.replace('_source_target_key', ' forward link index') name = name.replace(';schemaconstr', '') name = name.replace('_target_key', ' backward link index') # If the index name is from an actual index or constraint, # the `_index` part of the name just total noise, but if it # is from a link, it might be slightly informative if had_index: name = name.replace('_index', '') else: name = name.replace('_index', ' index') return Index(name) def _translate_relation(name: str, ctx: explain.AnalyzeContext) -> Relation: try: id = uuid.UUID(name) except ValueError: # For introspection queries there are tables are named pg_* return Relation(name) return Relation(_obj_to_name(ctx.schema.get_by_id(id), ctx)) def _translate_name( name: sn.Name, reverse_mod_aliases: dict[str, Optional[str]], ) -> str: if not isinstance(name, sn.QualName): return str(name) if name.module in reverse_mod_aliases: module = reverse_mod_aliases[name.module] if module is None: return name.name else: return f"{module}::{name.name}" else: module = name.module suffix = f"::{name.name}" while True: # looking for the longest prefix first try: prefix, submodule = module.rsplit("::", 1) except ValueError: return str(name) suffix = f"::{submodule}{suffix}" # Note: we don't strip default alias here so only absolute paths # are generated if aliased := reverse_mod_aliases.get(prefix): return aliased + suffix module = prefix # Legend: # * show, shown -- something visible in the text explain # * if xxx -- means some condition when parameter exists, option to explain # * str values also often have a list of options in the comment # (we do not use enums, because no exhaustivity guarantee) # * `kB` is unit for these values # * no key with list == empty list Expr = NewType('Expr', str) Kbytes = NewType('Kbytes', int) Millis = NewType('Millis', float) Index = NewType('Index', str) Relation = NewType('Relation', str) class PropType(enum.Enum): KBYTES = "kB" MILLIS = "ms" EXPR = "expr" INDEX = "index" RELATION = "relation" TEXT = "text" INT = "int" FLOAT = "float" LIST_KBYTES = "list:kB" LIST_MILLIS = "list:ms" LIST_EXPR = "list:expr" LIST_INDEX = "list:index" LIST_RELATION = "list:relation" LIST_TEXT = "list:text" LIST_INT = "list:int" LIST_FLOAT = "list:float" TYPES = { Kbytes: PropType.KBYTES, Millis: PropType.MILLIS, Expr: PropType.EXPR, Index: PropType.INDEX, Relation: PropType.RELATION, str: PropType.TEXT, int: PropType.INT, float: PropType.FLOAT, } class Important: __slots__ = () important = Important() @dataclasses.dataclass class PropInfo: type: type[object] enum_type: PropType important: bool class JitOptions(FromJson): # show all inlining: bool expressions: bool optimization: bool deforming: bool class JitTiming(FromJson): generation: float # ms inilining: float # ms optimization: float # ms emission: float # ms total: float # ms class JitInfo(FromJson): functions: int options: JitOptions timing: JitTiming class Worker(FromJson): worker_number: int actual_startup_time: Optional[float] # if timing actual_total_time: Optional[float] # if timing actual_rows: Optional[float] actual_loops: Optional[float] jit: Optional[JitInfo] # if bunch of options PlanT = TypeVar('PlanT', bound='Plan') @dataclasses.dataclass(kw_only=True) class CostMixin: # if cost startup_cost: float total_cost: float plan_rows: float plan_width: int # if analyze (zeros if never executed) actual_startup_time: Optional[float] = None # if timing actual_total_time: Optional[float] = None # if timing actual_rows: Optional[float] = None actual_loops: Optional[float] = None # if buffers shared_hit_blocks: Optional[int] = None shared_read_blocks: Optional[int] = None shared_dirtied_blocks: Optional[int] = None shared_written_blocks: Optional[int] = None local_hit_blocks: Optional[int] = None local_read_blocks: Optional[int] = None local_dirtied_blocks: Optional[int] = None local_written_blocks: Optional[int] = None temp_read_blocks: Optional[int] = None temp_written_blocks: Optional[int] = None class Plan(FromJson, CostMixin): # TODO(tailhook) output is lost somewhere node_type: str plan_id: uuid.UUID parent_relationship: Optional[str] subplan_name: Optional[str] # shown parallel_aware: bool # true always shown as a prefix of node name async_capable: bool # true always shown as a prefix of node name workers: Sequence[Worker] # shown if non-empty plans: list[Plan] __subclasses: ClassVar[dict[str, type[Plan]]] = dict() def __init_subclass__(cls, **kwargs: Any): super().__init_subclass__(**kwargs) cls.__subclasses[cls.__name__] = cls @classmethod def from_json( cls, json: dict[str, Any], ctx: explain.AnalyzeContext, ) -> Plan: copy = json.copy() copy['plan_id'] = uuid.uuid4() node_type = casefold.to_camel_case(copy.pop("Node Type")) subclass = cls.__subclasses.get(node_type, cls) return super(Plan, subclass).from_json(copy, ctx) @classmethod def get_props(cls) -> dict[str, PropInfo]: result = {} for name, prop in get_type_hints(cls, include_extras=True).items(): if name in CostMixin.__annotations__: # these are stored in the node itself continue if get_origin(prop) is Annotated: imp = important in get_args(prop) prop = get_args(prop)[0] else: imp = False try: if get_origin(prop) is list: enum_type = PropType["LIST_" + TYPES[prop].name] elif get_origin(prop) is Union: # optional enum_type = TYPES[get_args(prop)[0]] else: enum_type = TYPES[prop] except KeyError: # Unknown types are skipped, they are probably # nested structures, we don't support yet, and plan_id continue result[name] = PropInfo( type=prop, enum_type=enum_type, important=imp, ) return result # Base types class BaseScan(Plan): schema: Optional[str] # if verbose # It should have been required, but in ModifyTable it's optional, so # we try to make it compatible. We don't rely on it being required in # the code anyways. alias: Optional[str] class RelationScan(BaseScan): # It should have been required, but in ModifyTable it's optional, so # we try to make it compatible. We don't rely on it being required in # the code anyways. relation_name: Annotated[Optional[Relation], important] class FilterScan(FromJson): # mixin filter: Expr rows_removed_by_filter: Annotated[Optional[float], important] # Specific types class Result(Plan, FilterScan): one_time_filter: Expr class ProjectSet(Plan): pass class TargetTable(FromJson): schema: Optional[str] # if verbose alias: Optional[str] relation_name: Annotated[Optional[Relation], important] cte_name: Optional[str] tuplestore_name: Optional[str] tablefunction_name: Optional[str] function_name: Optional[str] # Also pluggable explain FDW class ModifyTable(RelationScan, TargetTable): operation: str # title target_tables: list[TargetTable] # show, if mult otherwise inherited props # Looks like conflict is only possible for single table, but # it's not clear # # if conflict conflict_resolution: Optional[str] # NOTHING, UPDATE conflict_arbiter_indexes: list[str] conflict_filter: Expr rows_removed_by_conflict_filter: float tuples_inserted: Optional[float] # if analyze conflicting_tuples: Optional[float] # if analyze class Append(Plan): pass class MergeAppend(Plan): sort_key: Annotated[list[Expr], important] # show presorted_key: Annotated[list[Expr], important] class RecursiveUnion(Plan): pass class BitmapAnd(Plan): pass class BitmapOr(Plan): pass class UniqueJoin(Plan, FilterScan): inner_unique: bool # Inner, Left, Full, Right, Semi, Anti, show join_type: Annotated[str, important] join_filter: Expr rows_removed_by_join_filter: Optional[float] class NestedLoop(UniqueJoin): pass class MergeJoin(UniqueJoin): merge_cond: Expr class HashJoin(UniqueJoin): hash_cond: Expr class SeqScan(RelationScan, FilterScan): pass class SampleScan(RelationScan, FilterScan): sampling_method: Annotated[Text, important] # show sampling_parameters: list[str] repeatable_seed: Annotated[Optional[str], important] # show class Gather(Plan, FilterScan): workers_planned: int workers_launched: Optional[int] # analyze params_evaluated: Optional[list[str]] single_copy: bool class GatherMerge(Plan, FilterScan): workers_planned: int workers_launched: Optional[int] # analyze params_evaluated: Optional[list[str]] class IndexScan(RelationScan, FilterScan): # Backwards, Forward, NoMovement, show: opt Backward scan_direction: Annotated[str, important] index_name: Annotated[Index, important] # show index_cond: Expr rows_removed_by_index_recheck: Annotated[Optional[float], important] order_by: Expr class IndexOnlyScan(IndexScan): heap_fetches: Optional[float] # if analyze class BitmapIndexScan(Plan): index_name: Annotated[Index, important] # show index_cond: Expr class BitmapHeapScan(RelationScan, FilterScan): recheck_cond: Expr rows_removed_by_index_recheck: Optional[float] exact_heap_blocks: Optional[int] # if analyze, show lossy_heap_blocks: Optional[int] # if analyze, show class TidScan(RelationScan, FilterScan): tid_cond: Expr class TidRangeScan(RelationScan, FilterScan): tid_cond: Expr class SubqueryScan(Plan, FilterScan): pass class FunctionScan(BaseScan, FilterScan): function_name: str function_call: Expr # if verbose class TableFunctionScan(BaseScan, FilterScan): table_function_name: str # always == 'xmltable' table_function_call: Expr # if verbose class ValuesScan(BaseScan, FilterScan): pass class CTEScan(BaseScan, FilterScan): cte_name: str class NamedTuplestoreScan(BaseScan, FilterScan): tuplestore_name: str class WorkTableScan(BaseScan, FilterScan): cte_name: str class ForeignScan(RelationScan, FilterScan): operation: Annotated[Optional[str], important] # show: title class CustomScan(RelationScan, FilterScan): custom_plan_provider: Optional[str] # extra info that is gather via custom function :shrug: class Materialize(Plan): pass class MemoizeWorker(Worker): # show if analyze && cache_misses > 0 (probably if enabled) cache_hits: int cache_misses: int cache_evictions: int cache_overflows: int peak_memory_usage: int # kB class Memoize(Plan): cache_key: str # show cache_mode: str # {binary, logical}, show # show if analyze && cache_misses > 0 (probably if enabled) cache_hits: int cache_misses: int cache_evictions: int cache_overflows: int peak_memory_usage: int # kB workers: Sequence[MemoizeWorker] class SortWorker(Worker): sort_method: str # show sort_space_used: int # show, kB sort_space_type: str # show class Sort(Plan): sort_key: Annotated[list[Expr], important] # show presorted_key: list[Expr] sort_method: Annotated[str, important] # show # * still in progress # * top-N heapsort # * quicksort # * external sort # * external merge sort_space_used: Annotated[Kbytes, important] # show, kB sort_space_type: Annotated[Text, important] # Disk, Memory, show workers: Sequence[SortWorker] # overrides class SortSpaceInfo(FromJson): average_sort_space_used: Annotated[Kbytes, important] # kB, show peak_sort_space_used: Annotated[Kbytes, important] # kB, show class SortGroupsInfo(FromJson): group_count: int sort_methods_used: list[str] # see Sort.sort_method sort_space_memory: SortSpaceInfo # show non-zero sort_space_disk: SortSpaceInfo # show non-zero class IncrementalSortWorker(Worker): full_sort_groups: Optional[SortGroupsInfo] # show pre_sorted_groups: Optional[SortGroupsInfo] # show class IncrementalSort(Plan): sort_key: Annotated[list[Expr], important] # show presorted_key: list[Expr] full_sort_groups: Optional[SortGroupsInfo] # show pre_sorted_groups: Optional[SortGroupsInfo] # show workers: Sequence[SortWorker] # overrides class Group(Plan, FilterScan): pass class Aggregate(Plan, FilterScan): strategy: Annotated[str, important] # show: title partial_mode: Annotated[str, important] # Partial, Finalize, Simple class WindowAgg(Plan): pass class Unique(Plan): pass class SetOp(Plan): strategy: str # Sorted, Hashed, show: title: SetOp, HashSetOp command: str # Intersect, Intersect All, Except, ExceptAll, show class LockRows(Plan): pass class Limit(Plan): pass class Hash(Plan): hash_buckets: Annotated[int, important] # show original_hash_buckets: int # show if differs hash_batches: Annotated[int, important] # show original_hash_batches: int # show if differs peak_memory_usage: Annotated[Kbytes, important] # kB # show ================================================ FILE: edb/server/compiler/explain/to_json.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any import enum import uuid from edb.ir import statypes class ToJson: def to_json(self) -> Any: return {k: v for k, v in self.__dict__.items() if v is not None} def json_hook(value: Any) -> Any: if isinstance(value, ToJson): return value.to_json() elif isinstance(value, uuid.UUID): return str(value) elif isinstance(value, enum.Enum): return value.value elif isinstance(value, (frozenset, set)): return list(value) elif isinstance(value, statypes.ScalarType): return value.to_json() raise TypeError(f"Cannot serialize {value!r}") ================================================ FILE: edb/server/compiler/rpc.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # cimport cython cdef char serialize_output_format(val) cdef deserialize_output_format(char mode) cdef char serialize_input_language(val) cdef deserialize_input_language(char mode) @cython.final cdef class SQLParamsSource: cdef: object _cached_key object _serialized readonly object types_in_out @cython.final cdef class CompilationRequest: cdef: object serializer readonly object source readonly object protocol_version readonly object input_language readonly object output_format readonly object input_format readonly bint expect_one readonly int implicit_limit readonly bint inline_typeids readonly bint inline_typenames readonly bint inline_objectids readonly str role_name readonly str branch_name readonly object modaliases readonly object session_config object database_config object system_config object schema_version readonly object key_params bytes serialized_cache object cache_key cdef _serialize(self) @cython.final cdef class CompilationRequestIdHandle: cdef: object cache_key ================================================ FILE: edb/server/compiler/rpc.pyi ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import typing import uuid import immutables from edb import edgeql from edb.server import defines, config from edb.server.compiler import sertypes, enums from edb import graphql from edb.pgsql import parser as pgparser class SQLParamsSource: types_in_out: list[tuple[list[str], list[tuple[str, str]]]] def cache_key(self) -> bytes: ... def serialize(self) -> bytes: ... @staticmethod def deserialize(data: bytes) -> SQLParamsSource: ... def text(self) -> str: ... class CompilationRequest: source: edgeql.Source | graphql.Source | pgparser.Source | SQLParamsSource protocol_version: defines.ProtocolVersion input_language: enums.InputLanguage output_format: enums.OutputFormat input_format: enums.InputFormat expect_one: bool implicit_limit: int inline_typeids: bool inline_typenames: bool inline_objectids: bool role_name: str branch_name: str modaliases: immutables.Map[str | None, str] | None session_config: immutables.Map[str, config.SettingValue] | None key_params: typing.Mapping[str, object] | None = None def __init__( self, *, source: edgeql.Source | graphql.Source | pgparser.Source, protocol_version: defines.ProtocolVersion, schema_version: uuid.UUID, compilation_config_serializer: sertypes.CompilationConfigSerializer, input_language: enums.InputLanguage = enums.InputLanguage.EDGEQL, output_format: enums.OutputFormat = enums.OutputFormat.BINARY, input_format: enums.InputFormat = enums.InputFormat.BINARY, expect_one: bool = False, implicit_limit: int = 0, inline_typeids: bool = False, inline_typenames: bool = False, inline_objectids: bool = True, modaliases: typing.Mapping[str | None, str] | None = None, session_config: typing.Mapping[str, config.SettingValue] | None = None, database_config: typing.Mapping[str, config.SettingValue] | None = None, system_config: typing.Mapping[str, config.SettingValue] | None = None, role_name: str = defines.EDGEDB_SUPERUSER, branch_name: str = defines.EDGEDB_SUPERUSER_DB, key_params: typing.Mapping[str, object] | None = None, ): ... def set_modaliases( self, value: typing.Mapping[str | None, str] | None ) -> CompilationRequest: ... def set_session_config( self, value: typing.Mapping[str, config.SettingValue] | None ) -> CompilationRequest: ... def set_database_config( self, value: typing.Mapping[str, config.SettingValue] | None ) -> CompilationRequest: ... def set_system_config( self, value: typing.Mapping[str, config.SettingValue] | None ) -> CompilationRequest: ... def set_schema_version(self, version: uuid.UUID) -> CompilationRequest: ... def serialize(self) -> bytes: ... @classmethod def deserialize( cls, data: bytes, query_text: str, compilation_config_serializer: sertypes.CompilationConfigSerializer, ) -> CompilationRequest: ... def get_cache_key(self) -> uuid.UUID: ... ================================================ FILE: edb/server/compiler/rpc.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import ( Mapping, ) import json import pickle import hashlib import uuid cimport cython import immutables from edb import edgeql, errors from edb.common import uuidgen from edb.edgeql import qltypes from edb.edgeql import tokenizer from edb.graphql import tokenizer as gql_tokenizer from edb.server import config, defines from edb.server.pgproto.pgproto cimport WriteBuffer, ReadBuffer from edb.pgsql import parser as pgparser from . import enums, sertypes cdef object OUT_FMT_BINARY = enums.OutputFormat.BINARY cdef object OUT_FMT_JSON = enums.OutputFormat.JSON cdef object OUT_FMT_JSON_ELEMENTS = enums.OutputFormat.JSON_ELEMENTS cdef object OUT_FMT_NONE = enums.OutputFormat.NONE cdef object IN_FMT_BINARY = enums.InputFormat.BINARY cdef object IN_FMT_JSON = enums.InputFormat.JSON cdef object IN_LANG_EDGEQL = enums.InputLanguage.EDGEQL cdef object IN_LANG_SQL = enums.InputLanguage.SQL cdef object IN_LANG_SQL_PARAMS = enums.InputLanguage.SQL_PARAMS cdef object IN_LANG_GRAPHQL = enums.InputLanguage.GRAPHQL cdef char MASK_JSON_PARAMETERS = 1 << 0 cdef char MASK_EXPECT_ONE = 1 << 1 cdef char MASK_INLINE_TYPEIDS = 1 << 2 cdef char MASK_INLINE_TYPENAMES = 1 << 3 cdef char MASK_INLINE_OBJECTIDS = 1 << 4 cdef char serialize_output_format(val): if val is OUT_FMT_BINARY: return b'b' elif val is OUT_FMT_JSON: return b'j' elif val is OUT_FMT_JSON_ELEMENTS: return b'J' elif val is OUT_FMT_NONE: return b'n' else: raise AssertionError("unreachable") cdef deserialize_output_format(char mode): if mode == b'b': return OUT_FMT_BINARY elif mode == b'j': return OUT_FMT_JSON elif mode == b'J': return OUT_FMT_JSON_ELEMENTS elif mode == b'n': return OUT_FMT_NONE else: raise errors.BinaryProtocolError( f'unknown output format {mode.to_bytes(1, "big")!r}') cdef char serialize_input_language(val): if val is IN_LANG_EDGEQL: return b'E' elif val is IN_LANG_SQL: return b'S' elif val is IN_LANG_SQL_PARAMS: return b'P' elif val is IN_LANG_GRAPHQL: return b'G' else: raise errors.BinaryProtocolError(f'unknown input language {val!r}') cdef deserialize_input_language(char lang): if lang == b'E': return IN_LANG_EDGEQL elif lang == b'S': return IN_LANG_SQL elif lang == b'P': return IN_LANG_SQL_PARAMS elif lang == b'G': return IN_LANG_GRAPHQL else: raise errors.BinaryProtocolError( f'unknown input language {lang.to_bytes(1, "big")!r}') @cython.final cdef class SQLParamsSource: def __init__( self, types_in_out: list[tuple[list[str], list[tuple[str, str]]]] ): self.types_in_out = types_in_out self._cached_key = None self._serialized = None def cache_key(self): if self._cached_key is not None: return self._cached_key if self._serialized is None: self.serialize() self._cached_key = hashlib.blake2b(self._serialized).digest() return self._cached_key def text(self): return '' def serialize(self): if self._serialized is not None: return self._serialized self._serialized = pickle.dumps(self.types_in_out, -1) return self._serialized @staticmethod def deserialize(data: bytes): types_in_out = pickle.loads(data) return SQLParamsSource(types_in_out) @cython.final cdef class CompilationRequest: def __cinit__( self, *, source: edgeql.Source, protocol_version: defines.ProtocolVersion, schema_version: uuid.UUID, compilation_config_serializer: sertypes.CompilationConfigSerializer, input_language: enums.InputLanguage = enums.InputLanguage.EDGEQL, output_format: enums.OutputFormat = OUT_FMT_BINARY, input_format: enums.InputFormat = IN_FMT_BINARY, expect_one: bint = False, implicit_limit: int = 0, inline_typeids: bint = False, inline_typenames: bint = False, inline_objectids: bint = True, modaliases: Mapping[str | None, str] | None = None, session_config: Mapping[str, config.SettingValue] | None = None, database_config: Mapping[str, config.SettingValue] | None = None, system_config: Mapping[str, config.SettingValue] | None = None, role_name: str = defines.EDGEDB_SUPERUSER, branch_name: str = defines.EDGEDB_SUPERUSER_DB, key_params: Mapping[str, object] | None = None, ): self.serializer = compilation_config_serializer self.source = source self.protocol_version = protocol_version self.input_language = input_language self.output_format = output_format self.input_format = input_format self.expect_one = expect_one self.implicit_limit = implicit_limit self.inline_typeids = inline_typeids self.inline_typenames = inline_typenames self.inline_objectids = inline_objectids self.schema_version = schema_version self.modaliases = modaliases self.session_config = session_config self.database_config = database_config self.system_config = system_config self.role_name = role_name self.branch_name = branch_name self.key_params = key_params self.serialized_cache = None self.cache_key = None def __copy__(self): cdef CompilationRequest rv rv = CompilationRequest( source=self.source, protocol_version=self.protocol_version, schema_version=self.schema_version, compilation_config_serializer=self.serializer, input_language=self.input_language, output_format=self.output_format, input_format=self.input_format, expect_one=self.expect_one, implicit_limit=self.implicit_limit, inline_typeids=self.inline_typeids, inline_typenames=self.inline_typenames, inline_objectids=self.inline_objectids, modaliases=self.modaliases, session_config=self.session_config, database_config=self.database_config, system_config=self.system_config, role_name=self.role_name, branch_name=self.branch_name, key_params=self.key_params, ) rv.serialized_cache = self.serialized_cache rv.cache_key = self.cache_key return rv def set_modaliases(self, value) -> CompilationRequest: self.modaliases = value self.serialized_cache = None self.cache_key = None return self def set_session_config(self, value) -> CompilationRequest: self.session_config = value self.serialized_cache = None self.cache_key = None return self def set_database_config(self, value) -> CompilationRequest: self.database_config = value self.serialized_cache = None self.cache_key = None return self def set_system_config(self, value) -> CompilationRequest: self.system_config = value self.serialized_cache = None self.cache_key = None return self def set_schema_version(self, version: uuid.UUID) -> CompilationRequest: self.schema_version = version self.serialized_cache = None self.cache_key = None return self def set_key_params(self, key_params) -> CompilationRequest: self.key_params = key_params self.serialized_cache = None self.cache_key = None return self @classmethod def deserialize( cls, data: bytes, query_text: str, compilation_config_serializer: sertypes.CompilationConfigSerializer, ) -> CompilationRequest: return _deserialize_comp_req( data, query_text, compilation_config_serializer) def serialize(self) -> bytes: if self.serialized_cache is None: self._serialize() return self.serialized_cache def get_cache_key(self) -> uuid.UUID: if self.cache_key is None: self._serialize() return self.cache_key cdef _serialize(self): cdef WriteBuffer buf hash_obj, buf = _serialize_comp_req(self) cache_key = hash_obj.digest() buf.write_bytes(cache_key) self.cache_key = uuidgen.from_bytes(cache_key) self.serialized_cache = bytes(buf) def __hash__(self): return hash(self.get_cache_key()) def __eq__(self, rhs) -> bool: cdef: CompilationRequest other if not isinstance(rhs, CompilationRequest): return NotImplemented other = rhs return ( self.source.cache_key() == other.source.cache_key() and self.protocol_version == other.protocol_version and self.input_language == other.input_language and self.output_format == other.output_format and self.input_format == other.input_format and self.expect_one == other.expect_one and self.implicit_limit == other.implicit_limit and self.inline_typeids == other.inline_typeids and self.inline_typenames == other.inline_typenames and self.inline_objectids == other.inline_objectids and self.role_name == other.role_name and self.branch_name == other.branch_name and self.key_params == other.key_params ) # A handle class to allow deleting cache entries just by their cache_key. @cython.final cdef class CompilationRequestIdHandle: def __cinit__( self, cache_key: uuid.UUID ): self.cache_key = cache_key def __hash__(self): return hash(self.cache_key) def __eq__(self, rhs) -> bool: ty = type(rhs) if ty is CompilationRequestIdHandle: return self.cache_key == rhs.cache_key elif ty is CompilationRequest: return self.cache_key == rhs.get_cache_key() else: return NotImplemented cdef CompilationRequest _deserialize_comp_req( data: bytes, query_text: str, compilation_config_serializer: sertypes.CompilationConfigSerializer, ): cdef: ReadBuffer buf = ReadBuffer.new_message_parser(data) CompilationRequest req if data[0] == 1: req = _deserialize_comp_req_v1( buf, query_text, compilation_config_serializer) else: raise errors.UnsupportedProtocolVersionError( f"unsupported compile cache: version {data[0]}" ) # Cache key is always trailing regardless of the version. req.cache_key = uuidgen.from_bytes(buf.read_bytes(16)) req.serialized_cache = data return req cdef _deserialize_comp_req_v1( buf: ReadBuffer, query_text: str, compilation_config_serializer: sertypes.CompilationConfigSerializer, ): # Format: # # * 1 byte of version (0) # * 1 byte of bit flags: # * json_parameters # * expect_one # * inline_typeids # * inline_typenames # * inline_objectids # * protocol_version (major: int64, minor: int16) # * 1 byte output_format (the same as in the binary protocol) # * implicit_limit: int64 # * Module aliases: # * length: int32 (negative means the modaliases is None) # * For each alias pair: # * 1 byte, 0 if the name is None # * else, C-String as the name # * C-String as the alias # * Key parameter values: int32-length prefixed json object, -1 if None # * Session config type descriptor # * 16 bytes type ID # * int32-length-prefixed serialized type descriptor # * Session config: int32-length-prefixed serialized data # * Serialized Source or NormalizedSource without the original query # string # * The schema version ID. # * 1 byte input language (the same as in the binary protocol) # * role_name as a UTF-8 encoded string # * branch_name as a UTF-8 encoded string cdef char flags assert buf.read_byte() == 1 # version flags = buf.read_byte() if flags & MASK_JSON_PARAMETERS > 0: input_format = IN_FMT_JSON else: input_format = IN_FMT_BINARY expect_one = flags & MASK_EXPECT_ONE > 0 inline_typeids = flags & MASK_INLINE_TYPEIDS > 0 inline_typenames = flags & MASK_INLINE_TYPENAMES > 0 inline_objectids = flags & MASK_INLINE_OBJECTIDS > 0 protocol_version = buf.read_int16(), buf.read_int16() output_format = deserialize_output_format(buf.read_byte()) implicit_limit = buf.read_int64() size = buf.read_int32() if size >= 0: modaliases = [] for _ in range(size): if buf.read_byte(): k = buf.read_null_str().decode("utf-8") else: k = None v = buf.read_null_str().decode("utf-8") modaliases.append((k, v)) modaliases = immutables.Map(modaliases) else: modaliases = None key_params_str = buf.read_len_prefixed_utf8() if key_params_str: key_params = immutables.Map(json.loads(key_params_str)) else: key_params = None serializer = compilation_config_serializer type_id = uuidgen.from_bytes(buf.read_bytes(16)) if type_id == serializer.type_id: buf.read_len_prefixed_bytes() else: serializer = sertypes.CompilationConfigSerializer( type_id, buf.read_len_prefixed_bytes(), defines.CURRENT_PROTOCOL ) data = buf.read_len_prefixed_bytes() if data: session_config = immutables.Map( ( k, config.SettingValue( name=k, value=v, source='session', scope=qltypes.ConfigScope.SESSION, ) ) for k, v in serializer.decode(data).items() ) else: session_config = None serialized_source = buf.read_len_prefixed_bytes() schema_version = uuidgen.from_bytes(buf.read_bytes(16)) input_language = deserialize_input_language(buf.read_byte()) role_name = buf.read_len_prefixed_utf8() branch_name = buf.read_len_prefixed_utf8() if input_language is enums.InputLanguage.EDGEQL: source = tokenizer.deserialize(serialized_source, query_text) elif input_language is enums.InputLanguage.SQL: source = pgparser.deserialize(serialized_source) elif input_language is enums.InputLanguage.SQL_PARAMS: source = SQLParamsSource.deserialize(serialized_source) elif input_language is enums.InputLanguage.GRAPHQL: source = gql_tokenizer.deserialize(serialized_source, query_text) else: raise AssertionError( f"unexpected source language in serialized " f"CompilationRequest: {input_language}" ) req = CompilationRequest( source=source, protocol_version=protocol_version, schema_version=schema_version, compilation_config_serializer=serializer, input_language=input_language, output_format=output_format, input_format=input_format, expect_one=expect_one, implicit_limit=implicit_limit, inline_typeids=inline_typeids, inline_typenames=inline_typenames, inline_objectids=inline_objectids, modaliases=modaliases, session_config=session_config, role_name=role_name, branch_name=branch_name, key_params=key_params, ) return req cdef _serialize_comp_req(req: CompilationRequest): # Please see _deserialize_comp_req_v1 for the format doc cdef: char version = 1, flags WriteBuffer out = WriteBuffer.new() out.write_byte(version) flags = ( (MASK_JSON_PARAMETERS if req.input_format is IN_FMT_JSON else 0) | (MASK_EXPECT_ONE if req.expect_one else 0) | (MASK_INLINE_TYPEIDS if req.inline_typeids else 0) | (MASK_INLINE_TYPENAMES if req.inline_typenames else 0) | (MASK_INLINE_OBJECTIDS if req.inline_objectids else 0) ) out.write_byte(flags) out.write_int16(req.protocol_version[0]) out.write_int16(req.protocol_version[1]) out.write_byte(serialize_output_format(req.output_format)) out.write_int64(req.implicit_limit) if req.modaliases is None: out.write_int32(-1) else: out.write_int32(len(req.modaliases)) for k, v in sorted( req.modaliases.items(), key=lambda i: (0, i[0]) if i[0] is None else (1, i[0]) ): if k is None: out.write_byte(0) else: out.write_byte(1) out.write_str(k, "utf-8") out.write_str(v, "utf-8") if req.key_params is None: key_params_str = b'' else: key_params_str = json.dumps(req.key_params).encode('utf-8') out.write_len_prefixed_bytes(key_params_str) type_id, desc = req.serializer.describe() out.write_bytes(type_id.bytes) out.write_len_prefixed_bytes(desc) hash_obj = hashlib.blake2b(memoryview(out), digest_size=16) hash_obj.update(req.source.cache_key()) if req.session_config is None: session_config = b"" else: session_config = req.serializer.encode_configs( req.session_config ) out.write_len_prefixed_bytes(session_config) # Build config that affects compilation: session -> database -> system. # This is only used for calculating cache_key, while session # config itreq is separately stored above in the serialized format. serialized_comp_config = req.serializer.encode_configs( req.system_config, req.database_config, req.session_config ) hash_obj.update(serialized_comp_config) # Must set_schema_version() before serializing compilation request assert req.schema_version is not None hash_obj.update(req.schema_version.bytes) out.write_len_prefixed_bytes(req.source.serialize()) out.write_bytes(req.schema_version.bytes) out.write_byte(serialize_input_language(req.input_language)) hash_obj.update(req.input_language.value.encode("utf-8")) role_name = req.role_name.encode("utf-8") out.write_len_prefixed_bytes(role_name) hash_obj.update(role_name) branch_name = req.branch_name.encode("utf-8") out.write_len_prefixed_bytes(branch_name) hash_obj.update(branch_name) return hash_obj, out ================================================ FILE: edb/server/compiler/sertypes.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, ClassVar, Literal, Optional, Iterable, Mapping, Sequence, cast, overload, ) import collections.abc import dataclasses import enum import functools import io import struct import uuid import immutables from edb import errors from edb.common import binwrapper from edb.common import lru from edb.common import uuidgen from edb.common import value_dispatch from edb.protocol import enums as p_enums from edb.server import config from edb.server import defines as edbdef from edb.edgeql import qltypes from edb.schema import name as s_name from edb.schema import globals as s_globals from edb.schema import links as s_links from edb.schema import objects as s_obj from edb.schema import objtypes as s_objtypes from edb.schema import pointers as s_pointers from edb.schema import scalars as s_scalars from edb.schema import schema as s_schema from edb.schema import types as s_types from edb.ir import ast as irast from edb.ir import statypes from . import enums _int32_packer = cast(Callable[[int], bytes], struct.Struct('!l').pack) _uint32_packer = cast(Callable[[int], bytes], struct.Struct('!L').pack) _uint16_packer = cast(Callable[[int], bytes], struct.Struct('!H').pack) _uint8_packer = cast(Callable[[int], bytes], struct.Struct('!B').pack) _int64_struct = struct.Struct('!q') _float32_struct = struct.Struct('!f') EMPTY_TUPLE_ID = s_obj.get_known_type_id('empty-tuple') EMPTY_TUPLE_DESC = b'\x04' + EMPTY_TUPLE_ID.bytes + b'\x00\x00' UUID_TYPE_ID = s_obj.get_known_type_id('std::uuid') STR_TYPE_ID = s_obj.get_known_type_id('std::str') NULL_TYPE_ID = uuidgen.UUID(b'\x00' * 16) NULL_TYPE_DESC = b'' class DescriptorTag(bytes, enum.Enum): SET = b'\x00' SHAPE = b'\x01' BASE_SCALAR = b'\x02' SCALAR = b'\x03' TUPLE = b'\x04' NAMEDTUPLE = b'\x05' ARRAY = b'\x06' ENUM = b'\x07' INPUT_SHAPE = b'\x08' RANGE = b'\x09' OBJECT = b'\x0a' COMPOUND = b'\x0b' MULTIRANGE = b'\x0c' SQL_ROW = b'\x0d' ANNO_TYPENAME = b'\xff' class ShapePointerFlags(enum.IntFlag): IS_IMPLICIT = enum.auto() IS_LINKPROP = enum.auto() IS_LINK = enum.auto() class CompoundOp(enum.IntEnum): UNION = 1 << 0 INTERSECTION = 1 << 1 EMPTY_BYTEARRAY = bytearray() def _encode_str(data: str) -> bytes: return data.encode('utf-8') def _decode_str(data: bytes) -> str: return data.decode('utf-8') def _encode_bool(data: bool) -> bytes: return b'\x01' if data else b'\x00' def _decode_bool(data: bytes) -> bool: return bool(data[0]) def _encode_int64(data: int) -> bytes: return _int64_struct.pack(data) def _decode_int64(data: bytes) -> int: return _int64_struct.unpack(data)[0] # type: ignore [no-any-return] def _encode_float32(data: float) -> bytes: return _float32_struct.pack(data) def _decode_float32(data: bytes) -> float: return _float32_struct.unpack(data)[0] # type: ignore [no-any-return] def _string_packer(s: str) -> bytes: s_bytes = s.encode('utf-8') return _uint32_packer(len(s_bytes)) + s_bytes def _name_packer(n: s_name.Name) -> bytes: return _string_packer(str(n)) def _bool_packer(b: bool) -> bytes: return b'\x01' if b else b'\x00' def cardinality_from_ptr( ptr: s_pointers.Pointer | s_globals.Global, schema: s_schema.Schema, ) -> enums.Cardinality: required = ptr.get_required(schema) schema_card = ptr.get_cardinality(schema) ir_card = qltypes.Cardinality.from_schema_value(required, schema_card) return enums.cardinality_from_ir_value(ir_card) InputShapeElement = tuple[str, s_types.Type, enums.Cardinality] InputShapeMap = Mapping[s_types.Type, Iterable[InputShapeElement]] ViewShapeMap = Mapping[s_obj.Object, list[s_pointers.Pointer]] ViewShapeMetadataMap = Mapping[s_types.Type, irast.ViewShapeMetadata] class Context: def __init__( self, *, schema: s_schema.Schema, protocol_version: edbdef.ProtocolVersion, view_shapes: ViewShapeMap = immutables.Map(), view_shapes_metadata: ViewShapeMetadataMap = immutables.Map(), follow_links: bool = True, inline_typenames: bool = False, name_filter: str = "", ) -> None: self.schema = schema self.view_shapes = view_shapes self.view_shapes_metadata = view_shapes_metadata self.protocol_version = protocol_version self.follow_links = follow_links self.inline_typenames = inline_typenames self.name_filter = name_filter self.buffer: list[bytes] = [] self.anno_buffer: list[bytes] = [] self.uuid_to_pos: dict[uuid.UUID, int] = {} def derive(self) -> Context: ctx = type(self)( schema=self.schema, protocol_version=self.protocol_version, view_shapes=self.view_shapes, view_shapes_metadata=self.view_shapes_metadata, follow_links=self.follow_links, inline_typenames=self.inline_typenames, name_filter=self.name_filter, ) ctx.buffer = self.buffer.copy() ctx.anno_buffer = self.anno_buffer.copy() ctx.uuid_to_pos = self.uuid_to_pos.copy() return ctx def _get_collection_type_id( coll_type: str, subtypes: list[uuid.UUID], element_names: list[str] | None = None, ) -> uuid.UUID: if coll_type == 'tuple' and not subtypes: return s_obj.get_known_type_id('empty-tuple') string_id = f'{coll_type}\x00{":".join(map(str, subtypes))}' if element_names: string_id += f'\x00{":".join(element_names)}' return uuidgen.uuid5(s_obj.TYPE_ID_NAMESPACE, string_id) def _get_object_shape_id( coll_type: str, subtypes: list[uuid.UUID], element_names: Optional[list[str]] = None, cardinalities: Optional[list[enums.Cardinality]] = None, *, links_props: Optional[list[bool]] = None, links: Optional[list[bool]] = None, has_implicit_fields: bool = False, ) -> uuid.UUID: parts = [coll_type] parts.append(":".join(map(str, subtypes))) if element_names: parts.append(":".join(element_names)) if cardinalities: parts.append(":".join(chr(c._value_) for c in cardinalities)) string_id = "\x00".join(parts) string_id += f'{has_implicit_fields!r};{links_props!r};{links!r}' return uuidgen.uuid5(s_obj.TYPE_ID_NAMESPACE, string_id) def _get_set_type_id(basetype_id: uuid.UUID) -> uuid.UUID: return uuidgen.uuid5( s_obj.TYPE_ID_NAMESPACE, 'set-of::' + str(basetype_id)) def _register_type_id( type_id: uuid.UUID, ctx: Context, ) -> uuid.UUID: if type_id not in ctx.uuid_to_pos: ctx.uuid_to_pos[type_id] = len(ctx.uuid_to_pos) return type_id def _describe_set( t: s_types.Type, *, ctx: Context, ) -> uuid.UUID: type_id = _describe_type(t, ctx=ctx) set_id = _get_set_type_id(type_id) if set_id in ctx.uuid_to_pos: return set_id buf = [] # .tag buf.append(DescriptorTag.SET._value_) # .id buf.append(set_id.bytes) # .type buf.append(_type_ref_id_packer(type_id, ctx=ctx)) return _finish_typedesc(set_id, buf, ctx=ctx) # The encoding format is documented in edb/api/types.txt. @functools.singledispatch def _describe_type(t: s_types.Type, *, ctx: Context) -> uuid.UUID: raise errors.InternalServerError( f'cannot describe type {t.get_name(ctx.schema)}') def _type_ref_packer(t: s_types.Type, *, ctx: Context) -> bytes: """Return typedesc representation of a type reference.""" return _type_ref_id_packer(_describe_type(t, ctx=ctx), ctx=ctx) def _type_ref_id_packer(type_id: uuid.UUID, *, ctx: Context) -> bytes: """Return typedesc representation of a type reference by type id.""" return _uint16_packer(ctx.uuid_to_pos[type_id]) def _type_ref_seq_packer(ts: Sequence[s_types.Type], *, ctx: Context) -> bytes: """Return typedesc representation of a sequence of type references.""" result = _uint16_packer(len(ts)) for t in ts: result += _type_ref_packer(t, ctx=ctx) return result def _type_ref_id_seq_packer(ts: Sequence[uuid.UUID], *, ctx: Context) -> bytes: """Return typedesc representation of a sequence of type id references.""" result = _uint16_packer(len(ts)) for t in ts: result += _type_ref_id_packer(t, ctx=ctx) return result def _finish_typedesc( type_id: uuid.UUID, buf: list[bytes], *, ctx: Context, ) -> uuid.UUID: desc = b''.join(buf) if ctx.protocol_version >= (2, 0): ctx.buffer.append(_uint32_packer(len(desc))) ctx.buffer.append(desc) return _register_type_id(type_id, ctx=ctx) # Tuple -> TupleTypeDescriptor @_describe_type.register def _describe_tuple(t: s_types.Tuple, *, ctx: Context) -> uuid.UUID: subtypes = [ _describe_type(st, ctx=ctx) for st in t.get_subtypes(ctx.schema) ] is_named = t.is_named(ctx.schema) if is_named: element_names = list(t.get_element_names(ctx.schema)) assert len(element_names) == len(subtypes) tag = DescriptorTag.NAMEDTUPLE else: element_names = None tag = DescriptorTag.TUPLE type_id = _get_collection_type_id( t.get_schema_name(), subtypes, element_names) if type_id in ctx.uuid_to_pos: return type_id buf = [] # .tag buf.append(tag._value_) # .id buf.append(type_id.bytes) if ctx.protocol_version >= (2, 0): # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) # .element_count buf.append(_uint16_packer(len(subtypes))) if element_names is not None: # .elements for el_name, el_type_id in zip(element_names, subtypes): # TupleElement.name buf.append(_string_packer(el_name)) # TupleElement.type buf.append(_type_ref_id_packer(el_type_id, ctx=ctx)) else: # .element_types for el_type_id in subtypes: buf.append(_type_ref_id_packer(el_type_id, ctx=ctx)) return _finish_typedesc(type_id, buf, ctx=ctx) # Array -> ArrayTypeDescriptor @_describe_type.register def _describe_array(t: s_types.Array, *, ctx: Context) -> uuid.UUID: subtypes = [ _describe_type(st, ctx=ctx) for st in t.get_subtypes(ctx.schema) ] assert len(subtypes) == 1 type_id = _get_collection_type_id(t.get_schema_name(), subtypes) if type_id in ctx.uuid_to_pos: return type_id buf = [] # .tag buf.append(DescriptorTag.ARRAY._value_) # .id buf.append(type_id.bytes) if ctx.protocol_version >= (2, 0): # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) # .type buf.append(_type_ref_id_packer(subtypes[0], ctx=ctx)) # .dimension_count (currently always 1) buf.append(_uint16_packer(1)) # .dimensions (currently always unbounded) buf.append(_int32_packer(-1)) return _finish_typedesc(type_id, buf, ctx=ctx) # Range -> RangeTypeDescriptor @_describe_type.register def _describe_range(t: s_types.Range, *, ctx: Context) -> uuid.UUID: subtypes = [ _describe_type(st, ctx=ctx) for st in t.get_subtypes(ctx.schema) ] assert len(subtypes) == 1 type_id = _get_collection_type_id(t.get_schema_name(), subtypes) if type_id in ctx.uuid_to_pos: return type_id buf = [] # .tag buf.append(DescriptorTag.RANGE._value_) # .id buf.append(type_id.bytes) if ctx.protocol_version >= (2, 0): # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) # .type buf.append(_type_ref_id_packer(subtypes[0], ctx=ctx)) return _finish_typedesc(type_id, buf, ctx=ctx) # MultiRange -> MultiRangeTypeDescriptor @_describe_type.register def _describe_multirange(t: s_types.MultiRange, *, ctx: Context) -> uuid.UUID: subtypes = [ _describe_type(st, ctx=ctx) for st in t.get_subtypes(ctx.schema) ] assert len(subtypes) == 1 type_id = _get_collection_type_id(t.get_schema_name(), subtypes) if type_id in ctx.uuid_to_pos: return type_id buf = [] # .tag buf.append(DescriptorTag.MULTIRANGE._value_) # .id buf.append(type_id.bytes) if ctx.protocol_version >= (2, 0): # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined buf.append(_bool_packer(t.get_is_persistent(ctx.schema))) # .ancestors buf.append(_type_ref_seq_packer([], ctx=ctx)) # .type buf.append(_type_ref_id_packer(subtypes[0], ctx=ctx)) return _finish_typedesc(type_id, buf, ctx=ctx) # ObjectType (representing a shape) -> ObjectShapeDescriptor @_describe_type.register def _describe_object_shape( t: s_objtypes.ObjectType, *, ctx: Context, ) -> uuid.UUID: ctx.schema, mt = t.material_type(ctx.schema) base_type_name = str(mt.get_name(ctx.schema)) subtypes = [] element_names = [] link_props = [] links = [] cardinalities: list[enums.Cardinality] = [] sources = [] metadata = ctx.view_shapes_metadata.get(t) implicit_id = metadata is not None and metadata.has_implicit_id for ptr in ctx.view_shapes.get(t, ()): name = ptr.get_shortname(ctx.schema).name if not name.startswith(ctx.name_filter): continue name = name.removeprefix(ctx.name_filter) if ptr.singular(ctx.schema): if isinstance(ptr, s_links.Link) and not ctx.follow_links: uuid_t = ctx.schema.get('std::uuid', type=s_scalars.ScalarType) subtype_id = _describe_type(uuid_t, ctx=ctx) else: tgt = ptr.get_target(ctx.schema) assert tgt is not None subtype_id = _describe_type(tgt, ctx=ctx) else: if isinstance(ptr, s_links.Link) and not ctx.follow_links: raise errors.InternalServerError( 'cannot describe multi links when follow_links=False' ) else: tgt = ptr.get_target(ctx.schema) assert tgt is not None subtype_id = _describe_set(tgt, ctx=ctx) subtypes.append(subtype_id) element_names.append(name) link_props.append(False) links.append(not ptr.is_property()) cardinalities.append(cardinality_from_ptr(ptr, ctx.schema)) ctx.schema, material_ptr = ptr.material_type(ctx.schema) ptr_source = material_ptr.get_source(ctx.schema) assert isinstance(ptr_source, s_objtypes.ObjectType) ctx.schema, ptr_source = ptr_source.material_type(ctx.schema) assert ptr_source is not None sources.append(ptr_source) t_rptr = t.get_rptr(ctx.schema) if t_rptr is not None and (rptr_ptrs := ctx.view_shapes.get(t_rptr)): # There are link properties in the mix for ptr in rptr_ptrs: tgt = ptr.get_target(ctx.schema) assert tgt is not None if ptr.singular(ctx.schema): subtype_id = _describe_type(tgt, ctx=ctx) else: subtype_id = _describe_set(tgt, ctx=ctx) subtypes.append(subtype_id) element_names.append(ptr.get_shortname(ctx.schema).name) link_props.append(True) links.append(False) cardinalities.append(cardinality_from_ptr(ptr, ctx.schema)) # XXX: link properties do not support polymorphism currently sources.append(mt) assert len(subtypes) == len(element_names) type_id = _get_object_shape_id( base_type_name, subtypes, element_names, cardinalities, links_props=link_props, links=links, has_implicit_fields=implicit_id, ) if type_id in ctx.uuid_to_pos: return type_id is_free_object_type = t.is_free_object_type(ctx.schema) buf = [] # .tag buf.append(DescriptorTag.SHAPE._value_) # .id buf.append(type_id.bytes) if ctx.protocol_version >= (2, 0): # .ephemeral_free_shape buf.append(_bool_packer(is_free_object_type)) # .type if is_free_object_type: buf.append(_uint16_packer(0)) else: obj_type_id = _describe_object_type(mt, ctx=ctx) buf.append(_type_ref_id_packer(obj_type_id, ctx=ctx)) # .element_count buf.append(_uint16_packer(len(subtypes))) # .elements for el_name, el_type_id, el_lp, el_l, el_c, el_src in ( zip(element_names, subtypes, link_props, links, cardinalities, sources) ): flags = 0 if el_lp: flags |= ShapePointerFlags.IS_LINKPROP if (implicit_id and el_name == 'id') or el_name == '__tid__': if el_type_id != UUID_TYPE_ID: raise errors.InternalServerError( f"{el_name!r} is expected to be a 'std::uuid' singleton") flags |= ShapePointerFlags.IS_IMPLICIT elif el_name == '__tname__': if el_type_id != STR_TYPE_ID: raise errors.InternalServerError( f"{el_name!r} is expected to be a 'std::str' singleton") flags |= ShapePointerFlags.IS_IMPLICIT if el_l: flags |= ShapePointerFlags.IS_LINK # ShapeElement.flags buf.append(_uint32_packer(flags)) # ShapeElement.cardinality buf.append(_uint8_packer(el_c.value)) # ShapeElement.name buf.append(_string_packer(el_name)) # ShapeElement.type buf.append(_type_ref_id_packer(el_type_id, ctx=ctx)) if ctx.protocol_version >= (2, 0): # .source_type if not is_free_object_type: src_type_id = _describe_object_type(el_src, ctx=ctx) buf.append(_type_ref_id_packer(src_type_id, ctx=ctx)) else: buf.append(_uint16_packer(0)) return _finish_typedesc(type_id, buf, ctx=ctx) def _describe_object_type( t: s_objtypes.ObjectType, *, ctx: Context, ) -> uuid.UUID: if t.is_compound_type(ctx.schema): return _describe_compound_object_type(t, ctx=ctx) else: return _describe_regular_object_type(t, ctx=ctx) # ObjectType (regular) -> ObjectTypeDescriptor def _describe_regular_object_type( t: s_objtypes.ObjectType, *, ctx: Context, ) -> uuid.UUID: if ctx.protocol_version < (2, 0): raise AssertionError( f"cannot describe material object type {t.get_name(ctx.schema)!r} " f"in protocol < 2.0" ) buf = [] type_id = t.id if type_id in ctx.uuid_to_pos: # already described return type_id # .tag buf.append(DescriptorTag.OBJECT._value_) # .id buf.append(type_id.bytes) # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined buf.append(_bool_packer(True)) return _finish_typedesc(type_id, buf, ctx=ctx) # ObjectType (compound) -> CompoundTypeDescriptor def _describe_compound_object_type( t: s_objtypes.ObjectType, *, ctx: Context, ) -> uuid.UUID: if ctx.protocol_version < (2, 0): raise AssertionError( f"cannot describe compound object type {t.get_name(ctx.schema)!r} " "in protocol < 2.0" ) buf = [] type_id = t.id if type_id in ctx.uuid_to_pos: # already described return type_id components = t.get_union_of(ctx.schema).objects(ctx.schema) if components: op = CompoundOp.UNION else: components = t.get_intersection_of(ctx.schema).objects(ctx.schema) if not components: raise AssertionError( f"{t.get_name(ctx.schema)} is not a compound type") op = CompoundOp.INTERSECTION # .tag buf.append(DescriptorTag.COMPOUND._value_) # .id buf.append(type_id.bytes) # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined buf.append(_bool_packer(False)) # .op buf.append(_uint8_packer(op)) # .components buf.append(_type_ref_id_seq_packer( [_describe_object_type(c, ctx=ctx) for c in components], ctx=ctx, )) return _finish_typedesc(type_id, buf, ctx=ctx) @_describe_type.register def _describe_scalar_type( t: s_scalars.ScalarType, *, ctx: Context, ) -> uuid.UUID: ctx.schema, smt = t.material_type(ctx.schema) type_id = smt.id if type_id in ctx.uuid_to_pos: # already described return type_id if smt.is_enum(ctx.schema): return _describe_enum(smt, ctx=ctx) else: return _describe_regular_scalar(smt, ctx=ctx) # ScalarType (regular) -> [Base]ScalarTypeDescriptor def _describe_regular_scalar( t: s_scalars.ScalarType, *, ctx: Context, ) -> uuid.UUID: buf = [] fundamental_type = t.get_topmost_concrete_base(ctx.schema) type_id = t.id type_is_fundamental = t == fundamental_type if ctx.protocol_version >= (2, 0): # .tag buf.append(DescriptorTag.SCALAR._value_) # .id buf.append(type_id.bytes) # .name buf.append(_name_packer(t.get_name(ctx.schema))) # .schema_defined buf.append(_bool_packer(True)) # .ancestors_count # .ancestors if type_is_fundamental: buf.append(_uint16_packer(0)) else: ancestors = [] for ancestor in t.get_ancestors(ctx.schema).objects(ctx.schema): ancestors.append(ancestor) if ancestor == fundamental_type: break buf.append(_type_ref_seq_packer(ancestors, ctx=ctx)) else: if type_is_fundamental: # .tag buf.append(DescriptorTag.BASE_SCALAR._value_) # .id buf.append(type_id.bytes) else: # .tag buf.append(DescriptorTag.SCALAR._value_) # .id buf.append(type_id.bytes) # .base_type_pos buf.append(_type_ref_packer(fundamental_type, ctx=ctx)) if ctx.inline_typenames: _add_annotation(t, ctx=ctx) return _finish_typedesc(type_id, buf, ctx=ctx) # ScalarType (enum) -> EnumTypeDescriptor def _describe_enum( enum: s_scalars.ScalarType, *, ctx: Context, ) -> uuid.UUID: buf = [] enum_values = enum.get_enum_values(ctx.schema) assert enum_values is not None type_id = enum.id # .tag buf.append(DescriptorTag.ENUM._value_) # .id buf.append(type_id.bytes) if ctx.protocol_version >= (2, 0): # .name buf.append(_name_packer(enum.get_name(ctx.schema))) # .schema_defined buf.append(_bool_packer(True)) # .ancestors ancestors = [] topmost = enum.get_topmost_concrete_base(ctx.schema) if enum != topmost: for ancestor in enum.get_ancestors(ctx.schema).objects(ctx.schema): ancestors.append(ancestor) if ancestor == topmost: break buf.append(_type_ref_seq_packer(ancestors, ctx=ctx)) # .member_count buf.append(_uint16_packer(len(enum_values))) # .members for enum_val in enum_values: buf.append(_string_packer(enum_val)) if ctx.protocol_version < (2, 0) and ctx.inline_typenames: _add_annotation(enum, ctx=ctx) return _finish_typedesc(type_id, buf, ctx=ctx) @overload def describe_input_shape( t: s_types.Type, input_shapes: InputShapeMap, *, prepare_state: Literal[False], ctx: Context, ) -> uuid.UUID: ... @overload def describe_input_shape( t: s_types.Type, input_shapes: InputShapeMap, *, ctx: Context, ) -> uuid.UUID: ... @overload def describe_input_shape( t: s_types.Type, input_shapes: InputShapeMap, *, prepare_state: Literal[True], ctx: Context, ) -> None: ... def describe_input_shape( t: s_types.Type, input_shapes: InputShapeMap, *, prepare_state: bool = False, ctx: Context, ) -> Optional[uuid.UUID]: if t in input_shapes: element_names = [] subtypes = [] cardinalities = [] for name, subtype, cardinality in input_shapes[t]: if ( cardinality == enums.Cardinality.MANY or cardinality == enums.Cardinality.AT_LEAST_ONE ): subtype_id = _describe_set(subtype, ctx=ctx) else: subtype_id = describe_input_shape( subtype, input_shapes, ctx=ctx) element_names.append(name) subtypes.append(subtype_id) cardinalities.append(cardinality) assert len(subtypes) == len(element_names) if prepare_state: return None ctx.schema, mt = t.material_type(ctx.schema) base_type_name = str(mt.get_name(ctx.schema)) type_id = _get_object_shape_id( base_type_name, subtypes, element_names, cardinalities) if type_id in ctx.uuid_to_pos: return type_id buf = [] # .tag buf.append(DescriptorTag.INPUT_SHAPE._value_) # .id buf.append(type_id.bytes) # .element_count buf.append(_uint16_packer(len(subtypes))) # .elements for el_name, el_type_id, el_c in ( zip(element_names, subtypes, cardinalities) ): # ShapeElement.flags buf.append(_uint32_packer(0)) # ShapeElement.cardinality buf.append(_uint8_packer(el_c.value)) # ShapeElement.name buf.append(_string_packer(el_name)) # ShapeElement.type buf.append(_type_ref_id_packer(el_type_id, ctx=ctx)) return _finish_typedesc(type_id, buf, ctx=ctx) else: return _describe_type(t, ctx=ctx) def _add_annotation(t: s_types.Type, *, ctx: Context) -> None: buf = [] # .tag buf.append(DescriptorTag.ANNO_TYPENAME._value_) # .id buf.append(t.id.bytes) # .annotation buf.append(_string_packer(t.get_displayname(ctx.schema))) desc = b''.join(buf) if ctx.protocol_version >= (2, 0): ctx.anno_buffer.append(_uint32_packer(len(desc))) ctx.anno_buffer.append(desc) def describe_params( *, schema: s_schema.Schema, params: list[tuple[str, s_types.Type, bool]], protocol_version: edbdef.ProtocolVersion, ) -> tuple[bytes, uuid.UUID]: if not params: return NULL_TYPE_DESC, NULL_TYPE_ID ctx = Context( schema=schema, protocol_version=protocol_version, ) params_buf = [] subtypes = [] element_names = [] cardinalities = [] for param_name, param_type, param_req in params: param_type_id = _describe_type(param_type, ctx=ctx) # ShapeElement.flags params_buf.append(_uint32_packer(0)) # ShapeElement.cardinality card = ( p_enums.Cardinality.ONE if param_req else p_enums.Cardinality.AT_MOST_ONE ) cardinalities.append(card) params_buf.append(_uint8_packer(card._value_)) # ShapeElement.name params_buf.append(_string_packer(param_name)) element_names.append(param_name) # ShapeElement.type params_buf.append(_type_ref_id_packer(param_type_id, ctx=ctx)) subtypes.append(param_type_id) if protocol_version >= (2, 0): # ShapeElement.source_type params_buf.append(_uint16_packer(0)) params_id = _get_object_shape_id( "std::FreeObject", subtypes, element_names, cardinalities) params_shape = [ DescriptorTag.SHAPE._value_, params_id.bytes, ] if protocol_version >= (2, 0): # .ephemeral_free_shape params_shape.append(_bool_packer(True)) # .type params_shape.append(_uint16_packer(0)) params_shape.extend([ _uint16_packer(len(params)), *params_buf, ]) _finish_typedesc(params_id, params_shape, ctx=ctx) full_params = b''.join([ *ctx.buffer, *ctx.anno_buffer, ]) return full_params, params_id def describe_sql_result( *, schema: s_schema.Schema, row: dict[str, s_types.Type], protocol_version: edbdef.ProtocolVersion, ) -> tuple[bytes, uuid.UUID]: ctx = Context( schema=schema, protocol_version=protocol_version, ) params_buf = [] subtypes = [] element_names = [] for rel_name, rel_t in row.items(): rel_type_id = _describe_type(rel_t, ctx=ctx) # SQLRecordElement.name params_buf.append(_string_packer(rel_name)) element_names.append(rel_name) # SQLRecordElement.type params_buf.append(_type_ref_id_packer(rel_type_id, ctx=ctx)) subtypes.append(rel_type_id) rec_id = _get_object_shape_id("SQLRow", subtypes, element_names) record_body_bytes = [ DescriptorTag.SQL_ROW._value_, rec_id.bytes, ] record_body_bytes.extend([ _uint16_packer(len(row)), *params_buf, ]) _finish_typedesc(rec_id, record_body_bytes, ctx=ctx) record = b''.join([ *ctx.buffer, *ctx.anno_buffer, ]) return record, rec_id def describe( schema: s_schema.Schema, typ: s_types.Type, view_shapes: ViewShapeMap = immutables.Map(), view_shapes_metadata: ViewShapeMetadataMap = immutables.Map(), *, protocol_version: edbdef.ProtocolVersion, follow_links: bool = True, inline_typenames: bool = False, name_filter: str = "", ) -> tuple[bytes, uuid.UUID]: ctx = Context( schema=schema, view_shapes=view_shapes, view_shapes_metadata=view_shapes_metadata, protocol_version=protocol_version, follow_links=follow_links, inline_typenames=inline_typenames, name_filter=name_filter, ) type_id = _describe_type(typ, ctx=ctx) out = b''.join(ctx.buffer) + b''.join(ctx.anno_buffer) return out, type_id # # Type descriptor parsing # class ParseContext: def __init__( self, protocol_version: edbdef.ProtocolVersion, ) -> None: self.protocol_version = protocol_version self.codecs_list: list[TypeDesc] = [] def parse( typedesc: bytes, protocol_version: edbdef.ProtocolVersion, ) -> TypeDesc: """Unmarshal a byte stream with one or more type descriptors.""" ctx = ParseContext(protocol_version) buf = io.BytesIO(typedesc) wrapped = binwrapper.BinWrapper(buf) while buf.tell() < len(typedesc): _parse(wrapped, ctx=ctx) if not ctx.codecs_list: raise errors.InternalServerError('could not parse type descriptor') return ctx.codecs_list[-1] def _parse(desc: binwrapper.BinWrapper, ctx: ParseContext) -> None: """Unmarshal the next type descriptor from the byte stream.""" if ctx.protocol_version >= (2, 0): # .length desc.read_bytes(4) t = desc.read_bytes(1) try: tag = DescriptorTag(t) except ValueError: if (t[0] >= 0x80 and t[0] <= 0xff): # Ignore all type annotations. _parse_string(desc) return else: raise NotImplementedError( f'no codec implementation for Gel data kind {hex(t[0])}') else: ctx.codecs_list.append(_parse_descriptor(tag, desc, ctx=ctx)) # # Parsing helpers # def _parse_type_id(desc: binwrapper.BinWrapper) -> uuid.UUID: return uuidgen.from_bytes(desc.read_bytes(16)) def _parse_bool(desc: binwrapper.BinWrapper) -> bool: return bool(desc.read_bytes(1)[0]) def _parse_string(desc: binwrapper.BinWrapper) -> str: b = desc.read_len32_prefixed_bytes() try: return b.decode("utf-8") except UnicodeDecodeError as e: raise errors.InternalServerError( f"malformed type descriptor: invalid UTF-8 string " f"at stream position {desc.tell()}") from e def _parse_strings(desc: binwrapper.BinWrapper) -> list[str]: num = desc.read_ui16() return [_parse_string(desc) for _ in range(num)] def _parse_type_ref( desc: binwrapper.BinWrapper, *, ctx: ParseContext, ) -> TypeDesc: offset = desc.read_ui16() try: return ctx.codecs_list[offset] except KeyError: raise errors.InternalServerError( f"malformed type descriptor: dangling type reference: {offset} " f"at stream position {desc.tell()}") from None def _parse_type_refs( desc: binwrapper.BinWrapper, *, ctx: ParseContext, ) -> list[TypeDesc]: els = desc.read_ui16() return [_parse_type_ref(desc, ctx=ctx) for _ in range(els)] # # Parsing dispatch. # @value_dispatch.value_dispatch def _parse_descriptor( tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> TypeDesc: raise AssertionError( f'no codec implementation for Gel data kind {tag._name_}') @_parse_descriptor.register(DescriptorTag.SET) def _parse_set_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> SetDesc: # .id tid = _parse_type_id(desc) # .type subtype = _parse_type_ref(desc, ctx=ctx) return SetDesc(tid=tid, subtype=subtype) @_parse_descriptor.register(DescriptorTag.OBJECT) def _parse_object_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> ObjectDesc: if ctx.protocol_version < (2, 0): raise errors.ProtocolError( "unexpected ObjectTypeDescriptor in protocol " f"{ctx.protocol_version[0]}.{ctx.protocol_version[1]}") # .id tid = _parse_type_id(desc) # .name name = _parse_string(desc) # .schema_defined schema_defined = _parse_bool(desc) return ObjectDesc(tid=tid, name=name, schema_defined=schema_defined) @_parse_descriptor.register(DescriptorTag.COMPOUND) def _parse_compound_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> CompoundDesc: if ctx.protocol_version < (2, 0): raise errors.ProtocolError( "unexpected CompoundTypeDescriptor in protocol " f"{ctx.protocol_version[0]}.{ctx.protocol_version[1]}") # .id tid = _parse_type_id(desc) # .name name = _parse_string(desc) # .schema_defined schema_defined = _parse_bool(desc) # .op op_byte = desc.read_ui8() try: op = CompoundOp(op_byte) except ValueError: raise errors.ProtocolError( f"unexpected op in CompoundTypeDescriptor: {hex(op_byte)}" ) # .components components = _parse_type_refs(desc, ctx=ctx) return CompoundDesc( tid=tid, name=name, schema_defined=schema_defined, op=op, components=components, ) @_parse_descriptor.register(DescriptorTag.SHAPE) def _parse_shape_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> ShapeDesc: # .id tid = _parse_type_id(desc) objtype = None if ctx.protocol_version >= (2, 0): # .ephemeral_free_shape ephemeral_free_shape = _parse_bool(desc) if ephemeral_free_shape: desc.read_ui16() else: objtype = _parse_type_ref(desc, ctx=ctx) # .element_count els = desc.read_ui16() # .elements fields = {} flags = {} cardinalities = {} sources = {} for _ in range(els): # ShapeElement.flags flag = desc.read_ui32() # ShapeElement.cardinality cardinality = enums.Cardinality(desc.read_bytes(1)[0]) # ShapeElement.name name = _parse_string(desc) # ShapeElement.type subtype = _parse_type_ref(desc, ctx=ctx) if ctx.protocol_version >= (2, 0): # ShapeElement.source_type sources[name] = _parse_type_ref(desc, ctx=ctx) fields[name] = subtype flags[name] = flag if cardinality: cardinalities[name] = cardinality return ShapeDesc( tid=tid, type=objtype, flags=flags, fields=fields, cardinalities=cardinalities, sources=sources, ) @_parse_descriptor.register(DescriptorTag.INPUT_SHAPE) def _parse_input_shape_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> InputShapeDesc: # .id tid = _parse_type_id(desc) # .element_count els = desc.read_ui16() # .elements input_fields = {} flags = {} cardinalities = {} fields_list = [] for idx in range(els): # ShapeElement.flags flag = desc.read_ui32() # ShapeElement.cardinality cardinality = enums.Cardinality(desc.read_bytes(1)[0]) # ShapeElement.name name = _parse_string(desc) # ShapeElement.type subtype = _parse_type_ref(desc, ctx=ctx) fields_list.append((name, subtype)) input_fields[name] = idx, subtype flags[name] = flag if cardinality: cardinalities[name] = cardinality return InputShapeDesc( fields_list=fields_list, tid=tid, flags=flags, fields=input_fields, cardinalities=cardinalities, ) @_parse_descriptor.register(DescriptorTag.BASE_SCALAR) def _parse_base_scalar_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> BaseScalarDesc: if ctx.protocol_version >= (2, 0): raise errors.ProtocolError( "unexpected BaseScalarDescriptor in protocol " f"{ctx.protocol_version[0]}.{ctx.protocol_version[1]}") # .id tid = _parse_type_id(desc) return BaseScalarDesc(tid=tid) @_parse_descriptor.register(DescriptorTag.SCALAR) def _parse_scalar_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> ScalarDesc: # .id tid = _parse_type_id(desc) if ctx.protocol_version >= (2, 0): # .name name = _parse_string(desc) # .schema_defined schema_defined = _parse_bool(desc) # .ancestors ancestors = _parse_type_refs(desc, ctx=ctx) if ancestors: fundamental_type = ancestors[-1] else: fundamental_type = None else: name = None schema_defined = None fundamental_type = _parse_type_ref(desc, ctx=ctx) ancestors = None return ScalarDesc( tid=tid, name=name, schema_defined=schema_defined, fundamental_type=fundamental_type, ancestors=ancestors, ) @_parse_descriptor.register(DescriptorTag.TUPLE) def _parse_tuple_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> TupleDesc: # .id tid = _parse_type_id(desc) if ctx.protocol_version >= (2, 0): # .name name = _parse_string(desc) # .schema_defined schema_defined = _parse_bool(desc) # .ancestors ancestors = _parse_type_refs(desc, ctx=ctx) else: name = None schema_defined = None ancestors = None # .element_count # .elements tuple_fields = _parse_type_refs(desc, ctx=ctx) return TupleDesc( tid=tid, name=name, schema_defined=schema_defined, ancestors=ancestors, fields=tuple_fields, ) @_parse_descriptor.register(DescriptorTag.NAMEDTUPLE) def _parse_namedtuple_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> NamedTupleDesc: # .id tid = _parse_type_id(desc) if ctx.protocol_version >= (2, 0): # .name name = _parse_string(desc) # .schema_defined schema_defined = _parse_bool(desc) # .ancestors ancestors = _parse_type_refs(desc, ctx=ctx) else: name = None schema_defined = None ancestors = None # .element_count els = desc.read_ui16() fields = {} for _ in range(els): # TupleElement.name el_name = _parse_string(desc) # TupleElement.type fields[el_name] = _parse_type_ref(desc, ctx=ctx) return NamedTupleDesc( tid=tid, name=name, schema_defined=schema_defined, ancestors=ancestors, fields=fields, ) @_parse_descriptor.register(DescriptorTag.ENUM) def _parse_enum_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> EnumDesc: # .id tid = _parse_type_id(desc) if ctx.protocol_version >= (2, 0): # .name name = _parse_string(desc) # .schema_defined schema_defined = _parse_bool(desc) # .ancestors ancestors = _parse_type_refs(desc, ctx=ctx) else: name = None schema_defined = None ancestors = None # .member_count # .members names = _parse_strings(desc) return EnumDesc( tid=tid, name=name, schema_defined=schema_defined, ancestors=ancestors, names=names, ) @_parse_descriptor.register(DescriptorTag.ARRAY) def _parse_array_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> ArrayDesc: # .id tid = _parse_type_id(desc) if ctx.protocol_version >= (2, 0): # .name name = _parse_string(desc) # .schema_defined schema_defined = _parse_bool(desc) # .ancestors ancestors = _parse_type_refs(desc, ctx=ctx) else: name = None schema_defined = None ancestors = None # .type subtype = _parse_type_ref(desc, ctx=ctx) # .dimension_count els = desc.read_ui16() if els != 1: raise NotImplementedError( 'cannot handle arrays with more than one dimension') # .dimensions dim_len = desc.read_i32() if dim_len != -1: raise NotImplementedError( 'cannot handle arrays with non-infinite dimensions') return ArrayDesc( tid=tid, name=name, schema_defined=schema_defined, ancestors=ancestors, dim_len=dim_len, subtype=subtype, ) @_parse_descriptor.register(DescriptorTag.RANGE) def _parse_range_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> RangeDesc: # .id tid = _parse_type_id(desc) if ctx.protocol_version >= (2, 0): # .name name = _parse_string(desc) # .schema_defined schema_defined = _parse_bool(desc) # .ancestors ancestors = _parse_type_refs(desc, ctx=ctx) else: name = None schema_defined = None ancestors = None # .type subtype = _parse_type_ref(desc, ctx=ctx) return RangeDesc( tid=tid, name=name, schema_defined=schema_defined, ancestors=ancestors, inner=subtype, ) @_parse_descriptor.register(DescriptorTag.MULTIRANGE) def _parse_multirange_descriptor( _tag: DescriptorTag, desc: binwrapper.BinWrapper, ctx: ParseContext, ) -> MultiRangeDesc: # .id tid = _parse_type_id(desc) if ctx.protocol_version >= (2, 0): # .name name = _parse_string(desc) # .schema_defined schema_defined = _parse_bool(desc) # .ancestors ancestors = _parse_type_refs(desc, ctx=ctx) else: name = None schema_defined = None ancestors = None # .type subtype = _parse_type_ref(desc, ctx=ctx) return MultiRangeDesc( tid=tid, name=name, schema_defined=schema_defined, ancestors=ancestors, inner=subtype, ) def _make_global_rep(typ: s_types.Type, ctx: Context) -> object: if isinstance(typ, s_types.Tuple): subtyps = typ.get_subtypes(ctx.schema) return ( int(enums.TypeTag.TUPLE), tuple(subtyp.id for subtyp in subtyps), tuple(_make_global_rep(subtyp, ctx) for subtyp in subtyps), ) elif isinstance(typ, s_types.Array): subtyp = typ.get_element_type(ctx.schema) return ( int(enums.TypeTag.ARRAY), subtyp.id, _make_global_rep(subtyp, ctx)) else: return None class StateSerializerFactory: def __init__(self, std_schema: s_schema.Schema, config_spec: config.Spec): """ { module := 'default', aliases := [ ('alias', 'module::target'), ... ], config := cfg::Config { session_idle_transaction_timeout: '0:05:00', query_execution_timeout: '0:00:00', allow_bare_ddl: AlwaysAllow, apply_access_policies: true, }, globals := { key := value, ... }, } """ schema = std_schema str_type = schema.get('std::str', type=s_scalars.ScalarType) free_obj = schema.get('std::FreeObject', type=s_objtypes.ObjectType) schema, self._state_type = derive_alias(schema, free_obj, 'state_type') # aliases := { ('alias1', 'mod::type'), ... } schema, alias_tuple = s_types.Tuple.from_subtypes( schema, [str_type, str_type]) schema, aliases_array = s_types.Array.from_subtypes( schema, [alias_tuple]) schema, self.globals_type = derive_alias( schema, free_obj, 'state_globals') # config := cfg::Config { session_cfg1, session_cfg2, ... } schema, config_type = derive_alias( schema, free_obj, 'state_config' ) config_shape = self._make_config_shape(config_spec, schema) # Build type descriptors and codecs for compiler RPC # comp_config := cfg::Config { comp_cfg1, comp_cfg2, ... } schema, self._comp_config_type = derive_alias( schema, free_obj, 'comp_config' ) self._comp_config_shape: tuple[InputShapeElement, ...] = ( self._make_config_shape( config_spec, schema, lambda setting: setting.affects_compilation, ) ) self._input_shapes: immutables.Map[ s_types.Type, tuple[InputShapeElement, ...], ] = immutables.Map([ (config_type, config_shape), (self._state_type, ( ("module", str_type, enums.Cardinality.AT_MOST_ONE), ("aliases", aliases_array, enums.Cardinality.AT_MOST_ONE), ("config", config_type, enums.Cardinality.AT_MOST_ONE), )) ]) self.config_type = config_type self._schema = schema self._contexts: dict[edbdef.ProtocolVersion, Context] = {} @staticmethod def _make_config_shape( config_spec: config.Spec, schema: s_schema.Schema, matches: Callable[[Any], bool] = lambda setting: not setting.system, ) -> tuple[InputShapeElement, ...]: config_shape: list[InputShapeElement] = [] for setting in config_spec.values(): if matches(setting): setting_type_name = setting.schema_type_name setting_type = schema.get(setting_type_name, type=s_types.Type) config_shape.append( ( setting.name, setting_type, enums.Cardinality.MANY if setting.set_of else enums.Cardinality.AT_MOST_ONE, ) ) return tuple(sorted(config_shape)) def make( self, user_schema: s_schema.Schema, global_schema: s_schema.Schema, protocol_version: edbdef.ProtocolVersion, ) -> StateSerializer: ctx = self._contexts.get(protocol_version) if ctx is None: ctx = Context( schema=self._schema, protocol_version=protocol_version, ) self._contexts[protocol_version] = ctx describe_input_shape( self._state_type, self._input_shapes, prepare_state=True, ctx=ctx, ) ctx = ctx.derive() ctx.schema = s_schema.ChainedSchema( self._schema, user_schema, global_schema) # Update the config shape with any extension configs ext_config_spec = config.load_ext_spec_from_schema( user_schema, self._schema) new_config = self._make_config_shape(ext_config_spec, ctx.schema) full_config = self._input_shapes[self.config_type] + new_config globals_shape = [] global_reps = {} # Only look at user_schema for globals, since system defined ones # are only set by the system. for g in user_schema.get_objects(type=s_globals.Global): if g.is_computable(ctx.schema): continue name = str(g.get_name(ctx.schema)) s_type = g.get_target(ctx.schema) if isinstance(s_type, (s_types.Array, s_types.Tuple)): global_reps[name] = _make_global_rep(s_type, ctx) cardinality = cardinality_from_ptr(g, ctx.schema) globals_shape.append((name, s_type, cardinality)) type_id = describe_input_shape( self._state_type, self._input_shapes.update({ self.globals_type: tuple(sorted(globals_shape)), self.config_type: full_config, self._state_type: self._input_shapes[self._state_type] + ( ( "globals", self.globals_type, enums.Cardinality.AT_MOST_ONE, ), ) }), ctx=ctx, ) type_data = b''.join(ctx.buffer) return StateSerializer( type_id, type_data, global_reps, protocol_version ) def make_compilation_config_serializer(self) -> CompilationConfigSerializer: ctx = Context( schema=self._schema, protocol_version=edbdef.CURRENT_PROTOCOL, ) type_id = describe_input_shape( self._comp_config_type, {self._comp_config_type: self._comp_config_shape}, ctx=ctx ) type_data = b''.join(ctx.buffer) return CompilationConfigSerializer( type_id, type_data, edbdef.CURRENT_PROTOCOL ) class InputShapeSerializer: def __init__( self, type_id: uuid.UUID, type_data: bytes, protocol_version: edbdef.ProtocolVersion, ) -> None: self._type_id = type_id self._type_data = type_data self._protocol_version = protocol_version @functools.cached_property def _codec(self) -> InputShapeDesc: codec = parse(self._type_data, self._protocol_version) assert isinstance(codec, InputShapeDesc) return codec @property def type_id(self) -> uuid.UUID: return self._type_id def describe(self) -> tuple[uuid.UUID, bytes]: return self._type_id, self._type_data def encode(self, state: Any) -> bytes: return self._codec.encode(state) def decode(self, state: bytes) -> Any: return self._codec.decode(state) class StateSerializer(InputShapeSerializer): def __init__( self, type_id: uuid.UUID, type_data: bytes, global_reps: dict[str, object], protocol_version: edbdef.ProtocolVersion, ) -> None: super().__init__(type_id, type_data, protocol_version) self._global_reps = global_reps @functools.cached_property def _codec(self) -> InputShapeDesc: codec = super()._codec # Global values are directly used in Postgres, so we don't need to # encode/decode them in the I/O server. This feature doesn't worth a # separate switch in the type desc, so we just hack it in here. _, globals_type_desc = codec.fields['globals'] assert isinstance(globals_type_desc, InputShapeDesc) globals_type_desc.__dict__['data_raw'] = True return codec def get_global_type_rep( self, global_name: str, ) -> Optional[object]: return self._global_reps.get(global_name) class CompilationConfigSerializer(InputShapeSerializer): @lru.lru_method_cache(64) def encode_configs( self, *configs: immutables.Map[str, config.SettingValue] | None ) -> bytes: state: dict[str, Any] = {} for conf in configs: if conf is not None: state.update((k, v.value) for k, v in conf.items()) return self.encode(state) def derive_alias( schema: s_schema.Schema, parent: s_objtypes.ObjectType, qualifier: str, ) -> tuple[s_schema.Schema, s_types.InheritingType]: return parent.derive_subtype( schema, name=s_obj.derive_name( schema, qualifier, module='__derived__', parent=parent, ), mark_derived=True, inheritance_refdicts={'pointers'}, attrs={'expr_type': s_types.ExprType.Select}, ) @dataclasses.dataclass(frozen=True, kw_only=True) class TypeDesc: tid: uuid.UUID def encode(self, data: Any) -> bytes: raise NotImplementedError def decode(self, data: bytes) -> Any: raise NotImplementedError @dataclasses.dataclass(frozen=True, kw_only=True) class SequenceDesc(TypeDesc): subtype: TypeDesc impl: ClassVar[type[s_obj.CollectionFactory[Any]]] def encode(self, data: collections.abc.Collection[Any]) -> bytes: if not data: return b''.join(( _uint32_packer(0), _uint32_packer(0), _uint32_packer(0), )) bufs = [ _uint32_packer(1), _uint32_packer(0), _uint32_packer(0), _uint32_packer(len(data)), _uint32_packer(1), ] for item in data: if item is None: bufs.append(_int32_packer(-1)) else: item_bytes = self.subtype.encode(item) bufs.append(_uint32_packer(len(item_bytes))) bufs.append(item_bytes) return b''.join(bufs) def decode(self, data: bytes) -> collections.abc.Collection[Any]: buf = io.BytesIO(data) wrapped = binwrapper.BinWrapper(buf) ndims = wrapped.read_ui32() if ndims == 0: return self.impl() assert ndims == 1 wrapped.read_ui32() wrapped.read_ui32() data_len = wrapped.read_ui32() assert wrapped.read_ui32() == 1 return self.impl( self.subtype.decode(wrapped.read_len32_prefixed_bytes()) for _ in range(data_len) ) @dataclasses.dataclass(frozen=True, kw_only=True) class SetDesc(SequenceDesc): impl = frozenset @dataclasses.dataclass(frozen=True, kw_only=True) class ShapeDesc(TypeDesc): type: Optional[TypeDesc] fields: dict[str, TypeDesc] flags: dict[str, int] cardinalities: dict[str, enums.Cardinality] sources: dict[str, TypeDesc] @dataclasses.dataclass(frozen=True, kw_only=True) class SchemaTypeDesc(TypeDesc): name: Optional[str] = None schema_defined: Optional[bool] = None @dataclasses.dataclass(frozen=True, kw_only=True) class ObjectDesc(SchemaTypeDesc): pass @dataclasses.dataclass(frozen=True, kw_only=True) class CompoundDesc(SchemaTypeDesc): op: CompoundOp components: list[TypeDesc] @dataclasses.dataclass(frozen=True, kw_only=True) class BaseScalarDesc(SchemaTypeDesc): codecs: ClassVar[dict[ uuid.UUID, tuple[Callable[[Any], bytes], Callable[[bytes], Any]] ]] = { s_obj.get_known_type_id('std::duration'): ( statypes.Duration.encode, statypes.Duration.decode, ), s_obj.get_known_type_id('std::str'): ( _encode_str, _decode_str, ), s_obj.get_known_type_id('std::bool'): ( _encode_bool, _decode_bool, ), s_obj.get_known_type_id('std::int64'): ( _encode_int64, _decode_int64, ), s_obj.get_known_type_id('std::float32'): ( _encode_float32, _decode_float32, ), } def encode(self, data: Any) -> bytes: if codecs := self.codecs.get(self.tid): return codecs[0](data) raise NotImplementedError def decode(self, data: bytes) -> Any: if codecs := self.codecs.get(self.tid): return codecs[1](data) raise NotImplementedError @dataclasses.dataclass(frozen=True, kw_only=True) class ScalarDesc(BaseScalarDesc): fundamental_type: Optional[TypeDesc] ancestors: Optional[list[TypeDesc]] @dataclasses.dataclass(frozen=True, kw_only=True) class NamedTupleDesc(SchemaTypeDesc): fields: dict[str, TypeDesc] ancestors: Optional[list[TypeDesc]] @dataclasses.dataclass(frozen=True, kw_only=True) class TupleDesc(SchemaTypeDesc): fields: list[TypeDesc] ancestors: Optional[list[TypeDesc]] def encode(self, data: collections.abc.Sequence[Any]) -> bytes: bufs = [_uint32_packer(len(self.fields))] for idx, desc in enumerate(self.fields): bufs.append(_uint32_packer(0)) item = desc.encode(data[idx]) bufs.append(_uint32_packer(len(item))) bufs.append(item) return b''.join(bufs) def decode(self, data: bytes) -> tuple[Any, ...]: buf = io.BytesIO(data) wrapped = binwrapper.BinWrapper(buf) assert wrapped.read_ui32() == len(self.fields) rv = [] for desc in self.fields: wrapped.read_ui32() rv.append(desc.decode(wrapped.read_len32_prefixed_bytes())) return tuple(rv) @dataclasses.dataclass(frozen=True, kw_only=True) class EnumDesc(SchemaTypeDesc): names: list[str] ancestors: Optional[list[TypeDesc]] @functools.cached_property def _decoder(self) -> Callable[[bytes], Any]: assert self.name is not None pytype = statypes.maybe_get_python_type_for_scalar_type_name(self.name) if pytype is not None and issubclass(pytype, statypes.ScalarType): return pytype.decode else: return _decode_str @functools.cached_property def _encoder(self) -> Callable[[Any], bytes]: assert self.name is not None pytype = statypes.maybe_get_python_type_for_scalar_type_name(self.name) if pytype is not None and issubclass(pytype, statypes.ScalarType): return pytype.encode else: return _encode_str def encode(self, data: Any) -> bytes: return self._encoder(data) def decode(self, data: bytes) -> Any: return self._decoder(data) @dataclasses.dataclass(frozen=True, kw_only=True) class ArrayDesc(SequenceDesc, SchemaTypeDesc): ancestors: Optional[list[TypeDesc]] dim_len: int impl = list @dataclasses.dataclass(frozen=True, kw_only=True) class RangeDesc(SchemaTypeDesc): ancestors: Optional[list[TypeDesc]] inner: TypeDesc @dataclasses.dataclass(frozen=True, kw_only=True) class MultiRangeDesc(SchemaTypeDesc): ancestors: Optional[list[TypeDesc]] inner: TypeDesc @dataclasses.dataclass(frozen=True, kw_only=True) class InputShapeDesc(TypeDesc): fields: dict[str, tuple[int, TypeDesc]] fields_list: list[tuple[str, TypeDesc]] flags: dict[str, int] cardinalities: dict[str, enums.Cardinality] data_raw: bool = False def encode(self, data: Mapping[str, Any]) -> bytes: bufs = [b''] count = 0 for key, desc_tuple in self.fields.items(): if key not in data: continue value = data[key] idx, desc = desc_tuple bufs.append(_uint32_packer(idx)) if value is None: bufs.append(_int32_packer(-1)) else: if not self.data_raw: value = desc.encode(value) bufs.append(_uint32_packer(len(value))) bufs.append(value) count += 1 bufs[0] = _uint32_packer(count) return b''.join(bufs) def decode(self, data: bytes) -> dict[str, Any]: rv = {} buf = io.BytesIO(data) wrapped = binwrapper.BinWrapper(buf) for _ in range(wrapped.read_ui32()): idx = wrapped.read_ui32() name, desc = self.fields_list[idx] item_data = wrapped.read_nullable_len32_prefixed_bytes() if item_data is None: cardinality = self.cardinalities.get(name) if cardinality == enums.Cardinality.ONE: raise errors.CardinalityViolationError( f"State '{name}' expects exactly 1 value, 0 given" ) elif cardinality == enums.Cardinality.AT_LEAST_ONE: raise errors.CardinalityViolationError( f"State '{name}' expects at least 1 value, 0 given" ) if self.data_raw or item_data is None: rv[name] = item_data else: rv[name] = desc.decode(item_data) return rv ================================================ FILE: edb/server/compiler/sql.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Mapping, Sequence, TYPE_CHECKING, Optional, Any import dataclasses import functools import hashlib import immutables import json from edb import errors from edb.common import ast from edb.common import uuidgen from edb.common import debug from edb.server import defines from edb.schema import schema as s_schema from edb.pgsql import ast as pgast from edb.pgsql import common as pg_common from edb.pgsql import codegen as pg_codegen from edb.pgsql import params as pg_params from edb.pgsql import parser as pg_parser from . import dbstate from . import enums if TYPE_CHECKING: from edb.pgsql import resolver as pg_resolver # Frontend-only settings. Maps setting name into their mutability flag. FE_SETTINGS_MUTABLE: immutables.Map[str, bool] = immutables.Map( { 'search_path': True, 'allow_user_specified_id': True, 'apply_access_policies_pg': True, 'server_version': False, 'server_version_num': False, } ) class DisableNormalization(BaseException): # An exception that indicates that the compiler cannot work with this query # because the constants have been extracted and replaced with parameters. # When raised, the query will be recompiled without normalization. pass def compile_sql( source: pg_parser.Source, *, schema: s_schema.Schema, tx_state: dbstate.SQLTransactionState, prepared_stmt_map: Mapping[str, str], current_database: str, allow_user_specified_id: Optional[bool], apply_access_policies: Optional[bool], include_edgeql_io_format_alternative: bool = False, allow_prepared_statements: bool = True, disambiguate_column_names: bool, backend_runtime_params: pg_params.BackendRuntimeParams, protocol_version: defines.ProtocolVersion, implicit_limit: Optional[int] = None, ) -> tuple[list[dbstate.SQLQueryUnit], bool]: def _try( q: str, normalized_params: list[int] ) -> list[dbstate.SQLQueryUnit]: return _compile_sql( q, orig_query_str=source.original_text(), schema=schema, tx_state=tx_state, prepared_stmt_map=prepared_stmt_map, current_database=current_database, allow_user_specified_id=allow_user_specified_id, apply_access_policies=apply_access_policies, include_edgeql_io_format_alternative=( include_edgeql_io_format_alternative), allow_prepared_statements=allow_prepared_statements, disambiguate_column_names=disambiguate_column_names, backend_runtime_params=backend_runtime_params, protocol_version=protocol_version, normalized_params=normalized_params, implicit_limit=implicit_limit, ) normalized_params = list(source.extra_type_oids()) try: try: return _try(source.text(), normalized_params), False except DisableNormalization: # compiler requested non-normalized query (it needs it for static # evaluation) try: if isinstance(source, pg_parser.NormalizedSource): units = _try(source.original_text(), []) # Unit isn't cacheable, since the key is the # extracted version. # TODO: Can we tell the server to cache using non-extracted? for unit in units: unit.cacheable = False return units, True except DisableNormalization: pass raise AssertionError( "compiler is requesting query normalization to be disabled," "but it already is disabled" ) except errors.EdgeDBError as original_err: if isinstance(source, pg_parser.NormalizedSource): # try non-normalized source try: _try(source.original_text(), []) except errors.EdgeDBError as denormalized_err: raise denormalized_err except Exception: raise original_err else: raise AssertionError( "Normalized query is broken while original is valid") else: raise original_err def _build_constant_extraction_map( src: pgast.Base, out: pgast.Base, ) -> pg_codegen.BaseSourceMap: """Traverse two ASTs in parallel and build a source map between them. The ASTs should *mostly* line up. When they don't, that is considered a leaf. This is used to translate SQL spans reported on a normalized query to ones that make sense on the pre-normalization version. Note that we only use this map for errors reported during the "parse" phase, so we don't need to worry about it being reused with different constants. """ tdata = pg_codegen.BaseSourceMap( source_start=src.span.start if src.span else 0, # HACK: I don't know why, but this - 1 helps a lot. output_start=out.span.start - 1 if out.span else 0, ) if type(src) is not type(out): return tdata children = tdata.children for (k1, v1), (k2, v2) in zip(ast.iter_fields(src), ast.iter_fields(out)): assert k1 == k2 if isinstance(v1, pgast.Base) and isinstance(v2, pgast.Base): children.append(_build_constant_extraction_map(v1, v2)) elif ( isinstance(v1, (tuple, list)) and isinstance(v2, (tuple, list)) ): for v1e, v2e in zip(v1, v2): if isinstance(v1e, pgast.Base) and isinstance(v2e, pgast.Base): children.append(_build_constant_extraction_map(v1e, v2e)) elif ( isinstance(v1, dict) and isinstance(v2, dict) ): for k, v1e in v1.items(): v2e = v2.get(k) if isinstance(v1e, pgast.Base) and isinstance(v2e, pgast.Base): children.append(_build_constant_extraction_map(v1e, v2e)) children.sort(key=lambda k: k.output_start) return tdata def _compile_sql( query_str: str, *, orig_query_str: Optional[str] = None, schema: s_schema.Schema, tx_state: dbstate.SQLTransactionState, prepared_stmt_map: Mapping[str, str], current_database: str, allow_user_specified_id: Optional[bool], apply_access_policies: Optional[bool], include_edgeql_io_format_alternative: bool = False, allow_prepared_statements: bool = True, disambiguate_column_names: bool, backend_runtime_params: pg_params.BackendRuntimeParams, protocol_version: defines.ProtocolVersion, normalized_params: list[int], implicit_limit: Optional[int] = None, ) -> list[dbstate.SQLQueryUnit]: opts = ResolverOptionsPartial( query_str=query_str, current_database=current_database, allow_user_specified_id=allow_user_specified_id, apply_access_policies=apply_access_policies, include_edgeql_io_format_alternative=( include_edgeql_io_format_alternative ), disambiguate_column_names=disambiguate_column_names, normalized_params=normalized_params, implicit_limit=implicit_limit, ) # orig_stmts are the statements prior to constant extraction stmts = pg_parser.parse(query_str, propagate_spans=True) if orig_query_str and orig_query_str != query_str: orig_stmts = pg_parser.parse(orig_query_str, propagate_spans=True) else: orig_stmts = stmts sql_units = [] for stmt, orig_stmt in zip(stmts, orig_stmts): orig_text = pg_codegen.generate_source(stmt) fe_settings = tx_state.current_fe_settings() track_stats = False extract_data = _build_constant_extraction_map(orig_stmt, stmt) unit = dbstate.SQLQueryUnit( orig_query=pg_codegen.generate_source(orig_stmt), fe_settings=fe_settings, # by default, the query is sent to PostgreSQL unchanged query=orig_text, ) if isinstance(stmt, (pgast.VariableSetStmt, pgast.VariableResetStmt)): if protocol_version != defines.POSTGRES_PROTOCOL: from edb.pgsql import resolver as pg_resolver pg_resolver.dispatch._raise_unsupported(stmt) value: Optional[dbstate.SQLSetting] if isinstance(stmt, pgast.VariableSetStmt): value = pg_arg_list_to_python(stmt.args) else: value = None fe_only = stmt.name and ( # GOTCHA: setting is frontend-only regardless of its mutability stmt.name in FE_SETTINGS_MUTABLE or stmt.name.startswith('global ') ) if fe_only: assert stmt.name if not FE_SETTINGS_MUTABLE.get(stmt.name, True): raise errors.QueryError( f'parameter "{stmt.name}" cannot be changed', pgext_code='55P02', # cant_change_runtime_param ) unit.set_vars = {stmt.name: value} unit.frontend_only = True unit.command_complete_tag = dbstate.TagPlain( tag=( b"SET" if isinstance(stmt, pgast.VariableSetStmt) else b"RESET" ) ) elif stmt.scope == pgast.OptionsScope.SESSION: unit.set_vars = {stmt.name: value} unit.is_local = stmt.scope == pgast.OptionsScope.TRANSACTION if not unit.is_local: unit.capabilities |= enums.Capability.SESSION_CONFIG unit.capabilities |= enums.Capability.SQL_SESSION_CONFIG elif isinstance(stmt, pgast.VariableShowStmt): if protocol_version != defines.POSTGRES_PROTOCOL: from edb.pgsql import resolver as pg_resolver pg_resolver.dispatch._raise_unsupported(stmt) stmt = _compile_show_command(stmt) unit.query = pg_codegen.generate_source(stmt) unit.command_complete_tag = dbstate.TagPlain(tag=b"SHOW") elif isinstance(stmt, pgast.SetTransactionStmt): if protocol_version != defines.POSTGRES_PROTOCOL: from edb.pgsql import resolver as pg_resolver pg_resolver.dispatch._raise_unsupported(stmt) if stmt.scope == pgast.OptionsScope.SESSION: unit.set_vars = { f"default_{name}": ( ( value.val if isinstance(value, pgast.StringConstant) else pg_codegen.generate_source(value) ), ) for name, value in stmt.options.options.items() } unit.capabilities |= enums.Capability.SQL_SESSION_CONFIG elif isinstance(stmt, (pgast.BeginStmt, pgast.StartStmt)): unit.tx_action = dbstate.TxAction.START unit.command_complete_tag = dbstate.TagPlain( tag=( b"START TRANSACTION" if isinstance(stmt, pgast.StartStmt) else b"BEGIN" ) ) elif isinstance(stmt, pgast.CommitStmt): unit.tx_action = dbstate.TxAction.COMMIT unit.tx_chain = stmt.chain or False unit.command_complete_tag = dbstate.TagPlain(tag=b"COMMIT") elif isinstance(stmt, pgast.RollbackStmt): unit.tx_action = dbstate.TxAction.ROLLBACK unit.tx_chain = stmt.chain or False unit.command_complete_tag = dbstate.TagPlain(tag=b"ROLLBACK") elif isinstance(stmt, pgast.SavepointStmt): unit.tx_action = dbstate.TxAction.DECLARE_SAVEPOINT unit.sp_name = stmt.savepoint_name unit.command_complete_tag = dbstate.TagPlain(tag=b"SAVEPOINT") elif isinstance(stmt, pgast.ReleaseStmt): unit.tx_action = dbstate.TxAction.RELEASE_SAVEPOINT unit.sp_name = stmt.savepoint_name unit.command_complete_tag = dbstate.TagPlain(tag=b"RELEASE") elif isinstance(stmt, pgast.RollbackToStmt): unit.tx_action = dbstate.TxAction.ROLLBACK_TO_SAVEPOINT unit.sp_name = stmt.savepoint_name unit.command_complete_tag = dbstate.TagPlain(tag=b"ROLLBACK") elif isinstance(stmt, pgast.TwoPhaseTransactionStmt): raise NotImplementedError( "two-phase transactions are not supported" ) elif isinstance(stmt, pgast.PrepareStmt): if not allow_prepared_statements: raise errors.UnsupportedFeatureError( "SQL prepared statements are not supported" ) if not isinstance(stmt.query, (pgast.Query, pgast.CopyStmt)): from edb.pgsql import resolver as pg_resolver pg_resolver.dispatch._raise_unsupported(stmt.query) # Translate the underlying query. stmt_resolved, stmt_source, _ = resolve_query( stmt.query, schema, tx_state, opts ) if stmt.argtypes: param_types = [] for pt in stmt.argtypes: param_types.append(pg_codegen.generate_source(pt)) param_text = f"({', '.join(param_types)})" else: param_text = "" sql_trailer = f"{param_text} AS ({stmt_source.text})" mangled_stmt_name = compute_stmt_name( f"PREPARE {pg_common.quote_ident(stmt.name)}{sql_trailer}", tx_state, ) sql_text = ( f"PREPARE {pg_common.quote_ident(mangled_stmt_name)}" f"{sql_trailer}" ) unit.query = sql_text unit.prepare = dbstate.PrepareData( stmt_name=stmt.name, be_stmt_name=mangled_stmt_name.encode("utf-8"), query=stmt_source.text, source_map=stmt_source.source_map, ) unit.command_complete_tag = dbstate.TagPlain(tag=b"PREPARE") unit.capabilities |= stmt_resolved.capabilities track_stats = True elif isinstance(stmt, pgast.ExecuteStmt): if not allow_prepared_statements: raise errors.UnsupportedFeatureError( "SQL prepared statements are not supported" ) orig_name = stmt.name mangled_name = prepared_stmt_map.get(orig_name) if not mangled_name: raise errors.QueryError( f"prepared statement \"{orig_name}\" does " f"not exist", pgext_code='26000', # invalid_sql_statement_name ) stmt.name = mangled_name unit.query = pg_codegen.generate_source(stmt) unit.execute = dbstate.ExecuteData( stmt_name=orig_name, be_stmt_name=mangled_name.encode("utf-8"), ) unit.cardinality = enums.Cardinality.MANY track_stats = True elif isinstance(stmt, pgast.DeallocateStmt): if not allow_prepared_statements: raise errors.UnsupportedFeatureError( "SQL prepared statements are not supported" ) orig_name = stmt.name mangled_name = prepared_stmt_map.get(orig_name) if not mangled_name: raise errors.QueryError( f"prepared statement \"{orig_name}\" does " f"not exist", pgext_code='26000', # invalid_sql_statement_name ) stmt.name = mangled_name unit.query = pg_codegen.generate_source(stmt) unit.deallocate = dbstate.DeallocateData( stmt_name=orig_name, be_stmt_name=mangled_name.encode("utf-8"), ) unit.command_complete_tag = dbstate.TagPlain(tag=b"DEALLOCATE") elif isinstance(stmt, pgast.LockStmt): if stmt.mode not in ('ACCESS SHARE', 'ROW SHARE', 'SHARE'): raise NotImplementedError("exclusive lock is not supported") # just ignore unit.query = "DO $$ BEGIN END $$;" elif isinstance(stmt, (pgast.Query, pgast.CopyStmt)): if ( protocol_version != defines.POSTGRES_PROTOCOL and isinstance(stmt, pgast.CopyStmt) ): from edb.pgsql import resolver as pg_resolver pg_resolver.dispatch._raise_unsupported(stmt) stmt_resolved, stmt_source, edgeql_fmt_src = resolve_query( stmt, schema, tx_state, opts ) unit.query = stmt_source.text unit.source_map = stmt_source.source_map if stmt_source.source_map: unit.source_map = ( pg_codegen.ChainedSourceMap([ stmt_source.source_map, extract_data, ]) ) if edgeql_fmt_src is not None: unit.eql_format_query = edgeql_fmt_src.text # We don't do anything with the translation data for # this query, since postgres typically doesn't report # out error positions that didn't get reported during # the "parse" phase. unit.command_complete_tag = stmt_resolved.command_complete_tag unit.params = stmt_resolved.params if isinstance(stmt, pgast.DMLQuery) and not stmt.returning_list: unit.cardinality = enums.Cardinality.NO_RESULT else: unit.cardinality = enums.Cardinality.MANY unit.capabilities |= stmt_resolved.capabilities track_stats = True else: from edb.pgsql import resolver as pg_resolver pg_resolver.dispatch._raise_unsupported(stmt) unit.stmt_name = compute_stmt_name(unit.query, tx_state).encode("utf-8") sql_info: dict[str, Any] = {} if track_stats and backend_runtime_params.has_stat_statements: cconfig: dict[str, dbstate.SQLSetting] = { k: v for k, v in fe_settings.items() if k is not None and v is not None and k in FE_SETTINGS_MUTABLE } cconfig.pop('server_version', None) cconfig.pop('server_version_num', None) if allow_user_specified_id is not None: cconfig.setdefault( 'allow_user_specified_id', ('true' if allow_user_specified_id else 'false',), ) if apply_access_policies is not None: cconfig.setdefault( 'apply_access_policies', ('true' if apply_access_policies else 'false',), ) search_path = parse_search_path(cconfig.pop("search_path", ("",))) cconfig = dict(sorted((k, v) for k, v in cconfig.items())) extras = { 'cc': cconfig, # compilation_config 'pv': protocol_version, # protocol_version 'dn': ', '.join(search_path), # default_namespace } sql_info['query'] = orig_text, sql_info['type'] = defines.QueryType.SQL, sql_info['extras'] = json.dumps(extras), id_hash = hashlib.blake2b(digest_size=16) id_hash.update( json.dumps(sql_info).encode(defines.EDGEDB_ENCODING) ) sql_info['id'] = str(uuidgen.from_bytes(id_hash.digest())) if debug.flags.sql_text_in_sql: sql_info['sql'] = orig_query_str or query_str if sql_info: prefix = ''.join([ '-- ', json.dumps(sql_info), '\n', ]) unit.prefix_len = len(prefix) unit.query = prefix + unit.query if unit.eql_format_query is not None: unit.eql_format_query = prefix + unit.eql_format_query if unit.tx_action is not None: unit.capabilities |= enums.Capability.TRANSACTION tx_state.apply(unit) sql_units.append(unit) if not sql_units: # Cluvio will try to execute an empty query sql_units.append( dbstate.SQLQueryUnit( orig_query='', query='', fe_settings=tx_state.current_fe_settings(), ) ) return sql_units def _compile_show_command(stmt: pgast.VariableShowStmt) -> pgast.Query: pg_settings_for_show = pgast.RelRangeVar( relation=pgast.Relation( schemaname=pg_common.versioned_schema('edgedbsql'), name='pg_settings_for_show', ) ) if stmt.name.lower() == 'all': return pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.ColumnRef(name=[pgast.Star()]) ) ], from_clause=[pg_settings_for_show], ) elif stmt.name.startswith("global "): # SELECT coalesce( # (SELECT value # FROM _edgecon_state # WHERE name = 'global ..' AND type = 'L' # ), # (SELECT value # FROM _edgecon_state # WHERE name = 'global ..' AND type = 'S' # ) # ) #>> '{}' AS "global .." sublinks: list[pgast.Base] = [ pgast.SubLink( operator=None, expr=pgast.SelectStmt( target_list=[ pgast.ResTarget( val=pgast.ColumnRef(name=['value']), ), ], from_clause=[ pgast.RelRangeVar( relation=pgast.Relation( name='_edgecon_state' ), ), ], where_clause=pgast.Expr( name='AND', lexpr=pgast.Expr( name='=', lexpr=pgast.ColumnRef(name=['name']), rexpr=pgast.StringConstant(val=stmt.name), ), rexpr=pgast.Expr( name='=', lexpr=pgast.ColumnRef(name=['type']), rexpr=pgast.StringConstant(val=typ), ), ), ) ) for typ in ["L", "S"] ] return pgast.SelectStmt( target_list=[ pgast.ResTarget( name=stmt.name, val=pgast.Expr( name='#>>', lexpr=pgast.CoalesceExpr(args=sublinks), rexpr=pgast.StringConstant(val='{}'), ), ) ] ) else: return pgast.SelectStmt( target_list=[ pgast.ResTarget( name=stmt.name, val=pgast.ColumnRef(name=['setting']), ), ], from_clause=[pg_settings_for_show], where_clause=pgast.Expr( name='=', lexpr=pgast.ColumnRef(name=['name']), rexpr=pgast.StringConstant(val=stmt.name), ) ) @dataclasses.dataclass(kw_only=True, eq=False, repr=False) class ResolverOptionsPartial: current_database: str query_str: str allow_user_specified_id: Optional[bool] apply_access_policies: Optional[bool] include_edgeql_io_format_alternative: Optional[bool] disambiguate_column_names: bool normalized_params: list[int] implicit_limit: Optional[int] def resolve_query( stmt: pgast.Query | pgast.CopyStmt, schema: s_schema.Schema, tx_state: dbstate.SQLTransactionState, opts: ResolverOptionsPartial, ) -> tuple[ pg_resolver.ResolvedSQL, pg_codegen.SQLSource, Optional[pg_codegen.SQLSource], ]: from edb.pgsql import resolver as pg_resolver search_path: Sequence[str] = ("public",) try: setting = tx_state.get("search_path") except KeyError: setting = None search_path = parse_search_path(setting) allow_user_specified_id = lookup_bool_setting( tx_state, 'allow_user_specified_id' ) if allow_user_specified_id is None: allow_user_specified_id = opts.allow_user_specified_id if allow_user_specified_id is None: allow_user_specified_id = False apply_access_policies = lookup_bool_setting( tx_state, 'apply_access_policies_pg' ) if apply_access_policies is None: apply_access_policies = opts.apply_access_policies if apply_access_policies is None: apply_access_policies = True options = pg_resolver.Options( current_database=opts.current_database, current_query=opts.query_str, search_path=search_path, allow_user_specified_id=allow_user_specified_id, apply_access_policies=apply_access_policies, include_edgeql_io_format_alternative=( opts.include_edgeql_io_format_alternative ), disambiguate_column_names=opts.disambiguate_column_names, normalized_params=opts.normalized_params, implicit_limit=opts.implicit_limit, ) resolved = pg_resolver.resolve(stmt, schema, options) source = pg_codegen.generate(resolved.ast, with_source_map=True) if resolved.edgeql_output_format_ast is not None: edgeql_format_source = pg_codegen.generate( resolved.edgeql_output_format_ast, with_source_map=True, ) else: edgeql_format_source = None return resolved, source, edgeql_format_source def lookup_bool_setting( tx_state: dbstate.SQLTransactionState, name: str ) -> Optional[bool]: try: setting = tx_state.get(name) except KeyError: setting = None if setting and setting[0]: return is_setting_truthy(setting[0]) return None def is_setting_truthy(value: str | int | float) -> bool | None: if isinstance(value, int): return value != 0 if isinstance(value, str): value = value.lower() if value == 'o': # ambigious return None truthy_values = ('on', 'true', 'yes', '1') if any(t.startswith(value) for t in truthy_values): return True falsy_values = ('off', 'false', 'no', '0') if any(t.startswith(value) for t in falsy_values): return False return None def compute_stmt_name(text: str, tx_state: dbstate.SQLTransactionState) -> str: stmt_hash = hashlib.sha1(text.encode("utf-8")) for setting_name in sorted(FE_SETTINGS_MUTABLE): try: setting_value = tx_state.get(setting_name) except KeyError: pass else: stmt_hash.update(f"{setting_name}:{setting_value}".encode("utf-8")) return f"edb{stmt_hash.hexdigest()}" @functools.cache def parse_search_path(search_path_str: list[str | int | float]) -> list[str]: return [part for part in search_path_str if isinstance(part, str)] def pg_arg_list_to_python(expr: pgast.ArgsList) -> dbstate.SQLSetting: return tuple(pg_const_to_python(a) for a in expr.args) def pg_const_to_python(expr: pgast.BaseExpr) -> str | int | float: "Converts a pg const expression into a Python value" if isinstance(expr, pgast.StringConstant): return expr.val if isinstance(expr, pgast.NumericConstant): try: return int(expr.val) except ValueError: return float(expr.val) raise NotImplementedError() ================================================ FILE: edb/server/compiler/status.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import functools from edb.edgeql import ast as qlast from edb.edgeql import qltypes @functools.singledispatch def get_status(ql: qlast.Base) -> bytes: raise NotImplementedError( f'cannot get status for the {type(ql).__name__!r} AST node' ) @get_status.register(qlast.CreateObject) def _ddl_create(ql: qlast.CreateObject) -> bytes: return f'CREATE {get_schema_class(ql)}'.encode() @get_status.register(qlast.AlterObject) def _ddl_alter(ql: qlast.AlterObject) -> bytes: return f'ALTER {get_schema_class(ql)}'.encode() @get_status.register(qlast.DropObject) def _ddl_drop(ql: qlast.DropObject) -> bytes: return f'DROP {get_schema_class(ql)}'.encode() def get_schema_class(ql: qlast.ObjectDDL) -> qltypes.SchemaObjectClass: osc = qltypes.SchemaObjectClass match ql: case qlast.DatabaseCommand(flavor='BRANCH'): return osc.BRANCH case qlast.DatabaseCommand(flavor='DATABASE'): return osc.DATABASE case qlast.FutureCommand(): return osc.FUTURE case qlast.ModuleCommand(): return osc.MODULE case qlast.RoleCommand(): return osc.ROLE case qlast.PropertyCommand(): return osc.PROPERTY case qlast.ObjectTypeCommand(): return osc.TYPE case qlast.AliasCommand(): return osc.ALIAS case qlast.GlobalCommand(): return osc.GLOBAL case qlast.PermissionCommand(): return osc.PERMISSION case qlast.LinkCommand(): return osc.LINK case qlast.IndexCommand(): return osc.INDEX case qlast.AccessPolicyCommand(): return osc.INDEX_MATCH case qlast.TriggerCommand(): return osc.TRIGGER case qlast.RewriteCommand(): return osc.REWRITE case qlast.FunctionCommand(): return osc.FUNCTION case qlast.OperatorCommand(): return osc.OPERATOR case qlast.CastCommand(): return osc.CAST case qlast.MigrationCommand(): return osc.MIGRATION case qlast.ExtensionPackageCommand(): return osc.EXTENSION_PACKAGE case qlast.ExtensionPackageMigrationCommand(): return osc.EXTENSION_PACKAGE_MIGRATION case qlast.ExtensionCommand(): return osc.EXTENSION case qlast.ExtensionCommand(): return osc.EXTENSION case qlast.AnnotationCommand(): return osc.ANNOTATION case qlast.PseudoTypeCommand(): return osc.PSEUDO_TYPE case qlast.ScalarTypeCommand(): return osc.SCALAR_TYPE case qlast.ConstraintCommand(): return osc.CONSTRAINT case qlast.AccessPolicyCommand(): # Why is this duplicate here? return osc.ACCESS_POLICY case _: raise AssertionError('unimplemented') @get_status.register(qlast.StartMigration) def _ddl_migr_start(ql: qlast.Base) -> bytes: return b'START MIGRATION' @get_status.register(qlast.CreateMigration) def _ddl_migr_create(ql: qlast.Base) -> bytes: return b'CREATE MIGRATION' @get_status.register(qlast.CommitMigration) def _ddl_migr_commit(ql: qlast.Base) -> bytes: return b'COMMIT MIGRATION' @get_status.register(qlast.DropMigration) def _ddl_migr_drop(ql: qlast.Base) -> bytes: return b'DROP MIGRATION' @get_status.register(qlast.AlterMigration) def _ddl_migr_alter(ql: qlast.Base) -> bytes: return b'ALTER MIGRATION' @get_status.register(qlast.AbortMigration) def _ddl_migr_abort(ql: qlast.Base) -> bytes: return b'ABORT MIGRATION' @get_status.register(qlast.PopulateMigration) def _ddl_migr_populate(ql: qlast.Base) -> bytes: return b'POPULATE MIGRATION' @get_status.register(qlast.DescribeCurrentMigration) def _ddl_migr_describe_current(ql: qlast.Base) -> bytes: return b'DESCRIBE CURRENT MIGRATION' @get_status.register(qlast.AlterCurrentMigrationRejectProposed) def _ddl_migr_alter_current(ql: qlast.Base) -> bytes: return b'ALTER CURRENT MIGRATION' @get_status.register(qlast.StartMigrationRewrite) def _ddl_migr_rw_start(ql: qlast.Base) -> bytes: return b'START MIGRATION REWRITE' @get_status.register(qlast.CommitMigrationRewrite) def _ddl_migr_rw_commit(ql: qlast.Base) -> bytes: return b'COMMIT MIGRATION REWRITE' @get_status.register(qlast.AbortMigrationRewrite) def _ddl_migr_rw_abort(ql: qlast.Base) -> bytes: return b'ABORT MIGRATION REWRITE' @get_status.register(qlast.ResetSchema) def _ddl_migr_reset_schema(ql: qlast.Base) -> bytes: return b'RESET SCHEMA' @get_status.register(qlast.SelectQuery) @get_status.register(qlast.GroupQuery) @get_status.register(qlast.InternalGroupQuery) @get_status.register(qlast.ForQuery) def _select(ql: qlast.Base) -> bytes: return b'SELECT' @get_status.register(qlast.InsertQuery) def _insert(ql: qlast.Base) -> bytes: return b'INSERT' @get_status.register(qlast.UpdateQuery) def _update(ql: qlast.Base) -> bytes: return b'UPDATE' @get_status.register(qlast.DeleteQuery) def _delete(ql: qlast.Base) -> bytes: return b'DELETE' @get_status.register(qlast.StartTransaction) def _tx_start(ql: qlast.Base) -> bytes: return b'START TRANSACTION' @get_status.register(qlast.CommitTransaction) def _tx_commit(ql: qlast.Base) -> bytes: return b'COMMIT TRANSACTION' @get_status.register(qlast.RollbackTransaction) def _tx_rollback(ql: qlast.Base) -> bytes: return b'ROLLBACK TRANSACTION' @get_status.register(qlast.DeclareSavepoint) def _tx_sp_declare(ql: qlast.Base) -> bytes: return b'DECLARE SAVEPOINT' @get_status.register(qlast.RollbackToSavepoint) def _tx_sp_rollback(ql: qlast.Base) -> bytes: return b'ROLLBACK TO SAVEPOINT' @get_status.register(qlast.ReleaseSavepoint) def _tx_sp_release(ql: qlast.Base) -> bytes: return b'RELEASE SAVEPOINT' @get_status.register(qlast.SessionSetAliasDecl) def _sess_set_alias(ql: qlast.Base) -> bytes: return b'SET ALIAS' @get_status.register(qlast.SessionResetAliasDecl) @get_status.register(qlast.SessionResetModule) @get_status.register(qlast.SessionResetAllAliases) def _sess_reset_alias(ql: qlast.Base) -> bytes: return b'RESET ALIAS' @get_status.register(qlast.ConfigOp) def _sess_set_config(ql: qlast.ConfigOp) -> bytes: if ql.scope == qltypes.ConfigScope.GLOBAL: if isinstance(ql, qlast.ConfigSet): return b'SET GLOBAL' else: return b'RESET GLOBAL' else: return f'CONFIGURE {ql.scope}'.encode('ascii') @get_status.register(qlast.DescribeStmt) def _describe(ql: qlast.Base) -> bytes: return f'DESCRIBE'.encode() @get_status.register(qlast.Rename) def _rename(ql: qlast.Base) -> bytes: return f'RENAME'.encode() @get_status.register(qlast.ExplainStmt) def _explain(ql: qlast.Base) -> bytes: return b'ANALYZE QUERY' @get_status.register(qlast.AdministerStmt) def _administer(ql: qlast.Base) -> bytes: return b'ADMINISTER' @get_status.register(qlast.DDLQuery) def _query(ql: qlast.DDLQuery) -> bytes: return get_status(ql.query) ================================================ FILE: edb/server/compiler_pool/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from .pool import create_compiler_pool, AbstractPool __all__ = ('create_compiler_pool', 'AbstractPool') ================================================ FILE: edb/server/compiler_pool/amsg.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Callable, cast, Generator, Optional import asyncio import os import socket import struct OnPidCallback = Callable[["HubProtocol", asyncio.Transport, int, int], None] OnConnectionLostCallback = Callable[[Optional[int]], None] _uint64_unpacker = struct.Struct('!Q').unpack _uint64_packer = struct.Struct('!Q').pack class MessageStream: """Data stream that yields messages.""" _buffer: bytes _curmsg_len: int def __init__(self) -> None: self._buffer = b'' self._curmsg_len = -1 def feed_data(self, data: bytes) -> Generator[bytes, None, None]: # TODO: rewrite to avoid buffer copies. self._buffer += data while self._buffer: if self._curmsg_len == -1: if len(self._buffer) >= 8: self._curmsg_len = _uint64_unpacker(self._buffer[:8])[0] self._buffer = self._buffer[8:] else: return if self._curmsg_len > 0 and len(self._buffer) >= self._curmsg_len: msg = self._buffer[:self._curmsg_len] self._buffer = self._buffer[self._curmsg_len:] self._curmsg_len = -1 yield msg else: return class HubProtocol(asyncio.Protocol): """The Protocol used on the hub side connecting to workers.""" _loop: asyncio.AbstractEventLoop _transport: Optional[asyncio.Transport] _closed: bool _stream: MessageStream _resp_waiters: dict[int, asyncio.Future[memoryview]] _on_pid: OnPidCallback _on_connection_lost: OnConnectionLostCallback _pid: Optional[int] def __init__( self, *, loop: asyncio.AbstractEventLoop, on_pid: OnPidCallback, on_connection_lost: OnConnectionLostCallback, ) -> None: self._loop = loop self._transport = None self._closed = False self._stream = MessageStream() self._resp_waiters = {} self._on_pid = on_pid self._on_connection_lost = on_connection_lost self._pid = None def connection_made(self, tr: asyncio.BaseTransport) -> None: self._transport = cast(asyncio.Transport, tr) def send( self, req_id: int, waiter: asyncio.Future[memoryview], payload: bytes, ) -> None: if req_id in self._resp_waiters: raise RuntimeError('FramedProtocol: duplicate request ID') assert self._transport is not None self._resp_waiters[req_id] = waiter self._transport.writelines( (_uint64_packer(len(payload) + 8), _uint64_packer(req_id), payload) ) def process_message(self, msg: bytes) -> None: msgview = memoryview(msg) req_id = _uint64_unpacker(msgview[:8])[0] waiter = self._resp_waiters.pop(req_id, None) if waiter is None: # This could have happened if the previous request got cancelled. return if not waiter.done(): waiter.set_result(msgview[8:]) def data_received(self, data: bytes) -> None: if self._pid is None: assert self._transport is not None pid_data = data[:8] version = _uint64_unpacker(data[8:16])[0] data = data[16:] self._pid = _uint64_unpacker(pid_data)[0] self._on_pid(self, self._transport, self._pid, version) for msg in self._stream.feed_data(data): self.process_message(msg) def connection_lost(self, exc: Optional[Exception]) -> None: self._closed = True if self._resp_waiters: if exc is not None: for waiter in self._resp_waiters.values(): waiter.set_exception(exc) else: for waiter in self._resp_waiters.values(): waiter.set_exception(ConnectionError( 'lost connection to the worker during a call')) self._resp_waiters = {} self._on_connection_lost(self._pid) class HubConnection: """An abstraction of the hub connections to the workers.""" _transport: asyncio.Transport _protocol: HubProtocol _loop: asyncio.AbstractEventLoop _req_id_cnt: int _version: int _aborted: bool def __init__( self, transport: asyncio.Transport, protocol: HubProtocol, loop: asyncio.AbstractEventLoop, version: int, ) -> None: self._transport = transport self._protocol = protocol self._loop = loop self._req_id_cnt = 0 self._version = version self._aborted = False def is_closed(self) -> bool: return self._protocol._closed async def request(self, data: bytes) -> memoryview: self._req_id_cnt += 1 req_id = self._req_id_cnt waiter = self._loop.create_future() self._protocol.send(req_id, waiter, data) return await waiter def abort(self) -> None: self._aborted = True self._transport.abort() class WorkerConnection: """Connection object used by the worker's process.""" _sock: Optional[socket.socket] _stream: MessageStream def __init__(self, sockname: str, version: int) -> None: self._sock = socket.socket(socket.AF_UNIX) self._sock.connect(sockname) self._sock.sendall( _uint64_packer(os.getpid()) + _uint64_packer(version) ) self._stream = MessageStream() def _on_message(self, msg: bytes) -> tuple[int, memoryview]: msgview = memoryview(msg) req_id = _uint64_unpacker(msgview[:8])[0] return req_id, msgview[8:] def reply(self, req_id: int, payload: bytes) -> None: assert self._sock is not None self._sock.sendall( b"".join( ( _uint64_packer(len(payload) + 8), _uint64_packer(req_id), payload, ) ) ) def iter_request(self) -> Generator[tuple[int, memoryview], None, None]: while True: data = b'' if self._sock is None else self._sock.recv(4096) if not data: # EOF received - abort self.abort() return yield from map(self._on_message, self._stream.feed_data(data)) def abort(self) -> None: if self._sock is not None: self._sock.close() self._sock = None class ServerProtocol: def worker_connected(self, pid: int, version: int) -> None: pass def worker_disconnected(self, pid: int) -> None: pass class Server: _sockname: str _loop: asyncio.AbstractEventLoop _srv: Optional[asyncio.AbstractServer] _pids: dict[int, HubConnection] _proto: ServerProtocol def __init__( self, sockname: str, loop: asyncio.AbstractEventLoop, server_protocol: ServerProtocol, ) -> None: self._sockname = sockname self._loop = loop self._srv = None self._pids = {} self._proto = server_protocol def _on_pid_connected( self, proto: HubProtocol, tr: asyncio.Transport, pid: int, version: int, ) -> None: assert pid not in self._pids self._pids[pid] = HubConnection(tr, proto, self._loop, version) self._proto.worker_connected(pid, version) def _on_pid_disconnected(self, pid: Optional[int]) -> None: if not pid: return if pid in self._pids: self._pids.pop(pid) self._proto.worker_disconnected(pid) def _proto_factory(self) -> HubProtocol: return HubProtocol( loop=self._loop, on_pid=self._on_pid_connected, on_connection_lost=self._on_pid_disconnected, ) def get_by_pid(self, pid: int) -> HubConnection: return self._pids[pid] async def start(self) -> None: self._srv = await self._loop.create_unix_server( self._proto_factory, path=self._sockname) async def stop(self) -> None: if self._srv is None: return self._srv.close() for con in self._pids.values(): con.abort() await self._srv.wait_closed() def kill_outdated_worker(self, current_version: int) -> None: for conn in self._pids.values(): if conn._version < current_version and not conn._aborted: conn.abort() break ================================================ FILE: edb/server/compiler_pool/multitenant_worker.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Callable, Optional, NamedTuple, Sequence import pickle import immutables from edb import edgeql from edb import graphql from edb.common import debug from edb.common import uuidgen from edb.pgsql import params as pgparams from edb.schema import schema as s_schema from edb.server import compiler from edb.server import config from edb.server import defines from . import state from . import worker_proc INITED: bool = False clients: immutables.Map[int, ClientSchema] = immutables.Map() BACKEND_RUNTIME_PARAMS: pgparams.BackendRuntimeParams = ( pgparams.get_default_runtime_params() ) COMPILER: compiler.Compiler LAST_STATE: Optional[compiler.dbstate.CompilerConnectionState] = None STD_SCHEMA: s_schema.Schema class ClientSchema(NamedTuple): dbs: state.DatabasesState global_schema: s_schema.Schema instance_config: immutables.Map[str, config.SettingValue] def __init_worker__( init_args_pickled: bytes, ) -> None: global INITED global BACKEND_RUNTIME_PARAMS global COMPILER global STD_SCHEMA ( backend_runtime_params, std_schema, refl_schema, schema_class_layout, ) = pickle.loads(init_args_pickled) INITED = True BACKEND_RUNTIME_PARAMS = backend_runtime_params STD_SCHEMA = std_schema COMPILER = compiler.new_compiler( std_schema, refl_schema, schema_class_layout, backend_runtime_params=backend_runtime_params, config_spec=None, ) def __sync__(client_id, pickled_schema, invalidation) -> None: global clients for cid in invalidation: try: clients = clients.delete(cid) except KeyError: pass try: client_schema: ClientSchema = clients.get(client_id) # type: ignore if pickled_schema: if client_schema is None: dbs = { dbname: state.DatabaseState( dbname, ( None if pickled_state.user_schema is None else pickle.loads(pickled_state.user_schema) ), pickle.loads(pickled_state.reflection_cache), pickle.loads(pickled_state.database_config), ) for dbname, pickled_state in pickled_schema.dbs.items() } if debug.flags.server: print(client_id, "FULL SYNC: ", list(dbs)) client_schema = ClientSchema( immutables.Map(dbs), pickle.loads(pickled_schema.global_schema), pickle.loads(pickled_schema.instance_config), ) clients = clients.set(client_id, client_schema) else: updates = {} dbs = client_schema.dbs if pickled_schema.dbs is not None: for dbname, pickled_state in pickled_schema.dbs.items(): db_state = dbs.get(dbname) if db_state is None: assert pickled_state.user_schema is not None assert pickled_state.reflection_cache is not None assert pickled_state.database_config is not None db_state = state.DatabaseState( dbname, pickle.loads(pickled_state.user_schema), pickle.loads(pickled_state.reflection_cache), pickle.loads(pickled_state.database_config), ) if debug.flags.server: print(client_id, "DIFF SYNC ADD: ", dbname) dbs = dbs.set(dbname, db_state) else: db_updates = {} if pickled_state.user_schema is not None: db_updates["user_schema"] = pickle.loads( pickled_state.user_schema ) if pickled_state.reflection_cache is not None: db_updates["reflection_cache"] = pickle.loads( pickled_state.reflection_cache ) if pickled_state.database_config is not None: db_updates["database_config"] = pickle.loads( pickled_state.database_config ) if db_updates: if debug.flags.server: print( client_id, "DIFF SYNC UPDATE: ", dbname ) val = dbs.get(dbname) dbs = dbs.set( dbname, val._replace(**db_updates), # type: ignore ) if pickled_schema.dropped_dbs is not None: for dbname in pickled_schema.dropped_dbs: if debug.flags.server: print(client_id, "DIFF SYNC DROP: ", dbname) dbs = dbs.delete(dbname) if dbs is not client_schema.dbs: updates["dbs"] = dbs if pickled_schema.global_schema is not None: updates["global_schema"] = pickle.loads( pickled_schema.global_schema ) if pickled_schema.instance_config is not None: updates["instance_config"] = pickle.loads( pickled_schema.instance_config ) if updates: client_schema = client_schema._replace( **updates # type: ignore ) clients = clients.set(client_id, client_schema) else: assert client_schema is not None except Exception as ex: raise state.FailedStateSync( f"failed to sync worker state: {type(ex).__name__}({ex})" ) from ex def compile( client_id: int, dbname: str, *compile_args: Any, **compile_kwargs: Any, ): client_schema = clients[client_id] db = client_schema.dbs[dbname] units, cstate = COMPILER.compile_serialized_request( db.user_schema, client_schema.global_schema, db.reflection_cache, db.database_config, client_schema.instance_config, *compile_args, **compile_kwargs, ) pickled_state = None if cstate is not None: global LAST_STATE LAST_STATE = cstate pickled_state = pickle.dumps(cstate, -1) return units, pickled_state def compile_in_tx( _, client_id: Optional[int], dbname: Optional[str], user_schema: Optional[bytes], cstate, *args, **kwargs, ): global LAST_STATE if cstate == state.REUSE_LAST_STATE_MARKER: assert LAST_STATE is not None cstate = LAST_STATE else: cstate = pickle.loads(cstate) if client_id is None: assert user_schema is not None cstate.set_root_user_schema(pickle.loads(user_schema)) else: assert dbname is not None client_schema = clients[client_id] db = client_schema.dbs[dbname] cstate.set_root_user_schema(db.user_schema) units, cstate = COMPILER.compile_serialized_request_in_tx( cstate, *args, **kwargs) LAST_STATE = cstate return units, pickle.dumps(cstate, -1) def compile_notebook( client_id: int, dbname: str, *compile_args: Any, **compile_kwargs: Any, ): global clients client_schema = clients[client_id] db = client_schema.dbs[dbname] return COMPILER.compile_notebook( db.user_schema, client_schema.global_schema, db.reflection_cache, db.database_config, client_schema.instance_config, *compile_args, **compile_kwargs, ) def compile_graphql( client_id: int, dbname: str, *compile_args: Any, **compile_kwargs: Any, ): global clients client_schema = clients[client_id] db = client_schema.dbs[dbname] gql_op = graphql.compile_graphql( STD_SCHEMA, db.user_schema, client_schema.global_schema, db.database_config, client_schema.instance_config, *compile_args, **compile_kwargs ) source = edgeql.Source.from_string( edgeql.generate_source(gql_op.edgeql_ast, pretty=True), ) cfg_ser = COMPILER.state.compilation_config_serializer request = compiler.CompilationRequest( source=source, protocol_version=defines.CURRENT_PROTOCOL, schema_version=uuidgen.uuid4(), compilation_config_serializer=cfg_ser, output_format=compiler.OutputFormat.JSON, input_format=compiler.InputFormat.JSON, expect_one=True, implicit_limit=0, inline_typeids=False, inline_typenames=False, inline_objectids=False, modaliases=None, session_config=None, ) unit_group, _ = COMPILER.compile( user_schema=db.user_schema, global_schema=client_schema.global_schema, reflection_cache=db.reflection_cache, database_config=db.database_config, system_config=client_schema.instance_config, request=request, ) return unit_group, gql_op def compile_sql( client_id: int, dbname: str, *compile_args: Any, **compile_kwargs: Any, ): client_schema = clients[client_id] db = client_schema.dbs[dbname] return COMPILER.compile_sql( db.user_schema, client_schema.global_schema, db.reflection_cache, db.database_config, client_schema.instance_config, *compile_args, **compile_kwargs, ) def call_for_client( client_id: int, pickled_schema: Optional[bytes], invalidation: Sequence[int], msg: Optional[bytes], *args: Any, ) -> Any: __sync__(client_id, pickled_schema, invalidation) if msg is None: methname, dbname, *compile_args = args else: assert args == () methname, args = pickle.loads(msg) ( dbname, # These are pass-thru arguments from Gel server, they are already # utilized in the compiler server and forwarded to us through # "pickled_schema" argument, so we don't need them here. evicted_dbs, user_schema, reflection_cache, global_schema, database_config, system_config, *compile_args, ) = args if methname == "compile": meth = compile elif methname == "compile_notebook": meth = compile_notebook elif methname == "compile_graphql": meth = compile_graphql elif methname == "compile_sql": meth = compile_sql else: raise NotImplementedError( f"call_for_client() is not implemented for {methname!r} method. " ) return meth(client_id, dbname, *compile_args) def get_handler(methname: str) -> Callable[..., Any]: meth: Callable[..., Any] if methname == "__init_worker__": meth = __init_worker__ else: if not INITED: raise RuntimeError( "call on uninitialized compiler worker" ) if methname == "call_for_client": meth = call_for_client elif methname == "compile_in_tx": meth = compile_in_tx else: meth = getattr(COMPILER, methname) return meth if __name__ == "__main__": try: worker_proc.main(get_handler) except KeyboardInterrupt: pass ================================================ FILE: edb/server/compiler_pool/pool.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, cast, Hashable, Mapping, NamedTuple, Optional, TYPE_CHECKING, ) import asyncio import collections import dataclasses import functools import hmac import logging import os import os.path import pickle import random import signal import subprocess import sys import time import immutables import psutil from edb.common import debug from edb.common import lru from edb.pgsql import params as pgparams from edb.schema import reflection as s_refl from edb.schema import schema as s_schema from edb.server import args as srvargs from edb.server import config from edb.server import dbview from edb.server import defines from edb.server import metrics from . import amsg from . import queue from . import state if TYPE_CHECKING: from edb import errors from edb import graphql from edb.server.compiler import compiler from edb.server.compiler import config as config_compiler from edb.server.compiler import dbstate from edb.server.compiler import sertypes SyncStateCallback = Callable[[], None] SyncFinalizer = Callable[[], None] Config = immutables.Map[str, config.SettingValue] InitArgs = tuple[ pgparams.BackendRuntimeParams, s_schema.Schema, s_schema.Schema, s_refl.SchemaClassLayout, bytes, Config, ] MultiTenantInitArgs = tuple[ pgparams.BackendRuntimeParams, s_schema.Schema, s_schema.Schema, s_refl.SchemaClassLayout, ] RemoteInitArgsPickle = tuple[bytes, bytes, bytes, bytes] PreArgs = tuple[Any, ...] PROCESS_INITIAL_RESPONSE_TIMEOUT: float = 60.0 KILL_TIMEOUT: float = 10.0 HEALTH_CHECK_MIN_INTERVAL: float = float( os.getenv("GEL_COMPILER_HEALTH_CHECK_MIN_INTERVAL", 10) ) HEALTH_CHECK_TIMEOUT: float = float( os.getenv("GEL_COMPILER_HEALTH_CHECK_TIMEOUT", 10) ) ADAPTIVE_SCALE_UP_WAIT_TIME: float = 3.0 ADAPTIVE_SCALE_DOWN_WAIT_TIME: float = 60.0 WORKER_PKG: str = __name__.rpartition('.')[0] + '.' DEFAULT_CLIENT: str = 'default' HIGH_RSS_GRACE_PERIOD: tuple[int, int] = (20 * 3600, 30 * 3600) CURRENT_COMPILER_PROTOCOL = 2 logger = logging.getLogger("edb.server") log_metrics = logging.getLogger("edb.server.metrics") # Inherit sys.path so that import system can find worker class # in unittests. _ENV = os.environ.copy() _ENV['PYTHONPATH'] = ':'.join(sys.path) @functools.lru_cache(maxsize=4) def _pickle_memoized(obj: Any) -> bytes: return pickle.dumps(obj, -1) class BaseWorker: _dbs: collections.OrderedDict[str, state.PickledDatabaseState] _backend_runtime_params: pgparams.BackendRuntimeParams _std_schema: s_schema.Schema _refl_schema: s_schema.Schema _schema_class_layout: s_refl.SchemaClassLayout _global_schema_pickle: bytes _system_config: Config _last_pickled_state: Optional[bytes] _con: Optional[amsg.HubConnection] _last_used: float _closed: bool def __init__( self, backend_runtime_params: pgparams.BackendRuntimeParams, std_schema: s_schema.Schema, refl_schema: s_schema.Schema, schema_class_layout: s_refl.SchemaClassLayout, global_schema_pickle: bytes, system_config: Config, ) -> None: self._dbs = collections.OrderedDict() self._backend_runtime_params = backend_runtime_params self._std_schema = std_schema self._refl_schema = refl_schema self._schema_class_layout = schema_class_layout self._global_schema_pickle = global_schema_pickle self._system_config = system_config self._last_pickled_state = None self._con = None self._last_used = time.monotonic() self._closed = False def get_db(self, name: str) -> Optional[state.PickledDatabaseState]: rv = self._dbs.get(name) if rv is not None: self._dbs.move_to_end(name, last=False) return rv def set_db(self, name: str, db: state.PickledDatabaseState) -> None: self._dbs[name] = db self._dbs.move_to_end(name, last=False) def prepare_evict_db(self, keep: int) -> list[str]: return list(self._dbs.keys())[keep:] def evict_db(self, name: str) -> Optional[state.PickledDatabaseState]: return self._dbs.pop(name, None) async def call( self, method_name: str, *args: Any, sync_state: Optional[SyncStateCallback] = None, ) -> Any: assert not self._closed assert self._con is not None if self._con.is_closed(): raise RuntimeError( 'the connection to the compiler worker process is ' 'unexpectedly closed') data = await self._request(method_name, args) status, *result = pickle.loads(data) self._last_used = time.monotonic() if status == 0: if sync_state is not None: sync_state() return result[0] elif status == 1: exc, tb = result if (sync_state is not None and not isinstance(exc, state.FailedStateSync)): sync_state() exc.__formatted_error__ = tb raise exc else: exc = RuntimeError( 'could not serialize result in worker subprocess') exc.__formatted_error__ = result[0] raise exc async def _request( self, method_name: str, args: tuple[Any, ...], ) -> memoryview: assert self._con is not None msg = pickle.dumps((method_name, args)) return await self._con.request(msg) class Worker(BaseWorker): _pid: int _proc: psutil.Process _manager: BaseLocalPool _server: amsg.Server _allow_high_rss_until: float def __init__( self, manager: BaseLocalPool, server: amsg.Server, pid: int, *args: Any, ) -> None: super().__init__(*args) self._pid = pid self._proc = psutil.Process(pid) self._manager = manager self._server = server grace_period = random.SystemRandom().randint(*HIGH_RSS_GRACE_PERIOD) self._allow_high_rss_until = time.monotonic() + grace_period async def _attach(self, init_args_pickled: bytes) -> None: self._manager._stats_spawned += 1 self._con = self._server.get_by_pid(self._pid) await self.call( '__init_worker__', init_args_pickled, ) def set_db(self, name: str, db: state.PickledDatabaseState) -> None: pid = str(self._pid) old_size: Optional[int] = None if (old_db := self._dbs.get(name)) is not None: old_size = old_db.get_estimated_size() super().set_db(name, db) metrics.compiler_process_schema_size.inc( db.get_estimated_size() - (old_size or 0), pid, DEFAULT_CLIENT ) if old_size is None: action = 'cache-add' metrics.compiler_process_branches.set( len(self._dbs), pid, DEFAULT_CLIENT ) else: action = 'cache-update' metrics.compiler_process_branch_actions.inc( 1, pid, DEFAULT_CLIENT, action ) def evict_db(self, name: str) -> Optional[state.PickledDatabaseState]: pid = str(self._pid) db = self._dbs.get(name) super().evict_db(name) if db is not None: metrics.compiler_process_schema_size.dec( db.get_estimated_size(), pid, DEFAULT_CLIENT ) metrics.compiler_process_branch_actions.inc( 1, pid, DEFAULT_CLIENT, 'cache-evict' ) return db def get_pid(self) -> int: return self._pid def get_rss(self) -> int: return self._proc.memory_info().rss def maybe_close_for_high_rss(self, max_rss: int) -> bool: if time.monotonic() > self._allow_high_rss_until: rss = self.get_rss() if rss > max_rss: logger.info( 'worker process with PID %d exceeds high RSS limit ' '(%d > %d), killing now', self._pid, rss, max_rss, ) self.close() return True return False def close(self) -> None: if self._closed: return self._closed = True metrics.compiler_process_kills.inc() self._manager._stats_killed += 1 self._manager._workers.pop(self._pid, None) self._manager._report_worker(self, action="kill") try: os.kill(self._pid, signal.SIGTERM) except ProcessLookupError: pass class AbstractPool[ BaseWorker_T: BaseWorker, InitArgs_T, InitArgsPickle_T, ]: _loop: asyncio.AbstractEventLoop _worker_branch_limit: int _backend_runtime_params: pgparams.BackendRuntimeParams _std_schema: s_schema.Schema _refl_schema: s_schema.Schema _schema_class_layout: s_refl.SchemaClassLayout _dbindex: Optional[dbview.DatabaseIndex] = None _last_active_time: float def __init__( self, *, loop: asyncio.AbstractEventLoop, worker_branch_limit: int, **kwargs: Any, ) -> None: self._loop = loop self._worker_branch_limit = worker_branch_limit self._init(kwargs) def _init(self, kwargs: dict[str, Any]) -> None: self._backend_runtime_params = kwargs["backend_runtime_params"] self._std_schema = kwargs["std_schema"] self._refl_schema = kwargs["refl_schema"] self._schema_class_layout = kwargs["schema_class_layout"] self._dbindex = kwargs.get("dbindex") self._last_active_time = 0 def _get_init_args(self) -> tuple[InitArgs_T, InitArgsPickle_T]: assert self._dbindex is not None return self._make_cached_init_args( *self._dbindex.get_cached_compiler_args() ) def _make_cached_init_args( self, global_schema_pickle: bytes, system_config: Config, ) -> tuple[InitArgs_T, InitArgsPickle_T]: raise NotImplementedError def _make_init_args( self, global_schema_pickle: bytes, system_config: Config, ) -> InitArgs: return ( self._backend_runtime_params, self._std_schema, self._refl_schema, self._schema_class_layout, global_schema_pickle, system_config, ) async def start(self) -> None: raise NotImplementedError async def stop(self) -> None: raise NotImplementedError def get_template_pid(self) -> Optional[int]: return None async def _compute_compile_preargs( self, method_name: str, worker: BaseWorker_T, dbname: str, user_schema_pickle: bytes, global_schema_pickle: bytes, reflection_cache: state.ReflectionCache, database_config: Config, system_config: Config, ) -> tuple[PreArgs, Optional[SyncStateCallback], SyncFinalizer]: def sync_worker_state_cb( *, worker: BaseWorker_T, dbname: str, user_schema_pickle: Optional[bytes] = None, global_schema_pickle: Optional[bytes] = None, reflection_cache: Optional[state.ReflectionCache] = None, database_config: Optional[Config] = None, system_config: Optional[Config] = None, evicted_dbs: Optional[list[str]] = None, ): if evicted_dbs is not None: for name in evicted_dbs: worker.evict_db(name) worker_db = worker.get_db(dbname) if worker_db is None: assert user_schema_pickle is not None assert reflection_cache is not None assert global_schema_pickle is not None assert database_config is not None assert system_config is not None worker.set_db( dbname, state.PickledDatabaseState( user_schema_pickle=user_schema_pickle, reflection_cache=reflection_cache, database_config=database_config, ), ) worker._global_schema_pickle = global_schema_pickle worker._system_config = system_config else: if ( user_schema_pickle is not None or reflection_cache is not None or database_config is not None ): worker.set_db( dbname, state.PickledDatabaseState( user_schema_pickle=( user_schema_pickle or worker_db.user_schema_pickle ), reflection_cache=( worker_db.reflection_cache if reflection_cache is None else reflection_cache ), database_config=( worker_db.database_config if database_config is None else database_config ), ), ) if global_schema_pickle is not None: worker._global_schema_pickle = global_schema_pickle if system_config is not None: worker._system_config = system_config worker_db = worker.get_db(dbname) preargs: list[Any] = [method_name, dbname] to_update: dict[str, Any] = {} branch_cache_hit = True if worker_db is None: branch_cache_hit = False evicted_dbs = worker.prepare_evict_db( self._worker_branch_limit - 1 ) preargs.extend([ evicted_dbs, user_schema_pickle, _pickle_memoized(reflection_cache), global_schema_pickle, _pickle_memoized(database_config), _pickle_memoized(system_config), ]) to_update = { 'evicted_dbs': evicted_dbs, 'user_schema_pickle': user_schema_pickle, 'reflection_cache': reflection_cache, 'global_schema_pickle': global_schema_pickle, 'database_config': database_config, 'system_config': system_config, } else: preargs.append([]) # evicted_dbs if worker_db.user_schema_pickle is not user_schema_pickle: branch_cache_hit = False preargs.append(user_schema_pickle) to_update['user_schema_pickle'] = user_schema_pickle else: preargs.append(None) if worker_db.reflection_cache is not reflection_cache: branch_cache_hit = False preargs.append(_pickle_memoized(reflection_cache)) to_update['reflection_cache'] = reflection_cache else: preargs.append(None) if worker._global_schema_pickle is not global_schema_pickle: preargs.append(global_schema_pickle) to_update['global_schema_pickle'] = global_schema_pickle else: preargs.append(None) if worker_db.database_config is not database_config: branch_cache_hit = False preargs.append(_pickle_memoized(database_config)) to_update['database_config'] = database_config else: preargs.append(None) if worker._system_config is not system_config: preargs.append(_pickle_memoized(system_config)) to_update['system_config'] = system_config else: preargs.append(None) self._report_branch_request(worker, branch_cache_hit) if to_update: callback = functools.partial( sync_worker_state_cb, worker=worker, dbname=dbname, **to_update ) else: callback = None return tuple(preargs), callback, lambda: None def _report_branch_request( self, worker: BaseWorker_T, cache_hit: bool, client: str = DEFAULT_CLIENT, ) -> None: pass async def _acquire_worker( self, *, condition: Optional[queue.AcquireCondition[BaseWorker_T]] = None, weighter: Optional[queue.Weighter[BaseWorker_T]] = None, **compiler_args: Any, ) -> BaseWorker_T: raise NotImplementedError def _release_worker( self, worker: BaseWorker_T, *, put_in_front: bool = True, ) -> None: raise NotImplementedError async def compile( self, dbname: str, user_schema_pickle: bytes, global_schema_pickle: bytes, reflection_cache: state.ReflectionCache, database_config: Config, system_config: Config, *compile_args: Any, **compiler_args: Any, ) -> tuple[dbstate.QueryUnitGroup, bytes, int]: fini = lambda: None worker = await self._acquire_worker(**compiler_args) try: preargs, sync_state, fini = await self._compute_compile_preargs( "compile", worker, dbname, user_schema_pickle, global_schema_pickle, reflection_cache, database_config, system_config, ) result = await worker.call( *preargs, *compile_args, sync_state=sync_state ) worker._last_pickled_state = result[1] if len(result) == 2: return *result, 0 else: return result finally: fini() self._release_worker(worker) async def compile_in_tx( self, dbname: str, user_schema_pickle: bytes, txid: int, pickled_state: bytes, state_id: int, *compile_args: Any, **compiler_args: Any, ) -> tuple[dbstate.QueryUnitGroup, bytes, int]: # When we compile a query, the compiler returns a tuple: # a QueryUnit and the state the compiler is in if it's in a # transaction. The state contains the information about all savepoints # and transient schema changes, so the next time we need to # compile a new query in this transaction the state is needed # to be passed to the next compiler compiling it. # # The compile state can be quite heavy and contain multiple versions # of schema, configs, and other session-related data. So the compiler # worker pickles it before sending it to the IO process, and the # IO process doesn't need to ever unpickle it. # # There's one crucial optimization we do here though. We try to # find the compiler process that we used before, that already has # this state unpickled. If we can find it, it means that the # compiler process won't have to waste time unpickling the state. # # We use "is" in `w._last_pickled_state is pickled_state` deliberately, # because `pickled_state` is saved on the Worker instance and # stored in edgecon; we never modify it, so `is` is sufficient and # is faster than `==`. worker = await self._acquire_worker( condition=lambda w: (w._last_pickled_state is pickled_state), compiler_args=compiler_args, ) dbname_arg = None user_schema_pickle_arg = None if worker._last_pickled_state is pickled_state: # Since we know that this particular worker already has the # state, we don't want to waste resources transferring the # state over the network. So we replace the state with a marker, # that the compiler process will recognize. pickled_state = state.REUSE_LAST_STATE_MARKER else: worker_db = worker.get_db(dbname) if ( worker_db is not None and worker_db.user_schema_pickle is user_schema_pickle ): dbname_arg = dbname else: user_schema_pickle_arg = user_schema_pickle try: units, new_pickled_state = await worker.call( 'compile_in_tx', dbname_arg, user_schema_pickle_arg, pickled_state, txid, *compile_args ) worker._last_pickled_state = new_pickled_state return units, new_pickled_state, 0 finally: # Put the worker at the end of the queue so that the chance # of reusing it later (and maximising the chance of # the w._last_pickled_state is pickled_state` check returning # `True` is higher. self._release_worker(worker, put_in_front=False) async def compile_notebook( self, dbname: str, user_schema_pickle: bytes, global_schema_pickle: bytes, reflection_cache: state.ReflectionCache, database_config: Config, system_config: Config, *compile_args: Any, **compiler_args: Any, ) -> list[ tuple[ bool, dbstate.QueryUnit | tuple[str, str, dict[int, str]] ] ]: fini = lambda: None worker = await self._acquire_worker(**compiler_args) try: preargs, sync_state, fini = await self._compute_compile_preargs( "compile_notebook", worker, dbname, user_schema_pickle, global_schema_pickle, reflection_cache, database_config, system_config, ) return await worker.call( *preargs, *compile_args, sync_state=sync_state ) finally: fini() self._release_worker(worker) async def compile_graphql( self, dbname: str, user_schema_pickle: bytes, global_schema_pickle: bytes, reflection_cache: state.ReflectionCache, database_config: Config, system_config: Config, *compile_args: Any, **compiler_args: Any, ) -> graphql.TranspiledOperation: fini = lambda: None worker = await self._acquire_worker(**compiler_args) try: preargs, sync_state, fini = await self._compute_compile_preargs( "compile_graphql", worker, dbname, user_schema_pickle, global_schema_pickle, reflection_cache, database_config, system_config, ) return await worker.call( *preargs, *compile_args, sync_state=sync_state ) finally: fini() self._release_worker(worker) async def compile_sql( self, dbname: str, user_schema_pickle: bytes, global_schema_pickle: bytes, reflection_cache: state.ReflectionCache, database_config: Config, system_config: Config, *compile_args: Any, **compiler_args: Any, ) -> list[dbstate.SQLQueryUnit]: fini = lambda: None worker = await self._acquire_worker(**compiler_args) try: preargs, sync_state, fini = await self._compute_compile_preargs( "compile_sql", worker, dbname, user_schema_pickle, global_schema_pickle, reflection_cache, database_config, system_config, ) return await worker.call( *preargs, *compile_args, sync_state=sync_state ) finally: fini() self._release_worker(worker) # We use a helper function instead of just fully generating the # functions in order to make the backtraces a little better. async def _simple_call(self, name: str, *args: Any, **kwargs: Any) -> Any: worker = await self._acquire_worker() try: return await worker.call( name, *args, **kwargs ) finally: self._release_worker(worker) async def interpret_backend_error( self, user_schema: bytes, global_schema: bytes, error_fields: dict[str, str], from_graphql: bool, ) -> errors.EdgeDBError: return await self._simple_call( 'interpret_backend_error', user_schema, global_schema, error_fields, from_graphql, ) async def parse_global_schema(self, global_schema_json: bytes) -> bytes: return await self._simple_call( 'parse_global_schema', global_schema_json ) async def parse_user_schema_db_config( self, user_schema_json: bytes, db_config_json: bytes, global_schema_pickle: bytes, ) -> dbstate.ParsedDatabase: return await self._simple_call( 'parse_user_schema_db_config', user_schema_json, db_config_json, global_schema_pickle, ) async def make_state_serializer( self, protocol_version: defines.ProtocolVersion, user_schema_pickle: bytes, global_schema_pickle: bytes, ) -> sertypes.StateSerializer: return await self._simple_call( 'make_state_serializer', protocol_version, user_schema_pickle, global_schema_pickle, ) async def make_compilation_config_serializer( self ) -> sertypes.CompilationConfigSerializer: return await self._simple_call('make_compilation_config_serializer') async def describe_database_dump( self, user_schema_json: bytes, global_schema_json: bytes, db_config_json: bytes, protocol_version: defines.ProtocolVersion, with_secrets: bool, ) -> compiler.DumpDescriptor: return await self._simple_call( 'describe_database_dump', user_schema_json, global_schema_json, db_config_json, protocol_version, with_secrets, ) async def describe_database_restore( self, user_schema_pickle: bytes, global_schema_pickle: bytes, dump_server_ver_str: str, dump_catalog_version: Optional[int], schema_ddl: bytes, schema_ids: list[tuple[str, str, bytes]], blocks: list[tuple[bytes, bytes]], # type_id, typespec protocol_version: defines.ProtocolVersion, ) -> compiler.RestoreDescriptor: return await self._simple_call( 'describe_database_restore', user_schema_pickle, global_schema_pickle, dump_server_ver_str, dump_catalog_version, schema_ddl, schema_ids, blocks, protocol_version, ) async def analyze_explain_output( self, query_asts_pickled: bytes, data: list[list[bytes]], ) -> bytes: return await self._simple_call( 'analyze_explain_output', query_asts_pickled, data ) async def validate_schema_equivalence( self, schema_a: bytes, schema_b: bytes, global_schema: bytes, conn_state_pickle: Optional[bytes], ) -> None: return await self._simple_call( 'validate_schema_equivalence', schema_a, schema_b, global_schema, conn_state_pickle, ) async def compile_structured_config( self, objects: Mapping[str, config_compiler.ConfigObject], source: str | None = None, allow_nested: bool = False, ) -> dict[str, Config]: return await self._simple_call( 'compile_structured_config', objects, source, allow_nested ) def get_debug_info(self) -> dict[str, Any]: return {} def get_size_hint(self) -> int: raise NotImplementedError def refresh_metrics(self) -> None: pass def _maybe_update_last_active_time(self) -> None: if sys.exc_info()[0] is None: self._last_active_time = time.monotonic() async def health_check(self) -> bool: elapsed = time.monotonic() - self._last_active_time if elapsed > HEALTH_CHECK_MIN_INTERVAL: async with asyncio.timeout(HEALTH_CHECK_TIMEOUT): await self.make_compilation_config_serializer() self._maybe_update_last_active_time() return True class BaseLocalPool[Worker_T: Worker, InitArgs_T]( AbstractPool[Worker_T, InitArgs_T, bytes], amsg.ServerProtocol, asyncio.SubprocessProtocol, ): _worker_class: type[Worker_T] _worker_mod: str = "worker" _workers_queue: queue.WorkerQueue[Worker_T] _workers: dict[int, Worker_T] _poolsock_name: str _pool_size: int _worker_max_rss: Optional[int] _server: Optional[amsg.Server] _ready_evt: asyncio.Event _running: Optional[bool] _stats_spawned: int _stats_killed: int def __init__( self, *, runstate_dir: str, pool_size: int, worker_max_rss: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) self._poolsock_name = os.path.join(runstate_dir, 'ipc') assert len(self._poolsock_name) <= ( defines.MAX_RUNSTATE_DIR_PATH + defines.MAX_UNIX_SOCKET_PATH_LENGTH + 1 ), "pool IPC socket length exceeds maximum allowed" assert pool_size >= 1 self._pool_size = pool_size self._worker_max_rss = worker_max_rss self._workers = {} self._server = amsg.Server(self._poolsock_name, self._loop, self) self._ready_evt = asyncio.Event() self._running = None self._stats_spawned = 0 self._stats_killed = 0 def _report_branch_request( self, worker: Worker_T, cache_hit: bool, client: str = DEFAULT_CLIENT ) -> None: pid = str(worker.get_pid()) metrics.compiler_process_branch_actions.inc( 1, pid, client, 'request' ) if cache_hit: metrics.compiler_process_branch_actions.inc( 1, pid, client, 'cache-hit' ) def is_running(self) -> bool: return bool(self._running) async def _attach_worker(self, pid: int) -> Optional[Worker_T]: if not self._running: return None assert self._server is not None logger.debug("Sending init args to worker with PID %s.", pid) init_args, init_args_pickled = self._get_init_args() worker = self._worker_class( # type: ignore self, self._server, pid, *init_args, ) await worker._attach(init_args_pickled) self._report_worker(worker) self._workers[pid] = worker self._workers_queue.release(worker) self._worker_attached() logger.debug("started compiler worker process (PID %s)", pid) if ( not self._ready_evt.is_set() and len(self._workers) == self._pool_size ): logger.info( f"started {self._pool_size} compiler worker " f"process{'es' if self._pool_size > 1 else ''}", ) self._ready_evt.set() return worker def _worker_attached(self) -> None: pass def worker_connected(self, pid: int, version: int) -> None: logger.debug("Worker with PID %s connected.", pid) self._loop.create_task(self._attach_worker(pid)) metrics.compiler_process_spawns.inc() metrics.current_compiler_processes.inc() def worker_disconnected(self, pid: int) -> None: logger.debug("Worker with PID %s disconnected.", pid) self._workers.pop(pid, None) metrics.current_compiler_processes.dec() expect = str(pid) def pid_filter(pid_str: str, *remaining_tags) -> bool: return pid_str == expect metrics.compiler_process_memory.clear(pid_filter) metrics.compiler_process_schema_size.clear(pid_filter) metrics.compiler_process_branches.clear(pid_filter) metrics.compiler_process_branch_actions.clear(pid_filter) metrics.compiler_process_client_actions.clear(pid_filter) async def start(self) -> None: if self._running is not None: raise RuntimeError( 'the compiler pool has already been started once') assert self._server is not None self._workers_queue = queue.WorkerQueue(self._loop) await self._server.start() self._running = True await self._start() await self._wait_ready() async def _wait_ready(self) -> None: await asyncio.wait_for( self._ready_evt.wait(), PROCESS_INITIAL_RESPONSE_TIMEOUT ) async def _create_compiler_process( self, numproc: Optional[int] = None, version: int = 0 ) -> asyncio.SubprocessTransport: # Create a new compiler process. When numproc is None, a single # standalone compiler worker process is started; if numproc is an int, # a compiler template process will be created, which will then fork # itself into `numproc` actual worker processes and run as a supervisor env = _ENV if debug.flags.server: env = {'EDGEDB_DEBUG_SERVER': '1', **_ENV} cmdline = [sys.executable] if sys.flags.isolated: cmdline.append('-I') cmdline.extend([ '-m', WORKER_PKG + self._worker_mod, '--sockname', self._poolsock_name, '--version-serial', str(version), ]) if numproc: cmdline.extend([ '--numproc', str(numproc), ]) transport, _ = await self._loop.subprocess_exec( lambda: self, *cmdline, env=env, stdin=subprocess.DEVNULL, stdout=None, stderr=None, ) return transport async def _start(self) -> None: raise NotImplementedError async def stop(self) -> None: if not self._running: return self._running = False assert self._server is not None await self._server.stop() self._server = None self._workers_queue = queue.WorkerQueue(self._loop) self._workers.clear() await self._stop() async def _stop(self) -> None: raise NotImplementedError def _report_worker( self, worker: Worker_T, *, action: str = "spawn" ) -> None: action = action.capitalize() if not action.endswith("e"): action += "e" action += "d" log_metrics.info( "%s a compiler worker with PID %d; pool=%d;" + " spawned=%d; killed=%d", action, worker.get_pid(), len(self._workers), self._stats_spawned, self._stats_killed, ) async def _acquire_worker( self, *, condition: Optional[queue.AcquireCondition[Worker_T]] = None, weighter: Optional[queue.Weighter[Worker_T]] = None, **compiler_args: Any, ) -> Worker_T: start_time = time.monotonic() try: while ( worker := await self._workers_queue.acquire( condition=condition, weighter=weighter ) ).get_pid() not in self._workers: # The worker was disconnected; skip to the next one. pass except TimeoutError: metrics.compiler_pool_queue_errors.inc(1.0, "timeout") raise except Exception: metrics.compiler_pool_queue_errors.inc(1.0, "ise") raise else: metrics.compiler_pool_wait_time.observe( time.monotonic() - start_time ) return worker def _release_worker( self, worker: Worker_T, *, put_in_front: bool = True, ) -> None: # Skip disconnected workers if worker.get_pid() in self._workers: if self._worker_max_rss is not None: if worker.maybe_close_for_high_rss(self._worker_max_rss): return self._workers_queue.release(worker, put_in_front=put_in_front) self._maybe_update_last_active_time() def get_debug_info(self) -> dict[str, Any]: return dict( worker_pids=list(self._workers.keys()), template_pid=self.get_template_pid(), ) def refresh_metrics(self) -> None: for w in self._workers.values(): metrics.compiler_process_memory.set(w.get_rss(), str(w.get_pid())) async def health_check(self) -> bool: if not ( self._running and self._ready_evt.is_set() and len(self._workers) > 0 ): return False return await super().health_check() class FixedPoolImpl[Worker_T: Worker, InitArgs_T]( BaseLocalPool[Worker_T, InitArgs_T], ): _template_transport: Optional[asyncio.SubprocessTransport] _template_proc_scheduled: bool _template_proc_version: int def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self._template_transport = None self._template_proc_scheduled = False self._template_proc_version = 0 def _worker_attached(self) -> None: if not self._running: return assert self._server is not None if len(self._workers) > self._pool_size: self._server.kill_outdated_worker(self._template_proc_version) def worker_connected(self, pid: int, version: int) -> None: if not self._running: return assert self._server is not None if version < self._template_proc_version: logger.debug( "Outdated worker with PID %s connected; discard now.", pid ) self._server.get_by_pid(pid).abort() metrics.compiler_process_spawns.inc() else: super().worker_connected(pid, version) def process_exited(self) -> None: # Template process exited self._template_transport = None if self._running: logger.error("Template compiler process exited; recreating now.") self._schedule_template_proc(0) def get_template_pid(self) -> Optional[int]: if self._template_transport is None: return None else: return self._template_transport.get_pid() async def _start(self) -> None: await self._create_template_proc(retry=False) async def _create_template_proc(self, retry: bool = True) -> None: self._template_proc_scheduled = False if not self._running: return self._template_proc_version += 1 try: # Create the template process, which will then fork() into numproc # child processes and manage them, so that we don't have to manage # the actual compiler worker processes in the main process. self._template_transport = await self._create_compiler_process( numproc=self._pool_size, version=self._template_proc_version, ) except Exception: if retry: if self._running: t = defines.BACKEND_COMPILER_TEMPLATE_PROC_RESTART_INTERVAL logger.exception( f"Unexpected error occurred creating template compiler" f" process; retry in {t} second{'s' if t > 1 else ''}." ) self._schedule_template_proc(t) else: raise def _schedule_template_proc(self, sleep: float) -> None: if self._template_proc_scheduled: return self._template_proc_scheduled = True self._loop.call_later( sleep, self._loop.create_task, self._create_template_proc() ) async def _stop(self) -> None: trans, self._template_transport = self._template_transport, None if trans is not None: trans.terminate() await trans._wait() # type: ignore trans.close() def get_size_hint(self) -> int: return self._pool_size @srvargs.CompilerPoolMode.Fixed.assign_implementation class FixedPool(FixedPoolImpl[Worker, InitArgs]): _worker_class = Worker @lru.lru_method_cache(1) def _make_cached_init_args( self, global_schema_pickle: bytes, system_config: Config, ) -> tuple[InitArgs, bytes]: init_args = self._make_init_args( global_schema_pickle, system_config ) pickled_args = pickle.dumps(init_args, -1) return init_args, pickled_args @srvargs.CompilerPoolMode.OnDemand.assign_implementation class SimpleAdaptivePool(BaseLocalPool[Worker, InitArgs]): _worker_class = Worker _worker_transports: dict[int, asyncio.SubprocessTransport] _expected_num_workers: int _scale_down_handle: Optional[asyncio.Handle] _max_num_workers: int _cleanups: dict[int, asyncio.Future] def __init__(self, *, pool_size: int, **kwargs: Any) -> None: super().__init__(pool_size=1, **kwargs) self._worker_transports = {} self._expected_num_workers = 0 self._scale_down_handle = None self._max_num_workers = pool_size self._cleanups = {} @lru.lru_method_cache(1) def _make_cached_init_args( self, global_schema_pickle: bytes, system_config: Config, ) -> tuple[InitArgs, bytes]: init_args = self._make_init_args( global_schema_pickle, system_config ) pickled_args = pickle.dumps(init_args, -1) return init_args, pickled_args async def _start(self) -> None: async with asyncio.TaskGroup() as g: for _i in range(self._pool_size): g.create_task(self._create_worker()) async def _stop(self) -> None: self._expected_num_workers = 0 transports, self._worker_transports = self._worker_transports, {} for transport in transports.values(): await transport._wait() # type: ignore transport.close() for cleanup in list(self._cleanups.values()): await cleanup async def _acquire_worker( self, *, condition: Optional[queue.AcquireCondition[Worker]] = None, weighter: Optional[queue.Weighter[Worker]] = None, **compiler_args: Any, ) -> Worker: scale_up_handle = None if ( self._running and self._workers_queue.qsize() == 0 and ( len(self._workers) == self._expected_num_workers < self._max_num_workers ) ): scale_up_handle = self._loop.call_later( ADAPTIVE_SCALE_UP_WAIT_TIME, self._maybe_scale_up ) if self._scale_down_handle is not None: self._scale_down_handle.cancel() self._scale_down_handle = None try: return await super()._acquire_worker( condition=condition, weighter=weighter, **compiler_args ) finally: if scale_up_handle is not None: scale_up_handle.cancel() def _release_worker( self, worker: Worker, *, put_in_front: bool = True, ) -> None: if self._scale_down_handle is not None: self._scale_down_handle.cancel() self._scale_down_handle = None super()._release_worker(worker, put_in_front=put_in_front) if ( self._running and self._workers_queue.count_waiters() == 0 and len(self._workers) > self._pool_size ): self._scale_down_handle = self._loop.call_later( ADAPTIVE_SCALE_DOWN_WAIT_TIME, self._scale_down, ) async def _wait_on_dying( self, pid: int, trans: asyncio.SubprocessTransport, ) -> None: await trans._wait() # type: ignore self._cleanups.pop(pid) def worker_disconnected(self, pid: int) -> None: num_workers_before = len(self._workers) super().worker_disconnected(pid) trans = self._worker_transports.pop(pid, None) if trans: trans.close() # amsg.Server notifies us when the *pipe* to the worker closes, # so we need to fire off a task to make sure that we wait for # the worker to exit, in order to avoid a warning. self._cleanups[pid] = ( self._loop.create_task(self._wait_on_dying(pid, trans))) if not self._running: return if len(self._workers) < self._pool_size: # The auto-scaler will not scale down below the pool_size, so we # should restart the unexpectedly-exited worker process. logger.warning( "Compiler worker process[%d] exited unexpectedly; " "start a new one now.", pid ) self._loop.create_task(self._create_worker()) self._expected_num_workers = len(self._workers) elif num_workers_before == self._expected_num_workers: # This is likely the case when a worker died unexpectedly, and we # don't want to restart the worker because the auto-scaler will # start a new one again if necessary. self._expected_num_workers = len(self._workers) def process_exited(self) -> None: if self._running: for pid, transport in list(self._worker_transports.items()): if transport.is_closing(): self._worker_transports.pop(pid, None) async def _create_worker(self) -> None: # Creates a single compiler worker process. transport = await self._create_compiler_process() self._worker_transports[transport.get_pid()] = transport self._expected_num_workers += 1 def _maybe_scale_up(self) -> None: if not self._running: return logger.info( "A compile request has waited for more than %d seconds, " "spawn a new compiler worker process now.", ADAPTIVE_SCALE_UP_WAIT_TIME, ) self._loop.create_task(self._create_worker()) def _scale_down(self) -> None: self._scale_down_handle = None if not self._running or len(self._workers) <= self._pool_size: return logger.info( "The compiler pool is not used in %d seconds, scaling down to %d.", ADAPTIVE_SCALE_DOWN_WAIT_TIME, self._pool_size, ) self._expected_num_workers = self._pool_size for worker in sorted( self._workers.values(), key=lambda w: w._last_used )[:-self._pool_size]: worker.close() def get_size_hint(self) -> int: return self._max_num_workers class RemoteWorker(BaseWorker): _con: amsg.HubConnection _secret: bytes def __init__( self, con: amsg.HubConnection, secret: bytes, *args: Any, ) -> None: super().__init__(*args) self._con = con self._secret = secret def close(self) -> None: if self._closed: return self._closed = True self._con.abort() async def _request( self, method_name: str, args: tuple[Any, ...], ) -> memoryview: msg = pickle.dumps((method_name, args)) digest = hmac.digest(self._secret, msg, "sha256") return await self._con.request(digest + msg) @srvargs.CompilerPoolMode.Remote.assign_implementation class RemotePool(AbstractPool[RemoteWorker, InitArgs, RemoteInitArgsPickle]): _pool_addr: tuple[str, int] _worker: Optional[asyncio.Future[RemoteWorker]] _sync_lock: asyncio.Lock _semaphore: asyncio.BoundedSemaphore _pool_size: int _secret: bytes def __init__( self, *, address: tuple[str, int], pool_size: int, **kwargs: Any, ) -> None: super().__init__(**kwargs) self._pool_addr = address self._worker = None self._sync_lock = asyncio.Lock() self._semaphore = asyncio.BoundedSemaphore(pool_size) self._pool_size = pool_size secret = os.environ.get("_EDGEDB_SERVER_COMPILER_POOL_SECRET") if not secret: raise AssertionError( "_EDGEDB_SERVER_COMPILER_POOL_SECRET environment variable " "is not set" ) self._secret = secret.encode() async def start(self, *, retry: bool = False) -> None: if self._worker is None: self._worker = self._loop.create_future() try: def on_pid(*args: Any) -> None: self._loop.create_task(self._connection_made(retry, *args)) await self._loop.create_connection( lambda: amsg.HubProtocol( loop=self._loop, on_pid=on_pid, on_connection_lost=self._connection_lost, ), *self._pool_addr, ) except Exception: if not retry: raise if self._worker is not None: self._loop.call_later(1, lambda: self._loop.create_task( self.start(retry=True) )) else: if not retry: await self._worker async def stop(self) -> None: if self._worker is not None: worker, self._worker = self._worker, None if worker.done(): (await worker).close() @lru.lru_method_cache(1) def _make_cached_init_args( self, global_schema_pickle: bytes, system_config: Config, ) -> tuple[InitArgs, RemoteInitArgsPickle]: init_args = self._make_init_args( global_schema_pickle, system_config, ) std_args = ( self._std_schema, self._refl_schema, self._schema_class_layout ) client_args = (self._backend_runtime_params,) return init_args, ( pickle.dumps(std_args, -1), pickle.dumps(client_args, -1), global_schema_pickle, pickle.dumps(system_config, -1), ) async def _connection_made( self, retry: bool, protocol: amsg.HubProtocol, transport: asyncio.Transport, _pid: int, version: int, ) -> None: if self._worker is None: return compiler_protocol = CURRENT_COMPILER_PROTOCOL try: init_args, init_args_pickled = self._get_init_args() worker = RemoteWorker( amsg.HubConnection(transport, protocol, self._loop, version), self._secret, *init_args, ) await worker.call( '__init_server__', compiler_protocol, defines.EDGEDB_CATALOG_VERSION, init_args_pickled, ) except state.IncompatibleClient as ex: transport.abort() if self._worker is not None: self._worker.set_exception(ex) self._worker = None except BaseException as ex: transport.abort() if self._worker is not None: if retry: await self.start(retry=True) else: self._worker.set_exception(ex) self._worker = None else: if self._worker is not None: self._worker.set_result(worker) def _connection_lost(self, _pid: Optional[int]) -> None: if self._worker is not None: self._worker = self._loop.create_future() self._loop.create_task(self.start(retry=True)) async def _acquire_worker(self, **compiler_args: Any) -> RemoteWorker: start_time = time.monotonic() try: await self._semaphore.acquire() assert self._worker is not None rv = await self._worker except TimeoutError: metrics.compiler_pool_queue_errors.inc(1.0, "timeout") raise except Exception: metrics.compiler_pool_queue_errors.inc(1.0, "ise") raise else: metrics.compiler_pool_wait_time.observe( time.monotonic() - start_time ) return rv def _release_worker( self, worker: RemoteWorker, *, put_in_front: bool = True, ) -> None: self._semaphore.release() self._maybe_update_last_active_time() async def compile_in_tx( self, dbname: str, user_schema_pickle: bytes, txid: int, pickled_state: bytes, state_id: int, *compile_args: Any, **compiler_args: Any, ) -> tuple[dbstate.QueryUnitGroup, bytes, int]: worker = await self._acquire_worker() try: return await worker.call( 'compile_in_tx', state_id, None, # client_id None, # dbname None, # user_schema_pickle state.REUSE_LAST_STATE_MARKER, txid, *compile_args ) except state.StateNotFound: return await worker.call( 'compile_in_tx', 0, # state_id None, # client_id None, # dbname user_schema_pickle, pickled_state, txid, *compile_args ) finally: self._release_worker(worker) async def _compute_compile_preargs( self, *args: Any ) -> tuple[PreArgs, Optional[SyncStateCallback], SyncFinalizer]: # State sync with the compiler server is serialized with _sync_lock, # also blocking any other compile requests that may sync state, so as # to avoid inconsistency. Meanwhile, we'd like to avoid locking when # sync is not needed (callback is None), so we have 2 fast paths here: # # 1. When _sync_lock is not held AND sync is not needed here, or # 2. after acquiring _sync_lock, we found that sync is not needed. # # In such cases, we avoid locking or release the lock immediately, so # that concurrent compile requests can proceed in parallel. preargs: PreArgs = () callback: Optional[SyncStateCallback] = lambda: None fini = lambda: None if not self._sync_lock.locked(): # Case 1: check if we need to sync state. ( preargs, callback, fini ) = await super()._compute_compile_preargs(*args) if callback is not None: # Check again with the lock acquired del preargs, callback await self._sync_lock.acquire() ( preargs, callback, fini ) = await super()._compute_compile_preargs(*args) if callback: # State sync is only considered done when we received a # successful response from the compiler server, when we # update the local state in the worker in the `callback` # function. We should usually release the lock after the # `callback`, but we must also release it if anything # failed along the way. fini = lambda: self._sync_lock.release() else: # Case 2: no state sync needed, release the lock immediately. self._sync_lock.release() return preargs, callback, fini def get_debug_info(self) -> dict[str, Any]: return dict( address="{}:{}".format(*self._pool_addr), size=self._semaphore._bound_value, # type: ignore free=self._semaphore._value, # type: ignore ) def get_size_hint(self) -> int: return self._pool_size async def health_check(self) -> bool: if self._worker is None or not self._worker.done(): return False return await super().health_check() @dataclasses.dataclass class TenantSchema: client_id: int dbs: collections.OrderedDict[str, state.PickledDatabaseState] global_schema_pickle: bytes system_config: Config def get_db(self, name: str) -> Optional[state.PickledDatabaseState]: rv = self.dbs.get(name) if rv is not None: self.dbs.move_to_end(name, last=False) return rv def set_db(self, name: str, db: state.PickledDatabaseState) -> None: self.dbs[name] = db self.dbs.move_to_end(name, last=False) def prepare_evict_db(self, keep: int) -> list[str]: return list(self.dbs.keys())[keep:] def evict_db(self, name: str) -> None: self.dbs.pop(name, None) def get_estimated_size(self) -> int: return sum(db.get_estimated_size() for db in self.dbs.values()) class PickledState(NamedTuple): user_schema: Optional[bytes] reflection_cache: Optional[bytes] database_config: Optional[bytes] class PickledSchema(NamedTuple): dbs: Optional[immutables.Map[str, PickledState]] = None global_schema: Optional[bytes] = None instance_config: Optional[bytes] = None dropped_dbs: tuple = () class BaseMultiTenantWorker[ TenantStore_T, BaseLocalPool_T: BaseLocalPool ](Worker): _manager: BaseLocalPool_T _cache: collections.OrderedDict[int, TenantStore_T] _invalidated_clients: list[int] _last_used_by_client: dict[int, float] _client_names: dict[int, str] def __init__( self, manager: BaseLocalPool_T, server: amsg.Server, pid: int, backend_runtime_params: pgparams.BackendRuntimeParams, std_schema: s_schema.Schema, refl_schema: s_schema.Schema, schema_class_layout: s_refl.SchemaClassLayout, ): super().__init__( manager, server, pid, backend_runtime_params, std_schema, refl_schema, schema_class_layout, None, None, ) self._cache = collections.OrderedDict() self._invalidated_clients = [] self._last_used_by_client = {} self._client_names = {} self._init() def _init(self) -> None: pass def get_tenant_schema( self, client_id: int, *, touch: bool = True ) -> Optional[TenantStore_T]: rv = self._cache.get(client_id) if rv is not None and touch: self._cache.move_to_end(client_id, last=False) return rv def set_tenant_schema( self, client_id: int, tenant_schema: TenantStore_T ) -> None: self._cache[client_id] = tenant_schema self._cache.move_to_end(client_id, last=False) self._last_used_by_client[client_id] = time.monotonic() def cache_size(self) -> int: return len(self._cache) - len(self._invalidated_clients) def last_used(self, client_id: int) -> float: return self._last_used_by_client.get(client_id, 0) def flush_invalidation(self) -> list[int]: evicted = 0 pid_str = str(self.get_pid()) evicted_names = set() client_ids, self._invalidated_clients = self._invalidated_clients, [] for client_id in client_ids: if self._cache.pop(client_id, None) is not None: evicted += 1 self._last_used_by_client.pop(client_id, None) client_name = self._client_names.pop(client_id, None) if client_name is not None: evicted_names.add(client_name) if evicted: metrics.compiler_process_client_actions.inc( evicted, pid_str, 'cache-evict' ) if evicted_names: def tag_filter(pid: str, client: str, *remaining_tags) -> bool: return pid == pid_str and client in evicted_names metrics.compiler_process_schema_size.clear(tag_filter) metrics.compiler_process_branches.clear(tag_filter) metrics.compiler_process_branch_actions.clear(tag_filter) return client_ids def set_client_name(self, client_id: int, name: Optional[str]) -> None: if client_id not in self._client_names: self._client_names[client_id] = name or f'unknown-{client_id}' def get_client_name(self, client_id: int) -> str: return self._client_names.get(client_id) or f'unknown-{client_id}' class MultiTenantWorker( BaseMultiTenantWorker[TenantSchema, "MultiTenantPool"] ): current_client_id: Optional[int] def _init(self) -> None: self.current_client_id = None def invalidate(self, client_id: int) -> None: if client_id in self._cache: self._invalidated_clients.append(client_id) def maybe_invalidate_last(self) -> None: if self.cache_size() == self._manager.cache_size: client_id = next(reversed(self._cache)) self._invalidated_clients.append(client_id) def get_invalidation(self) -> list[int]: return self._invalidated_clients[:] @srvargs.CompilerPoolMode.MultiTenant.assign_implementation class MultiTenantPool(FixedPoolImpl[MultiTenantWorker, MultiTenantInitArgs]): _worker_class = MultiTenantWorker _worker_mod = "multitenant_worker" def __init__(self, *, cache_size: int, **kwargs: Any) -> None: super().__init__(**kwargs) self._cache_size = cache_size @property def cache_size(self) -> int: return self._cache_size def drop_tenant(self, client_id: int) -> None: for worker in self._workers.values(): worker.invalidate(client_id) @lru.method_cache def _get_init_args(self) -> tuple[MultiTenantInitArgs, bytes]: init_args = ( self._backend_runtime_params, self._std_schema, self._refl_schema, self._schema_class_layout, ) return init_args, pickle.dumps(init_args, -1) def _weighter( self, client_id: int, worker: MultiTenantWorker, ) -> queue.Comparable: tenant_schema = worker.get_tenant_schema(client_id, touch=False) return ( bool(tenant_schema), worker.last_used(client_id) if tenant_schema else self._cache_size - worker.cache_size(), ) async def _acquire_worker( self, *, condition: Optional[queue.AcquireCondition[MultiTenantWorker]] = None, weighter: Optional[queue.Weighter[MultiTenantWorker]] = None, **compiler_args: Any ) -> MultiTenantWorker: client_id: Optional[int] = compiler_args.get("client_id") if weighter is None and client_id is not None: weighter = functools.partial(self._weighter, client_id) rv = await super()._acquire_worker( condition=condition, weighter=weighter, **compiler_args ) rv.current_client_id = client_id if client_id is not None: rv.set_client_name(client_id, compiler_args.get("client_name")) return rv def _release_worker( self, worker: MultiTenantWorker, *, put_in_front: bool = True, ) -> None: worker.current_client_id = None super()._release_worker(worker, put_in_front=put_in_front) async def _compute_compile_preargs( self, method_name: str, worker: MultiTenantWorker, dbname: str, user_schema_pickle: bytes, global_schema_pickle: bytes, reflection_cache: state.ReflectionCache, database_config: Config, system_config: Config, ) -> tuple[PreArgs, Optional[SyncStateCallback], SyncFinalizer]: assert isinstance(worker, MultiTenantWorker) def sync_worker_state_cb( *, worker: MultiTenantWorker, client_id: int, client_name: str, dbname: str, evicted_dbs: list[str], user_schema_pickle: Optional[bytes] = None, global_schema_pickle: Optional[bytes] = None, reflection_cache: Optional[state.ReflectionCache] = None, database_config: Optional[Config] = None, instance_config: Optional[Config] = None, ) -> None: pid = str(worker.get_pid()) tenant_schema = worker.get_tenant_schema(client_id) if tenant_schema is None: assert user_schema_pickle is not None assert reflection_cache is not None assert global_schema_pickle is not None assert database_config is not None assert instance_config is not None assert len(evicted_dbs) == 0 tenant_schema = TenantSchema( client_id, collections.OrderedDict( { dbname: state.PickledDatabaseState( user_schema_pickle, reflection_cache, database_config, ), } ), global_schema_pickle, instance_config, ) worker.set_tenant_schema(client_id, tenant_schema) metrics.compiler_process_branch_actions.inc( 1, pid, client_name, 'cache-add' ) metrics.compiler_process_client_actions.inc( 1, pid, 'cache-add' ) else: for name in evicted_dbs: tenant_schema.evict_db(name) if evicted_dbs: metrics.compiler_process_branch_actions.inc( len(evicted_dbs), pid, client_name, 'cache-evict' ) worker_db = tenant_schema.get_db(dbname) if worker_db is None: assert user_schema_pickle is not None assert reflection_cache is not None assert database_config is not None tenant_schema.set_db( dbname, state.PickledDatabaseState( user_schema_pickle=user_schema_pickle, reflection_cache=reflection_cache, database_config=database_config, ), ) metrics.compiler_process_branch_actions.inc( 1, pid, client_name, 'cache-add' ) metrics.compiler_process_client_actions.inc( 1, pid, 'cache-update' ) elif ( user_schema_pickle is not None or reflection_cache is not None or database_config is not None ): tenant_schema.set_db( dbname, state.PickledDatabaseState( user_schema_pickle=( user_schema_pickle or worker_db.user_schema_pickle ), reflection_cache=( reflection_cache or worker_db.reflection_cache ), database_config=( database_config or worker_db.database_config ), ) ) metrics.compiler_process_branch_actions.inc( 1, pid, client_name, 'cache-update' ) metrics.compiler_process_client_actions.inc( 1, pid, 'cache-update' ) if global_schema_pickle is not None: tenant_schema.global_schema_pickle = global_schema_pickle if instance_config is not None: tenant_schema.system_config = instance_config worker.flush_invalidation() metrics.compiler_process_schema_size.set( tenant_schema.get_estimated_size(), pid, client_name ) metrics.compiler_process_branches.set( len(tenant_schema.dbs), pid, client_name ) client_id = worker.current_client_id assert client_id is not None client_name = worker.get_client_name(client_id) tenant_schema = worker.get_tenant_schema(client_id, touch=False) to_update: dict[str, Hashable] evicted_dbs = [] branch_cache_hit = True if tenant_schema is None: branch_cache_hit = False # make room for the new client in this worker worker.maybe_invalidate_last() to_update = { "user_schema_pickle": user_schema_pickle, "reflection_cache": reflection_cache, "global_schema_pickle": global_schema_pickle, "database_config": database_config, "instance_config": system_config, } else: worker_db = tenant_schema.get_db(dbname) if worker_db is None: branch_cache_hit = False evicted_dbs = tenant_schema.prepare_evict_db( self._worker_branch_limit - 1 ) to_update = { "user_schema_pickle": user_schema_pickle, "reflection_cache": reflection_cache, "database_config": database_config, } else: to_update = {} if worker_db.user_schema_pickle is not user_schema_pickle: branch_cache_hit = False to_update["user_schema_pickle"] = user_schema_pickle if worker_db.reflection_cache is not reflection_cache: branch_cache_hit = False to_update["reflection_cache"] = reflection_cache if worker_db.database_config is not database_config: branch_cache_hit = False to_update["database_config"] = database_config if ( tenant_schema.global_schema_pickle is not global_schema_pickle ): to_update["global_schema_pickle"] = global_schema_pickle if tenant_schema.system_config is not system_config: to_update["instance_config"] = system_config self._report_branch_request(worker, branch_cache_hit, client_name) if to_update: pickled = { k.removesuffix("_pickle"): v if k.endswith("_pickle") else _pickle_memoized(v) for k, v in to_update.items() } if any(f in pickled for f in PickledState._fields): db_state = PickledState( **{ f: pickled.pop(f, None) for f in PickledState._fields } # type: ignore ) pickled["dbs"] = immutables.Map([(dbname, db_state)]) pickled_schema = PickledSchema( dropped_dbs=tuple(evicted_dbs), **pickled # type: ignore ) callback = functools.partial( sync_worker_state_cb, worker=worker, client_id=client_id, client_name=client_name, dbname=dbname, evicted_dbs=evicted_dbs, **to_update, # type: ignore ) else: pickled_schema = None callback = None metrics.compiler_process_client_actions.inc( 1, str(worker.get_pid()), 'cache-hit' ) return ( "call_for_client", client_id, pickled_schema, worker.get_invalidation(), None, # forwarded msg is only used in remote compiler server method_name, dbname, ), callback, lambda: None async def compile_in_tx( self, dbname: str, user_schema_pickle: bytes, txid: int, pickled_state: bytes, state_id: int, *compile_args: Any, **compiler_args: Any, ) -> tuple[dbstate.QueryUnitGroup, bytes, int]: client_id: int = compiler_args["client_id"] # Prefer a worker we used last time in the transaction (condition), or # (weighter) one with the user schema at tx start so that we can pass # over only the pickled state. Then prefer the least-recently used one # if many workers passed any check in the weighter, or the most vacant. def weighter(w: MultiTenantWorker) -> queue.Comparable: if ts := w.get_tenant_schema(client_id, touch=False): # Don't use ts.get_db() here to avoid confusing the LRU queue if db := ts.dbs.get(dbname): return ( True, db.user_schema_pickle is user_schema_pickle, w.last_used(client_id), ) else: return True, False, w.last_used(client_id) else: return False, False, self._cache_size - w.cache_size() worker = await self._acquire_worker( condition=lambda w: (w._last_pickled_state is pickled_state), weighter=cast(queue.Weighter, weighter), **compiler_args, ) # Avoid sending information that we know the worker already have. dbname_arg = None client_id_arg = None user_schema_pickle_arg = None if worker._last_pickled_state is pickled_state: pickled_state = state.REUSE_LAST_STATE_MARKER else: tenant_schema = worker.get_tenant_schema(client_id) if tenant_schema is None: # Just pass state + root user schema if this is a new client in # the worker; we don't want to initialize the client as we # don't have enough information to do so. user_schema_pickle_arg = user_schema_pickle else: # Don't use ts.get_db() here to avoid confusing the LRU queue worker_db = tenant_schema.dbs.get(dbname) if worker_db is None: # The worker has the client but not the database user_schema_pickle_arg = user_schema_pickle elif worker_db.user_schema_pickle is user_schema_pickle: # Avoid sending the root user schema because the worker has # it - just send client_id + dbname to reference it, as # well as the state of course. dbname_arg = dbname client_id_arg = client_id # Touch dbname to bump it in the LRU queue tenant_schema.get_db(dbname) else: # The worker has a different root user schema user_schema_pickle_arg = user_schema_pickle try: units, new_pickled_state = await worker.call( 'compile_in_tx', # multitenant_worker is also used in MultiSchemaPool for remote # compilers where the first argument "state_id" is used to find # worker without passing the pickled state. Here in multi- # tenant mode, we already have the pickled state, so "state_id" # is not used. Just prepend a fake ID to comply to the API. 0, # state_id client_id_arg, dbname_arg, user_schema_pickle_arg, pickled_state, txid, *compile_args ) worker._last_pickled_state = new_pickled_state return units, new_pickled_state, 0 finally: self._release_worker(worker, put_in_front=False) async def create_compiler_pool[AbstractPool_T: AbstractPool]( *, runstate_dir: str, pool_size: int, worker_branch_limit: int, backend_runtime_params: pgparams.BackendRuntimeParams, std_schema: s_schema.Schema, refl_schema: s_schema.Schema, schema_class_layout: s_refl.SchemaClassLayout, pool_class: type[AbstractPool_T], **kwargs: Any, ) -> AbstractPool_T: assert issubclass(pool_class, AbstractPool) loop = asyncio.get_running_loop() pool = pool_class( loop=loop, pool_size=pool_size, worker_branch_limit=worker_branch_limit, runstate_dir=runstate_dir, backend_runtime_params=backend_runtime_params, std_schema=std_schema, refl_schema=refl_schema, schema_class_layout=schema_class_layout, **kwargs, ) await pool.start() return pool ================================================ FILE: edb/server/compiler_pool/queue.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import asyncio import collections import typing W2 = typing.TypeVar('W2', contravariant=True) class AcquireCondition(typing.Protocol[W2]): def __call__(self, worker: W2) -> bool: ... class Comparable(typing.Protocol): def __gt__(self, other: typing.Self) -> bool: ... class Weighter(typing.Protocol[W2]): def __call__(self, worker: W2) -> Comparable: ... class WorkerQueue[W]: loop: asyncio.AbstractEventLoop _waiters: collections.deque[asyncio.Future[None]] _queue: collections.deque[W] def __init__( self, loop: asyncio.AbstractEventLoop, ) -> None: self._loop = loop self._waiters = collections.deque() self._queue = collections.deque() async def acquire( self, *, condition: typing.Optional[AcquireCondition[W]] = None, weighter: typing.Optional[Weighter[W]] = None, ) -> W: # There can be a race between a waiter scheduled for to wake up # and a worker being stolen (due to quota being enforced, # for example). In which case the waiter might get finally # woken up with an empty queue -- hence we use a `while` loop here. attempts = 0 while not self._queue: waiter = self._loop.create_future() attempts += 1 if attempts > 1: # If the waiter was woken up only to discover that # it needs to wait again, we don't want it to lose # its place in the waiters queue. self._waiters.appendleft(waiter) else: # On the first attempt the waiter goes to the end # of the waiters queue. self._waiters.append(waiter) try: await waiter except Exception: if not waiter.done(): waiter.cancel() try: self._waiters.remove(waiter) except ValueError: # The waiter could be removed from self._waiters # by a previous release() call. pass if self._queue and not waiter.cancelled(): # We were woken up by release(), but can't take # the call. Wake up the next in line. self._wakeup_next_waiter() raise if len(self._queue) > 1: if condition is not None: for w in self._queue: if condition(w): self._queue.remove(w) return w if weighter is not None: rv = self._queue[0] weight = weighter(rv) it = iter(self._queue) next(it) # skip the first for w in it: new_weight = weighter(w) if new_weight > weight: weight = new_weight rv = w self._queue.remove(rv) return rv return self._queue.popleft() def release(self, worker: W, *, put_in_front: bool=True) -> None: if put_in_front: self._queue.appendleft(worker) else: self._queue.append(worker) self._wakeup_next_waiter() def qsize(self) -> int: return len(self._queue) def count_waiters(self) -> int: return len(self._waiters) def _wakeup_next_waiter(self) -> None: while self._waiters: waiter = self._waiters.popleft() if not waiter.done(): waiter.set_result(None) break ================================================ FILE: edb/server/compiler_pool/server.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Callable, cast, NamedTuple, Optional, Sequence import asyncio import hmac import functools import logging import os import pathlib import pickle import secrets import tempfile import traceback import click import httptools import immutables from edb.common import debug from edb.common import lru from edb.common import markup from edb.server import metrics from edb.server import args as srvargs from edb.server import defines from edb.server import logsetup from . import amsg from . import pool as pool_mod from . import queue from . import worker_proc from . import state as state_mod _client_id_seq = 0 _tx_state_id_seq = 0 logger = logging.getLogger("edb.server") def next_tx_state_id(): global _tx_state_id_seq _tx_state_id_seq = (_tx_state_id_seq + 1) % (2 ** 63 - 1) return _tx_state_id_seq class PickledState(NamedTuple): user_schema: Optional[bytes] reflection_cache: Optional[bytes] database_config: Optional[bytes] def diff(self, other: PickledState): # Compare this state with the other state, generate a new state with # fields from this state which are different in the other state, while # the identical fields are left None, so that we can send the minimum # diff to the worker to update the changed fields only. user_schema = reflection_cache = database_config = None if self.user_schema is not other.user_schema: user_schema = self.user_schema if self.reflection_cache is not other.reflection_cache: reflection_cache = self.reflection_cache if self.database_config is not other.database_config: database_config = self.database_config return PickledState(user_schema, reflection_cache, database_config) def get_estimated_size(self) -> int: rv = 0 if self.user_schema is not None: rv += len(self.user_schema) if self.reflection_cache is not None: rv += len(self.reflection_cache) if self.database_config is not None: rv += len(self.database_config) return rv class ClientSchema(NamedTuple): dbs: immutables.Map[str, PickledState] global_schema: Optional[bytes] instance_config: Optional[bytes] dropped_dbs: tuple def diff(self, other: ClientSchema) -> tuple[ClientSchema, int, int]: # Compare this schema with the other schema, generate a new schema with # fields from this schema which are different in the other schema, # while the identical fields are left None, so that we can send the # minimum diff to the worker to update the changed fields only. # NOTE: this is a deep diff that compares all children fields. dropped_dbs = tuple( dbname for dbname in other.dbs if dbname not in self.dbs ) added = 0 updated = 0 dbs: immutables.Map[str, PickledState] = immutables.Map() for dbname, state in self.dbs.items(): other_state = other.dbs.get(dbname) if other_state is None: dbs = dbs.set(dbname, state) added += 1 elif state is not other_state: dbs = dbs.set(dbname, state.diff(other_state)) updated += 1 global_schema = instance_config = None if self.global_schema is not other.global_schema: global_schema = self.global_schema if self.instance_config is not other.instance_config: instance_config = self.instance_config return ( ClientSchema(dbs, global_schema, instance_config, dropped_dbs), added, updated, ) def get_estimated_size(self) -> int: return sum(db.get_estimated_size() for db in self.dbs.values()) class Worker(pool_mod.BaseMultiTenantWorker[ClientSchema, "MultiSchemaPool"]): _last_pickled_state_id: int def _init(self) -> None: self._last_pickled_state_id = 0 def invalidate(self, client_id: int) -> None: if self._cache.pop(client_id, None): self._invalidated_clients.append(client_id) self._last_used_by_client.pop(client_id, None) def invalidate_last(self, cache_size: int) -> None: if len(self._cache) == cache_size: client_id, _ = self._cache.popitem(last=True) self._invalidated_clients.append(client_id) self._last_used_by_client.pop(client_id, None) async def call( self, method_name: str, *args: Any, sync_state: Optional[pool_mod.SyncStateCallback] = None, msg: Optional[bytes] = None, ) -> Any: assert not self._closed assert self._con is not None if self._con.is_closed(): raise RuntimeError( "the connection to the compiler worker process is " "unexpectedly closed" ) if msg is None: msg = pickle.dumps((method_name, args)) return await self._con.request(msg) class MultiSchemaPool( pool_mod.FixedPoolImpl[Worker, pool_mod.MultiTenantInitArgs] ): _worker_class = Worker _worker_mod = "multitenant_worker" _workers: dict[int, Worker] # type: ignore _catalog_version: Optional[int] _inited: asyncio.Event _clients: dict[int, ClientSchema] _client_names: dict[int, str] _secret: bytes def __init__(self, cache_size, *, secret, **kwargs): super().__init__(**kwargs) self._catalog_version = None self._inited = asyncio.Event() self._cache_size = cache_size self._clients = {} self._client_names = {} self._secret = secret def _init(self, kwargs: dict[str, Any]) -> None: # this is deferred to _init_server() pass @lru.method_cache def _get_init_args(self) -> tuple[pool_mod.MultiTenantInitArgs, bytes]: init_args = ( self._backend_runtime_params, self._std_schema, self._refl_schema, self._schema_class_layout, ) return init_args, pickle.dumps(init_args, -1) async def _attach_worker(self, pid: int) -> Optional[Worker]: if not self._running: return None if not self._inited.is_set(): await self._inited.wait() return await super()._attach_worker(pid) async def _wait_ready(self) -> None: pass async def _init_server( self, client_id: int, client_name: str, compiler_protocol: int, catalog_version: int, init_args_pickled: pool_mod.RemoteInitArgsPickle, ) -> None: if compiler_protocol > pool_mod.CURRENT_COMPILER_PROTOCOL: raise state_mod.IncompatibleClient("compiler_protocol") ( std_args_pickled, client_args_pickled, global_schema_pickle, system_config_pickled, ) = init_args_pickled backend_runtime_params, = pickle.loads(client_args_pickled) if self._inited.is_set(): logger.debug("New client %d connected.", client_id) assert self._catalog_version is not None if self._catalog_version != catalog_version: raise state_mod.IncompatibleClient("catalog_version") if self._backend_runtime_params != backend_runtime_params: raise state_mod.IncompatibleClient("backend_runtime_params") else: ( self._std_schema, self._refl_schema, self._schema_class_layout, ) = pickle.loads(std_args_pickled) self._backend_runtime_params = backend_runtime_params assert self._catalog_version is None self._catalog_version = catalog_version self._inited.set() logger.info( "New client %d connected, compiler server initialized.", client_id, ) self._clients[client_id] = ClientSchema( dbs=immutables.Map(), global_schema=global_schema_pickle, instance_config=system_config_pickled, dropped_dbs=(), ) self._client_names[client_id] = client_name def _sync( self, *, client_id: int, dbname: str, evicted_dbs: list[str], user_schema: Optional[bytes], reflection_cache: Optional[bytes], global_schema: Optional[bytes], database_config: Optional[bytes], system_config: Optional[bytes], ) -> bool: """Sync the client state in the compiler server. The client state is carried over with the compile(), compile_sql(), compile_notebook(), compile_graphql() calls. Returns True if the client state changed, False otherwise. """ # EdgeDB instance syncs the schema with the compiler server client = self._clients[client_id] client_updates: dict[str, Any] = {} dbs = client.dbs.mutate() dbs_changed = False if evicted_dbs: for name in evicted_dbs: if dbs.pop(name, None) is not None: dbs_changed = True db = dbs.get(dbname) if db is None: assert user_schema is not None assert reflection_cache is not None assert database_config is not None dbs[dbname] = PickledState( user_schema, reflection_cache, database_config ) dbs_changed = True else: updates = {} if user_schema is not None: updates["user_schema"] = user_schema if reflection_cache is not None: updates["reflection_cache"] = reflection_cache if database_config is not None: updates["database_config"] = database_config if updates: db = db._replace(**updates) dbs[dbname] = db dbs_changed = True if global_schema is not None: client_updates["global_schema"] = global_schema if system_config is not None: client_updates["instance_config"] = system_config if dbs_changed: client_updates["dbs"] = dbs.finish() if client_updates: self._clients[client_id] = client._replace(**client_updates) return True else: return False def _weighter(self, client_id: int, worker: Worker) -> queue.Comparable: client_schema = worker.get_tenant_schema(client_id, touch=False) return ( bool(client_schema), worker.last_used(client_id) if client_schema else self._cache_size - worker.cache_size(), ) async def _call_for_client( self, *, client_id: int, method_name: str, dbname: str, evicted_dbs: list[str], user_schema: Optional[bytes], reflection_cache: Optional[bytes], global_schema: Optional[bytes], database_config: Optional[bytes], system_config: Optional[bytes], args: tuple[Any, ...], msg: memoryview, ) -> Any: try: updated = self._sync( client_id=client_id, dbname=dbname, evicted_dbs=evicted_dbs, user_schema=user_schema, reflection_cache=reflection_cache, global_schema=global_schema, database_config=database_config, system_config=system_config, ) except Exception as ex: raise state_mod.FailedStateSync( f"failed to sync compiler server state: " f"{type(ex).__name__}({ex})" ) from ex worker = await self._acquire_worker( weighter=functools.partial(self._weighter, client_id) ) try: pid = str(worker.get_pid()) client_schema = self._clients[client_id] client_name = self._client_names[client_id] worker.set_client_name(client_id, client_name) diff: Optional[ClientSchema] = client_schema cache = worker.get_tenant_schema(client_id) extra_args: tuple[Any, ...] = () metrics.compiler_process_branch_actions.inc( 1, pid, client_name, 'request' ) if cache is client_schema: # client schema is already in sync, don't send again diff = None msg_arg = bytes(msg) metrics.compiler_process_client_actions.inc( 1, pid, 'cache-hit' ) metrics.compiler_process_branch_actions.inc( 1, pid, client_name, 'cache-hit' ) else: metrics.compiler_process_schema_size.set( client_schema.get_estimated_size(), pid, client_name ) metrics.compiler_process_branches.set( len(client_schema.dbs), pid, client_name ) if cache is None: # make room for the new client in this worker worker.invalidate_last(self._cache_size) metrics.compiler_process_branch_actions.inc( len(client_schema.dbs), pid, client_name, 'cache-add' ) metrics.compiler_process_client_actions.inc( 1, pid, 'cache-add' ) else: # only send the difference in user schema diff, dbs_added, dbs_updated = client_schema.diff(cache) if dbname not in diff.dbs: metrics.compiler_process_branch_actions.inc( 1, pid, client_name, 'cache-hit' ) if dbs_added: metrics.compiler_process_branch_actions.inc( dbs_added, pid, client_name, 'cache-add' ) if dbs_updated: metrics.compiler_process_branch_actions.inc( dbs_updated, pid, client_name, 'cache-update' ) if diff.dropped_dbs: metrics.compiler_process_branch_actions.inc( len(diff.dropped_dbs), pid, client_name, 'cache-evict', ) metrics.compiler_process_client_actions.inc( 1, pid, 'cache-update' ) if updated: # re-pickle the request if user schema changed msg_arg = None extra_args = (method_name, dbname, *args) else: msg_arg = bytes(msg) invalidation = worker.flush_invalidation() resp = await worker.call( "call_for_client", client_id, diff, invalidation, msg_arg, *extra_args, ) status, *data = pickle.loads(resp) if status == 0: worker.set_tenant_schema(client_id, client_schema) if method_name == "compile": _units, new_pickled_state = data[0] if new_pickled_state: sid = next_tx_state_id() worker._last_pickled_state_id = sid resp = pickle.dumps((0, (*data[0], sid)), -1) elif status == 1: exc, _tb = data if not isinstance(exc, state_mod.FailedStateSync): worker.set_tenant_schema(client_id, client_schema) else: exc = RuntimeError( "could not serialize result in worker subprocess" ) exc.__formatted_error__ = data[0] raise exc return resp finally: self._release_worker(worker) async def compile_in_tx_( self, state_id: int, client_id: Optional[int], dbname: Optional[str], user_schema_pickle: Optional[bytes], pickled_state: bytes, txid: int, *compile_args: Any, msg: bytes, ): if pickled_state == state_mod.REUSE_LAST_STATE_MARKER: worker = await self._acquire_worker( condition=lambda w: (w._last_pickled_state_id == state_id) ) if worker._last_pickled_state_id != state_id: self._release_worker(worker) raise state_mod.StateNotFound() else: worker = await self._acquire_worker() try: resp = await worker.call( "compile_in_tx", state_id, client_id, dbname, user_schema_pickle, pickled_state, txid, *compile_args, msg=msg, ) status, *data = pickle.loads(resp) if status == 0: state_id = worker._last_pickled_state_id = next_tx_state_id() resp = pickle.dumps((0, (*data[0], state_id)), -1) return resp finally: self._release_worker(worker, put_in_front=False) async def _request(self, method_name: str, msg: bytes) -> Any: worker = await self._acquire_worker() try: return await worker.call(method_name, msg=msg) finally: self._release_worker(worker) async def handle_client_call( self, protocol: CompilerServerProtocol, req_id: int, msg: memoryview, ) -> None: client_id = protocol.client_id digest = msg[:32] msg = msg[32:] try: expected_digest = hmac.digest(self._secret, msg, "sha256") if not hmac.compare_digest(digest, expected_digest): raise AssertionError("message signature verification failed") method_name, args = pickle.loads(msg) if method_name != "__init_server__": await self._ready_evt.wait() if method_name == "__init_server__": await self._init_server(client_id, protocol.client_name, *args) pickled = pickle.dumps((0, None), -1) elif method_name in { "compile", "compile_notebook", "compile_graphql", "compile_sql", }: ( dbname, evicted_dbs, user_schema, reflection_cache, global_schema, database_config, system_config, *args, ) = args pickled = await self._call_for_client( client_id=client_id, method_name=method_name, dbname=dbname, evicted_dbs=evicted_dbs, user_schema=user_schema, reflection_cache=reflection_cache, global_schema=global_schema, database_config=database_config, system_config=system_config, args=args, msg=msg, ) elif method_name == "compile_in_tx": pickled = await self.compile_in_tx_(*args, msg=bytes(msg)) else: pickled = await self._request(method_name, bytes(msg)) except Exception as ex: worker_proc.prepare_exception(ex) if debug.flags.server and not isinstance( ex, state_mod.StateNotFound ): markup.dump(ex) data = (1, ex, traceback.format_exc()) try: pickled = pickle.dumps(data, -1) except Exception as ex: ex_tb = traceback.format_exc() ex_str = f"{ex}:\n\n{ex_tb}" pickled = pickle.dumps((2, ex_str), -1) protocol.reply(req_id, pickled) def client_disconnected(self, client_id: int) -> None: logger.debug("Client %d disconnected, invalidating cache.", client_id) self._clients.pop(client_id, None) self._client_names.pop(client_id, None) for worker in self._workers.values(): worker.invalidate(client_id) class CompilerServerProtocol(asyncio.Protocol): _pool: MultiSchemaPool _loop: asyncio.AbstractEventLoop _stream: amsg.MessageStream _transport: Optional[asyncio.Transport] _client_id: int _client_name: str def __init__( self, pool: MultiSchemaPool, loop: asyncio.AbstractEventLoop, ) -> None: global _client_id_seq self._pool = pool self._loop = loop self._stream = amsg.MessageStream() self._transport = None self._client_id = _client_id_seq = _client_id_seq + 1 self._client_name = 'unknown' def connection_made(self, transport: asyncio.BaseTransport) -> None: self._transport = cast(asyncio.Transport, transport) self._transport.write( amsg._uint64_packer(os.getpid()) + amsg._uint64_packer(0) ) peername = transport.get_extra_info('peername') try: self._client_name = '%s:%d' % peername except Exception: self._client_name = str(peername) def connection_lost(self, exc: Optional[Exception]) -> None: self._transport = None self._pool.client_disconnected(self._client_id) def data_received(self, data: bytes) -> None: for msg in self._stream.feed_data(data): msgview = memoryview(msg) req_id = amsg._uint64_unpacker(msgview[:8])[0] self._loop.create_task( self._pool.handle_client_call(self, req_id, msgview[8:]) ) @property def client_id(self) -> int: return self._client_id @property def client_name(self) -> str: return self._client_name def reply(self, req_id: int, resp: bytes) -> None: if self._transport is None: return self._transport.write( b"".join( ( amsg._uint64_packer(len(resp) + 8), amsg._uint64_packer(req_id), resp, ) ) ) class MetricsProtocol(asyncio.Protocol): _pool: MultiSchemaPool transport: Optional[asyncio.Transport] parser: httptools.HttpRequestParser url: Optional[bytes] def __init__(self, pool: MultiSchemaPool) -> None: self._pool = pool self.transport = None self.parser = httptools.HttpRequestParser(self) self.url = None def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = cast(asyncio.Transport, transport) def data_received(self, data: bytes) -> None: try: self.parser.feed_data(data) except Exception as ex: logger.exception(ex) def on_url(self, url: bytes) -> None: self.url = url def on_message_complete(self) -> None: match self.parser.get_method().upper(), self.url: case b"GET", b"/ready": self.respond("200 OK", "OK") case b"GET", b"/metrics": self._pool.refresh_metrics() self.respond( "200 OK", metrics.registry.generate(), "Content-Type: text/plain; version=0.0.4; charset=utf-8", ) case _: self.respond("404 Not Found", "Not Found") def respond( self, status: str, content: str, *extra_headers: str, encoding: str = "utf-8", ) -> None: content_bytes = content.encode(encoding) response = [ f"HTTP/{self.parser.get_http_version()} {status}", f"Content-Length: {len(content_bytes)}", *extra_headers, "", "", ] assert self.transport is not None self.transport.write("\r\n".join(response).encode("ascii")) self.transport.write(content_bytes) if not self.parser.should_keep_alive(): self.transport.close() async def server_main( listen_addresses: Sequence[str], listen_port: Optional[int], pool_size: int, client_schema_cache_size: int, runstate_dir: Optional[str | pathlib.Path], metrics_port: Optional[int], worker_max_rss: Optional[int], ): logsetup.setup_logging('i', 'stderr') if listen_port is None: listen_port = defines.EDGEDB_REMOTE_COMPILER_PORT if runstate_dir is None: temp_runstate_dir = tempfile.TemporaryDirectory(prefix='edbcompiler-') runstate_dir = temp_runstate_dir.name logger.debug("Using temporary runstate dir: %s", runstate_dir) else: temp_runstate_dir = None runstate_dir = str(runstate_dir) secret = os.environ.get("_EDGEDB_SERVER_COMPILER_POOL_SECRET") if not secret: logger.warning( "_EDGEDB_SERVER_COMPILER_POOL_SECRET is not set, " f"compilation requests will fail") secret = secrets.token_urlsafe() try: loop = asyncio.get_running_loop() pool = MultiSchemaPool( loop=loop, runstate_dir=runstate_dir, pool_size=pool_size, worker_branch_limit=0, # compiler server doesn't use this limit cache_size=client_schema_cache_size, secret=secret.encode(), worker_max_rss=worker_max_rss, ) await pool.start() try: async with asyncio.TaskGroup() as tg: tg.create_task( _run_server( loop, listen_addresses, listen_port, lambda: CompilerServerProtocol(pool, loop), "compile", ) ) if metrics_port: tg.create_task( _run_server( loop, listen_addresses, metrics_port, lambda: MetricsProtocol(pool), "metrics", ) ) finally: await pool.stop() finally: if temp_runstate_dir is not None: temp_runstate_dir.cleanup() async def _run_server( loop: asyncio.AbstractEventLoop, listen_addresses: Sequence[str], listen_port: int, protocol: Callable[[], asyncio.Protocol], purpose: str, ) -> None: server = await loop.create_server( protocol, listen_addresses, listen_port, start_serving=False, ) if len(listen_addresses) == 1: logger.info( "Listening for %s on %s:%s", purpose, listen_addresses[0], listen_port, ) else: logger.info( "Listening for %s on [%s]:%s", purpose, ",".join(listen_addresses), listen_port, ) try: await server.serve_forever() finally: server.close() await server.wait_closed() @click.command() @srvargs.compiler_options def main(**kwargs: Any) -> None: asyncio.run(server_main(**kwargs)) if __name__ == "__main__": main() ================================================ FILE: edb/server/compiler_pool/state.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import typing import immutables from edb.schema import schema from edb.server import config ReflectionCache = immutables.Map[str, tuple[str, ...]] class DatabaseState(typing.NamedTuple): name: str user_schema: schema.Schema reflection_cache: ReflectionCache database_config: immutables.Map[str, config.SettingValue] DatabasesState = immutables.Map[str, DatabaseState] class PickledDatabaseState(typing.NamedTuple): user_schema_pickle: bytes reflection_cache: ReflectionCache database_config: immutables.Map[str, config.SettingValue] def get_estimated_size(self) -> int: return ( len(self.user_schema_pickle) + len(self.reflection_cache) * 128 + len(self.database_config) * 128 ) class FailedStateSync(Exception): pass class StateNotFound(Exception): pass class IncompatibleClient(Exception): pass REUSE_LAST_STATE_MARKER = b'REUSE_LAST_STATE_MARKER' ================================================ FILE: edb/server/compiler_pool/worker.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Mapping, Optional import pickle import immutables from edb import edgeql from edb import graphql from edb.common import uuidgen from edb.pgsql import params as pgparams from edb.schema import schema as s_schema from edb.server import compiler from edb.server import config from edb.server import defines from . import state from . import worker_proc INITED: bool = False DBS: state.DatabasesState = immutables.Map() BACKEND_RUNTIME_PARAMS: pgparams.BackendRuntimeParams = \ pgparams.get_default_runtime_params() COMPILER: compiler.Compiler LAST_STATE: Optional[compiler.dbstate.CompilerConnectionState] = None LAST_STATE_PICKLE: Optional[bytes] = None STD_SCHEMA: s_schema.Schema GLOBAL_SCHEMA: s_schema.Schema INSTANCE_CONFIG: immutables.Map[str, config.SettingValue] def __init_worker__( init_args_pickled: bytes, ) -> None: global INITED global BACKEND_RUNTIME_PARAMS global COMPILER global STD_SCHEMA global GLOBAL_SCHEMA global INSTANCE_CONFIG ( backend_runtime_params, std_schema, refl_schema, schema_class_layout, global_schema_pickle, system_config, ) = pickle.loads(init_args_pickled) INITED = True BACKEND_RUNTIME_PARAMS = backend_runtime_params STD_SCHEMA = std_schema GLOBAL_SCHEMA = pickle.loads(global_schema_pickle) INSTANCE_CONFIG = system_config COMPILER = compiler.new_compiler( std_schema, refl_schema, schema_class_layout, backend_runtime_params=BACKEND_RUNTIME_PARAMS, config_spec=None, ) def __sync__( dbname: str, evicted_dbs: list[str], user_schema: Optional[bytes], reflection_cache: Optional[bytes], global_schema: Optional[bytes], database_config: Optional[bytes], system_config: Optional[bytes], ) -> state.DatabaseState: global DBS global GLOBAL_SCHEMA global INSTANCE_CONFIG try: if evicted_dbs: dbs = DBS.mutate() for name in evicted_dbs: dbs.pop(name, None) DBS = dbs.finish() db = DBS.get(dbname) if db is None: assert user_schema is not None assert reflection_cache is not None assert database_config is not None user_schema_unpacked = pickle.loads(user_schema) reflection_cache_unpacked = pickle.loads(reflection_cache) database_config_unpacked = pickle.loads(database_config) db = state.DatabaseState( dbname, user_schema_unpacked, reflection_cache_unpacked, database_config_unpacked, ) DBS = DBS.set(dbname, db) else: updates = {} if user_schema is not None: updates['user_schema'] = pickle.loads(user_schema) if reflection_cache is not None: updates['reflection_cache'] = pickle.loads(reflection_cache) if database_config is not None: updates['database_config'] = pickle.loads(database_config) if updates: db = db._replace(**updates) DBS = DBS.set(dbname, db) if global_schema is not None: GLOBAL_SCHEMA = pickle.loads(global_schema) if system_config is not None: INSTANCE_CONFIG = pickle.loads(system_config) except Exception as ex: raise state.FailedStateSync( f'failed to sync worker state: {type(ex).__name__}({ex})') from ex return db def compile( dbname: str, evicted_dbs: list[str], user_schema: Optional[bytes], reflection_cache: Optional[bytes], global_schema: Optional[bytes], database_config: Optional[bytes], system_config: Optional[bytes], *compile_args: Any, **compile_kwargs: Any, ): db = __sync__( dbname, evicted_dbs, user_schema, reflection_cache, global_schema, database_config, system_config, ) units, cstate = COMPILER.compile_serialized_request( db.user_schema, GLOBAL_SCHEMA, db.reflection_cache, db.database_config, INSTANCE_CONFIG, *compile_args, **compile_kwargs ) global LAST_STATE, LAST_STATE_PICKLE LAST_STATE = cstate LAST_STATE_PICKLE = None if cstate is not None: LAST_STATE_PICKLE = pickle.dumps(cstate, -1) return units, LAST_STATE_PICKLE def compile_in_tx( dbname: Optional[str], user_schema: Optional[bytes], cstate, *args, **kwargs ): global LAST_STATE, LAST_STATE_PICKLE prev_last_state_key = None if cstate == state.REUSE_LAST_STATE_MARKER: assert LAST_STATE is not None cstate = LAST_STATE prev_last_state_key = cstate.get_state_key() else: cstate = pickle.loads(cstate) LAST_STATE_PICKLE = None if dbname is None: assert user_schema is not None cstate.set_root_user_schema(pickle.loads(user_schema)) else: cstate.set_root_user_schema(DBS[dbname].user_schema) units, cstate = COMPILER.compile_serialized_request_in_tx( cstate, *args, **kwargs) LAST_STATE = cstate # We don't want to continuously re-pickle transaction state # for every new query in a transaction that doesn't actually change # its state in every query. I.e. it doesn't run DDL, configures # new session aliases, configs, or globals. if (prev_last_state_key is None or LAST_STATE_PICKLE is None or prev_last_state_key != cstate.get_state_key() ): LAST_STATE_PICKLE = pickle.dumps(cstate, -1) return units, LAST_STATE_PICKLE def compile_notebook( dbname: str, evicted_dbs: list[str], user_schema: Optional[bytes], reflection_cache: Optional[bytes], global_schema: Optional[bytes], database_config: Optional[bytes], system_config: Optional[bytes], *compile_args: Any, **compile_kwargs: Any, ): db = __sync__( dbname, evicted_dbs, user_schema, reflection_cache, global_schema, database_config, system_config, ) return COMPILER.compile_notebook( db.user_schema, GLOBAL_SCHEMA, db.reflection_cache, db.database_config, INSTANCE_CONFIG, *compile_args, **compile_kwargs ) def compile_graphql( dbname: str, evicted_dbs: list[str], user_schema: Optional[bytes], reflection_cache: Optional[bytes], global_schema: Optional[bytes], database_config: Optional[bytes], system_config: Optional[bytes], session_config: Mapping[str, Any], *compile_args: Any, **compile_kwargs: Any, ) -> tuple[compiler.QueryUnitGroup, graphql.TranspiledOperation]: db = __sync__( dbname, evicted_dbs, user_schema, reflection_cache, global_schema, database_config, system_config, ) gql_op = graphql.compile_graphql( STD_SCHEMA, db.user_schema, GLOBAL_SCHEMA, db.database_config, INSTANCE_CONFIG, *compile_args, **compile_kwargs ) source = edgeql.Source.from_string( edgeql.generate_source(gql_op.edgeql_ast, pretty=True), ) cfg_ser = COMPILER.state.compilation_config_serializer request = compiler.CompilationRequest( source=source, protocol_version=defines.CURRENT_PROTOCOL, schema_version=uuidgen.uuid4(), compilation_config_serializer=cfg_ser, output_format=compiler.OutputFormat.JSON, input_format=compiler.InputFormat.JSON, expect_one=True, implicit_limit=0, inline_typeids=False, inline_typenames=False, inline_objectids=False, modaliases=None, session_config=session_config, ) unit_group, _ = COMPILER.compile( user_schema=db.user_schema, global_schema=GLOBAL_SCHEMA, reflection_cache=db.reflection_cache, database_config=db.database_config, system_config=INSTANCE_CONFIG, request=request, ) return unit_group, gql_op # type: ignore[return-value] def compile_sql( dbname: str, evicted_dbs: list[str], user_schema: Optional[bytes], reflection_cache: Optional[bytes], global_schema: Optional[bytes], database_config: Optional[bytes], system_config: Optional[bytes], *compile_args: Any, **compile_kwargs: Any, ): db = __sync__( dbname, evicted_dbs, user_schema, reflection_cache, global_schema, database_config, system_config, ) return COMPILER.compile_sql( db.user_schema, GLOBAL_SCHEMA, db.reflection_cache, db.database_config, INSTANCE_CONFIG, *compile_args, **compile_kwargs ) def get_handler(methname): if methname == "__init_worker__": meth = __init_worker__ else: if not INITED: raise RuntimeError( "call on uninitialized compiler worker" ) if methname == "compile": meth = compile elif methname == "compile_in_tx": meth = compile_in_tx elif methname == "compile_notebook": meth = compile_notebook elif methname == "compile_graphql": meth = compile_graphql elif methname == "compile_sql": meth = compile_sql else: meth = getattr(COMPILER, methname) return meth if __name__ == "__main__": try: worker_proc.main(get_handler) except KeyboardInterrupt: pass ================================================ FILE: edb/server/compiler_pool/worker_proc.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import argparse import gc import os import pickle import signal import sys import time import traceback from edb.common import debug from edb.common import devmode from edb.common import markup from edb.common import lru from edb.edgeql import parser as ql_parser from . import amsg # "created continuously" means the interval between two consecutive spawns # is less than NUM_SPAWNS_RESET_INTERVAL seconds. NUM_SPAWNS_RESET_INTERVAL = 1 def worker(sockname, version_serial, get_handler): con = amsg.WorkerConnection(sockname, version_serial) try: for req_id, req in con.iter_request(): try: methname, args = pickle.loads(req) meth = get_handler(methname) except Exception as ex: prepare_exception(ex) if debug.flags.server: markup.dump(ex) data = (1, ex, traceback.format_exc()) else: try: res = meth(*args) data = (0, res) except Exception as ex: prepare_exception(ex) if debug.flags.server: markup.dump(ex) data = (1, ex, traceback.format_exc()) try: pickled = pickle.dumps(data, -1) except Exception as ex: ex_tb = traceback.format_exc() ex_str = f"{ex}:\n\n{ex_tb}" pickled = pickle.dumps((2, ex_str), -1) con.reply(req_id, pickled) # Now that we have responded, clear the compiler LRU # caches to avoid hanging onto heavy objects like schemas. lru.clear_lru_caches() finally: con.abort() def run_worker(sockname, version_serial, get_handler): with devmode.CoverageConfig.enable_coverage_if_requested(): worker(sockname, version_serial, get_handler) def prepare_exception(ex): clear_exception_frames(ex) if ex.__traceback__ is not None: ex.__traceback__ = ex.__traceback__.tb_next def clear_exception_frames(er): def _clear_exception_frames(er, visited): if er in visited: return er visited.add(er) traceback.clear_frames(er.__traceback__) if er.__cause__ is not None: er.__cause__ = _clear_exception_frames(er.__cause__, visited) if er.__context__ is not None: er.__context__ = _clear_exception_frames(er.__context__, visited) return er visited = set() _clear_exception_frames(er, visited) def listen_for_debugger(): if debug.flags.pydebug_listen: import debugpy debugpy.listen(38781) def main(get_handler): parser = argparse.ArgumentParser() parser.add_argument("--sockname") parser.add_argument("--numproc") parser.add_argument("--version-serial", type=int) args = parser.parse_args() sys.setrecursionlimit(2000) ql_parser.preload_spec() gc.freeze() listen_for_debugger() if args.numproc is None: # Run a single worker process run_worker(args.sockname, args.version_serial, get_handler) return numproc = int(args.numproc) assert numproc >= 1 # Abort the template process if more than `max_worker_spawns` # new workers are created continuously - it probably means the # worker cannot start correctly. max_worker_spawns = numproc * 2 children = set() continuous_num_spawns = 0 for _ in range(int(args.numproc)): # spawn initial workers if pid := os.fork(): # main process children.add(pid) continuous_num_spawns += 1 else: # child process break else: # main process - redirect SIGTERM to SystemExit and wait for children signal.signal(signal.SIGTERM, lambda *_: exit(os.EX_OK)) last_spawn_timestamp = time.monotonic() try: while children: pid, status = os.wait() children.remove(pid) ec = os.waitstatus_to_exitcode(status) if ec > 0 or -ec not in {0, signal.SIGINT}: # restart the child process if killed or ending abnormally, # unless we tried too many times continuously now = time.monotonic() if now - last_spawn_timestamp > NUM_SPAWNS_RESET_INTERVAL: continuous_num_spawns = 0 last_spawn_timestamp = now continuous_num_spawns += 1 if continuous_num_spawns > max_worker_spawns: # GOTCHA: we shouldn't return here because we need the # exception handler below to clean up the workers exit(os.EX_UNAVAILABLE) if pid := os.fork(): # main process children.add(pid) else: # child process break else: # main process - all children ended normally return except BaseException as e: # includes SystemExit and KeyboardInterrupt # main process - kill and wait for the remaining workers to exit try: signal.signal(signal.SIGTERM, signal.SIG_DFL) for pid in children: try: os.kill(pid, signal.SIGTERM) except OSError: pass try: while children: pid, status = os.wait() children.discard(pid) except OSError: pass finally: raise e # child process - clear the SIGTERM handler for potential Rust impl signal.signal(signal.SIGTERM, signal.SIG_DFL) run_worker(args.sockname, args.version_serial, get_handler) ================================================ FILE: edb/server/config/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Mapping, TypedDict import enum import immutables from edb import errors from edb.edgeql.qltypes import ConfigScope from .ops import OpCode, Operation, SettingValue from .ops import ( spec_to_json, to_json_obj, to_json, from_json, set_value, to_edgeql ) from .ops import value_from_json, value_to_json_value from .spec import ( Spec, FlatSpec, ChainedSpec, Setting, load_spec_from_schema, load_ext_spec_from_schema, load_ext_settings_from_schema, ) from .types import ConfigType, CompositeConfigType from .types import QueryCacheMode __all__ = ( 'lookup', 'Spec', 'FlatSpec', 'ChainedSpec', 'Setting', 'SettingValue', 'spec_to_json', 'to_json_obj', 'to_json', 'to_edgeql', 'from_json', 'set_value', 'value_from_json', 'value_to_json_value', 'ConfigScope', 'OpCode', 'Operation', 'ConfigType', 'CompositeConfigType', 'load_spec_from_schema', 'load_ext_spec_from_schema', 'load_ext_settings_from_schema', 'get_compilation_config', 'QueryCacheMode', 'ConState', 'ConStateType', ) # See edb/server/pgcon/connect.py for documentation of the types class ConStateType(enum.StrEnum): session_config = "C" backend_session_config = "B" command_line_argument = "A" environment_variable = "E" config_file = "F" class ConState(TypedDict): name: str value: Any type: ConStateType def lookup( name: str, *configs: Mapping[str, SettingValue], spec: Spec, allow_unrecognized: bool = False, ) -> Any: assert len(configs) > 0 try: setting = spec[name] except (KeyError, TypeError): if allow_unrecognized: return None else: raise errors.ConfigurationError( f'unrecognized configuration parameter {name!r}') for c in configs: try: setting_value = c[name] except KeyError: pass else: return setting_value.value else: return setting.default def get_compilation_config( config: Mapping[str, SettingValue], *, spec: Spec, ) -> immutables.Map[str, SettingValue]: return immutables.Map(( (k, v) for k, v in config.items() if k in spec if spec[k].affects_compilation )) def _serialize_val(v: object) -> object: if isinstance(v, frozenset): return [_serialize_val(x) for x in v] elif isinstance(v, CompositeConfigType): return v.to_json_value(redacted=True) else: return v def debug_serialize_config( cfg: Mapping[str, SettingValue], ) -> Any: return { name: {'redacted': True} if value.secret else _serialize_val(value.value) for name, value in cfg.items() } ================================================ FILE: edb/server/config/ops.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import base64 import json from typing import ( Any, Callable, Optional, Iterable, Mapping, Collection, NamedTuple, TYPE_CHECKING, TypeGuard, ) import immutables from edb import errors from edb.common import enum from edb.common import typeutils from edb.ir import statypes from edb.edgeql import codegen as qlcodegen from edb.edgeql import qltypes from edb.schema import objects as s_obj from edb.schema import utils as s_utils from . import spec from . import types MAX_CONFIG_SET_SIZE = 128 class OpCode(enum.StrEnum): CONFIG_ADD = 'ADD' CONFIG_REM = 'REM' CONFIG_SET = 'SET' CONFIG_RESET = 'RESET' class SettingValue(NamedTuple): name: str value: Any source: str scope: qltypes.ConfigScope # We track this just so that we can redact secret values in our # debug endpoints. secret: bool = False if TYPE_CHECKING: SettingsMap = immutables.Map[str, SettingValue] def _issubclass[T_type: type]( typ: type | types.ConfigTypeSpec, parent: T_type ) -> TypeGuard[T_type]: return isinstance(typ, type) and issubclass(typ, parent) def coerce_single_value(setting: spec.Setting, value: Any) -> Any: if isinstance(setting.type, type) and isinstance(value, setting.type): return value elif (isinstance(value, str) and _issubclass(setting.type, statypes.Duration)): return statypes.Duration(value) elif (isinstance(value, (str, int)) and _issubclass(setting.type, statypes.ConfigMemory)): return statypes.ConfigMemory(value) elif (isinstance(value, str) and _issubclass(setting.type, statypes.EnumScalarType)): return setting.type(value) else: raise errors.ConfigurationError( f'invalid value type for the {setting.name!r} setting') def _check_object_set_uniqueness( setting: spec.Setting, objs: Iterable[types.CompositeConfigType] ) -> frozenset[types.CompositeConfigType]: """Check the unique constraints for an object set""" new_values = set() exclusive_keys: dict[tuple[str, str], Any] = {} for new_value in objs: tspec = new_value._tspec for name in tspec.fields: if (val := getattr(new_value, name, None)) is None: continue if (site := tspec.get_field_unique_site(name)): key = (site.name, name) current = exclusive_keys.setdefault(key, set()) if val in current: raise errors.ConstraintViolationError( f'{setting.type.__name__}.{name} ' f'violates exclusivity constraint' ) current.add(val) if new_value in new_values: raise errors.ConstraintViolationError( f'{setting.type.__name__} has no unique values' ) new_values.add(new_value) if len(new_values) > MAX_CONFIG_SET_SIZE: raise errors.ConfigurationError( f'invalid value for the ' f'{setting.name!r} setting: set is too large') return frozenset(new_values) def coerce_object_set( spec: spec.Spec, setting: spec.Setting, values: Any ) -> Any: assert isinstance(setting.type, types.ConfigTypeSpec) if not setting.set_of and len(values) > 1: raise errors.ConstraintViolationError( f'cannot have multiple values for single setting {setting.name!r}' ) return _check_object_set_uniqueness( setting, ( types.CompositeConfigType.from_pyvalue( jv, spec=spec, tspec=setting.type) for jv in values ), ) class Operation(NamedTuple): opcode: OpCode scope: qltypes.ConfigScope setting_name: str value: str | int | bool | Collection[str | int | bool | None] | None def get_setting(self, spec: spec.Spec) -> spec.Setting: try: return spec[self.setting_name] except KeyError: raise errors.ConfigurationError( f'unknown setting {self.setting_name!r}') from None def coerce_value( self, spec: spec.Spec, setting: spec.Setting, *, allow_missing: bool = False, ): if isinstance(setting.type, types.ConfigTypeSpec): try: if self.opcode is OpCode.CONFIG_SET: return coerce_object_set(spec, setting, self.value) else: return types.CompositeConfigType.from_pyvalue( self.value, spec=spec, tspec=setting.type, allow_missing=allow_missing, ) except (ValueError, TypeError): raise errors.ConfigurationError( f'invalid value type for the {setting.name!r} setting') elif setting.set_of: if self.value is None and allow_missing: return None elif not typeutils.is_container(self.value): raise errors.ConfigurationError( f'invalid value type for the ' f'{setting.name!r} setting') else: val = frozenset( coerce_single_value(setting, v) for v in self.value) # type: ignore if len(val) > MAX_CONFIG_SET_SIZE: raise errors.ConfigurationError( f'invalid value for the ' f'{setting.name!r} setting: set is too large') return val else: try: return coerce_single_value(setting, self.value) except errors.ConfigurationError: if self.value is None and allow_missing: return None else: raise def coerce_global_value( self, *, allow_missing: bool = False ) -> Optional[bytes]: if allow_missing and self.value is None: return None else: assert isinstance(self.value, str) b = base64.b64decode(self.value) # Input comes prefixed with length; if the length is -1, # the value has explicitly been set to {}. return b[4:] if b[:4] != b'\xff\xff\xff\xff' else None def apply( self, spec: spec.Spec, storage: SettingsMap, *, source: str | None = None, ) -> SettingsMap: allow_missing = ( self.opcode is OpCode.CONFIG_REM or self.opcode is OpCode.CONFIG_RESET ) if self.scope != qltypes.ConfigScope.GLOBAL: setting = self.get_setting(spec) value = self.coerce_value( spec, setting, allow_missing=allow_missing) else: setting = None value = self.coerce_global_value(allow_missing=allow_missing) if self.opcode is OpCode.CONFIG_SET: storage = self._set_value(storage, value, source=source) elif self.opcode is OpCode.CONFIG_RESET: try: storage = storage.delete(self.setting_name) except KeyError: pass elif self.opcode is OpCode.CONFIG_ADD: assert setting if not isinstance(setting.type, types.ConfigTypeSpec): raise errors.InternalServerError( f'unexpected CONFIGURE SET += on a primitive ' f'configuration parameter: {self.setting_name}' ) exist_setting = storage.get(self.setting_name) if exist_setting is not None: exist_value = exist_setting.value else: exist_value = setting.default new_value = _check_object_set_uniqueness( setting, list(exist_value) + [value]) storage = self._set_value(storage, new_value, source=source) elif self.opcode is OpCode.CONFIG_REM: assert setting if not isinstance(setting.type, types.ConfigTypeSpec): raise errors.InternalServerError( f'unexpected CONFIGURE SET -= on a primitive ' f'configuration parameter: {self.setting_name}' ) exist_setting = storage.get(self.setting_name) if exist_setting is not None: exist_value = exist_setting.value else: exist_value = setting.default new_value = exist_value - {value} storage = self._set_value(storage, new_value, source=source) return storage def _set_value( self, storage: SettingsMap, value: Any, *, source: str | None = None, ) -> SettingsMap: if source is None: if self.scope is qltypes.ConfigScope.INSTANCE: source = 'system override' elif self.scope is qltypes.ConfigScope.DATABASE: source = 'database' elif self.scope is qltypes.ConfigScope.SESSION: source = 'session' elif self.scope is qltypes.ConfigScope.GLOBAL: source = 'global' else: raise AssertionError(f'unexpected config scope: {self.scope}') return set_value( storage, self.setting_name, value, source=source, scope=self.scope, ) @classmethod def from_json(cls, json_value: str) -> Operation: op_str, scope_str, name, value = json.loads(json_value) return Operation( opcode=OpCode(op_str), scope=qltypes.ConfigScope(scope_str), setting_name=name, value=value, ) def spec_to_json(spec: spec.Spec): dct = {} for setting in spec.values(): if _issubclass(setting.type, str): typeid = s_obj.get_known_type_id('std::str') elif _issubclass(setting.type, bool): typeid = s_obj.get_known_type_id('std::bool') elif _issubclass(setting.type, int): typeid = s_obj.get_known_type_id('std::int64') elif _issubclass(setting.type, float): typeid = s_obj.get_known_type_id('std::float32') elif _issubclass(setting.type, types.ConfigType): typeid = setting.type.get_edgeql_typeid() elif _issubclass(setting.type, statypes.Duration): typeid = s_obj.get_known_type_id('std::duration') elif _issubclass(setting.type, statypes.ConfigMemory): typeid = s_obj.get_known_type_id('cfg::memory') elif _issubclass(setting.type, statypes.EnumScalarType): typeid = setting.type.get_edgeql_typeid() elif isinstance(setting.type, types.ConfigTypeSpec): typeid = types.CompositeConfigType.get_edgeql_typeid() else: raise RuntimeError( f'cannot serialize type for config setting {setting.name}') typemod = qltypes.TypeModifier.SingletonType if setting.set_of: typemod = qltypes.TypeModifier.SetOfType dct[setting.name] = { 'default': value_to_json_value(setting, setting.default), 'internal': setting.internal, 'system': setting.system, 'typeid': str(typeid), 'typemod': str(typemod), 'backend_setting': setting.backend_setting, 'report': setting.report, } return json.dumps(dct) def value_to_json_value(setting: spec.Setting, value: Any): if setting.set_of: if isinstance(setting.type, types.ConfigTypeSpec): return [v.to_json_value() for v in value] else: return list(value) else: if isinstance(setting.type, types.ConfigTypeSpec): # We always store objects as list at the top-level, even # if they are single, because it simplifies things in the # config handling SQL. return [value.to_json_value()] if value is not None else [] elif (_issubclass(setting.type, statypes.ScalarType) and value is not None): return value.to_json() else: return value def value_from_json_value(spec: spec.Spec, setting: spec.Setting, value: Any): if setting.set_of: if isinstance(setting.type, types.ConfigTypeSpec): return frozenset( types.CompositeConfigType.from_pyvalue( v, spec=spec, tspec=setting.type, ) for v in value ) else: return frozenset(value) else: if isinstance(setting.type, types.ConfigTypeSpec): if not value: return None if len(value) > 1: raise errors.ConfigurationError( f'multiple entries for single object {setting.name}' ) return types.CompositeConfigType.from_pyvalue( value[0], spec=spec, tspec=setting.type, ) elif _issubclass(setting.type, statypes.Duration): return statypes.Duration.from_iso8601(value) elif _issubclass(setting.type, statypes.ConfigMemory): return statypes.ConfigMemory(value) elif _issubclass(setting.type, statypes.EnumScalarType): return setting.type(value) else: return value def value_from_json(spec, setting, value: str): return value_from_json_value(spec, setting, json.loads(value)) def value_to_edgeql_const( type: type | types.ConfigTypeSpec, value: Any, with_secrets: bool, ) -> str: ql = s_utils.const_ast_from_python(value, with_secrets=with_secrets) return qlcodegen.generate_source(ql) def to_json_obj( spec: spec.Spec, storage: Mapping[str, SettingValue], *, setting_filter: Optional[Callable[[SettingValue], bool]] = None, include_source: bool = True, ) -> dict[str, Any]: dct = {} for name, value in storage.items(): if setting_filter is None or setting_filter(value): setting = spec[name] val = value_to_json_value(setting, value.value) if include_source: dct[name] = { 'name': name, 'source': value.source, 'scope': str(value.scope), 'value': val, } else: dct[name] = val return dct def to_json( spec: spec.Spec, storage: Mapping[str, SettingValue], *, setting_filter: Optional[Callable[[SettingValue], bool]] = None, include_source: bool = True, ) -> str: dct = to_json_obj( spec, storage, setting_filter=setting_filter, include_source=include_source, ) return json.dumps(dct) def from_json(spec: spec.Spec, js: str | bytes) -> SettingsMap: base: SettingsMap = immutables.Map() with base.mutate() as mm: dct = json.loads(js) if not isinstance(dct, dict): raise errors.ConfigurationError( 'invalid JSON: top-level dict was expected') for key, value in dct.items(): setting = spec.get(key) if setting is None: # If the setting isn't in the spec, that's probably because # we've downgraded minor versions. Don't worry about it. continue mm[key] = SettingValue( name=key, value=value_from_json_value(spec, setting, value['value']), source=value['source'], scope=qltypes.ConfigScope(value['scope']), secret=setting.secret, ) return mm.finish() def to_edgeql( spec: spec.Spec, storage: Mapping[str, SettingValue], with_secrets: bool, ) -> str: stmts = [] for name, value in storage.items(): if name not in spec: continue setting = spec[name] if setting.secret and not with_secrets: continue if setting.protected: continue if isinstance(setting.type, types.ConfigTypeSpec): values = value.value if setting.set_of else [value.value] for x in values: # We look at the specific type of the object because # a subtype could have a secret that the parent doesn't. if x._tspec.has_secret and not with_secrets: continue val = value_to_edgeql_const( setting.type, x, with_secrets=with_secrets ) stmt = f'CONFIGURE {value.scope.to_edgeql()}\n{val};' stmts.append(stmt) else: val = value_to_edgeql_const( setting.type, value.value, with_secrets=with_secrets ) stmt = f'CONFIGURE {value.scope.to_edgeql()} SET {name} := {val};' stmts.append(stmt) return '\n'.join(stmts) def set_value( storage: SettingsMap, name: str, value: Any, source: str, scope: qltypes.ConfigScope, ) -> SettingsMap: secret = name in storage and storage[name].secret return storage.set( name, SettingValue(name=name, value=value, source=source, scope=scope, secret=secret), ) ================================================ FILE: edb/server/config/spec.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from abc import abstractmethod import collections.abc import dataclasses import json from typing import Any, Optional, Iterator, Sequence from edb.edgeql import compiler as qlcompiler from edb.ir import staeval from edb.ir import statypes from edb.schema import links as s_links from edb.schema import name as sn from edb.schema import objtypes as s_objtypes from edb.schema import scalars as s_scalars from edb.schema import schema as s_schema from edb.common.typeutils import downcast from . import types SETTING_TYPES = { str, int, bool, float, } @dataclasses.dataclass(frozen=True, eq=True) class Setting: name: str type: type | types.ConfigTypeSpec default: Any schema_type_name: Optional[sn.Name] = None set_of: bool = False system: bool = False internal: bool = False requires_restart: bool = False backend_setting: Optional[str] = None report: bool = False affects_compilation: bool = False enum_values: Optional[Sequence[str]] = None required: bool = True secret: bool = False protected: bool = False session_restricted: bool = False session_permission: Optional[str] = None def __post_init__(self) -> None: if ( self.type not in SETTING_TYPES and not isinstance(self.type, types.ConfigTypeSpec) and not ( isinstance(self.type, type) and issubclass(self.type, statypes.ScalarType) ) ): raise ValueError( f'invalid config setting {self.name!r}: ' f'type is expected to be either one of ' f'{{str, int, bool, float}} ' f'or an edb.server.config.types.ConfigType ', f'or edb.ir.statypes.ScalarType subclass') if self.set_of: if not isinstance(self.default, frozenset): raise ValueError( f'invalid config setting {self.name!r}: "SET OF" settings ' f'must have frozenset() as a default value, got ' f'{self.default!r}') if self.default: # SET OF settings shouldn't have non-empty defaults, # as otherwise there are multiple semantical ambiguities: # * Can a user add a new element to the set? # * What happens of a user discards all elements from the set? # Does the set become non-empty because the default would # propagate? # * etc. raise ValueError( f'invalid config setting {self.name!r}: "SET OF" settings ' f'should not have defaults') else: if ( not self.backend_setting and isinstance(self.type, type) and ( (self.default and not isinstance(self.default, self.type)) or (self.default is None and self.required) ) ): raise ValueError( f'invalid config setting {self.name!r}: ' f'the default {self.default!r} ' f'is not instance of {self.type}') if self.report and not self.system: raise ValueError('only instance settings can be reported') class Spec(collections.abc.Mapping): @abstractmethod def get_type_by_name(self, name: str) -> types.ConfigTypeSpec: raise NotImplementedError @abstractmethod def __iter__(self) -> Iterator[str]: raise NotImplementedError @abstractmethod def __getitem__(self, name: str) -> Setting: raise NotImplementedError @abstractmethod def __contains__(self, name: object) -> bool: raise NotImplementedError @abstractmethod def __len__(self) -> int: raise NotImplementedError class FlatSpec(Spec): def __init__(self, *settings: Setting): self._settings = tuple(settings) self._by_name = {s.name: s for s in self._settings} self._types_by_name: dict[str, types.ConfigTypeSpec] = {} for s in self._settings: if isinstance(s.type, types.ConfigTypeSpec): self._register_type(s.type) def _register_type(self, t: types.ConfigTypeSpec) -> None: self._types_by_name[t.name] = t for subclass in t.children: self._register_type(downcast(types.ConfigTypeSpec, subclass)) for field in t.fields.values(): f_type = field.type if isinstance(f_type, types.ConfigTypeSpec): self._register_type(f_type) def get_type_by_name(self, name: str) -> types.ConfigTypeSpec: return self._types_by_name[name] def __iter__(self) -> Iterator[str]: return iter(self._by_name) def __getitem__(self, name: str) -> Setting: return self._by_name[name] def __contains__(self, name: object) -> bool: return name in self._by_name def __len__(self) -> int: return len(self._settings) class ChainedSpec(Spec): def __init__(self, base: Spec, top: Spec): self._base = base self._top = top def get_type_by_name(self, name: str) -> types.ConfigTypeSpec: try: return self._top.get_type_by_name(name) except KeyError: return self._base.get_type_by_name(name) def __iter__(self) -> Iterator[str]: yield from self._top yield from self._base def __getitem__(self, name: str) -> Setting: if name in self._top: return self._top[name] else: return self._base[name] return self._by_name[name] def __contains__(self, name: object) -> bool: return name in self._top or name in self._base def __len__(self) -> int: return len(self._top) + len(self._base) def load_spec_from_schema( schema: s_schema.Schema, only_exts: bool=False, validate: bool=True, ) -> Spec: settings = [] if not only_exts: cfg = schema.get('cfg::Config', type=s_objtypes.ObjectType) settings.extend(_load_spec_from_type(schema, cfg)) settings.extend(load_ext_settings_from_schema(schema)) # Make sure there aren't any dangling ConfigObject children if validate: cfg_object = schema.get('cfg::ConfigObject', type=s_objtypes.ObjectType) for child in cfg_object.children(schema): if not schema.get_referrers( child, scls_type=s_links.Link, field_name='target' ): raise RuntimeError( f'cfg::ConfigObject child {child.get_name(schema)} has no ' f'links pointing at it (did you mean cfg::ExtensionConfig?)' ) return FlatSpec(*settings) def load_ext_settings_from_schema(schema: s_schema.Schema) -> list[Setting]: settings = [] ext_cfg = schema.get('cfg::ExtensionConfig', type=s_objtypes.ObjectType) for ecfg in ext_cfg.descendants(schema): if not ecfg.get_abstract(schema): settings.extend(_load_spec_from_type(schema, ecfg)) return settings def load_ext_spec_from_schema( user_schema: s_schema.Schema, std_schema: s_schema.Schema, ) -> Spec: schema = s_schema.ChainedSchema( std_schema, user_schema, s_schema.EMPTY_SCHEMA ) return load_spec_from_schema(schema, only_exts=True) def _load_spec_from_type( schema: s_schema.Schema, cfg: s_objtypes.ObjectType ) -> list[Setting]: settings = [] cfg_name = str(cfg.get_name(schema)) is_root = cfg_name == 'cfg::Config' for ptr_name, p in cfg.get_pointers(schema).items(schema): pn = str(ptr_name) if pn in ('id', '__type__') or p.get_computable(schema): continue ptype = p.get_target(schema) assert ptype # Skip backlinks to the base object. The will get plenty of # special treatment. if str(ptype.get_name(schema)) == 'cfg::AbstractConfig': continue pytype: type | types.ConfigTypeSpec if isinstance(ptype, s_objtypes.ObjectType): pytype = staeval.object_type_to_spec( ptype, schema, spec_class=types.ConfigTypeSpec, ) elif isinstance(ptype, s_scalars.ScalarType): pytype = staeval.scalar_type_to_python_type(ptype, schema) else: raise RuntimeError(f"unsupported config value type: {ptype}") attributes = {} for a, v in p.get_annotations(schema).items(schema): if isinstance(a, sn.QualName) and a.module == 'cfg': try: jv = json.loads(v.get_value(schema)) except json.JSONDecodeError: raise RuntimeError( f'Config annotation {a} on {p.get_name(schema)} ' f'is not valid json' ) attributes[a] = jv ptr_card = p.get_cardinality(schema) set_of = ptr_card.is_multi() backend_setting = attributes.get( sn.QualName('cfg', 'backend_setting'), None) required = p.get_required(schema) deflt_expr = p.get_default(schema) if deflt_expr is not None: deflt = qlcompiler.evaluate_to_python_val( deflt_expr.text, schema=schema) if set_of and not isinstance(deflt, frozenset): deflt = frozenset((deflt,)) else: if set_of: deflt = frozenset() elif backend_setting is None and required: raise RuntimeError(f'cfg::Config.{pn} has no default') else: deflt = None if not is_root: pn = f'{cfg_name}::{pn}' session_cfg_permissions = attributes.get( sn.QualName('cfg', 'session_cfg_permissions'), None ) session_restricted = session_cfg_permissions != '*' setting = Setting( pn, type=pytype, schema_type_name=ptype.get_name(schema), set_of=set_of, internal=attributes.get(sn.QualName('cfg', 'internal'), False), system=attributes.get(sn.QualName('cfg', 'system'), False), requires_restart=attributes.get( sn.QualName('cfg', 'requires_restart'), False), backend_setting=backend_setting, report=attributes.get( sn.QualName('cfg', 'report'), None), affects_compilation=attributes.get( sn.QualName('cfg', 'affects_compilation'), False), default=deflt, enum_values=( ptype.get_enum_values(schema) if isinstance(ptype, s_scalars.ScalarType) else None ), required=required, secret=p.get_secret(schema), protected=p.get_protected(schema), session_restricted=session_restricted, session_permission=( session_cfg_permissions if session_restricted else None ), ) settings.append(setting) return settings ================================================ FILE: edb/server/config/types.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, TYPE_CHECKING, TypeGuard import enum import platform from edb import errors from edb.common import typeutils from edb.common import typing_inspect from edb.schema import objects as s_obj from edb.ir import statypes if TYPE_CHECKING: from . import spec def _issubclass[T_type: type]( typ: type | statypes.CompositeTypeSpec, parent: T_type ) -> TypeGuard[T_type]: return isinstance(typ, type) and issubclass(typ, parent) class ConfigTypeSpec(statypes.CompositeTypeSpec): def __call__(self, **kwargs) -> CompositeConfigType: return CompositeConfigType(self, **kwargs) def from_pyvalue( self, v, *, spec, allow_missing=False ) -> CompositeConfigType: return CompositeConfigType.from_pyvalue( v, tspec=self, spec=spec, allow_missing=allow_missing ) class ConfigType: @classmethod def from_pyvalue(cls, v, *, tspec, spec, allow_missing=False): """Subclasses override this to allow creation from Python scalars.""" raise NotImplementedError @classmethod def from_json_value(cls, v, *, tspec, spec): raise NotImplementedError def to_json_value(self): raise NotImplementedError @classmethod def get_edgeql_typeid(cls): raise NotImplementedError class CompositeConfigType(ConfigType, statypes.CompositeType): _compare_keys: tuple[str, ...] def __init__(self, tspec: statypes.CompositeTypeSpec, **kwargs) -> None: object.__setattr__(self, '_tspec', tspec) for f in tspec.fields.values(): if f.name in kwargs: object.__setattr__(self, f.name, kwargs[f.name]) elif f.default is not statypes.MISSING: object.__setattr__(self, f.name, f.default) object.__setattr__(self, '_compare_keys', tuple( f.name for f in tspec.fields.values() if f.unique )) def __setattr__(self, k, v) -> None: raise TypeError(f"{self._tspec.name} is immutable") def __eq__(self, rhs: Any) -> bool: if ( not isinstance(rhs, CompositeConfigType) or self._tspec != rhs._tspec ): return NotImplemented compare_keys = self._compare_keys return ( tuple(getattr(self, k) for k in compare_keys) == tuple(getattr(rhs, k) for k in compare_keys) ) def __hash__(self) -> int: return hash(tuple(getattr(self, k) for k in self._compare_keys)) def __repr__(self) -> str: body = ', '.join( f'{f.name}={getattr(self, f.name)!r}' for f in self._tspec.fields.values() if hasattr(self, f.name) ) return f'{self._tspec.name}({body})' @classmethod def from_pyvalue( cls, data, *, tspec: statypes.CompositeTypeSpec, spec: spec.Spec, allow_missing=False, ) -> CompositeConfigType: if allow_missing and data is None: return None # type: ignore if not isinstance(data, dict): raise cls._err(tspec, f'expected a dict value, got {type(data)!r}') data = dict(data) tname = data.pop('_tname', None) if tname is not None: tspec = spec.get_type_by_name(tname) assert tspec fields = tspec.fields items = {} inv_keys = [] for fieldname, value in data.items(): field = fields.get(fieldname) if field is None: if value is None: # This may happen when data is produced by # a polymorphic config query. pass else: inv_keys.append(fieldname) continue f_type = field.type if value is None: # Config queries return empty pointer values as None. continue if typing_inspect.is_generic_type(f_type): container = typing_inspect.get_origin(f_type) if container not in (frozenset, list): raise RuntimeError( f'invalid type annotation on ' f'{tspec.name}.{fieldname}: ' f'{f_type!r} is not supported') eltype = typing_inspect.get_args(f_type, evaluate=True)[0] if isinstance(value, eltype): value = container((value,)) elif (typeutils.is_container(value) and all(isinstance(v, eltype) for v in value)): value = container(value) else: raise cls._err( tspec, f'invalid {fieldname!r} field value: expecting ' f'{eltype.__name__} or a list thereof, but got ' f'{type(value).__name__}' ) elif (isinstance(f_type, ConfigTypeSpec) and isinstance(value, dict)): tname = value.get('_tname', None) if tname is not None: actual_f_type = spec.get_type_by_name(tname) else: actual_f_type = f_type value['_tname'] = f_type.name value = cls.from_pyvalue(value, tspec=actual_f_type, spec=spec) elif ( _issubclass(f_type, statypes.Duration) and isinstance(value, str) ): value = statypes.Duration.from_iso8601(value) elif ( _issubclass(f_type, statypes.ConfigMemory) and isinstance(value, str | int) ): value = statypes.ConfigMemory(value) elif ( _issubclass(f_type, statypes.EnumScalarType) and isinstance(value, str) ): value = f_type(value) elif not isinstance(f_type, type) or not isinstance(value, f_type): raise cls._err( tspec, f'invalid {fieldname!r} field value: expecting ' f'{f_type.__name__}, but got {type(value).__name__}' ) items[fieldname] = value if inv_keys: sinv_keys = ', '.join(repr(r) for r in inv_keys) raise cls._err(tspec, f'unknown fields: {sinv_keys}') for fieldname, field in fields.items(): if fieldname not in items and field.default is statypes.MISSING: if allow_missing: items[fieldname] = None else: raise cls._err( tspec, f'missing required field: {fieldname!r}' ) try: return cls(tspec, **items) except (TypeError, ValueError) as ex: raise cls._err(tspec, str(ex)) @classmethod def get_edgeql_typeid(cls): return s_obj.get_known_type_id('std::json') @classmethod def from_json_value(cls, s, *, tspec: statypes.CompositeTypeSpec, spec): return cls.from_pyvalue(s, tspec=tspec, spec=spec) def to_json_value(self, redacted: bool = False): dct = {} dct['_tname'] = self._tspec.name for f in self._tspec.fields.values(): f_type = f.type value = getattr(self, f.name) if redacted and f.secret and value is not None: value = {'redacted': True} elif (isinstance(f_type, statypes.CompositeTypeSpec) and value is not None): value = value.to_json_value(redacted=redacted) elif typing_inspect.is_generic_type(f_type): value = list(value) if value is not None else [] elif (_issubclass(f_type, statypes.ScalarType) and value is not None): value = value.to_json() dct[f.name] = value return dct @classmethod def _err( cls, tspec: statypes.CompositeTypeSpec, msg: str ) -> errors.ConfigurationError: return errors.ConfigurationError( f'invalid {tspec.name.lower()!r} value: {msg}') class QueryCacheMode(enum.StrEnum): InMemory = "InMemory" RegInline = "RegInline" PgFunc = "PgFunc" Default = "Default" @classmethod def effective(cls, value: str | None) -> QueryCacheMode: if value is None: rv = cls.Default else: rv = cls(value) if rv is QueryCacheMode.Default: # Persistent cache disabled for now by default on arm64 linux # because of observed problems in CI test runs. if platform.system() == 'Linux' and platform.machine() == 'arm64': rv = QueryCacheMode.InMemory else: rv = QueryCacheMode.PgFunc return rv ================================================ FILE: edb/server/connpool/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os from .pool import Pool as Pool1Impl, _NaivePool # NoQA from .pool2 import Pool as Pool2Impl # During the transition period we allow for the pool to be swapped out. The # current default is to use the old pool, however this will be switched to use # the new pool once we've fully implemented all required features. if os.environ.get("EDGEDB_USE_NEW_CONNPOOL", "") == "1": Pool = Pool2Impl Pool2 = Pool1Impl else: # The two pools have the same effective type shape Pool = Pool1Impl # type: ignore Pool2 = Pool2Impl # type: ignore __all__ = ('Pool', 'Pool2') ================================================ FILE: edb/server/connpool/config.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging MIN_CONN_TIME_THRESHOLD = 0.01 MIN_QUERY_TIME_THRESHOLD = 0.001 MIN_LOG_TIME_THRESHOLD = 1 MIN_IDLE_TIME_BEFORE_GC = 120 CONNECT_FAILURE_RETRIES = 3 STATS_COLLECT_INTERVAL = 0.1 logger = logging.getLogger("edb.server") ================================================ FILE: edb/server/connpool/pool.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import typing import asyncio import collections import dataclasses import time from . import rolavg from . import config from .config import logger CP1 = typing.TypeVar('CP1', covariant=True) CP2 = typing.TypeVar('CP2', contravariant=True) class Connector(typing.Protocol[CP1]): def __call__(self, dbname: str) -> typing.Awaitable[CP1]: pass class Disconnector(typing.Protocol[CP2]): def __call__(self, conn: CP2) -> typing.Awaitable[None]: pass class StatsCollector(typing.Protocol): def __call__(self, stats: Snapshot) -> None: pass @dataclasses.dataclass class BlockSnapshot: dbname: str nwaiters_avg: int nconns: int npending: int nwaiters: int quota: int @dataclasses.dataclass class SnapshotLog: timestamp: float event: str dbname: str value: int @dataclasses.dataclass class Snapshot: timestamp: float capacity: int blocks: list[BlockSnapshot] log: list[SnapshotLog] failed_connects: int failed_disconnects: int successful_connects: int successful_disconnects: int @dataclasses.dataclass class ConnectionState: in_use_since: float = 0 in_use: bool = False in_stack_since: float = 0 class Block[C]: # A Block holds a number of connections to the same backend database. # A Pool consists of one or more blocks; blocks are the basic unit of # connection pool algorithm, while the pool itself also takes care of # balancing resources between Blocks (because all the blocks share the same # PostgreSQL `max_connections` limit), based on realtime statistics the # block collected and populated. # # Instead of the regular round-robin queue, EdgeDB adopted an LIFO stack # (conn_stack) for the connections - the most recently used connections are # always yielded first. This allows us to run "garbage collection" or # "connection stealing" to recycle the unused connections from the bottom # of the stack (the least recently used ones), so that other blocks could # reuse the spared resource. # # Block is coroutine-safe. Multiple tasks acquiring connections will be put # in a waiters' queue (conn_waiters), if the demand cannot be fulfilled # immediately without blocking/awaiting. When connections are ready in the # stack, the next task in the queue will be woken up to continue. loop: asyncio.AbstractEventLoop dbname: str conns: dict[C, ConnectionState] quota: int pending_conns: int last_connect_timestamp: float conn_acquired_num: int conn_waiters_num: int conn_waiters: collections.deque[asyncio.Future[None]] conn_stack: collections.deque[C] connect_failures_num: int querytime_avg: rolavg.RollingAverage nwaiters_avg: rolavg.RollingAverage suppressed: bool _cached_calibrated_demand: float _is_log_batching: bool _last_log_timestamp: float _log_events: dict[str, int] def __init__( self, dbname: str, loop: asyncio.AbstractEventLoop, ) -> None: self.dbname = dbname self.conns = {} self.quota = 1 self.pending_conns = 0 self.last_connect_timestamp = 0 self.loop = loop self.conn_acquired_num = 0 self.conn_waiters_num = 0 self.conn_waiters = collections.deque() self.conn_stack = collections.deque() self.connect_failures_num = 0 self.querytime_avg = rolavg.RollingAverage(history_size=20) self.nwaiters_avg = rolavg.RollingAverage(history_size=3) self.suppressed = False self._is_log_batching = False self._last_log_timestamp = 0 self._log_events = {} def count_conns(self) -> int: # The total number of connections in this block, including: # - Future connections that are still pending in connecting # - Idle connections in the stack # - Acquired connections (not in the stack) return len(self.conns) + self.pending_conns def count_waiters(self) -> int: # The number of tasks that are blocked in acquire() return self.conn_waiters_num def count_queued_conns(self) -> int: # Number of connections in the stack/pool return len(self.conn_stack) def count_pending_conns(self) -> int: # Number of future connections that are still pending in connecting return self.pending_conns def count_conns_over_quota(self) -> int: # How many connections over the quota return max(self.count_conns() - self.quota, 0) def count_approx_available_conns(self) -> int: # It's approximate because there might be a race when a connection # is being returned to the pool but not yet acquired by a waiter, # in which case the number isn't going to be accurate. return max( self.count_conns() - self.conn_acquired_num - self.conn_waiters_num, 0 ) def inc_acquire_counter(self) -> None: self.conn_acquired_num += 1 def dec_acquire_counter(self) -> None: self.conn_acquired_num -= 1 def try_steal( self, only_older_than: typing.Optional[float] = None ) -> typing.Optional[C]: # Try to take one unused connection from the block without blocking. # If only_older_than is provided, only the connection that was put in # the stack before the given timestamp is returned. None will be # returned if we cannot find such connection in this block. if not self.conn_stack: return None if only_older_than is not None: # We only need to check the bottom of the stack - higher items in # the stack only have larger timestamps oldest_conn = self.conn_stack[0] if self.conns[oldest_conn].in_stack_since > only_older_than: return None return self.conn_stack.popleft() async def try_acquire(self, *, attempts: int = 1) -> typing.Optional[C]: self.conn_waiters_num += 1 try: # Skip the waiters' queue if we can grab a connection from the # stack immediately - this is not completely fair, but it's # extremely hard to always take the shortcut and starve the queue # without blocking the main loop, so we are fine here. (This is # also how asyncio.Queue is implemented.) if not self.conn_stack: waiter = self.loop.create_future() if attempts > 1: # If the waiter was woken up only to discover that # it needs to wait again, we don't want it to lose # its place in the waiters queue. self.conn_waiters.appendleft(waiter) else: # On the first attempt the waiter goes to the end # of the waiters queue. self.conn_waiters.append(waiter) try: await waiter except Exception: if not waiter.done(): waiter.cancel() try: self.conn_waiters.remove(waiter) except ValueError: # The waiter could be removed from self.conn_waiters # by a previous release() call. pass if self.conn_stack and not waiter.cancelled(): # We were woken up by release(), but can't take # the call. Wake up the next in line. self._wakeup_next_waiter() raise # There can be a race between a waiter scheduled for to wake up # and a connection being stolen (due to quota being enforced, # for example). In which case the waiter might get finally # woken up with an empty queue -- hence the 'try'. # acquire will put a while loop around this # Yield the most recently used connection from the top of the stack if self.conn_stack: return self.conn_stack.pop() else: return None finally: self.conn_waiters_num -= 1 async def acquire(self) -> C: attempts = 1 while (c := await self.try_acquire(attempts=attempts)) is None: attempts += 1 return c def release(self, conn: C) -> None: # Put the connection (back) to the top of the stack, self.conn_stack.append(conn) # refresh the timestamp, self.conns[conn].in_stack_since = time.monotonic() # and call the queue. self._wakeup_next_waiter() def abort_waiters(self, e: Exception) -> None: # Propagate the given exception to all tasks that are waiting in # acquire() - this usually means the underlying connect() is failing while self.conn_waiters: waiter = self.conn_waiters.popleft() if not waiter.done(): waiter.set_exception(e) def _wakeup_next_waiter(self) -> None: while self.conn_waiters: waiter = self.conn_waiters.popleft() if not waiter.done(): waiter.set_result(None) break def log_connection(self, event: str, timestamp: float = 0) -> None: if not timestamp: timestamp = time.monotonic() # Add to the backlog if we're in batching, regardless of the time if self._is_log_batching: self._log_events[event] = self._log_events.setdefault(event, 0) + 1 # Time check only if we're not in batching elif timestamp - self._last_log_timestamp > \ config.MIN_LOG_TIME_THRESHOLD: logger.info( "Connection %s to backend database: %s", event, self.dbname ) self._last_log_timestamp = timestamp # Start batching if logging is too frequent, add timer only once here else: self._is_log_batching = True self._log_events = {event: 1} self.loop.call_later( config.MIN_LOG_TIME_THRESHOLD, self._log_batched_conns, ) def _log_batched_conns(self) -> None: logger.info( "Backend connections to database %s: %s " "in at least the last %.1f seconds.", self.dbname, ', '.join( f'{num} were {event}' for event, num in self._log_events.items() ), config.MIN_LOG_TIME_THRESHOLD, ) self._is_log_batching = False self._last_log_timestamp = time.monotonic() class BasePool[C]: _connect_cb: Connector[C] _disconnect_cb: Disconnector[C] _stats_cb: typing.Optional[StatsCollector] _max_capacity: int # total number of connections allowed in the pool _cur_capacity: int # counter of all connections (with pending) in the pool _loop: typing.Optional[asyncio.AbstractEventLoop] _current_snapshot: typing.Optional[Snapshot] _blocks: collections.OrderedDict[str, Block[C]] # Mapping from dbname to the Block instances, also used as a queue in a # starving situation when the blocks are fed with connections in a round- # robin fashion, see also Pool._tick(). _is_starving: bool # Indicates if any block is starving for connections, this usually means # the number of active blocks is greater than the pool max capacity. _failed_connects: int _failed_disconnects: int _successful_connects: int _successful_disconnects: int _conntime_avg: rolavg.RollingAverage def __init__( self, *, connect: Connector[C], disconnect: Disconnector[C], max_capacity: int, stats_collector: typing.Optional[StatsCollector]=None, ) -> None: self._connect_cb = connect self._disconnect_cb = disconnect self._stats_cb = stats_collector self._max_capacity = max_capacity self._cur_capacity = 0 self._loop = None self._current_snapshot = None self._blocks = collections.OrderedDict() self._is_starving = False self._running = True self._failed_connects = 0 self._failed_disconnects = 0 self._successful_connects = 0 self._successful_disconnects = 0 self._conntime_avg = rolavg.RollingAverage(history_size=10) async def close(self) -> None: self._running = False @property def max_capacity(self) -> int: return self._max_capacity @property def current_capacity(self) -> int: return self._cur_capacity @property def failed_connects(self) -> int: return self._failed_connects @property def failed_disconnects(self) -> int: return self._failed_disconnects @property def active_conns(self) -> int: return self.current_capacity - self._get_pending_conns() def _get_pending_conns(self) -> int: return sum( block.count_pending_conns() for block in self._blocks.values() ) def _get_loop(self) -> asyncio.AbstractEventLoop: if self._loop is None: self._loop = asyncio.get_running_loop() return self._loop def _build_snapshot(self, *, now: float) -> Snapshot: bstats: list[BlockSnapshot] = [] for block in self._blocks.values(): bstats.append( BlockSnapshot( dbname=block.dbname, nwaiters_avg=round(block.nwaiters_avg.avg()), nconns=len(block.conns), npending=block.count_pending_conns(), nwaiters=block.count_waiters(), quota=block.quota, ) ) bstats.sort(key=lambda b: b.dbname) return Snapshot( timestamp=now, blocks=bstats, capacity=self._cur_capacity, log=[], failed_connects=self._failed_connects, failed_disconnects=self._failed_disconnects, successful_connects=self._successful_connects, successful_disconnects=self._successful_disconnects, ) def _capture_snapshot(self, *, now: float) -> None: if self._stats_cb is None: return None assert self._current_snapshot is None self._current_snapshot = self._build_snapshot(now=now) def _report_snapshot(self) -> None: if self._stats_cb is None: return assert self._current_snapshot is not None self._stats_cb(self._current_snapshot) self._current_snapshot = None def _log_to_snapshot( self, *, dbname: str, event: str, value: int=0, now: float=0, ) -> None: if self._stats_cb is None: return if now == 0: now = time.monotonic() assert self._current_snapshot is not None self._current_snapshot.log.append( SnapshotLog( timestamp=now, dbname=dbname, event=event, value=value, ) ) def _new_block(self, dbname: str) -> Block[C]: assert dbname not in self._blocks block: Block[C] = Block(dbname, self._get_loop()) self._blocks[dbname] = block block.quota = 1 if self._is_starving: self._blocks.move_to_end(dbname, last=False) return block def _drop_block(self, block: Block[C]) -> None: assert not block.count_waiters() assert not block.count_conns() assert not block.quota self._blocks.pop(block.dbname) def _get_block(self, dbname: str) -> Block[C]: block = self._blocks.get(dbname) if block is None: block = self._new_block(dbname) return block async def _connect( self, block: Block[C], started_at: float, event: str ) -> None: logger.debug( "Establishing new connection to backend database: %s", block.dbname ) try: conn = await self._connect_cb(block.dbname) except Exception as e: self._failed_connects += 1 self._cur_capacity -= 1 logger.error( "Failed to establish a new connection to backend database: %s", block.dbname, exc_info=self._running, ) block.connect_failures_num += 1 if getattr(e, 'fields', {}).get('C') == '3D000': # 3D000 - INVALID CATALOG NAME, database does not exist # Skip retry and propagate the error immediately if block.connect_failures_num <= config.CONNECT_FAILURE_RETRIES: block.connect_failures_num = ( config.CONNECT_FAILURE_RETRIES + 1) if block.connect_failures_num > config.CONNECT_FAILURE_RETRIES: # Abort all waiters on this block and propagate the error, as # we don't have a mapping between waiters and _connect() tasks block.abort_waiters(e) else: # We must retry immediately here (without sleeping), or _tick() # will jump in and schedule more retries than what we expected. self._schedule_new_conn(block, event) return else: # reset the failure counter if we got the connection back block.connect_failures_num = 0 finally: ended_at = time.monotonic() self._conntime_avg.add(ended_at - started_at) block.pending_conns -= 1 self._successful_connects += 1 block.conns[conn] = ConnectionState() block.last_connect_timestamp = ended_at # Release the connection to block waiters. block.release(conn) block.log_connection(event, ended_at) async def _disconnect(self, conn: C, block: Block[C]) -> None: logger.debug( "Discarding a connection to backend database: %s", block.dbname ) try: await self._disconnect_cb(conn) except Exception: self._failed_disconnects += 1 raise else: self._successful_disconnects += 1 finally: self._cur_capacity -= 1 async def _transfer( self, from_block: Block[C], from_conn: C, to_block: Block[C], started_at: float, ) -> None: self._log_to_snapshot(dbname=from_block.dbname, event='transfer-from') await self._disconnect(from_conn, from_block) from_block.log_connection('transferred out') self._cur_capacity += 1 await self._connect(to_block, started_at, 'transferred in') def _schedule_transfer( self, from_block: Block[C], from_conn: C, to_block: Block[C], ) -> None: started_at = time.monotonic() assert not from_block.conns[from_conn].in_use from_block.conns.pop(from_conn) to_block.pending_conns += 1 if self._is_starving: self._blocks.move_to_end(to_block.dbname, last=True) self._blocks.move_to_end(from_block.dbname, last=True) self._get_loop().create_task( self._transfer(from_block, from_conn, to_block, started_at)) def _schedule_new_conn( self, block: Block[C], event: str = 'established' ) -> None: started_at = time.monotonic() self._cur_capacity += 1 block.pending_conns += 1 if self._is_starving: self._blocks.move_to_end(block.dbname, last=True) self._log_to_snapshot( dbname=block.dbname, event='connect', value=block.count_conns()) self._get_loop().create_task(self._connect(block, started_at, event)) def _schedule_discard(self, block: Block[C], conn: C) -> None: self._get_loop().create_task(self._discard_conn(block, conn)) async def _discard_conn(self, block: Block[C], conn: C) -> None: assert not block.conns[conn].in_use block.conns.pop(conn) self._log_to_snapshot( dbname=block.dbname, event='disconnect', value=block.count_conns()) await self._disconnect(conn, block) block.log_connection("discarded") class Pool[C](BasePool[C]): # The backend database connection pool implementation in EdgeDB, managing # connections to multiple databases of a single PostgreSQL cluster, # optimized for quality of service (QoS) so that connection acquisitions # and distribution are automatically balanced in a relatively fair way. # Connections to the same database are managed in a Block (see above). # # Conceptually, the Pool has 4 runtime modes (separately optimized): # Mode A: managing connections to only one database # Mode B: multiple databases, below max capacity # Mode C: reached max capacity, some tasks are waiting for connections # Mode D: some blocks are starving with zero connection # # Mode A is close to a regular connection pool - new connections are only # created when there are not enough spare ones in the pool, and used # connections are released back to the pool, cached for next acquisition # (unless being idle for too long and GC will recycle them). As a # simplified mode, there is usually a shortcut to return early for Mode A # in the same code base shared with other modes. # # Mode B is simply an extension of Mode A for multiple databases. Each # block in Mode B acts just like Mode A, with minimal difference like less # aggressive connection creation. Different blocks could freely create new # connections when needed, racing with each other organically by the demand # for Postgres connections. # # Mode C is when things get complicated. Without being able to create more # connections, pending connection requests can only be satisfied by either # a released connection from the same block, or the pool as the arbiter has # to "transfer" a connection from another block. This is achieved by # rebalancing the pool based on calculated per-block quotas recalibrated # in periodic "ticks" (see _tick()). # # In extreme cases, the number of blocks may go beyond the max capacity. # This is Mode D when even each block takes only at most one connection, # there are still some starved blocks that have no connections at all. # Mode D reuses the framework of Mode C but runs separate logic in a # different if-else branch. In short, the pool reallocates the limited # total number of connections to different blocks in a round-robin fashion. _new_blocks_waitlist: collections.OrderedDict[Block[C], bool] _blocks_over_quota: list[Block[C]] _nacquires: int _htick: typing.Optional[asyncio.Handle] _to_drop: list[Block[C]] _gc_interval: float # minimum seconds between GC runs _gc_requests: int # number of GC requests def __init__( self, *, connect: Connector[C], disconnect: Disconnector[C], max_capacity: int, stats_collector: typing.Optional[StatsCollector]=None, min_idle_time_before_gc: float = config.MIN_IDLE_TIME_BEFORE_GC, ) -> None: super().__init__( connect=connect, disconnect=disconnect, stats_collector=stats_collector, max_capacity=max_capacity, ) self._new_blocks_waitlist = collections.OrderedDict() self._blocks_over_quota = [] self._nacquires = 0 self._htick = None self._first_tick = True self._to_drop = [] self._gc_interval = min_idle_time_before_gc self._gc_requests = 0 def _maybe_schedule_tick(self) -> None: if self._first_tick: self._first_tick = False self._capture_snapshot(now=time.monotonic()) # Only schedule a tick under Mode C/D, and schedule at most one tick # at a time. if not self._nacquires or self._htick is not None: return self._htick = self._get_loop().call_later( max(self._conntime_avg.avg(), config.MIN_CONN_TIME_THRESHOLD), self._tick ) def _tick(self) -> None: self._htick = None if self._nacquires: # Schedule the next tick if we're still in Mode C/D. self._maybe_schedule_tick() now = time.monotonic() self._report_snapshot() self._capture_snapshot(now=now) # If we're managing connections to only one PostgreSQL DB (Mode A), # bail out early. Just give the one and only block we have the max # possible quota (which is needed only for logging purposes.) nblocks = len(self._blocks) if nblocks <= 1: self._is_starving = False if nblocks: first_block = next(iter(self._blocks.values())) first_block.quota = self._max_capacity first_block.nwaiters_avg.add(first_block.count_waiters()) return # Go over all the blocks and calculate: # - "nwaiters" - number of connection acquisitions # (including pending and acquired, per block and total) # - First round of per-block quota ( := nwaiters ) # - Calibrated demand (per block and total) # - If any block is starving / Mode D need_conns_at_least = 0 total_nwaiters = 0 total_calibrated_demand: float = 0 min_demand = float('inf') self._to_drop.clear() for block in self._blocks.values(): nwaiters = block.count_waiters() + block.conn_acquired_num block.quota = nwaiters # will likely be overwritten below total_nwaiters += nwaiters block.nwaiters_avg.add(nwaiters) nwaiters_avg = block.nwaiters_avg.avg() if nwaiters_avg and not block.suppressed: # GOTCHA: this is a counter of blocks that need at least 1 # connection. If this number is greater than _max_capacity, # some block will be starving with zero connection. need_conns_at_least += 1 else: if not block.count_conns(): self._to_drop.append(block) continue demand = ( max(nwaiters_avg, nwaiters) * max(block.querytime_avg.avg(), config.MIN_QUERY_TIME_THRESHOLD) ) total_calibrated_demand += demand block._cached_calibrated_demand = demand if min_demand > demand: min_demand = demand was_starving = self._is_starving self._is_starving = need_conns_at_least >= self._max_capacity if self._to_drop: for block in self._to_drop: self._drop_block(block) if not total_nwaiters: # No connection acquisition, nothing to do here. return if total_nwaiters < self._max_capacity: # The total demand for connections is lower than our max capacity, # we could bail out early. if self._cur_capacity >= self._max_capacity: # GOTCHA: this is still Mode C, because the total_nwaiters # number doesn't include the unused connections in the stacks # if any. Therefore, the rebalance here is necessary to shrink # those blocks and transfer the connection quota to the # starving ones (or they will block). We could simply depend on # the already-set quota based on nwaiters, and skip the regular # Mode C quota calculation below. self._maybe_rebalance() else: # If we still have space for more connections (Mode B), don't # actively rebalance the pool just yet - rebalance will kick in # when the max capacity is hit; or we'll depend on the garbage # collection to shrink the over-quota blocks. pass return if self._is_starving: # Mode D: recalculate the per-block quota. for block in tuple(self._blocks.values()): nconns = block.count_conns() if nconns == 1: if ( now - block.last_connect_timestamp < max(self._conntime_avg.avg(), config.MIN_CONN_TIME_THRESHOLD) ): # let it keep its connection block.quota = 1 else: block.quota = 0 self._blocks.move_to_end(block.dbname, last=True) elif nconns > 1: block.quota = 0 self._blocks.move_to_end(block.dbname, last=True) else: block.quota = 1 self._blocks.move_to_end(block.dbname, last=True) if block.quota: self._log_to_snapshot( dbname=block.dbname, event='set-quota', value=block.quota) else: self._log_to_snapshot( dbname=block.dbname, event='reset-quota') if not was_starving and self._new_blocks_waitlist: # Mode D assumes all connections are already in use or to be # used, depending on their `release()` to schedule transfers. # When just entering Mode D, there can be a special case when # no further `release()` will be called because all acquired # connections were returned to the pool before `_tick()` got a # chance to set `self._is_starving`, while some other blocks # are literally starving to death (blocked forever). # # This branch handles this particular case, by stealing # connections from the idle blocks and try to free them into # the starving blocks. for block in list(self._blocks.values()): while self._should_free_conn(block): if (conn := block.try_steal()) is None: # no more from this block break elif not self._maybe_free_into_starving_blocks( block, conn ): # put back the last stolen connection if we # don't need to steal anymore self._release_unused(block, conn) return else: # Mode C: distribute the total connections by calibrated demand # setting the per-block quota, then trigger rebalance. capacity_left = self._max_capacity if min_demand / total_calibrated_demand * self._max_capacity < 1: for block in self._blocks.values(): demand = block._cached_calibrated_demand if not demand: block.quota = 0 self._log_to_snapshot( dbname=block.dbname, event='reset-quota') k = (self._max_capacity * demand) / total_calibrated_demand if 0 < k <= 1: block.quota = 1 self._log_to_snapshot( dbname=block.dbname, event='set-quota', value=block.quota) capacity_left -= 1 assert capacity_left > 0 acc: float = 0 for block in self._blocks.values(): demand = block._cached_calibrated_demand if not demand: continue old_acc = acc acc += ( (capacity_left * demand) / total_calibrated_demand ) block.quota = round(acc) - round(old_acc) self._log_to_snapshot( dbname=block.dbname, event='set-quota', value=block.quota) self._maybe_rebalance() def _maybe_rebalance(self) -> None: if self._is_starving: return self._blocks_over_quota.clear() for block in self._blocks.values(): nconns = block.count_conns() quota = block.quota if nconns > quota: self._try_shrink_block(block) if block.count_conns() > quota: # If the block is still over quota, add it to a list so # that other blocks could steal connections from it self._blocks_over_quota.append(block) elif nconns < quota: while ( block.count_conns() < quota and self._cur_capacity < self._max_capacity ): self._schedule_new_conn(block) if self._blocks_over_quota: self._blocks_over_quota.sort( key=lambda b: b.count_conns_over_quota(), reverse=True ) def _should_free_conn(self, from_block: Block[C]) -> bool: # First, if we only manage one connection to one PostgreSQL DB -- # we don't need to bother with rebalancing the pool. So we bail out. if len(self._blocks) <= 1: return False from_block_size = from_block.count_conns() # Second, we bail out if: # # * the pool isn't starving, i.e. we have a fewer number of # different DB connections than the max number of connections # allowed; # # AND # # * the `from_block` block has fewer connections than its quota. if not self._is_starving and from_block_size <= from_block.quota: return False # Third, we bail out if: # # * the pool is starving; # # AND YET # # * the `from_block` block has only one connection; # # AND # # * the block has active waiters in its queue; # # AND # # * the block has been holding its last and only connection for # less time than the average time it spends on connecting to # PostgreSQL. if ( self._is_starving and from_block_size == 1 and from_block.count_waiters() and (time.monotonic() - from_block.last_connect_timestamp) < max(self._conntime_avg.avg(), config.MIN_CONN_TIME_THRESHOLD) ): return False return True def _maybe_free_into_starving_blocks( self, from_block: Block[C], conn: C, ) -> bool: label, to_block = self._find_most_starving_block() if to_block is None or to_block is from_block: return False assert label is not None self._schedule_transfer(from_block, conn, to_block) self._log_to_snapshot( dbname=to_block.dbname, event=label, value=1, ) return True def _try_shrink_block(self, block: Block[C]) -> None: while ( block.count_conns_over_quota() and self._should_free_conn(block) ): if (conn := block.try_steal()) is not None: _, to_block = self._find_most_starving_block() if to_block is not None: self._schedule_transfer(block, conn, to_block) else: self._schedule_discard(block, conn) else: break def _try_steal_conn(self, for_block: Block[C]) -> bool: if not self._blocks_over_quota: return False for block in self._blocks_over_quota: if block is for_block or not self._should_free_conn(block): continue if (conn := block.try_steal()) is not None: self._log_to_snapshot( dbname=block.dbname, event='conn-stolen') self._schedule_transfer(block, conn, for_block) return True return False def _find_most_starving_block( self, ) -> tuple[typing.Optional[str], typing.Optional[Block[C]]]: to_block = None # Find if there are any newly created blocks waiting for their # first connection. while self._new_blocks_waitlist: block, _ = self._new_blocks_waitlist.popitem(last=False) if block.count_conns() or not block.count_waiters(): # This block is already initialized. Skip it. # This branch shouldn't happen. continue to_block = block break if to_block is not None: return 'first-conn', to_block # Find if there are blocks without a single connection. # Find the one that is starving the most. max_need = 0 for block in self._blocks.values(): block_size = block.count_conns() block_demand = block.count_waiters() if block_size or not block_demand or block.suppressed: continue if block_demand > max_need: max_need = block_demand to_block = block if to_block is not None: return 'revive-conn', to_block # Find all blocks that are under quota and award the most # starving one. max_need = 0 for block in self._blocks.values(): block_size = block.count_conns() block_quota = block.quota if block_quota > block_size and not block.suppressed: need = block_quota - block_size if need > max_need: max_need = need to_block = block if to_block: return 'redist-conn', to_block return None, None async def _acquire(self, dbname: str) -> C: block = self._get_block(dbname) block.suppressed = False room_for_new_conns = self._cur_capacity < self._max_capacity block_nconns = block.count_conns() if room_for_new_conns: # First, schedule new connections if needed. if len(self._blocks) == 1: # Managing connections to only one DB and can open more # connections. Or this is before the first tick. if block.count_queued_conns() <= 1: # Only keep at most 1 spare connection in the ready queue. # When concurrent tasks are racing for the spare # connections in the same loop iteration, early requesters # will retrieve the spare connections immediately without # context switch (block.acquire() will not "block" in # await). Therefore, we will create just enough new # connections for the number of late requesters plus one. self._schedule_new_conn(block) elif ( not block_nconns or block_nconns < block.quota or not block.count_approx_available_conns() ): # Block has no connections at all, or not enough connections. self._schedule_new_conn(block) return await block.acquire() if not block_nconns: # This is a block without any connections. # Request one of the next released connections to be # reallocated for this block. if not self._try_steal_conn(block): self._new_blocks_waitlist[block] = True return await block.acquire() if block_nconns < block.quota: # Let's see if we can steal a connection from some block # that's over quota and open a new one. self._try_steal_conn(block) return await block.acquire() return await block.acquire() def _run_gc(self) -> None: loop = self._get_loop() if self._is_starving: # Bail out early if any block is starving, try GC later loop.call_later(self._gc_interval, self._run_gc) return if self._gc_requests > 1: # Schedule to run one more GC for requests before this run self._gc_requests = 1 loop.call_later(self._gc_interval, self._run_gc) else: # We will take care of the only GC request and pause GC self._gc_requests = 0 # Make sure the unused connections stay in the pool for at least one # GC interval. So theoretically unused connections are usually GC-ed # within 1-2 GC intervals. only_older_than = time.monotonic() - self._gc_interval for block in self._blocks.values(): while (conn := block.try_steal(only_older_than)) is not None: self._schedule_discard(block, conn) async def acquire(self, dbname: str) -> C: self._nacquires += 1 self._maybe_schedule_tick() try: conn = await self._acquire(dbname) finally: self._nacquires -= 1 block = self._blocks[dbname] assert not block.conns[conn].in_use block.inc_acquire_counter() block.conns[conn].in_use = True block.conns[conn].in_use_since = time.monotonic() return conn def release(self, dbname: str, conn: C, *, discard: bool = False) -> None: try: block = self._blocks[dbname] except KeyError: raise RuntimeError( f'cannot release connection {conn!r}: {dbname!r} database ' f'is not known to the pool' ) from None try: conn_state = block.conns[conn] except KeyError: raise RuntimeError( f'cannot release connection {conn!r}: the connection does not ' f'belong to the pool' ) from None if not conn_state.in_use: raise RuntimeError( f'cannot release connection {conn!r}: the connection was ' f'never acquired from the pool' ) from None block.dec_acquire_counter() block.querytime_avg.add(time.monotonic() - conn_state.in_use_since) conn_state.in_use = False conn_state.in_use_since = 0 self._maybe_schedule_tick() if not ( self._should_free_conn(block) and self._maybe_free_into_starving_blocks(block, conn) ): if discard: # Concurrent `acquire()` may be waiting to reuse the released # connection here - as we should discard this one, let's just # schedule a new one in the same block. self._schedule_discard(block, conn) self._schedule_new_conn(block) else: self._release_unused(block, conn) def _release_unused(self, block: Block[C], conn: C) -> None: block.release(conn) # Only request for GC if the connection is released unused self._gc_requests += 1 if self._gc_requests == 1: # Only schedule GC for the very first request - following # requests will be grouped into the next GC self._get_loop().call_later(self._gc_interval, self._run_gc) async def prune_inactive_connections(self, dbname: str) -> None: try: block = self._blocks[dbname] except KeyError: return None # Mark the block as suppressed, so that nothing will be # transferred to it. It will be unsuppressed if anything # actually tries to connect. # TODO: Is it possible to safely drop the block? block.suppressed = True conns = [] while (conn := block.try_steal()) is not None: conns.append(conn) while not block.count_waiters() and block.pending_conns: # try_acquire, because it can get stolen if c := await block.try_acquire(): conns.append(c) if conns: await asyncio.gather( *(self._discard_conn(block, conn) for conn in conns), return_exceptions=True ) async def prune_all_connections(self) -> None: # Brutally close all connections. This is used by HA failover. coros = [] for block in self._blocks.values(): block.conn_stack.clear() for conn in block.conns: coros.append(self._disconnect(conn, block)) block.conns.clear() self._log_to_snapshot( dbname=block.dbname, event='disconnect', value=0) await asyncio.gather(*coros, return_exceptions=True) # We don't have to worry about pending_conns here - # Tenant._pg_connect() will honor the failover and raise an error. def iterate_connections(self) -> typing.Iterator[C]: for block in self._blocks.values(): for conn in block.conns: yield conn class _NaivePool[C](BasePool[C]): """Implements a rather naive and flawed balancing algorithm. Should only be used for for testing purposes. """ _conns: dict[str, set[C]] _last_tick: float def __init__( self, connect: Connector[C], disconnect: Disconnector[C], max_capacity: int, stats_collector: typing.Optional[StatsCollector]=None, min_idle_time_before_gc: float = config.MIN_IDLE_TIME_BEFORE_GC, ) -> None: super().__init__( connect=connect, disconnect=disconnect, stats_collector=stats_collector, max_capacity=max_capacity, ) self._conns = {} self._last_tick = 0 def _maybe_tick(self) -> None: now = time.monotonic() if self._last_tick == 0: # First time `_tick()` is run. self._capture_snapshot(now=now) self._last_tick = now return if now - self._last_tick < 0.1: # Not enough time passed since the last tick. return self._last_tick = now self._report_snapshot() self._capture_snapshot(now=now) async def _steal_conn(self, for_block: Block[C]) -> None: # A simplified connection stealing implementation. # First, tries to steal one from the blocks queue unconditionally. for block in self._blocks.values(): if block is for_block: continue if (conn := block.try_steal()) is not None: self._log_to_snapshot( dbname=block.dbname, event='conn-stolen') self._schedule_transfer(block, conn, for_block) self._blocks.move_to_end(block.dbname, last=True) return # If all the blocks are busy, simply wait in the queue to get one. for block in self._blocks.values(): if block is for_block: continue if block.count_conns(): conn = await block.acquire() self._log_to_snapshot( dbname=block.dbname, event='conn-stolen') self._schedule_transfer(block, conn, for_block) self._blocks.move_to_end(block.dbname, last=True) return async def acquire(self, dbname: str) -> C: self._maybe_tick() block = self._get_block(dbname) if self._cur_capacity < self._max_capacity: self._schedule_new_conn(block) elif not block.count_conns(): # As a new block, steal one connection from other blocks if the # max capacity is reached. We cannot depend on the transfer logic # in `release()`, because it would hang if no other block releases. await self._steal_conn(block) return await block.acquire() def release(self, dbname: str, conn: C) -> None: self._maybe_tick() this_block = self._get_block(dbname) if this_block.count_conns() < this_block.count_waiters(): this_block.release(conn) return max_need = 0 to_block = None for block in self._blocks.values(): block_size = block.count_conns() block_demand = block.count_waiters() if not block_size and block_demand: need = block_demand * 1000 elif block_size < block_demand: need = block_demand - block_size else: continue if need > max_need: max_need = block_demand to_block = block if to_block is this_block or to_block is None: this_block.release(conn) return self._schedule_transfer(this_block, conn, to_block) self._log_to_snapshot( dbname=to_block.dbname, event='free', value=1, ) ================================================ FILE: edb/server/connpool/pool2.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import edb.server._rust_native._conn_pool as _rust import asyncio import time import typing import dataclasses import pickle from . import config from .config import logger from edb.server import rust_async_channel CP1 = typing.TypeVar('CP1', covariant=True) CP2 = typing.TypeVar('CP2', contravariant=True) class Connector(typing.Protocol[CP1]): def __call__(self, dbname: str) -> typing.Awaitable[CP1]: pass class Disconnector(typing.Protocol[CP2]): def __call__(self, conn: CP2) -> typing.Awaitable[None]: pass @dataclasses.dataclass class BlockSnapshot: dbname: str nwaiters_avg: int nconns: int npending: int nwaiters: int quota: int @dataclasses.dataclass class SnapshotLog: timestamp: float event: str dbname: str value: int @dataclasses.dataclass class Snapshot: timestamp: float capacity: int blocks: list[BlockSnapshot] log: list[SnapshotLog] failed_connects: int failed_disconnects: int successful_connects: int successful_disconnects: int class StatsCollector(typing.Protocol): def __call__(self, stats: Snapshot) -> None: pass # Connections must be hashable because we use them to reverse-lookup # an internal ID. class Pool[C: typing.Hashable]: _pool: _rust.ConnPool _next_conn_id: int _failed_connects: int _failed_disconnects: int _successful_connects: int _successful_disconnects: int _cur_capacity: int _max_capacity: int _task: typing.Optional[asyncio.Task[None]] _acquires: dict[int, asyncio.Future[int]] _prunes: dict[int, asyncio.Future[None]] _conns: dict[int, C] _errors: dict[int, BaseException] _conns_held: dict[C, int] _loop: asyncio.AbstractEventLoop _counts: typing.Any _stats_collector: typing.Optional[StatsCollector] def __init__( self, *, connect: Connector[C], disconnect: Disconnector[C], max_capacity: int, stats_collector: typing.Optional[StatsCollector] = None, min_idle_time_before_gc: float = config.MIN_IDLE_TIME_BEFORE_GC, ) -> None: # Re-load the logger if it's been mocked for testing global logger logger = config.logger logger.info( f'Creating a connection pool with max_capacity={max_capacity}' ) self._connect = connect self._disconnect = disconnect self._pool = _rust.ConnPool( max_capacity, min_idle_time_before_gc, config.STATS_COLLECT_INTERVAL ) self._max_capacity = max_capacity self._cur_capacity = 0 self._next_conn_id = 0 self._acquires = {} self._conns = {} self._errors = {} self._conns_held = {} self._prunes = {} self._loop = asyncio.get_running_loop() self._channel = rust_async_channel.RustAsyncChannel( self._pool._channel, self._process_message, ) self._task = self._loop.create_task(self._boot(self._channel)) self._failed_connects = 0 self._failed_disconnects = 0 self._successful_connects = 0 self._successful_disconnects = 0 self._counts = None self._stats_collector = stats_collector if stats_collector: stats_collector(self._build_snapshot(now=time.monotonic())) pass def __del__(self) -> None: if self._task: self._task.cancel() self._task = None async def close(self) -> None: if self._task: # Cancel the currently-executing futures for acq in self._acquires.values(): acq.set_exception(asyncio.CancelledError()) for prune in self._prunes.values(): prune.set_exception(asyncio.CancelledError()) logger.info("Closing connection pool...") task = self._task self._task = None task.cancel() try: await task except asyncio.exceptions.CancelledError: pass self._pool = None logger.info("Closed connection pool") async def _boot( self, channel: rust_async_channel.RustAsyncChannel, ) -> None: logger.info("Python-side connection pool booted") try: await channel.run() finally: channel.close() def _try_read(self) -> None: if self._channel: self._channel.read_hint() def _process_message(self, msg: typing.Any) -> None: # If we're closing, don't dispatch any operations if not self._task: return if msg[0] == 0: if f := self._acquires.pop(msg[1], None): f.set_result(msg[2]) else: logger.warning(f"Duplicate result for acquire {msg[1]}") elif msg[0] == 1: self._loop.create_task(self._perform_connect(msg[1], msg[2])) elif msg[0] == 2: self._loop.create_task(self._perform_disconnect(msg[1])) elif msg[0] == 3: self._loop.create_task(self._perform_reconnect(msg[1], msg[2])) elif msg[0] == 4: self._loop.create_task(self._perform_prune(msg[1])) elif msg[0] == 5: # Note that we might end up with duplicated messages at shutdown error = self._errors.pop(msg[2], None) if error: if f := self._acquires.pop(msg[1], None): f.set_exception(error) else: logger.warning(f"Duplicate exception for acquire {msg[1]}") elif msg[0] == 6: # Pickled metrics self._counts = pickle.loads(msg[1]) if self._stats_collector: self._stats_collector( self._build_snapshot(now=time.monotonic()) ) else: logger.critical(f'Unexpected message: {msg}') async def _perform_connect(self, id: int, db: str) -> None: self._cur_capacity += 1 try: self._conns[id] = await self._connect(db) self._successful_connects += 1 if self._pool: self._pool._completed(id) except Exception as e: self._errors[id] = e if self._pool: self._pool._failed(id, e) async def _perform_disconnect(self, id: int) -> None: try: conn = self._conns.pop(id) await self._disconnect(conn) self._successful_disconnects += 1 self._cur_capacity -= 1 if self._pool: self._pool._completed(id) except Exception as e: self._cur_capacity -= 1 if self._pool: self._pool._failed(id, e) async def _perform_reconnect(self, id: int, db: str) -> None: try: # Note that we cannot hold this connection here as there is an # implicit expectation that the connection will GC after disconnect # but before reconnect. conn = self._conns.pop(id) await self._disconnect(conn) self._successful_disconnects += 1 try: self._conns[id] = await self._connect(db) self._successful_connects += 1 if self._pool: self._pool._completed(id) except Exception as e: self._errors[id] = e if self._pool: self._pool._failed(id, e) except Exception as e: del self._conns[id] self._cur_capacity -= 1 if self._pool: self._pool._failed(id, e) async def _perform_prune(self, id: int) -> None: self._prunes[id].set_result(None) async def acquire(self, dbname: str) -> C: """Acquire a connection from the database. This connection must be released.""" if not self._task: raise asyncio.CancelledError() for i in range(config.CONNECT_FAILURE_RETRIES + 1): id = self._next_conn_id self._next_conn_id += 1 acquire: asyncio.Future[int] = asyncio.Future() self._acquires[id] = acquire self._pool._acquire(id, dbname) self._try_read() # This may throw! try: conn = await acquire c = self._conns[conn] self._conns_held[c] = id return c except Exception as e: # 3D000 - INVALID CATALOG NAME, database does not exist # Skip retry and propagate the error immediately if getattr(e, 'fields', {}).get('C') == '3D000': raise # Allow the final exception to escape if i == config.CONNECT_FAILURE_RETRIES: logger.exception( 'Failed to acquire connection, will not ' f'retry {dbname} ({self._cur_capacity}' 'active)' ) raise logger.exception( 'Failed to acquire connection, will retry: ' f'{dbname} ({self._cur_capacity} active)' ) raise AssertionError("Unreachable end of loop") def release(self, dbname: str, conn: C, discard: bool = False) -> None: """Releases a connection back into the pool, discarding or returning it in the background.""" id = self._conns_held.pop(conn) if discard: self._pool._discard(id) else: self._pool._release(id) self._try_read() async def prune_inactive_connections(self, dbname: str) -> None: if not self._task: raise asyncio.CancelledError() id = self._next_conn_id self._next_conn_id += 1 self._prunes[id] = asyncio.Future() self._pool._prune(id, dbname) await self._prunes[id] del self._prunes[id] async def prune_all_connections(self) -> None: # Brutally close all connections. This is used by HA failover. coros = [] for conn in self._conns.values(): coros.append(self._disconnect(conn)) await asyncio.gather(*coros, return_exceptions=True) @property def active_conns(self) -> int: return len(self._conns_held) def iterate_connections(self) -> typing.Iterator[C]: for conn in self._conns.values(): yield conn def _build_snapshot(self, *, now: float) -> Snapshot: blocks: list[BlockSnapshot] = [] if self._counts: block_stats = self._counts['blocks'] for dbname, stats in block_stats.items(): v = stats['value'] block_snapshot = BlockSnapshot( dbname=dbname, nconns=v[_rust.METRIC_ACTIVE], nwaiters_avg=v[_rust.METRIC_WAITING], npending=v[_rust.METRIC_CONNECTING] + v[_rust.METRIC_RECONNECTING], nwaiters=v[_rust.METRIC_WAITING], quota=stats['target'], ) blocks.append(block_snapshot) pass return Snapshot( timestamp=now, blocks=blocks, capacity=self._cur_capacity, log=[], failed_connects=self._failed_connects, failed_disconnects=self._failed_disconnects, successful_connects=self._successful_connects, successful_disconnects=self._successful_disconnects, ) @property def max_capacity(self) -> int: return self._max_capacity @property def current_capacity(self) -> int: return self._cur_capacity @property def failed_connects(self) -> int: return self._failed_connects @property def failed_disconnects(self) -> int: return self._failed_disconnects ================================================ FILE: edb/server/connpool/rolavg.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # class RollingAverage: __slots__ = ('_hist_size', '_hist', '_pos', '_cached_avg') _hist_size: int _pos: int _hist: list[float] _cached_avg: float def __init__(self, *, history_size: int): self._hist = [0] * history_size self._pos = 0 self._hist_size = history_size self._cached_avg = 0 def add(self, n: float) -> None: self._hist[self._pos % self._hist_size] = n self._pos += 1 self._cached_avg = 0 def avg(self) -> float: if self._cached_avg: return self._cached_avg self._cached_avg = ( sum(self._hist) / max(min(self._pos, self._hist_size), 1) ) return self._cached_avg ================================================ FILE: edb/server/consul.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright EdgeDB Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional import asyncio import json import logging import random import urllib.parse import httptools from edb.common import asyncwatcher logger = logging.getLogger("edb.server.consul") class ConsulKVProtocol(asyncwatcher.AsyncWatcherProtocol): def __init__( self, watcher: asyncwatcher.AsyncWatcher, consul_host: str, key: str, ) -> None: assert not key.startswith("/"), "absolute path rewrites Consul KV URL" super().__init__(watcher) self._host = consul_host self._key = key self._watcher = watcher self._parser = httptools.HttpResponseParser(self) self._last_modify_index: Optional[str] = None self._buffers: list[bytes] = [] def data_received(self, data: bytes) -> None: self._parser.feed_data(data) def on_status(self, status: bytes) -> None: status_code = self._parser.get_status_code() if status_code != 200: logger.debug( "Consul is returning non-200 responses: %s %r", status_code, status, ) if self._transport is not None: self._transport.close() def on_body(self, body: bytes) -> None: self._buffers.append(body) def on_message_complete(self) -> None: try: code = self._parser.get_status_code() if code == 200: self._watcher.incr_metrics_counter("watch-update") payload = json.loads(b"".join(self._buffers))[0] last_modify_index = payload["ModifyIndex"] self._watcher.on_update(payload["Value"]) if self._last_modify_index == last_modify_index: self._watcher.incr_metrics_counter("watch-timeout") self._last_modify_index = None else: self._last_modify_index = last_modify_index else: self._watcher.incr_metrics_counter(f"watch-err-{code}") self.request() finally: self._buffers.clear() def request(self) -> None: delay = self._watcher.consume_tokens(1) if delay > 0: asyncio.get_running_loop().call_later( delay + random.random() * 0.1, self.request ) return uri = urllib.parse.urljoin("/v1/kv/", self._key) if self._last_modify_index is not None: uri += f"?index={self._last_modify_index}" if self._transport is None or self._transport.is_closing(): logger.error("cannot perform Consul request: connection is closed") return self._transport.write( f"GET {uri} HTTP/1.1\r\n" f"Host: {self._host}\r\n" f"\r\n".encode() ) def close(self) -> None: if self._transport is not None and not self._transport.is_closing(): self._transport.close() self._transport = None ================================================ FILE: edb/server/daemon/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2012-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from .pidfile import PidFile # NOQA from .exceptions import DaemonError # NOQA from .daemon import DaemonContext # NOQA ================================================ FILE: edb/server/daemon/daemon.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2012-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Implementation of PEP 3143.""" from __future__ import annotations from typing import Optional import atexit import io import os import signal from . import lib, pidfile as pidfile_module from .exceptions import DaemonError class DaemonContext: def __init__( self, *, pidfile: Optional[os.PathLike] = None, files_preserve: Optional[list] = None, working_directory: str = '/', umask: int = 0o022, uid: Optional[int] = None, gid: Optional[int] = None, detach_process: Optional[bool] = None, prevent_core: bool = True, stdin: Optional[io.FileIO] = None, stdout: Optional[io.FileIO] = None, stderr: Optional[io.FileIO] = None, signal_map: Optional[dict] = None ): self.pidfile = os.fspath(pidfile) if pidfile is not None else None self.files_preserve = files_preserve self.working_directory = working_directory self.umask = umask self.prevent_core = prevent_core self.signal_map = signal_map if stdin is not None and not isinstance(stdin, str): lib.validate_stream(stdin, stream_name='stdin') self.stdin = stdin if stdout is not None and not isinstance(stdout, str): lib.validate_stream(stdout, stream_name='stdout') self.stdout = stdout if stderr is not None and not isinstance(stderr, str): lib.validate_stream(stderr, stream_name='stderr') self.stderr = stderr self.uid = uid self.gid = gid if detach_process is None: self.detach_process = lib.is_detach_process_context_required() else: self.detach_process = detach_process self._is_open = False self._close_stdin = self._close_stdout = self._close_stderr = None self._stdin_name = self._stdout_name = self._stderr_name = None self._pidfile = None is_open = property(lambda self: self._is_open) def open(self): if self._is_open: return self._init_pidfile() if self.prevent_core: lib.prevent_core_dump() lib.change_umask(self.umask) lib.change_working_directory(self.working_directory) # Test that we can write to log files/output right after # chdir call self._test_sys_streams() if self.uid is not None: lib.change_process_uid(self.uid) if self.gid is not None: lib.change_process_gid(self.gid) if self.detach_process: lib.detach_process_context() self._setup_signals() if self._pidfile is not None: self._pidfile.acquire() self._close_all_open_files() self._open_sys_streams() self._is_open = True atexit.register(self.close) def close(self): if not self._is_open: return atexit.unregister(self.close) if self._pidfile is not None: self._pidfile.release() self._pidfile = None self._close_sys_streams() self._is_open = False def _close_sys_streams(self): if self._close_stdin: self._close_stdin.close() self._close_stdin = None self.stdin = None if self._close_stdout: self._close_stdout.close() self._close_stdout = None self.stdout = None if self._close_stderr: self._close_stderr.close() self._close_stderr = None self.stderr = None def _test_sys_streams(self): stderr = self.stderr or self._stderr_name if isinstance(stderr, str): open(stderr, 'at').close() stdout = self.stdout or self._stdout_name if isinstance(stdout, str): open(stdout, 'at').close() def _open_sys_streams(self): stdin = self.stdin or self._stdin_name if isinstance(stdin, str): self._stdin_name = stdin self._close_stdin = stdin = open(stdin, 'rt') else: self._stdin_name = getattr(stdin, 'name', None) lib.redirect_stream('stdin', stdin) stderr = self.stderr or self._stderr_name if isinstance(stderr, str): self._stderr_name = stderr self._close_stderr = stderr = open(stderr, 'at') else: self._stderr_name = getattr(stderr, 'name', None) lib.redirect_stream('stderr', stderr) stdout = self.stdout or self._stdout_name if isinstance(stdout, str): self._stdout_name = stdout self._close_stdout = stdout = open(stdout, 'at') else: self._stdout_name = getattr(stdout, 'name', None) lib.redirect_stream('stdout', stdout) def signal_reopen_sys_streams(self, signal_number, stack_frame): self._close_sys_streams() self._open_sys_streams() def signal_terminate(self, signal_number, stack_frame): raise SystemExit('Termination on signal {}'.format(signal_number)) def __enter__(self): self.open() return self def __exit__(self, exc_type, exc, exc_tb): self.close() def _close_all_open_files(self): excl = set() if self.files_preserve: excl.update(self.files_preserve) if self.stderr and not isinstance(self.stderr, str): excl.add(self.stderr.fileno()) if self.stdin and not isinstance(self.stdin, str): excl.add(self.stdin.fileno()) if self.stdout and not isinstance(self.stdout, str): excl.add(self.stdout.fileno()) if self._pidfile is not None: pidfile = self._pidfile.fileno if pidfile is not None: excl.add(pidfile) lib.close_all_open_files(excl) def _setup_signals(self): signal_map = { 'SIGTSTP': None, 'SIGTTIN': None, 'SIGTTOU': None, 'SIGTERM': 'signal_terminate', 'SIGHUP': 'signal_reopen_sys_streams' } if self.signal_map: signal_map.update(self.signal_map) for name, handler in signal_map.items(): if isinstance(name, str): try: num = getattr(signal, name) except AttributeError: raise DaemonError('Invalid signal name {!r}'.format(name)) elif isinstance(name, int): if name < 1 or name >= signal.NSIG: raise DaemonError( 'Invalid signal number {!r}'.format(name)) num = name else: raise DaemonError( 'Invalid signal {!r}, str or int expected'.format(name)) if handler is None: signal.signal(num, signal.SIG_IGN) elif isinstance(handler, str): try: handler = getattr(self, handler) except AttributeError: raise DaemonError( 'Invalid signal {!r} handler name {!r}'.format( name, handler)) signal.signal(num, handler) else: if not callable(handler): raise DaemonError( 'Excpected callable signal {!r} handler: {!r}'.format( name, handler)) signal.signal(num, handler) def _init_pidfile(self): if self.pidfile is None: return if isinstance(self.pidfile, str): self._pidfile = pidfile_module.PidFile(self.pidfile) else: if isinstance(self.pidfile, pidfile_module.PidFile): if self.pidfile.locked: raise DaemonError( 'Pidfile object is already locked; ' 'unable to initialize daemon context') self._pidfile = self.pidfile else: raise DaemonError( 'Invalid pidfile, str of PidFile expected, got {!r}'. format(self.pidfile)) ================================================ FILE: edb/server/daemon/exceptions.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2012-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb import errors class DaemonError(errors.InternalServerError): pass ================================================ FILE: edb/server/daemon/lib.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2012-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional import errno import io import os import fcntl import logging import resource import stat import socket import sys from .exceptions import DaemonError logger = logging.getLogger('edb.server.daemon') def is_process_running(pid: int): """Check if there is a running process with `pid`.""" try: os.kill(pid, 0) return True except OSError as ex: if ex.errno == errno.ESRCH: return False else: raise def lock_file(fileno: int): """Lock file. Returns ``True`` if succeeded, ``False`` otherwise.""" try: # Try to lock file exclusively and in non-blocking fashion fcntl.flock(fileno, fcntl.LOCK_EX | fcntl.LOCK_NB) except IOError: return False else: return True def make_readonly(path: str): """Make a file read-only.""" assert os.path.isfile(path) os.chmod(path, stat.S_IROTH | stat.S_IRUSR | stat.S_IRGRP) def change_working_directory(path: str): """Change the working directory for this process.""" try: os.chdir(path) except OSError as ex: raise DaemonError( 'Unable to change working directory to {!r}'.format(path)) from ex def change_process_gid(gid: int): """Change the GID of this process. Requires appropriate OS privileges for this process. """ try: os.setgid(gid) except OSError as ex: raise DaemonError( 'Unable to change the owning GID to {!r}'.format(gid)) from ex def change_process_uid(uid: int): """Change the UID of this process. Requires appropriate OS privileges for this process. """ try: os.setuid(uid) except OSError as ex: raise DaemonError( 'Unable to change the owning UID to {!r}'.format(uid)) from ex def change_umask(mask: int): """Change process umask.""" try: os.umask(mask) except (OSError, OverflowError) as ex: raise DaemonError('Unable to set process umask to {:#o}'.format( mask)) from ex def prevent_core_dump(): """Prevent this process from generating a core dump.""" core_resource = resource.RLIMIT_CORE try: resource.getrlimit(core_resource) except ValueError as ex: raise DaemonError( 'Unable to limit core dump size: ' 'system does not support RLIMIT_CORE resource limit') from ex # Set hard & soft limits to 0, i.e. no core dump at all resource.setrlimit(core_resource, (0, 0)) def detach_process_context(): """Datach process context. Does it in three steps: 1. Forks and exists parent process. This detaches us from shell, and since the child will have a new PID but will inherit the Group PID from parent, the new process will not be a group leader. 2. Call 'setsid' to create a new session. This makes the process a session leader of a new session, process becomes the process group leader of a new process group and it doesn't have a controlling terminal. 3. Form and exit parent again. This guarantees that the daemon is not a session leader, which prevents it from acquiring a controlling terminal. Reference: “Advanced Programming in the Unix Environment”, section 13.3, by W. Richard Stevens. """ def fork_and_exit_parent(error_message): try: if os.fork() > 0: # Don't need to call 'sys.exit', as we don't want to # run any python interpreter clean-up handlers os._exit(0) except OSError as ex: raise DaemonError( '{}: [{}] {}'.format(error_message, ex.errno, ex.strerror)) from ex fork_and_exit_parent(error_message='Failed the first fork') os.setsid() fork_and_exit_parent(error_message='Failed the second fork') def is_process_started_by_init(): """Determine if the current process is started by 'init'.""" # The 'init' process has its PID set to 1. return os.getppid() == 1 def is_socket(fd): """Determine if the file descriptor is a socket.""" file_socket = socket.fromfd(fd, socket.AF_INET, socket.SOCK_RAW) try: file_socket.getsockopt(socket.SOL_SOCKET, socket.SO_TYPE) except socket.error as ex: return ex.args[0] != errno.ENOTSOCK else: return True def is_process_started_by_superserver(): """Determine if the current process is started by the superserver.""" # The internet superserver creates a network socket, and # attaches it to the standard streams of the child process. try: fileno = sys.__stdin__.fileno() except Exception: return False else: return is_socket(fileno) def is_detach_process_context_required(): """Determine whether detaching process context is required. Returns ``True`` if: - Process was started by `init`; or - Process was started by `inetd`. """ return not is_process_started_by_init( ) and not is_process_started_by_superserver() def get_max_fileno(default: int=2048): """Return the maximum number of open file descriptors.""" limit = resource.getrlimit(resource.RLIMIT_NOFILE)[1] if limit == resource.RLIM_INFINITY: return default return limit def try_close_fileno(fileno: int): """Try to close fileno.""" try: os.close(fileno) except OSError as ex: if ex.errno != errno.EBADF: raise DaemonError( 'Failed to close file descriptor {}'.format(fileno)) def close_all_open_files(exclude: Optional[set] = None): """Close all open file descriptors.""" maxfd = get_max_fileno() if exclude: for fd in reversed(range(maxfd)): if fd not in exclude: try_close_fileno(fd) else: for fd in reversed(range(maxfd)): try_close_fileno(fd) def redirect_stream(stream_name: str, target_stream: io.FileIO): """Redirect a system stream to the specified file. If ``target_stream`` is None - redirect to devnull. """ if target_stream is None: target_fd = os.open(os.devnull, os.O_RDWR) else: target_fd = target_stream.fileno() system_stream = getattr(sys, stream_name) os.dup2(target_fd, system_stream.fileno()) setattr(sys, '__{}__'.format(stream_name), system_stream) def validate_stream(stream, *, stream_name): """Check if `stream` is an open io.IOBase instance.""" if not isinstance(stream, io.IOBase): raise DaemonError( 'Invalid {} stream object, an instance of io.IOBase is expected'. format(stream_name)) if stream.closed: raise DaemonError('Stream {} is already closed'.format(stream_name)) ================================================ FILE: edb/server/daemon/pidfile.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2012-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import os import errno from .exceptions import DaemonError from . import lib class PidFile: def __init__(self, path, *, pid=None, data=None): self._path = path self._data = data self._pid = pid self._file = None self._locked = False locked = property(lambda self: self._locked) def _prepare_file_content(self): buf = '' if self._pid is None: buf += str(os.getpid()) else: buf += str(self._pid) if self._data: buf += '\n\n{}'.format(self._data) return buf def acquire(self): if self.locked: # No point in allowing re-entrance raise DaemonError('pid file is already acquired') path = self._path pidfile_dir = os.path.dirname(path) if not os.path.isdir(pidfile_dir): raise DaemonError( f"cannot create pid file: {pidfile_dir} " f"does not exist or is not a directory" ) if os.path.exists(path): if self.is_locked(path): raise DaemonError( 'pid file {!r} exists and belongs to a ' 'running process'.format(path)) os.unlink(path) self._file = open(path, 'wt') fileno = self._file.fileno() if not lib.lock_file(fileno): raise DaemonError('pid file {!r} already locked'.format(path)) self._file.write(self._prepare_file_content()) self._file.flush() lib.make_readonly(path) self._locked = True def fileno(self): if self._locked: return self._file.fileno() def release(self): if not self.locked: raise DaemonError('pid file is already released') if self._file: self._file.close() self._file = None if os.path.exists(self._path): os.remove(self._path) self._locked = False def __enter__(self): self.acquire() return self def __exit__(self, exc_type, exc, exc_tb): self.release() @classmethod def is_locked(cls, path): if os.path.exists(path): # If pid file already exists - check if it belongs to a # running process. If not - it should be safe to remove it. try: with open(path, 'rt') as f: pid = int(f.readline()) if lib.is_process_running(pid): return True except OSError as er: if er.errno == errno.ENOENT: # ENOENT - No such file or directory # Race - file did exist when we checked if it exists, but # got deleted before 'with open' was executed return False raise return False @classmethod def read(cls, path): with open(path, 'rt') as f: pid = int(f.readline()) f.readline() data = f.read() return pid, (data or None) ================================================ FILE: edb/server/dbview/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from .dbview import DatabaseIndex, Database, DatabaseConnectionView __all__ = ('DatabaseIndex', 'Database', 'DatabaseConnectionView') ================================================ FILE: edb/server/dbview/dbview.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # cimport cython from libc.stdint cimport uint64_t from edb.server.cache cimport stmt_cache cdef DEFAULT_STATE cpdef enum SideEffects: SchemaChanges = 1 << 0 DatabaseConfigChanges = 1 << 1 InstanceConfigChanges = 1 << 2 GlobalSchemaChanges = 1 << 3 DatabaseChanges = 1 << 4 @cython.final cdef class CompiledQuery: cdef public object query_unit_group cdef public object first_extra # Optional[int] cdef public object extra_counts cdef public object extra_blobs cdef public bint extra_formatted_as_text cdef public object extra_type_oids cdef public object request cdef public object recompiled_cache cdef public bint use_pending_func_cache cdef public object tag cdef bytes make_query_prefix(self) cdef class DatabaseIndex: cdef: dict _dbs object _server object _tenant object _sys_config object _comp_sys_config object _std_schema object _global_schema_pickle object _default_sysconfig object _sys_config_spec object _cached_compiler_args cdef invalidate_caches(self) cdef inline set_current_branches(self) cdef class Database: cdef: stmt_cache.StatementsCache _eql_to_compiled object _cache_locks object _sql_to_compiled DatabaseIndex _index object _views object _introspection_lock object _state_serializers readonly object user_config_spec object _cache_worker_task object _cache_queue object _cache_notify_task object _cache_notify_queue uint64_t _tx_seq object _active_tx_list object _func_cache_gt_tx_seq readonly str name readonly object schema_version readonly object dbver readonly object db_config readonly bytes user_schema_pickle readonly object reflection_cache readonly object backend_ids readonly object backend_oid_to_id readonly object extensions readonly object _feature_used_metrics readonly int dml_queries_executed cdef _invalidate_caches(self) cdef _cache_compiled_query(self, key, compiled) cdef _new_view(self, query_cache, protocol_version, role_name) cdef _remove_view(self, view) cdef _observe_auth_ext_config(self) cdef _set_backend_ids(self, types) cdef _update_backend_ids(self, new_types) cdef _set_extensions( self, extensions, ) cdef _set_feature_used_metrics(self, feature_used_metrics) cdef _set_and_signal_new_user_schema( self, new_schema_pickle, schema_version, extensions, ext_config_settings, feature_used_metrics, reflection_cache=?, backend_ids=?, db_config=?, start_stop_extensions=?, ) cpdef start_stop_extensions(self) cdef get_state_serializer(self, protocol_version) cpdef set_state_serializer(self, protocol_version, serializer) cdef inline uint64_t tx_seq_begin_tx(self) cdef inline tx_seq_end_tx(self, uint64_t seq) cdef class DatabaseConnectionView: cdef: Database _db bint _query_cache_enabled object _protocol_version str _role_name public bint is_transient # transient dbviews won't cause an immediate error in # ensure_database_not_connected(..., close_frontend_conns=False), # which is usually called from `DROP BRANCH` or `CREATE ... FROM`. # Although, transient dbviews users should guarantee the transient use # of pgcons, because _pg_ensure_database_not_connected() may still time # out `DROP BRANCH` if the transient pgcon is not released soon enough. # State properties object _config object _in_tx_config object _globals object _in_tx_globals object _modaliases object _in_tx_modaliases object _state_serializer object _in_tx_state_serializer object _command_state_serializer tuple _session_state_db_cache tuple _session_state_cache object _txid object _in_tx_db_config object _in_tx_savepoints object _in_tx_root_user_schema_pickle object _in_tx_user_schema_pickle object _in_tx_user_schema_version object _in_tx_user_config_spec object _in_tx_global_schema_pickle object _in_tx_new_types int _in_tx_dbver bint _in_tx uint64_t _in_tx_capabilities bint _in_tx_with_sysconfig bint _in_tx_with_dbconfig bint _in_tx_with_set bint _tx_error uint64_t _in_tx_seq object _in_tx_isolation_level uint64_t _capability_mask object _last_comp_state int _last_comp_state_id dict _sys_globals object __weakref__ cdef _reset_tx_state(self) cdef inline _check_in_tx_error(self, query_unit_group) cdef clear_tx_error(self) cdef rollback_tx_to_savepoint(self, name) cdef declare_savepoint(self, name, spid) cdef recover_aliases_and_config(self, modaliases, config, globals) cdef abort_tx(self) cpdef in_tx(self) cpdef in_tx_error(self) cdef cache_compiled_query(self, object key, object query_unit_group) cdef lookup_compiled_query(self, object key) cdef as_compiled(self, query_req, query_unit_group, bint use_metrics=?) cdef tx_error(self) cdef start(self, query_unit) cdef start_tx(self) cdef _apply_in_tx(self, query_unit) cdef start_implicit(self, query_unit) cdef on_error(self) cdef on_success(self, query_unit, new_types) cdef commit_implicit_tx( self, user_schema, extensions, ext_config_settings, global_schema, roles, cached_reflection, feature_used_metrics, ) cdef get_user_config_spec(self) cpdef get_config_spec(self) cpdef get_session_config(self) cdef set_session_config(self, new_conf) cpdef get_globals(self) cpdef set_globals(self, new_globals) cpdef get_global_value(self, k) cdef get_state_serializer(self) cdef set_state_serializer(self, new_serializer) cpdef get_database_config(self) cdef set_database_config(self, new_conf) cdef get_system_config(self) cpdef get_compilation_system_config(self) cdef config_lookup(self, name) cdef set_modaliases(self, new_aliases) cpdef get_modaliases(self) cdef bytes serialize_state(self) cdef bint is_state_desc_changed(self) cdef describe_state(self) cdef encode_state(self) cdef check_session_config_perms(self, keys) cdef decode_state(self, type_id, data) cdef decode_json_session_config(self, json_session_config) cdef bint needs_commit_after_state_sync(self) cdef check_capabilities( self, query_unit, allowed_capabilities, error_constructor, reason, unsafe_isolation_dangers, ) ================================================ FILE: edb/server/dbview/dbview.pyi ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import ( Any, Awaitable, Callable, Hashable, Iterator, Mapping, Optional, TypeAlias, ) import uuid import immutables from edb.schema import schema as s_schema from edb.server import config from edb.server import pgcon from edb.server import server from edb.server import tenant from edb.server.compiler import dbstate from edb.server.compiler import sertypes Config: TypeAlias = Mapping[str, config.SettingValue] class CompiledQuery: query_unit_group: dbstate.QueryUnitGroup class Database: name: str dbver: int db_config: Config extensions: set[str] user_config_spec: config.Spec dml_queries_executed: int @property def server(self) -> server.Server: ... @property def tenant(self) -> tenant.Tenant: ... def stop(self) -> None: ... async def monitor( self, worker: Callable[[], Awaitable[None]], name: str, ) -> None: ... async def cache_worker(self) -> None: ... async def cache_notifier(self) -> None: ... def start_stop_extensions(self) -> None: ... def cache_compiled_sql( self, key: Hashable, compiled: list[dbstate.SQLQueryUnit], schema_version: uuid.UUID, ) -> None: ... def lookup_compiled_sql( self, key: Hashable, ) -> Optional[list[dbstate.SQLQueryUnit]]: ... def set_state_serializer( self, protocol_version: tuple[int, int], serializer: sertypes.StateSerializer, ) -> None: pass def hydrate_cache(self, query_cache: list[tuple[bytes, ...]]) -> None: ... def invalidate_cache_entries(self, to_invalidate: list[uuid.UUID]) -> None: ... def clear_query_cache(self) -> None: ... def iter_views(self) -> Iterator[DatabaseConnectionView]: ... def get_query_cache_size(self) -> int: ... async def introspection(self) -> None: ... def lookup_config(self, name: str) -> Any: ... def is_introspected(self) -> bool: ... class DatabaseConnectionView: def in_tx(self) -> bool: ... def in_tx_error(self) -> bool: ... def get_session_config(self) -> Config: ... def get_modaliases(self) -> Mapping[str | None, str]: ... class DatabaseIndex: def __init__( self, tenant: tenant.Tenant, *, std_schema: s_schema.Schema, global_schema_pickle: bytes, sys_config: Config, default_sysconfig: Config, sys_config_spec: config.Spec, ) -> None: ... def count_connections(self, dbname: str) -> int: ... def get_sys_config(self) -> Config: ... def get_compilation_system_config(self) -> Config: ... def update_sys_config(self, sys_config: Config) -> None: ... def has_db(self, dbname: str) -> bool: ... def get_db(self, dbname) -> Database: ... def maybe_get_db(self, dbname) -> Optional[Database]: ... def get_global_schema_pickle(self) -> bytes: ... def update_global_schema(self, global_schema_pickle: bytes) -> None: ... def register_db( self, dbname: str, *, user_schema_pickle: Optional[bytes], schema_version: Optional[uuid.UUID], db_config: Optional[Config], reflection_cache: Optional[Mapping[str, tuple[str, ...]]], backend_ids: Optional[Mapping[str, tuple[int, str]]], extensions: Optional[set[str]], ext_config_settings: Optional[list[config.Setting]], early: bool = False, feature_used_metrics: Optional[Mapping[str, float]] = ..., ) -> Database: ... def unregister_db(self, dbname: str) -> None: ... def iter_dbs(self) -> Iterator[Database]: ... async def apply_system_config_op( self, conn: pgcon.PGConnection, op: config.Operation, ) -> None: ... def new_view( self, dbname: str, *, query_cache: bool, protocol_version: tuple[int, int], role_name: str, ) -> DatabaseConnectionView: ... def remove_view( self, view: DatabaseConnectionView, ) -> None: ... def invalidate_caches(self) -> None: ... def get_cached_compiler_args( self, ) -> tuple[ bytes, immutables.Map[str, config.SettingValue], ]: ... def lookup_config(self, name: str) -> Any: ... ================================================ FILE: edb/server/dbview/dbview.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import ( Optional, Sequence ) import asyncio import base64 import copy import json import logging import os.path import pickle import struct import time import typing import uuid import weakref import immutables from edb import errors from edb.common import debug, lru, uuidgen, asyncutil, span from edb import edgeql from edb.edgeql import qltypes from edb.schema import schema as s_schema from edb.schema import name as s_name from edb.server import compiler, defines, config, metrics, pgcon from edb.server.compiler import dbstate, enums, sertypes from edb.server.protocol import execute from edb.pgsql import dbops from edb.server.pgcon import errors as pgerror from edb.server.protocol import ai_ext cimport cython from edb.server.compiler cimport rpc from edb.server.protocol.args_ser cimport ( recode_global, ) __all__ = ( 'DatabaseIndex', 'DatabaseConnectionView', 'SideEffects', 'Database' ) cdef uint64_t PROTO_CAPS = enums.Capability.PROTO_CAPS cdef DEFAULT_MODALIASES = immutables.Map({None: defines.DEFAULT_MODULE_ALIAS}) cdef DEFAULT_CONFIG = immutables.Map() cdef DEFAULT_GLOBALS = immutables.Map() cdef DEFAULT_STATE = json.dumps([]).encode('utf-8') cdef INT32_PACKER = struct.Struct('!l').pack cdef int VER_COUNTER = 0 cdef DICTDEFAULT = (None, None) cdef object logger = logging.getLogger('edb.server') cdef uint64_t DML_CAPABILITIES = compiler.Capability.MODIFICATIONS cdef uint64_t DDL_CAPABILITIES = compiler.Capability.DDL DEF TEXT_OID = 25 # Mapping from oids of PostgreSQL types into corresponding EdgeQL type. # Needed only for pg types that do not exist in EdgeQL, such as pg_catalog.name cdef TYPES_SQL_ONLY = immutables.Map({ 18: "00000000-0000-0000-0000-000000000101", # char -> str 19: "00000000-0000-0000-0000-000000000101", # pgcatalog.name -> str 24: "00000000-0000-0000-0000-000000000101", # regproc -> str 26: "00000000-0000-0000-0000-000000000104", # oid -> int32 28: "00000000-0000-0000-0000-000000000104", # xid -> int32 29: "00000000-0000-0000-0000-000000000104", # cid -> int32 194: "00000000-0000-0000-0000-000000000101", # pg_node_tree -> str }) cdef next_dbver(): global VER_COUNTER VER_COUNTER += 1 return VER_COUNTER cdef enum CacheState: Pending = 0, Present, Evicted @cython.final cdef class CompiledQuery: def __init__( self, query_unit_group: dbstate.QueryUnitGroup, first_extra: Optional[int]=None, extra_counts=(), extra_blobs=(), extra_formatted_as_text: bool = False, extra_type_oids: Sequence[int] = (), request=None, recompiled_cache=None, use_pending_func_cache=False, ): self.query_unit_group = query_unit_group self.first_extra = first_extra self.extra_counts = extra_counts self.extra_blobs = extra_blobs self.extra_formatted_as_text = extra_formatted_as_text self.extra_type_oids = tuple(extra_type_oids) self.request = request self.recompiled_cache = recompiled_cache self.use_pending_func_cache = use_pending_func_cache cdef bytes make_query_prefix(self): data = {} if self.tag: data['tag'] = self.tag # maintenance reminder: please also update _amend_typedesc_in_sql() if data: data_bytes = json.dumps(data).encode(defines.EDGEDB_ENCODING) return b''.join([b'-- ', data_bytes, b'\n']) else: return b'' cdef class Database: # Global LRU cache of compiled queries _eql_to_compiled: stmt_cache.StatementsCache[uuid.UUID, dbstate.QueryUnitGroup] def __init__( self, DatabaseIndex index, str name, *, bytes user_schema_pickle, object schema_version, object db_config, object reflection_cache, object backend_ids, object extensions, object ext_config_settings, object feature_used_metrics, ): self.name = name self.schema_version = schema_version self.dbver = next_dbver() self._index = index self._views = weakref.WeakSet() self._state_serializers = {} self._introspection_lock = asyncio.Lock() self._eql_to_compiled = stmt_cache.StatementsCache( maxsize=self.lookup_config('query_cache_size') ) self._cache_locks = {} self._sql_to_compiled = lru.LRUMapping( maxsize=self.lookup_config('query_cache_size') ) # Tracks the active transactions and their creation sequence. The # sequence ID is incremental-only. ID 0 is reserved as a non-exist ID. self._tx_seq = 0 # most-recently used transaction sequence ID self._active_tx_list = {} # name it "list" to emphasize the order # Also an ordered dict of func cache present in DB but still using # inline SQL due to active transactions. self._func_cache_gt_tx_seq = {} self.db_config = db_config self.user_schema_pickle = user_schema_pickle if ext_config_settings is not None: self.user_config_spec = config.FlatSpec(*ext_config_settings) self.reflection_cache = reflection_cache self._set_backend_ids(backend_ids) self.extensions = set() self._set_extensions(extensions) self._observe_auth_ext_config() self._feature_used_metrics = {} self._set_feature_used_metrics(feature_used_metrics) self._cache_worker_task = self._cache_queue = None self._cache_notify_task = self._cache_notify_queue = None self._cache_queue = asyncio.Queue() self._cache_worker_task = asyncio.create_task( self.monitor(self.cache_worker, 'cache_worker')) # Queue of (key: str, is_add: bool) pairs. is_add signals # whether it is an addition or deletion. self._cache_notify_queue = asyncio.Queue() self._cache_notify_task = asyncio.create_task( self.monitor(self.cache_notifier, 'cache_notifier')) self.dml_queries_executed = 0 @property def server(self): return self._index._server @property def tenant(self): return self._index._tenant def stop(self): if self._cache_worker_task: self._cache_worker_task.cancel() self._cache_worker_task = None if self._cache_notify_task: self._cache_notify_task.cancel() self._cache_notify_task = None self._set_extensions(set()) self._set_feature_used_metrics({}) self.start_stop_extensions() async def monitor(self, worker, name): while True: try: await worker() except Exception as ex: debug.dump(ex) metrics.background_errors.inc( 1.0, self.tenant._instance_name, name ) # Give things time to recover, since the likely # failure mode here is a failover or some such. await asyncio.sleep(0.1) async def cache_worker(self): while True: # First, handle any evictions keys = [] while self._eql_to_compiled.needs_cleanup(): query_req, unit_group = self._eql_to_compiled.cleanup_one() if len(unit_group) == 1 and unit_group.cache_state == 1: keys.append(query_req.get_cache_key()) self._func_cache_gt_tx_seq.pop(query_req, None) unit_group.cache_state = CacheState.Evicted if keys: await self.tenant.evict_query_cache(self.name, keys) for key in keys: self._cache_notify_queue.put_nowait( (str(key), False) ) # Now, populate the cache # Empty the queue, for batching reasons. # N.B: This empty/get_nowait loop is safe because this is # an asyncio Queue. If it was threaded, it would be racy. ops = [await self._cache_queue.get()] while not self._cache_queue.empty(): ops.append(self._cache_queue.get_nowait()) # Filter ops for only what we need ops = [ (query_req, units) for query_req, units in ops if len(units) == 1 and units[0].cache_sql and units.cache_state == CacheState.Pending ] if not ops: continue g = execute.build_cache_persistence_units(ops) async with self.tenant.with_pgcon(self.name) as conn: try: await g.execute(conn, self) except Exception as e: logger.warning("Failed to persist function cache", exc_info=True) continue for query_req, units in ops: units.cache_state = CacheState.Present if self._active_tx_list: # Any active tx would delay the time we flip to func cache units.tx_seq_id = self._tx_seq self._func_cache_gt_tx_seq[query_req] = units else: units[0].maybe_use_func_cache() self._cache_notify_queue.put_nowait( (str(units[0].cache_key), True) ) cdef inline uint64_t tx_seq_begin_tx(self): self._tx_seq += 1 self._active_tx_list[self._tx_seq] = True return self._tx_seq cdef inline tx_seq_end_tx(self, uint64_t seq): # Remove the ending transaction from the active list if not self._active_tx_list.pop(seq, False): return # Stop early if we don't have func cache to activate if not self._func_cache_gt_tx_seq: return if self._active_tx_list: # Grab the seq ID of the oldest active transaction active_tx = next(iter(self._active_tx_list.keys())) else: # If all tx ended, we should just activate all pending func cache for units in self._func_cache_gt_tx_seq.values(): units[0].maybe_use_func_cache() self._func_cache_gt_tx_seq.clear() return # Or else, keep activating func cache until the oldest active tx drops = [] for query_req, units in self._func_cache_gt_tx_seq.items(): if units.tx_seq_id < active_tx: units[0].maybe_use_func_cache() drops.append(query_req) else: break for query_req in drops: self._func_cache_gt_tx_seq.pop(query_req) async def cache_notifier(self): await asyncutil.debounce( lambda: self._cache_notify_queue.get(), lambda keys: self.tenant.signal_sysevent( 'query-cache-changes', dbname=self.name, to_add=[k for k, b in keys if b], to_invalidate=[k for k, b in keys if not b], ), max_wait=1.0, delay_amt=0.2, # 100 keys will take up about 4000 bytes, which # fits in the 8000 allowed in events. max_batch_size=100, ) cdef _set_extensions(self, extensions): # Update metrics about extension use tname = self.tenant.get_instance_name() for ext in self.extensions: if ext not in extensions: metrics.extension_used.dec(1, tname, ext) for ext in extensions: if ext not in self.extensions: metrics.extension_used.inc(1, tname, ext) self.extensions = extensions cdef _set_feature_used_metrics(self, feature_used_metrics): # Update metrics about feature use # # We store the old feature use metrics so that we can # incrementally update them after DDL without needing to look # at the other database branches if feature_used_metrics is None: return tname = self.tenant.get_instance_name() keys = self._feature_used_metrics.keys() | feature_used_metrics.keys() for key in keys: # Update the count of how many times the feature is used metrics.feature_used.inc( feature_used_metrics.get(key, 0.0) - self._feature_used_metrics.get(key, 0.0), tname, key, ) # Update the count of branches using the feature at all metrics.feature_used_branches.inc( (feature_used_metrics.get(key, 0.0) > 0) - (self._feature_used_metrics.get(key, 0.0) > 0), tname, key, ) self._feature_used_metrics = feature_used_metrics cdef _set_and_signal_new_user_schema( self, new_schema_pickle, schema_version, extensions, ext_config_settings, feature_used_metrics, reflection_cache=None, backend_ids=None, db_config=None, start_stop_extensions=True, ): if new_schema_pickle is None: raise AssertionError('new_schema is not supposed to be None') self.schema_version = schema_version self.dbver = next_dbver() self.user_schema_pickle = new_schema_pickle self._set_extensions(extensions) self.user_config_spec = config.FlatSpec(*ext_config_settings) self._set_feature_used_metrics(feature_used_metrics) if backend_ids is not None: self._set_backend_ids(backend_ids) if reflection_cache is not None: self.reflection_cache = reflection_cache if db_config is not None: self.db_config = db_config self._observe_auth_ext_config() self._invalidate_caches() if start_stop_extensions: self.start_stop_extensions() cpdef start_stop_extensions(self): if "ai" in self.extensions: ai_ext.start_extension(self.tenant, self.name) else: ai_ext.stop_extension(self.tenant, self.name) cdef _observe_auth_ext_config(self): key = "ext::auth::AuthConfig::providers" if ( self.db_config is not None and self.user_config_spec is not None and key in self.user_config_spec ): providers = config.lookup( key, self.db_config, spec=self.user_config_spec, ) metrics.auth_providers.set( len(providers), self.tenant.get_instance_name(), self.name, ) cdef _set_backend_ids(self, types): self.backend_ids = {} self.backend_oid_to_id = dict(TYPES_SQL_ONLY) if types != None: self._update_backend_ids(types) cdef _update_backend_ids(self, new_types): self.backend_ids.update(new_types) self.backend_oid_to_id.update({ v[0]: k for k, v in new_types.items() if v[0] is not None }) cdef _invalidate_caches(self): self._sql_to_compiled.clear() self._index.invalidate_caches() cdef _cache_compiled_query(self, key, compiled: dbstate.QueryUnitGroup): # `dbver` must be the schema version `compiled` was compiled upon assert compiled.cacheable if key in self._eql_to_compiled: # We already have a cached query for the current user schema return self._eql_to_compiled[key] = compiled if self._cache_queue is not None: self._cache_queue.put_nowait((key, compiled)) def cache_compiled_sql(self, key, compiled: list[str], schema_version): existing, ver = self._sql_to_compiled.get(key, DICTDEFAULT) if existing is not None and ver == self.schema_version: # We already have a cached query for a more recent DB version. return if not all(unit.cacheable for unit in compiled): return # Store the matching schema version, see also the comments at origin self._sql_to_compiled[key] = compiled, schema_version def lookup_compiled_sql(self, key): rv, cached_ver = self._sql_to_compiled.get(key, DICTDEFAULT) if rv is not None and cached_ver != self.schema_version: rv = None return rv cdef _new_view(self, query_cache, protocol_version, role_name): view = DatabaseConnectionView( self, query_cache=query_cache, protocol_version=protocol_version, role_name=role_name, ) self._views.add(view) return view cdef _remove_view(self, view): self._views.remove(view) cdef get_state_serializer(self, protocol_version): return self._state_serializers.get(protocol_version) cpdef set_state_serializer(self, protocol_version, serializer): old_serializer = self._state_serializers.get(protocol_version) if ( old_serializer is None or old_serializer.type_id != serializer.type_id ): # also invalidate other protocol versions self._state_serializers = {protocol_version: serializer} return serializer else: return old_serializer def hydrate_cache(self, query_cache): warning_count = 0 for _, in_data, out_data in query_cache: try: query_req = rpc.CompilationRequest.deserialize( in_data, "", self.server.compilation_config_serializer, ) if query_req not in self._eql_to_compiled: unit = dbstate.QueryUnit.deserialize(out_data) group = dbstate.QueryUnitGroup() group.append(unit, serialize=False) group.cache_state = CacheState.Present if self._active_tx_list: # Any active transaction would delay the time we flip # to function cache group.tx_seq_id = self._tx_seq self._func_cache_gt_tx_seq[query_req] = group else: group[0].maybe_use_func_cache() self._eql_to_compiled[query_req] = group except Exception as e: if warning_count < 0: warning_count -= 1 elif warning_count < 10: logger.warning("skipping incompatible cache item: %s", e) warning_count += 1 else: logger.warning( "too many incompatible cache items, " "skipping the following warnings" ) warning_count = -warning_count - 1 if warning_count < 0: logger.warning( "skipped %d incompatible cache items", -warning_count ) def invalidate_cache_entry_object(self, obj): self._eql_to_compiled.pop(obj, None) def invalidate_cache_entries(self, to_invalidate): for key in to_invalidate: handle = rpc.CompilationRequestIdHandle(key) self._eql_to_compiled.pop(handle, None) def clear_query_cache(self): self._eql_to_compiled.clear() def iter_views(self): yield from self._views def get_query_cache_size(self): return len(self._eql_to_compiled) + len(self._sql_to_compiled) async def introspection(self): if self.user_schema_pickle is None: async with self._introspection_lock: if self.user_schema_pickle is None: await self.tenant.introspect_db(self.name) def is_introspected(self): return self.user_schema_pickle is not None def lookup_config(self, name: str): spec = self._index._sys_config_spec if self.user_config_spec is not None: spec = config.ChainedSpec(spec, self.user_config_spec) return config.lookup( name, self.db_config or DEFAULT_CONFIG, self._index._sys_config, spec=spec, ) cdef class DatabaseConnectionView: def __init__( self, db: Database, *, query_cache, protocol_version, role_name: str ): self._db = db self._query_cache_enabled = query_cache self._protocol_version = protocol_version self._modaliases = DEFAULT_MODALIASES self._config = DEFAULT_CONFIG self._globals = DEFAULT_GLOBALS self._session_state_db_cache = None self._session_state_cache = None self._state_serializer = None self._role_name = role_name # N.B: If we add anything that is not a string or list of string, we'll # need to adjust get_global_value to encode differently. self._sys_globals = { 'sys::current_role': self._role_name, 'sys::current_permissions': list(self.get_permissions()[1]) } if db.name == defines.EDGEDB_SYSTEM_DB: # Make system database read-only. self._capability_mask = ( compiler.Capability.ALL & ~compiler.Capability.DDL & ~compiler.Capability.MODIFICATIONS ) else: self._capability_mask = compiler.Capability.ALL self._last_comp_state = None self._last_comp_state_id = 0 self._in_tx_seq = 0 self._reset_tx_state() def __del__(self): # In any case if _reset_tx_state() is not called, remove self from # ACTIVE_TX_LIST to be safe self._db.tx_seq_end_tx(self._in_tx_seq) cdef _reset_tx_state(self): self._db.tx_seq_end_tx(self._in_tx_seq) self._in_tx_seq = 0 self._txid = None self._in_tx = False self._in_tx_config = None self._in_tx_globals = None self._in_tx_db_config = None self._in_tx_modaliases = None self._in_tx_savepoints = [] self._in_tx_capabilities = 0 self._in_tx_with_sysconfig = False self._in_tx_with_dbconfig = False self._in_tx_with_set = False self._in_tx_root_user_schema_pickle = None self._in_tx_user_schema_pickle = None self._in_tx_user_schema_version = None self._in_tx_global_schema_pickle = None self._in_tx_new_types = {} self._in_tx_user_config_spec = None self._in_tx_state_serializer = None self._tx_error = False self._in_tx_dbver = 0 self._in_tx_isolation_level = None cdef clear_tx_error(self): self._tx_error = False cdef rollback_tx_to_savepoint(self, name): self._tx_error = False # See also CompilerConnectionState.rollback_to_savepoint(). while self._in_tx_savepoints: if self._in_tx_savepoints[-1][0] == name: break else: self._in_tx_savepoints.pop() else: raise RuntimeError( f'savepoint {name} not found') _, spid, ( modaliases, config, globals, state_serializer ) = self._in_tx_savepoints[-1] self._txid = spid self.set_modaliases(modaliases) self.set_session_config(config) self.set_globals(globals) self.set_state_serializer(state_serializer) cdef declare_savepoint(self, name, spid): state = ( self.get_modaliases(), self.get_session_config(), self.get_globals(), self.get_state_serializer(), ) self._in_tx_savepoints.append((name, spid, state)) cdef recover_aliases_and_config(self, modaliases, config, globals): assert not self._in_tx self.set_modaliases(modaliases) self.set_session_config(config) self.set_globals(globals) cdef abort_tx(self): if not self.in_tx(): raise errors.InternalServerError('abort_tx(): not in transaction') self._reset_tx_state() cpdef get_session_config(self): if self._in_tx: return self._in_tx_config else: return self._config cpdef get_globals(self): if self._in_tx: return self._in_tx_globals else: return self._globals cpdef get_global_value(self, k): if k in self._sys_globals: # N.B: Currently only str and list[str] sys_global = self._sys_globals[k] encoded: bytes if isinstance(sys_global, str): encoded = sys_global.encode('utf-8') elif isinstance(sys_global, list): encoded = b'' encoded += b'\x00\x00\x00\x01' # ndims encoded += b'\x00\x00\x00\x00' # flags encoded += TEXT_OID.to_bytes(4, 'big') # array_tid encoded += len(sys_global).to_bytes(4, 'big') # count encoded += b'\x00\x00\x00\x01' # bound for elem in sys_global: elem_encoded = elem.encode('utf-8') encoded += len(elem_encoded).to_bytes(4, 'big') encoded += elem_encoded else: raise NotImplementedError return encoded, True else: entry = self.get_globals().get(k) if entry: return entry.value, True else: return None, False cdef get_state_serializer(self): if self._in_tx: return self._in_tx_state_serializer else: if self._state_serializer is None: self._state_serializer = self._db.get_state_serializer( self._protocol_version ) return self._state_serializer cdef set_state_serializer(self, new_serializer): if self._in_tx: if ( self._in_tx_state_serializer is None or self._in_tx_state_serializer.type_id != new_serializer.type_id ): self._in_tx_state_serializer = new_serializer else: # Use the same object as the database to avoid duplicate cache self._state_serializer = self._db.set_state_serializer( self._protocol_version, new_serializer ) cdef get_user_config_spec(self): if self._in_tx: return self._in_tx_user_config_spec else: return self._db.user_config_spec cpdef get_config_spec(self): return config.ChainedSpec( self._db._index._sys_config_spec, self.get_user_config_spec(), ) cdef set_session_config(self, new_conf): if self._in_tx: self._in_tx_config = new_conf else: self._config = new_conf cpdef set_globals(self, new_globals): if self._in_tx: self._in_tx_globals = new_globals else: self._globals = new_globals cpdef get_database_config(self): if self._in_tx: return self._in_tx_db_config else: return self._db.db_config cdef set_database_config(self, new_conf): if self._in_tx: self._in_tx_db_config = new_conf else: # N.B: If we *aren't* in a transaction, we rely on calling # process_side_effects() promptly to introspect the new # state. # (We do it this way to avoid potential races between # multiple connections do db configs.) pass cdef get_system_config(self): return self._db._index.get_sys_config() cpdef get_compilation_system_config(self): return self._db._index.get_compilation_system_config() cdef set_modaliases(self, new_aliases): if self._in_tx: self._in_tx_modaliases = new_aliases else: self._modaliases = new_aliases cpdef get_modaliases(self): if self._in_tx: return self._in_tx_modaliases else: return self._modaliases def get_user_schema_pickle(self): if self._in_tx: return self._in_tx_user_schema_pickle else: return self._db.user_schema_pickle def get_global_schema_pickle(self): if self._in_tx: return self._in_tx_global_schema_pickle else: return self._db._index._global_schema_pickle def resolve_backend_type_id(self, type_id): type_id = str(type_id) if self._in_tx: try: tinfo = self._in_tx_new_types[type_id] except KeyError: pass else: return int(tinfo[0]) tinfo = self._db.backend_ids.get(type_id) if tinfo is None: raise RuntimeError( f'cannot resolve backend OID for type {type_id}') return int(tinfo[0]) cdef bytes serialize_state(self): cdef list state if self._in_tx: raise errors.InternalServerError( 'no need to serialize state while in transaction') dbver = self._db.dbver if self._session_state_db_cache is not None: if self._session_state_db_cache[0] == (self._config, dbver): return self._session_state_db_cache[1] state = [] if self._config and self._config != DEFAULT_CONFIG: settings = self.get_config_spec() for sval in self._config.values(): setting = settings[sval.name] kind = 'B' if setting.backend_setting else 'C' jval = config.value_to_json_value(setting, sval.value) state.append({"name": sval.name, "value": jval, "type": kind}) # Include the database version in the state so that we are forced # to clear the config cache on dbver changes. state.append( {"name": '__dbver__', "value": dbver, "type": 'C'}) spec = json.dumps(state).encode('utf-8') self._session_state_db_cache = ((self._config, dbver), spec) return spec cdef bint is_state_desc_changed(self): # We may have executed a query, or COMMIT/ROLLBACK - just use the # serializer we preserved before. NOTE: the schema might have been # concurrently changed from other sessions, we should not reload # serializer from self._db here so that our state can be serialized # properly, and the Execute stays atomic. serializer = self.get_state_serializer() if self._command_state_serializer is not None: # If the resulting descriptor is the same as the input, return None if serializer.type_id == self._command_state_serializer.type_id: if self._in_tx: # There's a case when DDL was executed but the state schema # wasn't affected, so it's enough to keep just one copy. self._in_tx_state_serializer = ( self._command_state_serializer ) return False # Update with the new serializer for upcoming encoding self._command_state_serializer = serializer return True cdef describe_state(self): return self.get_state_serializer().describe() cdef encode_state(self): modaliases = self.get_modaliases() session_config = self.get_session_config() globals_ = self.get_globals() if self._session_state_cache is None: if ( session_config == DEFAULT_CONFIG and modaliases == DEFAULT_MODALIASES and globals_ == DEFAULT_GLOBALS ): return sertypes.NULL_TYPE_ID, b"" serializer = self._command_state_serializer self._command_state_serializer = None if not self.in_tx(): # After encode_state(), self._state_serializer is no longer used if # not in a transaction. So it should be cleared self._state_serializer = None if self._session_state_cache is not None: if ( modaliases, session_config, globals_, serializer.type_id.bytes ) == self._session_state_cache[:4]: return sertypes.NULL_TYPE_ID, b"" self._session_state_cache = None state = {} try: if modaliases[None] != defines.DEFAULT_MODULE_ALIAS: state['module'] = modaliases[None] except KeyError: pass else: modaliases = modaliases.delete(None) if modaliases: state['aliases'] = list(modaliases.items()) if session_config: state['config'] = {k: v.value for k, v in session_config.items()} if globals_: state['globals'] = {k: v.value for k, v in globals_.items()} return serializer.type_id, serializer.encode(state) cdef check_session_config_perms(self, keys): is_superuser, permissions = self.get_permissions() if not is_superuser: settings = self.get_config_spec() for k in keys: setting = settings[k] if setting.session_restricted and not ( setting.session_permission and setting.session_permission in permissions ): raise errors.DisabledCapabilityError( f'role {self._role_name} does not have permission to ' f'configure session config variable {k}' ) cdef decode_state(self, type_id, data): serializer = self.get_state_serializer() self._command_state_serializer = serializer if type_id == sertypes.NULL_TYPE_ID.bytes: self.set_modaliases(DEFAULT_MODALIASES) self.set_session_config(DEFAULT_CONFIG) self.set_globals(DEFAULT_GLOBALS) self._session_state_cache = None return if type_id != serializer.type_id.bytes: self._command_state_serializer = None raise errors.StateMismatchError( "Cannot decode state: type mismatch" ) if self._session_state_cache is not None: if type_id == self._session_state_cache[3]: if data == self._session_state_cache[4]: return state = serializer.decode(data) aliases = dict(state.get('aliases', [])) aliases[None] = state.get('module', defines.DEFAULT_MODULE_ALIAS) aliases = immutables.Map(aliases) config_obj = state.get('config', {}) self.check_session_config_perms(config_obj) session_config = immutables.Map({ k: config.SettingValue( name=k, value=v, source='session', scope=qltypes.ConfigScope.SESSION, ) for k, v in config_obj.items() }) globals_ = immutables.Map({ k: config.SettingValue( name=k, value=recode_global(self, v, serializer.get_global_type_rep(k)), source='global', scope=qltypes.ConfigScope.GLOBAL, ) for k, v in state.get('globals', {}).items() }) self.set_modaliases(aliases) self.set_session_config(session_config) self.set_globals(globals_) self._session_state_cache = ( aliases, session_config, globals_, type_id, data ) cdef decode_json_session_config(self, json_session_config): if not json_session_config: return settings = self.get_config_spec() self.check_session_config_perms(json_session_config) session_config = self.get_session_config() for k, v in json_session_config.items(): op = config.Operation( config.OpCode.CONFIG_SET, qltypes.ConfigScope.SESSION, k, v, ) session_config = op.apply(settings, session_config) self.set_session_config(session_config) cdef bint needs_commit_after_state_sync(self): return any( tx_conf in self._config for tx_conf in [ "default_transaction_isolation", "default_transaction_deferrable", # default_transaction_access_mode is not yet a backend config ] ) property txid: def __get__(self): return self._txid property dbname: def __get__(self): return self._db.name property reflection_cache: def __get__(self): return self._db.reflection_cache property dbver: def __get__(self): if self._in_tx and self._in_tx_dbver: return self._in_tx_dbver return self._db.dbver property schema_version: def __get__(self): if self._in_tx and self._in_tx_user_schema_version: return self._in_tx_user_schema_version return self._db.schema_version @property def server(self): return self._db._index._server @property def tenant(self): return self._db._index._tenant def get_permissions(self) -> tuple[bool, Sequence[str]]: if role_desc := self.tenant.get_roles().get(self._role_name): return ( bool(role_desc.get('superuser')), (role_desc.get('all_permissions') or ()) ) return False, () def get_role_capability(self) -> enums.Capability: if capability := self.tenant.get_role_capabilities().get( self._role_name ): return capability return enums.Capability.NONE cpdef in_tx(self): return self._in_tx cpdef in_tx_error(self): return self._tx_error cdef cache_compiled_query(self, object key, object query_unit_group): assert query_unit_group.cacheable if self._tx_error or self._in_tx_capabilities & DDL_CAPABILITIES: return self._db._cache_compiled_query(key, query_unit_group) cdef lookup_compiled_query(self, object key): if ( self._tx_error or not self._query_cache_enabled or self._in_tx_capabilities & DDL_CAPABILITIES ): return None return self._db._eql_to_compiled.get(key, None) cdef tx_error(self): if self._in_tx: self._tx_error = True cdef start(self, query_unit): if self._tx_error: self.raise_in_tx_error() if query_unit.tx_id is not None: self._txid = query_unit.tx_id self.start_tx() if self._in_tx: self._apply_in_tx(query_unit) cdef start_tx(self): state_serializer = self.get_state_serializer() self._in_tx = True self._in_tx_config = self._config self._in_tx_globals = self._globals self._in_tx_db_config = self._db.db_config self._in_tx_modaliases = self._modaliases self._in_tx_root_user_schema_pickle = self._db.user_schema_pickle self._in_tx_user_schema_pickle = self._db.user_schema_pickle self._in_tx_user_schema_version = self._db.schema_version self._in_tx_global_schema_pickle = \ self._db._index._global_schema_pickle self._in_tx_user_config_spec = self._db.user_config_spec self._in_tx_state_serializer = state_serializer self._in_tx_dbver = self._db.dbver self._in_tx_seq = self._db.tx_seq_begin_tx() cdef _apply_in_tx(self, query_unit): self._in_tx_capabilities |= query_unit.capabilities if query_unit.system_config: self._in_tx_with_sysconfig = True if query_unit.database_config: self._in_tx_with_dbconfig = True if query_unit.has_set: self._in_tx_with_set = True if query_unit.user_schema is not None: self._in_tx_dbver = next_dbver() self._in_tx_user_schema_pickle = query_unit.user_schema self._in_tx_user_schema_version = query_unit.user_schema_version self._in_tx_user_config_spec = config.FlatSpec( *query_unit.ext_config_settings ) if query_unit.global_schema is not None: self._in_tx_global_schema_pickle = query_unit.global_schema if query_unit.tx_isolation_level: self._in_tx_isolation_level = query_unit.tx_isolation_level cdef start_implicit(self, query_unit): if self._tx_error: self.raise_in_tx_error() if not self._in_tx: self.start_tx() self._apply_in_tx(query_unit) cdef on_error(self): self.tx_error() cdef on_success(self, query_unit, new_types): side_effects = 0 if not self._in_tx: if query_unit.capabilities & DML_CAPABILITIES: self._db.dml_queries_executed += 1 if new_types: self._db._update_backend_ids(new_types) if query_unit.user_schema is not None: self._db._set_and_signal_new_user_schema( query_unit.user_schema, query_unit.user_schema_version, query_unit.extensions, query_unit.ext_config_settings, query_unit.feature_used_metrics, pickle.loads(query_unit.cached_reflection) if query_unit.cached_reflection is not None else None, ) side_effects |= SideEffects.SchemaChanges if query_unit.system_config: side_effects |= SideEffects.InstanceConfigChanges if query_unit.database_config: side_effects |= SideEffects.DatabaseConfigChanges if query_unit.create_db: side_effects |= SideEffects.DatabaseChanges if query_unit.drop_db: side_effects |= SideEffects.DatabaseChanges if query_unit.global_schema is not None: side_effects |= SideEffects.GlobalSchemaChanges self._db._index.update_global_schema(query_unit.global_schema) self._db.tenant.set_roles(query_unit.roles) else: if new_types: self._in_tx_new_types.update(new_types) if query_unit.modaliases is not None: self.set_modaliases(query_unit.modaliases) if query_unit.tx_commit: if not self._in_tx: # This shouldn't happen because compiler has # checks around that. raise errors.InternalServerError( '"commit" outside of a transaction') self._config = self._in_tx_config self._modaliases = self._in_tx_modaliases self._globals = self._in_tx_globals if self._in_tx_capabilities & DML_CAPABILITIES: self._db.dml_queries_executed += 1 if self._in_tx_new_types: self._db._update_backend_ids(self._in_tx_new_types) if query_unit.user_schema is not None: self._db._set_and_signal_new_user_schema( query_unit.user_schema, query_unit.user_schema_version, query_unit.extensions, query_unit.ext_config_settings, query_unit.feature_used_metrics, # XXX? does this get set? pickle.loads(query_unit.cached_reflection) if query_unit.cached_reflection is not None else None, ) side_effects |= SideEffects.SchemaChanges if self._in_tx_with_sysconfig: side_effects |= SideEffects.InstanceConfigChanges if self._in_tx_with_dbconfig: side_effects |= SideEffects.DatabaseConfigChanges if query_unit.global_schema is not None: side_effects |= SideEffects.GlobalSchemaChanges self._db._index.update_global_schema(query_unit.global_schema) self._db.tenant.set_roles(query_unit.roles) self._reset_tx_state() elif query_unit.tx_rollback: # Note that we might not be in a transaction as we allow # ROLLBACKs outside of transaction blocks (just like Postgres). # TODO: That said, we should send a *warning* when a ROLLBACK # is executed outside of a tx. self._reset_tx_state() return side_effects cdef commit_implicit_tx( self, user_schema, extensions, ext_config_settings, global_schema, roles, cached_reflection, feature_used_metrics, ): assert self._in_tx side_effects = 0 self._config = self._in_tx_config self._modaliases = self._in_tx_modaliases self._globals = self._in_tx_globals if self._in_tx_new_types: self._db._update_backend_ids(self._in_tx_new_types) if user_schema is not None: self._db._set_and_signal_new_user_schema( user_schema, self._in_tx_user_schema_version, extensions, ext_config_settings, feature_used_metrics, pickle.loads(cached_reflection) if cached_reflection is not None else None ) side_effects |= SideEffects.SchemaChanges if self._in_tx_with_sysconfig: side_effects |= SideEffects.InstanceConfigChanges if self._in_tx_with_dbconfig: side_effects |= SideEffects.DatabaseConfigChanges if global_schema is not None: side_effects |= SideEffects.GlobalSchemaChanges self._db._index.update_global_schema(global_schema) self._db.tenant.set_roles(roles) self._reset_tx_state() return side_effects cdef config_lookup(self, name): return self.server.config_lookup( name, self.get_session_config(), self.get_database_config(), self.get_system_config(), ) async def recompile_cached_queries( self, user_schema, schema_version, send_log_message: typing.Callable[[int, str], None] | None = None, ): compiler_pool = self.server.get_compiler_pool() compile_concurrency = max(1, compiler_pool.get_size_hint() // 2) concurrency_control = asyncio.Semaphore(compile_concurrency) rv = [] recompile_timeout = self.config_lookup( "auto_rebuild_query_cache_timeout", ) loop = asyncio.get_running_loop() t0 = loop.time() if recompile_timeout is not None: stop_time = t0 + recompile_timeout.to_microseconds() / 1e6 else: stop_time = None async def recompile_request(query_req: rpc.CompilationRequest): async with concurrency_control: try: if stop_time is not None and loop.time() > stop_time: return database_config = self.get_database_config() system_config = self.get_compilation_system_config() query_req = copy.copy(query_req) query_req.set_schema_version(schema_version) query_req.set_database_config(database_config) query_req.set_system_config(system_config) async with asyncio.timeout_at(stop_time): unit_group, _, _ = await compiler_pool.compile( self.dbname, user_schema, self.get_global_schema_pickle(), self.reflection_cache, database_config, system_config, query_req.serialize(), "", client_id=self.tenant.client_id, client_name=self.tenant.get_instance_name(), ) except Exception: # ignore cache entry that cannot be recompiled pass else: rv.append((query_req, unit_group)) async with asyncio.TaskGroup() as g: req: rpc.CompilationRequest cnt = 0 # Reversed so that we compile more recently used first. for req, grp in reversed(self._db._eql_to_compiled.items()): if ( len(grp) == 1 # Only recompile queries from the *latest* version, # to avoid quadratic slowdown problems. and req.schema_version == self.schema_version # SQL queries require _amend_typedesc_in_sql() with a # backend connection, which is not available here. and req.input_language != enums.InputLanguage.SQL ): cnt += 1 g.create_task(recompile_request(req)) if send_log_message: send_log_message( errors.MigrationStatusMessage.get_code(), f'Recompiling {cnt} cached queries' ) return rv async def apply_config_ops(self, conn, ops): settings = self.get_config_spec() for op in ops: if op.scope is config.ConfigScope.INSTANCE: assert conn is not None await self._db._index.apply_system_config_op(conn, op) elif op.scope is config.ConfigScope.DATABASE: self.set_database_config( op.apply(settings, self.get_database_config()), ) elif op.scope is config.ConfigScope.SESSION: self.check_session_config_perms([op.setting_name]) self.set_session_config( op.apply(settings, self.get_session_config()), ) elif op.scope is config.ConfigScope.GLOBAL: self.set_globals( op.apply(settings, self.get_globals()), ) @staticmethod def raise_in_tx_error(): raise errors.TransactionError( 'current transaction is aborted, ' 'commands ignored until end of transaction block' ) from None async def parse( self, query_req: rpc.CompilationRequest, cached_globally: bint = False, use_metrics: bint = True, allow_capabilities: uint64_t = compiler.Capability.ALL, pgcon: pgcon.PGConnection | None = None, tag: str | None = None, send_log_message: typing.Callable[[int, str], None] | None = None, ) -> CompiledQuery: query_unit_group = None if self._query_cache_enabled: if cached_globally: # WARNING: only set cached_globally to True when the query is # strictly referring to only shared stable objects in user # schema or anything from std schema, for example: # YES: select ext::auth::UIConfig { ... } # NO: select default::User { ... } query_unit_group = ( self.server.system_compile_cache.get(query_req) ) else: query_unit_group = self.lookup_compiled_query(query_req) # Fast-path to skip all the locks if it's a cache HIT if query_unit_group is not None: return self.as_compiled( query_req, query_unit_group, use_metrics) lock = None schema_version = self.schema_version # Lock on the query compilation to avoid other coroutines running # the same compile and waste computational resources if cached_globally: lock_table = self.server.system_compile_cache_locks else: lock_table = self._db._cache_locks while True: # We need a loop here because schema_version is a part of the key, # there could be a DDL while we're waiting for the lock. lock = lock_table.get(query_req) if lock is None: lock = asyncio.Lock() lock_table[query_req] = lock await lock.acquire() if self.schema_version == schema_version: break else: lock.release() if not lock._waiters: del lock_table[query_req] schema_version = self.schema_version # Updating the schema_version will make query_req a new key query_req.set_schema_version(schema_version) try: # Check the cache again with the lock acquired if self._query_cache_enabled: if cached_globally: query_unit_group = ( self.server.system_compile_cache.get(query_req) ) else: query_unit_group = self.lookup_compiled_query(query_req) if query_unit_group is not None: return self.as_compiled( query_req, query_unit_group, use_metrics) try: query_unit_group = await self._compile(query_req) except (errors.EdgeQLSyntaxError, errors.InternalServerError): raise except errors.EdgeDBError: if self.in_tx_error(): # Because we are in an error state it's more reasonable # to fail with TransactionError("commands ignored") # rather than with a potentially more cryptic error. # An exception from this rule are syntax errors and # ISEs, because these could arise while the user is # trying to properly rollback this failed transaction. self.raise_in_tx_error() else: raise self.check_capabilities( query_unit_group, allow_capabilities, errors.DisabledCapabilityError, "disabled by the client", # In parse, we don't raise any errors based on # unsafe_isolation_dangers. We do report them to the # client in an annotation, though. unsafe_isolation_dangers=None, ) self._check_in_tx_error(query_unit_group) if query_req.input_language is enums.InputLanguage.SQL: if len(query_unit_group) > 1: raise errors.UnsupportedFeatureError( "multi-statement SQL scripts are not supported yet" ) if pgcon is None: raise errors.InternalServerError( "a valid backend connection is required to fully " "compile a query in SQL mode", ) await self._amend_typedesc_in_sql( query_req, query_unit_group, pgcon, tag, ) if self._query_cache_enabled and query_unit_group.cacheable: if cached_globally: self.server.system_compile_cache[query_req] = ( query_unit_group ) else: self.cache_compiled_query(query_req, query_unit_group) finally: if lock is not None: lock.release() if not lock._waiters: del lock_table[query_req] recompiled_cache = None if ( not self.in_tx() or len(query_unit_group) > 0 and query_unit_group[0].tx_commit ): # Recompile all cached queries if: # * Issued a DDL or committing a tx with DDL (recompilation # before in-tx DDL needs to fix _in_tx_capabilities caching 1st) # * Config.auto_rebuild_query_cache is turned on # # Ideally we should compute the proper user_schema, database_config # and system_config for recompilation from server/compiler.py with # proper handling of config values. For now we just use the values # in the current dbview and not support certain marginal cases. user_schema = None user_schema_version = None for unit in query_unit_group: if unit.tx_rollback: break if unit.user_schema: user_schema = unit.user_schema user_schema_version = unit.user_schema_version if user_schema and not self.config_lookup( "auto_rebuild_query_cache", ): user_schema = None if user_schema: recompiled_cache = await self.recompile_cached_queries( user_schema, user_schema_version, send_log_message=send_log_message, ) if use_metrics: if query_req.input_language is enums.InputLanguage.EDGEQL: metrics.edgeql_query_compilations.inc( 1.0, self.tenant.get_instance_name(), 'compiler' ) else: metrics.sql_compilations.inc( 1.0, self.tenant.get_instance_name() ) source = query_req.source if query_unit_group.force_non_normalized: source = source.denormalized() return CompiledQuery( query_unit_group=query_unit_group, first_extra=source.first_extra(), extra_counts=source.extra_counts(), extra_blobs=source.extra_blobs(), extra_formatted_as_text=source.extra_formatted_as_text(), extra_type_oids=source.extra_type_oids(), request=query_req, recompiled_cache=recompiled_cache, ) async def _amend_typedesc_in_sql( self, query_req: rpc.CompilationRequest, qug: dbstate.QueryUnitGroup, pgcon: pgcon.PGConnection, tag: str | None, ) -> None: # The SQL QueryUnitGroup as initially returned from the compiler # is missing the input/output type descriptors because we currently # don't run static SQL type inference. To mend that we ask Postgres # to infer the the result types (as an OID tuple) and then use # our OID -> scalar type mapping to construct an EdgeQL free shape with # corresponding properties which we then send to the compiler to # compute the type descriptors. to_describe = [] desc_map = {} source = query_req.source if qug.force_non_normalized: source = source.denormalized() first_extra = source.first_extra() num_injected_params = 0 if qug.globals is not None: num_injected_params += len(qug.globals) if qug.permissions is not None: num_injected_params += len(qug.permissions) if first_extra is not None: extra_type_oids = source.extra_type_oids() all_type_oids = [0] * first_extra + extra_type_oids num_injected_params += len(extra_type_oids) else: all_type_oids = [] for i, query_unit in enumerate(qug): intro_sql = query_unit.introspection_sql if intro_sql is None: intro_sql = query_unit.sql if tag is not None: # maintenance reminder: please also update make_query_prefix() tag_json = json.dumps({"tag": tag}) intro_sql = b''.join([ b'-- ', tag_json.encode(defines.EDGEDB_ENCODING), b'\n', intro_sql, ]) try: param_desc, result_desc = await pgcon.sql_describe( intro_sql, all_type_oids) except pgerror.BackendError as ex: ex._from_sql = True if 'P' in ex.fields: ex.fields['P'] = str( int(ex.fields['P']) - query_unit.sql_prefix_len ) if query_unit.source_map: ex._source_map = query_unit.source_map raise result_types = [] for col, toid in result_desc: edb_type_id = self._db.backend_oid_to_id.get(toid) if edb_type_id is None: raise errors.UnsupportedFeatureError( f"unsupported SQL type in column \"{col}\" " f"with type OID {toid}" ) result_types.append((col, edb_type_id)) params = [] if num_injected_params: param_desc = param_desc[:-num_injected_params] for pi, toid in enumerate(param_desc): edb_type_id = self._db.backend_oid_to_id.get(toid) if edb_type_id is None: raise errors.UnsupportedFeatureError( f"unsupported type in SQL parameter ${pi} " f"with type OID {toid}" ) params.append(edb_type_id) to_describe.append((params, result_types)) desc_map[len(to_describe) - 1] = i if to_describe: desc_qug = await self._compile_sql_descriptors( query_req, to_describe) for i, desc_qu in enumerate(desc_qug): qu_i = desc_map[i] if query_req.output_format is not enums.OutputFormat.NONE: qug[qu_i].out_type_data = desc_qu[1][0] qug[qu_i].out_type_id = desc_qu[1][1] qug[qu_i].in_type_data = desc_qu[0][0] qug[qu_i].in_type_id = desc_qu[0][1] qug[qu_i].in_type_args = desc_qu[0][2] qug[qu_i].in_type_args_real_count = desc_qu[0][3] # XXX We don't support SQL scripts just yet, so for now # we can just copy the last QU's descriptors and # apply them to the whole group (IOW a group is really # a group of ONE now.) # In near future we'll need to properly implement arg # remap. if query_req.output_format is not enums.OutputFormat.NONE: qug.out_type_data = desc_qug[-1][1][0] qug.out_type_id = desc_qug[-1][1][1] qug.in_type_data = desc_qug[-1][0][0] qug.in_type_id = desc_qug[-1][0][1] qug.in_type_args = desc_qug[-1][0][2] qug.in_type_args_real_count = desc_qug[-1][0][3] cdef inline _check_in_tx_error(self, query_unit_group): if self.in_tx_error(): # The current transaction is aborted, so we must fail # all commands except ROLLBACK or ROLLBACK TO SAVEPOINT. first = query_unit_group[0] if ( not ( first.tx_rollback or first.tx_savepoint_rollback or first.tx_abort_migration ) or len(query_unit_group) > 1 ): self.raise_in_tx_error() cdef as_compiled(self, query_req, query_unit_group, bint use_metrics=True): cdef use_pending_func_cache = False if query_unit_group.cache_state == 1: use_pending_func_cache = ( not self._in_tx_seq or self._in_tx_seq > query_unit_group.tx_seq_id ) self._check_in_tx_error(query_unit_group) if use_metrics: metrics.edgeql_query_compilations.inc( 1.0, self.tenant.get_instance_name(), 'cache' ) source = query_req.source return CompiledQuery( query_unit_group=query_unit_group, first_extra=source.first_extra(), extra_counts=source.extra_counts(), extra_blobs=source.extra_blobs(), extra_formatted_as_text=source.extra_formatted_as_text(), extra_type_oids=source.extra_type_oids(), use_pending_func_cache=use_pending_func_cache, ) async def _compile( self, query_req: rpc.CompilationRequest, ) -> dbstate.QueryUnitGroup: compiler_pool = self._db._index._server.get_compiler_pool() started_at = time.monotonic() try: if self.in_tx(): result = await compiler_pool.compile_in_tx( self.dbname, self._in_tx_root_user_schema_pickle, self.txid, self._last_comp_state, self._last_comp_state_id, query_req.serialize(), query_req.source.text(), self.in_tx_error(), client_id=self.tenant.client_id, client_name=self.tenant.get_instance_name(), ) else: result = await compiler_pool.compile( self.dbname, self.get_user_schema_pickle(), self.get_global_schema_pickle(), self.reflection_cache, self.get_database_config(), self.get_compilation_system_config(), query_req.serialize(), query_req.source.text(), client_id=self.tenant.client_id, client_name=self.tenant.get_instance_name(), ) finally: metrics.edgeql_query_compilation_duration.observe( time.monotonic() - started_at, self.tenant.get_instance_name(), ) metrics.query_compilation_duration.observe( time.monotonic() - started_at, self.tenant.get_instance_name(), "edgeql", ) unit_group, self._last_comp_state, self._last_comp_state_id = result return unit_group async def _compile_sql_descriptors( self, query_req: rpc.CompilationRequest, types_in_out: defines.ProtocolVersion, ) -> dbstate.QueryUnitGroup: compiler_pool = self._db._index._server.get_compiler_pool() cfg_ser = self.server.compilation_config_serializer req = rpc.CompilationRequest( source=rpc.SQLParamsSource(types_in_out), protocol_version=query_req.protocol_version, schema_version=query_req.schema_version, input_language=enums.InputLanguage.SQL_PARAMS, compilation_config_serializer=cfg_ser, ) return await self._compile(req) cdef check_capabilities( self, query_unit, allowed_capabilities, error_constructor, reason, unsafe_isolation_dangers, ): query_capabilities = query_unit.capabilities if query_capabilities & ~self._capability_mask: # _capability_mask is currently only used for system database raise query_capabilities.make_error( self._capability_mask, errors.UnsupportedCapabilityError, "system database is read-only", ) if (query_capabilities & PROTO_CAPS) & ~allowed_capabilities: raise query_capabilities.make_error( allowed_capabilities, error_constructor, reason, ) role_capability = self.get_role_capability() if query_capabilities & ~role_capability: raise query_capabilities.make_error( role_capability, error_constructor, f"role {self._role_name} does not have permission", ) if self.tenant.is_readonly(): if query_capabilities & enums.Capability.WRITE: readiness_reason = self.tenant.get_readiness_reason() msg = "the server is currently in read-only mode" if readiness_reason: msg = f"{msg}: {readiness_reason}" raise query_capabilities.make_error( ~enums.Capability.WRITE, errors.DisabledCapabilityError, msg, ) if query_unit.required_permissions: is_superuser, permissions = self.get_permissions() if not is_superuser: for perm in query_unit.required_permissions: if perm not in permissions: missing = sorted( set(query_unit.required_permissions) - set(permissions) ) plural = 's' if len(missing) > 1 else '' raise errors.DisabledCapabilityError( f'role {self._role_name} does not have required ' f'permission{plural}: {", ".join(missing)}' ) has_write = query_capabilities & enums.Capability.WRITE if has_write and unsafe_isolation_dangers: isolation = None # Sigh! We have two different isolation level enumerations! if self.in_tx(): isolation = self._in_tx_isolation_level else: isolation = self.config_lookup( "default_transaction_isolation" ) if isolation and isolation.to_str() == "RepeatableRead": isolation = ( qltypes.TransactionIsolationLevel.REPEATABLE_READ) else: isolation = qltypes.TransactionIsolationLevel.SERIALIZABLE not_serializable = ( isolation != qltypes.TransactionIsolationLevel.SERIALIZABLE ) if not_serializable: body = '\n'.join( ' - ' + str(e) for e in unsafe_isolation_dangers ) raise errors.UnsafeIsolationLevelError( f"Can not execute query with transaction isolation level " f"{isolation} because: \n{body}", ) if not self.in_tx() and has_write: access_mode = self.config_lookup("default_transaction_access_mode") if access_mode and access_mode.to_str() == "ReadOnly": raise query_capabilities.make_error( ~enums.Capability.WRITE, errors.TransactionError, "default_transaction_access_mode is set to ReadOnly", ) async def reload_state_serializer(self): # This should only happen once when a different protocol version is # used after schema change, or non-current version of protocol is used # for the first time after database introspection. Because such cases # are rare, we'd rather do it lazily here than enumerating all protocol # versions making several serializers in every schema change. compiler_pool = self._db._index._server.get_compiler_pool() state_serializer = await compiler_pool.make_state_serializer( self._protocol_version, self.get_user_schema_pickle(), self.get_global_schema_pickle(), ) self.set_state_serializer(state_serializer) cdef class DatabaseIndex: def __init__( self, tenant, *, std_schema, global_schema_pickle, sys_config, default_sysconfig, # system config without system override sys_config_spec, ): self._dbs = {} self._server = tenant.server self._tenant = tenant self._std_schema = std_schema self._global_schema_pickle = global_schema_pickle self._default_sysconfig = default_sysconfig self._sys_config_spec = sys_config_spec self.update_sys_config(sys_config) self._cached_compiler_args = None def count_connections(self, dbname: str): try: db = self._dbs[dbname] except KeyError: return 0 return sum(1 for dbv in (db)._views if not dbv.is_transient) def get_sys_config(self): return self._sys_config def get_compilation_system_config(self): return self._comp_sys_config def update_sys_config(self, sys_config): cdef Database db for db in self._dbs.values(): db.dbver = next_dbver() with self._default_sysconfig.mutate() as mm: mm.update(sys_config) sys_config = mm.finish() self._sys_config = sys_config self._comp_sys_config = config.get_compilation_config( sys_config, spec=self._sys_config_spec) self.invalidate_caches() def has_db(self, dbname): return dbname in self._dbs def get_db(self, dbname): try: return self._dbs[dbname] except KeyError: raise errors.UnknownDatabaseError( f'database branch {dbname!r} does not exist') def maybe_get_db(self, dbname): return self._dbs.get(dbname) def get_global_schema_pickle(self): return self._global_schema_pickle def update_global_schema(self, global_schema_pickle): self._global_schema_pickle = global_schema_pickle self.invalidate_caches() def register_db( self, dbname, *, user_schema_pickle, schema_version, db_config, reflection_cache, backend_ids, extensions, ext_config_settings, early=False, feature_used_metrics=None, ): cdef Database db db = self._dbs.get(dbname) if db is not None: db._set_and_signal_new_user_schema( user_schema_pickle, schema_version, extensions, ext_config_settings, feature_used_metrics, reflection_cache, backend_ids, db_config, not early, ) else: db = Database( self, dbname, user_schema_pickle=user_schema_pickle, schema_version=schema_version, db_config=db_config, reflection_cache=reflection_cache, backend_ids=backend_ids, extensions=extensions, ext_config_settings=ext_config_settings, feature_used_metrics=feature_used_metrics, ) self._dbs[dbname] = db if not early: db.start_stop_extensions() self.set_current_branches() return db def unregister_db(self, dbname): db = self._dbs.pop(dbname) db.stop() self.set_current_branches() cdef inline set_current_branches(self): metrics.current_branches.set( sum( dbname != defines.EDGEDB_SYSTEM_DB for dbname in self._dbs ), self._tenant.get_instance_name(), ) metrics.current_introspected_branches.set( sum( dbname != defines.EDGEDB_SYSTEM_DB and db.user_schema_pickle is not None for dbname, db in self._dbs.items() ), self._tenant.get_instance_name(), ) def iter_dbs(self): return iter(self._dbs.values()) async def _save_system_overrides(self, conn, spec): data = config.to_json( spec, self._sys_config, setting_filter=lambda v: v.source == 'system override', include_source=False, ) block = dbops.PLTopBlock() metadata = {'sysconfig': json.loads(data)} if self._tenant.get_backend_runtime_params().has_create_database: dbops.UpdateMetadata( dbops.Database( name=self._tenant.get_pg_dbname(defines.EDGEDB_SYSTEM_DB), ), metadata, ).generate(block) else: dbops.UpdateSingleDBMetadata( defines.EDGEDB_SYSTEM_DB, metadata ).generate(block) await conn.sql_execute(block.to_string().encode()) async def apply_system_config_op(self, conn, op): spec = self._sys_config_spec op_value = op.get_setting(spec) if op.opcode is not None: allow_missing = ( op.opcode is config.OpCode.CONFIG_REM or op.opcode is config.OpCode.CONFIG_RESET ) op_value = op.coerce_value( spec, op_value, allow_missing=allow_missing) # _save_system_overrides *must* happen before # the callbacks below, because certain config changes # may cause the backend connection to drop. self.update_sys_config( op.apply(spec, self._sys_config) ) await self._save_system_overrides(conn, spec) if op.opcode is config.OpCode.CONFIG_ADD: await self._server._on_system_config_add(op.setting_name, op_value) elif op.opcode is config.OpCode.CONFIG_REM: await self._server._on_system_config_rem(op.setting_name, op_value) elif op.opcode is config.OpCode.CONFIG_SET: await self._server._on_system_config_set(op.setting_name, op_value) elif op.opcode is config.OpCode.CONFIG_RESET: await self._server._on_system_config_reset(op.setting_name) else: raise errors.UnsupportedFeatureError( f'unsupported config operation: {op.opcode}') if op.opcode is config.OpCode.CONFIG_ADD: await self._server._after_system_config_add( op.setting_name, op_value) elif op.opcode is config.OpCode.CONFIG_REM: await self._server._after_system_config_rem( op.setting_name, op_value) elif op.opcode is config.OpCode.CONFIG_SET: await self._server._after_system_config_set( op.setting_name, op_value) elif op.opcode is config.OpCode.CONFIG_RESET: await self._server._after_system_config_reset( op.setting_name) def new_view( self, dbname: str, *, query_cache: bool, protocol_version, role_name: str, ): db = self.get_db(dbname) return (db)._new_view( query_cache, protocol_version, role_name ) def remove_view(self, view: DatabaseConnectionView): db = self.get_db(view.dbname) return (db)._remove_view(view) cdef invalidate_caches(self): self._cached_compiler_args = None def get_cached_compiler_args(self): if self._cached_compiler_args is None: self._cached_compiler_args = ( self._global_schema_pickle, self._comp_sys_config ) return self._cached_compiler_args def lookup_config(self, name: str): return config.lookup( name, self._sys_config, spec=self._sys_config_spec, ) ================================================ FILE: edb/server/defines.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import TypeAlias import enum from edb import buildmeta from edb.common import enum as s_enum from edb.schema import defines as s_def EDGEDB_PORT = 5656 EDGEDB_REMOTE_COMPILER_PORT = 5660 EDGEDB_SUPERGROUP = 'edgedb_supergroup' EDGEDB_SUPERUSER = s_def.EDGEDB_SUPERUSER EDGEDB_OLD_SUPERUSER = s_def.EDGEDB_OLD_SUPERUSER EDGEDB_TEMPLATE_DB = s_def.EDGEDB_TEMPLATE_DB EDGEDB_OLD_DEFAULT_DB = 'edgedb' EDGEDB_SUPERUSER_DB = 'main' EDGEDB_SYSTEM_DB = s_def.EDGEDB_SYSTEM_DB EDGEDB_ENCODING = 'utf-8' EDGEDB_VISIBLE_METADATA_PREFIX = r'Gel metadata follows, do not modify.\n' EDGEDB_SPECIAL_DBS = s_def.EDGEDB_SPECIAL_DBS EDGEDB_CATALOG_VERSION = buildmeta.EDGEDB_CATALOG_VERSION MIN_POSTGRES_VERSION = (14, 0) # Resource limit on open FDs for the server process. # By default, at least on macOS, the max number of open FDs # is 256, which is low and can cause 'edb test' to hang. # We try to bump the rlimit on server start if pemitted. EDGEDB_MIN_RLIMIT_NOFILE = 2048 BACKEND_CONNECTIONS_MIN = 4 BACKEND_COMPILER_POOL_SIZE_MIN = 1 # The time in seconds to wait before restarting the template compiler process # after it exits unexpectedly. BACKEND_COMPILER_TEMPLATE_PROC_RESTART_INTERVAL = 1 _MAX_QUERIES_CACHE_SYSTEM = 1000 _QUERY_ROLLING_AVG_LEN = 10 _QUERIES_ROLLING_AVG_LEN = 300 DEFAULT_MODULE_ALIAS = 'default' # The maximum length of a Unix socket relative to runstate dir. # 21 is the length of the longest socket we might use, which # is the admin socket (.s.EDGEDB.admin.xxxxx). MAX_UNIX_SOCKET_PATH_LENGTH = 21 # 104 is the maximum Unix socket path length on BSD/Darwin, whereas # Linux is constrained to 108. MAX_RUNSTATE_DIR_PATH = 104 - MAX_UNIX_SOCKET_PATH_LENGTH - 1 HTTP_PORT_QUERY_CACHE_SIZE = 1000 # The time in seconds the Gel server shall wait between retries to connect # to the system database after the connection was broken during runtime. SYSTEM_DB_RECONNECT_INTERVAL = 1 ProtocolVersion: TypeAlias = tuple[int, int] MIN_PROTOCOL: ProtocolVersion = (1, 0) CURRENT_PROTOCOL: ProtocolVersion = (3, 0) # Emulated PG binary protocol POSTGRES_PROTOCOL: ProtocolVersion = (-3, 0) MIN_SUGGESTED_CLIENT_POOL_SIZE = 10 MAX_SUGGESTED_CLIENT_POOL_SIZE = 100 _TLS_CERT_RELOAD_MAX_RETRIES = 5 _TLS_CERT_RELOAD_EXP_INTERVAL = 0.1 PGEXT_POSTGRES_VERSION = 13.9 PGEXT_POSTGRES_VERSION_NUM = 130009 # The time in seconds the Gel server will wait for a tenant to be gracefully # shutdown when removed from a multi-tenant host. MULTITENANT_TENANT_DESTROY_TIMEOUT = 30 class TxIsolationLevel(s_enum.StrEnum): RepeatableRead = 'REPEATABLE READ' Serializable = 'SERIALIZABLE' # Mapping to the backend `edb_stat_statements.stmt_type` values, # as well as `sys::QueryType` in edb/lib/sys.edgeql class QueryType(enum.IntEnum): EdgeQL = 1 SQL = 2 ================================================ FILE: edb/server/ha/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ================================================ FILE: edb/server/ha/adaptive.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Optional import asyncio import enum import logging import os from edb.server import metrics from . import base UNHEALTHY_MIN_TIME = int(os.getenv( 'EDGEDB_SERVER_BACKEND_ADAPTIVE_HA_UNHEALTHY_MIN_TIME', 30 )) UNEXPECTED_DISCONNECTS_THRESHOLD = int(os.getenv( 'EDGEDB_SERVER_BACKEND_ADAPTIVE_HA_DISCONNECT_PERCENT', 60 )) / 100 logger = logging.getLogger("edb.pgcluster") class State(enum.Enum): HEALTHY = 1 UNHEALTHY = 2 FAILOVER = 3 class AdaptiveHASupport: # Adaptive HA support is used to detect HA backends that does not actively # send clear failover signals to EdgeDB. It can be enabled through command # line argument --enable-backend-adaptive-ha. # # This class evaluates the events on the backend connection pool into 3 # states representing the status of the backend: # # * Healthy - all is good # * Unhealthy - a staging state before failover # * Failover - backend failover is in process # # When entering Unhealthy state, we will start to count events for a # threshold; when reached, we'll switch to Failover state - that means we # will actively disconnect all backend connections and wait for sys_pgcon # to reconnect. In any of the 3 states, client connections will not be # dropped. Whether the clients could issue queries is irrelevant to the 3 # states - `BackendUnavailableError` or `BackendInFailoverError` is only # raised if the sys_pgcon is broken. But even with that said, # `BackendUnavailableError` is only seen in Unhealthy (not always), and # Failover always means `BackendInFailoverError` for any queries. # # Rules of state switches: # # Unhealthy -> Healthy # * Successfully connected to a non-hot-standby backend. # * Data received from any pgcon (not implemented). # # Unhealthy -> Failover # * More than 60% (UNEXPECTED_DISCONNECTS_THRESHOLD) of existing pgcons # are "unexpectedly disconnected" (number of existing pgcons is # captured at the moment we change to Unhealthy state, and maintained # on "expected disconnects" too). # * (and) In Unhealthy state for more than UNHEALTHY_MIN_TIME seconds. # * (and) sys_pgcon is down. # * (or) Postgres shutdown/hot-standby notification received. # # Healthy -> Unhealthy # * Any unexpected disconnect. # * (or) Failed to connect due to ConnectionError (not implemented). # * (or) Last active time is greater than 10 seconds (depends on the # sys_pgcon idle-poll interval) (not implemented). # # Healthy -> Failover # * Postgres shutdown/hot-standby notification received. # # Failover -> Healthy # * Successfully connected to a non-hot-standby backend. # * (and) sys_pgcon is healthy. _state: State _unhealthy_timer_handle: Optional[asyncio.TimerHandle] def __init__(self, cluster_protocol: base.ClusterProtocol, tag: str): self._cluster_protocol = cluster_protocol self._state = State.UNHEALTHY self._pgcon_count = 0 self._unexpected_disconnects = 0 self._unhealthy_timer_handle = None self._sys_pgcon_healthy = False self._tag = tag def incr_metrics_counter(self, event: str, value: float = 1.0) -> None: metrics.ha_events_total.inc(value, f"adaptive://{self._tag}", event) def set_state_failover(self, *, call_on_switch_over=True): self._state = State.FAILOVER self._reset() if call_on_switch_over: logger.critical("adaptive: HA failover detected") self.incr_metrics_counter("failover") self._cluster_protocol.on_switch_over() def on_pgcon_broken(self, is_sys_pgcon: bool): if is_sys_pgcon: self._sys_pgcon_healthy = False if self._state == State.HEALTHY: self.incr_metrics_counter("unhealthy") self._state = State.UNHEALTHY self._unexpected_disconnects = 1 self._unhealthy_timer_handle = ( asyncio.get_running_loop().call_later( UNHEALTHY_MIN_TIME, self._maybe_failover ) ) self._pgcon_count = max( self._cluster_protocol.get_active_pgcon_num(), 0 ) + 1 logger.warning( "adaptive: Backend HA cluster is unhealthy. " "Captured number of pgcons: %d", self._pgcon_count, ) elif self._state == State.UNHEALTHY: self._unexpected_disconnects += 1 if self._unhealthy_timer_handle is None: self._maybe_failover() def on_pgcon_lost(self): if self._state == State.UNHEALTHY: self._pgcon_count = max(1, self._pgcon_count - 1) logger.debug( "on_pgcon_lost: decreasing captured pgcon count to: %d", self._pgcon_count, ) if self._unhealthy_timer_handle is None: self._maybe_failover() def on_pgcon_made(self, is_sys_pgcon: bool): if is_sys_pgcon: self._sys_pgcon_healthy = True if self._state == State.UNHEALTHY: self.incr_metrics_counter("healthy") self._state = State.HEALTHY logger.info("adaptive: Backend HA cluster is healthy") self._reset() elif self._state == State.FAILOVER: if self._sys_pgcon_healthy: self.incr_metrics_counter("healthy") self._state = State.HEALTHY logger.info( "adaptive: Backend HA cluster has recovered from failover" ) def _reset(self): self._pgcon_count = 0 self._unexpected_disconnects = 0 if self._unhealthy_timer_handle is not None: self._unhealthy_timer_handle.cancel() self._unhealthy_timer_handle = None def _maybe_failover(self): logger.debug( "_maybe_failover: unexpected disconnects: %d", self._unexpected_disconnects, ) self._unhealthy_timer_handle = None if ( self._unexpected_disconnects / self._pgcon_count >= UNEXPECTED_DISCONNECTS_THRESHOLD ) and not self._sys_pgcon_healthy: self.set_state_failover() ================================================ FILE: edb/server/ha/base.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Callable, Optional import urllib.parse from edb.common import asyncwatcher from edb.server import metrics class ClusterProtocol: def on_switch_over(self): pass def get_active_pgcon_num(self) -> int: raise NotImplementedError() class HABackend(asyncwatcher.AsyncWatcher): def __init__(self) -> None: super().__init__() self._failover_cb: Optional[Callable[[], None]] = None async def get_cluster_consensus(self) -> tuple[str, int]: raise NotImplementedError def get_master_addr(self) -> Optional[tuple[str, int]]: raise NotImplementedError def set_failover_callback(self, cb: Optional[Callable[[], None]]) -> None: self._failover_cb = cb @property def dsn(self) -> str: raise NotImplementedError def incr_metrics_counter(self, event: str, value: float = 1.0) -> None: metrics.ha_events_total.inc(value, self.dsn, event) def get_backend(parsed_dsn: urllib.parse.ParseResult) -> Optional[HABackend]: backend, _, sub_scheme = parsed_dsn.scheme.partition("+") if backend == "stolon": from . import stolon return stolon.get_backend(sub_scheme, parsed_dsn) return None ================================================ FILE: edb/server/ha/stolon.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional import asyncio import base64 import functools import json import logging import os import ssl import urllib.parse from edb.common import asyncwatcher from edb.common import token_bucket from edb.server import consul from . import base logger = logging.getLogger("edb.server") class StolonBackend(base.HABackend): _master_addr: Optional[tuple[str, int]] def __init__(self) -> None: super().__init__() self._master_addr = None async def get_cluster_consensus(self) -> tuple[str, int]: if self._master_addr is None: started_by_us = await self.start_watching() try: assert self._waiter is None self._waiter = asyncio.get_running_loop().create_future() await self._waiter finally: if started_by_us: self.stop_watching() await self.wait_stopped_watching() assert self._master_addr return self._master_addr def get_master_addr(self) -> Optional[tuple[str, int]]: return self._master_addr def _on_update(self, payload: bytes) -> None: try: data = json.loads(base64.b64decode(payload)) except (TypeError, ValueError): logger.exception(f"could not decode Stolon cluster data") return # Successful Consul response, reset retry backoff self._retry_attempt = 0 cluster_status = data.get("cluster", {}).get("status", {}) master_db = cluster_status.get("master") cluster_phase = cluster_status.get("phase") if cluster_phase != "normal": logger.debug("Stolon cluster phase: %r", cluster_phase) if not master_db: return master_status = ( data.get("dbs", {}).get(master_db, {}).get("status", {}) ) master_healthy = master_status.get("healthy") if not master_healthy: logger.warning("Stolon reports unhealthy master Postgres.") return master_host = master_status.get("listenAddress") master_port = master_status.get("port") if not master_host or not master_port: return master_addr = master_host, int(master_port) if master_addr != self._master_addr: if self._master_addr is None: logger.info("Discovered master Postgres at %r", master_addr) self._master_addr = master_addr else: logger.critical( f"Switching over the master Postgres from %r to %r", self._master_addr, master_addr, ) self._master_addr = master_addr if self._failover_cb is not None: self.incr_metrics_counter("failover") self._failover_cb() if self._waiter is not None: if not self._waiter.done(): self._waiter.set_result(None) self._waiter = None class StolonConsulBackend(StolonBackend): def __init__( self, cluster_name: str, *, host: str = "127.0.0.1", port: int = 8500, ssl: Optional[ssl.SSLContext] = None, ) -> None: super().__init__() self._cluster_name = cluster_name self._host = host self._port = port self._ssl = ssl # This means we can request for 10 consecutive requests immediately # after each response without delay, and then we're capped to 0.1 # request(token) per second, or 1 request per 10 seconds. cap = float(os.environ.get("EDGEDB_SERVER_CONSUL_TOKEN_CAPACITY", 10)) rate = float(os.environ.get("EDGEDB_SERVER_CONSUL_TOKEN_RATE", 0.1)) self._token_bucket = token_bucket.TokenBucket(cap, rate) async def _start_watching(self) -> asyncwatcher.AsyncWatcherProtocol: _, pr = await asyncio.get_running_loop().create_connection( functools.partial( consul.ConsulKVProtocol, self, self._host, f"stolon/cluster/{self._cluster_name}/clusterdata", ), self._host, self._port, ssl=self._ssl, ) return pr # type: ignore [return-value] @functools.cached_property def dsn(self) -> str: proto = "http" if self._ssl is None else "https" return ( f"stolon+consul+{proto}://" f"{self._host}:{self._port}/{self._cluster_name}" ) def consume_tokens(self, tokens: int) -> float: return self._token_bucket.consume(tokens) def get_backend( sub_scheme: str, parsed_dsn: urllib.parse.ParseResult ) -> StolonBackend: name = parsed_dsn.path.lstrip("/") if not name: raise ValueError("Stolon requires cluster name in the URI as path.") cls = None storage, _, wire_protocol = sub_scheme.partition("+") if storage == "consul": cls = StolonConsulBackend if not cls: raise ValueError(f"{parsed_dsn.scheme} is not supported") if wire_protocol not in {"", "http", "https"}: raise ValueError(f"Wire protocol {wire_protocol} is not supported") args: dict[str, Any] = {} if parsed_dsn.hostname: args["host"] = parsed_dsn.hostname if parsed_dsn.port: args["port"] = parsed_dsn.port if wire_protocol == "https": args["ssl"] = ssl.create_default_context() return cls(name, **args) ================================================ FILE: edb/server/http.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Mapping, Optional, Self, Callable, ) import asyncio import dataclasses import logging import json as json_lib import urllib.parse import time from http import HTTPStatus as HTTPStatus from edb.server._rust_native._http import Http from . import rust_async_channel logger = logging.getLogger("edb.server") HeaderType = Optional[list[tuple[str, str]] | dict[str, str]] @dataclasses.dataclass(frozen=True) class HttpStat: response_time_ms: int error_code: int response_body_size: int response_content_type: str request_body_size: int request_content_type: str method: str streaming: bool StatCallback = Callable[[HttpStat], None] class HttpClient: def __init__( self, limit: int, user_agent: str = "EdgeDB", stat_callback: Optional[StatCallback] = None, ): self._task = None self._client = None self._limit = limit self._skip_reads = 0 self._loop: Optional[asyncio.AbstractEventLoop] = ( asyncio.get_running_loop() ) self._task = None self._streaming: dict[int, asyncio.Queue[Any]] = {} self._next_id = 0 self._requests: dict[int, asyncio.Future] = {} self._user_agent = user_agent self._stat_callback = stat_callback def __del__(self) -> None: if not self.closed(): logger.error(f"HttpClient {id(self)} was not closed") def close(self) -> None: if not self.closed(): if self._task is not None: self._task.cancel() self._task = None self._loop = None self._client = None def closed(self) -> bool: return self._task is None and self._loop is None def _ensure_task(self): if self.closed(): raise Exception("HttpClient was closed") if self._task is None: self._client = Http(self._limit) self._task = self._loop.create_task(self._boot(self._client)) def _ensure_client(self): if self._client is None: raise Exception("HttpClient was closed") return self._client def _safe_close(self, id): if self._client is not None: self._client._close(id) def _safe_ack(self, id): if self._client is not None: self._client._ack_sse(id) def _update_limit(self, limit: int): if self._client is not None and limit != self._limit: self._limit = limit self._client._update_limit(limit) def _process_headers(self, headers: HeaderType) -> list[tuple[str, str]]: if headers is None: return [] if isinstance(headers, Mapping): return [(k, v) for k, v in headers.items()] if isinstance(headers, list): return headers raise ValueError(f"Invalid headers type: {type(headers)}") def _process_content( self, headers: list[tuple[str, str]], data: bytes | str | dict[str, str] | None = None, json: Any | None = None, ) -> bytes: if json is not None: data = json_lib.dumps(json).encode('utf-8') headers.append(('Content-Type', 'application/json')) elif isinstance(data, str): data = data.encode('utf-8') elif isinstance(data, dict): data = urllib.parse.urlencode(data).encode('utf-8') headers.append( ('Content-Type', 'application/x-www-form-urlencoded') ) elif data is None: data = bytes() elif isinstance(data, bytes): pass else: raise ValueError(f"Invalid content type: {type(data)}") return data def _process_path(self, path: str) -> str: return path def with_context( self, *, base_url: Optional[str] = None, headers: HeaderType = None, url_munger: Optional[Callable[[str], str]] = None, ) -> Self: """Create an HttpClient with common optional base URL and headers that will be applied to all requests.""" return HttpClientContext( http_client=self, base_url=base_url, headers=headers, url_munger=url_munger, ) # type: ignore async def request( self, *, method: str, path: str, headers: HeaderType = None, data: bytes | str | dict[str, str] | None = None, json: Any | None = None, cache: bool = False, ) -> tuple[int, bytearray, dict[str, str]]: self._ensure_task() path = self._process_path(path) headers_list = self._process_headers(headers) headers_list.append(("User-Agent", self._user_agent)) data = self._process_content(headers_list, data, json) id = self._next_id self._next_id += 1 self._requests[id] = asyncio.Future() start_time = time.monotonic() try: self._ensure_client()._request( id, path, method, data, headers_list, cache ) resp = await self._requests[id] if self._stat_callback: status_code, body, headers = resp self._stat_callback( HttpStat( response_time_ms=int( (time.monotonic() - start_time) * 1000 ), error_code=status_code, response_body_size=len(body), response_content_type=dict(headers_list).get( 'content-type', '' ), request_body_size=len(data), request_content_type=dict(headers_list).get( 'content-type', '' ), method=method, streaming=False, ) ) return resp finally: del self._requests[id] async def get( self, path: str, *, headers: HeaderType = None, cache: bool = False, ) -> Response: result = await self.request( method="GET", path=path, data=None, headers=headers, cache=cache ) return Response.from_tuple(result) async def post( self, path: str, *, headers: HeaderType = None, data: bytes | str | dict[str, str] | None = None, json: Any | None = None, ) -> Response: result = await self.request( method="POST", path=path, data=data, json=json, headers=headers ) return Response.from_tuple(result) async def stream_sse( self, path: str, *, method: str = "POST", headers: HeaderType = None, data: bytes | str | dict[str, str] | None = None, json: Any | None = None, ) -> Response | ResponseSSE: self._ensure_task() path = self._process_path(path) headers_list = self._process_headers(headers) headers_list.append(("User-Agent", self._user_agent)) data = self._process_content(headers_list, data, json) id = self._next_id self._next_id += 1 self._requests[id] = asyncio.Future() start_time = time.monotonic() try: self._ensure_client()._request_sse( id, path, method, data, headers_list ) resp = await self._requests[id] if self._stat_callback: if id in self._streaming: status_code, _ = resp body = b'' else: assert len(resp) >= 2 status_code, body = resp[0:2] self._stat_callback( HttpStat( response_time_ms=int( (time.monotonic() - start_time) * 1000 ), error_code=status_code, response_body_size=len(body), response_content_type=dict(headers_list).get( 'content-type', '' ), request_body_size=len(data), request_content_type=dict(headers_list).get( 'content-type', '' ), method=method, streaming=id in self._streaming, ) ) if id in self._streaming: # Valid to call multiple times cancel = lambda: self._safe_close(id) # Acknowledge SSE message (for backpressure) ack = lambda: self._safe_ack(id) return ResponseSSE.from_tuple( resp, self._streaming[id], cancel, ack ) return Response.from_tuple(resp) finally: del self._requests[id] async def _boot(self, client) -> None: logger.info(f"HTTP client initialized, user_agent={self._user_agent}") try: channel = rust_async_channel.RustAsyncChannel( client._channel, self._process_message ) try: await channel.run() finally: channel.close() except Exception: logger.error(f"Error in HTTP client", exc_info=True) raise def _process_message(self, msg: tuple[Any, ...]) -> None: try: msg_type, id, data = msg if msg_type == 0: # Error if id in self._requests: self._requests[id].set_exception(Exception(data)) if id in self._streaming: self._streaming[id].put_nowait(None) del self._streaming[id] elif msg_type == 1: # Response if id in self._requests: self._requests[id].set_result(data) elif msg_type == 2: # SSEStart if id in self._requests: self._streaming[id] = asyncio.Queue() self._requests[id].set_result(data) elif msg_type == 3: # SSEEvent if id in self._streaming: self._streaming[id].put_nowait(data) elif msg_type == 4: # SSEEnd if id in self._streaming: self._streaming[id].put_nowait(None) del self._streaming[id] except Exception as e: logger.error(f"Error processing message: {e}", exc_info=True) raise async def __aenter__(self) -> Self: return self async def __aexit__(self, *args) -> None: # type: ignore self.close() class HttpClientContext(HttpClient): def __init__( self, http_client: HttpClient, url_munger: Callable[[str], str] | None = None, headers: HeaderType = None, base_url: str | None = None, ): self.url_munger = url_munger self.http_client = http_client self.base_url = base_url self.headers = super()._process_headers(headers) # HttpClientContext does not need to be closed def __del__(self): pass def closed(self): return super().closed() def close(self): pass async def __aenter__(self) -> Self: return self async def __aexit__(self, *args) -> None: # type: ignore pass def _process_headers(self, headers): headers = super()._process_headers(headers) headers += self.headers return headers def _process_path(self, path): path = super()._process_path(path) if self.base_url is not None: path = self.base_url + path if self.url_munger is not None: path = self.url_munger(path) return path async def request( self, *, method, path, headers=None, data=None, json=None, cache=False, ): path = self._process_path(path) headers = self._process_headers(headers) return await self.http_client.request( method=method, path=path, headers=headers, data=data, json=json, cache=cache, ) async def stream_sse( self, path, *, method="POST", headers=None, data=None, json=None ): path = self._process_path(path) headers = self._process_headers(headers) return await self.http_client.stream_sse( path, method=method, headers=headers, data=data, json=json ) class CaseInsensitiveDict(dict): def __init__(self, data: Optional[list[tuple[str, str]]] = None): super().__init__() if data: for k, v in data: self[k.lower()] = v def __setitem__(self, key: str, value: str): super().__setitem__(key.lower(), value) def __getitem__(self, key: str): return super().__getitem__(key.lower()) def get(self, key: str, default=None): return super().get(key.lower(), default) def update(self, *args, **kwargs: str) -> None: if args: data = args[0] if isinstance(data, Mapping): for key, value in data.items(): self[key] = value else: for key, value in data: self[key] = value for key, value in kwargs.items(): self[key] = value @dataclasses.dataclass(frozen=True) class Response: status_code: int body: bytearray headers: CaseInsensitiveDict is_streaming: bool = False @classmethod def from_tuple(cls, data: tuple[int, bytearray, dict[str, str]]): status_code, body, headers_list = data headers = CaseInsensitiveDict([(k, v) for k, v in headers_list.items()]) return cls(status_code, body, headers) def json(self): return json_lib.loads(self.body.decode('utf-8')) def bytes(self): return bytes(self.body) @property def text(self) -> str: return self.body.decode('utf-8') async def __aenter__(self) -> Self: return self async def __aexit__(self, exc_type, exc_value, traceback): pass @dataclasses.dataclass(frozen=True) class ResponseSSE: status_code: int headers: CaseInsensitiveDict _stream: asyncio.Queue = dataclasses.field(repr=False) _cancel: Callable[[], None] = dataclasses.field(repr=False) _ack: Callable[[], None] = dataclasses.field(repr=False) _closed: list[bool] = dataclasses.field(default_factory=lambda: [False]) is_streaming: bool = True @classmethod def from_tuple( cls, data: tuple[int, dict[str, str]], stream: asyncio.Queue, cancel: Callable[[], None], ack: Callable[[], None], ): status_code, headers = data headers = CaseInsensitiveDict([(k, v) for k, v in headers.items()]) return cls(status_code, headers, stream, cancel, ack) @dataclasses.dataclass(frozen=True) class SSEEvent: event: str data: str id: Optional[str] = None def json(self): return json_lib.loads(self.data) def close(self): if not self.closed(): self._closed[0] = True self._cancel() def closed(self) -> bool: return self._closed[0] def __del__(self): if not self.closed(): logger.error(f"ResponseSSE {id(self)} was not closed") def __aiter__(self): return self async def __anext__(self): if self.closed(): raise StopAsyncIteration next = await self._stream.get() try: if next is None: self.close() raise StopAsyncIteration id, data, event = next return self.SSEEvent(event, data, id) finally: self._ack() async def __aenter__(self) -> Self: return self async def __aexit__(self, exc_type, exc_value, traceback): self.close() ================================================ FILE: edb/server/inplace_upgrade.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Iterator, Optional, Sequence, ) import dataclasses import json import logging import os from edb import edgeql from edb import errors from edb.edgeql import ast as qlast from edb.common import debug from edb.common import uuidgen from edb.schema import ddl as s_ddl from edb.schema import delta as sd from edb.schema import extensions as s_exts from edb.schema import functions as s_func from edb.schema import links as s_links from edb.schema import name as sn from edb.schema import objtypes as s_objtypes from edb.schema import reflection as s_refl from edb.schema import scalars as s_scalars from edb.schema import schema as s_schema from edb.schema import types as s_types from edb.schema import version as s_ver from edb.server import args as edbargs from edb.server import bootstrap from edb.server import config from edb.server import compiler as edbcompiler from edb.server.compiler import ddl as compiler_ddl from edb.server import defines as edbdef from edb.server import instdata from edb.server import pgcluster from edb.server import pgcon from edb.pgsql import common as pg_common from edb.pgsql import dbops from edb.pgsql import patches as pg_patches from edb.pgsql import metaschema from edb.pgsql import trampoline logger = logging.getLogger('edb.server') PGCon = bootstrap.PGConnectionProxy | pgcon.PGConnection def _is_stdlib_target( t: s_objtypes.ObjectType, schema: s_schema.Schema, ) -> bool: if intersection := t.get_intersection_of(schema): return any((_is_stdlib_target(it, schema) for it in intersection.objects(schema))) elif union := t.get_union_of(schema): return any((_is_stdlib_target(ut, schema) for ut in union.objects(schema))) name = t.get_name(schema) if name == sn.QualName('std', 'Object'): return False return t.get_name(schema).get_module_name() in s_schema.STD_MODULES def _compile_inheritance_schema_fixup( ctx: bootstrap.BootstrapContext, current_block: dbops.PLBlock, schema: s_schema.ChainedSchema, keys: dict[str, Any], ) -> None: """Compile schema-specific fixups for added stdlib types.""" backend_params = ctx.cluster.get_runtime_params() # Recompile functions that reference stdlib types (like # std::BaseObject or schema::Object), since new subtypes may have # been added. to_recompile = schema._top_schema.get_objects(type=s_func.Function) for func in to_recompile: if func.get_name(schema).get_root_module_name() == s_schema.EXT_MODULE: continue # If none of the types referenced in the function are standard # library types, we don't need to recompile. if not ( (expr := func.get_nativecode(schema)) and expr.refs and any( isinstance(dep, s_objtypes.ObjectType) and _is_stdlib_target(dep, schema) for dep in expr.refs.objects(schema) ) ): continue alter_func = func.init_delta_command( schema, sd.AlterObject ) alter_func.set_attribute_value( 'nativecode', func.get_nativecode(schema) ) alter_func.canonical = True # N.B: We are ignoring the schema changes, since we aren't # updating the schema version. _, plan, _ = bootstrap._process_delta_params( sd.DeltaRoot.from_commands(alter_func), schema, backend_params, stdmode=False, **keys, ) plan.generate(current_block) # Regenerate on_target_delete triggers for any links targeting a # stdlib type. links = schema._top_schema.get_objects(type=s_links.Link) for link in links: if link.get_name(schema).get_root_module_name() == s_schema.EXT_MODULE: continue source = link.get_source(schema) if ( not source or not source.is_material_object_type(schema) or link.get_computable(schema) or link.get_shortname(schema).name == '__type__' or not _is_stdlib_target(link.get_target(schema), schema) ): continue pol = link.get_on_target_delete(schema) # HACK: Set the policy in a temporary in-memory schema to be # something else, so that we can set it back to the real value # and pgdelta will generate code for it. fake_pol = ( s_links.LinkTargetDeleteAction.Allow if pol == s_links.LinkTargetDeleteAction.Restrict else s_links.LinkTargetDeleteAction.Restrict ) fake_schema = link.set_field_value(schema, 'on_target_delete', fake_pol) alter_delta, alter_link, _ = link.init_delta_branch( schema, sd.CommandContext(), sd.AlterObject ) alter_link.set_attribute_value('on_target_delete', pol) # N.B: We are ignoring the schema changes, since we aren't # updating the schema version. _, plan, _ = bootstrap._process_delta_params( sd.DeltaRoot.from_commands(alter_delta), fake_schema, backend_params, stdmode=False, **keys, ) plan.generate(current_block) def _compile_schema_fixup( ctx: bootstrap.BootstrapContext, schema: s_schema.ChainedSchema, keys: dict[str, Any], ) -> dbops.PLBlock: """Compile any schema-specific fixes that need to be applied.""" current_block = dbops.PLTopBlock() _compile_inheritance_schema_fixup(ctx, current_block, schema, keys) # Remove pointless triggers that existed before 6.8 compiler_ddl.remove_pointless_triggers(schema).generate(current_block) return current_block async def _collect_6x_upgrade_patches( ctx: bootstrap.BootstrapContext, schema: s_schema.Schema, ) -> tuple[list[qlast.Command], bool, bool]: from edb.pgsql import patches_6x cmds: list[qlast.Command] = [] # If that table doesn't exist, we aren't upgrading from 6.x, so # don't worry. (Which means, at this point, a 7.x -> dev/nightly # upgrade.) try: res = await ctx.conn.sql_fetch_val( f""" SELECT json::json FROM edgedbinstdata_v6_2f20b3fed0.instdata WHERE key = 'num_patches' """.encode('utf-8'), ) except pgcon.BackendError: return [], False, False needs_config = False jnum = json.loads(res) for kind, patch in patches_6x.PATCHES[jnum:]: if not kind.startswith('edgeql+user_ext'): continue # Only run a userext update if the extension we are trying to # update is installed. extension_name = kind.split('|')[-1] extension = schema.get_global( s_exts.Extension, extension_name, default=None) if not extension: continue if '+config' in kind: needs_config |= True for ddl_cmd in edgeql.parse_block(patch): if not isinstance(ddl_cmd, qlast.DDLCommand): assert isinstance(ddl_cmd, qlast.Query) ddl_cmd = qlast.DDLQuery(query=ddl_cmd) cmds.append(ddl_cmd) # 6.2 introduced a change to EmbeddingModel (the addition of a new # default value) that requires a repair to sync the user schema up # with, since ai index annotations get copied into objects. needs_repair = bool( schema.get_global(s_exts.Extension, 'ai', default=None) ) return cmds, needs_repair, needs_config def _subcommands_preorder(cmd: sd.Command) -> Iterator[sd.Command]: yield cmd for sub in cmd.get_subcommands(): yield from _subcommands_preorder(sub) async def _apply_ddl_schema_storage( ddl_cmd: qlast.Command, ctx: bootstrap.BootstrapContext, backend_params, keys: dict[str, Any], compilerctx: edbcompiler.CompileContext, schema: s_schema.Schema, schema_object_ids: dict[tuple[sn.Name, Any], uuidgen.UUID], fake_backend_ids: bool=False, ) -> tuple[s_schema.ChainedSchema, str]: # applies ddl schema storage but not the real ddl # returns that, though current_block = dbops.PLTopBlock() if debug.flags.sdl_loading: ddl_cmd.dump_edgeql() assert isinstance(ddl_cmd, qlast.DDLCommand) delta_command = s_ddl.delta_from_ddl( ddl_cmd, modaliases={}, schema=schema, schema_object_ids=schema_object_ids, **keys, ) # Prune any AlterSchemaVersion commands, because they won't work, # since all the compile_schema_storage_in_delta commands run right # away, while if we run a fixup block, it is during finalize. sub: sd.Command for sub in delta_command.get_subcommands( type=s_ver.AlterSchemaVersion ): delta_command.discard(sub) # This hack is quite frustrating: since the actual changes to the # pg schema (for extensions) happen *after* the reflection schema # updates (during finalization), backend_ids for new scalars # aren't ready, so we force reflection to *not* try computing them # now, and we'll get them later. if fake_backend_ids: for sub in _subcommands_preorder(delta_command): if not isinstance(sub, sd.CreateObject): continue mcls = sub.get_schema_metaclass() if ( issubclass(mcls, (s_scalars.ScalarType, s_types.Collection)) and not issubclass(mcls, s_types.CollectionExprAlias) ): sub.set_attribute_value('backend_id', 0) schema, plan, tplan = bootstrap._process_delta_params( delta_command, schema, backend_params, stdmode=False, **keys, ) compilerctx.state.current_tx().update_schema(schema) fixup_block = dbops.PLTopBlock() plan.generate(fixup_block) # # ??? # tplan.generate(current_block) fixup_ddl = fixup_block.to_string() context = sd.CommandContext(**keys) edbcompiler.compile_schema_storage_in_delta( ctx=compilerctx, delta=plan, block=current_block, context=context, ) # TODO: Should we batch them all up? patch = current_block.to_string() if debug.flags.delta_execute: debug.header('Patch Script') debug.dump_code(patch, lexer='sql') await ctx.conn.sql_execute(patch.encode('utf-8')) assert isinstance(schema, s_schema.ChainedSchema) return schema, fixup_ddl async def _upgrade_one( ctx: bootstrap.BootstrapContext, state: edbcompiler.CompilerState, global_schema: s_schema.Schema, std_global_schema: s_schema.Schema, upgrade_data: Optional[Any], ) -> None: if not upgrade_data: return backend_params = ctx.cluster.get_runtime_params() assert backend_params.has_create_database ddl = upgrade_data['ddl'] # ids: schema_object_ids = { ( sn.name_from_string(name), qltype if qltype else None ): uuidgen.UUID(objid) for name, qltype, objid in upgrade_data['ids'] } # Load the schemas schema = s_schema.ChainedSchema( state.std_schema, s_schema.EMPTY_SCHEMA, global_schema, ) compilerctx = edbcompiler.new_compiler_context( compiler_state=state, user_schema=schema.get_top_schema(), bootstrap_mode=False, # MAYBE? ) keys: dict[str, Any] = dict( testmode=True, ) # Apply the DDL, but usually *only* execute the schema storage part!! # For the actual core DDL, only do schema storage. # Sometimes we have to run extension patch code, and then we # need to run the actual code. for ddl_cmd in edgeql.parse_block(ddl): schema, _ = await _apply_ddl_schema_storage( ddl_cmd, ctx, backend_params, keys, compilerctx, schema, schema_object_ids ) schema_fixup = '' upgrade_patches, needs_repair, needs_config = ( await _collect_6x_upgrade_patches(ctx, schema) ) for ddl_cmd in upgrade_patches: schema, fixup = await _apply_ddl_schema_storage( ddl_cmd, ctx, backend_params, keys, compilerctx, schema, # Empty schema_object_ids because this isn't actually user # objects anymore, and because reusing the # schema_object_ids led to a subtle issue: # when altering a computed Global in a patch, the underlying # ObjectType got deleted and replaced with a new one with # identical id, which caused the link policy to spuriously # fail! schema_object_ids={}, fake_backend_ids=True, ) schema_fixup += fixup if upgrade_patches: version_key = pg_patches.get_version_key(len(pg_patches.PATCHES)) sysqueries = json.loads(await instdata.get_instdata( ctx.conn, f'sysqueries{version_key}', 'json')) schema_fixup += sysqueries['backend_id_fixup'] if needs_config: existing_view_columns = await bootstrap.get_existing_view_columns( ctx.conn) cfg_block = dbops.PLTopBlock() metaschema.get_config_views(schema, existing_view_columns).generate( cfg_block) schema_fixup += cfg_block.to_string() # If we need to do a schema repair... do it if needs_repair: # We want to do the repair against the *new* global schema # (which is std_global_schema), but there might be non-std # extensions installed, so we chain it with the original # global_schema to get it to work. confused_global_schema = s_schema.ChainedSchema( global_schema, std_global_schema, s_schema.EMPTY_SCHEMA, ) repair = bootstrap.prepare_repair_patch( state.std_schema, state.refl_schema, schema.get_top_schema(), confused_global_schema, state.schema_class_layout, backend_params, ) schema_fixup += repair # Refresh the pg_catalog materialized views current_block = dbops.PLTopBlock() refresh = metaschema.generate_sql_information_schema_refresh( backend_params.instance_params.version ) refresh.generate(current_block) patch = current_block.to_string() await ctx.conn.sql_execute(patch.encode('utf-8')) new_local_spec = config.load_spec_from_schema( schema, only_exts=True, # suppress validation because we might be in an intermediate state validate=False, ) spec_json = config.spec_to_json(new_local_spec) await ctx.conn.sql_execute(trampoline.fixup_query(f'''\ UPDATE edgedbinstdata_VER.instdata SET json = {pg_common.quote_literal(spec_json)} WHERE key = 'configspec_ext'; ''').encode('utf-8')) # Compile the fixup script for the schema and stash it away schema_fixup += _compile_schema_fixup(ctx, schema, keys).to_string() await bootstrap._store_static_text_cache( ctx, f'schema_fixup_query', schema_fixup, ) DEP_CHECK_QUERY = r''' with -- Fetch all the object types we care about. all_objs AS ( select objs.oid, ns.nspname as nspname, objs.name, objs.typ from ( select oid as oid, relname as name, (case when relkind = 'v' then 'view' else 'table' end) as typ, relnamespace as namespace from pg_catalog.pg_class union all select oid as oid, typname as name, 'type' as typ, typnamespace as namespace from pg_catalog.pg_type union all select oid as oid, proname as name, 'function' as typ, pronamespace as namespace from pg_catalog.pg_proc ) as objs inner join pg_catalog.pg_namespace ns on objs.namespace = ns.oid ), -- Fetch pg_depend along with some special handling of internal deps. cdeps AS ( select dep.objid, dep.refobjid, dep.deptype from pg_catalog.pg_depend dep union -- if there is an incoming 'i' dep to an obj A from B, treat all -- other outgoing deps from B as outgoing from A. We do this because -- the actual query in a view is stored in a pg_rewrite that *depends on* -- the view. (Seems backward.) select i.refobjid, c.refobjid, c.deptype from pg_catalog.pg_depend i inner join pg_catalog.pg_depend c on i.objid = c.objid where i.refobjid != c.refobjid and i.deptype = 'i' ) -- Get any dependencies from outside our namespaces into them. select src.typ, src.nspname, src.name, tgt.typ, tgt.nspname, tgt.name from all_objs src inner join cdeps dep on src.oid = dep.objid inner join all_objs tgt on tgt.oid = dep.refobjid where true and NOT src.nspname = ANY ({namespaces}) and tgt.nspname = ANY ({namespaces}) and dep.deptype != 'i' ''' async def _delete_schemas( conn: PGCon, to_delete: Sequence[str] ) -> None: # To add a bit more safety, check whether there are any # dependencies on the modules we want to delete from outside those # modules since the only way to delete non-empty schemas in # postgres is CASCADE. namespaces = ( f'ARRAY[{", ".join(pg_common.quote_literal(k) for k in to_delete)}]' ) qry = DEP_CHECK_QUERY.format(namespaces=namespaces) existing_deps = await conn.sql_fetch(qry.encode('utf-8')) if existing_deps: # All of the fields are text, so decode them all sdeps = [ tuple(x.decode('utf-8') for x in row) for row in existing_deps ] messages = [ f'{st} {pg_common.qname(ss, sn)} depends on ' f'{tt} {pg_common.qname(ts, tn)}\n' for st, ss, sn, tt, ts, tn in sdeps ] raise AssertionError( 'Dependencies to old schemas still exist: \n%s' % ''.join(messages) ) # It is *really* dumb the way that CASCADE works in postgres. await conn.sql_execute(f""" drop schema {', '.join(to_delete)} cascade """.encode('utf-8')) async def _get_namespaces( conn: PGCon, ) -> list[str]: return json.loads(await conn.sql_fetch_val(""" select json_agg(nspname) from pg_namespace where nspname like 'edgedb%\\_v%' """.encode('utf-8'))) async def _finalize_one( ctx: bootstrap.BootstrapContext, database: str, ) -> None: conn = ctx.conn # If the upgrade is already finalized, skip it. This lets us be # resilient to crashes during the finalization process, which may # leave some databases upgraded but not all. if (await instdata.get_instdata(conn, 'upgrade_finalized', 'text')) == b'1': logger.info(f"Database upgrade already finalized") return trampoline_query = await instdata.get_instdata( conn, 'trampoline_pivot_query', 'text') fixup_query = await instdata.get_instdata( conn, 'schema_fixup_query', 'text') await conn.sql_execute(trampoline_query) if fixup_query: await conn.sql_execute(fixup_query) # For the template database (which is upgraded *last*, after all # others have succeeded), run the commands to update the global # schema. (To populate the extension packages.) if database == edbdef.EDGEDB_TEMPLATE_DB: global_schema_update_query = await instdata.get_instdata( conn, 'global_schema_update_query', 'text') await conn.sql_execute(global_schema_update_query) namespaces = await _get_namespaces(ctx.conn) cur_suffix = pg_common.versioned_schema("") to_delete = [x for x in namespaces if not x.endswith(cur_suffix)] await _delete_schemas(conn, to_delete) await bootstrap._store_static_text_cache( ctx, f'upgrade_finalized', '1', ) async def _get_databases( ctx: bootstrap.BootstrapContext, ) -> list[str]: cluster = ctx.cluster tpl_db = cluster.get_db_name(edbdef.EDGEDB_TEMPLATE_DB) conn = await cluster.connect( source_description="inplace upgrade", database=tpl_db ) # FIXME: Use the sys query instead? try: databases = json.loads(await conn.sql_fetch_val( trampoline.fixup_query(""" SELECT json_agg(name) FROM edgedb_VER."_SysBranch"; """).encode('utf-8'), )) finally: conn.terminate() # DEBUG VELOCITY HACK: You can add a failing database to EARLY # when trying to upgrade the whole suite. # # Note: We put template last, since when deleting, we need it to # stay around so we can query all branches. EARLY: tuple[str, ...] = () databases.sort( key=lambda k: (k == edbdef.EDGEDB_TEMPLATE_DB, k not in EARLY, k) ) return databases async def _get_global_schema( ctx: bootstrap.BootstrapContext, state: edbcompiler.CompilerState, ) -> s_schema.Schema: cluster = ctx.cluster tpl_db = cluster.get_db_name(edbdef.EDGEDB_TEMPLATE_DB) conn = await cluster.connect( source_description="inplace upgrade", database=tpl_db ) # FIXME: Use the sys query instead? assert state.global_intro_query try: json_data = await conn.sql_fetch_val( state.global_intro_query.encode('utf-8')) finally: conn.terminate() return s_refl.parse_schema( base_schema=state.std_schema, data=json_data, schema_class_layout=state.schema_class_layout, ) async def _rollback_one( ctx: bootstrap.BootstrapContext, ) -> None: conn = ctx.conn namespaces = await _get_namespaces(conn) if pg_common.versioned_schema("edgedb") not in namespaces: logger.info(f"Database already rolled back or not prepared; skipping") return if (await instdata.get_instdata(conn, 'upgrade_finalized', 'text')) == b'1': logger.info(f"Database upgrade already finalized") raise errors.ConfigurationError( f"attempting to rollback database that has already begun " f"finalization: retry finalize instead" ) cur_suffix = pg_common.versioned_schema("") to_delete = [x for x in namespaces if x.endswith(cur_suffix)] await _delete_schemas(conn, to_delete) async def _rollback_all( ctx: bootstrap.BootstrapContext, ) -> None: cluster = ctx.cluster databases = await _get_databases(ctx) for database in databases: if database == os.environ.get( 'EDGEDB_UPGRADE_ROLLBACK_ERROR_INJECTION' ): raise AssertionError(f'failure injected on {database}') conn = bootstrap.PGConnectionProxy( cluster, source_description='inplace upgrade: rollback all', dbname=cluster.get_db_name(database), ) try: subctx = dataclasses.replace(ctx, conn=conn) logger.info(f"Rolling back preparation of database '{database}'") await _rollback_one(ctx=subctx) finally: conn.terminate() async def _upgrade_all( ctx: bootstrap.BootstrapContext, ) -> None: cluster = ctx.cluster stdlib, compiler = (await bootstrap._bootstrap(ctx)) state = compiler.state databases = await _get_databases(ctx) std_global_schema = stdlib.global_schema global_schema = await _get_global_schema(ctx, state) assert ctx.args.inplace_upgrade_prepare with open(ctx.args.inplace_upgrade_prepare) as f: upgrade_data = json.load(f) for database in databases: if database == edbdef.EDGEDB_TEMPLATE_DB: continue conn = bootstrap.PGConnectionProxy( cluster, source_description="inplace upgrade: upgrade all", dbname=cluster.get_db_name(database) ) try: subctx = dataclasses.replace(ctx, conn=conn) logger.info(f"Upgrading database '{database}'") await bootstrap._bootstrap(ctx=subctx, no_template=True) logger.info(f"Populating schema tables for '{database}'") await _upgrade_one( ctx=subctx, state=state, global_schema=global_schema, std_global_schema=std_global_schema, upgrade_data=upgrade_data.get(database), ) finally: conn.terminate() async def _finalize_all( ctx: bootstrap.BootstrapContext, ) -> None: cluster = ctx.cluster databases = await _get_databases(ctx) async def go( message: str, finish_message: Optional[str], final_command: bytes, inject_failure_on: Optional[str]=None, ) -> None: for database in databases: conn = await cluster.connect( source_description="inplace upgrade: finish", database=cluster.get_db_name(database) ) tmp_table_query = ( pgcon.SETUP_TEMP_TABLE_SCRIPT + pgcon.SETUP_DML_DUMMY_TABLE_SCRIPT ) await conn.sql_execute(tmp_table_query.encode('utf-8')) try: subctx = dataclasses.replace(ctx, conn=conn) logger.info(f"{message} database '{database}'") await conn.sql_execute(b'START TRANSACTION') # DEBUG HOOK: Inject a failure if specified if database == inject_failure_on: raise AssertionError(f'failure injected on {database}') await _finalize_one(subctx, database) await conn.sql_execute(final_command) if finish_message: logger.info(f"{finish_message} database '{database}'") finally: conn.terminate() inject_failure = os.environ.get('EDGEDB_UPGRADE_FINALIZE_ERROR_INJECTION') # Test all of the pivots in transactions we rollback, to make sure # that they work. This ensures that if there is a bug in the pivot # scripts on some database, we fail before any irreversible # changes are made to any database. # # *Then*, apply them all for real. They may fail # when applying for real, but that should be due to a crash or # some such, and so the user should be able to retry. # # We wanted to apply them all inside transactions and then commit # the transactions, but that requires holding open potentially too # many connections. await go("Testing pivot of", None, b'ROLLBACK') await go("Pivoting", "Finished pivoting", b'COMMIT', inject_failure) async def inplace_upgrade( cluster: pgcluster.BaseCluster, args: edbargs.ServerConfig, ) -> None: """Perform some or all of the inplace upgrade operations""" pgconn = bootstrap.PGConnectionProxy( cluster, source_description="inplace_upgrade" ) ctx = bootstrap.BootstrapContext(cluster=cluster, conn=pgconn, args=args) try: # XXX: Do we need to do this? mode = await bootstrap._get_cluster_mode(ctx) ctx = dataclasses.replace(ctx, mode=mode) if args.inplace_upgrade_rollback: await _rollback_all(ctx) if args.inplace_upgrade_prepare: await _upgrade_all(ctx) if args.inplace_upgrade_finalize: await _finalize_all(ctx) finally: pgconn.terminate() ================================================ FILE: edb/server/instdata.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, TYPE_CHECKING, ) from edb.pgsql import common as pg_common if TYPE_CHECKING: from edb.pgsql import metaschema async def get_instdata( backend_conn: metaschema.PGConnection, key: str, field: str, versioned: bool = True, ) -> bytes | Any: if field == 'json': field = 'json::json' if versioned: schema = pg_common.versioned_schema('edgedbinstdata') else: schema = 'edgedbinstdata' return await backend_conn.sql_fetch_val( f""" SELECT {field} FROM {schema}.instdata WHERE key = $1 """.encode('utf-8'), args=[key.encode("utf-8")], ) ================================================ FILE: edb/server/logsetup.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2011-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import contextlib import copy import io import logging import logging.handlers import sys import warnings from edb.common import debug from edb.common import term LOG_LEVELS = { 'S': 'SILENT', 'D': 'DEBUG', 'I': 'INFO', 'E': 'ERROR', 'W': 'WARN', 'WARN': 'WARN', 'ERROR': 'ERROR', 'CRITICAL': 'CRITICAL', 'INFO': 'INFO', 'DEBUG': 'DEBUG', 'SILENT': 'SILENT' } class Dark16: critical = term.Style16(color='white', bgcolor='red', bold=True) error = term.Style16(color='white', bgcolor='red') default = term.Style16(color='white', bgcolor='blue') pid = date = term.Style16(color='black', bold=True) name = term.Style16(color='black', bold=True) message = term.Style16() class Dark256: critical = term.Style256(color='#c6c6c6', bgcolor='#870000', bold=True) error = term.Style256(color='#c6c6c6', bgcolor='#870000') warning = term.Style256(color='#c6c6c6', bgcolor='#5f00d7') info = term.Style256(color='#c6c6c6', bgcolor='#005f00') default = term.Style256(color='#c6c6c6', bgcolor='#000087') pid = date = term.Style256(color='#626262', bold=True) name = term.Style256(color='#A2A2A2') message = term.Style16() class EdgeDBLogFormatter(logging.Formatter): default_time_format = '%Y-%m-%dT%H:%M:%S' default_msec_format = '%s.%03d' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__styles = None self._colorize = term.use_colors() if self._colorize: self._init_styles() def _init_styles(self): if not self.__styles: if term.max_colors() >= 255: self.__styles = Dark256() else: self.__styles = Dark16() def formatTime(self, record, datefmt=None): time = super().formatTime(record, datefmt=datefmt) if self._colorize: time = self.__styles.date.apply(time) return time def formatException(self, ei): sio = io.StringIO() with contextlib.redirect_stdout(sio): sys.excepthook(*ei) s = sio.getvalue() sio.close() if s[-1:] == "\n": s = s[:-1] return s def format(self, record): if self._colorize: record = copy.copy(record) level = record.levelname level_style = getattr(self.__styles, level.lower(), self.__styles.default) record.levelname = level_style.apply(level) record.process = self.__styles.pid.apply(str(record.process)) record.message = self.__styles.message.apply(record.getMessage()) record.name = self.__styles.name.apply(record.name) return super().format(record) class EdgeDBLogHandler(logging.StreamHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) fmt = EdgeDBLogFormatter( '{levelname} {process} {tenant} {asctime} {name}: {message}', style='{') self.setFormatter(fmt) IGNORE_DEPRECATIONS_IN = { 'graphql', 'promise', } def setup_logging(log_level, log_destination): log_level = log_level.upper() try: log_level = LOG_LEVELS[log_level] except KeyError: raise RuntimeError('Invalid logging level {!r}'.format(log_level)) if log_level == 'SILENT': logger = logging.getLogger() logger.disabled = True logger.setLevel(logging.CRITICAL) return if log_destination == 'syslog': fmt = logging.Formatter( '{processName}[{process}]: {tenant}: {name}: {message}', style='{') handler = logging.handlers.SysLogHandler( '/dev/log', facility=logging.handlers.SysLogHandler.LOG_DAEMON) handler.setFormatter(fmt) elif log_destination == 'stderr': handler = EdgeDBLogHandler() else: fmt = logging.Formatter( '{levelname} {process} {tenant} {asctime} {name}: {message}', style='{') handler = logging.FileHandler(log_destination) handler.setFormatter(fmt) log_level = logging._checkLevel(log_level) logger = logging.getLogger() logger.setLevel(log_level) logger.addHandler(handler) # Channel warnings into logging system logging.captureWarnings(True) # Show DeprecationWarnings by default ... warnings.simplefilter('default', category=DeprecationWarning) # ... except for some third-party` modules. for ignored_module in IGNORE_DEPRECATIONS_IN: warnings.filterwarnings('ignore', category=DeprecationWarning, module=ignored_module) if not debug.flags.log_metrics: log_metrics = logging.getLogger('edb.server.metrics') log_metrics.setLevel(logging.ERROR) ================================================ FILE: edb/server/main.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, Iterator, Mapping, NoReturn, TYPE_CHECKING, ) from edb.common.log import early_setup # ruff: noqa: E402 early_setup() import asyncio import contextlib import json import logging import os import os.path import pathlib import immutables import resource import signal import sys import tempfile import click import setproctitle import uvloop from edb import buildmeta from edb import errors from edb.ir import statypes from edb.common import exceptions from edb.common import devmode from edb.common import signalctl from edb.common import debug from . import config from . import args as srvargs from . import compiler as edbcompiler from . import daemon from . import defines from . import logsetup from . import pgcluster from . import service_manager if TYPE_CHECKING: from . import server from edb.server import bootstrap else: # Import server lazily to make sure that most of imports happen # under coverage (if we're testing with it). Otherwise # coverage will fail to detect that "import edb..." lines # actually were run. server = None logger = logging.getLogger('edb.server') _server_initialized = False def abort(msg, *args, exit_code=1) -> NoReturn: logger.critical(msg, *args) sys.exit(exit_code) @contextlib.contextmanager def _ensure_runstate_dir( default_runstate_dir: Optional[pathlib.Path], specified_runstate_dir: Optional[pathlib.Path] ) -> Iterator[pathlib.Path]: temp_runstate_dir = None if specified_runstate_dir is None: if default_runstate_dir is None: temp_runstate_dir = tempfile.TemporaryDirectory(prefix='edbrun-') runstate_parent = pathlib.Path(temp_runstate_dir.name) else: runstate_parent = default_runstate_dir try: runstate_dir = buildmeta.get_runstate_path(runstate_parent) except buildmeta.MetadataError: abort( f'cannot determine the runstate directory location; ' f'please use --runstate-dir to specify the correct location') else: runstate_dir = specified_runstate_dir runstate_dir = pathlib.Path(runstate_dir) if not runstate_dir.exists(): try: runstate_dir.mkdir(parents=True) except PermissionError as ex: abort( f'cannot create the runstate directory: ' f'{ex!s}; please use --runstate-dir to specify ' f'the correct location') if not os.path.isdir(runstate_dir): abort(f'{str(runstate_dir)!r} is not a directory; please use ' f'--runstate-dir to specify the correct location') try: yield runstate_dir finally: if temp_runstate_dir is not None: temp_runstate_dir.cleanup() @contextlib.contextmanager def _internal_state_dir( runstate_dir: pathlib.Path, args: srvargs.ServerConfig ) -> Iterator[tuple[pathlib.Path, srvargs.ServerConfig]]: try: with tempfile.TemporaryDirectory(prefix="", dir=runstate_dir) as td: if ( args.tls_cert_file and '' in str(args.tls_cert_file) ): args = args._replace( tls_cert_file=pathlib.Path( str(args.tls_cert_file).replace( '', td) ), tls_key_file=pathlib.Path( str(args.tls_key_file).replace( '', td) ) ) if ( args.jws_key_file and '' in str(args.jws_key_file) ): args = args._replace( jws_key_file=pathlib.Path( str(args.jws_key_file).replace( '', td) ), ) tdp = pathlib.Path(td) yield tdp, args except PermissionError as ex: abort(f'cannot write to the runstate directory: ' f'{ex!s}; please fix the permissions or use ' f'--runstate-dir to specify the correct location') async def _init_cluster( cluster: pgcluster.BaseCluster, args: srvargs.ServerConfig ) -> tuple[bool, edbcompiler.Compiler]: from edb.server import bootstrap new_instance = await bootstrap.ensure_bootstrapped(cluster, args) global _server_initialized _server_initialized = True return new_instance def _init_parsers(): # Initialize parsers that are used in the server process. from edb.edgeql import parser as ql_parser ql_parser.preload_spec() async def _run_server( cluster: pgcluster.BaseCluster, args: srvargs.ServerConfig, runstate_dir: pathlib.Path, internal_runstate_dir: pathlib.Path, *, do_setproctitle: bool, new_instance: bool, compiler: edbcompiler.Compiler, init_con_data: list[config.ConState], ): sockets = service_manager.get_activation_listen_sockets() if sockets: logger.info("detected service manager socket activation") with signalctl.SignalController(signal.SIGINT, signal.SIGTERM) as sc: from . import tenant as edbtenant # max_backend_connections should've been calculated already by now assert args.max_backend_connections is not None tenant = edbtenant.Tenant( cluster, instance_name=args.instance_name, max_backend_connections=args.max_backend_connections, backend_adaptive_ha=args.backend_adaptive_ha, extensions_dir=args.extensions_dir, ) tenant.set_init_con_data(init_con_data) tenant.set_reloadable_files( readiness_state_file=args.readiness_state_file, jwt_sub_allowlist_file=args.jwt_sub_allowlist_file, jwt_revocation_list_file=args.jwt_revocation_list_file, config_file=args.config_file, ) ss = server.Server( runstate_dir=runstate_dir, internal_runstate_dir=internal_runstate_dir, compiler_pool_size=args.compiler_pool_size, compiler_worker_branch_limit=args.compiler_worker_branch_limit, compiler_pool_mode=args.compiler_pool_mode, compiler_pool_addr=args.compiler_pool_addr, compiler_worker_max_rss=args.compiler_worker_max_rss, nethosts=args.bind_addresses, netport=args.port, listen_sockets=tuple(s for ss in sockets.values() for s in ss), auto_shutdown_after=args.auto_shutdown_after, echo_runtime_info=args.echo_runtime_info, status_sinks=args.status_sinks, startup_script=args.startup_script, binary_endpoint_security=args.binary_endpoint_security, http_endpoint_security=args.http_endpoint_security, default_auth_method=args.default_auth_method, testmode=args.testmode, daemonized=args.background, pidfile_dir=args.pidfile_dir, new_instance=new_instance, admin_ui=args.admin_ui, cors_always_allowed_origins=args.cors_always_allowed_origins, disable_dynamic_system_config=args.disable_dynamic_system_config, compiler_state=compiler.state, tenant=tenant, use_monitor_fs=args.reload_config_files in [ srvargs.ReloadTrigger.Default, srvargs.ReloadTrigger.FileSystemEvent, ], net_worker_mode=args.net_worker_mode, ) magic_smtp = os.getenv('EDGEDB_MAGIC_SMTP_CONFIG') if magic_smtp: await tenant.load_sidechannel_configs( json.loads(magic_smtp), compiler=compiler ) if args.config_file: await tenant.load_config_file(compiler) # This coroutine runs as long as the server, # and compiler(.state) is *heavy*, so make sure we don't # keep a reference to it. del compiler await sc.wait_for(ss.init()) ( tls_cert_newly_generated, jws_keys_newly_generated ) = await ss.maybe_generate_pki(args, ss) if args.bootstrap_only: if args.startup_script and new_instance: await sc.wait_for(ss.run_startup_script_and_exit()) return ss.init_tls( args.tls_cert_file, args.tls_key_file, tls_cert_newly_generated, args.tls_client_ca_file, ) ss.init_jwcrypto(args.jws_key_file, jws_keys_newly_generated) ss.start_watching_files() def load_configuration(_signum): if args.reload_config_files not in [ srvargs.ReloadTrigger.Default, srvargs.ReloadTrigger.Signal, ]: logger.info( "SIGHUP received, but reload on signal is disabled" ) return logger.info("reloading configuration") try: if args.readiness_state_file: tenant.reload_readiness_state() ss.reload_tls( args.tls_cert_file, args.tls_key_file, args.tls_client_ca_file, ) ss.load_jwcrypto(args.jws_key_file) tenant.reload_config_file.schedule() except Exception: logger.critical( "Unexpected error occurred during reload configuration; " "shutting down.", exc_info=True, ) ss.request_shutdown() try: await sc.wait_for(ss.start()) if do_setproctitle: setproctitle.setproctitle( f"edgedb-server-{ss.get_listen_port()}" ) # Notify systemd that we've started up. service_manager.sd_notify('READY=1') with signalctl.SignalController(signal.SIGHUP) as reload_ctl: reload_ctl.add_handler( load_configuration, signals=(signal.SIGHUP,) ) try: await sc.wait_for(ss.serve_forever()) except signalctl.SignalError as e: logger.info('Received signal: %s.', e.signo) finally: service_manager.sd_notify('STOPPING=1') logger.info('Shutting down.') await sc.wait_for(ss.stop()) async def _get_local_pgcluster( args: srvargs.ServerConfig, runstate_dir: pathlib.Path, tenant_id: str, ) -> tuple[pgcluster.Cluster, srvargs.ServerConfig]: pg_max_connections = args.max_backend_connections if not pg_max_connections: max_conns = srvargs.compute_default_max_backend_connections() pg_max_connections = max_conns if args.testmode: max_conns = srvargs.adjust_testmode_max_connections(max_conns) logger.info(f'Configuring Postgres max_connections=' f'{pg_max_connections} under test mode.') args = args._replace(max_backend_connections=max_conns) logger.info(f'Using {max_conns} max backend connections based on ' f'total memory.') cluster = await pgcluster.get_local_pg_cluster( args.data_dir, runstate_dir=runstate_dir, # Plus two below to account for system connections. max_connections=pg_max_connections + 2, tenant_id=tenant_id, log_level=args.log_level, ) cluster.update_connection_params( user='postgres', database='template1', server_settings={ "application_name": f'edgedb_instance_{args.instance_name}', } ) return cluster, args async def _get_remote_pgcluster( args: srvargs.ServerConfig, tenant_id: str, ) -> tuple[pgcluster.RemoteCluster, srvargs.ServerConfig]: cluster = await pgcluster.get_remote_pg_cluster( args.backend_dsn, tenant_id=tenant_id, specified_capabilities=args.backend_capability_sets, ) instance_params = cluster.get_runtime_params().instance_params max_conns = ( instance_params.max_connections - instance_params.reserved_connections) if not args.max_backend_connections: logger.info(f'Detected {max_conns} backend connections available.') if args.testmode: max_conns = srvargs.adjust_testmode_max_connections(max_conns) logger.info(f'Using max_backend_connections={max_conns} ' f'under test mode.') args = args._replace(max_backend_connections=max_conns) elif args.max_backend_connections > max_conns: abort(f'--max-backend-connections is too large for this backend; ' f'detected maximum available NUM: {max_conns}') cluster.update_connection_params(server_settings={ 'application_name': f'edgedb_instance_{args.instance_name}' }) return cluster, args def _patch_stdlib_testmode( stdlib: bootstrap.StdlibBits ) -> bootstrap.StdlibBits: from edb import edgeql from edb.pgsql import delta as delta_cmds from edb.pgsql import params as pg_params from edb.edgeql import ast as qlast from edb.schema import ddl as s_ddl from edb.schema import delta as sd from edb.schema import schema as s_schema from edb.schema import std as s_std schema: s_schema.Schema = s_schema.ChainedSchema( s_schema.EMPTY_SCHEMA, stdlib.stdschema, stdlib.global_schema, ) reflschema = stdlib.reflschema ctx = sd.CommandContext( stdmode=True, backend_runtime_params=pg_params.get_default_runtime_params(), ) for modname in s_schema.TESTMODE_SOURCES: ddl_text = s_std.get_std_module_text(modname) for ddl_cmd in edgeql.parse_block(ddl_text): assert isinstance(ddl_cmd, qlast.DDLCommand) delta = s_ddl.delta_from_ddl( ddl_cmd, modaliases={}, schema=schema, stdmode=True ) if not delta.canonical: sd.apply(delta, schema=schema) delta = delta_cmds.CommandMeta.adapt(delta) schema = sd.apply(delta, schema=schema, context=ctx) reflschema = delta.apply(reflschema, ctx) assert isinstance(schema, s_schema.ChainedSchema) return stdlib._replace( stdschema=schema.get_top_schema(), global_schema=schema.get_global_schema(), reflschema=reflschema, ) async def run_server( args: srvargs.ServerConfig, *, do_setproctitle: bool=False, runstate_dir: pathlib.Path, ) -> None: from . import server as server_mod global server server = server_mod logsetup.setup_logging(args.log_level, args.log_to) logger.info(f"starting Gel server {buildmeta.get_version_line()}") if args.multitenant_config_file: logger.info("configured as a multitenant instance") else: logger.info(f'instance name: {args.instance_name!r}') if devmode.is_in_dev_mode(): logger.info(f'development mode active') if fd_str := os.environ.get("EDGEDB_SERVER_EXTERNAL_LOCK_FD"): try: fd = int(fd_str) except ValueError: logger.info("Invalid EDGEDB_SERVER_EXTERNAL_LOCK_FD") else: os.set_inheritable(fd, False) if fd_str := os.environ.get("GEL_SERVER_EXTERNAL_LOCK_FD"): try: fd = int(fd_str) except ValueError: logger.info("Invalid GEL_SERVER_EXTERNAL_LOCK_FD") else: os.set_inheritable(fd, False) logger.debug( f"defaulting to the '{args.default_auth_method}' authentication method" ) if debug.flags.pydebug_listen: import debugpy debugpy.listen(38782) _init_parsers() pg_cluster_init_by_us = False if args.tenant_id is None: tenant_id = buildmeta.get_default_tenant_id() else: tenant_id = f'C{args.tenant_id}' cluster: pgcluster.Cluster | pgcluster.RemoteCluster runstate_dir_str = str(runstate_dir) runstate_dir_str_len = len( runstate_dir_str.encode( sys.getfilesystemencoding(), errors=sys.getfilesystemencodeerrors(), ), ) if runstate_dir_str_len > defines.MAX_RUNSTATE_DIR_PATH: abort( f'the length of the specified path for server run state ' f'exceeds the maximum of {defines.MAX_RUNSTATE_DIR_PATH} ' f'bytes: {runstate_dir_str!r} ({runstate_dir_str_len} bytes)', exit_code=11, ) if args.multitenant_config_file: from edb.schema import reflection as s_refl from . import bootstrap from . import multitenant try: stdlib: bootstrap.StdlibBits | None stdlib = bootstrap.read_data_cache( bootstrap.STDLIB_CACHE_FILE_NAME, pickled=True ) if stdlib is None: abort( "Cannot run multi-tenant server " "without pre-compiled standard library" ) if args.testmode: # In multitenant mode, the server/compiler is started without a # backend and will be connected to many backends. That means we # cannot load the stdlib from a certain backend; instead, the # pre-compiled stdlib is always in use. This means that we need # to explicitly enable --testmode starting a multitenant server # in order to handle backends with test-mode schema properly. try: stdlib = _patch_stdlib_testmode(stdlib) except errors.SchemaError: # The pre-compiled standard library already has test-mode # schema; ignore the patching error. pass compiler = edbcompiler.new_compiler( stdlib.stdschema, stdlib.reflschema, stdlib.classlayout, config_spec=None, ) reflection = s_refl.generate_structure( stdlib.reflschema, make_funcs=False, ) ( local_intro_sql, global_intro_sql ) = bootstrap.compile_intro_queries_stdlib( compiler=compiler, user_schema=stdlib.reflschema, reflection=reflection, ) del reflection compiler_state = edbcompiler.CompilerState( std_schema=compiler.state.std_schema, refl_schema=compiler.state.refl_schema, schema_class_layout=stdlib.classlayout, backend_runtime_params=( compiler.state.backend_runtime_params ), config_spec=compiler.state.config_spec, local_intro_query=local_intro_sql, global_intro_query=global_intro_sql, ) del local_intro_sql, global_intro_sql ( sys_queries, report_configs_typedesc_1_0, report_configs_typedesc_2_0, ) = bootstrap.compile_sys_queries( stdlib.reflschema, compiler, compiler_state.config_spec, ) sys_config, backend_settings, init_con_data = ( initialize_static_cfg( args, is_remote_cluster=True, compiler=compiler, ) ) del compiler if backend_settings: abort( 'Static backend settings for remote backend are ' 'not supported' ) with _internal_state_dir(runstate_dir, args) as ( int_runstate_dir, args, ): return await multitenant.run_server( args, sys_config=sys_config, sys_queries={ key: sql.encode("utf-8") for key, sql in sys_queries.items() }, report_config_typedesc={ (1, 0): report_configs_typedesc_1_0, (2, 0): report_configs_typedesc_2_0, }, runstate_dir=runstate_dir, internal_runstate_dir=int_runstate_dir, do_setproctitle=do_setproctitle, compiler_state=compiler_state, init_con_data=init_con_data, ) except server.StartupError as e: abort(str(e)) try: if args.data_dir: cluster, args = await _get_local_pgcluster( args, runstate_dir, tenant_id) elif args.backend_dsn: cluster, args = await _get_remote_pgcluster(args, tenant_id) else: # This should have been checked by main() already, # but be extra careful. abort('neither the data directory nor the remote Postgres DSN ' 'have been specified') except pgcluster.ClusterError as e: abort(str(e)) try: pg_cluster_init_by_us = await cluster.ensure_initialized() cluster_status = await cluster.get_status() logger.debug("postgres cluster status: %s", cluster_status) if isinstance(cluster, pgcluster.Cluster): is_local_cluster = True if cluster_status == 'running': # Refuse to start local instance on an occupied datadir, # as it's very likely that Postgres was orphaned by an # earlier unclean exit of EdgeDB. main_pid = cluster.get_main_pid() or '' abort( f'a PostgreSQL instance (PID {main_pid}) is already ' f'running in data directory "{args.data_dir}", please ' f'stop it to proceed' ) elif cluster_status == 'stopped': await cluster.start() else: abort('could not initialize data directory "%s"', args.data_dir) else: # We expect the remote cluster to be running is_local_cluster = False if cluster_status != "running": abort('specified PostgreSQL instance is not running') logger.info("postgres cluster is running") if ( args.inplace_upgrade_prepare or args.inplace_upgrade_finalize or args.inplace_upgrade_rollback ): from . import inplace_upgrade await inplace_upgrade.inplace_upgrade(cluster, args) return new_instance, compiler = await _init_cluster(cluster, args) _, backend_settings, init_con_data = initialize_static_cfg( args, is_remote_cluster=not is_local_cluster, compiler=compiler, ) if is_local_cluster: if new_instance or backend_settings: logger.info('Restarting server to reload configuration...') await cluster.stop() await cluster.start(server_settings=backend_settings) elif backend_settings: abort( 'Static backend settings for remote backend are not supported' ) del backend_settings if ( not args.bootstrap_only or args.bootstrap_command_file or args.bootstrap_command or ( args.tls_cert_mode is srvargs.ServerTlsCertMode.SelfSigned ) or ( args.jose_key_mode is srvargs.JOSEKeyMode.Generate ) ): instance_name = args.instance_name database = pgcluster.get_database_backend_name( defines.EDGEDB_TEMPLATE_DB, tenant_id=tenant_id, ) if args.data_dir else None server_settings = { 'application_name': f'edgedb_instance_{instance_name}', 'edgedb.instance_name': instance_name, 'edgedb.server_version': buildmeta.get_version_json(), } if database: cluster.update_connection_params( database=database, server_settings=server_settings ) else: cluster.update_connection_params( server_settings=server_settings ) with _internal_state_dir(runstate_dir, args) as ( int_runstate_dir, args, ): await _run_server( cluster, args, runstate_dir, int_runstate_dir, do_setproctitle=do_setproctitle, new_instance=new_instance, compiler=compiler, init_con_data=init_con_data, ) except server.StartupError as e: abort(str(e)) except BaseException: if pg_cluster_init_by_us and not _server_initialized: logger.warning( 'server bootstrap did not complete successfully, ' 'removing the data directory') if await cluster.get_status() == 'running': await cluster.stop() cluster.destroy() raise finally: if args.temp_dir: if await cluster.get_status() == 'running': await cluster.stop() cluster.destroy() elif await cluster.get_status() == 'running': await cluster.stop() def bump_rlimit_nofile() -> None: try: fno_soft, fno_hard = resource.getrlimit(resource.RLIMIT_NOFILE) except resource.error: logger.warning('could not read RLIMIT_NOFILE') else: if fno_soft < defines.EDGEDB_MIN_RLIMIT_NOFILE: try: resource.setrlimit( resource.RLIMIT_NOFILE, (min(defines.EDGEDB_MIN_RLIMIT_NOFILE, fno_hard), fno_hard)) except resource.error: logger.warning('could not set RLIMIT_NOFILE') def server_main(**kwargs: Any) -> None: exceptions.install_excepthook() bump_rlimit_nofile() asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) if kwargs['devmode'] is not None: devmode.enable_dev_mode(kwargs['devmode']) try: server_args = srvargs.parse_args(**kwargs) except srvargs.InvalidUsageError as e: abort(e.args[0], exit_code=e.args[1]) if server_args.data_dir: default_runstate_dir = server_args.data_dir else: default_runstate_dir = None specified_runstate_dir: Optional[pathlib.Path] if server_args.runstate_dir: specified_runstate_dir = server_args.runstate_dir elif server_args.bootstrap_only: # When bootstrapping a new EdgeDB instance it is often necessary # to avoid using the main runstate dir due to lack of permissions, # possibility of conflict with another running instance, etc. # The --bootstrap mode is also often runs unattended, i.e. # as a post-install hook during package installation. specified_runstate_dir = default_runstate_dir else: specified_runstate_dir = None runstate_dir_mgr = _ensure_runstate_dir( default_runstate_dir, specified_runstate_dir, ) with runstate_dir_mgr as runstate_dir: if server_args.background: daemon_opts: dict[str, Any] = {'detach_process': True} if server_args.daemon_user: daemon_opts['uid'] = server_args.daemon_user if server_args.daemon_group: daemon_opts['gid'] = server_args.daemon_group with daemon.DaemonContext(**daemon_opts): asyncio.run(run_server( server_args, runstate_dir=runstate_dir, )) else: with devmode.CoverageConfig.enable_coverage_if_requested(): asyncio.run(run_server( server_args, runstate_dir=runstate_dir, )) @click.group( 'Gel Server', invoke_without_command=True, context_settings=dict(help_option_names=['-h', '--help']) ) @srvargs.server_options @click.pass_context def main(ctx, version=False, **kwargs): if kwargs.get('testmode') and 'GEL_TEST_CATALOG_VERSION' in os.environ: buildmeta.EDGEDB_CATALOG_VERSION = int( os.environ['GEL_TEST_CATALOG_VERSION'] ) elif kwargs.get('testmode') and 'EDGEDB_TEST_CATALOG_VERSION' in os.environ: buildmeta.EDGEDB_CATALOG_VERSION = int( os.environ['EDGEDB_TEST_CATALOG_VERSION'] ) if version: print(f"gel-server, version {buildmeta.get_version()}") sys.exit(0) if ctx.invoked_subcommand is None: server_main(**kwargs) @main.command(hidden=True) @srvargs.compiler_options def compiler(**kwargs): from edb.server.compiler_pool import server as compiler_server asyncio.run(compiler_server.server_main(**kwargs)) def main_dev(): devmode.enable_dev_mode() main() def initialize_static_cfg( args: srvargs.ServerConfig, is_remote_cluster: bool, compiler: edbcompiler.Compiler, ) -> tuple[ Mapping[str, config.SettingValue], dict[str, str], list[config.ConState] ]: result = {} init_con_script_data: list[config.ConState] = [] backend_settings = {} config_spec = compiler.state.config_spec sources = { config.ConStateType.command_line_argument: "command line argument", config.ConStateType.environment_variable: "environment variable", } def add_config_values(obj: dict[str, Any], source: config.ConStateType): settings = compiler.compile_structured_config( {"cfg::Config": obj}, source=sources[source] )["cfg::Config"] for name, value in settings.items(): setting = config_spec[name] if is_remote_cluster: if setting.backend_setting and setting.requires_restart: if source == config.ConStateType.command_line_argument: where = "on command line" else: where = "as an environment variable" raise server.StartupError( f"Can't set config {name!r} {where} when using " f"a remote Postgres cluster" ) init_con_script_data.append({ "name": name, "value": config.value_to_json_value(setting, value.value), "type": source, }) result[name] = value if setting.backend_setting: backend_val = value.value if isinstance(backend_val, statypes.ScalarType): backend_val = backend_val.to_backend_str() backend_settings[setting.backend_setting] = str(backend_val) values: dict[str, Any] = {} translate_env = { "EDGEDB_SERVER_BIND_ADDRESS": "listen_addresses", "EDGEDB_SERVER_PORT": "listen_port", "GEL_SERVER_BIND_ADDRESS": "listen_addresses", "GEL_SERVER_PORT": "listen_port", } for name, value in os.environ.items(): if cfg := translate_env.get(name): values[cfg] = value else: cfg = name.removeprefix("EDGEDB_SERVER_CONFIG_cfg::") if cfg != name: values[cfg] = value else: cfg = name.removeprefix("GEL_SERVER_CONFIG_cfg::") if cfg != name: values[cfg] = value if values: add_config_values(values, config.ConStateType.environment_variable) values = {} if args.bind_addresses: values["listen_addresses"] = args.bind_addresses if args.port: values["listen_port"] = args.port if values: add_config_values(values, config.ConStateType.command_line_argument) return immutables.Map(result), backend_settings, init_con_script_data if __name__ == '__main__': main() ================================================ FILE: edb/server/metrics.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import sys from edb.common import prometheus as prom registry = prom.Registry(prefix='edgedb_server') COUNT_BUCKETS = prom.per_order_buckets( 1, 10000, entries_per_order=2, ) BYTES_BUCKETS = prom.per_order_buckets( 32, 2**20, entries_per_order=1, base=2, ) compiler_process_spawns = registry.new_counter( 'compiler_process_spawns_total', 'Total number of compiler processes spawned.' ) compiler_process_kills = registry.new_counter( 'compiler_process_kills_total', 'Total number of compiler processes killed.', ) current_compiler_processes = registry.new_gauge( 'compiler_processes_current', 'Current number of active compiler processes.' ) compiler_process_memory = registry.new_labeled_gauge( 'compiler_process_memory_bytes', 'Current memory usage of compiler processes in bytes.', labels=('pid',), ) compiler_process_schema_size = registry.new_labeled_gauge( 'compiler_process_schema_size_bytes', 'Current size of compiler process schema cache in bytes.', labels=('pid', 'client'), ) compiler_process_branches = registry.new_labeled_gauge( 'compiler_process_branches', 'Total number of branches cached in each compiler process.', labels=('pid', 'client'), ) compiler_process_branch_actions = registry.new_labeled_counter( 'compiler_process_branch_actions_total', 'Number of different branch actions happened in each ' 'compiler process.', labels=('pid', 'client', 'action'), ) compiler_process_client_actions = registry.new_labeled_counter( 'compiler_process_client_actions_total', 'Number of different client actions happened in each ' 'compiler process.', labels=('pid', 'action'), ) compiler_pool_wait_time = registry.new_histogram( 'compiler_pool_wait_time', 'Time it takes to acquire a compiler process.', unit=prom.Unit.SECONDS, ) compiler_pool_queue_errors = registry.new_labeled_counter( 'compiler_pool_queue_errors_total', 'Number of compiler pool errors in queue.', labels=('type',), ) current_branches = registry.new_labeled_gauge( 'branches_current', 'Current number of branches.', labels=('tenant',), ) current_introspected_branches = registry.new_labeled_gauge( 'branches_introspected_current', 'Current number of branches whose schemas are introspected.', labels=('tenant',), ) total_backend_connections = registry.new_labeled_counter( 'backend_connections_total', 'Total number of backend connections established.', labels=('tenant',), ) current_backend_connections = registry.new_labeled_gauge( 'backend_connections_current', 'Current number of active backend connections.', labels=('tenant',), ) backend_connection_establishment_errors = registry.new_labeled_counter( 'backend_connection_establishment_errors_total', 'Number of times the server could not establish a backend connection.', labels=('tenant',), ) backend_connection_establishment_latency = registry.new_labeled_histogram( 'backend_connection_establishment_latency', 'Time it takes to establish a backend connection.', unit=prom.Unit.SECONDS, labels=('tenant',), ) backend_connection_aborted = registry.new_labeled_counter( 'backend_connections_aborted_total', 'Number of aborted backend connections.', labels=('tenant', 'pgcode') ) backend_query_duration = registry.new_labeled_histogram( 'backend_query_duration', 'Time it takes to run a query on a backend connection.', unit=prom.Unit.SECONDS, labels=('tenant',), ) total_client_connections = registry.new_labeled_counter( 'client_connections_total', 'Total number of clients.', labels=('tenant',), ) current_client_connections = registry.new_labeled_gauge( 'client_connections_current', 'Current number of active clients.', labels=('tenant',), ) idle_client_connections = registry.new_labeled_counter( 'client_connections_idle_total', 'Total number of forcefully closed idle client connections.', labels=('tenant',), ) client_connection_duration = registry.new_labeled_histogram( 'client_connection_duration', 'Time a client connection is open.', unit=prom.Unit.SECONDS, labels=('tenant', 'interface'), ) edgeql_query_compilations = registry.new_labeled_counter( 'edgeql_query_compilations_total', 'Number of compiled/cached queries or scripts.', labels=('tenant', 'path') ) edgeql_query_compilation_duration = registry.new_labeled_histogram( 'edgeql_query_compilation_duration', 'Time it takes to compile an EdgeQL query or script.', unit=prom.Unit.SECONDS, labels=('tenant',), ) graphql_query_compilations = registry.new_labeled_counter( 'graphql_query_compilations_total', 'Number of compiled/cached GraphQL queries.', labels=('tenant', 'path') ) query_compilation_duration = registry.new_labeled_histogram( 'query_compilation_duration', 'Time it takes to compile a query or script.', unit=prom.Unit.SECONDS, labels=('tenant', 'interface'), ) sql_queries = registry.new_labeled_counter( 'sql_queries_total', 'Number of SQL queries.', labels=('tenant',) ) sql_compilations = registry.new_labeled_counter( 'sql_compilations_total', 'Number of SQL compilations.', labels=('tenant',) ) queries_per_connection = registry.new_labeled_histogram( 'queries_per_connection', 'Number of queries per connection.', buckets=COUNT_BUCKETS, labels=('tenant', 'interface'), ) query_size = registry.new_labeled_histogram( 'query_size', 'The size of a query.', unit=prom.Unit.BYTES, buckets=BYTES_BUCKETS, labels=('tenant', 'interface'), ) background_errors = registry.new_labeled_counter( 'background_errors_total', 'Number of unhandled errors in background server routines.', labels=('tenant', 'source') ) transaction_serialization_errors = registry.new_labeled_counter( 'transaction_serialization_errors_total', 'Number of transaction serialization errors.', labels=('tenant',) ) connection_errors = registry.new_labeled_counter( 'connection_errors_total', 'Number of network connection errors.', labels=('tenant',) ) ha_events_total = registry.new_labeled_counter( "ha_events_total", "Number of each high-availability watch event.", labels=("dsn", "event"), ) auth_api_calls = registry.new_labeled_counter( "auth_api_calls_total", "Number of API calls to the Auth extension.", labels=("tenant",), ) auth_ui_renders = registry.new_labeled_counter( "auth_ui_renders_total", "Number of UI pages rendered by the Auth extension.", labels=("tenant",), ) auth_providers = registry.new_labeled_gauge( 'auth_providers', 'Number of Auth providers configured.', labels=('tenant', 'branch'), ) extension_used = registry.new_labeled_gauge( 'extension_used_branch_count_current', 'How many branches an extension is used by.', labels=('tenant', 'extension'), ) feature_used_branches = registry.new_labeled_gauge( 'feature_used_branch_count_current', 'How many branches a schema feature is used by.', labels=('tenant', 'feature'), ) feature_used = registry.new_labeled_gauge( 'feature_used_num_count_current', 'How many times a schema feature is used.', labels=('tenant', 'feature'), ) auth_successful_logins = registry.new_labeled_counter( "auth_successful_logins_total", "Number of successful logins in the Auth extension.", labels=("tenant",), ) auth_provider_jwkset_fetch_success = registry.new_labeled_counter( "auth_provider_jwkset_fetch_success_total", "Number of successful Auth extension JWK Set fetches.", labels=("provider",), ) auth_provider_jwkset_fetch_errors = registry.new_labeled_counter( "auth_provider_jwkset_fetch_errors_total", "Number of failed Auth extension JWK Set fetches.", labels=("provider",), ) auth_provider_token_validation_success = registry.new_labeled_counter( "auth_provider_token_validation_success_total", "Number of successful Auth extension provider token validations.", labels=("provider",), ) auth_provider_token_validation_errors = registry.new_labeled_counter( "auth_provider_token_validation_errors_total", "Number of failed Auth extension provider token validations.", labels=("provider",), ) otc_initiated_total = registry.new_labeled_counter( "otc_initiated_total", "Number of one-time codes initiated.", labels=("tenant",), ) otc_verified_total = registry.new_labeled_counter( "otc_verified_total", "Number of one-time codes successfully verified.", labels=("tenant",), ) otc_failed_total = registry.new_labeled_counter( "otc_failed_total", "Number of one-time code verification failures.", labels=("tenant", "reason"), ) mt_tenants_total = registry.new_gauge( 'mt_tenants_current', 'Total number of currently-registered tenants.', ) mt_config_reloads = registry.new_counter( 'mt_config_reloads_total', 'Total number of the main multi-tenant config file reloads.', ) mt_config_reload_errors = registry.new_counter( 'mt_config_reload_errors_total', 'Total number of the main multi-tenant config file reload errors.', ) mt_tenant_add_total = registry.new_labeled_counter( 'mt_tenant_add_total', 'Total number of new tenants the server attempted to add.', labels=("tenant",), ) mt_tenant_add_errors = registry.new_labeled_counter( 'mt_tenant_add_errors_total', 'Total number of tenants the server failed to add.', labels=("tenant",), ) mt_tenant_remove_total = registry.new_labeled_counter( 'mt_tenant_remove_total', 'Total number of tenants the server attempted to remove.', labels=("tenant",), ) mt_tenant_remove_errors = registry.new_labeled_counter( 'mt_tenant_remove_errors_total', 'Total number of tenants the server failed to remove.', labels=("tenant",), ) mt_tenant_reload_total = registry.new_labeled_counter( 'mt_tenant_reload_total', 'Total number of tenants the server attempted to reload.', labels=("tenant",), ) mt_tenant_reload_errors = registry.new_labeled_counter( 'mt_tenant_reload_errors_total', 'Total number of tenants the server failed to reload.', labels=("tenant",), ) if os.name == 'posix' and (sys.platform == 'linux' or sys.platform == 'darwin'): open_fds = registry.new_gauge( 'open_fds', 'Number of open file descriptors.', ) max_open_fds = registry.new_gauge( 'max_open_fds', 'Maximum number of open file descriptors.', ) # Implement a function that monitors the number of open file descriptors # and updates the metrics accordingly. This will be replaced with a more # efficient implementation in Rust at a later date. def monitor_open_fds_linux(): import time while True: max_open_fds.set(os.sysconf('SC_OPEN_MAX')) # To get the current number of open files, stat /proc/self/fd/ # and get the size. If zero, count the number of entries in the # directory. # # This is supported in modern Linux kernels. # https://github.com/torvalds/linux/commit/f1f1f2569901ec5b9d425f2e91c09a0e320768f3 try: st = os.stat('/proc/self/fd/') if st.st_size == 0: open_fds.set(len(os.listdir('/proc/self/fd/'))) else: open_fds.set(st.st_size) except Exception: open_fds.set(-1) time.sleep(30) def monitor_open_fds_macos(): import time while True: max_open_fds.set(os.sysconf('SC_OPEN_MAX')) # Iterate the contents of /dev/fd to list all entries. # We assume that MacOS isn't going to be running a large installation # of EdgeDB on a single machine. try: open_fds.set(len(os.listdir('/dev/fd'))) except Exception: open_fds.set(-1) time.sleep(30) def start_monitoring_open_fds(): import threading # Supported only on Linux and macOS. if os.name == 'posix': if sys.platform == 'darwin': threading.Thread( target=monitor_open_fds_macos, name='open_fds_monitor', daemon=True ).start() elif sys.platform == 'linux': threading.Thread( target=monitor_open_fds_linux, name='open_fds_monitor', daemon=True ).start() start_monitoring_open_fds() ================================================ FILE: edb/server/multitenant.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Iterator, Mapping, MutableMapping, Sequence, TypedDict import asyncio import collections import json import logging import pathlib import signal import sys import weakref import setproctitle from edb import buildmeta from edb import errors from edb.common import retryloop from edb.common import signalctl from edb.common.log import current_tenant from edb.pgsql import params as pgparams from edb.server import compiler as edbcompiler from edb.server import metrics from . import args as srvargs from . import config from . import defines from . import pgcluster from . import server from . import tenant as edbtenant from .compiler_pool import pool as compiler_pool logger = logging.getLogger("edb.server") TenantConfig = TypedDict( "TenantConfig", { "instance-name": str, "backend-dsn": str, "max-backend-connections": int, "tenant-id": str, "backend-adaptive-ha": bool, "jwt-sub-allowlist-file": str, "jwt-revocation-list-file": str, "readiness-state-file": str, "admin": bool, "config-file": str, }, ) class MultiTenantServer(server.BaseServer): _config_file: pathlib.Path _sys_config: Mapping[str, config.SettingValue] _init_con_data: list[config.ConState] _tenants_by_sslobj: MutableMapping _tenants_conf: dict[str, dict[str, str]] _last_tenants_conf: dict[str, dict[str, str]] _tenants_lock: MutableMapping[str, asyncio.Lock] _tenants_serial: dict[str, int] _tenants: dict[str, edbtenant.Tenant] _admin_tenant: edbtenant.Tenant | None _task_group: asyncio.TaskGroup | None _task_serial: int def __init__( self, config_file: pathlib.Path, *, compiler_pool_tenant_cache_size: int, sys_config: Mapping[str, config.SettingValue], init_con_data: list[config.ConState], sys_queries: Mapping[str, bytes], report_config_typedesc: dict[defines.ProtocolVersion, bytes], **kwargs, ): super().__init__(**kwargs) self._config_file = config_file self._sys_config = sys_config self._init_con_data = init_con_data self._compiler_pool_tenant_cache_size = compiler_pool_tenant_cache_size self._tenants_by_sslobj = weakref.WeakKeyDictionary() self._tenants_conf = {} self._last_tenants_conf = {} self._tenants_lock = collections.defaultdict(asyncio.Lock) self._tenants_serial = {} self._tenants = {} self._admin_tenant = None self._task_group = asyncio.TaskGroup() self._task_serial = 0 self._sys_queries = sys_queries self._report_config_typedesc = report_config_typedesc def _get_sys_config(self) -> Mapping[str, config.SettingValue]: return self._sys_config def _sni_callback(self, sslobj, server_name, _sslctx): if server_name is None: self._tenants_by_sslobj[sslobj] = edbtenant.host_tenant elif tenant := self._tenants.get(server_name): self._tenants_by_sslobj[sslobj] = tenant def get_default_tenant(self) -> edbtenant.Tenant: raise errors.UnknownTenantError( "No such tenant configured.", hint="Please try again later, or " "double check the SNI/server name in TLS connection", ) def retrieve_tenant(self, sslobj) -> edbtenant.Tenant | None: return self._tenants_by_sslobj.pop(sslobj, None) def iter_tenants(self) -> Iterator[edbtenant.Tenant]: return iter(self._tenants.values()) async def _before_start_servers(self) -> None: assert self._task_group is not None await self._task_group.__aenter__() fs = self.reload_tenants() def reload_config_file(): logger.info("Reloading multi-tenant config file.") self.reload_tenants() self.monitor_fs(self._config_file, reload_config_file) if fs: await asyncio.wait(fs) def _get_status(self) -> dict[str, Any]: status = super()._get_status() tenants = {} for server_name, tenant in self._tenants.items(): tenants[server_name] = { "tenant_id": tenant.tenant_id, } status["tenants"] = tenants return status def _get_backend_runtime_params(self) -> pgparams.BackendRuntimeParams: return pgparams.get_default_runtime_params() async def stop(self): await super().stop() if self._task_group is not None: await self._task_group.__aexit__(*sys.exc_info()) try: for tenant in self._tenants.values(): tenant.stop() for tenant in self._tenants.values(): await tenant.wait_stopped() metrics.mt_tenants_total.dec() finally: for tenant in self._tenants.values(): tenant.terminate_sys_pgcon() def reload_tenants(self) -> Sequence[asyncio.Future]: metrics.mt_config_reloads.inc() try: with self._config_file.open() as cf: conf = json.load(cf) self._last_tenants_conf = self._tenants_conf rv = [] for sni, tenant_conf in conf.items(): if sni not in self._tenants_conf: rv.append( self._create_task(self._add_tenant, sni, tenant_conf) ) for sni in self._tenants_conf: if sni in conf: rv.append( self._create_task(self._reload_tenant, sni, conf[sni]) ) else: rv.append(self._create_task(self._remove_tenant, sni)) self._tenants_conf = conf return rv except Exception: metrics.mt_config_reload_errors.inc() raise def _create_task(self, method, *args) -> asyncio.Task: self._task_serial += 1 assert self._task_group is not None return self._task_group.create_task(method(self._task_serial, *args)) async def _create_tenant(self, conf: TenantConfig) -> edbtenant.Tenant: cluster = await pgcluster.get_remote_pg_cluster( conf["backend-dsn"], tenant_id=conf.get("tenant-id") ) instance_params = cluster.get_runtime_params().instance_params max_conns = ( instance_params.max_connections - instance_params.reserved_connections ) if "max-backend-connections" not in conf: logger.info(f"Detected {max_conns} backend connections available.") if self._testmode: max_conns = srvargs.adjust_testmode_max_connections(max_conns) logger.info( f"Using max_backend_connections={max_conns} " f"under test mode." ) elif conf["max-backend-connections"] > max_conns: raise server.StartupError( f"--max-backend-connections is too large for this backend; " f"detected maximum available NUM: {max_conns}" ) else: max_conns = conf["max-backend-connections"] cluster.update_connection_params(server_settings={ "application_name": f'edgedb_instance_{conf["instance-name"]}', "edgedb.instance_name": conf["instance-name"], "edgedb.server_version": buildmeta.get_version_json(), }) if self._jws_key is None: raise server.StartupError( "No secret key" ) tenant = edbtenant.Tenant( cluster, instance_name=conf["instance-name"], max_backend_connections=max_conns, backend_adaptive_ha=conf.get("backend-adaptive-ha", False), ) tenant.set_init_con_data(self._init_con_data) config_file = conf.get("config-file") tenant.set_reloadable_files( readiness_state_file=conf.get("readiness-state-file"), jwt_sub_allowlist_file=conf.get("jwt-sub-allowlist-file"), jwt_revocation_list_file=conf.get("jwt-revocation-list-file"), config_file=config_file, ) tenant.set_server(self) tenant.load_jwcrypto(self._jws_key) if config_file: await tenant.load_config_file(self.get_compiler_pool()) try: await tenant.init_sys_pgcon() await tenant.init(compat_check=True) tenant.start_watching_files() await tenant.start_accepting_new_tasks() tenant.start_running() if conf.get("admin", False): # There can be only one "admin" tenant, the behavior of setting # multiple tenants with `"admin": true` is undefined. self._admin_tenant = tenant return tenant except Exception: await self._destroy_tenant(tenant) raise def _get_admin_tenant(self) -> edbtenant.Tenant: if self._admin_tenant is None: return super()._get_admin_tenant() else: return self._admin_tenant async def _destroy_tenant(self, tenant: edbtenant.Tenant): try: if tenant.is_online(): tenant.set_readiness_state( srvargs.ReadinessState.Offline, "tenant is removed" ) tenant.stop_accepting_connections() tenant.stop() try: await asyncio.wait_for( tenant.wait_stopped(), defines.MULTITENANT_TENANT_DESTROY_TIMEOUT, ) except asyncio.TimeoutError: logger.warning( "Tenant removal is taking too long; " "brutally shutdown the tenant now" ) assert isinstance( self._compiler_pool, compiler_pool.MultiTenantPool ) self._compiler_pool.drop_tenant(tenant.client_id) finally: tenant.terminate_sys_pgcon() async def _add_tenant(self, serial: int, sni: str, conf: TenantConfig): def _warn(e): logger.warning( "Failed to add Tenant %s, retrying. Reason: %s", sni, e ) async def _add_tenant(): current_tenant.set(conf["instance-name"]) metrics.mt_tenant_add_total.inc(1.0, current_tenant.get()) rloop = retryloop.RetryLoop( backoff=retryloop.exp_backoff(), timeout=300, ignore=Exception, retry_cb=_warn, ) async for iteration in rloop: async with iteration: async with self._tenants_lock[sni]: if serial > self._tenants_serial.get(sni, 0): if sni not in self._tenants: tenant = await self._create_tenant(conf) self._tenants[sni] = tenant metrics.mt_tenants_total.inc() logger.info("Added Tenant %s", sni) self._tenants_serial[sni] = serial try: with signalctl.SignalController( signal.SIGINT, signal.SIGTERM ) as sc: await sc.wait_for(_add_tenant()) except signalctl.SignalError: pass except Exception: logger.critical("Failed to add Tenant %s", sni, exc_info=True) async with self._tenants_lock[sni]: if serial > self._tenants_serial.get(sni, 0): self._tenants_conf.pop(sni, None) metrics.mt_tenant_add_errors.inc(1.0, conf["instance-name"]) async def _remove_tenant(self, serial: int, sni: str): tenant = None try: async with self._tenants_lock[sni]: if serial > self._tenants_serial.get(sni, 0): if sni in self._tenants: tenant = self._tenants.pop(sni) metrics.mt_tenant_remove_total.inc( 1.0, tenant.get_instance_name() ) current_tenant.set(tenant.get_instance_name()) await self._destroy_tenant(tenant) metrics.mt_tenants_total.dec() logger.info("Removed Tenant %s", sni) self._tenants_serial[sni] = serial except Exception: logger.critical("Failed to remove Tenant %s", sni, exc_info=True) metrics.mt_tenant_remove_errors.inc( 1.0, tenant.get_instance_name() if tenant else 'unknown' ) async def _reload_tenant(self, serial: int, sni: str, conf: TenantConfig): tenant = None try: async with self._tenants_lock[sni]: if serial > self._tenants_serial.get(sni, 0): if tenant := self._tenants.get(sni): metrics.mt_tenant_reload_total.inc( 1.0, tenant.get_instance_name() ) current_tenant.set(tenant.get_instance_name()) orig = self._last_tenants_conf.get(sni, {}) diff = set(orig.keys()) - set(conf) for k, v in conf.items(): if orig.get(k) != v: diff.add(k) diff -= { "readiness-state-file", "jwt-sub-allowlist-file", "jwt-revocation-list-file", "config-file", } if diff: logger.warning( "The following config of tenant %s changed, " "but reloading them is not yet supported: %s", sni, ", ".join(diff), ) if not tenant.set_reloadable_files( readiness_state_file=conf.get( "readiness-state-file"), jwt_sub_allowlist_file=conf.get( "jwt-sub-allowlist-file"), jwt_revocation_list_file=conf.get( "jwt-revocation-list-file"), config_file=conf.get("config-file"), ): # none of the reloadable values was modified return tenant.reload() logger.info("Reloaded Tenant %s", sni) # GOTCHA: reloading tenant doesn't increase the tenant # serial because a reload shouldn't prevent a concurrent # removing of the tenant. except Exception: logger.critical("Failed to reload Tenant %s", sni, exc_info=True) metrics.mt_tenant_reload_errors.inc( 1.0, tenant.get_instance_name() if tenant else 'unknown' ) def get_debug_info(self): parent = super().get_debug_info() parent["tenants"] = { name: tenant.get_debug_info() for name, tenant in self._tenants.items() } return parent def _get_compiler_args(self) -> dict[str, Any]: args = super()._get_compiler_args() args["cache_size"] = self._compiler_pool_tenant_cache_size return args async def run_server( args: srvargs.ServerConfig, *, sys_config: Mapping[str, config.SettingValue], init_con_data: list[config.ConState], sys_queries: Mapping[str, bytes], report_config_typedesc: dict[defines.ProtocolVersion, bytes], runstate_dir: pathlib.Path, internal_runstate_dir: pathlib.Path, do_setproctitle: bool, compiler_state: edbcompiler.CompilerState, ): multitenant_config_file = args.multitenant_config_file assert multitenant_config_file is not None with signalctl.SignalController(signal.SIGINT, signal.SIGTERM) as sc: ss = MultiTenantServer( multitenant_config_file, sys_config=sys_config, init_con_data=init_con_data, sys_queries=sys_queries, report_config_typedesc=report_config_typedesc, runstate_dir=runstate_dir, internal_runstate_dir=internal_runstate_dir, nethosts=args.bind_addresses, netport=args.port, listen_sockets=(), auto_shutdown_after=args.auto_shutdown_after, echo_runtime_info=args.echo_runtime_info, status_sinks=args.status_sinks, binary_endpoint_security=args.binary_endpoint_security, http_endpoint_security=args.http_endpoint_security, default_auth_method=args.default_auth_method, testmode=args.testmode, admin_ui=args.admin_ui, cors_always_allowed_origins=args.cors_always_allowed_origins, disable_dynamic_system_config=args.disable_dynamic_system_config, compiler_pool_size=args.compiler_pool_size, compiler_worker_branch_limit=args.compiler_worker_branch_limit, compiler_pool_mode=srvargs.CompilerPoolMode.MultiTenant, compiler_pool_addr=args.compiler_pool_addr, compiler_pool_tenant_cache_size=( args.compiler_pool_tenant_cache_size ), compiler_worker_max_rss=args.compiler_worker_max_rss, compiler_state=compiler_state, use_monitor_fs=args.reload_config_files in [ srvargs.ReloadTrigger.Default, srvargs.ReloadTrigger.FileSystemEvent, ], ) # This coroutine runs as long as the server, # and compiler_state is *heavy*, so make sure we don't # keep a reference to it. del compiler_state await sc.wait_for(ss.init()) ( tls_cert_newly_generated, jws_keys_newly_generated ) = await ss.maybe_generate_pki(args, ss) ss.init_tls( args.tls_cert_file, args.tls_key_file, tls_cert_newly_generated, args.tls_client_ca_file, ) ss.init_jwcrypto(args.jws_key_file, jws_keys_newly_generated) ss.start_watching_files() def load_configuration(_signum): if args.reload_config_files not in [ srvargs.ReloadTrigger.Default, srvargs.ReloadTrigger.Signal, ]: logger.info( "SIGHUP received, but reload on signal is disabled" ) return logger.info("reloading configuration") try: ss.reload_tls( args.tls_cert_file, args.tls_key_file, args.tls_client_ca_file, ) ss.load_jwcrypto(args.jws_key_file) ss.reload_tenants() except Exception: logger.critical( "Unexpected error occurred during reload configuration; " "shutting down.", exc_info=True, ) ss.request_shutdown() try: await sc.wait_for(ss.start()) if do_setproctitle: setproctitle.setproctitle( f"edgedb-server-{ss.get_listen_port()}" ) with signalctl.SignalController(signal.SIGHUP) as reload_ctl: reload_ctl.add_handler( load_configuration, signals=(signal.SIGHUP,) ) try: await sc.wait_for(ss.serve_forever()) except signalctl.SignalError as e: logger.info("Received signal: %s.", e.signo) finally: logger.info("Shutting down.") await sc.wait_for(ss.stop()) ================================================ FILE: edb/server/net_worker.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import dataclasses import json import typing import asyncio import logging import base64 from edb.ir import statypes from edb.server import defines from edb.server.protocol import execute from edb.server.http import HttpClient from edb.common import retryloop from . import dbview if typing.TYPE_CHECKING: from edb.server import server as edbserver from edb.server import tenant as edbtenant logger = logging.getLogger("edb.server.net_worker") POLLING_INTERVAL = statypes.Duration(microseconds=10 * 1_000_000) # 10 seconds # TODO: Make this configurable via server config NET_HTTP_REQUEST_TTL = statypes.Duration( microseconds=3600 * 1_000_000 ) # 1 hour @dataclasses.dataclass class TenantState: # For each database, we track the last (db.dml_queries_executed, # db.dbver) that we saw. We track dbver so that we can detect # a database being dropped and recreated. database_counts: dict[str, tuple[int, int]] http_client: typing.Any async def _http_task(tenant: edbtenant.Tenant, state: TenantState) -> None: http_max_connections = tenant._server.config_lookup( 'http_max_connections', tenant.get_sys_config() ) http_client = state.http_client http_client._update_limit(http_max_connections) seen_counts = {} try: # TODO: I think this TaskGroup approach might not be the right # approach here. It is fragile to failures and means that slow # queries can cause things to wait on them. async with (asyncio.TaskGroup() as g,): for db in list(tenant.iter_dbs()): if db.name == defines.EDGEDB_SYSTEM_DB: # Don't run the net_worker for system database continue if not tenant.is_database_connectable(db.name): # Don't run the net_worker if the database is not # connectable, e.g. being dropped continue cur_seen = state.database_counts.get(db.name, (-1, -1)) # We only do the polling for net requests on branches # that have seen DML since our last execute. # # TODO: It would be even better if we only ran when # there were queries that actually touched # ScheduledRequest, but I'm still musing over how to # thread the data around in a way that isn't just a # total hack. if cur_seen == (db.dml_queries_executed, db.dbver): seen_counts[db.name] = cur_seen continue new_key = db.dml_queries_executed + 1, db.dbver try: json_bytes = await execute.parse_execute_json( db, """ with PENDING_REQUESTS := ( select std::net::http::ScheduledRequest filter .state = std::net::RequestState.Pending ), UPDATED := ( update PENDING_REQUESTS set { state := std::net::RequestState.InProgress, updated_at := datetime_of_statement(), } ), select UPDATED { id, method, url, body, headers, } """, cached_globally=True, tx_isolation=defines.TxIsolationLevel.RepeatableRead, query_tag='gel/net', ) seen_counts[db.name] = new_key except Exception as ex: # If the query fails (because the database branch # has been racily deleted, maybe), ignore an keep # going. We don't want the failure to bubble up # and cause tasks in the task group to die. logger.debug( "HTTP net_worker query failed " "(instance: %s, branch: %s)", tenant.get_instance_name(), db, exc_info=ex, ) continue pending_requests: list[dict] = json.loads(json_bytes) for pending_request in pending_requests: request = ScheduledRequest(**pending_request) g.create_task(handle_request(http_client, db, request)) except Exception as ex: logger.debug( "HTTP send failed (instance: %s)", tenant.get_instance_name(), exc_info=ex, ) state.database_counts = seen_counts def create_http(tenant: edbtenant.Tenant): return TenantState( http_client=tenant.get_http_client(originator="std::net"), database_counts={}, ) async def http(server: edbserver.BaseServer) -> None: tenant_http = dict() while True: tenant_set = set() try: tasks = [] for tenant in server.iter_tenants(): if tenant.accept_new_tasks: tenant_set.add(tenant) if tenant not in tenant_http: tenant_http[tenant] = create_http(tenant) tasks.append( tenant.create_task( _http_task(tenant, tenant_http[tenant]), interruptable=True, ) ) # Remove unused tenant_http entries for tenant in list(tenant_http.keys()): if tenant not in tenant_set: del tenant_http[tenant] if tasks: await asyncio.wait(tasks) except Exception as ex: logger.debug("HTTP worker failed", exc_info=ex) finally: await asyncio.sleep( POLLING_INTERVAL.to_microseconds() / 1_000_000.0 ) @dataclasses.dataclass class ScheduledRequest: id: str method: str url: str body: typing.Optional[bytes] headers: typing.Optional[list[dict]] def __post_init__(self): if self.body is not None: self.body = base64.b64decode(self.body).decode('utf-8').encode() async def handle_request( client: HttpClient, db: dbview.Database, request: ScheduledRequest ) -> None: response_status = None response_body = None response_headers = None failure = None try: headers = ( [(header["name"], header["value"]) for header in request.headers] if request.headers else None ) response = await client.request( method=request.method, path=request.url, data=request.body, headers=headers, ) response_status, response_bytes, response_hdict = response response_body = bytes(response_bytes) response_headers = list(response_hdict.items()) request_state = 'Completed' except Exception as ex: request_state = 'Failed' failure = { 'kind': 'NetworkError', 'message': str(ex), } def _warn(e): logger.warning( "Failed to update std::net::http record, retrying. Reason: %s", e ) async def _update_request(): rloop = retryloop.RetryLoop( backoff=retryloop.exp_backoff(), timeout=300, ignore=(Exception,), retry_cb=_warn, ) async for iteration in rloop: async with iteration: await execute.parse_execute_json( db, """ with nh as module std::net::http, net as module std::net, state := $state, failure := < optional tuple< kind: net::RequestFailureKind, message: str > >to_json($failure), response_status := $response_status, response_body := $response_body, response_headers := >>$response_headers, response := ( if state = net::RequestState.Completed then ( insert nh::Response { created_at := datetime_of_statement(), status := assert_exists(response_status), body := response_body, headers := response_headers, } ) else ({}) ), update nh::ScheduledRequest filter .id = $id set { state := state, response := response, failure := failure, updated_at := datetime_of_statement(), }; """, variables={ 'id': request.id, 'state': request_state, 'response_status': response_status, 'response_body': response_body, 'response_headers': response_headers, 'failure': json.dumps(failure), }, cached_globally=True, tx_isolation=defines.TxIsolationLevel.RepeatableRead, query_tag='gel/net', ) await _update_request() async def _delete_requests( db: dbview.Database, expires_in: statypes.Duration ) -> None: def _warn(e): logger.warning( "Failed to delete std::net::http::ScheduledRequest, retrying." " Reason: %s", e, ) rloop = retryloop.RetryLoop( backoff=retryloop.exp_backoff(), timeout=300, ignore=(Exception,), retry_cb=_warn, ) async for iteration in rloop: async with iteration: if not db.tenant.is_database_connectable(db.name): # Don't run the net_worker if the database is not # connectable, e.g. being dropped continue result_json = await execute.parse_execute_json( db, """ with requests := ( select std::net::http::ScheduledRequest filter .state != std::net::RequestState.Pending and (datetime_of_statement() - .updated_at) > $expires_in ) select count((delete requests)); """, variables={"expires_in": expires_in.to_backend_str()}, cached_globally=True, tx_isolation=defines.TxIsolationLevel.RepeatableRead, query_tag='gel/net', ) result: list[int] = json.loads(result_json) if result[0] > 0: logger.debug(f"Deleted {result[0]} requests") else: logger.debug(f"No requests to delete") async def _gc(tenant: edbtenant.Tenant, expires_in: statypes.Duration) -> None: try: async with asyncio.TaskGroup() as g: for db in tenant.iter_dbs(): if db.name == defines.EDGEDB_SYSTEM_DB: continue g.create_task(_delete_requests(db, expires_in)) except Exception as ex: logger.debug( "GC of std::net::http::ScheduledRequest failed (instance: %s)", tenant.get_instance_name(), exc_info=ex, ) async def gc(server: edbserver.BaseServer) -> None: while True: tasks = [ tenant.create_task( _gc(tenant, NET_HTTP_REQUEST_TTL), interruptable=True ) for tenant in server.iter_tenants() if tenant.accept_new_tasks ] try: if tasks: await asyncio.wait(tasks) except Exception as ex: logger.debug( "GC of std::net::http::ScheduledRequest failed", exc_info=ex ) finally: await asyncio.sleep( NET_HTTP_REQUEST_TTL.to_microseconds() / 1_000_000.0 ) ================================================ FILE: edb/server/pgcluster.py ================================================ # Copyright (C) 2016-present MagicStack Inc. and the EdgeDB authors. # Copyright (C) 2016-present the asyncpg authors and contributors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """PostgreSQL cluster management.""" from __future__ import annotations from typing import ( Any, Callable, Optional, Iterable, Mapping, Sequence, Coroutine, Unpack, cast, TYPE_CHECKING, ) import asyncio import copy import hashlib import json import logging import os import os.path import pathlib import re import shlex import shutil import signal import struct import textwrap import urllib.parse from edb import buildmeta from edb import errors from edb.common import lru from edb.common import supervisor from edb.common import uuidgen from edb.server import args as srvargs from edb.server import defines from edb.server import pgconnparams from edb.server.ha import base as ha_base from edb.pgsql import common as pgcommon from edb.pgsql import params as pgparams if TYPE_CHECKING: from edb.server import pgcon logger = logging.getLogger('edb.pgcluster') pg_dump_logger = logging.getLogger('pg_dump') pg_restore_logger = logging.getLogger('pg_restore') pg_ctl_logger = logging.getLogger('pg_ctl') pg_config_logger = logging.getLogger('pg_config') initdb_logger = logging.getLogger('initdb') postgres_logger = logging.getLogger('postgres') get_database_backend_name = pgcommon.get_database_backend_name get_role_backend_name = pgcommon.get_role_backend_name EDGEDB_SERVER_SETTINGS = { 'client_encoding': 'utf-8', # DO NOT raise client_min_messages above NOTICE level # because server indirect block return machinery relies # on NoticeResponse as the data channel. 'client_min_messages': 'NOTICE', 'search_path': 'edgedb', 'timezone': 'UTC', 'intervalstyle': 'iso_8601', 'jit': 'off', 'default_transaction_isolation': 'serializable', } class ClusterError(Exception): pass class PostgresPidFileNotReadyError(Exception): """Raised on an attempt to read non-existent or bad Postgres PID file""" class BaseCluster: def __init__( self, *, instance_params: Optional[pgparams.BackendInstanceParams] = None, ) -> None: self._connection_addr: Optional[tuple[str, int]] = None self._connection_params: pgconnparams.ConnectionParams = \ pgconnparams.ConnectionParams(server_settings=EDGEDB_SERVER_SETTINGS) self._pg_config_data: dict[str, str] = {} self._pg_bin_dir: Optional[pathlib.Path] = None if instance_params is None: self._instance_params = ( pgparams.get_default_runtime_params().instance_params) else: self._instance_params = instance_params def get_db_name(self, db_name: str) -> str: if ( not self._instance_params.capabilities & pgparams.BackendCapabilities.CREATE_DATABASE ): assert ( db_name == defines.EDGEDB_SUPERUSER_DB ), f"db_name={db_name} is not allowed" rv = self.get_connection_params().database assert rv is not None return rv return get_database_backend_name( db_name, tenant_id=self._instance_params.tenant_id, ) def get_role_name(self, role_name: str) -> str: if ( not self._instance_params.capabilities & pgparams.BackendCapabilities.CREATE_ROLE ): assert ( role_name == defines.EDGEDB_SUPERUSER ), f"role_name={role_name} is not allowed" rv = self.get_connection_params().user assert rv is not None return rv return get_database_backend_name( role_name, tenant_id=self._instance_params.tenant_id, ) async def start( self, wait: int = 60, *, server_settings: Optional[Mapping[str, str]] = None, **opts: Any, ) -> None: raise NotImplementedError async def stop(self, wait: int = 60) -> None: raise NotImplementedError def destroy(self) -> None: raise NotImplementedError async def connect(self, *, source_description: str, apply_init_script: bool = False, **kwargs: Unpack[pgconnparams.CreateParamsKwargs] ) -> pgcon.PGConnection: """Connect to this cluster, with optional overriding parameters. If overriding parameters are specified, they are applied to a copy of the connection parameters before the connection takes place.""" from edb.server import pgcon connection = copy.copy(self.get_connection_params()) addr = self._get_connection_addr() assert addr is not None connection.update(hosts=[addr]) connection.update(**kwargs) conn = await pgcon.pg_connect( connection, source_description=source_description, backend_params=self.get_runtime_params(), apply_init_script=apply_init_script, ) return conn async def start_watching( self, failover_cb: Optional[Callable[[], None]] = None ) -> None: pass def stop_watching(self) -> None: pass def get_runtime_params(self) -> pgparams.BackendRuntimeParams: params = self.get_connection_params() login_role: Optional[str] = params.user sup_role = self.get_role_name(defines.EDGEDB_SUPERUSER) return pgparams.BackendRuntimeParams( instance_params=self._instance_params, session_authorization_role=( None if login_role == sup_role else login_role ), ) def overwrite_capabilities( self, caps: pgparams.BackendCapabilities ) -> None: self._instance_params = self._instance_params._replace( capabilities=caps ) def update_connection_params( self, **kwargs: Unpack[pgconnparams.CreateParamsKwargs], ) -> None: self._connection_params.update(**kwargs) def get_pgaddr(self) -> pgconnparams.ConnectionParams: assert self._connection_params is not None addr = self._get_connection_addr() assert addr is not None params = copy.copy(self._connection_params) params.update(hosts=[addr]) return params def get_connection_params( self, ) -> pgconnparams.ConnectionParams: assert self._connection_params is not None return self._connection_params def _get_connection_addr(self) -> Optional[tuple[str, int]]: return self._connection_addr def is_managed(self) -> bool: raise NotImplementedError async def get_status(self) -> str: raise NotImplementedError def _dump_restore_conn_args( self, dbname: str, ) -> tuple[list[str], dict[str, str]]: params = copy.copy(self.get_connection_params()) addr = self._get_connection_addr() assert addr is not None params.update(database=dbname, hosts=[addr]) args = [ f'--dbname={params.database}', f'--host={params.host}', f'--port={params.port}', f'--username={params.user}', ] env = os.environ.copy() if params.password: env['PGPASSWORD'] = params.password return args, env async def dump_database( self, dbname: str, *, exclude_schemas: Iterable[str] = (), exclude_tables: Iterable[str] = (), include_schemas: Iterable[str] = (), include_tables: Iterable[str] = (), include_extensions: Iterable[str] = (), schema_only: bool = False, dump_object_owners: bool = True, create_database: bool = False, ) -> bytes: status = await self.get_status() if status != 'running': raise ClusterError('cannot dump: cluster is not running') if self._pg_bin_dir is None: await self.lookup_postgres() pg_dump = self._find_pg_binary('pg_dump') conn_args, env = self._dump_restore_conn_args(dbname) args = [ pg_dump, '--inserts', *conn_args, ] if not dump_object_owners: args.append('--no-owner') if schema_only: args.append('--schema-only') if create_database: args.append('--create') configs = [ ('exclude-schema', exclude_schemas), ('exclude-table', exclude_tables), ('schema', include_schemas), ('table', include_tables), ('extension', include_extensions), ] for flag, vals in configs: for val in vals: args.append(f'--{flag}={val}') stdout_lines, _, _ = await _run_logged_subprocess( args, logger=pg_dump_logger, log_stdout=False, env=env, ) return b'\n'.join(stdout_lines) async def _copy_database( self, src_dbname: str, tgt_dbname: str, src_args: list[str], tgt_args: list[str], ) -> None: status = await self.get_status() if status != 'running': raise ClusterError('cannot dump: cluster is not running') if self._pg_bin_dir is None: await self.lookup_postgres() pg_dump = self._find_pg_binary('pg_dump') # We actually just use psql to restore, because it is more # tolerant of version differences. # TODO: Maybe use pg_restore when we know we match the backend version? pg_restore = self._find_pg_binary('psql') src_conn_args, src_env = self._dump_restore_conn_args(src_dbname) tgt_conn_args, _tgt_env = self._dump_restore_conn_args(tgt_dbname) dump_args = [ pg_dump, '--verbose', *src_conn_args, *src_args ] restore_args = [ pg_restore, *tgt_conn_args, *tgt_args ] rpipe, wpipe = os.pipe() wpipef = os.fdopen(wpipe, "wb") try: # N.B: uvloop will waitpid() on the child process even if we don't # actually await on it due to a later error. dump_p, dump_out_r, dump_err_r = await _start_logged_subprocess( dump_args, logger=pg_dump_logger, override_stdout=wpipef, log_stdout=False, capture_stdout=False, capture_stderr=False, env=src_env, ) res_p, res_out_r, res_err_r = await _start_logged_subprocess( restore_args, logger=pg_restore_logger, stdin=rpipe, capture_stdout=False, capture_stderr=False, log_stdout=True, log_stderr=True, env=src_env, ) finally: wpipef.close() os.close(rpipe) dump_exit_code, _, _, restore_exit_code, _, _ = await asyncio.gather( dump_p.wait(), dump_out_r, dump_err_r, res_p.wait(), res_out_r, res_err_r, ) if dump_exit_code != 0 and dump_exit_code != -signal.SIGPIPE: raise errors.ExecutionError( f'branch failed: {dump_args[0]} exited with status ' f'{dump_exit_code}' ) if restore_exit_code != 0: raise errors.ExecutionError( f'branch failed: ' f'{restore_args[0]} exited with status {restore_exit_code}' ) def _find_pg_binary(self, binary: str) -> str: assert self._pg_bin_dir is not None bpath = self._pg_bin_dir / binary if not bpath.is_file(): raise ClusterError( 'could not find {} executable: '.format(binary) + '{!r} does not exist or is not a file'.format(bpath)) return str(bpath) def _subprocess_error( self, name: str, exitcode: int, stderr: Optional[bytes], ) -> ClusterError: if stderr: return ClusterError( f'{name} exited with status {exitcode}:\n' + textwrap.indent(stderr.decode(), ' ' * 4), ) else: return ClusterError( f'{name} exited with status {exitcode}', ) async def lookup_postgres(self) -> None: self._pg_bin_dir = await get_pg_bin_dir() def get_client_id(self) -> int: return 0 class Cluster(BaseCluster): def __init__( self, data_dir: pathlib.Path, *, runstate_dir: Optional[pathlib.Path] = None, instance_params: Optional[pgparams.BackendInstanceParams] = None, log_level: str = 'i', ): super().__init__(instance_params=instance_params) self._data_dir = data_dir self._runstate_dir = ( runstate_dir if runstate_dir is not None else data_dir) self._daemon_pid: Optional[int] = None self._daemon_process: Optional[asyncio.subprocess.Process] = None self._daemon_supervisor: Optional[supervisor.Supervisor] = None self._log_level = log_level def is_managed(self) -> bool: return True def get_data_dir(self) -> pathlib.Path: return self._data_dir def get_main_pid(self) -> Optional[int]: return self._daemon_pid async def get_status(self) -> str: stdout_lines, stderr_lines, exit_code = ( await _run_logged_text_subprocess( [self._pg_ctl, 'status', '-D', str(self._data_dir)], logger=pg_ctl_logger, check=False, ) ) if ( exit_code == 4 or not os.path.exists(self._data_dir) or not os.listdir(self._data_dir) ): return 'not-initialized' elif exit_code == 3: return 'stopped' elif exit_code == 0: output = '\n'.join(stdout_lines) r = re.match(r'.*PID\s?:\s+(\d+).*', output) if not r: raise ClusterError( f'could not parse pg_ctl status output: {output}') self._daemon_pid = int(r.group(1)) if self._connection_addr is None: self._connection_addr = self._connection_addr_from_pidfile() return 'running' else: stderr_text = '\n'.join(stderr_lines) raise ClusterError( f'`pg_ctl status` exited with status {exit_code}:\n' + textwrap.indent(stderr_text, ' ' * 4), ) async def ensure_initialized(self, **settings: Any) -> bool: cluster_status = await self.get_status() if cluster_status == 'not-initialized': logger.info( 'Initializing database cluster in %s', self._data_dir) have_c_utf8 = self.get_runtime_params().has_c_utf8_locale await self.init( username='postgres', locale='C.UTF-8' if have_c_utf8 else 'en_US.UTF-8', lc_collate='C', encoding='UTF8', ) self.reset_hba() self.add_hba_entry( type='local', database='all', user='postgres', auth_method='trust' ) self.add_hba_entry( type='local', database='replication', user='postgres', auth_method='trust' ) return True else: return False async def init(self, **settings: str) -> None: """Initialize cluster.""" if await self.get_status() != 'not-initialized': raise ClusterError( 'cluster in {!r} has already been initialized'.format( self._data_dir)) if settings: settings_args = ['--{}={}'.format(k.replace('_', '-'), v) for k, v in settings.items()] extra_args = ['-o'] + [' '.join(settings_args)] else: extra_args = [] await _run_logged_subprocess( [self._pg_ctl, 'init', '-D', str(self._data_dir)] + extra_args, logger=initdb_logger, ) async def start( self, wait: int = 60, *, server_settings: Optional[Mapping[str, str]] = None, **opts: str, ) -> None: """Start the cluster.""" status = await self.get_status() if status == 'running': return elif status == 'not-initialized': raise ClusterError( 'cluster in {!r} has not been initialized'.format( self._data_dir)) extra_args = ['--{}={}'.format(k, v) for k, v in opts.items()] start_settings = { 'listen_addresses': '', # we use Unix sockets 'unix_socket_permissions': '0700', 'unix_socket_directories': str(self._runstate_dir), # here we are not setting superuser_reserved_connections because # we're using superuser only now (so all connections available), # and we don't support reserving connections for now 'max_connections': str(self._instance_params.max_connections), # From Postgres docs: # # You might need to raise this value if you have queries that # touch many different tables in a single transaction, e.g., # query of a parent table with many children. # # EdgeDB queries might touch _lots_ of tables, especially in deep # inheritance hierarchies. This is especially important in low # `max_connections` scenarios. 'max_locks_per_transaction': 1024, 'max_pred_locks_per_transaction': 1024, "shared_preload_libraries": ",".join( [ "edb_stat_statements", ] ), "edb_stat_statements.track_planning": "true", # Required for pg_basebackup --incremental to work "summarize_wal": "on", } if os.getenv('EDGEDB_DEBUG_PGSERVER'): start_settings['log_min_messages'] = 'info' start_settings['log_statement'] = 'all' else: log_level_map = { 'd': 'INFO', 'i': 'WARNING', # NOTICE in Postgres is quite noisy 'w': 'WARNING', 'e': 'ERROR', 's': 'PANIC', } start_settings['log_min_messages'] = log_level_map[self._log_level] start_settings['log_statement'] = 'none' start_settings['log_line_prefix'] = '' if server_settings: start_settings.update(server_settings) ssl_key = start_settings.get('ssl_key_file') if ssl_key: # Make sure server certificate key file has correct permissions. keyfile = os.path.join(self._data_dir, 'srvkey.pem') assert isinstance(ssl_key, str) shutil.copy(ssl_key, keyfile) os.chmod(keyfile, 0o600) start_settings['ssl_key_file'] = keyfile for k, v in start_settings.items(): extra_args.extend(['-c', '{}={}'.format(k, v)]) self._daemon_process, *loggers = await _start_logged_subprocess( [self._postgres, '-D', str(self._data_dir), *extra_args], capture_stdout=False, capture_stderr=False, logger=postgres_logger, log_processor=postgres_log_processor, ) self._daemon_pid = self._daemon_process.pid sup = await supervisor.Supervisor.create(name="postgres loggers") for logger_coro in loggers: sup.create_task(logger_coro) self._daemon_supervisor = sup await self._test_connection(timeout=wait) async def reload(self) -> None: """Reload server configuration.""" status = await self.get_status() if status != 'running': raise ClusterError('cannot reload: cluster is not running') await _run_logged_subprocess( [self._pg_ctl, 'reload', '-D', str(self._data_dir)], logger=pg_ctl_logger, ) async def stop(self, wait: int = 60) -> None: await _run_logged_subprocess( [ self._pg_ctl, 'stop', '-D', str(self._data_dir), '-t', str(wait), '-m', 'fast' ], logger=pg_ctl_logger, ) if ( self._daemon_process is not None and self._daemon_process.returncode is None ): self._daemon_process.terminate() await asyncio.wait_for(self._daemon_process.wait(), timeout=wait) if self._daemon_supervisor is not None: await self._daemon_supervisor.cancel() self._daemon_supervisor = None def destroy(self) -> None: shutil.rmtree(self._data_dir) def reset_hba(self) -> None: """Remove all records from pg_hba.conf.""" pg_hba = os.path.join(self._data_dir, 'pg_hba.conf') try: with open(pg_hba, 'w'): pass except IOError as e: raise ClusterError( 'cannot modify HBA records: {}'.format(e)) from e def add_hba_entry( self, *, type: str = 'host', database: str, user: str, address: Optional[str] = None, auth_method: str, auth_options: Optional[Mapping[str, Any]] = None, ) -> None: """Add a record to pg_hba.conf.""" if type not in {'local', 'host', 'hostssl', 'hostnossl'}: raise ValueError('invalid HBA record type: {!r}'.format(type)) pg_hba = os.path.join(self._data_dir, 'pg_hba.conf') record = '{} {} {}'.format(type, database, user) if type != 'local': if address is None: raise ValueError( '{!r} entry requires a valid address'.format(type)) else: record += ' {}'.format(address) record += ' {}'.format(auth_method) if auth_options is not None: record += ' ' + ' '.join( '{}={}'.format(k, v) for k, v in auth_options.items()) try: with open(pg_hba, 'a') as f: print(record, file=f) except IOError as e: raise ClusterError( 'cannot modify HBA records: {}'.format(e)) from e async def trust_local_connections(self) -> None: self.reset_hba() self.add_hba_entry(type='local', database='all', user='all', auth_method='trust') self.add_hba_entry(type='host', address='127.0.0.1/32', database='all', user='all', auth_method='trust') self.add_hba_entry(type='host', address='::1/128', database='all', user='all', auth_method='trust') status = await self.get_status() if status == 'running': await self.reload() async def lookup_postgres(self) -> None: await super().lookup_postgres() self._pg_ctl = self._find_pg_binary('pg_ctl') self._postgres = self._find_pg_binary('postgres') def _get_connection_addr(self) -> tuple[str, int]: if self._connection_addr is None: self._connection_addr = self._connection_addr_from_pidfile() return self._connection_addr def _connection_addr_from_pidfile(self) -> tuple[str, int]: pidfile = os.path.join(self._data_dir, 'postmaster.pid') try: with open(pidfile, 'rt') as f: piddata = f.read() except FileNotFoundError: raise PostgresPidFileNotReadyError lines = piddata.splitlines() if len(lines) < 6: # A complete postgres pidfile is at least 6 lines raise PostgresPidFileNotReadyError pmpid = int(lines[0]) if self._daemon_pid and pmpid != self._daemon_pid: # This might be an old pidfile left from previous postgres # daemon run. raise PostgresPidFileNotReadyError portnum = int(lines[3]) sockdir = lines[4] hostaddr = lines[5] if sockdir: if sockdir[0] != '/': # Relative sockdir sockdir = os.path.normpath( os.path.join(self._data_dir, sockdir)) host_str = sockdir elif hostaddr: host_str = hostaddr else: raise PostgresPidFileNotReadyError if host_str == '*': host_str = 'localhost' elif host_str == '0.0.0.0': host_str = '127.0.0.1' elif host_str == '::': host_str = '::1' return (host_str, portnum) async def _test_connection(self, timeout: int = 60) -> str: from edb.server import pgcon self._connection_addr = None connected = False params = pgconnparams.ConnectionParams( user="postgres", database="postgres") for n in range(timeout + 9): # pg usually comes up pretty quickly, but not so quickly # that we don't hit the wait case. Make our first several # waits pretty short, to shave almost a second off the # happy case. sleep_time = 1.0 if n >= 10 else 0.1 try: conn_addr = self._get_connection_addr() except PostgresPidFileNotReadyError: try: assert self._daemon_process is not None code = await asyncio.wait_for( self._daemon_process.wait(), sleep_time ) except asyncio.TimeoutError: # means that the postgres process is still alive pass else: # the postgres process has exited prematurely raise ClusterError(f"The backend exited with {code}") continue try: params.update(hosts=[conn_addr]) con = await asyncio.wait_for( pgcon.pg_connect( params, source_description=f"{self.__class__}._test_connection", backend_params=self.get_runtime_params(), apply_init_script=False, ), timeout=5, ) except ( OSError, asyncio.TimeoutError, pgcon.BackendConnectionError, ) as e: if n % 10 == 0 and 0 < n < timeout + 9 - 1: logger.error("cannot connect to the backend cluster:" " %s, retrying...", e) await asyncio.sleep(sleep_time) continue except pgcon.BackendError: # Any other error other than ServerNotReadyError or # ConnectionError is interpreted to indicate the server is # up. break else: connected = True con.terminate() break if connected: return 'running' else: return 'not-initialized' class RemoteCluster(BaseCluster): def __init__( self, *, connection_addr: tuple[str, int], connection_params: pgconnparams.ConnectionParams, instance_params: Optional[pgparams.BackendInstanceParams] = None, ha_backend: Optional[ha_base.HABackend] = None, ): super().__init__(instance_params=instance_params) self._connection_params = connection_params self._connection_params.update( server_settings=EDGEDB_SERVER_SETTINGS ) self._connection_addr = connection_addr self._ha_backend = ha_backend def _get_connection_addr(self) -> Optional[tuple[str, int]]: if self._ha_backend is not None: return self._ha_backend.get_master_addr() return self._connection_addr async def ensure_initialized(self, **settings: Any) -> bool: return False def is_managed(self) -> bool: return False async def get_status(self) -> str: return 'running' def init(self, **settings: str) -> Optional[str]: pass async def start( self, wait: int = 60, *, server_settings: Optional[Mapping[str, str]] = None, **opts: Any, ) -> None: pass async def stop(self, wait: int = 60) -> None: pass def destroy(self) -> None: pass def reset_hba(self) -> None: raise ClusterError('cannot modify HBA records of unmanaged cluster') def add_hba_entry( self, *, type: str = 'host', database: str, user: str, address: Optional[str] = None, auth_method: str, auth_options: Optional[Mapping[str, Any]] = None, ) -> None: raise ClusterError('cannot modify HBA records of unmanaged cluster') async def start_watching( self, failover_cb: Optional[Callable[[], None]] = None ) -> None: if self._ha_backend is not None: self._ha_backend.set_failover_callback(failover_cb) await self._ha_backend.start_watching() def stop_watching(self) -> None: if self._ha_backend is not None: self._ha_backend.stop_watching() @lru.method_cache def get_client_id(self) -> int: tenant_id = self._instance_params.tenant_id if self._ha_backend is not None: backend_dsn = self._ha_backend.dsn else: assert self._connection_addr is not None assert self._connection_params is not None host, port = self._connection_addr database = self._connection_params.database backend_dsn = f"postgres://{host}:{port}/{database}" data = f"{backend_dsn}|{tenant_id}".encode("utf-8") digest = hashlib.blake2b(data, digest_size=8).digest() rv: int = struct.unpack("q", digest)[0] return rv async def get_pg_bin_dir() -> pathlib.Path: pg_config_data = await get_pg_config() pg_bin_dir = pg_config_data.get('bindir') if not pg_bin_dir: raise ClusterError( 'pg_config output did not provide the BINDIR value') return pathlib.Path(pg_bin_dir) async def get_pg_config() -> dict[str, str]: stdout_lines, _, _ = await _run_logged_text_subprocess( [str(buildmeta.get_pg_config_path())], logger=pg_config_logger, ) config = {} for line in stdout_lines: k, eq, v = line.partition('=') if eq: config[k.strip().lower()] = v.strip() return config async def get_local_pg_cluster( data_dir: pathlib.Path, *, runstate_dir: Optional[pathlib.Path] = None, max_connections: Optional[int] = None, tenant_id: Optional[str] = None, log_level: Optional[str] = None, ) -> Cluster: if log_level is None: log_level = 'i' if tenant_id is None: tenant_id = buildmeta.get_default_tenant_id() instance_params = None if max_connections is not None: instance_params = pgparams.get_default_runtime_params( max_connections=max_connections, tenant_id=tenant_id, ).instance_params cluster = Cluster( data_dir=data_dir, runstate_dir=runstate_dir, instance_params=instance_params, log_level=log_level, ) await cluster.lookup_postgres() return cluster async def get_remote_pg_cluster( dsn: str, *, tenant_id: Optional[str] = None, specified_capabilities: Optional[srvargs.BackendCapabilitySets] = None, ) -> RemoteCluster: from edb.server import pgcon parsed = urllib.parse.urlparse(dsn) ha_backend = None if parsed.scheme not in {'postgresql', 'postgres'}: ha_backend = ha_base.get_backend(parsed) if ha_backend is None: raise ValueError( 'invalid DSN: scheme is expected to be "postgresql", ' '"postgres" or one of the supported HA backend, ' 'got {!r}'.format(parsed.scheme)) addr = await ha_backend.get_cluster_consensus() dsn = 'postgresql://{}:{}'.format(*addr) if parsed.query: # Allow passing through Postgres connection parameters from the HA # backend DSN as "pg" prefixed query strings. For example, an HA # backend DSN with `?pgpassword=123` will result an actual backend # DSN with `?password=123`. They have higher priority than the `PG` # prefixed environment variables like `PGPASSWORD`. pq = urllib.parse.parse_qs(parsed.query, strict_parsing=True) query = {} for k, v in pq.items(): if k.startswith("pg") and k not in ["pghost", "pgport"]: if isinstance(v, list): val = v[-1] else: val = cast(str, v) query[k[2:]] = val if query: dsn += f"?{urllib.parse.urlencode(query)}" if tenant_id is None: t_id = buildmeta.get_default_tenant_id() else: t_id = tenant_id async def _get_cluster_type( conn: pgcon.PGConnection, ) -> tuple[type[RemoteCluster], Optional[str]]: managed_clouds = { 'rds_superuser': RemoteCluster, # Amazon RDS 'cloudsqlsuperuser': RemoteCluster, # GCP Cloud SQL 'azure_pg_admin': RemoteCluster, # Azure Postgres } managed_cloud_super = await conn.sql_fetch_val( b""" SELECT rolname FROM pg_roles WHERE rolname IN (SELECT json_array_elements_text($1::json)) LIMIT 1 """, args=[json.dumps(list(managed_clouds)).encode("utf-8")], ) if managed_cloud_super is not None: rolname = managed_cloud_super.decode("utf-8") return managed_clouds[rolname], rolname else: return RemoteCluster, None async def _detect_capabilities( conn: pgcon.PGConnection, ) -> pgparams.BackendCapabilities: from edb.server import pgcon from edb.server.pgcon import errors caps = pgparams.BackendCapabilities.NONE try: cur_cluster_name = await conn.sql_fetch_val( b""" SELECT setting FROM pg_file_settings WHERE setting = 'cluster_name' AND sourcefile = (( SELECT setting FROM pg_settings WHERE name = 'data_directory' ) || '/postgresql.auto.conf') """, ) except pgcon.BackendPrivilegeError: configfile_access = False else: try: await conn.sql_execute(b""" ALTER SYSTEM SET cluster_name = 'edgedb-test' """) except pgcon.BackendPrivilegeError: configfile_access = False except pgcon.BackendError as e: # Stolon keeper symlinks postgresql.auto.conf to /dev/null # making ALTER SYSTEM fail with InternalServerError, # see https://github.com/sorintlab/stolon/pull/343 if 'could not fsync file "postgresql.auto.conf"' in e.args[0]: configfile_access = False else: raise else: configfile_access = True if cur_cluster_name: cn = pgcommon.quote_literal( cur_cluster_name.decode("utf-8")) await conn.sql_execute( f""" ALTER SYSTEM SET cluster_name = {cn} """.encode("utf-8"), ) else: await conn.sql_execute( b""" ALTER SYSTEM SET cluster_name = DEFAULT """, ) if configfile_access: caps |= pgparams.BackendCapabilities.CONFIGFILE_ACCESS await conn.sql_execute(b"START TRANSACTION") rname = str(uuidgen.uuid1mc()) try: await conn.sql_execute( f"CREATE ROLE {pgcommon.quote_ident(rname)} WITH SUPERUSER" .encode("utf-8"), ) except pgcon.BackendPrivilegeError: can_make_superusers = False except pgcon.BackendError as e: if e.code_is( errors.ERROR_INTERNAL_ERROR ) and "not in permitted superuser list" in str(e): # DigitalOcean raises a custom error: # XX000: Role ... not in permitted superuser list can_make_superusers = False else: raise else: can_make_superusers = True finally: await conn.sql_execute(b"ROLLBACK") if can_make_superusers: caps |= pgparams.BackendCapabilities.SUPERUSER_ACCESS coll = await conn.sql_fetch_val(b""" SELECT collname FROM pg_collation WHERE lower(replace(collname, '-', '')) = 'c.utf8' LIMIT 1; """) if coll is not None: caps |= pgparams.BackendCapabilities.C_UTF8_LOCALE roles = json.loads(await conn.sql_fetch_val( b""" SELECT json_build_object( 'rolcreaterole', rolcreaterole, 'rolcreatedb', rolcreatedb ) FROM pg_roles WHERE rolname = (SELECT current_user); """, )) if roles['rolcreaterole']: caps |= pgparams.BackendCapabilities.CREATE_ROLE if roles['rolcreatedb']: caps |= pgparams.BackendCapabilities.CREATE_DATABASE stats_ver = await conn.sql_fetch_val(b""" SELECT default_version FROM pg_available_extensions WHERE name = 'edb_stat_statements'; """) if stats_ver in (b"1.0",): caps |= pgparams.BackendCapabilities.STAT_STATEMENTS return caps async def _get_pg_settings( conn: pgcon.PGConnection, name: str, ) -> str: return await conn.sql_fetch_val( # type: ignore b"SELECT setting FROM pg_settings WHERE name = $1", args=[name.encode("utf-8")], ) async def _get_reserved_connections( conn: pgcon.PGConnection, ) -> int: rv = int( await _get_pg_settings(conn, 'superuser_reserved_connections') ) for name in [ 'rds.rds_superuser_reserved_connections', ]: value = await _get_pg_settings(conn, name) if value: rv += int(value) return rv probe_connection = pgconnparams.ConnectionParams(dsn=dsn) conn = await pgcon.pg_connect( probe_connection, source_description="remote cluster probe", backend_params=pgparams.get_default_runtime_params(), apply_init_script=False ) params = conn.connection addr = conn.addr try: data = json.loads(await conn.sql_fetch_val( b""" SELECT json_build_object( 'user', current_user, 'dbname', current_database(), 'connlimit', ( select rolconnlimit from pg_roles where rolname = current_user ) )""", )) params.update( user=data["user"], database=data["dbname"] ) cluster_type, superuser_name = await _get_cluster_type(conn) max_connections = data["connlimit"] pg_max_connections = await _get_pg_settings(conn, 'max_connections') if max_connections == -1 or not isinstance(max_connections, int): max_connections = pg_max_connections else: max_connections = min(max_connections, pg_max_connections) capabilities = await _detect_capabilities(conn) if ( specified_capabilities is not None and specified_capabilities.must_be_absent ): disabled = [] for cap in specified_capabilities.must_be_absent: if capabilities & cap: capabilities &= ~cap disabled.append(cap) if disabled: logger.info( f"the following backend capabilities are explicitly " f"disabled by server command line: " f"{', '.join(str(cap.name) for cap in disabled)}" ) if t_id != buildmeta.get_default_tenant_id(): # GOTCHA: This tenant_id check cannot protect us from running # multiple EdgeDB servers using the default tenant_id with # different catalog versions on the same backend. However, that # would fail during bootstrap in single-role/database mode. if not capabilities & pgparams.BackendCapabilities.CREATE_ROLE: raise ClusterError( "The remote backend doesn't support CREATE ROLE; " "multi-tenancy is disabled." ) if not capabilities & pgparams.BackendCapabilities.CREATE_DATABASE: raise ClusterError( "The remote backend doesn't support CREATE DATABASE; " "multi-tenancy is disabled." ) pg_ver_string = conn.get_server_parameter_status("server_version") if pg_ver_string is None: raise ClusterError( "remote server did not report its version " "in ParameterStatus") if capabilities & pgparams.BackendCapabilities.CREATE_DATABASE: # If we can create databases, assume we're free to create # extensions in them as well. ext_schema = "edgedbext" existing_exts = {} else: ext_schema = (await conn.sql_fetch_val( b''' SELECT COALESCE( (SELECT schema_name FROM information_schema.schemata WHERE schema_name = 'heroku_ext'), 'edgedbext') ''', )).decode("utf-8") existing_exts_data = await conn.sql_fetch( b""" SELECT extname, nspname FROM pg_extension INNER JOIN pg_namespace ON (pg_extension.extnamespace = pg_namespace.oid) """ ) existing_exts = { r[0].decode("utf-8"): r[1].decode("utf-8") for r in existing_exts_data } instance_params = pgparams.BackendInstanceParams( capabilities=capabilities, version=buildmeta.parse_pg_version(pg_ver_string), base_superuser=superuser_name, max_connections=int(max_connections), reserved_connections=await _get_reserved_connections(conn), tenant_id=t_id, ext_schema=ext_schema, existing_exts=existing_exts, ) finally: conn.terminate() return cluster_type( connection_addr=addr, connection_params=params, instance_params=instance_params, ha_backend=ha_backend, ) async def _run_logged_text_subprocess( args: Sequence[str], logger: logging.Logger, level: int = logging.DEBUG, check: bool = True, log_stdout: bool = True, timeout: Optional[float] = None, **kwargs: Any, ) -> tuple[list[str], list[str], int]: stdout_lines, stderr_lines, exit_code = await _run_logged_subprocess( args, logger=logger, level=level, check=check, log_stdout=log_stdout, timeout=timeout, **kwargs, ) return ( [line.decode() for line in stdout_lines], [line.decode() for line in stderr_lines], exit_code, ) async def _run_logged_subprocess( args: Sequence[str], logger: logging.Logger, level: int = logging.DEBUG, check: bool = True, log_stdout: bool = True, log_stderr: bool = True, capture_stdout: bool = True, capture_stderr: bool = True, timeout: Optional[float] = None, stdin: Any = asyncio.subprocess.PIPE, **kwargs: Any, ) -> tuple[list[bytes], list[bytes], int]: process, stdout_reader, stderr_reader = await _start_logged_subprocess( args, logger=logger, level=level, log_stdout=log_stdout, log_stderr=log_stderr, capture_stdout=capture_stdout, capture_stderr=capture_stderr, stdin=stdin, **kwargs, ) if isinstance(stdin, int) and stdin >= 0: os.close(stdin) exit_code, stdout_lines, stderr_lines = await asyncio.wait_for( asyncio.gather(process.wait(), stdout_reader, stderr_reader), timeout=timeout, ) if exit_code != 0 and check: stderr_text = b'\n'.join(stderr_lines).decode() raise ClusterError( f'{args[0]} exited with status {exit_code}:\n' + textwrap.indent(stderr_text, ' ' * 4), ) else: return stdout_lines, stderr_lines, exit_code async def _start_logged_subprocess( args: Sequence[str], *, logger: logging.Logger, level: int = logging.DEBUG, override_stdout: Any = None, override_stderr: Any = None, log_stdout: bool = True, log_stderr: bool = True, capture_stdout: bool = True, capture_stderr: bool = True, stdin: Any = asyncio.subprocess.PIPE, log_processor: Optional[Callable[[str], tuple[str, int]]] = None, **kwargs: Any, ) -> tuple[ asyncio.subprocess.Process, Coroutine[Any, Any, list[bytes]], Coroutine[Any, Any, list[bytes]], ]: logger.log( level, f'running `{" ".join(shlex.quote(arg) for arg in args)}`' ) process = await asyncio.create_subprocess_exec( *args, stdin=stdin, stdout=( override_stdout if override_stdout else asyncio.subprocess.PIPE if log_stdout or capture_stdout else asyncio.subprocess.DEVNULL ), stderr=( override_stderr if override_stderr else asyncio.subprocess.PIPE if log_stderr or capture_stderr else asyncio.subprocess.DEVNULL ), limit=2 ** 20, # 1 MiB **kwargs, ) if log_stderr or capture_stderr: assert override_stderr is None assert process.stderr is not None stderr_reader = _capture_and_log_subprocess_output( process.pid, process.stderr, logger, level, log_processor, capture_output=capture_stderr, log_output=log_stderr, ) else: stderr_reader = _dummy() if log_stdout or capture_stdout: assert override_stdout is None assert process.stdout is not None stdout_reader = _capture_and_log_subprocess_output( process.pid, process.stdout, logger, level, log_processor, capture_output=capture_stdout, log_output=log_stdout, ) else: stdout_reader = _dummy() return process, stdout_reader, stderr_reader async def _capture_and_log_subprocess_output( pid: int, stream: asyncio.StreamReader, logger: logging.Logger, level: int, log_processor: Optional[Callable[[str], tuple[str, int]]] = None, *, capture_output: bool, log_output: bool, ) -> list[bytes]: lines = [] while not stream.at_eof(): line = await _safe_readline(stream) if line or not stream.at_eof(): line = line.rstrip(b'\n') if capture_output: lines.append(line) if log_output: log_line = line.decode() if log_processor is not None: log_line, level = log_processor(log_line) logger.log(level, log_line, extra={"process": pid}) return lines async def _safe_readline(stream: asyncio.StreamReader) -> bytes: try: line = await stream.readline() except ValueError: line = b"" return line async def _dummy() -> list[bytes]: return [] postgres_to_python_level_map = { "DEBUG5": logging.DEBUG, "DEBUG4": logging.DEBUG, "DEBUG3": logging.DEBUG, "DEBUG2": logging.DEBUG, "DEBUG1": logging.DEBUG, "INFO": logging.INFO, "NOTICE": logging.INFO, "LOG": logging.INFO, "WARNING": logging.WARNING, "ERROR": logging.ERROR, "FATAL": logging.CRITICAL, "PANIC": logging.CRITICAL, } postgres_log_re = re.compile(r'^(\w+):\s*(.*)$') postgres_specific_msg_level_map = { "terminating connection due to administrator command": logging.INFO, "the database system is shutting down": logging.INFO, } def postgres_log_processor(msg: str) -> tuple[str, int]: if m := postgres_log_re.match(msg): postgres_level = m.group(1) msg = m.group(2) level = postgres_specific_msg_level_map.get( msg, postgres_to_python_level_map.get(postgres_level, logging.INFO), ) else: level = logging.INFO return msg, level ================================================ FILE: edb/server/pgcon/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from .errors import ( BackendError, BackendConnectionError, BackendPrivilegeError, BackendCatalogNameError, ) from .pgcon import ( PGConnection, ) from .connect import ( pg_connect, SETUP_TEMP_TABLE_SCRIPT, SETUP_CONFIG_CACHE_SCRIPT, SETUP_DML_DUMMY_TABLE_SCRIPT, RESET_STATIC_CFG_SCRIPT, ) __all__ = ( 'pg_connect', 'PGConnection', 'BackendError', 'BackendConnectionError', 'BackendPrivilegeError', 'BackendCatalogNameError', 'SETUP_TEMP_TABLE_SCRIPT', 'SETUP_CONFIG_CACHE_SCRIPT', 'SETUP_DML_DUMMY_TABLE_SCRIPT', 'RESET_STATIC_CFG_SCRIPT' ) ================================================ FILE: edb/server/pgcon/connect.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import logging import textwrap from edb.pgsql.common import quote_ident as pg_qi from edb.pgsql.common import versioned_schema from edb.pgsql import params as pg_params from edb.server import pgcon from . import errors as pgerror from . import rust_transport logger = logging.getLogger('edb.server') INIT_CON_SCRIPT: bytes | None = None # The '_edgecon_state table' is used to store information about # the current session. The `type` column is one character, with one # of the following values: # # * 'C': a session-level config setting # # * 'B': a session-level config setting that's implemented by setting # a corresponding Postgres config setting. # * 'A': an instance-level config setting from command-line arguments # * 'E': an instance-level config setting from environment variable # * 'F': an instance/tenant-level config setting from the TOML config file # # * 'S': a session-level pg-ext config setting (frontend-only) # * 'L': a transaction-level pg-ext config setting (frontend-only) # * 'P': a session-level pg-ext config setting that's implemented by setting # a corresponding backend config setting (not stored in _edgecon_state) # # Please also update ConStateType in edb/server/config/__init__.py if changed. SETUP_TEMP_TABLE_SCRIPT = ''' CREATE TEMPORARY TABLE _edgecon_state ( name text NOT NULL, value jsonb NOT NULL, type text NOT NULL CHECK( type = 'C' OR type = 'B' OR type = 'A' OR type = 'E' OR type = 'F' OR type = 'S' OR type = 'L'), UNIQUE(name, type) ); '''.strip() SETUP_CONFIG_CACHE_SCRIPT = ''' CREATE TEMPORARY TABLE _config_cache ( source edgedb._sys_config_source_t, value edgedb._sys_config_val_t NOT NULL ); '''.strip() # A empty dummy table used when we need to emit no-op DML. # # This is used by scan_check_ctes in the pgsql compiler to # force the evaluation of error checking. SETUP_DML_DUMMY_TABLE_SCRIPT = ''' CREATE TEMPORARY TABLE _dml_dummy ( id int8, flag bool, unique(id) ); INSERT INTO _dml_dummy VALUES (0, false); '''.strip() RESET_STATIC_CFG_SCRIPT: bytes = b''' WITH x1 AS ( DELETE FROM _config_cache ) DELETE FROM _edgecon_state WHERE type = 'A' OR type = 'E' OR type = 'F'; ''' def _build_init_con_script(*, check_pg_is_in_recovery: bool) -> bytes: if check_pg_is_in_recovery: pg_is_in_recovery = (''' SELECT CASE WHEN pg_is_in_recovery() THEN edgedb.raise( NULL::bigint, 'read_only_sql_transaction', msg => 'cannot use a hot standby' ) END; ''').strip() else: pg_is_in_recovery = '' edgedb_schema = versioned_schema('edgedb') return textwrap.dedent(f''' {pg_is_in_recovery} {SETUP_TEMP_TABLE_SCRIPT} {SETUP_CONFIG_CACHE_SCRIPT} {SETUP_DML_DUMMY_TABLE_SCRIPT} CREATE CONSTRAINT TRIGGER _edgecon_state_local_reset AFTER INSERT ON _edgecon_state DEFERRABLE INITIALLY DEFERRED FOR EACH ROW EXECUTE FUNCTION {edgedb_schema}._clear_fe_local_sql_settings(); PREPARE _clear_state AS WITH x1 AS ( DELETE FROM _config_cache ) DELETE FROM _edgecon_state WHERE type = 'C' OR type = 'B' or type = 'S'; PREPARE _apply_state(jsonb) AS INSERT INTO _edgecon_state(name, value, type) SELECT (CASE WHEN e->'type' = '"B"'::jsonb THEN edgedb._apply_session_config(e->>'name', e->'value') ELSE e->>'name' END) AS name, e->'value' AS value, e->>'type' AS type FROM jsonb_array_elements($1::jsonb) AS e; PREPARE _reset_session_config AS SELECT edgedb._reset_session_config(); PREPARE _apply_sql_state(jsonb) AS WITH be AS ( SELECT pg_catalog.set_config( e->>'name', e->>'value', false ) AS value FROM jsonb_array_elements($1::jsonb) AS e WHERE e->'type' = '"P"'::jsonb ), fe AS ( INSERT INTO _edgecon_state(name, value, type) SELECT e->>'name' AS name, e->'value' AS value, e->>'type' AS type FROM jsonb_array_elements($1::jsonb) AS e WHERE e->'type' = '"S"'::jsonb RETURNING 1 ) SELECT 1 FROM be, fe; PREPARE _set_sql_state(text, text, text) AS INSERT INTO _edgecon_state(name, type, value) VALUES ($1, $2, pg_catalog.to_jsonb($3)) ON CONFLICT (name, type) DO UPDATE SET value = pg_catalog.to_jsonb($3); PREPARE _reset_sql_state(text, text) AS DELETE FROM _edgecon_state WHERE name = $1 AND type = $2; PREPARE _reset_sql_state_all AS DELETE FROM _edgecon_state WHERE type = 'S' OR type = 'L'; ''').strip().encode('utf-8') async def pg_connect( dsn_or_connection: str | rust_transport.ConnectionParams, *, backend_params: pg_params.BackendRuntimeParams, source_description: str, apply_init_script: bool = True, ) -> pgcon.PGConnection: global INIT_CON_SCRIPT if isinstance(dsn_or_connection, str): connection = rust_transport.ConnectionParams(dsn=dsn_or_connection) else: connection = dsn_or_connection # Note that we intentionally differ from the libpq connection behaviour # here: if we fail to connect with SSL enabled, we DO NOT retry with SSL # disabled. pgrawcon, pgconn = await rust_transport.create_postgres_connection( connection, lambda: pgcon.PGConnection(dbname=connection.database), source_description=source_description, ) connection = pgrawcon.connection pgconn.connection = pgrawcon.connection pgconn.parameter_status = pgrawcon.state.parameters cancellation_key = pgrawcon.state.cancellation_key if cancellation_key: pgconn.backend_pid = cancellation_key[0] pgconn.backend_secret = cancellation_key[1] pgconn.is_ssl = pgrawcon.state.ssl pgconn.addr = pgrawcon.addr if ( backend_params.has_create_role and backend_params.session_authorization_role ): sup_role = backend_params.session_authorization_role if connection.user != sup_role: # We used to use SET SESSION AUTHORIZATION here, there're some # security differences over SET ROLE, but as we don't allow # accessing Postgres directly through EdgeDB, SET ROLE is mostly # fine here. (Also hosted backends like Postgres on DigitalOcean # support only SET ROLE) await pgconn.sql_execute(f'SET ROLE {pg_qi(sup_role)}'.encode()) if 'in_hot_standby' in pgconn.parameter_status: # in_hot_standby is always present in Postgres 14 and above if pgconn.parameter_status['in_hot_standby'] == 'on': # Abort if we're connecting to a hot standby pgconn.terminate() raise pgerror.BackendError( fields=dict( M="cannot use a hot standby", C=pgerror.ERROR_READ_ONLY_SQL_TRANSACTION, ) ) if apply_init_script: if INIT_CON_SCRIPT is None: INIT_CON_SCRIPT = _build_init_con_script( # On lower versions of Postgres we use pg_is_in_recovery() to # check if it is a hot standby, and error out if it is. check_pg_is_in_recovery=( 'in_hot_standby' not in pgconn.parameter_status ), ) try: try: await pgconn.sql_execute(INIT_CON_SCRIPT) except pgcon.BackendError: from edb.pgsql import dbops, metaschema # ClearFELocalSQLSettingsFunction is needed by the # INIT_CON_SCRIPT, so we cannot simply patch it up # in the regular edb/pgsql/patches.py block = dbops.PLTopBlock() func = metaschema.ClearFELocalSQLSettingsFunction() dbops.CreateFunction(func, or_replace=True).generate(block) await pgconn.sql_execute(block.to_string().encode('utf-8')) await pgconn.sql_execute(INIT_CON_SCRIPT) except Exception: logger.exception( f"Failed to run init script for {pgconn.connection.to_dsn()}" ) await pgconn.close() raise return pgconn ================================================ FILE: edb/server/pgcon/cpythonx.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # cdef extern from "Python.h": void* PyMem_Calloc(size_t nelem, size_t elsize) ================================================ FILE: edb/server/pgcon/errors.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2008-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations ERROR_FEATURE_NOT_SUPPORTED = '0A000' ERROR_CARDINALITY_VIOLATION = '21000' # Class 22 — Data Exception ERROR_DATA_EXCEPTION = '22000' ERROR_NUMERIC_VALUE_OUT_OF_RANGE = '22003' ERROR_INVALID_DATETIME_FORMAT = '22007' ERROR_DATETIME_FIELD_OVERFLOW = '22008' ERROR_DIVISION_BY_ZERO = '22012' ERROR_INTERVAL_FIELD_OVERFLOW = '22015' ERROR_CHARACTER_NOT_IN_REPERTOIRE = '22021' ERROR_INVALID_PARAMETER_VALUE = '22023' ERROR_INVALID_TEXT_REPRESENTATION = '22P02' ERROR_INVALID_REGULAR_EXPRESSION = '2201B' ERROR_INVALID_LOGARITHM_ARGUMENT = '2201E' ERROR_INVALID_POWER_ARGUMENT = '2201F' ERROR_INVALID_ROW_COUNT_IN_LIMIT_CLAUSE = '2201W' ERROR_INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE = '2201X' # Class 23 — Integrity Constraint Violation ERROR_INTEGRITY_CONSTRAINT_VIOLATION = '23000' ERROR_RESTRICT_VIOLATION = '23001' ERROR_NOT_NULL_VIOLATION = '23502' ERROR_FOREIGN_KEY_VIOLATION = '23503' ERROR_UNIQUE_VIOLATION = '23505' ERROR_CHECK_VIOLATION = '23514' ERROR_EXCLUSION_VIOLATION = '23P01' # Class 25 - Invalid Transaction State ERRCODE_IN_FAILED_SQL_TRANSACTION = '25P02' ERROR_IDLE_IN_TRANSACTION_TIMEOUT = '25P03' ERROR_READ_ONLY_SQL_TRANSACTION = '25006' ERROR_INVALID_SQL_STATEMENT_NAME = '26000' # Class 28 - Invalid Authorization Specification ERROR_INVALID_AUTHORIZATION_SPECIFICATION = '28000' ERROR_INVALID_PASSWORD = '28P01' ERROR_INVALID_CATALOG_NAME = '3D000' ERROR_INVALID_CURSOR_NAME = '34000' ERROR_SERIALIZATION_FAILURE = '40001' ERROR_DEADLOCK_DETECTED = '40P01' # Class 42 - Syntax Error or Access Rule Violation ERROR_WRONG_OBJECT_TYPE = '42809' ERROR_INSUFFICIENT_PRIVILEGE = '42501' ERROR_UNDEFINED_COLUMN = '42703' ERROR_UNDEFINED_FUNCTION = '42883' ERROR_UNDEFINED_TABLE = '42P01' ERROR_UNDEFINED_PARAMETER = '42P02' ERROR_DUPLICATE_DATABASE = '42P04' ERROR_SYNTAX_ERROR = '42601' ERROR_DUPLICATE_CURSOR = '42P03' ERROR_DUPLICATE_PREPARED_STATEMENT = '42P05' ERROR_INVALID_COLUMN_REFERENCE = '42P10' ERROR_PROGRAM_LIMIT_EXCEEDED = '54000' ERROR_OBJECT_IN_USE = '55006' ERROR_QUERY_CANCELLED = '57014' ERROR_CANNOT_CONNECT_NOW = '57P03' # Class 08 - Connection Exception ERROR_CONNECTION_CLIENT_CANNOT_CONNECT = '08001' ERROR_CONNECTION_DOES_NOT_EXIST = '08003' ERROR_CONNECTION_REJECTION = '08004' ERROR_CONNECTION_FAILURE = '08006' ERROR_PROTOCOL_VIOLATION = '08P01' ERROR_INTERNAL_ERROR = 'XX000' CONNECTION_ERROR_CODES = [ ERROR_CANNOT_CONNECT_NOW, ERROR_CONNECTION_CLIENT_CANNOT_CONNECT, ERROR_CONNECTION_DOES_NOT_EXIST, ERROR_CONNECTION_REJECTION, ERROR_CONNECTION_FAILURE, ] class BackendError(Exception): def __init__(self, *, fields: dict[str, str]) -> None: msg = fields.get('M', f'error code {fields["C"]}') self.fields = fields super().__init__(msg) def code_is(self, code: str) -> bool: return self.fields["C"] == code def get_field(self, field: str) -> str | None: return self.fields.get(field) def get_error_class(fields: dict[str, str]) -> type[BackendError]: return error_class_map.get(fields["C"], BackendError) class BackendQueryCancelledError(BackendError): pass class BackendConnectionError(BackendError): pass class BackendPrivilegeError(BackendError): pass class BackendCatalogNameError(BackendError): pass error_class_map = { ERROR_CANNOT_CONNECT_NOW: BackendConnectionError, ERROR_CONNECTION_CLIENT_CANNOT_CONNECT: BackendConnectionError, ERROR_CONNECTION_DOES_NOT_EXIST: BackendConnectionError, ERROR_CONNECTION_REJECTION: BackendConnectionError, ERROR_CONNECTION_FAILURE: BackendConnectionError, ERROR_INSUFFICIENT_PRIVILEGE: BackendPrivilegeError, ERROR_QUERY_CANCELLED: BackendQueryCancelledError, ERROR_INVALID_CATALOG_NAME: BackendCatalogNameError, } def _build_fields(code, message, severity="ERROR", detail=None, hint=None): fields = { "S": severity, "V": severity, "C": code, "M": message, } if detail is not None: fields["D"] = detail if hint is not None: fields["H"] = hint return fields def new( code, message, severity="ERROR", detail=None, hint=None, **extra_fields ): fields = _build_fields(code, message, severity, detail, hint) fields.update(extra_fields) return get_error_class(fields)(fields=fields) class FeatureNotSupported(BackendError): def __init__(self, message="feature not supported", **kwargs): super().__init__( fields=_build_fields(ERROR_FEATURE_NOT_SUPPORTED, message, **kwargs) ) class ProtocolViolation(BackendError): def __init__(self, message="protocol violation", **kwargs): super().__init__( fields=_build_fields(ERROR_PROTOCOL_VIOLATION, message, **kwargs) ) class CannotConnectNowError(BackendError): def __init__(self, message="cannot connect now", **kwargs): super().__init__( fields=_build_fields(ERROR_CANNOT_CONNECT_NOW, message, **kwargs) ) class InvalidAuthSpec(BackendError): def __init__(self, message="invalid authorization specification", **kwargs): super().__init__( fields=_build_fields( ERROR_INVALID_AUTHORIZATION_SPECIFICATION, message, **kwargs ) ) ================================================ FILE: edb/server/pgcon/pgcon.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # include "./pgcon_sql.pxd" cimport cython cimport cpython from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ int32_t, uint32_t, int64_t, uint64_t from edb.server.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, FRBuffer, ) from edb.server.dbview cimport dbview from edb.server.pgproto.debug cimport PG_DEBUG from edb.server.cache cimport stmt_cache cdef enum PGTransactionStatus: PQTRANS_IDLE = 0 # connection idle PQTRANS_ACTIVE = 1 # command in progress PQTRANS_INTRANS = 2 # idle, within transaction block PQTRANS_INERROR = 3 # idle, within failed transaction PQTRANS_UNKNOWN = 4 # cannot determine status cdef enum PGAuthenticationState: PGAUTH_SUCCESSFUL = 0 PGAUTH_REQUIRED_KERBEROS = 2 PGAUTH_REQUIRED_PASSWORD = 3 PGAUTH_REQUIRED_PASSWORDMD5 = 5 PGAUTH_REQUIRED_SCMCRED = 6 PGAUTH_REQUIRED_GSS = 7 PGAUTH_REQUIRED_GSS_CONTINUE = 8 PGAUTH_REQUIRED_SSPI = 9 PGAUTH_REQUIRED_SASL = 10 PGAUTH_SASL_CONTINUE = 11 PGAUTH_SASL_FINAL = 12 @cython.final cdef class PGConnection: cdef: ReadBuffer buffer object loop str dbname object transport object msg_waiter readonly bint connected object connected_fut int32_t waiting_for_sync PGTransactionStatus xact_status public int32_t backend_pid public int32_t backend_secret public object parameter_status readonly object aborted_with_error stmt_cache.StatementsCache prep_stmts list last_parse_prep_stmts list log_listeners bint debug public object connection public object addr object server object tenant bint is_system_db bint close_requested readonly bint idle object cancel_fut bint _is_ssl public object pinned_by object last_state bint state_reset_needs_commit public object last_init_con_data str last_indirect_return PGSQLConnection _sql cdef before_command(self) cdef write(self, buf) cdef write_sync(self, WriteBuffer outbuf) cdef parse_error_message(self) cdef char parse_sync_message(self) cdef parse_parameter_status_message(self) cdef parse_notification(self) cdef fallthrough(self) cdef fallthrough_idle(self) cdef bint before_prepare( self, bytes stmt_name, int dbver, WriteBuffer outbuf) cdef write_sync(self, WriteBuffer outbuf) cdef send_sync(self) cdef make_clean_stmt_message(self, bytes stmt_name) cdef send_query_unit_group( self, object query_unit_group, bint sync, object bind_datas, bytes state, ssize_t start, ssize_t end, int dbver, object parse_array, object query_prefix, bint needs_commit_state, ) cdef _rewrite_copy_data( self, WriteBuffer wbuf, char *data, ssize_t data_len, ssize_t ncols, tuple elide_cols, dict type_id_map, tuple data_mending_desc, ) cdef _mend_copy_datum( self, WriteBuffer wbuf, FRBuffer *rbuf, object mending_desc, dict type_id_map, ) cdef inline str get_tenant_label(self) cpdef set_stmt_cache_size(self, int maxsize) ================================================ FILE: edb/server/pgcon/pgcon.pyi ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, Optional, ) import asyncio from edb.server import defines as edbdef from edb.server import pgconnparams class BackendError(Exception): def get_field(self, field: str) -> str | None: ... def code_is(self, code: str) -> bool: ... class BackendConnectionError(BackendError): ... class BackendPrivilegeError(BackendError): ... class BackendCatalogNameError(BackendError): ... class PGConnection(asyncio.Protocol): idle: bool backend_pid: int connection: pgconnparams.ConnectionParams addr: tuple[str, int] parameter_status: dict[str, str] backend_secret: int is_ssl: bool last_init_con_data: object def __init__(self, dbname): ... async def close(self): ... async def sql_execute( self, sql: bytes | tuple[bytes, ...], *, tx_isolation: edbdef.TxIsolationLevel | None = None, ) -> None: ... async def sql_fetch( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, state: Optional[bytes] = None, tx_isolation: edbdef.TxIsolationLevel | None = None, ) -> list[tuple[bytes, ...]]: ... async def sql_fetch_val( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, state: Optional[bytes] = None, tx_isolation: edbdef.TxIsolationLevel | None = None, ) -> bytes: ... async def sql_fetch_col( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, state: Optional[bytes] = None, tx_isolation: edbdef.TxIsolationLevel | None = None, ) -> list[bytes]: ... async def sql_describe( self, sql: bytes, param_type_oids: list[int] | None = None, ) -> tuple[list[int], list[tuple[str, int]]]: ... def terminate(self) -> None: ... def add_log_listener(self, cb: Callable[[str, str], None]) -> None: ... def get_server_parameter_status(self, parameter: str) -> Optional[str]: ... def set_stmt_cache_size(self, size: int) -> None: ... def set_server(self, server: object) -> None: ... async def signal_sysevent(self, event: str, *, dbname: str) -> None: ... def abort(self) -> None: ... def is_healthy(self) -> bool: ... async def listen_for_sysevent(self) -> None: ... def mark_as_system_db(self) -> None: ... def set_tenant(self, tenant: Any) -> None: ... def is_cancelling(self) -> bool: ... def start_pg_cancellation(self) -> None: ... def finish_pg_cancellation(self) -> None: ... SETUP_TEMP_TABLE_SCRIPT: str SETUP_CONFIG_CACHE_SCRIPT: str SETUP_DML_DUMMY_TABLE_SCRIPT: str ================================================ FILE: edb/server/pgcon/pgcon.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import ( Any, Callable, Dict, Optional, ) import asyncio import contextlib import decimal import codecs import hashlib import json import logging import os.path import sys import struct import textwrap import time cimport cython cimport cpython from . cimport cpythonx from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ int32_t, uint32_t, int64_t, uint64_t, \ UINT32_MAX from edb import errors from edb.edgeql import qltypes from edb.pgsql import common as pgcommon from edb.pgsql.common import quote_literal as pg_ql from edb.pgsql import codegen as pg_codegen from edb.server.pgproto cimport hton from edb.server.pgproto cimport pgproto from edb.server.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, FRBuffer, frb_init, frb_read, frb_get_len, frb_slice_from, ) from edb.server import compiler from edb.server.compiler import dbstate from edb.server import defines from edb.server.cache cimport stmt_cache from edb.server.dbview cimport dbview from edb.server.protocol cimport args_ser from edb.server.protocol cimport pg_ext from edb.server import metrics from edb.server.protocol cimport frontend from edb.common import debug from . import errors as pgerror DEF DATA_BUFFER_SIZE = 100_000 DEF PREP_STMTS_CACHE = 100 DEF COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0" cdef object CARD_NO_RESULT = compiler.Cardinality.NO_RESULT cdef object FMT_NONE = compiler.OutputFormat.NONE cdef dict POSTGRES_SHUTDOWN_ERR_CODES = { '57P01': 'admin_shutdown', '57P02': 'crash_shutdown', } cdef object EMPTY_SQL_STATE = b"{}" cdef WriteBuffer NO_ARGS = args_ser.combine_raw_args() cdef object logger = logging.getLogger('edb.server') include "./pgcon_sql.pyx" @cython.final cdef class EdegDBCodecContext(pgproto.CodecContext): cdef: object _codec def __cinit__(self): self._codec = codecs.lookup('utf-8') cpdef get_text_codec(self): return self._codec cdef is_encoding_utf8(self): return True @cython.final cdef class PGConnection: def __init__(self, dbname): self.buffer = ReadBuffer() self.loop = asyncio.get_running_loop() self.dbname = dbname self.connection = None self.transport = None self.msg_waiter = None self.prep_stmts = stmt_cache.StatementsCache(maxsize=PREP_STMTS_CACHE) self.connected_fut = self.loop.create_future() self.connected = False self.waiting_for_sync = 0 self.xact_status = PQTRANS_UNKNOWN self.backend_pid = -1 self.backend_secret = -1 self.parameter_status = dict() self.last_parse_prep_stmts = [] self.debug = debug.flags.server_proto self.last_indirect_return = None self.log_listeners = [] self.server = None self.tenant = None self.is_system_db = False self.close_requested = False self.pinned_by = None self.idle = True self.cancel_fut = None self._is_ssl = False # Set to the error the connection has been aborted with # by the backend. self.aborted_with_error = None # Session State Management # ------------------------ # Due to the fact that backend sessions are not pinned to frontend # sessions (EdgeQL, SQL, etc.) out of transactions, we need to sync # the backend state with the frontend state before executing queries. # # For performance reasons, we try to avoid syncing the state by # remembering the last state we've synced (last_state), and prefer # backend connection with the same state as the frontend. # # Syncing the state is done by resetting the session state as a whole, # followed by applying the new state, so that we don't have to track # individual config resets. Again for performance reasons, the state # sync is usually applied in the same implicit transaction as the # actual query in order to avoid extra round trips. # # Though, there are exceptions when we need to sync the state in a # separate transaction by inserting a SYNC message before the actual # query. This is because either that the query itself is a START # TRANSACTION / non-transactional command and a few other cases (see # _parse_execute() below), or the state change affects new transaction # creation like changing the `default_transaction_isolation` or its # siblings (see `needs_commit_state` parameters). In such cases, we # remember the `last_state` immediately after we received the # ReadyForQuery message caused by the SYNC above, if there are no # errors happened during state sync. Otherwise, we only remember # `last_state` after the implicit transaction ends successfully, when # we're sure the state is synced permanently. # # The actual queries may also change the session state. Regardless of # how we synced state previously, we always remember the `last_state` # after successful executions (also after transactions without errors, # implicit or explicit). # # Finally, resetting an existing session state that was positive in # `needs_commit_state` also requires a commit, because the new state # may not have `needs_commit_state`. To achieve this, we remember the # previous `needs_commit_state` in `state_reset_needs_commit` and # always insert a SYNC in the next state sync if it's True. Also, if # the actual queries modified those `default_transaction_*` settings, # we also need to set `state_reset_needs_commit` to True for the next # state sync(reset). See `needs_commit_after_state_sync()` functions # in dbview classes (EdgeQL and SQL). self.last_state = dbview.DEFAULT_STATE self.state_reset_needs_commit = False self._sql = PGSQLConnection(self) cpdef set_stmt_cache_size(self, int maxsize): self.prep_stmts.resize(maxsize) @property def is_ssl(self): return self._is_ssl @is_ssl.setter def is_ssl(self, value): self._is_ssl = value def debug_print(self, *args): print( '::PGCONN::', hex(id(self)), f'pgpid: {self.backend_pid}', *args, file=sys.stderr, ) def in_tx(self): return ( self.xact_status == PQTRANS_INTRANS or self.xact_status == PQTRANS_INERROR ) def is_cancelling(self): return self.cancel_fut is not None def start_pg_cancellation(self): if self.cancel_fut is not None: raise RuntimeError('another cancellation is in progress') self.cancel_fut = self.loop.create_future() def finish_pg_cancellation(self): assert self.cancel_fut is not None self.cancel_fut.set_result(True) def get_server_parameter_status(self, parameter: str) -> Optional[str]: return self.parameter_status.get(parameter) def abort(self): if not self.transport: return self.close_requested = True self.transport.abort() self.transport = None self.connected = False self.prep_stmts.clear() def terminate(self): if not self.transport: return self.close_requested = True self.write(WriteBuffer.new_message(b'X').end_message()) self.transport.close() self.transport = None self.connected = False self.prep_stmts.clear() if self.msg_waiter and not self.msg_waiter.done(): self.msg_waiter.set_exception(ConnectionAbortedError()) self.msg_waiter = None async def close(self): self.terminate() def set_tenant(self, tenant): self.tenant = tenant self.server = tenant.server def mark_as_system_db(self): if self.tenant.get_backend_runtime_params().has_create_database: assert defines.EDGEDB_SYSTEM_DB in self.dbname self.is_system_db = True def add_log_listener(self, cb): self.log_listeners.append(cb) async def listen_for_sysevent(self): try: if self.tenant.get_backend_runtime_params().has_create_database: assert defines.EDGEDB_SYSTEM_DB in self.dbname await self.sql_execute(b'LISTEN __edgedb_sysevent__;') except Exception: try: self.abort() finally: raise async def signal_sysevent(self, event, **kwargs): if self.tenant.get_backend_runtime_params().has_create_database: assert defines.EDGEDB_SYSTEM_DB in self.dbname event = json.dumps({ 'event': event, 'server_id': self.server._server_id, 'args': kwargs, }) query = f""" SELECT pg_notify( '__edgedb_sysevent__', {pg_ql(event)} ) """.encode() await self.sql_execute(query) async def sync(self): if self.waiting_for_sync: raise RuntimeError('a "sync" has already been requested') self.before_command() try: self.waiting_for_sync += 1 self.write(_SYNC_MESSAGE) while True: if not self.buffer.take_message(): await self.wait_for_message() mtype = self.buffer.get_message_type() if mtype == b'Z': self.parse_sync_message() return else: self.fallthrough() finally: await self.after_command() async def wait_for_sync(self): error = None try: while True: if not self.buffer.take_message(): await self.wait_for_message() mtype = self.buffer.get_message_type() if mtype == b'Z': return self.parse_sync_message() elif mtype == b'E': # ErrorResponse er_cls, fields = self.parse_error_message() error = er_cls(fields=fields) else: if not self.parse_notification(): if PG_DEBUG or self.debug: self.debug_print(f'PGCon.wait_for_sync: discarding ' f'{chr(mtype)!r} message') self.buffer.discard_message() finally: if error is not None: # Postgres might send an ErrorResponse if, e.g. # in implicit transaction fails to commit due to # serialization conflicts. raise error cdef inline str get_tenant_label(self): if self.tenant is None: return "system" else: return self.tenant.get_instance_name() cdef bint before_prepare( self, bytes stmt_name, int dbver, WriteBuffer outbuf, ): cdef bint parse = 1 while self.prep_stmts.needs_cleanup(): stmt_name_to_clean, _ = self.prep_stmts.cleanup_one() if self.debug: self.debug_print(f"discarding ps {stmt_name_to_clean!r}") outbuf.write_buffer( self.make_clean_stmt_message(stmt_name_to_clean)) if stmt_name in self.prep_stmts: if self.prep_stmts[stmt_name] == dbver: parse = 0 else: if self.debug: self.debug_print(f"discarding ps {stmt_name!r}") outbuf.write_buffer( self.make_clean_stmt_message(stmt_name)) del self.prep_stmts[stmt_name] return parse cdef write_sync(self, WriteBuffer outbuf): outbuf.write_bytes(_SYNC_MESSAGE) self.waiting_for_sync += 1 cdef send_sync(self): self.write(_SYNC_MESSAGE) self.waiting_for_sync += 1 def _build_apply_state_req(self, bytes serstate, WriteBuffer out): cdef: WriteBuffer buf if self.debug: self.debug_print("Syncing state: ", serstate) buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(b'_clear_state') # statement name buf.write_int16(0) # number of format codes buf.write_int16(0) # number of parameters buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(b'_reset_session_config') # statement name buf.write_int16(0) # number of format codes buf.write_int16(0) # number of parameters buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) if serstate is not None: buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(b'_apply_state') # statement name buf.write_int16(1) # number of format codes buf.write_int16(1) # binary buf.write_int16(1) # number of parameters buf.write_int32(len(serstate) + 1) buf.write_byte(1) # jsonb format version buf.write_bytes(serstate) buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) def _build_apply_sql_state_req(self, bytes state, WriteBuffer out): cdef: WriteBuffer buf buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(b'_clear_state') # statement name buf.write_int16(0) # number of format codes buf.write_int16(0) # number of parameters buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(b'_reset_session_config') # statement name buf.write_int16(0) # number of format codes buf.write_int16(0) # number of parameters buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) if state != EMPTY_SQL_STATE: buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(b'_apply_sql_state') # statement name buf.write_int16(1) # number of format codes buf.write_int16(1) # binary buf.write_int16(1) # number of parameters buf.write_int32(len(state) + 1) buf.write_byte(1) # jsonb format version buf.write_bytes(state) buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) async def _parse_apply_state_resp(self, int expected_completed): cdef: int num_completed = 0 while True: if not self.buffer.take_message(): await self.wait_for_message() mtype = self.buffer.get_message_type() if mtype == b'2' or mtype == b'D': # BindComplete or Data self.buffer.discard_message() elif mtype == b'E': er_cls, er_fields = self.parse_error_message() raise er_cls(fields=er_fields) elif mtype == b'C': self.buffer.discard_message() num_completed += 1 if num_completed == expected_completed: return else: self.fallthrough() @contextlib.asynccontextmanager async def parse_execute_script_context(self): self.before_command() started_at = time.monotonic() try: try: yield finally: while self.waiting_for_sync: await self.wait_for_sync() finally: metrics.backend_query_duration.observe( time.monotonic() - started_at, self.get_tenant_label() ) await self.after_command() cdef send_query_unit_group( self, object query_unit_group, bint sync, object bind_datas, bytes state, ssize_t start, ssize_t end, int dbver, object parse_array, object query_prefix, bint needs_commit_state, ): # parse_array is an array of booleans for output with the same size as # the query_unit_group, indicating if each unit is freshly parsed cdef: WriteBuffer out WriteBuffer buf WriteBuffer bind_data bytes stmt_name ssize_t idx = start bytes sql tuple sqls out = WriteBuffer.new() parsed = set() if state is not None and start == 0: self._build_apply_state_req(state, out) # N.B: Condition here needs to match that in wait_for_state_resp if needs_commit_state or self.state_reset_needs_commit: self.write_sync(out) # Build the parse_array first, closing statements if needed before # actually executing any command that may fail, in order to ensure # self.prep_stmts is always in sync with the actual open statements for query_unit in query_unit_group.units[start:end]: if query_unit.system_config: raise RuntimeError( "CONFIGURE INSTANCE command is not allowed in scripts" ) stmt_name = query_unit.sql_hash if stmt_name: # The same EdgeQL query may show up twice in the same script. # We just need to know and skip if we've already parsed the # same query within current send batch, because self.prep_stmts # will be updated before the next batch, with maybe a different # dbver after DDL. if stmt_name not in parsed and self.before_prepare( stmt_name, dbver, out ): parse_array[idx] = True parsed.add(stmt_name) idx += 1 idx = start for query_unit, bind_data in zip( query_unit_group.units[start:end], bind_datas): stmt_name = query_unit.sql_hash sql = query_unit.sql if query_prefix: sql = query_prefix + sql if stmt_name: if parse_array[idx]: buf = WriteBuffer.new_message(b'P') buf.write_bytestring(stmt_name) buf.write_bytestring(sql) buf.write_int16(0) out.write_buffer(buf.end_message()) metrics.query_size.observe( len(sql), self.get_tenant_label(), 'compiled', ) buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(stmt_name) buf.write_buffer(bind_data) out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) else: buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') # statement name buf.write_bytestring(sql) buf.write_int16(0) out.write_buffer(buf.end_message()) metrics.query_size.observe( len(sql), self.get_tenant_label(), 'compiled' ) buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(b'') # statement name buf.write_buffer(bind_data) out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) idx += 1 if sync: self.write_sync(out) else: out.write_bytes(FLUSH_MESSAGE) self.write(out) async def force_error(self): self.before_command() # Send a bogus parse that will cause an error to be generated out = WriteBuffer.new() buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') buf.write_bytestring(b'') buf.write_int16(0) # Then do a sync to get everything executed and lined back up out.write_buffer(buf.end_message()) self.write_sync(out) self.write(out) try: await self.wait_for_sync() except pgerror.BackendError as e: pass else: raise RuntimeError("Didn't get expected error!") finally: await self.after_command() async def wait_for_state_resp( self, bytes state, bint state_sync, bint needs_commit_state ): # N.B: Condition here needs to match that in send_query_unit_group if state_sync or self.state_reset_needs_commit: try: await self._parse_apply_state_resp(2 if state is None else 3) finally: await self.wait_for_sync() self.last_state = state self.state_reset_needs_commit = needs_commit_state else: await self._parse_apply_state_resp(2 if state is None else 3) async def wait_for_command( self, object query_unit, bint parse, int dbver, *, bint ignore_data, frontend.AbstractFrontendConnection fe_conn = None, ): cdef WriteBuffer buf = None result = None while True: if not self.buffer.take_message(): await self.wait_for_message() mtype = self.buffer.get_message_type() try: if mtype == b'D': # DataRow if ignore_data: self.buffer.discard_message() elif fe_conn is None: ncol = self.buffer.read_int16() row = [] for i in range(ncol): coll = self.buffer.read_int32() if coll == -1: row.append(None) else: row.append(self.buffer.read_bytes(coll)) if result is None: result = [] result.append(row) else: if buf is None: buf = WriteBuffer.new() self.buffer.redirect_messages(buf, b'D', 0) if buf.len() >= DATA_BUFFER_SIZE: fe_conn.write(buf) buf = None elif mtype == b'C': ## result # CommandComplete self.buffer.discard_message() if buf is not None: fe_conn.write(buf) buf = None return result elif mtype == b'1': # ParseComplete self.buffer.discard_message() if parse: self.prep_stmts[query_unit.sql_hash] = dbver elif mtype == b'E': ## result # ErrorResponse er_cls, er_fields = self.parse_error_message() raise er_cls(fields=er_fields) elif mtype == b'n': # NoData self.buffer.discard_message() elif mtype == b's': ## result # PortalSuspended self.buffer.discard_message() return result elif mtype == b'2': # BindComplete self.buffer.discard_message() elif mtype == b'3': # CloseComplete self.buffer.discard_message() elif mtype == b'I': ## result # EmptyQueryResponse self.buffer.discard_message() else: self.fallthrough() finally: self.buffer.finish_message() async def _describe( self, query: bytes, param_type_oids: Optional[list[int]], ): cdef: WriteBuffer out out = WriteBuffer.new() buf = WriteBuffer.new_message(b"P") # Parse buf.write_bytestring(b"") buf.write_bytestring(query) if param_type_oids: buf.write_int16(len(param_type_oids)) for oid in param_type_oids: buf.write_int32(oid) else: buf.write_int16(0) out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b"D") # Describe buf.write_byte(b"S") buf.write_bytestring(b"") out.write_buffer(buf.end_message()) out.write_bytes(FLUSH_MESSAGE) self.write(out) param_desc = None result_desc = None try: buf = None while True: if not self.buffer.take_message(): await self.wait_for_message() mtype = self.buffer.get_message_type() try: if mtype == b'1': # ParseComplete self.buffer.discard_message() elif mtype == b't': # ParameterDescription param_desc = self._decode_param_desc(self.buffer) elif mtype == b'T': # RowDescription result_desc = self._decode_row_desc(self.buffer) break elif mtype == b'n': # NoData self.buffer.discard_message() param_desc = [] result_desc = [] break elif mtype == b'E': ## result # ErrorResponse er_cls, er_fields = self.parse_error_message() raise er_cls(fields=er_fields) else: self.fallthrough() finally: self.buffer.finish_message() except Exception: self.send_sync() await self.wait_for_sync() raise if param_desc is None: raise RuntimeError( "did not receive ParameterDescription from backend " "in response to Describe" ) if result_desc is None: raise RuntimeError( "did not receive RowDescription from backend " "in response to Describe" ) return param_desc, result_desc def _decode_param_desc(self, buf: ReadBuffer): cdef: int16_t nparams uint32_t p_oid list result = [] nparams = buf.read_int16() for _ in range(nparams): p_oid = buf.read_int32() result.append(p_oid) return result def _decode_row_desc(self, buf: ReadBuffer): cdef: int16_t nfields bytes f_name uint32_t f_table_oid int16_t f_column_num uint32_t f_dt_oid int16_t f_dt_size int32_t f_dt_mod int16_t f_format list result nfields = buf.read_int16() result = [] for _ in range(nfields): f_name = buf.read_null_str() f_table_oid = buf.read_int32() f_column_num = buf.read_int16() f_dt_oid = buf.read_int32() f_dt_size = buf.read_int16() f_dt_mod = buf.read_int32() f_format = buf.read_int16() result.append((f_name.decode("utf-8"), f_dt_oid)) return result async def sql_describe( self, query: bytes, param_type_oids: Optional[list[int]] = None, ) -> tuple[list[int], list[tuple[str, int]]]: self.before_command() started_at = time.monotonic() try: return await self._describe(query, param_type_oids) finally: await self.after_command() async def _parse_execute( self, query, frontend.AbstractFrontendConnection fe_conn, WriteBuffer bind_data, bint use_prep_stmt, bytes state, int dbver, bint use_pending_func_cache, tx_isolation, list param_data_types, bytes query_prefix, bint needs_commit_state, ): cdef: WriteBuffer out WriteBuffer buf bytes stmt_name bytes sql tuple sqls bytes prologue_sql bytes epilogue_sql int32_t dat_len bint parse = 1 bint state_sync = 0 bint has_result = query.cardinality is not CARD_NO_RESULT bint discard_result = ( fe_conn is not None and query.output_format == FMT_NONE) uint64_t msgs_num uint64_t msgs_executed = 0 uint64_t i out = WriteBuffer.new() if state is not None: self._build_apply_state_req(state, out) if ( query.tx_id or not query.is_transactional or query.run_and_rollback or tx_isolation is not None or needs_commit_state or self.state_reset_needs_commit ): # This query has START TRANSACTION or non-transactional command # like CREATE DATABASE in it. # Restoring state must be performed in a separate # implicit transaction (otherwise START TRANSACTION DEFERRABLE # or CREATE DATABASE (since PG 14.7) would fail). # Hence - inject a SYNC after a state restore step. state_sync = 1 self.write_sync(out) if query.run_and_rollback or tx_isolation is not None: if self.in_tx(): sp_name = f'_edb_{time.monotonic_ns()}' prologue_sql = f'SAVEPOINT {sp_name}'.encode('utf-8') else: sp_name = None prologue_sql = b'START TRANSACTION' if tx_isolation is not None: prologue_sql += ( f' ISOLATION LEVEL {tx_isolation._value_}' .encode('utf-8') ) buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') buf.write_bytestring(prologue_sql) buf.write_int16(0) out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(b'') # statement name buf.write_int16(0) # number of format codes buf.write_int16(0) # number of parameters buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) # Insert a SYNC as a boundary of the parsing logic later self.write_sync(out) if use_pending_func_cache and query.cache_func_call: sql, stmt_name = query.cache_func_call sqls = (query_prefix + sql,) else: sqls = (query_prefix + query.sql,) + query.db_op_trailer stmt_name = query.sql_hash msgs_num = (len(sqls)) if use_prep_stmt: parse = self.before_prepare(stmt_name, dbver, out) else: stmt_name = b'' if parse: if len(self.last_parse_prep_stmts): for stmt_name_to_clean in self.last_parse_prep_stmts: out.write_buffer( self.make_clean_stmt_message(stmt_name_to_clean)) self.last_parse_prep_stmts.clear() if stmt_name == b'' and msgs_num > 1: i = 0 for sql in sqls: pname = b'__p%d__' % i self.last_parse_prep_stmts.append(pname) buf = WriteBuffer.new_message(b'P') buf.write_bytestring(pname) buf.write_bytestring(sql) buf.write_int16(0) out.write_buffer(buf.end_message()) i += 1 metrics.query_size.observe( len(sql), self.get_tenant_label(), 'compiled' ) else: if len(sqls) != 1: raise errors.InternalServerError( 'cannot PARSE more than one SQL query ' 'in non-anonymous mode') msgs_num = 1 buf = WriteBuffer.new_message(b'P') buf.write_bytestring(stmt_name) buf.write_bytestring(sqls[0]) if param_data_types: buf.write_int16(len(param_data_types)) for oid in param_data_types: buf.write_int32(oid) else: buf.write_int16(0) out.write_buffer(buf.end_message()) metrics.query_size.observe( len(sqls[0]), self.get_tenant_label(), 'compiled' ) assert bind_data is not None if stmt_name == b'' and msgs_num > 1: for s in self.last_parse_prep_stmts: buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(s) # statement name buf.write_buffer(bind_data) out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) else: buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(stmt_name) # statement name buf.write_buffer(bind_data) out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) if query.run_and_rollback or tx_isolation is not None: if query.run_and_rollback: if sp_name: sql = f'ROLLBACK TO SAVEPOINT {sp_name}'.encode('utf-8') else: sql = b'ROLLBACK' else: sql = b'COMMIT' buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') buf.write_bytestring(sql) buf.write_int16(0) out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(b'') # statement name buf.write_int16(0) # number of format codes buf.write_int16(0) # number of parameters buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) elif query.append_tx_op: if query.tx_commit: sql = b'COMMIT' elif query.tx_rollback: sql = b'ROLLBACK' else: raise errors.InternalServerError( "QueryUnit.append_tx_op is set but none of the " "Query.tx_ properties are" ) buf = WriteBuffer.new_message(b'P') buf.write_bytestring(b'') buf.write_bytestring(sql) buf.write_int16(0) out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'B') buf.write_bytestring(b'') # portal name buf.write_bytestring(b'') # statement name buf.write_int16(0) # number of format codes buf.write_int16(0) # number of parameters buf.write_int16(0) # number of result columns out.write_buffer(buf.end_message()) buf = WriteBuffer.new_message(b'E') buf.write_bytestring(b'') # portal name buf.write_int32(0) # limit: 0 - return all rows out.write_buffer(buf.end_message()) self.write_sync(out) self.write(out) result = None try: if state is not None: await self.wait_for_state_resp( state, state_sync, needs_commit_state) if query.run_and_rollback or tx_isolation is not None: await self.wait_for_sync() buf = None while True: if not self.buffer.take_message(): await self.wait_for_message() mtype = self.buffer.get_message_type() try: if mtype == b'D': # DataRow if discard_result: self.buffer.discard_message() continue if not has_result and fe_conn is not None: raise errors.InternalServerError( f'query that was inferred to have ' f'no data returned received a DATA package; ' f'query: {sqls}') if fe_conn is None: ncol = self.buffer.read_int16() row = [] for i in range(ncol): dat_len = self.buffer.read_int32() if dat_len == -1: row.append(None) else: row.append( self.buffer.read_bytes(dat_len)) if result is None: result = [] result.append(row) else: if buf is None: buf = WriteBuffer.new() self.buffer.redirect_messages(buf, b'D', 0) if buf.len() >= DATA_BUFFER_SIZE: fe_conn.write(buf) buf = None elif mtype == b'C': ## result # CommandComplete self.buffer.discard_message() if buf is not None: fe_conn.write(buf) buf = None msgs_executed += 1 if msgs_executed == msgs_num: break elif mtype == b'1' and parse: # ParseComplete self.buffer.discard_message() self.prep_stmts[stmt_name] = dbver elif mtype == b'E': ## result # ErrorResponse er_cls, er_fields = self.parse_error_message() raise er_cls(fields=er_fields) elif mtype == b'n': # NoData self.buffer.discard_message() elif mtype == b's': ## result # PortalSuspended self.buffer.discard_message() break elif mtype == b'2': # BindComplete self.buffer.discard_message() elif mtype == b'I': ## result # EmptyQueryResponse self.buffer.discard_message() break elif mtype == b'3': # CloseComplete self.buffer.discard_message() else: self.fallthrough() finally: self.buffer.finish_message() finally: await self.wait_for_sync() return result async def parse_execute( self, *, query, WriteBuffer bind_data = NO_ARGS, list param_data_types = None, frontend.AbstractFrontendConnection fe_conn = None, bint use_prep_stmt = False, bytes state = None, int dbver = 0, bint use_pending_func_cache = 0, tx_isolation = None, query_prefix = None, bint needs_commit_state = False, ): self.before_command() started_at = time.monotonic() try: return await self._parse_execute( query, fe_conn, bind_data, use_prep_stmt, state, dbver, use_pending_func_cache, tx_isolation, param_data_types, query_prefix or b'', needs_commit_state, ) finally: metrics.backend_query_duration.observe( time.monotonic() - started_at, self.get_tenant_label() ) await self.after_command() async def sql_fetch( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, state: Optional[bytes] = None, tx_isolation: defines.TxIsolationLevel | None = None, ) -> list[tuple[bytes, ...]]: if use_prep_stmt: sql_digest = hashlib.sha1() sql_digest.update(sql) sql_hash = sql_digest.hexdigest().encode('latin1') else: sql_hash = None query = compiler.QueryUnit( sql=sql, sql_hash=sql_hash, status=b"", ) return await self.parse_execute( query=query, bind_data=args_ser.combine_raw_args(args), use_prep_stmt=use_prep_stmt, state=state, tx_isolation=tx_isolation, ) async def sql_fetch_val( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, state: Optional[bytes] = None, tx_isolation: defines.TxIsolationLevel | None = None, ) -> bytes: data = await self.sql_fetch( sql, args=args, use_prep_stmt=use_prep_stmt, state=state, tx_isolation=tx_isolation, ) if data is None or len(data) == 0: return None elif len(data) > 1: raise RuntimeError( f"received too many rows for sql_fetch_val({sql!r})") row = data[0] if len(row) != 1: raise RuntimeError( f"received too many columns for sql_fetch_val({sql!r})") return row[0] async def sql_fetch_col( self, sql: bytes, *, args: tuple[bytes, ...] | list[bytes] = (), use_prep_stmt: bool = False, state: Optional[bytes] = None, tx_isolation: defines.TxIsolationLevel | None = None, ) -> list[bytes]: data = await self.sql_fetch( sql, args=args, use_prep_stmt=use_prep_stmt, state=state, tx_isolation=tx_isolation, ) if not data: return [] else: if len(data[0]) != 1: raise RuntimeError( f"received too many columns for sql_fetch_col({sql!r})") return [row[0] for row in data] async def _sql_execute(self, bytes sql): cdef: WriteBuffer out WriteBuffer buf out = WriteBuffer.new() buf = WriteBuffer.new_message(b'Q') buf.write_bytestring(sql) out.write_buffer(buf.end_message()) self.waiting_for_sync += 1 self.write(out) exc = None result = None while True: if not self.buffer.take_message(): await self.wait_for_message() mtype = self.buffer.get_message_type() try: if mtype == b'D': self.buffer.discard_message() elif mtype == b'T': # RowDescription self.buffer.discard_message() elif mtype == b'C': # CommandComplete self.buffer.discard_message() elif mtype == b'E': # ErrorResponse exc = self.parse_error_message() elif mtype == b'I': # EmptyQueryResponse self.buffer.discard_message() elif mtype == b'Z': self.parse_sync_message() break else: self.fallthrough() finally: self.buffer.finish_message() if exc is not None: raise exc[0](fields=exc[1]) else: return result async def sql_execute(self, sql: bytes | tuple[bytes, ...]) -> None: self.before_command() started_at = time.monotonic() if isinstance(sql, tuple): sql_string = b";\n".join(sql) else: sql_string = sql try: return await self._sql_execute(sql_string) finally: metrics.backend_query_duration.observe( time.monotonic() - started_at, self.get_tenant_label() ) await self.after_command() async def sql_apply_state( self, dbv: pg_ext.ConnectionView, ): self.before_command() try: state = dbv.serialize_state() if state is not None: buf = WriteBuffer.new() self._build_apply_sql_state_req(state, buf) self.write_sync(buf) self.write(buf) await self._parse_apply_state_resp( 2 if state != EMPTY_SQL_STATE else 1 ) await self.wait_for_sync() self.last_state = state self.state_reset_needs_commit = ( dbv.needs_commit_after_state_sync()) finally: await self.after_command() async def sql_extended_query( self, actions, fe_conn: frontend.AbstractFrontendConnection, dbver: int, dbv: pg_ext.ConnectionView, ) -> tuple[bool, bool]: self.before_command() try: state = self._sql._write_sql_extended_query(actions, dbver, dbv) if state is not None: await self._parse_apply_state_resp( 2 if state != EMPTY_SQL_STATE else 1 ) await self.wait_for_sync() self.last_state = state self.state_reset_needs_commit = ( dbv.needs_commit_after_state_sync()) try: return await self._sql._parse_sql_extended_query( actions, fe_conn, dbver, dbv, ) finally: if not dbv.in_tx(): self.last_state = dbv.serialize_state() self.state_reset_needs_commit = ( dbv.needs_commit_after_state_sync()) finally: await self.after_command() def _write_error_position( self, msg_buf: WriteBuffer, query: bytes, pos_bytes: bytes, source_map: Optional[pg_codegen.SourceMap], offset: int = 0, ): if source_map: pos = int(pos_bytes.decode('utf8')) if offset > 0 or pos + offset > 0: pos += offset pos = source_map.translate(pos) # pg uses 1-based indexes pos += 1 pos_bytes = str(pos).encode('utf8') msg_buf.write_byte(b'P') # Position else: msg_buf.write_byte(b'q') # Internal query msg_buf.write_bytestring(query) msg_buf.write_byte(b'p') # Internal position msg_buf.write_bytestring(pos_bytes) def load_last_ddl_return(self, object query_unit): if query_unit.ddl_stmt_id: data = self.last_indirect_return if data: ret = json.loads(data) if ret['ddl_stmt_id'] != query_unit.ddl_stmt_id: raise RuntimeError( 'unrecognized data notice after a DDL command: ' 'data_stmt_id do not match: expected ' f'{query_unit.ddl_stmt_id!r}, got ' f'{ret["ddl_stmt_id"]!r}' ) return ret else: raise RuntimeError( 'missing the required data notice after a DDL command' ) async def _dump(self, block, output_queue, fragment_suggested_size): cdef: WriteBuffer buf WriteBuffer qbuf WriteBuffer out qbuf = WriteBuffer.new_message(b'Q') qbuf.write_bytestring(block.sql_copy_stmt) qbuf.end_message() self.write(qbuf) self.waiting_for_sync += 1 er = None out = None i = 0 while True: if not self.buffer.take_message(): await self.wait_for_message() mtype = self.buffer.get_message_type() if mtype == b'H': # CopyOutResponse self.buffer.discard_message() elif mtype == b'd': # CopyData if out is None: out = WriteBuffer.new() if i == 0: # The first COPY IN message is prefixed with # `COPY_SIGNATURE` -- strip it. first = self.buffer.consume_message() if first[:len(COPY_SIGNATURE)] != COPY_SIGNATURE: raise RuntimeError('invalid COPY IN message') buf = WriteBuffer.new_message(b'd') buf.write_bytes(first[len(COPY_SIGNATURE) + 8:]) buf.end_message() out.write_buffer(buf) if out._length >= fragment_suggested_size: await output_queue.put((block, i, out)) i += 1 out = None if (not self.buffer.take_message() or self.buffer.get_message_type() != b'd'): continue self.buffer.redirect_messages( out, b'd', fragment_suggested_size) if out._length >= fragment_suggested_size: self.transport.pause_reading() await output_queue.put((block, i, out)) self.transport.resume_reading() i += 1 out = None elif mtype == b'c': # CopyDone self.buffer.discard_message() elif mtype == b'C': # CommandComplete if out is not None: await output_queue.put((block, i, out)) self.buffer.discard_message() elif mtype == b'E': er = self.parse_error_message() elif mtype == b'Z': self.parse_sync_message() break else: self.fallthrough() if er is not None: raise er[0](fields=er[1]) async def dump(self, input_queue, output_queue, fragment_suggested_size): self.before_command() try: while True: try: block = input_queue.pop() except IndexError: await output_queue.put(None) return await self._dump(block, output_queue, fragment_suggested_size) finally: # In case we errored while the transport was suspended. self.transport.resume_reading() await self.after_command() async def _restore(self, restore_block, bytes data, dict type_map): cdef: WriteBuffer buf WriteBuffer qbuf WriteBuffer out char* cbuf ssize_t clen ssize_t ncols qbuf = WriteBuffer.new_message(b'Q') qbuf.write_bytestring(restore_block.sql_copy_stmt) qbuf.end_message() self.write(qbuf) self.waiting_for_sync += 1 er = None while True: if not self.buffer.take_message(): await self.wait_for_message() mtype = self.buffer.get_message_type() if mtype == b'G': # CopyInResponse self.buffer.read_byte() ncols = self.buffer.read_int16() self.buffer.discard_message() break elif mtype == b'E': er = self.parse_error_message() elif mtype == b'Z': self.parse_sync_message() break else: self.fallthrough() if er is not None: raise er[0](fields=er[1]) buf = WriteBuffer.new() cpython.PyBytes_AsStringAndSize(data, &cbuf, &clen) if ( restore_block.compat_elided_cols or any(desc for desc in restore_block.data_mending_desc) ): self._rewrite_copy_data( buf, cbuf, clen, ncols, restore_block.data_mending_desc, type_map, restore_block.compat_elided_cols, ) else: if cbuf[0] != b'd': raise RuntimeError('unexpected dump data message structure') ln = hton.unpack_int32(cbuf + 1) buf.write_byte(b'd') buf.write_int32(ln + len(COPY_SIGNATURE) + 8) buf.write_bytes(COPY_SIGNATURE) buf.write_int32(0) buf.write_int32(0) buf.write_cstr(cbuf + 5, clen - 5) self.write(buf) qbuf = WriteBuffer.new_message(b'c') qbuf.end_message() self.write(qbuf) while True: if not self.buffer.take_message(): await self.wait_for_message() mtype = self.buffer.get_message_type() if mtype == b'C': # CommandComplete self.buffer.discard_message() elif mtype == b'E': er = self.parse_error_message() elif mtype == b'Z': self.parse_sync_message() break if er is not None: raise er[0](fields=er[1]) cdef _rewrite_copy_data( self, WriteBuffer wbuf, char* data, ssize_t data_len, ssize_t ncols, tuple data_mending_desc, dict type_id_map, tuple elided_cols, ): """Rewrite the binary COPY stream.""" cdef: FRBuffer rbuf FRBuffer datum_buf ssize_t i ssize_t real_ncols int8_t *elide int8_t elided int32_t datum_len char copy_msg_byte int16_t copy_msg_ncols const char *datum bint first = True bint received_eof = False real_ncols = ncols + len(elided_cols) frb_init(&rbuf, data, data_len) elide = cpythonx.PyMem_Calloc( real_ncols, sizeof(int8_t)) try: for col in elided_cols: elide[col] = 1 mbuf = WriteBuffer.new() while frb_get_len(&rbuf): if received_eof: raise RuntimeError('received CopyData after EOF') mbuf.start_message(b'd') copy_msg_byte = frb_read(&rbuf, 1)[0] if copy_msg_byte != b'd': raise RuntimeError( 'unexpected dump data message structure') frb_read(&rbuf, 4) if first: mbuf.write_bytes(COPY_SIGNATURE) mbuf.write_int32(0) mbuf.write_int32(0) first = False copy_msg_ncols = hton.unpack_int16(frb_read(&rbuf, 2)) if copy_msg_ncols == -1: # BINARY COPY EOF marker mbuf.write_int16(copy_msg_ncols) received_eof = True mbuf.end_message() wbuf.write_buffer(mbuf) mbuf.reset() continue else: mbuf.write_int16(ncols) # Tuple data for i in range(real_ncols): datum_len = hton.unpack_int32(frb_read(&rbuf, 4)) elided = elide[i] if not elided: mbuf.write_int32(datum_len) if datum_len != -1: datum = frb_read(&rbuf, datum_len) if not elided: datum_mending_desc = data_mending_desc[i] if ( datum_mending_desc is not None and datum_mending_desc.needs_mending ): frb_init(&datum_buf, datum, datum_len) self._mend_copy_datum( mbuf, &datum_buf, datum_mending_desc, type_id_map, ) else: mbuf.write_cstr(datum, datum_len) mbuf.end_message() wbuf.write_buffer(mbuf) mbuf.reset() finally: cpython.PyMem_Free(elide) cdef _mend_copy_datum( self, WriteBuffer wbuf, FRBuffer *rbuf, object mending_desc, dict type_id_map, ): cdef: ssize_t remainder int32_t ndims int32_t i int32_t nelems int32_t dim const char *buf FRBuffer elem_buf int32_t elem_len object elem_mending_desc kind = mending_desc.schema_object_class if kind is qltypes.SchemaObjectClass.ARRAY_TYPE: # Dimensions and flags buf = frb_read(rbuf, 8) ndims = hton.unpack_int32(buf) wbuf.write_cstr(buf, 8) elem_mending_desc = mending_desc.elements[0] # Discard the original element OID. frb_read(rbuf, 4) # Write the correct element OID. elem_type_id = elem_mending_desc.schema_type_id elem_type_oid = type_id_map[elem_type_id] wbuf.write_int32(elem_type_oid) if ndims == 0: # Empty array return if ndims != 1: raise ValueError( 'unexpected non-single dimension array' ) if mending_desc.needs_mending: # dim and lbound buf = frb_read(rbuf, 8) nelems = hton.unpack_int32(buf) wbuf.write_cstr(buf, 8) for i in range(nelems): elem_len = hton.unpack_int32(frb_read(rbuf, 4)) wbuf.write_int32(elem_len) frb_slice_from(&elem_buf, rbuf, elem_len) self._mend_copy_datum( wbuf, &elem_buf, mending_desc.elements[0], type_id_map, ) elif kind is qltypes.SchemaObjectClass.TUPLE_TYPE: nelems = hton.unpack_int32(frb_read(rbuf, 4)) wbuf.write_int32(nelems) for i in range(nelems): elem_mending_desc = mending_desc.elements[i] if elem_mending_desc is not None: # Discard the original element OID. frb_read(rbuf, 4) # Write the correct element OID. elem_type_id = elem_mending_desc.schema_type_id elem_type_oid = type_id_map[elem_type_id] wbuf.write_int32(elem_type_oid) elem_len = hton.unpack_int32(frb_read(rbuf, 4)) wbuf.write_int32(elem_len) if elem_len != -1: frb_slice_from(&elem_buf, rbuf, elem_len) if elem_mending_desc.needs_mending: self._mend_copy_datum( wbuf, &elem_buf, elem_mending_desc, type_id_map, ) else: wbuf.write_frbuf(&elem_buf) else: buf = frb_read(rbuf, 8) wbuf.write_cstr(buf, 8) elem_len = hton.unpack_int32(buf + 4) if elem_len != -1: wbuf.write_cstr(frb_read(rbuf, elem_len), elem_len) wbuf.write_frbuf(rbuf) async def restore(self, restore_block, bytes data, dict type_map): self.before_command() try: await self._restore(restore_block, data, type_map) finally: await self.after_command() def is_healthy(self): return ( self.connected and self.idle and self.cancel_fut is None and not self.waiting_for_sync and not self.in_tx() ) cdef before_command(self): if not self.connected: raise RuntimeError( 'pgcon: cannot issue new command: not connected') if self.waiting_for_sync: raise RuntimeError( 'pgcon: cannot issue new command; waiting for sync') if not self.idle: raise RuntimeError( 'pgcon: cannot issue new command; ' 'another command is in progress') if self.cancel_fut is not None: raise RuntimeError( 'pgcon: cannot start new command while cancelling the ' 'previous one') self.idle = False self.last_indirect_return = None async def after_command(self): if self.idle: raise RuntimeError('pgcon: idle while running a command') if self.cancel_fut is not None: await self.cancel_fut self.cancel_fut = None self.idle = True # If we were cancelling a command in Postgres there can be a # race between us calling `pg_cancel_backend()` and us receiving # the results of the successfully executed command. If this # happens, we might get the *next command* cancelled. To minimize # the chance of that we do another SYNC. await self.sync() else: self.idle = True cdef write(self, buf): self.transport.write(memoryview(buf)) cdef fallthrough(self): if self.parse_notification(): return cdef: char mtype = self.buffer.get_message_type() # Process a sync, or else the state machine might hang # forever... but still fail! if mtype == b'Z': self.parse_sync_message() raise RuntimeError( f'unexpected message type {chr(mtype)!r}') cdef fallthrough_idle(self): cdef char mtype while self.buffer.take_message(): if self.parse_notification(): continue mtype = self.buffer.get_message_type() if mtype != b'E': # ErrorResponse raise RuntimeError( f'unexpected message type {chr(mtype)!r} ' f'in IDLE state') # We have an error message sent to us by the backend. # It is not safe to assume that the connection # is alive. We assume that it's dead and should be # marked as "closed". try: er_cls, fields = self.parse_error_message() self.aborted_with_error = er_cls(fields=fields) pgcode = fields['C'] metrics.backend_connection_aborted.inc( 1.0, self.get_tenant_label(), pgcode ) if pgcode in POSTGRES_SHUTDOWN_ERR_CODES: pgreason = POSTGRES_SHUTDOWN_ERR_CODES[pgcode] pgmsg = fields.get('M', pgreason) logger.debug( 'backend connection aborted with a shutdown ' 'error code %r(%s): %s', pgcode, pgreason, pgmsg ) if self.is_system_db: self.tenant.set_pg_unavailable_msg(pgmsg) self.tenant.on_sys_pgcon_failover_signal() else: pgmsg = fields.get('M', '') logger.debug( 'backend connection aborted with an ' 'error code %r: %s', pgcode, pgmsg ) finally: self.abort() cdef parse_notification(self): cdef: char mtype = self.buffer.get_message_type() if mtype == b'S': # ParameterStatus name, value = self.parse_parameter_status_message() if self.is_system_db: self.tenant.on_sys_pgcon_parameter_status_updated(name, value) self.parameter_status[name] = value return True elif mtype == b'A': # NotificationResponse self.buffer.read_int32() # discard pid channel = self.buffer.read_null_str().decode() payload = self.buffer.read_null_str().decode() self.buffer.finish_message() if not self.is_system_db: # The server is still initializing, or we're getting # notification from a non-system-db connection. return True if channel == '__edgedb_sysevent__': event_data = json.loads(payload) event = event_data.get('event') server_id = event_data.get('server_id') if server_id == self.server._server_id: # We should only react to notifications sent # by other edgedb servers. Reacting to events # generated by this server must be implemented # at a different layer. return True logger.debug("received system event: %s", event) event_payload = event_data.get('args') if event == 'schema-changes': dbname = event_payload['dbname'] self.tenant.on_remote_ddl(dbname) elif event == 'database-config-changes': dbname = event_payload['dbname'] self.tenant.on_remote_database_config_change(dbname) elif event == 'system-config-changes': self.tenant.on_remote_system_config_change() elif event == 'global-schema-changes': self.tenant.on_global_schema_change() elif event == 'database-changes': self.tenant.on_remote_database_changes() elif event == 'ensure-database-not-used': dbname = event_payload['dbname'] self.tenant.on_remote_database_quarantine(dbname) elif event == 'query-cache-changes': dbname = event_payload['dbname'] to_add = event_payload.get('to_add') to_invalidate = event_payload.get('to_invalidate') self.tenant.on_remote_query_cache_change( dbname, to_add=to_add, to_invalidate=to_invalidate ) else: raise AssertionError(f'unexpected system event: {event!r}') return True elif mtype == b'N': # NoticeResponse _, fields = self.parse_error_message() severity = fields.get('V') message = fields.get('M') detail = fields.get('D') if ( severity == "NOTICE" and message.startswith("edb:notice:indirect_return") ): self.last_indirect_return = detail elif self.log_listeners: for listener in self.log_listeners: self.loop.call_soon(listener, severity, message) return True return False cdef parse_error_message(self): cdef: char code str value dict fields = {} object err_cls while True: code = self.buffer.read_byte() if code == 0: break value = self.buffer.read_null_str().decode() fields[chr(code)] = value self.buffer.finish_message() err_cls = pgerror.get_error_class(fields) if self.debug: self.debug_print('ERROR', err_cls.__name__, fields) return err_cls, fields cdef char parse_sync_message(self): cdef char status if not self.waiting_for_sync: raise RuntimeError('unexpected sync') self.waiting_for_sync -= 1 assert self.buffer.get_message_type() == b'Z' status = self.buffer.read_byte() if status == b'I': self.xact_status = PQTRANS_IDLE elif status == b'T': self.xact_status = PQTRANS_INTRANS elif status == b'E': self.xact_status = PQTRANS_INERROR else: self.xact_status = PQTRANS_UNKNOWN if self.debug: self.debug_print('SYNC MSG', self.xact_status, chr(status)) self.buffer.finish_message() return status cdef parse_parameter_status_message(self): cdef: str name str value assert self.buffer.get_message_type() == b'S' name = self.buffer.read_null_str().decode() value = self.buffer.read_null_str().decode() self.buffer.finish_message() if self.debug: self.debug_print('PARAMETER STATUS MSG', name, value) return name, value cdef make_clean_stmt_message(self, bytes stmt_name): cdef WriteBuffer buf buf = WriteBuffer.new_message(b'C') buf.write_byte(b'S') buf.write_bytestring(stmt_name) return buf.end_message() async def wait_for_message(self): if self.buffer.take_message(): return if self.transport is None: raise ConnectionAbortedError() self.msg_waiter = self.loop.create_future() await self.msg_waiter def connection_made(self, transport): if self.transport is not None: raise RuntimeError('connection_made: invalid connection status') self.transport = transport self.connected = True self.connected_fut.set_result(True) self.connected_fut = None def connection_lost(self, exc): # Mark the connection as disconnected, so that `self.is_healthy()` # surely returns False for this connection. self.connected = False self.transport = None if self.pinned_by is not None: pinned_by = self.pinned_by self.pinned_by = None pinned_by.on_aborted_pgcon(self) if self.is_system_db: self.tenant.on_sys_pgcon_connection_lost(exc) elif self.tenant is not None: if not self.close_requested: self.tenant.on_pgcon_broken() else: self.tenant.on_pgcon_lost() if self.connected_fut is not None and not self.connected_fut.done(): self.connected_fut.set_exception(ConnectionAbortedError()) return if self.msg_waiter is not None and not self.msg_waiter.done(): self.msg_waiter.set_exception(ConnectionAbortedError()) self.msg_waiter = None def pause_writing(self): pass def resume_writing(self): pass def data_received(self, data): self.buffer.feed_data(data) if self.connected and self.idle: assert self.msg_waiter is None self.fallthrough_idle() elif (self.msg_waiter is not None and self.buffer.take_message() and not self.msg_waiter.cancelled()): self.msg_waiter.set_result(True) self.msg_waiter = None def eof_received(self): pass # Underscored name for _SYNC_MESSAGE because it should always be emitted # using write_sync(), which properly counts them cdef bytes _SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message()) cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message()) cdef EdegDBCodecContext DEFAULT_CODEC_CONTEXT = EdegDBCodecContext() cdef inline int16_t read_int16(data: bytes): return int.from_bytes(data[0:2], "big", signed=True) cdef inline int32_t read_int32(data: bytes): return int.from_bytes(data[0:4], "big", signed=True) ================================================ FILE: edb/server/pgcon/pgcon_sql.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from edb.server.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, FRBuffer, ) cdef enum PGAction: START_IMPLICIT_TX = 0 PARSE = 1 BIND = 2 DESCRIBE_STMT = 3 DESCRIBE_STMT_ROWS = 4 DESCRIBE_PORTAL = 5 EXECUTE = 6 CLOSE_STMT = 7 CLOSE_PORTAL = 8 FLUSH = 9 SYNC = 10 cdef class PGMessage: cdef: PGAction action bytes stmt_name bytes portal_name str orig_portal_name object args object query_unit bint frontend_only bint valid bint injected object orig_query object fe_settings cdef inline bint is_frontend_only(self) cdef inline bint is_valid(self) cdef inline bint is_injected(self) cdef class PGSQLConnection: cdef: PGConnection con cdef _rewrite_sql_error_response(self, PGMessage action, WriteBuffer buf) ================================================ FILE: edb/server/pgcon/pgcon_sql.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # cdef class PGMessage: def __init__( self, PGAction action, bytes stmt_name=None, str portal_name=None, args=None, query_unit=None, fe_settings=None, injected=False, bytes force_portal_name=None, ): self.action = action self.stmt_name = stmt_name self.orig_portal_name = portal_name if force_portal_name is not None: self.portal_name = force_portal_name elif portal_name: self.portal_name = b'u' + portal_name.encode("utf-8") else: self.portal_name = b'' self.args = args self.query_unit = query_unit self.fe_settings = fe_settings self.valid = True self.injected = injected if self.query_unit is not None: self.frontend_only = self.query_unit.frontend_only else: self.frontend_only = False cdef inline bint is_frontend_only(self): return self.frontend_only def invalidate(self): self.valid = False cdef inline bint is_valid(self): return self.valid cdef inline bint is_injected(self): return self.injected def as_injected(self) -> PGMessage: return PGMessage( action=self.action, stmt_name=self.stmt_name, portal_name=self.orig_portal_name, args=self.args, query_unit=self.query_unit, fe_settings=self.fe_settings, injected=True, ) def __repr__(self): rv = [] if self.action == PGAction.START_IMPLICIT_TX: rv.append("START_IMPLICIT_TX") elif self.action == PGAction.PARSE: rv.append("PARSE") elif self.action == PGAction.BIND: rv.append("BIND") elif self.action == PGAction.DESCRIBE_STMT: rv.append("DESCRIBE_STMT") elif self.action == PGAction.DESCRIBE_STMT_ROWS: rv.append("DESCRIBE_STMT_ROWS") elif self.action == PGAction.DESCRIBE_PORTAL: rv.append("DESCRIBE_PORTAL") elif self.action == PGAction.EXECUTE: rv.append("EXECUTE") elif self.action == PGAction.CLOSE_STMT: rv.append("CLOSE_STMT") elif self.action == PGAction.CLOSE_PORTAL: rv.append("CLOSE_PORTAL") elif self.action == PGAction.FLUSH: rv.append("FLUSH") elif self.action == PGAction.SYNC: rv.append("SYNC") if self.stmt_name is not None: rv.append(f"stmt_name={self.stmt_name}") if self.orig_portal_name is not None: rv.append(f"portal_name={self.orig_portal_name!r}") if self.args is not None: rv.append(f"args={self.args}") rv.append(f"frontend_only={self.is_frontend_only()}") rv.append(f"injected={self.is_injected()}") if self.query_unit is not None: rv.append(f"query_unit={self.query_unit}") if len(rv) > 1: rv.insert(1, ":") return " ".join(rv) cdef class PGSQLConnection: def __init__(self, con): self.con = con def _write_sql_extended_query( self, actions, dbver: int, dbv: pg_ext.ConnectionView, ) -> bytes: cdef: WriteBuffer buf, msg_buf PGMessage action bint be_parse buf = WriteBuffer.new() state = None if not dbv.in_tx(): state = dbv.serialize_state() self.con._build_apply_sql_state_req(state, buf) # We need to close the implicit transaction with a SYNC here # because the next command may be e.g. "BEGIN DEFERRABLE". self.con.write_sync(buf) prepared = set() for action in actions: if action.is_frontend_only(): continue be_parse = True if action.action == PGAction.PARSE: sql_text, data = action.args[:2] if action.stmt_name in prepared: action.frontend_only = True else: if action.stmt_name: be_parse = self.con.before_prepare( action.stmt_name, dbver, buf ) if not be_parse: if self.con.debug: self.con.debug_print( 'Parse cache hit', action.stmt_name, sql_text) action.frontend_only = True if not action.is_frontend_only(): prepared.add(action.stmt_name) msg_buf = WriteBuffer.new_message(b'P') msg_buf.write_bytestring(action.stmt_name) msg_buf.write_bytestring(sql_text) msg_buf.write_bytes(data) buf.write_buffer(msg_buf.end_message()) metrics.query_size.observe( len(sql_text), self.con.get_tenant_label(), 'compiled' ) if self.con.debug: self.con.debug_print( 'Parse', action.stmt_name, sql_text, data ) elif action.action == PGAction.BIND: if action.query_unit is not None and action.query_unit.prepare: be_stmt_name = action.query_unit.prepare.be_stmt_name if be_stmt_name in prepared: action.frontend_only = True else: if be_stmt_name: be_parse = self.con.before_prepare( be_stmt_name, dbver, buf ) if not be_parse: if self.con.debug: self.con.debug_print( 'Parse cache hit', be_stmt_name) action.frontend_only = True prepared.add(be_stmt_name) if action.is_frontend_only(): pass elif action.query_unit is not None and isinstance( action.query_unit.command_complete_tag, dbstate.TagUnpackRow ): # in this case we are intercepting the only result row so # we want to set its encoding to be binary msg_buf = WriteBuffer.new_message(b'B') msg_buf.write_bytestring(action.portal_name) msg_buf.write_bytestring(action.stmt_name) # skim over param format codes param_formats = read_int16(action.args[0:2]) offset = 2 + param_formats * 2 # skim over param values params = read_int16(action.args[offset:offset+2]) offset += 2 for p in range(params): size = read_int32(action.args[offset:offset+4]) if size == -1: # special case: NULL size = 0 offset += 4 + size msg_buf.write_bytes(action.args[0:offset]) # set the result formats msg_buf.write_int16(1) # number of columns msg_buf.write_int16(1) # binary encoding buf.write_buffer(msg_buf.end_message()) else: msg_buf = WriteBuffer.new_message(b'B') msg_buf.write_bytestring(action.portal_name) msg_buf.write_bytestring(action.stmt_name) msg_buf.write_bytes(action.args) buf.write_buffer(msg_buf.end_message()) elif ( action.action in (PGAction.DESCRIBE_STMT, PGAction.DESCRIBE_STMT_ROWS) ): msg_buf = WriteBuffer.new_message(b'D') msg_buf.write_byte(b'S') msg_buf.write_bytestring(action.stmt_name) buf.write_buffer(msg_buf.end_message()) elif action.action == PGAction.DESCRIBE_PORTAL: msg_buf = WriteBuffer.new_message(b'D') msg_buf.write_byte(b'P') msg_buf.write_bytestring(action.portal_name) buf.write_buffer(msg_buf.end_message()) elif action.action == PGAction.EXECUTE: if action.query_unit is not None and action.query_unit.prepare: be_stmt_name = action.query_unit.prepare.be_stmt_name if be_stmt_name in prepared: action.frontend_only = True else: if be_stmt_name: be_parse = self.con.before_prepare( be_stmt_name, dbver, buf ) if not be_parse: if self.con.debug: self.con.debug_print( 'Parse cache hit', be_stmt_name) action.frontend_only = True prepared.add(be_stmt_name) if ( action.query_unit is not None and action.query_unit.deallocate is not None and self.con.before_prepare( action.query_unit.deallocate.be_stmt_name, dbver, buf ) ): # This prepared statement does not actually exist # on this connection, so there's nothing to DEALLOCATE. action.frontend_only = True if action.is_frontend_only(): pass elif action.query_unit is not None and isinstance( action.query_unit.command_complete_tag, (dbstate.TagCountMessages, dbstate.TagUnpackRow), ): # when executing TagUnpackRow, don't pass the limit through msg_buf = WriteBuffer.new_message(b'E') msg_buf.write_bytestring(action.portal_name) msg_buf.write_int32(0) buf.write_buffer(msg_buf.end_message()) else: # base case msg_buf = WriteBuffer.new_message(b'E') msg_buf.write_bytestring(action.portal_name) msg_buf.write_int32(action.args) buf.write_buffer(msg_buf.end_message()) elif action.action == PGAction.CLOSE_PORTAL: if action.query_unit is not None and action.query_unit.prepare: be_stmt_name = action.query_unit.prepare.be_stmt_name if be_stmt_name in prepared: action.frontend_only = True if not action.is_frontend_only(): msg_buf = WriteBuffer.new_message(b'C') msg_buf.write_byte(b'P') msg_buf.write_bytestring(action.portal_name) buf.write_buffer(msg_buf.end_message()) elif action.action == PGAction.CLOSE_STMT: if action.query_unit is not None and action.query_unit.prepare: be_stmt_name = action.query_unit.prepare.be_stmt_name if be_stmt_name in prepared: action.frontend_only = True if not action.is_frontend_only(): msg_buf = WriteBuffer.new_message(b'C') msg_buf.write_byte(b'S') msg_buf.write_bytestring(action.stmt_name) buf.write_buffer(msg_buf.end_message()) elif action.action == PGAction.FLUSH: msg_buf = WriteBuffer.new_message(b'H') buf.write_buffer(msg_buf.end_message()) elif action.action == PGAction.SYNC: self.con.write_sync(buf) if action.action not in (PGAction.SYNC, PGAction.FLUSH): # Make sure _parse_sql_extended_query() complete by sending a FLUSH # to the backend, but we won't flush the results to the frontend # because it's not requested. msg_buf = WriteBuffer.new_message(b'H') buf.write_buffer(msg_buf.end_message()) self.con.write(buf) return state async def _parse_sql_extended_query( self, actions, fe_conn: frontend.AbstractFrontendConnection, dbver: int, dbv: pg_ext.ConnectionView, ) -> tuple[bool, bool]: cdef: WriteBuffer buf, msg_buf PGMessage action bint ignore_till_sync = False int32_t row_count buf = WriteBuffer.new() rv = True for action in actions: if self.con.debug: self.con.debug_print( 'ACTION', action, 'ignore_till_sync=', ignore_till_sync ) if ignore_till_sync and action.action != PGAction.SYNC: continue elif action.action == PGAction.FLUSH: if buf.len() > 0: fe_conn.write(buf) fe_conn.flush() buf = WriteBuffer.new() continue elif action.action == PGAction.START_IMPLICIT_TX: dbv.start_implicit() continue elif action.is_frontend_only(): if action.action == PGAction.PARSE: if not action.is_injected(): msg_buf = WriteBuffer.new_message(b'1') buf.write_buffer(msg_buf.end_message()) elif action.action == PGAction.BIND: dbv.create_portal( action.orig_portal_name, action.query_unit ) if not action.is_injected(): msg_buf = WriteBuffer.new_message(b'2') # BindComplete buf.write_buffer(msg_buf.end_message()) elif action.action == PGAction.DESCRIBE_STMT: # ParameterDescription if not action.is_injected(): msg_buf = WriteBuffer.new_message(b't') msg_buf.write_int16(0) # number of parameters buf.write_buffer(msg_buf.end_message()) elif action.action == PGAction.EXECUTE: if action.query_unit.set_vars is not None: assert len(action.query_unit.set_vars) == 1 # CommandComplete msg_buf = WriteBuffer.new_message(b'C') if next( iter(action.query_unit.set_vars.values()) ) is None: msg_buf.write_bytestring(b'RESET') else: msg_buf.write_bytestring(b'SET') buf.write_buffer(msg_buf.end_message()) elif not action.is_injected(): # NoData msg_buf = WriteBuffer.new_message(b'n') buf.write_buffer(msg_buf.end_message()) # CommandComplete msg_buf = WriteBuffer.new_message(b'C') assert isinstance( action.query_unit.command_complete_tag, dbstate.TagPlain, ), "emulated SQL unit has no command_tag" plain = action.query_unit.command_complete_tag msg_buf.write_bytestring(plain.tag) buf.write_buffer(msg_buf.end_message()) dbv.on_success(action.query_unit) fe_conn.on_success(action.query_unit) elif action.action == PGAction.CLOSE_PORTAL: dbv.close_portal_if_exists(action.orig_portal_name) if not action.is_injected(): msg_buf = WriteBuffer.new_message(b'3') # CloseComplete buf.write_buffer(msg_buf.end_message()) elif action.action == PGAction.CLOSE_STMT: if not action.is_injected(): msg_buf = WriteBuffer.new_message(b'3') # CloseComplete buf.write_buffer(msg_buf.end_message()) if ( action.action == PGAction.DESCRIBE_STMT or action.action == PGAction.DESCRIBE_PORTAL ): if action.query_unit.set_vars is not None: msg_buf = WriteBuffer.new_message(b'n') # NoData buf.write_buffer(msg_buf.end_message()) continue row_count = 0 while True: if not self.con.buffer.take_message(): if buf.len() > 0: fe_conn.write(buf) fe_conn.flush() buf = WriteBuffer.new() await self.con.wait_for_message() mtype = self.con.buffer.get_message_type() if self.con.debug: self.con.debug_print(f'recv backend message: {chr(mtype)!r}') if ignore_till_sync: self.con.debug_print("ignoring until SYNC") if ignore_till_sync and mtype != b'Z': self.con.buffer.discard_message() continue if ( mtype == b'3' and action.action != PGAction.CLOSE_PORTAL and action.action != PGAction.CLOSE_STMT ): # before_prepare() initiates LRU cleanup for # prepared statements, so CloseComplete may # appear here. self.con.buffer.discard_message() continue # ParseComplete if mtype == b'1' and action.action == PGAction.PARSE: self.con.buffer.finish_message() if self.con.debug: self.con.debug_print('PARSE COMPLETE MSG') if action.stmt_name: self.con.prep_stmts[action.stmt_name] = dbver if not action.is_injected(): msg_buf = WriteBuffer.new_message(mtype) buf.write_buffer(msg_buf.end_message()) break # BindComplete elif mtype == b'2' and action.action == PGAction.BIND: self.con.buffer.finish_message() if self.con.debug: self.con.debug_print('BIND COMPLETE MSG') if action.query_unit is not None: dbv.create_portal( action.orig_portal_name, action.query_unit ) if not action.is_injected(): msg_buf = WriteBuffer.new_message(mtype) buf.write_buffer(msg_buf.end_message()) break elif ( # RowDescription or NoData mtype == b'T' or mtype == b'n' ) and ( action.action == PGAction.DESCRIBE_STMT or action.action == PGAction.DESCRIBE_STMT_ROWS or action.action == PGAction.DESCRIBE_PORTAL ): data = self.con.buffer.consume_message() if self.con.debug: self.con.debug_print('END OF DESCRIBE', mtype) if ( mtype == b'T' and action.query_unit is not None and isinstance( action.query_unit.command_complete_tag, dbstate.TagUnpackRow, ) ): # TagUnpackRow converts RowDescription into NoData msg_buf = WriteBuffer.new_message(b'n') buf.write_buffer(msg_buf.end_message()) elif not action.is_injected() and not ( mtype == b'n' and action.action == PGAction.DESCRIBE_STMT_ROWS ): msg_buf = WriteBuffer.new_message(mtype) msg_buf.write_bytes(data) buf.write_buffer(msg_buf.end_message()) break elif ( mtype == b't' # ParameterDescription and action.action == PGAction.DESCRIBE_STMT_ROWS ): self.con.buffer.consume_message() elif ( mtype == b't' # ParameterDescription ): # remap parameter descriptions # The "external" parameters (that are visible to the user) # don't include the internal params for globals and # extracted constants. # This chunk of code remaps the descriptions of internal # params into external ones. self.con.buffer.read_int16() # count_internal data_internal = self.con.buffer.consume_message() msg_buf = WriteBuffer.new_message(b't') external_params: int64_t = 0 if ( action.query_unit is not None and action.query_unit.params ): for index, param in enumerate(action.query_unit.params): if not isinstance(param, dbstate.SQLParamExternal): break external_params = index + 1 msg_buf.write_int16(external_params) msg_buf.write_bytes(data_internal[0:external_params * 4]) buf.write_buffer(msg_buf.end_message()) elif ( mtype == b'T' # RowDescription and action.action == PGAction.EXECUTE and action.query_unit is not None and isinstance( action.query_unit.command_complete_tag, dbstate.TagUnpackRow, ) ): data = self.con.buffer.consume_message() # tell the frontend connection that there is NoData # because we intercept and unpack the DataRow. msg_buf = WriteBuffer.new_message(b'n') buf.write_buffer(msg_buf.end_message()) elif ( mtype == b'D' # DataRow and action.action == PGAction.EXECUTE and action.query_unit is not None and isinstance( action.query_unit.command_complete_tag, dbstate.TagUnpackRow, ) ): # unpack a single row with a single column data = self.con.buffer.consume_message() field_size = read_int32(data[2:6]) val_bytes = data[6:6 + field_size] row_count = int.from_bytes(val_bytes, "big", signed=True) elif ( # CommandComplete, EmptyQueryResponse, PortalSuspended mtype == b'C' or mtype == b'I' or mtype == b's' ) and action.action == PGAction.EXECUTE: data = self.con.buffer.consume_message() if self.con.debug: self.con.debug_print('END OF EXECUTE', mtype) if action.query_unit is not None: fe_conn.on_success(action.query_unit) dbv.on_success(action.query_unit) if action.query_unit.prepare is not None: be_stmt_name = action.query_unit.prepare.be_stmt_name if be_stmt_name: if self.con.debug: self.con.debug_print( f"remembering ps {be_stmt_name}, " f"dbver {dbver}" ) self.con.prep_stmts[be_stmt_name] = dbver if ( not action.is_injected() and action.query_unit is not None and action.query_unit.command_complete_tag ): tag = action.query_unit.command_complete_tag msg_buf = WriteBuffer.new_message(mtype) if isinstance(tag, dbstate.TagPlain): msg_buf.write_bytestring(tag.tag) elif isinstance(tag, (dbstate.TagCountMessages, dbstate.TagUnpackRow)): msg_buf.write_bytes(bytes(tag.prefix, "utf-8")) # This should return the number of modified rows by # the top-level query, but we are returning the # count of rows in the response. These two will # always match because our compiled DML with always # have a top-level SELECT with same number of rows # as the DML stmt somewhere in the the CTEs. msg_buf.write_str(str(row_count), "utf-8") buf.write_buffer(msg_buf.end_message()) elif not action.is_injected(): msg_buf = WriteBuffer.new_message(mtype) msg_buf.write_bytes(data) buf.write_buffer(msg_buf.end_message()) break # CloseComplete elif mtype == b'3' and action.action == PGAction.CLOSE_PORTAL: self.con.buffer.finish_message() if self.con.debug: self.con.debug_print('CLOSE COMPLETE MSG (PORTAL)') dbv.close_portal_if_exists(action.orig_portal_name) if not action.is_injected(): msg_buf = WriteBuffer.new_message(mtype) buf.write_buffer(msg_buf.end_message()) break elif mtype == b'3' and action.action == PGAction.CLOSE_STMT: self.con.buffer.finish_message() if self.con.debug: self.con.debug_print('CLOSE COMPLETE MSG (STATEMENT)') if not action.is_injected(): msg_buf = WriteBuffer.new_message(mtype) buf.write_buffer(msg_buf.end_message()) break elif mtype == b'E': # ErrorResponse rv = False if self.con.debug: self.con.debug_print('ERROR RESPONSE MSG') if action.query_unit is not None: fe_conn.on_error(action.query_unit) dbv.on_error() self._rewrite_sql_error_response(action, buf) fe_conn.write(buf) fe_conn.flush() buf = WriteBuffer.new() ignore_till_sync = True break elif mtype == b'Z': # ReadyForQuery ignore_till_sync = False dbv.end_implicit() status = self.con.parse_sync_message() msg_buf = WriteBuffer.new_message(b'Z') msg_buf.write_byte(status) buf.write_buffer(msg_buf.end_message()) fe_conn.write(buf) fe_conn.flush() return True, True else: if not action.is_injected(): if self.con.debug: self.con.debug_print('REDIRECT OTHER MSG', mtype) messages_redirected = self.con.buffer.redirect_messages( buf, mtype, 0 ) # DataRow if mtype == b'D': row_count += messages_redirected else: logger.warning( f"discarding unexpected backend message: " f"{chr(mtype)!r}" ) self.con.buffer.discard_message() if buf.len() > 0: fe_conn.write(buf) return rv, False cdef _rewrite_sql_error_response(self, PGMessage action, WriteBuffer buf): cdef WriteBuffer msg_buf if action.action == PGAction.PARSE: msg_buf = WriteBuffer.new_message(b'E') while True: field_type = self.con.buffer.read_byte() if field_type == b'P': # Position if action.query_unit is None: source_map = None offset = 0 else: qu = action.query_unit source_map = qu.source_map offset = -qu.prefix_len self.con._write_error_position( msg_buf, action.args[0], self.con.buffer.read_null_str(), source_map, offset, ) continue else: msg_buf.write_byte(field_type) if field_type == b'\0': break msg_buf.write_bytestring( self.con.buffer.read_null_str() ) self.con.buffer.finish_message() buf.write_buffer(msg_buf.end_message()) elif action.action in ( PGAction.BIND, PGAction.EXECUTE, PGAction.DESCRIBE_PORTAL, PGAction.CLOSE_PORTAL, ): portal_name = action.orig_portal_name msg_buf = WriteBuffer.new_message(b'E') message = None while True: field_type = self.con.buffer.read_byte() if field_type == b'C': # Code msg_buf.write_byte(field_type) code = self.con.buffer.read_null_str() msg_buf.write_bytestring(code) if code == b'34000': message = f'cursor "{portal_name}" does not exist' elif code == b'42P03': message = f'cursor "{portal_name}" already exists' elif field_type == b'M' and message: msg_buf.write_byte(field_type) msg_buf.write_bytestring( message.encode('utf-8') ) elif field_type == b'P': if action.query_unit is not None: qu = action.query_unit query_text = qu.query.encode("utf-8") if qu.prepare is not None: offset = -55 source_map = qu.prepare.source_map else: offset = 0 source_map = qu.source_map offset -= qu.prefix_len else: query_text = b"" source_map = None offset = 0 self.con._write_error_position( msg_buf, query_text, self.con.buffer.read_null_str(), source_map, offset, ) else: msg_buf.write_byte(field_type) if field_type == b'\0': break msg_buf.write_bytestring( self.con.buffer.read_null_str() ) self.con.buffer.finish_message() buf.write_buffer(msg_buf.end_message()) else: data = self.con.buffer.consume_message() msg_buf = WriteBuffer.new_message(b'E') msg_buf.write_bytes(data) buf.write_buffer(msg_buf.end_message()) ================================================ FILE: edb/server/pgcon/rust_transport.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This module implements a Rust-based transport for PostgreSQL connections. The PGRawConn class provides a high-level interface for establishing and managing PostgreSQL connections using a Rust-implemented state machine. It handles the complexities of connection establishment, including SSL negotiation and authentication, while presenting a simple asyncio-like transport interface to the caller. """ from __future__ import annotations from typing import Optional, Protocol, Callable, Any import asyncio import ssl as ssl_module import socket import warnings from enum import Enum, auto from dataclasses import dataclass from edb.server._rust_native import _pg_rust as pgrust from edb.server.pgconnparams import ( ConnectionParams, SSLMode, get_pg_home_directory, ) from . import errors as pgerror TCP_KEEPIDLE = 24 TCP_KEEPINTVL = 2 TCP_KEEPCNT = 3 DEFAULT_CONNECT_TIMEOUT = 60 class ConnectionStateType(Enum): CONNECTING = 0 SSL_CONNECTING = auto() AUTHENTICATING = auto() SYNCHRONIZING = auto() READY = auto() class Authentication(Enum): NONE = 0 TRUST = auto() PASSWORD = auto() MD5 = auto() SCRAM_SHA256 = auto() @dataclass class PGState: parameters: dict[str, str] cancellation_key: Optional[tuple[int, int]] auth: Optional[Authentication] server_error: Optional[list[tuple[str, str]]] ssl: bool class ConnectionStateUpdate(Protocol): def send(self, message: memoryview) -> None: ... def upgrade(self) -> None: ... def parameter(self, name: str, value: str) -> None: ... def cancellation_key(self, pid: int, key: int) -> None: ... def state_changed(self, state: int) -> None: ... def auth(self, auth: int) -> None: ... StateChangeCallback = Callable[[ConnectionStateType], None] def _parse_tls_version(tls_version: str) -> ssl_module.TLSVersion: if tls_version.startswith('SSL'): raise ValueError(f"Unsupported TLS version: {tls_version}") try: return ssl_module.TLSVersion[tls_version.replace('.', '_')] except KeyError: raise ValueError(f"No such TLS version: {tls_version}") def _create_ssl(ssl_config: dict[str, Any]): sslmode = SSLMode.parse(ssl_config['sslmode']) ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) ssl.check_hostname = sslmode >= SSLMode.verify_full if sslmode < SSLMode.require: ssl.verify_mode = ssl_module.CERT_NONE else: if ssl_config['sslrootcert']: ssl.load_verify_locations(ssl_config['sslrootcert']) ssl.verify_mode = ssl_module.CERT_REQUIRED else: if sslmode == SSLMode.require: ssl.verify_mode = ssl_module.CERT_NONE if ssl_config['sslcrl']: ssl.load_verify_locations(ssl_config['sslcrl']) ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN if ssl_config['sslkey'] and ssl_config['sslcert']: ssl.load_cert_chain( ssl_config['sslcert'], ssl_config['sslkey'], ssl_config['sslpassword'] or '', ) if ssl_config['ssl_max_protocol_version']: ssl.maximum_version = _parse_tls_version( ssl_config['ssl_max_protocol_version'] ) if ssl_config['ssl_min_protocol_version']: ssl.minimum_version = _parse_tls_version( ssl_config['ssl_min_protocol_version'] ) # OpenSSL 1.1.1 keylog file if hasattr(ssl, 'keylog_filename'): if ssl_config['keylog_filename']: ssl.keylog_filename = ssl_config['keylog_filename'] return ssl class PGConnectionProtocol(asyncio.Protocol): """A protocol that manages the initial connection and authentication process for PostgreSQL. This protocol acts as an intermediary between the raw socket connection and the user's protocol. """ def __init__( self, hostname: Optional[str], state: pgrust.PyConnectionState, pg_state: PGState, complete_callback: Callable[ [asyncio.BaseTransport], tuple[PGRawConn, asyncio.Protocol] ], ): self.state = state self.pg_state = pg_state self.ready_future: asyncio.Future = asyncio.Future() self.ready_future.add_done_callback(self._cleanup) self._complete_callback = complete_callback self._host = hostname self._transport: Optional[asyncio.Transport] = None def _cleanup(self, _fut: asyncio.Future) -> None: # IMPORTANT: break Python/Rust ref cycle self.state.update = None self.state = None def data_received(self, data: bytes): if self.ready_future.done(): return try: self.state.drive_message(memoryview(data)) if self.state.is_ready(): assert self._transport is not None self.ready_future.set_result( self._complete_callback(self._transport) ) except Exception as e: if not self.ready_future.done(): self.ready_future.set_exception(ConnectionError(e)) def connection_lost(self, exc): if self.ready_future.done(): return if exc: self.ready_future.set_exception(exc) else: ex = pgerror.new( pgerror.ERROR_CONNECTION_FAILURE, "Unexpected connection error", ) ex.__cause__ = exc self.ready_future.set_exception(ex) # This may be called multiple times if we upgrade to SSL def connection_made(self, transport) -> None: try: if self._transport is None: # Initial connection self._transport = transport self.state.update = self self.state.drive_initial() else: # Upgrade path self._transport = transport except Exception: pass return super().connection_made(transport) def send(self, message: memoryview) -> None: assert self._transport is not None self._transport.write(bytes(message)) def upgrade(self) -> None: asyncio.create_task(self._upgrade_to_ssl()) async def _upgrade_to_ssl(self): transport = self._transport assert transport is not None try: ssl_context = _create_ssl(self.state.config) loop = asyncio.get_running_loop() new_transport = await loop.start_tls( transport, self, ssl_context, server_side=False, ssl_handshake_timeout=None, server_hostname=self._host, ) self._transport = new_transport self.state.drive_ssl_ready() self.pg_state.ssl = True except Exception as e: if not self.ready_future.done(): self.ready_future.set_exception(e) transport.abort() def parameter(self, name: str, value: str) -> None: self.pg_state.parameters[name] = value def cancellation_key(self, pid: int, key: int) -> None: self.pg_state.cancellation_key = (pid, key) def state_changed(self, _: int) -> None: pass def auth(self, auth: int) -> None: self.pg_state.auth = Authentication(auth) def server_error(self, error: list[tuple[str, str]]) -> None: if not self.ready_future.done(): self.ready_future.set_exception( pgerror.BackendConnectionError(fields=dict(error)) ) class PGRawConn(asyncio.Transport): def __init__( self, source_description: Optional[str], connection: ConnectionParams, raw_transport: asyncio.Transport, pg_state: PGState, addr: tuple[str, int], ): super().__init__() self.source_description = source_description self.connection = connection self.raw_transport = raw_transport self._pg_state = pg_state self.addr = addr @property def state(self): return self._pg_state def write(self, data: bytes | bytearray | memoryview): self.raw_transport.write(data) def close(self): self.raw_transport.close() def is_closing(self): return self.raw_transport.is_closing() def get_extra_info(self, name: str, default=None): return self.raw_transport.get_extra_info(name, default) def pause_reading(self): self.raw_transport.pause_reading() def resume_reading(self): self.raw_transport.resume_reading() def is_reading(self): return self.raw_transport.is_reading() def abort(self): self.raw_transport.abort() def __repr__(self): params = ', '.join( f"{k}={v}" for k, v in self._pg_state.parameters.items() ) auth_str = ( f", auth={self._pg_state.auth.name}" if self._pg_state.auth else "" ) source_str = ( f", source={self.source_description}" if self.source_description else "" ) raw_repr = repr(self.raw_transport) dsn = self.connection._params return ( f"" ) def __del__(self): if not self.raw_transport.is_closing(): warnings.warn( f"unclosed connection {repr(self)}", ResourceWarning, stacklevel=2, ) self.raw_transport.abort() async def _create_connection_to( protocol_factory: Callable[[Optional[str], str, int], PGConnectionProtocol], address_family: str, host: str | bytes, hostname: str, port: int, ) -> tuple[asyncio.Transport, PGConnectionProtocol]: if address_family == "unix": t, protocol = await asyncio.get_running_loop().create_unix_connection( lambda: protocol_factory(None, hostname, port), path=host # type: ignore ) return (t, protocol) else: t, protocol = await asyncio.get_running_loop().create_connection( lambda: protocol_factory(hostname, hostname, port), str(host), port ) _set_tcp_keepalive(t) return (t, protocol) async def _create_connection( protocol_factory: Callable[[Optional[str], str, int], PGConnectionProtocol], connect_timeout: Optional[int], host_candidates: list[tuple[str, str | bytes, str, int]], ) -> tuple[asyncio.Transport, PGConnectionProtocol]: e = None for address_family, host, hostname, port in host_candidates: try: async with asyncio.timeout( connect_timeout if connect_timeout else DEFAULT_CONNECT_TIMEOUT ): return await _create_connection_to( protocol_factory, address_family, host, hostname, port ) except asyncio.TimeoutError as ex: raise pgerror.new( pgerror.ERROR_CONNECTION_FAILURE, "timed out connecting to backend", ) from ex except Exception as ex: e = ex continue raise ConnectionError( f"Failed to connect to any of the provided hosts: {host_candidates}" ) from e def _set_tcp_keepalive(transport): # TCP keepalive was initially added here for special cases where idle # connections are dropped silently on GitHub Action running test suite # against AWS RDS. We are keeping the TCP keepalive for generic # Postgres connections as the kernel overhead is considered low, and # in certain cases it does save us some reconnection time. # # In case of high-availability Postgres, TCP keepalive is necessary to # disconnect from a failing master node, if no other failover information # is available. sock = transport.get_extra_info('socket') sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) # TCP_KEEPIDLE: the time (in seconds) the connection needs to remain idle # before TCP starts sending keepalive probes. This is socket.TCP_KEEPIDLE # on Linux, and socket.TCP_KEEPALIVE on macOS from Python 3.10. if hasattr(socket, 'TCP_KEEPIDLE'): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, TCP_KEEPIDLE) if hasattr(socket, 'TCP_KEEPALIVE'): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, TCP_KEEPIDLE) # TCP_KEEPINTVL: The time (in seconds) between individual keepalive probes. if hasattr(socket, 'TCP_KEEPINTVL'): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, TCP_KEEPINTVL) # TCP_KEEPCNT: The maximum number of keepalive probes TCP should send # before dropping the connection. if hasattr(socket, 'TCP_KEEPCNT'): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, TCP_KEEPCNT) def complete_connection_callback( host, port, source_description, state, protocol_factory, pg_state ) -> Callable[[asyncio.BaseTransport], tuple[PGRawConn, asyncio.Protocol]]: def complete_connection(upgraded_transport): conn = PGRawConn( source_description, ConnectionParams._create(state.config), upgraded_transport, pg_state, (host, port), ) # We've successfully upgraded the protocol at this point, and the remote # PG server is sitting in the idle state, waiting for us to send the # next message. We transition to the user protocol here, synthesizing # a `connection_made` event. user_protocol = protocol_factory() upgraded_transport.set_protocol(user_protocol) # Notify the user protocol of successful connection user_protocol.connection_made(conn) return conn, user_protocol return complete_connection async def create_postgres_connection[P: asyncio.Protocol]( dsn: str | ConnectionParams, protocol_factory: Callable[[], P], *, source_description: Optional[str] = None, ) -> tuple[PGRawConn, P]: """ Open a PostgreSQL connection to the address specified by the DSN or ConnectionParams, creating the user protocol from the protocol_factory. This method establishes the connection asynchronously. When successful, it returns a (PGRawConn, protocol) pair. """ if isinstance(dsn, str): dsn = ConnectionParams(dsn=dsn) connect_timeout = dsn.connect_timeout try: state = pgrust.PyConnectionState( dsn._params, "postgres", get_pg_home_directory() ) except Exception as e: raise ValueError(e) pg_state = PGState( parameters={}, cancellation_key=None, auth=None, server_error=None, ssl=False, ) # The PGConnectionProtocol will drive the PyConnectionState from network # bytes it receives, as well as driving the connection from the messages # from PyConnectionState. connect_protocol_factory = ( lambda hostname, host, port: PGConnectionProtocol( hostname, state, pg_state, complete_connection_callback( host, port, source_description, state, protocol_factory, pg_state, ), ) ) # Create a transport to the backend based off the host candidates. host_candidates = await asyncio.get_running_loop().run_in_executor( executor=None, func=lambda: state.config.host_candidates ) _, protocol = await _create_connection( connect_protocol_factory, connect_timeout, host_candidates, ) conn, user_protocol = await protocol.ready_future return conn, user_protocol ================================================ FILE: edb/server/pgconnparams.py ================================================ # Copyright (C) 2016-present MagicStack Inc. and the EdgeDB authors. # Copyright (C) 2016-present the asyncpg authors and contributors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import TypedDict, NotRequired, Optional, Unpack, Self, Any import enum import pathlib import platform from edb.server._rust_native._pg_rust import PyConnectionParams _system = platform.uname().system if _system == 'Windows': import ctypes.wintypes CSIDL_APPDATA = 0x001A def get_pg_home_directory() -> Optional[str]: # We cannot simply use expanduser() as that returns the user's # home directory, whereas Postgres stores its config in # %AppData% on Windows. buf = ctypes.create_unicode_buffer(ctypes.wintypes.MAX_PATH) r = ctypes.windll.shell32.SHGetFolderPathW( # type: ignore 0, CSIDL_APPDATA, 0, 0, buf ) if r: return None else: return str(pathlib.Path(buf.value) / 'postgresql') else: def get_pg_home_directory() -> Optional[str]: try: return str(pathlib.Path.home()) except RuntimeError: # This can happen if the home directory is not set return None class SSLMode(enum.IntEnum): disable = 0 allow = 1 prefer = 2 require = 3 verify_ca = 4 verify_full = 5 @classmethod def parse(cls, sslmode: str) -> Self: value: Self = getattr(cls, sslmode.replace('-', '_')) assert value is not None, f"Invalid SSL mode: {sslmode}" return value class CreateParamsKwargs(TypedDict, total=False): dsn: NotRequired[str] hosts: NotRequired[Optional[list[tuple[str, int]]]] host: NotRequired[Optional[str]] user: NotRequired[Optional[str]] password: NotRequired[Optional[str]] database: NotRequired[Optional[str]] server_settings: NotRequired[Optional[dict[str, str]]] sslmode: NotRequired[Optional[SSLMode]] sslrootcert: NotRequired[Optional[str]] connect_timeout: NotRequired[Optional[int]] class ConnectionParams: """ A Python representation of the Rust connection parameters that are passed back during connection/parse. This class encapsulates the connection parameters used for establishing a connection to a PostgreSQL database. """ _params: PyConnectionParams def __init__(self, **kwargs: Unpack[CreateParamsKwargs]) -> None: dsn = kwargs.pop("dsn", None) if dsn: self._params = PyConnectionParams(dsn) else: self._params = PyConnectionParams(None) self.update(**kwargs) @classmethod def _create( cls, params: dict[str, Any], ) -> Self: instance = super().__new__(cls) instance._params = params return instance def update(self, **kwargs: Unpack[CreateParamsKwargs]) -> None: if dsn := kwargs.pop('dsn', None): params = PyConnectionParams(dsn) for k, v in params.to_dict().items(): self._params[k] = v if server_settings := kwargs.pop("server_settings", None): for k2, v2 in server_settings.items(): self._params.update_server_settings(k2, v2) if host_specs := kwargs.pop("hosts", None): hosts, ports = zip(*host_specs) self._params['host'] = ','.join(hosts) self._params['port'] = ','.join(map(str, ports)) if (ssl_mode := kwargs.pop("sslmode", None)) is not None: mode: SSLMode = ssl_mode self._params["sslmode"] = mode.name if connect_timeout := kwargs.pop("connect_timeout", None): self._params["connect_timeout"] = str(connect_timeout) for k, v in kwargs.items(): if k == "database": k = "dbname" self._params[k] = v def clear_server_settings(self) -> None: self._params.clear_server_settings() def resolve(self) -> Self: return self._create( self._params.resolve("", get_pg_home_directory()), ) def __copy__(self) -> Self: return self._create(self._params.clone()) @property def hosts(self) -> Optional[list[tuple[dict[str, Any], int]]]: return self._params['hosts'] # type: ignore @property def host(self) -> Optional[str]: return self._params['host'] # type: ignore @property def port(self) -> Optional[int]: return self._params['port'] # type: ignore @property def user(self) -> Optional[str]: return self._params['user'] # type: ignore @property def password(self) -> Optional[str]: return self._params['password'] # type: ignore @property def database(self) -> Optional[str]: return self._params['dbname'] # type: ignore @property def connect_timeout(self) -> Optional[int]: connect_timeout = self._params['connect_timeout'] return int(connect_timeout) if connect_timeout else None @property def sslmode(self) -> Optional[SSLMode]: sslmode = self._params['sslmode'] return SSLMode.parse(sslmode) if sslmode is not None else None def to_dsn(self) -> str: dsn: str = self._params.to_dsn() return dsn @property def __dict__(self) -> dict[str, Any]: to_dict: dict[str, str] = self._params.to_dict() database = to_dict.pop('dbname', None) if database: to_dict['database'] = database return to_dict @__dict__.setter def __dict__(self, value: dict[str, Any]) -> None: new_params = self._params.__class__() try: for k, v in value.items(): new_params[k] = v self._params = new_params except Exception: raise ValueError("Failed to update __dict__") def __repr__(self) -> Any: return self._params.__repr__() ================================================ FILE: edb/server/protocol/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from . import protocol HttpProtocol = protocol.HttpProtocol __all__ = ('HttpProtocol',) ================================================ FILE: edb/server/protocol/ai_ext.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from dataclasses import dataclass, field from typing import ( cast, Any, AsyncIterator, ClassVar, Literal, NoReturn, Optional, Sequence, TYPE_CHECKING, ) import abc import asyncio import contextlib import contextvars import itertools import json import logging import uuid import tiktoken from mistral_common.tokens.tokenizers import mistral as mistral_tokenizer from edb import errors from edb.common import asyncutil from edb.common import debug from edb.common import enum as s_enum from edb.common import markup from edb.common import uuidgen from edb.common.typeutils import not_none from edb.server import compiler, http from edb.server import defines as edbdef from edb.server.compiler import sertypes from edb.server.protocol import execute from edb.server.protocol import request_scheduler as rs if TYPE_CHECKING: from edb.server import dbview from edb.server import tenant as srv_tenant from edb.server import pgcon from edb.server.protocol import protocol logger = logging.getLogger("edb.server.ai_ext") keepalive_token = "ai-index-builder" class AIExtError(Exception): http_status: ClassVar[http.HTTPStatus] = ( http.HTTPStatus.INTERNAL_SERVER_ERROR) def __init__( self, *args: object, json: Optional[dict[str, Any]] = None, ) -> None: super().__init__(*args) self._json = json def get_http_status(self) -> http.HTTPStatus: return self.__class__.http_status def json(self) -> dict[str, Any]: if self._json is not None: return self._json else: return { "message": str(self.args[0]), "type": self.__class__.__name__, } class AIProviderError(AIExtError): pass class ConfigurationError(AIExtError): pass class InternalError(AIExtError): pass class BadRequestError(AIExtError): http_status = http.HTTPStatus.BAD_REQUEST class ApiStyle(s_enum.StrEnum): OpenAI = 'OpenAI' Anthropic = 'Anthropic' Ollama = 'Ollama' class Tokenizer(abc.ABC): @abc.abstractmethod def encode(self, text: str) -> list[int]: """Encode text into tokens.""" raise NotImplementedError @abc.abstractmethod def encode_padding(self) -> int: """How many special characters are added to encodings?""" raise NotImplementedError @abc.abstractmethod def decode(self, tokens: list[int]) -> str: """Decode tokens into text.""" raise NotImplementedError def shorten_to_token_length( self, text: str, token_length: int ) -> tuple[str, int]: """Truncate text to a maximum token length.""" encoded = self.encode(text) if len(encoded) > token_length: encoded = encoded[:token_length] return self.decode(encoded), len(encoded) class OpenAITokenizer(Tokenizer): _instances: dict[str, OpenAITokenizer] = {} encoding: Any @classmethod def for_model(cls, model_name: str) -> OpenAITokenizer: if model_name in cls._instances: return cls._instances[model_name] tokenizer = OpenAITokenizer() tokenizer.encoding = tiktoken.encoding_for_model(model_name) cls._instances[model_name] = tokenizer return tokenizer def encode(self, text: str) -> list[int]: return cast(list[int], self.encoding.encode(text)) def encode_padding(self) -> int: return 0 def decode(self, tokens: list[int]) -> str: return cast(str, self.encoding.decode(tokens)) class MistralTokenizer(Tokenizer): _instances: dict[str, MistralTokenizer] = {} tokenizer: Any @classmethod def for_model(cls, model_name: str) -> MistralTokenizer: if model_name in cls._instances: return cls._instances[model_name] assert model_name == 'mistral-embed' tokenizer = MistralTokenizer() tokenizer.tokenizer = mistral_tokenizer.MistralTokenizer.v1() cls._instances[model_name] = tokenizer return tokenizer def encode(self, text: str) -> list[int]: # V1 tokenizer wraps input text with control tokens [INST] [/INST]. # # While these count towards the overal token limit, how special tokens # are applied to embedding requests is not documented. For now, directly # pass the text into the inner tokenizer. tokenized = self.tokenizer.instruct_tokenizer.tokenizer.encode( text, bos=False, eos=False ) return cast(list[int], tokenized) def encode_padding(self) -> int: # V1 tokenizer wraps input text with control tokens [INST] [/INST]. # # This is only 2 tokens, and testing shows that mistral-embed does add # two tokens to embeddings inputs. However, this is not documented, so # add some extra leeway in case things change. # # Note, other models may use significantly more control tokens. return 16 def decode(self, tokens: list[int]) -> str: return cast(str, self.tokenizer.decode(tokens)) class OllamaTokenizer(Tokenizer): """ Simply counts the number of characters. A tokenizer API is in progress, but unlikely to be released soon. """ _instances: dict[str, OllamaTokenizer] = {} @classmethod def for_model(cls, model_name: str) -> OllamaTokenizer: if model_name in cls._instances: return cls._instances[model_name] tokenizer = OllamaTokenizer() cls._instances[model_name] = tokenizer return tokenizer def encode(self, text: str) -> list[int]: return [ord(c) for c in text] def encode_padding(self) -> int: return 0 def decode(self, tokens: list[int]) -> str: return ''.join(chr(c) for c in tokens) class TestTokenizer(Tokenizer): _instances: dict[str, TestTokenizer] = {} @classmethod def for_model(cls, model_name: str) -> TestTokenizer: if model_name in cls._instances: return cls._instances[model_name] tokenizer = TestTokenizer() cls._instances[model_name] = tokenizer return tokenizer def encode(self, text: str) -> list[int]: return [ord(c) for c in text] def encode_padding(self) -> int: return 0 def decode(self, tokens: list[int]) -> str: return ''.join(chr(c) for c in tokens) def get_model_tokenizer( provider_name: str, model_name: str, ) -> Optional[Tokenizer]: """Get the tokenizer for a given provider and model""" if provider_name == 'builtin::openai': return OpenAITokenizer.for_model(model_name) elif provider_name == 'builtin::mistral': return MistralTokenizer.for_model(model_name) if provider_name == 'builtin::ollama': return OllamaTokenizer.for_model(model_name) elif provider_name == 'custom::test': return TestTokenizer.for_model(model_name) else: return None @dataclass(frozen=True) class ProviderConfig: name: str display_name: str api_url: str client_id: str secret: str api_style: ApiStyle def get_embeddings_from_result( self, embeddings_result: bytes ) -> list[list[float]]: """Decode and extract the embeddings from an embeddings request.""" decoded_result = json.loads( embeddings_result.decode("utf-8") ) if self.api_style == ApiStyle.Ollama: return cast( list[list[float]], decoded_result["embeddings"], ) else: return cast( list[list[float]], [ entry_result["embedding"] for entry_result in decoded_result["data"] ], ) @dataclass(frozen=True, kw_only=True) class BaseModel: name: str provider: str name_annotation: ClassVar[str] = "ext::ai::model_name" provider_annotation: ClassVar[str] = "ext::ai::model_provider" @dataclass(frozen=True, kw_only=True) class EmbeddingModel (BaseModel): max_input_tokens: int max_batch_tokens: int max_batch_size: int | None max_output_dimensions: int supports_shortening: bool gel_type: ClassVar[str] = "ext::ai::EmbeddingModel" max_model_input_tokens_annotation: ClassVar[str] = ( "ext::ai::embedding_model_max_input_tokens" ) max_batch_tokens_annotation: ClassVar[str] = ( "ext::ai::embedding_model_max_batch_tokens" ) max_batch_size_annotation: ClassVar[str] = ( "ext::ai::embedding_model_max_batch_size" ) max_output_dimensions_annotation: ClassVar[str] = ( "ext::ai::embedding_model_max_output_dimensions" ) supports_shortening_annotation: ClassVar[str] = ( "ext::ai::embedding_model_supports_shortening" ) @dataclass(frozen=True, kw_only=True) class TextGenerationModel (BaseModel): gel_type: ClassVar[str] = "ext::ai::TextGenerationModel" def start_extension( tenant: srv_tenant.Tenant, dbname: str, ) -> None: task_name = _get_builder_task_name(dbname) task = tenant.get_task(task_name) if task is None: logger.info(f"starting AI extension tasks on database {dbname!r}") tenant.create_task( _ext_ai_index_builder_controller_loop(tenant, dbname), interruptable=True, name=task_name, ) def stop_extension( tenant: srv_tenant.Tenant, dbname: str, ) -> None: task_name = _get_builder_task_name(dbname) task = tenant.get_task(task_name) if task is not None: logger.info(f"stopping AI extension tasks on database {dbname!r}") task.cancel() def _get_builder_task_name(dbname: str) -> str: return f"ext::ai::index builder on database {dbname!r}" _task_name = contextvars.ContextVar( "ext_ai_index_builder_task_name", default="-") async def _ext_ai_index_builder_controller_loop( tenant: srv_tenant.Tenant, dbname: str, ) -> None: task_name = _get_builder_task_name(dbname) _task_name.set(task_name) logger.info(f"started {task_name}") db = tenant.get_db(dbname=dbname) holding_lock = False await db.introspection() naptime_cfg = db.lookup_config("ext::ai::Config::indexer_naptime") naptime = naptime_cfg.to_microseconds() / 1000000 provider_schedulers: dict[str, ProviderScheduler] = {} try: while tenant.accept_new_tasks: if not db.tenant.is_database_connectable(dbname): # Don't do work if the database is not connectable, # e.g. being dropped await asyncio.sleep(naptime) continue models = [] sleep_timer: rs.Timer = rs.Timer(None, False) try: async with tenant.with_pgcon(dbname) as pgconn: models = await _ext_ai_fetch_active_models(pgconn) if models: if not holding_lock: holding_lock = await _ext_ai_lock(tenant, pgconn) if holding_lock: provider_contexts = _prepare_provider_contexts( db, pgconn, tenant.get_http_client(originator="ai/index"), models, provider_schedulers, naptime, ) try: sleep_timer = ( await _ext_ai_index_builder_work( provider_schedulers, provider_contexts, ) ) finally: if not sleep_timer.is_ready_and_urgent(): await asyncutil.deferred_shield( _ext_ai_unlock(tenant)) tenant.server.remove_keepalive_token( ( keepalive_token, tenant.get_instance_name(), dbname, ) ) holding_lock = False except Exception: logger.error( f"caught error in {task_name}", exc_info=tenant.accept_new_tasks, ) if not tenant.accept_new_tasks: break if not sleep_timer.is_ready_and_urgent(): delay = sleep_timer.remaining_time(naptime) if delay == naptime: logger.debug( f"{task_name}: " f"No work. Napping for {naptime:.2f} seconds." ) await asyncio.sleep(delay) finally: logger.info(f"stopped {task_name}") async def _ext_ai_fetch_active_models( pgconn: pgcon.PGConnection, ) -> list[tuple[int, str, str]]: models = await pgconn.sql_fetch( b""" SELECT id, name, provider FROM edgedbext.ai_active_embedding_models """, ) result = [] if models: for model in models: result.append(( int.from_bytes(model[0], byteorder="big", signed=True), model[1].decode("utf-8"), model[2].decode("utf-8"), )) return result # The _ext_ai_lock() is a long-term lock held in the system pgcon. It is used # in the index builder job above guarding multiple alternating database pgcons # and outgoing HTTP requests (free up pgcons while waiting for a response from # external services), so that different Gel tenants on the same backend # run this job exclusively. # # The following implementation is also safe to be used by multiple tasks within # the same tenant (though at the time of writing, there is only one such task # per tenant). To achieve this, we added an extra query on pg_locks to check if # it's already held by another task, because advisory locks allow reentrancy # from the same session (the same sys_pgcon). And to avoid racing conditions, # we use another advisory lock over the 2 queries of check-and-lock in the # local session. This also means, one must use _ext_ai_lock() instead of an # individual lock of the 2 locks here to avoid misuse. # # If you are editing the magic numbers here: make sure it fits in a Postgres # Oid type (uint32), or you'll need to change the `classid` query below. _EXT_AI_ADVISORY_LOCK = b"3987734540" _EXT_AI_ADVISORY_LOCK_LOCK = b"3987734541" async def _ext_ai_lock( tenant: srv_tenant.Tenant, pgconn: pgcon.PGConnection, ) -> bool: # We use transaction-level advisory locks to ensure releasing await pgconn.sql_execute(b"START TRANSACTION") try: b = await pgconn.sql_fetch_val( b"SELECT pg_try_advisory_xact_lock(" + _EXT_AI_ADVISORY_LOCK_LOCK + b")" ) if b == b'\x01': lock_free = await pgconn.sql_fetch_val( b''' SELECT NOT EXISTS ( SELECT 1 FROM pg_locks WHERE locktype = 'advisory' AND classid = 0 AND objid = \ ''' + _EXT_AI_ADVISORY_LOCK + b')') if lock_free == b'\x01': async with tenant.use_sys_pgcon() as syscon: # The long-term holding lock must be on session-level b = await syscon.sql_fetch_val( b"SELECT pg_try_advisory_lock(" + _EXT_AI_ADVISORY_LOCK + b")" ) return b == b'\x01' finally: await pgconn.sql_execute(b"ROLLBACK") return False async def _ext_ai_unlock( tenant: srv_tenant.Tenant, ) -> None: async with tenant.use_sys_pgcon() as syscon: await syscon.sql_fetch_val( b"SELECT pg_advisory_unlock(" + _EXT_AI_ADVISORY_LOCK + b")") def _prepare_provider_contexts( db: dbview.Database, pgconn: pgcon.PGConnection, http_client: http.HttpClient, models: list[tuple[int, str, str]], provider_schedulers: dict[str, ProviderScheduler], naptime: float, ) -> dict[str, ProviderContext]: models_by_provider: dict[str, list[str]] = {} for entry in models: model_name = entry[1] provider_name = entry[2] try: models_by_provider[provider_name].append(model_name) except KeyError: m = models_by_provider[provider_name] = [] m.append(model_name) # Drop any extra providers, they were probably deleted. unused_provider_names = { provider_name for provider_name in provider_schedulers.keys() if provider_name not in models_by_provider } for unused_provider_name in unused_provider_names: provider_schedulers.pop(unused_provider_name, None) # Create contexts provider_contexts = {} for provider_name, provider_models in models_by_provider.items(): if provider_name not in provider_schedulers: # Create new schedulers if necessary provider_schedulers[provider_name] = ProviderScheduler( service=rs.Service( limits={'requests': None, 'tokens': None}, ), provider_name=provider_name, ) provider_scheduler = provider_schedulers[provider_name] if not provider_scheduler.timer.is_ready(): continue provider_contexts[provider_name] = ProviderContext( naptime=naptime, db=db, pgconn=pgconn, http_client=http_client, provider_models=provider_models, ) return provider_contexts async def _ext_ai_index_builder_work( provider_schedulers: dict[str, ProviderScheduler], provider_contexts: dict[str, ProviderContext], ) -> rs.Timer: async with asyncio.TaskGroup() as g: for provider_name, provider_scheduler in provider_schedulers.items(): if provider_name not in provider_contexts: continue provider_context = provider_contexts[provider_name] g.create_task(provider_scheduler.process(provider_context)) sleep_timer = rs.Timer.combine( provider_scheduler.timer for provider_scheduler in provider_schedulers.values() ) if sleep_timer is not None: return sleep_timer else: # Return any non-urgent timer return rs.Timer(None, False) @dataclass(frozen=True) class EmbeddingsData: embeddings: bytes @dataclass class ProviderContext(rs.Context): db: dbview.Database pgconn: pgcon.PGConnection http_client: http.HttpClient provider_models: list[str] @dataclass class ProviderScheduler(rs.Scheduler[EmbeddingsData]): provider_name: str = '' # If a text is too long for a model, it will be excluded from embeddings # to prevent pointlessly wasting requests and tokens. # An embedding index may have its `truncate_to_max` flag switched. If the # flag is on, previously excluded inputs will be truncated and processed. model_excluded_ids: dict[str, list[str]] = field(default_factory=dict) async def get_params( self, context: rs.Context, ) -> Optional[Sequence[EmbeddingsParams]]: assert isinstance(context, ProviderContext) rv = await _generate_embeddings_params( context.db, context.pgconn, context.http_client, self.provider_name, context.provider_models, self.model_excluded_ids, tokens_rate_limit=( self.service.limits['tokens'].total if self.service.limits['tokens'] is not None else None ), ) if rv: context.db.server.add_keepalive_token( ( keepalive_token, context.db.tenant.get_instance_name(), context.db.name, ) ) return rv def finalize(self, execution_report: rs.ExecutionReport) -> None: task_name = _task_name.get() for message in execution_report.known_error_messages: logger.error( f"{task_name}: " f"Could not generate embeddings for {self.provider_name} " f"due to an internal error: {message}" ) @dataclass(frozen=True, kw_only=True) class EmbeddingsParams(rs.Params[EmbeddingsData]): pgconn: pgcon.PGConnection http_client: http.HttpClient provider: ProviderConfig model_name: str inputs: list[tuple[PendingEmbedding, str]] token_count: int shortening: Optional[int] user: Optional[str] def costs(self) -> dict[str, int]: return { 'requests': 1, 'tokens': self.token_count, } def create_request(self) -> EmbeddingsRequest: return EmbeddingsRequest(self) class EmbeddingsRequest(rs.Request[EmbeddingsData]): async def run(self) -> Optional[rs.Result[EmbeddingsData]]: task_name = _task_name.get() try: assert isinstance(self.params, EmbeddingsParams) result = await _generate_embeddings( self.params.provider, self.params.model_name, [input[1] for input in self.params.inputs], self.params.shortening, self.params.user, self.params.http_client, ) result.pgconn = self.params.pgconn result.pending_entries = [ input[0] for input in self.params.inputs ] return result except AIExtError as e: logger.error(f"{task_name}: {e}") return None except Exception as e: logger.error( f"{task_name}: could not generate embeddings " f"due to an internal error: {e}" ) return None class EmbeddingsResult(rs.Result[EmbeddingsData]): provider_cfg: ProviderConfig pgconn: Optional[Any] = None pending_entries: Optional[list[PendingEmbedding]] = None async def finalize(self) -> None: if isinstance(self.data, rs.Error): return if self.pgconn is None or self.pending_entries is None: return # Entries must line up with the embeddings data: # - `_generate_embeddings` produces produces embeddings data matching # the order of its inputs # # Entries must be grouped by target rel: # - `_generate_embeddings_params` sorts inputs by target rel before groups = itertools.groupby( self.pending_entries, key=lambda e: (e.target_rel, e.target_attr), ) offset = 0 for (rel, attr), items in groups: ids = [item.id for item in items] await _update_embeddings_in_db( self.pgconn, self.provider_cfg, rel, attr, ids, self.data.embeddings, offset, ) offset += len(ids) async def _generate_embeddings_params( db: dbview.Database, pgconn: pgcon.PGConnection, http_client: http.HttpClient, provider_name: str, provider_models: list[str], model_excluded_ids: dict[str, list[str]], *, tokens_rate_limit: Optional[int | Literal['unlimited']], ) -> Optional[list[EmbeddingsParams]]: task_name = _task_name.get() try: provider_cfg = _get_provider_config( db=db, provider_name=provider_name) except LookupError as e: logger.error(f"{task_name}: {e}") return None embedding_models = await _get_embedding_models( db, provider_models ) model_pending_entries: dict[str, list[PendingEmbedding]] = {} for model_name in provider_models: logger.debug( f"{task_name} considering {model_name!r} " f"indexes via {provider_name!r}" ) pending_entries = await _get_pending_embeddings( pgconn, model_name, model_excluded_ids ) if not pending_entries: continue logger.debug( f"{task_name} found {len(pending_entries)} entries to index" ) try: model_list = model_pending_entries[model_name] except KeyError: model_list = model_pending_entries[model_name] = [] model_list.extend(pending_entries) embeddings_params: list[EmbeddingsParams] = [] for model_name, pending_entries in model_pending_entries.items(): embedding_model = embedding_models[model_name] groups = itertools.groupby( pending_entries, key=lambda e: e.target_dims_shortening ) for shortening, part_iter in groups: part = list(part_iter) part_texts = [(p.text, p.truncate_to_max) for p in part] batches, excluded_indexes = batch_texts( part_texts, get_model_tokenizer(provider_name, model_name), max_input_tokens=embedding_model.max_input_tokens, max_batch_tokens=embedding_model.max_batch_tokens, max_batch_size=embedding_model.max_batch_size, ) if excluded_indexes: if model_name not in model_excluded_ids: model_excluded_ids[model_name] = [] model_excluded_ids[model_name].extend( part[excluded_index].id.hex for excluded_index in excluded_indexes ) for batch in batches: inputs = [ (part[entry.input_index], entry.input_text) for entry in batch.entries ] # Sort the batches by target_rel. This groups embeddings # for each table together. # This is necessary for `EmbeddingsResult.finalize()` inputs.sort(key=lambda e: e[0].target_rel) embeddings_params.append(EmbeddingsParams( pgconn=pgconn, provider=provider_cfg, model_name=model_name, inputs=inputs, token_count=batch.token_count, shortening=shortening, user=None, http_client=http_client, )) return embeddings_params @dataclass(frozen=True, kw_only=True) class TextBatchEntry: input_index: int input_text: str @dataclass(frozen=True, kw_only=True) class TextBatch: entries: list[TextBatchEntry] token_count: int def batch_texts( texts: list[tuple[str, bool]], tokenizer: Optional[Tokenizer], *, max_input_tokens: int, max_batch_tokens: int, max_batch_size: int | None, ) -> tuple[list[TextBatch], list[int]]: """Given a list of texts and whether each can be truncated, produce a list of valid texts to batch. Additionally, returns a list of indexes of texts which are too long and should be excluded from future embeddings requests. """ excluded_indexes: list[int] = [] if tokenizer: input_indexes: list[int] = [] input_texts: list[str] = [] for index, (text, allowed_to_truncate) in enumerate(texts): ensured = _ensure_text_token_length( text, allowed_to_truncate, tokenizer, max_input_tokens ) if ensured is None: # If the text is too long, mark it as excluded and # skip. excluded_indexes.append(index) continue input_indexes.append(index) input_texts.append(ensured) # Group the valid texts into batches based on token count batched_inputs = _batch_embeddings_inputs( tokenizer, input_texts, max_batch_tokens, max_batch_size ) # Gather results batches = [ TextBatch( entries=[ TextBatchEntry( input_index=input_indexes[index], input_text=input_texts[index], ) for index in batch_input_indexes ], token_count=token_count ) for batch_input_indexes, token_count in batched_inputs ] elif max_batch_size: batch_count = (len(texts) - 1) // max_batch_size + 1 batches = [ TextBatch( entries=[ TextBatchEntry( input_index=index, input_text=texts[index][0], ) for index in range( batch_index * max_batch_size, min((batch_index + 1) * max_batch_size, len(texts)) ) ], token_count=0, ) for batch_index in range(batch_count) ] else: batches = [ TextBatch( entries=[ TextBatchEntry( input_index=index, input_text=texts[index][0], ) for index in range(len(texts)) ], token_count=0, ) ] return (batches, excluded_indexes) def _ensure_text_token_length( text: str, allowed_to_truncate: bool, tokenizer: Tokenizer, max_token_count: int, ) -> Optional[str]: """Ensure text does not exceed the allowed token length. If the text is ok, return the text unchanged. If the text is too long and `allowed_to_truncate` is true, return the text truncated. Otherwise return None. """ truncate_length = max_token_count - tokenizer.encode_padding() if allowed_to_truncate: text, current_token_count = ( tokenizer.shorten_to_token_length(text, truncate_length) ) else: current_token_count = len(tokenizer.encode(text)) if current_token_count > truncate_length: return None return text @dataclass(frozen=True, kw_only=True) class PendingEmbedding: id: uuid.UUID text: str target_rel: str target_attr: str target_dims_shortening: Optional[int] truncate_to_max: bool async def _get_pending_embeddings( pgconn: pgcon.PGConnection, model_name: str, model_excluded_ids: dict[str, list[str]], ) -> list[PendingEmbedding]: task_name = _task_name.get() where_clause = "" if ( model_name in model_excluded_ids and (excluded_ids := model_excluded_ids[model_name]) ): # Only exclude long text if it won't be auto-truncated. logger.debug( f"{task_name} skipping {len(excluded_ids)} indexes " f"for {model_name!r}" ) where_clause = (f""" WHERE q."id" not in ({','.join( "'" + excluded_id + "'" for excluded_id in excluded_ids )}) OR q."truncate_to_max" """) entries = await pgconn.sql_fetch( f""" SELECT * FROM ( SELECT "id", "text", "target_rel", "target_attr", "target_dims_shortening", "truncate_to_max" FROM edgedbext."ai_pending_embeddings_{model_name}" LIMIT 500 ) AS q {where_clause} ORDER BY q."target_dims_shortening" """.encode(), tx_isolation=edbdef.TxIsolationLevel.RepeatableRead, ) if not entries: return [] result = [] for entry in entries: result.append(PendingEmbedding( id=uuidgen.from_bytes(entry[0]), text=entry[1].decode("utf-8"), target_rel=entry[2].decode(), target_attr=entry[3].decode(), target_dims_shortening=( int.from_bytes( entry[4], byteorder="big", signed=False, ) if entry[4] is not None else None ), truncate_to_max=bool.from_bytes(entry[5]), )) return result def _batch_embeddings_inputs( tokenizer: Tokenizer, inputs: list[str], max_batch_tokens: int, max_batch_size: int | None, ) -> list[tuple[list[int], int]]: """Create batches of embeddings inputs. Returns batches which are a tuple of: - Indexes of input strings grouped to avoid exceeding the max_batch_token - The batch's token count """ # Get token counts input_token_counts = [ len(tokenizer.encode(input)) for input in inputs ] # Get indexes of inputs, sorted from shortest to longest by token count # Use the text itself as a tie breaker for consistency unbatched_input_indexes = list(range(len(inputs))) unbatched_input_indexes.sort( key=lambda index: (input_token_counts[index], inputs[index]), reverse=False, ) def unbatched_token_count(unbatched_index: int) -> int: return input_token_counts[unbatched_input_indexes[unbatched_index]] # Remove any inputs that are larger than the maximum while ( unbatched_input_indexes and unbatched_token_count(-1) > max_batch_tokens ): unbatched_input_indexes.pop() batches: list[tuple[list[int], int]] = [] while unbatched_input_indexes: # Start with the largest available input batch_input_indexes = [unbatched_input_indexes[-1]] batch_token_count = unbatched_token_count(-1) unbatched_input_indexes.pop() if batch_token_count < max_batch_tokens: # Then add the smallest available input as long as long as the # max batch token and input counts aren't exceeded unbatched_index = 0 while ( unbatched_index < len(unbatched_input_indexes) and ( max_batch_size is None or len(batch_input_indexes) < max_batch_size ) ): if ( batch_token_count + unbatched_token_count(unbatched_index) <= max_batch_tokens ): batch_input_indexes.append( unbatched_input_indexes[unbatched_index] ) batch_token_count += unbatched_token_count(unbatched_index) unbatched_input_indexes.pop(unbatched_index) else: unbatched_index += 1 batches.append((batch_input_indexes, batch_token_count)) return batches async def _update_embeddings_in_db( pgconn: pgcon.PGConnection, provider_cfg: ProviderConfig, rel: str, attr: str, ids: list[uuid.UUID], embeddings: bytes, offset: int, ) -> int: id_array = '{' + ', '.join(f'"{id.hex}"' for id in ids) + '}' entries = await pgconn.sql_fetch_val( f""" WITH upd AS ( UPDATE {rel} AS target SET {attr} = ( (embeddings.data)::text::edgedb.vector) FROM ( SELECT row_number() over () AS n, j.data FROM (SELECT data FROM json_array_elements($1::json) AS data OFFSET $3::text::int ) AS j ) AS embeddings, unnest($2::text::text[]) WITH ORDINALITY AS ids(id, n) WHERE embeddings."n" = ids."n" AND target."id" = ids."id"::uuid RETURNING target."id" ) SELECT count(*)::text FROM upd """.encode(), args=( str(provider_cfg.get_embeddings_from_result(embeddings)).encode(), id_array.encode(), str(offset).encode(), ), tx_isolation=edbdef.TxIsolationLevel.RepeatableRead, ) return int(entries.decode()) async def _generate_embeddings( provider: ProviderConfig, model_name: str, inputs: list[str], shortening: Optional[int], user: Optional[str], http_client: http.HttpClient, ) -> EmbeddingsResult: task_name = _task_name.get() count = len(inputs) suf = "s" if count > 1 else "" logger.debug( f"{task_name} generating embeddings via {model_name!r} " f"of {provider.name!r} for {len(inputs)} object{suf}" ) if provider.api_style == ApiStyle.OpenAI: result = await _generate_openai_embeddings( provider, model_name, inputs, shortening, user, http_client ) elif provider.api_style == ApiStyle.Ollama: result = await _generate_ollama_embeddings( provider, model_name, inputs, shortening, http_client ) else: raise RuntimeError( f"unsupported model provider API style: {provider.api_style}, " f"provider: {provider.name}" ) result.provider_cfg = provider return result async def _generate_openai_embeddings( provider: ProviderConfig, model_name: str, inputs: list[str], shortening: Optional[int], user: Optional[str], http_client: http.HttpClient, ) -> EmbeddingsResult: headers = { "Authorization": f"Bearer {provider.secret}", } if provider.name == "builtin::openai" and provider.client_id: headers["OpenAI-Organization"] = provider.client_id client = http_client.with_context( headers=headers, base_url=provider.api_url, ) params: dict[str, Any] = { "model": model_name, "encoding_format": "float", "input": inputs, } if shortening is not None: params["dimensions"] = shortening if user is not None: params["user"] = user result = await client.post( "/embeddings", json=params, ) error = None if result.status_code >= 400: error = rs.Error( message=( f"API call to generate embeddings failed with status " f"{result.status_code}: {result.text}" ), retry=( # If the request fails with 429 - too many requests, it can be # retried result.status_code == 429 ), ) return EmbeddingsResult( data=(error if error else EmbeddingsData(result.bytes())), limits=_read_openai_limits(result), ) def _read_openai_header_field( result: Any, field_names: list[str], ) -> Optional[int]: # Return the value of the first requested field available try: for field_name in field_names: if field_name in result.headers: header_value = result.headers[field_name] return int(header_value) if header_value is not None else None except (ValueError, TypeError): pass return None def _read_openai_limits( result: Any, ) -> dict[str, rs.Limits]: request_limit = _read_openai_header_field( result, [ 'x-ratelimit-limit-project-requests', 'x-ratelimit-limit-requests', ], ) request_remaining = _read_openai_header_field( result, [ 'x-ratelimit-remaining-project-requests', 'x-ratelimit-remaining-requests', ], ) token_limit = _read_openai_header_field( result, [ 'x-ratelimit-limit-project-tokens', 'x-ratelimit-limit-tokens', ], ) token_remaining = _read_openai_header_field( result, [ 'x-ratelimit-remaining-project-tokens', 'x-ratelimit-remaining-tokens', ], ) return { 'requests': rs.Limits( total=request_limit, remaining=request_remaining, ), 'tokens': rs.Limits( total=token_limit, remaining=token_remaining, ), } async def _generate_ollama_embeddings( provider: ProviderConfig, model_name: str, inputs: list[str], shortening: Optional[int], http_client: http.HttpClient, ) -> EmbeddingsResult: headers: dict[str, str] = {} client = http_client.with_context( headers=headers, base_url=provider.api_url, ) params: dict[str, Any] = { "model": model_name, "input": inputs, } if shortening is not None: params["dimensions"] = shortening result = await client.post( "/embed", json=params, ) error = None if result.status_code >= 400: error = rs.Error( message=( f"API call to generate embeddings failed with status " f"{result.status_code}: {result.text}" ), retry=( # If the request fails with 429 - too many requests, it can be # retried result.status_code == 429 ), ) return EmbeddingsResult( data=(error if error else EmbeddingsData(result.bytes())), limits=_ollama_limits(), ) def _ollama_limits() -> dict[str, rs.Limits]: return { 'requests': rs.Limits( total='unlimited', remaining=None, ), 'tokens': rs.Limits( total='unlimited', remaining=None, ), } async def _start_chat( *, protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, provider: ProviderConfig, http_client: http.HttpClient, model_name: str, messages: list[dict[str, Any]], stream: bool, temperature: Optional[float], top_p: Optional[float], max_tokens: Optional[int], seed: Optional[int], safe_prompt: Optional[bool], top_k: Optional[int], logit_bias: Optional[dict[int, int]], logprobs: Optional[bool], user: Optional[str], tools: Optional[list[dict[str, Any]]], ) -> None: if provider.api_style == ApiStyle.OpenAI: await _start_openai_chat( protocol=protocol, request=request, response=response, provider=provider, http_client=http_client, model_name=model_name, messages=messages, stream=stream, temperature=temperature, top_p=top_p, max_tokens=max_tokens, seed=seed, safe_prompt=safe_prompt, logit_bias=logit_bias, logprobs=logprobs, user=user, tools=tools, ) elif provider.api_style == ApiStyle.Anthropic: await _start_anthropic_chat( protocol=protocol, request=request, response=response, provider=provider, http_client=http_client, model_name=model_name, messages=messages, stream=stream, temperature=temperature, top_p=top_p, top_k=top_k, tools=tools, max_tokens=max_tokens, ) elif provider.api_style == ApiStyle.Ollama: await _start_ollama_chat( protocol=protocol, request=request, response=response, provider=provider, http_client=http_client, model_name=model_name, messages=messages, stream=stream, temperature=temperature, top_p=top_p, top_k=top_k, tools=tools, ) else: raise RuntimeError( f"unsupported model provider API style: {provider.api_style}, " f"provider: {provider.name}" ) @contextlib.asynccontextmanager async def aconnect_sse( client: http.HttpClient, method: str, url: str, **kwargs: Any, ) -> AsyncIterator[http.ResponseSSE]: headers = kwargs.pop("headers", {}) headers["Accept"] = "text/event-stream" headers["Cache-Control"] = "no-store" stm = await client.stream_sse( method=method, path=url, headers=headers, **kwargs ) if isinstance(stm, http.Response): raise AIProviderError( f"API call to generate chat completions failed with status " f"{stm.status_code}: {stm.text}" ) async with stm as response: if response.status_code >= 400: # Unlikely that we have a streaming response with a non-200 result raise AIProviderError( f"API call to generate chat completions failed with status " f"{response.status_code}" ) yield response async def _start_openai_like_chat( *, protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, provider_name: str, client: http.HttpClient, model_name: str, messages: list[dict[str, Any]], stream: bool, temperature: Optional[float], top_p: Optional[float], max_tokens: Optional[int], seed: Optional[int], safe_prompt: Optional[bool], logit_bias: Optional[dict[int, int]], logprobs: Optional[bool], user: Optional[str], tools: Optional[list[dict[str, Any]]], ) -> None: isOpenAI = provider_name == "builtin::openai" params: dict[str, Any] = { "model": model_name, "messages": messages, } if temperature is not None: params["temperature"] = temperature if top_p is not None: params["top_p"] = top_p if tools is not None: params["tools"] = tools if isOpenAI and logit_bias is not None: params["logit_bias"] = logit_bias if isOpenAI and logprobs is not None: params["logprobs"] = logprobs if isOpenAI and user is not None: params["user"] = user if not isOpenAI and safe_prompt is not None: params["safe_prompt"] = safe_prompt if max_tokens is not None: if isOpenAI: params["max_completion_tokens"] = max_tokens else: params["max_tokens"] = max_tokens if seed is not None: if isOpenAI: params["seed"] = seed else: params["random_seed"] = seed if stream: async with aconnect_sse( client, method="POST", url="/chat/completions", json={ **params, "stream": True, } ) as event_source: # we need tool_index and finish_reason to correctly # send 'content_block_stop' chunk for tool call messages tool_index = 0 finish_reason = "unknown" async for sse in event_source: if not response.sent: response.status = http.HTTPStatus.OK response.content_type = b'text/event-stream' response.close_connection = False response.custom_headers["Cache-Control"] = "no-cache" protocol.write(request, response) if sse.event != "message": continue if sse.data == "[DONE]": # mistral doesn't send finish_reason for tool calls if finish_reason == "unknown": event = ( b'event: content_block_stop\n' + b'data: {"type": "content_block_stop",' + b'"index": ' + str(tool_index).encode() + b'}\n\n' ) protocol.write_raw(event) event = ( b'event: message_stop\n' + b'data: {"type": "message_stop"}\n\n' ) protocol.write_raw(event) break message = sse.json() if message.get("object") == "chat.completion.chunk": data = message.get("choices")[0] delta = data.get("delta") role = delta.get("role") tool_calls = delta.get("tool_calls") if role: event_data = json.dumps({ "type": "message_start", "message": { "id": message["id"], "role": role, "model": message["model"], "usage": message.get("usage") }, }).encode("utf-8") event = ( b'event: message_start\n' + b'data: ' + event_data + b'\n\n' ) protocol.write_raw(event) event_data = json.dumps({ "type": "content_block_start", "index": 0, "content_block": { "type": "text", "text": "" } }).encode("utf-8") event = ( b'event: content_block_start\n' + b'data: ' + event_data + b'\n\n' ) protocol.write_raw(event) # if there's only one openai tool call it shows up here if tool_calls: for tool_call in tool_calls: tool_index = tool_call["index"] event_data = json.dumps({ "type": "content_block_start", "index": tool_call["index"] + 1, "content_block": { "id": tool_call["id"], "type": "tool_use", "name": tool_call["function"]["name"], "args": tool_call["function"]["arguments"], }, }).encode("utf-8") event = ( b'event: content_block_start\n' + b'data:' + event_data + b'\n\n' ) protocol.write_raw(event) # if there are few openai tool calls, they show up here # mistral tool calls always show up here elif tool_calls: # OpenAI provides index, Mistral doesn't for index, tool_call in enumerate(tool_calls): currentIndex = tool_call.get("index") or index if tool_call.get("type") == "function" or \ "id" in tool_call: if currentIndex > 0: tool_index = currentIndex # send the stop chunk for the previous tool event = ( b'event: content_block_stop\n' + b'data: { \ "type": "content_block_stop",' + b'"index": ' + str(currentIndex).encode() + b'}\n\n' ) protocol.write_raw(event) event_data = json.dumps({ "type": "content_block_start", "index": currentIndex + 1, "content_block": { "id": tool_call.get("id"), "type": "tool_use", "name": tool_call["function"]["name"], "args": tool_call["function"]["arguments"], }, }).encode("utf-8") event = ( b'event: content_block_start\n' + b'data:' + event_data + b'\n\n' ) protocol.write_raw(event) else: event_data = json.dumps({ "type": "content_block_delta", "index": currentIndex + 1, "delta": { "type": "tool_call_delta", "args": tool_call["function"]["arguments"], }, }).encode("utf-8") event = ( b'event: content_block_delta\n' + b'data:' + event_data + b'\n\n' ) protocol.write_raw(event) elif finish_reason := data.get("finish_reason"): index = ( tool_index + 1 if finish_reason == "tool_calls" else 0 ) event = ( b'event: content_block_stop\n' + b'data: {"type": "content_block_stop",' + b'"index": ' + str(index).encode() + b'}\n\n' ) protocol.write_raw(event) event_data = json.dumps({ "type": "message_delta", "delta": { "stop_reason": finish_reason, }, "usage": message.get("usage") }).encode("utf-8") event = ( b'event: message_delta\n' + b'data: ' + event_data + b'\n\n' ) protocol.write_raw(event) else: event_data = json.dumps({ "type": "content_block_delta", "index": 0, "delta": { "type": "text_delta", "text": delta.get("content"), }, "logprobs": data.get("logprobs"), }).encode("utf-8") event = ( b'event: content_block_delta\n' + b'data:' + event_data + b'\n\n' ) protocol.write_raw(event) protocol.close() else: result = await client.post( "/chat/completions", json={ **params } ) if result.status_code >= 400: raise AIProviderError( f"API call to generate chat completions failed with status " f"{result.status_code}: {result.text}" ) response.status = http.HTTPStatus.OK result_data = result.json() choice = result_data["choices"][0] tool_calls = choice["message"].get("tool_calls") tool_calls_formatted = [ { "id": tool_call["id"], "type": tool_call["type"], "name": tool_call["function"]["name"], "args": json.loads(tool_call["function"]["arguments"]), } for tool_call in tool_calls or [] ] body = { "id": result_data["id"], "model": result_data["model"], "text": choice["message"]["content"], "finish_reason": choice.get("finish_reason"), "usage": result_data.get("usage"), "logprobs": choice.get("logprobs"), "tool_calls": tool_calls_formatted, } response.content_type = b'application/json' response.body = json.dumps(body).encode("utf-8") async def _start_openai_chat( *, protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, provider: ProviderConfig, http_client: http.HttpClient, model_name: str, messages: list[dict[str, Any]], stream: bool, temperature: Optional[float], top_p: Optional[float], max_tokens: Optional[int], seed: Optional[int], safe_prompt: Optional[bool], logit_bias: Optional[dict[int, int]], logprobs: Optional[bool], user: Optional[str], tools: Optional[list[dict[str, Any]]], ) -> None: headers = { "Authorization": f"Bearer {provider.secret}", } if provider.name == "builtin::openai" and provider.client_id: headers["OpenAI-Organization"] = provider.client_id client = http_client.with_context( base_url=provider.api_url, headers=headers, ) await _start_openai_like_chat( protocol=protocol, request=request, response=response, provider_name=provider.name, client=client, model_name=model_name, messages=messages, stream=stream, temperature=temperature, top_p=top_p, max_tokens=max_tokens, seed=seed, safe_prompt=safe_prompt, logit_bias=logit_bias, logprobs=logprobs, user=user, tools=tools, ) # Anthropic differs from OpenAI and Mistral as there's no tool chunk: # tool_call(tool_use) is part of the assistant chunk, and # tool_result is part of the user chunk. async def _start_anthropic_chat( *, protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, provider: ProviderConfig, http_client: http.HttpClient, model_name: str, messages: list[dict[str, Any]], stream: bool, temperature: Optional[float], top_p: Optional[float], top_k: Optional[int], tools: Optional[list[dict[str, Any]]], max_tokens: Optional[int], ) -> None: headers = { "x-api-key": f"{provider.secret}", } if provider.name == "builtin::anthropic": headers["anthropic-version"] = "2023-06-01" headers["anthropic-beta"] = "messages-2023-12-15" client = http_client.with_context( headers={ "anthropic-version": "2023-06-01", "anthropic-beta": "messages-2023-12-15", "x-api-key": f"{provider.secret}", }, base_url=provider.api_url, ) anthropic_messages = [] system_prompt_parts = [] for message in messages: if message["role"] == "system": system_prompt_parts.append(message["content"]) elif message["role"] == "assistant" and "tool_calls" in message: # Anthropic fails when an assistant chunk has multiple tool calls # and is followed by several tool_result chunks (or a user chunk # with multiple tool_results). Each assistant chunk should have # only 1 tool_use, followed by 1 tool_result chunk. for tool_call in message["tool_calls"]: msg = { "role": "assistant", "content": [ { "id": tool_call["id"], "type": "tool_use", "name": tool_call["function"]["name"], "input": json.loads( tool_call["function"]["arguments"]), } ], } anthropic_messages.append(msg) # Check if message is a tool result elif message["role"] == "tool": tool_result = { "role": "user", "content": [ { "type": "tool_result", "tool_use_id": message["tool_call_id"], "content": message["content"] } ], } anthropic_messages.append(tool_result) else: anthropic_messages.append(message) system_prompt = "\n".join(system_prompt_parts) # Each tool_use chunk must be followed by a matching tool_result chunk reordered_messages = [] # Separate tool_result messages by tool_use_id for faster access tool_result_map = { item["content"][0]["tool_use_id"]: item for item in anthropic_messages if item["role"] == "user" and isinstance(item["content"][0], dict) and item["content"][0]["type"] == "tool_result" } for message in anthropic_messages: if message["role"] == "assistant": reordered_messages.append(message) if isinstance(message["content"], list): for item in message["content"]: if item["type"] == "tool_use": # find the matching user tool_result message tool_use_id = item["id"] if tool_use_id in tool_result_map: reordered_messages.append( tool_result_map[tool_use_id]) # append user message that is not tool_result elif not (message["role"] == "user" and isinstance(message["content"][0], dict) and message["content"][0]["type"] == "tool_result"): reordered_messages.append(message) params = { "model": model_name, "messages": reordered_messages, "system": system_prompt, **({"temperature": temperature} if temperature is not None else {}), **({"top_p": top_p} if top_p is not None else {}), **{"max_tokens": max_tokens if max_tokens is not None else 4096}, **({"top_k": top_k} if top_k is not None else {}), **({"tools": tools} if tools is not None else {}), } if stream: async with aconnect_sse( client, method="POST", url="/messages", json={ **params, "stream": True, } ) as event_source: tool_index = 0 async for sse in event_source: if not response.sent: response.status = http.HTTPStatus.OK response.content_type = b'text/event-stream' response.close_connection = False response.custom_headers["Cache-Control"] = "no-cache" protocol.write(request, response) if sse.event == "message_start": message = sse.json()["message"] for k in tuple(message): if k not in {"id", "type", "role", "model", "usage"}: del message[k] message["usage"] = { "prompt_tokens": message["usage"]["input_tokens"], "completion_tokens": message["usage"]["output_tokens"] } message_data = json.dumps(message).encode("utf-8") event = ( b'event: message_start\n' + b'data: {"type": "message_start",' + b'"message":' + message_data + b'}\n\n' ) protocol.write_raw(event) elif sse.event == "content_block_start": sse_data = json.loads(sse.data) protocol.write_raw( b'event: content_block_start\n' + b'data: ' + json.dumps(sse_data).encode("utf-8") + b'\n\n' ) # we don't send content_block_stop event when text # chunk ends, should be okay since we don't consume # this event on the client side data = sse.json() if ( "content_block" in data and data["content_block"].get("type") == "tool_use" ): currentIndex = data["index"] if currentIndex > 0: tool_index = currentIndex event_data = json.dumps({ "type": "content_block_stop", "index": currentIndex - 1}) protocol.write_raw( b'event: content_block_stop\n' + b'data: ' + event_data.encode("utf-8") + b'\n\n' ) elif sse.event == "content_block_delta": message = sse.json() # it is always dict irl but test is failing delta = message.get("delta") if delta and delta.get("type") == "input_json_delta": delta["type"] = "tool_call_delta" if delta and "partial_json" in delta: delta["args"] = delta.pop("partial_json") event_data = json.dumps(message) event = ( b'event: content_block_delta\n' + b'data: ' + event_data.encode("utf-8") + b'\n\n' ) protocol.write_raw(event) elif sse.event == "message_delta": message = sse.json() if message["delta"]["stop_reason"] == "tool_use": event = ( b'event: content_block_stop\n' + b'data: {"type": "content_block_stop",' + b'"index": ' + str(tool_index).encode("utf-8") + b'}\n\n' ) protocol.write_raw(event) event_data = json.dumps({ "type": "message_delta", "delta": message["delta"], "usage": {"completion_tokens": message["usage"]["output_tokens"]} }) event = ( b'event: message_delta\n' + b'data: ' + event_data.encode("utf-8") + b'\n\n' ) protocol.write_raw(event) elif sse.event == "message_stop": event = ( b'event: message_stop\n' + b'data: {"type": "message_stop"}\n\n' ) protocol.write_raw(event) # needed because stream doesn't close itself protocol.close() protocol.close() else: result = await client.post( "/messages", json={ **params } ) if result.status_code >= 400: raise AIProviderError( f"API call to generate chat completions failed with status " f"{result.status_code}: {result.text}" ) response.status = http.HTTPStatus.OK response.content_type = b'application/json' result_data = result.json() tool_calls = [ item for item in result_data["content"] if item.get("type") == "tool_use" ] tool_calls_formatted = [ { "id": tool_call["id"], "type": "function", "name": tool_call["name"], "args": tool_call["input"], } for tool_call in tool_calls ] body = { "id": result_data["id"], "model": result_data["model"], "text": next((item["text"] for item in result_data["content"] if item.get("type") == "text"), ""), "finish_reason": result_data["stop_reason"], "usage": { "prompt_tokens": result_data["usage"]["input_tokens"], "completion_tokens": result_data["usage"]["output_tokens"] }, "tool_calls": tool_calls_formatted, } response.body = json.dumps( body ).encode("utf-8") async def _start_ollama_chat( *, protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, provider: ProviderConfig, http_client: http.HttpClient, model_name: str, messages: list[dict[str, Any]], stream: bool, temperature: Optional[float], top_p: Optional[float], top_k: Optional[int], tools: Optional[list[dict[str, Any]]], ) -> None: # The default API doesn't produce a SSE stream. Use the experimental # OpenAI-like API if we need a stream. base_url = provider.api_url if stream: base_url = '/'.join( base_url.split('/')[:-1] + ['v1'] ) # Generate params differently for stream and non-stream modes since they # use different APIs. options = { **({"temperature": temperature} if temperature is not None else {}), **({"top_p": top_p} if top_p is not None else {}), **({"top_k": top_k} if top_k is not None else {}), } if stream: params = { "model": model_name, "messages": messages, "options": options, } # Only include tools in streaming params if no tool messages are # provided. if tools is not None and not any( message["role"] == "tool" for message in messages ): params["tools"] = tools else: converted_messages = [] for message in messages: if message["role"] == "user": # Ollama can't handle content block messages. # Unpack them into separate messages. if isinstance(message["content"], str): converted_messages.append(message) else: # array of content blocks for block in message["content"]: if block["type"] != "text": raise TypeError( f"Unsupported content type: '{block["type"]}'. " f"For non-text content, use streamed mode." ) converted_messages.append({ "role": message["role"], "content": block["text"], }) elif message["role"] == "assistant": # Gel http API packs arguments into a string, but ollama # requires plain json. converted_messages.append({ "role": message["role"], "content": message["content"], "tool_calls": [ { "id": tool_call["id"], "function": { "name": tool_call["function"]["name"], "arguments": json.loads( tool_call["function"]["arguments"] ), } } for tool_call in message["tool_calls"] ] }) else: converted_messages.append(message) params = { "model": model_name, "messages": converted_messages, "options": options, **({"tools": tools} if tools is not None else {}), } client = http_client.with_context( headers={}, base_url=base_url, ) if stream: async with aconnect_sse( client, method="POST", url="/chat/completions", json={ **params, "stream": True, } ) as event_source: # we need tool_index and finish_reason to correctly # send 'content_block_stop' chunk for tool call messages tool_index = 0 finish_reason: Optional[str] = "unknown" started = False async for sse in event_source: if not response.sent: response.status = http.HTTPStatus.OK response.content_type = b'text/event-stream' response.close_connection = False response.custom_headers["Cache-Control"] = "no-cache" protocol.write(request, response) if sse.event != "message": continue if sse.data == "[DONE]": # mistral doesn't send finish_reason for tool calls if finish_reason == "unknown": event = ( b'event: content_block_stop\n' + b'data: {"type": "content_block_stop",' + b'"index": ' + str(tool_index).encode() + b'}\n\n' ) protocol.write_raw(event) event = ( b'event: message_stop\n' + b'data: {"type": "message_stop"}\n\n' ) protocol.write_raw(event) break message = sse.json() if message.get("object") == "chat.completion.chunk": choices = message.get("choices") data = choices[0] if choices else {} delta = data.get("delta") or {} role = delta.get("role") tool_calls = delta.get("tool_calls") # Unlike OpenAI, Ollama includes the role in every event. # Just create a new start event. if not started: event_data = json.dumps({ "type": "message_start", "message": { "id": message["id"], "role": role, "model": message["model"], "usage": message.get("usage") }, }).encode("utf-8") event = ( b'event: message_start\n' + b'data: ' + event_data + b'\n\n' ) protocol.write_raw(event) event_data = json.dumps({ "type": "content_block_start", "index": 0, "content_block": { "type": "text", "text": "" } }).encode("utf-8") event = ( b'event: content_block_start\n' + b'data: ' + event_data + b'\n\n' ) protocol.write_raw(event) started = True if tool_calls: for index, tool_call in enumerate(tool_calls): currentIndex = tool_call.get("index") or index if tool_call.get("type") == "function" or \ "id" in tool_call: if currentIndex > 0: tool_index = currentIndex # send the stop chunk for the previous tool event = ( b'event: content_block_stop\n' + b'data: { \ "type": "content_block_stop",' + b'"index": ' + str(currentIndex).encode() + b'}\n\n' ) protocol.write_raw(event) event_data = json.dumps({ "type": "content_block_start", "index": currentIndex + 1, "content_block": { "id": tool_call.get("id"), "type": "tool_use", "name": tool_call["function"]["name"], "args": tool_call["function"]["arguments"], }, }).encode("utf-8") event = ( b'event: content_block_start\n' + b'data:' + event_data + b'\n\n' ) protocol.write_raw(event) else: event_data = json.dumps({ "type": "content_block_delta", "index": currentIndex + 1, "delta": { "type": "tool_call_delta", "args": tool_call["function"]["arguments"], }, }).encode("utf-8") event = ( b'event: content_block_delta\n' + b'data:' + event_data + b'\n\n' ) protocol.write_raw(event) elif finish_reason := data.get("finish_reason"): index = ( tool_index + 1 if finish_reason == "tool_calls" else 0 ) event = ( b'event: content_block_stop\n' + b'data: {"type": "content_block_stop",' + b'"index": ' + str(index).encode() + b'}\n\n' ) protocol.write_raw(event) event_data = json.dumps({ "type": "message_delta", "delta": { "stop_reason": finish_reason, }, "usage": message.get("usage") }).encode("utf-8") event = ( b'event: message_delta\n' + b'data: ' + event_data + b'\n\n' ) protocol.write_raw(event) else: event_data = json.dumps({ "type": "content_block_delta", "index": 0, "delta": { "type": "text_delta", "text": delta.get("content"), }, "logprobs": data.get("logprobs"), }).encode("utf-8") event = ( b'event: content_block_delta\n' + b'data:' + event_data + b'\n\n' ) protocol.write_raw(event) protocol.close() else: result = await client.post( "/chat", json={ **params, "stream": False } ) if result.status_code >= 400: raise AIProviderError( f"API call to generate chat completions failed with status " f"{result.status_code}: {result.text}" ) response.status = http.HTTPStatus.OK response.content_type = b'application/json' result_data = result.json() tool_calls = result_data["message"].get("tool_calls") tool_calls_formatted = [ { # Ollama does not provide tool call ids for non-stream. # Use enumerate to generate a placeholder. "id": f"call_{tool_call_id}", # Ollama doesn't provide the tool call type but it should # always be 'function'. "type": "function", "name": tool_call["function"]["name"], "args": tool_call["function"]["arguments"], } for tool_call_id, tool_call in enumerate(tool_calls or []) ] body = { "model": result_data["model"], "text": result_data["message"]["content"], "finish_reason": result_data["done_reason"], "usage": { "prompt_tokens": result_data["prompt_eval_count"], "completion_tokens": result_data["eval_count"] }, "tool_calls": tool_calls_formatted, } # Ollama has no documented tools response response.body = json.dumps( body ).encode("utf-8") # # HTTP API # async def handle_request( protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, db: dbview.Database, role_name: str, args: list[str], tenant: srv_tenant.Tenant, ) -> None: if len(args) != 1 or args[0] not in {"rag", "embeddings"}: response.body = b'Unknown path' response.status = http.HTTPStatus.NOT_FOUND response.close_connection = True return if request.method != b"POST": response.body = b"Invalid request method" response.status = http.HTTPStatus.METHOD_NOT_ALLOWED response.close_connection = True return if request.content_type != b"application/json": response.body = b"Expected application/json input" response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True return await db.introspection() try: if args[0] == "rag": await _handle_rag_request( protocol, request, response, db, role_name, tenant ) elif args[0] == "embeddings": await _handle_embeddings_request( request, response, db, role_name, tenant ) else: response.body = b'Unknown path' response.status = http.HTTPStatus.NOT_FOUND response.close_connection = True return except Exception as ex: if not isinstance(ex, AIExtError): ex = InternalError(str(ex)) if not isinstance(ex, BadRequestError): logger.error(f"error while handling a /{args[0]} request: {ex}") response.status = ex.get_http_status() response.content_type = b'application/json' response.body = json.dumps(ex.json()).encode("utf-8") response.close_connection = True return async def _handle_rag_request( protocol: protocol.HttpProtocol, request: protocol.HttpRequest, response: protocol.HttpResponse, db: dbview.Database, role_name: str, tenant: srv_tenant.Tenant, ) -> None: try: http_client = tenant.get_http_client(originator="ai/rag") body = json.loads(request.body) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') context = body.get('context') if context is None: raise TypeError( 'missing required "context" object in request') if not isinstance(context, dict): raise TypeError( '"context" value in request is not a valid JSON object') ctx_query = context.get("query") ctx_variables = context.get("variables") ctx_globals = context.get("globals") ctx_max_obj_count = context.get("max_object_count") if not ctx_query: raise TypeError( 'missing required "query" in request "context" object') if ctx_variables is not None and not isinstance(ctx_variables, dict): raise TypeError('"variables" must be a JSON object') if ctx_globals is not None and not isinstance(ctx_globals, dict): raise TypeError('"globals" must be a JSON object') model = cast(str, body.get('model')) if not model: raise TypeError( 'missing required "model" in request') query = body.get('query') if not query: raise TypeError( 'missing required "query" in request') stream = body.get('stream') if stream is None: stream = False elif not isinstance(stream, bool): raise TypeError('"stream" must be a boolean') if ctx_max_obj_count is None: ctx_max_obj_count = 5 elif not isinstance(ctx_max_obj_count, int) or ctx_max_obj_count <= 0: raise TypeError( '"context.max_object_count" must be a positive integer') prompt_id = None prompt_name = None custom_prompt = None custom_prompt_messages: list[dict[str, Any]] = [] prompt = body.get("prompt") if prompt is None: prompt_name = "builtin::rag-default" else: if not isinstance(prompt, dict): raise TypeError( '"prompt" value in request must be a JSON object') prompt_name = prompt.get("name") prompt_id = prompt.get("id") custom_prompt = prompt.get("custom") if prompt_name and prompt_id: raise TypeError( "prompt.id and prompt.name are mutually exclusive" ) if custom_prompt: if not isinstance(custom_prompt, list): raise TypeError( ( "prompt.custom must be a list, where each element " "is one of the following types:\n" "{ role: 'system', content: str },\n" "{ role: 'user', content: [{ type: 'text', " "text: str }] },\n" "{ role: 'assistant', content: str, " "optional tool_calls: [{id: str, type: 'function'," " function: { name: str, arguments: str }}] },\n" "{ role: 'tool', content: str, tool_call_id: str }" ) ) for entry in custom_prompt: if not isinstance(entry, dict) or not entry.get("role"): raise TypeError( ( "each prompt.custom entry must be a " "dictionary of one of the following types:\n" "{ role: 'system', content: str },\n" "{ role: 'user', content: [{ type: 'text', " "text: str }] },\n" "{ role: 'assistant', content: str, " "optional tool_calls: [{id: str, " "type: 'function', function: { " "name: str, arguments: str }}] },\n" "{ role: 'tool', content: str, " "tool_call_id: str }" ) ) entry_role = entry.get('role') if entry_role == 'system': if not isinstance(entry.get("content"), str): raise TypeError( "System message content has to be string." ) elif entry_role == 'user': if not isinstance(entry.get("content"), list): raise TypeError( ( "User message content has to be a list of " "{ type: 'text', text: str }" ) ) for content_entry in entry["content"]: if content_entry.get( "type" ) != "text" or not isinstance( content_entry.get("text"), str ): raise TypeError( ( "Element of user message content has to" "be of type { type: 'text', text: str }" ) ) elif entry_role == 'assistant': if not isinstance(entry.get("content"), str): raise TypeError( "Assistant message content has to be string" ) tool_calls = entry.get("tool_calls") if tool_calls: if not isinstance(tool_calls, list): raise TypeError( ( "Assistant tool calls must be" "a list of:\n" "{id: str, type: 'function', function:" " {name: str, arguments: str }}" ) ) for call in tool_calls: if ( not isinstance(call, dict) or not isinstance(call.get("id"), str) or call.get("type") != "function" or not isinstance( call.get("function"), dict ) or not isinstance( call["function"].get("name"), str ) or not isinstance( call["function"].get("arguments"), str, ) ): raise TypeError( ( "A tool call must be of type:\n" "{id: str, type: 'function', " "function: { name: str, " "arguments: str }}" ) ) elif entry_role == 'tool': if not isinstance(entry.get("content"), str): raise TypeError( "Tool message content has to be string." ) if not isinstance(entry.get("tool_call_id"), str): raise TypeError( "Tool message tool_call_id has to be string." ) else: raise TypeError( ( "Message role must match one of these: " "system, user, assistant, tool." ) ) custom_prompt_messages.append(entry) except Exception as ex: raise BadRequestError(ex.args[0]) provider_name: str model_name: str try_builtin = False if ':' in model: parts = model.split(':') if len(parts) > 2: raise BadRequestError( f"Invalid model uri, ':' used more than once: {model}" ) provider_name = parts[0] model_name = parts[1] try_builtin = True else: provider_name = await _get_model_provider( db, base_model_type=TextGenerationModel.gel_type, model_name=model, ) model_name = model chat_provider = _get_provider_config( db, provider_name, try_builtin=try_builtin ) vector_provider, vector_query = await _generate_embeddings_for_type( db, http_client, ctx_query, content=query, role_name=role_name, ) ctx_query = f""" WITH __query := >std::to_json($input), search := ext::ai::search(({ctx_query}), __query), context := ( for s in search union ( (ext::ai::to_context(s.object), s.distance) ) ) SELECT ( SELECT context ORDER BY .1 ASC EMPTY LAST LIMIT $limit ).0 """ if ctx_variables is None: ctx_variables = {} ctx_variables["input"] = str( vector_provider.get_embeddings_from_result(vector_query)[0] ) ctx_variables["limit"] = ctx_max_obj_count context = await _edgeql_query_json( db=db, query=ctx_query, variables=ctx_variables, globals_=ctx_globals, role_name=role_name, ) if len(context) == 0: raise BadRequestError( 'query did not match any data in specified context', ) prompt_query = """ SELECT ext::ai::ChatPrompt { messages: { participant_role, content, }, } FILTER """ if prompt_id or prompt_name: prompt_variables = {} if prompt_name: prompt_query += ".name = $prompt_name" prompt_variables["prompt_name"] = prompt_name elif prompt_id: prompt_query += ".id = $prompt_id" prompt_variables["prompt_id"] = prompt_id prompts = await _edgeql_query_json( db=db, role_name=role_name, query=prompt_query, variables=prompt_variables, ) if len(prompts) == 0: raise BadRequestError("could not find the specified chat prompt") prompt = prompts[0] else: prompt = { "messages": [], } prompt_messages: list[dict[str, Any]] = [] for message in prompt["messages"]: if message["participant_role"] == "User": content = message["content"].format( context="\n".join(context), query=query, ) elif message["participant_role"] == "System": content = message["content"].format( context="\n".join(context), ) else: content = message["content"] role = message["participant_role"].lower() prompt_messages.append(dict(role=role, content=content)) # don't add here at the end the user query msg because Mistral and # Anthropic doesn't work if the user message shows after the tools messages = prompt_messages + custom_prompt_messages await _start_chat( protocol=protocol, request=request, response=response, provider=chat_provider, http_client=http_client, model_name=model_name, messages=messages, stream=stream, temperature=body.get("temperature"), top_p=body.get("top_p"), max_tokens=body.get("max_tokens"), seed=body.get("seed"), safe_prompt=body.get("safe_prompt"), top_k=body.get("top_k"), logit_bias=body.get("logit_bias"), logprobs=body.get("logprobs"), user=body.get("user"), tools=body.get("tools"), ) async def _handle_embeddings_request( request: protocol.HttpRequest, response: protocol.HttpResponse, db: dbview.Database, role_name: str, tenant: srv_tenant.Tenant, ) -> None: try: body = json.loads(request.body) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') inputs = body.get("inputs") input = body.get("input") if inputs is not None and input is not None: raise TypeError( "You cannot provide both 'inputs' and 'input'. " "Please provide 'inputs'; 'input' has been deprecated." ) if input is not None: logger.warning("'input' is deprecated, use 'inputs' instead") inputs = input if not inputs: raise TypeError( 'missing or empty required "inputs" value in request' ) model_name = body.get("model") if not model_name: raise TypeError( 'missing or empty required "model" value in request') shortening = body.get("dimensions") user = body.get("user") except Exception as ex: raise BadRequestError(str(ex)) from None provider_name = await _get_model_provider( db, base_model_type=EmbeddingModel.gel_type, model_name=model_name, ) if provider_name is None: # Error return provider = _get_provider_config(db, provider_name) if not isinstance(inputs, list): inputs = [inputs] result = await _generate_embeddings( provider, model_name, inputs, shortening, user, http_client=tenant.get_http_client(originator="ai/embeddings"), ) if isinstance(result.data, rs.Error): raise AIProviderError(result.data.message) result_data: bytes if provider.api_style == ApiStyle.Ollama: # Ollama produces embeddings differently. # Repackage it to look like OpenAI. decoded_result = json.loads( result.data.embeddings.decode("utf-8") ) embeddings = cast(list[list[float]], decoded_result["embeddings"]) prompt_eval_count = cast(int, decoded_result['prompt_eval_count']) result_data = json.dumps({ "object": "list", "data": [ { "object": "embedding", "index": index, "embedding": embedding } for index, embedding in enumerate(embeddings) ], "model": model_name, "usage": { "prompt_tokens": prompt_eval_count, "total_tokens": prompt_eval_count } }).encode() else: result_data = result.data.embeddings response.status = http.HTTPStatus.OK response.content_type = b'application/json' response.body = result_data async def _edgeql_query_json( *, db: dbview.Database, query: str, role_name: str | None, variables: Optional[dict[str, Any]] = None, globals_: Optional[dict[str, Any]] = None, ) -> list[Any]: try: result = await execute.parse_execute_json( db, query, variables=variables or {}, globals_=globals_, query_tag='gel/ai', ) content = json.loads(result) except Exception as ex: try: await _db_error(db, ex) except Exception as iex: raise iex from None else: return cast(list[Any], content) async def _db_error( db: dbview.Database, ex: Exception, *, errcls: Optional[type[AIExtError]] = None, context: Optional[str] = None, ) -> NoReturn: if debug.flags.server: markup.dump(ex) iex = await execute.interpret_error(ex, db) if context: msg = f'{context}: {iex}' else: msg = str(iex) err_dct = { 'message': msg, 'type': str(type(iex).__name__), 'code': iex.get_code(), } if errcls is None: if isinstance(iex, errors.QueryError): errcls = BadRequestError else: errcls = InternalError raise errcls(json=err_dct) from iex def _get_provider_config( db: dbview.Database, provider_name: str, try_builtin: bool = False, ) -> ProviderConfig: """Try to return a provider config with a matching name. Otherwise, raise an error. Checks if there is a builtin provider with a matching name. eg. "openai" -> ProviderConfig(name="builtin::openai", ...) """ cfg = db.lookup_config("ext::ai::Config::providers") def _create_provider_config(db_cfg: Any) -> ProviderConfig: cfg = cast(ProviderConfig, db_cfg) return ProviderConfig( name=cfg.name, display_name=cfg.display_name, api_url=cfg.api_url, client_id=cfg.client_id, secret=cfg.secret, api_style=cfg.api_style, ) # try builtin prefix builtin_prefix = "builtin::" if try_builtin and not provider_name.startswith(builtin_prefix): builtin_name = builtin_prefix + provider_name for provider in cfg: if provider.name == builtin_name: return _create_provider_config(provider) # try unmodified name for provider in cfg: if provider.name == provider_name: return _create_provider_config(provider) raise ConfigurationError( f"provider {provider_name!r} has not been configured" ) async def _get_embedding_models( db: dbview.Database, model_names: list[str], ) -> dict[str, EmbeddingModel]: model_annotations = await _get_model_annotations( db, base_model_type=EmbeddingModel.gel_type, model_names=model_names, annotation_names=[ EmbeddingModel.provider_annotation, EmbeddingModel.max_model_input_tokens_annotation, EmbeddingModel.max_batch_tokens_annotation, EmbeddingModel.max_batch_size_annotation, EmbeddingModel.max_output_dimensions_annotation, EmbeddingModel.supports_shortening_annotation, ], ) def _get_ann( model: str, anns: dict[str, str | None], name: str, ) -> str: val = anns.get(name) if val is None: raise InternalError(f"Could not read annotation '{name}'") if val == "": raise InternalError( f"Model '{model}' is missing value for annotation '{name}'" ) return val def _get_bool_ann( model: str, anns: dict[str, str | None], name: str, ) -> bool: val = _get_ann(model, anns, name) try: return bool(val) except ValueError: raise InternalError( f"Model '{model}' annotation '{name}' " f"has non boolean value {val}" ) def _get_int_ann( model: str, anns: dict[str, str | None], name: str, ) -> int: val = _get_ann(model, anns, name) try: return int(val) except ValueError: raise InternalError( f"Model '{model}' annotation '{name}' " f"has non integer value {val}" ) def _get_int_or_none_ann( model: str, anns: dict[str, str | None], name: str, ) -> int | None: val = _get_ann(model, anns, name) if val == "": return None try: return int(val) except ValueError: raise InternalError( f"Model '{model}' annotation '{name}' " f"has non integer value {val}" ) result: dict[str, EmbeddingModel] = {} for model, anns in model_annotations.items(): result[model] = EmbeddingModel( name=model, provider=_get_ann(model, anns, EmbeddingModel.provider_annotation), max_input_tokens=_get_int_ann( model, anns, EmbeddingModel.max_model_input_tokens_annotation ), max_batch_tokens=_get_int_ann( model, anns, EmbeddingModel.max_batch_tokens_annotation ), max_batch_size=_get_int_or_none_ann( model, anns, EmbeddingModel.max_batch_size_annotation ), max_output_dimensions=_get_int_ann( model, anns, EmbeddingModel.max_output_dimensions_annotation ), supports_shortening=_get_bool_ann( model, anns, EmbeddingModel.supports_shortening_annotation ), ) return result async def _get_model_annotations( db: dbview.Database, base_model_type: str, model_names: list[str], annotation_names: list[str], ) -> dict[str, dict[str, str | None]]: models = await _edgeql_query_json( db=db, role_name=None, query=""" WITH Parent := ( SELECT schema::ObjectType FILTER .name = $base_model_type ), Models := Parent.>$annotation_names) ), } FILTER .model_name in array_unpack(>$model_names) """, variables={ "base_model_type": base_model_type, "model_names": model_names, "annotation_names": annotation_names, }, ) if len(models) == 0: raise BadRequestError("invalid model name") result: dict[str, dict[str, str | None]] = {} for model in models: model_name = model['model_name'] if model_name in result: raise InternalError(f"models with duplicate name: {model_name}") model_anns = { ann_name: ann_value for ann_name, ann_value in model['values'] } result[model_name] = { ann_name: model_anns.get(ann_name) for ann_name in annotation_names } return result async def _get_model_provider( db: dbview.Database, base_model_type: str, model_name: str, ) -> str: model_annotations = await _get_model_annotations( db, base_model_type, [model_name], [BaseModel.provider_annotation] ) return not_none( model_annotations[model_name][BaseModel.provider_annotation] ) async def _generate_embeddings_for_type( db: dbview.Database, http_client: http.HttpClient, type_query: str, content: str, role_name: str, ) -> tuple[ProviderConfig, bytes]: type_desc = await execute.describe( db, f"SELECT ({type_query})", allow_capabilities=compiler.Capability.NONE, query_tag='gel/ai', role_name=role_name, ) if ( not isinstance(type_desc, sertypes.ShapeDesc) or not isinstance(type_desc.type, sertypes.ObjectDesc) ): raise errors.InvalidReferenceError( 'context query does not return an ' 'object type indexed with an `ext::ai::index`' ) return await generate_embeddings_for_text( db, http_client, type_desc.type.tid, content ) async def generate_embeddings_for_text( db: dbview.Database, http_client: http.HttpClient, type_id: str | uuid.UUID, content: str, ) -> tuple[ProviderConfig, bytes]: index = await get_ai_index_for_type(db, type_id) provider = _get_provider_config(db=db, provider_name=index.provider) if ( index.index_embedding_dimensions < index.model_embedding_dimensions ): shortening = index.index_embedding_dimensions else: shortening = None result = await _generate_embeddings( provider, index.model, [content], shortening, None, http_client, ) if isinstance(result.data, rs.Error): raise AIProviderError(result.data.message) return provider, result.data.embeddings @dataclass(frozen=True, kw_only=True) class TextEmbeddingsResult: success: Optional[list[list[float]]] = None too_long: Optional[list[int]] = None async def generate_embeddings_for_texts( db: dbview.Database, http_client: http.HttpClient, inputs: list[tuple[str | uuid.UUID, str]], ) -> TextEmbeddingsResult: """Generate embeddings for strings to search for ai indexed objects. Each input string may have a different object. The object is specified by the object type id as either a string or uuid. Produces embeddings for the input strings by: - grouping string by their index model and shortening - batching those groups - then doing embeddings requests in batches Input strings are truncated if allowed by their index. If any string is too long and truncating is not allowed, a "too_long" result is returned. If all embeddings requests are successful, the embeddings are returned as a "success" result in the same order as the inputs. """ # Gather information about the indexes and embeddings # For each type, we will need: # - model name # - max input tokens # - max batch tokens # - provider config # - allowed to truncate # - shortening, if any type_ai_indexes: dict[str, AIIndex] = {} for type_id, _ in inputs: type_id = str(type_id) if type_id not in type_ai_indexes: type_ai_indexes[type_id] = await get_ai_index_for_type(db, type_id) model_providers = { ai_index.model: ai_index.provider for ai_index in type_ai_indexes.values() } embedding_models = await _get_embedding_models( db, list(model_providers.keys()) ) provider_configs = { provider: _get_provider_config(db=db, provider_name=provider) for provider in set(model_providers.values()) } # Group the inputs by model and shortening group_input_indexes: dict[tuple[str, Optional[int]], list[int]] = {} for input_index, (type_id, _) in enumerate(inputs): ai_index = type_ai_indexes[str(type_id)] model_name = ai_index.model shortening = ( ai_index.index_embedding_dimensions if ( ai_index.index_embedding_dimensions < ai_index.model_embedding_dimensions ) else None ) group_key = (model_name, shortening) if group_key not in group_input_indexes: group_input_indexes[group_key] = [] group_input_indexes[group_key].append(input_index) # Batch each group separately group_batch_texts_and_indexes: dict[ tuple[str, Optional[int]], list[tuple[ # texts, truncated if needed list[str], # the associated input index list[int], ]] ] = {} too_long: list[int] = [] for group_key, input_indexes in group_input_indexes.items(): model_name, shortening = group_key provider = model_providers[model_name] embedding_model = embedding_models[model_name] tokenizer = get_model_tokenizer(provider, model_name) texts = [ ( inputs[input_index][1], type_ai_indexes[str(inputs[input_index][0])].truncate_to_max, ) for input_index in input_indexes ] text_batches, excluded_indexes = batch_texts( texts, tokenizer, max_input_tokens=embedding_model.max_input_tokens, max_batch_tokens=embedding_model.max_batch_tokens, max_batch_size=embedding_model.max_batch_size, ) if excluded_indexes or too_long: # If any input is too long, collect all inputs that are too long # and return them as a failure too_long.extend( input_indexes[excluded_index] for excluded_index in excluded_indexes ) continue group_batch_texts_and_indexes[group_key] = [] for text_batch in text_batches: batched_texts: list[str] = [] batched_input_indexes: list[int] = [] for entry in text_batch.entries: batched_texts.append(entry.input_text) batched_input_indexes.append( input_indexes[entry.input_index] ) group_batch_texts_and_indexes[group_key].append( (batched_texts, batched_input_indexes) ) if too_long: return TextEmbeddingsResult(too_long=too_long) # Do the embeddings # We have been tracking the input indexes of the batch texts this whole # time. Use these indexes to fill in a result embeddings list embeddings: list[Optional[list[float]]] = [None] * len(inputs) for group_key, batched_texts_and_indexes in ( group_batch_texts_and_indexes.items() ): model_name, shortening = group_key provider = model_providers[model_name] provider_config = provider_configs[provider] for batched_texts, batched_input_indexes in batched_texts_and_indexes: embeddings_result = await _generate_embeddings( provider_config, model_name, batched_texts, shortening, None, http_client, ) if isinstance(embeddings_result.data, rs.Error): raise AIProviderError(embeddings_result.data.message) result_entries = provider_config.get_embeddings_from_result( embeddings_result.data.embeddings ) for entry_index, result_entry in enumerate(result_entries): input_index = batched_input_indexes[entry_index] embeddings[input_index] = result_entry assert all(e is not None for e in embeddings) return TextEmbeddingsResult( success=cast(list[list[float]], embeddings), ) @dataclass(frozen=True, kw_only=True) class AIIndex: model: str provider: str model_embedding_dimensions: int index_embedding_dimensions: int truncate_to_max: bool async def get_ai_index_for_type( db: dbview.Database, type_id: str | uuid.UUID, ) -> AIIndex: try: indexes = await _edgeql_query_json( db=db, query=""" WITH ObjectType := ( SELECT schema::ObjectType FILTER .id = $type_id ), SELECT ObjectType.indexes { model := ( SELECT (FOR a IN .annotations SELECT (a@value, a.name)) FILTER .1 = "ext::ai::model_name" LIMIT 1 ).0, provider := ( SELECT (FOR a IN .annotations SELECT (a@value, a.name)) FILTER .1 = "ext::ai::model_provider" LIMIT 1 ).0, model_embedding_dimensions := ( SELECT (FOR a IN .annotations SELECT (a@value, a.name)) FILTER .1 = "ext::ai::embedding_model_max_output_dimensions" LIMIT 1 ).0, index_embedding_dimensions := ( SELECT (FOR a IN .annotations SELECT (a@value, a.name)) FILTER .1 = "ext::ai::embedding_dimensions" LIMIT 1 ).0, truncate_to_max := any(( for kwarg in array_unpack(.kwargs) select ( kwarg.name = 'truncate_to_max' and str_lower(kwarg.expr) = 'true' ) )) } FILTER .ancestors.name = 'ext::ai::index' """, variables={"type_id": str(type_id)}, role_name=None, ) if len(indexes) == 0: raise errors.InvalidReferenceError( 'context query does not return an ' 'object type indexed with an `ext::ai::index`' ) elif len(indexes) > 1: raise errors.InvalidReferenceError( 'context query returns an object ' 'indexed with multiple `ext::ai::index` indexes' ) except Exception as ex: await _db_error(db, ex, context="context.query") index = indexes[0] return AIIndex( model=index["model"], provider=index["provider"], model_embedding_dimensions=index["model_embedding_dimensions"], index_embedding_dimensions=index["index_embedding_dimensions"], truncate_to_max=index["truncate_to_max"], ) ================================================ FILE: edb/server/protocol/args_ser.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # cimport cython from edb.server.dbview cimport dbview from edb.server.pgproto.pgproto cimport WriteBuffer cdef WriteBuffer recode_bind_args( dbview.DatabaseConnectionView dbv, dbview.CompiledQuery compiled, bytes bind_args, list converted_args, list positions = ?, list data_types = ?, ) cdef recode_bind_args_for_script( dbview.DatabaseConnectionView dbv, dbview.CompiledQuery compiled, bytes bind_args, object converted_args, ssize_t start, ssize_t end, ) cdef bytes recode_global( dbview.DatabaseConnectionView dbv, bytes glob, object glob_descriptor, ) cdef WriteBuffer combine_raw_args(object args = ?) @cython.final cdef class ParamConversion: cdef: str param_name str conversion_name tuple additional_info bytes encoded_data object constant_value cdef list[ParamConversion] get_param_conversions( dbview.DatabaseConnectionView dbv, list server_param_conversions, bytes bind_args, list[bytes] extra_blobs, ) cdef dict[int, bytes] get_args_data_for_indexes( bytes bind_args, list[bytes] extra_blobs, list[int] target_indexes, ) cdef class ConvertedArg: cdef: int bind_format_code @cython.final cdef class ConvertedArgStr(ConvertedArg): cdef: str data @staticmethod cdef ConvertedArgStr new(str data) @cython.final cdef class ConvertedArgFloat64(ConvertedArg): cdef: float data @staticmethod cdef ConvertedArgFloat64 new(float data) @cython.final cdef class ConvertedArgListFloat32(ConvertedArg): cdef: list data @staticmethod cdef ConvertedArgListFloat32 new(list data) ================================================ FILE: edb/server/protocol/args_ser.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # cimport cython cimport cpython from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ int32_t, uint32_t, int64_t, uint64_t from edb import errors from edb.server.compiler import sertypes from edb.server.compiler import enums from edb.server.compiler import dbstate from edb.server.dbview cimport dbview from edb.server.pgproto cimport hton from edb.server.pgproto.pgproto cimport ( WriteBuffer, FRBuffer, frb_init, frb_read, frb_get_len, frb_slice_from, ) cdef uint32_t SCALAR_TAG = int(enums.TypeTag.SCALAR) cdef uint32_t TUPLE_TAG = int(enums.TypeTag.TUPLE) cdef uint32_t ARRAY_TAG = int(enums.TypeTag.ARRAY) cdef recode_bind_args_for_script( dbview.DatabaseConnectionView dbv, dbview.CompiledQuery compiled, bytes bind_args, object converted_args, ssize_t start, ssize_t end, ): cdef: WriteBuffer bind_data ssize_t i ssize_t oidx ssize_t iidx unit_group = compiled.query_unit_group # TODO: just do the simple thing if it is only one! positions = [] recoded_buf = recode_bind_args( dbv, compiled, bind_args, None, positions ) # TODO: something with less copies recoded = bytes(memoryview(recoded_buf)) bind_array = [] for i in range(start, end): query_unit = unit_group[i] bind_data = WriteBuffer.new() bind_data.write_int32(0x00010001) num_args = query_unit.in_type_args_real_count num_args += _count_globals(query_unit) if compiled.first_extra is not None: num_args += compiled.extra_counts[i] if query_unit.server_param_conversions is not None: num_args += len(query_unit.server_param_conversions) bind_data.write_int16(num_args) if query_unit.in_type_args: for iidx, arg in enumerate(query_unit.in_type_args): oidx = arg.outer_idx if arg.outer_idx is not None else iidx barg = recoded[positions[oidx]:positions[oidx+1]] bind_data.write_bytes(barg) if compiled.first_extra is not None: bind_data.write_bytes(compiled.extra_blobs[i]) _inject_globals(dbv, query_unit, bind_data) if converted_args and i in converted_args: for arg in converted_args[i]: assert isinstance(arg, ConvertedArg) arg.encode(bind_data) bind_data.write_int32(0x00010001) bind_array.append(bind_data) return bind_array cdef WriteBuffer recode_bind_args( dbview.DatabaseConnectionView dbv, dbview.CompiledQuery compiled, bytes bind_args, list converted_args, # XXX do something better?!? list positions = None, list data_types = None, ): cdef: FRBuffer in_buf FRBuffer peek_buf WriteBuffer out_buf = WriteBuffer.new() int32_t recv_args int32_t decl_args ssize_t in_len ssize_t i int32_t array_tid const char *data bint live = positions is None assert cpython.PyBytes_CheckExact(bind_args) frb_init( &in_buf, cpython.PyBytes_AS_STRING(bind_args), cpython.Py_SIZE(bind_args)) # number of elements in the tuple # for empty tuple it's okay to send zero-length arguments qug = compiled.query_unit_group is_null_type = qug.in_type_id == sertypes.NULL_TYPE_ID.bytes if frb_get_len(&in_buf) == 0: if not is_null_type: raise errors.InputDataError( f"insufficient data for type-id {qug.in_type_id}") recv_args = 0 else: if is_null_type: raise errors.InputDataError( "absence of query arguments must be encoded with a " "'zero' type " "(id: 00000000-0000-0000-0000-000000000000, " "encoded with zero bytes)") recv_args = hton.unpack_int32(frb_read(&in_buf, 4)) decl_args = len(qug.in_type_args or ()) if recv_args != decl_args: raise errors.InputDataError( f"invalid argument count, " f"expected: {decl_args}, got: {recv_args}") num_args = qug.in_type_args_real_count if compiled.first_extra is not None: assert recv_args == compiled.first_extra, \ f"argument count mismatch {recv_args} != {compiled.first_extra}" num_args += compiled.extra_counts[0] num_globals = _count_globals(qug) num_args += num_globals if converted_args is not None: num_args += len(converted_args) if live: if not compiled.extra_formatted_as_text: # all parameter values are in binary out_buf.write_int32(0x00010001) elif not recv_args and not num_globals: # all parameter values are in text (i.e extracted SQL constants) out_buf.write_int16(0x0000) else: # got a mix of binary and text, spell them out explicitly out_buf.write_int16(num_args) # explicit args are in binary for _ in range(recv_args): out_buf.write_int16(0x0001) # and extracted SQL constants are in text if compiled.extra_counts: for _ in range(compiled.extra_counts[0]): out_buf.write_int16(0x0000) # and injected globals are binary again for _ in range(num_globals): out_buf.write_int16(0x0001) # and converted args depend on the conversion if converted_args: for arg in converted_args: out_buf.write_int16(arg.bind_format_code) out_buf.write_int16(num_args) if data_types is not None and compiled.extra_type_oids: data_types.extend([0] * recv_args) data_types.extend(compiled.extra_type_oids) data_types.extend([0] * num_globals) if qug.in_type_args: for param in qug.in_type_args: if positions is not None: positions.append(out_buf._length) frb_read(&in_buf, 4) # reserved # Some of the logic paths below need the length are cleaner if # the length is still present in the input buf, so we just # *peek* at the length here, and need to consume it later. peek_buf = in_buf in_len = hton.unpack_int32(frb_read(&peek_buf, 4)) if in_len < 0: # This means argument value is NULL if param.required: raise errors.QueryError( f"parameter ${param.name} is required") # If the param has encoded tuples, we need to decode them # and reencode them as arrays of scalars. if param.sub_params: tids, trans_typ = param.sub_params _decode_tuple_args( dbv, &in_buf, out_buf, in_len, tids, trans_typ) continue frb_read(&in_buf, 4) out_buf.write_int32(in_len) if in_len > 0: if param.array_type_id is not None: array_tid = dbv.resolve_backend_type_id( param.array_type_id) recode_array(dbv, &in_buf, out_buf, in_len, array_tid, None) else: data = frb_read(&in_buf, in_len) out_buf.write_cstr(data, in_len) if positions is not None: positions.append(out_buf._length) if live: if compiled.first_extra is not None: out_buf.write_bytes(compiled.extra_blobs[0]) # Inject any globals variables into the argument stream. _inject_globals(dbv, qug, out_buf) if converted_args: for arg in converted_args: assert isinstance(arg, ConvertedArg) arg.encode(out_buf) # All columns are in binary format out_buf.write_int32(0x00010001) if frb_get_len(&in_buf): raise errors.InputDataError('unexpected trailing data in buffer') return out_buf cdef bytes recode_global( dbv: dbview.DatabaseConnectionView, glob: bytes, glob_descriptor: object, ): cdef: WriteBuffer out_buf FRBuffer in_buf if glob_descriptor is None: return glob out_buf = WriteBuffer.new() assert cpython.PyBytes_CheckExact(glob) frb_init( &in_buf, cpython.PyBytes_AS_STRING(glob), cpython.Py_SIZE(glob)) _recode_global(dbv, &in_buf, out_buf, in_buf.len, glob_descriptor) if frb_get_len(&in_buf): raise errors.InputDataError('unexpected trailing data in buffer') return bytes(memoryview(out_buf)) cdef _recode_global( dbv: dbview.DatabaseConnectionView, FRBuffer* in_buf, out_buf: WriteBuffer, in_len: ssize_t, glob_descriptor: object, ): if glob_descriptor is None: data = frb_read(in_buf, in_len) out_buf.write_cstr(data, in_len) elif glob_descriptor[0] == TUPLE_TAG: _, el_tids, el_infos = glob_descriptor recode_global_tuple(dbv, in_buf, out_buf, in_len, el_tids, el_infos) elif glob_descriptor[0] == ARRAY_TAG: _, el_tid, tuple_info = glob_descriptor btid = dbv.resolve_backend_type_id(el_tid) recode_array(dbv, in_buf, out_buf, in_len, btid, tuple_info) cdef recode_global_tuple( dbv: dbview.DatabaseConnectionView, FRBuffer* in_buf, out_buf: WriteBuffer, in_len: ssize_t, el_tids: tuple, el_infos: tuple, ): """ Tuples in globals need to have NULLs checked and oids injected, like arrays do. Annoyingly this is a *totally separate* code path than tuple query parameters go through. This is because global tuples actually can get passed as postgres composite types, since they are declared in the schema. """ cdef: WriteBuffer buf ssize_t cnt ssize_t idx ssize_t num ssize_t tag FRBuffer sub_buf frb_slice_from(&sub_buf, in_buf, in_len) cnt = hton.unpack_int32(frb_read(&sub_buf, 4)) out_buf.write_int32(cnt) num = len(el_tids) if cnt != num: raise errors.InputDataError( f"tuple length mismatch: {cnt} vs {num}") for idx in range(num): frb_read(&sub_buf, 4) el_btid = dbv.resolve_backend_type_id(el_tids[idx]) out_buf.write_int32(el_btid) in_len = hton.unpack_int32(frb_read(&sub_buf, 4)) if in_len < 0: raise errors.InputDataError("invalid NULL inside type") out_buf.write_int32(in_len) _recode_global(dbv, &sub_buf, out_buf, in_len, el_infos[idx]) if frb_get_len(&sub_buf): raise errors.InputDataError('unexpected trailing data in buffer') cdef recode_array( dbv: dbview.DatabaseConnectionView, FRBuffer* in_buf, out_buf: WriteBuffer, in_len: ssize_t, array_tid: int32_t, tuple_info: object, ): # For a standalone array, we still need to inject oids and reject # NULL elements. cdef: ssize_t cnt ssize_t idx ssize_t num ssize_t tag FRBuffer sub_buf frb_slice_from(&sub_buf, in_buf, in_len) ndims = hton.unpack_int32(frb_read(&sub_buf, 4)) # ndims if ndims != 1 and ndims != 0: raise errors.InputDataError("unsupported array dimensions") out_buf.write_int32(ndims) data = frb_read(&sub_buf, 8) # flags + reserved (oid) out_buf.write_cstr(data, 4) # just write flags out_buf.write_int32(array_tid) if ndims != 0: cnt = hton.unpack_int32(frb_read(&sub_buf, 4)) out_buf.write_int32(cnt) val = hton.unpack_int32(frb_read(&sub_buf, 4)) # bound if val != 1: raise errors.InputDataError("unsupported array bound") out_buf.write_int32(val) # We have to actually scan the array to make sure it # doesn't have any NULLs in it. for idx in range(cnt): in_len = hton.unpack_int32(frb_read(&sub_buf, 4)) if in_len < 0: raise errors.InputDataError("invalid NULL inside type") out_buf.write_int32(in_len) if tuple_info is None: data = frb_read(&sub_buf, in_len) out_buf.write_cstr(data, in_len) else: _recode_global(dbv, &sub_buf, out_buf, in_len, tuple_info) if frb_get_len(&sub_buf): raise errors.InputDataError('unexpected trailing data in buffer') cdef _decode_tuple_args_core( FRBuffer* in_buf, out_bufs: tuple[WriteBuffer], counts: list[int], acounts: list[int], trans_typ: tuple, in_array: bool, ): # Recurse over the types and the input data, collecting the # arguments into the various out_bufs. See # edb.edgeql.compiler.tuple_args for more discussion. cdef: ssize_t in_len WriteBuffer buf ssize_t cnt ssize_t idx ssize_t num ssize_t tag int32_t val FRBuffer sub_buf tag = trans_typ[0] idx = trans_typ[1] in_len = hton.unpack_int32(frb_read(in_buf, 4)) buf = out_bufs[idx] if in_len < 0: raise errors.InputDataError("invalid NULL inside type") frb_slice_from(&sub_buf, in_buf, in_len) if tag == SCALAR_TAG: buf.write_int32(in_len) data = frb_read(&sub_buf, in_len) buf.write_cstr(data, in_len) if in_array: counts[idx] += 1 elif tag == TUPLE_TAG: cnt = hton.unpack_int32(frb_read(&sub_buf, 4)) num = len(trans_typ) - 2 if cnt != num: raise errors.InputDataError( f"tuple length mismatch: {cnt} vs {num}") for idx in range(num): typ = trans_typ[idx + 2] frb_read(&sub_buf, 4) _decode_tuple_args_core( &sub_buf, out_bufs, counts, acounts, typ, in_array) elif tag == ARRAY_TAG: val = hton.unpack_int32(frb_read(&sub_buf, 4)) # ndims if val != 1 and val != 0: raise errors.InputDataError("unsupported array dimensions") frb_read(&sub_buf, 4) # flags frb_read(&sub_buf, 4) # reserved if val == 0: cnt = 0 else: cnt = hton.unpack_int32(frb_read(&sub_buf, 4)) val = hton.unpack_int32(frb_read(&sub_buf, 4)) # bound if val != 1: raise errors.InputDataError("unsupported array bound") # For nested arrays, we need to produce an array containing # the start/end indexes in the flattened array. if in_array: # If this is the first element, put in the 0 if acounts[idx] == -1: counts[idx] += 1 acounts[idx] = 0 buf.write_int32(4) buf.write_int32(0) counts[idx] += 1 acounts[idx] += cnt buf.write_int32(4) buf.write_int32(acounts[idx]) styp = trans_typ[2] for _ in range(cnt): _decode_tuple_args_core( &sub_buf, out_bufs, counts, acounts, styp, True) if frb_get_len(&sub_buf): raise errors.InputDataError('unexpected trailing data in buffer') cdef WriteBuffer _decode_tuple_args( dbv: dbview.DatabaseConnectionView, FRBuffer* in_buf, out_buf: WriteBuffer, in_len: ssize_t, tids: list, trans_typ: object, ): # PERF: Can we use real arrays, instead of python lists? cdef: const char *data list buffers list counts list acounts WriteBuffer buf # N.B: We have peeked at in_len, but the size is still in the buffer, for # more convenient processing by _decode_tuple_args_core if in_len < 0: # For a NULL argument, fill out *every* one of our args with NULL for _ in tids: out_buf.write_int32(in_len) # We only peeked at in_len before, so consume it now frb_read(in_buf, 4) return buffers = [] counts = [] acounts = [] for maybe_tid in tids: buf = WriteBuffer.new() counts.append(0 if maybe_tid else -1) acounts.append(-1) buffers.append(buf) _decode_tuple_args_core( in_buf, tuple(buffers), counts, acounts, trans_typ, False) # zip all of the buffers we have collected into up # PERF: or should we just index? for maybe_tid, count, buf in zip(tids, counts, buffers): if maybe_tid: ndims = 1 out_buf.write_int32(12 + 8 * ndims + buf.len()) # ndimensions + flags array_tid = dbv.resolve_backend_type_id(maybe_tid) out_buf.write_int32(1) out_buf.write_int32(0) out_buf.write_int32(array_tid) out_buf.write_int32(count) out_buf.write_int32(1) out_buf.write_buffer(buf) cdef _inject_globals( dbv: dbview.DatabaseConnectionView, query_unit_or_group: object, out_buf: WriteBuffer, ): if globals := query_unit_or_group.globals: for (name, has_present_arg) in globals: val, is_present = dbv.get_global_value(name) if val is not None: out_buf.write_int32(len(val)) out_buf.write_bytes(val) else: out_buf.write_int32(-1) if has_present_arg: out_buf.write_int32(1) present = b'\x01' if is_present else b'\x00' out_buf.write_bytes(present) if permissions := query_unit_or_group.permissions: superuser, available_permissions = dbv.get_permissions() for permission in permissions: out_buf.write_int32(1) out_buf.write_byte( superuser or permission in available_permissions ) cdef uint64_t _count_globals( query_unit: object, ): cdef: uint64_t num_args num_args = 0 if query_unit.globals: num_args += len(query_unit.globals) for _, has_present_arg in query_unit.globals: if has_present_arg: num_args += 1 if query_unit.permissions: num_args += len(query_unit.permissions) return num_args cdef WriteBuffer combine_raw_args( args: tuple[bytes, ...] | list[bytes] = (), ): cdef: int arg_len WriteBuffer bind_data = WriteBuffer.new() if len(args) > 32767: raise AssertionError( 'the number of query arguments cannot exceed 32767') bind_data.write_int32(0x00010001) bind_data.write_int16( len(args)) for arg in args: if arg is None: bind_data.write_int32(-1) else: arg_len = len(arg) if arg_len > 0x7fffffff: raise ValueError("argument too long") bind_data.write_int32( arg_len) bind_data.write_bytes(arg) bind_data.write_int32(0x00010001) return bind_data @cython.final cdef class ParamConversion: def __init__( self, *, param_name, conversion_name, additional_info, encoded_data, constant_value, ): self.param_name = param_name self.conversion_name = conversion_name self.additional_info = additional_info self.encoded_data = encoded_data self.constant_value = constant_value def get_param_name(self): return self.param_name def get_conversion_name(self): return self.conversion_name def get_additional_info(self): return self.additional_info def get_encoded_data(self): return self.encoded_data def get_constant_value(self): return self.constant_value def param_as_int(self) -> int: return self._decode_int() if self.constant_value is None else self.constant_value def param_as_str(self) -> str: return self._decode_str() if self.constant_value is None else self.constant_value def param_as_array_of_str(self) -> list[str]: return self._decode_array_of_str() if self.constant_value is None else self.constant_value def _decode_int(self) -> int: return int.from_bytes(self.encoded_data) def _decode_str(self) -> str: return self.encoded_data.decode("utf-8") def _decode_array_of_str(self) -> list[str]: # See gel-python for more details on array encoding texts = [] text_count = int.from_bytes(self.encoded_data[12:16]) data = self.encoded_data[20:] for _ in range(text_count): text_length = int.from_bytes(data[:4]) data = data[4:] texts.append(data[:(text_length)].decode("utf-8")) data = data[text_length:] return texts cdef list[ParamConversion] get_param_conversions( dbview.DatabaseConnectionView dbv, list server_param_conversions, bytes bind_args, list[bytes] extra_blobs, ): # Get encoded data from bind args and extra blobs bind_args_datas: dict[int, bytes] = get_args_data_for_indexes( bind_args, extra_blobs, [ param_conversion.script_param_index for param_conversion in server_param_conversions if param_conversion.script_param_index is not None ], ) # Construct the ParamConversions result: list[ParamConversion] = [] for param_conversion in server_param_conversions: assert isinstance(param_conversion, dbstate.ServerParamConversion) param_name = param_conversion.param_name if ( param_conversion.script_param_index is not None and param_conversion.constant_value is not None ): raise RuntimeError( f"Parameter '{param_name}' has both " f"a constant and a query arg value" ) elif param_conversion.script_param_index is not None: # using data from the bind args result.append(ParamConversion( param_name=param_name, conversion_name=param_conversion.conversion_name, additional_info=param_conversion.additional_info, encoded_data=bind_args_datas[ param_conversion.script_param_index ], constant_value=None, )) elif param_conversion.constant_value is not None: # using a constant from the query result.append(ParamConversion( param_name=param_name, conversion_name=param_conversion.conversion_name, additional_info=param_conversion.additional_info, encoded_data=None, constant_value=param_conversion.constant_value, )) else: raise RuntimeError( f"Parameter '{param_name}' has no value" ) return result cdef dict[int, bytes] get_args_data_for_indexes( bytes bind_args, list[bytes] extra_blobs, list[int] target_indexes, ): """Extract bytes from the bind args and extra blobs by reading the length of each variable and skipping forward by that amount. When reaching the end of a blob, continue reading data from the next blob. """ cdef: FRBuffer in_buf ssize_t in_len const char *data_str all_blobs = [bind_args, *extra_blobs] curr_blob_index = 0 # The first blob is the bind_args, which is has additional data which should # be skipped when extracting the arg data. args_needs_recoding = True def setup_blob_buffer(): nonlocal curr_blob_index nonlocal args_needs_recoding if curr_blob_index >= len(all_blobs): raise RuntimeError('insufficient args data') blob = all_blobs[curr_blob_index] assert cpython.PyBytes_CheckExact(blob) frb_init( &in_buf, cpython.PyBytes_AS_STRING(blob), cpython.Py_SIZE(blob) ) args_needs_recoding = curr_blob_index == 0 if args_needs_recoding: # Skip prefixed argument count if frb_get_len(&in_buf) == 0: pass else: frb_read(&in_buf, 4) setup_blob_buffer() curr_arg_index = 0 target_indexes.sort() result: dict[int, bytes] = {} for target_index in target_indexes: # Read up to the end of the target variable for arg_index in range(curr_arg_index, target_index + 1): if frb_get_len(&in_buf) == 0: # We've reached the end of the previous blob. # Set up the next one and keep scanning. curr_blob_index += 1 setup_blob_buffer() if args_needs_recoding: # Skip reserved frb_read(&in_buf, 4) # reserved in_len = hton.unpack_int32(frb_read(&in_buf, 4)) data_str = frb_read(&in_buf, in_len) if arg_index == target_index: # Store the target variable data data = cpython.PyBytes_FromStringAndSize(data_str, in_len) result[target_index] = data curr_arg_index = target_index + 1 return result # After param conversions, we need to re-encode the converted # arg before putting it into the recoded bind args cdef class ConvertedArg: def encode(self, buffer: WriteBuffer): raise NotImplementedError cdef class ConvertedArgStr(ConvertedArg): @staticmethod cdef ConvertedArgStr new(str data): cdef ConvertedArgStr result result = ConvertedArgStr.__new__(ConvertedArgStr) result.bind_format_code = 0x0000 result.data = data return result def encode(self, buffer: WriteBuffer): encoded = self.data.encode() buffer.write_int32(len(encoded)) buffer.write_bytes(encoded) cdef class ConvertedArgFloat64(ConvertedArg): @staticmethod cdef ConvertedArgFloat64 new(float data): cdef ConvertedArgFloat64 result result = ConvertedArgFloat64.__new__(ConvertedArgFloat64) result.bind_format_code = 0x0001 result.data = data return result def encode(self, buffer: WriteBuffer): buffer.write_int32(8) # elem size buffer.write_double(self.data) cdef class ConvertedArgListFloat32(ConvertedArg): @staticmethod cdef ConvertedArgListFloat32 new(list data): cdef ConvertedArgListFloat32 result result = ConvertedArgListFloat32.__new__(ConvertedArgListFloat32) result.bind_format_code = 0x0001 result.data = data return result def encode(self, buffer: WriteBuffer): elem_count = len(self.data) buffer.write_int32(12 + 8 + elem_count * 8) # buffer length buffer.write_int32(1) # number of dimensions buffer.write_int32(0) # flags buffer.write_int32(700) # array_tid for "real" buffer.write_int32(elem_count) # count buffer.write_int32(1) # bound for elem in self.data: buffer.write_int32(4) # elem size buffer.write_float(elem) ================================================ FILE: edb/server/protocol/auth/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import TYPE_CHECKING import http import json from edb import errors from edb.common import debug from edb.common import markup from . import scram if TYPE_CHECKING: from edb.server import tenant as edbtenant from edb.server.protocol import protocol async def handle_request( request: protocol.HttpRequest, response: protocol.HttpResponse, path_parts: list[str], tenant: edbtenant.Tenant, ) -> None: try: if path_parts == ["token"]: if not request.authorization: response.status = http.HTTPStatus.UNAUTHORIZED response.custom_headers["WWW-Authenticate"] = "SCRAM-SHA-256" return scheme, _, auth_str = request.authorization.decode( "ascii" ).partition(" ") if scheme.lower().startswith("scram"): scram.handle_request(scheme, auth_str, response, tenant) else: response.body = b"Unsupported authentication scheme" response.status = http.HTTPStatus.UNAUTHORIZED response.custom_headers["WWW-Authenticate"] = "SCRAM-SHA-256" response.close_connection = True else: response.body = b"Unknown path" response.status = http.HTTPStatus.NOT_FOUND response.close_connection = True except errors.EdgeDBError as ex: if debug.flags.server: markup.dump(ex) _response_error( response, http.HTTPStatus.INTERNAL_SERVER_ERROR, str(ex), type(ex) ) except Exception as ex: if debug.flags.server: markup.dump(ex) # XXX Fix this when LSP "location" objects are implemented ex_type = errors.InternalServerError _response_error( response, http.HTTPStatus.INTERNAL_SERVER_ERROR, str(ex), ex_type ) def _response_error( response: protocol.HttpResponse, status: http.HTTPStatus, message: str, ex_type: type[errors.EdgeDBError], ) -> None: err_dct = { "message": message, "type": str(ex_type.__name__), "code": ex_type.get_code(), } response.body = json.dumps({"error": err_dct}).encode() response.status = status response.close_connection = True ================================================ FILE: edb/server/protocol/auth/scram.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import NamedTuple, Optional, TYPE_CHECKING import base64 import collections import hashlib import http import os import time from edgedb import scram from edb.common import debug from edb.common import markup from edb.server import auth if TYPE_CHECKING: from edb.server import tenant as edbtenant from edb.server.protocol import protocol SESSION_TIMEOUT: float = 30 SESSION_HIGH_WATER_MARK: float = SESSION_TIMEOUT * 10 class Session(NamedTuple): time: float client_nonce: str server_nonce: str client_first_bare: bytes cb_flag: bool server_first: bytes verifier: scram.SCRAMVerifier mock_auth: bool username: str sessions: collections.OrderedDict[str, Session] = collections.OrderedDict() def handle_request( scheme: str, auth_str: str, response: protocol.HttpResponse, tenant: edbtenant.Tenant, ) -> None: server = tenant.server if scheme != "SCRAM-SHA-256": response.body = ( b"Client selected an invalid SASL authentication mechanism" ) response.status = http.HTTPStatus.UNAUTHORIZED response.custom_headers["WWW-Authenticate"] = "SCRAM-SHA-256" return data = None sid = None try: for kv_str in auth_str.split(): key, _, value = kv_str.rstrip(",").partition("=") if key == "data": data = base64.b64decode(value.strip('"')).strip() elif key == "sid": sid = value.strip('"') if data is None: raise ValueError("Malformed SCRAM message: data is missing") except Exception as ex: if debug.flags.server: markup.dump(ex) response.body = str(ex).encode("ascii") response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True return jws = server.get_jws_key() if jws is None or not jws.has_private_keys(): response.body = b"Server doesn't support HTTP SCRAM authentication" response.status = http.HTTPStatus.FORBIDDEN response.close_connection = True return if sid is None: try: bare_offset: int cb_flag: bool authzid: Optional[bytes] username_bytes: bytes client_nonce: str ( bare_offset, cb_flag, authzid, username_bytes, client_nonce, ) = scram.parse_client_first_message(data) except ValueError as ex: if debug.flags.server: markup.dump(ex) response.body = f"Bad client first message: {ex!s}".encode("ascii") response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True return username = username_bytes.decode("utf-8") client_first_bare = data[bare_offset:] if isinstance(cb_flag, str): response.body = ( b"Malformed SCRAM message: " b"The client selected SCRAM-SHA-256 without " b"channel binding, but the SCRAM message " b"includes channel binding data." ) response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True return if authzid: response.body = ( b"Client uses SASL authorization identity, " b"which is not supported" ) response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True return try: verifier, mock_auth = get_scram_verifier(username, tenant) except ValueError as ex: if debug.flags.server: markup.dump(ex) response.body = b"Authentication failed" response.status = http.HTTPStatus.UNAUTHORIZED response.custom_headers["WWW-Authenticate"] = "SCRAM-SHA-256" return server_nonce: str = scram.generate_nonce() server_first: bytes = scram.build_server_first_message( server_nonce, client_nonce, verifier.salt, verifier.iterations ).encode("utf-8") if len(sessions) > SESSION_HIGH_WATER_MARK: while sessions: key, session = sessions.popitem(last=False) if session.time + SESSION_TIMEOUT > time.monotonic(): sessions[key] = session sessions.move_to_end(key, last=False) break sid = ( base64.urlsafe_b64encode(os.urandom(16)) .decode("ascii") .rstrip("=") ) assert sid not in sessions sessions[sid] = Session( time.monotonic(), client_nonce, server_nonce, client_first_bare, cb_flag, server_first, verifier, mock_auth, username, ) server_first_str = base64.b64encode(server_first).decode("ascii") response.status = http.HTTPStatus.UNAUTHORIZED response.custom_headers[ "WWW-Authenticate" ] = f"SCRAM-SHA-256 sid={sid}, data={server_first_str}" else: session = sessions.pop(sid) if session is None: response.body = b"Bad session ID" response.status = http.HTTPStatus.UNAUTHORIZED response.custom_headers["WWW-Authenticate"] = "SCRAM-SHA-256" return ( ts, client_nonce, server_nonce, client_first_bare, cb_flag, server_first, verifier, mock_auth, username, ) = session if ts + SESSION_TIMEOUT < time.monotonic(): response.body = b"Session timed out" response.status = http.HTTPStatus.UNAUTHORIZED response.custom_headers["WWW-Authenticate"] = "SCRAM-SHA-256" return try: ( cb_data, client_proof, proof_len, ) = scram.parse_client_final_message( data, client_nonce, server_nonce ) except ValueError as ex: if debug.flags.server: markup.dump(ex) response.body = f"Bad client final message: {ex!s}".encode("ascii") response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True return client_final_without_proof = data[:-proof_len] cb_data_ok = (cb_flag is False and cb_data == b"biws") or ( cb_flag is True and cb_data == b"eSws" ) if not cb_data_ok: response.body = ( b"Malformed SCRAM message: " b"Unexpected SCRAM channel-binding attribute " b"in client-final-message." ) response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True return if ( not scram.verify_client_proof( client_first_bare, server_first, client_final_without_proof, verifier.stored_key, client_proof, ) or mock_auth ): response.body = b"Authentication failed" response.status = http.HTTPStatus.UNAUTHORIZED response.custom_headers["WWW-Authenticate"] = "SCRAM-SHA-256" return server_final = base64.b64encode( scram.build_server_final_message( client_first_bare, server_first, client_final_without_proof, verifier.server_key, ).encode("utf-8") ).decode("ascii") try: response.body = auth.generate_gel_token( jws, roles=[username], ).encode("ascii") except ValueError as ex: if debug.flags.server: markup.dump(ex) response.body = b"Authentication failed" response.status = http.HTTPStatus.UNAUTHORIZED response.custom_headers["WWW-Authenticate"] = "SCRAM-SHA-256" return response.custom_headers[ "Authentication-Info" ] = f"sid={sid}, data={server_final}" def get_scram_verifier( user: str, tenant: edbtenant.Tenant, ) -> tuple[scram.SCRAMVerifier, bool]: roles = tenant.get_roles() rolerec = roles.get(user) if rolerec is not None: verifier_string = rolerec["password"] if verifier_string is not None: verifier = scram.parse_verifier(verifier_string) is_mock = False return verifier, is_mock # To avoid revealing the validity of the submitted user name, # generate a mock verifier using a salt derived from the # received user name and the cluster mock auth nonce. # The same approach is taken by Postgres. nonce = tenant.get_instance_data("mock_auth_nonce") salt = hashlib.sha256(nonce.encode() + user.encode()).digest() verifier = scram.SCRAMVerifier( mechanism="SCRAM-SHA-256", iterations=scram.DEFAULT_ITERATIONS, salt=salt[: scram.DEFAULT_SALT_LENGTH], stored_key=b"", server_key=b"", ) is_mock = True return verifier, is_mock ================================================ FILE: edb/server/protocol/auth_ext/__init__.py ================================================ from . import http __all__ = ('http', ) ================================================ FILE: edb/server/protocol/auth_ext/_static/interactions.js ================================================ document.addEventListener("DOMContentLoaded", () => { /** @type {HTMLElement | null} */ const sliderContainer = /** @type {HTMLElement | null} */ ( document.getElementById("slider-container") ); if (!sliderContainer) { return; } /** @type {HTMLElement | null} */ const tabsContainer = /** @type {HTMLElement | null} */ ( document.getElementById("email-provider-tabs") ); if (tabsContainer) { /** @type {HTMLElement[]} */ const tabButtons = /** @type {HTMLElement[]} */ ( Array.from(tabsContainer.children) ); for (let i = 0; i < tabButtons.length; i++) { const tab = tabButtons[i]; tab.addEventListener("click", () => { activateTab(i); }); /** @param {KeyboardEvent} e */ tab.addEventListener("keydown", (e) => { switch (e.key) { case "ArrowLeft": e.preventDefault(); focusTab((i - 1 + tabButtons.length) % tabButtons.length); break; case "ArrowRight": e.preventDefault(); focusTab((i + 1) % tabButtons.length); break; case "Home": e.preventDefault(); focusTab(0); break; case "End": e.preventDefault(); focusTab(tabButtons.length - 1); break; case "Enter": case " ": e.preventDefault(); activateTab(i); break; } }); } /** @param {number} index */ function focusTab(index) { for (let j = 0; j < tabButtons.length; j++) { const t = tabButtons[j]; const isActive = j === index; t.setAttribute("tabindex", isActive ? "0" : "-1"); } tabButtons[index].focus(); } /** @param {number} index */ function activateTab(index) { const tabChildren = /** @type {HTMLCollection} */ ( /** @type {HTMLElement} */ (tabsContainer).children ); setActiveClass(tabChildren, index); syncAriaState(tabChildren, /** @type {HTMLElement} */ (sliderContainer), index); moveSliderToIndex(/** @type {HTMLElement} */ (sliderContainer), index); } } else { /** @type {HTMLFormElement | null} */ const form = /** @type {HTMLFormElement | null} */ ( document.getElementById("email-factor") ); if (!form) { return; } let mainFormAction = form.action; /** @type {HTMLInputElement[]} */ const hiddenInputs = /** @type {HTMLInputElement[]} */ ( Array.from(form.querySelectorAll("input[type=hidden]")) ).filter((input) => !!input.dataset.secondaryValue); /** @type {string[]} */ const hiddenInputValues = hiddenInputs.map((input) => input.value); if (!sliderContainer.children[0].classList.contains("active")) { form.action = /** @type {any} */ (form.dataset).secondaryAction; setInputValues(hiddenInputs); } /** @type {HTMLElement | null} */ const showBtn = /** @type {HTMLElement | null} */ ( document.getElementById("show-password-form") ); showBtn?.setAttribute("aria-controls", "panel-password"); showBtn?.setAttribute("aria-expanded", "false"); document .getElementById("password-email") ?.setAttribute("aria-describedby", "show-password-form"); document .getElementById("hide-password-form") ?.setAttribute("aria-controls", "panel-password"); showBtn?.addEventListener("click", () => { moveSliderToIndex(/** @type {HTMLElement} */ (sliderContainer), 1); form.action = /** @type {any} */ (form.dataset).secondaryAction; setInputValues(hiddenInputs); document.getElementById("password")?.focus({ preventScroll: true }); showBtn.setAttribute("aria-expanded", "true"); }); document.getElementById("hide-password-form")?.addEventListener("click", () => { moveSliderToIndex(/** @type {HTMLElement} */ (sliderContainer), 0); form.action = mainFormAction; setInputValues(hiddenInputs, hiddenInputValues); showBtn?.setAttribute("aria-expanded", "false"); }); } }); document.addEventListener("DOMContentLoaded", () => { /** @type {HTMLAnchorElement | null} */ const forgotLink = /** @type {HTMLAnchorElement | null} */ ( document.getElementById("forgot-password-link") ); // Find the email input near the forgot link; fall back to a known id. /** @type {HTMLInputElement | null} */ const emailInput = /** @type {HTMLInputElement | null} */ ( forgotLink?.closest("form")?.querySelector('input[name="email"]') ) || /** @type {HTMLInputElement | null} */ ( document.getElementById("password-email") ); if (forgotLink && emailInput) { const href = forgotLink.href; emailInput.addEventListener("input", (e) => { const target = /** @type {HTMLInputElement} */ (e.target); forgotLink.href = `${href}&email=${encodeURIComponent(target.value)}`; }); forgotLink.href = `${href}&email=${encodeURIComponent(emailInput.value)}`; } /** @type {NodeListOf} */ const emailInputs = /** @type {NodeListOf} */ ( document.querySelectorAll("input[name=email]") ); for (const input of emailInputs) { input.addEventListener("input", (e) => { const target = /** @type {HTMLInputElement} */ (e.target); const val = target.value; for (const _input of emailInputs) { if (_input !== input) { _input.value = val; } } }); } }); /** * @param {HTMLInputElement[]} inputs * @param {string[]} [values] */ function setInputValues(inputs, values) { for (let i = 0; i < inputs.length; i++) { const input = inputs[i]; const secondary = /** @type {any} */ (input.dataset).secondaryValue; input.value = values ? values[i] : secondary || ""; } } let firstInteraction = true; /** * @param {HTMLElement} sliderContainer * @param {number} index */ function moveSliderToIndex(sliderContainer, index) { if (firstInteraction) { firstInteraction = false; // Fix the height of the main form card wrapper so the layout doesn't shift // when tabs are clicked const containerWrapper = document.getElementById("container-wrapper"); if (containerWrapper) { containerWrapper.style.height = containerWrapper.getElementsByClassName("container")[0].clientHeight + "px"; } // Set the height for the first time as transition from 'auto' doesn't work sliderContainer.style.height = `${ sliderContainer.getElementsByClassName("active")[0].scrollHeight }px`; } setActiveClass(sliderContainer.children, index); sliderContainer.style.transform = `translateX(${-100 * index}%)`; sliderContainer.style.height = `${sliderContainer.children[index].scrollHeight}px`; } /** * @param {HTMLCollection} tabButtons * @param {HTMLElement} sliderContainer * @param {number} index */ function syncAriaState(tabButtons, sliderContainer, index) { for (let i = 0; i < tabButtons.length; i++) { const tab = /** @type {HTMLElement} */ (tabButtons[i]); const isActive = i === index; tab.setAttribute("aria-selected", isActive ? "true" : "false"); tab.setAttribute("tabindex", isActive ? "0" : "-1"); } const panels = /** @type {HTMLCollectionOf} */ (sliderContainer.children); for (let i = 0; i < panels.length; i++) { const isActive = i === index; if (isActive) { panels[i].removeAttribute("hidden"); panels[i].setAttribute("aria-hidden", "false"); } else { panels[i].setAttribute("hidden", ""); panels[i].setAttribute("aria-hidden", "true"); } } } /** * @param {HTMLCollection} items * @param {number} index */ function setActiveClass(items, index) { for (let i = 0; i < items.length; i++) { if (i === index) { items[i].classList.add("active"); } else { items[i].classList.remove("active"); } } } ================================================ FILE: edb/server/protocol/auth_ext/_static/styles.css ================================================ @font-face { font-family: "Roboto Flex"; font-style: normal; font-display: swap; font-weight: 100 1000; src: url(roboto-flex-latin-wght-normal.woff2) format("woff2-variations"); unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; } body { background: #f3f4f6; margin: 0; padding: 0; min-height: 100vh; height: max-content; display: grid; grid-template-rows: minmax(120px, 1fr) auto minmax(120px, 1fr); justify-content: center; justify-items: center; font-family: "Roboto Flex", sans-serif; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } .brand-logo { margin-bottom: 16px; margin-top: 32px; align-self: end; } .brand-logo img { max-width: 300px; max-height: 100px; } .container-wrapper { grid-row: 2; } .container { background: #fff; padding: 24px; padding-bottom: 16px; width: 326px; border-radius: 16px; box-shadow: 0px 2px 2px rgba(3, 7, 18, 0.02), 0px 7px 7px rgba(3, 7, 18, 0.03), 0px 16px 16px rgba(3, 7, 18, 0.05); display: flex; flex-direction: column; overflow: hidden; } .container form { display: contents; } .container h1 { margin: 0; color: #495057; font-size: 22px; font-style: normal; font-weight: 550; margin-bottom: 20px; } .container h1 span { opacity: 0.7; } .container input { border-radius: 8px; border: 1px solid #dee2e6; background: #f8f9fa; line-height: 40px; padding: 0 14px; color: #495057; font-family: inherit; font-size: 16px; font-weight: 400; outline: none; margin-bottom: 16px; } .container input:focus-visible { outline: 3px solid var(--accent-focus-color); } .container label { color: #495057; font-size: 16px; font-weight: 450; line-height: 18px; margin-bottom: 8px; } .container button { display: grid; align-items: center; grid-template-columns: 1fr auto 1fr; padding: 0 12px; height: 46px; border-radius: 8px; background: var(--accent-bg-color); border: none; color: var(--accent-bg-text-color); font-family: inherit; font-size: 17px; font-weight: 550; cursor: pointer; margin: 8px 0; } .container button span { grid-column: 2; margin: 0 12px; } .container button svg { margin-left: 8px; justify-self: end; } .container button:hover { background: var(--accent-bg-hover-color); } .container button:focus-visible { outline: 3px solid var(--accent-focus-color); outline-offset: 2px; } .container button:disabled { opacity: 0.6; pointer-events: none; } .container button.secondary { background: none; border: 1px solid #ced4da; color: #6c757d; font-weight: 500; } .container button.secondary svg { color: #adb5bd; } .container button.secondary:hover { background: #f5f6f8; } .container button.icon-only { display: flex; width: 46px; padding: 0; justify-content: center; margin-right: 12px; flex-shrink: 0; } .container button.icon-only svg { margin-left: 0; transform: scaleX(-1); } .button-group { display: flex; } .button-group button:not(.icon-only) { flex-grow: 1; } .slider-container { width: calc(100% + 48px); display: flex; align-items: start; margin: 0 -24px; transition: transform 0.3s, height 0.3s; } .slider-section { width: calc(100% - 48px); margin: 0 24px; flex-shrink: 0; display: flex; flex-direction: column; height: 0; visibility: hidden; opacity: 0; transition: opacity 0.15s 0s linear, visibility 0s 0.3s linear; } .slider-section > *, .slider-section > form > * { flex-shrink: 0; } .slider-section.active { height: auto; visibility: visible; opacity: 1; transition-delay: 0s; } .tabs { display: flex; justify-content: center; gap: 12px; margin-bottom: 20px; } .tab { position: relative; display: flex; height: 38px; align-items: center; padding: 0 12px; color: #6c757d; font-size: 15px; font-weight: 550; cursor: pointer; } .tab svg { position: absolute; bottom: -1px; left: 0; width: 100%; fill: var(--accent-text-color); opacity: 0; transition: opacity 0.3s; } .tab.active { color: #495057; } .tab.active svg { opacity: 1; } a { outline: 0; text-decoration: none; } a:focus-visible { text-decoration: underline solid 2px var(--accent-focus-color); text-underline-offset: 4px; } .field-header { display: flex; justify-content: space-between; } .field-note { color: #97a1ab; font-size: 14px; font-weight: 400; } a.field-note:hover { color: var(--accent-text-color); } .oauth-buttons { display: flex; flex-direction: column; margin-bottom: 8px; gap: 16px; } .oauth-buttons a { display: flex; align-items: center; justify-content: start; height: 46px; border-radius: 8px; padding: 0 12px; border: 1px solid #dee2e6; text-decoration: none; color: #495057; font-size: 16px; font-weight: 450; } .oauth-buttons a:hover { background: #f5f6f8; } .oauth-buttons a:focus-visible { outline: 3px solid var(--accent-focus-color); } .oauth-buttons a span { margin-left: 12px; } .oauth-buttons a img { width: 32px; height: 32px; object-fit: contain; } .oauth-buttons.collapsed { flex-direction: row; flex-wrap: wrap; } .oauth-buttons.collapsed a { padding: 0; width: 46px; justify-content: center; flex-shrink: 0; } .oauth-buttons.collapsed a span { display: none; } .divider { display: flex; align-items: center; color: #6c757d; font-size: 16px; font-weight: 450; line-height: 19px; margin-top: 12px; margin-bottom: 16px; } .divider span { margin: 0 16px; } .divider:before, .divider:after { content: ""; height: 0; border-bottom: 1px solid #dee2e6; flex-grow: 1; } .bottom-note { color: #6c757d; font-size: 16px; font-weight: 400; line-height: 19px; margin-top: 4px; } .bottom-note a { color: var(--accent-text-color); } .error-message, .success-message { display: flex; padding: 10px 12px; align-items: center; gap: 12px; border-radius: 8px; border: 1px solid #f9827b; background: #fee6e5; color: #eb4b42; font-size: 14px; font-weight: 400; line-height: 19px; margin-bottom: 16px; } .error-message svg, .success-message svg { flex-shrink: 0; } .error-message a, .success-message a { color: var(--accent-text-color); } .error-message b, .success-message b { font-weight: 600; } .success-message { color: #1f8aed; border-color: #1f8aed; background: #e4f1fc; } .no-webauthn-error { border: 1px solid #f9827b; background: #fee6e5; color: #eb4b42; margin: 8px 0; border-radius: 8px; padding: 10px 12px; } @media (prefers-color-scheme: dark) { body { background: #191c1f; color: #dee2e6; } .container { background: #2a2f34; } .container h1 { color: #dee2e6; } .container button.secondary { border-color: #495057; color: #ced4da; } .container button.secondary svg { color: #6c757d; } .container button.secondary:hover { background: #363c42; } .container input { border-color: #495057; background: #31373d; color: #dee2e6; } .container input:focus-visible { outline-color: var(--accent-focus-dark-color); } .container label { color: #dee2e6; } a:focus-visible { text-decoration-color: var(--accent-focus-dark-color); } .field-note { color: #adb5bd; } a.field-note:hover { color: var(--accent-text-dark-color); } .oauth-buttons a { border-color: #495057; color: #dee2e6; } .oauth-buttons a:hover { background: #363c42; } .oauth-buttons a:focus-visible { outline-color: var(--accent-focus-dark-color); } .divider { color: #6c757d; } .divider:before, .divider:after { border-bottom-color: #495057; } .tab { color: #adb5bd; } .tab.active { color: #dee2e6; } .bottom-note { color: #ced4da; } .bottom-note a { color: var(--accent-text-dark-color); } .error-message a, .success-message a { color: var(--accent-text-dark-color); } .error-message { background: #423336; border-color: #a1433d; } .success-message { background: #293a4a; } .no-webauthn-error { background: #423336; border-color: #a1433d; } } ================================================ FILE: edb/server/protocol/auth_ext/_static/utils.js ================================================ /** * @param {(form: HTMLFormElement) => void} handler * @returns Uint8Array */ export function addWebAuthnSubmitHandler(handler) { document.addEventListener("DOMContentLoaded", () => { if (!window.PublicKeyCredential) { console.error("WebAuthn is not supported in this browser."); for (const button of [ document.getElementById("webauthn-signin"), document.getElementById("webauthn-signup"), ]) { if (button) { const newEl = document.createElement("div"); newEl.classList.add("no-webauthn-error"); newEl.appendChild( document.createTextNode( `Your browser does not support the WebAuthn API. ` + `Use another login method, or upgrade your browser.` ) ); button.parentNode.replaceChild(newEl, button); } } return; } const emailFactorForm = document.getElementById("email-factor"); if (emailFactorForm === null) { return; } emailFactorForm.addEventListener("submit", (event) => { if (new URL(emailFactorForm.action).pathname == location.pathname) { event.preventDefault(); handler(emailFactorForm); } }); }); } /** * Decode a base64url encoded string * @param {string} base64UrlString * @returns Uint8Array */ export function decodeBase64Url(base64UrlString) { return Uint8Array.from( atob(base64UrlString.replace(/-/g, "+").replace(/_/g, "/")), (c) => c.charCodeAt(0) ); } /** * Encode a Uint8Array to a base64url encoded string * @param {Uint8Array} bytes * @returns string */ export function encodeBase64Url(bytes) { return btoa(String.fromCharCode(...bytes)) .replace(/\+/g, "-") .replace(/\//g, "_") .replace(/=/g, ""); } /** * Parse an HTTP Response object. Allows passing in custom handlers for * different status codes and error.type values * * @param {Response} response * @param {Function[]=} handlers */ export async function parseResponseAsJSON(response, handlers = []) { const bodyText = await response.text(); if (!response.ok) { let error; try { error = JSON.parse(bodyText)?.error; } catch (e) { throw new Error( `Failed to parse body as JSON. Status: ${response.status} ${response.statusText}. Body: ${bodyText}` ); } for (const handler of handlers) { handler(response, error); } throw new Error( `Response was not OK. Status: ${response.status} ${response.statusText}. Body: ${bodyText}` ); } return JSON.parse(bodyText); } ================================================ FILE: edb/server/protocol/auth_ext/_static/webauthn-authenticate.js ================================================ import { addWebAuthnSubmitHandler, decodeBase64Url, encodeBase64Url, parseResponseAsJSON, } from "./utils.js"; addWebAuthnSubmitHandler(onAuthenticateSubmit); let authenticating = false; /** * Handle the form submission for WebAuthn authentication * @param {HTMLFormElement} form * @returns void */ async function onAuthenticateSubmit(form) { if (authenticating) { return; } authenticating = true; const signinButton = document.getElementById("webauthn-signin"); signinButton.disabled = true; const formData = new FormData(form); const email = formData.get("email"); const provider = "builtin::local_webauthn"; const challenge = formData.get("challenge"); const redirectOnFailure = formData.get("redirect_on_failure"); const redirectTo = formData.get("redirect_to"); const missingFields = Object.entries({ email, challenge, redirectTo, }).filter(([k, v]) => !v); if (missingFields.length > 0) { throw new Error( "Missing required parameters: " + missingFields.map(([k]) => k).join(", ") ); } try { const response = await authenticate({ email, provider, challenge, }); const redirectUrl = new URL(redirectTo); if ("code" in response) { redirectUrl.searchParams.append("code", response.code); } else if ("verification_email_sent_at" in response) { redirectUrl.searchParams.append( "verification_email_sent_at", response.verification_email_sent_at ); } window.location.href = redirectUrl.href; } catch (error) { console.error("Failed to authenticate WebAuthn credentials:", error); const url = new URL(redirectOnFailure ?? redirectTo); url.searchParams.append("error", error.message); window.location.href = url.href; } finally { authenticating = false; signinButton.disabled = false; } } const WEBAUTHN_OPTIONS_URL = new URL( "../webauthn/authenticate/options", window.location ); const WEBAUTHN_AUTHENTICATE_URL = new URL( "../webauthn/authenticate", window.location ); /** * Authenticate an existing WebAuthn credential for the given email address * @param {Object} props - The properties for registration * @param {string} props.email - Email address to register * @param {string} props.provider - WebAuthn provider * @param {string} props.challenge - PKCE challenge * @returns {Promise} - The server response */ export async function authenticate({ email, provider, challenge }) { // Check if WebAuthn is supported if (!window.PublicKeyCredential) { console.error("WebAuthn is not supported in this browser."); return; } // Fetch WebAuthn options from the server const options = await getAuthenticateOptions(email); // Get the existing credentials assertion const assertion = await navigator.credentials.get({ publicKey: { ...options, challenge: decodeBase64Url(options.challenge), allowCredentials: options.allowCredentials.map((credential) => ({ ...credential, id: decodeBase64Url(credential.id), })), }, }); // Register the credentials on the server return await authenticateAssertion({ email, assertion, challenge, }); } /** * Fetch WebAuthn options from the server * @param {string} email - Email address to register * @returns {Promise} */ async function getAuthenticateOptions(email) { const url = new URL(WEBAUTHN_OPTIONS_URL); url.searchParams.set("email", email); const optionsResponse = await fetch(url, { method: "GET", }); return parseResponseAsJSON(optionsResponse, [ (response, error) => { if (response.status === 400 && error?.type === "InvalidData") { throw new Error(error?.message ?? "Email is invalid"); } if (!response.ok) { console.error( "Failed to fetch WebAuthn options:", optionsResponse.statusText ); console.error(error); throw new Error("Failed to fetch WebAuthn options"); } }, ]); } /** * Authenticate the credentials on the server * @param {Object} props * @param {string} props.email * @param {Object} props.assertion * @param {string} props.provider * @param {string} props.challenge * @returns {Promise} */ async function authenticateAssertion(props) { // Assertion includes raw bytes, so need to be encoded as base64url // for transmission const encodedAssertion = { type: props.assertion.type, id: props.assertion.id, authenticatorAttachment: props.assertion.authenticatorAttachment, clientExtensionResults: props.assertion.getClientExtensionResults(), rawId: encodeBase64Url(new Uint8Array(props.assertion.rawId)), response: { authenticatorData: encodeBase64Url( new Uint8Array(props.assertion.response.authenticatorData) ), clientDataJSON: encodeBase64Url( new Uint8Array(props.assertion.response.clientDataJSON) ), signature: encodeBase64Url( new Uint8Array(props.assertion.response.signature) ), userHandle: props.assertion.response.userHandle ? encodeBase64Url(new Uint8Array(props.assertion.response.userHandle)) : null, }, }; const authenticateResponse = await fetch(WEBAUTHN_AUTHENTICATE_URL, { method: "POST", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ email: props.email, assertion: encodedAssertion, provider: props.provider, challenge: props.challenge, }), }); return await parseResponseAsJSON(authenticateResponse, [ (response, error) => { if (response.status === 401 && error?.type === "VerificationRequired") { console.error( "User's email is not verified", response.statusText, JSON.stringify(error) ); throw new Error( "Please verify your email before attempting to sign in." ); } }, (response, error) => { console.error( "Failed to authenticate WebAuthn credentials:", response.statusText, JSON.stringify(error) ); throw new Error("Failed to authenticate WebAuthn credentials"); }, ]); } ================================================ FILE: edb/server/protocol/auth_ext/_static/webauthn-register.js ================================================ import { addWebAuthnSubmitHandler, decodeBase64Url, encodeBase64Url, } from "./utils.js"; addWebAuthnSubmitHandler(onRegisterSubmit); let registering = false; /** * Handle the form submission for WebAuthn registration * @param {HTMLFormElement} form * @returns void */ export async function onRegisterSubmit(form) { if (registering) { return; } registering = true; const registerButton = document.getElementById("webauthn-signup"); registerButton.disabled = true; const formData = new FormData(form); const email = formData.get("email"); const provider = "builtin::local_webauthn"; const challenge = formData.get("challenge"); const redirectOnFailure = formData.get("redirect_on_failure"); const redirectTo = formData.get("redirect_to"); const verifyUrl = formData.get("verify_url"); try { const missingFields = Object.entries({ email, provider, challenge, redirectTo, verifyUrl, }).filter(([k, v]) => !v); if (missingFields.length > 0) { throw new Error( "Missing required parameters: " + missingFields.map(([k]) => k).join(", ") ); } const response = await register({ email, provider, challenge, verifyUrl, }); const redirectUrl = new URL(redirectTo); redirectUrl.searchParams.append("isSignUp", "true"); if ("code" in response) { redirectUrl.searchParams.append("code", response.code); } else if ("verification_email_sent_at" in response) { redirectUrl.searchParams.append( "verification_email_sent_at", response.verification_email_sent_at ); } if ("email" in response) { redirectUrl.searchParams.append("email", response.email); } window.location.href = redirectUrl.href; } catch (error) { console.error("Failed to register WebAuthn credentials:", error); const url = new URL(redirectOnFailure ?? redirectTo); url.searchParams.append("error", error.message); window.location.href = url.href; } finally { registering = false; registerButton.disabled = false; } } const WEBAUTHN_OPTIONS_URL = new URL( "../webauthn/register/options", window.location ); const WEBAUTHN_REGISTER_URL = new URL("../webauthn/register", window.location); /** * Register a new WebAuthn credential for the given email address * @param {Object} props - The properties for registration * @param {string} props.email - Email address to register * @param {string} props.provider - WebAuthn provider * @param {string} props.challenge - PKCE challenge * @param {string} props.verifyUrl - URL to verify email after registration * @returns {Promise} - The server response */ export async function register({ email, provider, challenge, verifyUrl }) { // Check if WebAuthn is supported if (!window.PublicKeyCredential) { console.error("WebAuthn is not supported in this browser."); return; } // Fetch WebAuthn options from the server const options = await getCreateOptions(email); // Register the new credential const credentials = await navigator.credentials.create({ publicKey: { ...options, challenge: decodeBase64Url(options.challenge), user: { ...options.user, id: decodeBase64Url(options.user.id), }, }, }); // Register the credentials on the server return await registerCredentials({ email, credentials, provider, challenge, verifyUrl, }); } /** * Fetch WebAuthn options from the server * @param {string} email - Email address to register * @returns {Promise} */ async function getCreateOptions(email) { const url = new URL(WEBAUTHN_OPTIONS_URL); url.searchParams.set("email", email); const optionsResponse = await fetch(url, { method: "GET", }); if (!optionsResponse.ok) { console.error( "Failed to fetch WebAuthn options:", optionsResponse.statusText ); console.error(await optionsResponse.text()); throw new Error("Failed to fetch WebAuthn options"); } try { return await optionsResponse.json(); } catch (e) { console.error("Failed to parse WebAuthn options:", e); throw new Error("Failed to parse WebAuthn options"); } } /** * Register the credentials on the server * @param {Object} props * @param {string} props.email * @param {Object} props.credentials * @param {string} props.provider * @param {string} props.challenge * @param {string} props.verifyUrl * @returns {Promise} */ async function registerCredentials(props) { // Credentials include raw bytes, so need to be encoded as base64url // for transmission const encodedCredentials = { type: props.credentials.type, authenticatorAttachment: props.credentials.authenticatorAttachment, clientExtensionResults: props.credentials.getClientExtensionResults(), id: props.credentials.id, rawId: encodeBase64Url(new Uint8Array(props.credentials.rawId)), response: { attestationObject: encodeBase64Url( new Uint8Array(props.credentials.response.attestationObject) ), clientDataJSON: encodeBase64Url( new Uint8Array(props.credentials.response.clientDataJSON) ), }, }; const registerResponse = await fetch(WEBAUTHN_REGISTER_URL, { method: "POST", headers: { "Content-Type": "application/json", }, body: JSON.stringify({ email: props.email, credentials: encodedCredentials, provider: props.provider, challenge: props.challenge, verify_url: props.verifyUrl, }), }); if (!registerResponse.ok) { console.error( "Failed to register WebAuthn credentials:", registerResponse.statusText ); console.error(await registerResponse.text()); throw new Error("Failed to register WebAuthn credentials"); } try { return await registerResponse.json(); } catch (e) { console.error("Failed to parse WebAuthn registration result:", e); throw new Error("Failed to parse WebAuthn registration result"); } } ================================================ FILE: edb/server/protocol/auth_ext/apple.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any import uuid import urllib.parse from . import base class AppleProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "apple", "https://appleid.apple.com", *args, **kwargs, ) async def get_code_url( self, state: str, redirect_uri: str, additional_scope: str ) -> str: oidc_config = await self._get_oidc_config() params = { "client_id": self.client_id, # Non-standard "name" scope "scope": f"openid email name {additional_scope}", "state": state, "redirect_uri": redirect_uri, "nonce": str(uuid.uuid4()), "response_type": "code id_token", "response_mode": "form_post", } encoded = urllib.parse.urlencode(params) return f"{oidc_config.authorization_endpoint}?{encoded}" ================================================ FILE: edb/server/protocol/auth_ext/azure.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any from . import base class AzureProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "azure", "https://login.microsoftonline.com/common/v2.0", *args, **kwargs, ) ================================================ FILE: edb/server/protocol/auth_ext/base.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import uuid import urllib.parse import enum import logging from typing import Any, Callable from datetime import datetime from . import data, errors from edb.server.http import HttpClient from edb.server import auth as jwt_auth from edb.server.protocol.auth_ext import util as auth_util from edb.server import metrics logger = logging.getLogger("edb.server.ext.auth") class BaseProvider: def __init__( self, name: str, issuer_url: str, client_id: str, client_secret: str, *, additional_scope: str | None, http_factory: Callable[..., HttpClient], ): self.name = name self.issuer_url = issuer_url self.client_id = client_id self.client_secret = client_secret self.http_factory = http_factory self.additional_scope = additional_scope async def get_code_url( self, state: str, redirect_uri: str, additional_scope: str ) -> str: raise NotImplementedError async def exchange_code( self, code: str, redirect_uri: str ) -> data.OAuthAccessTokenResponse: raise NotImplementedError async def fetch_user_info( self, token_response: data.OAuthAccessTokenResponse ) -> data.UserInfo: raise NotImplementedError def _maybe_isoformat_to_timestamp(self, value: str | None) -> float | None: return datetime.fromisoformat(value).timestamp() if value else None class ContentType(enum.StrEnum): JSON = "application/json" FORM_ENCODED = "application/x-www-form-urlencoded" class OpenIDConnectProvider(BaseProvider): def __init__( self, name: str, issuer_url: str, *args: Any, **kwargs: Any, ): super().__init__(name, issuer_url, *args, **kwargs) async def get_code_url( self, state: str, redirect_uri: str, additional_scope: str ) -> str: oidc_config = await self._get_oidc_config() params = { "client_id": self.client_id, "scope": f"openid profile email {additional_scope}", "state": state, "redirect_uri": redirect_uri, "nonce": str(uuid.uuid4()), "response_type": "code", } encoded = urllib.parse.urlencode(params) return f"{oidc_config.authorization_endpoint}?{encoded}" async def exchange_code( self, code: str, redirect_uri: str ) -> data.OpenIDConnectAccessTokenResponse: oidc_config = await self._get_oidc_config() token_endpoint = urllib.parse.urlparse(oidc_config.token_endpoint) async with self.http_factory( base_url=f"{token_endpoint.scheme}://{token_endpoint.netloc}" ) as client: request_body = { "grant_type": "authorization_code", "code": code, "client_id": self.client_id, "client_secret": self.client_secret, "redirect_uri": redirect_uri, } headers = {"Accept": ContentType.JSON.value} resp = await client.post( token_endpoint.path, data=request_body, headers=headers, ) if resp.status_code >= 400: raise errors.OAuthProviderFailure( f"Failed to exchange code: {resp.text}" ) content_type = resp.headers.get('Content-Type') if content_type.startswith(str(ContentType.JSON)): response_body = resp.json() else: response_body = { k: v[0] if len(v) == 1 else v for k, v in urllib.parse.parse_qs(resp.text).items() } return data.OpenIDConnectAccessTokenResponse(**response_body) async def fetch_user_info( self, token_response: data.OAuthAccessTokenResponse ) -> data.UserInfo: if not isinstance( token_response, data.OpenIDConnectAccessTokenResponse ): raise TypeError( "token_response must be of type " "OpenIDConnectAccessTokenResponse" ) id_token = token_response.id_token # Retrieve JWK Set, potentially from the cache oidc_config = await self._get_oidc_config() try: async def fetcher(url: str) -> jwt_auth.JWKSet: jwks_uri = urllib.parse.urlparse(url) async with self.http_factory( base_url=f"{jwks_uri.scheme}://{jwks_uri.netloc}" ) as client: r = await client.get(jwks_uri.path, cache=True) jwk_set = jwt_auth.JWKSet() jwk_set.load_json(r.text) jwk_set.default_validation_context.allow( "aud", [self.client_id] ) jwk_set.default_validation_context.require_expiry() metrics.auth_provider_jwkset_fetch_success.inc( 1.0, self.name ) return jwk_set jwk_set = await auth_util.get_remote_jwtset( oidc_config.jwks_uri, fetcher ) except Exception as e: metrics.auth_provider_jwkset_fetch_errors.inc(1.0, self.name) logger.exception( f"Failed to fetch JWK Set from provider {oidc_config.jwks_uri}" ) raise errors.MisconfiguredProvider( f"Failed to fetch JWK Set from provider {oidc_config.jwks_uri}" ) from e # Load the token as a JWT object and verify it directly. This will # validate the audience and expiry. try: payload = jwk_set.validate(id_token) except Exception as e: metrics.auth_provider_token_validation_errors.inc(1.0, self.name) raise errors.MisconfiguredProvider( "Failed to parse ID token with provider keyset" ) from e metrics.auth_provider_token_validation_success.inc(1.0, self.name) return data.UserInfo( sub=str(payload["sub"]), name=payload.get("name"), email=payload.get("email"), picture=payload.get("picture"), source_id_token=id_token, ) async def _get_oidc_config(self) -> data.OpenIDConfig: client = self.http_factory(base_url=self.issuer_url) response = await client.get( '/.well-known/openid-configuration', cache=True ) config = response.json() return data.OpenIDConfig(**config) ================================================ FILE: edb/server/protocol/auth_ext/config.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Literal, Optional from dataclasses import dataclass import urllib.parse from edb.ir import statypes VerificationMethod = Literal['Link', 'Code'] class UIConfig: app_name: Optional[str] logo_url: Optional[str] dark_logo_url: Optional[str] brand_color: Optional[str] redirect_to: str redirect_to_on_signup: Optional[str] @dataclass class AppDetailsConfig: app_name: Optional[str] logo_url: Optional[str] dark_logo_url: Optional[str] brand_color: Optional[str] @dataclass class ProviderConfig: name: str @dataclass class OAuthProviderConfig(ProviderConfig): display_name: str client_id: str secret: str additional_scope: Optional[str] issuer_url: Optional[str] logo_url: Optional[str] @dataclass class DiscordOAuthProviderConfig(OAuthProviderConfig): prompt: str class WebAuthnProviderConfig(ProviderConfig): relying_party_origin: str require_verification: bool verification_method: VerificationMethod @dataclass class WebAuthnProvider: name: str relying_party_origin: str require_verification: bool verification_method: VerificationMethod def __init__( self, name: str, relying_party_origin: str, require_verification: bool, verification_method: VerificationMethod, ): self.name = name self.relying_party_origin = relying_party_origin self.require_verification = require_verification self.verification_method = verification_method parsed_url = urllib.parse.urlparse(self.relying_party_origin) if parsed_url.hostname is None: raise ValueError( "Invalid relying_party_origin, hostname cannot be None" ) self.relying_party_id = parsed_url.hostname @dataclass class EmailPasswordProviderConfig(ProviderConfig): name: Literal["builtin::local_emailpassword"] require_verification: bool verification_method: VerificationMethod @dataclass class MagicLinkProviderConfig(ProviderConfig): name: Literal["builtin::local_magic_link"] token_time_to_live: statypes.Duration verification_method: VerificationMethod auto_signup: bool @dataclass class WebhookConfig: events: list[str] url: str signing_secret_key: Optional[str] ================================================ FILE: edb/server/protocol/auth_ext/data.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import dataclasses import datetime import base64 from typing import Any, Optional @dataclasses.dataclass class UserInfo: """ OpenID Connect compatible user info. See: https://openid.net/specs/openid-connect-core-1_0.html """ sub: str name: Optional[str] = None given_name: Optional[str] = None family_name: Optional[str] = None middle_name: Optional[str] = None nickname: Optional[str] = None preferred_username: Optional[str] = None profile: Optional[str] = None picture: Optional[str] = None website: Optional[str] = None email: Optional[str] = None email_verified: Optional[bool] = None gender: Optional[str] = None birthdate: Optional[str] = None zoneinfo: Optional[str] = None locale: Optional[str] = None phone_number: Optional[str] = None phone_number_verified: Optional[bool] = None address: Optional[dict[str, str]] = None updated_at: Optional[float] = None source_id_token: Optional[str] = None def __str__(self) -> str: return self.sub def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"sub={self.sub!r} " f"name={self.name!r} " f"email={self.email!r} " f"preferred_username={self.preferred_username!r})" ) @dataclasses.dataclass class Identity: id: str subject: str issuer: str created_at: datetime.datetime modified_at: datetime.datetime def __str__(self) -> str: return self.id @dataclasses.dataclass class LocalIdentity(Identity): pass @dataclasses.dataclass class OpenIDConfig: """ OpenID Connect configuration. Only includes fields actually in use. See: - https://openid.net/specs/openid-connect-discovery-1_0.html - https://accounts.google.com/.well-known/openid-configuration """ issuer: str authorization_endpoint: str token_endpoint: str jwks_uri: str def __init__(self, **kwargs: Any): for field in dataclasses.fields(self): setattr(self, field.name, kwargs.get(field.name)) def __str__(self) -> str: return self.issuer @dataclasses.dataclass(repr=False) class OAuthAccessTokenResponse: """ Access Token Response. https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.4 """ access_token: str token_type: str expires_in: int refresh_token: str | None def __init__(self, **kwargs: Any): for field in dataclasses.fields(self): if field.name in kwargs: setattr(self, field.name, kwargs.pop(field.name)) else: setattr(self, field.name, None) self._extra_fields = kwargs @dataclasses.dataclass(repr=False) class OpenIDConnectAccessTokenResponse(OAuthAccessTokenResponse): """ OpenID Connect Access Token Response. https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse """ id_token: str def __init__(self, **kwargs: Any): super().__init__(**kwargs) @dataclasses.dataclass class EmailFactor: id: str created_at: datetime.datetime modified_at: datetime.datetime identity: LocalIdentity email: str verified_at: Optional[datetime.datetime] def __init__( self, *, id: str, created_at: datetime.datetime, modified_at: datetime.datetime, identity: LocalIdentity, email: str, verified_at: Optional[datetime.datetime], ): self.id = id self.created_at = created_at self.modified_at = modified_at self.identity = ( LocalIdentity(**identity) if isinstance(identity, dict) else identity ) self.email = email self.verified_at = verified_at @dataclasses.dataclass class WebAuthnFactor(EmailFactor): user_handle: bytes credential_id: bytes public_key: bytes def __init__( self, *, id: str, created_at: datetime.datetime, modified_at: datetime.datetime, identity: LocalIdentity, email: str, verified_at: Optional[datetime.datetime], user_handle: bytes, credential_id: bytes, public_key: bytes, ): self.id = id self.created_at = created_at self.modified_at = modified_at self.identity = ( LocalIdentity(**identity) if isinstance(identity, dict) else identity ) self.email = email self.verified_at = verified_at self.user_handle = base64.b64decode(user_handle) self.credential_id = base64.b64decode(credential_id) self.public_key = base64.b64decode(public_key) @dataclasses.dataclass class WebAuthnAuthenticationChallenge: id: str created_at: datetime.datetime modified_at: datetime.datetime challenge: bytes factors: list[WebAuthnFactor] def __init__( self, *, id: str, created_at: datetime.datetime, modified_at: datetime.datetime, challenge: bytes, factors: list[WebAuthnFactor], ): self.id = id self.created_at = created_at self.modified_at = modified_at self.challenge = base64.b64decode(challenge) self.factors = [ WebAuthnFactor(**factor) if isinstance(factor, dict) else factor for factor in factors ] ================================================ FILE: edb/server/protocol/auth_ext/discord.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any import urllib.parse import functools from . import base, data, errors class DiscordProvider(base.BaseProvider): def __init__( self, prompt: str, *args: Any, **kwargs: Any ): super().__init__("discord", "https://discord.com", *args, **kwargs) self.auth_domain = self.issuer_url self.api_domain = f"{self.issuer_url}/api/v10" self.auth_client = functools.partial( self.http_factory, base_url=self.auth_domain ) self.api_client = functools.partial( self.http_factory, base_url=self.api_domain ) self.prompt = prompt async def get_code_url( self, state: str, redirect_uri: str, additional_scope: str ) -> str: params = { "client_id": self.client_id, "scope": f"email identify {additional_scope}", "state": state, "redirect_uri": redirect_uri, "response_type": "code", "prompt": self.prompt, } encoded = urllib.parse.urlencode(params) return f"{self.auth_domain}/oauth2/authorize?{encoded}" async def exchange_code( self, code: str, redirect_uri: str ) -> data.OAuthAccessTokenResponse: async with self.auth_client() as client: resp = await client.post( "/api/oauth2/token", data={ "grant_type": "authorization_code", "code": code, "client_id": self.client_id, "client_secret": self.client_secret, "redirect_uri": redirect_uri, }, headers={ "accept": "application/json", }, ) if resp.status_code >= 400: raise errors.OAuthProviderFailure( f"Failed to exchange code: {resp.text}" ) json = resp.json() return data.OAuthAccessTokenResponse(**json) async def fetch_user_info( self, token_response: data.OAuthAccessTokenResponse ) -> data.UserInfo: async with self.api_client() as client: resp = await client.get( "/users/@me", headers={ "Authorization": f"Bearer {token_response.access_token}", "Accept": "application/json", "Cache-Control": "no-store", }, ) payload = resp.json() return data.UserInfo( sub=str(payload["id"]), preferred_username=payload.get("username"), name=payload.get("global_name"), email=payload.get("email"), picture=payload.get("avatar"), ) ================================================ FILE: edb/server/protocol/auth_ext/email.py ================================================ import asyncio import urllib.parse import random import logging from email.message import EmailMessage from typing import Any, Coroutine from edb.server import tenant, smtp from edb import errors from . import util, ui logger = logging.getLogger("edb.server.ext.auth") async def send_password_reset_email( db: Any, tenant: tenant.Tenant, to_addr: str, reset_url: str, test_mode: bool, ) -> None: app_details_config = util.get_app_details_config(db) if app_details_config is None: email_args = {} else: email_args = dict( app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) msg = ui.render_password_reset_email( to_addr=to_addr, reset_url=reset_url, **email_args, ) await _maybe_send_message(msg, tenant, db, test_mode) async def send_verification_email( db: Any, tenant: tenant.Tenant, to_addr: str, verify_url: str, verification_token: str, provider: str, test_mode: bool, ) -> None: app_details_config = util.get_app_details_config(db) verification_token_params = urllib.parse.urlencode( { "verification_token": verification_token, "provider": provider, "email": to_addr, } ) verify_url = f"{verify_url}?{verification_token_params}" if app_details_config is None: email_args = {} else: email_args = dict( app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) msg = ui.render_verification_email( to_addr=to_addr, verify_url=verify_url, **email_args, ) await _maybe_send_message(msg, tenant, db, test_mode) async def send_magic_link_email( db: Any, tenant: tenant.Tenant, to_addr: str, link: str, test_mode: bool, ) -> None: app_details_config = util.get_app_details_config(db) if app_details_config is None: email_args = {} else: email_args = dict( app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) msg = ui.render_magic_link_email( to_addr=to_addr, link=link, **email_args, ) await _maybe_send_message(msg, tenant, db, test_mode) async def send_one_time_code_email( db: Any, tenant: tenant.Tenant, to_addr: str, code: str, test_mode: bool, ) -> None: app_details_config = util.get_app_details_config(db) if app_details_config is None: email_args = {} else: email_args = dict( app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) msg = ui.render_one_time_code_email( to_addr=to_addr, code=code, **email_args, ) await _maybe_send_message(msg, tenant, db, test_mode) async def send_password_reset_code_email( db: Any, tenant: tenant.Tenant, to_addr: str, code: str, test_mode: bool, ) -> None: """Send a password reset email with a one-time code.""" app_details_config = util.get_app_details_config(db) if app_details_config is None: email_args = {} else: email_args = dict( app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) msg = ui.render_password_reset_code_email( to_addr=to_addr, code=code, **email_args, ) await _maybe_send_message(msg, tenant, db, test_mode) async def send_fake_email(tenant: tenant.Tenant) -> None: async def noop_coroutine() -> None: pass coro = noop_coroutine() await _protected_send(coro, tenant) async def _maybe_send_message( msg: EmailMessage, tenant: tenant.Tenant, db: Any, test_mode: bool, ) -> None: try: smtp_provider = smtp.SMTP(db) except errors.ConfigurationError as e: logger.debug( "ConfigurationError while instantiating SMTP provider, " f"sending fake email instead: {e}" ) smtp_provider = None if smtp_provider is None: coro = send_fake_email(tenant) else: coro = smtp_provider.send( msg, test_mode=test_mode, ) await _protected_send(coro, tenant) async def _protected_send( coro: Coroutine[Any, Any, None], tenant: tenant.Tenant ) -> None: task = tenant.create_task(coro, interruptable=True) # Prevent timing attack await asyncio.sleep(random.random() * 0.5) # Expose e.g. configuration errors if task.done(): await task ================================================ FILE: edb/server/protocol/auth_ext/email_password.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import argon2 import json import hashlib import base64 import dataclasses from typing import Any, Optional from edb.errors import ConstraintViolationError from . import errors, util, data, local ph = argon2.PasswordHasher() @dataclasses.dataclass class EmailPasswordProviderConfig: name: str require_verification: bool verification_method: str class Client(local.Client): def __init__(self, db: Any): super().__init__(db) self.config = self._get_provider_config("builtin::local_emailpassword") async def register(self, input: dict[str, Any]) -> data.EmailFactor: match (input.get("email"), input.get("password")): case (str(e), str(p)): email = e password = p case _: raise errors.InvalidData( "Missing 'email' or 'password' in data" ) try: r = await util.json_query( db=self.db, query="""\ with email := $email, password_hash := $password_hash, identity := (insert ext::auth::LocalIdentity { issuer := "local", subject := "", }), factor := (insert ext::auth::EmailPasswordFactor { password_hash := password_hash, email := email, identity := identity, }), select factor { id, email, verified_at, created_at, modified_at, identity: { * }, };""", variables={ "email": email, "password_hash": ph.hash(password), }, ) except ConstraintViolationError: raise errors.UserAlreadyRegistered() result_json = json.loads(r.decode()) assert len(result_json) == 1 return data.EmailFactor(**result_json[0]) async def authenticate( self, email: str, password: str ) -> data.LocalIdentity: r = await util.json_query( db=self.db, query="""\ with email := $email, select ext::auth::EmailPasswordFactor { password_hash, identity: { * } } filter .email = email;""", variables={ "email": email, }, ) password_credential_dicts = json.loads(r.decode()) if len(password_credential_dicts) != 1: raise errors.NoIdentityFound() password_credential_dict = password_credential_dicts[0] password_hash = password_credential_dict["password_hash"] try: ph.verify(password_hash, password) except argon2.exceptions.VerifyMismatchError: raise errors.NoIdentityFound() local_identity = data.LocalIdentity( **password_credential_dict["identity"] ) if ph.check_needs_rehash(password_hash): new_hash = ph.hash(password) await util.json_query( db=self.db, query="""\ with email := $email, new_hash := $new_hash, update ext::auth::EmailPasswordFactor filter .email = email set { password_hash := new_hash };""", variables={ "email": email, "new_hash": new_hash, }, ) return local_identity async def get_email_factor_and_secret( self, email: str, ) -> tuple[data.EmailFactor, str]: r = await util.json_query( db=self.db, query=""" with email := $email, select ext::auth::EmailPasswordFactor { ** } filter .email = email""", variables={ "email": email, }, ) result_json = json.loads(r.decode()) if len(result_json) != 1: raise errors.NoIdentityFound() password_cred = result_json[0] secret = base64.b64encode( hashlib.sha256(password_cred['password_hash'].encode()).digest() ).decode() email_factor = data.EmailFactor( **{k: v for k, v in password_cred.items() if k != "password_hash"} ) return (email_factor, secret) async def validate_reset_secret( self, identity_id: str, secret: str, ) -> Optional[data.LocalIdentity]: r = await util.json_query( db=self.db, query="""\ with identity_id := $identity_id, select ext::auth::EmailPasswordFactor { password_hash, identity: { * } } filter .identity.id = identity_id;""", variables={ "identity_id": identity_id, }, ) result_json = json.loads(r.decode()) if len(result_json) != 1: raise errors.NoIdentityFound() password_cred = result_json[0] local_identity = data.LocalIdentity(**password_cred["identity"]) current_secret = base64.b64encode( hashlib.sha256(password_cred['password_hash'].encode()).digest() ).decode() return local_identity if secret == current_secret else None async def update_password( self, identity_id: str, secret: str, password: str ) -> data.LocalIdentity: local_identity = await self.validate_reset_secret(identity_id, secret) if local_identity is None: raise errors.InvalidData("Invalid 'reset_token'") # TODO: check if race between validating secret and updating password # is a problem await util.json_query( db=self.db, query="""\ with identity_id := $identity_id, new_hash := $new_hash, update ext::auth::EmailPasswordFactor filter .identity.id = identity_id set { password_hash := new_hash, verified_at := .verified_at ?? datetime_current() };""", variables={ 'identity_id': identity_id, 'new_hash': ph.hash(password), }, ) return local_identity def _get_provider_config( self, provider_name: str ) -> EmailPasswordProviderConfig: provider_client_config = util.get_config( self.db, "ext::auth::AuthConfig::providers", frozenset ) for cfg in provider_client_config: if cfg.name == provider_name: return EmailPasswordProviderConfig( name=cfg.name, require_verification=cfg.require_verification, verification_method=cfg.verification_method, ) raise errors.MissingConfiguration( provider_name, f"Provider is not configured" ) ================================================ FILE: edb/server/protocol/auth_ext/errors.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # class AuthExtError(Exception): """Base class for all exceptions raised by the auth extension.""" pass class NotFound(AuthExtError): """Required resource could not be found.""" def __init__(self, description: str): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class MissingConfiguration(AuthExtError): """Required configuration is missing.""" def __init__(self, key: str, description: str): self.key = key self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"key={self.key!r} " f"description={self.description!r}" ")" ) def __str__(self) -> str: return f"{self.description}: {self.key}" class InvalidData(AuthExtError): """Data received from the client is invalid.""" def __init__(self, description: str): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class MisconfiguredProvider(AuthExtError): """Data received from the auth provider is invalid.""" def __init__(self, description: str): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class NoIdentityFound(AuthExtError): """Could not find a matching identity.""" def __init__( self, description: str = ( "Could not find an Identity matching the provided credentials" ), ): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class UserAlreadyRegistered(AuthExtError): """Attempt to register an already registered handle.""" def __init__( self, description: str = ("This user has already been registered"), ): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class OAuthProviderFailure(AuthExtError): """OAuth Provider returned a non-success for some part of the flow""" def __init__( self, description: str, ): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class VerificationTokenExpired(AuthExtError): """Email verification token has expired""" def __init__( self, description: str = "Email verification token has expired", ): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class VerificationRequired(AuthExtError): """Email verification is required""" def __init__( self, description: str = "Email verification is required", ): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class PKCECreationFailed(AuthExtError): """Failed to create a valid PKCEChallenge object""" def __init__( self, description: str = "Failed to create a valid PKCEChallenge object" ): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class PKCEVerificationFailed(AuthExtError): """Verifier and challenge do not match""" def __init__( self, description: str = "Verifier and challenge do not match" ): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class WebAuthnAuthenticationFailed(AuthExtError): """WebAuthn authentication failed""" def __init__(self, description: str = "WebAuthn authentication failed"): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class WebAuthnRegistrationFailed(AuthExtError): """WebAuthn registration failed""" def __init__(self, description: str = "WebAuthn registration failed"): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class OTCVerificationError(AuthExtError): """Base class for one-time code verification failures.""" def __init__(self, description: str): self.description = description def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"description={self.description!r}" ")" ) def __str__(self) -> str: return self.description class OTCRateLimited(OTCVerificationError): """Maximum verification attempts exceeded for the factor.""" def __init__( self, description: str = "Maximum verification attempts exceeded" ): super().__init__(description) class OTCInvalidCode(OTCVerificationError): """The provided one-time code is invalid or not found.""" def __init__(self, description: str = "Invalid code"): super().__init__(description) class OTCExpired(OTCVerificationError): """The one-time code has expired.""" def __init__(self, description: str = "Code has expired"): super().__init__(description) class OTCVerificationFailed(OTCVerificationError): """General verification failure that doesn't fall into other categories.""" def __init__(self, description: str = "Verification failed"): super().__init__(description) ================================================ FILE: edb/server/protocol/auth_ext/github.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any import urllib.parse import functools from . import base, data, errors class GitHubProvider(base.BaseProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__("github", "https://github.com", *args, **kwargs) self.auth_domain = self.issuer_url self.api_domain = "https://api.github.com" self.auth_client = functools.partial( self.http_factory, base_url=self.auth_domain ) self.api_client = functools.partial( self.http_factory, base_url=self.api_domain ) async def get_code_url( self, state: str, redirect_uri: str, additional_scope: str ) -> str: params = { "client_id": self.client_id, "scope": f"read:user user:email {additional_scope}", "state": state, "redirect_uri": redirect_uri, } encoded = urllib.parse.urlencode(params) return f"{self.auth_domain}/login/oauth/authorize?{encoded}" async def exchange_code( self, code: str, redirect_uri: str ) -> data.OAuthAccessTokenResponse: async with self.auth_client() as client: resp = await client.post( "/login/oauth/access_token", json={ "grant_type": "authorization_code", "code": code, "client_id": self.client_id, "client_secret": self.client_secret, "redirect_uri": redirect_uri, }, headers={ "accept": "application/json", }, ) if resp.status_code >= 400: raise errors.OAuthProviderFailure( f"Failed to exchange code: {resp.text}" ) json = resp.json() return data.OAuthAccessTokenResponse(**json) async def fetch_user_info( self, token_response: data.OAuthAccessTokenResponse ) -> data.UserInfo: async with self.api_client() as client: resp = await client.get( "/user", headers={ "Authorization": f"Bearer {token_response.access_token}", "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", "Cache-Control": "no-store", }, ) payload = resp.json() return data.UserInfo( sub=str(payload["id"]), preferred_username=payload.get("login"), name=payload.get("name"), email=payload.get("email"), picture=payload.get("avatar_url"), updated_at=self._maybe_isoformat_to_timestamp( payload.get("updated_at") ), ) ================================================ FILE: edb/server/protocol/auth_ext/google.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any from . import base class GoogleProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "google", "https://accounts.google.com", *args, **kwargs ) ================================================ FILE: edb/server/protocol/auth_ext/http.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import datetime import http import http.cookies import json import logging import urllib.parse import base64 import hashlib import os import mimetypes import uuid import dataclasses from typing import ( Any, Optional, cast, TYPE_CHECKING, Callable, ) import aiosmtplib from edb import errors as edb_errors from edb.common import debug from edb.common import markup from edb.server import tenant as edbtenant, metrics from edb.server.config.types import CompositeConfigType from edb.ir import statypes from . import ( email_password, oauth, errors, util, pkce, ui, config, email as auth_emails, webauthn, magic_link, webhook, jwt, otc, local, ) from .data import EmailFactor if TYPE_CHECKING: from edb.server.protocol import protocol logger = logging.getLogger('edb.server.ext.auth') class Router: test_url: Optional[str] def __init__( self, *, db: edbtenant.dbview.Database, base_path: str, tenant: edbtenant.Tenant, ): self.db = db self.base_path = base_path self.tenant = tenant self.test_mode = tenant.server.in_test_mode() self.signing_key = jwt.SigningKey( lambda: util.get_config( self.db, "ext::auth::AuthConfig::auth_signing_key" ), self.base_path, ) def _get_url_munger( self, request: protocol.HttpRequest ) -> Callable[[str], str] | None: """ Returns a callable that can be used to modify the base URL when making requests to the OAuth provider. This is used to redirect requests to the test OAuth provider when running in test mode. """ if not self.test_mode: return None test_url = ( request.params[b'oauth-test-server'].decode() if (request.params and b'oauth-test-server' in request.params) else None ) if test_url: return lambda path: f"{test_url}{urllib.parse.quote(path)}" return None async def handle_request( self, request: protocol.HttpRequest, response: protocol.HttpResponse, args: list[str], ) -> None: if self.db.db_config is None: await self.db.introspection() self.test_url = ( request.params[b'oauth-test-server'].decode() if ( self.test_mode and request.params and b'oauth-test-server' in request.params ) else None ) logger.info( f"Handling incoming HTTP request: /ext/auth/{'/'.join(args)}" ) try: match args: # PKCE token exchange route case ("token",): await self.handle_token(request, response) # OAuth routes case ("authorize",): await self.handle_authorize(request, response) case ("callback",): await self.handle_callback(request, response) # Email/password routes case ("register",): await self.handle_register(request, response) case ("authenticate",): await self.handle_authenticate(request, response) case ('send-reset-email',): await self.handle_send_reset_email(request, response) case ('reset-password',): await self.handle_reset_password(request, response) # Magic link routes case ('magic-link', 'register'): await self.handle_magic_link_register(request, response) case ('magic-link', 'email'): await self.handle_magic_link_email(request, response) case ('magic-link', 'authenticate'): await self.handle_magic_link_authenticate(request, response) # WebAuthn routes case ('webauthn', 'register'): await self.handle_webauthn_register(request, response) case ('webauthn', 'register', 'options'): await self.handle_webauthn_register_options( request, response ) case ('webauthn', 'authenticate'): await self.handle_webauthn_authenticate(request, response) case ('webauthn', 'authenticate', 'options'): await self.handle_webauthn_authenticate_options( request, response ) # Email verification routes case ("verify",): await self.handle_verify(request, response) case ("resend-verification-email",): await self.handle_resend_verification_email( request, response ) # UI routes case ('ui', 'signin'): await self.handle_ui_signin(request, response) case ('ui', 'signup'): await self.handle_ui_signup(request, response) case ('ui', 'forgot-password'): await self.handle_ui_forgot_password(request, response) case ('ui', 'reset-password'): await self.handle_ui_reset_password(request, response) case ("ui", "verify"): await self.handle_ui_verify(request, response) case ("ui", "resend-verification"): await self.handle_ui_resend_verification(request, response) case ("ui", "magic-link-sent"): await self.handle_ui_magic_link_sent(request, response) case ('ui', '_static', filename): filepath = os.path.join( os.path.dirname(__file__), '_static', filename ) try: with open(filepath, 'rb') as f: response.status = http.HTTPStatus.OK response.content_type = ( mimetypes.guess_type(filename)[0] or 'application/octet-stream' ).encode() response.body = f.read() except FileNotFoundError: response.status = http.HTTPStatus.NOT_FOUND case _: raise errors.NotFound("Unknown auth endpoint") # User-facing errors except errors.NotFound as ex: _fail_with_error( response=response, status=http.HTTPStatus.NOT_FOUND, ex=ex, ) except errors.InvalidData as ex: _fail_with_error( response=response, status=http.HTTPStatus.BAD_REQUEST, ex=ex, ) except errors.PKCEVerificationFailed as ex: _fail_with_error( response=response, status=http.HTTPStatus.FORBIDDEN, ex=ex, ) except errors.NoIdentityFound as ex: _fail_with_error( response=response, status=http.HTTPStatus.FORBIDDEN, ex=ex, ) except errors.UserAlreadyRegistered as ex: _fail_with_error( response=response, status=http.HTTPStatus.CONFLICT, ex=ex, ) except errors.VerificationRequired as ex: _fail_with_error( response=response, status=http.HTTPStatus.UNAUTHORIZED, ex=ex, ) # Server errors except errors.MissingConfiguration as ex: _fail_with_error( response=response, status=http.HTTPStatus.INTERNAL_SERVER_ERROR, ex=ex, ) except errors.WebAuthnRegistrationFailed as ex: _fail_with_error( response=response, status=http.HTTPStatus.BAD_REQUEST, ex=ex, exc_info=True, ) except errors.WebAuthnAuthenticationFailed as ex: _fail_with_error( response=response, status=http.HTTPStatus.UNAUTHORIZED, ex=ex, exc_info=True, ) except Exception as ex: if debug.flags.server: markup.dump(ex) _fail_with_error( response=response, status=http.HTTPStatus.INTERNAL_SERVER_ERROR, ex=edb_errors.InternalServerError(str(ex)), exc_info=True, ) async def handle_authorize( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) provider_name = _get_search_param(query, "provider") allowed_redirect_to = self._make_allowed_url( _get_search_param(query, "redirect_to") ) allowed_redirect_to_on_signup = self._maybe_make_allowed_url( _maybe_get_search_param(query, "redirect_to_on_signup") ) allowed_callback_url = self._maybe_make_allowed_url( _maybe_get_search_param(query, "callback_url") ) challenge = _get_search_param( query, "challenge", fallback_keys=["code_challenge"] ) oauth_client = oauth.Client( db=self.db, provider_name=provider_name, url_munger=self._get_url_munger(request), http_client=self.tenant.get_http_client(originator="auth"), ) await pkce.create(self.db, challenge) redirect_uri = ( allowed_callback_url.url if allowed_callback_url else self._get_callback_url() ) authorize_url = await oauth_client.get_authorize_url( redirect_uri=redirect_uri, state=jwt.OAuthStateToken( provider=provider_name, redirect_to=allowed_redirect_to.url, redirect_to_on_signup=( allowed_redirect_to_on_signup.url if allowed_redirect_to_on_signup else None ), challenge=challenge, redirect_uri=redirect_uri, ).sign(self.signing_key), ) # n.b. Explicitly allow authorization URL to be outside of allowed # URLs because it is a trusted URL from the identity provider. self._do_redirect(response, AllowedUrl(authorize_url)) async def handle_callback( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: if request.method == b"POST" and ( request.content_type == b"application/x-www-form-urlencoded" ): form_data = urllib.parse.parse_qs(request.body.decode()) state = _maybe_get_form_field(form_data, "state") code = _maybe_get_form_field(form_data, "code") error = _maybe_get_form_field(form_data, "error") error_description = _maybe_get_form_field( form_data, "error_description" ) elif request.url.query is not None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) state = _maybe_get_search_param(query, "state") code = _maybe_get_search_param(query, "code") error = _maybe_get_search_param(query, "error") error_description = _maybe_get_search_param( query, "error_description" ) else: raise errors.OAuthProviderFailure( "Provider did not respond with expected data" ) if state is None: raise errors.InvalidData( "Provider did not include the 'state' parameter in callback" ) if error is not None: try: claims = jwt.OAuthStateToken.verify(state, self.signing_key) redirect_to = claims.redirect_to except Exception: raise errors.InvalidData("Invalid state token") params = { "error": error, } error_str = error if error_description is not None: params["error_description"] = error_description error_str += f": {error_description}" logger.debug(f"OAuth provider returned an error: {error_str}") return self._try_redirect( response, util.join_url_params(redirect_to, params), ) if code is None: raise errors.InvalidData( "Provider did not include the 'code' parameter in callback" ) try: claims = jwt.OAuthStateToken.verify(state, self.signing_key) provider_name = claims.provider allowed_redirect_to = self._make_allowed_url(claims.redirect_to) allowed_redirect_to_on_signup = self._maybe_make_allowed_url( claims.redirect_to_on_signup ) challenge = claims.challenge redirect_uri = claims.redirect_uri except Exception: raise errors.InvalidData("Invalid state token") oauth_client = oauth.Client( db=self.db, provider_name=provider_name, url_munger=self._get_url_munger(request), http_client=self.tenant.get_http_client(originator="auth"), ) ( identity, new_identity, auth_token, refresh_token, id_token, ) = await oauth_client.handle_callback(code, redirect_uri) if new_identity: await self._maybe_send_webhook( webhook.IdentityCreated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=identity.id, ) ) pkce_code = await pkce.link_identity_challenge( self.db, identity.id, challenge ) if auth_token or refresh_token: await pkce.add_provider_tokens( self.db, id=pkce_code, auth_token=auth_token, refresh_token=refresh_token, id_token=id_token, ) new_url = ( (allowed_redirect_to_on_signup or allowed_redirect_to) if new_identity else allowed_redirect_to ).map( lambda u: util.join_url_params( u, {"code": pkce_code, "provider": provider_name} ) ) logger.info( "OAuth callback successful: " f"identity_id={identity.id}, new_identity={new_identity}" ) self._do_redirect(response, new_url) async def handle_token( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) code = _get_search_param(query, "code") verifier = _get_search_param( query, "verifier", fallback_keys=["code_verifier"] ) verifier_size = len(verifier) if verifier_size < 43: raise errors.InvalidData( "Verifier must be at least 43 characters long" ) if verifier_size > 128: raise errors.InvalidData( "Verifier must be shorter than 128 characters long" ) try: pkce_object = await pkce.get_by_id(self.db, code) except Exception: raise errors.NoIdentityFound("Could not find a matching PKCE code") if pkce_object.identity_id is None: raise errors.InvalidData("Code is not associated with an Identity") hashed_verifier = hashlib.sha256(verifier.encode()).digest() base64_url_encoded_verifier = base64.urlsafe_b64encode( hashed_verifier ).rstrip(b'=') if base64_url_encoded_verifier.decode() == pkce_object.challenge: await pkce.delete(self.db, code) identity_id = pkce_object.identity_id await self._maybe_send_webhook( webhook.IdentityAuthenticated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=identity_id, ) ) auth_expiration_time = util.get_config( self.db, "ext::auth::AuthConfig::token_time_to_live", statypes.Duration, ) session_token = jwt.SessionToken( subject=identity_id, ).sign( self.signing_key, expires_in=auth_expiration_time.to_timedelta(), ) metrics.auth_successful_logins.inc( 1.0, self.tenant.get_instance_name() ) logger.info(f"Token exchange successful: identity_id={identity_id}") response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps( { "auth_token": session_token, "identity_id": identity_id, "provider_token": pkce_object.auth_token, "provider_refresh_token": pkce_object.refresh_token, "provider_id_token": pkce_object.id_token, } ).encode() else: raise errors.PKCEVerificationFailed async def handle_register( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: data = self._get_data_from_request(request) allowed_redirect_to = self._maybe_make_allowed_url( cast(Optional[str], data.get("redirect_to")) ) maybe_challenge = cast(Optional[str], data.get("challenge")) register_provider_name = cast(Optional[str], data.get("provider")) if register_provider_name is None: raise errors.InvalidData('Missing "provider" in register request') email_password_client = email_password.Client(db=self.db) require_verification = email_password_client.config.require_verification if not require_verification and maybe_challenge is None: raise errors.InvalidData('Missing "challenge" in register request') pkce_code: Optional[str] = None try: email_factor = await email_password_client.register(data) identity = email_factor.identity verify_url = data.get("verify_url", f"{self.base_path}/ui/verify") verification_token = self._make_verification_token( identity_id=identity.id, verify_url=verify_url, maybe_challenge=maybe_challenge, maybe_redirect_to=( allowed_redirect_to.url if allowed_redirect_to else None ), ) await self._maybe_send_webhook( webhook.IdentityCreated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=identity.id, ) ) await self._maybe_send_webhook( webhook.EmailFactorCreated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=identity.id, email_factor_id=email_factor.id, ) ) await self._maybe_send_webhook( webhook.EmailVerificationRequested( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=identity.id, email_factor_id=email_factor.id, verification_token=verification_token, ) ) if require_verification: response_dict = { "identity_id": identity.id, "email": email_factor.email, "verification_email_sent_at": datetime.datetime.now( datetime.timezone.utc ).isoformat(), } else: # Checked at the beginning of the route handler assert maybe_challenge is not None await pkce.create(self.db, maybe_challenge) pkce_code = await pkce.link_identity_challenge( self.db, identity.id, maybe_challenge ) response_dict = { "code": pkce_code, "email": email_factor.email, "provider": register_provider_name, } await self._send_verification_email( provider=register_provider_name, verification_token=verification_token, to_addr=data["email"], verify_url=verify_url, ) logger.info( f"Identity created: identity_id={identity.id}, " f"pkce_id={pkce_code!r}" ) if allowed_redirect_to is not None: self._do_redirect( response, allowed_redirect_to.map( lambda u: util.join_url_params(u, response_dict) ), ) else: response.status = http.HTTPStatus.CREATED response.content_type = b"application/json" response.body = json.dumps(response_dict).encode() except Exception as ex: redirect_on_failure = data.get( "redirect_on_failure", data.get("redirect_to") ) if redirect_on_failure is not None: error_message = str(ex) email = data.get("email", "") logger.error( f"Error creating identity: error={error_message}, " f"email={email}" ) error_redirect_url = util.join_url_params( redirect_on_failure, { "error": error_message, "email": email, }, ) return self._try_redirect(response, error_redirect_url) else: raise ex async def handle_authenticate( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: data = self._get_data_from_request(request) _check_keyset(data, {"provider", "challenge", "email", "password"}) challenge = data["challenge"] email = data["email"] password = data["password"] await pkce.create(self.db, challenge) allowed_redirect_to = self._maybe_make_allowed_url( cast(Optional[str], data.get("redirect_to")) ) email_password_client = email_password.Client(db=self.db) try: local_identity = await email_password_client.authenticate( email, password ) verified_at = ( await email_password_client.get_verified_by_identity_id( identity_id=local_identity.id ) ) if ( email_password_client.config.require_verification and verified_at is None ): raise errors.VerificationRequired() pkce_code = await pkce.link_identity_challenge( self.db, local_identity.id, challenge ) response_dict = {"code": pkce_code} logger.info( f"Authentication successful: identity_id={local_identity.id}, " f"pkce_id={pkce_code}" ) if allowed_redirect_to: self._do_redirect( response, allowed_redirect_to.map( lambda u: util.join_url_params(u, response_dict) ), ) else: response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps(response_dict).encode() except Exception as ex: redirect_on_failure = data.get( "redirect_on_failure", data.get("redirect_to") ) if redirect_on_failure is not None: error_message = str(ex) email = data.get("email", "") logger.error( f"Error authenticating: error={error_message}, " f"email={email}" ) error_redirect_url = util.join_url_params( redirect_on_failure, { "error": error_message, "email": email, }, ) return self._try_redirect(response, error_redirect_url) else: raise ex async def handle_verify( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: data = self._get_data_from_request(request) verification_token = data.get("verification_token") email = data.get("email") code = data.get("code") provider = data.get("provider") if not provider: raise errors.InvalidData('Missing "provider" in verify request') if verification_token: try: token = jwt.VerificationToken.verify( verification_token, self.signing_key, ) email_factor = await self._try_verify_email( provider=provider, identity_id=token.subject, ) await self._maybe_send_webhook( webhook.EmailVerified( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=token.subject, email_factor_id=email_factor.id, ) ) except errors.VerificationTokenExpired: response.status = http.HTTPStatus.FORBIDDEN response.content_type = b"application/json" error_message = "The verification token is older than 24 hours" logger.error(f"Verification token expired: {error_message}") response.body = json.dumps({"message": error_message}).encode() return logger.info( f"Email verified via token: identity_id={token.subject}, " f"email_factor_id={email_factor.id}, " f"email={email_factor.email}" ) identity_id = token.subject challenge = token.maybe_challenge redirect_to = token.maybe_redirect_to elif email and code: _check_keyset( data, { "email", "code", "provider", }, ) email_client: local.Client if provider == "builtin::local_emailpassword": email_client = email_password.Client(db=self.db) elif provider == "builtin::local_webauthn": email_client = webauthn.Client(db=self.db) else: raise errors.InvalidData(f"Unsupported provider: {provider}") try: maybe_email_factor = ( await email_client.get_email_factor_by_email(email) ) if maybe_email_factor is None: raise errors.NoIdentityFound("Invalid email") email_factor = maybe_email_factor otc_id = await otc.verify(self.db, str(email_factor.id), code) await self._handle_otc_verified( identity_id=str(email_factor.identity.id), email_factor_id=str(email_factor.id), otc_id=str(otc_id), ) await self._try_verify_email( provider=provider, identity_id=email_factor.identity.id, ) await self._maybe_send_webhook( webhook.EmailVerified( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=email_factor.identity.id, email_factor_id=email_factor.id, ) ) logger.info( f"Email verified via OTC: " f"identity_id={email_factor.identity.id}, " f"email_factor_id={email_factor.id}, " f"email={email}" ) identity_id = email_factor.identity.id challenge = data.get("challenge") redirect_to = data.get("redirect_to") except Exception as ex: self._handle_otc_failed(ex) response.status = http.HTTPStatus.BAD_REQUEST response.content_type = b"application/json" response.body = json.dumps( {"error": str(ex), "error_code": "verification_failed"} ).encode() return else: raise errors.InvalidData( 'Must provide either "verification_token" (Link mode) ' 'or "email" + "code" (OTC mode)' ) logger.info(f"Challenge: {challenge}, Redirect to: {redirect_to}") match (challenge, redirect_to): case (str(), str()): await pkce.create(self.db, challenge) code = await pkce.link_identity_challenge( self.db, identity_id, challenge ) return self._try_redirect( response, util.join_url_params(redirect_to, {"code": code}), ) case (str(), None): await pkce.create(self.db, challenge) code = await pkce.link_identity_challenge( self.db, identity_id, challenge ) response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps({"code": code}).encode() return case (None, str()): return self._try_redirect(response, redirect_to) case (None, None): response.status = http.HTTPStatus.NO_CONTENT return async def handle_resend_verification_email( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: request_data = self._get_data_from_request(request) _check_keyset(request_data, {"provider"}) provider_name = request_data["provider"] local_client: email_password.Client | webauthn.Client match provider_name: case "builtin::local_emailpassword": local_client = email_password.Client(db=self.db) case "builtin::local_webauthn": local_client = webauthn.Client(db=self.db) case _: raise errors.InvalidData( f"Unsupported provider: {request_data['provider']}" ) verify_url = request_data.get( "verify_url", f"{self.base_path}/ui/verify" ) email_factor: Optional[EmailFactor] = None if "verification_token" in request_data: token = jwt.VerificationToken.verify( request_data["verification_token"], self.signing_key, skip_expiration_check=True, ) identity_id = token.subject verify_url = token.verify_url maybe_challenge = token.maybe_challenge maybe_redirect_to = token.maybe_redirect_to email_factor = await local_client.get_email_factor_by_identity_id( identity_id ) else: maybe_challenge = request_data.get( "challenge", request_data.get("code_challenge") ) maybe_redirect_to = request_data.get("redirect_to") if maybe_redirect_to and not self._is_url_allowed( maybe_redirect_to ): raise errors.InvalidData( "Redirect URL does not match any allowed URLs.", ) match local_client: case webauthn.Client(): _check_keyset(request_data, {"credential_id"}) credential_id = base64.b64decode( request_data["credential_id"] ) email_factor = ( await local_client.get_email_factor_by_credential_id( credential_id ) ) case email_password.Client(): _check_keyset(request_data, {"email"}) email_factor = await local_client.get_email_factor_by_email( request_data["email"] ) if email_factor is None: match local_client: case webauthn.Client(): logger.debug( f"Failed to find email factor for resend verification " f"email: provider={provider_name}, " f"webauthn_credential_id={request_data.get('credential_id')}" ) case email_password.Client(): logger.debug( f"Failed to find email factor for resend verification " f"email: provider={provider_name}, " f"email={request_data.get('email')}" ) await auth_emails.send_fake_email(self.tenant) else: logger.info( f"Resending verification email: provider={provider_name}, " f"identity_id={email_factor.identity.id}, " f"email_factor_id={email_factor.id}, " f"email={email_factor.email}" ) verification_token = self._make_verification_token( identity_id=email_factor.identity.id, verify_url=verify_url, maybe_challenge=maybe_challenge, maybe_redirect_to=maybe_redirect_to, ) await self._maybe_send_webhook( webhook.EmailVerificationRequested( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=email_factor.identity.id, email_factor_id=email_factor.id, verification_token=verification_token, ) ) await self._send_verification_email( provider=request_data["provider"], verification_token=verification_token, to_addr=email_factor.email, verify_url=verify_url, ) response.status = http.HTTPStatus.OK async def handle_send_reset_email( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: data = self._get_data_from_request(request) _check_keyset(data, {"provider", "email", "reset_url", "challenge"}) email = data["email"] email_password_client = email_password.Client(db=self.db) if not self._is_url_allowed(data["reset_url"]): raise errors.InvalidData( "Redirect URL does not match any allowed URLs.", ) allowed_redirect_to = self._maybe_make_allowed_url( data.get("redirect_to") ) try: try: ( email_factor, secret, ) = await email_password_client.get_email_factor_and_secret( email ) identity_id = email_factor.identity.id if email_password_client.config.verification_method == "Code": code, otc_id = await otc.create( self.db, str(email_factor.id), datetime.timedelta(minutes=10), ) await auth_emails.send_password_reset_code_email( db=self.db, tenant=self.tenant, to_addr=email, code=code, test_mode=self.test_mode, ) await self._handle_otc_initiated( identity_id=identity_id, email_factor_id=str(email_factor.id), otc_id=str(otc_id), one_time_code=code, ) logger.info( "Sent OTC password reset email: " f"email={email}, otc_id={otc_id}" ) else: new_reset_token = jwt.ResetToken( subject=identity_id, secret=secret, challenge=data["challenge"], ).sign(self.signing_key) reset_token_params = {"reset_token": new_reset_token} reset_url = util.join_url_params( data['reset_url'], reset_token_params ) await self._maybe_send_webhook( webhook.PasswordResetRequested( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now( datetime.timezone.utc ), identity_id=identity_id, reset_token=new_reset_token, email_factor_id=email_factor.id, ) ) await auth_emails.send_password_reset_email( db=self.db, tenant=self.tenant, to_addr=email, reset_url=reset_url, test_mode=self.test_mode, ) except errors.NoIdentityFound: logger.debug( f"Failed to find identity for send reset email: " f"email={email}" ) await auth_emails.send_fake_email(self.tenant) return_data = { "email_sent": email, } if allowed_redirect_to: return self._do_redirect( response, allowed_redirect_to.map( lambda u: util.join_url_params(u, return_data) ), ) else: response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps(return_data).encode() except aiosmtplib.SMTPException as ex: if not debug.flags.server: logger.warning("Failed to send emails via SMTP", exc_info=True) raise edb_errors.InternalServerError( "Failed to send the email, please try again later." ) from ex except Exception as ex: redirect_on_failure = data.get( "redirect_on_failure", data.get("redirect_to") ) if redirect_on_failure is not None: error_message = str(ex) logger.error( f"Error sending reset email: error={error_message}, " f"email={email}" ) redirect_url = util.join_url_params( redirect_on_failure, { "error": error_message, "email": email, }, ) return self._try_redirect( response, redirect_url, ) else: raise ex async def handle_reset_password( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: data = self._get_data_from_request(request) try: _check_keyset(data, {"password", "provider"}) password = data['password'] reset_token = data.get('reset_token') email = data.get('email') code = data.get('code') allowed_redirect_to = self._maybe_make_allowed_url( data.get("redirect_to") ) email_password_client = email_password.Client(db=self.db) if reset_token: token = jwt.ResetToken.verify( reset_token, self.signing_key, ) await email_password_client.update_password( token.subject, token.secret, password ) await pkce.create(self.db, token.challenge) code = await pkce.link_identity_challenge( self.db, token.subject, token.challenge ) response_dict = {"code": code} logger.info( "Reset password via token: " f"identity_id={token.subject}, pkce_id={code}" ) elif email and code: try: ( email_factor, secret, ) = await email_password_client.get_email_factor_and_secret( email ) otc_id = await otc.verify( self.db, str(email_factor.id), code ) logger.info( "OTC verified for password reset: " f"otc_id={otc_id}, email={email}" ) await self._handle_otc_verified( identity_id=email_factor.identity.id, email_factor_id=str(email_factor.id), otc_id=str(otc_id), ) except Exception as ex: self._handle_otc_failed(ex) raise try: await email_password_client.update_password( email_factor.identity.id, secret, password ) except Exception as ex: raise errors.InvalidData( f"Failed to reset password: {str(ex)}" ) challenge = data.get('challenge') if challenge: await pkce.create(self.db, challenge) auth_code = await pkce.link_identity_challenge( self.db, email_factor.identity.id, challenge ) response_dict = {"code": auth_code} logger.info( "Reset password via OTC: " f"identity_id={email_factor.identity.id}, email={email}" ) else: response_dict = {"status": "password_reset"} else: raise errors.InvalidData( 'Must provide either "reset_token" (Token mode) ' 'or "email" + "code" (OTC mode)' ) if allowed_redirect_to: return self._do_redirect( response, allowed_redirect_to.map( lambda u: util.join_url_params(u, response_dict) ), ) else: response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps(response_dict).encode() except Exception as ex: redirect_on_failure = data.get( "redirect_on_failure", data.get("redirect_to") ) if redirect_on_failure is not None: error_message = str(ex) logger.error( f"Error resetting password: error={error_message}, " f"reset_token={reset_token}, email={email}" ) error_params = { "error": error_message, } if reset_token: error_params["reset_token"] = reset_token if email: error_params["email"] = email redirect_url = util.join_url_params( redirect_on_failure, error_params, ) return self._try_redirect(response, redirect_url) else: raise ex async def handle_magic_link_register( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: data = self._get_data_from_request(request) email: str | None = None try: _check_keyset(data, {"email"}) email = cast(str, data["email"]) allowed_redirect_to = self._maybe_make_allowed_url( data.get("redirect_to") ) magic_link_client = magic_link.Client( db=self.db, issuer=self.base_path, tenant=self.tenant, test_mode=self.test_mode, signing_key=self.signing_key, ) if not _accepts_json(request) and not allowed_redirect_to: raise errors.InvalidData( "Request must accept JSON or provide a redirect URL." ) email_factor = await magic_link_client.register( email=email, ) await self._maybe_send_webhook( webhook.IdentityCreated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=email_factor.identity.id, ) ) await self._maybe_send_webhook( webhook.EmailFactorCreated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=email_factor.identity.id, email_factor_id=email_factor.id, ) ) if magic_link_client.provider.verification_method == "Code": code, otc_id = await otc.create( self.db, str(email_factor.id), datetime.timedelta(minutes=10), ) await auth_emails.send_one_time_code_email( db=self.db, tenant=self.tenant, to_addr=email, code=code, test_mode=self.test_mode, ) await self._handle_otc_initiated( identity_id=str(email_factor.identity.id), email_factor_id=str(email_factor.id), otc_id=str(otc_id), one_time_code=code, ) logger.info( "Sent OTC email: " f"identity_id={email_factor.identity.id}, " f"email={email}, otc_id={otc_id}" ) return_data = { "code": "true", "signup": "true", "identity_id": str(email_factor.identity.id), "email": email, } else: _check_keyset( data, {"challenge", "callback_url", "redirect_on_failure"} ) challenge = data["challenge"] callback_url = data["callback_url"] if not self._is_url_allowed(callback_url): raise errors.InvalidData( "Callback URL does not match any allowed URLs.", ) allowed_redirect_on_failure = self._make_allowed_url( data["redirect_on_failure"] ) allowed_link_url = self._maybe_make_allowed_url( data.get("link_url") ) link_url = ( allowed_link_url.url if allowed_link_url else f"{self.base_path}/magic-link/authenticate" ) magic_link_token = magic_link_client.make_magic_link_token( identity_id=email_factor.identity.id, callback_url=callback_url, challenge=challenge, ) await self._maybe_send_webhook( webhook.MagicLinkRequested( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=email_factor.identity.id, email_factor_id=email_factor.id, magic_link_token=magic_link_token, magic_link_url=link_url, ) ) logger.info( "Sending magic link: " f"identity_id={email_factor.identity.id}, email={email}" ) await magic_link_client.send_magic_link( email=email, link_url=link_url, redirect_on_failure=allowed_redirect_on_failure.url, token=magic_link_token, ) return_data = { "email_sent": email, } if _accepts_json(request): response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps(return_data).encode() elif allowed_redirect_to: return self._do_redirect( response, allowed_redirect_to.map( lambda u: util.join_url_params(u, return_data) ), ) else: # This should not happen since we check earlier for this case # but this seems safer than a cast raise errors.InvalidData( "Request must accept JSON or provide a redirect URL." ) except Exception as ex: if _accepts_json(request): raise ex redirect_on_failure = data.get( "redirect_on_failure", data.get("redirect_to") ) error_message = str(ex) email = email or "" logger.error( f"Error sending magic link email: error={error_message}, " f"email={email}" ) if redirect_on_failure is None: raise ex error_redirect_url = util.join_url_params( redirect_on_failure, { "error": error_message, "email": email, }, ) self._try_redirect(response, error_redirect_url) async def handle_magic_link_email( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: data = self._get_data_from_request(request) email: str | None = None return_data: dict[str, Any] = {} try: _check_keyset(data, {"email"}) email = cast(str, data["email"]) allowed_redirect_to = self._maybe_make_allowed_url( data.get("redirect_to") ) magic_link_client = magic_link.Client( db=self.db, issuer=self.base_path, tenant=self.tenant, test_mode=self.test_mode, signing_key=self.signing_key, ) email_factor = await magic_link_client.get_email_factor_by_email( email ) is_signup = False if email_factor is None: if magic_link_client.provider.auto_signup: # Auto-signup is enabled, create a new user is_signup = True email_factor = await magic_link_client.register( email=email, ) await self._maybe_send_webhook( webhook.IdentityCreated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=email_factor.identity.id, ) ) await self._maybe_send_webhook( webhook.EmailFactorCreated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=email_factor.identity.id, email_factor_id=email_factor.id, ) ) else: logger.error( "Cannot send magic link email: no email factor found " f"for email={email}" ) await auth_emails.send_fake_email(self.tenant) if magic_link_client.provider.verification_method == "Code": return_data = { "code": "true", "email": email, } else: return_data = { "email_sent": email, } if email_factor is not None: identity_id = email_factor.identity.id if magic_link_client.provider.verification_method == "Code": code, otc_id = await otc.create( self.db, str(email_factor.id), datetime.timedelta(minutes=10), ) await auth_emails.send_one_time_code_email( db=self.db, tenant=self.tenant, to_addr=email, code=code, test_mode=self.test_mode, ) await self._handle_otc_initiated( identity_id=str(identity_id), email_factor_id=str(email_factor.id), otc_id=str(otc_id), one_time_code=code, ) logger.info( f"Sent OTC email: identity_id={identity_id}, " f"email={email}, otc_id={otc_id}" ) return_data = { "code": "true", "email": email, } if is_signup: return_data["identity_id"] = str(identity_id) return_data["signup"] = "true" else: _check_keyset( data, {"challenge", "callback_url", "redirect_on_failure"}, ) challenge = data["challenge"] callback_url = data["callback_url"] if not self._is_url_allowed(callback_url): raise errors.InvalidData( "callback_url does not match any allowed URLs.", ) redirect_on_failure = data["redirect_on_failure"] if not self._is_url_allowed(redirect_on_failure): raise errors.InvalidData( "redirect_on_failure" " does not match any allowed URLs.", ) allowed_link_url = self._maybe_make_allowed_url( data.get("link_url") ) link_url = ( allowed_link_url.url if allowed_link_url else f"{self.base_path}/magic-link/authenticate" ) magic_link_token = magic_link_client.make_magic_link_token( identity_id=identity_id, callback_url=callback_url, challenge=challenge, ) await self._maybe_send_webhook( webhook.MagicLinkRequested( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now( datetime.timezone.utc ), identity_id=identity_id, email_factor_id=email_factor.id, magic_link_token=magic_link_token, magic_link_url=link_url, ) ) await magic_link_client.send_magic_link( email=email, token=magic_link_token, link_url=link_url, redirect_on_failure=redirect_on_failure, ) logger.info( "Sent magic link email: " f"identity_id={identity_id}, email={email}" ) return_data = { "email_sent": email, } if is_signup: return_data["signup"] = "true" if allowed_redirect_to: return self._do_redirect( response, allowed_redirect_to.map( lambda u: util.join_url_params(u, return_data) ), ) else: response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps(return_data).encode() except Exception as ex: if _accepts_json(request): raise ex redirect_on_failure = data.get( "redirect_on_failure", data.get("redirect_to") ) error_message = str(ex) email = email or "" logger.error( f"Error sending magic link email: error={error_message}, " f"email={email}" ) if redirect_on_failure is None: raise ex error_redirect_url = util.join_url_params( redirect_on_failure, { "error": error_message, "email": email, }, ) self._try_redirect(response, error_redirect_url) async def handle_magic_link_authenticate( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) token_str = _maybe_get_search_param(query, "token") if token_str: try: token = jwt.MagicLinkToken.verify(token_str, self.signing_key) await pkce.create(self.db, token.challenge) code = await pkce.link_identity_challenge( self.db, token.subject, token.challenge ) local_client = magic_link.Client( db=self.db, tenant=self.tenant, test_mode=self.test_mode, issuer=self.base_path, signing_key=self.signing_key, ) await local_client.verify_email( token.subject, datetime.datetime.now(datetime.timezone.utc) ) return self._try_redirect( response, util.join_url_params(token.callback_url, {"code": code}), ) except Exception as ex: redirect_on_failure = _maybe_get_search_param( query, "redirect_on_failure" ) if redirect_on_failure is None: raise ex else: error_message = str(ex) logger.error( "Error authenticating magic link: " f"error={error_message}, token={token_str}" ) redirect_url = util.join_url_params( redirect_on_failure, { "error": error_message, }, ) return self._try_redirect(response, redirect_url) else: try: data = self._get_data_from_request(request) _check_keyset( data, {"email", "code", "challenge"} ) email = data["email"] code_str = data["code"] challenge = data["challenge"] maybe_callback_url = cast( Optional[str], data.get("callback_url") ) if ( maybe_callback_url and not self._is_url_allowed(maybe_callback_url) ): raise errors.InvalidData( "Callback URL does not match any allowed URLs.", ) magic_link_client = magic_link.Client( db=self.db, tenant=self.tenant, test_mode=self.test_mode, issuer=self.base_path, signing_key=self.signing_key, ) email_factor = ( await magic_link_client.get_email_factor_by_email(email) ) if email_factor is None: raise errors.NoIdentityFound("Invalid email") try: otc_id = await otc.verify( self.db, str(email_factor.id), code_str ) await self._handle_otc_verified( identity_id=str(email_factor.identity.id), email_factor_id=str(email_factor.id), otc_id=str(otc_id), ) except Exception as ex: self._handle_otc_failed(ex) raise await pkce.create(self.db, challenge) auth_code = await pkce.link_identity_challenge( self.db, email_factor.identity.id, challenge ) await magic_link_client.verify_email( email_factor.identity.id, datetime.datetime.now(datetime.timezone.utc), ) response_dict = {"code": auth_code} if maybe_callback_url: return self._try_redirect( response, util.join_url_params(maybe_callback_url, response_dict), ) else: response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps(response_dict).encode() except Exception as ex: redirect_on_failure = _maybe_get_search_param( query, "redirect_on_failure" ) if redirect_on_failure is None: response.status = http.HTTPStatus.BAD_REQUEST response.content_type = b"application/json" response.body = json.dumps( {"error": str(ex), "error_code": "verification_failed"} ).encode() return else: error_message = str(ex) logger.error( f"Error authenticating OTC: error={error_message}, " f"email={_maybe_get_search_param(query, 'email')}" ) redirect_url = util.join_url_params( redirect_on_failure, { "error": error_message, }, ) return self._try_redirect(response, redirect_url) async def handle_webauthn_register_options( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) email = _get_search_param(query, "email") webauthn_client = webauthn.Client(self.db) try: ( user_handle, registration_options, ) = await webauthn_client.create_registration_options_for_email( email=email, ) except Exception as e: raise errors.WebAuthnRegistrationFailed( "Failed to create registration options" ) from e response.status = http.HTTPStatus.OK response.content_type = b"application/json" _set_cookie( response, "edgedb-webauthn-registration-user-handle", user_handle, path="/", ) response.body = registration_options async def handle_webauthn_register( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: data = self._get_data_from_request(request) _check_keyset( data, {"provider", "challenge", "email", "credentials", "verify_url"}, ) webauthn_client = webauthn.Client(self.db) provider_name: str = data["provider"] email: str = data["email"] verify_url: str = data["verify_url"] credentials: str = data["credentials"] pkce_challenge: str = data["challenge"] user_handle_cookie = request.cookies.get( "edgedb-webauthn-registration-user-handle" ) user_handle_base64url: Optional[str] = ( user_handle_cookie.value if user_handle_cookie else data.get("user_handle") ) if user_handle_base64url is None: raise errors.InvalidData( "Missing user_handle from cookie or request body" ) try: user_handle = base64.urlsafe_b64decode( f"{user_handle_base64url}===" ) except Exception as e: raise errors.InvalidData("Failed to decode user_handle") from e require_verification = webauthn_client.provider.require_verification pkce_code: Optional[str] = None try: email_factor = await webauthn_client.register( credentials=credentials, email=email, user_handle=user_handle, ) except Exception as e: raise errors.WebAuthnRegistrationFailed( "Failed to register WebAuthn" ) from e identity_id = email_factor.identity.id await self._maybe_send_webhook( webhook.IdentityCreated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=identity_id, ) ) await self._maybe_send_webhook( webhook.EmailFactorCreated( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=identity_id, email_factor_id=email_factor.id, ) ) verification_token = self._make_verification_token( identity_id=identity_id, verify_url=verify_url, maybe_challenge=pkce_challenge, maybe_redirect_to=None, ) await self._maybe_send_webhook( webhook.EmailVerificationRequested( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=identity_id, email_factor_id=email_factor.id, verification_token=verification_token, ) ) await self._send_verification_email( provider=provider_name, verification_token=verification_token, to_addr=email_factor.email, verify_url=verify_url, ) if not require_verification: await pkce.create(self.db, pkce_challenge) pkce_code = await pkce.link_identity_challenge( self.db, identity_id, pkce_challenge ) _set_cookie( response, "edgedb-webauthn-registration-user-handle", "", path="/", ) response.status = http.HTTPStatus.CREATED response.content_type = b"application/json" if require_verification: now_iso8601 = datetime.datetime.now( datetime.timezone.utc ).isoformat() response.body = json.dumps( { "identity_id": identity_id, "email": email_factor.email, "verification_email_sent_at": now_iso8601, } ).encode() logger.info( f"Sent verification email: identity_id={identity_id}, " f"email={email}" ) else: if pkce_code is None: raise errors.PKCECreationFailed response.body = json.dumps( { "code": pkce_code, "provider": provider_name, "email": email_factor.email, } ).encode() logger.info( f"WebAuthn registration successful: identity_id={identity_id}, " f"email={email}, " f"pkce_id={pkce_code}" ) async def handle_webauthn_authenticate_options( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) email = _get_search_param(query, "email") webauthn_provider = self._get_webauthn_provider() if webauthn_provider is None: raise errors.MissingConfiguration( "ext::auth::AuthConfig::providers", "WebAuthn provider is not configured", ) webauthn_client = webauthn.Client(self.db) try: ( _, registration_options, ) = await webauthn_client.create_authentication_options_for_email( email=email, webauthn_provider=webauthn_provider ) except Exception as e: raise errors.WebAuthnAuthenticationFailed( "Failed to create authentication options" ) from e response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = registration_options async def handle_webauthn_authenticate( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: data = self._get_data_from_request(request) _check_keyset( data, {"challenge", "email", "assertion"}, ) webauthn_client = webauthn.Client(self.db) email: str = data["email"] assertion: str = data["assertion"] pkce_challenge: str = data["challenge"] try: identity = await webauthn_client.authenticate( assertion=assertion, email=email, ) except Exception as e: raise errors.WebAuthnAuthenticationFailed( "Failed to authenticate WebAuthn" ) from e require_verification = webauthn_client.provider.require_verification if require_verification: email_is_verified = await webauthn_client.is_email_verified( email, assertion ) if not email_is_verified: raise errors.VerificationRequired() await pkce.create(self.db, pkce_challenge) code = await pkce.link_identity_challenge( self.db, identity.id, pkce_challenge ) logger.info( f"WebAuthn authentication successful: identity_id={identity.id}, " f"email={email}, " f"pkce_id={code}" ) response.status = http.HTTPStatus.OK response.content_type = b"application/json" response.body = json.dumps( { "code": code, } ).encode() async def handle_ui_signin( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: ui_config = self._get_ui_config() if ui_config is None: response.status = http.HTTPStatus.NOT_FOUND response.body = b'Auth UI not enabled' else: providers = util.maybe_get_config( self.db, "ext::auth::AuthConfig::providers", frozenset, ) if providers is None or len(providers) == 0: raise errors.MissingConfiguration( 'ext::auth::AuthConfig::providers', 'No providers are configured', ) app_details_config = self._get_app_details_config() query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) maybe_challenge = _get_pkce_challenge( response=response, cookies=request.cookies, query_dict=query, ) if maybe_challenge is None: raise errors.InvalidData( 'Missing "challenge" in register request' ) response.status = http.HTTPStatus.OK response.content_type = b'text/html' response.body = ui.render_signin_page( base_path=self.base_path, providers=providers, redirect_to=ui_config.redirect_to, redirect_to_on_signup=ui_config.redirect_to_on_signup, error_message=_maybe_get_search_param(query, 'error'), email=_maybe_get_search_param(query, 'email'), challenge=maybe_challenge, selected_tab=_maybe_get_search_param(query, 'selected_tab'), app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) async def handle_ui_signup( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: ui_config = self._get_ui_config() if ui_config is None: response.status = http.HTTPStatus.NOT_FOUND response.body = b'Auth UI not enabled' else: providers = util.maybe_get_config( self.db, "ext::auth::AuthConfig::providers", frozenset, ) if providers is None or len(providers) == 0: raise errors.MissingConfiguration( 'ext::auth::AuthConfig::providers', 'No providers are configured', ) query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) maybe_challenge = _get_pkce_challenge( response=response, cookies=request.cookies, query_dict=query, ) if maybe_challenge is None: raise errors.InvalidData( 'Missing "challenge" in register request' ) app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b'text/html' response.body = ui.render_signup_page( base_path=self.base_path, providers=providers, redirect_to=ui_config.redirect_to, redirect_to_on_signup=ui_config.redirect_to_on_signup, error_message=_maybe_get_search_param(query, 'error'), email=_maybe_get_search_param(query, 'email'), challenge=maybe_challenge, selected_tab=_maybe_get_search_param(query, 'selected_tab'), app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) async def handle_ui_forgot_password( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: ui_config = self._get_ui_config() password_provider = ( self._get_password_provider() if ui_config is not None else None ) if ui_config is None or password_provider is None: response.status = http.HTTPStatus.NOT_FOUND response.body = ( b'Password provider not configured' if ui_config else b'Auth UI not enabled' ) return query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) challenge = _get_search_param( query, "challenge", fallback_keys=["code_challenge"] ) app_details_config = self._get_app_details_config() redirect_on_failure = ( f"{self.base_path}/ui/forgot-password?challenge={challenge}" ) reset_url = f"{self.base_path}/ui/reset-password" if password_provider.verification_method == "Code": redirect_to = ( f"{self.base_path}/ui/reset-password?" f"code=true&challenge={challenge}" ) else: redirect_to = ( f"{self.base_path}/ui/forgot-password?challenge={challenge}" ) response.status = http.HTTPStatus.OK response.content_type = b'text/html' response.body = ui.render_forgot_password_page( redirect_to=redirect_to, redirect_on_failure=redirect_on_failure, reset_url=reset_url, provider_name=password_provider.name, error_message=_maybe_get_search_param(query, 'error'), email=_maybe_get_search_param(query, 'email'), email_sent=_maybe_get_search_param(query, 'email_sent'), challenge=challenge, app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) async def handle_ui_reset_password( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: ui_config = self._get_ui_config() password_provider = ( self._get_password_provider() if ui_config is not None else None ) if ui_config is None or password_provider is None: raise errors.NotFound( 'Password provider not configured' if ui_config else 'Auth UI not enabled' ) query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) challenge = _get_pkce_challenge( response=response, cookies=request.cookies, query_dict=query, ) if challenge is None: raise errors.InvalidData( 'Missing "challenge" in reset password request' ) error_message = _maybe_get_search_param(query, "error") if password_provider.verification_method == "Code": maybe_email = _maybe_get_search_param( query, "email", fallback_keys=["email_sent"] ) if maybe_email is None: raise errors.InvalidData('Missing "email" for reset code flow') app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b'text/html' response.body = ui.render_reset_password_page( base_path=self.base_path, provider_name=password_provider.name, is_valid=True, # Code flow is always valid to show the form redirect_to=ui_config.redirect_to, challenge=challenge, reset_token=None, error_message=error_message, is_code_flow=True, email=maybe_email, app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) return try: reset_token = _get_search_param(query, 'reset_token') token = jwt.ResetToken.verify(reset_token, self.signing_key) email_password_client = email_password.Client( db=self.db, ) is_valid = ( await email_password_client.validate_reset_secret( token.subject, token.secret ) is not None ) except Exception: is_valid = False app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b'text/html' response.body = ui.render_reset_password_page( base_path=self.base_path, provider_name=password_provider.name, is_valid=is_valid, redirect_to=ui_config.redirect_to, reset_token=reset_token, challenge=challenge, error_message=error_message, is_code_flow=False, email=None, app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) async def handle_ui_verify( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: error_messages: list[str] = [] is_valid = True is_code_method: bool = False ui_config = self._get_ui_config() if ui_config is None: response.status = http.HTTPStatus.NOT_FOUND response.body = b'Auth UI not enabled' return query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) maybe_provider_name = _maybe_get_search_param(query, "provider") provider: ( config.WebAuthnProvider | config.EmailPasswordProviderConfig | None ) = None # Decide flow by provider config match maybe_provider_name: case "builtin::local_emailpassword": provider = self._get_password_provider() if provider is None: raise errors.MissingConfiguration( "ext::auth::AuthConfig::providers", "EmailPassword provider is not configured", ) is_code_method = provider.verification_method == "Code" case "builtin::local_webauthn": provider = self._get_webauthn_provider() if provider is None: raise errors.MissingConfiguration( "ext::auth::AuthConfig::providers", "WebAuthn provider is not configured", ) is_code_method = provider.verification_method == "Code" case _: raise errors.InvalidData( f"Unknown provider: {maybe_provider_name}" ) if is_code_method: email = _get_search_param(query, "email") error_message = _maybe_get_search_param(query, "error") challenge = _get_pkce_challenge( cookies=request.cookies, response=response, query_dict=query, ) if challenge is None: raise errors.InvalidData( 'Missing "challenge" in email verification request' ) app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b'text/html' response.body = ui.render_email_verification_page_code_flow( challenge=challenge, email=email, provider=maybe_provider_name, base_path=self.base_path, callback_url=( ui_config.redirect_to_on_signup or ui_config.redirect_to ), error_message=error_message, app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) return maybe_pkce_code: str | None = None redirect_to = ui_config.redirect_to_on_signup or ui_config.redirect_to maybe_verification_token = _maybe_get_search_param( query, "verification_token" ) match (maybe_provider_name, maybe_verification_token): case (None, None): error_messages.append( "Missing provider and email verification token." ) is_valid = False case (None, _): error_messages.append("Missing provider.") is_valid = False case (_, None): error_messages.append("Missing email verification token.") is_valid = False case (str(provider_name), str(verification_token)): try: token = jwt.VerificationToken.verify( verification_token, self.signing_key, ) await self._try_verify_email( provider=provider_name, identity_id=token.subject, ) match token.maybe_challenge: case str(ch): await pkce.create(self.db, ch) maybe_pkce_code = ( await pkce.link_identity_challenge( self.db, token.subject, ch, ) ) case _: maybe_pkce_code = None redirect_to = token.maybe_redirect_to or redirect_to redirect_to = ( util.join_url_params( redirect_to, { "code": maybe_pkce_code, }, ) if maybe_pkce_code else redirect_to ) except errors.VerificationTokenExpired: app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b"text/html" response.body = ui.render_email_verification_expired_page( verification_token=verification_token, app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) return except Exception as ex: error_messages.append(repr(ex)) is_valid = False # Only redirect back if verification succeeds if is_valid: return self._try_redirect(response, redirect_to) app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b'text/html' response.body = ui.render_email_verification_page_link_flow( verification_token=maybe_verification_token, is_valid=is_valid, error_messages=error_messages, app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) async def handle_ui_resend_verification( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) ui_config = self._get_ui_config() password_provider = ( self._get_password_provider() if ui_config is not None else None ) is_valid = True if password_provider is None: response.status = http.HTTPStatus.NOT_FOUND response.body = b'Password provider not configured' return try: _check_keyset(query, {"verification_token"}) verification_token = query["verification_token"][0] token = jwt.VerificationToken.verify( verification_token, self.signing_key, skip_expiration_check=True, ) email_password_client = email_password.Client(self.db) email_factor = ( await email_password_client.get_email_factor_by_identity_id( token.subject ) ) if email_factor is None: raise errors.NoIdentityFound( "Could not find email for provided identity" ) verify_url = f"{self.base_path}/ui/verify" verification_token = self._make_verification_token( identity_id=token.subject, verify_url=verify_url, maybe_challenge=token.maybe_challenge, maybe_redirect_to=token.maybe_redirect_to, ) await self._send_verification_email( provider=password_provider.name, verification_token=verification_token, to_addr=email_factor.email, verify_url=verify_url, ) except Exception: is_valid = False app_details_config = self._get_app_details_config() response.status = http.HTTPStatus.OK response.content_type = b"text/html" response.body = ui.render_resend_verification_done_page( is_valid=is_valid, verification_token=_maybe_get_search_param( query, "verification_token" ), app_name=app_details_config.app_name, logo_url=app_details_config.logo_url, dark_logo_url=app_details_config.dark_logo_url, brand_color=app_details_config.brand_color, ) async def handle_ui_magic_link_sent( self, request: protocol.HttpRequest, response: protocol.HttpResponse, ) -> None: ui_config = self._get_ui_config() if ui_config is None: response.status = http.HTTPStatus.NOT_FOUND response.body = b'Auth UI not enabled' return app_details = self._get_app_details_config() query = urllib.parse.parse_qs( request.url.query.decode("ascii") if request.url.query else "" ) # Use magic link provider to decide display mode magic_link_provider = None providers = util.maybe_get_config( self.db, "ext::auth::AuthConfig::providers", frozenset, ) if providers is not None: for p in providers: if getattr(p, 'name', None) == 'builtin::local_magic_link': magic_link_provider = p break is_code_method = ( getattr(magic_link_provider, 'verification_method', 'Link') == 'Code' ) if is_code_method: is_signup = _maybe_get_search_param(query, "signup") == "true" email = _get_search_param(query, "email") challenge = _get_pkce_challenge( cookies=request.cookies, response=response, query_dict=query ) if challenge is None: response.status = http.HTTPStatus.BAD_REQUEST response.body = b'Missing challenge from cookie or query params' return error_message = _maybe_get_search_param(query, "error") callback_url = ui_config.redirect_to if is_signup and ui_config.redirect_to_on_signup: callback_url = ui_config.redirect_to_on_signup content = ui.render_magic_link_sent_page_code_flow( app_name=app_details.app_name, logo_url=app_details.logo_url, dark_logo_url=app_details.dark_logo_url, brand_color=app_details.brand_color, email=email, challenge=challenge, callback_url=callback_url, error_message=error_message, ) else: content = ui.render_magic_link_sent_page_link_flow( app_name=app_details.app_name, logo_url=app_details.logo_url, dark_logo_url=app_details.dark_logo_url, brand_color=app_details.brand_color, ) response.status = http.HTTPStatus.OK response.content_type = b"text/html" response.body = content def _get_webhook_config(self) -> list[config.WebhookConfig]: raw_webhook_configs = util.get_config( self.db, "ext::auth::AuthConfig::webhooks", frozenset, ) return [ config.WebhookConfig( events=raw_config.events, url=raw_config.url, signing_secret_key=raw_config.signing_secret_key, ) for raw_config in raw_webhook_configs ] async def _maybe_send_webhook(self, event: webhook.Event) -> None: webhook_configs = self._get_webhook_config() for webhook_config in webhook_configs: if event.event_type in webhook_config.events: request_id = await webhook.send( db=self.db, url=webhook_config.url, secret=webhook_config.signing_secret_key, event=event, ) logger.info( f"Sent webhook request {request_id} " f"to {webhook_config.url} for event {event!r}" ) async def _handle_otc_initiated( self, identity_id: str, email_factor_id: str, otc_id: str, one_time_code: str, ) -> None: metrics.otc_initiated_total.inc(1.0, self.tenant.get_instance_name()) await self._maybe_send_webhook( webhook.OneTimeCodeRequested( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=identity_id, email_factor_id=email_factor_id, otc_id=str(otc_id), one_time_code=one_time_code, ) ) logger.info( f"OTC initiated: identity_id={identity_id}, otc_id={otc_id}" ) async def _handle_otc_verified( self, identity_id: str, email_factor_id: str, otc_id: str ) -> None: metrics.otc_verified_total.inc(1.0, self.tenant.get_instance_name()) await self._maybe_send_webhook( webhook.OneTimeCodeVerified( event_id=str(uuid.uuid4()), timestamp=datetime.datetime.now(datetime.timezone.utc), identity_id=identity_id, email_factor_id=email_factor_id, otc_id=str(otc_id), ) ) logger.info(f"OTC verified: identity_id={identity_id}, otc_id={otc_id}") def _handle_otc_failed(self, ex: Exception) -> None: match ex: case errors.OTCRateLimited(): failure_reason = "rate_limited" case errors.OTCInvalidCode(): failure_reason = "invalid_code" case errors.OTCExpired(): failure_reason = "expired" case errors.OTCVerificationFailed(): failure_reason = "verification_failed" case _: failure_reason = "unknown" metrics.otc_failed_total.inc( 1.0, self.tenant.get_instance_name(), failure_reason ) logger.info(f"OTC verification failed: reason={failure_reason}") def _get_callback_url(self) -> str: return f"{self.base_path}/callback" def _get_data_from_request( self, request: protocol.HttpRequest, ) -> dict[Any, Any]: content_type = request.content_type match content_type: case b"application/x-www-form-urlencoded": return { k: v[0] for k, v in urllib.parse.parse_qs( request.body.decode('ascii') ).items() } case b"application/json": data = json.loads(request.body) if not isinstance(data, dict): raise errors.InvalidData( f"Invalid json data, expected an object" ) return data case _: raise errors.InvalidData( f"Unsupported Content-Type: {content_type!r}" ) def _get_ui_config(self) -> config.UIConfig: return cast( config.UIConfig, util.maybe_get_config( self.db, "ext::auth::AuthConfig::ui", CompositeConfigType ), ) def _get_app_details_config(self) -> config.AppDetailsConfig: return util.get_app_details_config(self.db) def _get_password_provider( self, ) -> Optional[config.EmailPasswordProviderConfig]: providers = cast( list[config.ProviderConfig], util.get_config( self.db, "ext::auth::AuthConfig::providers", frozenset, ), ) password_providers = [ cast(config.EmailPasswordProviderConfig, p) for p in providers if (p.name == 'builtin::local_emailpassword') ] return password_providers[0] if len(password_providers) == 1 else None def _get_webauthn_provider(self) -> config.WebAuthnProvider | None: providers = cast( list[config.ProviderConfig], util.get_config( self.db, "ext::auth::AuthConfig::providers", frozenset, ), ) webauthn_providers = cast( list[config.WebAuthnProviderConfig], [p for p in providers if (p.name == 'builtin::local_webauthn')], ) if len(webauthn_providers) == 1: provider = webauthn_providers[0] return config.WebAuthnProvider( name=provider.name, relying_party_origin=provider.relying_party_origin, require_verification=provider.require_verification, verification_method=provider.verification_method, ) else: return None def _make_verification_token( self, identity_id: str, verify_url: str, maybe_challenge: str | None, maybe_redirect_to: str | None, *, maybe_provider: str | None = None, ) -> str: if not self._is_url_allowed(verify_url): raise errors.InvalidData( "Verify URL does not match any allowed URLs.", ) return jwt.VerificationToken( subject=identity_id, verify_url=verify_url, maybe_challenge=maybe_challenge, maybe_redirect_to=maybe_redirect_to, ).sign(self.signing_key) async def _send_verification_email( self, *, verification_token: str, verify_url: str, provider: str, to_addr: str, ) -> None: client: email_password.Client | webauthn.Client | None = None match provider: case "builtin::local_emailpassword": client = email_password.Client(db=self.db) case "builtin::local_webauthn": client = webauthn.Client(self.db) if client is not None: if client.config.verification_method == "Code": email_factor = await client.get_email_factor_by_email(to_addr) if email_factor is not None: code, otc_id = await otc.create( self.db, str(email_factor.id), datetime.timedelta(minutes=10), ) await auth_emails.send_one_time_code_email( db=self.db, tenant=self.tenant, to_addr=to_addr, code=code, test_mode=self.test_mode, ) await self._handle_otc_initiated( identity_id=email_factor.identity.id, email_factor_id=str(email_factor.id), otc_id=str(otc_id), one_time_code=code, ) logger.info( "Sent OTC verification email: " f"email={to_addr}, otc_id={otc_id}" ) return await auth_emails.send_verification_email( db=self.db, tenant=self.tenant, to_addr=to_addr, verification_token=verification_token, provider=provider, verify_url=verify_url, test_mode=self.test_mode, ) async def _try_verify_email( self, provider: str, identity_id: str ) -> EmailFactor: current_time = datetime.datetime.now(datetime.timezone.utc) client: email_password.Client | webauthn.Client match provider: case "builtin::local_emailpassword": client = email_password.Client(db=self.db) case "builtin::local_webauthn": client = webauthn.Client(self.db) case _: raise errors.InvalidData( f"Unknown provider: {provider}", ) updated = await client.verify_email(identity_id, current_time) if updated is None: raise errors.NoIdentityFound( "Could not verify email for identity" f" {identity_id}. This email address may not exist" " in our system, or it might already be verified." ) return updated def _is_url_allowed(self, url: str) -> bool: allowed_urls = util.get_config( self.db, "ext::auth::AuthConfig::allowed_redirect_urls", frozenset, ) allowed_urls = cast(frozenset[str], allowed_urls).union( {self.base_path} ) ui_config = self._get_ui_config() if ui_config: allowed_urls = allowed_urls.union({ui_config.redirect_to}) if ui_config.redirect_to_on_signup: allowed_urls = allowed_urls.union( {ui_config.redirect_to_on_signup} ) lower_url = url.lower() for allowed_url in allowed_urls: lower_allowed_url = allowed_url.lower() if lower_url.startswith(lower_allowed_url): return True parsed_allowed_url = urllib.parse.urlparse(lower_allowed_url) allowed_domain = parsed_allowed_url.netloc allowed_path = parsed_allowed_url.path parsed_lower_url = urllib.parse.urlparse(lower_url) lower_domain = parsed_lower_url.netloc lower_path = parsed_lower_url.path if ( lower_domain == allowed_domain or lower_domain.endswith('.' + allowed_domain) ) and lower_path.startswith(allowed_path): return True return False def _do_redirect( self, response: protocol.HttpResponse, allowed_url: AllowedUrl ) -> None: response.status = http.HTTPStatus.FOUND response.custom_headers["Location"] = allowed_url.url def _try_redirect(self, response: protocol.HttpResponse, url: str) -> None: allowed_url = self._make_allowed_url(url) self._do_redirect(response, allowed_url) def _make_allowed_url(self, url: str) -> AllowedUrl: if not self._is_url_allowed(url): raise errors.InvalidData( "Redirect URL does not match any allowed URLs.", ) return AllowedUrl(url) def _maybe_make_allowed_url( self, url: Optional[str] ) -> Optional[AllowedUrl]: return self._make_allowed_url(url) if url else None @dataclasses.dataclass class AllowedUrl: url: str def map(self, f: Callable[[str], str]) -> "AllowedUrl": return AllowedUrl(f(self.url)) def _fail_with_error( *, response: protocol.HttpResponse, status: http.HTTPStatus, ex: Exception, exc_info: bool = False, ) -> None: err_dct = { "message": str(ex), "type": str(ex.__class__.__name__), } logger.error( f"Failed to handle HTTP request: {err_dct!r}", exc_info=exc_info ) response.body = json.dumps({"error": err_dct}).encode() response.status = status def _maybe_get_search_param( query_dict: dict[str, list[str]], key: str, *, fallback_keys: Optional[list[str]] = None, ) -> str | None: params = query_dict.get(key) if params is None and fallback_keys is not None: for fallback_key in fallback_keys: params = query_dict.get(fallback_key) if params is not None: break return params[0] if params else None def _get_search_param( query_dict: dict[str, list[str]], key: str, *, fallback_keys: Optional[list[str]] = None, ) -> str: val = _maybe_get_search_param(query_dict, key, fallback_keys=fallback_keys) if val is None: raise errors.InvalidData(f"Missing query parameter: {key}") return val def _maybe_get_form_field( form_dict: dict[str, list[str]], key: str ) -> str | None: maybe_val = form_dict.get(key) if maybe_val is None: return None return maybe_val[0] def _get_pkce_challenge( *, response: protocol.HttpResponse, cookies: http.cookies.SimpleCookie, query_dict: dict[str, list[str]], ) -> str | None: cookie_name = 'edgedb-pkce-challenge' challenge: str | None = _maybe_get_search_param(query_dict, 'challenge') if challenge is not None: logger.info( f"PKCE challenge found in query param 'challenge': {challenge!r}" ) else: challenge = _maybe_get_search_param(query_dict, "code_challenge") if challenge is not None: logger.info( "PKCE challenge found in query param 'code_challenge':" f" {challenge!r}" ) if challenge is not None: _set_cookie(response, cookie_name, challenge) else: if 'edgedb-pkce-challenge' in cookies: challenge = cookies['edgedb-pkce-challenge'].value logger.info(f"PKCE challenge found in cookie: {challenge!r}") else: logger.info("No PKCE challenge found in query params or cookies.") logger.info(f"Query params: {query_dict}") logger.info(f"Cookies: {cookies}") return challenge def _set_cookie( response: protocol.HttpResponse, name: str, value: str, *, http_only: bool = True, secure: bool = True, same_site: str = "Strict", path: Optional[str] = None, ) -> None: val: http.cookies.Morsel[str] = http.cookies.SimpleCookie({name: value})[ name ] val["httponly"] = http_only val["secure"] = secure val["samesite"] = same_site if path is not None: val["path"] = path response.custom_headers["Set-Cookie"] = val.OutputString() def _check_keyset(candidate: dict[str, Any], keyset: set[str]) -> None: missing_fields = [field for field in keyset if field not in candidate] if missing_fields: raise errors.InvalidData( f"Missing required fields: {', '.join(missing_fields)}" ) def _accepts_json(request: protocol.HttpRequest) -> bool: return request.accept == b"application/json" ================================================ FILE: edb/server/protocol/auth_ext/jwt.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import dataclasses import datetime from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand from cryptography.hazmat.backends import default_backend from typing import Any, Callable from edb.server import auth as jwt_auth from edb.ir import statypes as statypes from . import errors VALIDATION_TOKEN_DEFAULT_EXPIRATION = datetime.timedelta(seconds=24 * 60 * 60) RESET_TOKEN_DEFAULT_EXPIRATION = datetime.timedelta(minutes=10) OAUTH_STATE_TOKEN_DEFAULT_EXPIRATION = datetime.timedelta(minutes=5) class SigningKey: subkeys: dict[str | None, jwt_auth.JWKSet] def __init__( self, key_fetch: Callable[[], str], issuer: str, *, is_key_for_testing: bool = False, ): self.key = "" self.key_fetch = key_fetch self.issuer = issuer self.subkeys = {} self.__is_key_for_testing = is_key_for_testing def subkey(self, info: str | None = None) -> jwt_auth.JWKSet: # Clear keycache if the key has changed current_key = self.key_fetch() if current_key != self.key: self.key = current_key self.subkeys = {} if info in self.subkeys: return self.subkeys[info] if info is None: key = jwt_auth.JWKSet.from_hs256_key(self.key.encode()) else: key = jwt_auth.JWKSet.from_hs256_key(derive_key_raw(self.key, info)) key.default_validation_context.require_expiry() key.default_validation_context.allow("iss", [self.issuer]) self.subkeys[info] = key return key def sign( self, info: str | None, claims: dict[str, str | None], *, ctx: jwt_auth.SigningCtx, ) -> str: # Remove any None values from the claims claims = {k: v for k, v in claims.items() if v is not None} if self.__is_key_for_testing: claims["__test__"] = str(info) return self.subkey(info).sign(claims, ctx=ctx) def validate( self, token: str, info: str | None = None, skip_expiration_check: bool = False, ) -> dict[str, Any]: key = self.subkey(info) try: ctx = None if skip_expiration_check: ctx = jwt_auth.ValidationCtx() ctx.ignore_expiry() ctx.allow("iss", [self.issuer]) return key.validate(token, ctx=ctx) except Exception as e: raise errors.InvalidData(f"Invalid token: {e}") from e def verify_str(cls: Any, claims: dict[str, Any], key: str) -> str: value = claims.get(key, None) if isinstance(value, str): return value raise errors.InvalidData(f"Invalid '{cls.__name__}'") def verify_str_opt(cls: Any, claims: dict[str, Any], key: str) -> str | None: value = claims.get(key, None) if isinstance(value, str): return value if value is None: return None raise errors.InvalidData(f"Invalid '{cls.__name__}'") @dataclasses.dataclass class MagicLinkToken: """ A token that can be used to verify a magic link sent to a user via email. Expiration is controlled by the provider parameter `token_time_to_live`. """ subject: str callback_url: str challenge: str def sign( self, signing_key: SigningKey, expires_in: datetime.timedelta ) -> str: signing_ctx = jwt_auth.SigningCtx() signing_ctx.set_expiry(int(expires_in.total_seconds())) signing_ctx.set_not_before(30) signing_ctx.set_issuer(signing_key.issuer) return signing_key.sign( "magic_link", { "sub": self.subject, "callback_url": self.callback_url, "challenge": self.challenge, }, ctx=signing_ctx, ) @classmethod def verify(cls, token: str, signing_key: SigningKey) -> 'MagicLinkToken': claims = signing_key.validate(token, "magic_link") identity_id = verify_str(cls, claims, 'sub') challenge = verify_str(cls, claims, 'challenge') callback_url = verify_str(cls, claims, 'callback_url') return MagicLinkToken( subject=identity_id, callback_url=callback_url, challenge=challenge, ) @dataclasses.dataclass class ResetToken: """ A token that can be used to verify a password reset request. """ subject: str secret: str challenge: str def sign( self, signing_key: SigningKey, expires_in: datetime.timedelta = RESET_TOKEN_DEFAULT_EXPIRATION, ) -> str: signing_ctx = jwt_auth.SigningCtx() signing_ctx.set_expiry(int(expires_in.total_seconds())) signing_ctx.set_not_before(30) signing_ctx.set_issuer(signing_key.issuer) return signing_key.sign( "reset", { "sub": self.subject, "secret": self.secret, "challenge": self.challenge, }, ctx=signing_ctx, ) @classmethod def verify(cls, token: str, signing_key: SigningKey) -> 'ResetToken': claims = signing_key.validate(token, "reset") return ResetToken( subject=verify_str(cls, claims, 'sub'), secret=verify_str(cls, claims, 'secret'), challenge=verify_str(cls, claims, 'challenge'), ) @dataclasses.dataclass class VerificationToken: """ A token that can be used to verify a user's email address. Note that we allow expired tokens to trigger a resend of the verification email, but not to verify the email address. """ subject: str verify_url: str maybe_challenge: str | None maybe_redirect_to: str | None def sign( self, signing_key: SigningKey, expires_in: datetime.timedelta = VALIDATION_TOKEN_DEFAULT_EXPIRATION, ) -> str: signing_ctx = jwt_auth.SigningCtx() signing_ctx.set_expiry(int(expires_in.total_seconds())) signing_ctx.set_not_before(30) signing_ctx.set_issuer(signing_key.issuer) return signing_key.sign( "verification", { "sub": self.subject, "verify_url": self.verify_url, "challenge": self.maybe_challenge, "redirect_to": self.maybe_redirect_to, }, ctx=signing_ctx, ) @classmethod def verify( cls, token: str, signing_key: SigningKey, skip_expiration_check: bool = False, ) -> 'VerificationToken': claims = signing_key.validate( token, "verification", skip_expiration_check=skip_expiration_check, ) return VerificationToken( subject=verify_str(cls, claims, 'sub'), verify_url=verify_str(cls, claims, 'verify_url'), maybe_challenge=verify_str_opt(cls, claims, 'challenge'), maybe_redirect_to=verify_str_opt(cls, claims, 'redirect_to'), ) @dataclasses.dataclass class SessionToken: """ The token representing an auth session for a user. Expiration is controlled by the database parameter `ext::auth::AuthConfig::token_time_to_live`. """ subject: str def sign( self, signing_key: SigningKey, expires_in: datetime.timedelta, ) -> str: signing_ctx = jwt_auth.SigningCtx() signing_ctx.set_expiry(int(expires_in.total_seconds())) signing_ctx.set_not_before(30) signing_ctx.set_issuer(signing_key.issuer) return signing_key.sign( None, { "sub": self.subject, }, ctx=signing_ctx, ) @classmethod def verify(cls, token: str, signing_key: SigningKey) -> 'SessionToken': claims = signing_key.validate(token, None) return SessionToken( subject=verify_str(cls, claims, 'sub'), ) @dataclasses.dataclass class OAuthStateToken: """ The token representing an OAuth state passed to the identity provider. It allows the auth extension server to reference data from the original authorize request, such as the provider, application redirect URLs, PKCE challenge, and OAuth callback URL. """ provider: str redirect_to: str challenge: str redirect_uri: str redirect_to_on_signup: str | None = None def sign( self, signing_key: SigningKey, expires_in: datetime.timedelta = OAUTH_STATE_TOKEN_DEFAULT_EXPIRATION, ) -> str: signing_ctx = jwt_auth.SigningCtx() signing_ctx.set_expiry(int(expires_in.total_seconds())) signing_ctx.set_not_before(30) signing_ctx.set_issuer(signing_key.issuer) return signing_key.sign( "state", { "provider": self.provider, "redirect_to": self.redirect_to, "redirect_to_on_signup": self.redirect_to_on_signup, "challenge": self.challenge, "redirect_uri": self.redirect_uri, }, ctx=signing_ctx, ) @classmethod def verify(cls, token: str, signing_key: SigningKey) -> 'OAuthStateToken': claims = signing_key.validate(token, "state") return OAuthStateToken( provider=verify_str(cls, claims, 'provider'), redirect_to=verify_str(cls, claims, 'redirect_to'), redirect_to_on_signup=verify_str_opt( cls, claims, 'redirect_to_on_signup' ), challenge=verify_str(cls, claims, 'challenge'), redirect_uri=verify_str(cls, claims, 'redirect_uri'), ) def derive_key_raw(key: str, info: str) -> bytes: """Derive a new key from the given symmetric key using HKDF.""" input_key_material = key.encode() backend = default_backend() hkdf = HKDFExpand( algorithm=hashes.SHA256(), length=32, info=info.encode("utf-8"), backend=backend, ) new_key_bytes = hkdf.derive(input_key_material) return new_key_bytes ================================================ FILE: edb/server/protocol/auth_ext/local.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime import json from typing import Any, cast from . import data, util class Client: def __init__(self, db: Any): self.db = db async def verify_email( self, identity_id: str, verified_at: datetime.datetime ) -> data.EmailFactor | None: result_bytes = await util.json_query( db=self.db, query="""\ with LOCAL_IDENTITY := $identity_id, verified_at := $verified_at, UPDATED := ( update ext::auth::EmailFactor filter .identity = LOCAL_IDENTITY and not exists .verified_at ?? false set { verified_at := verified_at } ), select UPDATED {**}; """, variables={ "identity_id": identity_id, "verified_at": verified_at.isoformat(), }, ) result_json = json.loads(result_bytes.decode()) if len(result_json) == 0: return None factor = result_json[0] return data.EmailFactor(**factor) async def get_email_factor_by_identity_id( self, identity_id: str ) -> data.EmailFactor | None: r = await util.json_query( self.db, """ select ext::auth::EmailFactor { ** } filter .identity.id = $identity_id; """, variables={"identity_id": identity_id}, ) result_json = json.loads(r.decode()) if len(result_json) == 0: return None assert len(result_json) == 1 factor = result_json[0] return data.EmailFactor(**factor) async def get_verified_by_identity_id(self, identity_id: str) -> str | None: r = await util.json_query( self.db, """ select ext::auth::EmailFactor { verified_at, } filter .identity.id = $identity_id; """, variables={"identity_id": identity_id}, ) result_json = json.loads(r.decode()) if len(result_json) == 0: return None assert len(result_json) == 1 return cast(str, result_json[0]["verified_at"]) async def get_identity_id_by_email( self, email: str, *, factor_type: str = 'EmailFactor' ) -> str | None: r = await util.json_query( self.db, f""" with email := $email, identity := ( select ext::auth::LocalIdentity filter . data.EmailFactor | None: r = await util.json_query( self.db, """ select ext::auth::EmailFactor { ** } filter .email = $email; """, variables={"email": email}, ) result_json = json.loads(r.decode()) if len(result_json) == 0: return None assert len(result_json) == 1 factor = result_json[0] return data.EmailFactor(**factor) ================================================ FILE: edb/server/protocol/auth_ext/magic_link.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import aiosmtplib import json from typing import Any, cast from edb import errors as edb_errors from edb.common import debug from . import config, data, errors, jwt, util, local, email as auth_emails logger = logging.getLogger('edb.server') class Client(local.Client): def __init__( self, db: Any, tenant: Any, test_mode: bool, issuer: str, signing_key: jwt.SigningKey, ): super().__init__(db) self.tenant = tenant self.test_mode = test_mode self.issuer = issuer self.provider = self._get_provider() self.signing_key = signing_key def _get_provider(self) -> config.MagicLinkProviderConfig: provider_name = "builtin::local_magic_link" provider_client_config = cast( list[config.ProviderConfig], util.get_config( self.db, "ext::auth::AuthConfig::providers", frozenset ), ) for cfg in provider_client_config: if cfg.name == provider_name: cfg = cast(config.MagicLinkProviderConfig, cfg) return config.MagicLinkProviderConfig( name=cfg.name, token_time_to_live=cfg.token_time_to_live, verification_method=cfg.verification_method, auto_signup=cfg.auto_signup, ) raise errors.MissingConfiguration( provider_name, f"Provider is not configured" ) async def register(self, email: str) -> data.EmailFactor: try: result = await util.json_query( self.db, """ with email := $email, identity := (insert ext::auth::LocalIdentity { issuer := "local", subject := "", }), email_factor := (insert ext::auth::MagicLinkFactor { email := email, identity := identity, }) select email_factor { ** };""", variables={ "email": email, }, ) except edb_errors.ConstraintViolationError: raise errors.UserAlreadyRegistered() result_json = json.loads(result.decode()) assert len(result_json) == 1 factor_dict = result_json[0] return data.EmailFactor(**factor_dict) def make_magic_link_token( self, *, identity_id: str, callback_url: str, challenge: str, ) -> str: return jwt.MagicLinkToken( subject=identity_id, callback_url=callback_url, challenge=challenge, ).sign( self.signing_key, expires_in=self.provider.token_time_to_live.to_timedelta(), ) async def send_magic_link( self, *, email: str, link_url: str, token: str, redirect_on_failure: str, ) -> None: link = util.join_url_params( link_url, { "token": token, "redirect_on_failure": redirect_on_failure, }, ) try: await auth_emails.send_magic_link_email( db=self.db, tenant=self.tenant, to_addr=email, link=link, test_mode=self.test_mode, ) except aiosmtplib.SMTPException as ex: if not debug.flags.server: logger.warning( "Failed to send magic link via SMTP", exc_info=True ) raise edb_errors.InternalServerError( "Failed to send magic link email, please try again later." ) from ex ================================================ FILE: edb/server/protocol/auth_ext/oauth.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json from typing import cast, Any, Callable from edb.server.http import HttpClient from . import github, google, azure, apple, discord, slack from . import config, errors, util, data, base class Client: provider: base.BaseProvider def __init__( self, *, db: Any, provider_name: str, http_client: HttpClient, url_munger: Callable[[str], str] | None = None, ): self.db = db http_factory = lambda *args, **kwargs: http_client.with_context( *args, url_munger=url_munger, **kwargs ) provider_config = self._get_provider_config(provider_name) provider_args: tuple[str, str] | tuple[str, str, str, str] = ( provider_config.client_id, provider_config.secret, ) provider_kwargs = { "http_factory": http_factory, "additional_scope": provider_config.additional_scope, } match (provider_name, provider_config.issuer_url): case ("builtin::oauth_github", _): self.provider = github.GitHubProvider( *provider_args, **provider_kwargs, ) case ("builtin::oauth_google", _): self.provider = google.GoogleProvider( *provider_args, **provider_kwargs, ) case ("builtin::oauth_azure", _): self.provider = azure.AzureProvider( *provider_args, **provider_kwargs, ) case ("builtin::oauth_apple", _): self.provider = apple.AppleProvider( *provider_args, **provider_kwargs, ) case ("builtin::oauth_discord", _) if isinstance( provider_config, config.DiscordOAuthProviderConfig ): self.provider = discord.DiscordProvider( provider_config.prompt, *provider_args, **provider_kwargs, ) case ("builtin::oauth_slack", _): self.provider = slack.SlackProvider( *provider_args, **provider_kwargs, ) case (provider_name, str(issuer_url)): self.provider = base.OpenIDConnectProvider( provider_name, issuer_url, *provider_args, **provider_kwargs, ) case _: raise errors.InvalidData(f"Invalid provider: {provider_name}") async def get_authorize_url(self, state: str, redirect_uri: str) -> str: return await self.provider.get_code_url( state=state, redirect_uri=redirect_uri, additional_scope=self.provider.additional_scope or "", ) async def handle_callback( self, code: str, redirect_uri: str ) -> tuple[data.Identity, bool, str | None, str | None, str | None]: response = await self.provider.exchange_code(code, redirect_uri) user_info = await self.provider.fetch_user_info(response) auth_token = response.access_token refresh_token = response.refresh_token source_id_token = user_info.source_id_token return ( *(await self._handle_identity(user_info)), auth_token, refresh_token, source_id_token, ) async def _handle_identity( self, user_info: data.UserInfo ) -> tuple[data.Identity, bool]: """Update or create an identity""" r = await util.json_query( db=self.db, query="""\ with iss := $issuer_url, sub := $subject, identity := ( insert ext::auth::Identity { issuer := iss, subject := sub, } unless conflict on ((.issuer, .subject)) else ext::auth::Identity ) select { identity := (select identity {*}), new := (identity not in ext::auth::Identity) };""", variables={ "issuer_url": self.provider.issuer_url, "subject": user_info.sub, }, ) result_json = json.loads(r.decode()) assert len(result_json) == 1 return ( data.Identity(**result_json[0]['identity']), result_json[0]['new'], ) def _get_provider_config( self, provider_name: str ) -> config.OAuthProviderConfig | config.DiscordOAuthProviderConfig: provider_client_config = util.get_config( self.db, "ext::auth::AuthConfig::providers", frozenset ) for cfg in provider_client_config: if cfg.name == provider_name: cfg = cast(config.OAuthProviderConfig, cfg) if provider_name == "builtin::oauth_discord": return config.DiscordOAuthProviderConfig( name=cfg.name, display_name=cfg.display_name, client_id=cfg.client_id, secret=cfg.secret, additional_scope=getattr(cfg, "additional_scope", None), issuer_url=getattr(cfg, "issuer_url", None), logo_url=getattr(cfg, "logo_url", None), prompt=getattr(cfg, "prompt", "consent"), ) return config.OAuthProviderConfig( name=cfg.name, display_name=cfg.display_name, client_id=cfg.client_id, secret=cfg.secret, additional_scope=getattr(cfg, "additional_scope", None), issuer_url=getattr(cfg, "issuer_url", None), logo_url=getattr(cfg, "logo_url", None), ) raise errors.MissingConfiguration( provider_name, "Provider is not configured" ) ================================================ FILE: edb/server/protocol/auth_ext/otc.py ================================================ # # This source file is part of the Gel open source project. # # Copyright 2025-present MagicStack Inc. and the Gel authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import hashlib import json import logging import secrets import uuid import datetime from typing import Any from edb.server import defines from . import errors, util logger = logging.getLogger(__name__) MAX_ATTEMPTS = 5 def generate_code() -> str: return f"{secrets.randbelow(1000000):06d}" def hash_code(code: str) -> bytes: return hashlib.sha256(code.encode('utf-8')).digest() async def create( db: Any, factor_id: str, ttl: datetime.timedelta ) -> tuple[str, uuid.UUID]: """ Create a new OneTimeCode object in the database. Args: db: Database connection factor_id: The ID of the factor to associate with this code ttl: Time to live for the code Returns: Tuple of (code, otc_id) where code is the plain text code and otc_id is the database ID """ code = generate_code() code_hash = hash_code(code) expires_at = datetime.datetime.now(datetime.timezone.utc) + ttl r = await util.json_query( db=db, query="""\ with ONE_TIME_CODE := (insert ext::auth::OneTimeCode { factor := $factor_id, code_hash := $code_hash, expires_at := $expires_at, }) select ONE_TIME_CODE { id };""", variables={ "factor_id": factor_id, "code_hash": code_hash, "expires_at": expires_at.isoformat(), }, ) result_json = json.loads(r.decode()) if len(result_json) != 1: raise errors.InvalidData("Failed to create OneTimeCode") return code, uuid.UUID(result_json[0]["id"]) async def verify(db: Any, factor_id: str, code: str) -> str: """ Verify a one-time code for a given factor. This function performs all verification operations in a single atomic query: - Cleanup expired codes - Check rate limits - Find and validate the code - Delete the code if valid - Record the authentication attempt Args: db: Database connection factor_id: The ID of the factor code: The code to verify Returns: The OneTimeCode ID if verification succeeds Raises: OTCRateLimited: If maximum verification attempts exceeded OTCInvalidCode: If the code is invalid or not found OTCExpired: If the code has expired OTCVerificationFailed: If verification fails for other reasons """ code_hash = hash_code(code) # N.B: I (sully) don't want to make this RepeatableRead because I # worry that it would allow bypassing MAX_ATTEMPTS by performing # many requests "at the same time". We need to analyze this before # we change it. r = await util.json_query( db=db, query="""\ with factor_id := $factor_id, FACTOR := (select ext::auth::Factor filter .id = factor_id), code_hash := $code_hash, MAX_ATTEMPTS := $max_attempts, now := datetime_current(), window_start := now - '10 minutes', # Check rate limits failed_attempts := ( select count( select ext::auth::AuthenticationAttempt filter .factor = FACTOR and .attempt_type = ext::auth::AuthenticationAttemptType.OneTimeCode and .successful = false and .created_at > window_start ) ), # Find the OTC otc := ( select ext::auth::OneTimeCode filter .factor = FACTOR and .code_hash = code_hash limit 1 ), is_rate_limited := failed_attempts >= MAX_ATTEMPTS, is_code_found := exists otc, is_code_expired := (otc.expires_at < now) ?? false, is_code_valid := ( is_code_found and not is_code_expired and not is_rate_limited ), # Delete OTC if valid (side effect) deleted_otc := ( delete ext::auth::OneTimeCode filter .id = otc.id and is_code_valid ), # Record attempt (side effect) recorded_attempt := ( if (exists FACTOR) then ( insert ext::auth::AuthenticationAttempt { factor := FACTOR, attempt_type := ext::auth::AuthenticationAttemptType.OneTimeCode, successful := is_code_valid, } ) else {} ), select { failed_attempts := failed_attempts, success := is_code_valid, rate_limited := is_rate_limited, code_found := is_code_found, code_expired := is_code_expired, otc_id := otc.id, };""", variables={ "factor_id": factor_id, "code_hash": code_hash, "max_attempts": MAX_ATTEMPTS, }, ) result_json = json.loads(r.decode()) result = result_json[0] # Run an OTC GC. We don't really mind if it fails due to # serialization problems. try: await util.json_query( db=db, query="""\ with now := datetime_current(), # Cleanup expired codes select count( delete ext::auth::OneTimeCode filter .expires_at < now ) """, tx_isolation=defines.TxIsolationLevel.RepeatableRead, ) except Exception: pass if result["rate_limited"]: raise errors.OTCRateLimited() elif not result["code_found"]: raise errors.OTCInvalidCode() elif result["code_expired"]: raise errors.OTCExpired() elif not result["success"]: raise errors.OTCVerificationFailed() return str(result["otc_id"]) async def cleanup_old_attempts(db: Any, retention_hours: int = 24) -> int: """ Remove authentication attempts older than the retention window. This is intended for scheduled maintenance jobs to prevent unbounded growth of the authentication attempts table. Args: db: Database connection retention_hours: Hours to retain attempts (defaults to 24) Returns: Number of old attempts that were deleted """ r = await util.json_query( db=db, query="""\ with cutoff_time := datetime_current() - $retention_duration, old_attempts := ( delete ext::auth::AuthenticationAttempt filter .created_at < cutoff_time ) select count(old_attempts);""", variables={ "retention_duration": f"{retention_hours} hours", }, ) result_json = json.loads(r.decode()) return result_json[0] if result_json else 0 ================================================ FILE: edb/server/protocol/auth_ext/pkce.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import typing import asyncio import json import logging import dataclasses from edb.ir import statypes from edb.server import defines from . import util, errors if typing.TYPE_CHECKING: from edb.server import server as edbserver from edb.server import tenant as edbtenant logger = logging.getLogger("edb.server") VALIDITY = statypes.Duration.from_microseconds(10 * 60_000_000) # 10 minutes @dataclasses.dataclass(repr=False) class PKCEChallenge: """ Object that represents the ext::auth::PKCEChallenge type """ id: str challenge: str auth_token: str | None refresh_token: str | None id_token: str | None identity_id: str | None async def create(db: edbtenant.dbview.Database, challenge: str) -> None: await util.json_query( db, """ insert ext::auth::PKCEChallenge { challenge := $challenge, } unless conflict on .challenge else (select ext::auth::PKCEChallenge) """, variables={ "challenge": challenge, }, ) async def link_identity_challenge( db: edbtenant.dbview.Database, identity_id: str, challenge: str, ) -> str: r = await util.json_query( db, """ update ext::auth::PKCEChallenge filter .challenge = $challenge set { identity := $identity_id } """, variables={ "challenge": challenge, "identity_id": identity_id, }, ) result_json = json.loads(r.decode()) if len(result_json) != 1: raise errors.PKCEVerificationFailed( f"No linked PKCE session found for challenge '{challenge}'" ) return typing.cast(str, result_json[0]["id"]) async def add_provider_tokens( db: edbtenant.dbview.Database, id: str, auth_token: str | None, refresh_token: str | None, id_token: str | None, ) -> str: r = await util.json_query( db, """ update ext::auth::PKCEChallenge filter .id = $id set { auth_token := $auth_token, refresh_token := $refresh_token, id_token := $id_token, } """, variables={ "id": id, "auth_token": auth_token, "refresh_token": refresh_token, "id_token": id_token, }, ) result_json = json.loads(r.decode()) if len(result_json) != 1: raise errors.PKCEVerificationFailed( f"No PKCE session found with id '{id}'" ) return typing.cast(str, result_json[0]["id"]) async def get_by_id(db: edbtenant.dbview.Database, id: str) -> PKCEChallenge: r = await util.json_query( db, """ select ext::auth::PKCEChallenge { id, challenge, auth_token, refresh_token, id_token, identity_id := .identity.id } filter .id = $id and (datetime_current() - .created_at) < $validity; """, variables={"id": id, "validity": VALIDITY.to_backend_str()}, ) result_json = json.loads(r.decode()) if len(result_json) != 1: raise errors.PKCEVerificationFailed( f"No current PKCE session found with id '{id}'" ) return PKCEChallenge(**result_json[0]) async def delete(db: edbtenant.dbview.Database, id: str) -> None: r = await util.json_query( db, """ delete ext::auth::PKCEChallenge filter .id = $id """, variables={"id": id}, ) result_json = json.loads(r.decode()) if len(result_json) != 1: raise errors.PKCEVerificationFailed( f"No PKCE session found with id '{id}'" ) async def _delete_challenge(db: edbtenant.dbview.Database) -> None: if not db.tenant.is_database_connectable(db.name): # Don't run gc if the database is not connectable, e.g. being dropped return await util.json_query( db, """ delete ext::auth::PKCEChallenge filter (datetime_of_statement() - .created_at) > $validity """, variables={"validity": VALIDITY.to_backend_str()}, tx_isolation=defines.TxIsolationLevel.RepeatableRead, ) async def _gc(tenant: edbtenant.Tenant) -> None: try: async with asyncio.TaskGroup() as g: for db in tenant.iter_dbs(): if "auth" in db.extensions: g.create_task(_delete_challenge(db)) except Exception as ex: logger.debug( "GC of ext::auth::PKCEChallenge failed (instance: %s)", tenant.get_instance_name(), exc_info=ex, ) async def gc(server: edbserver.BaseServer) -> None: while True: try: tasks = [ tenant.create_task(_gc(tenant), interruptable=True) for tenant in server.iter_tenants() if tenant.accept_new_tasks ] if tasks: await asyncio.wait(tasks) except Exception as ex: logger.debug("GC of ext::auth::PKCEChallenge failed", exc_info=ex) finally: await asyncio.sleep(VALIDITY.to_microseconds() / 1_000_000.0) ================================================ FILE: edb/server/protocol/auth_ext/slack.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any from . import base class SlackProvider(base.OpenIDConnectProvider): def __init__(self, *args: Any, **kwargs: Any): super().__init__( "slack", "https://slack.com", *args, **kwargs, ) ================================================ FILE: edb/server/protocol/auth_ext/ui/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import cast, Optional import html import email.message from edb.server.protocol.auth_ext import config as auth_config from . import components as render def render_signin_page( *, base_path: str, providers: frozenset[auth_config.ProviderConfig], error_message: Optional[str] = None, email: Optional[str] = None, challenge: str, selected_tab: Optional[str] = None, # config redirect_to: str, redirect_to_on_signup: Optional[str] = None, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, ) -> bytes: password_provider = None webauthn_provider = None magic_link_provider = None oauth_providers = [] for p in providers: if p.name == 'builtin::local_emailpassword': password_provider = cast(auth_config.EmailPasswordProviderConfig, p) elif p.name == 'builtin::local_webauthn': webauthn_provider = cast(auth_config.WebAuthnProviderConfig, p) elif p.name == 'builtin::local_magic_link': magic_link_provider = cast(auth_config.MagicLinkProviderConfig, p) elif p.name.startswith('builtin::oauth_') or hasattr(p, "issuer_url"): oauth_providers.append(cast(auth_config.OAuthProviderConfig, p)) email_factor_form = render_email_factor_form( challenge=challenge, email=email, selected_tab=selected_tab, single_form_fields=f''' { render.hidden_input( name='redirect_to', value=( redirect_to if webauthn_provider else (base_path + '/ui/magic-link-sent') ), secondary_value=redirect_to, ) } { render.hidden_input( name='redirect_on_failure', value=f'{base_path}/ui/signin', secondary_value=f'{base_path}/ui/signin?selected_tab=password', ) } { render.hidden_input( name='provider', value=magic_link_provider.name if magic_link_provider else '', secondary_value=( password_provider.name if password_provider else '' ), ) } { render.hidden_input(name='callback_url', value=redirect_to) if magic_link_provider else '' } ''', password_form=( render.render_password_form( challenge=challenge, email=email, redirect_to=redirect_to, base_path=base_path, provider_name=password_provider.name, ) if password_provider else None ), webauthn_form=( render.render_webauthn_form( challenge=challenge, email=email, redirect_to=redirect_to, base_path=base_path, provider_name=webauthn_provider.name, ) if webauthn_provider else None ), magic_link_form=( render.render_magic_link_form( challenge=challenge, email=email, callback_url=( redirect_to if magic_link_provider.verification_method == "Link" else None ), base_path=base_path, provider_name=magic_link_provider.name, verification_method=magic_link_provider.verification_method, ) if magic_link_provider else None ), magic_link_verification_method=( magic_link_provider.verification_method if magic_link_provider else 'Link' ), ) if email_factor_form: email_factor_form += render.bottom_note( "Don't have an account?", link='Sign up', href='signup' ) oauth_buttons = render.oauth_buttons( oauth_providers=oauth_providers, label_prefix=('Sign in with' if email_factor_form else 'Continue with'), challenge=challenge, redirect_to=redirect_to, redirect_to_on_signup=redirect_to_on_signup, collapsed=email_factor_form is not None and len(oauth_providers) >= 3, ) return render.base_page( title=f'Sign in{f" to {app_name}" if app_name else ""}', logo_url=logo_url, dark_logo_url=dark_logo_url, brand_color=brand_color, cleanup_search_params=['error', 'email', 'selected_tab'], content=f''' {render.title('Sign in', app_name=app_name)} {render.error_message(error_message)} {oauth_buttons} { render.divider if email_factor_form and len(oauth_providers) > 0 else '' } {email_factor_form or ''} {render.script('webauthn-authenticate') if webauthn_provider else ''} ''', ) def render_email_factor_form( *, selected_tab: Optional[str] = None, single_form_fields: str = '', password_form: Optional[str], webauthn_form: Optional[str], magic_link_form: Optional[str], magic_link_verification_method: str = "Link", # used only for slider mode challenge: Optional[str] = None, email: Optional[str] = None, ) -> Optional[str]: match (password_form, webauthn_form, magic_link_form): case (None, None, None): return None case (_, None, None): return password_form case (None, _, None): return webauthn_form case (None, None, _): return magic_link_form magic_link_tab_label = render.get_magic_link_tab_label( magic_link_verification_method ) magic_link_button_text = render.get_magic_link_button_text( magic_link_verification_method ) # Determine whether to render tabs (multiple distinct forms) or the # single-form slider (quick factor + password) UI. # Slider is shown only when there is exactly one of webauthn/magic-link # available AND a password form, since it relies on a shared email input. has_password = password_form is not None has_webauthn = webauthn_form is not None has_magic_link = magic_link_form is not None has_single_quick_factor = has_webauthn ^ has_magic_link should_render_slider = ( has_password and has_single_quick_factor and challenge is not None ) if not should_render_slider or (has_webauthn and has_magic_link): tabs = [ ( ('Passkey', webauthn_form, selected_tab == 'webauthn') if webauthn_form else None ), ( ('Password', password_form, selected_tab == 'password') if password_form else None ), ( ( magic_link_tab_label, magic_link_form, selected_tab == 'magic_link', ) if magic_link_form else None ), ] selected_tabs = [t[2] for t in tabs if t is not None] selected_index = ( selected_tabs.index(True) if True in selected_tabs else 0 ) labels = [t[0] for t in tabs if t is not None] sections = [t[1] for t in tabs if t is not None] return render.tabs_buttons( labels, selected_index ) + render.tabs_content(sections, selected_index, labels) # Build slider content for the single-form flow. base_email_factor_form = render.render_base_email_form( id="email", challenge=challenge or "", email=email ) password_input = render.render_password_input( challenge=challenge or "", should_show_forgot_password=True ) slider_content = [ f''' { render.button("Sign In", id="webauthn-signin") if webauthn_form else render.button(magic_link_button_text, id="magic-link-signin") } { render.button( "Sign in with password", id="show-password-form", secondary=True, type="button", ) } ''', f''' {password_input}
{ render.button( None, id="hide-password-form", secondary=True, type="button" ) } {render.button("Sign in with password", id="password-signin")}
''', ] return f"""
{single_form_fields} {base_email_factor_form} { render.tabs_content( slider_content, selected_tab=(1 if selected_tab == 'password' else 0), ) }
""" def render_signup_page( *, base_path: str, providers: frozenset[auth_config.ProviderConfig], error_message: Optional[str] = None, email: Optional[str] = None, challenge: str, selected_tab: Optional[str] = None, # config redirect_to: str, redirect_to_on_signup: Optional[str] = None, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, ) -> bytes: password_provider = None webauthn_provider = None magic_link_provider = None oauth_providers = [] for p in providers: if p.name == 'builtin::local_emailpassword': password_provider = cast(auth_config.EmailPasswordProviderConfig, p) elif p.name == 'builtin::local_webauthn': webauthn_provider = cast(auth_config.WebAuthnProviderConfig, p) elif p.name == 'builtin::local_magic_link': magic_link_provider = cast(auth_config.MagicLinkProviderConfig, p) elif p.name.startswith('builtin::oauth_') or hasattr(p, "issuer_url"): oauth_providers.append(cast(auth_config.OAuthProviderConfig, p)) email_factor_form = render_email_factor_form( selected_tab=selected_tab, password_form=( render.render_password_signup_form( challenge=challenge, email=email, redirect_to=render.get_email_password_signup_redirect_url( password_provider.verification_method, base_path, redirect_to_on_signup or redirect_to, ), base_path=base_path, provider_name=password_provider.name, ) if password_provider else None ), webauthn_form=( render.render_webauthn_signup_form( challenge=challenge, email=email, redirect_to=render.get_webauthn_signup_redirect_url( webauthn_provider.verification_method, base_path, redirect_to_on_signup or redirect_to, ), base_path=base_path, provider_name=webauthn_provider.name, ) if webauthn_provider else None ), magic_link_form=( render.render_magic_link_signup_form( challenge=challenge, email=email, callback_url=redirect_to_on_signup or redirect_to, base_path=base_path, provider_name=magic_link_provider.name, verification_method=magic_link_provider.verification_method, ) if magic_link_provider else None ), magic_link_verification_method=( magic_link_provider.verification_method if magic_link_provider else 'Link' ), ) if email_factor_form: email_factor_form += render.bottom_note( 'Already have an account?', link='Sign in', href='signin' ) oauth_buttons = render.oauth_buttons( oauth_providers=oauth_providers, label_prefix=('Sign up with' if email_factor_form else 'Continue with'), challenge=challenge, redirect_to=redirect_to, redirect_to_on_signup=redirect_to_on_signup, collapsed=email_factor_form is not None and len(oauth_providers) >= 3, ) return render.base_page( title=f'Sign up{f" to {app_name}" if app_name else ""}', logo_url=logo_url, dark_logo_url=dark_logo_url, brand_color=brand_color, cleanup_search_params=['error', 'email', 'selected_tab'], content=f''' {render.title('Sign up', app_name=app_name)} {render.error_message(error_message)} {oauth_buttons} { render.divider if email_factor_form and len(oauth_providers) > 0 else '' } {email_factor_form or ''} {render.script('webauthn-register') if webauthn_provider else ''} ''', ) def render_forgot_password_page( *, redirect_to: str, redirect_on_failure: str, reset_url: str, provider_name: str, challenge: str, error_message: Optional[str] = None, email: Optional[str] = None, email_sent: Optional[str] = None, # config app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, ) -> bytes: if email_sent is not None: content = render.success_message( f'Password reset email has been sent to {email_sent}' ) else: content = f''' {render.error_message(error_message)}
{render.button('Send Reset Email')}
''' return render.base_page( title=f'Reset password{f" for {app_name}" if app_name else ""}', logo_url=logo_url, dark_logo_url=dark_logo_url, brand_color=brand_color, cleanup_search_params=['error', 'email', 'email_sent'], content=f''' {render.title('Reset password', join='for', app_name=app_name)} {content} {render.bottom_note("Back to", link="Sign In", href="signin")} ''', ) def render_reset_password_page( *, base_path: str, provider_name: str, is_valid: bool, redirect_to: str, challenge: str, reset_token: Optional[str] = None, error_message: Optional[str] = None, is_code_flow: bool = False, email: Optional[str] = None, # config app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, ) -> bytes: if not is_valid: content = render.error_message( f'''Reset token is invalid, it may have expired. Try sending another reset email ''', False, ) elif is_code_flow and email: content = f''' {render.error_message(error_message)}

We've sent a 6-digit reset code to { html.escape(email) }

{ render.code_input_form( action="../reset-password", email=email, provider=provider_name, label="Enter reset code", button_text="Reset Password", additional_fields=f''' ''', ) } ''' else: content = f''' {render.error_message(error_message)}
{render.button('Reset Password')}
''' return render.base_page( title=f'Reset password{f" for {app_name}" if app_name else ""}', logo_url=logo_url, dark_logo_url=dark_logo_url, brand_color=brand_color, cleanup_search_params=['error'], content=f''' {render.title('Reset password', join='for', app_name=app_name)} {content} ''', ) def render_email_verification_page_code_flow( *, email: str, provider: str, callback_url: str, base_path: str, challenge: str, error_message: Optional[str] = None, # config app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, ) -> bytes: """Renders verification page that handles both link and code flows.""" content = f''' {render.error_message(error_message)}

We've sent a 6-digit verification code to { html.escape(email) }

{ render.code_input_form( action="../verify", email=email, provider=provider, label="Enter verification code", button_text="Verify Email", additional_fields=f''' ''' ) } ''' return render.base_page( title=f'Verify email{f" for {app_name}" if app_name else ""}', logo_url=logo_url, dark_logo_url=dark_logo_url, brand_color=brand_color, cleanup_search_params=['error'], content=f''' {render.title('Verify email', join='for', app_name=app_name)} {content} ''', ) def render_email_verification_page_link_flow( *, is_valid: bool, error_messages: list[str], verification_token: Optional[str] = None, # config app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, ) -> bytes: resend_url = None if verification_token: verification_token = html.escape(verification_token) resend_url = ( f"resend-verification?verification_token={verification_token}" ) if not is_valid: messages = ''.join( [render.error_message(error) for error in error_messages] ) content = f''' {messages} { ( f'Try sending another verification' 'email' ) if resend_url else '' } ''' else: content = ''' Email has been successfully verified. You may now sign in ''' return render.base_page( title=f'Verify email{f" for {app_name}" if app_name else ""}', logo_url=logo_url, dark_logo_url=dark_logo_url, brand_color=brand_color, cleanup_search_params=['error'], content=f''' {render.title('Verify email', join='for', app_name=app_name)} {content} ''', ) def render_email_verification_expired_page( verification_token: str, # config app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, ) -> bytes: verification_token = html.escape(verification_token) content = render.error_message( f''' Your verification token has expired. Click here to resend the verification email ''', False, ) return render.base_page( title=f'Verification expired{f" for {app_name}" if app_name else ""}', logo_url=logo_url, dark_logo_url=dark_logo_url, brand_color=brand_color, cleanup_search_params=['error'], content=f''' { render.title('Verification expired', join='for', app_name=app_name) } {content} ''', ) def render_resend_verification_done_page( *, is_valid: bool, verification_token: Optional[str] = None, # config app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, ) -> bytes: if verification_token is None: content = render.error_message( f""" Missing verification token, please follow the link provided in the original email, or on the signin page. """, False, ) else: verification_token = html.escape(verification_token) if is_valid: content = f''' Your verification email has been resent. Please check your email. ''' else: content = f''' Unable to resend verification email. Please try again. ''' return render.base_page( title=( f'Email verification resent{f" for {app_name}" if app_name else ""}' ), logo_url=logo_url, dark_logo_url=dark_logo_url, brand_color=brand_color, cleanup_search_params=['error'], content=f''' { render.title( 'Email verification resent', join='for', app_name=app_name ) } {content} ''', ) def render_magic_link_sent_page_code_flow( *, email: str, challenge: str, callback_url: str, error_message: Optional[str] = None, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, ) -> bytes: content = f''' {render.error_message(error_message)}

We've sent a 6-digit sign-in code to { html.escape(email) }

{ render.code_input_form( action="../magic-link/authenticate", email=email, provider="builtin::local_magic_link", label="Enter sign-in code", button_text="Sign In", additional_fields=f''' ''', ) } ''' title = f'Sign in code sent{f" for {app_name}" if app_name else ""}' page_title = 'Sign in code sent' return render.base_page( title=title, logo_url=logo_url, dark_logo_url=dark_logo_url, brand_color=brand_color, cleanup_search_params=['error'], content=f''' {render.title(page_title, join='for', app_name=app_name)} {content} ''', ) def render_magic_link_sent_page_link_flow( *, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = None, ) -> bytes: content = render.success_message( "A sign in link has been sent to your email. Please check your " "email." ) title = f'Sign in link sent{f" for {app_name}" if app_name else ""}' page_title = 'Sign in link sent' return render.base_page( title=title, logo_url=logo_url, dark_logo_url=dark_logo_url, brand_color=brand_color, cleanup_search_params=['error'], content=f''' {render.title(page_title, join='for', app_name=app_name)} {content} ''', ) # emails def render_password_reset_email( *, to_addr: str, reset_url: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = render.DEFAULT_BRAND_COLOR, ) -> email.message.EmailMessage: brand_color = brand_color or render.DEFAULT_BRAND_COLOR msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = "Reset password" plain_text_content = f""" Somebody requested a new password for the {app_name or ''} account associated with {to_addr}. Please paste the following URL into your browser address bar to verify your email address: {reset_url} """ html_content = f"""
Somebody requested a new password for the {app_name or ''} account associated with {to_addr}.
No changes have been made to your account yet.
You can reset your password by clicking the button below:
Reset your password
In case the button didn't work, please paste the following URL into your browser address bar:

{reset_url}

If you did not request a new password, please let us know immediately by replying to this email.
""" # noqa: E501 msg["X-gel-password-reset-url"] = reset_url msg.set_content(plain_text_content, subtype="plain") msg.add_alternative( render.base_default_email( content=html_content, app_name=app_name, logo_url=logo_url, ), subtype="html", ) return msg def render_verification_email( *, to_addr: str, verify_url: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = render.DEFAULT_BRAND_COLOR, ) -> email.message.EmailMessage: brand_color = brand_color or render.DEFAULT_BRAND_COLOR msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = ( f"Verify your email{f' for {app_name}' if app_name else ''}" ) plain_text_content = f""" Congratulations, you're registered{f' at {app_name}' if app_name else ''}! Please paste the following URL into your browser address bar to verify your email address: {verify_url} """ html_content = f"""
Congratulations, you're registered {f'at {app_name}' if app_name else ''}!
Please press the button below to verify your email address:
Verify email address
In case the button didn't work, please paste the following URL into your browser address bar:

{verify_url}

""" msg["X-gel-email-verify-url"] = verify_url msg.set_content(plain_text_content, subtype="plain") msg.add_alternative( render.base_default_email( content=html_content, app_name=app_name, logo_url=logo_url, ), subtype="html", ) return msg def render_magic_link_email( *, to_addr: str, link: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = render.DEFAULT_BRAND_COLOR, ) -> email.message.EmailMessage: brand_color = brand_color or render.DEFAULT_BRAND_COLOR msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = f"Sign in{f' to {app_name}' if app_name else ''}" plain_text_content = f""" Please paste the following URL into your browser address bar to be signed into your account: {link} """ html_content = f"""
Click the button below to sign in{f' to {app_name}' if app_name else ''}:
Sign in
In case the button didn't work, please paste the following URL into your browser address bar:

{link}

""" # noqa: E501 msg["X-gel-magic-link"] = link msg.set_content(plain_text_content, subtype="plain") msg.add_alternative( render.base_default_email( content=html_content, app_name=app_name, logo_url=logo_url, ), subtype="html", ) return msg def render_one_time_code_email( *, to_addr: str, code: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = render.DEFAULT_BRAND_COLOR, ) -> email.message.EmailMessage: """Renders an email containing a one-time verification code.""" brand_color = brand_color or render.DEFAULT_BRAND_COLOR msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = ( f"Your verification code{f' for {app_name}' if app_name else ''}" ) plain_text_content = f""" Your verification code{f' for {app_name}' if app_name else ''} is: {code} This code will expire in 10 minutes. """ html_content = f"""
Your verification code{f' for {app_name}' if app_name else ''} is:
{code}
This code will expire in 10 minutes for your security.
""" # noqa: E501 msg["X-gel-email-verify-code"] = code msg.set_content(plain_text_content, subtype="plain") msg.add_alternative( render.base_default_email( content=html_content, app_name=app_name, logo_url=logo_url, ), subtype="html", ) return msg def render_password_reset_code_email( *, to_addr: str, code: str, app_name: Optional[str] = None, logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = render.DEFAULT_BRAND_COLOR, ) -> email.message.EmailMessage: """Renders an email containing a one-time code for password reset.""" brand_color = brand_color or render.DEFAULT_BRAND_COLOR msg = email.message.EmailMessage() msg["To"] = to_addr msg["Subject"] = ( f"Password reset code{f' for {app_name}' if app_name else ''}" ) plain_text_content = f""" Your password reset code{f' for {app_name}' if app_name else ''} is: {code} This code will expire in 10 minutes. If you didn't request a password reset, you can safely ignore this email. """ # noqa: E501 html_content = f"""
Your password reset code{f' for {app_name}' if app_name else ''} is:
{code}
This code will expire in 10 minutes for your security. If you didn't request a password reset, you can safely ignore this email.
""" # noqa: E501 msg["X-gel-password-reset-code"] = code msg.set_content(plain_text_content, subtype="plain") msg.add_alternative( render.base_default_email( content=html_content, app_name=app_name, logo_url=logo_url, ), subtype="html", ) return msg ================================================ FILE: edb/server/protocol/auth_ext/ui/components.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional, TYPE_CHECKING import html import re import urllib.parse from . import util if TYPE_CHECKING: from edb.server.protocol.auth_ext import config as auth_config known_oauth_provider_names = [ 'builtin::oauth_github', 'builtin::oauth_google', 'builtin::oauth_apple', 'builtin::oauth_azure', 'builtin::oauth_discord', 'builtin::oauth_slack', ] DEFAULT_BRAND_COLOR = "1f8aed" def base_page( *, content: str, title: str, cleanup_search_params: list[str], logo_url: Optional[str] = None, dark_logo_url: Optional[str] = None, brand_color: Optional[str] = DEFAULT_BRAND_COLOR, ) -> bytes: logo = '' if logo_url: logo = '' cleanup_script = ( f'''''' if len(cleanup_search_params) > 0 else '' ) if ( brand_color is None or util.hex_color_regexp.fullmatch(brand_color) is None ): brand_color = DEFAULT_BRAND_COLOR return f''' {html.escape(title)} {cleanup_script} {logo}
{content}
'''.encode() def script(name: str) -> str: return f'' def title(title: str, *, app_name: Optional[str], join: str = 'to') -> str: if app_name is None: return f'''

{title}

''' return f'''

{title} {join} {html.escape(app_name)}

''' def oauth_buttons( *, redirect_to: str, challenge: str, redirect_to_on_signup: Optional[str], oauth_providers: list[auth_config.OAuthProviderConfig], label_prefix: str, collapsed: bool, ) -> str: if len(oauth_providers) == 0: return '' oauth_params = { 'redirect_to': redirect_to, 'challenge': challenge, } if redirect_to_on_signup: oauth_params['redirect_to_on_signup'] = redirect_to_on_signup buttons = '\n'.join( [ _oauth_button(p, oauth_params, label_prefix=label_prefix) for p in sorted(oauth_providers, key=lambda p: p.name) ] ) return f''' ''' def _oauth_button( provider: auth_config.OAuthProviderConfig, params: dict[str, str], *, label_prefix: str, ) -> str: href = '../authorize?' + urllib.parse.urlencode( {'provider': provider.name, **params} ) if ( provider.name.startswith('builtin::') and provider.name in known_oauth_provider_names ): img = f'''{provider.display_name} Icon''' elif provider.logo_url is not None: img = f'''{provider.display_name} Icon''' else: img = '' label = f'{label_prefix} {provider.display_name}' return f''' {img} {label} ''' def button( text: Optional[str], *, id: Optional[str] = None, secondary: Optional[bool] = False, type: Optional[str] = 'submit', ) -> str: classes = [] if secondary: classes.append('secondary') if text is None: classes.append('icon-only') attrs = f'type="{type}"' if id: attrs += f' id="{id}"' if len(classes): attrs += f' class="{" ".join(classes)}"' return f''' ''' divider = '''
or
''' def _slugify_label(label: str) -> str: slug = label.lower().strip() slug = re.sub(r"[^a-z0-9]+", "-", slug) slug = re.sub(r"(^-|-$)", "", slug) return slug or "section" def tabs_content( sections: list[str], selected_tab: int, labels: Optional[list[str]] = None ) -> str: content = '' for i, section in enumerate(sections): active = selected_tab == i aria_attrs = '' if labels is not None and i < len(labels): slug = _slugify_label(labels[i]) aria_attrs = ( ' role="tabpanel" ' f'id="panel-{slug}" aria-labelledby="tab-{slug}"' ) hidden_attrs = ' aria-hidden="true" hidden' if not active else '' else: hidden_attrs = '' if active else '' content += f'''
{section}
''' style = ( f'style="transform: translateX({-100 * selected_tab}%)"' if selected_tab > 0 else '' ) return f'''
{content}
''' _tab_underline = ''' ''' def tabs_buttons(labels: list[str], selected_tab: int) -> str: content = '' for i, label in enumerate(labels): active = selected_tab == i slug = _slugify_label(label) aria_selected = 'true' if active else 'false' tabindex = '0' if active else '-1' content += f''' ''' return f'''
{content}
''' def hidden_input( *, name: str, value: str, secondary_value: Optional[str] = None ) -> str: return f'''''' def bottom_note(message: str, *, link: str, href: str) -> str: return f"""
{message} {link}
""" def error_message(message: Optional[str], escape: bool = True) -> str: if message is None: return '' return f'''
{html.escape(message) if escape else message}
''' def success_message(message: str) -> str: return f'''
{message}
''' def code_input_form( *, action: str, email: str, provider: str, label: str = "Enter verification code", button_text: str = "Verify Code", additional_fields: str = "", ) -> str: """Renders a code input form with auto-formatting and mobile keyboard support.""" return f'''
{additional_fields} {button(button_text)}
''' def base_default_email( *, content: str, app_name: Optional[str], logo_url: Optional[str], ) -> str: logo_html = ( f"""

                                      {f'{app_name} logo' if app_name else ''}
""" # noqa: E501 if logo_url else "" ) return f"""
{logo_html}
{content}
""" # noqa: E501 # Form Component Helpers # ===================== def get_magic_link_tab_label(verification_method: str) -> str: return "Email Code" if verification_method == "Code" else "Email Link" def get_magic_link_button_text(verification_method: str) -> str: return ( "Email sign in code" if verification_method == "Code" else "Email sign in link" ) def get_email_password_signup_redirect_url( verification_method: str, base_path: str, fallback_redirect: str ) -> str: if verification_method == "Code": return f"{base_path}/ui/verify?provider=builtin::local_emailpassword" else: return fallback_redirect def get_webauthn_signup_redirect_url( verification_method: str, base_path: str, fallback_redirect: str ) -> str: if verification_method == "Code": return f"{base_path}/ui/verify?provider=builtin::local_webauthn" else: return fallback_redirect def get_password_reset_redirect_url( verification_method: str, base_path: str, challenge: str ) -> str: if verification_method == "Code": return f"{base_path}/ui/reset-password" else: return f"{base_path}/ui/forgot-password?challenge={challenge}" def get_send_button_text(verification_method: str) -> str: return "Send Code" if verification_method == "Code" else "Send Link" def get_verification_method_label(verification_method: str) -> str: return "Email Code" if verification_method == "Code" else "Email Link" def render_base_email_form( *, id: str, challenge: str, email: str | None = None ) -> str: return f""" """ def render_password_input( *, challenge: str, should_show_forgot_password: bool ) -> str: forgot_password_link = ( f""" Forgot password? """ if should_show_forgot_password else '' ) return f"""
{forgot_password_link}
""" def render_password_form( *, challenge: str, email: str | None = None, redirect_to: str, base_path: str, provider_name: str, ) -> str: return f"""
{render_base_email_form( id="password-email", challenge=challenge, email=email )} {render_password_input( challenge=challenge, should_show_forgot_password=True, )} {button("Sign In", id="password-signin")}
""" def render_webauthn_form( *, challenge: str, email: str | None = None, redirect_to: str, base_path: str, provider_name: str, ) -> str: """Render a complete WebAuthn authentication form.""" return f"""
{render_base_email_form( id="webauthn-email", challenge=challenge, email=email )} {button("Sign In", id="webauthn-signin")}
""" def render_magic_link_form( *, challenge: str, email: str | None = None, base_path: str, provider_name: str, callback_url: str | None = None, verification_method: str = "Link", ) -> str: button_text = get_magic_link_button_text(verification_method) callback_field = ( f''' ''' if verification_method == "Link" else "" ) return f"""
{callback_field} {render_base_email_form( id="magic-link-email", challenge=challenge, email=email )} {button(button_text, id="magic-link-signin")}
""" # Signup-specific form helpers # =========================== def render_password_signup_form( *, challenge: str, email: str | None = None, redirect_to: str, base_path: str, provider_name: str, ) -> str: return f"""
{render_base_email_form( id="password-email", challenge=challenge, email=email )} {render_password_input( challenge=challenge, should_show_forgot_password=False, )} {button("Sign Up", id="password-signup")}
""" def render_webauthn_signup_form( *, challenge: str, email: str | None = None, redirect_to: str, base_path: str, provider_name: str, ) -> str: """Render a complete WebAuthn signup form.""" return f"""
{render_base_email_form( id="webauthn-email", challenge=challenge, email=email )} {button("Sign Up", id="webauthn-signup")}
""" def render_magic_link_signup_form( *, challenge: str, email: str | None = None, base_path: str, provider_name: str, callback_url: str | None = None, verification_method: str = "Link", ) -> str: """Render a complete magic link/OTC signup form.""" tab_label = get_magic_link_tab_label(verification_method) callback_field = ( f""" """ if verification_method == "Link" else "" ) return f"""
{callback_field} {render_base_email_form( id="magic-link-email", challenge=challenge, email=email )} {button(f"Sign Up with {tab_label}", id="magic-link-signup")}
""" ================================================ FILE: edb/server/protocol/auth_ext/ui/util.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import re # Colour utils hex_color_regexp = re.compile(r'[0-9a-fA-F]{6}') def get_colour_vars(bg_hex: str) -> str: bg_rgb = hex_to_rgb(bg_hex) bg_hsl = rgb_to_hsl(*bg_rgb) luma = rgb_to_luma(*bg_rgb) luma_dark = luma < 0.6 text_color = hsl_to_rgb( bg_hsl[0], bg_hsl[1], min(90 if luma_dark else 35, bg_hsl[2]) ) dark_text_color = hsl_to_rgb(bg_hsl[0], bg_hsl[1], max(60, bg_hsl[2])) return f'''--accent-bg-color: #{bg_hex}; --accent-bg-text-color: #{rgb_to_hex( *hsl_to_rgb( bg_hsl[0], bg_hsl[1], 95 if luma_dark else max(10, min(25, luma * 100 - 60)) ) )}; --accent-bg-hover-color: #{rgb_to_hex( *hsl_to_rgb( bg_hsl[0], bg_hsl[1], bg_hsl[2] + (5 if luma_dark else -5) ) )}; --accent-text-color: #{rgb_to_hex(*text_color)}; --accent-text-dark-color: #{rgb_to_hex(*dark_text_color)}; --accent-focus-color: rgba({','.join( str(c) for c in text_color)},0.6); --accent-focus-dark-color: rgba({','.join( str(c) for c in dark_text_color)},0.6);''' def hex_to_rgb(hex: str) -> tuple[float, float, float]: return ( int(hex[0:2], base=16), int(hex[2:4], base=16), int(hex[4:6], base=16), ) def rgb_to_hex(r: float, g: float, b: float) -> str: return '%02x%02x%02x' % (int(r), int(g), int(b)) def rgb_to_luma(r: float, g: float, b: float) -> float: return (r * 0.299 + g * 0.587 + b * 0.114) / 255 def rgb_to_hsl(r: float, g: float, b: float) -> tuple[float, float, float]: r /= 255 g /= 255 b /= 255 l = max(r, g, b) s = l - min(r, g, b) h = ( ( ((g - b) / s) if l == r else (2 + (b - r) / s) if l == g else (4 + (r - g) / s) ) if s != 0 else 0 ) return ( 60 * h + 360 if 60 * h < 0 else 60 * h, 100 * ( (s / (2 * l - s) if l <= 0.5 else s / (2 - (2 * l - s))) if s != 0 else 0 ), (100 * (2 * l - s)) / 2, ) def hsl_to_rgb(h: float, s: float, l: float) -> tuple[float, float, float]: s /= 100 l /= 100 k = lambda n: (n + h / 30) % 12 a = s * min(l, 1 - l) f = lambda n: l - a * max(-1, min(k(n) - 3, min(9 - k(n), 1))) return ( round(255 * f(0)), round(255 * f(8)), round(255 * f(4)), ) ================================================ FILE: edb/server/protocol/auth_ext/util.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import urllib.parse import html import logging import asyncio from typing import ( overload, Any, cast, Optional, TYPE_CHECKING, Callable, Awaitable, Mapping, ) import immutables from edb.common import retryloop from edb import errors as edb_errors from edb.server import config as edb_config, auth as jwt_auth from edb.server.config.types import CompositeConfigType from edb.server.protocol import execute from . import errors, config if TYPE_CHECKING: from edb.server import tenant as edbtenant from edb.server.dbview import dbview from edb.server import defines as edbdef logger = logging.getLogger('edb.server.ext.auth') # Cache JWKSets for 10 minutes jwtset_cache = jwt_auth.JWKSetCache(60 * 10) async def json_query_no_retry( db: dbview.Database, query: str, *, variables: Mapping[str, Any] = immutables.Map(), tx_isolation: edbdef.TxIsolationLevel | None = None, role_name: str | None = None, ) -> bytes: try: return await execute.parse_execute_json( db, query, variables=variables, tx_isolation=tx_isolation, role_name=role_name, cached_globally=True, query_tag='gel/auth', ) except Exception as e: raise (await execute.interpret_error(e, db)) from None async def json_query( db: dbview.Database, query: str, *, variables: Mapping[str, Any] = immutables.Map(), tx_isolation: edbdef.TxIsolationLevel | None = None, role_name: str | None = None, retry_timeout: float = 5.0, ) -> bytes: # TODO: Should we move the retry into a function in execute instead? rloop = retryloop.RetryLoop( backoff=retryloop.exp_backoff(), timeout=retry_timeout, ignore=(edb_errors.TransactionConflictError,), ) async for iteration in rloop: async with iteration: return await json_query_no_retry( db, query, variables=variables, tx_isolation=tx_isolation, role_name=role_name, ) raise AssertionError('retryloop is broken') def maybe_get_config_unchecked(db: edbtenant.dbview.Database, key: str) -> Any: return edb_config.lookup(key, db.db_config, spec=db.user_config_spec) @overload def maybe_get_config[T]( db: Any, key: str, expected_type: type[T] ) -> T | None: ... @overload def maybe_get_config(db: Any, key: str) -> str | None: ... def maybe_get_config( db: Any, key: str, expected_type: type[object] = str ) -> object: value = maybe_get_config_unchecked(db, key) if value is None: return None if not isinstance(value, expected_type): raise TypeError( f"Config value `{key}` must be {expected_type.__name__}, got " f"{type(value).__name__}" ) return value @overload def get_config[T](db: Any, key: str, expected_type: type[T]) -> T: ... @overload def get_config(db: Any, key: str) -> str: ... def get_config(db: Any, key: str, expected_type: type[object] = str) -> object: value = maybe_get_config(db, key, expected_type) if value is None: raise errors.MissingConfiguration( key=key, description="Missing configuration value", ) return value def get_config_unchecked(db: Any, key: str) -> Any: value = maybe_get_config_unchecked(db, key) if value is None: raise errors.MissingConfiguration( key=key, description="Missing configuration value", ) return value def get_config_typename(config_value: edb_config.SettingValue) -> str: return config_value._tspec.name # type: ignore def escape_and_truncate(input_str: str | None, max_len: int) -> str | None: if input_str is None: return None trunc = ( f"{input_str[:max_len]}..." if len(input_str) > max_len else input_str ) return html.escape(trunc) def get_app_details_config(db: Any) -> config.AppDetailsConfig: ui_config = cast( Optional[config.UIConfig], maybe_get_config(db, "ext::auth::AuthConfig::ui", CompositeConfigType), ) return config.AppDetailsConfig( app_name=escape_and_truncate( maybe_get_config(db, "ext::auth::AuthConfig::app_name") or (ui_config.app_name if ui_config else None), 100, ), logo_url=escape_and_truncate( maybe_get_config(db, "ext::auth::AuthConfig::logo_url") or (ui_config.logo_url if ui_config else None), 2000, ), dark_logo_url=escape_and_truncate( maybe_get_config(db, "ext::auth::AuthConfig::dark_logo_url") or (ui_config.dark_logo_url if ui_config else None), 2000, ), brand_color=escape_and_truncate( maybe_get_config(db, "ext::auth::AuthConfig::brand_color") or (ui_config.brand_color if ui_config else None), 8, ), ) def join_url_params(url: str, params: dict[str, str]) -> str: parsed_url = urllib.parse.urlparse(url) query_params = { **urllib.parse.parse_qs(parsed_url.query), **{key: [val] for key, val in params.items()}, } new_query_params = urllib.parse.urlencode(query_params, doseq=True) return parsed_url._replace(query=new_query_params).geturl() async def get_remote_jwtset( url: str, fetch_lambda: Callable[[str], Awaitable[jwt_auth.JWKSet]], ) -> jwt_auth.JWKSet: """ Get a JWKSet from the cache, or fetch it from the given URL if it's not in the cache. """ is_fresh, jwtset = jwtset_cache.get(url) match (is_fresh, jwtset): case (_, None): jwtset = await fetch_lambda(url) jwtset_cache.set(url, jwtset) case (True, jwtset): pass case _: # Run fetch in background to refresh cache async def refresh_cache(url: str) -> None: try: new_jwtset = await fetch_lambda(url) jwtset_cache.set(url, new_jwtset) except Exception: logger.exception( f"Failed to refresh JWKSet cache for {url}" ) asyncio.create_task(refresh_cache(url)) assert jwtset is not None return jwtset ================================================ FILE: edb/server/protocol/auth_ext/webauthn.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import dataclasses import base64 import json import webauthn from typing import Optional, TYPE_CHECKING from webauthn.helpers import ( parse_authentication_credential_json, structs as webauthn_structs, exceptions as webauthn_exceptions, ) from edb.errors import ConstraintViolationError from . import config, data, errors, util, local if TYPE_CHECKING: from edb.server import tenant as edbtenant @dataclasses.dataclass(repr=False) class WebAuthnRegistrationChallenge: """ Object that represents the ext::auth::WebAuthnRegistrationChallenge type """ id: str challenge: bytes user_handle: bytes email: str class Client(local.Client): def __init__(self, db: edbtenant.dbview.Database): self.db = db self.provider = self._get_provider() self.app_name = self._get_app_name() self.config = self._get_provider_config("builtin::local_webauthn") def _get_provider(self) -> config.WebAuthnProvider: provider_name = "builtin::local_webauthn" provider_client_config = util.get_config( self.db, "ext::auth::AuthConfig::providers", frozenset ) for cfg in provider_client_config: if cfg.name == provider_name: return config.WebAuthnProvider( name=cfg.name, relying_party_origin=cfg.relying_party_origin, require_verification=cfg.require_verification, verification_method=cfg.verification_method, ) raise errors.MissingConfiguration( provider_name, f"Provider is not configured" ) def _get_app_name(self) -> Optional[str]: app_config = util.get_app_details_config(self.db) return app_config.app_name async def create_registration_options_for_email( self, email: str, ) -> tuple[str, bytes]: maybe_user_handle = await self._maybe_get_existing_user_handle( email=email ) registration_options = webauthn.generate_registration_options( rp_id=self.provider.relying_party_id, rp_name=(self.app_name or self.provider.relying_party_origin), user_name=email, user_display_name=email, user_id=maybe_user_handle, ) await self._create_registration_challenge( email=email, challenge=registration_options.challenge, user_handle=registration_options.user.id, ) return ( base64.urlsafe_b64encode(registration_options.user.id).decode(), webauthn.options_to_json(registration_options).encode(), ) async def _maybe_get_existing_user_handle( self, email: str, ) -> Optional[bytes]: result = await util.json_query( self.db, """ with email := $email, factors := ( select ext::auth::WebAuthnFactor filter .email = email ), select assert_single((select distinct factors.user_handle));""", variables={ "email": email, }, ) result_json = json.loads(result.decode()) if len(result_json) == 0: return None else: return base64.b64decode(result_json[0]) async def _create_registration_challenge( self, email: str, challenge: bytes, user_handle: bytes, ) -> None: await util.json_query( self.db, """ with challenge := $challenge, user_handle := $user_handle, email := $email, insert ext::auth::WebAuthnRegistrationChallenge { challenge := challenge, user_handle := user_handle, email := email, }""", variables={ "challenge": challenge, "user_handle": user_handle, "email": email, }, ) async def register( self, credentials: str, email: str, user_handle: bytes, ) -> data.EmailFactor: registration_challenge = await self._get_registration_challenge( email=email, user_handle=user_handle, ) await self._delete_registration_challenges( email=email, user_handle=user_handle, ) registration_verification = webauthn.verify_registration_response( credential=credentials, expected_challenge=registration_challenge.challenge, expected_rp_id=self.provider.relying_party_id, expected_origin=self.provider.relying_party_origin, ) try: result = await util.json_query( self.db, """ with email := $email, user_handle := $user_handle, credential_id := $credential_id, public_key := $public_key, identity := (insert ext::auth::LocalIdentity { issuer := "local", subject := "", }), factor := (insert ext::auth::WebAuthnFactor { email := email, user_handle := user_handle, credential_id := credential_id, public_key := public_key, identity := identity, }), select factor { ** };""", variables={ "email": email, "user_handle": user_handle, "credential_id": registration_verification.credential_id, "public_key": ( registration_verification.credential_public_key ), }, ) except ConstraintViolationError: raise errors.UserAlreadyRegistered() result_json = json.loads(result.decode()) assert len(result_json) == 1 factor_dict = result_json[0] local_identity = data.LocalIdentity(**factor_dict.pop("identity")) return data.WebAuthnFactor(**factor_dict, identity=local_identity) async def _get_registration_challenge( self, email: str, user_handle: bytes, ) -> WebAuthnRegistrationChallenge: result = await util.json_query( self.db, """ with email := $email, user_handle := $user_handle, select ext::auth::WebAuthnRegistrationChallenge { id, challenge, user_handle, email, } filter .email = email and .user_handle = user_handle;""", variables={ "email": email, "user_handle": user_handle, }, ) result_json = json.loads(result.decode()) assert len(result_json) == 1 challenge_dict = result_json[0] return WebAuthnRegistrationChallenge( id=challenge_dict["id"], challenge=base64.b64decode(challenge_dict["challenge"]), user_handle=base64.b64decode(challenge_dict["user_handle"]), email=challenge_dict["email"], ) async def _delete_registration_challenges( self, email: str, user_handle: bytes, ) -> None: await util.json_query( self.db, """ with email := $email, user_handle := $user_handle, delete ext::auth::WebAuthnRegistrationChallenge filter .email = email and .user_handle = user_handle;""", variables={ "email": email, "user_handle": user_handle, }, ) async def create_authentication_options_for_email( self, *, webauthn_provider: config.WebAuthnProvider, email: str, ) -> tuple[str, bytes]: # Find credential IDs by email result = await util.json_query( self.db, """ select ext::auth::WebAuthnFactor { user_handle, credential_id, } filter .email = $email;""", variables={ "email": email, }, ) result_json = json.loads(result.decode()) if len(result_json) == 0: raise errors.WebAuthnAuthenticationFailed( "No WebAuthn credentials found for this email." ) user_handles: set[str] = {x["user_handle"] for x in result_json} assert ( len(user_handles) == 1 ), "Found WebAuthn multiple user handles for the same email." user_handle = base64.b64decode(result_json[0]["user_handle"]) credential_ids = [ webauthn_structs.PublicKeyCredentialDescriptor( base64.b64decode(x["credential_id"]) ) for x in result_json ] registration_options = webauthn.generate_authentication_options( rp_id=webauthn_provider.relying_party_id, allow_credentials=credential_ids, ) await util.json_query( self.db, """ with challenge := $challenge, user_handle := $user_handle, email := $email, factors := ( assert_exists(( select ext::auth::WebAuthnFactor filter .user_handle = user_handle and .email = email )) ) insert ext::auth::WebAuthnAuthenticationChallenge { challenge := challenge, factors := factors, } unless conflict on .factors else ( update ext::auth::WebAuthnAuthenticationChallenge set { challenge := challenge } );""", variables={ "challenge": registration_options.challenge, "user_handle": user_handle, "email": email, }, ) return ( base64.urlsafe_b64encode(user_handle).decode(), webauthn.options_to_json(registration_options).encode(), ) async def is_email_verified( self, email: str, assertion: str, ) -> bool: credential = parse_authentication_credential_json(assertion) result = await util.json_query( self.db, """ with email := $email, credential_id := $credential_id, factor := assert_single(( select ext::auth::WebAuthnFactor filter .email = email and credential_id = credential_id )), select (factor.verified_at <= std::datetime_current()) ?? false;""", variables={ "email": email, "credential_id": credential.raw_id, }, ) result_json = json.loads(result.decode()) return bool(result_json[0]) async def _get_authentication_challenge( self, email: str, credential_id: bytes, ) -> data.WebAuthnAuthenticationChallenge: result = await util.json_query( self.db, """ with email := $email, credential_id := $credential_id, select ext::auth::WebAuthnAuthenticationChallenge { id, created_at, modified_at, challenge, factors: { id, created_at, modified_at, email, verified_at, user_handle, credential_id, public_key, identity: { created_at, modified_at, id, issuer, subject, } }, } filter .factors.email = email and .factors.credential_id = credential_id;""", variables={ "email": email, "credential_id": credential_id, }, ) result_json = json.loads(result.decode()) if len(result_json) == 0: raise errors.WebAuthnAuthenticationFailed( "Could not find a challenge. Please retry authentication." ) elif len(result_json) > 1: raise errors.WebAuthnAuthenticationFailed( "Multiple challenges found. Please retry authentication." ) return data.WebAuthnAuthenticationChallenge(**result_json[0]) async def _delete_authentication_challenges( self, email: str, credential_id: bytes, ) -> None: await util.json_query( self.db, """ with email := $email, credential_id := $credential_id, delete ext::auth::WebAuthnAuthenticationChallenge filter .factors.email = email and .factors.credential_id = credential_id;""", variables={ "email": email, "credential_id": credential_id, }, ) async def authenticate( self, *, email: str, assertion: str, ) -> data.LocalIdentity: credential = parse_authentication_credential_json(assertion) authentication_challenge = await self._get_authentication_challenge( email=email, credential_id=credential.raw_id, ) await self._delete_authentication_challenges( email=email, credential_id=credential.raw_id, ) factor = next( ( f for f in authentication_challenge.factors if f.credential_id == credential.raw_id ), None, ) assert factor is not None, "Missing factor for the given credential." try: webauthn.verify_authentication_response( credential=credential, expected_challenge=authentication_challenge.challenge, credential_public_key=factor.public_key, credential_current_sign_count=0, expected_rp_id=self.provider.relying_party_id, expected_origin=self.provider.relying_party_origin, ) except webauthn_exceptions.InvalidAuthenticationResponse: raise errors.WebAuthnAuthenticationFailed( "Invalid authentication response. Please retry authentication." ) return factor.identity async def get_email_factor_by_credential_id( self, credential_id: bytes, ) -> Optional[data.EmailFactor]: result = await util.json_query( self.db, """ with credential_id := $credential_id, select ext::auth::WebAuthnFactor { id, created_at, modified_at, email, verified_at, identity: {*}, } filter .credential_id = credential_id;""", variables={ "credential_id": credential_id, }, ) result_json = json.loads(result.decode()) if len(result_json) == 0: return None elif len(result_json) > 1: # This should never happen given the exclusive constraint raise errors.WebAuthnAuthenticationFailed( "Multiple WebAuthn factors found for the same credential ID." ) return data.EmailFactor(**result_json[0]) def _get_provider_config( self, provider_name: str ) -> config.WebAuthnProvider: provider_client_config = util.get_config( self.db, "ext::auth::AuthConfig::providers", frozenset ) for cfg in provider_client_config: if cfg.name == provider_name: return config.WebAuthnProvider( name=cfg.name, relying_party_origin=cfg.relying_party_origin, require_verification=cfg.require_verification, verification_method=cfg.verification_method, ) raise errors.MissingConfiguration( provider_name, f"Provider is not configured" ) ================================================ FILE: edb/server/protocol/auth_ext/webhook.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import dataclasses import typing import abc import datetime import json import hmac import hashlib import uuid from . import util if typing.TYPE_CHECKING: from edb.server import tenant as edbtenant @dataclasses.dataclass class Event(abc.ABC): event_type: str event_id: str timestamp: datetime.datetime def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"timestamp={self.timestamp!r}, " f"event_id={self.event_id!r}" ")" ) @dataclasses.dataclass class HasIdentity(abc.ABC): identity_id: str @dataclasses.dataclass class HasEmailFactor(abc.ABC): email_factor_id: str @dataclasses.dataclass class IdentityCreated(Event, HasIdentity): event_type: typing.Literal["IdentityCreated"] = dataclasses.field( default="IdentityCreated", init=False, ) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"timestamp={self.timestamp}, " f"event_id={self.event_id}, " f"identity_id={self.identity_id}" ")" ) @dataclasses.dataclass class IdentityAuthenticated(Event, HasIdentity): event_type: typing.Literal["IdentityAuthenticated"] = dataclasses.field( default="IdentityAuthenticated", init=False, ) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"timestamp={self.timestamp}, " f"event_id={self.event_id}, " f"identity_id={self.identity_id}" ")" ) @dataclasses.dataclass class EmailFactorCreated(Event, HasIdentity, HasEmailFactor): event_type: typing.Literal["EmailFactorCreated"] = dataclasses.field( default="EmailFactorCreated", init=False, ) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"timestamp={self.timestamp}, " f"event_id={self.event_id}, " f"identity_id={self.identity_id}, " f"email_factor_id={self.email_factor_id}" ")" ) @dataclasses.dataclass class EmailVerificationRequested(Event, HasIdentity, HasEmailFactor): event_type: typing.Literal["EmailVerificationRequested"] = ( dataclasses.field( default="EmailVerificationRequested", init=False, ) ) verification_token: str def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"timestamp={self.timestamp}, " f"event_id={self.event_id}, " f"identity_id={self.identity_id}, " f"email_factor_id={self.email_factor_id}" ")" ) @dataclasses.dataclass class EmailVerified(Event, HasIdentity, HasEmailFactor): event_type: typing.Literal["EmailVerified"] = dataclasses.field( default="EmailVerified", init=False, ) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"timestamp={self.timestamp}, " f"event_id={self.event_id}, " f"identity_id={self.identity_id}, " f"email_factor_id={self.email_factor_id}" ")" ) @dataclasses.dataclass class PasswordResetRequested(Event, HasIdentity, HasEmailFactor): event_type: typing.Literal["PasswordResetRequested"] = dataclasses.field( default="PasswordResetRequested", init=False, ) reset_token: str def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"timestamp={self.timestamp}, " f"event_id={self.event_id}, " f"identity_id={self.identity_id}, " f"email_factor_id={self.email_factor_id}" ")" ) @dataclasses.dataclass class MagicLinkRequested(Event, HasIdentity, HasEmailFactor): event_type: typing.Literal["MagicLinkRequested"] = dataclasses.field( default="MagicLinkRequested", init=False, ) magic_link_token: str magic_link_url: str def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"timestamp={self.timestamp}, " f"event_id={self.event_id}, " f"identity_id={self.identity_id}, " f"email_factor_id={self.email_factor_id}" ")" ) @dataclasses.dataclass class OneTimeCodeRequested(Event, HasIdentity, HasEmailFactor): event_type: typing.Literal["OneTimeCodeRequested"] = dataclasses.field( default="OneTimeCodeRequested", init=False, ) otc_id: str one_time_code: str def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"timestamp={self.timestamp}, " f"event_id={self.event_id}, " f"identity_id={self.identity_id}, " f"email_factor_id={self.email_factor_id}" ")" ) @dataclasses.dataclass class OneTimeCodeVerified(Event, HasIdentity, HasEmailFactor): event_type: typing.Literal["OneTimeCodeVerified"] = dataclasses.field( default="OneTimeCodeVerified", init=False, ) otc_id: str def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"timestamp={self.timestamp}, " f"event_id={self.event_id}, " f"identity_id={self.identity_id}, " f"email_factor_id={self.email_factor_id}" ")" ) class WebhookEncoder(json.JSONEncoder): def default(self, obj: typing.Any) -> typing.Any: if isinstance(obj, datetime.datetime): return obj.isoformat() if isinstance(obj, uuid.UUID): return str(obj) return super().default(obj) async def send( db: edbtenant.dbview.Database, url: str, secret: typing.Optional[str], event: Event, ) -> str: body = json.dumps( dataclasses.asdict(event), cls=WebhookEncoder ).encode() headers = [("Content-Type", "application/json")] if secret: signature = hmac.new( secret.encode(), body, hashlib.sha256 ).hexdigest() headers.append(("x-ext-auth-signature-sha256", signature)) result_json = await util.json_query( db, """ with nh as module std::net::http, net as module std::net, # n.b. workaround for bug in parse_execute_json url := $url, headers := >>$headers, body := $body, REQUEST := ( nh::schedule_request( url, method := nh::Method.POST, headers := headers, body := body, ) ), select REQUEST; """, variables={ "url": url, "body": body, "headers": headers, }, ) result = json.loads(result_json) match result[0]["id"]: case str(id): return id case _: raise ValueError( "Expected single result with 'id' string property, got " f"{result[0]!r}" ) ================================================ FILE: edb/server/protocol/auth_helpers.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # cdef extract_token_from_auth_data(bytes auth_data) cdef auth_jwt(tenant, prefixed_token, str user, str dbname) cdef scram_get_verifier(tenant, str user) cdef parse_basic_auth(str auth_payload) cdef extract_http_user(scheme, auth_payload, params) cdef auth_basic(tenant, str username, str password) cdef auth_mtls(transport) cdef auth_mtls_with_user(transport, str username) ================================================ FILE: edb/server/protocol/auth_helpers.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Authentication code that is shared between HTTP and binary protocols""" from edgedb import scram import base64 import hashlib import json import logging from edb import errors from edb.server.auth import validate_gel_token cdef object logger = logging.getLogger('edb.server') cdef extract_token_from_auth_data(auth_data: bytes): header_value = auth_data.decode("ascii") scheme, _, payload = header_value.partition(" ") return scheme.lower(), payload.strip() cdef auth_jwt(tenant, prefixed_token: str | None, user: str, dbname: str): if not prefixed_token: raise errors.AuthenticationError( 'authentication failed: no authorization data provided') key = tenant.server.get_jws_key() if err := validate_gel_token(key, prefixed_token, user, dbname, tenant.get_instance_name()): raise errors.AuthenticationError(str(err)) # Ensure it's a valid role, but check after the JWT is validated role = tenant.get_roles().get(user) if role is None: raise errors.AuthenticationError('authentication failed') cdef scram_get_verifier(tenant, user: str): roles = tenant.get_roles() rolerec = roles.get(user) if rolerec is not None: verifier_string = rolerec['password'] if verifier_string is not None: try: verifier = scram.parse_verifier(verifier_string) except ValueError: raise errors.AuthenticationError( f'invalid SCRAM verifier for user {user!r}') from None is_mock = False return verifier, is_mock # To avoid revealing the validity of the submitted user name, # generate a mock verifier using a salt derived from the # received user name and the cluster mock auth nonce. # The same approach is taken by Postgres. nonce = tenant.get_instance_data('mock_auth_nonce') salt = hashlib.sha256(nonce.encode() + user.encode()).digest() verifier = scram.SCRAMVerifier( mechanism='SCRAM-SHA-256', iterations=scram.DEFAULT_ITERATIONS, salt=salt[:scram.DEFAULT_SALT_LENGTH], stored_key=b'', server_key=b'', ) is_mock = True return verifier, is_mock def scram_verify_password(password: str, verifier: object) -> bool: """Check the given password against a verifier. Returns True if the password is OK, False otherwise. """ # adapted from edgedb-python's scram.verify_password but made to # take a verifier object instead of a string bpassword = scram.saslprep(password).encode('utf-8') salted_password = scram.get_salted_password( bpassword, verifier.salt, verifier.iterations) computed_key = scram.get_server_key(salted_password) return verifier.server_key == computed_key cdef parse_basic_auth(auth_payload: str): try: decoded = base64.b64decode(auth_payload).decode('utf-8') except ValueError: raise errors.AuthenticationError( 'authentication failed: malformed authentication') from None username, colon, password = decoded.partition(':') if colon != ':': raise errors.AuthenticationError( 'authentication failed: malformed authentication') return username, password cdef extract_http_user(scheme, auth_payload, params): """Extract the username from an HTTP request. Raises an AuthenticationError if something is too malformed. Returns the username, along with the password, if appropriate. (To avoid needing to parse the packet twice.) """ if scheme == 'basic': return parse_basic_auth(auth_payload) else: # Respect X-EdgeDB-User if present, but otherwise default to 'edgedb' if params and b'user' in params: username = params[b'user'].decode('ascii') else: username = 'edgedb' return username, None cdef auth_basic(tenant, username: str, password: str): verifier, mock_auth = scram_get_verifier(tenant, username) if not scram_verify_password(password, verifier) or mock_auth: raise errors.AuthenticationError('authentication failed') cdef auth_mtls(transport): sslobj = transport.get_extra_info('ssl_object') if sslobj is None: raise errors.AuthenticationError( "mTLS authentication is not supported over plaintext transport") cert_data = sslobj.getpeercert() if not cert_data: # None or empty dict # If --tls-client-ca-file is specified, the SSLContext used here would # have done load_verify_locations() in `server/server.py`, and we will # have a valid client certificate (non-empty dict) now if one was # provided by the client and passed validation; empty dict otherwise. # `None` just means the peer didn't send a client certificate. raise errors.AuthenticationError( "valid client certificate required") return cert_data cdef auth_mtls_with_user(transport, str username): cert_data = auth_mtls(transport) try: for rdn in cert_data["subject"]: if rdn[0][0] == 'commonName': if rdn[0][1] == username: return except Exception as ex: raise errors.AuthenticationError( "bad client certificate") from ex raise errors.AuthenticationError( f"Common Name of client certificate doesn't match {username!r}", ) ================================================ FILE: edb/server/protocol/binary.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # cimport cython cimport cpython from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ int32_t, uint32_t, int64_t, uint64_t from edb.server.pgproto.pgproto cimport ( WriteBuffer, ) from edb.server.dbview cimport dbview from edb.server.pgproto.debug cimport PG_DEBUG from edb.server.protocol cimport frontend cdef enum EdgeSeverity: EDGE_SEVERITY_DEBUG = 20 EDGE_SEVERITY_INFO = 40 EDGE_SEVERITY_NOTICE = 60 EDGE_SEVERITY_WARNING = 80 EDGE_SEVERITY_ERROR = 120 EDGE_SEVERITY_FATAL = 200 EDGE_SEVERITY_PANIC = 255 cdef enum EdgeConnectionStatus: EDGECON_NEW = 0 EDGECON_STARTED = 1 EDGECON_OK = 2 EDGECON_BAD = 3 cdef class EdgeConnection(frontend.FrontendConnection): cdef: EdgeConnectionStatus _con_status readonly dbview.DatabaseConnectionView _dbview object _startup_msg_waiter dbview.CompiledQuery _last_anon_compiled int64_t _last_anon_compiled_hash bint query_cache_enabled tuple protocol_version tuple max_protocol tuple min_protocol object last_state int last_state_id bint _in_dump_restore bytes _auth_data dict _conn_params cdef inline dbview.DatabaseConnectionView get_dbview(self) cdef parse_execute_request(self) cdef parse_cardinality(self, bytes card) cdef char render_cardinality(self, query_unit) except -1 cdef fallthrough(self) cdef sync_status(self) cdef WriteBuffer make_negotiate_protocol_version_msg( self, tuple target_proto ) cdef WriteBuffer make_command_data_description_msg( self, dbview.CompiledQuery query ) cdef WriteBuffer make_state_data_description_msg(self) cdef WriteBuffer make_command_complete_msg(self, capabilities, status) cdef inline ignore_headers(self) cdef dict parse_headers(self) cdef dict parse_annotations(self) cdef inline ignore_annotations(self) cdef get_checked_tag(self, dict annotations) cdef write_status(self, bytes name, bytes value) cdef write_edgedb_error(self, exc) cdef write_log(self, EdgeSeverity severity, uint32_t code, str message) @cython.final cdef class VirtualTransport: cdef: WriteBuffer buf bint closed object transport ================================================ FILE: edb/server/protocol/binary.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import asyncio import base64 import collections import contextlib import json import logging import time import statistics import traceback import sys cimport cython cimport cpython from typing import Dict, List, Optional, Sequence, Tuple from edb.server.protocol cimport cpythonx from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ int32_t, uint32_t, int64_t, uint64_t, \ UINT32_MAX import immutables from edb import buildmeta from edb import edgeql from edb.edgeql import qltypes from edb.graphql import tokenizer as gql_tokenizer from edb.pgsql import parser as pgparser from edb.graphql import tokenizer as gql_tokenizer from edb.server.pgproto cimport hton from edb.server.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, FRBuffer, frb_init, frb_read, frb_read_all, frb_get_len, ) from edb.server.pgproto.pgproto import UUID as pg_UUID from edb.server.dbview cimport dbview from edb.server import config from edb.server import args as srvargs from edb.server import compiler from edb.server import defines as edbdef from edb.server.compiler import errormech from edb.server.compiler import enums from edb.server.compiler import sertypes from edb.server.compiler cimport rpc from edb.server.protocol cimport auth_helpers from edb.server.protocol import execute from edb.server.protocol cimport frontend from edb.server.pgcon cimport pgcon from edb.server.pgcon import errors as pgerror from edb.server import metrics from edb.schema import objects as s_obj from edb import errors from edb.errors import base as base_errors, EdgeQLSyntaxError from edb.common import debug from edb.protocol import messages from edb import _graphql_rewrite include "./consts.pxi" cdef bytes EMPTY_TUPLE_UUID = s_obj.get_known_type_id('empty-tuple').bytes cdef uint64_t PROTO_CAPS = enums.Capability.PROTO_CAPS cdef object CARD_NO_RESULT = compiler.Cardinality.NO_RESULT cdef object CARD_AT_MOST_ONE = compiler.Cardinality.AT_MOST_ONE cdef object CARD_MANY = compiler.Cardinality.MANY cdef object FMT_NONE = compiler.OutputFormat.NONE cdef object FMT_BINARY = compiler.OutputFormat.BINARY cdef object LANG_EDGEQL = compiler.InputLanguage.EDGEQL cdef object LANG_SQL = compiler.InputLanguage.SQL cdef object LANG_GRAPHQL = compiler.InputLanguage.GRAPHQL cdef tuple DUMP_VER_MIN = (0, 7) cdef tuple DUMP_VER_MAX = edbdef.CURRENT_PROTOCOL cdef tuple MIN_PROTOCOL = edbdef.MIN_PROTOCOL cdef tuple CURRENT_PROTOCOL = edbdef.CURRENT_PROTOCOL cdef object logger = logging.getLogger('edb.server') cdef object log_metrics = logging.getLogger('edb.server.metrics') DEF QUERY_HEADER_DUMP_SECRETS = 0xFF10 def parse_catalog_version_header(value: bytes) -> uint64_t: if len(value) != 8: raise errors.BinaryProtocolError( f'catalog version value must be exactly 8 bytes (got {len(value)})' ) cdef uint64_t catver = hton.unpack_uint64(cpython.PyBytes_AS_STRING(value)) return catver cdef class EdgeConnection(frontend.FrontendConnection): interface = "edgeql" def __init__( self, server, tenant, *, auth_data: bytes, conn_params: dict[str, str] | None, protocol_version: edbdef.ProtocolVersion = CURRENT_PROTOCOL, **kwargs, ): super().__init__(server, tenant, **kwargs) self._con_status = EDGECON_NEW self._dbview = None self._last_anon_compiled = None self.query_cache_enabled = not (debug.flags.disable_qcache or debug.flags.edgeql_compile) self.protocol_version = protocol_version self.min_protocol = MIN_PROTOCOL self.max_protocol = CURRENT_PROTOCOL self._conn_params = conn_params self._in_dump_restore = False # Authentication data supplied by the transport (e.g. the content # of an HTTP Authorization header). self._auth_data = auth_data cdef is_in_tx(self): return self.get_dbview().in_tx() cdef inline dbview.DatabaseConnectionView get_dbview(self): if self._dbview is None: raise RuntimeError('Cannot access dbview while it is None') return self._dbview def debug_print(self, *args): if self._dbview is None: # This may happen before dbview is initialized, e.g. sending errors # to non-TLS clients due to mandatory TLS. print( '::EDGEPROTO::', f'id:{self._id}', f'in_tx:{0}', f'tx_error:{0}', *args, file=sys.stderr, ) else: print( '::EDGEPROTO::', f'id:{self._id}', f'in_tx:{int(self._dbview.in_tx())}', f'tx_error:{int(self._dbview.in_tx_error())}', *args, file=sys.stderr, ) def is_idle(self, expiry_time: float): # A connection is idle if it awaits for the next message for # client for too long (even if it is in an open transaction!) return ( self._con_status != EDGECON_BAD and super().is_idle(expiry_time) and not self._in_dump_restore ) def is_alive(self): return self._con_status == EDGECON_STARTED and super().is_alive() def close_for_idling(self): try: self.write_edgedb_error( errors.IdleSessionTimeoutError( 'closing the connection due to idling') ) finally: self.close() # will flush cdef _after_idling(self): self.server.on_binary_client_after_idling(self) async def do_handshake(self): cdef: char mtype if self._transport_proto is srvargs.ServerConnTransport.HTTP: if self._conn_params is None: params = {} else: params = self._conn_params else: await self.wait_for_message(report_idling=True) mtype = self.buffer.get_message_type() if mtype != b'V': raise errors.BinaryProtocolError( f'unexpected initial message: "{chr(mtype)}", ' f'expected "V"') params = await self._do_handshake() if self._conn_params is not None: params = self._conn_params + params return params async def auth(self, params): cdef: WriteBuffer msg_buf WriteBuffer buf user = params.get('user') if not user: raise errors.BinaryProtocolError( f'missing required connection parameter in ClientHandshake ' f'message: "user"' ) user = self.tenant.resolve_user_name(user) database = params.get('database') branch = params.get('branch') if not database and not branch: raise errors.BinaryProtocolError( f'missing required connection parameter in ClientHandshake ' f'message: "branch" (or "database")' ) database = self.tenant.resolve_branch_name(database, branch) logger.debug('received connection request by %s to database %s', user, database) await self._authenticate(user, database, params) logger.debug('successfully authenticated %s in database %s', user, database) if not self.tenant.is_database_connectable(database): raise errors.AccessError( f'database {database!r} does not accept connections' ) self.dbname = database self.username = user # In the tunneled HTTP endpoint, auth gets done after we have # set up a dbview, so we need to update it.. if self._dbview: self._dbview._role_name = user await self._start_connection(database) if self._transport_proto is srvargs.ServerConnTransport.HTTP: return buf = WriteBuffer() msg_buf = WriteBuffer.new_message(b'R') msg_buf.write_int32(0) msg_buf.end_message() buf.write_buffer(msg_buf) msg_buf = WriteBuffer.new_message(b'K') # TODO: should send ID of this connection msg_buf.write_bytes(b'\x00' * 32) msg_buf.end_message() buf.write_buffer(msg_buf) if self.get_dbview().get_state_serializer() is None: await self.get_dbview().reload_state_serializer() buf.write_buffer(self.make_state_data_description_msg()) self.write(buf) # In dev mode we expose the backend postgres DSN if self.server.in_dev_mode(): params = self.tenant.get_pgaddr() params.update(database=self.tenant.get_pg_dbname( self.get_dbview().dbname )) params.clear_server_settings() self.write_status(b'pgdsn', params.to_dsn().encode()) self.write_status( b'suggested_pool_concurrency', str(self.tenant.suggested_client_pool_size).encode() ) self.write_status( b'system_config', self.tenant.get_report_config_data(self.protocol_version), ) self.write(self.sync_status()) self.flush() async def _do_handshake(self): cdef: uint16_t major uint16_t minor int i uint16_t reserved dict params = {} major = self.buffer.read_int16() minor = self.buffer.read_int16() self.protocol_version = major, minor nparams = self.buffer.read_int16() for i in range(nparams): k = self.buffer.read_len_prefixed_utf8() v = self.buffer.read_len_prefixed_utf8() params[k] = v reserved = self.buffer.read_int16() if reserved != 0: raise errors.BinaryProtocolError( f'unexpected value in reserved field of ClientHandshake') self.buffer.finish_message() negotiate = False if self.protocol_version < self.min_protocol: target_proto = self.min_protocol negotiate = True elif self.protocol_version > self.max_protocol: target_proto = self.max_protocol negotiate = True else: target_proto = self.protocol_version if negotiate: self.write(self.make_negotiate_protocol_version_msg(target_proto)) self.flush() return params async def _start_connection(self, database: str) -> None: dbv = await self.tenant.new_dbview( dbname=database, query_cache=self.query_cache_enabled, protocol_version=self.protocol_version, role_name=self.username, ) assert type(dbv) is dbview.DatabaseConnectionView self._dbview = dbv self.dbname = database self._con_status = EDGECON_STARTED cdef stop_connection(self): self._con_status = EDGECON_BAD if self._dbview is not None: self.tenant.remove_dbview(self._dbview) self._dbview = None def _auth_jwt(self, user, database, params): # token in the HTTP header has higher priority than # the ClientHandshake message, under the scenario of # binary protocol over HTTP if self._auth_data: scheme, prefixed_token = auth_helpers.extract_token_from_auth_data( self._auth_data) if scheme != 'bearer': raise errors.AuthenticationError( 'authentication failed: unrecognized authentication scheme') else: prefixed_token = params.get('secret_key') return auth_helpers.auth_jwt( self.tenant, prefixed_token, user, database) cdef WriteBuffer _make_authentication_sasl_initial(self, list methods): cdef WriteBuffer msg_buf msg_buf = WriteBuffer.new_message(b'R') msg_buf.write_int32(10) # Number of auth methods followed by a series # of zero-terminated strings identifying each method, # sorted in the order of server preference. msg_buf.write_int32(len(methods)) for method in methods: msg_buf.write_len_prefixed_bytes(method) return msg_buf.end_message() cdef _expect_sasl_initial_response(self): mtype = self.buffer.get_message_type() if mtype != b'p': raise errors.BinaryProtocolError( f'expected SASL response, got message type {mtype}') selected_mech = self.buffer.read_len_prefixed_bytes() client_first = self.buffer.read_len_prefixed_bytes() self.buffer.finish_message() if not client_first: # The client didn't send the Client Initial Response # in SASLInitialResponse, this is an error. raise errors.BinaryProtocolError( f'client did not send the Client Initial Response ' f'data in SASLInitialResponse') return selected_mech, client_first cdef WriteBuffer _make_authentication_sasl_msg( self, bytes data, bint final ): cdef WriteBuffer msg_buf msg_buf = WriteBuffer.new_message(b'R') if final: msg_buf.write_int32(12) else: msg_buf.write_int32(11) msg_buf.write_len_prefixed_bytes(data) return msg_buf.end_message() cdef bytes _expect_sasl_response(self): mtype = self.buffer.get_message_type() if mtype != b'r': raise errors.BinaryProtocolError( f'expected SASL response, got message type {mtype}') client_final = self.buffer.read_len_prefixed_bytes() self.buffer.finish_message() return client_final async def _execute_script( self, compiled: object, bind_args: bytes, *, query_req: Optional[rpc.CompilationRequest] = None, ): cdef: pgcon.PGConnection conn dbview.DatabaseConnectionView dbv if self._cancelled: raise ConnectionAbortedError dbv = self.get_dbview() async with self.with_pgcon() as conn: await execute.execute_script( conn, dbv, compiled, bind_args, fe_conn=self, query_req=query_req, ) def _tokenize( self, eql: bytes, lang: enums.InputLanguage, ) -> edgeql.Source: text = eql.decode('utf-8') if lang is LANG_EDGEQL: if debug.flags.edgeql_disable_normalization: return edgeql.Source.from_string(text) else: return edgeql.NormalizedSource.from_string(text) elif lang is LANG_SQL: if debug.flags.edgeql_disable_normalization: return pgparser.Source.from_string(text) else: return pgparser.NormalizedSource.from_string(text) elif lang is LANG_GRAPHQL: if debug.flags.edgeql_disable_normalization: return gql_tokenizer.Source.from_string(text) else: try: return gql_tokenizer.NormalizedSource.from_string(text) except Exception: return gql_tokenizer.Source.from_string(text) else: raise errors.UnsupportedFeatureError( f"unsupported input language: {lang}") async def _suppress_tx_timeout(self): async with self.with_pgcon() as conn: await conn.sql_execute(b''' select pg_catalog.set_config( 'idle_in_transaction_session_timeout', '0', true) ''') async def _restore_tx_timeout(self, dbview.DatabaseConnectionView dbv): old_timeout = dbv.get_session_config().get( 'session_idle_transaction_timeout', ) timeout = ( 'NULL' if not old_timeout else repr(old_timeout.value.to_backend_str()) ) async with self.with_pgcon() as conn: await conn.sql_execute(f''' select pg_catalog.set_config( 'idle_in_transaction_session_timeout', {timeout}, true) '''.encode('utf-8')) async def _parse( self, rpc.CompilationRequest query_req, uint64_t allow_capabilities, tag=None, ) -> dbview.CompiledQuery: cdef dbview.DatabaseConnectionView dbv dbv = self.get_dbview() if self.debug: source = query_req.source text = source.text() self.debug_print('PARSE', text) self.debug_print( 'Cache key', source.cache_key(), f"protocol_version={query_req.protocol_version}", f"input_language={query_req.input_language}", f"output_format={query_req.output_format}", f"expect_one={query_req.expect_one}", f"implicit_limit={query_req.implicit_limit}", f"inline_typeids={query_req.inline_typeids}", f"inline_typenames={query_req.inline_typenames}", f"inline_objectids={query_req.inline_objectids}", f"allow_capabilities={allow_capabilities}", f"modaliazes={dbv.get_modaliases()}", f"session_config={dbv.get_session_config()}", ) self.debug_print('Extra variables', source.variables(), 'after', source.first_extra()) query_unit_group = dbv.lookup_compiled_query(query_req) if query_unit_group is None: # If we have to do a compile within a transaction, suppress # the idle_in_transaction_session_timeout. suppress_timeout = dbv.in_tx() and not dbv.in_tx_error() if suppress_timeout: await self._suppress_tx_timeout() try: if query_req.input_language is LANG_SQL: async with self.with_pgcon() as pg_conn: return await dbv.parse( query_req, allow_capabilities=allow_capabilities, pgcon=pg_conn, tag=tag, ) else: return await dbv.parse( query_req, allow_capabilities=allow_capabilities, send_log_message=( lambda code, s: self.write_log( EdgeSeverity.EDGE_SEVERITY_NOTICE, code, s, ) ) ) finally: if suppress_timeout: try: await self._restore_tx_timeout(dbv) except pgerror.BackendError as ex: # dbv.parse() for LANG_SQL can send a SQL # query, which can put the transaction in a # bad state if it fails. If we fail because of # that, swallow it. if ( query_req.input_language is not LANG_SQL or not ex.code_is( pgerror.ERRCODE_IN_FAILED_SQL_TRANSACTION ) ): raise else: return dbv.as_compiled(query_req, query_unit_group) cdef parse_cardinality(self, bytes card): if card[0] == CARD_MANY.value: return CARD_MANY elif card[0] == CARD_AT_MOST_ONE.value: return CARD_AT_MOST_ONE else: try: card_name = compiler.Cardinality(card[0]).name except ValueError: raise errors.BinaryProtocolError( f'unknown expected cardinality "{repr(card)[2:-1]}"') else: raise errors.BinaryProtocolError( f'cardinality {card_name} cannot be requested') cdef char render_cardinality(self, query_unit_group) except -1: return query_unit_group.cardinality.value cdef dict parse_headers(self): cdef: dict attrs uint16_t num_fields uint16_t key bytes value attrs = {} num_fields = self.buffer.read_int16() while num_fields: key = self.buffer.read_int16() value = self.buffer.read_len_prefixed_bytes() attrs[key] = value num_fields -= 1 return attrs cdef inline ignore_headers(self): cdef: uint16_t num_fields num_fields = self.buffer.read_int16() while num_fields: self.buffer.read_int16() self.buffer.read_len_prefixed_bytes() num_fields -= 1 cdef dict parse_annotations(self): cdef: dict annos uint16_t num_annos str name, value annos = {} num_annos = self.buffer.read_int16() while num_annos: name = self.buffer.read_len_prefixed_utf8() value = self.buffer.read_len_prefixed_utf8() annos[name] = value num_annos -= 1 return annos cdef inline ignore_annotations(self): cdef: uint16_t num_annos num_annos = self.buffer.read_int16() while num_annos: self.buffer.read_len_prefixed_bytes() self.buffer.read_len_prefixed_bytes() num_annos -= 1 ############# cdef WriteBuffer make_negotiate_protocol_version_msg( self, tuple target_proto, ): cdef: WriteBuffer msg # NegotiateProtocolVersion msg = WriteBuffer.new_message(b'v') # Highest supported major version of the protocol. msg.write_int16(target_proto[0]) # Highest supported minor version of the protocol. msg.write_int16(target_proto[1]) # No extensions are currently supported. msg.write_int16(0) msg.end_message() return msg cdef WriteBuffer make_command_data_description_msg( self, dbview.CompiledQuery query ): cdef: WriteBuffer msg int16_t ann_count msg = WriteBuffer.new_message(b'T') ann_count = 0 if query.query_unit_group.warnings: ann_count += 1 if query.query_unit_group.unsafe_isolation_dangers: ann_count += 1 msg.write_int16(ann_count) if query.query_unit_group.warnings: warnings = json.dumps( [w.to_json() for w in query.query_unit_group.warnings] ).encode('utf-8') msg.write_len_prefixed_bytes(b'warnings') msg.write_len_prefixed_bytes(warnings) if query.query_unit_group.unsafe_isolation_dangers: dangers = json.dumps([ w.to_json() for w in query.query_unit_group.unsafe_isolation_dangers ]).encode('utf-8') msg.write_len_prefixed_bytes(b'unsafe_isolation_dangers') msg.write_len_prefixed_bytes(dangers) msg.write_int64( query.query_unit_group.capabilities & PROTO_CAPS ) msg.write_byte(self.render_cardinality(query.query_unit_group)) in_data = query.query_unit_group.in_type_data msg.write_bytes(query.query_unit_group.in_type_id) msg.write_len_prefixed_bytes(in_data) out_data = query.query_unit_group.out_type_data msg.write_bytes(query.query_unit_group.out_type_id) msg.write_len_prefixed_bytes(out_data) msg.end_message() return msg cdef WriteBuffer make_state_data_description_msg(self): cdef WriteBuffer msg type_id, type_data = self.get_dbview().describe_state() msg = WriteBuffer.new_message(b's') msg.write_bytes(type_id.bytes) msg.write_len_prefixed_bytes(type_data) msg.end_message() return msg cdef WriteBuffer make_command_complete_msg(self, capabilities, status): cdef: WriteBuffer msg state_tid, state_data = self.get_dbview().encode_state() msg = WriteBuffer.new_message(b'C') msg.write_int16(0) # no annotations msg.write_int64(capabilities & PROTO_CAPS) msg.write_len_prefixed_bytes(status) msg.write_bytes(state_tid.bytes) msg.write_len_prefixed_bytes(state_data) return msg.end_message() async def _execute_rollback(self, compiled: dbview.CompiledQuery): cdef: dbview.DatabaseConnectionView _dbview pgcon.PGConnection conn query_unit = compiled.query_unit_group[0] _dbview = self.get_dbview() if not ( query_unit.tx_savepoint_rollback or query_unit.tx_rollback or query_unit.tx_abort_migration ): _dbview.raise_in_tx_error() async with self.with_pgcon() as conn: if query_unit.sql: await conn.sql_execute(query_unit.sql) if query_unit.tx_abort_migration: _dbview.clear_tx_error() elif query_unit.tx_savepoint_rollback: _dbview.rollback_tx_to_savepoint(query_unit.sp_name) else: assert query_unit.tx_rollback _dbview.abort_tx() async def _execute( self, compiled: dbview.CompiledQuery, bind_args: bytes, use_prep_stmt: bint, *, query_req: Optional[rpc.CompilationRequest] = None, ): cdef: dbview.DatabaseConnectionView dbv pgcon.PGConnection conn dbv = self.get_dbview() async with self.with_pgcon() as conn: await execute.execute( conn, dbv, compiled, bind_args, fe_conn=self, use_prep_stmt=use_prep_stmt, query_req=query_req, ) query_unit = compiled.query_unit_group[0] if query_unit.config_requires_restart: self.write_log( EdgeSeverity.EDGE_SEVERITY_NOTICE, errors.LogMessage.get_code(), 'server restart is required for the configuration ' 'change to take effect') cdef parse_execute_request(self): cdef: uint64_t allow_capabilities = 0 uint64_t compilation_flags = 0 int64_t implicit_limit = 0 bint inline_typenames = False bint inline_typeids = False bint inline_objectids = False object cardinality object output_format bint expect_one = False bytes query dbview.DatabaseConnectionView _dbview allow_capabilities = PROTO_CAPS & self.buffer.read_int64() compilation_flags = self.buffer.read_int64() implicit_limit = self.buffer.read_int64() if implicit_limit < 0: raise errors.BinaryProtocolError( f'implicit limit cannot be negative' ) inline_typenames = ( compilation_flags & messages.CompilationFlag.INJECT_OUTPUT_TYPE_NAMES ) inline_typeids = ( compilation_flags & messages.CompilationFlag.INJECT_OUTPUT_TYPE_IDS ) inline_objectids = ( compilation_flags & messages.CompilationFlag.INJECT_OUTPUT_OBJECT_IDS ) if self.protocol_version >= (3, 0): lang = rpc.deserialize_input_language(self.buffer.read_byte()) else: lang = LANG_EDGEQL output_format = rpc.deserialize_output_format(self.buffer.read_byte()) if ( lang is LANG_SQL and output_format is not FMT_NONE and output_format is not FMT_BINARY ): raise errors.UnsupportedFeatureError( "non-binary output format is not supported with " "SQL as the input language" ) cardinality = self.parse_cardinality(self.buffer.read_byte()) expect_one = cardinality is CARD_AT_MOST_ONE if lang is LANG_SQL and cardinality is not CARD_MANY: raise errors.UnsupportedFeatureError( "output cardinality assertions are not supported with " "SQL as the input language" ) query = self.buffer.read_len_prefixed_bytes() if not query: raise errors.BinaryProtocolError('empty query') metrics.query_size.observe( len(query), self.get_tenant_label(), 'edgeql' ) _dbview = self.get_dbview() state_tid = self.buffer.read_bytes(16) state_data = self.buffer.read_len_prefixed_bytes() try: _dbview.decode_state(state_tid, state_data) except errors.StateMismatchError: self.write(self.make_state_data_description_msg()) raise cfg_ser = self.server.compilation_config_serializer rv = rpc.CompilationRequest( source=self._tokenize(query, lang), protocol_version=self.protocol_version, schema_version=_dbview.schema_version, compilation_config_serializer=cfg_ser, input_language=lang, output_format=output_format, expect_one=expect_one, implicit_limit=implicit_limit, inline_typeids=inline_typeids, inline_typenames=inline_typenames, inline_objectids=inline_objectids, modaliases=_dbview.get_modaliases(), session_config=_dbview.get_session_config(), database_config=_dbview.get_database_config(), system_config=_dbview.get_compilation_system_config(), role_name=self.username, branch_name=self.dbname, ) return rv, allow_capabilities cdef get_checked_tag(self, dict annotations): tag = annotations.get("tag") if not tag: return None if len(tag) > 128: raise errors.BinaryProtocolError( 'bad annotation: tag too long (> 128 bytes)') return tag async def parse(self): cdef: bytes eql rpc.CompilationRequest query_req dbview.DatabaseConnectionView _dbview WriteBuffer parse_complete WriteBuffer buf uint64_t allow_capabilities self._last_anon_compiled = None if self.protocol_version >= (3, 0): self.ignore_annotations() else: self.ignore_headers() _dbview = self.get_dbview() if _dbview.get_state_serializer() is None: await _dbview.reload_state_serializer() query_req, allow_capabilities = self.parse_execute_request() compiled = await self._parse(query_req, allow_capabilities) buf = self.make_command_data_description_msg(compiled) # Cache compilation result in anticipation that the client # will follow up with an Execute immediately. # # N.B.: we cannot rely on query cache because not all units # are cacheable. self._last_anon_compiled = compiled self._last_anon_compiled_hash = hash(query_req) self.write(buf) self.flush() async def execute(self): cdef: rpc.CompilationRequest query_req dbview.DatabaseConnectionView _dbview bytes in_tid bytes out_tid bytes args uint64_t allow_capabilities if self.protocol_version >= (3, 0): tag = self.get_checked_tag(self.parse_annotations()) else: self.ignore_headers() tag = None _dbview = self.get_dbview() if _dbview.get_state_serializer() is None: await _dbview.reload_state_serializer() query_req, allow_capabilities = self.parse_execute_request() in_tid = self.buffer.read_bytes(16) out_tid = self.buffer.read_bytes(16) args = self.buffer.read_len_prefixed_bytes() self.buffer.finish_message() compiled = None if ( self._last_anon_compiled is not None and hash(query_req) == self._last_anon_compiled_hash and in_tid == self._last_anon_compiled.query_unit_group.in_type_id and out_tid == self._last_anon_compiled.query_unit_group.out_type_id ): compiled = self._last_anon_compiled query_unit_group = compiled.query_unit_group else: query_unit_group = _dbview.lookup_compiled_query(query_req) if query_unit_group is None: if self.debug: self.debug_print('EXECUTE /CACHE MISS', query_req.source.text()) compiled = await self._parse(query_req, allow_capabilities, tag) query_unit_group = compiled.query_unit_group # If this is a graphql request, and the compilation of it # depends on reading the value of some variables (because they # are used in @include or as params to type introspection, for # example) we need to reflect those variables into the # query_req and look it up again, and then maybe compile again. # # What a pain! if query_unit_group.graphql_key_variables: key_vars = _extract_key_vars(query_unit_group, query_req, args) query_req = query_req.__copy__() query_req.set_key_params(key_vars) compiled = None query_unit_group = _dbview.lookup_compiled_query(query_req) # If we had to do a graphql_key_variables lookup, we might need # to compile again. if query_unit_group is None: if self.debug: self.debug_print( 'EXECUTE /CACHE MISS (graphql nonsense)', query_req.source.text(), ) compiled = await self._parse(query_req, allow_capabilities, tag) query_unit_group = compiled.query_unit_group else: if not compiled: compiled = _dbview.as_compiled(query_req, query_unit_group) compiled.tag = tag if self._cancelled: raise ConnectionAbortedError self._query_count += 1 # Clear the _last_anon_compiled so that the next Execute - if # identical - will always lookup in the cache and honor the # `cacheable` flag to compile the query again. self._last_anon_compiled = None _dbview.check_capabilities( query_unit_group, allow_capabilities, errors.DisabledCapabilityError, "disabled by the client", unsafe_isolation_dangers=query_unit_group.unsafe_isolation_dangers, ) if query_unit_group.in_type_id != in_tid: self.write(self.make_command_data_description_msg(compiled)) raise errors.ParameterTypeMismatchError( "specified parameter type(s) do not match the parameter " "types inferred from specified command(s)" ) if ( query_unit_group.out_type_id != out_tid or query_unit_group.warnings ): # The client has no up-to-date information about the output, # so provide one. self.write(self.make_command_data_description_msg(compiled)) if self.debug: self.debug_print('EXECUTE', query_req.source.text()) force_script = any(x.needs_readback for x in query_unit_group) if ( _dbview.in_tx_error() or query_unit_group[0].tx_savepoint_rollback or query_unit_group[0].tx_abort_migration ): assert len(query_unit_group) == 1 await self._execute_rollback(compiled) elif len(query_unit_group) > 1 or force_script: await self._execute_script(compiled, args, query_req=query_req) else: use_prep = ( len(query_unit_group) == 1 and bool(query_unit_group[0].sql_hash) ) await self._execute(compiled, args, use_prep, query_req=query_req) if self._cancelled: raise ConnectionAbortedError if _dbview.is_state_desc_changed(): self.write(self.make_state_data_description_msg()) self.write( self.make_command_complete_msg( compiled.query_unit_group.capabilities, compiled.query_unit_group[-1].status, ) ) self.flush() async def sync(self): self.buffer.consume_message() self.write(self.sync_status()) if self.debug: self.debug_print('SYNC') self.flush() def check_readiness(self): if self.tenant.is_blocked(): readiness_reason = self.tenant.get_readiness_reason() msg = "the server is not accepting requests" if readiness_reason: msg = f"{msg}: {readiness_reason}" raise errors.ServerBlockedError(msg) elif not self.tenant.is_online(): readiness_reason = self.tenant.get_readiness_reason() msg = "the server is going offline" if readiness_reason: msg = f"{msg}: {readiness_reason}" raise errors.ServerOfflineError(msg) async def authenticate(self): self.check_readiness() params = await self.do_handshake() await self.auth(params) self.server.on_binary_client_authed(self) async def main_step(self, char mtype): try: self.check_readiness() if mtype == b'O': await self.execute() elif mtype == b'P': await self.parse() elif mtype == b'S': await self.sync() elif mtype == b'X': self.close() return True elif mtype == b'>': await self.dump() elif mtype == b'<': # The restore protocol cannot send SYNC beforehand, # so if an error occurs the server should send an # ERROR message immediately. await self.restore() elif mtype == b'D': raise errors.BinaryProtocolError( "Describe message (D) is not supported in " "protocol versions greater than 0.13") elif mtype == b'E': raise errors.BinaryProtocolError( "Legacy Execute message (E) is not supported in " "protocol versions greater than 0.13") elif mtype == b'Q': raise errors.BinaryProtocolError( "ExecuteScript message (Q) is not supported in " "protocol versions greater then 0.13") else: self.fallthrough() except ConnectionError: raise except asyncio.CancelledError: raise except Exception as ex: if self._cancelled and \ isinstance(ex, pgerror.BackendQueryCancelledError): # If we are cancelling the protocol (means that the # client side of the connection has dropped and we # need to gracefull cleanup and abort) we want to # propagate the BackendQueryCancelledError exception. # # If we're not cancelling, we'll treat it just like # any other error coming from Postgres (a query # might get cancelled due to a variety of reasons.) raise # The connection has been aborted; there's nothing # we can do except shutting this down. if self._con_status == EDGECON_BAD: return True self.get_dbview().tx_error() self.buffer.finish_message() ex = await self.interpret_error(ex) self.write_edgedb_error(ex) if isinstance( ex, (errors.ServerOfflineError, errors.ServerBlockedError), ): # This server is going into "offline" or "blocked" mode, # close the connection. self.write(self.sync_status()) self.flush() self.close() return self.flush() # The connection was aborted while we were # interpreting the error (via compiler/errmech.py). if self._con_status == EDGECON_BAD: return True await self.recover_from_error() else: self.buffer.finish_message() cdef _main_task_stopped_normally(self): self.write_log( EdgeSeverity.EDGE_SEVERITY_NOTICE, errors.LogMessage.get_code(), 'requested to stop; disconnecting now') async def recover_from_error(self): # Consume all messages until sync. while True: if not self.buffer.take_message(): await self.wait_for_message(report_idling=True) mtype = self.buffer.get_message_type() if mtype == b'S': await self.sync() return else: self.buffer.discard_message() cdef write_error(self, exc): self.write_edgedb_error(execute.interpret_simple_error(exc)) cdef write_edgedb_error(self, exc): cdef: WriteBuffer buf int16_t fields_len if self.debug and not isinstance(exc, errors.BackendUnavailableError): self.debug_print('EXCEPTION', type(exc).__name__, exc) from edb.common.markup import dump dump(exc) if debug.flags.server and not isinstance( exc, errors.BackendUnavailableError ): self.loop.call_exception_handler({ 'message': ( 'an error in edgedb protocol' ), 'exception': exc, 'protocol': self, 'transport': self._transport, }) fields = {} if isinstance(exc, errors.EdgeDBError): fields.update(exc._attrs) if isinstance(exc, errors.TransactionSerializationError): metrics.transaction_serialization_errors.inc( 1.0, self.get_tenant_label() ) try: formatted_error = exc.__formatted_error__ except AttributeError: try: formatted_error = ''.join( traceback.format_exception( type(exc), exc, exc.__traceback__, limit=50)) except Exception: formatted_error = 'could not serialize error traceback' fields[base_errors.FIELD_SERVER_TRACEBACK] = formatted_error buf = WriteBuffer.new_message(b'E') buf.write_byte(EdgeSeverity.EDGE_SEVERITY_ERROR) buf.write_int32(exc.get_code()) buf.write_len_prefixed_utf8(str(exc)) buf.write_int16(len(fields)) for k, v in fields.items(): buf.write_int16(k) buf.write_len_prefixed_utf8(str(v)) buf.end_message() self.write(buf) async def interpret_error(self, exc): dbv = self.get_dbview() return await execute.interpret_error( exc, dbv._db, global_schema_pickle=dbv.get_global_schema_pickle(), user_schema_pickle=dbv.get_user_schema_pickle(), ) cdef write_status(self, bytes name, bytes value): cdef: WriteBuffer buf buf = WriteBuffer.new_message(b'S') buf.write_len_prefixed_bytes(name) buf.write_len_prefixed_bytes(value) buf.end_message() self.write(buf) cdef write_log(self, EdgeSeverity severity, uint32_t code, str message): cdef: WriteBuffer buf if severity >= EdgeSeverity.EDGE_SEVERITY_ERROR: # This is an assertion. raise errors.InternalServerError( 'emitting a log message with severity=ERROR') buf = WriteBuffer.new_message(b'L') buf.write_byte(severity) buf.write_int32(code) buf.write_len_prefixed_utf8(message) buf.write_int16(0) # number of annotations buf.end_message() self.write(buf) cdef sync_status(self): cdef: WriteBuffer buf dbview.DatabaseConnectionView _dbview buf = WriteBuffer.new_message(b'Z') buf.write_int16(0) # no annotations # NOTE: EdgeDB and PostgreSQL current statuses can disagree. # For example, Postres can be "PQTRANS_INTRANS" whereas EdgeDB # would be "PQTRANS_INERROR". This can happen becuase some of # EdgeDB errors can happen at the compile stage, not even # reaching Postgres. _dbview = self.get_dbview() if _dbview.in_tx_error(): buf.write_byte(b'E') elif _dbview.in_tx(): buf.write_byte(b'T') else: buf.write_byte(b'I') return buf.end_message() cdef fallthrough(self): cdef: char mtype = self.buffer.get_message_type() if mtype == b'H': # Flush self.buffer.discard_message() self.flush() elif mtype == b'X': # Terminate self.buffer.discard_message() self.close() else: raise errors.BinaryProtocolError( f'unexpected message type {chr(mtype)!r}') def connection_made(self, transport): if self._con_status != EDGECON_NEW: raise errors.BinaryProtocolError( 'invalid connection status while establishing the connection') super().connection_made(transport) cdef _main_task_created(self): self.server.on_binary_client_connected(self) def connection_lost(self, exc): self.server.on_binary_client_disconnected(self) super().connection_lost(exc) @contextlib.asynccontextmanager async def _with_dump_restore_pgcon(self): self._in_dump_restore = True try: async with self.with_pgcon() as conn: yield conn finally: self._in_dump_restore = False # If backpressure was being applied during the operation, release it. # `resume_reading` is idempotent. self._transport.resume_reading() async def dump(self): cdef: WriteBuffer msg_buf dbview.DatabaseConnectionView _dbview uint64_t flags # Parse the "Dump" message if self.protocol_version >= (3, 0): self.ignore_annotations() flags = self.buffer.read_int64() include_secrets = flags & messages.DumpFlag.DUMP_SECRETS else: headers = self.parse_headers() include_secrets = headers.get(QUERY_HEADER_DUMP_SECRETS) == b'\x01' self.buffer.finish_message() _dbview = self.get_dbview() if _dbview.txid: raise errors.ProtocolError( 'DUMP must not be executed while in transaction' ) is_superuser, _ = _dbview.get_permissions() if not is_superuser: raise errors.DisabledCapabilityError( f'role {_dbview._role_name} does not have permission to ' f'perform dump' ) server = self.server compiler_pool = server.get_compiler_pool() dbname = _dbview.dbname async with self._with_dump_restore_pgcon() as pgcon: # To avoid having races, we want to: # # 1. start a transaction; # # 2. in the compiler process we connect to that transaction # and re-introspect the schema in it. # # 3. all dump worker pg connection would work on the same # connection. # # This guarantees that every pg connection and the compiler work # with the same DB state. await pgcon.sql_execute( b'''START TRANSACTION ISOLATION LEVEL SERIALIZABLE READ ONLY DEFERRABLE; -- Disable transaction or query execution timeout -- limits. Both clients and the server can be slow -- during the dump/restore process. SET LOCAL idle_in_transaction_session_timeout = 0; SET LOCAL statement_timeout = 0; ''', ) user_schema_json = await server.introspect_user_schema_json(pgcon) global_schema_json = ( await server.introspect_global_schema_json(pgcon) ) db_config_json = await server.introspect_db_config(pgcon) dump_protocol = self.max_protocol schema_ddl, schema_dynamic_ddl, schema_ids, blocks = ( await compiler_pool.describe_database_dump( user_schema_json, global_schema_json, db_config_json, dump_protocol, include_secrets, ) ) if schema_dynamic_ddl: for query in schema_dynamic_ddl: result = await pgcon.sql_fetch_val(query.encode('utf-8')) if result: schema_ddl += '\n' + result.decode('utf-8') msg_buf = WriteBuffer.new_message(b'@') # DumpHeader msg_buf.write_int16(4) # number of key-value pairs msg_buf.write_int16(DUMP_HEADER_BLOCK_TYPE) msg_buf.write_len_prefixed_bytes(DUMP_HEADER_BLOCK_TYPE_INFO) msg_buf.write_int16(DUMP_HEADER_SERVER_VER) msg_buf.write_len_prefixed_utf8(str(buildmeta.get_version())) msg_buf.write_int16(DUMP_HEADER_SERVER_CATALOG_VERSION) msg_buf.write_int32(8) msg_buf.write_int64(buildmeta.EDGEDB_CATALOG_VERSION) msg_buf.write_int16(DUMP_HEADER_SERVER_TIME) msg_buf.write_len_prefixed_utf8(str(int(time.time()))) msg_buf.write_int16(dump_protocol[0]) msg_buf.write_int16(dump_protocol[1]) msg_buf.write_len_prefixed_utf8(schema_ddl) msg_buf.write_int32(len(schema_ids)) for (tn, td, tid) in schema_ids: msg_buf.write_len_prefixed_utf8(tn) msg_buf.write_len_prefixed_utf8(td) assert len(tid) == 16 msg_buf.write_bytes(tid) # uuid msg_buf.write_int32(len(blocks)) for block in blocks: assert len(block.schema_object_id.bytes) == 16 msg_buf.write_bytes(block.schema_object_id.bytes) # uuid msg_buf.write_len_prefixed_bytes(block.type_desc) msg_buf.write_int16(len(block.schema_deps)) for depid in block.schema_deps: assert len(depid.bytes) == 16 msg_buf.write_bytes(depid.bytes) # uuid self._transport.write(memoryview(msg_buf.end_message())) self.flush() blocks_queue = collections.deque(blocks) output_queue = asyncio.Queue(maxsize=2) async with asyncio.TaskGroup() as g: g.create_task(pgcon.dump( blocks_queue, output_queue, DUMP_BLOCK_SIZE, )) nstops = 0 while True: if self._cancelled: raise ConnectionAbortedError out = await output_queue.get() if out is None: nstops += 1 if nstops == 1: # we only have one worker right now break else: block, block_num, data = out msg_buf = WriteBuffer.new_message(b'=') # DumpBlock msg_buf.write_int16(4) # number of key-value pairs msg_buf.write_int16(DUMP_HEADER_BLOCK_TYPE) msg_buf.write_len_prefixed_bytes( DUMP_HEADER_BLOCK_TYPE_DATA) msg_buf.write_int16(DUMP_HEADER_BLOCK_ID) msg_buf.write_len_prefixed_bytes( block.schema_object_id.bytes) msg_buf.write_int16(DUMP_HEADER_BLOCK_NUM) msg_buf.write_len_prefixed_bytes( str(block_num).encode()) msg_buf.write_int16(DUMP_HEADER_BLOCK_DATA) msg_buf.write_len_prefixed_buffer(data) self._transport.write(memoryview(msg_buf.end_message())) if self._write_waiter: await self._write_waiter await pgcon.sql_execute(b"ROLLBACK;") msg_buf = WriteBuffer.new_message(b'C') # CommandComplete msg_buf.write_int16(0) # no annotations msg_buf.write_int64(0) # capabilities msg_buf.write_len_prefixed_bytes(b'DUMP') msg_buf.write_bytes(sertypes.NULL_TYPE_ID.bytes) msg_buf.write_len_prefixed_bytes(b'') self.write(msg_buf.end_message()) self.flush() async def _execute_utility_stmt(self, eql: str, pgcon): cdef dbview.DatabaseConnectionView _dbview = self.get_dbview() cfg_ser = self.server.compilation_config_serializer query_req = rpc.CompilationRequest( source=edgeql.Source.from_string(eql), protocol_version=self.protocol_version, schema_version=_dbview.schema_version, compilation_config_serializer=cfg_ser, role_name=self.username, branch_name=self.dbname, ) compiled = await _dbview.parse(query_req) query_unit_group = compiled.query_unit_group assert len(query_unit_group) == 1 query_unit = query_unit_group[0] try: _dbview.start(query_unit) await pgcon.sql_execute(query_unit.sql) except Exception: _dbview.on_error() if ( query_unit.tx_commit and not pgcon.in_tx() and _dbview.in_tx() ): # The COMMIT command has failed. Our Postgres connection # isn't in a transaction anymore. Abort the transaction # in dbview. _dbview.abort_tx() raise else: _dbview.on_success(query_unit, {}) # _execute_utility_stmt is only used in restore(), where the state # serializer is not coming with the COMMIT command. However, we try # to keep the state serializer here anyways in case of future use if query_unit_group.state_serializer is not None: _dbview.set_state_serializer(query_unit_group.state_serializer) async def restore(self): cdef: WriteBuffer msg_buf char mtype dbview.DatabaseConnectionView _dbview _dbview = self.get_dbview() if _dbview.txid: raise errors.ProtocolError( 'RESTORE must not be executed while in transaction' ) is_superuser, _ = _dbview.get_permissions() if not is_superuser: raise errors.DisabledCapabilityError( f'role {_dbview._role_name} does not have permission to ' f'perform restore' ) if _dbview.get_state_serializer() is None: await _dbview.reload_state_serializer() # Parse the "Restore" message if self.buffer.read_int16() != 0: # number of attributes raise errors.BinaryProtocolError('unexpected attributes') self.buffer.read_int16() # discard -j level # Now parse the embedded "DumpHeader" message: server = self.server compiler_pool = server.get_compiler_pool() global_schema_pickle = _dbview.get_global_schema_pickle() user_schema_pickle = _dbview.get_user_schema_pickle() dump_server_ver_str = None cat_ver = None headers_num = self.buffer.read_int16() for _ in range(headers_num): hdrname = self.buffer.read_int16() hdrval = self.buffer.read_len_prefixed_bytes() if hdrname == DUMP_HEADER_SERVER_VER: dump_server_ver_str = hdrval.decode('utf-8') if hdrname == DUMP_HEADER_SERVER_CATALOG_VERSION: cat_ver = parse_catalog_version_header(hdrval) proto_major = self.buffer.read_int16() proto_minor = self.buffer.read_int16() proto = (proto_major, proto_minor) if proto > DUMP_VER_MAX or proto < DUMP_VER_MIN: raise errors.ProtocolError( f'unsupported dump version {proto_major}.{proto_minor}') schema_ddl = self.buffer.read_len_prefixed_bytes() ids_num = self.buffer.read_int32() schema_ids = [] for _ in range(ids_num): schema_ids.append(( self.buffer.read_len_prefixed_utf8(), self.buffer.read_len_prefixed_utf8(), self.buffer.read_bytes(16), )) block_num = self.buffer.read_int32() blocks = [] for _ in range(block_num): blocks.append(( self.buffer.read_bytes(16), self.buffer.read_len_prefixed_bytes(), )) # Ignore deps info for _ in range(self.buffer.read_int16()): self.buffer.read_bytes(16) self.buffer.finish_message() dbname = _dbview.dbname async with self._with_dump_restore_pgcon() as pgcon: _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'') await self._execute_utility_stmt( 'START TRANSACTION', pgcon, ) try: await pgcon.sql_execute( b''' -- Drop isolation level. SET TRANSACTION ISOLATION LEVEL READ COMMITTED; -- Disable transaction or query execution timeout -- limits. Both clients and the server can be slow -- during the dump/restore process. SET LOCAL idle_in_transaction_session_timeout = 0; SET LOCAL statement_timeout = 0; ''', ) schema_sql_units, restore_blocks, tables, repopulate_units = \ await compiler_pool.describe_database_restore( user_schema_pickle, global_schema_pickle, dump_server_ver_str, cat_ver, schema_ddl, schema_ids, blocks, proto, ) for query_unit in schema_sql_units: new_types = None _dbview.start(query_unit) try: if query_unit.config_ops: for op in query_unit.config_ops: if op.scope is config.ConfigScope.INSTANCE: raise errors.ProtocolError( 'CONFIGURE INSTANCE cannot be executed' ' in dump restore' ) if query_unit.sql: if query_unit.ddl_stmt_id: await pgcon.parse_execute(query=query_unit) ddl_ret = pgcon.load_last_ddl_return(query_unit) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] else: await pgcon.sql_execute(query_unit.sql) except Exception: _dbview.on_error() raise else: _dbview.on_success(query_unit, new_types) restore_blocks = { b.schema_object_id: b for b in restore_blocks } disable_trigger_q = '' enable_trigger_q = '' for table in tables: disable_trigger_q += ( f'ALTER TABLE {table} DISABLE TRIGGER ALL;' ) enable_trigger_q += ( f'ALTER TABLE {table} ENABLE TRIGGER ALL;' ) await pgcon.sql_execute(disable_trigger_q.encode()) # Send "RestoreReady" message msg = WriteBuffer.new_message(b'+') msg.write_int16(0) # no annotations msg.write_int16(1) # -j1 self.write(msg.end_message()) self.flush() while True: if not self.buffer.take_message(): # Don't report idling when restoring a dump. # This is an edge case and the client might be # legitimately slow. await self.wait_for_message(report_idling=False) mtype = self.buffer.get_message_type() if mtype == b'=': # RestoreBlock block_type = None block_id = None block_num = None block_data = None num_headers = self.buffer.read_int16() for _ in range(num_headers): header = self.buffer.read_int16() if header == DUMP_HEADER_BLOCK_TYPE: block_type = self.buffer.read_len_prefixed_bytes() elif header == DUMP_HEADER_BLOCK_ID: block_id = self.buffer.read_len_prefixed_bytes() block_id = pg_UUID(block_id) elif header == DUMP_HEADER_BLOCK_NUM: block_num = self.buffer.read_len_prefixed_bytes() elif header == DUMP_HEADER_BLOCK_DATA: block_data = self.buffer.read_len_prefixed_bytes() self.buffer.finish_message() if (block_type is None or block_id is None or block_num is None or block_data is None): raise errors.ProtocolError('incomplete data block') restore_block = restore_blocks[block_id] type_id_map = self._build_type_id_map_for_restore_mending( restore_block) self._transport.pause_reading() await pgcon.restore(restore_block, block_data, type_id_map) self._transport.resume_reading() elif mtype == b'.': # RestoreEof self.buffer.finish_message() break else: self.fallthrough() for repopulate_unit in repopulate_units: await pgcon.sql_execute(repopulate_unit.encode()) await pgcon.sql_execute(enable_trigger_q.encode()) except Exception: await pgcon.sql_execute(b'ROLLBACK') _dbview.abort_tx() raise else: await self._execute_utility_stmt('COMMIT', pgcon) execute.signal_side_effects(_dbview, dbview.SideEffects.SchemaChanges) await self.tenant.introspect_db(dbname) if _dbview.is_state_desc_changed(): self.write(self.make_state_data_description_msg()) state_tid, state_data = _dbview.encode_state() msg = WriteBuffer.new_message(b'C') # CommandComplete msg.write_int16(0) # no annotations msg.write_int64(0) # capabilities msg.write_len_prefixed_bytes(b'RESTORE') msg.write_bytes(state_tid.bytes) msg.write_len_prefixed_bytes(state_data) self.write(msg.end_message()) self.flush() def _build_type_id_map_for_restore_mending(self, restore_block): type_map = {} descriptor_stack = [] if not restore_block.data_mending_desc: return type_map descriptor_stack.append(restore_block.data_mending_desc) while descriptor_stack: desc_tuple = descriptor_stack.pop() for desc in desc_tuple: if desc is not None: type_map[desc.schema_type_id] = ( self.get_dbview().resolve_backend_type_id( desc.schema_type_id, ) ) descriptor_stack.append(desc.elements) return type_map @cython.final cdef class VirtualTransport: def __init__(self, transport): self.buf = WriteBuffer.new() self.closed = False self.transport = transport def write(self, data): self.buf.write_bytes(bytes(data)) def _get_data(self): return bytes(self.buf) def is_closing(self): return self.closed def close(self): self.closed = True def abort(self): self.closed = True def get_extra_info(self, name, default=None): return self.transport.get_extra_info(name, default) async def eval_buffer( server, tenant, database: str, data: bytes, conn_params: dict[str, str], protocol_version: edbdef.ProtocolVersion, auth_data: bytes, transport: srvargs.ServerConnTransport, tcp_transport: asyncio.Transport, ): cdef: VirtualTransport vtr EdgeConnection proto vtr = VirtualTransport(tcp_transport) proto = new_edge_connection( server, tenant, passive=True, auth_data=auth_data, transport=transport, conn_params=conn_params, protocol_version=protocol_version, ) proto.connection_made(vtr) if vtr.is_closing() or proto._main_task is None: raise RuntimeError( 'cannot process the request, the server is shutting down') # HACK: In the tunneled protocol we don't have the username when # we create the dbview, so put in an empty username. It will be # filled in once auth is called. proto.username = '' try: await proto._start_connection(database) proto.data_received(data) await proto._main_task except Exception as ex: proto.connection_lost(ex) else: proto.connection_lost(None) data = vtr._get_data() return data def new_edge_connection( server, tenant, *, external_auth: bool = False, passive: bool = False, transport: srvargs.ServerConnTransport = ( srvargs.ServerConnTransport.TCP), auth_data: bytes = b'', protocol_version: edbdef.ProtocolVersion = edbdef.CURRENT_PROTOCOL, conn_params: dict[str, str] | None = None, connection_made_at: float | None = None, ): return EdgeConnection( server, tenant, external_auth=external_auth, passive=passive, transport=transport, auth_data=auth_data, protocol_version=protocol_version, conn_params=conn_params, connection_made_at=connection_made_at, ) async def run_script( server, tenant, database: str, user: str, script: str, ) -> None: cdef: EdgeConnection conn dbview.CompiledQuery compiled dbview.DatabaseConnectionView _dbview conn = new_edge_connection(server, tenant) conn.username = user await conn._start_connection(database) try: _dbview = conn.get_dbview() cfg_ser = server.compilation_config_serializer compiled = await _dbview.parse( rpc.CompilationRequest( source=edgeql.Source.from_string(script), protocol_version=conn.protocol_version, schema_version=_dbview.schema_version, compilation_config_serializer=cfg_ser, output_format=FMT_NONE, role_name=user, branch_name=database, ), ) compiled.tag = "gel/startup-script" if len(compiled.query_unit_group) > 1: await conn._execute_script(compiled, b'') else: await conn._execute(compiled, b'', use_prep_stmt=0) except Exception as e: exc = await conn.interpret_error(e) if isinstance(exc, errors.EdgeDBError): raise exc from None else: raise exc finally: conn.close() cdef _extract_key_vars( qug: dbstate.QueryUnitGroup, query_req: rpc.CompilationRequest, args: bytes ): cdef: FRBuffer in_buf char *p int32_t recv_args int32_t decl_args ssize_t in_len frb_init( &in_buf, cpython.PyBytes_AS_STRING(args), cpython.Py_SIZE(args)) keys = qug.graphql_key_variables in_type_args = qug.in_type_args or () decl_args = len(in_type_args) if args: recv_args = hton.unpack_int32(frb_read(&in_buf, 4)) else: recv_args = 0 if recv_args != decl_args: raise errors.InputDataError( f"invalid argument count, " f"expected: {decl_args}, got: {recv_args}") vals = {} for param in in_type_args: frb_read(&in_buf, 4) # reserved needed = param.name in keys in_len = hton.unpack_int32(frb_read(&in_buf, 4)) if not needed: if in_len > 0: frb_read(&in_buf, in_len) continue if in_len < 0: val = None else: p = frb_read(&in_buf, in_len) # Very hacky and minimal decoding support. if param.typename == 'std::str': val = cpython.PyUnicode_DecodeUTF8(p, in_len, NULL) elif param.typename == 'std::bool': val = p[0] != 0 else: raise AssertionError( f'unsupported type for graphql introspection: ' f'{param.typename}' ) vals[param.name] = val # Extracted arguments come from the NormalizedSource. query_vars = query_req.source.variables() for name in keys: if name.startswith('__edb_arg_'): vals[name] = query_vars[name] return vals ================================================ FILE: edb/server/protocol/consts.pxi ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DEF DUMP_BLOCK_SIZE = 1024 * 1024 * 10 DEF DUMP_HEADER_BLOCK_TYPE = 101 DEF DUMP_HEADER_BLOCK_TYPE_INFO = b'I' DEF DUMP_HEADER_BLOCK_TYPE_DATA = b'D' DEF DUMP_HEADER_SERVER_TIME = 102 DEF DUMP_HEADER_SERVER_VER = 103 DEF DUMP_HEADER_BLOCKS_INFO = 104 DEF DUMP_HEADER_SERVER_CATALOG_VERSION = 105 DEF DUMP_HEADER_BLOCK_ID = 110 DEF DUMP_HEADER_BLOCK_NUM = 111 DEF DUMP_HEADER_BLOCK_DATA = 112 ================================================ FILE: edb/server/protocol/cpythonx.pxd ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 cdef extern from "Python.h": object PyLong_FromUnicodeObject( object u, int base) ================================================ FILE: edb/server/protocol/edgeql_ext.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import decimal import http import json import urllib.parse import immutables from edb import errors from edb import edgeql from edb.server import defines as edbdef from edb.server.protocol import execute from edb.schema import schema as s_schema from edb.common import debug from edb.common import markup from edb.edgeql import qltypes from edb.server import compiler from edb.server import config from edb.server.compiler import enums from edb.server.dbview cimport dbview from edb.server.pgproto.pgproto cimport WriteBuffer async def handle_request( object request, object response, dbview.Database db, str role_name, list args, object tenant, ): if args != []: response.body = b'Unknown path' response.status = http.HTTPStatus.NOT_FOUND response.close_connection = True return variables = None globals_ = None query = None config = None try: if request.method == b'POST': if request.content_type and b'json' in request.content_type: body = json.loads(request.body, parse_float=decimal.Decimal) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') query = body.get('query') variables = body.get('variables') globals_ = body.get('globals') config = body.get('config') else: raise TypeError( 'unable to interpret EdgeQL POST request') elif request.method == b'GET': if request.url.query: url_query = request.url.query.decode('ascii') qs = urllib.parse.parse_qs(url_query) query = qs.get('query') if query is not None: query = query[0] variables = qs.get('variables') if variables is not None: try: variables = json.loads(variables[0]) except Exception: raise TypeError( '"variables" must be a JSON object') globals_ = qs.get('globals') if globals_ is not None: try: globals_ = json.loads(globals_[0]) except Exception: raise TypeError( '"globals" must be a JSON object') config = qs.get('config') if config is not None: try: config = json.loads(config[0]) except Exception: raise TypeError( '"config" must be a JSON object') else: raise TypeError('expected a GET or a POST request') if not query: raise TypeError('invalid EdgeQL request: query is missing') if variables is not None and not isinstance(variables, dict): raise TypeError('"variables" must be a JSON object') if globals_ is not None and not isinstance(globals_, dict): raise TypeError('"globals" must be a JSON object') if config is not None and not isinstance(config, dict): raise TypeError('"config" must be a JSON object') except Exception as ex: if debug.flags.server: markup.dump(ex) response.body = str(ex).encode() response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True return response.status = http.HTTPStatus.OK response.content_type = b'application/json' try: result = await execute.parse_execute_json( db, query, role_name=role_name, variables=variables or {}, globals_=globals_, session_config=config, ) except Exception as ex: if debug.flags.server: markup.dump(ex) ex = await execute.interpret_error(ex, db) response.body = json.dumps({'error': ex.to_json()}).encode() else: response.body = b'{"data":' + result + b'}' ================================================ FILE: edb/server/protocol/execute.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from edb.server.pgproto.pgproto cimport WriteBuffer cdef class ExecutionGroup: cdef: object group list bind_datas cdef append(self, object query_unit, WriteBuffer bind_data=?) ================================================ FILE: edb/server/protocol/execute.pyi ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import ( Any, Mapping, Optional, ) import immutables from edb import errors from edb.server import compiler from edb.server import defines as edbdef from edb.server.compiler import sertypes from edb.server.dbview import dbview async def describe( db: dbview.Database, query: str, *, query_cache_enabled: Optional[bool] = None, allow_capabilities: compiler.Capability = ( compiler.Capability.MODIFICATIONS), query_tag: str | None = None, role_name: str, ) -> sertypes.TypeDesc: ... async def parse_execute_json( db: dbview.Database, query: str, *, variables: Mapping[str, Any] = immutables.Map(), globals_: Optional[Mapping[str, Any]] = None, output_format: compiler.OutputFormat = compiler.OutputFormat.JSON, query_cache_enabled: Optional[bool] = None, cached_globally: bool = False, use_metrics: bool = True, tx_isolation: edbdef.TxIsolationLevel | None = None, query_tag: str | None = None, role_name: str | None = None, ) -> bytes: ... async def interpret_error( exc: Exception, db: dbview.Database, *, global_schema_pickle: object=None, user_schema_pickle: object=None, from_graphql: bool=False, ) -> errors.EdgeDBError: ... ================================================ FILE: edb/server/protocol/execute.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import ( Any, Mapping, Optional, ) from edgedb import scram import asyncio import base64 import decimal import hashlib import json import logging import immutables from edb import errors from edb.common import debug from edb import edgeql from edb.edgeql import qltypes from edb.pgsql.parser import exceptions as parser_errors from edb.server import compiler from edb.server import config from edb.server import defines as edbdef from edb.server import metrics from edb.server.compiler import dbstate from edb.server.compiler import errormech from edb.server.compiler cimport rpc from edb.server.compiler import sertypes from edb.server.dbview cimport dbview from edb.server.protocol cimport args_ser from edb.server.protocol cimport frontend from edb.server.protocol import ai_ext from edb.server.pgcon cimport pgcon from edb.server.pgcon import errors as pgerror cdef object logger = logging.getLogger('edb.server') cdef object FMT_NONE = compiler.OutputFormat.NONE cdef WriteBuffer NO_ARGS = args_ser.combine_raw_args() cdef class ExecutionGroup: def __cinit__(self): self.group = compiler.QueryUnitGroup() self.bind_datas = [] cdef append(self, object query_unit, WriteBuffer bind_data=NO_ARGS): self.group.append(query_unit, serialize=False) self.bind_datas.append(bind_data) async def execute( self, pgcon.PGConnection be_conn, object dbv, # can be DatabaseConnectionView or Database fe_conn: frontend.AbstractFrontendConnection = None, bytes state = None, bint needs_commit_state = False, ): cdef int dbver rv = None async with be_conn.parse_execute_script_context(): dbver = dbv.dbver parse_array = [False] * len(self.group) be_conn.send_query_unit_group( self.group, True, # sync self.bind_datas, state, 0, # start len(self.group), # end dbver, parse_array, None, # query_prefix needs_commit_state, ) if state is not None: await be_conn.wait_for_state_resp( state, state_sync=needs_commit_state, needs_commit_state=needs_commit_state, ) for i, unit in enumerate(self.group): ignore_data = unit.output_format == FMT_NONE rv = await be_conn.wait_for_command( unit, parse_array[i], dbver, ignore_data=ignore_data, fe_conn=None if ignore_data else fe_conn, ) return rv cpdef ExecutionGroup build_cache_persistence_units( pairs: list[tuple[rpc.CompilationRequest, compiler.QueryUnitGroup]], ExecutionGroup group = None, ): if group is None: group = ExecutionGroup() insert_sql = b''' INSERT INTO "edgedb"."_query_cache" ("key", "schema_version", "input", "output", "evict") VALUES ($1, $2, $3, $4, $5) ON CONFLICT (key) DO NOTHING ''' sql_hash = hashlib.sha1(insert_sql).hexdigest().encode('latin1') for request, units in pairs: # FIXME: this is temporary; drop this assertion when we support scripts assert len(units) == 1 query_unit = units[0] assert query_unit.cache_sql is not None persist, evict = query_unit.cache_sql serialized_result = units.maybe_get_serialized(0) assert serialized_result is not None if evict: group.append(compiler.QueryUnit(sql=evict, status=b'')) if persist: group.append(compiler.QueryUnit(sql=persist, status=b'')) group.append( compiler.QueryUnit(sql=insert_sql, sql_hash=sql_hash, status=b''), args_ser.combine_raw_args(( query_unit.cache_key.bytes, query_unit.user_schema_version.bytes, request.serialize(), serialized_result, evict, )), ) return group async def describe( db: dbview.Database, query: str, *, query_cache_enabled: Optional[bool] = None, allow_capabilities: compiler.Capability = compiler.Capability.MODIFICATIONS, query_tag: str | None = None, role_name: str, ) -> sertypes.TypeDesc: dbv = await _get_transient_dbv(db, role_name=role_name) _, compiled = await _parse( dbv, query, query_cache_enabled=query_cache_enabled, allow_capabilities=allow_capabilities, ) if query_tag: compiled.tag = query_tag try: desc = sertypes.parse( compiled.query_unit_group.out_type_data, edbdef.CURRENT_PROTOCOL, ) finally: db.tenant.remove_dbview(dbv) return desc async def _get_transient_dbv( db: dbview.Database, *, query_cache_enabled: Optional[bool] = None, role_name: str, ) -> dbview.DatabaseConnectionView: if query_cache_enabled is None: query_cache_enabled = not ( debug.flags.disable_qcache or debug.flags.edgeql_compile) tenant = db.tenant dbv = await tenant.new_dbview( dbname=db.name, query_cache=query_cache_enabled, protocol_version=edbdef.CURRENT_PROTOCOL, role_name=role_name, ) dbv.is_transient = True return dbv async def _parse( dbv: dbview.DatabaseConnectionView, query: str, *, input_format: compiler.InputFormat = compiler.InputFormat.BINARY, output_format: compiler.OutputFormat = compiler.OutputFormat.BINARY, allow_capabilities: compiler.Capability = compiler.Capability.MODIFICATIONS, use_metrics: bool = True, cached_globally: bool = False, query_cache_enabled: Optional[bool] = None, ) -> tuple[ rpc.CompilationRequest, dbview.CompiledQuery, ]: db = dbv._db tenant = db.tenant if use_metrics: metrics.query_size.observe( len(query.encode('utf-8')), tenant.get_instance_name(), 'edgeql' ) query_req = rpc.CompilationRequest( source=edgeql.Source.from_string(query), protocol_version=edbdef.CURRENT_PROTOCOL, schema_version=dbv.schema_version, compilation_config_serializer=db.server.compilation_config_serializer, input_format=input_format, output_format=output_format, session_config=dbv.get_session_config(), database_config=dbv.get_database_config(), system_config=dbv.get_compilation_system_config(), role_name=dbv._role_name, ) compiled = await dbv.parse( query_req, cached_globally=cached_globally, use_metrics=use_metrics, allow_capabilities=allow_capabilities, ) return query_req, compiled # TODO: can we merge execute and execute_script? async def execute( be_conn: pgcon.PGConnection, dbv: dbview.DatabaseConnectionView, compiled: dbview.CompiledQuery, bind_args: bytes, *, fe_conn: frontend.AbstractFrontendConnection = None, use_prep_stmt: bint = False, tx_isolation: edbdef.TxIsolationLevel | None = None, query_req: Optional[rpc.CompilationRequest] = None, ): cdef: bytes state = None, orig_state = None WriteBuffer bound_args_buf bint needs_commit_state = False query_unit = compiled.query_unit_group[0] if not dbv.in_tx(): orig_state = state = dbv.serialize_state() needs_commit_state = dbv.needs_commit_after_state_sync() new_types = None server = dbv.server tenant = dbv.tenant data = None try: if be_conn.last_state == state: # the current status in be_conn is in sync with dbview, skip the # state restoring state = None dbv.start(query_unit) if query_unit.create_db_template: await tenant.on_before_create_db_from_template( query_unit.create_db_template, dbv.dbname, query_unit.create_db_mode, ) if query_unit.drop_db: await tenant.on_before_drop_db( query_unit.drop_db, dbv.dbname, close_frontend_conns=query_unit.drop_db_reset_connections, ) if query_unit.early_non_tx_sql: # Sync state non transactionally await be_conn.sql_fetch(b'select 1', state=state) for sql in query_unit.early_non_tx_sql: await be_conn.sql_execute(sql) if query_unit.system_config: # execute_system_config() always sync state in a separate tx, # so we don't need to pass down the needs_commit_state here await execute_system_config(be_conn, dbv, query_unit, state) else: config_ops = query_unit.config_ops if query_unit.sql: if query_unit.user_schema: await be_conn.parse_execute( query=query_unit, state=state, needs_commit_state=needs_commit_state, ) if query_unit.ddl_stmt_id is not None: ddl_ret = be_conn.load_last_ddl_return(query_unit) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] else: converted_args: Optional[list[args_ser.ConvertedArg]] = None if query_unit.server_param_conversions: converted_args = (await _convert_parameters( dbv, compiled, query_unit.server_param_conversions, bind_args, )).get(0, None) data_types = [] bound_args_buf = args_ser.recode_bind_args( dbv, compiled, bind_args, converted_args, None, data_types, ) assert not (query_unit.database_config and query_unit.needs_readback), ( "needs_readback+database_config must use execute_script" ) read_data = ( query_unit.needs_readback or query_unit.is_explain) data = await be_conn.parse_execute( query=query_unit, fe_conn=fe_conn if not read_data else None, bind_data=bound_args_buf, param_data_types=data_types, use_prep_stmt=use_prep_stmt, state=state, needs_commit_state=needs_commit_state, dbver=dbv.dbver, use_pending_func_cache=compiled.use_pending_func_cache, tx_isolation=tx_isolation, query_prefix=compiled.make_query_prefix(), ) if query_unit.needs_readback and data: config_ops = [ config.Operation.from_json(r[0][1:]) for r in data ] if query_unit.is_explain: # Go back to the compiler pool to analyze # the explain output. compiler_pool = server.get_compiler_pool() r = await compiler_pool.analyze_explain_output( query_unit.query_asts, data ) buf = WriteBuffer.new_message(b'D') buf.write_int16(1) # 1 column buf.write_len_prefixed_bytes(r) fe_conn.write(buf.end_message()) if state is not None: # state is restored, clear orig_state so that we can # set be_conn.last_state correctly later orig_state = None if query_unit.tx_savepoint_rollback: dbv.rollback_tx_to_savepoint(query_unit.sp_name) if query_unit.tx_savepoint_declare: dbv.declare_savepoint( query_unit.sp_name, query_unit.sp_id) if query_unit.create_db_template: try: await tenant.on_after_create_db_from_template( query_unit.create_db, query_unit.create_db_template, query_unit.create_db_mode, ) except Exception: # Clean up the database if we failed to restore into it. # TODO: Is it worth having 'ready' flag that we set after # the database is fully set up, and use that to clean up # databases where a crash prevented doing this cleanup? db_name = f'{tenant.tenant_id}_{query_unit.create_db}' await be_conn.sql_execute( b'drop database "%s"' % db_name.encode('utf-8') ) raise if query_unit.create_db: await tenant.introspect_db(query_unit.create_db) if query_unit.drop_db: tenant.on_after_drop_db(query_unit.drop_db) if config_ops: await dbv.apply_config_ops(be_conn, config_ops) if query_unit.user_schema and debug.flags.delta_validate_reflection: global_schema = ( query_unit.global_schema or dbv.get_global_schema_pickle()) new_user_schema = await dbv.tenant._debug_introspect( be_conn, global_schema) compiler_pool = dbv.server.get_compiler_pool() await compiler_pool.validate_schema_equivalence( query_unit.user_schema, new_user_schema, global_schema, dbv._last_comp_state, ) query_unit.user_schema = new_user_schema except Exception as ex: if isinstance(ex, pgerror.BackendError): # If we made schema changes, include the new schema in the # exception so that it can be used when interpreting. if query_unit.user_schema: ex._user_schema = query_unit.user_schema # If we get an undefined function error, this is probably # because of a pgfunc cache invalidation race condition, # where another frontend dropped the function but we # haven't processed the message yet. We are going to # trigger a client retry (via errormech), but we also want # to invalidate the cache entry, in cache we haven't # processed the message by the retry. if ( query_req and ex.code_is(pgerror.ERROR_UNDEFINED_FUNCTION) ): dbv._db.invalidate_cache_entry_object(query_req) if query_unit.source_map: ex._from_sql = True dbv.on_error() if query_unit.tx_commit and not be_conn.in_tx() and dbv.in_tx(): # The COMMIT command has failed. Our Postgres connection # isn't in a transaction anymore. Abort the transaction # in dbview. dbv.abort_tx() raise else: side_effects = dbv.on_success(query_unit, new_types) state_serializer = compiled.query_unit_group.state_serializer if state_serializer is not None: dbv.set_state_serializer(state_serializer) if side_effects: await process_side_effects(dbv, side_effects, be_conn) if not dbv.in_tx() and not query_unit.tx_rollback and query_unit.sql: state = dbv.serialize_state() if state is not orig_state: # In 3 cases the state is changed: # 1. The non-tx query changed the state # 2. The state is synced with dbview (orig_state is None) # 3. We came out from a transaction (orig_state is None) # Excluding two special case when the state is NOT changed: # 1. An orphan ROLLBACK command without a paring start tx # 2. There was no SQL, so the state can't have been synced. be_conn.last_state = state be_conn.state_reset_needs_commit = ( dbv.needs_commit_after_state_sync()) if compiled.recompiled_cache: for req, qu_group in compiled.recompiled_cache: dbv.cache_compiled_query(req, qu_group) finally: if query_unit.drop_db: tenant.allow_database_connections(query_unit.drop_db) return data async def _convert_parameters( dbv: dbview.DatabaseConnectionView, compiled: dbview.CompiledQuery, server_param_conversions: list[dbstate.ServerParamConversion], bind_args: bytes, ) -> dict[int, list[args_ser.ConvertedArg]]: """ If there are server param conversions, compute them now so that they are injected into the recoded bind args later. """ param_conversions: list[args_ser.ParamConversion] = ( args_ser.get_param_conversions( dbv, server_param_conversions, bind_args, compiled.extra_blobs, ) ) # Cache converted args which may be used in multiple units converted_args_cache: list[Optional[args_ser.ConvertedArg]] = ( [None] * len(param_conversions) ) unit_group = compiled.query_unit_group # First check for conversions which should be done in batches ai_text_embedding_conversion_indexes: list[int] = [] ai_text_embedding_conversions: list[args_ser.ParamConversion] = [] for unit_index, converted_params_indexes in ( unit_group.unit_converted_param_indexes.items() ): for conversion_index in converted_params_indexes: conversion = param_conversions[conversion_index] conversion_name: str = conversion.get_conversion_name() if conversion_name == 'ai_text_embedding': ai_text_embedding_conversion_indexes.append(conversion_index) ai_text_embedding_conversions.append(conversion) # Compute batched conversions and store them in cache if ai_text_embedding_conversions: converted_args = ( await _batch_convert_ai_text_embedding( dbv, ai_text_embedding_conversions ) ) for conversion_index, converted_arg in zip( ai_text_embedding_conversion_indexes, converted_args, ): converted_args_cache[conversion_index] = converted_arg # Do the remaining conversions converted_args: dict[int, list[args_ser.ConvertedArg]] = {} for unit_index, converted_params_indexes in ( unit_group.unit_converted_param_indexes.items() ): unit_converted_args: list[args_ser.ParamConversion] = [] for conversion_index in converted_params_indexes: # Check for a cached conversion arg if converted_arg := converted_args_cache[conversion_index]: unit_converted_args.append(converted_arg) continue # Do the conversion converted_arg = await _convert_parameter( param_conversions[conversion_index] ) unit_converted_args.append(converted_arg) converted_args_cache[conversion_index] = converted_arg if unit_converted_args: converted_args[unit_index] = unit_converted_args return converted_args async def _convert_parameter( conversion: args_ser.ParamConversion, ) -> args_ser.ConvertedArg: conversion_name = conversion.get_conversion_name() # We receive the encoded param data from the bind_args or extra blobs # and decode it manually. if ( conversion_name == 'cast_int64_to_str' or conversion_name == 'cast_int64_to_str_volatile' ): decoded_param_data = conversion.param_as_int() return args_ser.ConvertedArgStr.new( str(decoded_param_data) ) elif conversion_name == 'cast_int64_to_float64': decoded_param_data = conversion.param_as_int() return args_ser.ConvertedArgFloat64.new( float(decoded_param_data) ) elif conversion_name == 'join_str_array': decoded_param_data = conversion.param_as_array_of_str() separator = conversion.get_additional_info()[0] return args_ser.ConvertedArgStr.new( separator.join(decoded_param_data) ) elif conversion_name == 'ai_text_embedding': raise RuntimeError(f'conversion should be batched: {conversion_name}') else: raise errors.QueryError( f'unknown param conversion: {conversion_name}' ) async def _batch_convert_ai_text_embedding( dbv: dbview.DatabaseConnectionView, conversions: list[args_ser.ParamConversion], ) -> list[args_ser.ConvertedArg]: embeddings_inputs: list[tuple[str, str]] = [ ( conversion_data.get_additional_info()[0], conversion_data.param_as_str(), ) for conversion_data in conversions ] tenant = dbv.tenant db = tenant.maybe_get_db(dbname=dbv.dbname) assert db is not None embeddings_result = await ai_ext.generate_embeddings_for_texts( db, tenant.get_http_client(originator="ai/index"), embeddings_inputs, ) if embeddings_result.too_long: long_input = embeddings_inputs[embeddings_result.too_long[0]][1][:100] raise errors.QueryError( f'Search text exceeds maximum input token length: {long_input}...' ) if not embeddings_result.success: raise RuntimeError('failed to get embeddings') return [ args_ser.ConvertedArgListFloat32.new( embeddings ) for embeddings in embeddings_result.success ] async def execute_script( conn: pgcon.PGConnection, dbv: dbview.DatabaseConnectionView, compiled: dbview.CompiledQuery, bind_args: bytes, *, query_req: Optional[rpc.CompilationRequest] = None, fe_conn: Optional[frontend.AbstractFrontendConnection], ): cdef: bytes state = None, orig_state = None ssize_t sent = 0 bint in_tx, sync, no_sync object user_schema, extensions, ext_config_settings, cached_reflection object global_schema, roles WriteBuffer bind_data int dbver = dbv.dbver bint parse, needs_commit_state = False user_schema = extensions = ext_config_settings = cached_reflection = None feature_used_metrics = None global_schema = roles = None unit_group = compiled.query_unit_group query_prefix = compiled.make_query_prefix() query_unit = None sync = False no_sync = False in_tx = dbv.in_tx() if not in_tx: orig_state = state = dbv.serialize_state() needs_commit_state = dbv.needs_commit_after_state_sync() data = None try: if conn.last_state == state: # the current status in be_conn is in sync with dbview, skip the # state restoring state = None async with conn.parse_execute_script_context(): converted_args: Optional[dict[int, list[args_ser.ConvertedArg]]] = None if unit_group.server_param_conversions: converted_args = await _convert_parameters( dbv, compiled, unit_group.server_param_conversions, bind_args, ) parse_array = [False] * len(unit_group) for idx, query_unit in enumerate(unit_group): if fe_conn is not None and fe_conn.cancelled: raise ConnectionAbortedError assert not query_unit.is_explain # XXX: pull out? # We want to minimize the round trips we need to make, so # ideally we buffer up everything, send it once, and then issue # one SYNC. This gets messed up if there are commands where # we need to read back information, though, such as SET GLOBAL. # # Because of that, we look for the next command that # needs read back (probably there won't be one!), and # execute everything up to that point at once, # finished by a FLUSH. if idx >= sent: no_sync = False for n in range(idx, len(unit_group)): ng = unit_group[n] if ng.ddl_stmt_id or ng.needs_readback: sent = n + 1 if ng.needs_readback: no_sync = True break else: sent = len(unit_group) sync = sent == len(unit_group) and not no_sync bind_array = args_ser.recode_bind_args_for_script( dbv, compiled, bind_args, converted_args, idx, sent, ) dbver = dbv.dbver conn.send_query_unit_group( unit_group, sync, bind_array, state, idx, sent, dbver, parse_array, query_prefix, needs_commit_state, ) if idx == 0 and state is not None: await conn.wait_for_state_resp( state, state_sync=needs_commit_state, needs_commit_state=needs_commit_state, ) conn.state_reset_needs_commit = needs_commit_state # state is restored, clear orig_state so that we can # set conn.last_state correctly later orig_state = None new_types = None dbv.start_implicit(query_unit) config_ops = query_unit.config_ops if query_unit.user_schema: user_schema = query_unit.user_schema extensions = query_unit.extensions ext_config_settings = query_unit.ext_config_settings cached_reflection = query_unit.cached_reflection feature_used_metrics = query_unit.feature_used_metrics if query_unit.global_schema: global_schema = query_unit.global_schema roles = query_unit.roles if query_unit.sql: parse = parse_array[idx] fe_output = query_unit.output_format != FMT_NONE ignore_data = ( not fe_output and not query_unit.needs_readback ) data = await conn.wait_for_command( query_unit, parse, dbver, ignore_data=ignore_data, fe_conn=fe_conn if fe_output else None, ) if query_unit.ddl_stmt_id: ddl_ret = conn.load_last_ddl_return(query_unit) if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] if query_unit.needs_readback and data: config_ops = [ config.Operation.from_json(r[0][1:]) for r in data ] if config_ops: await dbv.apply_config_ops(conn, config_ops) side_effects = dbv.on_success(query_unit, new_types) if side_effects: raise errors.InternalServerError( "Side-effects in implicit transaction!" ) # Need to sync before calling process_side_effects, which will # look at the database. Also, want to sync before we record success, # since sync could fail. if sent and not sync: sync = True await conn.sync() except Exception as e: dbv.on_error() if isinstance(e, pgerror.BackendError): # Include the new schema in the exception so that it can be # used when interpreting. e._user_schema = dbv.get_user_schema_pickle() # If we get an undefined function error, this is probably # because of a pgfunc cache invalidation race condition, # where another frontend dropped the function but we # haven't processed the message yet. We are going to # trigger a client retry (via errormech), but we also want # to invalidate the cache entry, in cache we haven't # processed the message by the retry. if ( query_req and e.code_is(pgerror.ERROR_UNDEFINED_FUNCTION) ): dbv._db.invalidate_cache_entry_object(query_req) if query_unit and query_unit.source_map: e._from_sql = True if not in_tx and dbv.in_tx(): # Abort the implicit transaction dbv.abort_tx() # If something went wrong that is *not* on the backend side, force # an error to occur on the SQL side. if not isinstance(e, pgerror.BackendError): await conn.force_error() raise else: updated_user_schema = False if user_schema and debug.flags.delta_validate_reflection: cur_global_schema = ( global_schema or dbv.get_global_schema_pickle()) new_user_schema = await dbv.tenant._debug_introspect( conn, cur_global_schema) compiler_pool = dbv.server.get_compiler_pool() await compiler_pool.validate_schema_equivalence( user_schema, new_user_schema, cur_global_schema, dbv._last_comp_state, ) user_schema = new_user_schema updated_user_schema = True if not in_tx: side_effects = dbv.commit_implicit_tx( user_schema, extensions, ext_config_settings, global_schema, roles, cached_reflection, feature_used_metrics, ) if side_effects: await process_side_effects(dbv, side_effects, conn) state = dbv.serialize_state() if state is not orig_state: conn.last_state = state conn.state_reset_needs_commit = ( dbv.needs_commit_after_state_sync()) elif updated_user_schema: dbv._in_tx_user_schema_pickle = user_schema if unit_group.state_serializer is not None: dbv.set_state_serializer(unit_group.state_serializer) finally: if sent and not sync: await conn.sync() return data async def execute_system_config( conn: pgcon.PGConnection, dbv: dbview.DatabaseConnectionView, query_unit: compiler.QueryUnit, state: bytes | None, ): if query_unit.is_system_config: dbv.server.before_alter_system_config() # Sync state await conn.sql_fetch(b'select 1', state=state) if query_unit.sql: data = await conn.sql_fetch_col(query_unit.sql) else: data = None if data: # Prefer encoded op produced by the SQL command. if data[0][0] != 0x01: raise errors.InternalServerError( f"unexpected JSONB version produced by SQL statement for " f"CONFIGURE INSTANCE: {data[0][0]}" ) config_ops = [config.Operation.from_json(r[1:]) for r in data] else: # Otherwise, fall back to staticly evaluated op. config_ops = query_unit.config_ops await dbv.apply_config_ops(conn, config_ops) await conn.sql_execute(b'delete from _config_cache') # If this is a backend configuration setting we also # need to make sure it has been loaded. if query_unit.backend_config: await conn.sql_execute(b'SELECT pg_reload_conf()') async def process_side_effects(dbv, side_effects, conn): signal_side_effects(dbv, side_effects) if side_effects & dbview.SideEffects.DatabaseConfigChanges: tenant = dbv.tenant await tenant.process_local_database_config_change(conn, dbv.dbname) def signal_side_effects(dbv, side_effects): tenant = dbv.tenant if not tenant.accept_new_tasks: return if side_effects & dbview.SideEffects.SchemaChanges: tenant.create_task( tenant.signal_sysevent( 'schema-changes', dbname=dbv.dbname, ), interruptable=False, ) if side_effects & dbview.SideEffects.GlobalSchemaChanges: tenant.create_task( tenant.signal_sysevent( 'global-schema-changes', ), interruptable=False, ) if side_effects & dbview.SideEffects.DatabaseConfigChanges: tenant.create_task( tenant.signal_sysevent( 'database-config-changes', dbname=dbv.dbname, ), interruptable=False, ) if side_effects & dbview.SideEffects.DatabaseChanges: tenant.create_task( tenant.signal_sysevent( 'database-changes', ), interruptable=False, ) if side_effects & dbview.SideEffects.InstanceConfigChanges: tenant.create_task( tenant.signal_sysevent( 'system-config-changes', ), interruptable=False, ) async def parse_execute_json( db: dbview.Database, query: str, *, variables: Mapping[str, Any] = immutables.Map(), globals_: Optional[Mapping[str, Any]] = None, session_config: Optional[Mapping[str, Any]] = None, output_format: compiler.OutputFormat = compiler.OutputFormat.JSON, query_cache_enabled: Optional[bool] = None, # WARNING: only set cached_globally to True when the query is # strictly referring to only shared stable objects in user schema # or anything from std schema, for example: # YES: select ext::auth::UIConfig { ... } # NO: select default::User { ... } cached_globally: bool = False, use_metrics: bool = True, tx_isolation: edbdef.TxIsolationLevel | None = None, query_tag: str | None = None, role_name: str | None = None, ) -> bytes: if role_name is None: role_name = edbdef.EDGEDB_SUPERUSER dbv: dbview.DatabaseConnectionView = await _get_transient_dbv( db, query_cache_enabled=query_cache_enabled, role_name=role_name, ) dbv.decode_json_session_config(session_config) query_req, compiled = await _parse( dbv, query, input_format=compiler.InputFormat.JSON, output_format=output_format, allow_capabilities=compiler.Capability.MODIFICATIONS, use_metrics=use_metrics, cached_globally=cached_globally, ) if query_tag: compiled.tag = query_tag tenant = db.tenant async with tenant.with_pgcon(db.name) as pgcon: try: return await execute_json( pgcon, dbv, compiled, variables=variables, globals_=globals_, tx_isolation=tx_isolation, query_req=query_req, ) finally: tenant.remove_dbview(dbv) async def execute_json( be_conn: pgcon.PGConnection, dbv: dbview.DatabaseConnectionView, compiled: dbview.CompiledQuery, variables: Mapping[str, Any] = immutables.Map(), globals_: Optional[Mapping[str, Any]] = None, *, fe_conn: Optional[frontend.AbstractFrontendConnection] = None, use_prep_stmt: bint = False, tx_isolation: edbdef.TxIsolationLevel | None = None, query_req: Optional[rpc.CompilationRequest] = None, ) -> bytes: if globals_ is None: globals_ = {} if compiled.query_unit_group.json_permissions: # Inject any required permissions into the globals json. superuser, available_permissions = dbv.get_permissions() for permission in compiled.query_unit_group.json_permissions: if permission in globals_: raise RuntimeError( f"Permission cannot be passed as globals: '{permission}'" ) globals_[permission] = ( superuser or permission in available_permissions ) # TODO: only when needed? in a less dodgy way?? for k, v in dbv._sys_globals.items(): if k in globals_: raise RuntimeError( f"System global '{k}' cannot be explicitly specified" ) globals_[k] = v dbv.set_globals(immutables.Map({ "__::__edb_json_globals__": config.SettingValue( name="__::__edb_json_globals__", value=_encode_json_value(globals_), source='global', scope=qltypes.ConfigScope.GLOBAL, ) })) qug = compiled.query_unit_group args = [] if qug.in_type_args: for param in qug.in_type_args: value = variables.get(param.name) args.append(value) bind_args = _encode_args(args) force_script = any(x.needs_readback for x in qug) if len(qug) > 1 or force_script: if tx_isolation is not None: raise errors.InternalServerError( "execute_script does not support " "modified transaction isolation" ) data = await execute_script( be_conn, dbv, compiled, bind_args, fe_conn=fe_conn, query_req=query_req, ) else: if tx_isolation is not None: if dbv.in_tx(): raise errors.InternalServerError( "cannot run statement with alternate transaction " "isolation: already in a transaction" ) query_unit = compiled.query_unit_group[0] if not query_unit.is_transactional: raise errors.InternalServerError( "cannot run statement with alternate transaction " "isolation: statement is not transactional" ) data = await execute( be_conn, dbv, compiled, bind_args, fe_conn=fe_conn, tx_isolation=tx_isolation, query_req=query_req, ) if fe_conn is None: if not data or len(data) > 1 or len(data[0]) != 1: raise errors.InternalServerError( f'received incorrect response data for a JSON query') return data[0][0] else: return None class DecimalEncoder(json.JSONEncoder): def encode(self, obj): if isinstance(obj, dict): return '{' + ', '.join( f'{self.encode(k)}: {self.encode(v)}' for (k, v) in obj.items() ) + '}' if isinstance(obj, list): return '[' + ', '.join(map(self.encode, obj)) + ']' if isinstance(obj, bytes): return self.encode(base64.b64encode(obj).decode()) if isinstance(obj, decimal.Decimal): return f'{obj:f}' return super().encode(obj) cdef bytes _encode_json_value(object val): jarg = json.dumps(val, cls=DecimalEncoder) return b'\x01' + jarg.encode('utf-8') cdef bytes _encode_args(list args): cdef: WriteBuffer out_buf = WriteBuffer.new() if args: out_buf.write_int32(len(args)) for arg in args: out_buf.write_int32(0) # reserved if arg is None: out_buf.write_int32(-1) else: jval = _encode_json_value(arg) out_buf.write_int32(len(jval)) out_buf.write_bytes(jval) return bytes(out_buf) cdef _check_for_ise(exc): # Unwrap ExceptionGroup that has only one Exception if isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: exc = exc.exceptions[0] if not isinstance(exc, errors.EdgeDBError): # TODO(rename): change URL once we can nexc = errors.InternalServerError( f'{type(exc).__name__}: {exc}', hint=( f'This is most likely a bug in Gel. ' f'Please consider opening an issue ticket ' f'at https://github.com/edgedb/edgedb/issues/new' f'?template=bug_report.md' ), ).with_traceback(exc.__traceback__) formatted = getattr(exc, '__formatted_error__', None) if formatted: nexc.__formatted_error__ = formatted if isinstance(exc, BaseExceptionGroup): nexc.__cause__ = exc.with_traceback(None) exc = nexc return exc async def interpret_error( exc: Exception, db: dbview.Database, *, global_schema_pickle: object=None, user_schema_pickle: object=None, from_graphql: bool=False, ) -> Exception: if isinstance(exc, RecursionError): exc = errors.UnsupportedFeatureError( "The query caused the compiler " "stack to overflow. It is likely too deeply nested.", hint=( "If the query does not contain deep nesting, " "this may be a bug." ), ) elif isinstance(exc, pgerror.BackendError): try: from_sql = getattr(exc, '_from_sql', False) source_map = getattr(exc, '_source_map', None) fields = exc.fields static_exc = errormech.static_interpret_backend_error( fields, from_graphql=from_graphql ) # only use the backend if schema is required if static_exc is errormech.SchemaRequired: # Grab the schema from the exception first, if it is present. user_schema_pickle = ( getattr(exc, '_user_schema', None) or user_schema_pickle or db.user_schema_pickle ) global_schema_pickle = ( global_schema_pickle or db._index._global_schema_pickle ) compiler_pool = db._index._server.get_compiler_pool() exc = await compiler_pool.interpret_backend_error( user_schema_pickle, global_schema_pickle, fields, from_graphql, ) elif isinstance(static_exc, ( errors.DuplicateDatabaseDefinitionError, errors.UnknownDatabaseError)): tenant_id = db.tenant.tenant_id message = static_exc.args[0].replace(f'{tenant_id}_', '') exc = type(static_exc)(message) else: exc = static_exc if from_sql and isinstance(exc, errors.InternalServerError): exc = errors.ExecutionError(*exc.args) # Translate error position for SQL queries if we can if source_map and isinstance(exc, errors.EdgeDBError): if 'P' in fields: errors.EdgeDBError.set_position( exc, source_map.translate(int(fields['P'])), None, ) # Include hint/detail from SQL queries also, if we haven't # produced our own. if from_sql and isinstance(exc, errors.EdgeDBError): if 'H' in fields or 'D' in fields: hint = exc.hint or fields.get('H') details = exc.details or fields.get('D') # ... there is some sort of cython bug/"feature" # involving the type annotation above which causes # exc.set_hint_and_details to fail, so we copy it # to a new variable. exc2: object = exc exc2.set_hint_and_details(hint, details) except Exception as e: from edb.common import debug if debug.flags.server: debug.dump(e) exc = RuntimeError( 'unhandled error while calling interpret_backend_error(); ' 'run with EDGEDB_DEBUG_SERVER to debug.') elif isinstance(exc, parser_errors.PSqlParseError): exc = errormech.static_interpret_psql_parse_error(exc) return _check_for_ise(exc) def interpret_simple_error( exc: Exception, ) -> Exception: """Intepret a protocol error not associated with a query or schema""" if isinstance(exc, pgerror.BackendError): static_exc = errormech.static_interpret_backend_error(exc.fields) if static_exc is not errormech.SchemaRequired: exc = static_exc return _check_for_ise(exc) ================================================ FILE: edb/server/protocol/frontend.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from edb.server.dbview cimport dbview from edb.server.pgcon cimport pgcon from edb.server.pgproto.pgproto cimport ReadBuffer, WriteBuffer cdef class AbstractFrontendConnection: cdef write(self, WriteBuffer buf) cdef flush(self) cdef class FrontendConnection(AbstractFrontendConnection): cdef: str _id object server readonly object tenant object loop readonly str dbname str username dbview.Database database pgcon.PGConnection _pinned_pgcon bint _pinned_pgcon_in_tx int _get_pgcon_cc object _transport WriteBuffer _write_buf object _write_waiter object connection_made_at int _query_count ReadBuffer buffer object _msg_take_waiter object started_idling_at bint idling bint _passive_mode bint authed object _main_task bint _cancelled bint _stop_requested bint _pgcon_released_in_connection_lost bint debug object _transport_proto bint _external_auth cdef _after_idling(self) cdef _main_task_created(self) cdef _main_task_stopped_normally(self) cdef write_error(self, exc) cdef stop_connection(self) cdef abort_pinned_pgcon(self) cdef is_in_tx(self) cdef WriteBuffer _make_authentication_sasl_initial(self, list methods) cdef _expect_sasl_initial_response(self) cdef WriteBuffer _make_authentication_sasl_msg( self, bytes data, bint final) cdef bytes _expect_sasl_response(self) ================================================ FILE: edb/server/protocol/frontend.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import asyncio import contextlib import logging import time from edgedb import scram from edb import errors from edb.common import debug from edb.server import defines from edb.server import args as srvargs, metrics from edb.server.pgcon import errors as pgerror from . cimport auth_helpers DEF FLUSH_BUFFER_AFTER = 100_000 cdef object logger = logging.getLogger('edb.server') cdef class AbstractFrontendConnection: cdef write(self, WriteBuffer buf): raise NotImplementedError cdef flush(self): raise NotImplementedError cdef class FrontendConnection(AbstractFrontendConnection): interface = "frontend" def __init__( self, server, tenant, *, passive: bool, transport: srvargs.ServerConnTransport, external_auth: bool, connection_made_at: float | None = None, ): self._id = server.on_binary_client_created() self.server = server self.tenant = tenant self.loop = server.get_loop() self.dbname = None self._pinned_pgcon = None self._pinned_pgcon_in_tx = False self._get_pgcon_cc = 0 self.connection_made_at = connection_made_at self._query_count = 0 self._transport = None self._write_buf = None self._write_waiter = None self.buffer = ReadBuffer() self._msg_take_waiter = None self.idling = False self.started_idling_at = 0.0 # In "passive" mode the protocol is instantiated to parse and execute # just what's in the buffer. It cannot "wait for message". This # is used to implement binary protocol over http+fetch. self._passive_mode = passive self.authed = False self._main_task = None self._cancelled = False self._stop_requested = False self._pgcon_released_in_connection_lost = False self.debug = debug.flags.server_proto self._transport_proto = transport self._external_auth = external_auth def get_id(self): return self._id cdef is_in_tx(self): return False # backend connection def __del__(self): # Should not ever happen, there's a strong ref to # every client connection until it hits connection_lost(). if self._pinned_pgcon is not None: # XXX/TODO: add test diagnostics for this and # fail all tests if this ever happens. self.abort_pinned_pgcon() async def get_pgcon(self) -> pgcon.PGConnection: if self._cancelled or self._pgcon_released_in_connection_lost: raise RuntimeError( 'cannot acquire a pgconn; the connection is closed') self._get_pgcon_cc += 1 try: if self._get_pgcon_cc > 1: raise RuntimeError('nested get_pgcon() calls are prohibited') if self.is_in_tx(): # In transaction. We must have a working pinned connection. if not self._pinned_pgcon_in_tx or self._pinned_pgcon is None: raise RuntimeError( 'get_pgcon(): in dbview transaction, ' 'but `_pinned_pgcon` is None') return self._pinned_pgcon if self._pinned_pgcon is not None: raise RuntimeError('there is already a pinned pgcon') conn = await self.tenant.acquire_pgcon(self.dbname) self._pinned_pgcon = conn conn.pinned_by = self return conn except Exception: self._get_pgcon_cc -= 1 raise def maybe_release_pgcon(self, pgcon.PGConnection conn): self._get_pgcon_cc -= 1 if self._get_pgcon_cc < 0: raise RuntimeError( 'maybe_release_pgcon() called more times than get_pgcon()') if self._pinned_pgcon is not conn: raise RuntimeError('mismatched released connection') if self.is_in_tx(): if self._cancelled: # There could be a situation where we cancel the protocol while # it's in a transaction. In which case we want to immediately # return the connection to the pool (where it would be # discarded and re-opened.) conn.pinned_by = None self._pinned_pgcon = None if not self._pgcon_released_in_connection_lost: self.tenant.release_pgcon( self.dbname, conn, discard=debug.flags.server_clobber_pg_conns, ) else: self._pinned_pgcon_in_tx = True else: conn.pinned_by = None self._pinned_pgcon_in_tx = False self._pinned_pgcon = None if not self._pgcon_released_in_connection_lost: self.tenant.release_pgcon( self.dbname, conn, discard=debug.flags.server_clobber_pg_conns, ) @contextlib.asynccontextmanager async def with_pgcon(self): con = await self.get_pgcon() try: yield con finally: self.maybe_release_pgcon(con) def on_aborted_pgcon(self, pgcon.PGConnection conn): try: self._pinned_pgcon = None if not self._pgcon_released_in_connection_lost: self.tenant.release_pgcon(self.dbname, conn, discard=True) if conn.aborted_with_error is not None: self.write_error(conn.aborted_with_error) finally: self.close() # will flush cdef abort_pinned_pgcon(self): if self._pinned_pgcon is not None: self._pinned_pgcon.pinned_by = None self._pinned_pgcon.abort() self.tenant.release_pgcon( self.dbname, self._pinned_pgcon, discard=True) self._pinned_pgcon = None # I/O write methods, implements AbstractFrontendConnection cdef write(self, WriteBuffer buf): # One rule for this method: don't write partial messages. if self._write_buf is not None: self._write_buf.write_buffer(buf) if self._write_buf.len() >= FLUSH_BUFFER_AFTER: self.flush() else: self._write_buf = buf cdef flush(self): if self._transport is None: # could be if the connection is lost and a coroutine # method is finalizing. raise ConnectionAbortedError if self._write_buf is not None and self._write_buf.len(): buf = self._write_buf self._write_buf = None self._transport.write(memoryview(buf)) def pause_writing(self): if self._write_waiter and not self._write_waiter.done(): return self._write_waiter = self.loop.create_future() def resume_writing(self): if not self._write_waiter or self._write_waiter.done(): return self._write_waiter.set_result(True) # I/O read methods def data_received(self, data): self.buffer.feed_data(data) if self._msg_take_waiter is not None and self.buffer.take_message(): self._msg_take_waiter.set_result(True) self._msg_take_waiter = None def eof_received(self): pass cdef _after_idling(self): # Hook for EdgeConnection pass async def wait_for_message(self, *, bint report_idling): if self.buffer.take_message(): return if self._passive_mode: raise RuntimeError('cannot wait for more messages in passive mode') if self._transport is None: # could be if the connection is lost and a coroutine # method is finalizing. raise ConnectionAbortedError self._msg_take_waiter = self.loop.create_future() if report_idling: self.idling = True self.started_idling_at = time.monotonic() try: await self._msg_take_waiter finally: self.idling = False self._after_idling() def is_idle(self, expiry_time: float): # A connection is idle if it awaits for the next message for # client for too long (even if it is in an open transaction!) return self.idling and self.started_idling_at < expiry_time # establishing a new connection cdef _main_task_created(self): pass cdef _main_task_stopped_normally(self): pass def get_tenant_label(self): if self.tenant is None: return "unknown" else: return self.tenant.get_instance_name() def connection_made(self, transport): if self.tenant is None: self._transport = transport self._main_task = self.loop.create_task(self.handshake()) self._main_task_created() elif self.tenant.is_accepting_connections(): self._transport = transport self._main_task = self.tenant.create_task( self.main(), interruptable=False ) self._main_task_created() else: transport.abort() async def handshake(self): try: await self._handshake() except Exception as ex: if self._transport is not None: # If there's no transport it means that the connection # was aborted, in which case we don't really care about # reporting the exception. self.write_error(ex) self.close() if not isinstance(ex, (errors.ProtocolError, errors.AuthenticationError)): self.loop.call_exception_handler({ 'message': ( f'unhandled error in {self.__class__.__name__} while ' 'accepting new connection' ), 'exception': ex, 'protocol': self, 'transport': self._transport, 'task': self._main_task, }) async def _handshake(self): if self.tenant is None: self.tenant = self.server.get_default_tenant() if self.tenant.is_accepting_connections(): self._main_task = self.tenant.create_task( self.main(), interruptable=False ) else: if self._transport is not None: self._transport.abort() # main skeleton async def main_step(self, char mtype): raise NotImplementedError cdef write_error(self, exc): raise NotImplementedError async def main(self): cdef char mtype try: await self.authenticate() except Exception as ex: if self._transport is not None: # If there's no transport it means that the connection # was aborted, in which case we don't really care about # reporting the exception. self.write_error(ex) self.close() if not isinstance(ex, (errors.ProtocolError, errors.AuthenticationError)): self.loop.call_exception_handler({ 'message': ( f'unhandled error in {self.__class__.__name__} while ' 'accepting new connection' ), 'exception': ex, 'protocol': self, 'transport': self._transport, 'task': self._main_task, }) return self.authed = True try: while True: if self._cancelled: self.abort() return if self._stop_requested: break if not self.buffer.take_message(): if self._passive_mode: # In "passive" mode we only parse what's in the buffer # and return. If there's any unparsed (incomplete) data # in the buffer it's an error. if self.buffer._length: raise RuntimeError( 'unparsed data in the read buffer') # Flush whatever data is in the internal buffer before # returning. self.flush() return await self.wait_for_message(report_idling=True) mtype = self.buffer.get_message_type() if await self.main_step(mtype): break except asyncio.CancelledError: # Happens when the connection is aborted, the backend is # being closed and propagates CancelledError to all # EdgeCon methods that await on, say, the compiler process. # We shouldn't have CancelledErrors otherwise, therefore, # in this situation we just silently exit. pass except ConnectionError: metrics.connection_errors.inc( 1.0, self.get_tenant_label(), ) except pgerror.BackendQueryCancelledError: pass except Exception as ex: # We can only be here if an exception occurred during # handling another exception, in which case, the only # sane option is to abort the connection. self.loop.call_exception_handler({ 'message': ( 'unhandled error in edgedb protocol while ' 'handling an error' ), 'exception': ex, 'protocol': self, 'transport': self._transport, 'task': self._main_task, }) finally: if self._stop_requested: self._main_task_stopped_normally() self.close() else: # Abort the connection. # It might have already been cleaned up, but abort() is # safe to be called on a closed connection. self.abort() # shutting down the connection cdef stop_connection(self): pass def close(self): self.abort_pinned_pgcon() self.stop_connection() if self._transport is not None: self.flush() self._transport.close() self._transport = None def abort(self): self.abort_pinned_pgcon() self.stop_connection() if self._transport is not None: self._transport.abort() self._transport = None def request_stop(self): # Actively stop a frontend connection - this is used by the server # when it's stopping. self._stop_requested = True if self._msg_take_waiter is not None: if not self._msg_take_waiter.done(): self._msg_take_waiter.cancel() @property def cancelled(self) -> bool: return self._cancelled def is_alive(self): return self._transport is not None and not self._cancelled def connection_lost(self, exc): # Let's talk about cancellation. # # 1. Since we need to synchronize the state between Postgres and # EdgeDB, we need to make sure we never do straight asyncio # cancellation while some operation in pgcon is in flight. # # Doing that can lead to the following few bad scenarios: # # * pgcon connction being wrecked by asyncio.CancelledError; # # * pgcon completing its operation and then, a rogue # CancelledError preventing us to apply the new state # to dbview/server config/etc. # # 2. It is safe to cancel `_msg_take_waiter` though. Cancelling it # would abort protocol parsing, but there's no global state that # needs syncing in protocol messages. # # 3. We can interrupt some operations like auth with a CancelledError. # Again, those operations don't mutate global state. if self.connection_made_at is not None: tenant_label = self.get_tenant_label() metrics.client_connection_duration.observe( time.monotonic() - self.connection_made_at, tenant_label, self.interface, ) if self.authed: metrics.queries_per_connection.observe( self._query_count, tenant_label, self.interface ) if isinstance(exc, ConnectionError): metrics.connection_errors.inc(1.0, tenant_label) if (self._msg_take_waiter is not None and not self._msg_take_waiter.done()): # We're parsing the protocol. We can abort that. self._msg_take_waiter.cancel() if ( self._main_task is not None and not self._main_task.done() and not self._cancelled ): # The main connection handling task is up and running. # First, let's set a flag to signal that we should cancel soon; # after all the client has already disconnected. self._cancelled = True # Make sure nothing is blocked on flow control. # (Currently only dump uses this.) self.resume_writing() if not self.authed: # We must be still authenticating. We can abort that. self._main_task.cancel() else: if ( self._pinned_pgcon is not None and not self._pinned_pgcon.idle ): # Looks like we have a Postgres connection acquired and # it's actively running some command for us. To make # sure we're not leaving behind a heavy query, perform # an explicit Postgres cancellation because a mere # connection drop wouldn't necessarily abort the query # right away). Additionally, we must discard the connection # as we cannot be completely sure about its state. Postgres # cancellation is signal-based and is addressed to a whole # connection and not a concrete operation. The result is # that we might be racing with the currently running query # and if that completes before the cancellation signal # reaches the backend, we'll be setting a trap for the # _next_ query that is unlucky enough to pick up this # Postgres backend from the connection pool. # TODO(fantix): hold server shutdown to complete this task if self.tenant.accept_new_tasks: self.tenant.create_task( self.tenant.cancel_and_discard_pgcon( self._pinned_pgcon, self.dbname ), interruptable=False, ) # Prevent the main task from releasing the same connection # twice. This flag is for now only used in this case. self._pgcon_released_in_connection_lost = True # In all other cases, we can just wait until the `main()` # coroutine notices that `self._cancelled` was set. # It would be a mistake to cancel the main task here, as it # could be unpacking results from pgcon and applying them # to the global state. # # Ultimately, the main() coroutine will be aborted, eventually, # and will call `self.abort()` to shut all things down. else: # The `main()` coroutine isn't running, it means that the # connection is already pretty much dead. Nonetheless, call # abort() to make sure we've cleaned everything up properly. self.abort() # Authentication async def authenticate(self): raise NotImplementedError def _auth_jwt(self, user, database, params): raise NotImplementedError def _auth_trust(self, user): roles = self.tenant.get_roles() if user not in roles: raise errors.AuthenticationError('authentication failed') async def _authenticate(self, user, database, params): # The user has already been authenticated by other means # (such as the ability to write to a protected socket). if self._external_auth: authmethods = [ self.server.config_settings.get_type_by_name('cfg::Trust')() ] else: authmethods = await self.tenant.get_auth_methods( user, self._transport_proto) auth_errors = {} for authmethod in authmethods: authmethod_name = authmethod._tspec.name.split('::')[1] try: if authmethod_name == 'SCRAM': await self._auth_scram(user) elif authmethod_name == 'JWT': self._auth_jwt(user, database, params) elif authmethod_name == 'Trust': self._auth_trust(user) elif authmethod_name == 'Password': raise errors.AuthenticationError( 'authentication failed: ' 'Simple password authentication required but it is ' 'only supported for HTTP endpoints' ) elif authmethod_name == 'mTLS': auth_helpers.auth_mtls_with_user(self._transport, user) else: raise errors.InternalServerError( f'unimplemented auth method: {authmethod_name}') except errors.AuthenticationError as e: auth_errors[authmethod_name] = e else: break if len(auth_errors) == len(authmethods): if len(auth_errors) > 1: desc = "; ".join( f"{k}: {e.args[0]}" for k, e in auth_errors.items()) raise errors.AuthenticationError( f"all authentication methods failed: {desc}") else: raise next(iter(auth_errors.values())) role = self.tenant.get_roles().get(user) if not role: raise errors.AuthenticationError('authentication failed') branches = role['branches'] if ( '*' not in branches and database not in branches and database != defines.EDGEDB_SYSTEM_DB ): raise errors.AuthenticationError( f"authentication failed: user does not have permission for " f"database branch '{database}'" ) cdef WriteBuffer _make_authentication_sasl_initial(self, list methods): raise NotImplementedError cdef _expect_sasl_initial_response(self): raise NotImplementedError cdef WriteBuffer _make_authentication_sasl_msg( self, bytes data, bint final ): raise NotImplementedError cdef bytes _expect_sasl_response(self): raise NotImplementedError async def _auth_scram(self, user): cdef WriteBuffer msg_buf # Tell the client that we require SASL SCRAM auth. msg_buf = self._make_authentication_sasl_initial([b'SCRAM-SHA-256']) self.write(msg_buf) self.flush() selected_mech = None verifier = None mock_auth = False client_nonce = None cb_flag = None done = False while not done: if not self.buffer.take_message(): await self.wait_for_message(report_idling=True) mtype = self.buffer.get_message_type() if selected_mech is None: # Initial response. ( selected_mech, client_first ) = self._expect_sasl_initial_response() if selected_mech != b'SCRAM-SHA-256': raise errors.BinaryProtocolError( f'client selected an invalid SASL authentication ' f'mechanism') verifier, mock_auth = auth_helpers.scram_get_verifier( self.tenant, user) try: bare_offset, cb_flag, authzid, username, client_nonce = ( scram.parse_client_first_message(client_first)) except ValueError as e: raise errors.BinaryProtocolError(str(e)) client_first_bare = client_first[bare_offset:] if isinstance(cb_flag, str): raise errors.BinaryProtocolError( 'malformed SCRAM message', details='The client selected SCRAM-SHA-256 without ' 'channel binding, but the SCRAM message ' 'includes channel binding data.') if authzid: raise errors.UnsupportedFeatureError( 'client uses SASL authorization identity, ' 'which is not supported') server_nonce = scram.generate_nonce() server_first = scram.build_server_first_message( server_nonce, client_nonce, verifier.salt, verifier.iterations).encode('utf-8') msg_buf = self._make_authentication_sasl_msg(server_first, 0) self.write(msg_buf) self.flush() else: # client final message client_final = self._expect_sasl_response() try: cb_data, client_proof, proof_len = ( scram.parse_client_final_message( client_final, client_nonce, server_nonce)) except ValueError as e: raise errors.BinaryProtocolError(str(e)) from None client_final_without_proof = client_final[:-proof_len] cb_data_ok = ( (cb_flag is False and cb_data == b'biws') or (cb_flag is True and cb_data == b'eSws') ) if not cb_data_ok: raise errors.BinaryProtocolError( 'malformed SCRAM message', details='Unexpected SCRAM channel-binding attribute ' 'in client-final-message.') if not scram.verify_client_proof( client_first_bare, server_first, client_final_without_proof, verifier.stored_key, client_proof): raise errors.AuthenticationError( 'authentication failed') if mock_auth: # This user actually does not exist, so fail here. raise errors.AuthenticationError( 'authentication failed') server_final = scram.build_server_final_message( client_first_bare, server_first, client_final_without_proof, verifier.server_key, ).encode('utf-8') # AuthenticationSASLFinal msg_buf = self._make_authentication_sasl_msg(server_final, 1) self.write(msg_buf) self.flush() done = True ================================================ FILE: edb/server/protocol/metrics.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import TYPE_CHECKING import http from edb import errors from edb.server import metrics from edb.server import server from edb.common import debug from edb.common import markup if TYPE_CHECKING: from edb.server import tenant as edbtenant from edb.server.protocol import protocol async def handle_request( request: protocol.HttpRequest, response: protocol.HttpResponse, tenant: edbtenant.Tenant, ) -> None: try: if tenant is None or isinstance(tenant.server, server.Server): output = metrics.registry.generate() else: output = metrics.registry.generate( tenant=tenant.get_instance_name() ) response.status = http.HTTPStatus.OK response.content_type = b'text/plain; version=0.0.4; charset=utf-8' response.body = output.encode() response.close_connection = True except Exception as ex: if debug.flags.server: markup.dump(ex) # XXX Fix this when LSP "location" objects are implemented ex_type = errors.InternalServerError _response_error( response, http.HTTPStatus.INTERNAL_SERVER_ERROR, str(ex), ex_type ) def _response_error( response: protocol.HttpResponse, status: http.HTTPStatus, message: str, ex_type: type[errors.EdgeDBError], ) -> None: response.body = ( f'Unexpected error in /metrics.\n\n' f'{ex_type.__name__}: {message}' ).encode() response.status = status response.close_connection = True ================================================ FILE: edb/server/protocol/notebook_ext.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from edb.server.protocol cimport frontend from edb.server.pgproto.pgproto cimport WriteBuffer cdef class NotebookConnection(frontend.AbstractFrontendConnection): cdef: WriteBuffer buf cdef bytes _get_data(self) ================================================ FILE: edb/server/protocol/notebook_ext.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import base64 import http import json import urllib.parse import immutables from edb import errors from edb.server.pgcon import errors as pgerrors from edb.common import debug from edb.common import markup from edb.server import compiler from edb.server import defines as edbdef from edb.server.compiler import OutputFormat from edb.server.compiler import dbstate from edb.server.compiler import enums from edb.server.protocol import execute as p_execute from edb.server.dbview cimport dbview from edb.server.protocol cimport frontend from edb.server.pgproto.pgproto cimport ( WriteBuffer, ) include "./consts.pxi" cdef tuple CURRENT_PROTOCOL = edbdef.CURRENT_PROTOCOL ALLOWED_CAPABILITIES = ( enums.Capability.MODIFICATIONS | enums.Capability.DDL ) cdef handle_error( object request, object response, error ): if debug.flags.server: markup.dump(error) er_type = type(error) if not issubclass(er_type, errors.EdgeDBError): er_type = errors.InternalServerError response.body = json.dumps({ 'kind': 'error', 'error': { 'message': str(error), 'type': er_type.__name__, } }).encode() response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True async def handle_request( object request, object response, object db, str role_name, list args, object tenant, ): response.content_type = b'application/json' if args == ['status'] and request.method == b'GET': try: await heartbeat_check(db, tenant) except Exception as ex: return handle_error(request, response, ex) else: response.status = http.HTTPStatus.OK response.body = b'{"kind": "status", "status": "OK"}' return if args != []: ex = Exception(f'Unknown path') return handle_error(request, response, ex) queries = None try: if request.method == b'POST': body = json.loads(request.body) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') queries = body.get('queries') else: raise TypeError('expected a POST request') if not queries: raise TypeError( 'invalid notebook request: "queries" is missing') except Exception as ex: return handle_error(request, response, ex) response.status = http.HTTPStatus.OK try: result = await execute(db, role_name, tenant, queries) except Exception as ex: return handle_error(request, response, ex) else: response.custom_headers['EdgeDB-Protocol-Version'] = \ f'{CURRENT_PROTOCOL[0]}.{CURRENT_PROTOCOL[1]}' response.custom_headers['Gel-Protocol-Version'] = \ f'{CURRENT_PROTOCOL[0]}.{CURRENT_PROTOCOL[1]}' response.body = b'{"kind": "results", "results":' + result + b'}' async def heartbeat_check(db, tenant): async with tenant.with_pgcon(db.name) as pgcon: await pgcon.sql_execute(b"SELECT 'OK';") cdef class NotebookConnection(frontend.AbstractFrontendConnection): def __cinit__(self): self.buf = WriteBuffer.new() cdef write(self, WriteBuffer data): self.buf.write_bytes(bytes(data)) cdef bytes _get_data(self): return bytes(self.buf) cdef flush(self): pass async def execute(db, role_name, tenant, queries: list): dbv: dbview.DatabaseConnectionView = await tenant.new_dbview( dbname=db.name, query_cache=False, protocol_version=edbdef.CURRENT_PROTOCOL, role_name=role_name, ) compiler_pool = tenant.server.get_compiler_pool() units = await compiler_pool.compile_notebook( dbv.dbname, dbv.get_user_schema_pickle(), dbv.get_global_schema_pickle(), dbv.reflection_cache, dbv.get_database_config(), dbv.get_compilation_system_config(), queries, CURRENT_PROTOCOL, 50, # implicit limit client_id=tenant.client_id, client_name=tenant.get_instance_name(), ) result = [] bind_data = None async with tenant.with_pgcon(db.name) as pgcon: try: await pgcon.sql_execute(b'START TRANSACTION;') dbv.start_tx() for is_error, unit_or_error in units: if is_error: result.append({ 'kind': 'error', 'error': unit_or_error, }) else: query_unit = unit_or_error query_unit_group = dbstate.QueryUnitGroup() query_unit_group.append(query_unit) dbv.check_capabilities( query_unit, ALLOWED_CAPABILITIES, errors.UnsupportedCapabilityError, "disallowed in notebook", query_unit_group.unsafe_isolation_dangers, ) try: if query_unit.in_type_args: raise errors.QueryError( 'cannot use query parameters in tutorial') fe_conn = NotebookConnection() compiled = dbview.CompiledQuery( query_unit_group=query_unit_group) await p_execute.execute( pgcon, dbv, compiled, b'', fe_conn=fe_conn, ) except Exception as ex: if debug.flags.server: markup.dump(ex) ex = await p_execute.interpret_error( ex, dbv._db, global_schema_pickle=dbv.get_global_schema_pickle(), user_schema_pickle=dbv.get_user_schema_pickle(), ) result.append({ 'kind': 'error', 'error': [type(ex).__name__, str(ex), {}], }) break else: result.append({ 'kind': 'data', 'data': ( base64.b64encode( query_unit.out_type_id).decode(), base64.b64encode( query_unit.out_type_data).decode(), base64.b64encode( fe_conn._get_data()).decode(), base64.b64encode( query_unit.status).decode(), ), }) finally: try: await pgcon.sql_execute(b'ROLLBACK;') finally: tenant.remove_dbview(dbv) return json.dumps(result).encode() ================================================ FILE: edb/server/protocol/pg_ext.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from edb.server.pgproto.pgproto cimport WriteBuffer from edb.server.protocol cimport frontend from edb.server.pgcon.pgcon cimport PGMessage cimport edb.pgsql.parser.parser as pg_parser cdef class PreparedStmt: cdef: PGMessage parse_action pg_parser.Source source cdef class ConnectionView: cdef: object _local_fe_defaults object _settings object _fe_settings bint _in_tx_explicit bint _in_tx_implicit object _in_tx_settings object _in_tx_fe_settings object _in_tx_fe_local_settings dict _in_tx_portals object _in_tx_new_portals object _in_tx_savepoints bint _tx_error tuple _session_state_db_cache cdef _init_user_configs(self, username, tenant) cpdef inline current_fe_settings(self) cdef inline fe_transaction_state(self) cpdef inline bint in_tx(self) cdef inline _reset_tx_state( self, bint chain_implicit, bint chain_explicit ) cdef bint needs_commit_after_state_sync(self) cpdef inline close_portal_if_exists(self, str name) cpdef inline close_portal(self, str name) cdef inline find_portal(self, str name) cdef inline portal_exists(self, str name) cdef class PgConnection(frontend.FrontendConnection): cdef: ConnectionView _dbview bytes secret dict prepared_stmts dict sql_prepared_stmts dict sql_prepared_stmts_map dict wrapping_prepared_stmts bint ignore_till_sync object sslctx object endpoint_security bint is_tls bint _disable_cache bint _disable_normalization cdef inline WriteBuffer ready_for_query(self) ================================================ FILE: edb/server/protocol/pg_ext.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2022-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import codecs import collections import contextlib import copy import encodings.aliases import logging import hashlib import json import os import sys import time import uuid from collections import deque from typing import Sequence cimport cython import immutables from libc.stdint cimport int32_t, int16_t, uint32_t from edb import errors from edb.common import debug from edb.common.log import current_tenant from edb.pgsql.common import setting_to_sql from edb.pgsql.parser import exceptions as parser_errors import edb.pgsql.parser.parser as pg_parser cimport edb.pgsql.parser.parser as pg_parser from edb.server import args as srvargs from edb.server import defines, metrics from edb.server import tenant as edbtenant from edb.server.compiler import dbstate, enums from edb.server.pgcon import errors as pgerror from edb.server.pgcon.pgcon cimport PGAction, PGMessage from edb.server.protocol cimport frontend DEFAULT_SETTINGS = dbstate.DEFAULT_SQL_SETTINGS DEFAULT_FE_SETTINGS = dbstate.DEFAULT_SQL_FE_SETTINGS cdef object logger = logging.getLogger('edb.server') cdef object DEFAULT_STATE = None encodings.aliases.aliases["sql_ascii"] = "ascii" class ExtendedQueryError(Exception): pass @contextlib.contextmanager def managed_error(): try: yield except Exception as e: raise ExtendedQueryError(e) @cython.final cdef class PreparedStmt: def __init__(self, PGMessage parse_action, pg_parser.Source source): self.parse_action = parse_action self.source = source @cython.final cdef class ConnectionView: def __init__(self): self._in_tx_explicit = False self._in_tx_implicit = False # Kepp track of backend settings so that we can sync to use different # backend connections (pgcon) within the same frontend connection, # see serialize_state() below and its usages in pgcon.pyx. self._settings = DEFAULT_SETTINGS self._in_tx_settings = None # Frontend-only settings are defined by the high-level compiler, and # tracked only here, syncing between the compiler process, # see current_fe_settings(), fe_transaction_state() and usages below. self._local_fe_defaults = DEFAULT_FE_SETTINGS self._fe_settings = self._local_fe_defaults self._in_tx_fe_settings = None self._in_tx_fe_local_settings = None self._in_tx_portals = {} self._in_tx_new_portals = set() self._in_tx_savepoints = collections.deque() self._tx_error = False global DEFAULT_STATE if DEFAULT_STATE is None: DEFAULT_STATE = json.dumps( [ { "type": "P", "name": key, "value": setting_to_sql(key, val), } for key, val in DEFAULT_SETTINGS.items() ] + [ { "type": "S", "name": key, "value": setting_to_sql(key, val), } for key, val in DEFAULT_FE_SETTINGS.items() ] ).encode("utf-8") self._session_state_db_cache = ( DEFAULT_SETTINGS, DEFAULT_FE_SETTINGS, DEFAULT_STATE ) cdef _init_user_configs(self, username, tenant): assert self._fe_settings is DEFAULT_FE_SETTINGS assert self._in_tx_fe_local_settings is None role = tenant.get_roles()[username] apply_default = role['apply_access_policies_pg_default'] if apply_default is not None: self._local_fe_defaults = self._local_fe_defaults.set( 'apply_access_policies_pg', (str(apply_default).lower(),), ) self._fe_settings = self._local_fe_defaults cpdef inline current_fe_settings(self): if self.in_tx(): # For easier access, _in_tx_fe_local_settings is always a superset # of _in_tx_fe_settings; _in_tx_fe_settings only keeps track of # non-local settings, so that the local settings don't go across # transaction boundaries; this must be consistent with dbstate.py. return self._in_tx_fe_local_settings or self._local_fe_defaults else: return self._fe_settings cdef inline fe_transaction_state(self): return dbstate.SQLTransactionState( in_tx=self.in_tx(), settings=self._fe_settings, in_tx_settings=self._in_tx_fe_settings, in_tx_local_settings=self._in_tx_fe_local_settings, savepoints=[sp[:3] for sp in self._in_tx_savepoints], ) cpdef inline bint in_tx(self): return self._in_tx_explicit or self._in_tx_implicit cdef inline _reset_tx_state( self, bint chain_implicit, bint chain_explicit ): # This method is a part of ending a transaction. COMMIT must be handled # before calling this method. If any of the chain_* flag is set, a new # transaction will be opened immediately after clean-up. self._in_tx_implicit = chain_implicit self._in_tx_explicit = chain_explicit self._in_tx_settings = self._settings if self.in_tx() else None self._in_tx_fe_settings = self._fe_settings if self.in_tx() else None self._in_tx_fe_local_settings = self._in_tx_fe_settings self._in_tx_portals.clear() self._in_tx_new_portals.clear() self._in_tx_savepoints.clear() self._tx_error = False def start_implicit(self): if self._in_tx_implicit: raise RuntimeError("already in implicit transaction") else: if not self.in_tx(): self._in_tx_settings = self._settings self._in_tx_fe_settings = self._fe_settings self._in_tx_fe_local_settings = self._fe_settings self._in_tx_implicit = True def end_implicit(self): if not self._in_tx_implicit: raise RuntimeError("not in implicit transaction") if self._in_tx_explicit: # There is an explicit transaction, nothing to do other than # turning off the implicit flag so that we can start_implicit again self._in_tx_implicit = False else: # Commit or rollback the implicit transaction if not self._tx_error: self._settings = self._in_tx_settings self._fe_settings = self._in_tx_fe_settings self._reset_tx_state(False, False) def on_success(self, unit: dbstate.SQLQueryUnit): # Handle ROLLBACK first before self._tx_error if unit.tx_action == dbstate.TxAction.ROLLBACK: if not self._in_tx_explicit: # TODO: warn about "no tx" but still rollback implicit pass self._reset_tx_state(self._in_tx_implicit, unit.tx_chain) elif unit.tx_action == dbstate.TxAction.ROLLBACK_TO_SAVEPOINT: if not self._in_tx_explicit: if self._in_tx_implicit: self._tx_error = True raise errors.TransactionError( "ROLLBACK TO SAVEPOINT can only be used " "in transaction blocks" ) while self._in_tx_savepoints: ( sp_name, fe_settings, fe_local_settings, settings, new_portals, ) = self._in_tx_savepoints[-1] for name in new_portals: self._in_tx_portals.pop(name, None) if sp_name == unit.sp_name: new_portals.clear() self._in_tx_settings = settings self._in_tx_fe_settings = fe_settings self._in_tx_fe_local_settings = fe_local_settings self._in_tx_new_portals = new_portals break else: self._in_tx_savepoints.pop() else: self._tx_error = True raise errors.TransactionError( f'savepoint "{unit.sp_name}" does not exist' ) elif self._tx_error: raise errors.TransactionError( "current transaction is aborted, " "commands ignored until end of transaction block" ) elif unit.tx_action == dbstate.TxAction.START: if self._in_tx_explicit: # TODO: warning: there is already a transaction in progress pass else: if not self.in_tx(): self._in_tx_settings = self._settings self._in_tx_fe_settings = self._fe_settings self._in_tx_fe_local_settings = self._fe_settings self._in_tx_explicit = True elif unit.tx_action == dbstate.TxAction.COMMIT: if not self._in_tx_explicit: # TODO: warning: there is no transaction in progress # but we still commit implicit transactions if any pass if self.in_tx(): self._settings = self._in_tx_settings self._fe_settings = self._in_tx_fe_settings self._reset_tx_state(self._in_tx_implicit, unit.tx_chain) elif unit.tx_action == dbstate.TxAction.DECLARE_SAVEPOINT: if not self._in_tx_explicit: raise errors.TransactionError( "SAVEPOINT can only be used in transaction blocks" ) self._in_tx_new_portals = set() self._in_tx_savepoints.append(( unit.sp_name, self._in_tx_fe_settings, self._in_tx_fe_local_settings, self._in_tx_settings, self._in_tx_new_portals, )) elif unit.tx_action == dbstate.TxAction.RELEASE_SAVEPOINT: pass if unit.set_vars: # only session settings here if unit.set_vars == {None: None}: # RESET ALL if self.in_tx(): self._in_tx_settings = DEFAULT_SETTINGS self._in_tx_fe_settings = self._local_fe_defaults self._in_tx_fe_local_settings = self._local_fe_defaults else: self._settings = DEFAULT_SETTINGS self._fe_settings = self._local_fe_defaults else: if self.in_tx(): if unit.frontend_only: if not unit.is_local: settings = self._in_tx_fe_settings.mutate() for k, v in unit.set_vars.items(): if v is None: if k in self._local_fe_defaults: settings[k] = self._local_fe_defaults[k] else: settings.pop(k, None) else: settings[k] = v self._in_tx_fe_settings = settings.finish() settings = self._in_tx_fe_local_settings.mutate() else: settings = self._in_tx_settings.mutate() elif not unit.is_local: if unit.frontend_only: settings = self._fe_settings.mutate() else: settings = self._settings.mutate() else: return for k, v in unit.set_vars.items(): if v is None: if unit.frontend_only and k in self._local_fe_defaults: settings[k] = self._local_fe_defaults[k] else: settings.pop(k, None) else: settings[k] = v if self.in_tx(): if unit.frontend_only: self._in_tx_fe_local_settings = settings.finish() else: self._in_tx_settings = settings.finish() else: if unit.frontend_only: self._fe_settings = settings.finish() else: self._settings = settings.finish() def on_error(self): self._tx_error = True cpdef inline close_portal(self, str name): try: return self._in_tx_portals.pop(name) except KeyError: raise pgerror.new( pgerror.ERROR_INVALID_CURSOR_NAME, f"cursor \"{name}\" does not exist", ) from None cpdef inline close_portal_if_exists(self, str name): return self._in_tx_portals.pop(name, None) def create_portal(self, str name, query_unit): if not self.in_tx(): raise RuntimeError( "portals cannot be created outside a transaction" ) if name and name in self._in_tx_portals: raise pgerror.new( pgerror.ERROR_DUPLICATE_CURSOR, f"cursor \"{name}\" already exists", ) self._in_tx_portals[name] = query_unit cdef inline find_portal(self, str name): try: return self._in_tx_portals[name] except KeyError: raise pgerror.new( pgerror.ERROR_INVALID_CURSOR_NAME, f"cursor \"{name}\" does not exist", ) from None cdef inline portal_exists(self, str name): return name in self._in_tx_portals def serialize_state(self): if self.in_tx(): raise errors.InternalServerError( 'no need to serialize state while in transaction') if ( self._settings == DEFAULT_SETTINGS and self._fe_settings == DEFAULT_FE_SETTINGS ): return DEFAULT_STATE if self._session_state_db_cache is not None: if self._session_state_db_cache[:2] == ( self._settings, self._fe_settings ): return self._session_state_db_cache[-1] rv = json.dumps( [ {"type": "P", "name": key, "value": setting_to_sql(key, val)} for key, val in self._settings.items() ] + [ {"type": "S", "name": key, "value": setting_to_sql(key, val)} for key, val in self._fe_settings.items() ] ).encode("utf-8") self._session_state_db_cache = (self._settings, self._fe_settings, rv) return rv cdef bint needs_commit_after_state_sync(self): return any( tx_conf in self._settings for tx_conf in [ "default_transaction_isolation", "default_transaction_deferrable", "default_transaction_read_only", ] ) cdef class PgConnection(frontend.FrontendConnection): interface = "sql" def __init__(self, server, sslctx, endpoint_security, **kwargs): super().__init__(server, None, **kwargs) self._dbview = ConnectionView() self._id = str((int(self._id) % (2 ** 32))) self.prepared_stmts = {} # via extended query Parse self.sql_prepared_stmts = {} # via a PREPARE statement self.sql_prepared_stmts_map = {} # Tracks prepared statements of operations # on *other* prepared statements. self.wrapping_prepared_stmts = {} self.ignore_till_sync = False self.sslctx = sslctx self.endpoint_security = endpoint_security self.is_tls = False self._disable_cache = debug.flags.disable_qcache self._disable_normalization = debug.flags.edgeql_disable_normalization cdef _main_task_created(self): self.server.on_pgext_client_connected(self) # complete the client initial message with a mocked type self.buffer.feed_data(b'\xff') def connection_lost(self, exc): self.server.on_pgext_client_disconnected(self) super().connection_lost(exc) cdef is_in_tx(self): return self._dbview.in_tx() cdef write_error(self, exc): cdef WriteBuffer buf if self.debug and not isinstance(exc, errors.BackendUnavailableError): self.debug_print('EXCEPTION', type(exc).__name__, exc) from edb.common.markup import dump dump(exc) if debug.flags.server and not isinstance( exc, errors.BackendUnavailableError ): self.loop.call_exception_handler({ 'message': ( 'an error in edgedb protocol' ), 'exception': exc, 'protocol': self, 'transport': self._transport, }) message = str(exc) buf = WriteBuffer.new_message(b'E') if isinstance(exc, pgerror.BackendError): if exc.code_is(pgerror.ERROR_SERIALIZATION_FAILURE): metrics.transaction_serialization_errors.inc( 1.0, self.get_tenant_label() ) elif isinstance(exc, parser_errors.PSqlUnsupportedError): exc = pgerror.FeatureNotSupported(str(exc)) elif isinstance(exc, parser_errors.PSqlSyntaxError): exc = pgerror.new( pgerror.ERROR_SYNTAX_ERROR, str(exc), P=str(exc.cursor_pos), ) elif isinstance(exc, errors.AuthenticationError): exc = pgerror.InvalidAuthSpec(str(exc), severity="FATAL") elif isinstance(exc, errors.BinaryProtocolError): exc = pgerror.ProtocolViolation( str(exc), detail=exc.details, severity="FATAL" ) elif isinstance(exc, errors.UnsupportedFeatureError): args = {} if exc.line >= 0: args['L'] = str(exc.line) if exc.position >= 0: args['P'] = str(exc.position + 1) exc = pgerror.new( pgerror.ERROR_FEATURE_NOT_SUPPORTED, str(exc), **args, ) elif isinstance(exc, errors.EdgeDBError): args = dict(hint=exc.hint, detail=exc.details) if exc.line >= 0: args['L'] = str(exc.line) if exc.position >= 0: # pg uses 1 based indexes for showing errors. args['P'] = str(exc.position + 1) exc = pgerror.new( exc.pgext_code or pgerror.ERROR_INTERNAL_ERROR, str(exc), **args, ) if isinstance(exc, errors.TransactionSerializationError): metrics.transaction_serialization_errors.inc( 1.0, self.get_tenant_label() ) else: exc = pgerror.new( pgerror.ERROR_INTERNAL_ERROR, str(exc), severity="FATAL", ) for k, v in exc.fields.items(): buf.write_byte(ord(k)) buf.write_str(v, "utf-8") buf.write_byte(b'\0') self.write(buf.end_message()) async def _handshake(self): cdef int16_t proto_ver_major, proto_ver_minor for first in (True, False): if not self.buffer.take_message(): await self.wait_for_message(report_idling=True) proto_ver_major = self.buffer.read_int16() proto_ver_minor = self.buffer.read_int16() if proto_ver_major == 1234: if proto_ver_minor == 5678: # CancelRequest pid = str(self.buffer.read_int32()) secret = self.buffer.read_bytes(4) self.buffer.finish_message() if self.debug: self.debug_print("CancelRequest", pid, secret) self.server.cancel_pgext_connection(pid, secret) self.request_stop() break elif proto_ver_minor == 5679: # SSLRequest if self.debug: self.debug_print("SSLRequest") if not first: raise pgerror.ProtocolViolation( "found multiple SSLRequest", severity="FATAL" ) self.buffer.finish_message() if self._transport is None: raise ConnectionAbortedError if self.debug: self.debug_print("S for SSLRequest") self._transport.write(b'S') # complete the next client message with a mocked type self.buffer.feed_data(b'\xff') self._transport = await self.loop.start_tls( self._transport, self, self.sslctx, server_side=True, ) tenant = self.server.retrieve_tenant( self._transport.get_extra_info("ssl_object") ) if tenant is edbtenant.host_tenant: tenant = None self.tenant = tenant if self.tenant is not None: current_tenant.set(self.tenant.get_instance_name()) self.is_tls = True elif proto_ver_minor == 5680: # GSSENCRequest raise pgerror.FeatureNotSupported( "GSSENCRequest is not supported", severity="FATAL" ) else: raise pgerror.FeatureNotSupported(severity="FATAL") elif proto_ver_major == 3 and proto_ver_minor == 0: # StartupMessage with 3.0 protocol if self.debug: self.debug_print("StartupMessage") if ( not self.is_tls and self.endpoint_security == srvargs.ServerEndpointSecurityMode.Tls ): raise pgerror.InvalidAuthSpec( "TLS required due to server endpoint security", severity="FATAL", ) await super()._handshake() break else: raise pgerror.ProtocolViolation( "invalid protocol version", severity="FATAL" ) def cancel(self, secret): if ( self.secret == secret and self._pinned_pgcon is not None and not self._pinned_pgcon.idle and self.tenant.accept_new_tasks ): self.tenant.create_task( self.tenant.cancel_pgcon_operation(self._pinned_pgcon), interruptable=False, ) def debug_print(self, *args): print("::PGEXT::", f"id:{self._id}", *args, file=sys.stderr) cdef WriteBuffer _make_authentication_sasl_initial(self, list methods): cdef WriteBuffer msg_buf msg_buf = WriteBuffer.new_message(b'R') msg_buf.write_int32(10) for method in methods: msg_buf.write_bytestring(method) msg_buf.write_byte(b'\0') msg_buf.end_message() if self.debug: self.debug_print("AuthenticationSASL:", *methods) return msg_buf cdef _expect_sasl_initial_response(self): mtype = self.buffer.get_message_type() if mtype != b'p': raise pgerror.ProtocolViolation( f'expected SASL response, got message type {mtype}') selected_mech = self.buffer.read_null_str() try: client_first = self.buffer.read_len_prefixed_bytes() except BufferError: client_first = None self.buffer.finish_message() if self.debug: self.debug_print( "SASLInitialResponse:", selected_mech, len(client_first) if client_first else client_first, ) if not client_first: # The client didn't send the Client Initial Response # in SASLInitialResponse, this is an error. raise pgerror.ProtocolViolation( 'client did not send the Client Initial Response ' 'data in SASLInitialResponse') return selected_mech, client_first cdef WriteBuffer _make_authentication_sasl_msg( self, bytes data, bint final ): cdef WriteBuffer msg_buf msg_buf = WriteBuffer.new_message(b'R') if final: msg_buf.write_int32(12) else: msg_buf.write_int32(11) msg_buf.write_bytes(data) msg_buf.end_message() if self.debug: self.debug_print( "AuthenticationSASLFinal" if final else "AuthenticationSASLContinue", len(data), ) return msg_buf cdef bytes _expect_sasl_response(self): mtype = self.buffer.get_message_type() if mtype != b'p': raise pgerror.ProtocolViolation( f'expected SASL response, got message type {mtype}') client_final = self.buffer.consume_message() if self.debug: self.debug_print("SASLResponse", len(client_final)) return client_final def check_readiness(self): if self.tenant.is_blocked(): readiness_reason = self.tenant.get_readiness_reason() msg = "the server is not accepting requests" if readiness_reason: msg = f"{msg}: {readiness_reason}" raise pgerror.CannotConnectNowError(msg) elif not self.tenant.is_online(): readiness_reason = self.tenant.get_readiness_reason() msg = "the server is going offline" if readiness_reason: msg = f"{msg}: {readiness_reason}" raise pgerror.CannotConnectNowError(msg) async def authenticate(self): cdef: WriteBuffer msg_buf WriteBuffer buf self.check_readiness() params = {} while True: name = self.buffer.read_null_str() if not name: break value = self.buffer.read_null_str() params[name.decode("utf-8")] = value.decode("utf-8") if self.debug: self.debug_print("StartupMessage params:", params) if "user" not in params: raise pgerror.ProtocolViolation( "StartupMessage must have a \"user\"", severity="FATAL" ) self.buffer.finish_message() user = params["user"] database = params.get("database", user) if "client_encoding" in params: encoding = params["client_encoding"] client_encoding = encodings.normalize_encoding(encoding).upper() try: codecs.lookup(client_encoding) except LookupError: raise pgerror.new( pgerror.ERROR_INVALID_PARAMETER_VALUE, f'invalid value for parameter "client_encoding": "{encoding}"', ) self._dbview._settings = self._dbview._settings.set( "client_encoding", (client_encoding,) ) else: client_encoding = "UTF8" logger.debug('received pg connection request by %s to database %s', user, database) if database == '__default__': database = self.tenant.default_database elif ( database == defines.EDGEDB_OLD_DEFAULT_DB and self.tenant.maybe_get_db( dbname=defines.EDGEDB_OLD_DEFAULT_DB ) is None ): database = self.tenant.default_database user = self.tenant.resolve_user_name(user) await self._authenticate(user, database, params) logger.debug('successfully authenticated %s in database %s', user, database) if not self.tenant.is_database_connectable(database): raise pgerror.InvalidAuthSpec( f'database {database!r} does not accept connections', severity="FATAL", ) self.database = self.tenant.get_db(dbname=database) await self.database.introspection() self.dbname = database self.username = user self._dbview._init_user_configs(user, self.tenant) buf = WriteBuffer() msg_buf = WriteBuffer.new_message(b'R') msg_buf.write_int32(0) msg_buf.end_message() buf.write_buffer(msg_buf) if self.debug: self.debug_print("AuthenticationOk") self.secret = os.urandom(4) msg_buf = WriteBuffer.new_message(b'K') msg_buf.write_int32(int(self._id)) msg_buf.write_bytes(self.secret) msg_buf.end_message() buf.write_buffer(msg_buf) if self.debug: self.debug_print("BackendKeyData") async with self.with_pgcon() as conn: for name, value in conn.parameter_status.items(): msg_buf = WriteBuffer.new_message(b'S') msg_buf.write_str(name, "utf-8") if name == "client_encoding": value = client_encoding elif name == "server_version": value = str(defines.PGEXT_POSTGRES_VERSION) elif name == "session_authorization": value = user elif name == "application_name": value = self.tenant.get_instance_name() msg_buf.write_str(value, "utf-8") msg_buf.end_message() buf.write_buffer(msg_buf) if self.debug: self.debug_print(f"ParameterStatus: {name}={value}") self.write(buf) # Try to sync the settings, especially client_encoding. await conn.sql_apply_state(self._dbview) self.write(self.ready_for_query()) self.flush() cdef inline WriteBuffer ready_for_query(self): cdef WriteBuffer msg_buf self.ignore_till_sync = False msg_buf = WriteBuffer.new_message(b'Z') if self._dbview.in_tx(): if self._dbview._tx_error: msg_buf.write_byte(b'E') else: msg_buf.write_byte(b'T') else: msg_buf.write_byte(b'I') return msg_buf.end_message() def on_success(self, query_unit): cdef: PreparedStmt stmt if query_unit.deallocate is not None: stmt_name = query_unit.deallocate.stmt_name self.sql_prepared_stmts.pop(stmt_name, None) self.sql_prepared_stmts_map.pop(stmt_name, None) self.prepared_stmts.pop(stmt_name, None) # If any wrapping prepared statements referred to this # prepared statement, invalidate them. for wrapping_ps in self.wrapping_prepared_stmts.pop(stmt_name, []): stmt = self.prepared_stmts.get(wrapping_ps) if stmt is not None: stmt.parse_action.invalidate() def on_error(self, query_unit): cdef: PreparedStmt stmt if query_unit.prepare is not None: stmt_name = query_unit.prepare.stmt_name self.sql_prepared_stmts.pop(stmt_name, None) self.sql_prepared_stmts_map.pop(stmt_name, None) self.prepared_stmts.pop(stmt_name, None) # If any wrapping prepared statements referred to this # prepared statement, invalidate them. for wrapping_ps in self.wrapping_prepared_stmts.pop(stmt_name, []): stmt = self.prepared_stmts.get(wrapping_ps) if stmt is not None: stmt.parse_action.invalidate() async def main_step(self, char mtype): try: await self._main_step(mtype) except pgerror.BackendError as ex: self.write_error(ex) self.write(self.ready_for_query()) self.flush() self.request_stop() async def _main_step(self, char mtype): cdef: WriteBuffer buf ConnectionView dbv dbv = self._dbview self.check_readiness() if self.debug: self.debug_print("main_step", repr(chr(mtype))) if self.ignore_till_sync: self.debug_print("ignoring") if mtype == b'S': # Sync self.buffer.finish_message() if self.debug: self.debug_print("Sync") if dbv._in_tx_implicit: actions = [PGMessage(PGAction.SYNC)] async with self.with_pgcon() as conn: success, _ = await conn.sql_extended_query( actions, self, self.database.dbver, dbv) self.ignore_till_sync = not success else: self.ignore_till_sync = False self.write(self.ready_for_query()) self.flush() elif mtype == b'X': # Terminate self.buffer.finish_message() if self.debug: self.debug_print("Terminate") self.close() return True elif self.ignore_till_sync: self.buffer.discard_message() elif mtype == b'Q': # Query try: query = self.buffer.read_null_str() metrics.query_size.observe( len(query), self.get_tenant_label(), 'sql' ) query_str = query.decode("utf8") self.buffer.finish_message() if self.debug: self.debug_print("Query", query_str) actions = await self.simple_query(query_str) del query_str, query except Exception as ex: self.write_error(ex) self.write(self.ready_for_query()) self.flush() else: async with self.with_pgcon() as conn: try: _, rq_sent = await conn.sql_extended_query( actions, self, self.database.dbver, dbv, ) except Exception as ex: self.write_error(ex) self.write(self.ready_for_query()) else: if not rq_sent: self.write(self.ready_for_query()) self.flush() elif ( mtype == b'P' or mtype == b'B' or mtype == b'D' or mtype == b'E' or # One of Parse, Bind, Describe or Execute starts an extended query mtype == b'C' # or Close ): try: actions, exception = await self.extended_query() except ExtendedQueryError as ex: actions = () exception = ex else: async with self.with_pgcon() as conn: try: success, _ = await conn.sql_extended_query( actions, self, self.database.dbver, dbv) self.ignore_till_sync = not success except Exception as ex: self.write_error(ex) self.flush() self.ignore_till_sync = True if exception: self.write_error(exception.args[0]) self.flush() self.ignore_till_sync = True elif mtype == b'H': # Flush self.buffer.finish_message() if self.debug: self.debug_print("Flush") self.flush() else: if self.debug: self.debug_print( "MESSAGE", chr(mtype), self.buffer.consume_message() ) raise pgerror.FeatureNotSupported() async def simple_query(self, query_str: str) -> list[PGMessage]: cdef: PreparedStmt stmt actions = [] dbv = self._dbview if self._disable_normalization: source = pg_parser.Source.from_string(query_str) else: source = pg_parser.NormalizedSource.from_string(query_str) query_units = await self.compile(source, dbv) # TODO: currently, normalization does not work with multiple queries # so we must re-run the compilation with non-normalized query. # Ideally we could detect this before compilation. if len(query_units) > 1: source = pg_parser.Source.from_string(query_str) query_units = await self.compile(source, dbv) for qu in query_units: self.check_capabilities(qu) already_in_implicit_tx = dbv._in_tx_implicit metrics.sql_queries.inc( len(query_units), self.tenant.get_instance_name() ) self._query_count += len(query_units) if not already_in_implicit_tx: actions.append(PGMessage(PGAction.START_IMPLICIT_TX)) for qu in query_units: if qu.execute is not None: fe_settings = dbv.current_fe_settings() known_be_name = ( self.sql_prepared_stmts_map.get(qu.execute.stmt_name)) recompile = ( qu.fe_settings != fe_settings or qu.execute.be_stmt_name != known_be_name.encode("utf-8") ) actions.extend(await self._ensure_nested_ps_exists( dbv, qu, force_recompilation=recompile)) else: recompile = False if recompile: stmt, new_stmts = await self._parse_statement( stmt_name=None, query_str=qu.orig_query, parse_data=b"\x00\x00", dbv=dbv, force_recompilation=True, injected_action=True, ) else: stmt, new_stmts = await self._parse_unit( stmt_name=None, unit=qu, source=source, parse_data=b"\x00\x00", dbv=dbv, injected_action=True, ) parse_unit = stmt.parse_action.query_unit if parse_unit.set_vars: actions.extend(self._build_sql_settings_actions(parse_unit)) actions.append(stmt.parse_action) # 2 bytes: number of format codes (1) # 2 bytes: first format code (1) is binary # (this implies that all args are binary) # 2 bytes: number of arguments (0) # 2 bytes: number of result format codes (0) # (this implies that ) bind_data = b"\x00\x01\x00\x01\x00\x00\x00\x00" # remap argumnets, which will inject globals bind_data = remap_arguments( bind_data, parse_unit.params, dbv.current_fe_settings(), source, self.get_permissions(), self.username, ) actions.append( PGMessage( PGAction.BIND, portal_name="", stmt_name=parse_unit.stmt_name, args=bind_data, query_unit=parse_unit, injected=True, ) ) actions.append( PGMessage( PGAction.DESCRIBE_STMT_ROWS, stmt_name=parse_unit.stmt_name, query_unit=parse_unit, ) ) actions.append( PGMessage( PGAction.EXECUTE, args=0, portal_name="", query_unit=parse_unit, injected=False, ) ) actions.append( PGMessage( PGAction.CLOSE_PORTAL, portal_name="", query_unit=parse_unit, injected=True, ) ) actions.append(PGMessage(PGAction.SYNC)) return actions async def extended_query(self): cdef: WriteBuffer buf int16_t i bytes data bint in_implicit PreparedStmt stmt ConnectionView dbv # Extended-query pre-plays on a deeply-cloned temporary dbview so as to # compose the actions list with correct states; the actual changes to # dbview is applied in pgcon.pyx when the actions are actually executed dbv = copy.deepcopy(self._dbview) actions = deque() fresh_stmts = set() in_implicit = self._dbview._in_tx_implicit # Here we will exhaust the buffer and queue up actions for the backend. # Any error in this step will be handled in the outer main_step() - # the error will be returned, any remaining messages in the buffer will # be discarded until a Sync message is found (ignore_till_sync). # This also means no partial action is executed in the backend for now. while self.buffer.take_message(): if not in_implicit: actions.append(PGMessage(PGAction.START_IMPLICIT_TX)) in_implicit = True with managed_error(): dbv.start_implicit() mtype = self.buffer.get_message_type() if mtype == b'P': # Parse stmt_name = self.buffer.read_null_str().decode("utf8") query_bytes = self.buffer.read_null_str() query_str = query_bytes.decode("utf8") data = self.buffer.consume_message() if self.debug: self.debug_print("Parse", repr(stmt_name), query_str, data) metrics.query_size.observe( len(query_bytes), self.get_tenant_label(), 'sql' ) with managed_error(): if ( stmt_name and ( stmt_name in self.prepared_stmts or stmt_name in self.sql_prepared_stmts ) ): raise pgerror.new( pgerror.ERROR_DUPLICATE_PREPARED_STATEMENT, f"prepared statement \"{stmt_name}\" already " f"exists", ) stmt, new_stmts = await self._parse_statement( stmt_name, query_str, data, dbv ) if stmt.parse_action.query_unit.execute is not None: actions.extend( await self._ensure_nested_ps_exists( dbv, stmt.parse_action.query_unit, ) ) fresh_stmts.update(new_stmts) actions.append(stmt.parse_action) elif mtype == b'B': # Bind portal_name = self.buffer.read_null_str().decode("utf8") stmt_name = self.buffer.read_null_str().decode("utf8") data = self.buffer.consume_message() if self.debug: self.debug_print( "Bind", repr(portal_name), repr(stmt_name), data ) with managed_error(): stmt = await self._ensure_ps_locality( dbv, stmt_name, fresh_stmts, actions, ) try: params = stmt.parse_action.query_unit.params fe_settings = dbv.current_fe_settings() data = remap_arguments( data, params, fe_settings, stmt.source, self.get_permissions(), self.username, ) except Exception as e: # we return here instead of raising the exception # because we want to also return the previous actions return actions, ExtendedQueryError(e) actions.append( PGMessage( PGAction.BIND, stmt_name=stmt.parse_action.stmt_name, portal_name=portal_name, args=data, query_unit=stmt.parse_action.query_unit, ) ) dbv.create_portal(portal_name, stmt.parse_action.query_unit) elif mtype == b'D': # Describe kind = self.buffer.read_byte() name = self.buffer.read_null_str().decode("utf8") self.buffer.finish_message() if self.debug: self.debug_print("Describe", kind, repr(name)) with managed_error(): if kind == b'S': # prepared statement stmt = await self._ensure_ps_locality( dbv, name, fresh_stmts, actions, ) actions.append( PGMessage( PGAction.DESCRIBE_STMT, stmt_name=stmt.parse_action.stmt_name, query_unit=stmt.parse_action.query_unit, ) ) elif kind == b'P': # portal actions.append( PGMessage( PGAction.DESCRIBE_PORTAL, portal_name=name, query_unit=dbv.find_portal(name), ) ) else: raise pgerror.ProtocolViolation( "invalid Describe kind" ) elif mtype == b'E': # Execute portal_name = self.buffer.read_null_str().decode("utf8") max_rows = self.buffer.read_int32() self.buffer.finish_message() if self.debug: self.debug_print("Execute", repr(portal_name), max_rows) metrics.sql_queries.inc(1.0, self.tenant.get_instance_name()) self._query_count += 1 with managed_error(): unit = dbv.find_portal(portal_name) if unit.set_vars: actions.extend(self._build_sql_settings_actions(unit)) actions.append( PGMessage( PGAction.EXECUTE, portal_name=portal_name, args=max_rows, query_unit=unit, ) ) dbv.on_success(unit) elif mtype == b'C': # Close kind = self.buffer.read_byte() name_bytes = self.buffer.read_null_str() name = name_bytes.decode("utf8") self.buffer.finish_message() if self.debug: self.debug_print("Close", kind, repr(name)) with managed_error(): if kind == b'S': # prepared statement if name not in self.prepared_stmts: raise pgerror.new( pgerror.ERROR_INVALID_SQL_STATEMENT_NAME, f"prepared statement \"{name}\" does not " f"exist", ) # The prepared statement in the backend is managed by # the LRU cache in pgcon.pyx, we don't close it here fresh_stmts.discard(name) self.prepared_stmts.pop(name) actions.append( PGMessage( PGAction.CLOSE_STMT, stmt_name=name_bytes, ), ) elif kind == b'P': # portal actions.append( PGMessage( PGAction.CLOSE_PORTAL, portal_name=name, query_unit=dbv.close_portal(name), ), ) else: raise pgerror.ProtocolViolation("invalid Close kind") elif mtype == b'H': # Flush self.buffer.finish_message() if self.debug: self.debug_print("Flush") actions.append(PGMessage(PGAction.FLUSH)) elif mtype == b'S': # Sync in_implicit = False self.buffer.finish_message() if self.debug: self.debug_print("Sync") with managed_error(): actions.append(PGMessage(PGAction.SYNC)) dbv.end_implicit() break else: # Other messages would cut off the current extended_query() break if self.debug: self.debug_print("extended_query", actions) return actions, None def check_capabilities( self, query_unit, ): query_capabilities = query_unit.capabilities role_capability = self.get_role_capability() if query_capabilities & ~role_capability: raise query_capabilities.make_error( role_capability, errors.DisabledCapabilityError, f"role {self.username} does not have permission", ) def get_role_capability(self) -> enums.Capability: if capability := self.tenant.get_role_capabilities().get( self.username ): return capability return enums.Capability.NONE def get_permissions(self) -> tuple[bool, Sequence[str]]: if role_desc := self.tenant.get_roles().get(self.username): return ( bool(role_desc.get('superuser')), (role_desc.get('all_permissions') or ()) ) return False, () async def _ensure_ps_locality( self, dbv: ConnectionView, stmt_name: str, local_stmts: set[str], actions ) -> PreparedStmt: """Make sure given *stmt_name* is known by Postgres Frontend SQL connections do not normally own Postgres connections, so there is no affinity between them. Thus, whenever we receive a message operating on some prepared statement, we must ensure that this statement has been prepared in the currently active Postgres connection. We rely on pgcon LRU to actually make a decision on whether to issue the injected Parse messages. NB: this method mutates *local_stmts* and *actions*. """ cdef: PreparedStmt stmt stmt = self.prepared_stmts.get(stmt_name) if stmt is None: raise pgerror.new( pgerror.ERROR_INVALID_SQL_STATEMENT_NAME, f"prepared statement \"{stmt_name}\" does not " f"exist", ) if stmt_name not in local_stmts: # Non-local statement, so inject its Parse. fe_settings = dbv.current_fe_settings() qu = stmt.parse_action.query_unit assert qu is not None if stmt.parse_action.fe_settings != fe_settings: # Some of the statically compiler-evaluated # queries like `current_schema` depend on the # fe_settings, we need to re-compile if the # fe_settings have changed. stmt.parse_action.invalidate() if ( qu.execute is not None and ( qu.execute.be_stmt_name != self.sql_prepared_stmts_map.get( qu.execute.stmt_name).encode("utf-8") ) ): # Likewise, re-compile if this is an EXECUTE query # and the translated name of the prepared statement # has changed (e.g. due to it having been deallocated # and prepared with a different query). stmt.parse_action.invalidate() if not stmt.parse_action.is_valid(): parse_actions, new_stmts = await self._reparse( stmt_name, stmt.parse_action, dbv, ) local_stmts.update(new_stmts) actions.extend(parse_actions) stmt = self.prepared_stmts[stmt_name] else: actions.append(stmt.parse_action.as_injected()) local_stmts.add(stmt_name) return stmt async def _reparse( self, str stmt_name, PGMessage parse_action, ConnectionView dbv, ): cdef: PreparedStmt outer_stmt actions = [] qu = parse_action.query_unit assert qu is not None if self.debug: self.debug_print("reparsing", stmt_name, parse_action) if ( qu.prepare is not None or qu.execute is not None ): actions.extend( await self._ensure_nested_ps_exists( dbv, qu, force_recompilation=True, ), ) outer_stmt, new_stmts = await self._parse_statement( stmt_name, qu.orig_query, parse_action.args[2], dbv, force_recompilation=True, injected_action=True, ) actions.append(outer_stmt.parse_action) return actions, new_stmts async def _ensure_nested_ps_exists( self, dbv: ConnectionView, execute_unit: dbstate.SQLQueryUnit, force_recompilation: bool = False, ) -> list[PGMessage]: cdef: PreparedStmt sql_stmt exec_data = execute_unit.execute prep_qu = self.sql_prepared_stmts.pop(exec_data.stmt_name, None) actions = [] if prep_qu is None: raise pgerror.new( pgerror.ERROR_INVALID_SQL_STATEMENT_NAME, f"prepared statement " f"\"{exec_data.stmt_name}\" does not " f"exist", ) sql_stmt, _ = await self._parse_statement( prep_qu.stmt_name.decode("utf-8"), prep_qu.orig_query, b"\x00\x00", dbv, injected_action=True, force_recompilation=force_recompilation, ) actions.append(sql_stmt.parse_action) parse_stmt_name = sql_stmt.parse_action.stmt_name portal_name = parse_stmt_name.decode("utf-8") parse_query_unit = sql_stmt.parse_action.query_unit actions.append( PGMessage( PGAction.BIND, portal_name=portal_name, stmt_name=parse_stmt_name, args=b"\x00\x01\x00\x01\x00\x00\x00\x00", query_unit=parse_query_unit, injected=True, ) ) actions.append( PGMessage( PGAction.EXECUTE, args=0, portal_name=portal_name, query_unit=parse_query_unit, injected=True, ) ) actions.append( PGMessage( PGAction.CLOSE_PORTAL, portal_name=portal_name, query_unit=parse_query_unit, injected=True, ) ) return actions def _build_sql_settings_actions(self, qu): actions = [] if qu.set_vars == {None: None}: # RESET ALL actions.append( PGMessage( PGAction.BIND, force_portal_name=b"injected", stmt_name=b"_reset_sql_state_all", # 2 bytes: number of format codes (0) # 2 bytes: number of parameters (0) # 2 bytes: number of result format codes (0) args=b"\x00\x00\x00\x00\x00\x00", injected=True, ) ) actions.append( PGMessage( PGAction.EXECUTE, args=0, force_portal_name=b"injected", injected=True, ) ) actions.append( PGMessage( PGAction.CLOSE_PORTAL, force_portal_name=b"injected", injected=True, ) ) elif qu.frontend_only: for k, v in qu.set_vars.items(): buf = WriteBuffer.new() buf.write_int16(0) # number of format codes # number of parameters: if v is None: buf.write_int16(2) else: buf.write_int16(3) buf.write_len_prefixed_utf8(k) # 1st param: name buf.write_len_prefixed_utf8( # 2nd param: type "L" if qu.is_local else "S") if v is not None: buf.write_len_prefixed_utf8( # 3rd param: value setting_to_sql(k, v)) buf.write_int16(0) # number of result format codes if v is None: actions.append( PGMessage( PGAction.BIND, force_portal_name=b"injected", stmt_name=b"_reset_sql_state", args=bytes(buf), injected=True, ) ) else: actions.append( PGMessage( PGAction.BIND, force_portal_name=b"injected", stmt_name=b"_set_sql_state", args=bytes(buf), injected=True, ) ) actions.append( PGMessage( PGAction.EXECUTE, args=0, force_portal_name=b"injected", injected=True, ) ) actions.append( PGMessage( PGAction.CLOSE_PORTAL, force_portal_name=b"injected", injected=True, ) ) return actions async def _parse_statement( self, stmt_name: str | None, query_str: str, parse_data: bytes, dbv: ConnectionView, force_recompilation: bool = False, injected_action: bool = False, ) -> Tuple[PreparedStmt, set[str]]: """Generate a PARSE action for *query_str*. The *query_str* string must contain exactly one SQL statement. """ stmts = set() if self._disable_normalization: source = pg_parser.Source.from_string(query_str) else: source = pg_parser.NormalizedSource.from_string(query_str) query_units = await self.compile( source, dbv, ignore_cache=force_recompilation ) if len(query_units) > 1: raise pgerror.new( pgerror.ERROR_SYNTAX_ERROR, "cannot insert multiple commands into a prepared " "statement", ) for qu in query_units: self.check_capabilities(qu) return await self._parse_unit( stmt_name, query_units[0], source, parse_data, dbv, injected_action=injected_action, ) async def _parse_unit( self, stmt_name: str | None, unit: dbstate.SQLQueryUnit, source: pg_parser.Source, parse_data: bytes, dbv: ConnectionView, injected_action: bool = False, ) -> Tuple[PreparedStmt, set[str]]: stmts = set() fe_settings = dbv.current_fe_settings() nested_ps_name = None if unit.prepare is not None: # Statement-level PREPARE nested_ps_name = unit.prepare.stmt_name unit = self._validate_prepare_stmt(unit) stmts.add(nested_ps_name) self.sql_prepared_stmts[nested_ps_name] = unit self.sql_prepared_stmts_map[nested_ps_name] = ( unit.prepare.be_stmt_name.decode("utf-8")) elif unit.execute is not None: # Statement-level EXECUTE nested_ps_name = unit.execute.stmt_name unit = self._validate_execute_stmt(unit) elif unit.deallocate is not None: # Statement-level DEALLOCATE nested_ps_name = unit.deallocate.stmt_name unit = self._validate_deallocate_stmt(unit) remapped_parse_data = remap_parameters(parse_data, unit.params) action = PGMessage( PGAction.PARSE, stmt_name=unit.stmt_name, args=(unit.query.encode("utf-8"), remapped_parse_data, parse_data), query_unit=unit, fe_settings=fe_settings, injected=injected_action, ) if stmt_name is not None and nested_ps_name is not None: # This is a prepared statement of an operation on *another* # prepared statement, and so we must track this relationship # in case the nested prepared statement gets deallocated. try: self.wrapping_prepared_stmts[nested_ps_name].add(stmt_name) except KeyError: self.wrapping_prepared_stmts[nested_ps_name] = set([stmt_name]) stmt = PreparedStmt( parse_action=action, source=source, ) if stmt_name is not None: self.prepared_stmts[stmt_name] = stmt stmts.add(stmt_name) return stmt, stmts async def compile( self, source: pg_parser.Source, ConnectionView dbv, ignore_cache=False ) -> List[dbstate.SQLQueryUnit]: if self.debug: self.debug_print("Compile", source.text()) fe_settings = dbv.current_fe_settings() key = compute_cache_key(source, fe_settings) ignore_cache |= self._disable_cache result: List[dbstate.SQLQueryUnit] if not ignore_cache: result = self.database.lookup_compiled_sql(key) if result is not None: return result # Remember the schema version we are compiling on, so that we can # cache the result with the matching version. In case of concurrent # schema update, we're only storing an outdated cache entry, and # the next identical query could get recompiled on the new schema. schema_version = self.database.schema_version compiler_pool = self.server.get_compiler_pool() started_at = time.monotonic() try: result = await compiler_pool.compile_sql( self.dbname, self.database.user_schema_pickle, self.database._index._global_schema_pickle, self.database.reflection_cache, self.database.db_config, self.database._index.get_compilation_system_config(), source, dbv.fe_transaction_state(), self.sql_prepared_stmts_map, self.dbname, self.username, client_id=self.tenant.client_id, client_name=self.tenant.get_instance_name(), ) finally: metrics.query_compilation_duration.observe( time.monotonic() - started_at, self.tenant.get_instance_name(), "sql", ) self.database.cache_compiled_sql(key, result, schema_version) metrics.sql_compilations.inc( len(result), self.tenant.get_instance_name() ) if self.debug: self.debug_print("Compile result", result) return result def _validate_prepare_stmt(self, qu): assert qu.prepare is not None stmt_name = qu.prepare.stmt_name if ( stmt_name in self.prepared_stmts or stmt_name in self.sql_prepared_stmts ): raise pgerror.new( pgerror.ERROR_DUPLICATE_PREPARED_STATEMENT, f"prepared statement \"{stmt_name}\" " f"already exists", ) return qu def _validate_execute_stmt(self, qu): assert qu.execute is not None stmt_name = qu.execute.stmt_name sql_ps = self.sql_prepared_stmts.get(stmt_name) if sql_ps is None: raise pgerror.new( pgerror.ERROR_INVALID_SQL_STATEMENT_NAME, f"prepared statement \"{stmt_name}\" does " f"not exist", ) return qu def _validate_deallocate_stmt(self, qu): assert qu.deallocate is not None stmt_name = qu.deallocate.stmt_name sql_ps = self.sql_prepared_stmts.get(stmt_name) if sql_ps is None: raise pgerror.new( pgerror.ERROR_INVALID_SQL_STATEMENT_NAME, f"prepared statement \"{stmt_name}\" does " f"not exist", ) return qu def compute_cache_key( source: pg_parser.Source, fe_settings: dbstate.SQLSettings ) -> bytes: h = hashlib.blake2b(source.cache_key()) for key, value in fe_settings.items(): if key.startswith('global '): continue h.update(hash(value).to_bytes(8, signed=True)) return h.digest() cdef bytes remap_arguments( data: bytes, params: list[dbstate.SQLParam] | None, fe_settings: dbstate.SQLSettings, source: pg_parser.Source, permission_info: tuple[bool, Sequence[str]], username: str, # HACK ): cdef: int16_t param_format_count int32_t offset int32_t arg_offset_external int16_t param_count_external int32_t size is_superuser, permissions = permission_info # The "external" parameters (that are visible to the user) # don't include the internal params for globals and extracted constants. # So when we send external params to postgres, we remap them # to correct positions and add the globals. buf = WriteBuffer.new() # remap param format codes param_format_count = read_int16(data[0:2]) offset = 2 if params: buf.write_int16(len(params)) for i, param in enumerate(params): if isinstance(param, dbstate.SQLParamExternal): if param_format_count == 0: buf.write_int16(0) # text elif param_format_count == 1: buf.write_bytes( data[offset:offset+2] ) else: o = offset + i * 2 buf.write_bytes(data[o:o+2]) elif isinstance(param, dbstate.SQLParamExtractedConst): buf.write_int16(0) # text else: # this is for globals buf.write_int16(1) # binary else: buf.write_int16(0) offset += param_format_count * 2 # find positions of external args arg_count_external = read_int16(data[offset:offset+2]) offset += 2 arg_offset_external = offset for p in range(arg_count_external): size = read_int32(data[offset:offset+4]) if size == -1: # special case: NULL size = 0 size += 4 # for size which is int32 offset += size # write remapped args if params: buf.write_int16(len(params)) param_count_external = 0 for i, param in enumerate(params): if not isinstance(param, dbstate.SQLParamExternal): break param_count_external = i + 1 if param_count_external != arg_count_external: raise pgerror.new( pgerror.ERROR_PROTOCOL_VIOLATION, f'bind message supplies {arg_count_external} ' f'parameters, but prepared statement "" requires ' f'{param_count_external}', ) # write external args if arg_offset_external < offset: buf.write_bytes(data[arg_offset_external:offset]) # write non-external args extracted_consts = list(source.variables().values()) for (e, param) in enumerate(params[param_count_external:]): if isinstance(param, dbstate.SQLParamExtractedConst): buf.write_len_prefixed_bytes(extracted_consts[e]) elif isinstance(param, dbstate.SQLParamGlobal): name = param.global_name if param.is_permission: buf.write_int32(1) buf.write_byte(is_superuser or str(name) in permissions) elif name.module == 'sys' and name.name == 'current_role': write_arg(buf, param.pg_type, (username,)) else: setting_name = f'global {name.module}::{name.name}' values = fe_settings.get(setting_name, None) if values == None: buf.write_int32(-1) # NULL else: write_arg(buf, param.pg_type, values) else: buf.write_int16(0) # result format codes buf.write_bytes(data[offset:]) return bytes(buf) cdef bytes remap_parameters( data: bytes, params: list[dbstate.SQLParam] | None ): # Inject parameter type descriptions in parse message for parameters for # globals and extracted constants. if not params: return b"\x00\x00" buf = WriteBuffer.new() buf.write_int16(len(params)) # copy the params specified by user specified_ext = read_int16(data[0:2]) buf.write_bytes(data[2:2 + specified_ext*4]) for index, param in enumerate(params): # already written if index < specified_ext: continue if isinstance(param, dbstate.SQLParamExternal): buf.write_int32(0) # unspecified elif isinstance(param, dbstate.SQLParamExtractedConst): buf.write_int32(param.type_oid) elif isinstance(param, dbstate.SQLParamGlobal): buf.write_int32(0) # unspecified assert len(bytes(buf)) == 2 + 4 * len(params) return bytes(buf) cdef write_arg( buf: WriteBuffer, pg_type: tuple, values: dbstate.SQLSetting ): value = values[0] if pg_type == ('text',) and isinstance(value, str): val = str(value).encode('UTF-8') buf.write_len_prefixed_bytes(val) elif pg_type == ('uuid',) and isinstance(value, str): try: id = uuid.UUID(value) buf.write_len_prefixed_bytes(id.bytes) except ValueError: buf.write_int32(-1) # NULL elif pg_type == ('int8',) and isinstance(value, int): buf.write_int32(8) buf.write_int64(value) elif pg_type == ('int4',) and isinstance(value, int): buf.write_int32(4) buf.write_int32(value) elif pg_type == ('int2',) and isinstance(value, int): buf.write_int32(2) buf.write_int16(value) elif pg_type == ('bool',): is_truthy = is_setting_truthy(value) if is_truthy == None: buf.write_int32(-1) else: buf.write_int32(1) buf.write_byte(1 if is_truthy else 0) elif pg_type == ('float8',) and isinstance(value, float): buf.write_int32(8) buf.write_double(value) elif pg_type == ('float4',) and isinstance(value, float): buf.write_int32(4) buf.write_float(value) else: buf.write_int32(-1) # NULL raise RuntimeError( f"unimplemented glob type={pg_type}, value={type(value)}" ) def is_setting_truthy(value: str | int | float) -> bool | None: if isinstance(value, int): return value != 0 if isinstance(value, str): value = value.lower() if value == 'o': # ambigious return None truthy_values = ('on', 'true', 'yes', '1') if any(t.startswith(value) for t in truthy_values): return True falsy_values = ('off', 'false', 'no', '0') if any(t.startswith(value) for t in falsy_values): return False return None cdef inline int16_t read_int16(data: bytes): return int.from_bytes(data[0:2], "big", signed=True) cdef inline int32_t read_int32(data: bytes): return int.from_bytes(data[0:4], "big", signed=True) def new_pg_connection(server, sslctx, endpoint_security, connection_made_at): return PgConnection( server, sslctx, endpoint_security, passive=False, transport=srvargs.ServerConnTransport.TCP_PG, external_auth=False, connection_made_at=connection_made_at, ) ================================================ FILE: edb/server/protocol/protocol.pxd ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from edb.server.protocol cimport binary cdef class HttpRequest: cdef: public object url public bytes version public bint should_keep_alive public bytes content_type public bytes method public bytes accept public bytes body public bytes host public bytes origin public bytes authorization public object params public object forwarded public object cookies cdef class HttpResponse: cdef: public object status public bint close_connection public bytes content_type public dict custom_headers public bytes body public bint sent cdef class HttpProtocol: cdef public object server cdef: object loop object parser object transport object unprocessed object sslctx object sslctx_pgext bint in_response bint first_data_call bint external_auth bint respond_hsts bint is_tls object binary_endpoint_security object http_endpoint_security object tenant bint is_tenant_host object connection_made_at HttpRequest current_request cdef _not_found(self, HttpRequest request, HttpResponse response, str message = ?) cdef _bad_request(self, HttpRequest request, HttpResponse response, str message) cdef _unauthorized(self, HttpRequest request, HttpResponse response, str message) cdef _return_binary_error(self, binary.EdgeConnection proto) cdef _write(self, bytes req_version, bytes resp_status, bytes content_type, dict custom_headers, bytes body, bint close_connection) cpdef write(self, HttpRequest request, HttpResponse response) cdef unhandled_exception(self, bytes status, ex) cdef resume(self) cpdef close(self) cdef inline _schedule_handle_request(self, request) cdef inline _close_with_error(self, bytes status, bytes message) ================================================ FILE: edb/server/protocol/protocol.pyi ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import asyncio import http import http.cookies import httptools import ssl from edb.server import args as srvargs from edb.server import server class HttpRequest: url: httptools.URL version: bytes should_keep_alive: bool content_type: bytes method: bytes accept: bytes body: bytes host: bytes origin: bytes authorization: bytes params: dict[bytes, bytes] forwarded: dict[bytes, bytes] cookies: http.cookies.SimpleCookie class HttpResponse: status: http.HTTPStatus close_connection: bool content_type: bytes custom_headers: dict[str, str] body: bytes sent: bool class HttpProtocol(asyncio.Protocol): def __init__( self, server: server.BaseServer, sslctx: ssl.SSLContext, sslctx_pgext: ssl.SSLContext, *, external_auth: bool = False, binary_endpoint_security: srvargs.ServerEndpointSecurityMode, http_endpoint_security: srvargs.ServerEndpointSecurityMode, ) -> None: ... def write_raw(self, data: bytes) -> None: ... def write(self, request: HttpRequest, response: HttpResponse) -> None: ... def close(self) -> None: ... ================================================ FILE: edb/server/protocol/protocol.pyx ================================================ # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # include "./consts.pxi" import asyncio import collections import http import http.cookies import re import ssl import time import urllib.parse import httptools from edb import errors from edb.common import debug from edb.common import markup from edb.common.log import current_tenant from edb.graphql import extension as graphql_ext from edb.server import args as srvargs from edb.server import config, metrics as srv_metrics from edb.server import tenant as edbtenant from edb.server.protocol cimport binary from edb.server.protocol import binary from edb.server.protocol import pg_ext from edb.server import defines as edbdef from edb.server.dbview cimport dbview # Without an explicit cimport of `pgproto.debug`, we # can't cimport `protocol.binary` for some reason. from edb.server.pgproto.debug cimport PG_DEBUG from . import auth from . cimport auth_helpers from . import edgeql_ext from . import metrics from . import server_info from . import notebook_ext from . import system_api from . import ui_ext from . import auth_ext from . import ai_ext HTTPStatus = http.HTTPStatus PROTO_MIME = ( f'application/x.edgedb.' f'v_{edbdef.CURRENT_PROTOCOL[0]}_{edbdef.CURRENT_PROTOCOL[1]}' f'.binary' ).encode() PROTO_MIME_RE = re.compile(br'application/x\.edgedb\.v_(\d+)_(\d+)\.binary') cdef class HttpRequest: def __cinit__(self): self.body = b'' self.authorization = b'' self.content_type = b'' self.forwarded = {} self.cookies = http.cookies.SimpleCookie() cdef class HttpResponse: def __cinit__(self): self.status = HTTPStatus.OK self.content_type = b'text/plain' self.custom_headers = {} self.body = b'' self.close_connection = False self.sent = False cdef class HttpProtocol: def __init__( self, server, sslctx, sslctx_pgext, *, external_auth: bool=False, binary_endpoint_security = None, http_endpoint_security = None, ): self.loop = server.get_loop() self.server = server self.tenant = None self.transport = None self.external_auth = external_auth self.sslctx = sslctx self.sslctx_pgext = sslctx_pgext self.parser = None self.current_request = None self.in_response = False self.unprocessed = None self.first_data_call = True self.binary_endpoint_security = binary_endpoint_security self.http_endpoint_security = http_endpoint_security self.respond_hsts = False # redirect non-TLS HTTP clients to TLS URL self.is_tls = False self.is_tenant_host = False def connection_made(self, transport): self.connection_made_at = time.monotonic() self.transport = transport def connection_lost(self, exc): srv_metrics.client_connection_duration.observe( time.monotonic() - self.connection_made_at, self.get_tenant_label(), "http", ) self.transport = None self.unprocessed = None self.server.maybe_auto_shutdown() def get_tenant_label(self): if self.tenant is None: return "unknown" else: return self.tenant.get_instance_name() def pause_writing(self): pass def resume_writing(self): pass def eof_received(self): pass def data_received(self, data): if self.first_data_call: self.first_data_call = False # Detect if the client is speaking TLS in the "first" data using # the SSL library. This is not the official handshake as we only # need to know "is_tls"; the first data is used again for the true # handshake if is_tls = True. This is for further responding a nice # error message to non-TLS clients. is_tls = True try: outgoing = ssl.MemoryBIO() incoming = ssl.MemoryBIO() incoming.write(data) sslobj = self.sslctx.wrap_bio( incoming, outgoing, server_side=True ) sslobj.do_handshake() except ssl.SSLWantReadError: pass except ssl.SSLError: is_tls = False self.is_tls = is_tls if is_tls: # Most clients should arrive here to continue with TLS self.transport.pause_reading() self.loop.create_task(self._forward_first_data(data)) self.loop.create_task(self._start_tls()) return # In case when we're talking to a non-TLS client, keep using the # legacy magic byte check to choose the HTTP or binary protocol. if data[0:2] == b'V\x00': # This is, most likely, our binary protocol, # as its first message kind is `V`. # # Switch protocols now (for compatibility). if ( self.binary_endpoint_security is srvargs.ServerEndpointSecurityMode.Optional ): self._switch_to_binary_protocol(data) else: self._return_binary_error( self._switch_to_binary_protocol() ) return elif data[0:1] == b'\x00': # Postgres protocol, assuming the 1st message is less than 16MB pg_ext_conn = pg_ext.new_pg_connection( self.server, self.sslctx_pgext, self.binary_endpoint_security, connection_made_at=self.connection_made_at, ) self.transport.set_protocol(pg_ext_conn) pg_ext_conn.connection_made(self.transport) pg_ext_conn.data_received(data) return else: # HTTP. self._init_http_parser() self.respond_hsts = ( self.http_endpoint_security is srvargs.ServerEndpointSecurityMode.Tls ) try: self.parser.feed_data(data) except Exception as ex: self.unhandled_exception(b'400 Bad Request', ex) def on_url(self, url: bytes): self.current_request.url = httptools.parse_url(url) def on_header(self, name: bytes, value: bytes): name = name.lower() if name == b'content-type': self.current_request.content_type = value elif name == b'host': self.current_request.host = value elif name == b'origin': self.current_request.origin = value elif name == b'accept': if self.current_request.accept: self.current_request.accept += b',' + value else: self.current_request.accept = value elif name == b'authorization': self.current_request.authorization = value elif name.startswith(b'x-edgedb-'): if self.current_request.params is None: self.current_request.params = {} param = name[len(b'x-edgedb-'):] self.current_request.params[param] = value elif name.startswith(b'x-gel-'): if self.current_request.params is None: self.current_request.params = {} param = name[len(b'x-gel-'):] self.current_request.params[param] = value elif name.startswith(b'x-forwarded-'): if self.current_request.forwarded is None: self.current_request.forwarded = {} forwarded_key = name[len(b'x-forwarded-'):] self.current_request.forwarded[forwarded_key] = value elif name == b'cookie': self.current_request.cookies.load(value.decode('ascii')) def on_body(self, body: bytes): self.current_request.body += body def on_message_begin(self): self.current_request = HttpRequest() def on_message_complete(self): self.transport.pause_reading() req = self.current_request req.version = self.parser.get_http_version().encode() req.should_keep_alive = self.parser.should_keep_alive() req.method = self.parser.get_method().upper() if self.in_response: # pipelining support if self.unprocessed is None: self.unprocessed = collections.deque() self.unprocessed.append(req) else: self.in_response = True self._schedule_handle_request(req) self.server._http_last_minute_requests += 1 cdef inline _schedule_handle_request(self, request): if self.tenant is None: self.loop.create_task(self._handle_request(request)) elif self.tenant.is_accepting_connections(): self.tenant.create_task( self._handle_request(request), interruptable=False ) else: self._close_with_error( b'503 Service Unavailable', b'The server is closing.', ) cpdef close(self): if self.transport is not None: self.transport.close() self.transport = None self.unprocessed = None cdef unhandled_exception(self, bytes status, ex): if debug.flags.server: markup.dump(ex) self._close_with_error( status, f'{type(ex).__name__}: {ex}'.encode(), ) cdef inline _close_with_error(self, bytes status, bytes message): self._write( b'1.0', status, b'text/plain', {}, message, True) self.close() cdef resume(self): if self.transport is None: return if self.unprocessed: req = self.unprocessed.popleft() self._schedule_handle_request(req) else: self.transport.resume_reading() cdef _write(self, bytes req_version, bytes resp_status, bytes content_type, dict custom_headers, bytes body, bint close_connection): if self.transport is None: return data = [ b'HTTP/', req_version, b' ', resp_status, b'\r\n', b'Content-Type: ', content_type, b'\r\n', ] if content_type != b"text/event-stream": data.extend( (b'Content-Length: ', f'{len(body)}'.encode(), b'\r\n'), ) for key, value in custom_headers.items(): data.append(f'{key}: {value}\r\n'.encode()) if close_connection: data.append(b'Connection: close\r\n') data.append(b'\r\n') if body: data.append(body) self.transport.write(b''.join(data)) cpdef write(self, HttpRequest request, HttpResponse response): assert type(response.status) is HTTPStatus self._write( request.version, f'{response.status.value} {response.status.phrase}'.encode(), response.content_type, response.custom_headers, response.body, response.close_connection or not request.should_keep_alive) response.sent = True def write_raw(self, bytes data): self.transport.write(data) def _switch_to_binary_protocol(self, data=None): binproto = binary.new_edge_connection( self.server, self.tenant, external_auth=self.external_auth, connection_made_at=self.connection_made_at, ) self.transport.set_protocol(binproto) binproto.connection_made(self.transport) if data: binproto.data_received(data) return binproto def _init_http_parser(self): self.parser = httptools.HttpRequestParser(self) self.current_request = HttpRequest() async def _forward_first_data(self, data): # As we stole the "first data", we need to manually send it back to # the SSLProtocol. The hack here is highly-coupled with uvloop impl. transport = self.transport # The TCP transport for i in range(3): await asyncio.sleep(0) ssl_proto = self.transport.get_protocol() if ssl_proto is not self: break else: raise RuntimeError("start_tls() hasn't run in 3 loop iterations") # Delay for one more iteration to make sure the first data is fed after # SSLProtocol.connection_made() is called. await asyncio.sleep(0) data_len = len(data) buf = ssl_proto.get_buffer(data_len) buf[:data_len] = data ssl_proto.buffer_updated(data_len) async def _start_tls(self): self.transport = await self.loop.start_tls( self.transport, self, self.sslctx, server_side=True ) sslobj = self.transport.get_extra_info('ssl_object') tenant = self.server.retrieve_tenant(sslobj) if tenant is edbtenant.host_tenant: tenant = None self.is_tenant_host = True self.tenant = tenant if self.tenant is not None: current_tenant.set(self.get_tenant_label()) if sslobj.selected_alpn_protocol() == 'edgedb-binary': self._switch_to_binary_protocol() else: # It's either HTTP as the negotiated protocol, or the negotiation # failed and we have no idea what ALPN the client has set. Here we # just start talking in HTTP, and let the client bindings decide if # this is an error based on the ALPN result. self._init_http_parser() cdef _return_binary_error(self, binary.EdgeConnection proto): proto.write_error(errors.BinaryProtocolError( 'TLS Required', details='The server requires Transport Layer Security (TLS)', hint='Upgrade the client to a newer version that supports TLS' )) proto.close() async def _handle_request(self, HttpRequest request): cdef: HttpResponse response = HttpResponse() if self.transport is None: return if self.respond_hsts: if request.host: path = request.url.path.lstrip(b'/') loc = b'https://' + request.host + b'/' + path self.transport.write( b'HTTP/1.1 301 Moved Permanently\r\n' b'Strict-Transport-Security: max-age=31536000\r\n' b'Location: ' + loc + b'\r\n' b'\r\n' ) else: msg = b'Request is missing a header: Host\r\n' self.transport.write( b'HTTP/1.1 400 Bad Request\r\n' b'Content-Length: ' + str(len(msg)).encode() + b'\r\n' b'\r\n' + msg ) self.close() return if self.is_tls: if ( self.http_endpoint_security is srvargs.ServerEndpointSecurityMode.Optional ): response.custom_headers['Strict-Transport-Security'] = \ 'max-age=0' elif ( self.http_endpoint_security is srvargs.ServerEndpointSecurityMode.Tls ): response.custom_headers['Strict-Transport-Security'] = \ 'max-age=31536000' else: raise AssertionError( f"unexpected http_endpoint_security " f"value: {self.http_endpoint_security}" ) try: await self.handle_request(request, response) except errors.AvailabilityError as ex: self._close_with_error( b"503 Service Unavailable", f'{type(ex).__name__}: {ex}'.encode(), ) return except Exception as ex: self.unhandled_exception(b"500 Internal Server Error", ex) return if not response.sent: self.write(request, response) self.in_response = False if response.close_connection or not request.should_keep_alive: self.close() else: self.resume() def check_readiness(self): if self.tenant.is_blocked(): readiness_reason = self.tenant.get_readiness_reason() msg = "the server is not accepting requests" if readiness_reason: msg = f"{msg}: {readiness_reason}" raise errors.ServerBlockedError(msg) elif not self.tenant.is_online(): readiness_reason = self.tenant.get_readiness_reason() msg = "the server is going offline" if readiness_reason: msg = f"{msg}: {readiness_reason}" raise errors.ServerOfflineError(msg) async def handle_request(self, HttpRequest request, HttpResponse response): request_url = get_request_url(request, self.is_tls) path = request_url.path.decode('ascii') path = path.strip('/') path_parts = path.split('/') path_parts_len = len(path_parts) route = path_parts[0] if self.tenant is None and route in ['db', 'auth', 'branch']: self.tenant = self.server.get_default_tenant() self.check_readiness() if self.tenant.is_accepting_connections(): return await self.tenant.create_task( self.handle_request(request, response), interruptable=False, ) else: return self._close_with_error( b'503 Service Unavailable', b'The server is closing.', ) if route in ['db', 'branch']: if path_parts_len < 2: return self._not_found(request, response) dbname = urllib.parse.unquote(path_parts[1]) dbname = self.tenant.resolve_branch_name( database=dbname if route == 'db' else None, branch=dbname if route == 'branch' else None, ) extname = path_parts[2] if path_parts_len > 2 else None # Binary proto tunnelled through HTTP if extname is None: if await self._handle_cors( request, response, dbname=dbname, allow_methods=['POST'], allow_headers=[ 'Authorization', 'X-EdgeDB-User', 'X-Gel-User' ], ): return if request.method == b'POST': if not request.content_type: return self._bad_request( request, response, message="missing or malformed Content-Type header", ) ver_m = PROTO_MIME_RE.match(request.content_type) if not ver_m: return self._bad_request( request, response, message="missing or malformed Content-Type header", ) proto_ver = ( int(ver_m.group(1).decode()), int(ver_m.group(2).decode()), ) if proto_ver < edbdef.MIN_PROTOCOL: return self._bad_request( request, response, message="requested protocol version is too old and " "no longer supported", ) if proto_ver > edbdef.CURRENT_PROTOCOL: proto_ver = edbdef.CURRENT_PROTOCOL params = request.params if params is None: conn_params = {} else: conn_params = { n.decode("utf-8"): v.decode("utf-8") for n, v in request.params.items() } conn_params["database"] = dbname response.body = await binary.eval_buffer( self.server, self.tenant, database=dbname, data=self.current_request.body, conn_params=conn_params, protocol_version=proto_ver, auth_data=self.current_request.authorization, transport=srvargs.ServerConnTransport.HTTP, tcp_transport=self.transport, ) response.status = http.HTTPStatus.OK response.content_type = ( f'application/x.edgedb.v_' f'{proto_ver[0]}_{proto_ver[1]}.binary' ).encode() response.close_connection = True else: if await self._handle_cors( request, response, dbname=dbname, allow_methods=['GET', 'POST'], allow_headers=[ 'Authorization', 'X-EdgeDB-User', 'X-Gel-User' ], expose_headers=( ['EdgeDB-Protocol-Version', 'Gel-Protocol-Version'] if extname == 'notebook' else ['WWW-Authenticate'] if extname != 'auth' else None ), allow_credentials=True ): return # Check if this is a request to a registered extension if extname == 'edgeql': extname = 'edgeql_http' if extname == 'ext': if path_parts_len < 4: return self._not_found(request, response) extname = path_parts[3] args = path_parts[4:] else: args = path_parts[3:] role_name = None if extname != 'auth': role_name = await self._check_http_auth( request, response, dbname ) if not role_name: return db = self.tenant.maybe_get_db(dbname=dbname) if db is None: return self._not_found(request, response) if extname not in db.extensions: return self._not_found(request, response) if extname == 'graphql': await graphql_ext.handle_request( request, response, db, role_name, args, self.tenant ) elif extname == 'notebook': await notebook_ext.handle_request( request, response, db, role_name, args, self.tenant ) elif extname == 'edgeql_http': await edgeql_ext.handle_request( request, response, db, role_name, args, self.tenant ) elif extname == 'ai': await ai_ext.handle_request( self, request, response, db, role_name, args, self.tenant ) elif extname == 'auth': netloc = ( f"{request_url.host.decode()}:{request_url.port}" if request_url.port else request_url.host.decode() ) ext_base_path = f"{request_url.schema.decode()}://" \ f"{netloc}/{route}/" \ f"{urllib.parse.quote(dbname)}/ext/auth" handler = auth_ext.http.Router( db=db, base_path=ext_base_path, tenant=self.tenant, ) await handler.handle_request(request, response, args) if args: if args[0] == 'ui': if not (len(args) > 1 and args[1] == "_static"): srv_metrics.auth_ui_renders.inc( 1.0, self.get_tenant_label() ) else: srv_metrics.auth_api_calls.inc( 1.0, self.get_tenant_label() ) else: return self._not_found(request, response) elif route == 'auth': if await self._handle_cors( request, response, allow_methods=['GET'], allow_headers=['Authorization'], expose_headers=['WWW-Authenticate', 'Authentication-Info'] ): return # Authentication request await auth.handle_request( request, response, path_parts[1:], self.tenant, ) elif route == 'server': if not await self._authenticate_for_default_conn_transport( request, response, srvargs.ServerConnTransport.HTTP_HEALTH, ): return # System API request await system_api.handle_request( request, response, path_parts[1:], self.server, self.tenant, is_tenant_host=self.is_tenant_host, ) elif path_parts == ['metrics'] and request.method == b'GET': if not await self._authenticate_for_default_conn_transport( request, response, srvargs.ServerConnTransport.HTTP_METRICS, ): return self.server.get_compiler_pool().refresh_metrics() # Quoting the Open Metrics spec: # Implementers MUST expose metrics in the OpenMetrics # text format in response to a simple HTTP GET request # to a documented URL for a given process or device. # This endpoint SHOULD be called "/metrics". await metrics.handle_request( request, response, self.tenant, ) elif (path_parts == ['server-info'] and request.method == b'GET' and (self.server.in_dev_mode() or self.server.in_test_mode()) ): await server_info.handle_request( request, response, self.server, ) elif path_parts[0] == 'ui': if not self.server.is_admin_ui_enabled(): return self._not_found( request, response, "Admin UI is not enabled on this EdgeDB instance. " "Run the server with --admin-ui=enabled " "(or EDGEDB_SERVER_ADMIN_UI=enabled) to enable." ) else: await ui_ext.handle_request( request, response, path_parts[1:], self.server, ) else: return self._not_found(request, response) cdef _not_found( self, HttpRequest request, HttpResponse response, str message = "Unknown path", ): response.body = message.encode("utf-8") response.status = http.HTTPStatus.NOT_FOUND response.close_connection = True cdef _bad_request( self, HttpRequest request, HttpResponse response, str message, ): response.body = message.encode("utf-8") response.status = http.HTTPStatus.BAD_REQUEST response.close_connection = True async def _handle_cors( self, HttpRequest request, HttpResponse response, *, str dbname = None, list allow_methods = None, list allow_headers = [], list expose_headers = None, bint allow_credentials = False ): db = self.tenant.maybe_get_db(dbname=dbname) if dbname else None config = None if db is not None: if db.db_config is None: await db.introspection() config = db.db_config.get('cors_allow_origins') if config is None: config = self.tenant.get_sys_config().get('cors_allow_origins') allowed_origins = config.value if config else None overrides = self.server.get_cors_always_allowed_origins() if allowed_origins is None and overrides == []: return False origin = request.origin.decode() if request.origin else None origin_allowed = origin is not None and ( any( override.match(origin) if isinstance(override, re.Pattern) else origin == override for override in overrides ) or (origin in allowed_origins or '*' in allowed_origins) ) if origin_allowed: response.custom_headers['Access-Control-Allow-Origin'] = origin if expose_headers is not None: response.custom_headers['Access-Control-Expose-Headers'] = ( ', '.join(expose_headers)) if request.method == b'OPTIONS': response.status = http.HTTPStatus.NO_CONTENT if origin_allowed: if allow_methods is not None: response.custom_headers['Access-Control-Allow-Methods'] = ( ', '.join(allow_methods)) response.custom_headers['Access-Control-Allow-Headers'] = ( ', '.join(['Content-Type'] + allow_headers)) if allow_credentials: response.custom_headers['Access-Control-Allow-Credentials'] = ( 'true') return True return False cdef _unauthorized( self, HttpRequest request, HttpResponse response, str message, ): response.body = message.encode("utf-8") response.status = http.HTTPStatus.UNAUTHORIZED response.close_connection = True async def _check_http_auth( self, HttpRequest request, HttpResponse response, str dbname, ): dbindex: dbview.DatabaseIndex = self.tenant._dbindex scheme = None try: # Extract the username from the relevant request headers scheme, auth_payload = auth_helpers.extract_token_from_auth_data( request.authorization) username, opt_password = auth_helpers.extract_http_user( scheme, auth_payload, request.params) username = self.tenant.resolve_user_name(username) # Fetch the configured auth methods authmethods = await self.tenant.get_auth_methods( username, srvargs.ServerConnTransport.SIMPLE_HTTP) auth_errors = {} for authmethod in authmethods: authmethod_name = authmethod._tspec.name.split('::')[1] try: # If the auth method and the provided auth information # match, try to resolve the authentication. if authmethod_name == 'JWT' and scheme == 'bearer': auth_helpers.auth_jwt( self.tenant, auth_payload, username, dbname) elif authmethod_name == 'Password' and scheme == 'basic': auth_helpers.auth_basic( self.tenant, username, opt_password) elif authmethod_name == 'Trust': pass elif authmethod_name == 'SCRAM': raise errors.AuthenticationError( 'authentication failed: ' 'SCRAM authentication required but not ' 'supported for HTTP' ) elif authmethod_name == 'mTLS': if ( self.http_endpoint_security is srvargs.ServerEndpointSecurityMode.Tls or self.is_tls ): auth_helpers.auth_mtls_with_user( self.transport, username) else: raise errors.AuthenticationError( 'authentication failed: wrong method used') except errors.AuthenticationError as e: auth_errors[authmethod_name] = e else: break if len(auth_errors) == len(authmethods): if len(auth_errors) > 1: desc = "; ".join( f"{k}: {e.args[0]}" for k, e in auth_errors.items()) raise errors.AuthenticationError( f"all authentication methods failed: {desc}") else: raise next(iter(auth_errors.values())) role = self.tenant.get_roles().get(username) if not role: raise errors.AuthenticationError('authentication failed') branches = role['branches'] if '*' not in branches and dbname not in branches: raise errors.AuthenticationError( f"authentication failed: user does not have permission for " f"database branch '{dbname}'" ) except Exception as ex: if debug.flags.server: markup.dump(ex) self._unauthorized(request, response, str(ex)) # If no scheme was specified, add a WWW-Authenticate header if scheme == '': response.custom_headers['WWW-Authenticate'] = ( 'Basic realm="edgedb", Bearer' ) return None return username async def _authenticate_for_default_conn_transport( self, HttpRequest request, HttpResponse response, transport: srvargs.ServerConnTransport, ): try: auth_methods = self.server.get_default_auth_methods(transport) auth_errors = {} for auth_method in auth_methods: authmethod_name = auth_method._tspec.name.split('::')[1] try: # If the auth method and the provided auth information # match, try to resolve the authentication. if authmethod_name == 'Trust': pass elif authmethod_name == 'mTLS': if ( self.http_endpoint_security is srvargs.ServerEndpointSecurityMode.Tls or self.is_tls ): auth_helpers.auth_mtls(self.transport) else: raise errors.AuthenticationError( 'authentication failed: wrong method used') except errors.AuthenticationError as e: auth_errors[authmethod_name] = e else: break if len(auth_errors) == len(auth_methods): if len(auth_errors) > 1: desc = "; ".join( f"{k}: {e.args[0]}" for k, e in auth_errors.items()) raise errors.AuthenticationError( f"all authentication methods failed: {desc}") else: raise next(iter(auth_errors.values())) except Exception as ex: if debug.flags.server: markup.dump(ex) self._unauthorized(request, response, str(ex)) return False return True def get_request_url(request, is_tls): request_url = request.url default_schema = b"https" if is_tls else b"http" if all( getattr(request_url, attr) is None for attr in ('schema', 'host', 'port') ): forwarded = request.forwarded if hasattr(request, 'forwarded') else {} schema = forwarded.get(b'proto', default_schema).decode() host_header = forwarded.get(b'host', request.host).decode() host, _, port = host_header.partition(':') path = request_url.path.decode() new_url = f"{schema}://"\ f"{host}{port and ':' + port}"\ f"{path}" request_url = httptools.parse_url(new_url.encode()) return request_url ================================================ FILE: edb/server/protocol/request_scheduler.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from dataclasses import dataclass, field from typing import ( Final, Iterable, Literal, Optional, Sequence, ) import abc import asyncio import copy import random @dataclass class Timer: """Represents a time after which an action should be taken. Examples: (None, True) = execute immediately (None, False) = execute any time (10, True) = execute immediately after 10s (10, False) = execute any time after 10s """ time: Optional[float] = None # Whether the action should be taken as soon as possible after the time. urgent: bool = True @staticmethod def create_delay(delay: Optional[float], urgent: bool) -> Timer: now = asyncio.get_running_loop().time() if delay is None: time = None else: now = asyncio.get_running_loop().time() time = now + delay return Timer(time, urgent) def is_ready(self) -> bool: now = asyncio.get_running_loop().time() return self.time is None or self.time <= now def is_ready_and_urgent(self) -> bool: return self.is_ready() and self.urgent def remaining_time(self, max_delay: float) -> float: """How long before this timer is ready in seconds.""" if self.urgent: if self.time is None: return 0 else: # 1ms extra, just in case now = asyncio.get_running_loop().time() delay = self.time - now + 0.001 return min(max(0, delay), max_delay) else: # If not urgent, wait as long as possible return max_delay @staticmethod def combine(timers: Iterable[Timer]) -> Optional[Timer]: """Combine the timers to determine the when to take the next action. If the timers are (1, False), (2, False), (3, True), it may be wasteful to act at times [1, 2, 3]. Instead, we would prefer to act only once, at time 3, since only the third action was urgent. """ for target_urgency in [True, False]: if any( timer.time is None and timer.urgent == target_urgency for timer in timers ): # An action should be taken right away. return Timer(None, target_urgency) urgent_times = [ timer.time for timer in timers if timer.time is not None and timer.urgent == target_urgency ] if len(urgent_times) > 0: # An action should be taken after some delay return Timer(min(urgent_times), target_urgency) # Nothing to do return None def _default_delay_time() -> Timer: return Timer() @dataclass class Scheduler[_T](abc.ABC): """A scheduler for requests to an asynchronous service. A Scheduler both generates requests and tracks when the service can be accessed. """ service: Service # The next time to process requests timer: Timer = field(default_factory=_default_delay_time) @abc.abstractmethod async def get_params( self, context: Context, ) -> Optional[Sequence[Params[_T]]]: """Get parameters for the requests to run.""" raise NotImplementedError async def process(self, context: Context) -> bool: if not self.timer.is_ready(): return False request_params: Optional[Sequence[Params[_T]]] try: request_params = await self.get_params(context) except Exception: request_params = None error_count = 0 deferred_costs: dict[str, int] = { limit_name: 0 for limit_name in self.service.limits } success_count = 0 if request_params is None: error_count = 1 elif len(request_params) > 0: try: execution_report = await execute_no_sleep( request_params, service=self.service, ) except Exception: execution_report = ExecutionReport(unknown_error_count=1) assert isinstance(execution_report, ExecutionReport) self.finalize(execution_report) # Cache limits for next time if execution_report.updated_limits is not None: for limit_name, service_limit in self.service.limits.items(): if limit_name not in execution_report.updated_limits: continue updated_limit = execution_report.updated_limits[limit_name] if service_limit is not None: service_limit.update_total(updated_limit) service_limit.delay_factor = updated_limit.delay_factor else: self.service.limits[limit_name] = updated_limit # Update counts error_count = ( len(execution_report.known_error_messages) + execution_report.unknown_error_count ) deferred_costs = execution_report.deferred_costs success_count = execution_report.success_count # Update when this service should be processed again self.timer = self.service.next_delay( success_count, deferred_costs, error_count, context.naptime ) return True @abc.abstractmethod def finalize(self, execution_report: ExecutionReport) -> None: """An optional final step after executing requests""" pass @dataclass class Context: """Information passed to a Scheduler to process requests.""" # If there is no work, the scheduler should take a nap. naptime: float @dataclass class ExecutionReport: """Information about the requests after they are complete""" success_count: int = 0 unknown_error_count: int = 0 known_error_messages: list[str] = field(default_factory=list) deferred_costs: dict[str, int] = field(default_factory=dict) # Some requests may report an update to the service's rate limits. updated_limits: dict[str, Limits] = field(default_factory=dict) @dataclass class Service: """Information on how to access to a given service.""" # Information about the service's rate limits # Even if no Limit is available, the presence of a key indicates that a # limit is used at least sometimes. limits: dict[str, Optional[Limits]] = field(default_factory=dict) # The maximum number of times to retry requests max_retry_count: Final[int] = 4 # Whether to jitter the delay time if a retry error is produced jitter: Final[bool] = True # Initial guess for the delay guess_delay: Final[float] = 1.0 # The upper bound for delays delay_max: Final[float] = 60.0 def next_delay( self, success_count: int, deferred_costs: dict[str, int], error_count: int, naptime: float ) -> Timer: """When should the service should be processed again.""" if self.limits is not None: # Find the limit with the largest delay limit_delays: dict[str, Optional[float]] = {} for limit_names, service_limit in self.limits.items(): if service_limit is None: # If no information is available, assume no limits limit_delays[limit_names] = None else: base_delay = service_limit.base_delay( deferred_costs[limit_names], guess=self.guess_delay, ) if base_delay is None: limit_delays[limit_names] = None else: # If delay_factor is very high, it may take quite a long # time for it to return to 1. A maximum delay prevents # this service from never getting checked. limit_delays[limit_names] = ( min( base_delay * service_limit.delay_factor, self.delay_max, ) ) delay = _get_maximum_delay(limit_delays) else: # We have absolutely no information about the delay, assume naptime. delay = naptime if error_count > 0: # There was an error, wait before trying again. # Use the larger of delay or naptime. delay = max(delay, naptime) if delay is not None else naptime urgent = False elif any( deferred_cost > 0 for deferred_cost in deferred_costs.values() ): # There is some deferred work, apply the delay and run immediately. urgent = True elif success_count > 0: # Some work was done successfully. Run again to ensure no more work # needs to be done. delay = None urgent = True else: # No work left to do, wait before trying again. # Use the larger of delay or naptime. delay = max(delay, naptime) if delay is not None else naptime urgent = False return Timer.create_delay(delay, urgent) @dataclass class Limits: """Information about a service's rate limits.""" # Total limit of a resource per minute for a service. total: Optional[int | Literal['unlimited']] = None # Remaining resources before the limit is hit. # It is assumed to be decreasing during a call to execute_no_sleep. # # This can be set by users before a call to Scheduler.process. # It will also be updated during execution if a responseincludes an updated # value. # # Finally, it is reset after requests are executed since we don't know when # the next call will be. remaining: Optional[int] = None # A delay factor to implement exponential backoff delay_factor: float = 1 def base_delay( self, request_cost: int, *, guess: float, ) -> Optional[float]: if self.total == 'unlimited': return None if self.remaining is not None and request_cost <= self.remaining: return None if self.total is not None: assert isinstance(self.total, int) return 60.0 / self.total * 1.1 # 10% buffer just in case # guess the delay return guess def update_total(self, latest: Limits) -> Limits: """Update total based on the latest information. The total will change rarely. Always take the latest value if it exists """ if latest.total is not None: self.total = latest.total return self def update_remaining(self, latest: Limits) -> Limits: """Update remaining based on the latest information. The remaining amount is assumed to decreasing during a call to execute_no_sleep. """ if self.remaining is None: self.remaining = latest.remaining elif latest.remaining is not None: self.remaining = min(self.remaining, latest.remaining) if self.total == 'unlimited' and self.remaining: # If there is a remaining value, the total is not actually # unlimited. self.total = None return self class Request[_T](abc.ABC): """Represents an async request""" params: Params[_T] _inner: asyncio.Task[Optional[Result[_T]]] def __init__(self, params: Params[_T]): self.params = params self._inner = asyncio.create_task(self.run()) @abc.abstractmethod async def run(self) -> Optional[Result[_T]]: """Run the request and return a result.""" raise NotImplementedError async def wait_result(self) -> None: """Wait for the request to complete.""" await self._inner def get_result(self) -> Optional[Result[_T]]: """Get the result of the request.""" result = self._inner.result() return result class Params[_T](abc.ABC): """The parameters of an async request. These are used to generate requests. A single Params instance may be used to generate multiple Request instances if it fails, but can be retried right away. """ @abc.abstractmethod def costs(self) -> dict[str, int]: """Expected cost to execute the request. Keys must match service rate limits.""" raise NotImplementedError @abc.abstractmethod def create_request(self) -> Request[_T]: """Create a request with the parameters.""" raise NotImplementedError @dataclass(frozen=True) class Result[_T](abc.ABC): """The result of an async request.""" data: _T | Error # Some services can return request limits along with their usual results. # Keys should be a subset of service limits. limits: dict[str, Limits] = field(default_factory=dict) @abc.abstractmethod async def finalize(self) -> None: """An optional finalize to be run sequentially.""" pass @dataclass(frozen=True) class Error: """Represents an error from an async request.""" message: str # If there was an error, it may be possible to retry the request # Eg. 429 too many requests retry: bool async def execute_no_sleep[_T]( params: Sequence[Params[_T]], *, service: Service, ) -> ExecutionReport: """Attempt to execute as many requests as possible without sleeping.""" report = ExecutionReport() # Set up limits execute_limits: dict[str, Limits] = { limit_name: ( # If no other information is available, for the first attempt assume # there is no limit. Limits(total='unlimited') if service_limit is None else copy.copy(service_limit) ) for limit_name, service_limit in service.limits.items() } # If any requests fail and can be retried, retry them up to a maximum number # of times. retry_count: int = 0 # If the costs are larger than a total limit, set aside the excess to be # processed later. # This prevents wasting resources, and allows the delays to increase # specifically when an unexpected deferral happens. pending_request_indexes: list[int] excess_request_indexes: list[int] initial_pending_cost = { limit_name: 0 for limit_name in service.limits.keys() } for request_index in range(len(params)): for limit_name, cost in params[request_index].costs().items(): initial_pending_cost[limit_name] += cost if ( # If the pending cost exceeds a known limit, set aside some # requests. any( limit.total < initial_pending_cost[limit_name] for limit_name, limit in service.limits.items() if limit is not None and isinstance(limit.total, int) ) # Always include at least 1 request and request_index != 0 ): pending_request_indexes = ( list(range(request_index)) ) excess_request_indexes = ( list(range(request_index, len(params))) ) break else: # All inputs can be processed pending_request_indexes = list(range(len(params))) excess_request_indexes = [] while pending_request_indexes and retry_count < service.max_retry_count: # Find the highest delay required by any of the service's limits limit_base_delays = _get_limit_base_delays( params, execute_limits, pending_request_indexes, service.guess_delay ) base_delay = _get_maximum_delay(limit_base_delays) active_request_indexes: list[int] inactive_request_indexes: list[int] if base_delay is None: # Try to execute all requests. active_request_indexes = pending_request_indexes inactive_request_indexes = [] elif retry_count == 0: # If there is any delay, only execute one request. # This may update the remaining limit, allowing the remaining # requests to run. active_request_indexes = pending_request_indexes[:1] inactive_request_indexes = pending_request_indexes[1:] else: break results = await _execute_specified( params, active_request_indexes, ) # Check results retry_request_indexes: list[int] = [] for request_index in active_request_indexes: if request_index not in results: report.unknown_error_count += 1 continue result = results[request_index] if isinstance(result.data, Error): if result.data.retry: # requests can be retried retry_request_indexes.append(request_index) else: # error with message report.known_error_messages.append(result.data.message) else: report.success_count += 1 await result.finalize() if result.limits is not None: for limit_name, execute_limit in execute_limits.items(): if limit_name not in result.limits: continue result_limit = result.limits[limit_name] execute_limit.update_total(result_limit) execute_limit.update_remaining(result_limit) retry_count += 1 pending_request_indexes = ( retry_request_indexes + inactive_request_indexes ) # Determine which limits cause unexpected deferrals and require additional # delays. limit_base_delays = _get_limit_base_delays( params, execute_limits, pending_request_indexes, service.guess_delay ) expected_pending_cost = { limit_name for limit_name in service.limits.keys() if limit_base_delays[limit_name] is not None } if len(expected_pending_cost) == 0: # If requests were deferred, but no limit appears to be the cause, delay # them all just in case. expected_pending_cost = set(service.limits.keys()) # Update deferred costs and any resulting limits. report.deferred_costs = { limit_name: 0 for limit_name in service.limits } for limit_name in service.limits.keys(): unexpected_deferred_costs = sum( params[i].costs()[limit_name] for i in pending_request_indexes ) excess_deferred_costs = sum( params[i].costs()[limit_name] for i in excess_request_indexes ) report.deferred_costs[limit_name] = ( unexpected_deferred_costs + excess_deferred_costs ) if ( unexpected_deferred_costs != 0 # If the limit was not a cause of delays, don't increase the delay # factor. and limit_name in expected_pending_cost ): # If there are deferred requests, gradually increase the delay # factor execute_limits[limit_name].delay_factor *= ( 1 + random.random() if service.jitter else 2 ) elif ( len(report.known_error_messages) == 0 and report.unknown_error_count == 0 and excess_deferred_costs == 0 ): # If there are no errors, gradually decrease the delay factor over # time. execute_limits[limit_name].delay_factor = max( 0.95 * execute_limits[limit_name].delay_factor, 1, ) # We don't know when the service will be called again, so just clear the # remaining values for execute_limit in execute_limits.values(): execute_limit.remaining = None # Return the updated request limits report.updated_limits = execute_limits return report async def _execute_specified[_T]( params: Sequence[Params[_T]], indexes: Iterable[int], ) -> dict[int, Result[_T]]: # Send all requests at once. # We are confident that all requests can be handled right away. requests: dict[int, Request[_T]] = {} for request_index in indexes: requests[request_index] = params[request_index].create_request() results: dict[int, Result[_T]] = {} for request_index, request in requests.items(): await request.wait_result() result = request.get_result() if result is not None: results[request_index] = result return results def _get_limit_base_delays[_T]( params: Sequence[Params[_T]], limits: dict[str, Limits], request_indexes: Sequence[int], guess_delay: float, ) -> dict[str, Optional[float]]: base_delays = {} for limit_name, limit in limits.items(): pending_limit_cost = sum( params[request_index].costs()[limit_name] for request_index in request_indexes ) base_delays[limit_name] = (limit.base_delay( pending_limit_cost, guess=guess_delay, )) return base_delays def _get_maximum_delay( delays: dict[str, Optional[float]] ) -> Optional[float]: result: Optional[float] = None for delay in delays.values(): if result is None: result = delay elif delay is not None: result = max(result, delay) return result ================================================ FILE: edb/server/protocol/server_info.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, TYPE_CHECKING import dataclasses import http import json import immutables from edb import errors from edb.ir import statypes from edb.common import debug from edb.common import markup if TYPE_CHECKING: from edb.server import server as edbserver from edb.server.protocol import protocol class ImmutableEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: if isinstance(obj, (set, frozenset)): return list(obj) if isinstance(obj, immutables.Map): return dict(obj.items()) if dataclasses.is_dataclass(obj) and not isinstance(obj, type): return dataclasses.asdict(obj) if isinstance(obj, statypes.ScalarType): return obj.to_json() if isinstance(obj, statypes.CompositeType): return obj.to_json_value() return super().default(obj) async def handle_request( request: protocol.HttpRequest, response: protocol.HttpResponse, server: edbserver.Server, ) -> None: try: output = ImmutableEncoder().encode(server.get_debug_info()) response.status = http.HTTPStatus.OK response.content_type = b'application/json' response.body = output.encode() response.close_connection = True except Exception as ex: if debug.flags.server: markup.dump(ex) # XXX Fix this when LSP "location" objects are implemented ex_type = errors.InternalServerError _response_error( response, http.HTTPStatus.INTERNAL_SERVER_ERROR, str(ex), ex_type ) def _response_error( response: protocol.HttpResponse, status: http.HTTPStatus, message: str, ex_type: type[errors.EdgeDBError], ) -> None: response.body = ( f'Unexpected error in /server-info.\n\n' f'{ex_type.__name__}: {message}' ).encode() response.status = status response.close_connection = True ================================================ FILE: edb/server/protocol/system_api.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import TYPE_CHECKING import asyncio import http import json from edb import errors from edb.common import debug from edb.common import markup if TYPE_CHECKING: from edb.server import tenant as edbtenant, server as edbserver from edb.server.protocol import protocol async def handle_request( request: protocol.HttpRequest, response: protocol.HttpResponse, path_parts: list[str], server: edbserver.BaseServer, tenant: edbtenant.Tenant, is_tenant_host: bool, ) -> None: try: if tenant is None: try: tenant = server.get_default_tenant() except Exception: # Multi-tenant server doesn't have default tenant pass if tenant is None and not is_tenant_host: _response( response, http.HTTPStatus.NOT_FOUND, b'"No such tenant configured"', True, ) elif path_parts == ['status', 'ready'] and request.method == b'GET': if tenant is None: await handle_compiler_query(server, response) else: await tenant.create_task( handle_readiness_query(request, response, tenant), interruptable=False, ) elif path_parts == ['status', 'alive'] and request.method == b'GET': if tenant is None: await handle_compiler_query(server, response) else: await tenant.create_task( handle_liveness_query(request, response, tenant), interruptable=False, ) else: _response( response, http.HTTPStatus.NOT_FOUND, b'"Unknown path"', True, ) except errors.BackendUnavailableError as ex: _response_error( response, http.HTTPStatus.SERVICE_UNAVAILABLE, str(ex), type(ex) ) except errors.EdgeDBError as ex: if debug.flags.server: markup.dump(ex) _response_error( response, http.HTTPStatus.INTERNAL_SERVER_ERROR, str(ex), type(ex) ) except Exception as ex: if debug.flags.server: markup.dump(ex) # XXX Fix this when LSP "location" objects are implemented ex_type = errors.InternalServerError _response_error( response, http.HTTPStatus.INTERNAL_SERVER_ERROR, str(ex), ex_type ) def _response_error( response: protocol.HttpResponse, status: http.HTTPStatus, message: str, ex_type: type[errors.EdgeDBError], ) -> None: err_dct = { 'message': message, 'type': str(ex_type.__name__), 'code': ex_type.get_code(), } _response(response, status, json.dumps({'error': err_dct}).encode(), True) def _response( response: protocol.HttpResponse, status: http.HTTPStatus, message: bytes, close_connection: bool, ) -> None: response.body = message response.status = status response.content_type = b'application/json' response.close_connection = close_connection def _response_ok(response: protocol.HttpResponse, message: bytes) -> None: _response(response, http.HTTPStatus.OK, message, False) async def _ping( response: protocol.HttpResponse, tenant: edbtenant.Tenant ) -> None: try: async with asyncio.TaskGroup() as tg: ping_backend = tg.create_task(tenant.ping_backend()) ping_compiler = tg.create_task( tenant.server.get_compiler_pool().health_check() ) except *TimeoutError: if isinstance(ping_backend.exception(), TimeoutError): who = "the backend" else: who = "the compiler pool" _response_error( response, http.HTTPStatus.SERVICE_UNAVAILABLE, f"{who} health check timed out", errors.AvailabilityError, ) else: if not ping_backend.result(): _response_error( response, http.HTTPStatus.SERVICE_UNAVAILABLE, "this server is not ready to accept connections", errors.BackendUnavailableError, ) elif not ping_compiler.result(): _response_error( response, http.HTTPStatus.SERVICE_UNAVAILABLE, "The compiler pool is not ready", errors.AvailabilityError, ) else: _response_ok(response, b'"OK"') async def handle_compiler_query( server: edbserver.BaseServer, response: protocol.HttpResponse, ) -> None: if await server.get_compiler_pool().health_check(): _response_ok(response, b'"OK"') else: _response_error( response, http.HTTPStatus.SERVICE_UNAVAILABLE, "The compiler pool is not ready", errors.AvailabilityError, ) async def handle_liveness_query( request: protocol.HttpRequest, response: protocol.HttpResponse, tenant: edbtenant.Tenant, ) -> None: await _ping(response, tenant) async def handle_readiness_query( request: protocol.HttpRequest, response: protocol.HttpResponse, tenant: edbtenant.Tenant, ) -> None: if not tenant.is_ready(): _response_error( response, http.HTTPStatus.SERVICE_UNAVAILABLE, "this server is not ready to accept connections", errors.AccessError, ) else: await _ping(response, tenant) ================================================ FILE: edb/server/protocol/ui_ext.pyx ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import base64 import http import json import urllib.parse import os import mimetypes import immutables from edb import buildmeta from edb import errors from edb.common import debug from edb.common import markup STATIC_FILES_DIR = str(buildmeta.get_shared_data_dir_path() / 'ui') static_files = dict() def cache_assets(): for dirpath, _, filenames in os.walk(STATIC_FILES_DIR): for filename in filenames: fullpath = os.path.join(dirpath, filename) mimetype = mimetypes.guess_type(filename)[0] if mimetype is None: mimetype = 'application/octet-stream' with open(fullpath, 'rb') as f: static_files[os.path.relpath(fullpath, STATIC_FILES_DIR)] = ( f.read(), mimetype.encode() ) async def handle_request( request, response, path_parts, server, ): try: if path_parts == []: path_parts = ['index.html'] data, content_type = static_files.get( os.path.join(*path_parts), static_files['index.html'] ) response.status = http.HTTPStatus.OK response.content_type = content_type response.body = data return except Exception as ex: return handle_error(request, response, ex) def handle_error( request, response, error ): if debug.flags.server: markup.dump(error) response.body = b'Internal Server Error' response.status = http.HTTPStatus.INTERNAL_SERVER_ERROR response.close_connection = True ================================================ FILE: edb/server/rust_async_channel.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import asyncio import io import logging from typing import Protocol, Optional, Any, Callable logger = logging.getLogger("edb.server") MAX_BATCH_SIZE = 16 class RustPipeProtocol(Protocol): def _read(self) -> tuple[Any, ...]: ... def _try_read(self) -> Optional[tuple[Any, ...]]: ... def _close_pipe(self) -> None: ... _fd: int class RustAsyncChannel: _buffered_reader: io.BufferedReader _skip_reads: int _closed: asyncio.Event def __init__( self, pipe: RustPipeProtocol, callback: Callable[[tuple[Any, ...]], None], ) -> None: self._closed = asyncio.Event() fd = pipe._fd self._buffered_reader = io.BufferedReader( io.FileIO(fd), buffer_size=MAX_BATCH_SIZE ) self._fd = fd self._pipe = pipe self._callback = callback self._skip_reads = 0 def __del__(self): if not self._closed.is_set(): logger.error(f"RustAsyncChannel {id(self)} was not closed") async def run(self): loop = asyncio.get_running_loop() loop.add_reader(self._fd, self._channel_read) try: await self._closed.wait() finally: loop.remove_reader(self._fd) def close(self): if not self._closed.is_set(): self._pipe._close_pipe() self._buffered_reader.close() self._closed.set() def read_hint(self): while msg := self._pipe._try_read(): self._skip_reads += 1 self._callback(msg) def _channel_read(self) -> None: try: n = len(self._buffered_reader.read1(MAX_BATCH_SIZE)) if n == 0: return if self._skip_reads > n: self._skip_reads -= n return n -= self._skip_reads self._skip_reads = 0 for _ in range(n): msg = self._pipe._read() if msg is None: self.close() return self._callback(msg) except Exception: logger.error( f"Error reading from Rust async channel", exc_info=True ) self.close() ================================================ FILE: edb/server/server.py ================================================ # mypy: check-untyped-defs # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import re from typing import ( Any, Callable, Optional, Hashable, Iterator, Mapping, Sequence, TYPE_CHECKING, ) import asyncio import collections import ipaddress import itertools import json import logging import os import pathlib import pickle import socket import ssl import stat import time import uuid import immutables from edb import buildmeta from edb import errors from edb.common import devmode from edb.common import lru from edb.common import secretkey from edb.common import windowedsum from edb.common.log import current_tenant from edb.schema import reflection as s_refl from edb.schema import schema as s_schema from edb.server import auth from edb.server import args as srvargs from edb.server import cache from edb.server import config from edb.server import compiler_pool from edb.server import daemon from edb.server import defines from edb.server import instdata from edb.server import protocol from edb.server import net_worker from edb.server import tenant as edbtenant from edb.server.protocol import binary # type: ignore from edb.server.protocol import pg_ext # type: ignore from edb.server.protocol import ui_ext # type: ignore from edb.server.protocol.auth_ext import pkce from edb.server import metrics from edb.server import pgcon from edb.pgsql import patches as pg_patches from . import compiler as edbcompiler from .compiler import sertypes if TYPE_CHECKING: import asyncio.base_events from edb.pgsql import params as pgparams from . import bootstrap ADMIN_PLACEHOLDER = "" logger = logging.getLogger('edb.server') log_metrics = logging.getLogger('edb.server.metrics') class StartupError(Exception): pass class BaseServer: _sys_queries: Mapping[str, bytes] _local_intro_query: bytes _global_intro_query: bytes _report_config_typedesc: dict[defines.ProtocolVersion, bytes] _use_monitor_fs: bool _file_watch_handles: list[asyncio.Handle] _std_schema: s_schema.Schema _refl_schema: s_schema.Schema _schema_class_layout: s_refl.SchemaClassLayout _servers: Mapping[str, asyncio.AbstractServer] _testmode: bool # We maintain an OrderedDict of all active client connections. # We use an OrderedDict because it allows to move keys to either # end of the dict. That's used to keep all active client connections # grouped at the right end of the dict. The idea is that we can then # have a periodically run coroutine to GC all inactive connections. # This should be more economical than maintaining a TimerHandle for # every open connection. Also, this way, we can react to the # `session_idle_timeout` config setting changed mid-flight. _binary_conns: collections.OrderedDict[binary.EdgeConnection, bool] _pgext_conns: dict[str, pg_ext.PgConnection] _idle_gc_handler: asyncio.TimerHandle | None = None _stmt_cache_size: int | None = None _compiler_pool: compiler_pool.AbstractPool | None compilation_config_serializer: sertypes.CompilationConfigSerializer _http_request_logger: asyncio.Task | None _auth_gc: asyncio.Task | None _net_worker_http: asyncio.Task | None _net_worker_http_gc: asyncio.Task | None def __init__( self, *, runstate_dir: pathlib.Path, internal_runstate_dir: pathlib.Path, compiler_pool_size: int, compiler_worker_branch_limit, compiler_pool_mode: srvargs.CompilerPoolMode, compiler_pool_addr: tuple[str, int], nethosts: Sequence[str], netport: int, compiler_worker_max_rss: Optional[int] = None, listen_sockets: tuple[socket.socket, ...] = (), testmode: bool = False, daemonized: bool = False, pidfile_dir: Optional[pathlib.Path] = None, binary_endpoint_security: srvargs.ServerEndpointSecurityMode = ( srvargs.ServerEndpointSecurityMode.Tls), http_endpoint_security: srvargs.ServerEndpointSecurityMode = ( srvargs.ServerEndpointSecurityMode.Tls), auto_shutdown_after: float = -1, echo_runtime_info: bool = False, status_sinks: Sequence[Callable[[str], None]] = (), default_auth_method: srvargs.ServerAuthMethods = ( srvargs.DEFAULT_AUTH_METHODS), admin_ui: bool = False, cors_always_allowed_origins: Optional[str] = None, disable_dynamic_system_config: bool = False, compiler_state: edbcompiler.CompilerState, use_monitor_fs: bool = False, net_worker_mode: srvargs.NetWorkerMode = srvargs.NetWorkerMode.Default, ): self.__loop = asyncio.get_running_loop() self._use_monitor_fs = use_monitor_fs self._schema_class_layout = compiler_state.schema_class_layout self._config_settings = compiler_state.config_spec self._refl_schema = compiler_state.refl_schema self._std_schema = compiler_state.std_schema assert compiler_state.global_intro_query is not None self._global_intro_query = ( compiler_state.global_intro_query.encode("utf-8")) assert compiler_state.local_intro_query is not None self._local_intro_query = ( compiler_state.local_intro_query.encode("utf-8")) # Used to tag PG notifications to later disambiguate them. self._server_id = str(uuid.uuid4()) self._daemonized = daemonized self._pidfile_dir = pidfile_dir self._runstate_dir = runstate_dir self._internal_runstate_dir = internal_runstate_dir self._compiler_pool = None self._compiler_pool_size = compiler_pool_size self._compiler_worker_branch_limit = compiler_worker_branch_limit self._compiler_pool_mode = compiler_pool_mode self._compiler_pool_addr = compiler_pool_addr self._compiler_worker_max_rss = compiler_worker_max_rss self._system_compile_cache = lru.LRUMapping( maxsize=defines._MAX_QUERIES_CACHE_SYSTEM ) self._system_compile_cache_locks: dict[Any, Any] = {} self._listen_sockets = listen_sockets if listen_sockets: nethosts = tuple(s.getsockname()[0] for s in listen_sockets) netport = listen_sockets[0].getsockname()[1] self._listen_hosts = nethosts self._listen_port = netport # Shutdown the server after the last management # connection has disconnected # and there have been no new connections for n seconds self._auto_shutdown_after = auto_shutdown_after self._auto_shutdown_handler: Any = None self._keepalive_tokens: set = set() self._echo_runtime_info = echo_runtime_info self._status_sinks = status_sinks self._sys_queries = immutables.Map() self._devmode = devmode.is_in_dev_mode() self._testmode = testmode self._binary_proto_id_counter = 0 self._binary_conns = collections.OrderedDict() self._pgext_conns = {} self._servers = {} self._http_query_cache = cache.StatementsCache( maxsize=defines.HTTP_PORT_QUERY_CACHE_SIZE) self._http_last_minute_requests = windowedsum.WindowedSum() self._http_request_logger = None self._auth_gc = None self._net_worker_http = None self._net_worker_http_gc = None self._net_worker_mode = net_worker_mode self._stop_evt = asyncio.Event() self._tls_cert_file: str | Any = None self._tls_cert_newly_generated = False self._sslctx: ssl.SSLContext | Any = None self._sslctx_pgext: ssl.SSLContext | Any = None self._jws_key: auth.JWKSet | None = None self._jws_keys_newly_generated = False self._default_auth_method_spec = default_auth_method self._default_auth_methods = self._get_auth_method_types( default_auth_method) self._binary_endpoint_security = binary_endpoint_security self._http_endpoint_security = http_endpoint_security self._idle_gc_handler = None self._admin_ui = admin_ui self._cors_always_allowed_origins = [ re.compile( '^' + origin .replace('.', '\\.') .replace('*', '.*') + '$' ) if '*' in origin else origin for origin in cors_always_allowed_origins.split(',') ] if cors_always_allowed_origins else [] self._file_watch_handles = [] self._tls_certs_reload_retry_handle: Any | asyncio.TimerHandle = None self._disable_dynamic_system_config = disable_dynamic_system_config self._report_config_typedesc = {} def _get_auth_method_types( self, auth_methods_spec: srvargs.ServerAuthMethods, ) -> dict[srvargs.ServerConnTransport, list[config.CompositeConfigType]]: mapping = {} for transport, methods in auth_methods_spec.items(): result = [] for method in methods: auth_type = self.config_settings.get_type_by_name( f'cfg::{method.value}' ) result.append(auth_type()) mapping[transport] = result return mapping async def _request_stats_logger(self): last_seen = -1 while True: current = int(self._http_last_minute_requests) if current != last_seen: log_metrics.info( "HTTP requests in last minute: %d", current, ) last_seen = current await asyncio.sleep(30) def get_server_id(self): return self._server_id def get_listen_hosts(self): return self._listen_hosts def get_listen_port(self): return self._listen_port def get_loop(self): return self.__loop def in_dev_mode(self): return self._devmode def in_test_mode(self): return self._testmode def is_admin_ui_enabled(self): return self._admin_ui def get_cors_always_allowed_origins(self): return self._cors_always_allowed_origins def on_binary_client_created(self) -> str: self._binary_proto_id_counter += 1 return str(self._binary_proto_id_counter) def on_binary_client_connected(self, conn): self._binary_conns[conn] = True metrics.current_client_connections.inc( 1.0, conn.get_tenant_label() ) def on_binary_client_authed(self, conn): self._report_connections(event='opened') metrics.total_client_connections.inc( 1.0, conn.get_tenant_label() ) def on_binary_client_after_idling(self, conn): try: self._binary_conns.move_to_end(conn, last=True) except KeyError: # Shouldn't happen, but just in case some weird async twist # gets us here we don't want to crash the connection with # this error. metrics.background_errors.inc( 1.0, conn.get_tenant_label(), 'client_after_idling' ) def on_binary_client_disconnected(self, conn): self._binary_conns.pop(conn, None) self._report_connections(event="closed") metrics.current_client_connections.dec( 1.0, conn.get_tenant_label() ) self.maybe_auto_shutdown() def maybe_delay_auto_shutdown(self): if self._auto_shutdown_handler: self._auto_shutdown_handler.cancel() self._auto_shutdown_handler = None def maybe_auto_shutdown(self): if ( not self._binary_conns and not self._keepalive_tokens and self._auto_shutdown_after >= 0 and self._auto_shutdown_handler is None ): self._auto_shutdown_handler = self.__loop.call_later( self._auto_shutdown_after, self.request_auto_shutdown) def _report_connections(self, *, event: str) -> None: log_metrics.info( "%s a connection; open_count=%d", event, len(self._binary_conns), ) def add_keepalive_token(self, token: Hashable) -> None: self.maybe_delay_auto_shutdown() self._keepalive_tokens.add(token) def remove_keepalive_token(self, token: Hashable) -> None: self._keepalive_tokens.discard(token) self.maybe_auto_shutdown() def on_pgext_client_connected(self, conn): self._pgext_conns[conn.get_id()] = conn def on_pgext_client_disconnected(self, conn): self._pgext_conns.pop(conn.get_id(), None) self.maybe_auto_shutdown() def cancel_pgext_connection(self, pid, secret): conn = self._pgext_conns.get(pid) if conn is not None: conn.cancel(secret) def monitor_fs( self, file_path: str | pathlib.Path, cb: Callable[[], None], ) -> Callable[[], None]: if not self._use_monitor_fs: return lambda: None if isinstance(file_path, str): path = pathlib.Path(file_path) path_str = file_path else: path = file_path path_str = str(file_path) handle = None parent_dir = path.parent def watch_dir(file_modified, _event): nonlocal handle if parent_dir / os.fsdecode(file_modified) == path: try: new_handle = self.__loop._monitor_fs( # type: ignore path_str, callback) except FileNotFoundError: pass else: finalizer() handle = new_handle self._file_watch_handles.append(handle) cb() def callback(_file_modified, _event): nonlocal handle # First, cancel the existing watcher and call cb() regardless of # what event it is. This is because macOS issues RENAME while Linux # issues CHANGE, and we don't have enough knowledge about renaming. # The idea here is to re-watch the file path after every event, so # that even if the file is recreated, we still watch the right one. finalizer() try: cb() finally: try: # Then, see if we can directly re-watch the target path handle = self.__loop._monitor_fs( # type: ignore path_str, callback) except FileNotFoundError: # If not, watch the parent directory to wait for recreation handle = self.__loop._monitor_fs( # type: ignore str(parent_dir), watch_dir) self._file_watch_handles.append(handle) # ... we depend on an event loop internal _monitor_fs handle = self.__loop._monitor_fs(path_str, callback) # type: ignore def finalizer(): try: self._file_watch_handles.remove(handle) except ValueError: # The server may have cleared _file_watch_handles before the # tenants do, so we can skip the double cancel here. pass else: handle.cancel() self._file_watch_handles.append(handle) return finalizer def _get_sys_config(self) -> Mapping[str, config.SettingValue]: raise NotImplementedError def config_lookup( self, name: str, *configs: Mapping[str, config.SettingValue], ) -> Any: return config.lookup(name, *configs, spec=self._config_settings) @property def config_settings(self) -> config.Spec: return self._config_settings async def init(self): if self.is_admin_ui_enabled(): ui_ext.cache_assets() sys_config = self._get_sys_config() if not self._listen_hosts: self._listen_hosts = ( self.config_lookup('listen_addresses', sys_config) or ('localhost',) ) if self._listen_port is None: self._listen_port = ( self.config_lookup('listen_port', sys_config) or defines.EDGEDB_PORT ) self._stmt_cache_size = self.config_lookup( '_pg_prepared_statement_cache_size', sys_config ) self.reinit_idle_gc_collector() def reinit_idle_gc_collector(self) -> float: if self._auto_shutdown_after >= 0: return -1 if self._idle_gc_handler is not None: self._idle_gc_handler.cancel() self._idle_gc_handler = None session_idle_timeout = self.config_lookup( 'session_idle_timeout', self._get_sys_config()) timeout = session_idle_timeout.to_microseconds() timeout /= 1_000_000.0 # convert to seconds if timeout > 0: self._idle_gc_handler = self.__loop.call_later( timeout, self._idle_gc_collector) return timeout @property def stmt_cache_size(self) -> int | None: return self._stmt_cache_size @property def system_compile_cache(self): return self._system_compile_cache def request_stop_fe_conns(self, dbname: str) -> None: for conn in itertools.chain( self._binary_conns.keys(), self._pgext_conns.values() ): if conn.dbname == dbname: conn.request_stop() @property def system_compile_cache_locks(self): return self._system_compile_cache_locks def _idle_gc_collector(self): try: self._idle_gc_handler = None idle_timeout = self.reinit_idle_gc_collector() if idle_timeout <= 0: return now = time.monotonic() expiry_time = now - idle_timeout for conn in self._binary_conns: try: if conn.is_idle(expiry_time): label = conn.get_tenant_label() metrics.idle_client_connections.inc(1.0, label) current_tenant.set(label) conn.close_for_idling() elif conn.is_alive(): # We are sorting connections in # 'on_binary_client_after_idling' to specifically # enable this optimization. As soon as we find first # non-idle active connection we're guaranteed # to have traversed all of the potentially idling # connections. break except Exception: metrics.background_errors.inc( 1.0, conn.get_tenant_label(), 'close_for_idling' ) conn.abort() except Exception: metrics.background_errors.inc( 1.0, 'system', 'idle_clients_collector' ) raise def _get_backend_runtime_params(self) -> pgparams.BackendRuntimeParams: raise NotImplementedError def _get_compiler_args(self) -> dict[str, Any]: # Force Postgres version in BackendRuntimeParams to be the # minimal supported, because the compiler does not rely on # the version, and not pinning it would make the remote compiler # pool refuse connections from clients that have differing versions # of Postgres backing them. runtime_params = self._get_backend_runtime_params() min_ver = '.'.join(str(v) for v in defines.MIN_POSTGRES_VERSION) runtime_params = runtime_params._replace( instance_params=runtime_params.instance_params._replace( version=buildmeta.parse_pg_version(min_ver), ), ) args = dict( pool_size=self._compiler_pool_size, worker_branch_limit=self._compiler_worker_branch_limit, pool_class=self._compiler_pool_mode.pool_class, runstate_dir=self._internal_runstate_dir, backend_runtime_params=runtime_params, std_schema=self._std_schema, refl_schema=self._refl_schema, schema_class_layout=self._schema_class_layout, ) if self._compiler_pool_mode == srvargs.CompilerPoolMode.Remote: args['address'] = self._compiler_pool_addr else: if self._compiler_worker_max_rss is not None: args['worker_max_rss'] = self._compiler_worker_max_rss return args async def _destroy_compiler_pool(self): if self._compiler_pool is not None: await self._compiler_pool.stop() self._compiler_pool = None def get_compiler_pool(self): return self._compiler_pool async def introspect_global_schema_json( self, conn: pgcon.PGConnection ) -> bytes: return await conn.sql_fetch_val(self._global_intro_query) def _parse_global_schema( self, json_data: Any ) -> s_schema.Schema: return s_refl.parse_schema( base_schema=self._std_schema, data=json_data, schema_class_layout=self._schema_class_layout, ) async def introspect_global_schema( self, conn: pgcon.PGConnection ) -> s_schema.Schema: json_data = await self.introspect_global_schema_json(conn) return self._parse_global_schema(json_data) async def introspect_user_schema_json( self, conn: pgcon.PGConnection, ) -> bytes: return await conn.sql_fetch_val(self._local_intro_query) def _parse_user_schema( self, json_data: Any, global_schema: s_schema.Schema, ) -> s_schema.Schema: base_schema = s_schema.ChainedSchema( self._std_schema, s_schema.EMPTY_SCHEMA, global_schema, ) return s_refl.parse_schema( base_schema=base_schema, data=json_data, schema_class_layout=self._schema_class_layout, ) async def _introspect_user_schema( self, conn: pgcon.PGConnection, global_schema: s_schema.Schema, ) -> s_schema.Schema: json_data = await self.introspect_user_schema_json(conn) return self._parse_user_schema(json_data, global_schema) async def introspect_db_config(self, conn: pgcon.PGConnection) -> bytes: return await conn.sql_fetch_val(self.get_sys_query("dbconfig")) def _parse_db_config( self, db_config_json: bytes, user_schema: s_schema.Schema ) -> Mapping[str, config.SettingValue]: spec = config.ChainedSpec( self._config_settings, config.load_ext_spec_from_schema( user_schema, self.get_std_schema(), ), ) return config.from_json(spec, db_config_json) async def get_dbnames(self, syscon): dbs_query = self.get_sys_query('listdbs') json_data = await syscon.sql_fetch_val(dbs_query) return json.loads(json_data) async def _on_system_config_add(self, setting_name, value): # CONFIGURE INSTANCE INSERT ConfigObject; pass async def _on_system_config_rem(self, setting_name, value): # CONFIGURE INSTANCE RESET ConfigObject; pass async def _on_system_config_set(self, setting_name, value): # CONFIGURE INSTANCE SET setting_name := value; pass async def _on_system_config_reset(self, setting_name): # CONFIGURE INSTANCE RESET setting_name; pass def before_alter_system_config(self): if self._disable_dynamic_system_config: raise errors.ConfigurationError( "cannot change this configuration value in this instance" ) async def _after_system_config_add(self, setting_name, value): # CONFIGURE INSTANCE INSERT ConfigObject; pass async def _after_system_config_rem(self, setting_name, value): # CONFIGURE INSTANCE RESET ConfigObject; pass async def _after_system_config_set(self, setting_name, value): # CONFIGURE INSTANCE SET setting_name := value; pass async def _after_system_config_reset(self, setting_name): # CONFIGURE INSTANCE RESET setting_name; pass def _make_protocol(self): self.maybe_delay_auto_shutdown() return protocol.HttpProtocol( self, self._sslctx, self._sslctx_pgext, binary_endpoint_security=self._binary_endpoint_security, http_endpoint_security=self._http_endpoint_security, ) async def _start_server( self, host: str, port: int, sock: Optional[socket.socket] = None, ) -> Optional[asyncio.base_events.Server]: try: kwargs: dict[str, Any] if sock is not None: kwargs = {"sock": sock} else: kwargs = {"host": host, "port": port} return await self.__loop.create_server( self._make_protocol, **kwargs ) except Exception as e: logger.warning( f"could not create listen socket for '{host}:{port}': {e}" ) return None async def _start_admin_server( self, port: int, ) -> asyncio.base_events.Server: admin_unix_sock_path = os.path.join( self._runstate_dir, f'.s.GEL.admin.{port}') symlink = os.path.join( self._runstate_dir, f'.s.EDGEDB.admin.{port}') exists = False try: mode = os.lstat(symlink).st_mode if stat.S_ISSOCK(mode): os.unlink(symlink) else: exists = True except FileNotFoundError: pass if not exists: os.symlink(admin_unix_sock_path, symlink) assert len(admin_unix_sock_path) <= ( defines.MAX_RUNSTATE_DIR_PATH + defines.MAX_UNIX_SOCKET_PATH_LENGTH + 1 ), "admin Unix socket length exceeds maximum allowed" admin_unix_srv = await self.__loop.create_unix_server( lambda: binary.new_edge_connection( self, self._get_admin_tenant(), external_auth=True ), admin_unix_sock_path ) os.chmod(admin_unix_sock_path, stat.S_IRUSR | stat.S_IWUSR) logger.info('Serving admin on %s', admin_unix_sock_path) return admin_unix_srv def _get_admin_tenant(self) -> edbtenant.Tenant: return self.get_default_tenant() async def _start_servers( self, hosts: tuple[str, ...], port: int, *, admin: bool = True, sockets: tuple[socket.socket, ...] = (), ): servers = {} if port == 0: # Automatic port selection requires us to start servers # sequentially until we get a working bound socket to ensure # consistent port value across all requested listen addresses. try: for host in hosts: server = await self._start_server(host, port) if server is not None: if port == 0: port = server.sockets[0].getsockname()[1] servers[host] = server except Exception: await self._stop_servers(servers.values()) raise else: start_tasks = {} try: async with asyncio.TaskGroup() as g: if sockets: for host, sock in zip(hosts, sockets): start_tasks[host] = g.create_task( self._start_server(host, port, sock=sock) ) else: for host in hosts: start_tasks[host] = g.create_task( self._start_server(host, port) ) except Exception: await self._stop_servers([ fut.result() for fut in start_tasks.values() if ( fut.done() and fut.exception() is None and fut.result() is not None ) ]) raise servers.update({ host: srv for host, fut in start_tasks.items() if (srv := fut.result()) is not None }) # Fail if none of the servers can be started, except when the admin # server on a UNIX domain socket will be started. if not servers and (not admin or port == 0): raise StartupError("could not create any listen sockets") addrs = [] for tcp_srv in servers.values(): for s in tcp_srv.sockets: addrs.append(s.getsockname()) if len(addrs) > 1: if port: addr_str = f"{{{', '.join(addr[0] for addr in addrs)}}}:{port}" else: addr_str = f"""{{{', '.join( f'{addr[0]}:{addr[1]}' for addr in addrs)}}}""" elif addrs: addr_str = f'{addrs[0][0]}:{addrs[0][1]}' port = addrs[0][1] else: addr_str = None if addr_str: logger.info('Serving on %s', addr_str) if admin and port: try: admin_unix_srv = await self._start_admin_server(port) except Exception: await self._stop_servers(servers.values()) raise servers[ADMIN_PLACEHOLDER] = admin_unix_srv return servers, port, addrs def _sni_callback(self, sslobj, server_name, sslctx): # Match the given SNI for a pre-registered Tenant instance, # and temporarily store in memory indexed by sslobj for future # retrieval, see also retrieve_tenant() below. # # Used in multi-tenant server only. This method must not fail. pass def reload_tls(self, tls_cert_file, tls_key_file, client_ca_file): logger.info("loading TLS certificates") tls_password_needed = False if self._tls_certs_reload_retry_handle is not None: self._tls_certs_reload_retry_handle.cancel() self._tls_certs_reload_retry_handle = None def _tls_private_key_password(): nonlocal tls_password_needed tls_password_needed = True return ( os.environ.get('GEL_SERVER_TLS_PRIVATE_KEY_PASSWORD', '') or os.environ.get('EDGEDB_SERVER_TLS_PRIVATE_KEY_PASSWORD', '') ) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx_pgext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) try: sslctx.load_cert_chain( tls_cert_file, tls_key_file, password=_tls_private_key_password, ) sslctx_pgext.load_cert_chain( tls_cert_file, tls_key_file, password=_tls_private_key_password, ) except ssl.SSLError as e: if e.library == "SSL" and e.errno == 9: # ERR_LIB_PEM if tls_password_needed: if _tls_private_key_password(): raise StartupError( "Cannot load TLS certificates - it's likely that " "the private key password is wrong." ) from e else: raise StartupError( "Cannot load TLS certificates - the private key " "file is likely protected by a password. Specify " "the password using environment variable: " "GEL_SERVER_TLS_PRIVATE_KEY_PASSWORD" ) from e elif tls_key_file is None: raise StartupError( "Cannot load TLS certificates - have you specified " "the private key file using the `--tls-key-file` " "command-line argument?" ) from e else: raise StartupError( "Cannot load TLS certificates - please double check " "if the specified certificate files are valid." ) elif e.library == "X509" and e.errno == 116: # X509 Error 116: X509_R_KEY_VALUES_MISMATCH raise StartupError( "Cannot load TLS certificates - the private key doesn't " "match the certificate." ) raise StartupError(f"Cannot load TLS certificates - {e}") from e if client_ca_file is not None: try: sslctx.load_verify_locations(client_ca_file) sslctx_pgext.load_verify_locations(client_ca_file) except ssl.SSLError as e: raise StartupError( f"Cannot load client CA certificates - {e}") from e sslctx.verify_mode = ssl.CERT_OPTIONAL sslctx_pgext.verify_mode = ssl.CERT_OPTIONAL sslctx.set_alpn_protocols(['edgedb-binary', 'http/1.1']) sslctx.sni_callback = self._sni_callback sslctx_pgext.sni_callback = self._sni_callback self._sslctx = sslctx self._sslctx_pgext = sslctx_pgext def init_tls( self, tls_cert_file, tls_key_file, tls_cert_newly_generated, client_ca_file, ): assert self._sslctx is self._sslctx_pgext is None self.reload_tls(tls_cert_file, tls_key_file, client_ca_file) self._tls_cert_file = str(tls_cert_file) self._tls_cert_newly_generated = tls_cert_newly_generated def reload_tls(retry=0): try: self.reload_tls(tls_cert_file, tls_key_file, client_ca_file) except (StartupError, FileNotFoundError) as e: if retry > defines._TLS_CERT_RELOAD_MAX_RETRIES: logger.critical(str(e)) self.request_shutdown() else: delay = defines._TLS_CERT_RELOAD_EXP_INTERVAL * 2 ** retry logger.warning("%s; retrying in %.1f seconds.", e, delay) self._tls_certs_reload_retry_handle = ( self.__loop.call_later( delay, reload_tls, retry + 1, ) ) except Exception: logger.critical( "error while reloading TLS certificate and/or key, " "shutting down.", exc_info=True, ) self.request_shutdown() self.monitor_fs(tls_cert_file, reload_tls) if tls_cert_file != tls_key_file: self.monitor_fs(tls_key_file, reload_tls) if client_ca_file is not None: self.monitor_fs(client_ca_file, reload_tls) def start_watching_files(self): # TODO(fantix): include the monitor_fs() lines above pass def load_jwcrypto(self, jws_key_file: pathlib.Path) -> auth.JWKSet: try: jws_key = auth.load_secret_key(jws_key_file) self._jws_key = jws_key return jws_key except auth.SecretKeyReadError as e: raise StartupError(e.args[0]) from e def init_jwcrypto( self, jws_key_file: pathlib.Path, jws_keys_newly_generated: bool, ) -> None: self.load_jwcrypto(jws_key_file) self._jws_keys_newly_generated = jws_keys_newly_generated def get_jws_key(self) -> auth.JWKSet | None: return self._jws_key async def _stop_servers(self, servers): async with asyncio.TaskGroup() as g: for srv in servers: srv.close() g.create_task(srv.wait_closed()) async def _before_start_servers(self) -> None: pass async def _after_start_servers(self) -> None: pass async def start(self): self._stop_evt.clear() self._http_request_logger = self.__loop.create_task( self._request_stats_logger() ) pool = await compiler_pool.create_compiler_pool( **self._get_compiler_args() ) self.compilation_config_serializer = ( await pool.make_compilation_config_serializer() ) self._compiler_pool = pool await self._before_start_servers() self._servers, actual_port, listen_addrs = await self._start_servers( tuple((await _resolve_interfaces(self._listen_hosts))[0]), self._listen_port, sockets=self._listen_sockets, ) self._listen_hosts = [addr[0] for addr in listen_addrs] self._listen_port = actual_port if self._daemonized: pidfile_dir = self._pidfile_dir if pidfile_dir is None: pidfile_dir = self._runstate_dir pidfile_path = pidfile_dir / f".s.EDGEDB.{actual_port}.lock" pidfile = daemon.PidFile(pidfile_path) pidfile.acquire() await self._after_start_servers() self._auth_gc = self.__loop.create_task(pkce.gc(self)) if self._net_worker_mode is srvargs.NetWorkerMode.Default: self._net_worker_http = self.__loop.create_task( net_worker.http(self) ) self._net_worker_http_gc = self.__loop.create_task( net_worker.gc(self) ) if self._echo_runtime_info: ri = { "port": self._listen_port, "runstate_dir": str(self._runstate_dir), "tls_cert_file": self._tls_cert_file, } print(f'\nEDGEDB_SERVER_DATA:{json.dumps(ri)}\n', flush=True) status = self._get_status() status["listen_addrs"] = listen_addrs status_str = f'READY={json.dumps(status)}' for status_sink in self._status_sinks: status_sink(status_str) if self._auto_shutdown_after > 0: self._auto_shutdown_handler = self.__loop.call_later( self._auto_shutdown_after, self.request_auto_shutdown) def _get_status(self) -> dict[str, Any]: return { "port": self._listen_port, "socket_dir": str(self._runstate_dir), "main_pid": os.getpid(), "tls_cert_file": self._tls_cert_file, "tls_cert_newly_generated": self._tls_cert_newly_generated, "jws_keys_newly_generated": self._jws_keys_newly_generated, } def request_auto_shutdown(self): if self._auto_shutdown_after == 0: logger.info("shutting down server: all clients disconnected") else: logger.info( f"shutting down server: no clients connected in last" f" {self._auto_shutdown_after} seconds" ) self.request_shutdown() def request_shutdown(self): self._stop_evt.set() async def stop(self): if self._idle_gc_handler is not None: self._idle_gc_handler.cancel() self._idle_gc_handler = None if self._http_request_logger is not None: self._http_request_logger.cancel() if self._auth_gc is not None: self._auth_gc.cancel() if self._net_worker_http is not None: self._net_worker_http.cancel() if self._net_worker_http_gc is not None: self._net_worker_http_gc.cancel() for handle in self._file_watch_handles: handle.cancel() self._file_watch_handles.clear() await self._stop_servers(self._servers.values()) self._servers = {} # This should be done by tenant.stop(), but let's still do it again for conn in self._binary_conns: conn.request_stop() self._binary_conns.clear() for conn in self._pgext_conns.values(): conn.request_stop() self._pgext_conns.clear() def request_frontend_stop(self, tenant: edbtenant.Tenant): dropped = [] for conn in self._binary_conns: if conn.tenant is tenant: conn.request_stop() dropped.append(conn) for conn in dropped: self._binary_conns.pop(conn, None) dropped.clear() for conn in self._pgext_conns.values(): if conn.tenant is tenant: conn.request_stop() dropped.append(conn) for conn in dropped: self._pgext_conns.pop(conn, None) async def serve_forever(self): await self._stop_evt.wait() def get_sys_query(self, key): return self._sys_queries[key] def get_debug_info(self): """Used to render the /server-info endpoint in dev/test modes. Some tests depend on the exact layout of the returned structure. """ return dict( params=dict( dev_mode=self._devmode, test_mode=self._testmode, default_auth_methods=str(self._default_auth_method_spec), listen_hosts=self._listen_hosts, listen_port=self._listen_port, ), instance_config=config.debug_serialize_config( self._get_sys_config()), compiler_pool=( self._compiler_pool.get_debug_info() if self._compiler_pool else None ), ) def get_report_config_typedesc( self, ) -> dict[defines.ProtocolVersion, bytes]: return self._report_config_typedesc def get_default_auth_methods( self, transport: srvargs.ServerConnTransport ) -> list[config.CompositeConfigType]: return self._default_auth_methods.get(transport, []) def get_std_schema(self) -> s_schema.Schema: return self._std_schema def retrieve_tenant(self, sslobj) -> edbtenant.Tenant | None: # After TLS handshake, the client connection would use this method to # retrieve the Tenant instance associated with the given SSLObject. # # This method must not fail. See also _sni_callback() above. return self.get_default_tenant() def get_default_tenant(self) -> edbtenant.Tenant: # The client connection must proceed on a Tenant instance. In cases: # 1. plain-text connection without TLS handshake # 2. TLS handshake didn't provide SNI # 3. SNI didn't match any Tenant (retrieve_tenant() returned None) # this method will be called for a "default" tenant to use. # # The caller must be ready to handle errors raised in this method, and # provide a decent error. raise NotImplementedError def iter_tenants(self) -> Iterator[edbtenant.Tenant]: raise NotImplementedError async def maybe_generate_pki( self, args: srvargs.ServerConfig, ss: BaseServer ) -> tuple[bool, bool]: tls_cert_newly_generated = False if args.tls_cert_mode is srvargs.ServerTlsCertMode.SelfSigned: assert args.tls_cert_file is not None if not args.tls_cert_file.exists(): assert args.tls_key_file is not None logger.info( f'generating self-signed TLS certificate ' f'in "{args.tls_cert_file}"' ) secretkey.generate_tls_cert( args.tls_cert_file, args.tls_key_file, ss.get_listen_hosts(), ) tls_cert_newly_generated = True jws_keys_newly_generated = False if args.jose_key_mode is srvargs.JOSEKeyMode.Generate: assert args.jws_key_file is not None if not args.jws_key_file.exists(): logger.info( f'generating JOSE key pair in "{args.jws_key_file}"' ) auth.generate_jwk(args.jws_key_file) jws_keys_newly_generated = True return tls_cert_newly_generated, jws_keys_newly_generated class Server(BaseServer): _tenant: edbtenant.Tenant _startup_script: srvargs.StartupScript | None _new_instance: bool def __init__( self, *, tenant: edbtenant.Tenant, startup_script: srvargs.StartupScript | None = None, new_instance: bool, **kwargs, ): super().__init__(**kwargs) self._tenant = tenant self._startup_script = startup_script self._new_instance = new_instance tenant.set_server(self) def _get_sys_config(self) -> Mapping[str, config.SettingValue]: return self._tenant.get_sys_config() async def init(self) -> None: logger.debug("starting server init") await self._tenant.init_sys_pgcon() await self._load_instance_data() await self._maybe_patch() await self._tenant.init() await super().init() def get_default_tenant(self) -> edbtenant.Tenant: return self._tenant def iter_tenants(self) -> Iterator[edbtenant.Tenant]: yield self._tenant async def _get_patch_log( self, conn: pgcon.PGConnection, idx: int ) -> Optional[bootstrap.PatchEntry]: # We need to maintain a log in the system database of # patches that have been applied. This is so that if a # patch creates a new object, and then we succesfully # apply the patch to a user db but crash *before* applying # it to the system db, when we start up again and try # applying it to the system db, it is important that we # apply the same compiled version of the patch. If we # instead recompiled it, and it created new objects, those # objects might have a different id in the std schema and # in the actual user db. result = await instdata.get_instdata( conn, f'patch_log_{idx}', 'bin') if result: return pickle.loads(result) else: return None async def _prepare_patches( self, conn: pgcon.PGConnection ) -> dict[int, bootstrap.PatchEntry]: """Prepare all the patches""" num_patches = await self._tenant.get_patch_count(conn) if num_patches < len(pg_patches.PATCHES): logger.info("preparing patches for database upgrade") patches = {} patch_list = list(enumerate(pg_patches.PATCHES)) for num, (kind, patch) in patch_list[num_patches:]: from . import bootstrap # noqa: F402 idx = num_patches + num if not (entry := await self._get_patch_log(conn, idx)): patch_info = await bootstrap.gather_patch_info( num, kind, patch, conn ) entry = bootstrap.prepare_patch( num, kind, patch, self._std_schema, self._refl_schema, self._schema_class_layout, self._tenant.get_backend_runtime_params(), patch_info=patch_info, ) await bootstrap._store_static_bin_cache_conn( conn, f'patch_log_{idx}', pickle.dumps(entry)) patches[num] = entry _, _, updates = entry if 'std_and_reflection_schema' in updates: self._std_schema, self._refl_schema = updates[ 'std_and_reflection_schema'] # +config patches might modify config_spec, which requires # a reload of it from the schema. if '+config' in kind: config_spec = config.load_spec_from_schema(self._std_schema) self._config_settings = config_spec if 'local_intro_query' in updates: self._local_intro_query = updates['local_intro_query'] if 'global_intro_query' in updates: self._global_intro_query = updates['global_intro_query'] if 'classlayout' in updates: self._schema_class_layout = updates['classlayout'] if 'sysqueries' in updates: queries = json.loads(updates['sysqueries']) self._sys_queries = immutables.Map( {k: q.encode() for k, q in queries.items()}) if 'report_configs_typedesc' in updates: self._report_config_typedesc = ( updates['report_configs_typedesc']) return patches async def _maybe_apply_patches( self, dbname: str, conn: pgcon.PGConnection, patches: dict[int, bootstrap.PatchEntry], sys: bool=False, ) -> None: """Apply any un-applied patches to the database.""" num_patches = await self._tenant.get_patch_count(conn) for num, (sql_b, syssql, keys) in patches.items(): if num_patches <= num: if sys: sql_b += syssql logger.info("applying patch %d to database '%s'", num, dbname) sql = tuple(x.encode('utf-8') for x in sql_b) # For certain things, we need to actually run it # against each user database. if keys.get('is_user_update'): from . import bootstrap kind, patch = pg_patches.PATCHES[num] patch_info = await bootstrap.gather_patch_info( num, kind, patch, conn ) # Reload the compiler state from this database in # particular, so we can compiler from exactly the # right state. (Since self._std_schema and the like might # be further advanced.) state = (await edbcompiler.new_compiler_from_pg(conn)).state assert state.global_intro_query and state.local_intro_query global_schema = self._parse_global_schema( await conn.sql_fetch_val( state.global_intro_query.encode('utf-8')), ) user_schema = self._parse_user_schema( await conn.sql_fetch_val( state.local_intro_query.encode('utf-8')), global_schema, ) entry = bootstrap.prepare_patch( num, kind, patch, state.std_schema, state.refl_schema, state.schema_class_layout, self._tenant.get_backend_runtime_params(), patch_info=patch_info, user_schema=user_schema, global_schema=global_schema, dbname=dbname, ) sql += tuple(x.encode('utf-8') for x in entry[0]) if sql: await conn.sql_execute(sql) logger.info( "finished applying patch %d to database '%s'", num, dbname) async def _maybe_patch_db( self, dbname: str, patches: dict[int, bootstrap.PatchEntry], sem: Any ) -> None: logger.info("applying patches to database '%s'", dbname) try: async with sem: async with self._tenant.direct_pgcon(dbname) as conn: await self._maybe_apply_patches(dbname, conn, patches) except Exception as e: if ( isinstance(e, errors.EdgeDBError) and not isinstance(e, errors.InternalServerError) ): raise raise errors.InternalServerError( f'Could not apply patches for minor version upgrade to ' f'database {dbname}' ) from e async def _maybe_patch(self) -> None: """Apply patches to all the databases""" async with self._tenant.use_sys_pgcon() as syscon: patches = await self._prepare_patches(syscon) if not patches: return dbnames = await self.get_dbnames(syscon) async with asyncio.TaskGroup() as g: # Cap the parallelism used when applying patches, to avoid # having huge numbers of in flight patches that make # little visible progress in the logs. sem = asyncio.Semaphore(16) # Patch all the databases for dbname in dbnames: if dbname != defines.EDGEDB_SYSTEM_DB: g.create_task( self._maybe_patch_db(dbname, patches, sem)) # Patch the template db, so that any newly created databases # will have the patches. g.create_task(self._maybe_patch_db( defines.EDGEDB_TEMPLATE_DB, patches, sem)) await self._tenant.ensure_database_not_connected( defines.EDGEDB_TEMPLATE_DB ) # Patch the system db last. The system db needs to go last so # that it only gets updated if all of the other databases have # been succesfully patched. This is important, since we don't check # other databases for patches unless the system db is patched. # # Driving everything from the system db like this lets us # always use the correct schema when compiling patches. async with self._tenant.use_sys_pgcon() as syscon: await self._maybe_apply_patches( defines.EDGEDB_SYSTEM_DB, syscon, patches, sys=True) def _load_schema(self, result, version_key) -> s_schema.Schema: res = pickle.loads(result[2:]) if version_key != pg_patches.get_version_key(len(pg_patches.PATCHES)): res = s_schema.upgrade_schema(res) return res async def _load_instance_data(self): logger.info("loading instance data") async with self._tenant.use_sys_pgcon() as syscon: patch_count = await self._tenant.get_patch_count(syscon) version_key = pg_patches.get_version_key(patch_count) result = await instdata.get_instdata( syscon, f'sysqueries{version_key}', 'json') queries = json.loads(result) self._sys_queries = immutables.Map( {k: q.encode() for k, q in queries.items()}) self._report_config_typedesc[(1, 0)] = ( await instdata.get_instdata( syscon, f'report_configs_typedesc_1_0{version_key}', 'bin', ) ) self._report_config_typedesc[(2, 0)] = ( await instdata.get_instdata( syscon, f'report_configs_typedesc_2_0{version_key}', 'bin', ) ) def _reload_stmt_cache_size(self): size = self.config_lookup( '_pg_prepared_statement_cache_size', self._get_sys_config() ) self._stmt_cache_size = size self._tenant.set_stmt_cache_size(size) async def _restart_servers_new_addr(self, nethosts, netport): if not netport: raise RuntimeError('cannot restart without network port specified') nethosts, has_ipv4_wc, has_ipv6_wc = await _resolve_interfaces( nethosts ) servers_to_stop = [] servers_to_stop_early = [] servers = {} if self._listen_port == netport: hosts_to_start = [ host for host in nethosts if host not in self._servers ] for host, srv in self._servers.items(): if host == ADMIN_PLACEHOLDER or host in nethosts: servers[host] = srv elif host in ['::', '0.0.0.0']: servers_to_stop_early.append(srv) else: if has_ipv4_wc: try: ipaddress.IPv4Address(host) except ValueError: pass else: servers_to_stop_early.append(srv) continue if has_ipv6_wc: try: ipaddress.IPv6Address(host) except ValueError: pass else: servers_to_stop_early.append(srv) continue servers_to_stop.append(srv) admin = False else: hosts_to_start = nethosts servers_to_stop = list(self._servers.values()) admin = True if servers_to_stop_early: await self._stop_servers_with_logging(servers_to_stop_early) if hosts_to_start: try: new_servers, *_ = await self._start_servers( tuple(hosts_to_start), netport, admin=admin, ) servers.update(new_servers) except StartupError: raise errors.ConfigurationError( 'Server updated its config but cannot serve on requested ' 'address/port, please see server log for more information.' ) self._servers = servers self._listen_hosts = [ s.getsockname()[0] for host, tcp_srv in servers.items() if host != ADMIN_PLACEHOLDER for s in tcp_srv.sockets # type: ignore ] self._listen_port = netport await self._stop_servers_with_logging(servers_to_stop) async def _stop_servers_with_logging(self, servers_to_stop): addrs = [] unix_addr = None port = None for srv in servers_to_stop: for s in srv.sockets: addr = s.getsockname() if isinstance(addr, tuple): addrs.append(addr[:2]) if port is None: port = addr[1] elif port != addr[1]: port = 0 else: unix_addr = addr if len(addrs) > 1: if port: addr_str = f"{{{', '.join(addr[0] for addr in addrs)}}}:{port}" else: addr_str = f"{{{', '.join('%s:%d' % addr for addr in addrs)}}}" elif addrs: addr_str = "%s:%d" % addrs[0] else: addr_str = None if addr_str: logger.info('Stopping to serve on %s', addr_str) if unix_addr: logger.info('Stopping to serve admin on %s', unix_addr) await self._stop_servers(servers_to_stop) async def _on_system_config_set(self, setting_name, value): try: if setting_name == 'listen_addresses': await self._restart_servers_new_addr(value, self._listen_port) elif setting_name == 'listen_port': await self._restart_servers_new_addr(self._listen_hosts, value) elif setting_name == 'session_idle_timeout': self.reinit_idle_gc_collector() elif setting_name == '_pg_prepared_statement_cache_size': self._reload_stmt_cache_size() self._tenant.schedule_reported_config_if_needed(setting_name) except Exception: metrics.background_errors.inc( 1.0, self._tenant.get_instance_name(), 'on_system_config_set' ) raise async def _on_system_config_reset(self, setting_name): try: if setting_name == 'listen_addresses': cfg = self._get_sys_config() await self._restart_servers_new_addr( self.config_lookup('listen_addresses', cfg) or ('localhost',), self._listen_port, ) elif setting_name == 'listen_port': cfg = self._get_sys_config() await self._restart_servers_new_addr( self._listen_hosts, self.config_lookup('listen_port', cfg) or defines.EDGEDB_PORT, ) elif setting_name == 'session_idle_timeout': self.reinit_idle_gc_collector() elif setting_name == '_pg_prepared_statement_cache_size': self._reload_stmt_cache_size() self._tenant.schedule_reported_config_if_needed(setting_name) except Exception: metrics.background_errors.inc( 1.0, self._tenant.get_instance_name(), 'on_system_config_reset' ) raise async def _after_system_config_add(self, setting_name, value): try: if setting_name == 'auth': self._tenant.populate_sys_auth() except Exception: metrics.background_errors.inc( 1.0, self._tenant.get_instance_name(), 'after_system_config_add', ) raise async def _after_system_config_rem(self, setting_name, value): try: if setting_name == 'auth': self._tenant.populate_sys_auth() except Exception: metrics.background_errors.inc( 1.0, self._tenant.get_instance_name(), 'after_system_config_rem', ) raise async def run_startup_script_and_exit(self): """Run the script specified in *startup_script* and exit immediately""" if self._startup_script is None: raise AssertionError('startup script is not defined') pool = await compiler_pool.create_compiler_pool( **self._get_compiler_args() ) self.compilation_config_serializer = ( await pool.make_compilation_config_serializer() ) self._compiler_pool = pool try: await binary.run_script( server=self, tenant=self._tenant, database=self._startup_script.database, user=self._startup_script.user, script=self._startup_script.text, ) finally: await self._destroy_compiler_pool() async def _before_start_servers(self) -> None: await self._tenant.start_accepting_new_tasks() if self._startup_script and self._new_instance: await binary.run_script( server=self, tenant=self._tenant, database=self._startup_script.database, user=self._startup_script.user, script=self._startup_script.text, ) async def _after_start_servers(self) -> None: self._tenant.start_running() def _get_status(self) -> dict[str, Any]: status = super()._get_status() status["tenant_id"] = self._tenant.tenant_id return status def load_jwcrypto(self, jws_key_file: pathlib.Path) -> auth.JWKSet: jws_key = super().load_jwcrypto(jws_key_file) self._tenant.load_jwcrypto(jws_key) return jws_key def request_shutdown(self): self._tenant.stop_accepting_connections() super().request_shutdown() async def stop(self): try: self._tenant.stop() await super().stop() await self._tenant.wait_stopped() await self._destroy_compiler_pool() finally: self._tenant.terminate_sys_pgcon() def get_debug_info(self): parent = super().get_debug_info() child = self._tenant.get_debug_info() parent["params"].update(child["params"]) child["params"] = parent["params"] parent.update(child) return parent def _get_backend_runtime_params(self) -> pgparams.BackendRuntimeParams: return self._tenant.get_backend_runtime_params() def _get_compiler_args(self) -> dict[str, Any]: rv = super()._get_compiler_args() rv.update(self._tenant.get_compiler_args()) return rv def start_watching_files(self): super().start_watching_files() self._tenant.start_watching_files() def _cleanup_wildcard_addrs( hosts: Sequence[str], ) -> tuple[list[str], list[str], bool, bool]: """Filter out conflicting addresses in presence of INADDR_ANY wildcards. Attempting to bind to 0.0.0.0 (or ::) _and_ a non-wildcard address will usually result in EADDRINUSE. To avoid this, filter out all specific addresses if a wildcard is present in the *hosts* sequence. Returns a tuple: first element is the new list of hosts, second element is a list of rejected host addrs/names. """ ipv4_hosts = set() ipv6_hosts = set() named_hosts = set() ipv4_wc = ipaddress.ip_address('0.0.0.0') ipv6_wc = ipaddress.ip_address('::') for host in hosts: if host == "*": ipv4_hosts.add(ipv4_wc) ipv6_hosts.add(ipv6_wc) continue try: ip = ipaddress.IPv4Address(host) except ValueError: pass else: ipv4_hosts.add(ip) continue try: ip6 = ipaddress.IPv6Address(host) except ValueError: pass else: ipv6_hosts.add(ip6) continue named_hosts.add(host) if not ipv4_hosts and not ipv6_hosts: return (list(hosts), [], False, False) if ipv4_wc not in ipv4_hosts and ipv6_wc not in ipv6_hosts: return (list(hosts), [], False, False) if ipv4_wc in ipv4_hosts and ipv6_wc in ipv6_hosts: return ( ['0.0.0.0', '::'], [ str(a) for a in ((named_hosts | ipv4_hosts | ipv6_hosts) - {ipv4_wc, ipv6_wc}) ], True, True, ) if ipv4_wc in ipv4_hosts: return ( [str(a) for a in ({ipv4_wc} | ipv6_hosts)], [str(a) for a in ((named_hosts | ipv4_hosts) - {ipv4_wc})], True, False, ) if ipv6_wc in ipv6_hosts: return ( [str(a) for a in ({ipv6_wc} | ipv4_hosts)], [str(a) for a in ((named_hosts | ipv6_hosts) - {ipv6_wc})], False, True, ) raise AssertionError('unreachable') async def _resolve_host(host: str) -> list[str] | Exception: loop = asyncio.get_running_loop() try: addrinfo = await loop.getaddrinfo( None if host == '*' else host, 0, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM, flags=socket.AI_PASSIVE, ) except Exception as e: return e else: return [addr[4][0] for addr in addrinfo] async def _resolve_interfaces( hosts: Sequence[str], ) -> tuple[Sequence[str], bool, bool]: async with asyncio.TaskGroup() as g: resolve_tasks = { host: g.create_task(_resolve_host(host)) for host in hosts } addrs = [] for host, fut in resolve_tasks.items(): result = fut.result() if isinstance(result, Exception): logger.warning( f"could not translate host name {host!r} to address: {result}") else: addrs.extend(result) ( clean_addrs, rejected_addrs, has_ipv4_wc, has_ipv6_wc ) = _cleanup_wildcard_addrs(addrs) if rejected_addrs: logger.warning( "wildcard addresses found in listen_addresses; " + "discarding the other addresses: " + ", ".join(repr(h) for h in rejected_addrs) ) return clean_addrs, has_ipv4_wc, has_ipv6_wc ================================================ FILE: edb/server/service_manager.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Optional import errno import logging import os import socket import sys SD_LISTEN_FDS_START = 3 logger = logging.getLogger('edb.server') def _stream_socket_from_fd(fd: int) -> Optional[socket.socket]: try: sock = socket.socket(fileno=fd) except OSError: logger.warning( f"activation file descriptor {fd} is not a socket " f", ignoring" ) return None if sock.family not in {socket.AF_INET, socket.AF_INET6}: logger.warning( f"activation file descriptor {fd} is not an AF_INET[6] socket " f", ignoring" ) return None if sock.type != socket.SOCK_STREAM: logger.warning( f"activation file descriptor {fd} is not an SOCK_STREAM " f"socket, ignoring" ) return None return sock def sd_notify(message: str) -> None: notify_socket = os.environ.get('NOTIFY_SOCKET') if not notify_socket: return if notify_socket[0] == '@': notify_socket = '\0' + notify_socket[1:] with socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as sd_sock: try: sd_sock.connect(notify_socket) sd_sock.sendall(message.encode()) except Exception as e: logger.info('Could not send systemd notification: %s', e) def sd_get_activation_listen_sockets() -> dict[str, list[socket.socket]]: # Prevent socket activation variables from being inherited by # child processes (regardless of success below). listen_pid = os.environ.pop("LISTEN_PID", "") listen_fds = os.environ.pop("LISTEN_FDS", "") listen_fdnames = os.environ.pop("LISTEN_FDNAMES", "") if not listen_pid or not listen_fds: return {} try: expected_pid = int(listen_pid) except ValueError: logger.warning( "the value of LISTEN_PID environment variable " "is not a valid integer, ignoring socket activation data" ) return {} if expected_pid != os.getpid(): logger.warning( "the value of LISTEN_PID does not match the PID of this " "process, ignoring socket activation data" ) return {} try: num_fds = int(listen_fds) except ValueError: logger.warning( "the value of LISTEN_FDS environment variable " "is not a valid integer, ignoring socket activation data" ) return {} fd_names = listen_fdnames.split(":") fd_range = range(SD_LISTEN_FDS_START, SD_LISTEN_FDS_START + num_fds) sockets: dict[str, list[socket.socket]] = {} for i, fd in enumerate(fd_range): os.set_inheritable(fd, False) try: name = fd_names[i] except IndexError: name = f"LISTEN_FD_{fd}" sock = _stream_socket_from_fd(fd) if sock is not None: sockets.setdefault(name, []).append(sock) return sockets if sys.platform == "darwin": import ctypes import ctypes.util syslib = ctypes.CDLL(ctypes.util.find_library('System')) syslib.launch_activate_socket.argypes = [ # type: ignore[attr-defined] ctypes.c_char_p, ctypes.POINTER(ctypes.POINTER(ctypes.c_int)), ctypes.POINTER(ctypes.c_size_t), ] class LaunchActivateSocketError(Exception): def __init__(self, errno: int) -> None: self.errno = errno def _launch_activate_socket(name) -> list[int]: fds = ctypes.POINTER(ctypes.c_int)() num_fds = ctypes.c_size_t() result = syslib.launch_activate_socket( ctypes.c_char_p(name.encode("utf-8")), ctypes.byref(fds), ctypes.byref(num_fds), ) if result == 0: return [fds[i] for i in range(num_fds.value)] elif result == errno.ESRCH: # Not running under launchd return [] else: raise LaunchActivateSocketError(result) def launchd_get_activation_listen_sockets() -> ( dict[str, list[socket.socket]] ): names = ["edgedb-server"] sockets: dict[str, list[socket.socket]] = {} for name in names: try: fds = _launch_activate_socket(name) except LaunchActivateSocketError as e: logger.warning( f"could not activate socket {name}: " f"launch_activate_socket() returned {e.errno}") continue for fd in fds: os.set_inheritable(fd, False) sock = _stream_socket_from_fd(fd) if sock is not None: sockets.setdefault(name, []).append(sock) return sockets else: def launchd_get_activation_listen_sockets() -> ( dict[str, list[socket.socket]] ): return {} def get_activation_listen_sockets() -> dict[str, list[socket.socket]]: if sys.platform == "darwin": sockets = launchd_get_activation_listen_sockets() else: sockets = sd_get_activation_listen_sockets() return sockets ================================================ FILE: edb/server/smtp.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2024-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import dataclasses import email.message import asyncio import logging import os import hashlib import pickle import aiosmtplib from typing import Optional from edb.common import retryloop from edb.ir import statypes from edb import errors from . import dbview _semaphore: asyncio.BoundedSemaphore | None = None logger = logging.getLogger('edb.server.smtp') @dataclasses.dataclass class SMTPProviderConfig: name: str sender: Optional[str] host: Optional[str] port: Optional[int] username: Optional[str] password: Optional[str] security: str validate_certs: bool timeout_per_email: statypes.Duration timeout_per_attempt: statypes.Duration class SMTP: def __init__(self, db: dbview.Database): current_provider = get_current_email_provider(db) self.sender = current_provider.sender or "noreply@example.com" default_port = ( 465 if current_provider.security == "TLS" else 587 if current_provider.security == "STARTTLS" else 25 ) use_tls: bool start_tls: bool | None match current_provider.security: case "PlainText": use_tls = False start_tls = False case "TLS": use_tls = True start_tls = False case "STARTTLS": use_tls = False start_tls = True case "STARTTLSOrPlainText": use_tls = False start_tls = None case _: raise NotImplementedError host = current_provider.host or "localhost" port = current_provider.port or default_port username = current_provider.username password = current_provider.password validate_certs = current_provider.validate_certs timeout_per_attempt = current_provider.timeout_per_attempt req_timeout = timeout_per_attempt.to_microseconds() / 1_000_000.0 self.timeout_per_email = ( current_provider.timeout_per_email.to_microseconds() / 1_000_000.0 ) self.client = aiosmtplib.SMTP( hostname=host, port=port, username=username, password=password, timeout=req_timeout, use_tls=use_tls, start_tls=start_tls, validate_certs=validate_certs, ) async def send( self, message: email.message.Message, *, test_mode: bool = False, ) -> None: global _semaphore if _semaphore is None: _semaphore = asyncio.BoundedSemaphore( int( os.environ.get( "EDGEDB_SERVER_AUTH_SMTP_CONCURRENCY", os.environ.get("EDGEDB_SERVER_SMTP_CONCURRENCY", 5), ) ) ) # n.b. When constructing EmailMessage objects, we don't set the "From" # header since that is configured in the SmtpProviderConfig. However, # the EmailMessage will have the correct "To" header. message["From"] = self.sender rloop = retryloop.RetryLoop( timeout=self.timeout_per_email, backoff=retryloop.exp_backoff(), ignore=( aiosmtplib.SMTPConnectError, aiosmtplib.SMTPHeloError, aiosmtplib.SMTPServerDisconnected, aiosmtplib.SMTPConnectTimeoutError, aiosmtplib.SMTPConnectResponseError, ), ) async for iteration in rloop: async with iteration: async with _semaphore: # Currently we are not reusing SMTP connections, but # ideally we should replace this with a pool of # connections, and drop idle connections after configured # time. if test_mode: self._send_test_mode_email(message) else: logger.info( "Sending SMTP message to " f"{self.client.hostname}:{self.client.port}" ) async with self.client: errors, response = await self.client.send_message( message ) if errors: logger.error( f"SMTP server returned errors: {errors}" ) else: logger.info( f"SMTP message sent successfully: {response}" ) def _send_test_mode_email(self, message: email.message.Message): sender = message["From"] recipients = message["To"] recipients_list: list[str] if isinstance(recipients, str): recipients_list = [recipients] elif recipients is None: recipients_list = [] else: recipients_list = list(recipients) hash_input = f"{sender}{','.join(recipients_list)}" file_name_hash = hashlib.sha256(hash_input.encode()).hexdigest() file_name = f"/tmp/edb-test-email-{file_name_hash}.pickle" test_file = os.environ.get( "EDGEDB_TEST_EMAIL_FILE", file_name, ) if os.path.exists(test_file): os.unlink(test_file) with open(test_file, "wb") as f: logger.info(f"Dumping SMTP message to {test_file}") args = dict( message=message, sender=sender, recipients=recipients, hostname=self.client.hostname, port=self.client.port, username=self.client._login_username, password=self.client._login_password, timeout=self.client.timeout, use_tls=self.client.use_tls, start_tls=self.client._start_tls_on_connect, validate_certs=self.client.validate_certs, ) pickle.dump(args, f) def get_current_email_provider( db: dbview.Database, ) -> SMTPProviderConfig: current_provider_name = db.lookup_config("current_email_provider_name") if current_provider_name is None: raise errors.ConfigurationError("No email provider configured") found = None objs = ( list(db.lookup_config("email_providers")) + db.tenant._sidechannel_email_configs ) for obj in objs: if obj.name == current_provider_name: values = {} for field in dataclasses.fields(SMTPProviderConfig): key = field.name values[key] = getattr(obj, key) found = SMTPProviderConfig(**values) break if found is None: raise errors.ConfigurationError( f"No email provider named {current_provider_name!r}" ) return found ================================================ FILE: edb/server/tenant.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2023-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, Iterator, Iterable, Mapping, Coroutine, AsyncGenerator, Optional, TypedDict, TYPE_CHECKING, ) import asyncio import contextlib import dataclasses import json import logging import os import pathlib import pickle import struct import sys import textwrap import time import tomllib import uuid import weakref import immutables from edb import buildmeta from edb import errors from edb.common import asyncutil from edb.common import lru from edb.common import retryloop from edb.common import verutils from edb.common.log import current_tenant from . import auth from . import args as srvargs from . import config from . import connpool from . import dbview from . import defines from . import instdata from . import metrics from . import pgcon from . import compiler as edbcompiler from . import pgconnparams from .ha import adaptive as adaptive_ha from .ha import base as ha_base from .http import HttpClient from .pgcon import errors as pgcon_errors from .compiler import enums as compiler_enums if TYPE_CHECKING: from edb.pgsql import params as pgparams from . import pgcluster from . import server as edbserver from . import compiler_pool as edbcompiler_pool logger = logging.getLogger("edb.server") HTTP_MAX_CONNECTIONS = 100 HEALTH_CHECK_MIN_INTERVAL: float = float( os.getenv("GEL_BACKEND_HEALTH_CHECK_MIN_INTERVAL", 10) ) HEALTH_CHECK_TIMEOUT: float = float( os.getenv("GEL_BACKEND_HEALTH_CHECK_TIMEOUT", 10) ) class RoleDescriptor(TypedDict): superuser: bool name: str password: str | None all_permissions: list[str] | None branches: list[str] apply_access_policies_pg_default: bool | None class Tenant(ha_base.ClusterProtocol): _server: edbserver.BaseServer _cluster: pgcluster.BaseCluster _tenant_id: str _instance_name: str _instance_data: Mapping[str, str] _dbindex: dbview.DatabaseIndex | None _initing: bool _running: bool _accepting_connections: bool _introspection_locks: weakref.WeakValueDictionary[str, asyncio.Lock] __loop: asyncio.AbstractEventLoop _task_group: asyncio.TaskGroup | None _tasks: set[asyncio.Task] _accept_new_tasks: bool _file_watch_finalizers: list[Callable[[], None]] __sys_pgcon: pgcon.PGConnection | None _sys_pgcon_waiter: asyncio.Lock _sys_pgcon_ready_evt: asyncio.Event _sys_pgcon_reconnect_evt: asyncio.Event _sys_pgcon_last_active_time: float _max_backend_connections: int _suggested_client_pool_size: int _pg_pool: connpool.Pool _pg_unavailable_msg: str | None _init_con_data: list[config.ConState] _init_con_sql: bytes | None _ha_master_serial: int _backend_adaptive_ha: adaptive_ha.AdaptiveHASupport | None _readiness_state_file: pathlib.Path | None _readiness: srvargs.ReadinessState _readiness_reason: str _config_file: pathlib.Path | None _extensions_dirs: tuple[pathlib.Path, ...] # A set of databases that should not accept new connections. _block_new_connections: set[str] _report_config_data: dict[defines.ProtocolVersion, bytes] _roles: Mapping[str, RoleDescriptor] _role_capabilities: Mapping[str, compiler_enums.Capability] _sys_auth: tuple[Any, ...] _jwt_sub_allowlist_file: pathlib.Path | None _jwt_sub_allowlist: frozenset[str] | None _jwt_revocation_list_file: pathlib.Path | None _jwt_revocation_list: frozenset[str] | None _http_client: HttpClient | None _sidechannel_email_configs: list[Any] def __init__( self, cluster: pgcluster.BaseCluster, *, instance_name: str, max_backend_connections: int, backend_adaptive_ha: bool = False, extensions_dir: tuple[pathlib.Path, ...] = (), ): self._cluster = cluster self._tenant_id = self.get_backend_runtime_params().tenant_id self._instance_name = instance_name self._instance_data = immutables.Map() self._initing = True self._running = False self._accepting_connections = False self._task_group = None self._tasks = set() self._named_tasks: dict[str, asyncio.Task] = dict() self._accept_new_tasks = False self._file_watch_finalizers = [] self._introspection_locks = weakref.WeakValueDictionary() self._sidechannel_email_configs = [] self._extensions_dirs = extensions_dir # Never use `self.__sys_pgcon` directly; get it via # `async with self.use_sys_pgcon()`. self.__sys_pgcon = None self._sys_pgcon_last_active_time = 0 # Increase-only counter to reject outdated attempts to connect self._ha_master_serial = 0 if backend_adaptive_ha: self._backend_adaptive_ha = adaptive_ha.AdaptiveHASupport( self, self._instance_name ) else: self._backend_adaptive_ha = None self._readiness_state_file = None self._readiness = srvargs.ReadinessState.Default self._readiness_reason = "" self._config_file = None self._max_backend_connections = max_backend_connections self._suggested_client_pool_size = max( min( max_backend_connections, defines.MAX_SUGGESTED_CLIENT_POOL_SIZE ), defines.MIN_SUGGESTED_CLIENT_POOL_SIZE, ) self._pg_pool = connpool.Pool( connect=self._pg_connect, disconnect=self._pg_disconnect, # 1 connection is reserved for the system DB max_capacity=max_backend_connections - 1, ) self._pg_unavailable_msg = None self._block_new_connections = set() self._report_config_data = {} self._init_con_data = [] self._init_con_sql = None # DB state will be initialized in init(). self._dbindex = None self._branch_sem = asyncio.Semaphore(value=1) self._roles = immutables.Map() self._role_capabilities = immutables.Map() self._sys_auth = tuple() self._jwt_sub_allowlist_file = None self._jwt_sub_allowlist = None self._jwt_revocation_list_file = None self._jwt_revocation_list = None self._http_client = None # If it isn't stored in instdata, it is the old default. self.default_database = defines.EDGEDB_OLD_DEFAULT_DB def set_reloadable_files( self, readiness_state_file: str | pathlib.Path | None = None, jwt_sub_allowlist_file: str | pathlib.Path | None = None, jwt_revocation_list_file: str | pathlib.Path | None = None, config_file: str | pathlib.Path | None = None, ) -> bool: rv = False if isinstance(readiness_state_file, str): readiness_state_file = pathlib.Path(readiness_state_file) if self._readiness_state_file != readiness_state_file: self._readiness_state_file = readiness_state_file rv = True if isinstance(jwt_sub_allowlist_file, str): jwt_sub_allowlist_file = pathlib.Path(jwt_sub_allowlist_file) if self._jwt_sub_allowlist_file != jwt_sub_allowlist_file: self._jwt_sub_allowlist_file = jwt_sub_allowlist_file rv = True if isinstance(jwt_revocation_list_file, str): jwt_revocation_list_file = pathlib.Path(jwt_revocation_list_file) if self._jwt_revocation_list_file != jwt_revocation_list_file: self._jwt_revocation_list_file = jwt_revocation_list_file rv = True if isinstance(config_file, str): config_file = pathlib.Path(config_file) if self._config_file != config_file: self._config_file = config_file rv = True return rv def set_server(self, server: edbserver.BaseServer) -> None: self._server = server self.__loop = server.get_loop() async def load_sidechannel_configs( self, value: Any, *, compiler: ( edbcompiler.Compiler | edbcompiler_pool.AbstractPool | None ) = None, ) -> None: if compiler is None: compiler = self._server.get_compiler_pool() objects = {"cfg::Config": {"email_providers": value}} if isinstance(compiler, edbcompiler.Compiler): result = compiler.compile_structured_config( objects, source="magic", allow_nested=True ) else: result = await compiler.compile_structured_config( objects, source="magic", allow_nested=True ) email_providers = result["cfg::Config"]["email_providers"] self._sidechannel_email_configs = list(email_providers.value) def get_http_client(self, *, originator: str) -> HttpClient: if self._http_client is None: http_max_connections = self._server.config_lookup( 'http_max_connections', self.get_sys_config() ) self._http_client = HttpClient( http_max_connections, user_agent=f"EdgeDB {buildmeta.get_version_string(short=True)}", stat_callback=lambda stat: logger.debug( f"HTTP stat: {originator} {stat}" ) ) return self._http_client def on_switch_over(self): # Bumping this serial counter will "cancel" all pending connections # to the old master. self._ha_master_serial += 1 if self._accept_new_tasks: self.create_task( self._pg_pool.prune_all_connections(), interruptable=True, ) if self.__sys_pgcon is None: # Assume a reconnect task is already running, now that we know the # new master is likely ready, let's just give the task a push. self._sys_pgcon_reconnect_evt.set() else: # Brutally close the sys_pgcon to the old master - this should # trigger a reconnect task. self.__sys_pgcon.abort() if self._backend_adaptive_ha is not None: # Switch to FAILOVER if adaptive HA is enabled self._backend_adaptive_ha.set_state_failover( call_on_switch_over=False ) def get_active_pgcon_num(self) -> int: return self._pg_pool.active_conns @property def client_id(self) -> int: return self._cluster.get_client_id() @property def server(self) -> edbserver.BaseServer: return self._server @property def tenant_id(self) -> str: return self._tenant_id @property def suggested_client_pool_size(self) -> int: return self._suggested_client_pool_size def get_pg_dbname(self, dbname: str) -> str: return self._cluster.get_db_name(dbname) def get_pgaddr(self) -> pgconnparams.ConnectionParams: return self._cluster.get_pgaddr() @lru.method_cache def get_backend_runtime_params(self) -> pgparams.BackendRuntimeParams: return self._cluster.get_runtime_params() def get_instance_name(self) -> str: return self._instance_name def get_instance_data(self, key: str) -> str: return self._instance_data[key] def is_online(self) -> bool: return self._readiness is not srvargs.ReadinessState.Offline def is_blocked(self) -> bool: return self._readiness is srvargs.ReadinessState.Blocked def is_ready(self) -> bool: return ( self._readiness is srvargs.ReadinessState.Default or self._readiness is srvargs.ReadinessState.ReadOnly ) def is_readonly(self) -> bool: return self._readiness is srvargs.ReadinessState.ReadOnly def get_readiness_reason(self) -> str: return self._readiness_reason def get_sys_config(self) -> Mapping[str, config.SettingValue]: assert self._dbindex is not None return self._dbindex.get_sys_config() def get_report_config_data( self, protocol_version: defines.ProtocolVersion, ) -> bytes: if protocol_version >= (2, 0): return self._report_config_data[(2, 0)] else: return self._report_config_data[(1, 0)] def get_global_schema_pickle(self) -> bytes: assert self._dbindex is not None return self._dbindex.get_global_schema_pickle() def get_db(self, *, dbname: str) -> dbview.Database: assert self._dbindex is not None return self._dbindex.get_db(dbname) def maybe_get_db(self, *, dbname: str) -> dbview.Database | None: assert self._dbindex is not None return self._dbindex.maybe_get_db(dbname) def is_accepting_connections(self) -> bool: return self._accepting_connections and self._accept_new_tasks def get_roles(self) -> Mapping[str, RoleDescriptor]: return self._roles def set_roles(self, roles: Mapping[str, RoleDescriptor]) -> None: self._roles = roles self._refresh_role_capabilities() def get_role_capabilities(self) -> Mapping[str, compiler_enums.Capability]: return self._role_capabilities def _refresh_role_capabilities(self) -> None: role_capabilities: dict[str, compiler_enums.Capability] = {} for name, role_desc in self._roles.items(): superuser = bool(role_desc.get('superuser')) available_permissions = (role_desc.get('all_permissions') or ()) if superuser: capability = compiler_enums.Capability.ALL else: capability = ( compiler_enums.Capability.TRANSACTION | compiler_enums.Capability.SESSION_CONFIG | compiler_enums.Capability.PERSISTENT_CONFIG ) # Non-superuser can be given capabilities via # the permissions if 'sys::perm::data_modification' in available_permissions: capability |= compiler_enums.Capability.MODIFICATIONS if 'sys::perm::ddl' in available_permissions: capability |= compiler_enums.Capability.DDL if 'sys::perm::branch_config' in available_permissions: capability |= compiler_enums.Capability.BRANCH_CONFIG if 'sys::perm::sql_session_config' in available_permissions: capability |= compiler_enums.Capability.SQL_SESSION_CONFIG if 'sys::perm::analyze' in available_permissions: capability |= compiler_enums.Capability.ANALYZE role_capabilities[name] = capability self._role_capabilities = immutables.Map(role_capabilities) async def _fetch_roles(self, syscon: pgcon.PGConnection) -> None: role_query = self._server.get_sys_query("roles") json_data = await syscon.sql_fetch_val(role_query, use_prep_stmt=True) roles = json.loads(json_data) self._roles = immutables.Map([(r["name"], r) for r in roles]) self._refresh_role_capabilities() async def init_sys_pgcon(self) -> None: self._sys_pgcon_waiter = asyncio.Lock() self.__sys_pgcon = await self._pg_connect( defines.EDGEDB_SYSTEM_DB, source_description="init_sys_pgcon", ) self._sys_pgcon_last_active_time = time.monotonic() self._sys_pgcon_ready_evt = asyncio.Event() self._sys_pgcon_reconnect_evt = asyncio.Event() async def get_patch_count(self, conn: pgcon.PGConnection) -> int: """Get the number of applied patches.""" num_patches = await instdata.get_instdata( conn, 'num_patches', 'json') res: int = json.loads(num_patches) if num_patches else 0 return res async def _check_metaschema_compatibility( self, con: pgcon.PGConnection ) -> None: from edb.pgsql import patches as pg_patches # Check catalog version result = await instdata.get_instdata( con, 'instancedata', 'json', versioned=False ) catver = json.loads(result).get('catver') if catver != defines.EDGEDB_CATALOG_VERSION: raise errors.ConfigurationError( 'database instance incompatible with this version of Gel', details=( f'The database instance was initialized with ' f'Gel format version {catver}, but this version ' f'of the server expects format version ' f'{defines.EDGEDB_CATALOG_VERSION}.' ), hint=( 'You need to either recreate the instance and upgrade ' 'using dump/restore, or do an inplace upgrade.' ) ) # Check patch count num_patches = await self.get_patch_count(con) if num_patches < len(pg_patches.PATCHES): raise errors.ConfigurationError( 'database instance incompatible with this version of Gel', details=f"expected {len(pg_patches.PATCHES)} patches, " f"but only {num_patches} applied", hint="if you are adding an old backend to a multi-tenant " "server, firstly run a new single-tenant server on " "that backend to apply the patches.", ) async def init(self, compat_check: bool = False) -> None: logger.debug("starting database introspection") async with self.use_sys_pgcon() as syscon: if compat_check: await self._check_metaschema_compatibility(syscon) result = await instdata.get_instdata( syscon, 'instancedata', 'json') self._instance_data = immutables.Map(json.loads(result)) await self._fetch_roles(syscon) if self._server.get_compiler_pool() is None: # Parse global schema in I/O process if this is done only once logger.debug("parsing global schema locally") global_schema_pickle = pickle.dumps( await self._server.introspect_global_schema(syscon), -1 ) data = None else: # Multi-tenant server defers the parsing into the compiler data = await self._server.introspect_global_schema_json(syscon) compiler_pool = self._server.get_compiler_pool() default_database = await instdata.get_instdata( syscon, 'default_branch', 'text') if default_database: self.default_database = default_database.decode('utf-8') if data is not None: logger.debug("parsing global schema") global_schema_pickle = ( await compiler_pool.parse_global_schema(data) ) logger.info("loading system config") sys_config = await self._load_sys_config() default_sysconfig = await self._load_sys_config("sysconfig_default") await self._load_reported_config() # To make in-place upgrade failures more testable, check # 'force_database_error' with a 'startup' scope. force_error = self._server.config_lookup( 'force_database_error', sys_config) edbcompiler.maybe_force_database_error(force_error, scope='startup') self._dbindex = dbview.DatabaseIndex( self, std_schema=self._server.get_std_schema(), global_schema_pickle=global_schema_pickle, sys_config=sys_config, default_sysconfig=default_sysconfig, sys_config_spec=self._server.config_settings, ) await self._introspect_dbs() await self.load_extension_packages(buildmeta.get_extension_dir_path()) # Allow user-specified too. for dir in self._extensions_dirs: await self.load_extension_packages(dir) # Now, once all DBs have been introspected, start listening on # any notifications about schema/roles/etc changes. assert self.__sys_pgcon is not None await self.__sys_pgcon.listen_for_sysevent() self.__sys_pgcon.mark_as_system_db() self._sys_pgcon_ready_evt.set() self.populate_sys_auth() self.reload_readiness_state() self._initing = False async def load_extension_packages(self, path: pathlib.Path) -> None: exts = [] if self._is_extension_package(path): exts.append(path) else: try: with os.scandir(path) as it: for entry in it: if ( entry.is_dir() and self._is_extension_package(entry) ): exts.append(pathlib.Path(entry)) except FileNotFoundError: pass if not exts: return async with self.use_sys_pgcon() as syscon: from edb.pgsql import trampoline ext_packages_json = await syscon.sql_fetch_val( trampoline.fixup_query(""" SELECT json_agg(o.c) FROM ( SELECT json_build_array(p.name, p.version) AS c FROM edgedb_VER."_SysExtensionPackage" AS p ) AS o; """).encode('utf-8') ) ext_packages = { (name, verutils.from_json(version)) for name, version in json.loads(ext_packages_json) } for ext in exts: await self._load_extension_package(ext, ext_packages) def _is_extension_package(self, path: pathlib.Path | os.DirEntry) -> bool: return (pathlib.Path(path) / 'MANIFEST.toml').exists() async def _load_extension_package( self, path: pathlib.Path, ext_packages: set[tuple[str, verutils.Version]], ) -> None: with open(path / 'MANIFEST.toml', 'rb') as m: manifest = tomllib.load(m) name = manifest['name'] version = verutils.parse_version(manifest['version']) if (name, version) in ext_packages: logger.info( f"Extension package '{manifest['name']}' {version} " f"already installed" ) return scripts = [] for file in manifest['files']: with open(path / file, 'rb') as f: scripts.append(f.read().decode('utf-8')) from edb.schema import schema as s_schema async with self.use_sys_pgcon() as syscon: global_schema = await self._server.introspect_global_schema(syscon) compiler = edbcompiler.new_compiler( std_schema=self._server._std_schema, reflection_schema=self._server._refl_schema, schema_class_layout=self._server._schema_class_layout, ) compilerctx = edbcompiler.new_compiler_context( compiler_state=compiler.state, global_schema=global_schema, user_schema=s_schema.EMPTY_SCHEMA, internal_schema_mode=True, # Extension installation only works if stdmode or testmode is # set. Force testmode to be set, since we don't want to set # stdmode, because we want any externally loaded extensions to # be marked as *not* builtin. force_testmode=True, ) script = '\n'.join(scripts) _, sql_script = edbcompiler.compile_edgeql_script(compilerctx, script) logger.info( f"Installing extension package '{manifest['name']}'") async with self.use_sys_pgcon() as syscon: await syscon.sql_execute(sql_script.encode('utf-8')) global_schema = await self._server.introspect_global_schema(syscon) assert self._dbindex self._dbindex.update_global_schema(pickle.dumps(global_schema)) def start_watching_files(self): if self._readiness_state_file is not None: def reload_state_file(): self.reload_readiness_state() self._file_watch_finalizers.append( self._server.monitor_fs( self._readiness_state_file, reload_state_file ) ) if self._jwt_sub_allowlist_file is not None: def reload_jwt_sub_allowlist_file(): self.load_jwt_sub_allowlist() self._file_watch_finalizers.append( self._server.monitor_fs( self._jwt_sub_allowlist_file, reload_jwt_sub_allowlist_file ) ) if self._jwt_revocation_list_file is not None: def reload_jwt_revocation_list_file(): self.load_jwt_revocation_list() self._file_watch_finalizers.append( self._server.monitor_fs( self._jwt_revocation_list_file, reload_jwt_revocation_list_file, ) ) if self._config_file is not None: def reload_config_file(): self.reload_config_file.schedule() self._file_watch_finalizers.append( self.server.monitor_fs(self._config_file, reload_config_file) ) async def start_accepting_new_tasks(self) -> None: assert self._task_group is None self._task_group = asyncio.TaskGroup() await self._task_group.__aenter__() self._accept_new_tasks = True await self._cluster.start_watching(self.on_switch_over) def start_running(self) -> None: self._running = True self._accepting_connections = True assert self._dbindex is not None for db in self._dbindex.iter_dbs(): db.start_stop_extensions() def stop_accepting_connections(self) -> None: self._accepting_connections = False @property def accept_new_tasks(self): return self._accept_new_tasks def is_db_ready(self, dbname: str) -> bool: if not self._accept_new_tasks: return False if ( not (db := self.maybe_get_db(dbname=dbname)) or not db.is_introspected() ): return False return True def create_task( self, coro: Coroutine, *, interruptable: bool, name: Optional[str] = None, ) -> asyncio.Task: # Interruptable tasks are regular asyncio tasks that may be interrupted # randomly in the middle when the event loop stops; while tasks with # interruptable=False are always awaited before the server stops, so # that e.g. all finally blocks get a chance to execute in those tasks. # Therefore, it is an error trying to create a task while the server is # not expecting one, so always couple the call with an additional check if self._accept_new_tasks and self._task_group is not None: current_tenant.set(self.get_instance_name()) if interruptable: rv = self.__loop.create_task(coro, name=name) else: rv = self._task_group.create_task(coro, name=name) # Keep a strong reference of the created Task if name is not None: if name in self._named_tasks: raise RuntimeError( f"task {name!r} already exists on on this server") self._named_tasks[name] = rv rv.add_done_callback( lambda task: self._named_tasks.pop(task.get_name(), None)) else: self._tasks.add(rv) rv.add_done_callback(self._tasks.discard) return rv else: # Hint: add `if tenant.accept_new_tasks` before `.create_task()` raise RuntimeError("task cannot be created at this time") def get_task(self, name: str) -> Optional[asyncio.Task]: return self._named_tasks.get(name) def stop(self) -> None: self._running = False self._accept_new_tasks = False self._cluster.stop_watching() self._stop_watching_files() self._server.request_frontend_stop(self) def _stop_watching_files(self): while self._file_watch_finalizers: self._file_watch_finalizers.pop()() async def wait_stopped(self) -> None: if self._task_group is not None: tg = self._task_group self._task_group = None await tg.__aexit__(*sys.exc_info()) await self._pg_pool.close() def terminate_sys_pgcon(self) -> None: if self.__sys_pgcon is not None: self.__sys_pgcon.terminate() self.__sys_pgcon = None del self._sys_pgcon_waiter def set_init_con_data(self, data: list[config.ConState]) -> None: self._init_con_data = data self._init_con_sql = None if data: self._init_con_sql = self._make_init_con_sql(data) def _make_init_con_sql(self, data: list[config.ConState]) -> bytes: if not data: return b"" from edb.pgsql import common quoted_json = common.quote_literal(json.dumps(data)) return textwrap.dedent( f''' INSERT INTO _edgecon_state SELECT * FROM jsonb_to_recordset({quoted_json}::jsonb) AS cfg(name text, value jsonb, type text); ''' ).strip().encode() async def _pg_connect( self, dbname: str, source_description: str="pool connection" ) -> pgcon.PGConnection: ha_serial = self._ha_master_serial if self.get_backend_runtime_params().has_create_database: pg_dbname = self.get_pg_dbname(dbname) else: pg_dbname = self.get_pg_dbname(defines.EDGEDB_SUPERUSER_DB) started_at = time.monotonic() try: rv = await self._cluster.connect( source_description=source_description, database=pg_dbname, apply_init_script=True ) if self._server.stmt_cache_size is not None: rv.set_stmt_cache_size(self._server.stmt_cache_size) if self._init_con_sql: await rv.sql_execute(self._init_con_sql) rv.last_init_con_data = self._init_con_data except Exception: metrics.backend_connection_establishment_errors.inc( 1.0, self._instance_name ) raise finally: metrics.backend_connection_establishment_latency.observe( time.monotonic() - started_at, self._instance_name ) if ha_serial == self._ha_master_serial: rv.set_tenant(self) if self._backend_adaptive_ha is not None: self._backend_adaptive_ha.on_pgcon_made( dbname == defines.EDGEDB_SYSTEM_DB ) metrics.total_backend_connections.inc(1.0, self._instance_name) metrics.current_backend_connections.inc(1.0, self._instance_name) return rv else: rv.terminate() raise ConnectionError("connected to outdated Postgres master") async def _pg_disconnect(self, conn: pgcon.PGConnection) -> None: metrics.current_backend_connections.dec(1.0, self._instance_name) conn.terminate() def get_introspection_lock( self, dbname: str, ) -> asyncio.Lock: lock = self._introspection_locks.get(dbname) if not lock: self._introspection_locks[dbname] = lock = asyncio.Lock() return lock @contextlib.asynccontextmanager async def direct_pgcon( self, dbname: str, ) -> AsyncGenerator[pgcon.PGConnection, None]: conn = None try: conn = await self._pg_connect( dbname, source_description="direct_pgcon" ) yield conn finally: if conn is not None: await self._pg_disconnect(conn) @contextlib.asynccontextmanager async def use_sys_pgcon(self) -> AsyncGenerator[pgcon.PGConnection, None]: if not self._initing and not self._running: raise RuntimeError("Gel server is not running.") await self._sys_pgcon_waiter.acquire() if not self._initing and not self._running: self._sys_pgcon_waiter.release() raise RuntimeError("Gel server is not running.") if self.__sys_pgcon is None or not self.__sys_pgcon.is_healthy(): conn, self.__sys_pgcon = self.__sys_pgcon, None if conn is not None: self._sys_pgcon_ready_evt.clear() conn.abort() # We depend on the reconnect on connection_lost() of __sys_pgcon await self._sys_pgcon_ready_evt.wait() if self.__sys_pgcon is None: self._sys_pgcon_waiter.release() raise RuntimeError("Cannot acquire pgcon to the system DB.") try: yield self.__sys_pgcon finally: if self.__sys_pgcon is not None and self.__sys_pgcon.is_healthy(): self._sys_pgcon_last_active_time = time.monotonic() self._sys_pgcon_waiter.release() def set_stmt_cache_size(self, size: int) -> None: for conn in self._pg_pool.iterate_connections(): conn.set_stmt_cache_size(size) def on_sys_pgcon_parameter_status_updated( self, name: str, value: str, ) -> None: try: if name == "in_hot_standby" and value == "on": # It is a strong evidence of failover if the sys_pgcon receives # a notification that in_hot_standby is turned on. self.on_sys_pgcon_failover_signal() except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "on_sys_pgcon_parameter_status_updated" ) raise def on_sys_pgcon_failover_signal(self) -> None: if not self._running: return try: if self._backend_adaptive_ha is not None: # Switch to FAILOVER if adaptive HA is enabled self._backend_adaptive_ha.set_state_failover() elif getattr(self._cluster, "_ha_backend", None) is None: # If the server is not using an HA backend, nor has enabled the # adaptive HA monitoring, we still try to "switch over" by # disconnecting all pgcons if failover signal is received, # allowing reconnection to happen sooner. self.on_switch_over() # Else, the HA backend should take care of calling on_switch_over() except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "on_sys_pgcon_failover_signal" ) raise def on_sys_pgcon_connection_lost(self, exc: Exception | None) -> None: try: if not self._running: # The tenant is shutting down, release all events so that # the waiters if any could continue and exit self._sys_pgcon_ready_evt.set() self._sys_pgcon_reconnect_evt.set() return logger.error( "Connection to the system database is " + ("closed." if exc is None else f"broken! Reason: {exc}") ) self.set_pg_unavailable_msg( "Connection is lost, please check server log for the reason." ) self.__sys_pgcon = None self._sys_pgcon_ready_evt.clear() if self._accept_new_tasks: self.create_task( self._reconnect_sys_pgcon(), interruptable=True ) self.on_pgcon_broken(True) except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "on_sys_pgcon_connection_lost" ) raise async def _reconnect_sys_pgcon(self) -> None: try: conn = None while self._running: # Keep retrying as far as: # 1. This tenant is still running # 2. We still cannot connect to the Postgres cluster try: conn = await self._pg_connect( defines.EDGEDB_SYSTEM_DB, source_description="_reconnect_sys_pgcon" ) break except OSError: pass except pgcon_errors.BackendError as e: # Be quiet if the Postgres cluster is still starting up, # or the HA failover is still in progress. # TODO: ERROR_FEATURE_NOT_SUPPORTED should be removed # once PostgreSQL supports SERIALIZABLE in hot standbys if not ( e.code_is(pgcon_errors.ERROR_FEATURE_NOT_SUPPORTED) or e.code_is(pgcon_errors.ERROR_CANNOT_CONNECT_NOW) or e.code_is( pgcon_errors.ERROR_READ_ONLY_SQL_TRANSACTION ) ): logger.error("Failed connecting to the backend: %s", e) if self._running: logger.info("Waiting for the backend to recover") try: # Retry after INTERVAL seconds, unless the event is set # and we can retry immediately after the event. await asyncio.wait_for( self._sys_pgcon_reconnect_evt.wait(), defines.SYSTEM_DB_RECONNECT_INTERVAL, ) # But the event can only skip one INTERVAL. self._sys_pgcon_reconnect_evt.clear() except asyncio.TimeoutError: pass if not self._running: if conn is not None: conn.abort() return assert conn is not None logger.info("Successfully reconnected to the system database.") self.__sys_pgcon = conn self.__sys_pgcon.mark_as_system_db() self._sys_pgcon_last_active_time = time.monotonic() # This await is meant to be after mark_as_system_db() because we # need the pgcon to be able to trigger another reconnect if its # connection is lost during this await. await self.__sys_pgcon.listen_for_sysevent() self.set_pg_unavailable_msg(None) finally: self._sys_pgcon_ready_evt.set() def on_pgcon_broken(self, is_sys_pgcon: bool = False) -> None: try: if self._backend_adaptive_ha: self._backend_adaptive_ha.on_pgcon_broken(is_sys_pgcon) except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "on_pgcon_broken" ) raise def on_pgcon_lost(self) -> None: try: if self._backend_adaptive_ha: self._backend_adaptive_ha.on_pgcon_lost() except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "on_pgcon_lost") raise def set_pg_unavailable_msg(self, msg: str | None) -> None: if msg is None or self._pg_unavailable_msg is None: self._pg_unavailable_msg = msg @contextlib.asynccontextmanager async def with_pgcon( self, dbname: str, *, discard: bool=False ) -> AsyncGenerator[pgcon.PGConnection, None]: conn = await self.acquire_pgcon(dbname=dbname) try: yield conn finally: self.release_pgcon(dbname, conn, discard=discard) async def acquire_pgcon(self, dbname: str) -> pgcon.PGConnection: if self._pg_unavailable_msg is not None: raise errors.BackendUnavailableError( "Postgres is not available: " + self._pg_unavailable_msg ) for _ in range(self._pg_pool.max_capacity): conn = await self._pg_pool.acquire(dbname) if not conn.is_healthy(): logger.warning("acquired an unhealthy pgcon; discard now") elif conn.last_init_con_data is not self._init_con_data: try: await conn.sql_execute( pgcon.RESET_STATIC_CFG_SCRIPT + (self._init_con_sql or b'') ) except Exception as e: logger.warning( "failed to update pgcon; discard now: %s", e ) else: conn.last_init_con_data = self._init_con_data return conn else: return conn self._pg_pool.release(dbname, conn, discard=True) else: # This is unlikely to happen, but we defer to the caller to retry # when it does happen raise errors.BackendUnavailableError( "No healthy backend connection available at the moment, " "please try again." ) def release_pgcon( self, dbname: str, conn: pgcon.PGConnection, *, discard: bool = False, ) -> None: if not conn.is_healthy(): if not discard: logger.warning("Released an unhealthy pgcon; discard now.") discard = True try: self._pg_pool.release(dbname, conn, discard=discard) except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "release_pgcon" ) raise def allow_database_connections(self, dbname: str) -> None: self._block_new_connections.discard(dbname) def is_database_connectable(self, dbname: str) -> bool: return ( self._running and dbname != defines.EDGEDB_TEMPLATE_DB and dbname not in self._block_new_connections ) async def ensure_database_not_connected( self, dbname: str, close_frontend_conns: bool = False ) -> None: if self._dbindex and self._dbindex.count_connections(dbname): if close_frontend_conns: self._server.request_stop_fe_conns(dbname) else: # If there are open Gel connections to the `dbname` DB # just raise the error Postgres would have raised itself. raise errors.ExecutionError( f"database branch {dbname!r} is being accessed by " f"other users" ) self._block_new_connections.add(dbname) rloop = retryloop.RetryLoop( timeout=30.0, ignore=errors.ExecutionError, ) async for iteration in rloop: async with iteration: # Signal adjacent servers to prune their connections to this # database. await self.signal_sysevent( "ensure-database-not-used", dbname=dbname ) # Prune our inactive connections. (Do it in the loop # to help in the close_frontend_conns situation.) await self._pg_pool.prune_inactive_connections(dbname) await self._pg_ensure_database_not_connected(dbname) async def _pg_ensure_database_not_connected(self, dbname: str) -> None: async with self.use_sys_pgcon() as pgcon: conns = await pgcon.sql_fetch_col( b""" SELECT row_to_json(pg_stat_activity) FROM pg_stat_activity WHERE datname = $1 """, args=[self.get_pg_dbname(dbname).encode("utf-8")], ) if conns: debug_info = "" if self.server.in_dev_mode() or self.server.in_test_mode(): jconns = [json.loads(conn) for conn in conns] debug_info = ": " + json.dumps(jconns) raise errors.ExecutionError( f"database branch {dbname!r} is being accessed by " f"other users{debug_info}" ) @contextlib.asynccontextmanager async def _with_intro_pgcon( self, dbname: str ) -> AsyncGenerator[pgcon.PGConnection | None, None]: conn = None try: conn = await self.acquire_pgcon(dbname) yield conn except pgcon_errors.BackendError as e: if e.code_is(pgcon_errors.ERROR_INVALID_CATALOG_NAME): # database does not exist (anymore) logger.warning( "Detected concurrently-dropped database branch %s; " "skipping.", dbname, ) if self._dbindex is not None and self._dbindex.has_db(dbname): self._dbindex.unregister_db(dbname) yield None else: raise finally: if conn: self.release_pgcon(dbname, conn) async def _introspect_extensions( self, conn: pgcon.PGConnection ) -> set[str]: from edb.pgsql import trampoline extension_names_json = await conn.sql_fetch_val( trampoline.fixup_query(""" SELECT json_agg(name) FROM edgedb_VER."_SchemaExtension"; """).encode('utf-8'), ) if extension_names_json: extensions = set(json.loads(extension_names_json)) else: extensions = set() return extensions async def _debug_introspect( self, conn: pgcon.PGConnection, global_schema_pickle, ) -> Any: user_schema_json = ( await self._server.introspect_user_schema_json(conn) ) db_config_json = await self._server.introspect_db_config(conn) compiler_pool = self._server.get_compiler_pool() return (await compiler_pool.parse_user_schema_db_config( user_schema_json, db_config_json, global_schema_pickle, )).user_schema_pickle async def introspect_db( self, dbname: str, *, conn: Optional[pgcon.PGConnection]=None, reintrospection: bool=False, ) -> None: """Use this method to (re-)introspect a DB. If the DB is already registered in self._dbindex, its schema, config, etc. would simply be updated. If it's missing an entry for it would be created. All remote notifications of remote events should use this method to refresh the state. Even if the remote event was a simple config change, a lot of other events could happen before it was sent to us by a remote server and us receiving it. E.g. a DB could have been dropped and recreated again. It's safer to refresh the entire state than refreshing individual components of it. Besides, DDL and database-level config modifications are supposed to be rare events. This supports passing in a connection to use as well, so that we can synchronously introspect on config changes without risking deadlock by acquiring two connections at once. Returns True if the query cache mode changed. """ cm = ( contextlib.nullcontext(conn) if conn else self._with_intro_pgcon(dbname) ) async with cm as conn: if not conn: return # Acquire a per-db lock for doing the introspection, to avoid # race conditions where an older introspection might overwrite # a newer one. async with self.get_introspection_lock(dbname): await self._introspect_db( dbname, conn=conn, reintrospection=reintrospection ) async def _introspect_db( self, dbname: str, conn: pgcon.PGConnection, reintrospection: bool, ) -> None: from edb.pgsql import trampoline logger.info("introspecting database '%s'", dbname) assert self._dbindex is not None if db := self._dbindex.maybe_get_db(dbname): cache_mode_val = db.lookup_config('query_cache_mode') else: cache_mode_val = self._dbindex.lookup_config('query_cache_mode') old_cache_mode = config.QueryCacheMode.effective(cache_mode_val) # Introspection user_schema_json = ( await self._server.introspect_user_schema_json(conn) ) reflection_cache_json = await conn.sql_fetch_val( trampoline.fixup_query(""" SELECT json_agg(o.c) FROM ( SELECT json_build_object( 'eql_hash', t.eql_hash, 'argnames', array_to_json(t.argnames) ) AS c FROM ROWS FROM(edgedb_VER._get_cached_reflection()) AS t(eql_hash text, argnames text[]) ) AS o; """).encode('utf-8'), ) reflection_cache = immutables.Map( { r["eql_hash"]: tuple(r["argnames"]) for r in json.loads(reflection_cache_json) } ) backend_ids_json = await conn.sql_fetch_val( trampoline.fixup_query(""" SELECT json_object_agg( "id"::text, json_build_array("backend_id", "name") )::text FROM edgedb_VER."_SchemaType" """).encode('utf-8'), ) backend_ids = json.loads(backend_ids_json) db_config_json = await self._server.introspect_db_config(conn) extensions = await self._introspect_extensions(conn) query_cache: list[tuple[bytes, ...]] | None = None if ( not reintrospection and old_cache_mode is not config.QueryCacheMode.InMemory ): query_cache = await self._load_query_cache(conn) # Analysis compiler_pool = self._server.get_compiler_pool() parsed_db = await compiler_pool.parse_user_schema_db_config( user_schema_json, db_config_json, self.get_global_schema_pickle() ) db = self._dbindex.register_db( dbname, user_schema_pickle=parsed_db.user_schema_pickle, schema_version=parsed_db.schema_version, db_config=parsed_db.database_config, reflection_cache=reflection_cache, backend_ids=backend_ids, extensions=extensions, ext_config_settings=parsed_db.ext_config_settings, feature_used_metrics=parsed_db.feature_used_metrics, ) db.set_state_serializer( parsed_db.protocol_version, parsed_db.state_serializer, ) cache_mode = config.QueryCacheMode.effective( db.lookup_config('query_cache_mode') ) if query_cache and cache_mode is not config.QueryCacheMode.InMemory: db.hydrate_cache(query_cache) elif old_cache_mode is not cache_mode: logger.info( "clearing query cache for database '%s'", dbname) await conn.sql_execute( b'SELECT edgedb._clear_query_cache()') assert self._dbindex self._dbindex.get_db(dbname).clear_query_cache() async def _early_introspect_db(self, dbname: str) -> None: """We need to always introspect the extensions for each database. Otherwise, we won't know to accept connections for graphql or http, for example, until a native connection is made. """ current_tenant.set(self.get_instance_name()) logger.info("introspecting extensions for database '%s'", dbname) async with self._with_intro_pgcon(dbname) as conn: if not conn: return assert self._dbindex is not None if not self._dbindex.has_db(dbname): extensions = await self._introspect_extensions(conn) # Re-check in case we have a concurrent introspection task. if not self._dbindex.has_db(dbname): self._dbindex.register_db( dbname, user_schema_pickle=None, schema_version=None, db_config=None, reflection_cache=None, backend_ids=None, extensions=extensions, ext_config_settings=None, early=True, ) # Early introspection runs *before* we start accepting tasks. # This means that if we are one of multiple frontends, and we # get a ensure-database-not-used message, we aren't able to # handle it. This can result in us hanging onto a connection # that another frontend wants to get rid of. # # We still want to use the pool, though, since it limits our # connections in the way we want. # # Hack around this by pruning the connection ourself. await self._pg_pool.prune_inactive_connections(dbname) async def _introspect_dbs(self) -> None: async with self.use_sys_pgcon() as syscon: dbnames = await self._server.get_dbnames(syscon) async with asyncio.TaskGroup() as g: for dbname in dbnames: # There's a risk of the DB being dropped by another server # between us building the list of databases and loading # information about them. g.create_task(self._early_introspect_db(dbname)) async def _load_reported_config(self) -> None: async with self.use_sys_pgcon() as syscon: try: data = await syscon.sql_fetch_val( self._server.get_sys_query("report_configs"), use_prep_stmt=True, state=b'[]', # clear _config_cache ) for ( protocol_ver, typedesc, ) in self._server.get_report_config_typedesc().items(): self._report_config_data[protocol_ver] = ( struct.pack("!L", len(typedesc)) + typedesc + struct.pack("!L", len(data)) + data ) except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "load_reported_config" ) raise async def _load_sys_config( self, query_name: str = "sysconfig", syscon: pgcon.PGConnection | None = None, ) -> Mapping[str, config.SettingValue]: query = self._server.get_sys_query(query_name) if syscon is None: async with self.use_sys_pgcon() as syscon: sys_config_json = await syscon.sql_fetch_val(query) else: sys_config_json = await syscon.sql_fetch_val(query) return config.from_json(self._server.config_settings, sys_config_json) async def _reintrospect_global_schema(self) -> None: if not self._initing and not self._running: logger.warning( "global-schema-changes event received during shutdown; " "ignoring." ) return async with self.use_sys_pgcon() as syscon: data = await self._server.introspect_global_schema_json(syscon) await self._fetch_roles(syscon) compiler_pool = self._server.get_compiler_pool() global_schema_pickle = await compiler_pool.parse_global_schema(data) assert self._dbindex is not None self._dbindex.update_global_schema(global_schema_pickle) def populate_sys_auth(self) -> None: assert self._dbindex is not None cfg = self._dbindex.get_sys_config() auth = self._server.config_lookup("auth", cfg) or () self._sys_auth = tuple(sorted(auth, key=lambda a: a.priority)) def resolve_branch_name( self, database: str | None, branch: str | None ) -> str: default = self.default_database if branch == '__default__': return default elif branch is not None: return branch elif ( database == defines.EDGEDB_OLD_DEFAULT_DB and self.maybe_get_db(dbname=defines.EDGEDB_OLD_DEFAULT_DB) is None ): return default else: assert database is not None return database def resolve_user_name(self, user: str) -> str: if ( user == defines.EDGEDB_OLD_SUPERUSER and user not in self.get_roles() ): return defines.EDGEDB_SUPERUSER else: return user async def get_auth_methods( self, user: str, transport: srvargs.ServerConnTransport, ) -> list[config.CompositeConfigType]: authlist = self._sys_auth methods = [] if authlist: for auth in authlist: match = (user in auth.user or "*" in auth.user) and ( not auth.method.transports or transport in auth.method.transports ) if match: methods.append(auth.method) break if not methods: methods = self._server.get_default_auth_methods(transport) return methods async def new_dbview( self, *, dbname: str, query_cache: bool, protocol_version: defines.ProtocolVersion, role_name: str, ) -> dbview.DatabaseConnectionView: db = self.get_db(dbname=dbname) await db.introspection() assert self._dbindex is not None return self._dbindex.new_view( dbname, query_cache=query_cache, protocol_version=protocol_version, role_name=role_name, ) def remove_dbview(self, dbview_: dbview.DatabaseConnectionView) -> None: assert self._dbindex is not None return self._dbindex.remove_view(dbview_) def schedule_reported_config_if_needed(self, setting_name: str) -> None: setting = self._server.config_settings.get(setting_name) if setting and setting.report and self._accept_new_tasks: self.create_task(self._load_reported_config(), interruptable=True) def load_jwcrypto(self, jwk_key: auth.JWKSet) -> None: self._jws_key = jwk_key self.load_jwt_sub_allowlist() self.load_jwt_revocation_list() def load_jwt_sub_allowlist(self) -> None: if self._jwt_sub_allowlist_file is not None: logger.info( "(re-)loading JWT subject allowlist from " f'"{self._jwt_sub_allowlist_file}"' ) try: self._jwt_sub_allowlist = frozenset( self._jwt_sub_allowlist_file.read_text().splitlines(), ) if self._jws_key is not None: self._jws_key.default_validation_context.allow( "sub", self._jwt_sub_allowlist ) else: from . import server as edbserver raise edbserver.StartupError( "cannot load JWT sub allowlist: no secret key" ) except Exception as e: from . import server as edbserver raise edbserver.StartupError( f"cannot load JWT sub allowlist: {e}" ) from e def load_jwt_revocation_list(self) -> None: if self._jwt_revocation_list_file is not None: logger.info( "(re-)loading JWT revocation list from " f'"{self._jwt_revocation_list_file}"' ) try: self._jwt_revocation_list = frozenset( self._jwt_revocation_list_file.read_text().splitlines(), ) if self._jws_key is not None: self._jws_key.default_validation_context.deny( "jti", self._jwt_revocation_list ) else: from . import server as edbserver raise edbserver.StartupError( "cannot load JWT revocation list: no secret key" ) except Exception as e: from . import server as edbserver raise edbserver.StartupError( f"cannot load JWT revocation list: {e}" ) from e def reload_readiness_state(self) -> None: if self._readiness_state_file is None: return try: with self._readiness_state_file.open("rt") as rt: line = rt.readline().strip() try: state, _, reason = line.partition(":") self._readiness = srvargs.ReadinessState(state) self._readiness_reason = reason logger.info( "readiness state file changed, " "setting server readiness to %r%s", state, f" ({reason})" if reason else "", ) except ValueError: logger.warning( "invalid state in readiness state file (%r): %r, " "resetting server readiness to 'default'", self._readiness_state_file, state, ) self._readiness = srvargs.ReadinessState.Default except FileNotFoundError: logger.info( "readiness state file (%s) removed, resetting " "server readiness to 'default'", self._readiness_state_file, ) self._readiness = srvargs.ReadinessState.Default except Exception as e: logger.warning( "cannot read readiness state file (%s): %s, " "resetting server readiness to 'default'", self._readiness_state_file, e, ) self._readiness = srvargs.ReadinessState.Default self._accepting_connections = self.is_online() def set_readiness_state(self, state: srvargs.ReadinessState, reason: str): self._readiness = state self._readiness_reason = reason @asyncutil.exclusive_task async def reload_config_file(self): if self._config_file is None: return try: await self._reload_config_file() except Exception: logger.error("failed to reload config file", exc_info=True) metrics.background_errors.inc( 1.0, self._instance_name, "reload_config_file" ) async def load_config_file(self, compiler): logger.info("loading config file") # Read the TOML file with self._config_file.open('rb') as f: toml_data = tomllib.load(f) # Handle special case for `magic_smtp_config` magic_smtp_config = toml_data.pop("magic_smtp_config", None) if magic_smtp_config: await self.load_sidechannel_configs( magic_smtp_config, compiler=compiler ) # Parse TOML config file content into JSON if toml_data and toml_data.get("cfg::Config"): result = compiler.compile_structured_config( toml_data, "configuration file" ) if asyncio.iscoroutine(result): result = await result def setting_filter(value: config.SettingValue) -> bool: if self._server.config_settings[value.name].backend_setting: raise errors.ConfigurationError( f"backend config {value.name!r} cannot be set " f"via config file" ) return True json_obj = config.to_json_obj( self._server.config_settings, result["cfg::Config"], include_source=False, setting_filter=setting_filter, ) config_file_data = [ { "name": name, "value": value, "type": config.ConStateType.config_file, } for name, value in json_obj.items() ] else: config_file_data = [] # Update init_con_data and SQL self.set_init_con_data( [ cs for cs in self._init_con_data if cs["type"] != config.ConStateType.config_file ] + config_file_data ) async def _reload_config_file(self): # Load TOML config file compiler = self._server.get_compiler_pool() await self.load_config_file(compiler) # Update sys pgcon and reload system config async with self.use_sys_pgcon() as syscon: if syscon.last_init_con_data is not self._init_con_data: await syscon.sql_execute( pgcon.RESET_STATIC_CFG_SCRIPT + (self._init_con_sql or b'') ) syscon.last_init_con_data = self._init_con_data sys_config = await self._load_sys_config(syscon=syscon) # GOTCHA: no need to notify other EdgeDBs on the same backend about # such change to sysconfig, because static config is instance-local self._dbindex.update_sys_config(sys_config) def reload(self): # In multi-tenant mode, the file paths for the following states may be # unset in a reload, while it's impossible in a regular server. # Therefore, we are clearing the states here first, rather than doing # so in reload_readiness_state() or load_jwcrypto(). self._readiness = srvargs.ReadinessState.Default self._jwt_sub_allowlist = None self._jwt_revocation_list = None # Re-add the fs watchers in case the path changed self._stop_watching_files() self.reload_readiness_state() self.load_jwcrypto() self.reload_config_file.schedule() self.start_watching_files() async def on_before_drop_db( self, dbname: str, current_dbname: str, close_frontend_conns: bool = False, ) -> None: if current_dbname == dbname: raise errors.ExecutionError( f"cannot drop the currently open database branch {dbname!r}" ) await self.ensure_database_not_connected( dbname, close_frontend_conns=close_frontend_conns ) async def on_before_create_db_from_template( self, dbname: str, current_dbname: str, mode: str ) -> None: # Make sure the database exists. # TODO: Is it worth producing a nicer error message if it # fails on the backside? (Because of a race?) self.get_db(dbname=dbname) if mode == 'TEMPLATE': await self.ensure_database_not_connected(dbname) async def on_after_create_db_from_template( self, tgt_dbname: str, src_dbname: str, mode: str ) -> None: if mode == 'TEMPLATE': self.allow_database_connections(tgt_dbname) return logger.info('Starting copy from %s to %s', src_dbname, tgt_dbname) from edb.pgsql import common from . import bootstrap # noqa: F402 real_tgt_dbname = common.get_database_backend_name( tgt_dbname, tenant_id=self._tenant_id) real_src_dbname = common.get_database_backend_name( src_dbname, tenant_id=self._tenant_id) # HACK: Limit the maximum number of in-flight branch # creations. This is because branches use up to 3 concurrent # connections (one direct, two via pg_dump/pg_restore), and so # it can substantially blow our budget if many are in flight. # The right way to handle this issue would probably be to use # the connection pool to reserve the connections, but we would # need to carefully consider deadlock concerns if we want to # allow tasks to acquire multiple pool connections. async with self._branch_sem: async with self.direct_pgcon(tgt_dbname) as con: await bootstrap.create_branch( self._cluster, self._server._refl_schema, con, real_src_dbname, real_tgt_dbname, mode, self._server._sys_queries['backend_id_fixup'], ) logger.info('Finished copy from %s to %s', src_dbname, tgt_dbname) def on_after_drop_db(self, dbname: str) -> None: try: assert self._dbindex is not None if self._dbindex.has_db(dbname): self._dbindex.unregister_db(dbname) self._block_new_connections.discard(dbname) except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "on_after_drop_db" ) raise async def ping_backend(self) -> bool: if not self._running: return False elapsed = time.monotonic() - self._sys_pgcon_last_active_time if elapsed > HEALTH_CHECK_MIN_INTERVAL: async with asyncio.timeout(HEALTH_CHECK_TIMEOUT): async with self.use_sys_pgcon() as syscon: await syscon.sql_fetch_val(b"select 'OK'") return True async def cancel_pgcon_operation(self, con: pgcon.PGConnection) -> bool: async with self.use_sys_pgcon() as syscon: if con.idle: # con could have received the query results while we # were acquiring a system connection to cancel it. return False if con.is_cancelling(): # Somehow the connection is already being cancelled and # we don't want to have to cancellations go in parallel. return False con.start_pg_cancellation() try: # Returns True if the `pid` exists and it was able to send it a # SIGINT. Will throw an exception if the privileges aren't # sufficient. result = await syscon.sql_fetch_val( f"SELECT pg_cancel_backend({con.backend_pid});".encode(), ) finally: con.finish_pg_cancellation() return result == b"\x01" async def cancel_and_discard_pgcon( self, con: pgcon.PGConnection, dbname: str, ) -> None: try: if self._running: await self.cancel_pgcon_operation(con) finally: self.release_pgcon(dbname, con, discard=True) async def signal_sysevent(self, event: str, **kwargs) -> None: try: if not self._initing and not self._running: # This is very likely if we are doing # "run_startup_script_and_exit()", but is also possible if the # tenant was shut down with this coroutine as a background task # in flight. return async with self.use_sys_pgcon() as con: await con.signal_sysevent(event, **kwargs) except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "signal_sysevent" ) raise def on_remote_database_quarantine(self, dbname: str) -> None: if not self._accept_new_tasks: return # Block new connections to the database. self._block_new_connections.add(dbname) async def task(): try: await self._pg_pool.prune_inactive_connections(dbname) except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "remote_db_quarantine" ) raise self.create_task(task(), interruptable=True) def on_remote_ddl(self, dbname: str) -> None: if not self.is_db_ready(dbname): return # Triggered by a postgres notification event 'schema-changes' # on the __edgedb_sysevent__ channel async def task(): try: await self.introspect_db(dbname) except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "on_remote_ddl" ) raise self.create_task(task(), interruptable=True) def on_remote_database_changes(self) -> None: if not self._accept_new_tasks: return # Triggered by a postgres notification event 'database-changes' # on the __edgedb_sysevent__ channel async def task(): async with self.use_sys_pgcon() as syscon: dbnames = set(await self._server.get_dbnames(syscon)) tg = asyncio.TaskGroup() async with tg as g: for dbname in dbnames: if not self._dbindex.has_db(dbname): g.create_task(self._early_introspect_db(dbname)) dropped = [] for db in self._dbindex.iter_dbs(): if db.name not in dbnames: dropped.append(db.name) for dbname in dropped: self.on_after_drop_db(dbname) self.create_task(task(), interruptable=True) def on_remote_database_config_change(self, dbname: str) -> None: if not self._accept_new_tasks: return # Triggered by a postgres notification event 'database-config-changes' # on the __edgedb_sysevent__ channel async def task(): try: await self.introspect_db(dbname, reintrospection=True) except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "on_remote_database_config_change", ) raise self.create_task(task(), interruptable=True) async def process_local_database_config_change( self, conn: pgcon.PGConnection, dbname: str, ) -> None: # It's easier and safer to just do full re-introspection # of the DB and update all components of it. # TODO: Can we just do config? await self.introspect_db(dbname, conn=conn, reintrospection=True) def on_remote_system_config_change(self) -> None: if not self._accept_new_tasks: return # Triggered by a postgres notification event 'system-config-changes' # on the __edgedb_sysevent__ channel async def task(): try: cfg = await self._load_sys_config() self._dbindex.update_sys_config(cfg) self._server.reinit_idle_gc_collector() except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "on_remote_system_config_change" ) raise self.create_task(task(), interruptable=True) def on_global_schema_change(self) -> None: if not self._accept_new_tasks: return async def task(): try: await self._reintrospect_global_schema() except Exception: metrics.background_errors.inc( 1.0, self._instance_name, "on_global_schema_change" ) raise self.create_task(task(), interruptable=True) async def _load_query_cache( self, conn: pgcon.PGConnection, keys: Optional[Iterable[uuid.UUID]] = None, ) -> list[tuple[bytes, ...]] | None: if keys is None: return await conn.sql_fetch( b''' SELECT "schema_version", "input", "output" FROM "edgedb"."_query_cache" ''', use_prep_stmt=True, ) else: # If keys were specified, just load those keys. # TODO: Or should we do something time based? return await conn.sql_fetch( b''' SELECT "schema_version", "input", "output" ROWS FROM json_array_elements($1) j(ikey) INNER JOIN "edgedb"."_query_cache" ON (to_jsonb(ARRAY[ikey])->>0)::uuid = key ''', args=(json.dumps(keys).encode('utf-8'),), use_prep_stmt=True, ) async def evict_query_cache( self, dbname: str, keys: Iterable[uuid.UUID], ) -> None: try: async with self._with_intro_pgcon(dbname) as conn: if not conn: return for key in keys: await conn.sql_fetch( b'SELECT "edgedb"."_evict_query_cache"($1)', args=(key.bytes,), use_prep_stmt=True, ) except Exception: logger.exception("error in evict_query_cache():") metrics.background_errors.inc( 1.0, self._instance_name, "evict_query_cache" ) def on_remote_query_cache_change( self, dbname: str, to_add: Optional[list[str]], to_invalidate: Optional[list[str]], ) -> None: if not self.is_db_ready(dbname): return if to_invalidate: if db := self.maybe_get_db(dbname=dbname): db.invalidate_cache_entries( [uuid.UUID(s) for s in to_invalidate] ) async def task(): try: async with self._with_intro_pgcon(dbname) as conn: if not conn: return query_cache = await self._load_query_cache( conn, keys=to_add ) if query_cache and (db := self.maybe_get_db(dbname=dbname)): db.hydrate_cache(query_cache) except Exception: logger.exception("error in on_remote_query_cache_change():") metrics.background_errors.inc( 1.0, self._instance_name, "on_remote_query_cache_change" ) raise # If neither to_add nor to_invalidate are specified, then we do # a full introspection. if to_add or not to_invalidate: self.create_task(task(), interruptable=True) def get_debug_info(self) -> dict[str, Any]: from . import smtp pgaddr = self.get_pgaddr() pgaddr.clear_server_settings() pgdict = pgaddr.__dict__ del pgdict['database'] pgaddr.__dict__ = pgdict obj = dict( params=dict( max_backend_connections=self._max_backend_connections, suggested_client_pool_size=self._suggested_client_pool_size, tenant_id=self._tenant_id, ), instance_config=config.debug_serialize_config( self.get_sys_config()), user_roles=self._roles, pg_addr=dict( server_settings=vars(self._cluster.get_connection_params()), dsn=pgaddr.to_dsn(), ), pg_pool=self._pg_pool._build_snapshot(now=time.monotonic()), ) dbs = {} if self._dbindex is not None: for db in self._dbindex.iter_dbs(): if db.name in defines.EDGEDB_SPECIAL_DBS: continue try: email_provider = dataclasses.asdict( smtp.get_current_email_provider(db) ) except errors.ConfigurationError: email_provider = None dbs[db.name] = dict( name=db.name, dbver=db.dbver, config=( None if db.db_config is None else config.debug_serialize_config(db.db_config) ), extensions=sorted(db.extensions), query_cache_size=db.get_query_cache_size(), connections=[ dict( in_tx=view.in_tx(), in_tx_error=view.in_tx_error(), config=config.debug_serialize_config( view.get_session_config()), module_aliases=view.get_modaliases(), ) for view in db.iter_views() ], current_email_provider=email_provider, ) obj["databases"] = dbs return obj def get_compiler_args(self) -> dict[str, Any]: assert self._dbindex is not None return {"dbindex": self._dbindex} def iter_dbs(self) -> Iterator[dbview.Database]: if self._dbindex is not None: yield from self._dbindex.iter_dbs() # sentinel Tenant object to indicate an empty SNI host_tenant = Tenant.__new__(Tenant) ================================================ FILE: edb/testbase/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ================================================ FILE: edb/testbase/asyncutils.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import unittest try: import async_solipsism except ImportError: async_solipsism = None # type: ignore def with_fake_event_loop(f): # async_solpsism creates an event loop with, among other things, # a totally fake clock which starts at 0. def new(*args, **kwargs): if not async_solipsism: raise unittest.SkipTest('async_solipsism is missing') loop = async_solipsism.EventLoop() try: loop.run_until_complete(f(*args, **kwargs)) finally: loop.close() return new ================================================ FILE: edb/testbase/cluster.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional, Mapping, TYPE_CHECKING import asyncio import json import os import pathlib import socket import subprocess import sys import tempfile import time from edb import buildmeta from edb.common import devmode from edb.edgeql import quote from edb.server import auth from edb.server import args as edgedb_args from edb.server import defines as edgedb_defines from edb.server import pgcluster from edb.server import pgconnparams if TYPE_CHECKING: from edb.server import pgcon class ClusterError(Exception): pass class BaseCluster: def __init__( self, runstate_dir: pathlib.Path, *, port: int = edgedb_defines.EDGEDB_PORT, env: Optional[Mapping[str, str]] = None, testmode: bool = False, log_level: Optional[str] = None, security: Optional[ edgedb_args.ServerSecurityMode ] = None, http_endpoint_security: Optional[ edgedb_args.ServerEndpointSecurityMode ] = None, compiler_pool_mode: Optional[ edgedb_args.CompilerPoolMode ] = None, net_worker_mode: Optional[ edgedb_args.NetWorkerMode ] = None, ): self._edgedb_cmd = [sys.executable, '-I', '-m', 'edb.server.main'] if "EDGEDB_SERVER_MULTITENANT_CONFIG_FILE" not in os.environ: self._edgedb_cmd.append('--instance-name=localtest') self._edgedb_cmd.append('--tls-cert-mode=generate_self_signed') self._edgedb_cmd.append('--jose-key-mode=generate') if log_level: self._edgedb_cmd.extend(['--log-level', log_level]) compiler_addr = os.getenv('EDGEDB_TEST_REMOTE_COMPILER') if compiler_addr: compiler_pool_mode = edgedb_args.CompilerPoolMode.Remote self._edgedb_cmd.extend( [ '--compiler-pool-addr', compiler_addr, ] ) if devmode.is_in_dev_mode(): self._edgedb_cmd.append('--devmode') if testmode: self._edgedb_cmd.append('--testmode') if security: self._edgedb_cmd.extend(( '--security', str(security), )) if http_endpoint_security: self._edgedb_cmd.extend(( '--http-endpoint-security', str(http_endpoint_security), )) if compiler_pool_mode is not None: self._edgedb_cmd.extend(( '--compiler-pool-mode', str(compiler_pool_mode), )) if net_worker_mode is not None: self._edgedb_cmd.extend(( '--net-worker-mode', str(net_worker_mode), )) self._log_level = log_level self._runstate_dir = runstate_dir self._edgedb_cmd.extend(['--runstate-dir', str(runstate_dir)]) self._pg_cluster: Optional[pgcluster.BaseCluster] = None self._pg_connect_args: pgconnparams.CreateParamsKwargs = {} self._daemon_process: Optional[subprocess.Popen[str]] = None self._port = port self._effective_port = None self._tls_cert_file = None self._env = env async def _get_pg_cluster(self) -> pgcluster.BaseCluster: if self._pg_cluster is None: self._pg_cluster = await self._new_pg_cluster() return self._pg_cluster async def _new_pg_cluster(self) -> pgcluster.BaseCluster: raise NotImplementedError() async def get_status(self) -> str: pg_cluster = await self._get_pg_cluster() pg_status = await pg_cluster.get_status() initially_stopped = pg_status == 'stopped' if initially_stopped: await pg_cluster.start() elif pg_status == 'not-initialized': return 'not-initialized' conn = None try: conn = await pg_cluster.connect( source_description=f"{self.__class__.__name__}.get_status", **self._pg_connect_args, ) db_exists = await self._edgedb_template_exists(conn) finally: if conn is not None: conn.terminate() await asyncio.sleep(0) if initially_stopped: await pg_cluster.stop() if initially_stopped: return 'stopped' if db_exists else 'not-initialized,stopped' else: return 'running' if db_exists else 'not-initialized,running' def get_connect_args(self) -> dict[str, Any]: return { 'host': 'localhost', 'port': self._effective_port, 'tls_ca_file': self._tls_cert_file, } async def init( self, *, server_settings: Optional[Mapping[str, str]] = None, ) -> None: cluster_status = await self.get_status() if not cluster_status.startswith('not-initialized'): raise ClusterError('cluster has already been initialized') self._init() async def start( self, wait: int=60, *, port: Optional[int] = None, **settings: Any, ) -> None: if port is None: port = self._port if port == 0: cmd_port = 'auto' else: cmd_port = str(port) extra_args = ['--{}={}'.format(k.replace('_', '-'), v) for k, v in settings.items()] extra_args.append(f'--port={cmd_port}') status_r = status_w = None if port == 0: status_r, status_w = socket.socketpair() extra_args.append(f'--emit-server-status=fd://{status_w.fileno()}') env: Optional[dict[str, str]] if self._env: env = os.environ.copy() env.update(self._env) else: env = None self._daemon_process = subprocess.Popen( self._edgedb_cmd + extra_args, env=env, text=True, pass_fds=(status_w.fileno(),) if status_w is not None else (), ) if status_w is not None: status_w.close() try: await self._wait_for_server(timeout=wait, status_sock=status_r) except Exception: self.stop() raise def stop(self, wait: int = 60) -> None: if (self._daemon_process is not None and self._daemon_process.returncode is None): self._daemon_process.terminate() self._daemon_process.wait(wait) def destroy(self) -> None: if self._pg_cluster is not None: self._pg_cluster.destroy() def _init(self) -> None: env: Optional[dict[str, str]] if self._env: env = os.environ.copy() env.update(self._env) else: env = None init = subprocess.run( self._edgedb_cmd + ['--bootstrap-only'], stdout=sys.stdout, stderr=sys.stderr, env=env) if init.returncode != 0: raise ClusterError( f'edgedb-server --bootstrap-only failed with ' f'exit code {init.returncode}') async def _edgedb_template_exists( self, conn: pgcon.PGConnection, ) -> bool: return await conn.sql_fetch_val( b"SELECT True FROM pg_catalog.pg_database WHERE datname = $1", args=[edgedb_defines.EDGEDB_TEMPLATE_DB.encode("utf-8")], ) is not None async def _wait_for_server( self, timeout: float = 30.0, status_sock: Optional[socket.socket] = None, ) -> None: async def _read_server_status( stream: asyncio.StreamReader, ) -> dict[str, Any]: while True: line = await stream.readline() if not line: raise ClusterError("Gel server terminated") if line.startswith(b'READY='): break _, _, dataline = line.decode().partition('=') try: return json.loads(dataline) # type: ignore except Exception as e: raise ClusterError( f"Gel server returned invalid status line: " f"{dataline!r} ({e})" ) async def test() -> None: stat_reader, stat_writer = await asyncio.open_connection( sock=status_sock, ) try: data = await asyncio.wait_for( _read_server_status(stat_reader), timeout=timeout ) except asyncio.TimeoutError: raise ClusterError( f'Gel server did not initialize ' f'within {timeout} seconds' ) from None self._effective_port = data['port'] self._tls_cert_file = data['tls_cert_file'] stat_writer.close() left = timeout if status_sock is not None: started = time.monotonic() await test() left -= (time.monotonic() - started) if res := self._admin_query( "SELECT ();", f"{max(1, int(left))}s", check=False, ): raise ClusterError( f'could not connect to edgedb-server ' f'within {timeout} seconds (exit code = {res})') from None def _admin_query( self, query: str, wait_until_available: str = "0s", check: bool=True, ) -> int: args = [ "gel", "query", "--unix-path", str(os.path.abspath(self._runstate_dir)), "--port", str(self._effective_port), "--admin", "--user", edgedb_defines.EDGEDB_SUPERUSER, "--branch", edgedb_defines.EDGEDB_SUPERUSER_DB, "--wait-until-available", wait_until_available, query, ] res = subprocess.run( args=args, check=check, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) return res.returncode async def set_test_config(self) -> None: # Set session_idle_transaction_timeout to 5 minutes. self._admin_query(''' CONFIGURE INSTANCE SET session_idle_transaction_timeout := '5 minutes'; ''') # And disable session_idle_timeout self._admin_query(''' CONFIGURE INSTANCE SET session_idle_timeout := '0 seconds'; ''') async def set_superuser_password(self, password: str) -> None: self._admin_query(f''' ALTER ROLE {edgedb_defines.EDGEDB_SUPERUSER} SET password := {quote.quote_literal(password)} ''') async def trust_local_connections(self) -> None: self._admin_query(''' CONFIGURE INSTANCE INSERT Auth { priority := 0, method := (INSERT Trust), } ''') def has_create_database(self) -> bool: return True def has_create_role(self) -> bool: return True class Cluster(BaseCluster): def __init__( self, data_dir: pathlib.Path, *, pg_superuser: str = 'postgres', port: int = edgedb_defines.EDGEDB_PORT, runstate_dir: Optional[pathlib.Path] = None, env: Optional[Mapping[str, str]] = None, testmode: bool = False, log_level: Optional[str] = None, security: Optional[ edgedb_args.ServerSecurityMode ] = None, http_endpoint_security: Optional[ edgedb_args.ServerEndpointSecurityMode ] = None, compiler_pool_mode: Optional[ edgedb_args.CompilerPoolMode ] = None, ) -> None: self._data_dir = data_dir if runstate_dir is None: runstate_dir = buildmeta.get_runstate_path(self._data_dir) super().__init__( runstate_dir, port=port, env=env, testmode=testmode, log_level=log_level, security=security, http_endpoint_security=http_endpoint_security, compiler_pool_mode=compiler_pool_mode, ) self._edgedb_cmd.extend(['-D', str(self._data_dir)]) self._pg_connect_args['user'] = pg_superuser self._pg_connect_args['database'] = 'template1' self._jws_key: Optional[auth.JWKSet] = None async def _new_pg_cluster(self) -> pgcluster.Cluster: return await pgcluster.get_local_pg_cluster( self._data_dir, runstate_dir=self._runstate_dir, log_level=self._log_level, ) def get_data_dir(self) -> pathlib.Path: return self._data_dir async def init( self, *, server_settings: Optional[Mapping[str, str]] = None, ) -> None: cluster_status = await self.get_status() if not cluster_status.startswith('not-initialized'): raise ClusterError( 'cluster in {!r} has already been initialized'.format( self._data_dir)) self._init() class TempCluster(Cluster): def __init__( self, *, data_dir_suffix: Optional[str] = None, data_dir_prefix: Optional[str] = None, data_dir_parent: Optional[str] = None, env: Optional[Mapping[str, str]] = None, testmode: bool = False, log_level: Optional[str] = None, security: Optional[ edgedb_args.ServerSecurityMode ] = None, http_endpoint_security: Optional[ edgedb_args.ServerEndpointSecurityMode ] = None, compiler_pool_mode: Optional[ edgedb_args.CompilerPoolMode ] = None, ) -> None: tempdir = pathlib.Path( tempfile.mkdtemp( suffix=data_dir_suffix, prefix=data_dir_prefix, dir=data_dir_parent, ), ) super().__init__( data_dir=tempdir, runstate_dir=tempdir, env=env, testmode=testmode, log_level=log_level, security=security, http_endpoint_security=http_endpoint_security, compiler_pool_mode=compiler_pool_mode, ) class RunningCluster(BaseCluster): def __init__(self, **conn_args: Any) -> None: self.conn_args = conn_args def is_managed(self) -> bool: return False def ensure_initialized(self) -> bool: return False def get_connect_args(self) -> dict[str, Any]: return dict(self.conn_args) async def get_status(self) -> str: return 'running' async def init( self, *, server_settings: Optional[Mapping[str, str]] = None, ) -> None: pass async def start( self, wait: int=60, *, port: Optional[int] = None, **settings: Any, ) -> None: pass def stop(self, wait: int = 60) -> None: pass def destroy(self) -> None: pass def has_create_database(self) -> bool: return os.environ.get('EDGEDB_TEST_CASES_SET_UP') != 'inplace' def has_create_role(self) -> bool: return os.environ.get('EDGEDB_TEST_HAS_CREATE_ROLE') == 'True' class TempClusterWithRemotePg(BaseCluster): def __init__( self, backend_dsn: str, *, data_dir_suffix: Optional[str] = None, data_dir_prefix: Optional[str] = None, data_dir_parent: Optional[str] = None, env: Optional[Mapping[str, str]] = None, testmode: bool = False, log_level: Optional[str] = None, security: Optional[ edgedb_args.ServerSecurityMode ] = None, http_endpoint_security: Optional[ edgedb_args.ServerEndpointSecurityMode ] = None, compiler_pool_mode: Optional[ edgedb_args.CompilerPoolMode ] = None, ) -> None: runstate_dir = pathlib.Path( tempfile.mkdtemp( suffix=data_dir_suffix, prefix=data_dir_prefix, dir=data_dir_parent, ), ) self._backend_dsn = backend_dsn mt = "EDGEDB_SERVER_MULTITENANT_CONFIG_FILE" in os.environ if mt: compiler_pool_mode = edgedb_args.CompilerPoolMode.MultiTenant super().__init__( runstate_dir, env=env, testmode=testmode, log_level=log_level, security=security, http_endpoint_security=http_endpoint_security, compiler_pool_mode=compiler_pool_mode, ) if not mt: self._edgedb_cmd.extend(['--backend-dsn', backend_dsn]) async def _new_pg_cluster(self) -> pgcluster.BaseCluster: return await pgcluster.get_remote_pg_cluster(self._backend_dsn) def has_create_database(self) -> bool: if self._pg_cluster: return self._pg_cluster.get_runtime_params().has_create_database else: return super().has_create_database() def has_create_role(self) -> bool: if self._pg_cluster: return self._pg_cluster.get_runtime_params().has_create_role else: return super().has_create_role() ================================================ FILE: edb/testbase/connection.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """A specialized client API for Gel tests. Historically Gel tests relied on a very specific client API that is no longer supported by edgedb-python. Here we implement that API (for example, transactions can be nested and are non-retrying). """ from __future__ import annotations import typing import abc import asyncio import contextlib import enum import functools import random import socket import ssl import time from gel import abstract from gel import errors from gel import con_utils from gel import enums as edgedb_enums from gel import options from gel.protocol import protocol # type: ignore from edb.protocol import protocol as edb_protocol # type: ignore class TransactionState(enum.Enum): NEW = 0 STARTED = 1 COMMITTED = 2 ROLLEDBACK = 3 FAILED = 4 InputLanguage = protocol.InputLanguage class BaseTransaction(abc.ABC): ID_COUNTER = 0 def __init__(self, owner): self._connection = owner self._state = TransactionState.NEW self._managed = False self._nested = False type(self).ID_COUNTER += 1 self._id = f'raw_tx_{self.ID_COUNTER}' def is_active(self) -> bool: return self._state is TransactionState.STARTED def __check_state_base(self, opname): if self._state is TransactionState.COMMITTED: raise errors.InterfaceError( f'cannot {opname}; the transaction is already committed') if self._state is TransactionState.ROLLEDBACK: raise errors.InterfaceError( f'cannot {opname}; the transaction is already rolled back') if self._state is TransactionState.FAILED: raise errors.InterfaceError( f'cannot {opname}; the transaction is in error state') def __check_state(self, opname): if self._state is not TransactionState.STARTED: if self._state is TransactionState.NEW: raise errors.InterfaceError( f'cannot {opname}; the transaction is not yet started') self.__check_state_base(opname) def _make_start_query(self): self.__check_state_base('start') if self._state is TransactionState.STARTED: raise errors.InterfaceError( 'cannot start; the transaction is already started') qry = self._make_start_query_inner() if self._connection._top_xact is None: self._connection._top_xact = self return qry @abc.abstractmethod def _make_start_query_inner(self): ... def _make_commit_query(self): self.__check_state('commit') if self._connection._top_xact is self: self._connection._top_xact = None return 'COMMIT;' def _make_rollback_query(self): self.__check_state('rollback') if self._connection._top_xact is self: self._connection._top_xact = None if self._nested: query = f'ROLLBACK TO SAVEPOINT {self._id};' else: query = 'ROLLBACK;' return query async def start(self) -> None: query = self._make_start_query() try: await self._connection.execute(query) except BaseException: self._state = TransactionState.FAILED raise else: self._state = TransactionState.STARTED async def commit(self) -> None: if self._managed: raise errors.InterfaceError( 'cannot manually commit from within an `async with` block') await self._commit() async def _commit(self) -> None: query = self._make_commit_query() try: # Use _fetchall to ensure there is no retry performed. # The protocol level apparently thinks the transaction is # over if COMMIT fails, and since we use that to decide # whether to retry in query/execute, it would want to # retry a COMMIT. await self._connection._fetchall(query) except BaseException: self._state = TransactionState.FAILED raise else: self._state = TransactionState.COMMITTED async def rollback(self) -> None: if self._managed: raise errors.InterfaceError( 'cannot manually rollback from within an `async with` block') await self._rollback() async def _rollback(self) -> None: query = self._make_rollback_query() try: await self._connection.execute(query) except BaseException: self._state = TransactionState.FAILED raise else: self._state = TransactionState.ROLLEDBACK class RawTransaction(BaseTransaction): def _make_start_query_inner(self): con = self._connection if con._top_xact is not None: # Nested transaction block self._nested = True if self._nested: query = f'DECLARE SAVEPOINT {self._id};' else: query = 'START TRANSACTION;' return query def _make_commit_query(self): query = super()._make_commit_query() if self._nested: query = f'RELEASE SAVEPOINT {self._id};' return query def _make_rollback_query(self): query = super()._make_rollback_query() if self._nested: query = f'ROLLBACK TO SAVEPOINT {self._id};' return query async def __aenter__(self): if self._managed: raise errors.InterfaceError( 'cannot enter context: already in an `async with` block') self._managed = True await self.start() return self async def __aexit__(self, extype, ex, tb): try: if extype is not None: await self._rollback() else: await self._commit() finally: self._managed = False class _Executor(abstract.AsyncIOExecutor): # TODO: Remove this, once we land this in gel-python and update # our bindings. async def query_graphql_json( self, query, *args: typing.Any, **kwargs: typing.Any ) -> str: return await self._query( abstract.QueryContext( query=abstract.QueryWithArgs( query, # None, args, kwargs, input_language=ord('G'), ), cache=self._get_query_cache(), query_options=abstract._query_single_json_opts, retry_options=self._get_retry_options(), state=self._get_state(), # transaction_options=self._get_active_tx_options(), warning_handler=self._get_warning_handler(), annotations=self._get_annotations(), ) ) class Iteration(BaseTransaction, _Executor): def __init__(self, retry, connection, iteration): super().__init__(connection) self._options = retry._options.transaction_options self.__retry = retry self.__iteration = iteration self.__started = False async def __aenter__(self): if self._managed: raise errors.InterfaceError( 'cannot enter context: already in an `async with` block') self._managed = True return self async def __aexit__(self, extype, ex, tb): self._managed = False if not self.__started: return False try: if extype is None: await self._commit() else: await self._rollback() except errors.EdgeDBError as err: if ex is None: # On commit we don't know if commit is succeeded before the # database have received it or after it have been done but # network is dropped before we were able to receive a response raise err # If we were going to rollback, look at original error # to find out whether we want to retry, regardless of # the rollback error. # In this case we ignore rollback issue as original error is more # important, e.g. in case `CancelledError` it's important # to propagate it to cancel the whole task. # NOTE: rollback error is always swallowed, should we use # on_log_message for it? if ( extype is not None and issubclass(extype, errors.EdgeDBError) and ex.has_tag(errors.SHOULD_RETRY) ): return self.__retry._retry(ex) def _make_start_query_inner(self): return self._options.start_transaction_query() def _get_query_cache(self) -> abstract.QueryCache: return self._connection._query_cache async def _query(self, query_context: abstract.QueryContext): await self._ensure_transaction() return await self._connection.raw_query(query_context) async def _execute(self, query: abstract.ExecuteContext) -> None: await self._ensure_transaction() await self._connection._execute(query) async def _ensure_transaction(self): if not self._managed: raise errors.InterfaceError( "Only managed retriable transactions are supported. " "Use `async with transaction:`" ) if not self.__started: self.__started = True if self._connection.is_closed(): await self._connection.connect( single_attempt=self.__iteration != 0 ) await self.start() def _get_retry_options(self) -> options.RetryOptions: return options.RetryOptions.defaults() def _get_state(self) -> options.State: return self._connection._get_state() def _get_warning_handler(self) -> options.WarningHandler: return self._connection._get_warning_handler() def _get_annotations(self) -> dict[str, str]: return self._connection._get_annotations() class Retry: def __init__(self, connection, raw=False): self._connection = connection self._iteration = 0 self._done = False self._next_backoff = 0 self._options = connection._options def _retry(self, exc): self._last_exception = exc rule = self._options.retry_options.get_rule_for_exception(exc) if self._iteration >= rule.attempts: return False self._done = False self._next_backoff = rule.backoff(self._iteration) return True def __aiter__(self): return self async def __anext__(self): # Note: when changing this code consider also # updating Retry.__next__. if self._done: raise StopAsyncIteration if self._next_backoff: await asyncio.sleep(self._next_backoff) self._done = True iteration = Iteration(self, self._connection, self._iteration) self._iteration += 1 return iteration class Connection(options._OptionsMixin, _Executor): _top_xact: RawTransaction | None = None def __init__( self, connect_args, *, test_no_tls=False, server_hostname=None ): super().__init__() self._connect_args = connect_args self._protocol = None self._transport = None self._query_cache = abstract.QueryCache( codecs_registry=protocol.CodecsRegistry(), query_cache=protocol.LRUMapping(maxsize=1000), ) self._test_no_tls = test_no_tls self._params = None self._server_hostname = server_hostname self._log_listeners = set() self._capture_warnings = None def add_log_listener(self, callback): self._log_listeners.add(callback) def remove_log_listener(self, callback): self._log_listeners.discard(callback) def _get_retry_options(self) -> options.RetryOptions: return self._options.retry_options def _get_state(self): return self._options.state def _get_annotations(self) -> dict[str, str]: return self._options.annotations @contextlib.contextmanager def capture_warnings(self) -> typing.Iterator[list[errors.EdgeDBError]]: old = self._capture_warnings warnings: list[errors.EdgeDBError] = [] self._capture_warnings = warnings try: yield warnings finally: self._capture_warnings = old def _warning_handler(self, warnings, res): if self._capture_warnings is not None: self._capture_warnings.extend(warnings) return res else: raise warnings[0] def _get_warning_handler(self) -> options.WarningHandler: return self._warning_handler def _on_log_message(self, msg): if self._log_listeners: loop = asyncio.get_running_loop() for cb in self._log_listeners: loop.call_soon(cb, self, msg) def _shallow_clone(self): con = self.__class__.__new__(self.__class__) con._connect_args = self._connect_args con._protocol = self._protocol con._query_cache = self._query_cache con._test_no_tls = self._test_no_tls con._params = self._params con._server_hostname = self._server_hostname return con def _get_query_cache(self) -> abstract.QueryCache: return self._query_cache async def ensure_connected(self): if self.is_closed(): await self.connect() return self async def _query(self, query_context: abstract.QueryContext): await self.ensure_connected() return await self.raw_query(query_context) async def _retry_operation(self, func): i = 0 while True: i += 1 try: return await func() # Retry transaction conflict errors, up to a maximum of 5 # times. We don't do this if we are in a transaction, # since that *ought* to be done at the transaction level. # Though in reality in the test suite it is usually done at the # test runner level. except errors.TransactionConflictError: if i >= 10 or self.is_in_transaction(): raise await asyncio.sleep( min((2 ** i) * 0.1, 10) + random.randrange(100) * 0.001 ) def _prohibit_state(self, state) -> None: # The testbase connection uses our own subclass of # gel-python's AsyncIOProtocol, # edb.protocol.protocol.Protocol, that overrides encode_state # to ignore any user specified state and to always just use # whatever the server suggests. # # It is probably possible to make CONFIGURE .../SET GLOBAL # play nicely with with_globals/etc, by decoding what the server # has sent and then overlaying the user configured stuff. # # Since it doesn't work, disable it. if state.as_dict(): raise AssertionError( f'test suite client cannot use with_XXX config methods; ' f'use SET ... in the protocol instead ' f'or use the stock python client ' f'(or go make it work; that would be nice too)\n' f'config was: {state.as_dict()}' ) async def _execute(self, script: abstract.ExecuteContext) -> None: await self.ensure_connected() self._prohibit_state(script.state) async def _inner(): ctx = script.lower(allow_capabilities=edgedb_enums.Capability.ALL) res = await self._protocol.execute(ctx) if ctx.warnings: script.warning_handler(ctx.warnings, res) await self._retry_operation(_inner) async def raw_query(self, query_context: abstract.QueryContext): self._prohibit_state(query_context.state) async def _inner(): ctx = query_context.lower( allow_capabilities=edgedb_enums.Capability.ALL) res = await self._protocol.query(ctx) if ctx.warnings: res = query_context.warning_handler(ctx.warnings, res) return res return await self._retry_operation(_inner) async def _fetchall_generic(self, ctx): await self.ensure_connected() res = await self._protocol.query(ctx) if ctx.warnings: res = self._get_warning_handler()(ctx.warnings, res) return res async def _fetchall( self, query: str, *args, __language__: protocol.InputLanguage = protocol.InputLanguage.EDGEQL, __limit__: int = 0, __typeids__: bool = False, __typenames__: bool = False, __allow_capabilities__: edgedb_enums.Capability = ( edgedb_enums.Capability.ALL), **kwargs, ): return await self._fetchall_generic( protocol.ExecuteContext( query=query, args=args, kwargs=kwargs, reg=self._query_cache.codecs_registry, qc=self._query_cache.query_cache, implicit_limit=__limit__, inline_typeids=__typeids__, inline_typenames=__typenames__, input_language=__language__, output_format=protocol.OutputFormat.BINARY, allow_capabilities=__allow_capabilities__, ) ) async def _fetchall_json( self, query: str, *args, __limit__: int = 0, **kwargs, ): return await self._fetchall_generic( protocol.ExecuteContext( query=query, args=args, kwargs=kwargs, reg=self._query_cache.codecs_registry, qc=self._query_cache.query_cache, implicit_limit=__limit__, inline_typenames=False, input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON, ) ) async def _fetchall_json_elements(self, query: str, *args, **kwargs): return await self._fetchall_generic( protocol.ExecuteContext( query=query, args=args, kwargs=kwargs, reg=self._query_cache.codecs_registry, qc=self._query_cache.query_cache, input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON_ELEMENTS, allow_capabilities=edgedb_enums.Capability.EXECUTE, # type: ignore ) ) def _clear_codecs_cache(self): self._query_cache.codecs_registry.clear_cache() def _get_last_status(self) -> typing.Optional[str]: if self._protocol is None: return None status = self._protocol.last_status if status is not None: status = status.decode() return status def _get_last_capabilities( self, ) -> typing.Optional[edgedb_enums.Capability]: if self._protocol is None: return None else: return self._protocol.last_capabilities def is_closed(self): return self._protocol is None or not self._protocol.connected async def connect(self, single_attempt=False): self._params, client_config = con_utils.parse_connect_arguments( **self._connect_args, tls_server_name=None, command_timeout=None, server_settings=None, ) start = time.monotonic() if single_attempt: max_time = 0 else: max_time = start + client_config.wait_until_available iteration = 1 while True: addr = self._params.address try: await asyncio.wait_for( self.connect_addr(), client_config.connect_timeout, ) except TimeoutError as e: if iteration > 1 and time.monotonic() >= max_time: raise errors.ClientConnectionTimeoutError( f"connecting to {addr} failed in" f" {client_config.connect_timeout} sec" ) from e except errors.ClientConnectionError as e: if ( not e.has_tag(errors.SHOULD_RECONNECT) or (iteration > 1 and time.monotonic() >= max_time) ): nice_err = e.__class__( con_utils.render_client_no_connection_error( e, addr, attempts=iteration, duration=time.monotonic() - start, )) raise nice_err from e.__cause__ else: return iteration += 1 await asyncio.sleep(0.01 + random.random() * 0.2) async def connect_addr(self): tr = None loop = asyncio.get_running_loop() addr = self._params.address protocol_factory = functools.partial( edb_protocol.Protocol, self._params, loop ) try: if isinstance(addr, str): # UNIX socket tr, pr = await loop.create_unix_connection( protocol_factory, addr ) elif self._test_no_tls: tr, pr = await loop.create_connection(protocol_factory, *addr) else: try: tr, pr = await loop.create_connection( protocol_factory, *addr, server_hostname=self._server_hostname, ssl=self._params.ssl_ctx, ) except ssl.CertificateError as e: raise con_utils.wrap_error(e) from e except ssl.SSLError as e: if e.reason == 'CERTIFICATE_VERIFY_FAILED': raise con_utils.wrap_error(e) from e tr, pr = await loop.create_connection( protocol_factory, *addr, ) else: con_utils.check_alpn_protocol( tr.get_extra_info('ssl_object') ) except socket.gaierror as e: # All name resolution errors are considered temporary raise errors.ClientConnectionFailedTemporarilyError(str(e)) from e except OSError as e: raise con_utils.wrap_error(e) from e except Exception: if tr is not None: tr.close() raise pr.set_connection(self) try: await pr.connect() except OSError as e: if tr is not None: tr.close() raise con_utils.wrap_error(e) from e except BaseException: if tr is not None: tr.close() raise self._protocol = pr self._transport = tr def retrying_transaction(self) -> Retry: return Retry(self) def transaction(self) -> RawTransaction: return RawTransaction(self) def is_in_transaction(self): return self._protocol.is_in_transaction() def get_settings(self) -> dict[str, typing.Any]: return self._protocol.get_settings() @property def dbname(self) -> str: return self._params.database def connected_addr(self): return self._params.address async def aclose(self): if not self.is_closed(): try: self._protocol.terminate() await self._protocol.wait_for_disconnect() except (Exception, asyncio.CancelledError): self.terminate() raise def terminate(self): if not self.is_closed(): self._protocol.abort() def get_transport(self): return self._transport async def async_connect_test_client( dsn: typing.Optional[str] = None, host: typing.Optional[str] = None, port: typing.Optional[int] = None, credentials: typing.Optional[str] = None, credentials_file: typing.Optional[str] = None, user: typing.Optional[str] = None, password: typing.Optional[str] = None, secret_key: typing.Optional[str] = None, branch: typing.Optional[str] = None, database: typing.Optional[str] = None, tls_ca: typing.Optional[str] = None, tls_ca_file: typing.Optional[str] = None, tls_security: typing.Optional[str] = None, test_no_tls: bool = False, wait_until_available: int = 30, timeout: int = 10, server_hostname: str | None = None, ) -> Connection: return await Connection( { "dsn": dsn, "host": host, "port": port, "credentials": credentials, "credentials_file": credentials_file, "user": user, "password": password, "secret_key": secret_key, "branch": branch, "database": database, "timeout": timeout, "tls_ca": tls_ca, "tls_ca_file": tls_ca_file, "tls_security": tls_security, "wait_until_available": wait_until_available, }, test_no_tls=test_no_tls, server_hostname=server_hostname, ).ensure_connected() ================================================ FILE: edb/testbase/experimental_interpreter.py ================================================ from __future__ import annotations from typing import ( Any, Optional, ) import unittest from edb.common import assert_data_shape from edb.tools.experimental_interpreter.new_interpreter import EdgeQLInterpreter bag = assert_data_shape.bag class ExperimentalInterpreterTestCase(unittest.TestCase): SCHEMA: Optional[str] = None SETUP: Optional[str] = None INTERPRETER_USE_SQLITE = False client: EdgeQLInterpreter initial_state: object @classmethod def setUpClass(cls): if cls.SCHEMA is not None: with open(cls.SCHEMA) as f: schema_content = f.read() else: schema_content = "" sqlite_filename = None if cls.INTERPRETER_USE_SQLITE: sqlite_filename = ":memory:" try: import sqlite3 except ModuleNotFoundError: raise unittest.SkipTest("sqlite is not installed") if sqlite3.sqlite_version_info < (3, 37): raise unittest.SkipTest("sqlite version is too old (need 3.37)") cls.client = EdgeQLInterpreter(schema_content, sqlite_filename) if cls.SETUP is not None: with open(cls.SETUP) as f: setup_content = f.read() cls.client.query_str(setup_content) cls.initial_state = cls.client.db.dump_state() def setUp(self): self.client.db.restore_state(self.initial_state) def execute(self, query: str, *, variables=None) -> Any: return self.client.run_single_str_get_json_with_cache( query, variables=variables) def execute_single(self, query: str, *, variables=None) -> Any: return self.client.query_single_json(query, variables=variables) def assert_query_result(self, query, exp_result_json, exp_result_binary=..., *, msg: Optional[str] = None, sort: Optional[bool] = None, variables=None, ): if (hasattr(self, "use_experimental_interpreter") and self.use_experimental_interpreter): result = self.client.run_single_str_get_json_with_cache( query, variables=variables) res = result if sort is not None: assert_data_shape.sort_results(res, sort) if exp_result_binary is not ...: assert_data_shape.assert_data_shape( res, exp_result_binary, self.fail, message=msg) else: assert_data_shape.assert_data_shape( res, exp_result_json, self.fail, message=msg) ================================================ FILE: edb/testbase/http.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Callable, Optional, ) import asyncio import http.server import json import threading import urllib.parse import urllib.request import dataclasses import time import random import gel from edb.errors import base as base_errors from edb.common import assert_data_shape from . import server bag = assert_data_shape.bag class BaseHttpTest(server.QueryTestCase): @classmethod async def _wait_for_db_config( cls, config_key, *, server=None, instance_config=False, value=None, is_reset=False, ): dbname = cls.get_database_name() # Wait for the database config changes to propagate to the # server by watching a debug endpoint async for tr in cls.try_until_succeeds( ignore=AssertionError, timeout=120, ): async with tr: with cls.http_con(server) as http_con: ( rdata, _headers, _status, ) = cls.http_con_request( http_con, prefix="", path="server-info", ) data = json.loads(rdata) if "databases" not in data: # multi-tenant instance - use the first tenant data = next(iter(data["tenants"].values())) if instance_config: config = data["instance_config"] else: config = data["databases"][dbname]["config"] if is_reset: if config_key in config: raise AssertionError("database config not ready") else: if config_key not in config: raise AssertionError("database config not ready") if value and config[config_key] != value: raise AssertionError("database config not ready") class BaseHttpExtensionTest(BaseHttpTest): @classmethod def get_extension_path(cls): raise NotImplementedError @classmethod def get_api_prefix(cls): extpath = cls.get_extension_path() dbname = cls.get_database_name() return f"/branch/{dbname}/{extpath}" class ExtAuthTestCase(BaseHttpExtensionTest): EXTENSIONS = ["pgcrypto", "auth"] @classmethod def get_extension_path(cls): return "ext/auth" def generate_pkce_pair(self) -> tuple[str, str]: """Generate a PKCE verifier and its corresponding challenge. Returns: (verifier, challenge): tuple of str """ import os import base64 import hashlib verifier = base64.urlsafe_b64encode(os.urandom(43)).rstrip(b'=') challenge = base64.urlsafe_b64encode( hashlib.sha256(verifier).digest() ).rstrip(b'=') return verifier.decode(), challenge.decode() class EdgeQLTestCase(BaseHttpExtensionTest): EXTENSIONS = ["edgeql_http"] @classmethod def get_extension_path(cls): return "edgeql" def edgeql_query( self, query, *, use_http_post=True, variables=None, globals=None, config=None, origin=None, user=None, password=None, ): req_data = {"query": query} if use_http_post: if variables is not None: req_data["variables"] = variables if globals is not None: req_data["globals"] = globals if config is not None: req_data["config"] = config req = urllib.request.Request(self.http_addr, method="POST") req.add_header("Content-Type", "application/json") req.add_header( "Authorization", self.make_auth_header(user, password) ) if origin: req.add_header("Origin", origin) response = urllib.request.urlopen( req, json.dumps(req_data).encode(), context=self.tls_context ) resp_data = json.loads(response.read()) else: if variables is not None: req_data["variables"] = json.dumps(variables) if globals is not None: req_data["globals"] = json.dumps(globals) if config is not None: req_data["config"] = json.dumps(config) req = urllib.request.Request( f"{self.http_addr}/?{urllib.parse.urlencode(req_data)}", ) req.add_header( "Authorization", self.make_auth_header(user, password) ) response = urllib.request.urlopen( req, context=self.tls_context, ) resp_data = json.loads(response.read()) if "data" in resp_data: return (resp_data["data"], response) err = resp_data["error"] ex_msg = err["message"].strip() ex_code = err["code"] raise gel.EdgeDBError._from_code(ex_code, ex_msg) def assert_edgeql_query_result( self, query, result, *, msg=None, sort=None, use_http_post=True, variables=None, globals=None, config=None, ): res, _ = self.edgeql_query( query, use_http_post=use_http_post, variables=variables, globals=globals, config=config, ) if sort is not None: # GQL will always have a single object returned. The data is # in the top-level fields, so that's what needs to be sorted. for r in res.values(): assert_data_shape.sort_results(r, sort) assert_data_shape.assert_data_shape(res, result, self.fail, message=msg) return res class GraphQLTestCase(BaseHttpExtensionTest): EXTENSIONS = ["graphql"] @classmethod def get_extension_path(cls): return "graphql" def graphql_query( self, query, *, operation_name=None, use_http_post=True, variables=None, globals=None, deprecated_globals=None, config=None, user=None, password=None, ): def inner(): return self._graphql_query( query, operation_name=operation_name, use_http_post=use_http_post, variables=variables, globals=globals, deprecated_globals=deprecated_globals, config=config, user=user, password=password, ) return self._retry_operation(inner) def _retry_operation(self, func): i = 0 while True: i += 1 try: return func() # Retry transaction conflict errors except gel.errors.TransactionConflictError: if i >= 10: raise time.sleep( min((2 ** i) * 0.1, 10) + random.randrange(100) * 0.001 ) def _graphql_query( self, query, *, operation_name=None, use_http_post=True, variables=None, globals=None, deprecated_globals=None, config=None, user=None, password=None, ): req_data = {"query": query} if operation_name is not None: req_data["operationName"] = operation_name if use_http_post: if variables is not None: req_data["variables"] = variables if globals is not None: if variables is None: req_data["variables"] = dict() req_data["variables"]["__globals__"] = globals if config is not None: if variables is None: req_data["variables"] = dict() req_data["variables"]["__config__"] = config # Support testing the old way of sending globals. if deprecated_globals is not None: req_data["globals"] = deprecated_globals req = urllib.request.Request(self.http_addr, method="POST") req.add_header("Content-Type", "application/json") req.add_header( "Authorization", self.make_auth_header(user, password) ) response = urllib.request.urlopen( req, json.dumps(req_data).encode(), context=self.tls_context ) resp_data = json.loads(response.read()) else: if globals is not None: if variables is None: variables = dict() variables["__globals__"] = globals if config is not None: if variables is None: variables = dict() variables["__config__"] = config # Support testing the old way of sending globals. if deprecated_globals is not None: req_data["globals"] = json.dumps(deprecated_globals) if variables is not None: req_data["variables"] = json.dumps(variables) req = urllib.request.Request( f"{self.http_addr}/?{urllib.parse.urlencode(req_data)}", ) req.add_header( "Authorization", self.make_auth_header(user, password) ) response = urllib.request.urlopen( req, context=self.tls_context, ) resp_data = json.loads(response.read()) if "data" in resp_data: return resp_data["data"] err = resp_data["errors"][0] typename, msg = err["message"].split(":", 1) msg = msg.strip() try: ex_type = getattr(gel, typename) except AttributeError: raise AssertionError( f"server returned an invalid exception typename: {typename!r}" f"\n Message: {msg}" ) ex = ex_type(msg) if "locations" in err: # XXX Fix this when LSP "location" objects are implemented ex._attrs[base_errors.FIELD_LINE_START] = str( err["locations"][0]["line"] ).encode() ex._attrs[base_errors.FIELD_COLUMN_START] = str( err["locations"][0]["column"] ).encode() raise ex async def _native_graphql_query( self, query, *, # Can/should we support operation_name somehow... variables=None, globals=None, ): # The graphql tests are all synchronous, and our gel # connections need to be async... so we spin up a new # connection and asyncio.run the coro. con = await self.connect() try: # Ahhhhhh. We don't support with_globals on testbase # connections, so.... if globals: glob_defs = { obj.name: obj for obj in await con.query(''' select schema::Global { name, required, tname := .target.name } ''') } for k, v in globals.items(): glob = glob_defs[k] # Why do we allow this for the HTTP proto?? # We don't for binary proto stuff. if v is None and glob.required: continue mod = 'required ' if glob.required else '' await con.execute( f'set global {k} := <{mod}{glob.tname}>$0', json.dumps(v), ) async with server.RollbackChanges(con): return json.loads(await con.query_graphql_json( query, **(variables or {}), )) finally: await con.aclose() def assert_graphql_query_result( self, query, result, *, msg=None, sort=None, operation_name=None, use_http_post=True, native_variables=None, variables=None, globals=None, deprecated_globals=None, config=None, ): # Try to use the native protocol first! if operation_name is None and config is None: try: res = asyncio.run(self._native_graphql_query( query, variables=( native_variables if native_variables is not None else variables ), globals=globals or deprecated_globals, )) if sort is not None: # GQL will always have a single object # returned. The data is in the top-level fields, # so that's what needs to be sorted. for r in res.values(): assert_data_shape.sort_results(r, sort) assert_data_shape.assert_data_shape( res, result, self.fail, message=msg) except gel.UnsupportedFeatureError as e: if 'Default variables are not supported' in str(e): # Whatever. pass else: raise res = self.graphql_query( query, operation_name=operation_name, use_http_post=use_http_post, variables=variables, globals=globals, deprecated_globals=deprecated_globals, config=config, ) if sort is not None: # GQL will always have a single object returned. The data is # in the top-level fields, so that's what needs to be sorted. for r in res.values(): assert_data_shape.sort_results(r, sort) assert_data_shape.assert_data_shape(res, result, self.fail, message=msg) return res class MockHttpServerHandler(http.server.BaseHTTPRequestHandler): def get_server_and_path(self) -> tuple[str, str]: server = f'http://{self.headers.get("Host")}' return server, self.path def do_GET(self): self.close_connection = False server, path = self.get_server_and_path() self.server.owner.handle_request("GET", server, path, self) def do_POST(self): self.close_connection = False server, path = self.get_server_and_path() self.server.owner.handle_request("POST", server, path, self) def log_message(self, *args): pass class MultiHostMockHttpServerHandler(MockHttpServerHandler): def get_server_and_path(self) -> tuple[str, str]: # Path looks like: # http://127.0.0.1:32881/https%3A//slack.com/.well-known/openid-configuration raw_url = urllib.parse.unquote(self.path.lstrip("/")) url = urllib.parse.urlparse(raw_url) return (f"{url.scheme}://{url.netloc}", url.path.lstrip("/")) ResponseType = tuple[str, int] | tuple[str, int, dict[str, str]] @dataclasses.dataclass class RequestDetails: headers: dict[str, str | Any] query_params: dict[str, list[str]] body: Optional[str] class MockHttpServer: def __init__( self, handler_type: type[MockHttpServerHandler] = MockHttpServerHandler, port: int = 0, ) -> None: self._port = port self.has_started = threading.Event() self.routes: dict[ tuple[str, str, str], ( ResponseType | Callable[ [MockHttpServerHandler, RequestDetails], ResponseType ] ), ] = {} self.requests: dict[tuple[str, str, str], list[RequestDetails]] = {} self.url: Optional[str] = None self.handler_type = handler_type def get_base_url(self) -> str: if self.url is None: raise RuntimeError("mock server is not running") return self.url def register_route_handler( self, method: str, server: str, path: str, ): def wrapper( handler: ( ResponseType | Callable[ [MockHttpServerHandler, RequestDetails], ResponseType ] ), ): self.routes[(method, server, path)] = handler return handler return wrapper def handle_request( self, method: str, server: str, path: str, handler: MockHttpServerHandler, ): # `handler` is documented here: # https://docs.python.org/3/library/http.server.html#http.server.BaseHTTPRequestHandler key = (method, server, path) if key not in self.requests: self.requests[key] = [] # Parse and save the request details parsed_path = urllib.parse.urlparse(path) headers = {k.lower(): v for k, v in dict(handler.headers).items()} query_params = urllib.parse.parse_qs(parsed_path.query) if "content-length" in headers: body = handler.rfile.read(int(headers["content-length"])).decode() else: body = None request_details = RequestDetails( headers=headers, query_params=query_params, body=body, ) self.requests[key].append(request_details) if key not in self.routes: error_message = ( f"No route handler for {key}\n\n" f"Available routes:\n{self.routes}" ) handler.send_error(404, message=error_message) return registered_handler = self.routes[key] if callable(registered_handler): try: handler_result = registered_handler(handler, request_details) if len(handler_result) == 2: response, status = handler_result additional_headers = None elif len(handler_result) == 3: response, status, additional_headers = handler_result except Exception: handler.send_error(500) raise else: if len(registered_handler) == 2: response, status = registered_handler additional_headers = None elif len(registered_handler) == 3: response, status, additional_headers = registered_handler accept_header = request_details.headers.get( "accept", "application/json" ) if ( accept_header.startswith("application/json") or ( accept_header.startswith("application/") and "vnd." in accept_header and "+json" in accept_header ) or accept_header == "*/*" ): content_type = "application/json" elif accept_header.startswith("application/x-www-form-urlencoded"): content_type = "application/x-www-form-urlencoded" else: handler.send_error( 415, f"Unsupported accept header: {accept_header}" ) return data = response.encode() handler.send_response(status) handler.send_header("Content-Type", content_type) handler.send_header("Content-Length", str(len(data))) if additional_headers is not None: for header, value in additional_headers.items(): handler.send_header(header, value) handler.end_headers() handler.wfile.write(data) def start(self): assert not hasattr(self, "_http_runner") self._http_runner = threading.Thread(target=self._http_worker) self._http_runner.start() self.has_started.wait() self.url = f"http://{self._address[0]}:{self._address[1]}/" def __enter__(self): self.start() return self def _http_worker(self): self._http_server = http.server.HTTPServer( ("localhost", self._port), self.handler_type ) self._http_server.owner = self self._address = self._http_server.server_address self.has_started.set() self._http_server.serve_forever(poll_interval=0.01) self._http_server.server_close() def stop(self): self._http_server.shutdown() if self._http_runner is not None: self._http_runner.join(timeout=60) if self._http_runner.is_alive(): raise RuntimeError("Mock HTTP server failed to stop") self._http_runner = None def __exit__(self, *exc): self.stop() self.url = None ================================================ FILE: edb/testbase/lang.py ================================================ # mypy: ignore-errors # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any, Optional import typing import functools import os import re import unittest from edb.common import span from edb.common import debug from edb.common import devmode from edb.common import markup from edb import buildmeta from edb import errors from edb import edgeql from edb.edgeql import ast as qlast from edb.edgeql import parser as qlparser from edb.edgeql.parser import grammar as qlgrammar from edb.edgeql import qltypes from edb.server import compiler as edbcompiler from edb.schema import ddl as s_ddl from edb.schema import delta as sd from edb.schema import migrations as s_migrations # noqa from edb.schema import reflection as s_refl from edb.schema import schema as s_schema from edb.schema import std as s_std from edb.schema import utils as s_utils from edb.schema import modules as s_mod def must_fail(exc_type, exc_msg_re=None, **kwargs): """A decorator to ensure that the test fails with a specific exception. If exc_msg_re is passed, assertRaisesRegex will be used to match the exception message. Example: @must_fail(EdgeQLSyntaxError, 'non-default argument follows', line=2, col=61) def test_edgeql_syntax_1(self): ... """ def wrap(func): args = (exc_type,) if exc_msg_re is not None: args += (exc_msg_re,) _set_spec(func, 'must_fail', (args, kwargs)) return func return wrap def _set_spec(func, name, attrs): try: spec = func.test_spec except AttributeError: spec = func.test_spec = {} assert name not in spec spec[name] = attrs class DocTestMeta(type(unittest.TestCase)): def __new__(mcls, name, bases, dct): for attr, meth in tuple(dct.items()): if attr.startswith('test_') and meth.__doc__: @functools.wraps(meth) def wrapper(self, meth=meth, doc=meth.__doc__): spec = getattr(meth, 'test_spec', {}) spec['test_name'] = meth.__name__ if doc: output = error = None source, _, output = doc.partition('\n% OK %') if not output: source, _, error = doc.partition('\n% ERROR %') if not error: output = None else: output = error else: source = output = None self._run_test(source=source, spec=spec, expected=output) dct[attr] = wrapper return super().__new__(mcls, name, bases, dct) class BaseDocTest(unittest.TestCase, metaclass=DocTestMeta): parser_debug_flag = '' re_filter: Optional[typing.Pattern[str]] = None def _run_test(self, *, source, spec=None, expected=None): if spec and 'must_fail' in spec: spec_args, spec_kwargs = spec['must_fail'] if len(spec_args) == 1: assertRaises = self.assertRaises else: assertRaises = self.assertRaisesRegex with assertRaises(*spec_args) as cm: return self.run_test(source=source, spec=spec, expected=expected) if cm.exception: exc = cm.exception for attr_name, expected_val in spec_kwargs.items(): val = getattr(exc, attr_name) if val != expected_val: raise AssertionError( f'must_fail: attribute {attr_name!r} is ' f'{val} (expected is {expected_val!r})') from exc else: return self.run_test(source=source, spec=spec, expected=expected) def run_test(self, *, source, spec, expected=None): raise NotImplementedError def assert_equal( self, expected, result, *, re_filter: Optional[str] = None, message: Optional[str] = None ) -> None: if re_filter is None: re_filter = self.re_filter if re_filter is not None: expected_stripped = re_filter.sub('', expected).lower() result_stripped = re_filter.sub('', result).lower() else: expected_stripped = expected.lower() result_stripped = result.lower() self.assertEqual( expected_stripped, result_stripped, (f'{message if message else ""}' + f'\nexpected:\n{expected}\nreturned:\n{result}') ) class BaseSyntaxTest(BaseDocTest): ast_to_source: Optional[Any] = None markup_dump_lexer: Optional[str] = None @classmethod def get_grammar_token(cls) -> type[qlgrammar.tokens.GrammarToken]: raise NotImplementedError def run_test(self, *, source, spec, expected=None): debug = bool(os.environ.get(self.parser_debug_flag)) if debug: markup.dump_code(source, lexer=self.markup_dump_lexer) inast = qlparser.parse(self.get_grammar_token(), source) if debug: markup.dump(inast) # make sure that the AST has context span.SpanValidator().visit(inast) processed_src = self.ast_to_source(inast) if debug: markup.dump_code(processed_src, lexer=self.markup_dump_lexer) expected_src = source if expected is None else expected self.assert_equal(expected_src, processed_src) _std_schema = None _refl_schema = None _schema_class_layout = None def _load_std_schema(): global _std_schema if _std_schema is None: std_dirs_hash = buildmeta.hash_dirs(s_std.CACHE_SRC_DIRS) schema = None if devmode.is_in_dev_mode(): schema = buildmeta.read_data_cache( std_dirs_hash, 'transient-stdschema.pickle') if schema is None: schema = s_schema.EMPTY_SCHEMA for modname in [*s_schema.STD_SOURCES, *s_schema.TESTMODE_SOURCES]: schema = s_std.load_std_module(schema, modname) schema, _ = s_std.make_schema_version(schema) schema, _ = s_std.make_global_schema_version(schema) if devmode.is_in_dev_mode(): buildmeta.write_data_cache( schema, std_dirs_hash, 'transient-stdschema.pickle') _std_schema = schema return _std_schema def _load_reflection_schema(): global _refl_schema global _schema_class_layout if _refl_schema is None: std_dirs_hash = buildmeta.hash_dirs(s_std.CACHE_SRC_DIRS) cache = None if devmode.is_in_dev_mode(): cache = buildmeta.read_data_cache( std_dirs_hash, 'transient-reflschema.pickle') if cache is not None: reflschema, classlayout = cache else: std_schema = _load_std_schema() reflection = s_refl.generate_structure(std_schema) classlayout = reflection.class_layout context = sd.CommandContext(stdmode=True) reflschema = reflection.intro_schema_delta.apply( std_schema, context) if devmode.is_in_dev_mode(): buildmeta.write_data_cache( (reflschema, classlayout), std_dirs_hash, 'transient-reflschema.pickle', ) _refl_schema = reflschema _schema_class_layout = classlayout return _refl_schema, _schema_class_layout def new_compiler(): std_schema = _load_std_schema() refl_schema, layout = _load_reflection_schema() return edbcompiler.new_compiler( std_schema=std_schema, reflection_schema=refl_schema, schema_class_layout=layout, ) class BaseSchemaTest(BaseDocTest): DEFAULT_MODULE = 'default' SCHEMA: Optional[str] = None schema: s_schema.Schema @classmethod def setUpClass(cls): script = cls.get_schema_script() if script is not None: cls.schema = cls.run_ddl(_load_std_schema(), script) else: cls.schema = _load_std_schema() @classmethod def run_ddl(cls, schema, ddl, default_module=s_mod.DEFAULT_MODULE_ALIAS): statements = edgeql.parse_block(ddl) current_schema = schema target_schema = None migration_schema = None migration_target = None migration_script = [] for stmt in statements: if isinstance(stmt, qlast.StartMigration): # START MIGRATION if target_schema is None: target_schema = _load_std_schema() migration_target, _ = s_ddl.apply_sdl( stmt.target, base_schema=target_schema, testmode=True, ) migration_schema = current_schema ddl_plan = None elif isinstance(stmt, qlast.PopulateMigration): # POPULATE MIGRATION if migration_target is None: raise errors.QueryError( 'unexpected POPULATE MIGRATION:' ' not currently in a migration block', span=stmt.span, ) migration_diff = s_ddl.delta_schemas( migration_schema, migration_target, ) if debug.flags.delta_plan: debug.header('Populate Migration Diff') debug.dump(migration_diff, schema=schema) new_ddl = s_ddl.ddlast_from_delta( migration_schema, migration_target, migration_diff, ) migration_script.extend(new_ddl) if debug.flags.delta_plan: debug.header('Populate Migration DDL AST') text = [] for cmd in new_ddl: debug.dump(cmd) text.append(edgeql.generate_source(cmd, pretty=True)) debug.header('Populate Migration DDL Text') debug.dump_code(';\n'.join(text) + ';') elif isinstance(stmt, qlast.DescribeCurrentMigration): # This is silly, and we don't bother doing all the work, # but try to catch when doing the JSON thing wouldn't work. if stmt.language is qltypes.DescribeLanguage.JSON: guided_diff = s_ddl.delta_schemas( migration_schema, migration_target, generate_prompts=True, ) s_ddl.statements_from_delta( schema, migration_target, guided_diff, ) elif isinstance(stmt, qlast.CommitMigration): if migration_target is None: raise errors.QueryError( 'unexpected COMMIT MIGRATION:' ' not currently in a migration block', span=stmt.span, ) last_migration = current_schema.get_last_migration() if last_migration: last_migration_ref = s_utils.name_to_ast_ref( last_migration.get_name(current_schema), ) else: last_migration_ref = None create_migration = qlast.CreateMigration( body=qlast.NestedQLBlock(commands=migration_script), parent=last_migration_ref, ) ddl_plan = s_ddl.delta_from_ddl( create_migration, schema=migration_schema, modaliases={None: default_module}, testmode=True, ) if debug.flags.delta_plan: debug.header('Delta Plan') debug.dump(ddl_plan, schema=schema) migration_schema = None migration_target = None migration_script = [] elif isinstance(stmt, qlast.DDLCommand): if migration_target is not None: migration_script.append(stmt) ddl_plan = None else: ddl_plan = s_ddl.delta_from_ddl( stmt, schema=current_schema, modaliases={None: default_module}, testmode=True, ) if debug.flags.delta_plan: debug.header('Delta Plan') debug.dump(ddl_plan, schema=schema) else: raise ValueError( f'unexpected {stmt!r} in compiler setup script') if ddl_plan is not None: context = sd.CommandContext() context.testmode = True current_schema = ddl_plan.apply(current_schema, context) return current_schema @classmethod def load_schema( cls, source: str, modname: Optional[str] = None ) -> s_schema.Schema: if not modname: modname = cls.DEFAULT_MODULE sdl_schema = qlparser.parse_sdl(f'module {modname} {{ {source} }}') schema = _load_std_schema() return s_ddl.apply_sdl( sdl_schema, base_schema=schema, )[0] @classmethod def get_schema_script(cls): script = '' # look at all SCHEMA entries and potentially create multiple modules schema = [] for name in dir(cls): val = getattr(cls, name) m = re.match(r'^SCHEMA(?:_(\w+))?', name) if m and val: module_name = (m.group(1) or 'default').lower().replace('_', '::') if '\n' in val: # Inline schema source module = val else: with open(val, 'r') as sf: module = sf.read() schema.append(f'\nmodule {module_name} {{ {module} }}') if schema: script += f'\nSTART MIGRATION' script += f' TO {{ {"".join(schema)} }};' script += f'\nPOPULATE MIGRATION;' script += f'\nCOMMIT MIGRATION;' return script.strip(' \n') class BaseSchemaLoadTest(BaseSchemaTest): def run_test(self, *, source, spec, expected=None): self.load_schema(source) class BaseEdgeQLCompilerTest(BaseSchemaTest): @classmethod def get_schema_script(cls): script = super().get_schema_script() if not script: raise ValueError( 'compiler test cases must define at least one ' 'schema in the SCHEMA[_MODNAME] class attribute.') return script ================================================ FILE: edb/testbase/proc.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import asyncio import socket import sys import unittest from edb.common import devmode from . import server exec(sys.argv[1], globals(), locals()) class ProcTest(server.TestCase): def notify_parent(self, mark): self.parent_writer.write(str(mark).encode() + b"\n") async def wait_for_parent(self, mark): self.assertEqual( (await self.parent_reader.readline()).strip(), str(mark).encode(), ) @classmethod def setUpClass(cls): super().setUpClass() async def _setup(): sock = socket.fromfd( int(sys.argv[3]), socket.AF_UNIX, socket.SOCK_STREAM ) cls.parent_reader, cls.parent_writer = ( await asyncio.open_connection(sock=sock) ) cls.loop.run_until_complete(_setup()) exec(sys.argv[2], globals(), locals()) def main(): cov_config = devmode.CoverageConfig.from_environ() if cov_config: cov = cov_config.new_coverage_object() cov.start() try: unittest.main(argv=sys.argv[:1], verbosity=2) finally: cov.stop() cov.save() else: unittest.main(argv=sys.argv[:1], verbosity=2) if __name__ == "__main__": main() ================================================ FILE: edb/testbase/protocol/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ================================================ FILE: edb/testbase/protocol/test.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb.testbase import server from edb.protocol import protocol # type: ignore from edb.protocol.protocol import Connection class ProtocolTestCase(server.DatabaseTestCase): PARALLELISM_GRANULARITY = 'database' BASE_TEST_CLASS = True con: Connection def setUp(self): self.con = self.loop.run_until_complete( protocol.new_connection( **self.get_connect_args(database=self.get_database_name()) ) ) def tearDown(self): try: self.loop.run_until_complete( self.con.aclose() ) finally: self.con = None ================================================ FILE: edb/testbase/serutils.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import dataclasses import datetime import decimal import functools import uuid import edgedb @functools.singledispatch def serialize(o): raise TypeError(f'cannot serialize type {type(o)}') @serialize.register def _tuple(o: edgedb.Tuple): return [serialize(el) for el in o] @serialize.register def _namedtuple(o: edgedb.NamedTuple): return {attr: serialize(getattr(o, attr)) for attr in dir(o)} @serialize.register def _object(o: edgedb.Object): # We iterate over dataclasses.fields(o) (instead of dir(o)) # because it contains both regular pointers and link properties, # and is I think the only current way to extract the names of all # the link properties attrs = [field.name for field in dataclasses.fields(o)] return {attr: serialize(getattr(o, attr)) for attr in attrs} @serialize.register(edgedb.Set) @serialize.register(edgedb.Array) def _set(o): return [serialize(el) for el in o] @serialize.register(uuid.UUID) def _stringify(o): return str(o) @serialize.register(int) @serialize.register(float) @serialize.register(str) @serialize.register(bytes) @serialize.register(bool) @serialize.register(type(None)) @serialize.register(decimal.Decimal) @serialize.register(datetime.timedelta) @serialize.register(edgedb.RelativeDuration) @serialize.register(edgedb.DateDuration) def _scalar(o): return o @serialize.register def _datetime(o: datetime.datetime): return o.isoformat() @serialize.register def _date(o: datetime.date): return o.isoformat() @serialize.register def _time(o: datetime.time): return o.isoformat() @serialize.register def _enum(o: edgedb.EnumValue): return str(o) @serialize.register def _record(o: edgedb.Record): return {k: serialize(v) for k, v in o.as_dict().items()} @serialize.register def _range(o: edgedb.Range): return { 'lower': serialize(o.lower), 'inc_lower': o.inc_lower, 'upper': serialize(o.upper), 'inc_upper': o.inc_upper, 'empty': o.is_empty(), } @serialize.register def _multirane(o: edgedb.MultiRange): return [serialize(el) for el in o] @serialize.register def _cfg_memory(o: edgedb.ConfigMemory): return str(o) ================================================ FILE: edb/testbase/server.py ================================================ # mypy: ignore-errors # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import ( Any, Optional, Iterable, Literal, Sequence, NamedTuple, TYPE_CHECKING, ) import typing import asyncio import atexit import base64 import contextlib import dataclasses import functools import heapq import http import http.client import inspect import json import os import pathlib import random import re import secrets import shlex import socket import ssl import struct import subprocess import sys import tempfile import time import unittest import urllib import edgedb from edb.edgeql import quote as qlquote from edb.server import args as edgedb_args from edb.testbase import cluster as edgedb_cluster from edb.server import pgcluster from edb.server import defines as edgedb_defines from edb.server import auth from edb.server.pgconnparams import ConnectionParams from edb.common import assert_data_shape from edb.common import devmode from edb.common import debug from edb.common import retryloop from edb.common import secretkey from edb import buildmeta from edb import protocol from edb.protocol import protocol as test_protocol from edb.testbase import serutils from edb.testbase import connection as tconn if TYPE_CHECKING: import asyncpg DatabaseName = str SetupScript = str def _add_test(result, test): # test is a tuple of the same test method that may zREPEAT cls = type(test[0]) try: methods, repeat_methods = result[cls] except KeyError: # put zREPEAT tests in a separate list methods = [] repeat_methods = [] result[cls] = methods, repeat_methods methods.append(test[0]) if len(test) > 1: repeat_methods.extend(test[1:]) def _merge_results(result): # make sure all the zREPEAT tests comes in the end return {k: v[0] + v[1] for k, v in result.items()} def _get_test_cases(tests): result = {} for test in tests: if isinstance(test, unittest.TestSuite): result.update(_get_test_cases(test._tests)) elif not getattr(test, '__unittest_skip__', False): _add_test(result, (test,)) return result def get_test_cases(tests): return _merge_results(_get_test_cases(tests)) bag = assert_data_shape.bag generate_jwk = auth.generate_jwk generate_tls_cert = secretkey.generate_tls_cert class CustomSNI_HTTPSConnection(http.client.HTTPSConnection): def __init__(self, *args, server_hostname=..., **kwargs): super().__init__(*args, **kwargs) self.server_hostname = server_hostname def connect(self): super(http.client.HTTPSConnection, self).connect() if self._tunnel_host: server_hostname = self._tunnel_host elif self.server_hostname is not ...: server_hostname = self.server_hostname else: server_hostname = self.host self.sock = self._context.wrap_socket(self.sock, server_hostname=server_hostname) def true_close(self): self.close() class StubbornHttpConnection(CustomSNI_HTTPSConnection): def close(self): # Don't actually close the connection. This allows us to # test keep-alive and "Connection: close" headers. pass def true_close(self): http.client.HTTPConnection.close(self) class TestCaseMeta(type(unittest.TestCase)): _database_names = set() @staticmethod def _iter_methods(bases, ns): for base in bases: for methname in dir(base): if not methname.startswith('test_'): continue meth = getattr(base, methname) if not inspect.iscoroutinefunction(meth): continue yield methname, meth for methname, meth in ns.items(): if not methname.startswith('test_'): continue if not inspect.iscoroutinefunction(meth): continue yield methname, meth @classmethod def wrap(mcls, meth, is_repeat=False): @functools.wraps(meth) def wrapper(self, *args, __meth__=meth, **kwargs): try_no = 1 if is_repeat and not getattr(self, 'TRANSACTION_ISOLATION', False): raise unittest.SkipTest() self.is_repeat = is_repeat while True: try: # There might be unobvious serializability # anomalies across the test suite, so, rather # than hunting them down every time, simply # retry the test. self.loop.run_until_complete( __meth__(self, *args, **kwargs)) except (edgedb.TransactionSerializationError, edgedb.TransactionDeadlockError): if ( try_no == 10 # Only do a retry loop when we have a transaction or not getattr(self, 'TRANSACTION_ISOLATION', False) ): raise else: self.loop.run_until_complete(self.xact.rollback()) self.loop.run_until_complete(asyncio.sleep( min((2 ** try_no) * 0.1, 10) + random.randrange(100) * 0.001 )) self.xact = self.con.transaction() self.loop.run_until_complete(self.xact.start()) try_no += 1 else: break return wrapper @classmethod def add_method(mcls, methname, ns, meth): ns[methname] = mcls.wrap(meth) # If EDGEDB_TEST_REPEATS is set, duplicate all the tests. # This is valuable because it should exercise the function # cache. if ( os.environ.get('EDGEDB_TEST_REPEATS', None) and methname.startswith('test_') ): new = methname.replace('test_', 'test_zREPEAT_', 1) ns[new] = mcls.wrap(meth, is_repeat=True) def __new__(mcls, name, bases, ns): for methname, meth in mcls._iter_methods(bases, ns.copy()): if methname in ns: del ns[methname] mcls.add_method(methname, ns, meth) cls = super().__new__(mcls, name, bases, ns) if not ns.get('BASE_TEST_CLASS') and hasattr(cls, 'get_database_name'): dbname = cls.get_database_name() if name in mcls._database_names: raise TypeError( f'{name} wants duplicate database name: {dbname}') mcls._database_names.add(name) return cls class TestCase(unittest.TestCase, metaclass=TestCaseMeta): is_repeat: bool = False @classmethod def setUpClass(cls): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) cls.loop = loop @classmethod def tearDownClass(cls): cls.loop.close() asyncio.set_event_loop(None) @classmethod def uses_server(cls) -> bool: return True def add_fail_notes(self, **kwargs): if getattr(self, 'fail_notes', None) is None: self.fail_notes = {} self.fail_notes.update(kwargs) @contextlib.contextmanager def annotate(self, **kwargs): # Annotate the test in case the nested block of code fails. try: yield except Exception: self.add_fail_notes(**kwargs) raise @contextlib.contextmanager def assertRaisesRegex(self, exception, regex, msg=None, **kwargs): with super().assertRaisesRegex(exception, regex, msg=msg): try: yield except BaseException as e: if isinstance(e, exception): for attr_name, expected_val in kwargs.items(): val = getattr(e, attr_name) if val != expected_val: raise self.failureException( f'{exception.__name__} context attribute ' f'{attr_name!r} is {val} (expected ' f'{expected_val!r})') from e raise @staticmethod def try_until_succeeds( *, ignore: type[Exception] | tuple[type[Exception]] | None = None, ignore_regexp: str | None = None, delay: float=0.5, timeout: float=5 ): """Retry a block of code a few times ignoring the specified errors. Example: async for tr in self.try_until_succeeds( ignore=edgedb.AuthenticationError): async with tr: await edgedb.connect(...) """ if ignore is None and ignore_regexp is None: raise ValueError('Expect at least one of ignore or ignore_regexp') return retryloop.RetryLoop( backoff=retryloop.const_backoff(delay), timeout=timeout, ignore=ignore, ignore_regexp=ignore_regexp, ) @staticmethod def try_until_fails( *, wait_for: type[Exception] | tuple[type[Exception]] | None = None, wait_for_regexp: str | None = None, delay: float=0.5, timeout: float=5 ): """Retry a block of code a few times until the specified error happens. Example: async for tr in self.try_until_fails( wait_for=edgedb.AuthenticationError): async with tr: await edgedb.connect(...) """ if wait_for is None and wait_for_regexp is None: raise ValueError( 'Expect at least one of wait_for or wait_for_regexp' ) return retryloop.RetryLoop( backoff=retryloop.const_backoff(delay), timeout=timeout, wait_for=wait_for, wait_for_regexp=wait_for_regexp, ) def addCleanup(self, func, *args, **kwargs): @functools.wraps(func) def cleanup(): res = func(*args, **kwargs) if inspect.isawaitable(res): self.loop.run_until_complete(res) super().addCleanup(cleanup) def __getstate__(self): # TestCases get pickled when run in in separate OS processes # via `edb test -jN`. If they reference any unpickleable objects, # the test engine crashes with no indication why and on what test. # That said, most of the TestCases' guts are not needed for the # test results renderer, so we only keep the essential attributes # here. outcome = self._outcome if outcome is not None and getattr(outcome, "errors", []): # We don't use `test._outcome` to render errors in # our renderers. outcome.errors = [] return { '_testMethodName': self._testMethodName, '_outcome': outcome, '_testMethodDoc': self._testMethodDoc, '_subtest': self._subtest, '_cleanups': [], '_type_equality_funcs': self._type_equality_funcs, 'fail_notes': getattr(self, 'fail_notes', None), } @contextlib.contextmanager def assertChange( self, measure_fn: typing.Callable[[], int | float], expected_change: int | float ): before = measure_fn() try: yield finally: after = measure_fn() change = after - before self.assertEqual(expected_change, change) class RollbackException(Exception): pass class RollbackChanges: def __init__(self, con): self._conn = con async def __aenter__(self): self._tx = self._conn.transaction() await self._tx.start() async def __aexit__(self, exc_type, exc, tb): await self._tx.rollback() class TestCaseWithHttpClient(TestCase): @classmethod def get_api_prefix(cls): return '' @classmethod @contextlib.contextmanager def http_con( cls, server, keep_alive=True, client_cert_file=None, client_key_file=None, **kwargs, ): conn_args = server.get_connect_args() tls_context = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, cafile=conn_args["tls_ca_file"], ) tls_context.check_hostname = False if any((client_cert_file, client_key_file)): tls_context.load_cert_chain(client_cert_file, client_key_file) if keep_alive: ConCls = StubbornHttpConnection else: ConCls = CustomSNI_HTTPSConnection con = ConCls( conn_args["host"], conn_args["port"], context=tls_context, **kwargs, ) con.connect() try: yield con finally: con.true_close() @classmethod def http_con_send_request( cls, http_con: http.client.HTTPConnection, params: Optional[dict[str, str]] = None, *, prefix: Optional[str] = None, headers: Optional[dict[str, str]] = None, method: str = "GET", body: bytes = b"", path: str = "", ): url = f'https://{http_con.host}:{http_con.port}' if prefix is None: prefix = cls.get_api_prefix() if prefix: url = f'{url}{prefix}' if path: url = f'{url}/{path}' if params is not None: url = f'{url}?{urllib.parse.urlencode(params)}' if headers is None: headers = {} http_con.request(method, url, body=body, headers=headers) @classmethod def http_con_read_response( cls, http_con: http.client.HTTPConnection, ) -> tuple[bytes, dict[str, str], int]: resp = http_con.getresponse() resp_body = resp.read() resp_headers = {k.lower(): v for k, v in resp.getheaders()} return resp_body, resp_headers, resp.status @classmethod def http_con_request( cls, http_con: http.client.HTTPConnection, params: Optional[dict[str, str]] = None, *, prefix: Optional[str] = None, headers: Optional[dict[str, str]] = None, method: str = "GET", body: bytes = b"", path: str = "", ) -> tuple[bytes, dict[str, str], int]: cls.http_con_send_request( http_con, params, prefix=prefix, headers=headers, method=method, body=body, path=path, ) return cls.http_con_read_response(http_con) @classmethod def http_con_json_request( cls, http_con: http.client.HTTPConnection, params: Optional[dict[str, str]] = None, *, prefix: Optional[str] = None, headers: Optional[dict[str, str]] = None, body: Any, path: str = "", ): response, headers, status = cls.http_con_request( http_con, params, method="POST", body=json.dumps(body).encode(), prefix=prefix, headers={ "Content-Type": "application/json", **(headers or {}), }, path=path, ) if status == http.HTTPStatus.OK: result = json.loads(response) else: result = None return result, headers, status @classmethod def http_con_binary_request( cls, http_con: http.client.HTTPConnection, query: str, proto_ver=edgedb_defines.CURRENT_PROTOCOL, bearer_token: Optional[str] = None, user: str = "edgedb", database: str = "main", ): proto_ver_str = f"v_{proto_ver[0]}_{proto_ver[1]}" mime_type = f"application/x.edgedb.{proto_ver_str}.binary" headers = {"Content-Type": mime_type, "X-EdgeDB-User": user} if bearer_token: headers["Authorization"] = f"Bearer {bearer_token}" content, headers, status = cls.http_con_request( http_con, method="POST", path=f"db/{database}", prefix="", body=protocol.Execute( annotations=[], allowed_capabilities=protocol.Capability.ALL, compilation_flags=protocol.CompilationFlag(0), implicit_limit=0, command_text=query, input_language=protocol.InputLanguage.EDGEQL, output_format=protocol.OutputFormat.JSON, expected_cardinality=protocol.Cardinality.AT_MOST_ONE, input_typedesc_id=b"\0" * 16, output_typedesc_id=b"\0" * 16, state_typedesc_id=b"\0" * 16, arguments=b"", state_data=b"", ).dump() + protocol.Sync().dump(), headers=headers, ) content = memoryview(content) uint32_unpack = struct.Struct("!L").unpack msgs = [] while content: mtype = content[0] (msize,) = uint32_unpack(content[1:5]) msg = protocol.ServerMessage.parse(mtype, content[5: msize + 1]) msgs.append(msg) content = content[msize + 1:] return msgs, headers, status _default_cluster = None async def init_cluster( data_dir=None, backend_dsn=None, *, cleanup_atexit=True, init_settings=None, security=edgedb_args.ServerSecurityMode.Strict, http_endpoint_security=edgedb_args.ServerEndpointSecurityMode.Optional, compiler_pool_mode=edgedb_args.CompilerPoolMode.Fixed, ) -> edgedb_cluster.BaseCluster: if data_dir is not None and backend_dsn is not None: raise ValueError( "data_dir and backend_dsn cannot be set at the same time") if init_settings is None: init_settings = {} log_level = 's' if not debug.flags.server else 'd' if backend_dsn: cluster = edgedb_cluster.TempClusterWithRemotePg( backend_dsn, testmode=True, log_level=log_level, data_dir_prefix='edb-test-', security=security, http_endpoint_security=http_endpoint_security, compiler_pool_mode=compiler_pool_mode, ) destroy = True elif data_dir is None: cluster = edgedb_cluster.TempCluster( testmode=True, log_level=log_level, data_dir_prefix='edb-test-', security=security, http_endpoint_security=http_endpoint_security, compiler_pool_mode=compiler_pool_mode, ) destroy = True else: cluster = edgedb_cluster.Cluster( testmode=True, data_dir=data_dir, log_level=log_level, security=security, http_endpoint_security=http_endpoint_security, compiler_pool_mode=compiler_pool_mode, ) destroy = False pg_cluster = await cluster._get_pg_cluster() if await pg_cluster.get_status() == 'not-initialized': await cluster.init(server_settings=init_settings) await cluster.start(port=0) await cluster.set_test_config() await cluster.set_superuser_password('test') if cleanup_atexit: atexit.register(_shutdown_cluster, cluster, destroy=destroy) return cluster def _start_cluster( *, loop: asyncio.AbstractEventLoop, cleanup_atexit=True, http_endpoint_security=None, ): global _default_cluster if _default_cluster is None: cluster_addr = os.environ.get('EDGEDB_TEST_CLUSTER_ADDR') if cluster_addr: conn_spec = json.loads(cluster_addr) _default_cluster = edgedb_cluster.RunningCluster(**conn_spec) else: # This branch is not usually used - `edb test` will call # init_cluster() separately and set EDGEDB_TEST_CLUSTER_ADDR data_dir = os.environ.get('EDGEDB_TEST_DATA_DIR') backend_dsn = os.environ.get('EDGEDB_TEST_BACKEND_DSN') _default_cluster = loop.run_until_complete( init_cluster( data_dir=data_dir, backend_dsn=backend_dsn, cleanup_atexit=cleanup_atexit, http_endpoint_security=http_endpoint_security, ) ) return _default_cluster def _shutdown_cluster(cluster, *, destroy=True): global _default_cluster _default_cluster = None if cluster is not None: cluster.stop() if destroy: cluster.destroy() def _fetch_metrics(host: str, port: int, sslctx=None) -> str: return _call_system_api( host, port, '/metrics', return_json=False, sslctx=sslctx ) def _fetch_server_info(host: str, port: int) -> dict[str, Any]: return _call_system_api(host, port, '/server-info') def _call_system_api( host: str, port: int, path: str, return_json=True, sslctx=None, **kwargs, ): if sslctx is None: con = http.client.HTTPConnection(host, port, **kwargs) else: con = CustomSNI_HTTPSConnection(host, port, context=sslctx, **kwargs) con.connect() try: con.request( 'GET', f'http://{host}:{port}{path}' ) resp = con.getresponse() if resp.status != 200: err = resp.read().decode() raise AssertionError( f'{path} returned non 200 HTTP status: {resp.status}\n\t{err}' ) rv = resp.read().decode() if return_json: rv = json.loads(rv) return rv finally: con.close() def parse_metrics(metrics: str) -> dict[str, float]: res = {} for line in metrics.splitlines(): if line.startswith('#') or ' ' not in line: continue key, _, val = line.partition(' ') res[key] = float(val) return res def _extract_background_errors(metrics: str) -> str | None: non_zero = [] for label, total in parse_metrics(metrics).items(): if label.startswith('edgedb_server_background_errors_total'): if total: non_zero.append( f'non-zero {label!r} metric: {total}' ) if non_zero: return '\n'.join(non_zero) else: return None async def drop_db(conn, dbname): await conn.execute(f'DROP BRANCH {dbname}') class ClusterTestCase(TestCaseWithHttpClient): BASE_TEST_CLASS = True backend_dsn: Optional[str] = None # Some tests may want to manage transactions manually, # or affect non-transactional state, in which case # TRANSACTION_ISOLATION must be set to False TRANSACTION_ISOLATION = True # By default, tests from the same testsuite may be ran in parallel in # several test worker processes. However, certain cases might exhibit # pathological locking behavior, or are parallel-unsafe altogether, in # which case PARALLELISM_GRANULARITY must be set to 'database', 'suite', # or 'system'. The 'database' granularity signals that no two runners # may execute tests on the same database in parallel, although the tests # may still run on copies of the test database. The 'suite' granularity # means that only one test worker is allowed to execute tests from this # suite. Finally, the 'system' granularity means that the test suite # is not parallelizable at all and must run sequentially with respect # to *all other* suites with 'system' granularity. PARALLELISM_GRANULARITY = 'default' # Turns on "Gel developer" mode which allows using restricted # syntax like USING SQL and similar. It allows modifying standard # library (e.g. declaring casts). INTERNAL_TESTMODE = True # Turns off query cache recompilation on DDL ENABLE_RECOMPILATION = False # Setup and teardown commands that run per test PER_TEST_SETUP: Sequence[str] = () PER_TEST_TEARDOWN: Sequence[str] = () @classmethod def setUpClass(cls): super().setUpClass() cls.cluster = _start_cluster( loop=cls.loop, cleanup_atexit=True, http_endpoint_security=( edgedb_args.ServerEndpointSecurityMode.Optional), ) cls.has_create_database = cls.cluster.has_create_database() cls.has_create_role = cls.cluster.has_create_role() cls.is_superuser = cls.has_create_database and cls.has_create_role cls.backend_dsn = os.environ.get('EDGEDB_TEST_BACKEND_DSN') if getattr(cls, 'BACKEND_SUPERUSER', False): if not cls.is_superuser: raise unittest.SkipTest('skipped due to lack of superuser') @classmethod async def tearDownSingleDB(cls): await cls.con.execute("RESET SCHEMA TO initial;") @classmethod def fetch_metrics(cls) -> str: assert cls.cluster is not None conargs = cls.cluster.get_connect_args() host, port = conargs['host'], conargs['port'] ctx = ssl.create_default_context() ctx.load_verify_locations(conargs['tls_ca_file']) return _fetch_metrics(host, port, sslctx=ctx) @classmethod def get_connect_args( cls, *, cluster=None, database=None, user=None, password=None, secret_key=None, ): if password is None and secret_key is None: password = "test" if cluster is None: cluster = cls.cluster if database is None: database = edgedb_defines.EDGEDB_SUPERUSER_DB if user is None: user = edgedb_defines.EDGEDB_SUPERUSER conargs = cluster.get_connect_args().copy() conargs.update(dict(user=user, password=password, database=database, secret_key=secret_key)) return conargs @classmethod def make_auth_header(cls, user=None, password=None): # urllib *does* have actual support for basic auth but it is so much # more annoying than just doing it yourself... conargs = cls.get_connect_args(user=user, password=password) username = conargs.get('user') password = conargs.get('password') key = f'{username}:{password}'.encode('ascii') basic_header = f'Basic {base64.b64encode(key).decode("ascii")}' return basic_header @classmethod def get_parallelism_granularity(cls): if cls.PARALLELISM_GRANULARITY == 'default': if cls.TRANSACTION_ISOLATION: return 'default' else: return 'database' else: return cls.PARALLELISM_GRANULARITY @classmethod def uses_database_copies(cls): return ( os.environ.get('EDGEDB_TEST_PARALLEL') and cls.get_parallelism_granularity() == 'database' ) def ensure_no_background_server_errors(self): metrics = self.fetch_metrics() errors = _extract_background_errors(metrics) if errors: raise AssertionError( f'{self._testMethodName!r}:\n\n{errors}' ) @contextlib.asynccontextmanager async def assertRaisesRegexTx(self, exception, regex, msg=None, **kwargs): """A version of assertRaisesRegex with automatic transaction recovery """ with super().assertRaisesRegex(exception, regex, msg=msg, **kwargs): try: tx = self.con.transaction() await tx.start() yield finally: await tx.rollback() @classmethod @contextlib.contextmanager def http_con( cls, server=None, keep_alive=True, client_cert_file=None, client_key_file=None, **kwargs, ): if server is None: server = cls with super().http_con( server, keep_alive=keep_alive, client_cert_file=client_cert_file, client_key_file=client_key_file, **kwargs, ) as http_con: yield http_con @property def http_addr(self) -> str: conn_args = self.get_connect_args() url = f'https://{conn_args["host"]}:{conn_args["port"]}' prefix = self.get_api_prefix() if prefix: url = f'{url}{prefix}' return url @property def tls_context(self) -> ssl.SSLContext: conn_args = self.get_connect_args() tls_context = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, cafile=conn_args["tls_ca_file"], ) tls_context.check_hostname = False return tls_context def ignore_warnings(warning_message=None): def w(f): async def wf(self, *args, **kwargs): with self.ignore_warnings(warning_message): return await f(self, *args, **kwargs) return wf return w class ConnectedTestCase(ClusterTestCase): BASE_TEST_CLASS = True NO_FACTOR = True WARN_FACTOR = False con: tconn.Connection @classmethod def setUpClass(cls): super().setUpClass() cls.loop.run_until_complete(cls.setup_and_connect()) @classmethod def tearDownClass(cls): try: cls.loop.run_until_complete(cls.teardown_and_disconnect()) finally: super().tearDownClass() @contextlib.contextmanager def ignore_warnings(self, warning_message=None): with self.con.capture_warnings() as warnings: yield if warning_message is not None: for warning in warnings: # If it doesn't match the re, send it back to the con. # It might get raised or it might get captured by an # enclosing call to capture_warnings/ignore_warnings. if not re.search(warning_message, str(warning)): self.con._get_warning_handler()([warning], None) @classmethod async def setup_and_connect(cls): cls.con = await cls.connect() @classmethod async def teardown_and_disconnect(cls): await cls.con.aclose() # Give event loop another iteration so that connection # transport has a chance to properly close. await asyncio.sleep(0) cls.con = None def setUp(self): if self.INTERNAL_TESTMODE: self.loop.run_until_complete( self.con.execute( 'CONFIGURE SESSION SET __internal_testmode := true;')) if not self.ENABLE_RECOMPILATION: self.loop.run_until_complete( self.con.execute( 'CONFIGURE SESSION SET auto_rebuild_query_cache := false;' ) ) if self.NO_FACTOR: self.loop.run_until_complete( self.con.execute( 'CONFIGURE SESSION SET simple_scoping := true;')) if self.WARN_FACTOR: self.loop.run_until_complete( self.con.execute( 'CONFIGURE SESSION SET warn_old_scoping := true;')) if self.TRANSACTION_ISOLATION: self.xact = self.con.transaction() self.loop.run_until_complete(self.xact.start()) for cmd in self.PER_TEST_SETUP: self.loop.run_until_complete(self.con.execute(cmd)) super().setUp() def tearDown(self): try: self.ensure_no_background_server_errors() for cmd in self.PER_TEST_TEARDOWN: self.loop.run_until_complete(self.con.execute(cmd)) finally: try: if self.TRANSACTION_ISOLATION: self.loop.run_until_complete(self.xact.rollback()) del self.xact if self.con.is_in_transaction(): self.loop.run_until_complete( self.con.query('ROLLBACK')) raise AssertionError( 'test connection is still in transaction ' '*after* the test') if not self.TRANSACTION_ISOLATION: self.loop.run_until_complete( self.con.execute('RESET ALIAS *;')) finally: super().tearDown() @classmethod async def connect( cls, *, cluster=None, database=None, user=None, password=None, secret_key=None, ) -> tconn.Connection: conargs = cls.get_connect_args( cluster=cluster, database=database, user=user, password=password, secret_key=secret_key, ) return await tconn.async_connect_test_client(**conargs) def repl(self): """Open interactive EdgeQL REPL right in the test. This is obviously only for debugging purposes. Just add `self.repl()` at any point in your test. """ conargs = self.get_connect_args() cmd = [ 'python', '-m', 'edb.cli', '--database', self.con.dbname, '--user', conargs['user'], '--tls-ca-file', conargs['tls_ca_file'], ] env = os.environ.copy() env['EDGEDB_HOST'] = conargs['host'] env['EDGEDB_PORT'] = str(conargs['port']) if password := conargs.get('password'): env['EDGEDB_PASSWORD'] = password if secret_key := conargs.get('secret_key'): env['EDGEDB_SECRET_KEY'] = secret_key proc = subprocess.Popen( cmd, stdin=sys.stdin, stdout=sys.stdout, env=env) while proc.returncode is None: try: proc.wait() except KeyboardInterrupt: pass def _run_and_rollback(self): return RollbackChanges(self.con) async def _run_and_rollback_retrying(self): @contextlib.asynccontextmanager async def cm(tx): try: async with tx: await tx._ensure_transaction() yield tx raise RollbackException except RollbackException: pass async for tx in self.con.retrying_transaction(): yield cm(tx) def assert_data_shape(self, data, shape, message=None, rel_tol=None, abs_tol=None): assert_data_shape.assert_data_shape( data, shape, self.fail, message=message, rel_tol=rel_tol, abs_tol=abs_tol, ) async def assert_query_result( self, query, exp_result_json, exp_result_binary=..., *, always_typenames=False, always_typeids=False, msg=None, sort=None, implicit_limit=0, variables=None, json_only=False, binary_only=False, rel_tol=None, abs_tol=None, language: Literal["sql", "edgeql"] = "edgeql", ): fetch_args = variables if isinstance(variables, tuple) else () fetch_kw = variables if isinstance(variables, dict) else {} if not binary_only and language != "sql": try: tx = self.con.transaction() await tx.start() try: res = await self.con._fetchall_json( query, *fetch_args, __limit__=implicit_limit, **fetch_kw) finally: await tx.rollback() res = json.loads(res) if sort is not None: assert_data_shape.sort_results(res, sort) assert_data_shape.assert_data_shape( res, exp_result_json, self.fail, message=msg, rel_tol=rel_tol, abs_tol=abs_tol, ) except Exception: self.add_fail_notes(serialization='json') if msg: self.add_fail_notes(msg=msg) raise if json_only: return if exp_result_binary is ...: # The expected result is the same exp_result_binary = exp_result_json typenames = random.choice([True, False]) or always_typenames typeids = random.choice([True, False]) or always_typeids try: res = await self.con._fetchall( query, *fetch_args, __typenames__=typenames, __typeids__=typeids, __limit__=implicit_limit, __language__=( tconn.InputLanguage.SQL if language == "sql" else tconn.InputLanguage.EDGEQL ), **fetch_kw ) res = serutils.serialize(res) if sort is not None: assert_data_shape.sort_results(res, sort) assert_data_shape.assert_data_shape( res, exp_result_binary, self.fail, message=msg, rel_tol=rel_tol, abs_tol=abs_tol, ) except Exception: self.add_fail_notes( serialization='binary', __typenames__=typenames, __typeids__=typeids) if msg: self.add_fail_notes(msg=msg) raise async def assert_sql_query_result( self, query, exp_result, *, implicit_limit=0, msg=None, sort=None, variables=None, rel_tol=None, abs_tol=None, apply_access_policies=True, ): if not apply_access_policies: ctx = self.without_access_policies() else: ctx = contextlib.nullcontext() async with ctx: await self.assert_query_result( query, exp_result, implicit_limit=implicit_limit, msg=msg, sort=sort, variables=variables, rel_tol=rel_tol, abs_tol=abs_tol, language="sql", ) async def assert_index_use(self, query, *args, index_type): def look(obj): if ( isinstance(obj, dict) and "IndexScan" in obj.get('plan_type', '') ): return any( prop['title'] == 'index_name' and index_type in prop['value'] for prop in obj.get('properties', []) ) if isinstance(obj, dict): return any([look(v) for v in obj.values()]) elif isinstance(obj, list): return any(look(v) for v in obj) else: return False plan = await self.con.query_json(f'analyze {query}', *args) if not look(json.loads(plan)): raise AssertionError(f"query did not use the {index_type!r} index") @classmethod def get_backend_sql_dsn(cls, dbname=None): settings = cls.con.get_settings() pgdsn = settings.get('pgdsn') if pgdsn is None: raise unittest.SkipTest('raw SQL test skipped: not in devmode') params = ConnectionParams(dsn=pgdsn.decode('utf8')) if dbname: params.update(database=dbname) params.clear_server_settings() return params.to_dsn() @classmethod async def get_backend_sql_connection(cls, dbname=None): """Get a raw connection to the underlying SQL server, if possible This is useful when we want to do things like querying the pg_catalog of the underlying database. """ try: import asyncpg except ImportError: raise unittest.SkipTest( 'SQL test skipped: asyncpg not installed') pgdsn = cls.get_backend_sql_dsn(dbname=dbname) return await asyncpg.connect(pgdsn) @classmethod @contextlib.asynccontextmanager async def with_backend_sql_connection(cls, dbname=None): con = await cls.get_backend_sql_connection(dbname=dbname) try: yield con finally: await con.close() @contextlib.asynccontextmanager async def without_access_policies(self): await self.con.execute( 'CONFIGURE SESSION SET apply_access_policies := false' ) raised_an_execption = False try: yield except BaseException: raised_an_execption = True raise finally: if not (raised_an_execption and self.con.is_in_transaction()): await self.con.execute( 'CONFIGURE SESSION RESET apply_access_policies' ) @classmethod def get_sql_proto_dsn(cls, dbname=None): dbname = dbname or cls.con.dbname conargs = cls.get_connect_args() return ( f"postgres://{conargs['user']}:{conargs['password']}@" f"{conargs['host']}:{conargs['port']}/{cls.con.dbname}?" f"sslrootcert={conargs['tls_ca_file']}" ) class DatabaseTestCase(ConnectedTestCase): SETUP: Optional[str | pathlib.Path | list[str] | list[pathlib.Path]] = None TEARDOWN: Optional[str] = None SCHEMA: Optional[str | pathlib.Path] = None DEFAULT_MODULE: str = 'default' EXTENSIONS: list[str] = [] BASE_TEST_CLASS = True con: Any # XXX: the real type? @classmethod async def setup_and_connect(cls): dbname = cls.get_database_name() cls.con = None class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP', 'run') # Only open an extra admin connection if necessary. if class_set_up == 'run': script = f'CREATE DATABASE {dbname};' admin_conn = await cls.connect( database=edgedb_defines.EDGEDB_SUPERUSER_DB ) await admin_conn.execute(script) await admin_conn.aclose() elif class_set_up == 'inplace': dbname = edgedb_defines.EDGEDB_SUPERUSER_DB elif cls.uses_database_copies(): admin_conn = await cls.connect( database=edgedb_defines.EDGEDB_SUPERUSER_DB ) base_db_name, _, _ = dbname.rpartition('_') if cls.get_setup_script(): await admin_conn.execute(''' configure session set __internal_testmode := true; ''') create_command = ( f'CREATE TEMPLATE BRANCH {qlquote.quote_ident(dbname)}' f' FROM {qlquote.quote_ident(base_db_name)};' ) else: create_command = ( f'CREATE EMPTY BRANCH {qlquote.quote_ident(dbname)}') # The retry here allows the test to survive a concurrent testing # Gel server (e.g. async with tb.start_edgedb_server()) whose # introspection holds a lock on the base_db here async for tr in cls.try_until_succeeds( ignore=edgedb.ExecutionError, timeout=30, ): async with tr: await admin_conn.execute(create_command) await admin_conn.aclose() cls.con = await cls.connect(database=dbname) if class_set_up != 'skip': script = cls.get_setup_script() if script: with cls.con.ignore_warnings(): await cls.con.execute(script) @staticmethod def get_set_up(): return os.environ.get('EDGEDB_TEST_CASES_SET_UP', 'run') @classmethod async def teardown_and_disconnect(cls): script = '' class_set_up = cls.get_set_up() if cls.TEARDOWN and class_set_up != 'skip': script = cls.TEARDOWN.strip() try: if script: await cls.con.execute(script) if class_set_up == 'inplace': await cls.tearDownSingleDB() finally: await cls.con.aclose() if class_set_up == 'inplace': pass elif class_set_up == 'run' or cls.uses_database_copies(): dbname = qlquote.quote_ident(cls.get_database_name()) admin_conn = await cls.connect( database=edgedb_defines.EDGEDB_SUPERUSER_DB ) try: await drop_db(admin_conn, dbname) finally: await admin_conn.aclose() @classmethod def get_connect_args( cls, *, database=None, **kwargs, ): return super().get_connect_args( database=database or cls.get_database_name(), **kwargs, ) @classmethod def get_database_name(cls): if not getattr(cls, 'has_create_database', True): return edgedb_defines.EDGEDB_SUPERUSER_DB if cls.__name__.startswith('TestEdgeQL'): dbname = cls.__name__[len('TestEdgeQL'):] elif cls.__name__.startswith('Test'): dbname = cls.__name__[len('Test'):] else: dbname = cls.__name__ if cls.uses_database_copies(): return f'{dbname.lower()}_{os.getpid()}' else: return dbname.lower() @classmethod def get_api_prefix(cls): return f'/db/{cls.get_database_name()}' @classmethod def get_setup_script(cls): script = '' has_nontrivial_script = False # allow the setup script to also run in test mode and no recompilation if cls.INTERNAL_TESTMODE: script += '\nCONFIGURE SESSION SET __internal_testmode := true;' if not cls.ENABLE_RECOMPILATION: script += ( '\nCONFIGURE SESSION SET auto_rebuild_query_cache := false;' ) if getattr(cls, 'BACKEND_SUPERUSER', False): is_superuser = getattr(cls, 'is_superuser', True) if not is_superuser: raise unittest.SkipTest('skipped due to lack of superuser') schema = [] # Incude the extensions before adding schemas. for ext in cls.EXTENSIONS: schema.append(f'using extension {ext};') # Look at all SCHEMA entries and potentially create multiple # modules, but always create the test module, if not `default`. if cls.DEFAULT_MODULE != 'default': schema.append(f'\nmodule {cls.DEFAULT_MODULE} {{}}') for name in dir(cls): m = re.match(r'^SCHEMA(?:_(\w+))?', name) if m: module_name = ( (m.group(1) or cls.DEFAULT_MODULE) .lower().replace('_', '::') ) schema_fn = getattr(cls, name) if schema_fn is not None: with open(schema_fn, 'r') as sf: module = sf.read() schema.append(f'\nmodule {module_name} {{ {module} }}') full_schema_fn = getattr(cls, 'FULL_SCHEMA', None) if full_schema_fn: with open(full_schema_fn, 'r') as sf: schema.append(sf.read()) if schema: has_nontrivial_script = True script += f'\nSTART MIGRATION' script += f' TO {{ {"".join(schema)} }};' script += f'\nPOPULATE MIGRATION;' script += f'\nCOMMIT MIGRATION;' if cls.SETUP: if not isinstance(cls.SETUP, (list, tuple)): scripts = [cls.SETUP] else: scripts = cls.SETUP for scr in scripts: has_nontrivial_script = True is_path = ( isinstance(scr, pathlib.Path) or '\n' not in scr and os.path.exists(scr) ) if is_path: with open(scr, 'rt') as f: setup_text = f.read() else: assert isinstance(scr, str) setup_text = scr script += '\n' + setup_text # If the SETUP script did a SET MODULE, make sure it is cleared # (since in some modes we keep using the same connection) script += '\nRESET MODULE;' # allow the setup script to also run in test mode if cls.INTERNAL_TESTMODE: script += '\nCONFIGURE SESSION SET __internal_testmode := false;' if not cls.ENABLE_RECOMPILATION: script += '\nCONFIGURE SESSION RESET auto_rebuild_query_cache;' return script.strip(' \n') if has_nontrivial_script else '' async def migrate(self, migration, *, module: Optional[str] = 'default'): if module: migration = f""" module {module} {{ {migration} }} """ with self.ignore_warnings('Non-simple_scoping will be removed'): await self.con.execute(f""" START MIGRATION TO {{ {migration} }}; POPULATE MIGRATION; COMMIT MIGRATION; """) class Error: def __init__(self, cls, message, shape): self._message = message self._class = cls self._shape = shape @property def message(self): return self._message @property def cls(self): return self._class @property def shape(self): return self._shape class BaseQueryTestCase(DatabaseTestCase): BASE_TEST_CLASS = True class DDLTestCase(BaseQueryTestCase): # DDL test cases generally need to be serialized # to avoid deadlocks in parallel execution. PARALLELISM_GRANULARITY = 'database' BASE_TEST_CLASS = True class QueryTestCase(BaseQueryTestCase): BASE_TEST_CLASS = True class SQLQueryTestCase(BaseQueryTestCase): BASE_TEST_CLASS = True scon: asyncpg.Connection @classmethod def setUpClass(cls): try: import asyncpg # noqa: F401 except ImportError: raise unittest.SkipTest('SQL tests skipped: asyncpg not installed') super().setUpClass() cls.scon = cls.loop.run_until_complete( cls.create_sql_connection() ) @classmethod def create_sql_connection( cls, *, user: str = None, password: str = None, ) -> asyncio.Future[asyncpg.Connection]: import asyncpg conargs = cls.get_connect_args() tls_context = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, cafile=conargs["tls_ca_file"], ) tls_context.check_hostname = False return asyncpg.connect( host=conargs['host'], port=conargs['port'], user=conargs['user'] if user is None else user, password=conargs['password'] if password is None else password, database=cls.con.dbname, ssl=tls_context, ) @classmethod def tearDownClass(cls): try: cls.loop.run_until_complete(cls.scon.close()) # Give event loop another iteration so that connection # transport has a chance to properly close. cls.loop.run_until_complete(asyncio.sleep(0)) cls.scon = None finally: super().tearDownClass() def setUp(self): if self.TRANSACTION_ISOLATION: self.stran = self.scon.transaction() self.loop.run_until_complete(self.stran.start()) super().setUp() def tearDown(self): try: if self.TRANSACTION_ISOLATION: self.loop.run_until_complete(self.stran.rollback()) self.loop.run_until_complete(self.scon.execute('RESET ALL')) finally: super().tearDown() async def squery_values(self, query, *args): res = await self.scon.fetch(query, *args) return [list(r.values()) for r in res] def assert_shape(self, res: Any, rows: int, columns: int | list[str]): """ Fail if query result does not confront the specified shape, defined in terms of: - number of rows, - number of columns (not checked if there are not rows) - column names. """ self.assertEqual(len(res), rows) if isinstance(columns, int): if rows > 0: self.assertEqual(len(res[0]), columns) elif isinstance(columns, list): self.assertListEqual(columns, list(res[0].keys())) class CLITestCaseMixin: def run_cli(self, *args, input: Optional[str] = None) -> None: conn_args = self.get_connect_args() self.run_cli_on_connection(conn_args, *args, input=input) @classmethod def run_cli_on_connection( cls, conn_args: dict[str, Any], *args, input: Optional[str] = None ) -> None: cmd_args = [ '--host', conn_args['host'], '--port', str(conn_args['port']), '--tls-ca-file', conn_args['tls_ca_file'] ] if conn_args.get('user'): cmd_args += ['--user', conn_args['user']] if conn_args.get('password'): cmd_args += ['--password-from-stdin'] if input is not None: input = f"{conn_args['password']}\n{input}" else: input = f"{conn_args['password']}\n" cmd_args += args cmd = ['gel'] + cmd_args try: subprocess.run( cmd, input=input.encode() if input else None, check=True, capture_output=True, ) except subprocess.CalledProcessError as e: output = '\n'.join(getattr(out, 'decode', out.__str__)() for out in [e.output, e.stderr] if out) raise AssertionError( f'command {cmd} returned non-zero exit status {e.returncode}' f'\n{output}' ) from e class DumpCompatTestCaseMeta(TestCaseMeta): def __new__( mcls, name, bases, ns, *, dump_subdir=None, check_method=None, ): if not name.startswith('Test'): return super().__new__(mcls, name, bases, ns) if dump_subdir is None: raise TypeError( f'{name}: missing required "dump_subdir" class argument') if check_method is None: raise TypeError( f'{name}: missing required "check_method" class argument') mod = sys.modules[ns['__module__']] dumps_dir = pathlib.Path(mod.__file__).parent / 'dumps' / dump_subdir async def check_dump_restore_compat_single_db(self, *, dumpfn): dbname = edgedb_defines.EDGEDB_SUPERUSER_DB self.run_cli('-d', dbname, 'restore', str(dumpfn)) try: await check_method(self) finally: await self.tearDownSingleDB() async def check_dump_restore_compat(self, *, dumpfn: pathlib.Path): if not self.has_create_database: return await check_dump_restore_compat_single_db( self, dumpfn=dumpfn ) dbname = f"{type(self).__name__}_{dumpfn.stem}" qdbname = qlquote.quote_ident(dbname) await self.con.execute(f'CREATE DATABASE {qdbname}') try: self.run_cli('-d', dbname, 'restore', str(dumpfn)) con2 = await self.connect(database=dbname) except Exception: await drop_db(self.con, qdbname) raise oldcon = self.__class__.con self.__class__.con = con2 try: await check_method(self) finally: self.__class__.con = oldcon await con2.aclose() await drop_db(self.con, qdbname) for entry in dumps_dir.iterdir(): if not entry.is_file() or not entry.name.endswith(".dump"): continue mcls.add_method( f'test_{dump_subdir}_restore_compatibility_{entry.stem}', ns, functools.partial(check_dump_restore_compat, dumpfn=entry), ) return super().__new__(mcls, name, bases, ns) class DumpCompatTestCase( ConnectedTestCase, CLITestCaseMixin, metaclass=DumpCompatTestCaseMeta, ): BASE_TEST_CLASS = True TRANSACTION_ISOLATION = False class StableDumpTestCase(QueryTestCase, CLITestCaseMixin): BASE_TEST_CLASS = True STABLE_DUMP = True TRANSACTION_ISOLATION = False PARALLELISM_GRANULARITY = 'suite' async def check_dump_restore_single_db(self, check_method): with tempfile.TemporaryDirectory() as f: fname = os.path.join(f, 'dump') dbname = edgedb_defines.EDGEDB_SUPERUSER_DB await asyncio.to_thread(self.run_cli, '-d', dbname, 'dump', fname) await self.tearDownSingleDB() await asyncio.to_thread( self.run_cli, '-d', dbname, 'restore', fname ) # Cycle the connection to avoid state mismatches await self.con.aclose() self.con = await self.connect(database=dbname) await check_method(self) async def check_dump_restore( self, check_method, include_secrets: bool=False ): if not self.has_create_database: return await self.check_dump_restore_single_db(check_method) src_dbname = self.get_database_name() tgt_dbname = f'{src_dbname}_restored' q_tgt_dbname = qlquote.quote_ident(tgt_dbname) with tempfile.TemporaryDirectory() as f: fname = os.path.join(f, 'dump') extra = ['--include-secrets'] if include_secrets else [] await asyncio.to_thread( self.run_cli, '-d', src_dbname, 'dump', fname, *extra ) await self.con.execute(f'CREATE DATABASE {q_tgt_dbname}') try: await asyncio.to_thread( self.run_cli, '-d', tgt_dbname, 'restore', fname ) con2 = await self.connect(database=tgt_dbname) except Exception: await drop_db(self.con, q_tgt_dbname) raise oldcon = self.con self.__class__.con = con2 try: await check_method(self) finally: self.__class__.con = oldcon await con2.aclose() await drop_db(self.con, q_tgt_dbname) async def check_branching(self, include_data=False, *, check_method): if not self.has_create_database: self.skipTest("create branch is not supported by the backend") orig_branch = self.get_database_name() new_branch = f'new_{orig_branch}' # record the original schema orig_schema = await self.con.query_single('describe schema as sdl') # connect to a default branch so we can create a new branch branch_type = 'data' if include_data else 'schema' await self.con.execute( f'create {branch_type} branch {new_branch} ' f'from {orig_branch}' ) try: con2 = await self.connect(database=new_branch) except Exception: await drop_db(self.con, new_branch) raise oldcon = self.con self.__class__.con = con2 try: # We cannot compare the SDL text of the new branch schema to the # original because the order in which it renders all the # components is not guaranteed. Instead we will use migrations to # compare the new branch schema to the original. We expect there # to be no difference and therefore a new migration to the # original schema should have the "complete" status right away. with self.ignore_warnings(): await self.con.execute( f'start migration to {{ {orig_schema} }}' ) mig_status = json.loads( await self.con.query_single_json( 'describe current migration as json' ) ) self.assertTrue(mig_status.get('complete')) await self.con.execute('abort migration') # run the check_method on the copied branch if include_data: await check_method(self) else: await check_method(self, include_data=include_data) finally: self.__class__.con = oldcon await con2.aclose() await drop_db(self.con, new_branch) class StablePGDumpTestCase(BaseQueryTestCase): BASE_TEST_CLASS = True TRANSACTION_ISOLATION = False def run_pg_dump(self, *args, input: Optional[str] = None) -> None: conargs = self.get_connect_args() self.run_pg_dump_on_connection(conargs, *args, input=input) @classmethod def run_pg_dump_on_connection( cls, dsn: str, *args, input: Optional[str] = None ) -> None: cmd = [cls._pg_bin_dir / 'pg_dump', '--dbname', dsn] cmd += args try: subprocess.run( cmd, input=input.encode() if input else None, check=True, capture_output=True, ) except subprocess.CalledProcessError as e: output = '\n'.join(getattr(out, 'decode', out.__str__)() for out in [e.output, e.stderr] if out) raise AssertionError( f'command {cmd} returned non-zero exit status {e.returncode}' f'\n{output}' ) from e @classmethod def setUpClass(cls): try: import asyncpg except ImportError: raise unittest.SkipTest('SQL tests skipped: asyncpg not installed') if cls.get_set_up() == 'inplace': raise unittest.SkipTest('SQL dump tests skipped in single db mode') super().setUpClass() frontend_dsn = cls.get_sql_proto_dsn() src_dbname = cls.con.dbname tgt_dbname = f'restored_{src_dbname}' try: newdsn = cls.get_backend_sql_dsn(dbname=tgt_dbname) except Exception: super().tearDownClass() raise cls._pg_bin_dir = cls.loop.run_until_complete( pgcluster.get_pg_bin_dir()) cls.backend = cls.loop.run_until_complete( cls.get_backend_sql_connection()) # Run pg_dump to create the dump data for an existing Gel database. with tempfile.NamedTemporaryFile() as f: cls.run_pg_dump_on_connection(frontend_dsn, '-f', f.name) # Skip the restore part of the test if the database # backend is older than our pg_dump, since it won't work. pg_ver_str = cls.loop.run_until_complete( cls.backend.fetch('select version()') )[0][0] pg_ver = buildmeta.parse_pg_version(pg_ver_str) bundled_ver = buildmeta.get_pg_version() if pg_ver.major < bundled_ver.major: raise unittest.SkipTest('pg_dump newer than backend') # Create a new Postgres database to be used for dump tests. db_exists = cls.loop.run_until_complete( cls.backend.fetch(f''' SELECT oid FROM pg_database WHERE datname = {tgt_dbname!r} ''') ) if list(db_exists): cls.loop.run_until_complete( cls.backend.execute(f'drop database {tgt_dbname}') ) cls.loop.run_until_complete( cls.backend.execute(f'create database {tgt_dbname}') ) # Populate the new database using the dump cmd = [ cls._pg_bin_dir / 'psql', '-a', '--dbname', newdsn, '-f', f.name, '-v', 'ON_ERROR_STOP=on', ] try: subprocess.run( cmd, input=None, check=True, capture_output=True, ) except subprocess.CalledProcessError as e: output = '\n'.join(getattr(out, 'decode', out.__str__)() for out in [e.output, e.stderr] if out) raise AssertionError( f'command {cmd} returned non-zero exit status ' f'{e.returncode}\n{output}' ) from e # Connect to the newly created database. cls.scon = cls.loop.run_until_complete( asyncpg.connect(newdsn)) @classmethod def tearDownClass(cls): try: cls.loop.run_until_complete(cls.scon.close()) # Give event loop another iteration so that connection # transport has a chance to properly close. cls.loop.run_until_complete(asyncio.sleep(0)) cls.scon = None tgt_dbname = f'restored_{cls.con.dbname}' cls.loop.run_until_complete( cls.backend.execute(f'drop database {tgt_dbname}') ) cls.loop.run_until_complete(cls.backend.close()) cls.loop.run_until_complete(asyncio.sleep(0)) finally: super().tearDownClass() def assert_shape( self, sqlres: Iterable[Any], eqlres: Iterable[asyncpg.Record], ) -> None: """ Compare the shape of results produced by a SQL query and an EdgeQL query. """ assert_data_shape.assert_data_shape( list(sqlres), [dataclasses.asdict(r) for r in eqlres], self.fail, from_sql=True, ) def multi_prop_subquery(self, source: str, prop: str) -> str: "Propduce a subquery fetching a multi prop as an array." return ( f'(SELECT array_agg(target) FROM "{source}.{prop}"' f' WHERE source = "{source}".id) AS {prop}' ) def single_link_subquery( self, source: str, link: str, target: str, link_props: Optional[Iterable[str]] = None ) -> str: """Propduce a subquery fetching a single link as a record. If no link properties are specified then the array of records will be made up of target types. If the link properties are specified then the array of records will be made up of link records. """ if link_props: return ( f'(SELECT x FROM "{target}"' f' JOIN "{source}.{link}" x ON x.target = "{target}".id' f' WHERE x.source = "{source}".id) AS _{link}' ) else: return ( f'(SELECT "{target}" FROM "{target}"' f' WHERE "{target}".id = "{source}".{link}_id) AS {link}' ) def multi_link_subquery( self, source: str, link: str, target: str, link_props: Optional[Iterable[str]] = None ) -> str: """Propduce a subquery fetching a multi link as an array or records. If no link properties are specified then the array of records will be made up of target types. If the link properties are specified then the array of records will be made up of link records. """ if link_props: return ( f'(SELECT array_agg(x) FROM "{target}"' f' JOIN "{source}.{link}" x ON x.target = "{target}".id' f' WHERE x.source = "{source}".id) AS _{link}' ) else: return ( f'(SELECT array_agg("{target}") FROM "{target}"' f' JOIN "{source}.{link}" x ON x.target = "{target}".id' f' WHERE x.source = "{source}".id) AS {link}' ) def get_test_cases_setup( cases: Iterable[unittest.TestCase] ) -> list[tuple[unittest.TestCase, DatabaseName, SetupScript]]: result: list[tuple[unittest.TestCase, DatabaseName, SetupScript]] = [] for case in cases: if not hasattr(case, 'get_setup_script'): continue try: setup_script = case.get_setup_script() except unittest.SkipTest: continue dbname = case.get_database_name() result.append((case, dbname, setup_script)) return result def test_cases_use_server(cases: Iterable[unittest.TestCase]) -> bool: for case in cases: if not hasattr(case, 'uses_server'): continue if case.uses_server(): return True async def setup_test_cases( cases, conn, num_jobs, try_cached_db=False, skip_empty_databases=False, verbose=False, ): setup = get_test_cases_setup(cases) stats = [] if num_jobs == 1: # Special case for --jobs=1 for _case, dbname, setup_script in setup: if skip_empty_databases and not setup_script: continue await _setup_database( dbname, setup_script, conn, stats, try_cached_db) if verbose: print(f' -> {dbname}: OK', flush=True) else: async with asyncio.TaskGroup() as g: # Use a semaphore to limit the concurrency of bootstrap # tasks to the number of jobs (bootstrap is heavy, having # more tasks than `--jobs` won't necessarily make # things faster.) sem = asyncio.BoundedSemaphore(num_jobs) async def controller(coro, dbname, *args): async with sem: await coro(dbname, *args) if verbose: print(f' -> {dbname}: OK', flush=True) for _case, dbname, setup_script in setup: if skip_empty_databases and not setup_script: continue g.create_task(controller( _setup_database, dbname, setup_script, conn, stats, try_cached_db)) return stats async def _setup_database( dbname, setup_script, conn_args, stats, try_cached_db): start_time = time.monotonic() default_args = { 'user': edgedb_defines.EDGEDB_SUPERUSER, 'password': 'test', } default_args.update(conn_args) try: admin_conn = await tconn.async_connect_test_client( database=edgedb_defines.EDGEDB_SUPERUSER_DB, **default_args) except Exception as ex: raise RuntimeError( f'exception during creation of {dbname!r} test DB; ' f'could not connect to the {edgedb_defines.EDGEDB_SUPERUSER_DB} ' f'db; {type(ex).__name__}({ex})' ) from ex try: await admin_conn.execute( f'CREATE DATABASE {qlquote.quote_ident(dbname)};' ) except edgedb.DuplicateDatabaseDefinitionError: # Eh, that's fine # And, if we are trying to use a cache of the database, assume # the db is populated and return. if try_cached_db: elapsed = time.monotonic() - start_time stats.append( ('setup::' + dbname, {'running-time': elapsed, 'cached': True})) return except Exception as ex: raise RuntimeError( f'exception during creation of {dbname!r} test DB: ' f'{type(ex).__name__}({ex})' ) from ex finally: await admin_conn.aclose() dbconn = await tconn.async_connect_test_client( database=dbname, **default_args ) try: if setup_script: async for tx in dbconn.retrying_transaction(): async with tx: with dbconn.capture_warnings(): await dbconn.execute(setup_script) except Exception as ex: raise RuntimeError( f'exception during initialization of {dbname!r} test DB: ' f'{type(ex).__name__}({ex})' ) from ex finally: await dbconn.aclose() elapsed = time.monotonic() - start_time stats.append( ('setup::' + dbname, {'running-time': elapsed, 'cached': False})) _lock_cnt = 0 def gen_lock_key(): global _lock_cnt _lock_cnt += 1 return os.getpid() * 1000 + _lock_cnt class _EdgeDBServerData(NamedTuple): host: str port: int password: str server_data: Any tls_cert_file: str pid: int def get_connect_args(self, **kwargs) -> dict[str, str | int]: conn_args = dict( user='edgedb', password=self.password, host=self.host, port=self.port, tls_ca_file=self.tls_cert_file, ) conn_args.update(kwargs) return conn_args def fetch_metrics(self) -> str: ctx = ssl.create_default_context() ctx.load_verify_locations(self.tls_cert_file) return _fetch_metrics(self.host, self.port, sslctx=ctx) def fetch_server_info(self) -> dict[str, Any]: return _fetch_server_info(self.host, self.port) def call_system_api(self, path: str, **kwargs): args = dict(host=self.host, port=self.port, path=path) args.update(kwargs) return _call_system_api(**args) async def connect(self, **kwargs: Any) -> tconn.Connection: conn_args = self.get_connect_args(**kwargs) return await tconn.async_connect_test_client(**conn_args) async def connect_pg(self, **kwargs: Any) -> asyncpg.Connection: import asyncpg conn_args = self.get_connect_args(**kwargs) return await asyncpg.connect( host=conn_args['host'], port=conn_args['port'], user=conn_args['user'], password=conn_args['password'], ssl='require' ) async def connect_test_protocol(self, **kwargs): conn_args = self.get_connect_args(**kwargs) conn = await test_protocol.new_connection(**conn_args) await conn.connect() return conn class _EdgeDBServer: proc: Optional[asyncio.Process] def __init__( self, *, bind_addrs: tuple[str, ...] = ('localhost',), bootstrap_command: Optional[str], auto_shutdown_after: Optional[int], adjacent_to: Optional[tconn.Connection], max_allowed_connections: Optional[int], compiler_pool_size: int, compiler_pool_mode: Optional[edgedb_args.CompilerPoolMode] = None, debug: bool, backend_dsn: Optional[str] = None, data_dir: Optional[str] = None, runstate_dir: Optional[str] = None, reset_auth: Optional[bool] = None, tenant_id: Optional[str] = None, security: edgedb_args.ServerSecurityMode, default_auth_method: Optional[ edgedb_args.ServerAuthMethod | edgedb_args.ServerAuthMethods ] = None, binary_endpoint_security: Optional[ edgedb_args.ServerEndpointSecurityMode] = None, http_endpoint_security: Optional[ edgedb_args.ServerEndpointSecurityMode] = None, # see __aexit__ enable_backend_adaptive_ha: bool = False, ignore_other_tenants: bool = False, readiness_state_file: Optional[str] = None, tls_cert_file: Optional[os.PathLike] = None, tls_key_file: Optional[os.PathLike] = None, tls_cert_mode: edgedb_args.ServerTlsCertMode = ( edgedb_args.ServerTlsCertMode.SelfSigned), tls_client_ca_file: Optional[os.PathLike] = None, jws_key_file: Optional[os.PathLike] = None, jwt_sub_allowlist_file: Optional[os.PathLike] = None, jwt_revocation_list_file: Optional[os.PathLike] = None, multitenant_config: Optional[str] = None, config_file: Optional[os.PathLike] = None, default_branch: Optional[str] = None, env: Optional[dict[str, str]] = None, extra_args: Optional[list[str]] = None, net_worker_mode: Optional[str] = None, password: Optional[str] = None, ) -> None: self.bind_addrs = bind_addrs self.auto_shutdown_after = auto_shutdown_after self.bootstrap_command = bootstrap_command self.adjacent_to = adjacent_to self.max_allowed_connections = max_allowed_connections self.compiler_pool_size = compiler_pool_size self.compiler_pool_mode = compiler_pool_mode self.debug = debug self.backend_dsn = backend_dsn self.data_dir = data_dir self.runstate_dir = runstate_dir self.reset_auth = reset_auth self.tenant_id = tenant_id self.proc = None self.data = None self.security = security self.default_auth_method = default_auth_method self.binary_endpoint_security = binary_endpoint_security self.http_endpoint_security = http_endpoint_security self.enable_backend_adaptive_ha = enable_backend_adaptive_ha self.ignore_other_tenants = ignore_other_tenants self.readiness_state_file = readiness_state_file self.tls_cert_file = tls_cert_file self.tls_key_file = tls_key_file self.tls_cert_mode = tls_cert_mode self.tls_client_ca_file = tls_client_ca_file self.jws_key_file = jws_key_file self.jwt_sub_allowlist_file = jwt_sub_allowlist_file self.jwt_revocation_list_file = jwt_revocation_list_file self.multitenant_config = multitenant_config self.config_file = config_file self.default_branch = default_branch self.env = env self.extra_args = extra_args self.net_worker_mode = net_worker_mode self.password = password async def wait_for_server_readiness(self, stream: asyncio.StreamReader): while True: line = await stream.readline() if self.debug: print(line.decode()) if not line: raise RuntimeError("Gel server terminated") if line.startswith(b'READY='): break _, _, dataline = line.decode().partition('=') return json.loads(dataline) async def kill_process(self, proc: asyncio.Process): proc.terminate() try: await asyncio.wait_for(proc.wait(), timeout=60) except TimeoutError: proc.kill() async def _shutdown(self, exc: Optional[Exception] = None): if self.proc is None: return if self.proc.returncode is None: if self.auto_shutdown_after is not None and exc is None: try: await asyncio.wait_for(self.proc.wait(), timeout=60 * 5) except TimeoutError: self.proc.kill() raise AssertionError( 'server did not auto-shutdown in 5 minutes') else: await self.kill_process(self.proc) # asyncio, hello? # Workaround SubprocessProtocol.__del__ weirdly # complaining that loop is closed. self.proc._transport.close() self.proc = None async def __aenter__(self): status_r, status_w = socket.socketpair() cmd = [ sys.executable, '-I', '-m', 'edb.server.main', '--port', 'auto', '--testmode', '--emit-server-status', f'fd://{status_w.fileno()}', '--compiler-pool-size', str(self.compiler_pool_size), '--tls-cert-mode', str(self.tls_cert_mode), '--jose-key-mode', 'generate', ] if self.compiler_pool_mode is not None: cmd.extend(('--compiler-pool-mode', self.compiler_pool_mode.value)) for addr in self.bind_addrs: cmd.extend(('--bind-address', addr)) reset_auth = self.reset_auth cmd.extend(['--log-level', 'd' if self.debug else 's']) if self.max_allowed_connections is not None: cmd.extend([ '--max-backend-connections', str(self.max_allowed_connections), ]) if self.backend_dsn is not None: cmd.extend([ '--backend-dsn', self.backend_dsn, ]) elif self.adjacent_to is not None: settings = self.adjacent_to.get_settings() pgdsn = settings.get('pgdsn') if pgdsn is None: raise RuntimeError('test requires devmode to access pgdsn') cmd += [ '--backend-dsn', pgdsn.decode('utf-8') ] elif self.multitenant_config: cmd += ['--multitenant-config-file', self.multitenant_config] elif self.data_dir: cmd += ['--data-dir', self.data_dir] else: cmd += ['--temp-dir'] if reset_auth is None: reset_auth = True if not reset_auth: password = self.password bootstrap_command = '' else: password = secrets.token_urlsafe() bootstrap_command = f"""\ ALTER ROLE admin {{ SET password := '{password}'; }}; """ if self.bootstrap_command is not None: bootstrap_command += self.bootstrap_command if bootstrap_command: cmd += ['--bootstrap-command', bootstrap_command] if self.default_branch is not None: cmd += ['--default-branch', self.default_branch] if self.auto_shutdown_after is not None: cmd += ['--auto-shutdown-after', str(self.auto_shutdown_after)] if self.runstate_dir: cmd += ['--runstate-dir', self.runstate_dir] if self.tenant_id: cmd += ['--tenant-id', self.tenant_id] if self.security: cmd += ['--security', str(self.security)] if self.default_auth_method: cmd += ['--default-auth-method', str(self.default_auth_method)] if self.binary_endpoint_security: cmd += ['--binary-endpoint-security', str(self.binary_endpoint_security)] if self.http_endpoint_security: cmd += ['--http-endpoint-security', str(self.http_endpoint_security)] if self.enable_backend_adaptive_ha: cmd += ['--enable-backend-adaptive-ha'] if self.ignore_other_tenants: cmd += ['--ignore-other-tenants'] if self.tls_cert_file: cmd += ['--tls-cert-file', self.tls_cert_file] if self.tls_key_file: cmd += ['--tls-key-file', self.tls_key_file] if self.tls_client_ca_file: cmd += ['--tls-client-ca-file', str(self.tls_client_ca_file)] if self.readiness_state_file: cmd += ['--readiness-state-file', self.readiness_state_file] if self.jws_key_file: cmd += ['--jws-key-file', str(self.jws_key_file)] if self.jwt_sub_allowlist_file: cmd += ['--jwt-sub-allowlist-file', self.jwt_sub_allowlist_file] if self.jwt_revocation_list_file: cmd += ['--jwt-revocation-list-file', self.jwt_revocation_list_file] if self.config_file: cmd += ['--config-file', self.config_file] if not self.multitenant_config: cmd += ['--instance-name=localtest'] if self.net_worker_mode: cmd += ['--net-worker-mode', self.net_worker_mode] if self.extra_args: cmd.extend(self.extra_args) if self.debug: print( f'Starting Gel cluster with the following params:\n' f'{" ".join(shlex.quote(c) for c in cmd)}' ) env = os.environ.copy() if self.env: env.update(self.env) env.pop("EDGEDB_SERVER_MULTITENANT_CONFIG_FILE", None) stat_reader, stat_writer = await asyncio.open_connection(sock=status_r) self.proc: asyncio.Process = await asyncio.create_subprocess_exec( *cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, pass_fds=(status_w.fileno(),), ) status_task = asyncio.create_task( asyncio.wait_for( self.wait_for_server_readiness(stat_reader), timeout=240, ), ) output = b'' async def read_stdout(): nonlocal output # Tee the log temporarily to a tempfile that exists as long as the # test is running. This helps debug hanging tests. with tempfile.NamedTemporaryFile( mode='w+t', prefix='edgedb-test-log-') as temp_file: if self.debug: print(f"Logging to {temp_file.name}") while True: line = await self.proc.stdout.readline() if not line: break output += line temp_file.write(line.decode(errors='ignore')) if self.debug: print(line.decode(errors='ignore'), end='') stdout_task = asyncio.create_task(read_stdout()) try: _, pending = await asyncio.wait( [ status_task, asyncio.create_task(self.proc.wait()), ], return_when=asyncio.FIRST_COMPLETED, ) except (Exception, asyncio.CancelledError): try: await self._shutdown() finally: raise finally: stat_writer.close() status_w.close() if pending: for task in pending: if not task.done(): task.cancel() await asyncio.wait(pending, timeout=10) if self.proc.returncode is not None: await stdout_task raise edgedb_cluster.ClusterError(output.decode(errors='ignore')) else: assert status_task.done() data = status_task.result() return _EdgeDBServerData( host='localhost', port=data['port'], password=password, server_data=data, tls_cert_file=data['tls_cert_file'], pid=self.proc.pid, ) async def __aexit__(self, exc_type, exc, tb): try: if ( ( self.http_endpoint_security is edgedb_args.ServerEndpointSecurityMode.Optional ) and self.data is not None and self.auto_shutdown_after is None ): # It's a good idea to test most of the ad-hoc test clusters # for any errors in background tasks, as such tests usually # test the functionality that involves notifications and # other async events. metrics = _fetch_metrics('127.0.0.1', self.data['port']) errors = _extract_background_errors(metrics) if errors: raise AssertionError( 'server terminated with unexpected ' + 'background errors\n\n' + errors ) finally: await self._shutdown(exc) def start_edgedb_server( *, bind_addrs: tuple[str, ...] = ('localhost',), auto_shutdown_after: Optional[int]=None, bootstrap_command: Optional[str]=None, max_allowed_connections: Optional[int]=5, compiler_pool_size: int=2, compiler_pool_mode: Optional[edgedb_args.CompilerPoolMode] = None, adjacent_to: Optional[tconn.Connection]=None, debug: bool=debug.flags.server, backend_dsn: Optional[str] = None, runstate_dir: Optional[str] = None, data_dir: Optional[str] = None, reset_auth: Optional[bool] = None, tenant_id: Optional[str] = None, security: edgedb_args.ServerSecurityMode = ( edgedb_args.ServerSecurityMode.Strict), default_auth_method: Optional[ edgedb_args.ServerAuthMethod | edgedb_args.ServerAuthMethods ] = None, binary_endpoint_security: Optional[ edgedb_args.ServerEndpointSecurityMode] = None, http_endpoint_security: Optional[ edgedb_args.ServerEndpointSecurityMode] = None, enable_backend_adaptive_ha: bool = False, ignore_other_tenants: bool = False, readiness_state_file: Optional[str] = None, tls_cert_file: Optional[os.PathLike] = None, tls_key_file: Optional[os.PathLike] = None, tls_cert_mode: edgedb_args.ServerTlsCertMode = ( edgedb_args.ServerTlsCertMode.SelfSigned), tls_client_ca_file: Optional[os.PathLike] = None, jws_key_file: Optional[os.PathLike] = None, jwt_sub_allowlist_file: Optional[os.PathLike] = None, jwt_revocation_list_file: Optional[os.PathLike] = None, multitenant_config: Optional[str] = None, config_file: Optional[os.PathLike] = None, env: Optional[dict[str, str]] = None, extra_args: Optional[list[str]] = None, default_branch: Optional[str] = None, net_worker_mode: Optional[str] = None, force_new: bool = False, # True for ignoring multitenant config env ): if (not devmode.is_in_dev_mode() or adjacent_to) and not runstate_dir: if backend_dsn or adjacent_to: import traceback # We don't want to implicitly "fix the issue" for the test author print('WARNING: starting a Gel server with the default ' 'runstate_dir; the test is likely to fail or hang. ' 'Consider specifying the runstate_dir parameter.') print('\n'.join(traceback.format_stack(limit=5))) password = None if mt_conf := os.environ.get("EDGEDB_SERVER_MULTITENANT_CONFIG_FILE"): if multitenant_config is None and max_allowed_connections == 10: if not any( ( adjacent_to, data_dir, backend_dsn, compiler_pool_mode, default_branch, force_new, ) ): multitenant_config = mt_conf max_allowed_connections = None password = 'test' # set in init_cluster() by test/runner.py params = locals() exclusives = [ name for name in [ "adjacent_to", "data_dir", "backend_dsn", "multitenant_config", ] if params[name] ] if len(exclusives) > 1: raise RuntimeError( " and ".join(exclusives) + " options are mutually exclusive" ) if not runstate_dir and data_dir: runstate_dir = data_dir return _EdgeDBServer( bind_addrs=bind_addrs, auto_shutdown_after=auto_shutdown_after, bootstrap_command=bootstrap_command, max_allowed_connections=max_allowed_connections, adjacent_to=adjacent_to, compiler_pool_size=compiler_pool_size, compiler_pool_mode=compiler_pool_mode, debug=debug, backend_dsn=backend_dsn, tenant_id=tenant_id, data_dir=data_dir, runstate_dir=runstate_dir, reset_auth=reset_auth, security=security, default_auth_method=default_auth_method, binary_endpoint_security=binary_endpoint_security, http_endpoint_security=http_endpoint_security, enable_backend_adaptive_ha=enable_backend_adaptive_ha, ignore_other_tenants=ignore_other_tenants, readiness_state_file=readiness_state_file, tls_cert_file=tls_cert_file, tls_key_file=tls_key_file, tls_cert_mode=tls_cert_mode, tls_client_ca_file=tls_client_ca_file, jws_key_file=jws_key_file, jwt_sub_allowlist_file=jwt_sub_allowlist_file, jwt_revocation_list_file=jwt_revocation_list_file, multitenant_config=multitenant_config, config_file=config_file, env=env, extra_args=extra_args, default_branch=default_branch, net_worker_mode=net_worker_mode, password=password, ) def get_cases_by_shard(cases, selected_shard, total_shards, verbosity, stats): if total_shards <= 1: return cases selected_shard -= 1 # starting from 0 new_test_est = 0.1 # default estimate if test is not found in stats new_setup_est = 1 # default estimate if setup is not found in stats # For logging total_tests = 0 selected_tests = 0 total_est = 0 selected_est = 0 # Priority queue of tests grouped by setup script ordered by estimated # running time of the groups. Order of tests within cases is preserved. tests_by_setup = [] # Priority queue of individual tests ordered by estimated running time. tests_with_est = [] # Prepare the source heaps setup_count = 0 for case, tests in cases.items(): # Extract zREPEAT tests and attach them to their first runs combined = {} for test in tests: test_name = str(test) orig_name = test_name.replace('test_zREPEAT', 'test') if orig_name == test_name: if test_name in combined: combined[test_name] = (test, *combined[test_name]) else: combined[test_name] = (test,) else: if orig_name in combined: combined[orig_name] = (*combined[orig_name], test) else: combined[orig_name] = (test,) setup_script_getter = getattr(case, 'get_setup_script', None) if setup_script_getter and combined: tests_per_setup = [] est_per_setup = setup_est = stats.get( 'setup::' + case.get_database_name(), (new_setup_est, 0), )[0] for test_name, test in combined.items(): total_tests += len(test) est = stats.get(test_name, (new_test_est, 0))[0] * len(test) est_per_setup += est tests_per_setup.append((est, test)) heapq.heappush( tests_by_setup, (-est_per_setup, setup_count, setup_est, tests_per_setup), ) setup_count += 1 total_est += est_per_setup else: for test_name, test in combined.items(): total_tests += len(test) est = stats.get(test_name, (new_test_est, 0))[0] * len(test) total_est += est heapq.heappush(tests_with_est, (-est, total_tests, test)) target_est = total_est / total_shards # target running time of one shard shards_est = [(0, shard, set()) for shard in range(total_shards)] cases = {} # output setup_to_alloc = set(range(setup_count)) # tracks first run of each setup # Assign per-setup tests first while tests_by_setup: remaining_est, setup_id, setup_est, tests = heapq.heappop( tests_by_setup, ) est_acc, current, setups = heapq.heappop(shards_est) # Add setup time if setup_id not in setups: setups.add(setup_id) est_acc += setup_est if current == selected_shard: selected_est += setup_est if setup_id in setup_to_alloc: setup_to_alloc.remove(setup_id) else: # This means one more setup for the overall test run target_est += setup_est / total_shards # Add as much tests from this group to current shard as possible while tests: est, test = tests.pop(0) est_acc += est # est is a positive number remaining_est += est # remaining_est is a negative number if current == selected_shard: # Add the test to the result _add_test(cases, test) selected_tests += len(test) selected_est += est if est_acc >= target_est and -remaining_est > setup_est * 2: # Current shard is full and the remaining tests would take more # time than their setup, then add the tests back to the heap so # that we could add them to another shard heapq.heappush( tests_by_setup, (remaining_est, setup_id, setup_est, tests), ) break heapq.heappush(shards_est, (est_acc, current, setups)) # Assign all non-setup tests, but leave the last shard for everything else setups = set() while tests_with_est and len(shards_est) > 1: est, _, test = heapq.heappop(tests_with_est) # est is negative est_acc, current, setups = heapq.heappop(shards_est) est_acc -= est if current == selected_shard: # Add the test to the result _add_test(cases, test) selected_tests += len(test) selected_est -= est if est_acc >= target_est: # The current shard is full if current == selected_shard: # End early if the selected shard is full break else: # Only add the current shard back to the heap if it's not full heapq.heappush(shards_est, (est_acc, current, setups)) else: # Add all the remaining tests to the first remaining shard if any while shards_est: est_acc, current, setups = heapq.heappop(shards_est) if current == selected_shard: for est, _, test in tests_with_est: _add_test(cases, test) selected_tests += len(test) selected_est -= est break tests_with_est.clear() # should always be empty already here if verbosity >= 1: print(f'Running {selected_tests}/{total_tests} tests for shard ' f'#{selected_shard + 1} out of {total_shards} shards, ' f'estimate: {int(selected_est / 60)}m {int(selected_est % 60)}s' f' / {int(total_est / 60)}m {int(total_est % 60)}s, ' f'{len(setups)}/{setup_count} databases to setup.') return _merge_results(cases) def find_available_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind(("localhost", 0)) return sock.getsockname()[1] def _needs_factoring(weakly): def decorator(f): async def g(self, *args, **kwargs): if self.NO_FACTOR and not weakly: with self.assertRaisesRegex(Exception, ''): await f(self, *args, **kwargs) elif self.WARN_FACTOR: with self.assertRaisesRegex( edgedb.InvalidReferenceError, 'attempting to factor out' ): await f(self, *args, **kwargs) else: await f(self, *args, **kwargs) return g return decorator @contextlib.asynccontextmanager async def temp_file_with(data: bytes): with tempfile.NamedTemporaryFile() as f: f.write(data) f.flush() yield f needs_factoring = _needs_factoring(weakly=False) needs_factoring_weakly = _needs_factoring(weakly=True) ================================================ FILE: edb/tools/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # DO NOT ADD ANYTHING TO THIS FILE: # This package contains tools like mypy plugins that # when loaded should not trigger any imports. ================================================ FILE: edb/tools/__main__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2016-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Stub to allow invoking `edb` as `python -m edb.tools`.""" from __future__ import annotations import sys from edb.tools import edb if __name__ == '__main__': sys.exit(edb.edbcommands(prog_name='edb')) ================================================ FILE: edb/tools/ast_inheritance_graph.py ================================================ # Generates an inheritance graph of Python classes. # # Usage: # $ edb ast-inheritance-graph | fdp -T svg -o ast-fdp.svg # # Requirements: # - graphviz import typing import dataclasses import enum import click from edb.edgeql import ast as qlast from edb.ir import ast as irast from edb.pgsql import ast as pgast from edb.tools.edb import edbcommands class ASTModule(str, enum.Enum): ql = "ql" ir = "ir" pg = "pg" @dataclasses.dataclass() class ASTClass: name: str typ: typing.Any bases: set[str] children: set[str] @edbcommands.command("ast-inheritance-graph") @click.argument('ast', type=click.Choice(ASTModule)) # type: ignore def main(ast: ASTModule) -> None: ast_mod: typing.Any if ast == ASTModule.ql: ast_mod = qlast elif ast == ASTModule.ir: ast_mod = irast elif ast == ASTModule.pg: ast_mod = pgast else: raise AssertionError() # discover all nodes ast_classes: dict[str, ASTClass] = {} for name, typ in ast_mod.__dict__.items(): if not isinstance(typ, type): continue if not issubclass(typ, ast_mod.Base) or name in { 'Base', 'ImmutableBase', }: continue if typ.__rust_ignore__: # type: ignore continue # re-run field collection to correctly handle forward-references typ = typ._collect_direct_fields() # type: ignore ast_classes[typ.__name__] = ASTClass( name=name, typ=typ, children=set(), bases=set(), ) for ast_class in ast_classes.values(): for base in ast_class.typ.__bases__: if base.__name__ not in ast_classes: continue ast_class.bases.add(base.__name__) ast_classes[base.__name__].children.add(ast_class.name) inheritance_graph(ast_classes) enum_graph(ast_classes) def inheritance_graph(ast_classes: dict[str, ASTClass]): print('digraph I {') for ast_class in ast_classes.values(): if ast_class.typ.__abstract_node__: print(f' {ast_class.name} [color = red];') for base in ast_class.bases: print(f' {ast_class.name} -> {base};') print('}') def enum_graph(ast_classes: dict[str, ASTClass]): print('digraph M {') def dfs(node, start): ast_class = ast_classes[node] if ast_class.typ.__abstract_node__: print(f' {node}_{start} [color = red];') for child in ast_class.children: print(f' {node}_{start} -> {child}_{start};') dfs(child, start) for ast_class in ast_classes.values(): if len(ast_class.bases) != 0 or len(ast_class.children) == 0: continue dfs(ast_class.name, ast_class.name) print('}') ================================================ FILE: edb/tools/cli.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2021-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import subprocess import sys import click from edb import cli as rustcli from edb.tools.edb import edbcommands @edbcommands.command('cli', add_help_option=False, context_settings=dict(ignore_unknown_options=True)) @click.argument('args', nargs=-1, type=click.UNPROCESSED) def cli(args: tuple[str, ...]): """Run edgedb CLI with `-H localhost`.""" args_list = _ensure_linked(args) if ( '--wait-until-available' not in args_list and not any('--wait-until-available=' in a for a in args_list) ): args_list += ['--wait-until-available', '60s'] sys.exit(rustcli.rustcli(args=[sys.argv[0], *args_list])) @edbcommands.command('ui', add_help_option=False, context_settings=dict(ignore_unknown_options=True)) @click.argument('args', nargs=-1, type=click.UNPROCESSED) def ui(args: tuple[str, ...]): """Run edgedb GUI with `-H localhost`.""" _ensure_linked(args) subprocess.check_call( [ sys.executable, "-I", "-m", "edb.cli", "ui", "--instance=_localdev", ], ) def _ensure_linked(args: tuple[str, ...]) -> list[str]: if ( '--host' not in args and not any(a.startswith('-H') for a in args) and not any(a.startswith('--host=') for a in args) and '--port' not in args and not any(a.startswith('-P') for a in args) and not any(a.startswith('--port=') for a in args) and '--instance' not in args and not any(a.startswith('-I') for a in args) and not any(a.startswith('--instance=') for a in args) ): subprocess.check_call([ sys.executable, "-I", "-m", "edb.cli", "instance", "link", "--host=localhost", "--non-interactive", "--trust-tls-cert", "--overwrite", "--quiet", "_localdev", ]) return list(args) + ['-I', '_localdev'] else: return list(args) ================================================ FILE: edb/tools/config.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2020-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import click from edb import buildmeta from edb.common import devmode from edb.tools.edb import edbcommands @edbcommands.command("config") @click.option( "--make-include", is_flag=True, help='Print path to the include file for extension Makefiles', ) @click.option( "--pg-config", is_flag=True, help='Print path to bundled pg_config', ) def config(make_include: bool, pg_config: bool) -> None: '''Query certain parameters about an edgedb environment''' if make_include: share = buildmeta.get_extension_dir_path() base = share.parent.parent.parent # XXX: It should not be here. if not devmode.is_in_dev_mode(): base = base / 'share' mk = ( base / 'tests' / 'extension-testing' / 'exts.mk' ) print(mk) if pg_config: print(buildmeta.get_pg_config_path()) ================================================ FILE: edb/tools/dflags.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2019-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from edb.common import debug from edb.tools.edb import edbcommands @edbcommands.command('dflags') def gen_types(): """Print available debug flags.""" for flag in debug.flags: print(f'env EDGEDB_DEBUG_{flag.name.upper()}=1') print(f' {flag.doc}\n') ================================================ FILE: edb/tools/docs/__init__.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from docutils import nodes as d_nodes from docutils.parsers import rst as d_rst from sphinx import addnodes as s_nodes from sphinx import transforms as s_transforms from sphinx.domains.index import IndexDirective from . import edb from . import cli from . import eql from . import js from . import sdl from . import graphql from . import go from . import shared class ProhibitedNodeTransform(s_transforms.SphinxTransform): default_priority = 1 # before ReferencesResolver def apply(self): for bq in list(self.document.traverse(d_nodes.block_quote)): if not bq['classes'] or 'pull-quote' not in bq['classes']: raise shared.EdgeSphinxExtensionError( f'blockquote found: {bq.asdom().toxml()!r} in {bq.source};' f' Try using the "pull-quote" directive') else: bq.get('classes').remove('pull-quote') trs = list(self.document.traverse(d_nodes.title_reference)) if trs: raise shared.EdgeSphinxExtensionError( f'title reference (single backticks quote) found: ' f'{trs[0].asdom().toxml()!r} in {trs[0].source}; ' f'perhaps you wanted to use double backticks?') class VersionAdded(d_rst.Directive): has_content = True optional_arguments = 0 required_arguments = 1 def run(self): node = s_nodes.versionmodified() node['type'] = 'versionadded' node['version'] = self.arguments[0] self.state.nested_parse(self.content, self.content_offset, node) return [node] class VersionChanged(d_rst.Directive): has_content = True optional_arguments = 0 required_arguments = 1 def run(self): node = s_nodes.versionmodified() node['type'] = 'versionchanged' node['version'] = self.arguments[0] self.state.nested_parse(self.content, self.content_offset, node) return [node] class VersionedSection(d_rst.Directive): has_content = False optional_arguments = 0 required_arguments = 0 def run(self): node = d_nodes.container() node['versioned-section'] = True return [node] class VersionedReplaceRole: def __call__( self, role, rawtext, text, lineno, inliner, options=None, content=None ): nodes = [] if not text.startswith('_default:'): text = '_default:' + text for section in text.split(';'): parts = section.split(':', maxsplit=1) node = s_nodes.versionmodified() node['type'] = 'versionchanged' node['version'] = parts[0].strip() node += d_nodes.Text(parts[1].strip()) nodes.append(node) return nodes, [] class APIIndex(IndexDirective): def run(self): nodes = super().run() nodes[0]['api-index'] = True return nodes def setup(app): edb.setup_domain(app) cli.setup_domain(app) eql.setup_domain(app) js.setup_domain(app) sdl.setup_domain(app) graphql.setup_domain(app) go.setup_domain(app) app.add_directive('versionadded', VersionAdded, True) app.add_directive('versionchanged', VersionChanged, True) app.add_directive('code-block', shared.CodeBlock, True) app.add_directive('versioned-section', VersionedSection) app.add_directive('api-index', APIIndex) app.add_role('versionreplace', VersionedReplaceRole()) app.add_transform(ProhibitedNodeTransform) ================================================ FILE: edb/tools/docs/cli.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any from edb.tools.pygments.edgeql import EdgeQLLexer import pygments.lexers from sphinx import domains as s_domains from docutils.parsers.rst import directives as d_directives # type: ignore from . import shared class CLISynopsisDirective(shared.CodeBlock): has_content = True optional_arguments = 0 required_arguments = 0 option_spec: dict[str, Any] = { 'version-lt': d_directives.unchanged_required } def run(self): self.arguments = ['cli-synopsis'] return super().run() class CLIDomain(s_domains.Domain): name = "cli" label = "Command Line Interface" directives = { 'synopsis': CLISynopsisDirective, } def setup_domain(app): # Dummy lexers; the actual highlighting is implemented # in the edgedb.com website code. app.add_lexer("txt", pygments.lexers.TextLexer) app.add_lexer("bash", pygments.lexers.TextLexer) app.add_lexer("cli", EdgeQLLexer) app.add_lexer("cli-synopsis", EdgeQLLexer) app.add_role( 'cli:synopsis', shared.InlineCodeRole('cli-synopsis')) app.add_domain(CLIDomain) ================================================ FILE: edb/tools/docs/edb.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import re from sphinx import domains as s_domains from docutils import nodes as d_nodes from docutils.parsers import rst as d_rst from docutils.parsers.rst import directives as d_directives # type: ignore from sphinx import transforms class EDBYoutubeEmbed(d_rst.Directive): has_content = True optional_arguments = 0 required_arguments = 1 def run(self): node = d_nodes.container() node['youtube-video-id'] = self.arguments[0] self.state.nested_parse(self.content, self.content_offset, node) return [node] class EDBCollapsed(d_rst.Directive): has_content = True optional_arguments = 0 required_arguments = 0 option_spec = { 'summary': d_directives.unchanged_required, } def run(self): node = d_nodes.container() node['collapsed_block'] = True if 'summary' in self.options: node['summary'] = self.options['summary'] self.state.nested_parse(self.content, self.content_offset, node) return [node] class EDBEnvironmentSwitcher(d_rst.Directive): has_content = False optional_arguments = 0 required_arguments = 0 def run(self): node = d_nodes.container() node['env-switcher'] = True return [node] class EDBSplitSection(d_rst.Directive): has_content = True optional_arguments = 0 required_arguments = 0 def run(self): node = d_nodes.container() node['split-section'] = True self.state.nested_parse(self.content, self.content_offset, node) split_indexes = [ index for index, child in enumerate(node.children) if isinstance(child, d_nodes.container) and child.get('split-point') ] if len(split_indexes) > 1: raise Exception( f'cannot have multiple edb:split-point\'s in edb:split-section' ) blocks = ( node.children[:split_indexes[0]] if node.children[split_indexes[0]].get('code-above') else node.children[split_indexes[0] + 1:] ) if len(split_indexes) == 1 else [node.children[-1]] if len(blocks) < 1: raise Exception( f'no content found at end of edb:split-section block, ' f'or before/after the edb:split-point in the edb:split-section' ) return [node] class EDBSplitPoint(d_rst.Directive): has_content = False optional_arguments = 1 required_arguments = 0 def run(self): node = d_nodes.container() node['split-point'] = True if len(self.arguments) > 0: if self.arguments[0] not in ['above', 'below']: raise Exception( f"expected edb:split-point arg to be 'above', 'below' " f"or empty (defaults to 'below')" ) if self.arguments[0] == 'above': node['code-above'] = True return [node] class GelDomain(s_domains.Domain): name = "edb" label = "Gel" directives = { 'collapsed': EDBCollapsed, 'youtube-embed': EDBYoutubeEmbed, 'env-switcher': EDBEnvironmentSwitcher, 'split-section': EDBSplitSection, 'split-point': EDBSplitPoint } class GelSubstitutionTransform(transforms.SphinxTransform): default_priority = 0 def apply(self): builder_name = "html" if hasattr(self.document.settings, 'env'): env = self.document.settings.env if env and hasattr(env, "app"): builder_name = env.app.builder.name # Traverse all substitution_reference nodes. for node in self.document.traverse(d_nodes.substitution_reference): nt = node.astext() if nt.lower() in { "gel", "gel's", "edgedb", "gelcmd", ".gel", "gel.toml", "gel-server", "geluri", "admin", "main", "branch", "branches" }: if builder_name in {"xml", "edge-xml"}: if nt == "gelcmd": sub = d_nodes.literal( 'gel', 'gel', **{ "edb-gelcmd": "true", "edb-gelcmd-top": "true", "edb-substitution": "true", } ) elif nt == "geluri": sub = d_nodes.literal( 'gel', 'gel://', **{ "edb-geluri": "true", "edb-geluri-top": "true", "edb-substitution": "true", } ) else: sub = d_nodes.inline( nt, nt, **{"edb-substitution": "true"} ) node.replace_self(sub) else: node.replace_self(d_nodes.Text(nt)) class GelCmdRole: def __call__( self, role, rawtext, text, lineno, inliner, options=None, content=None ): text = text.strip() text = re.sub(r'(\n\s*)+', " ", text) if text.startswith("edgedb"): fn = inliner.document.current_source raise Exception( f"{fn}:{lineno} - :gelcmd:`{text}` - can't start with 'edgedb'" ) if text.startswith("gel ") or text == "gel": fn = inliner.document.current_source raise Exception( f"{fn}:{lineno} - :gelcmd:`{text}` - can't start with 'gel'" ) text = f'gel {text}' node = d_nodes.literal(text, text) node["edb-gelcmd"] = "true" node["edb-gelcmd-top"] = "false" node["edb-substitution"] = "true" return [node], [] class GelUriRole: def __call__( self, role, rawtext, text, lineno, inliner, options=None, content=None ): if text.startswith("edgedb://"): fn = inliner.document.current_source raise Exception( f"{fn}:{lineno} - :geluri:`{text}`" f" - can't start with 'edgedb://'" ) if text.startswith("gel://"): fn = inliner.document.current_source raise Exception( f"{fn}:{lineno} - :geluri:`{text}` - can't start with 'gel://'" ) text = f'gel://{text}' node = d_nodes.literal(text, text) node["edb-geluri"] = "true" node["edb-geluri-top"] = "false" node["edb-substitution"] = "true" return [node], [] class DotGelRole: def __call__( self, role, rawtext, text, lineno, inliner, options=None, content=None ): if text.endswith(".gel") or text.endswith(".esdl"): fn = inliner.document.current_source raise Exception( f"{fn}:{lineno} - :dotgel:`{text}`" f" - can't end with '.esdl' or '.gel'" ) text = f'{text}.gel' node = d_nodes.literal(text, text) node["edb-dotgel"] = "true" node["edb-substitution"] = "true" return [node], [] class GelEnvRole: def __call__( self, role, rawtext, text, lineno, inliner, options=None, content=None ): if ( text.startswith("EDGEDB_") or text.startswith("GEL_") or text.startswith("_") ): fn = inliner.document.current_source raise Exception( f"{fn}:{lineno} - :gelenv:`{text}`" f" - can't start with 'EDGEDB_', 'GEL_', or '_'" ) text = f'GEL_{text}' node = d_nodes.literal(text, text) node["edb-gelenv"] = "true" node["edb-substitution"] = "true" return [node], [] def setup_domain(app): app.add_role('gelcmd', GelCmdRole()) app.add_role('geluri', GelUriRole()) app.add_role('dotgel', DotGelRole()) app.add_role('gelenv', GelEnvRole()) app.add_domain(GelDomain) app.add_transform(GelSubstitutionTransform) def setup(app): setup_domain(app) ================================================ FILE: edb/tools/docs/eql.py ================================================ # # This source file is part of the EdgeDB open source project. # # Copyright 2018-present MagicStack Inc. and the EdgeDB authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # r""" ===================================== :eql: domain for EdgeQL documentation ===================================== Functions --------- To declare a function use a ".. eql:function::" directive. A few things must be defined: * Full function signature with a fully qualified name must be specified. * ":param $name: description:" a short description of the $name parameter. $name must match the the name of the parameter in function's signature. If a parameter is anonymous, its number should be used instead (e.g. $1). * ":paramtype $name: type": for every :param: there must be a corresponding :paramtype field. For example: ":paramtype $name: int64" declares that the type of the $name parameter is `int64`. If a parameter has more than one valid types list them separated by "or": ":paramtype $name: int64 or str". * :return: and :returntype: are similar to :param: and :paramtype: but lack parameter names. They must be used to document the return value of the function. * A few paragraphs and code samples. The first paragraph must be a single sentence no longer than 79 characters describing the function. Example: .. eql:function:: std::array_agg(SET OF any, $a: any) -> array :param $1: input set :paramtype $1: SET OF any :param $a: description of this param :paramtype $a: int64 or str :return: array made of input set elements :returntype: array Return the array made from all of the input set elements. The ordering of the input set will be preserved if specified. A function can be referenced from anywhere in the documentation by using a ":eql:func:" role. For instance: * ":eql:func:`array_agg`"; * ":eql:func:`std::array_agg`"; * or, "look at this :eql:func:`fancy function `". Operators --------- Use ".. eql:operator::" directive to declare an operator. Supported fields: * ":optype NAME: TYPE" -- operand type. The first argument of the directive must be a string in the following format: "OPERATOR_ID: OPERATOR SIGNATURE". For instance, for a "+" operator it would be "PLUS: A + B": .. eql:operator:: PLUS: A + B :optype A: int64 or str or bytes :optype B: any :resulttype: any Arithmetic addition. To reference an operator use the :eql:op: role along with OPERATOR_ID: ":eql:op:`plus`" or ":eql:op:`+ `". Operator ID is case-insensitive. Statements ---------- Use ":eql-statement:" field for sections that describe a statement. A :eql-haswith: field should be used if the statement supports a WITH block. Example: SELECT ------ :eql-statement: :eql-haswith: SELECT is used to select stuff. .. eql:synopsis:: [WITH [MODULE name]] SELECT expr FILTER expr .. eql:clause:: FILTER: A FILTER B :paramtype A: any :paramtype B: SET OF any :returntype: any FILTER should be used to filter stuff. More paragraphs describing intricacies of SELECT go here... More paragraphs describing intricacies of SELECT go here... More paragraphs describing intricacies of SELECT go here... Notes: * To reference a statement use the ":eql:stmt:" role. For instance: - :eql:stmt:`SELECT` - :eql:stmt:`my fav statement